Merge branch 'master' into patch-1

This commit is contained in:
Niels Hofmans 2019-12-12 14:35:39 +01:00 committed by GitHub
commit 1f1c6a0b87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1373 changed files with 487044 additions and 128467 deletions

6
.gitignore vendored
View File

@ -9,4 +9,8 @@ guide/public
\#*\#
cscope.*
cloudflared
!cmd/cloudflared/
cloudflared.exe
!cmd/cloudflared/
.DS_Store
*-session.log
ssh_server_tests/.env

View File

@ -1,7 +1,9 @@
# use a builder image for building cloudflare
FROM golang:1.13.3 as builder
ENV GO111MODULE=on
ENV CGO_ENABLED=0
ENV GOOS=linux
# switch to the right gopath directory
WORKDIR /go/src/github.com/cloudflare/cloudflared/
# copy our sources into the builder image

611
Gopkg.lock generated
View File

@ -1,611 +0,0 @@
# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
[[projects]]
digest = "1:9f3b30d9f8e0d7040f729b82dcbc8f0dead820a133b3147ce355fc451f32d761"
name = "github.com/BurntSushi/toml"
packages = ["."]
pruneopts = "UT"
revision = "3012a1dbe2e4bd1391d42b32f0577cb7bbc7f005"
version = "v0.3.1"
[[projects]]
digest = "1:d6afaeed1502aa28e80a4ed0981d570ad91b2579193404256ce672ed0a609e0d"
name = "github.com/beorn7/perks"
packages = ["quantile"]
pruneopts = "UT"
revision = "4b2b341e8d7715fae06375aa633dbb6e91b3fb46"
version = "v1.0.0"
[[projects]]
digest = "1:fed1f537c2f1269fe475a8556c393fe466641682d73ef8fd0491cd3aa1e47bad"
name = "github.com/certifi/gocertifi"
packages = ["."]
pruneopts = "UT"
revision = "deb3ae2ef2610fde3330947281941c562861188b"
version = "2018.01.18"
[[projects]]
digest = "1:e5003c19d396d8b3cf1324ea0bf49b00f13e9466d0297d1268b641f1c617c3a2"
name = "github.com/cloudflare/brotli-go"
packages = ["."]
pruneopts = "T"
revision = "18c9f6c67e3dfc12e0ddaca748d2887f97a7ac28"
[[projects]]
digest = "1:6dbb2bbc7e6333e691c4d82fd86485f0695a35902fbb9b2df5f72e22ab0040f3"
name = "github.com/cloudflare/golibs"
packages = ["lrucache"]
pruneopts = "UT"
revision = "333127dbecfcc23a8db7d9a4f52785d23aff44a1"
[[projects]]
digest = "1:3f9506ee991cdee1f05bf0cd3e34b5cd922dc00d6a950fb4beb4e07ab1c4d3d1"
name = "github.com/coredns/coredns"
packages = [
"core/dnsserver",
"coremain",
"pb",
"plugin",
"plugin/cache",
"plugin/cache/freq",
"plugin/etcd/msg",
"plugin/metrics",
"plugin/metrics/vars",
"plugin/pkg/cache",
"plugin/pkg/dnstest",
"plugin/pkg/dnsutil",
"plugin/pkg/doh",
"plugin/pkg/edns",
"plugin/pkg/fuzz",
"plugin/pkg/log",
"plugin/pkg/nonwriter",
"plugin/pkg/rcode",
"plugin/pkg/response",
"plugin/pkg/trace",
"plugin/pkg/uniq",
"plugin/pkg/watch",
"plugin/test",
"request",
]
pruneopts = "UT"
revision = "2e322f6e8a54f18c6aef9c25a7c432c291a3d9f7"
version = "v1.2.0"
[[projects]]
digest = "1:6f70106e7bc1c803e8a0a4519e09c12d154771acfa2559206e97b033bbd1dd38"
name = "github.com/coreos/go-oidc"
packages = ["jose"]
pruneopts = "UT"
revision = "a93f71fdfe73d2c0f5413c0565eea0af6523a6df"
[[projects]]
digest = "1:1da3a221f0bc090792d3a2a080ff09008427c0e0f0533a4ed6abd8994421da73"
name = "github.com/coreos/go-systemd"
packages = ["daemon"]
pruneopts = "UT"
revision = "95778dfbb74eb7e4dbaf43bf7d71809650ef8076"
version = "v19"
[[projects]]
digest = "1:ffe9824d294da03b391f44e1ae8281281b4afc1bdaa9588c9097785e3af10cec"
name = "github.com/davecgh/go-spew"
packages = ["spew"]
pruneopts = "UT"
revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73"
version = "v1.1.1"
[[projects]]
branch = "master"
digest = "1:c013ffc6e15f9f898078f9d38441c68b228aa7b899659452170250ccb27f5f1e"
name = "github.com/elgs/gosqljson"
packages = ["."]
pruneopts = "UT"
revision = "027aa4915315a0b2825c0f025cea347829b974fa"
[[projects]]
digest = "1:d4268b2a09b1f736633577c4ac93f2a5356c73742fff5344e2451aeec60a7ad0"
name = "github.com/equinox-io/equinox"
packages = [
".",
"internal/go-update",
"internal/go-update/internal/binarydist",
"internal/go-update/internal/osext",
"internal/osext",
"proto",
]
pruneopts = "UT"
revision = "5205c98a6c11dc72747ce12fff6cd620a99fde05"
version = "v1.2.0"
[[projects]]
digest = "1:433763f10d88dba9b533a7ea2fe9f5ee11e57e00306eb97a1f6090fd978e8fa1"
name = "github.com/facebookgo/grace"
packages = ["gracenet"]
pruneopts = "UT"
revision = "75cf19382434e82df4dd84953f566b8ad23d6e9e"
[[projects]]
branch = "master"
digest = "1:50a46ab1d5edbbdd55125b4d37f1bf503d0807c26461f9ad7b358d6006641d09"
name = "github.com/flynn/go-shlex"
packages = ["."]
pruneopts = "UT"
revision = "3f9db97f856818214da2e1057f8ad84803971cff"
[[projects]]
digest = "1:d4623fc7bf7e281d9107367cc4a9e76ed3e86b1eec1a4e30630c870bef1fedd0"
name = "github.com/getsentry/raven-go"
packages = ["."]
pruneopts = "UT"
revision = "ed7bcb39ff10f39ab08e317ce16df282845852fa"
[[projects]]
branch = "master"
digest = "1:3e6afc3ed8a72949aa735c00fddc23427dc9384ccfd51cf0d91a412e668da632"
name = "github.com/golang-collections/collections"
packages = ["queue"]
pruneopts = "UT"
revision = "604e922904d35e97f98a774db7881f049cd8d970"
[[projects]]
digest = "1:239c4c7fd2159585454003d9be7207167970194216193a8a210b8d29576f19c9"
name = "github.com/golang/protobuf"
packages = [
"proto",
"ptypes",
"ptypes/any",
"ptypes/duration",
"ptypes/timestamp",
]
pruneopts = "UT"
revision = "b5d812f8a3706043e23a9cd5babf2e5423744d30"
version = "v1.3.1"
[[projects]]
digest = "1:582b704bebaa06b48c29b0cec224a6058a09c86883aaddabde889cd1a5f73e1b"
name = "github.com/google/uuid"
packages = ["."]
pruneopts = "UT"
revision = "0cd6bf5da1e1c83f8b45653022c74f71af0538a4"
version = "v1.1.1"
[[projects]]
digest = "1:d5f97fc268267ec1b61c3453058c738246fc3e746f14b1ae25161513b7367b0c"
name = "github.com/gorilla/mux"
packages = ["."]
pruneopts = "UT"
revision = "c5c6c98bc25355028a63748a498942a6398ccd22"
version = "v1.7.1"
[[projects]]
digest = "1:43dd08a10854b2056e615d1b1d22ac94559d822e1f8b6fcc92c1a1057e85188e"
name = "github.com/gorilla/websocket"
packages = ["."]
pruneopts = "UT"
revision = "ea4d1f681babbce9545c9c5f3d5194a789c89f5b"
version = "v1.2.0"
[[projects]]
branch = "master"
digest = "1:1a1206efd03a54d336dce7bb8719e74f2f8932f661cb9f57d5813a1d99c083d8"
name = "github.com/grpc-ecosystem/grpc-opentracing"
packages = ["go/otgrpc"]
pruneopts = "UT"
revision = "8e809c8a86450a29b90dcc9efbf062d0fe6d9746"
[[projects]]
digest = "1:31e761d97c76151dde79e9d28964a812c46efc5baee4085b86f68f0c654450de"
name = "github.com/konsorten/go-windows-terminal-sequences"
packages = ["."]
pruneopts = "UT"
revision = "f55edac94c9bbba5d6182a4be46d86a2c9b5b50e"
version = "v1.0.2"
[[projects]]
digest = "1:bc1c0be40c67b6b4aee09d7508d5a2a52c1c116b1fa43806dad2b0d6b4d4003b"
name = "github.com/lib/pq"
packages = [
".",
"oid",
"scram",
]
pruneopts = "UT"
revision = "51e2106eed1cea199c802d2a49e91e2491b02056"
version = "v1.1.0"
[[projects]]
digest = "1:2fa7b0155cd54479a755c629de26f888a918e13f8857a2c442205d825368e084"
name = "github.com/mattn/go-colorable"
packages = ["."]
pruneopts = "UT"
revision = "3a70a971f94a22f2fa562ffcc7a0eb45f5daf045"
version = "v0.1.1"
[[projects]]
digest = "1:e150b5fafbd7607e2d638e4e5cf43aa4100124e5593385147b0a74e2733d8b0d"
name = "github.com/mattn/go-isatty"
packages = ["."]
pruneopts = "UT"
revision = "c2a7a6ca930a4cd0bc33a3f298eb71960732a3a7"
version = "v0.0.7"
[[projects]]
digest = "1:ff5ebae34cfbf047d505ee150de27e60570e8c394b3b8fdbb720ff6ac71985fc"
name = "github.com/matttproud/golang_protobuf_extensions"
packages = ["pbutil"]
pruneopts = "UT"
revision = "c12348ce28de40eed0136aa2b644d0ee0650e56c"
version = "v1.0.1"
[[projects]]
digest = "1:75fa16a231ef40da3e462d651c20b9df20bde0777bdc1ac0982242c79057ee71"
name = "github.com/mholt/caddy"
packages = [
".",
"caddyfile",
"telemetry",
]
pruneopts = "UT"
revision = "d3b731e9255b72d4571a5aac125634cf1b6031dc"
[[projects]]
digest = "1:2b4b4b2e5544c2a11a486c1b631357aa2ddf766e50c1b2483cf809da2c511234"
name = "github.com/miekg/dns"
packages = ["."]
pruneopts = "UT"
revision = "73601d4aed9d844322611759d7f3619110b7c88e"
version = "v1.1.8"
[[projects]]
digest = "1:5d231480e1c64a726869bc4142d270184c419749d34f167646baa21008eb0a79"
name = "github.com/mitchellh/go-homedir"
packages = ["."]
pruneopts = "UT"
revision = "af06845cf3004701891bf4fdb884bfe4920b3727"
version = "v1.1.0"
[[projects]]
digest = "1:53bc4cd4914cd7cd52139990d5170d6dc99067ae31c56530621b18b35fc30318"
name = "github.com/mitchellh/mapstructure"
packages = ["."]
pruneopts = "UT"
revision = "3536a929edddb9a5b34bd6861dc4a9647cb459fe"
version = "v1.1.2"
[[projects]]
digest = "1:11e62d6050198055e6cd87ed57e5d8c669e84f839c16e16f192374d913d1a70d"
name = "github.com/opentracing/opentracing-go"
packages = [
".",
"ext",
"log",
]
pruneopts = "UT"
revision = "659c90643e714681897ec2521c60567dd21da733"
version = "v1.1.0"
[[projects]]
digest = "1:40e195917a951a8bf867cd05de2a46aaf1806c50cf92eebf4c16f78cd196f747"
name = "github.com/pkg/errors"
packages = ["."]
pruneopts = "UT"
revision = "645ef00459ed84a119197bfb8d8205042c6df63d"
version = "v0.8.0"
[[projects]]
digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe"
name = "github.com/pmezard/go-difflib"
packages = ["difflib"]
pruneopts = "UT"
revision = "792786c7400a136282c1664665ae0a8db921c6c2"
version = "v1.0.0"
[[projects]]
digest = "1:c968b29db5d68ec97de404b6d058d5937fa015a141b3b4f7a0d87d5f8226f04c"
name = "github.com/prometheus/client_golang"
packages = [
"prometheus",
"prometheus/promhttp",
]
pruneopts = "UT"
revision = "967789050ba94deca04a5e84cce8ad472ce313c1"
version = "v0.9.0-pre1"
[[projects]]
branch = "master"
digest = "1:2d5cd61daa5565187e1d96bae64dbbc6080dacf741448e9629c64fd93203b0d4"
name = "github.com/prometheus/client_model"
packages = ["go"]
pruneopts = "UT"
revision = "fd36f4220a901265f90734c3183c5f0c91daa0b8"
[[projects]]
digest = "1:35cf6bdf68db765988baa9c4f10cc5d7dda1126a54bd62e252dbcd0b1fc8da90"
name = "github.com/prometheus/common"
packages = [
"expfmt",
"internal/bitbucket.org/ww/goautoneg",
"model",
]
pruneopts = "UT"
revision = "a82f4c12f983cc2649298185f296632953e50d3e"
version = "v0.3.0"
[[projects]]
branch = "master"
digest = "1:49b09905e781d7775c086604cc00083e1832d0783f1f421b79f42657c457d029"
name = "github.com/prometheus/procfs"
packages = ["."]
pruneopts = "UT"
revision = "8368d24ba045f26503eb745b624d930cbe214c79"
[[projects]]
digest = "1:1a23fdd843129ef761ffe7651bc5fe7c5b09fbe933e92783ab06cc11c37b7b37"
name = "github.com/rifflock/lfshook"
packages = ["."]
pruneopts = "UT"
revision = "b9218ef580f59a2e72dad1aa33d660150445d05a"
version = "v2.4"
[[projects]]
digest = "1:04457f9f6f3ffc5fea48e71d62f2ca256637dee0a04d710288e27e05c8b41976"
name = "github.com/sirupsen/logrus"
packages = ["."]
pruneopts = "UT"
revision = "839c75faf7f98a33d445d181f3018b5c3409a45e"
version = "v1.4.2"
[[projects]]
digest = "1:7e8d267900c7fa7f35129a2a37596e38ed0f11ca746d6d9ba727980ee138f9f6"
name = "github.com/stretchr/testify"
packages = [
"assert",
"require",
]
pruneopts = "UT"
revision = "12b6f73e6084dad08a7c6e575284b177ecafbc71"
version = "v1.2.1"
[[projects]]
branch = "master"
digest = "1:a84d5ec8b40a827962ea250f2cf03434138ccae9d83fcac12fb49b70c70b80cc"
name = "golang.org/x/crypto"
packages = [
"curve25519",
"ed25519",
"ed25519/internal/edwards25519",
"internal/chacha20",
"internal/subtle",
"nacl/box",
"nacl/secretbox",
"poly1305",
"salsa20/salsa",
"ssh",
"ssh/terminal",
]
pruneopts = "UT"
revision = "f416ebab96af27ca70b6e5c23d6a0747530da626"
[[projects]]
branch = "master"
digest = "1:52d140f7ab52e491cc1cbc93e6637aa5e9a7f3beae7545d675b02e52ca9d7290"
name = "golang.org/x/net"
packages = [
"bpf",
"context",
"http/httpguts",
"http2",
"http2/hpack",
"idna",
"internal/iana",
"internal/socket",
"internal/timeseries",
"ipv4",
"ipv6",
"trace",
"websocket",
]
pruneopts = "UT"
revision = "1da14a5a36f220ea3f03470682b737b1dfd5de22"
[[projects]]
digest = "1:39ebcc2b11457b703ae9ee2e8cca0f68df21969c6102cb3b705f76cca0ea0239"
name = "golang.org/x/sync"
packages = ["errgroup"]
pruneopts = "UT"
revision = "1d60e4601c6fd243af51cc01ddf169918a5407ca"
[[projects]]
branch = "master"
digest = "1:77751d02e939d7078faedaeec10c09af575a09c528d84d18f2cb45a84bd1889a"
name = "golang.org/x/sys"
packages = [
"cpu",
"unix",
"windows",
"windows/registry",
"windows/svc",
"windows/svc/eventlog",
"windows/svc/mgr",
]
pruneopts = "UT"
revision = "12500544f89f9420afe9529ba8940bf72d294972"
[[projects]]
digest = "1:a2ab62866c75542dd18d2b069fec854577a20211d7c0ea6ae746072a1dccdd18"
name = "golang.org/x/text"
packages = [
"collate",
"collate/build",
"internal/colltab",
"internal/gen",
"internal/tag",
"internal/triegen",
"internal/ucd",
"language",
"secure/bidirule",
"transform",
"unicode/bidi",
"unicode/cldr",
"unicode/norm",
"unicode/rangetable",
]
pruneopts = "UT"
revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0"
version = "v0.3.0"
[[projects]]
branch = "master"
digest = "1:c3076e7defee87de1236f1814beb588f40a75544c60121e6eb38b3b3721783e2"
name = "google.golang.org/genproto"
packages = ["googleapis/rpc/status"]
pruneopts = "UT"
revision = "d1146b9035b912113a38af3b138eb2af567b2c67"
[[projects]]
digest = "1:31d87f39886fb38a2b6c097ff3b9f985d6960772170d64a68246f7790e955746"
name = "google.golang.org/grpc"
packages = [
".",
"balancer",
"balancer/base",
"balancer/roundrobin",
"binarylog/grpc_binarylog_v1",
"codes",
"connectivity",
"credentials",
"credentials/internal",
"encoding",
"encoding/proto",
"grpclog",
"internal",
"internal/backoff",
"internal/balancerload",
"internal/binarylog",
"internal/channelz",
"internal/envconfig",
"internal/grpcrand",
"internal/grpcsync",
"internal/syscall",
"internal/transport",
"keepalive",
"metadata",
"naming",
"peer",
"resolver",
"resolver/dns",
"resolver/passthrough",
"stats",
"status",
"tap",
]
pruneopts = "UT"
revision = "236199dd5f8031d698fb64091194aecd1c3895b2"
version = "v1.20.0"
[[projects]]
branch = "altsrc-parse-durations"
digest = "1:0370b1bceda03dbfade3abbde639a43f1113bab711ec760452e5c0dcc0c14787"
name = "gopkg.in/urfave/cli.v2"
packages = [
".",
"altsrc",
]
pruneopts = "UT"
revision = "d604b6ffeee878fbf084fd2761466b6649989cee"
source = "https://github.com/cbranch/cli"
[[projects]]
digest = "1:4d2e5a73dc1500038e504a8d78b986630e3626dc027bc030ba5c75da257cdb96"
name = "gopkg.in/yaml.v2"
packages = ["."]
pruneopts = "UT"
revision = "51d6538a90f86fe93ac480b35f37b2be17fef232"
version = "v2.2.2"
[[projects]]
digest = "1:8ffc3ddc31414c0a71220957bb723b16510d7fcb5b3880dc0da4cf6d39c31642"
name = "zombiezen.com/go/capnproto2"
packages = [
".",
"encoding/text",
"internal/fulfiller",
"internal/nodemap",
"internal/packed",
"internal/queue",
"internal/schema",
"internal/strquote",
"pogs",
"rpc",
"rpc/internal/refcount",
"schemas",
"server",
"std/capnp/rpc",
]
pruneopts = "UT"
revision = "7cfd211c19c7f5783c695f3654efa46f0df259c3"
source = "https://github.com/zombiezen/go-capnproto2"
version = "v2.17.1"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
input-imports = [
"github.com/cloudflare/brotli-go",
"github.com/cloudflare/golibs/lrucache",
"github.com/coredns/coredns/core/dnsserver",
"github.com/coredns/coredns/plugin",
"github.com/coredns/coredns/plugin/cache",
"github.com/coredns/coredns/plugin/metrics/vars",
"github.com/coredns/coredns/plugin/pkg/dnstest",
"github.com/coredns/coredns/plugin/pkg/rcode",
"github.com/coredns/coredns/request",
"github.com/coreos/go-oidc/jose",
"github.com/coreos/go-systemd/daemon",
"github.com/elgs/gosqljson",
"github.com/equinox-io/equinox",
"github.com/facebookgo/grace/gracenet",
"github.com/getsentry/raven-go",
"github.com/golang-collections/collections/queue",
"github.com/google/uuid",
"github.com/gorilla/mux",
"github.com/gorilla/websocket",
"github.com/lib/pq",
"github.com/mattn/go-colorable",
"github.com/miekg/dns",
"github.com/mitchellh/go-homedir",
"github.com/mitchellh/mapstructure",
"github.com/pkg/errors",
"github.com/prometheus/client_golang/prometheus",
"github.com/prometheus/client_golang/prometheus/promhttp",
"github.com/rifflock/lfshook",
"github.com/sirupsen/logrus",
"github.com/stretchr/testify/assert",
"github.com/stretchr/testify/require",
"golang.org/x/crypto/nacl/box",
"golang.org/x/crypto/ssh",
"golang.org/x/crypto/ssh/terminal",
"golang.org/x/net/context",
"golang.org/x/net/http2",
"golang.org/x/net/http2/hpack",
"golang.org/x/net/idna",
"golang.org/x/net/trace",
"golang.org/x/net/websocket",
"golang.org/x/sync/errgroup",
"golang.org/x/sys/windows",
"golang.org/x/sys/windows/svc",
"golang.org/x/sys/windows/svc/eventlog",
"golang.org/x/sys/windows/svc/mgr",
"gopkg.in/urfave/cli.v2",
"gopkg.in/urfave/cli.v2/altsrc",
"zombiezen.com/go/capnproto2",
"zombiezen.com/go/capnproto2/encoding/text",
"zombiezen.com/go/capnproto2/pogs",
"zombiezen.com/go/capnproto2/rpc",
"zombiezen.com/go/capnproto2/schemas",
"zombiezen.com/go/capnproto2/server",
"zombiezen.com/go/capnproto2/std/capnp/rpc",
]
solver-name = "gps-cdcl"
solver-version = 1

View File

@ -1,91 +0,0 @@
[prune]
go-tests = true
unused-packages = true
[[prune.project]]
name = "github.com/cloudflare/brotli-go"
unused-packages = false
[[constraint]]
name = "github.com/facebookgo/grace"
revision = "75cf19382434e82df4dd84953f566b8ad23d6e9e"
[[constraint]]
name = "github.com/getsentry/raven-go"
revision = "ed7bcb39ff10f39ab08e317ce16df282845852fa"
[[constraint]]
name = "github.com/pkg/errors"
version = "=0.8.0"
[[constraint]]
name = "github.com/prometheus/client_golang"
version = "=0.9.0-pre1"
[[constraint]]
name = "github.com/sirupsen/logrus"
version = "=1.4.2"
[[constraint]]
name = "github.com/stretchr/testify"
version = "=1.2.1"
[[constraint]]
name = "golang.org/x/net"
branch = "master" # master required by github.com/miekg/dns
[[constraint]]
name = "golang.org/x/sync"
revision = "1d60e4601c6fd243af51cc01ddf169918a5407ca"
[[constraint]]
name = "gopkg.in/urfave/cli.v2"
source = "https://github.com/cbranch/cli"
branch = "altsrc-parse-durations"
[[constraint]]
name = "zombiezen.com/go/capnproto2"
source = "https://github.com/zombiezen/go-capnproto2"
version = "=2.17.1"
[[constraint]]
name = "github.com/gorilla/websocket"
version = "=1.2.0"
[[constraint]]
name = "github.com/coredns/coredns"
version = "=1.2.0"
[[constraint]]
name = "github.com/miekg/dns"
version = "=1.1.8"
[[constraint]]
name = "github.com/cloudflare/brotli-go"
revision = "18c9f6c67e3dfc12e0ddaca748d2887f97a7ac28"
[[override]]
name = "github.com/mholt/caddy"
revision = "d3b731e9255b72d4571a5aac125634cf1b6031dc"
[[constraint]]
name = "github.com/coreos/go-oidc"
revision = "a93f71fdfe73d2c0f5413c0565eea0af6523a6df"
[[constraint]]
name = "golang.org/x/crypto"
branch = "master" # master required by github.com/miekg/dns
[[constraint]]
name = "github.com/cloudflare/golibs"
revision = "333127dbecfcc23a8db7d9a4f52785d23aff44a1"
[[constraint]]
name = "github.com/google/uuid"
version = "=1.1.1"
[[constraint]]
name = "github.com/mitchellh/mapstructure"
version = "1.1.2"

View File

@ -29,7 +29,7 @@ clean:
.PHONY: cloudflared
cloudflared: tunnel-deps
go build -v $(VERSION_FLAGS) $(IMPORT_PATH)/cmd/cloudflared
go build -v -mod=vendor $(VERSION_FLAGS) $(IMPORT_PATH)/cmd/cloudflared
.PHONY: container
container:
@ -37,7 +37,11 @@ container:
.PHONY: test
test: vet
go test -v -race $(VERSION_FLAGS) ./...
go test -v -mod=vendor -race $(VERSION_FLAGS) ./...
.PHONY: test-ssh-server
test-ssh-server:
docker-compose -f ssh_server_tests/docker-compose.yml up
.PHONY: cloudflared-deb
cloudflared-deb: cloudflared
@ -78,6 +82,6 @@ tunnelrpc/tunnelrpc.capnp.go: tunnelrpc/tunnelrpc.capnp
.PHONY: vet
vet:
go vet ./...
go vet -mod=vendor ./...
which go-sumtype # go get github.com/BurntSushi/go-sumtype
go-sumtype $$(go list ./...)
go-sumtype $$(go list -mod=vendor ./...)

View File

@ -1,3 +1,103 @@
2019.11.3
- 2019-11-20 TUN-2562: Update Cloudflare Origin CA RSA root
2019.11.2
- 2019-11-18 TUN-2567: AuthOutcome can be turned back into AuthResponse
- 2019-11-18 TUN-2563: Exposes config_version metrics
2019.11.1
- 2019-11-12 Add db-connect, a SQL over HTTPS server
- 2019-11-12 TUN-2053: Add a /healthcheck endpoint to the metrics server
- 2019-11-13 TUN-2178: public API to create new h2mux.MuxedStreamRequest
- 2019-11-13 TUN-2490: respect original representation of HTTP request path
- 2019-11-18 TUN-2547: TunnelRPC definitions for Authenticate flow
- 2019-11-18 TUN-2551: TunnelRPC definitions for ReconnectTunnel flow
- 2019-11-05 TUN-2506: Expose active streams metrics
2019.11.0
- 2019-11-04 TUN-2502: Switch to go modules
- 2019-11-04 TUN-2500: Don't send client registration errors to Sentry
- 2019-11-04 TUN-2489: Delete stream from activestreammap when read and write are both closed
- 2019-11-05 TUN-2505: Terminate stream on receipt of RST_STREAM; MuxedStream.CloseWrite() should terminate the MuxedStream.Write() loop
- 2019-10-30 TUN-2451: Log inavlid path
- 2019-10-22 TUN-2425: Enable cloudflared to serve multiple Hello World servers by having each of them create its own ServeMux
- 2019-10-22 AUTH-2173: Prepends access login url with scheme if one doesnt exist
- 2019-10-23 TUN-2460: Configure according to the ClientConfig recevied from a successful Connect
- 2019-10-23 AUTH-2177: Reads and writes error streams
2019.10.4
- 2019-10-21 TUN-2450: Remove Brew publishing formula
2019.10.3
- 2019-10-18 Fix #129: Excessive memory usage streaming large files (#142)
2019.10.2
- 2019-10-17 AUTH-2167: Adds CLI option for host key directory
2019.10.1
- 2019-10-17 Adds variable to fix windows build
2019.10.0
- 2019-10-11 AUTH-2105: Dont require --destination arg
- 2019-10-14 TUN-2344: log more details: http2.Framer.ErrorDetail() if available, connectionID
- 2019-10-16 AUTH-2159: Moves shutdownC close into error handling AUTH-2161: Lowers size of preamble length AUTH-2160: Fixes url parsing logic
- 2019-10-16 AUTH-2135: Adds support for IPv6 and tests
- 2019-10-02 AUTH-2105: Adds support for local forwarding. Refactor auditlogger creation. AUTH-2088: Adds dynamic destination routing
- 2019-10-09 AUTH-2114: Uses short lived cert auth for outgoing client connection
- 2019-09-30 AUTH-2089: Revise ssh server to function as a proxy
2019.9.2
- 2019-09-26 TUN-2355: Roll back TUN-2276
2019.9.1
- 2019-09-23 TUN-2334: remove tlsConfig.ServerName special case
- 2019-09-23 AUTH-2077: Quotes open browser command in windows
- 2019-09-11 AUTH-2050: Adds time.sleep to temporarily avoid hitting tunnel muxer dealock issue
- 2019-09-10 AUTH-2056: Writes stderr to its own stream for non-pty connections
- 2019-09-16 TUN-2307: Capnp is the only serialization format used in tunnelpogs
- 2019-09-18 TUN-2315: Replace Scope with IntentLabel
- 2019-09-17 TUN-2309: Split ConnectResult into ConnectError and ConnectSuccess, each implementing its own capnp serialization logic
- 2019-09-18 AUTH-2052: Adds tests for SSH server
- 2019-09-18 AUTH-2067: Log commands correctly
- 2019-09-19 AUTH-2055: Verifies token at edge on access login
- 2019-09-04 TUN-2201: change SRV records used by cloudflared
- 2019-09-06 TUN-2280: Revert "TUN-2260: add name/group to CapnpConnectParameters, remove Scope"
- 2019-09-03 AUTH-1943 hooked up uploader to logger, added timestamp to session logs, add tests
- 2019-09-04 AUTH-2036: Refactor user retrieval, shutdown after ssh server stops, add custom version string
- 2019-09-06 AUTH-1942 added event log to ssh server
- 2019-09-04 AUTH-2037: Adds support for ssh port forwarding
- 2019-09-05 TUN-2276: Path encoding broken
2019.9.0
- 2019-09-05 TUN-2279: Revert path encoding fix
- 2019-08-30 AUTH-2021 - check error for failing tests
- 2019-08-29 AUTH-2030: Support both authorized_key and short lived cert authentication simultaniously without specifiying at start time
- 2019-08-29 AUTH-2026: Adds support for non-pty sessions and inline command exec
- 2019-08-26 AUTH-1943: Adds session logging
- 2019-08-26 TUN-2162: Decomplect OpenStream to allow finer-grained timeouts
- 2019-08-29 TUN-2260: add name/group to CapnpConnectParameters, remove Scope
2019.8.4
- 2019-08-30 Fix #111: Add support for specifying a specific HTTP Host: header on the origin. (#114)
- 2019-08-22 TUN-2165: Add ClientConfig to tunnelrpc.ConnectResult
- 2019-08-20 AUTH-2014: Checks users login shell
- 2019-08-26 TUN-2243: Revert "STOR-519: Add db-connect, a SQL over HTTPS server"
- 2019-08-27 TUN-2244: Add NO_AUTOUPDATE env var
- 2019-08-22 AUTH-2018: Adds support for authorized keys and short lived certs
- 2019-08-28 AUTH-2022: Adds ssh timeout configuration
- 2019-08-28 TUN-1968: Gracefully diff StreamHandler.UpdateConfig
- 2019-08-26 AUTH-2021 - s3 bucket uploading for SSH logs
- 2019-08-19 AUTH-2004: Adds static host key support
- 2019-07-18 AUTH-1941: Adds initial SSH server implementation
2019.8.3
- 2019-08-20 STOR-519: Add db-connect, a SQL over HTTPS server
- 2019-08-20 Release 2019.8.2
- 2019-08-20 Revert "AUTH-1941: Adds initial SSH server implementation"
- 2019-08-11 TUN-2163: Add GrapQLType method to Scope interface
- 2019-08-06 TUN-2152: Requests with a query in the URL are erroneously escaped
- 2019-07-18 AUTH-1941: Adds initial SSH server implementation
2019.8.1
- 2019-08-05 TUN-2111: Implement custom serialization logic for FallibleConfig and OriginConfig
- 2019-08-06 Revert "TUN-1736: Missing headers when passing an invalid path"

View File

@ -0,0 +1,106 @@
package awsuploader
import (
"os"
"path/filepath"
"time"
"github.com/sirupsen/logrus"
)
// DirectoryUploadManager is used to manage file uploads on an interval from a directory
type DirectoryUploadManager struct {
logger *logrus.Logger
uploader Uploader
rootDirectory string
sweepInterval time.Duration
ticker *time.Ticker
shutdownC chan struct{}
workQueue chan string
}
// NewDirectoryUploadManager create a new DirectoryUploadManager
// uploader is an Uploader to use as an actual uploading engine
// directory is the directory to sweep for files to upload
// sweepInterval is how often to iterate the directory and upload the files within
func NewDirectoryUploadManager(logger *logrus.Logger, uploader Uploader, directory string, sweepInterval time.Duration, shutdownC chan struct{}) *DirectoryUploadManager {
workerCount := 10
manager := &DirectoryUploadManager{
logger: logger,
uploader: uploader,
rootDirectory: directory,
sweepInterval: sweepInterval,
shutdownC: shutdownC,
workQueue: make(chan string, workerCount),
}
//start workers
for i := 0; i < workerCount; i++ {
go manager.worker()
}
return manager
}
// Upload a file using the uploader
// This is useful for "out of band" uploads that need to be triggered immediately instead of waiting for the sweep
func (m *DirectoryUploadManager) Upload(filepath string) error {
return m.uploader.Upload(filepath)
}
// Start the upload ticker to walk the directories
func (m *DirectoryUploadManager) Start() {
m.ticker = time.NewTicker(m.sweepInterval)
go m.run()
}
func (m *DirectoryUploadManager) run() {
for {
select {
case <-m.shutdownC:
m.ticker.Stop()
return
case <-m.ticker.C:
m.sweep()
}
}
}
// sweep the directory and kick off uploads
func (m *DirectoryUploadManager) sweep() {
filepath.Walk(m.rootDirectory, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() {
return nil
}
//30 days ago
retentionTime := 30 * (time.Hour * 24)
checkTime := time.Now().Add(-time.Duration(retentionTime))
//delete the file it is stale
if info.ModTime().Before(checkTime) {
os.Remove(path)
return nil
}
//add the upload to the work queue
go func() {
m.workQueue <- path
}()
return nil
})
}
// worker handles upload requests
func (m *DirectoryUploadManager) worker() {
for {
select {
case <-m.shutdownC:
return
case filepath := <-m.workQueue:
if err := m.Upload(filepath); err != nil {
m.logger.WithError(err).Error("Cannot upload file to s3 bucket")
} else {
os.Remove(filepath)
}
}
}
}

View File

@ -0,0 +1,137 @@
package awsuploader
import (
"errors"
"io/ioutil"
"math/rand"
"os"
"path/filepath"
"testing"
"time"
"github.com/sirupsen/logrus"
)
type MockUploader struct {
shouldFail bool
}
func (m *MockUploader) Upload(filepath string) error {
if m.shouldFail {
return errors.New("upload set to fail")
}
return nil
}
func NewMockUploader(shouldFail bool) Uploader {
return &MockUploader{shouldFail: shouldFail}
}
func getDirectoryPath(t *testing.T) string {
dir, err := os.Getwd()
if err != nil {
t.Fatal("couldn't create the test directory!", err)
}
return filepath.Join(dir, "uploads")
}
func setupTestDirectory(t *testing.T) string {
path := getDirectoryPath(t)
os.RemoveAll(path)
time.Sleep(100 * time.Millisecond) //short way to wait for the OS to delete the folder
err := os.MkdirAll(path, os.ModePerm)
if err != nil {
t.Fatal("couldn't create the test directory!", err)
}
return path
}
func createUploadManager(t *testing.T, shouldFailUpload bool) *DirectoryUploadManager {
rootDirectory := setupTestDirectory(t)
uploader := NewMockUploader(shouldFailUpload)
logger := logrus.New()
shutdownC := make(chan struct{})
return NewDirectoryUploadManager(logger, uploader, rootDirectory, 1*time.Second, shutdownC)
}
func createFile(t *testing.T, fileName string) (*os.File, string) {
path := filepath.Join(getDirectoryPath(t), fileName)
f, err := os.Create(path)
if err != nil {
t.Fatal("upload to create file for sweep test", err)
}
return f, path
}
func TestUploadSuccess(t *testing.T) {
manager := createUploadManager(t, false)
path := filepath.Join(getDirectoryPath(t), "test_file")
if err := manager.Upload(path); err != nil {
t.Fatal("the upload request method failed", err)
}
}
func TestUploadFailure(t *testing.T) {
manager := createUploadManager(t, true)
path := filepath.Join(getDirectoryPath(t), "test_file")
if err := manager.Upload(path); err == nil {
t.Fatal("the upload request method should have failed and didn't", err)
}
}
func TestSweepSuccess(t *testing.T) {
manager := createUploadManager(t, false)
f, path := createFile(t, "test_file")
defer f.Close()
manager.Start()
time.Sleep(2 * time.Second)
if _, err := os.Stat(path); os.IsExist(err) {
//the file should have been deleted
t.Fatal("the manager failed to delete the file", err)
}
}
func TestSweepFailure(t *testing.T) {
manager := createUploadManager(t, true)
f, path := createFile(t, "test_file")
defer f.Close()
manager.Start()
time.Sleep(2 * time.Second)
_, serr := f.Stat()
if serr != nil {
//the file should still exist
os.Remove(path)
t.Fatal("the manager failed to delete the file", serr)
}
}
func TestHighLoad(t *testing.T) {
manager := createUploadManager(t, false)
for i := 0; i < 30; i++ {
f, _ := createFile(t, randomString(6))
defer f.Close()
}
manager.Start()
time.Sleep(4 * time.Second)
directory := getDirectoryPath(t)
files, err := ioutil.ReadDir(directory)
if err != nil || len(files) > 0 {
t.Fatalf("the manager failed to upload all the files: %s files left: %d", err, len(files))
}
}
// LowerCase [a-z]
const randSet = "abcdefghijklmnopqrstuvwxyz"
// String returns a string of length 'n' from a set of letters 'lset'
func randomString(n int) string {
b := make([]byte, n)
lsetLen := len(randSet)
for i := range b {
b[i] = randSet[rand.Intn(lsetLen)]
}
return string(b)
}

View File

@ -0,0 +1,62 @@
package awsuploader
import (
"context"
"os"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
)
//FileUploader aws compliant bucket upload
type FileUploader struct {
storage *s3.S3
bucketName string
clientID string
secretID string
}
// NewFileUploader creates a new S3 compliant bucket uploader
func NewFileUploader(bucketName, region, accessKeyID, secretID, token, s3Host string) (*FileUploader, error) {
sess, err := session.NewSession(&aws.Config{
Region: aws.String(region),
Credentials: credentials.NewStaticCredentials(accessKeyID, secretID, token),
})
if err != nil {
return nil, err
}
var storage *s3.S3
if s3Host != "" {
storage = s3.New(sess, &aws.Config{Endpoint: aws.String(s3Host)})
} else {
storage = s3.New(sess)
}
return &FileUploader{
storage: storage,
bucketName: bucketName,
}, nil
}
// Upload a file to the bucket
func (u *FileUploader) Upload(filepath string) error {
info, err := os.Stat(filepath)
if err != nil {
return err
}
file, err := os.Open(filepath)
if err != nil {
return err
}
defer file.Close()
_, serr := u.storage.PutObjectWithContext(context.Background(), &s3.PutObjectInput{
Bucket: aws.String(u.bucketName),
Key: aws.String(info.Name()),
Body: file,
})
return serr
}

View File

@ -0,0 +1,7 @@
package awsuploader
// UploadManager is used to manage file uploads on an interval
type UploadManager interface {
Upload(string) error
Start()
}

7
awsuploader/uploader.go Normal file
View File

@ -0,0 +1,7 @@
package awsuploader
// Uploader the functions required to upload to a bucket
type Uploader interface {
//Upload a file to the bucket
Upload(string) error
}

View File

@ -114,7 +114,7 @@ func createWebsocketStream(options *StartOptions) (*cloudflaredWebsocket.Conn, e
wsConn, resp, err := cloudflaredWebsocket.ClientConnect(req, nil)
defer closeRespBody(resp)
if err != nil && isAccessResponse(resp) {
if err != nil && IsAccessResponse(resp) {
wsConn, err = createAccessAuthenticatedStream(options)
if err != nil {
return nil, err
@ -126,10 +126,10 @@ func createWebsocketStream(options *StartOptions) (*cloudflaredWebsocket.Conn, e
return &cloudflaredWebsocket.Conn{Conn: wsConn}, nil
}
// isAccessResponse checks the http Response to see if the url location
// IsAccessResponse checks the http Response to see if the url location
// contains the Access structure.
func isAccessResponse(resp *http.Response) bool {
if resp == nil || resp.StatusCode <= 300 {
func IsAccessResponse(resp *http.Response) bool {
if resp == nil || resp.StatusCode != http.StatusFound {
return false
}
@ -156,7 +156,7 @@ func createAccessAuthenticatedStream(options *StartOptions) (*websocket.Conn, er
return wsConn, nil
}
if !isAccessResponse(resp) {
if !IsAccessResponse(resp) {
return nil, err
}
@ -179,7 +179,7 @@ func createAccessAuthenticatedStream(options *StartOptions) (*websocket.Conn, er
// createAccessWebSocketStream builds an Access request and makes a connection
func createAccessWebSocketStream(options *StartOptions) (*websocket.Conn, *http.Response, error) {
req, err := buildAccessRequest(options)
req, err := BuildAccessRequest(options)
if err != nil {
return nil, nil, err
}
@ -187,7 +187,7 @@ func createAccessWebSocketStream(options *StartOptions) (*websocket.Conn, *http.
}
// buildAccessRequest builds an HTTP request with the Access token set
func buildAccessRequest(options *StartOptions) (*http.Request, error) {
func BuildAccessRequest(options *StartOptions) (*http.Request, error) {
req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
if err != nil {
return nil, err

View File

@ -102,14 +102,14 @@ func TestIsAccessResponse(t *testing.T) {
ExpectedOut bool
}{
{"nil response", nil, false},
{"redirect with no location", &http.Response{StatusCode: http.StatusPermanentRedirect}, false},
{"redirect with no location", &http.Response{StatusCode: http.StatusFound}, false},
{"200 ok", &http.Response{StatusCode: http.StatusOK}, false},
{"redirect with location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: validLocationHeader}, true},
{"redirect with invalid location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: invalidLocationHeader}, false},
{"redirect with location", &http.Response{StatusCode: http.StatusFound, Header: validLocationHeader}, true},
{"redirect with invalid location", &http.Response{StatusCode: http.StatusFound, Header: invalidLocationHeader}, false},
}
for i, tc := range testCases {
if isAccessResponse(tc.In) != tc.ExpectedOut {
if IsAccessResponse(tc.In) != tc.ExpectedOut {
t.Fatalf("Failed case %d -- %s", i, tc.Description)
}
}

View File

@ -1,5 +1,5 @@
pinned_go: &pinned_go go=1.12.7-1
build_dir: &build_dir /cfsetup_build/src/github.com/cloudflare/cloudflared/
build_dir: &build_dir /cfsetup_build
default-flavor: stretch
stretch: &stretch
build:
@ -8,7 +8,6 @@ stretch: &stretch
- *pinned_go
- build-essential
post-cache:
- export GOPATH=/cfsetup_build/
- export GOOS=linux
- export GOARCH=amd64
- make cloudflared
@ -20,7 +19,6 @@ stretch: &stretch
- fakeroot
- rubygem-fpm
post-cache:
- export GOPATH=/cfsetup_build/
- export GOOS=linux
- export GOARCH=amd64
- make cloudflared-deb
@ -30,7 +28,6 @@ stretch: &stretch
- *pinned_go
- build-essential
post-cache:
- export GOPATH=/cfsetup_build/
- export GOOS=linux
- export GOARCH=amd64
- make release
@ -41,7 +38,6 @@ stretch: &stretch
- crossbuild-essential-armhf
- gcc-arm-linux-gnueabihf
post-cache:
- export GOPATH=/cfsetup_build/
- export GOOS=linux
- export GOARCH=arm
- export CC=arm-linux-gnueabihf-gcc
@ -52,7 +48,6 @@ stretch: &stretch
- *pinned_go
- gcc-multilib
post-cache:
- export GOPATH=/cfsetup_build/
- export GOOS=linux
- export GOARCH=386
- make release
@ -62,7 +57,6 @@ stretch: &stretch
- *pinned_go
- gcc-mingw-w64
post-cache:
- export GOPATH=/cfsetup_build/
- export GOOS=windows
- export GOARCH=amd64
- export CC=x86_64-w64-mingw32-gcc
@ -73,7 +67,6 @@ stretch: &stretch
- *pinned_go
- gcc-mingw-w64
post-cache:
- export GOPATH=/cfsetup_build/
- export GOOS=windows
- export GOARCH=386
- export CC=i686-w64-mingw32-gcc-win32
@ -84,12 +77,22 @@ stretch: &stretch
- *pinned_go
- build-essential
post-cache:
- export GOPATH=/cfsetup_build/
- export GOOS=linux
- export GOARCH=amd64
- sudo chown -R $(whoami) /cfsetup_build/
- go get github.com/BurntSushi/go-sumtype
- export PATH="$GOPATH/bin:$PATH"
# cd to a non-module directory: https://github.com/golang/go/issues/24250
- (cd / && go get github.com/BurntSushi/go-sumtype)
- export PATH="$HOME/go/bin:$PATH"
- make test
jessie: *stretch
# cfsetup compose
default-stack: test_dbconnect
test_dbconnect:
compose:
up-args:
- --renew-anon-volumes
- --abort-on-container-exit
- --exit-code-from=cloudflared
files:
- dbconnect_tests/dbconnect.yaml

View File

@ -23,7 +23,7 @@ func ssh(c *cli.Context) error {
if err != nil || rawHostName == "" {
return cli.ShowCommandHelp(c, "ssh")
}
originURL := "https://" + hostname
originURL := ensureURLScheme(hostname)
// get the headers from the cmdline and add them
headers := buildRequestHeaders(c.StringSlice(sshHeaderFlag))
@ -34,6 +34,11 @@ func ssh(c *cli.Context) error {
headers.Add("CF-Access-Client-Secret", c.String(sshTokenSecretFlag))
}
destination := c.String(sshDestinationFlag)
if destination != "" {
headers.Add("CF-Access-SSH-Destination", destination)
}
options := &carrier.StartOptions{
OriginURL: originURL,
Headers: headers,

View File

@ -1,17 +1,20 @@
package access
import (
"errors"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"text/template"
"time"
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/cmd/cloudflared/shell"
"github.com/cloudflare/cloudflared/cmd/cloudflared/token"
"github.com/cloudflare/cloudflared/sshgen"
"github.com/cloudflare/cloudflared/validation"
"github.com/pkg/errors"
"golang.org/x/net/idna"
"github.com/cloudflare/cloudflared/log"
@ -21,6 +24,7 @@ import (
const (
sshHostnameFlag = "hostname"
sshDestinationFlag = "destination"
sshURLFlag = "url"
sshHeaderFlag = "header"
sshTokenIDFlag = "service-token-id"
@ -124,6 +128,10 @@ func Commands() []*cli.Command {
Name: sshHostnameFlag,
Usage: "specify the hostname of your application.",
},
&cli.StringFlag{
Name: sshDestinationFlag,
Usage: "specify the destination address of your SSH server.",
},
&cli.StringFlag{
Name: sshURLFlag,
Usage: "specify the host:port to forward data to Cloudflare edge.",
@ -183,14 +191,20 @@ func login(c *cli.Context) error {
raven.SetDSN(sentryDSN)
logger := log.CreateLogger()
args := c.Args()
appURL, err := url.Parse(args.First())
rawURL := ensureURLScheme(args.First())
appURL, err := url.Parse(rawURL)
if args.Len() < 1 || err != nil {
logger.Errorf("Please provide the url of the Access application\n")
return err
}
token, err := token.FetchToken(appURL)
if err != nil {
logger.Errorf("Failed to fetch token: %s\n", err)
if err := verifyTokenAtEdge(appURL, c); err != nil {
logger.WithError(err).Error("Could not verify token")
return err
}
token, err := token.GetTokenIfExists(appURL)
if err != nil || token == "" {
fmt.Fprintln(os.Stderr, "Unable to find token for provided application.")
return err
}
fmt.Fprintf(os.Stdout, "Successfully fetched your token:\n\n%s\n\n", string(token))
@ -198,6 +212,16 @@ func login(c *cli.Context) error {
return nil
}
// ensureURLScheme prepends a URL with https:// if it doesnt have a scheme. http:// URLs will not be converted.
func ensureURLScheme(url string) string {
url = strings.Replace(strings.ToLower(url), "http://", "https://", 1)
if !strings.HasPrefix(url, "https://") {
url = fmt.Sprintf("https://%s", url)
}
return url
}
// curl provides a wrapper around curl, passing Access JWT along in request
func curl(c *cli.Context) error {
raven.SetDSN(sentryDSN)
@ -281,7 +305,7 @@ func sshGen(c *cli.Context) error {
return cli.ShowCommandHelp(c, "ssh-gen")
}
originURL, err := url.Parse("https://" + hostname)
originURL, err := url.Parse(ensureURLScheme(hostname))
if err != nil {
return err
}
@ -372,3 +396,59 @@ func isFileThere(candidate string) bool {
}
return true
}
// verifyTokenAtEdge checks for a token on disk, or generates a new one.
// Then makes a request to to the origin with the token to ensure it is valid.
// Returns nil if token is valid.
func verifyTokenAtEdge(appUrl *url.URL, c *cli.Context) error {
headers := buildRequestHeaders(c.StringSlice(sshHeaderFlag))
if c.IsSet(sshTokenIDFlag) {
headers.Add("CF-Access-Client-Id", c.String(sshTokenIDFlag))
}
if c.IsSet(sshTokenSecretFlag) {
headers.Add("CF-Access-Client-Secret", c.String(sshTokenSecretFlag))
}
options := &carrier.StartOptions{OriginURL: appUrl.String(), Headers: headers}
if valid, err := isTokenValid(options); err != nil {
return err
} else if valid {
return nil
}
if err := token.RemoveTokenIfExists(appUrl); err != nil {
return err
}
if valid, err := isTokenValid(options); err != nil {
return err
} else if !valid {
return errors.New("failed to verify token")
}
return nil
}
// isTokenValid makes a request to the origin and returns true if the response was not a 302.
func isTokenValid(options *carrier.StartOptions) (bool, error) {
req, err := carrier.BuildAccessRequest(options)
if err != nil {
return false, errors.Wrap(err, "Could not create access request")
}
// Do not follow redirects
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Timeout: time.Second * 5,
}
resp, err := client.Do(req)
if err != nil {
return false, err
}
defer resp.Body.Close()
// A redirect to login means the token was invalid.
return !carrier.IsAccessResponse(resp), nil
}

View File

@ -0,0 +1,25 @@
package access
import "testing"
func Test_ensureURLScheme(t *testing.T) {
type args struct {
url string
}
tests := []struct {
name string
args args
want string
}{
{"no scheme", args{"localhost:123"}, "https://localhost:123"},
{"http scheme", args{"http://test"}, "https://test"},
{"https scheme", args{"https://test"}, "https://test"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ensureURLScheme(tt.args.url); got != tt.want {
t.Errorf("ensureURLScheme() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,11 @@
//+build darwin
package shell
import (
"os/exec"
)
func getBrowserCmd(url string) *exec.Cmd {
return exec.Command("open", url)
}

View File

@ -0,0 +1,11 @@
//+build !windows,!darwin,!linux,!netbsd,!freebsd,!openbsd
package shell
import (
"os/exec"
)
func getBrowserCmd(url string) *exec.Cmd {
return nil
}

View File

@ -0,0 +1,11 @@
//+build linux freebsd openbsd netbsd
package shell
import (
"os/exec"
)
func getBrowserCmd(url string) *exec.Cmd {
return exec.Command("xdg-open", url)
}

View File

@ -0,0 +1,18 @@
//+build windows
package shell
import (
"fmt"
"os/exec"
"syscall"
)
func getBrowserCmd(url string) *exec.Cmd {
cmd := exec.Command("cmd")
// CmdLine is only defined when compiling for windows.
// Empty string is the cmd proc "Title". Needs to be included because the start command will interpret the first
// quoted string as that field and we want to quote the URL.
cmd.SysProcAttr = &syscall.SysProcAttr{CmdLine: fmt.Sprintf(`/c start "" "%s"`, url)}
return cmd
}

View File

@ -4,25 +4,11 @@ import (
"io"
"os"
"os/exec"
"runtime"
)
// OpenBrowser opens the specified URL in the default browser of the user
func OpenBrowser(url string) error {
var cmd string
var args []string
switch runtime.GOOS {
case "windows":
cmd = "cmd"
args = []string{"/c", "start"}
case "darwin":
cmd = "open"
default: // "linux", "freebsd", "openbsd", "netbsd"
cmd = "xdg-open"
}
args = append(args, url)
return exec.Command(cmd, args...).Start()
return getBrowserCmd(url).Start()
}
// Run will kick off a shell task and pipe the results to the respective std pipes

View File

@ -26,15 +26,15 @@ var logger = log.CreateLogger()
type lock struct {
lockFilePath string
backoff *origin.BackoffHandler
sigHandler *signalHandler
sigHandler *signalHandler
}
type signalHandler struct {
sigChannel chan os.Signal
signals []os.Signal
sigChannel chan os.Signal
signals []os.Signal
}
func (s *signalHandler) register(handler func()){
func (s *signalHandler) register(handler func()) {
s.sigChannel = make(chan os.Signal, 1)
signal.Notify(s.sigChannel, s.signals...)
go func(s *signalHandler) {
@ -59,8 +59,8 @@ func newLock(path string) *lock {
return &lock{
lockFilePath: lockPath,
backoff: &origin.BackoffHandler{MaxRetries: 7},
sigHandler: &signalHandler{
signals: []os.Signal{syscall.SIGINT, syscall.SIGTERM},
sigHandler: &signalHandler{
signals: []os.Signal{syscall.SIGINT, syscall.SIGTERM},
},
}
}
@ -68,8 +68,8 @@ func newLock(path string) *lock {
func (l *lock) Acquire() error {
// Intercept SIGINT and SIGTERM to release lock before exiting
l.sigHandler.register(func() {
l.deleteLockFile()
os.Exit(0)
l.deleteLockFile()
os.Exit(0)
})
// Check for a path.lock file

View File

@ -18,7 +18,7 @@ import (
)
const (
baseStoreURL = "https://login.cloudflarewarp.com/"
baseStoreURL = "https://login.argotunnel.com/"
clientTimeout = time.Second * 60
)
@ -45,7 +45,7 @@ func Run(transferURL *url.URL, resourceName, key, value, path string, shouldEncr
if err != nil {
fmt.Fprintf(os.Stderr, "Please open the following URL and log in with your Cloudflare account:\n\n%s\n\nLeave cloudflared running to download the %s automatically.\n", requestURL, resourceName)
} else {
fmt.Fprintf(os.Stderr, "A browser window should have opened at the following URL:\n\n%s\n\nIf the browser failed to open, open it yourself and visit the URL above.\n", requestURL)
fmt.Fprintf(os.Stderr, "A browser window should have opened at the following URL:\n\n%s\n\nIf the browser failed to open, please visit the URL above directly in your browser.\n", requestURL)
}
var resourceData []byte

View File

@ -7,39 +7,77 @@ import (
"net"
"net/url"
"os"
"reflect"
"runtime"
"runtime/trace"
"sync"
"syscall"
"time"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/supervisor"
"github.com/google/uuid"
"github.com/getsentry/raven-go"
"golang.org/x/crypto/ssh/terminal"
"github.com/cloudflare/cloudflared/awsuploader"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/cmd/sqlgateway"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/dbconnect"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/sshlog"
"github.com/cloudflare/cloudflared/sshserver"
"github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
"github.com/coreos/go-systemd/daemon"
"github.com/facebookgo/grace/gracenet"
"github.com/getsentry/raven-go"
"github.com/gliderlabs/ssh"
"github.com/google/uuid"
"github.com/pkg/errors"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
)
const sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
const (
sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
sshLogFileDirectory = "/usr/local/var/log/cloudflared/"
// sshPortFlag is the port on localhost the cloudflared ssh server will run on
sshPortFlag = "local-ssh-port"
// sshIdleTimeoutFlag defines the duration a SSH session can remain idle before being closed
sshIdleTimeoutFlag = "ssh-idle-timeout"
// sshMaxTimeoutFlag defines the max duration a SSH session can remain open for
sshMaxTimeoutFlag = "ssh-max-timeout"
// bucketNameFlag is the bucket name to use for the SSH log uploader
bucketNameFlag = "bucket-name"
// regionNameFlag is the AWS region name to use for the SSH log uploader
regionNameFlag = "region-name"
// secretIDFlag is the Secret id of SSH log uploader
secretIDFlag = "secret-id"
// accessKeyIDFlag is the Access key id of SSH log uploader
accessKeyIDFlag = "access-key-id"
// sessionTokenIDFlag is the Session token of SSH log uploader
sessionTokenIDFlag = "session-token"
// s3URLFlag is the S3 URL of SSH log uploader (e.g. don't use AWS s3 and use google storage bucket instead)
s3URLFlag = "s3-url-host"
// hostKeyPath is the path of the dir to save SSH host keys too
hostKeyPath = "host-key-path"
noIntentMsg = "The --intent argument is required. Cloudflared looks up an Intent to determine what configuration to use (i.e. which tunnels to start). If you don't have any Intents yet, you can use a placeholder Intent Label for now. Then, when you make an Intent with that label, cloudflared will get notified and open the tunnels you specified in that Intent."
)
var (
shutdownC chan struct{}
@ -99,43 +137,7 @@ func Commands() []*cli.Command {
ArgsUsage: " ", // can't be the empty string or we get the default output
Hidden: false,
},
{
Name: "db",
Action: func(c *cli.Context) error {
tags := make(map[string]string)
tags["hostname"] = c.String("hostname")
raven.SetTagsContext(tags)
fmt.Printf("\nSQL Database Password: ")
pass, err := terminal.ReadPassword(int(syscall.Stdin))
if err != nil {
logger.Error(err)
}
go sqlgateway.StartProxy(c, logger, string(pass))
raven.CapturePanic(func() { err = tunnel(c) }, nil)
if err != nil {
raven.CaptureError(err, nil)
}
return err
},
Before: Before,
Usage: "SQL Gateway is an SQL over HTTP reverse proxy",
Flags: []cli.Flag{
&cli.BoolFlag{
Name: "db",
Value: true,
Usage: "Enable the SQL Gateway Proxy",
},
&cli.StringFlag{
Name: "address",
Value: "",
Usage: "Database connection string: db://user:pass",
},
},
Hidden: true,
},
dbConnectCmd(),
}
var subcommands []*cli.Command
@ -327,6 +329,57 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
c.Set("url", "https://"+helloListener.Addr().String())
}
if c.IsSet("ssh-server") {
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
msg := fmt.Sprintf("--ssh-server is not supported on %s", runtime.GOOS)
logger.Error(msg)
return errors.New(msg)
}
logger.Infof("ssh-server set")
logManager := sshlog.NewEmptyManager()
if c.IsSet(bucketNameFlag) && c.IsSet(regionNameFlag) && c.IsSet(accessKeyIDFlag) && c.IsSet(secretIDFlag) {
uploader, err := awsuploader.NewFileUploader(c.String(bucketNameFlag), c.String(regionNameFlag),
c.String(accessKeyIDFlag), c.String(secretIDFlag), c.String(sessionTokenIDFlag), c.String(s3URLFlag))
if err != nil {
msg := "Cannot create uploader for SSH Server"
logger.WithError(err).Error(msg)
return errors.Wrap(err, msg)
}
if err := os.MkdirAll(sshLogFileDirectory, 0700); err != nil {
msg := fmt.Sprintf("Cannot create SSH log file directory %s", sshLogFileDirectory)
logger.WithError(err).Errorf(msg)
return errors.Wrap(err, msg)
}
logManager = sshlog.New(sshLogFileDirectory)
uploadManager := awsuploader.NewDirectoryUploadManager(logger, uploader, sshLogFileDirectory, 30*time.Minute, shutdownC)
uploadManager.Start()
}
localServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
server, err := sshserver.New(logManager, logger, version, localServerAddress, c.String("hostname"), c.Path(hostKeyPath), shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
if err != nil {
msg := "Cannot create new SSH Server"
logger.WithError(err).Error(msg)
return errors.Wrap(err, msg)
}
wg.Add(1)
go func() {
defer wg.Done()
if err = server.Start(); err != nil && err != ssh.ErrServerClosed {
logger.WithError(err).Error("SSH server error")
// TODO: remove when declarative tunnels are implemented.
close(shutdownC)
}
}()
c.Set("url", "ssh://"+localServerAddress)
}
if host := hostnameFromURI(c.String("url")); host != "" {
listener, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
@ -432,21 +485,16 @@ func startDeclarativeTunnel(ctx context.Context,
return err
}
var scope pogs.Scope
if c.IsSet("group") == c.IsSet("system-name") {
err = fmt.Errorf("exactly one of --group or --system-name must be specified")
logger.WithError(err).Error("unable to determine scope")
return err
} else if c.IsSet("group") {
scope = pogs.NewGroup(c.String("group"))
} else {
scope = pogs.NewSystemName(c.String("system-name"))
intentLabel := c.String("intent")
if intentLabel == "" {
logger.Error("--intent was empty")
return fmt.Errorf(noIntentMsg)
}
cloudflaredConfig := &connection.CloudflaredConfig{
BuildInfo: buildInfo,
CloudflaredID: cloudflaredID,
Scope: scope,
IntentLabel: intentLabel,
Tags: tags,
}
@ -559,6 +607,60 @@ func addPortIfMissing(uri *url.URL, port int) string {
return fmt.Sprintf("%s:%d", uri.Hostname(), port)
}
func dbConnectCmd() *cli.Command {
cmd := dbconnect.Cmd()
// Append the tunnel commands so users can customize the daemon settings.
cmd.Flags = appendFlags(Flags(), cmd.Flags...)
// Override before to run tunnel validation before dbconnect validation.
cmd.Before = func(c *cli.Context) error {
err := Before(c)
if err == nil {
err = dbconnect.CmdBefore(c)
}
return err
}
// Override action to setup the Proxy, then if successful, start the tunnel daemon.
cmd.Action = func(c *cli.Context) error {
err := dbconnect.CmdAction(c)
if err == nil {
err = tunnel(c)
}
return err
}
return cmd
}
// appendFlags will append extra flags to a slice of flags.
//
// The cli package will panic if two flags exist with the same name,
// so if extraFlags contains a flag that was already defined, modify the
// original flags to use the extra version.
func appendFlags(flags []cli.Flag, extraFlags ...cli.Flag) []cli.Flag {
for _, extra := range extraFlags {
var found bool
// Check if an extra flag overrides an existing flag.
for i, flag := range flags {
if reflect.DeepEqual(extra.Names(), flag.Names()) {
flags[i] = extra
found = true
break
}
}
// Append the extra flag if it has nothing to override.
if !found {
flags = append(flags, extra)
}
}
return flags
}
func tunnelFlags(shouldHide bool) []cli.Flag {
return []cli.Flag{
&cli.StringFlag{
@ -574,10 +676,11 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "no-autoupdate",
Usage: "Disable periodic check for updates, restarting the server with the new version.",
Value: false,
Hidden: shouldHide,
Name: "no-autoupdate",
Usage: "Disable periodic check for updates, restarting the server with the new version.",
EnvVars: []string{"NO_AUTOUPDATE"},
Value: false,
Hidden: shouldHide,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "is-autoupdated",
@ -635,6 +738,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
EnvVars: []string{"TUNNEL_HOSTNAME"},
Hidden: shouldHide,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "http-host-header",
Usage: "Sets the HTTP Host header for the local webserver.",
EnvVars: []string{"TUNNEL_HTTP_HOST_HEADER"},
Hidden: shouldHide,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "origin-server-name",
Usage: "Hostname on the origin server certificate.",
@ -732,6 +841,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
Hidden: shouldHide,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "ssh-server",
Value: false,
Usage: "Run an SSH Server",
EnvVars: []string{"TUNNEL_SSH_SERVER"},
Hidden: true, // TODO: remove when feature is complete
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "pidfile",
Usage: "Write the application's PID to this file after first successful connection.",
@ -856,15 +972,15 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "system-name",
Usage: "Unique identifier for this cloudflared instance. It can be configured individually in the Declarative Tunnel UI. Mutually exclusive with `--group`.",
EnvVars: []string{"TUNNEL_SYSTEM_NAME"},
Name: "intent",
Usage: "The label of an Intent from which `cloudflared` should gets its tunnels from. Intents can be created in the Origin Registry UI.",
EnvVars: []string{"TUNNEL_INTENT"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "group",
Usage: "Name of a group of cloudflared instances, of which this instance should be an identical copy. They can be configured collectively in the Declarative Tunnel UI. Mutually exclusive with `--system-name`.",
EnvVars: []string{"TUNNEL_GROUP"},
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "use-reconnect-token",
Usage: "Test reestablishing connections with the new 'reconnect token' flow.",
EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"},
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
@ -874,5 +990,66 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
EnvVars: []string{"DIAL_EDGE_TIMEOUT"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: sshPortFlag,
Usage: "Localhost port that cloudflared SSH server will run on",
Value: "2222",
EnvVars: []string{"LOCAL_SSH_PORT"},
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: sshIdleTimeoutFlag,
Usage: "Connection timeout after no activity",
EnvVars: []string{"SSH_IDLE_TIMEOUT"},
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: sshMaxTimeoutFlag,
Usage: "Absolute connection timeout",
EnvVars: []string{"SSH_MAX_TIMEOUT"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: bucketNameFlag,
Usage: "Bucket name of where to upload SSH logs",
EnvVars: []string{"BUCKET_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: regionNameFlag,
Usage: "Region name of where to upload SSH logs",
EnvVars: []string{"REGION_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: accessKeyIDFlag,
Usage: "Access Key ID of where to upload SSH logs",
EnvVars: []string{"ACCESS_CLIENT_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: secretIDFlag,
Usage: "Secret ID of where to upload SSH logs",
EnvVars: []string{"SECRET_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: sessionTokenIDFlag,
Usage: "Session Token to use in the configuration of SSH logs uploading",
EnvVars: []string{"SESSION_TOKEN_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: s3URLFlag,
Usage: "S3 url of where to upload SSH logs",
EnvVars: []string{"S3_URL"},
Hidden: true,
}),
altsrc.NewPathFlag(&cli.PathFlag{
Name: hostKeyPath,
Usage: "Absolute path of directory to save SSH host keys in",
EnvVars: []string{"HOST_KEY_PATH"},
Hidden: true,
}),
}
}

View File

@ -203,11 +203,14 @@ func prepareTunnelConfig(
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: c.IsSet("no-tls-verify")},
}
dialContext := (&net.Dialer{
dialer := &net.Dialer{
Timeout: c.Duration("proxy-connect-timeout"),
KeepAlive: c.Duration("proxy-tcp-keepalive"),
DualStack: !c.Bool("proxy-no-happy-eyeballs"),
}).DialContext
}
if c.Bool("proxy-no-happy-eyeballs") {
dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
}
dialContext := dialer.DialContext
if c.IsSet("unix-socket") {
unixSocket, err := config.ValidateUnixSocket(c)
@ -253,6 +256,7 @@ func prepareTunnelConfig(
HTTPTransport: httpTransport,
HeartbeatInterval: c.Duration("heartbeat-interval"),
Hostname: hostname,
HTTPHostHeader: c.String("http-host-header"),
IncidentLookup: origin.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"),
IsFreeTunnel: isFreeTunnel,
@ -271,6 +275,7 @@ func prepareTunnelConfig(
TlsConfig: toEdgeTLSConfig,
TransportLogger: transportLogger,
UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"),
UseReconnectToken: c.Bool("use-reconnect-token"),
}, nil
}

View File

@ -15,7 +15,7 @@ import (
const (
baseLoginURL = "https://dash.cloudflare.com/argotunnel"
callbackStoreURL = "https://login.cloudflarewarp.com/"
callbackStoreURL = "https://login.argotunnel.com/"
)
func login(c *cli.Context) error {

View File

@ -1,148 +0,0 @@
package sqlgateway
import (
"database/sql"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"strings"
"time"
_ "github.com/lib/pq"
cli "gopkg.in/urfave/cli.v2"
"github.com/elgs/gosqljson"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
)
type Message struct {
Connection Connection `json:"connection"`
Command string `json:"command"`
Params []interface{} `json:"params"`
}
type Connection struct {
SSLMode string `json:"sslmode"`
Token string `json:"token"`
}
type Response struct {
Columns []string `json:"columns"`
Rows [][]string `json:"rows"`
Error string `json:"error"`
}
type Proxy struct {
Context *cli.Context
Router *mux.Router
Token string
User string
Password string
Driver string
Database string
Logger *logrus.Logger
}
func StartProxy(c *cli.Context, logger *logrus.Logger, password string) error {
proxy := NewProxy(c, logger, password)
logger.Infof("Starting SQL Gateway Proxy on port %s", strings.Split(c.String("url"), ":")[1])
err := http.ListenAndServe(":"+strings.Split(c.String("url"), ":")[1], proxy.Router)
if err != nil {
return err
}
return nil
}
func randID(n int, c *cli.Context) string {
charBytes := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
b := make([]byte, n)
for i := range b {
b[i] = charBytes[rand.Intn(len(charBytes))]
}
return fmt.Sprintf("%s&%s", c.String("hostname"), b)
}
// db://user@dbname
func parseInfo(input string) (string, string, string) {
p1 := strings.Split(input, "://")
p2 := strings.Split(p1[1], "@")
return p1[0], p2[0], p2[1]
}
func NewProxy(c *cli.Context, logger *logrus.Logger, pass string) *Proxy {
rand.Seed(time.Now().UnixNano())
driver, user, dbname := parseInfo(c.String("address"))
proxy := Proxy{
Context: c,
Router: mux.NewRouter(),
Token: randID(64, c),
Logger: logger,
User: user,
Password: pass,
Database: dbname,
Driver: driver,
}
logger.Info(fmt.Sprintf(`
--------------------
SQL Gateway Proxy
Token: %s
--------------------
`, proxy.Token))
proxy.Router.HandleFunc("/", proxy.proxyRequest).Methods("POST")
return &proxy
}
func (proxy *Proxy) proxyRequest(rw http.ResponseWriter, req *http.Request) {
var message Message
response := Response{}
err := json.NewDecoder(req.Body).Decode(&message)
if err != nil {
proxy.Logger.Error(err)
http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest)
return
}
if message.Connection.Token != proxy.Token {
proxy.Logger.Error("Invalid token")
http.Error(rw, "400 - Invalid token", http.StatusBadRequest)
return
}
connStr := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=%s", proxy.User, proxy.Password, proxy.Database, message.Connection.SSLMode)
db, err := sql.Open(proxy.Driver, connStr)
defer db.Close()
if err != nil {
proxy.Logger.Error(err)
http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest)
return
} else {
proxy.Logger.Info("Forwarding SQL: ", message.Command)
rw.Header().Set("Content-Type", "application/json")
headers, data, err := gosqljson.QueryDbToArray(db, "lower", message.Command, message.Params...)
if err != nil {
proxy.Logger.Error(err)
http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest)
return
} else {
response = Response{headers, data, ""}
}
}
json.NewEncoder(rw).Encode(response)
}

View File

@ -2,38 +2,26 @@ package connection
import (
"context"
"net"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
rpc "zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/h2mux"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
openStreamTimeout = 30 * time.Second
)
type dialError struct {
cause error
}
func (e dialError) Error() string {
return e.cause.Error()
}
type Connection struct {
id uuid.UUID
muxer *h2mux.Muxer
}
func newConnection(muxer *h2mux.Muxer, edgeIP *net.TCPAddr) (*Connection, error) {
func newConnection(muxer *h2mux.Muxer) (*Connection, error) {
id, err := uuid.NewRandom()
if err != nil {
return nil, err
@ -50,32 +38,15 @@ func (c *Connection) Serve(ctx context.Context) error {
}
// Connect is used to establish connections with cloudflare's edge network
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (*pogs.ConnectResult, error) {
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
defer cancel()
rpcConn, err := c.newRPConn(openStreamCtx, logger)
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (tunnelpogs.ConnectResult, error) {
tsClient, err := NewRPCClient(ctx, c.muxer, logger.WithField("rpc", "connect"), openStreamTimeout)
if err != nil {
return nil, errors.Wrap(err, "cannot create new RPC connection")
}
defer rpcConn.Close()
tsClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx)}
defer tsClient.Close()
return tsClient.Connect(ctx, parameters)
}
func (c *Connection) Shutdown() {
c.muxer.Shutdown()
}
func (c *Connection) newRPConn(ctx context.Context, logger *logrus.Entry) (*rpc.Conn, error) {
stream, err := c.muxer.OpenRPCStream(ctx)
if err != nil {
return nil, err
}
return rpc.NewConn(
tunnelrpc.NewTransportLogger(logger.WithField("rpc", "connect"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("rpc", "connect")),
), nil
}

54
connection/dial.go Normal file
View File

@ -0,0 +1,54 @@
package connection
import (
"context"
"crypto/tls"
"net"
"time"
"github.com/pkg/errors"
)
// DialEdge makes a TLS connection to a Cloudflare edge node
func DialEdge(
ctx context.Context,
timeout time.Duration,
tlsConfig *tls.Config,
edgeTCPAddr *net.TCPAddr,
) (net.Conn, error) {
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
dialCtx, dialCancel := context.WithTimeout(ctx, timeout)
defer dialCancel()
dialer := net.Dialer{}
edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeTCPAddr.String())
if err != nil {
return nil, newDialError(err, "DialContext error")
}
tlsEdgeConn := tls.Client(edgeConn, tlsConfig)
tlsEdgeConn.SetDeadline(time.Now().Add(timeout))
if err = tlsEdgeConn.Handshake(); err != nil {
return nil, newDialError(err, "Handshake with edge error")
}
// clear the deadline on the conn; h2mux has its own timeouts
tlsEdgeConn.SetDeadline(time.Time{})
return tlsEdgeConn, nil
}
// DialError is an error returned from DialEdge
type DialError struct {
cause error
}
func newDialError(err error, message string) error {
return DialError{cause: errors.Wrap(err, message)}
}
func (e DialError) Error() string {
return e.cause.Error()
}
func (e DialError) Cause() error {
return e.cause
}

View File

@ -13,13 +13,13 @@ import (
)
const (
// Used to discover HA Warp servers
srvService = "warp"
// Used to discover HA origintunneld servers
srvService = "origintunneld"
srvProto = "tcp"
srvName = "cloudflarewarp.com"
srvName = "argotunnel.com"
// Used to fallback to DoT when we can't use the default resolver to
// discover HA Warp servers (GitHub issue #75).
// discover HA origintunneld servers (GitHub issue #75).
dotServerName = "cloudflare-dns.com"
dotServerAddr = "1.1.1.1:853"
dotTimeout = time.Duration(15 * time.Second)
@ -30,8 +30,8 @@ const (
var friendlyDNSErrorLines = []string{
`Please try the following things to diagnose this issue:`,
` 1. ensure that cloudflarewarp.com is returning "warp" service records.`,
` Run your system's equivalent of: dig srv _warp._tcp.cloudflarewarp.com`,
` 1. ensure that argotunnel.com is returning "origintunneld" service records.`,
` Run your system's equivalent of: dig srv _origintunneld._tcp.argotunnel.com`,
` 2. ensure that your DNS resolver is not returning compressed SRV records.`,
` See GitHub issue https://github.com/golang/go/issues/27546`,
` For example, you could use Cloudflare's 1.1.1.1 as your resolver:`,
@ -102,7 +102,7 @@ func EdgeDiscovery(logger *logrus.Entry) ([]*net.TCPAddr, error) {
// Try to fall back to DoT from Cloudflare directly.
//
// Note: Instead of DoT, we could also have used DoH. Either of these:
// - directly via the JSON API (https://1.1.1.1/dns-query?ct=application/dns-json&name=_warp._tcp.cloudflarewarp.com&type=srv)
// - directly via the JSON API (https://1.1.1.1/dns-query?ct=application/dns-json&name=_origintunneld._tcp.argotunnel.com&type=srv)
// - indirectly via `tunneldns.NewUpstreamHTTPS()`
// But both of these cases miss out on a key feature from the stdlib:
// "The returned records are sorted by priority and randomized by weight within a priority."
@ -119,7 +119,7 @@ func EdgeDiscovery(logger *logrus.Entry) ([]*net.TCPAddr, error) {
for _, s := range friendlyDNSErrorLines {
logger.Errorln(s)
}
return nil, errors.Wrap(err, "Could not lookup srv records on _warp._tcp.cloudflarewarp.com")
return nil, errors.Wrap(err, "Could not lookup srv records on _origintunneld._tcp.argotunnel.com")
}
// Accept the fallback results and keep going
addrs = fallbackAddrs

View File

@ -4,27 +4,32 @@ import (
"context"
"crypto/tls"
"fmt"
"net"
"sync"
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/streamhandler"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
quickStartLink = "https://developers.cloudflare.com/argo-tunnel/quickstart/"
faqLink = "https://developers.cloudflare.com/argo-tunnel/faq/"
quickStartLink = "https://developers.cloudflare.com/argo-tunnel/quickstart/"
faqLink = "https://developers.cloudflare.com/argo-tunnel/faq/"
defaultRetryAfter = time.Second * 5
packageNamespace = "connection"
edgeManagerSubsystem = "edgemanager"
)
// EdgeManager manages connections with the edge
type EdgeManager struct {
// streamHandler handles stream opened by the edge
streamHandler h2mux.MuxedStreamHandler
streamHandler *streamhandler.StreamHandler
// TLSConfig is the TLS configuration to connect with edge
tlsConfig *tls.Config
// cloudflaredConfig is the cloudflared configuration that is determined when the process first starts
@ -35,23 +40,36 @@ type EdgeManager struct {
state *edgeManagerState
logger *logrus.Entry
metrics *metrics
}
// EdgeConnectionManagerConfigurable is the configurable attributes of a EdgeConnectionManager
type metrics struct {
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
activeStreams prometheus.Gauge
}
func newMetrics(namespace, subsystem string) *metrics {
return &metrics{
activeStreams: h2mux.NewActiveStreamsMetrics(namespace, subsystem),
}
}
// EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager
type EdgeManagerConfigurable struct {
TunnelHostnames []h2mux.TunnelHostname
*pogs.EdgeConnectionConfig
*tunnelpogs.EdgeConnectionConfig
}
type CloudflaredConfig struct {
CloudflaredID uuid.UUID
Tags []pogs.Tag
Tags []tunnelpogs.Tag
BuildInfo *buildinfo.BuildInfo
Scope pogs.Scope
IntentLabel string
}
func NewEdgeManager(
streamHandler h2mux.MuxedStreamHandler,
streamHandler *streamhandler.StreamHandler,
edgeConnMgrConfigurable *EdgeManagerConfigurable,
userCredential []byte,
tlsConfig *tls.Config,
@ -66,6 +84,7 @@ func NewEdgeManager(
serviceDiscoverer: serviceDiscoverer,
state: newEdgeConnectionManagerState(edgeConnMgrConfigurable, userCredential),
logger: logger.WithField("subsystem", "connectionManager"),
metrics: newMetrics(packageNamespace, edgeManagerSubsystem),
}
}
@ -87,8 +106,12 @@ func (em *EdgeManager) Run(ctx context.Context) error {
// Create/delete connection one at a time, so we don't need to adjust for connections that are being created/deleted
// in shouldCreateConnection or shouldReduceConnection calculation
if em.state.shouldCreateConnection(em.serviceDiscoverer.AvailableAddrs()) {
if err := em.newConnection(ctx); err != nil {
em.logger.WithError(err).Error("cannot create new connection")
if connErr := em.newConnection(ctx); connErr != nil {
if !connErr.ShouldRetry {
em.logger.WithError(connErr).Error(em.noRetryMessage())
return connErr
}
em.logger.WithError(connErr).Error("cannot create new connection")
}
} else if em.state.shouldReduceConnection() {
if err := em.closeConnection(ctx); err != nil {
@ -103,13 +126,13 @@ func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurab
em.state.updateConfigurable(newConfigurable)
}
func (em *EdgeManager) newConnection(ctx context.Context) error {
edgeIP := em.serviceDiscoverer.Addr()
edgeConn, err := em.dialEdge(ctx, edgeIP)
if err != nil {
return errors.Wrap(err, "dial edge error")
}
func (em *EdgeManager) newConnection(ctx context.Context) *tunnelpogs.ConnectError {
edgeTCPAddr := em.serviceDiscoverer.Addr()
configurable := em.state.getConfigurable()
edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr)
if err != nil {
return retryConnection(fmt.Sprintf("dial edge error: %v", err))
}
// Establish a muxed connection with the edge
// Client mux handshake with agent server
muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
@ -119,40 +142,41 @@ func (em *EdgeManager) newConnection(ctx context.Context) error {
HeartbeatInterval: configurable.HeartbeatInterval,
MaxHeartbeats: configurable.MaxFailedHeartbeats,
Logger: em.logger.WithField("subsystem", "muxer"),
})
}, em.metrics.activeStreams)
if err != nil {
return errors.Wrap(err, "couldn't perform handshake with edge")
retryConnection(fmt.Sprintf("couldn't perform handshake with edge: %v", err))
}
h2muxConn, err := newConnection(muxer, edgeIP)
h2muxConn, err := newConnection(muxer)
if err != nil {
return errors.Wrap(err, "couldn't create h2mux connection")
return retryConnection(fmt.Sprintf("couldn't create h2mux connection: %v", err))
}
go em.serveConn(ctx, h2muxConn)
connResult, err := h2muxConn.Connect(ctx, &pogs.ConnectParameters{
connResult, err := h2muxConn.Connect(ctx, &tunnelpogs.ConnectParameters{
CloudflaredID: em.cloudflaredConfig.CloudflaredID,
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
NumPreviousAttempts: 0,
OriginCert: em.state.getUserCredential(),
Scope: em.cloudflaredConfig.Scope,
IntentLabel: em.cloudflaredConfig.IntentLabel,
Tags: em.cloudflaredConfig.Tags,
}, em.logger)
if err != nil {
h2muxConn.Shutdown()
return errors.Wrap(err, "couldn't connect to edge")
return retryConnection(fmt.Sprintf("couldn't connect to edge: %v", err))
}
if connErr := connResult.Err; connErr != nil {
if !connErr.ShouldRetry {
return errors.Wrap(connErr, em.noRetryMessage())
}
return errors.Wrapf(connErr, "edge responded with RetryAfter=%v", connErr.RetryAfter)
if connErr := connResult.ConnectError(); connErr != nil {
return connErr
}
em.state.newConnection(h2muxConn)
em.logger.Infof("connected to %s", connResult.ServerInfo.LocationName)
em.logger.Infof("connected to %s", connResult.ConnectedTo())
if connResult.ClientConfig() != nil {
em.streamHandler.UseConfiguration(ctx, connResult.ClientConfig())
}
return nil
}
@ -171,28 +195,6 @@ func (em *EdgeManager) serveConn(ctx context.Context, conn *Connection) {
em.state.closeConnection(conn)
}
func (em *EdgeManager) dialEdge(ctx context.Context, edgeIP *net.TCPAddr) (*tls.Conn, error) {
timeout := em.state.getConfigurable().Timeout
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
dialCtx, dialCancel := context.WithTimeout(ctx, timeout)
defer dialCancel()
dialer := net.Dialer{DualStack: true}
edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeIP.String())
if err != nil {
return nil, dialError{cause: errors.Wrap(err, "DialContext error")}
}
tlsEdgeConn := tls.Client(edgeConn, em.tlsConfig)
tlsEdgeConn.SetDeadline(time.Now().Add(timeout))
if err = tlsEdgeConn.Handshake(); err != nil {
return nil, dialError{cause: errors.Wrap(err, "Handshake with edge error")}
}
// clear the deadline on the conn; h2mux has its own timeouts
tlsEdgeConn.SetDeadline(time.Time{})
return tlsEdgeConn, nil
}
func (em *EdgeManager) noRetryMessage() string {
messageTemplate := "cloudflared could not register an Argo Tunnel on your account. Please confirm the following before trying again:" +
"1. You have Argo Smart Routing enabled in your account, See Enable Argo section of %s." +
@ -282,3 +284,11 @@ func (ems *edgeManagerState) getUserCredential() []byte {
defer ems.RUnlock()
return ems.userCredential
}
func retryConnection(cause string) *tunnelpogs.ConnectError {
return &tunnelpogs.ConnectError{
Cause: cause,
RetryAfter: defaultRetryAfter,
ShouldRetry: true,
}
}

View File

@ -4,13 +4,15 @@ import (
"testing"
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
var (
@ -42,16 +44,12 @@ var (
}
)
type mockStreamHandler struct {
}
func (msh *mockStreamHandler) ServeStream(*h2mux.MuxedStream) error {
return nil
}
func mockEdgeManager() *EdgeManager {
newConfigChan := make(chan<- *pogs.ClientConfig)
useConfigResultChan := make(<-chan *pogs.UseConfigurationResult)
logger := logrus.New()
return NewEdgeManager(
&mockStreamHandler{},
streamhandler.NewStreamHandler(newConfigChan, useConfigResultChan, logger),
configurable,
[]byte{},
nil,

49
connection/rpc.go Normal file
View File

@ -0,0 +1,49 @@
package connection
import (
"context"
"fmt"
"time"
"github.com/sirupsen/logrus"
rpc "zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// NewRPCClient creates and returns a new RPC client, which will communicate
// using a stream on the given muxer
func NewRPCClient(
ctx context.Context,
muxer *h2mux.Muxer,
logger *logrus.Entry,
openStreamTimeout time.Duration,
) (client tunnelpogs.TunnelServer_PogsClient, err error) {
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
defer openStreamCancel()
stream, err := muxer.OpenRPCStream(openStreamCtx)
if err != nil {
return
}
if !isRPCStreamResponse(stream.Headers) {
stream.Close()
err = fmt.Errorf("rpc: bad response headers: %v", stream.Headers)
return
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger),
)
client = tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
return client, nil
}
func isRPCStreamResponse(headers []h2mux.Header) bool {
return len(headers) == 1 &&
headers[0].Name == ":status" &&
headers[0].Value == "200"
}

145
dbconnect/client.go Normal file
View File

@ -0,0 +1,145 @@
package dbconnect
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strings"
"time"
"unicode"
"unicode/utf8"
)
// Client is an interface to talk to any database.
//
// Currently, the only implementation is SQLClient, but its structure
// should be designed to handle a MongoClient or RedisClient in the future.
type Client interface {
Ping(context.Context) error
Submit(context.Context, *Command) (interface{}, error)
}
// NewClient creates a database client based on its URL scheme.
func NewClient(ctx context.Context, originURL *url.URL) (Client, error) {
return NewSQLClient(ctx, originURL)
}
// Command is a standard, non-vendor format for submitting database commands.
//
// When determining the scope of this struct, refer to the following litmus test:
// Could this (roughly) conform to SQL, Document-based, and Key-value command formats?
type Command struct {
Statement string `json:"statement"`
Arguments Arguments `json:"arguments,omitempty"`
Mode string `json:"mode,omitempty"`
Isolation string `json:"isolation,omitempty"`
Timeout time.Duration `json:"timeout,omitempty"`
}
// Validate enforces the contract of Command: non empty statement (both in length and logic),
// lowercase mode and isolation, non-zero timeout, and valid Arguments.
func (cmd *Command) Validate() error {
if cmd.Statement == "" {
return fmt.Errorf("cannot provide an empty statement")
}
if strings.Map(func(char rune) rune {
if char == ';' || unicode.IsSpace(char) {
return -1
}
return char
}, cmd.Statement) == "" {
return fmt.Errorf("cannot provide a statement with no logic: '%s'", cmd.Statement)
}
cmd.Mode = strings.ToLower(cmd.Mode)
cmd.Isolation = strings.ToLower(cmd.Isolation)
if cmd.Timeout.Nanoseconds() <= 0 {
cmd.Timeout = 24 * time.Hour
}
return cmd.Arguments.Validate()
}
// UnmarshalJSON converts a byte representation of JSON into a Command, which is also validated.
func (cmd *Command) UnmarshalJSON(data []byte) error {
// Alias is required to avoid infinite recursion from the default UnmarshalJSON.
type Alias Command
alias := &struct {
*Alias
}{
Alias: (*Alias)(cmd),
}
err := json.Unmarshal(data, &alias)
if err == nil {
err = cmd.Validate()
}
return err
}
// Arguments is a wrapper for either map-based or array-based Command arguments.
//
// Each field is mutually-exclusive and some Client implementations may not
// support both fields (eg. MySQL does not accept named arguments).
type Arguments struct {
Named map[string]interface{}
Positional []interface{}
}
// Validate enforces the contract of Arguments: non nil, mutually exclusive, and no empty or reserved keys.
func (args *Arguments) Validate() error {
if args.Named == nil {
args.Named = map[string]interface{}{}
}
if args.Positional == nil {
args.Positional = []interface{}{}
}
if len(args.Named) > 0 && len(args.Positional) > 0 {
return fmt.Errorf("both named and positional arguments cannot be specified: %+v and %+v", args.Named, args.Positional)
}
for key := range args.Named {
if key == "" {
return fmt.Errorf("named arguments cannot contain an empty key: %+v", args.Named)
}
if !utf8.ValidString(key) {
return fmt.Errorf("named argument does not conform to UTF-8 encoding: %s", key)
}
if strings.HasPrefix(key, "_") {
return fmt.Errorf("named argument cannot start with a reserved keyword '_': %s", key)
}
if unicode.IsNumber([]rune(key)[0]) {
return fmt.Errorf("named argument cannot start with a number: %s", key)
}
}
return nil
}
// UnmarshalJSON converts a byte representation of JSON into Arguments, which is also validated.
func (args *Arguments) UnmarshalJSON(data []byte) error {
var obj interface{}
err := json.Unmarshal(data, &obj)
if err != nil {
return err
}
named, ok := obj.(map[string]interface{})
if ok {
args.Named = named
} else {
positional, ok := obj.([]interface{})
if ok {
args.Positional = positional
} else {
return fmt.Errorf("arguments must either be an object {\"0\":\"val\"} or an array [\"val\"]: %s", string(data))
}
}
return args.Validate()
}

183
dbconnect/client_test.go Normal file
View File

@ -0,0 +1,183 @@
package dbconnect
import (
"encoding/json"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCommandValidateEmpty(t *testing.T) {
stmts := []string{
"",
";",
" \n\t",
";\n;\t;",
}
for _, stmt := range stmts {
cmd := Command{Statement: stmt}
assert.Error(t, cmd.Validate(), stmt)
}
}
func TestCommandValidateMode(t *testing.T) {
modes := []string{
"",
"query",
"ExEc",
"PREPARE",
}
for _, mode := range modes {
cmd := Command{Statement: "Ok", Mode: mode}
assert.NoError(t, cmd.Validate(), mode)
assert.Equal(t, strings.ToLower(mode), cmd.Mode)
}
}
func TestCommandValidateIsolation(t *testing.T) {
isos := []string{
"",
"default",
"read_committed",
"SNAPshot",
}
for _, iso := range isos {
cmd := Command{Statement: "Ok", Isolation: iso}
assert.NoError(t, cmd.Validate(), iso)
assert.Equal(t, strings.ToLower(iso), cmd.Isolation)
}
}
func TestCommandValidateTimeout(t *testing.T) {
cmd := Command{Statement: "Ok", Timeout: 0}
assert.NoError(t, cmd.Validate())
assert.NotZero(t, cmd.Timeout)
cmd = Command{Statement: "Ok", Timeout: 1 * time.Second}
assert.NoError(t, cmd.Validate())
assert.Equal(t, 1*time.Second, cmd.Timeout)
}
func TestCommandValidateArguments(t *testing.T) {
cmd := Command{Statement: "Ok", Arguments: Arguments{
Named: map[string]interface{}{"key": "val"},
Positional: []interface{}{"val"},
}}
assert.Error(t, cmd.Validate())
}
func TestCommandUnmarshalJSON(t *testing.T) {
strs := []string{
"{\"statement\":\"Ok\"}",
"{\"statement\":\"Ok\",\"arguments\":[0, 3.14, \"apple\"],\"mode\":\"query\"}",
"{\"statement\":\"Ok\",\"isolation\":\"read_uncommitted\",\"timeout\":1000}",
}
for _, str := range strs {
var cmd Command
assert.NoError(t, json.Unmarshal([]byte(str), &cmd), str)
}
strs = []string{
"",
"\"",
"{}",
"{\"argument\":{\"key\":\"val\"}}",
"{\"statement\":[\"Ok\"]}",
}
for _, str := range strs {
var cmd Command
assert.Error(t, json.Unmarshal([]byte(str), &cmd), str)
}
}
func TestArgumentsValidateNotNil(t *testing.T) {
args := Arguments{}
assert.NoError(t, args.Validate())
assert.NotNil(t, args.Named)
assert.NotNil(t, args.Positional)
}
func TestArgumentsValidateMutuallyExclusive(t *testing.T) {
args := []Arguments{
Arguments{},
Arguments{Named: map[string]interface{}{"key": "val"}},
Arguments{Positional: []interface{}{"val"}},
}
for _, arg := range args {
assert.NoError(t, arg.Validate())
assert.False(t, len(arg.Named) > 0 && len(arg.Positional) > 0)
}
args = []Arguments{
Arguments{
Named: map[string]interface{}{"key": "val"},
Positional: []interface{}{"val"},
},
}
for _, arg := range args {
assert.Error(t, arg.Validate())
assert.True(t, len(arg.Named) > 0 && len(arg.Positional) > 0)
}
}
func TestArgumentsValidateKeys(t *testing.T) {
keys := []string{
"",
"_",
"_key",
"1",
"1key",
"\xf0\x28\x8c\xbc", // non-utf8
}
for _, key := range keys {
args := Arguments{Named: map[string]interface{}{key: "val"}}
assert.Error(t, args.Validate(), key)
}
}
func TestArgumentsUnmarshalJSON(t *testing.T) {
strs := []string{
"{}",
"{\"key\":\"val\"}",
"{\"key\":[1, 3.14, {\"key\":\"val\"}]}",
"[]",
"[\"key\",\"val\"]",
"[{}]",
}
for _, str := range strs {
var args Arguments
assert.NoError(t, json.Unmarshal([]byte(str), &args), str)
}
strs = []string{
"",
"\"",
"1",
"\"key\"",
"{\"key\",\"val\"}",
}
for _, str := range strs {
var args Arguments
assert.Error(t, json.Unmarshal([]byte(str), &args), str)
}
}

157
dbconnect/cmd.go Normal file
View File

@ -0,0 +1,157 @@
package dbconnect
import (
"context"
"log"
"net"
"strconv"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
)
// Cmd is the entrypoint command for dbconnect.
//
// The tunnel package is responsible for appending this to tunnel.Commands().
func Cmd() *cli.Command {
return &cli.Command{
Category: "Database Connect (ALPHA)",
Name: "db-connect",
Usage: "Access your SQL database from Cloudflare Workers or the browser",
ArgsUsage: " ",
Description: `
Creates a connection between your database and the Cloudflare edge.
Now you can execute SQL commands anywhere you can send HTTPS requests.
Connect your database with any of the following commands, you can also try the "playground" without a database:
cloudflared db-connect --hostname sql.mysite.com --url postgres://user:pass@localhost?sslmode=disable \
--auth-domain mysite.cloudflareaccess.com --application-aud my-access-policy-tag
cloudflared db-connect --hostname sql-dev.mysite.com --url mysql://localhost --insecure
cloudflared db-connect --playground
Requests should be authenticated using Cloudflare Access, learn more about how to enable it here:
https://developers.cloudflare.com/access/service-auth/service-token/
`,
Flags: []cli.Flag{
altsrc.NewStringFlag(&cli.StringFlag{
Name: "url",
Usage: "URL to the database (eg. postgres://user:pass@localhost?sslmode=disable)",
EnvVars: []string{"TUNNEL_URL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "hostname",
Usage: "Hostname to accept commands over HTTPS (eg. sql.mysite.com)",
EnvVars: []string{"TUNNEL_HOSTNAME"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "auth-domain",
Usage: "Cloudflare Access authentication domain for your account (eg. mysite.cloudflareaccess.com)",
EnvVars: []string{"TUNNEL_ACCESS_AUTH_DOMAIN"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "application-aud",
Usage: "Cloudflare Access application \"AUD\" to verify JWTs from requests",
EnvVars: []string{"TUNNEL_ACCESS_APPLICATION_AUD"},
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "insecure",
Usage: "Disable authentication, the database will be open to the Internet",
Value: false,
EnvVars: []string{"TUNNEL_ACCESS_INSECURE"},
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "playground",
Usage: "Run a temporary, in-memory SQLite3 database for testing",
Value: false,
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "loglevel",
Value: "debug", // Make it more verbose than the tunnel default 'info'.
EnvVars: []string{"TUNNEL_LOGLEVEL"},
Hidden: true,
}),
},
Before: CmdBefore,
Action: CmdAction,
Hidden: true,
}
}
// CmdBefore runs some validation checks before running the command.
func CmdBefore(c *cli.Context) error {
// Show the help text is no flags are specified.
if c.NumFlags() == 0 {
return cli.ShowSubcommandHelp(c)
}
// Hello-world and playground are synonymous with each other,
// unset hello-world to prevent tunnel from initializing the hello package.
if c.IsSet("hello-world") {
c.Set("playground", "true")
c.Set("hello-world", "false")
}
// Unix-socket database urls are supported, but the logic is the same as url.
if c.IsSet("unix-socket") {
c.Set("url", c.String("unix-socket"))
c.Set("unix-socket", "")
}
// When playground mode is enabled, run with an in-memory database.
if c.IsSet("playground") {
c.Set("url", "sqlite3::memory:?cache=shared")
c.Set("insecure", strconv.FormatBool(!c.IsSet("auth-domain") && !c.IsSet("application-aud")))
}
// At this point, insecure configurations are valid.
if c.Bool("insecure") {
return nil
}
// Ensure that secure configurations specify a hostname, domain, and tag for JWT validation.
if !c.IsSet("hostname") || !c.IsSet("auth-domain") || !c.IsSet("application-aud") {
log.Fatal("must specify --hostname, --auth-domain, and --application-aud unless you want to run in --insecure mode")
}
return nil
}
// CmdAction starts the Proxy and sets the url in cli.Context to point to the Proxy address.
func CmdAction(c *cli.Context) error {
// STOR-612: sync with context in tunnel daemon.
ctx := context.Background()
var proxy *Proxy
var err error
if c.Bool("insecure") {
proxy, err = NewInsecureProxy(ctx, c.String("url"))
} else {
proxy, err = NewSecureProxy(ctx, c.String("url"), c.String("auth-domain"), c.String("application-aud"))
}
if err != nil {
log.Fatal(err)
return err
}
listenerC := make(chan net.Listener)
defer close(listenerC)
// Since the Proxy should only talk to the tunnel daemon, find the next available
// localhost port and start to listen to requests.
go func() {
err := proxy.Start(ctx, "127.0.0.1:", listenerC)
if err != nil {
log.Fatal(err)
}
}()
// Block until the the Proxy is online, retreive its address, and change the url to point to it.
// This is effectively "handing over" control to the tunnel package so it can run the tunnel daemon.
c.Set("url", "https://"+(<-listenerC).Addr().String())
return nil
}

27
dbconnect/cmd_test.go Normal file
View File

@ -0,0 +1,27 @@
package dbconnect
import (
"testing"
"github.com/stretchr/testify/assert"
"gopkg.in/urfave/cli.v2"
)
func TestCmd(t *testing.T) {
tests := [][]string{
{"cloudflared", "db-connect", "--playground"},
{"cloudflared", "db-connect", "--playground", "--hostname", "sql.mysite.com"},
{"cloudflared", "db-connect", "--url", "sqlite3::memory:?cache=shared", "--insecure"},
{"cloudflared", "db-connect", "--url", "sqlite3::memory:?cache=shared", "--hostname", "sql.mysite.com", "--auth-domain", "mysite.cloudflareaccess.com", "--application-aud", "aud"},
}
app := &cli.App{
Name: "cloudflared",
Commands: []*cli.Command{Cmd()},
}
for _, test := range tests {
assert.NoError(t, app.Run(test))
}
}

271
dbconnect/proxy.go Normal file
View File

@ -0,0 +1,271 @@
package dbconnect
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/validation"
"github.com/gorilla/mux"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// Proxy is an HTTP server that proxies requests to a Client.
type Proxy struct {
client Client
accessValidator *validation.Access
logger *logrus.Logger
}
// NewInsecureProxy creates a Proxy that talks to a Client at an origin.
//
// In insecure mode, the Proxy will allow all Command requests.
func NewInsecureProxy(ctx context.Context, origin string) (*Proxy, error) {
originURL, err := url.Parse(origin)
if err != nil {
return nil, errors.Wrap(err, "must provide a valid database url")
}
client, err := NewClient(ctx, originURL)
if err != nil {
return nil, err
}
err = client.Ping(ctx)
if err != nil {
return nil, errors.Wrap(err, "could not connect to the database")
}
return &Proxy{client, nil, logrus.New()}, nil
}
// NewSecureProxy creates a Proxy that talks to a Client at an origin.
//
// In secure mode, the Proxy will reject any Command requests that are
// not authenticated by Cloudflare Access with a valid JWT.
func NewSecureProxy(ctx context.Context, origin, authDomain, applicationAUD string) (*Proxy, error) {
proxy, err := NewInsecureProxy(ctx, origin)
if err != nil {
return nil, err
}
validator, err := validation.NewAccessValidator(ctx, authDomain, authDomain, applicationAUD)
if err != nil {
return nil, err
}
proxy.accessValidator = validator
return proxy, err
}
// IsInsecure gets whether the Proxy will accept a Command from any source.
func (proxy *Proxy) IsInsecure() bool {
return proxy.accessValidator == nil
}
// IsAllowed checks whether a http.Request is allowed to receive data.
//
// By default, requests must pass through Cloudflare Access for authentication.
// If the proxy is explcitly set to insecure mode, all requests will be allowed.
func (proxy *Proxy) IsAllowed(r *http.Request, verbose ...bool) bool {
if proxy.IsInsecure() {
return true
}
// Access and Tunnel should prevent bad JWTs from even reaching the origin,
// but validate tokens anyway as an abundance of caution.
err := proxy.accessValidator.ValidateRequest(r.Context(), r)
if err == nil {
return true
}
// Warn administrators that invalid JWTs are being rejected. This is indicative
// of either a misconfiguration of the CLI or a massive failure of upstream systems.
if len(verbose) > 0 {
proxy.httpLog(r, err).Error("Failed JWT authentication")
}
return false
}
// Start the Proxy at a given address and notify the listener channel when the server is online.
func (proxy *Proxy) Start(ctx context.Context, addr string, listenerC chan<- net.Listener) error {
// STOR-611: use a seperate listener and consider web socket support.
httpListener, err := hello.CreateTLSListener(addr)
if err != nil {
return errors.Wrapf(err, "could not create listener at %s", addr)
}
errC := make(chan error)
defer close(errC)
// Starts the HTTP server and begins to serve requests.
go func() {
errC <- proxy.httpListen(ctx, httpListener)
}()
// Continually ping the server until it comes online or 10 attempts fail.
go func() {
var err error
for i := 0; i < 10; i++ {
_, err = http.Get("http://" + httpListener.Addr().String())
// Once no error was detected, notify the listener channel and return.
if err == nil {
listenerC <- httpListener
return
}
// Backoff between requests to ping the server.
<-time.After(1 * time.Second)
}
errC <- errors.Wrap(err, "took too long for the http server to start")
}()
return <-errC
}
// httpListen starts the httpServer and blocks until the context closes.
func (proxy *Proxy) httpListen(ctx context.Context, listener net.Listener) error {
httpServer := &http.Server{
Addr: listener.Addr().String(),
Handler: proxy.httpRouter(),
ReadTimeout: 10 * time.Second,
WriteTimeout: 60 * time.Second,
IdleTimeout: 60 * time.Second,
}
go func() {
<-ctx.Done()
httpServer.Close()
listener.Close()
}()
return httpServer.Serve(listener)
}
// httpRouter creates a mux.Router for the Proxy.
func (proxy *Proxy) httpRouter() *mux.Router {
router := mux.NewRouter()
router.HandleFunc("/ping", proxy.httpPing()).Methods("GET", "HEAD")
router.HandleFunc("/submit", proxy.httpSubmit()).Methods("POST")
return router
}
// httpPing tests the connection to the database.
//
// By default, this endpoint is unauthenticated to allow for health checks.
// To enable authentication, Cloudflare Access must be enabled on this route.
func (proxy *Proxy) httpPing() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
err := proxy.client.Ping(ctx)
if err == nil {
proxy.httpRespond(w, r, http.StatusOK, "")
} else {
proxy.httpRespondErr(w, r, http.StatusInternalServerError, err)
}
}
}
// httpSubmit sends a command to the database and returns its response.
//
// By default, this endpoint will reject requests that do not pass through Cloudflare Access.
// To disable authentication, the --insecure flag must be specified in the command line.
func (proxy *Proxy) httpSubmit() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !proxy.IsAllowed(r, true) {
proxy.httpRespondErr(w, r, http.StatusForbidden, fmt.Errorf(""))
return
}
var cmd Command
err := json.NewDecoder(r.Body).Decode(&cmd)
if err != nil {
proxy.httpRespondErr(w, r, http.StatusBadRequest, err)
return
}
ctx := r.Context()
data, err := proxy.client.Submit(ctx, &cmd)
if err != nil {
proxy.httpRespondErr(w, r, http.StatusUnprocessableEntity, err)
return
}
w.Header().Set("Content-type", "application/json")
err = json.NewEncoder(w).Encode(data)
if err != nil {
proxy.httpRespondErr(w, r, http.StatusInternalServerError, err)
}
}
}
// httpRespond writes a status code and string response to the response writer.
func (proxy *Proxy) httpRespond(w http.ResponseWriter, r *http.Request, status int, message string) {
w.WriteHeader(status)
// Only expose the message detail of the reponse if the request is not HEAD
// and the user is authenticated. For example, this prevents an unauthenticated
// failed health check from accidentally leaking sensitive information about the Client.
if r.Method != http.MethodHead && proxy.IsAllowed(r) {
if message == "" {
message = http.StatusText(status)
}
fmt.Fprint(w, message)
}
}
// httpRespondErr is similar to httpRespond, except it formats errors to be more friendly.
func (proxy *Proxy) httpRespondErr(w http.ResponseWriter, r *http.Request, defaultStatus int, err error) {
status, err := httpError(defaultStatus, err)
proxy.httpRespond(w, r, status, err.Error())
if len(err.Error()) > 0 {
proxy.httpLog(r, err).Warn("Database proxy error")
}
}
// httpLog returns a logrus.Entry that is formatted to output a request Cf-ray.
func (proxy *Proxy) httpLog(r *http.Request, err error) *logrus.Entry {
return proxy.logger.WithContext(r.Context()).WithField("CF-RAY", r.Header.Get("Cf-ray")).WithError(err)
}
// httpError extracts common errors and returns an status code and friendly error.
func httpError(defaultStatus int, err error) (int, error) {
if err == nil {
return http.StatusNotImplemented, fmt.Errorf("error expected but found none")
}
if err == io.EOF {
return http.StatusBadRequest, fmt.Errorf("request body cannot be empty")
}
if err == context.DeadlineExceeded {
return http.StatusRequestTimeout, err
}
_, ok := err.(net.Error)
if ok {
return http.StatusRequestTimeout, err
}
if err == context.Canceled {
// Does not exist in Golang, but would be: http.StatusClientClosedWithoutResponse
return 444, err
}
return defaultStatus, err
}

238
dbconnect/proxy_test.go Normal file
View File

@ -0,0 +1,238 @@
package dbconnect
import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
)
func TestNewInsecureProxy(t *testing.T) {
origins := []string{
"",
":/",
"http://localhost",
"tcp://localhost:9000?debug=true",
"mongodb://127.0.0.1",
}
for _, origin := range origins {
proxy, err := NewInsecureProxy(context.Background(), origin)
assert.Error(t, err)
assert.Empty(t, proxy)
}
}
func TestProxyIsAllowed(t *testing.T) {
proxy := helperNewProxy(t)
req := httptest.NewRequest("GET", "https://1.1.1.1/ping", nil)
assert.True(t, proxy.IsAllowed(req))
proxy = helperNewProxy(t, true)
req.Header.Set("Cf-access-jwt-assertion", "xxx")
assert.False(t, proxy.IsAllowed(req))
}
func TestProxyStart(t *testing.T) {
proxy := helperNewProxy(t)
ctx := context.Background()
listenerC := make(chan net.Listener)
err := proxy.Start(ctx, "1.1.1.1:", listenerC)
assert.Error(t, err)
err = proxy.Start(ctx, "127.0.0.1:-1", listenerC)
assert.Error(t, err)
ctx, cancel := context.WithTimeout(ctx, 0)
defer cancel()
err = proxy.Start(ctx, "127.0.0.1:", listenerC)
assert.IsType(t, http.ErrServerClosed, err)
}
func TestProxyHTTPRouter(t *testing.T) {
proxy := helperNewProxy(t)
router := proxy.httpRouter()
tests := []struct {
path string
method string
valid bool
}{
{"", "GET", false},
{"/", "GET", false},
{"/ping", "GET", true},
{"/ping", "HEAD", true},
{"/ping", "POST", false},
{"/submit", "POST", true},
{"/submit", "GET", false},
{"/submit/extra", "POST", false},
}
for _, test := range tests {
match := &mux.RouteMatch{}
ok := router.Match(httptest.NewRequest(test.method, "https://1.1.1.1"+test.path, nil), match)
assert.True(t, ok == test.valid, test.path)
}
}
func TestProxyHTTPPing(t *testing.T) {
proxy := helperNewProxy(t)
server := httptest.NewServer(proxy.httpPing())
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, int64(2), res.ContentLength)
res, err = client.Head(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, int64(-1), res.ContentLength)
}
func TestProxyHTTPSubmit(t *testing.T) {
proxy := helperNewProxy(t)
server := httptest.NewServer(proxy.httpSubmit())
defer server.Close()
client := server.Client()
tests := []struct {
input string
status int
output string
}{
{"", http.StatusBadRequest, "request body cannot be empty"},
{"{}", http.StatusBadRequest, "cannot provide an empty statement"},
{"{\"statement\":\"Ok\"}", http.StatusUnprocessableEntity, "cannot provide invalid sql mode: ''"},
{"{\"statement\":\"Ok\",\"mode\":\"query\"}", http.StatusUnprocessableEntity, "near \"Ok\": syntax error"},
{"{\"statement\":\"CREATE TABLE t (a INT);\",\"mode\":\"exec\"}", http.StatusOK, "{\"last_insert_id\":0,\"rows_affected\":0}\n"},
}
for _, test := range tests {
res, err := client.Post(server.URL, "application/json", strings.NewReader(test.input))
assert.NoError(t, err)
assert.Equal(t, test.status, res.StatusCode)
if res.StatusCode > http.StatusOK {
assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-type"))
} else {
assert.Equal(t, "application/json", res.Header.Get("Content-type"))
}
data, err := ioutil.ReadAll(res.Body)
defer res.Body.Close()
str := string(data)
assert.NoError(t, err)
assert.Equal(t, test.output, str)
}
}
func TestProxyHTTPSubmitForbidden(t *testing.T) {
proxy := helperNewProxy(t, true)
server := httptest.NewServer(proxy.httpSubmit())
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)
assert.Zero(t, res.ContentLength)
}
func TestProxyHTTPRespond(t *testing.T) {
proxy := helperNewProxy(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
}))
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusAccepted, res.StatusCode)
assert.Equal(t, int64(5), res.ContentLength)
data, err := ioutil.ReadAll(res.Body)
defer res.Body.Close()
assert.Equal(t, []byte("Hello"), data)
}
func TestProxyHTTPRespondForbidden(t *testing.T) {
proxy := helperNewProxy(t, true)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
}))
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusAccepted, res.StatusCode)
assert.Equal(t, int64(0), res.ContentLength)
}
func TestHTTPError(t *testing.T) {
_, errTimeout := net.DialTimeout("tcp", "127.0.0.1", 0)
assert.Error(t, errTimeout)
tests := []struct {
input error
status int
output error
}{
{nil, http.StatusNotImplemented, fmt.Errorf("error expected but found none")},
{io.EOF, http.StatusBadRequest, fmt.Errorf("request body cannot be empty")},
{context.DeadlineExceeded, http.StatusRequestTimeout, nil},
{context.Canceled, 444, nil},
{errTimeout, http.StatusRequestTimeout, nil},
{fmt.Errorf(""), http.StatusInternalServerError, nil},
}
for _, test := range tests {
status, err := httpError(http.StatusInternalServerError, test.input)
assert.Error(t, err)
assert.Equal(t, test.status, status)
if test.output == nil {
test.output = test.input
}
assert.Equal(t, test.output, err)
}
}
func helperNewProxy(t *testing.T, secure ...bool) *Proxy {
t.Helper()
proxy, err := NewSecureProxy(context.Background(), "file::memory:?cache=shared", "test.cloudflareaccess.com", "")
assert.NoError(t, err)
assert.NotNil(t, proxy)
if len(secure) == 0 {
proxy.accessValidator = nil // Mark as insecure
}
return proxy
}

318
dbconnect/sql.go Normal file
View File

@ -0,0 +1,318 @@
package dbconnect
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/url"
"reflect"
"strings"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"github.com/xo/dburl"
// SQL drivers self-register with the database/sql package.
// https://github.com/golang/go/wiki/SQLDrivers
_ "github.com/denisenkom/go-mssqldb"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"github.com/kshvakov/clickhouse"
"github.com/lib/pq"
)
// SQLClient is a Client that talks to a SQL database.
type SQLClient struct {
Dialect string
driver *sqlx.DB
}
// NewSQLClient creates a SQL client based on its URL scheme.
func NewSQLClient(ctx context.Context, originURL *url.URL) (Client, error) {
res, err := dburl.Parse(originURL.String())
if err != nil {
helpText := fmt.Sprintf("supported drivers: %+q, see documentation for more details: %s", sql.Drivers(), "https://godoc.org/github.com/xo/dburl")
return nil, fmt.Errorf("could not parse sql database url '%s': %s\n%s", originURL, err.Error(), helpText)
}
// Establishes the driver, but does not test the connection.
driver, err := sqlx.Open(res.Driver, res.DSN)
if err != nil {
return nil, fmt.Errorf("could not open sql driver %s: %s\n%s", res.Driver, err.Error(), res.DSN)
}
// Closes the driver, will occur when the context finishes.
go func() {
<-ctx.Done()
driver.Close()
}()
return &SQLClient{driver.DriverName(), driver}, nil
}
// Ping verifies a connection to the database is still alive.
func (client *SQLClient) Ping(ctx context.Context) error {
return client.driver.PingContext(ctx)
}
// Submit queries or executes a command to the SQL database.
func (client *SQLClient) Submit(ctx context.Context, cmd *Command) (interface{}, error) {
txx, err := cmd.ValidateSQL(client.Dialect)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, cmd.Timeout)
defer cancel()
var res interface{}
// Get the next available sql.Conn and submit the Command.
err = sqlConn(ctx, client.driver, txx, func(conn *sql.Conn) error {
stmt := cmd.Statement
args := cmd.Arguments.Positional
if cmd.Mode == "query" {
res, err = sqlQuery(ctx, conn, stmt, args)
} else {
res, err = sqlExec(ctx, conn, stmt, args)
}
return err
})
return res, err
}
// ValidateSQL extends the contract of Command for SQL dialects:
// mode is conformed, arguments are []sql.NamedArg, and isolation is a sql.IsolationLevel.
//
// When the command should not be wrapped in a transaction, *sql.TxOptions and error will both be nil.
func (cmd *Command) ValidateSQL(dialect string) (*sql.TxOptions, error) {
err := cmd.Validate()
if err != nil {
return nil, err
}
mode, err := sqlMode(cmd.Mode)
if err != nil {
return nil, err
}
// Mutates Arguments to only use positional arguments with the type sql.NamedArg.
// This is a required by the sql.Driver before submitting arguments.
cmd.Arguments.sql(dialect)
iso, err := sqlIsolation(cmd.Isolation)
if err != nil {
return nil, err
}
// When isolation is out-of-range, this is indicative that no
// transaction should be executed and sql.TxOptions should be nil.
if iso < sql.LevelDefault {
return nil, nil
}
// In query mode, execute the transaction in read-only, unless it's Microsoft SQL
// which does not support that type of transaction.
readOnly := mode == "query" && dialect != "mssql"
return &sql.TxOptions{Isolation: iso, ReadOnly: readOnly}, nil
}
// sqlConn gets the next available sql.Conn in the connection pool and runs a function to use it.
//
// If the transaction options are nil, run the useIt function outside a transaction.
// This is potentially an unsafe operation if the command does not clean up its state.
func sqlConn(ctx context.Context, driver *sqlx.DB, txx *sql.TxOptions, useIt func(*sql.Conn) error) error {
conn, err := driver.Conn(ctx)
if err != nil {
return err
}
defer conn.Close()
// If transaction options are specified, begin and defer a rollback to catch errors.
var tx *sql.Tx
if txx != nil {
tx, err = conn.BeginTx(ctx, txx)
if err != nil {
return err
}
defer tx.Rollback()
}
err = useIt(conn)
// Check if useIt was successful and a transaction exists before committing.
if err == nil && tx != nil {
err = tx.Commit()
}
return err
}
// sqlQuery queries rows on a sql.Conn and returns an array of result objects.
func sqlQuery(ctx context.Context, conn *sql.Conn, stmt string, args []interface{}) ([]map[string]interface{}, error) {
rows, err := conn.QueryContext(ctx, stmt, args...)
if err == nil {
return sqlRows(rows)
}
return nil, err
}
// sqlExec executes a command on a sql.Conn and returns the result of the operation.
func sqlExec(ctx context.Context, conn *sql.Conn, stmt string, args []interface{}) (sqlResult, error) {
exec, err := conn.ExecContext(ctx, stmt, args...)
if err == nil {
return sqlResultFrom(exec), nil
}
return sqlResult{}, err
}
// sql mutates Arguments to contain a positional []sql.NamedArg.
//
// The actual return type is []interface{} due to the native Golang
// function signatures for sql.Exec and sql.Query being generic.
func (args *Arguments) sql(dialect string) {
result := args.Positional
for i, val := range result {
result[i] = sqlArg("", val, dialect)
}
for key, val := range args.Named {
result = append(result, sqlArg(key, val, dialect))
}
args.Positional = result
args.Named = map[string]interface{}{}
}
// sqlArg creates a sql.NamedArg from a key-value pair and an optional dialect.
//
// Certain dialects will need to wrap objects, such as arrays, to conform its driver requirements.
func sqlArg(key, val interface{}, dialect string) sql.NamedArg {
switch reflect.ValueOf(val).Kind() {
// PostgreSQL and Clickhouse require arrays to be wrapped before
// being inserted into the driver interface.
case reflect.Slice, reflect.Array:
switch dialect {
case "postgres":
val = pq.Array(val)
case "clickhouse":
val = clickhouse.Array(val)
}
}
return sql.Named(fmt.Sprint(key), val)
}
// sqlIsolation tries to match a string to a sql.IsolationLevel.
func sqlIsolation(str string) (sql.IsolationLevel, error) {
if str == "none" {
return sql.IsolationLevel(-1), nil
}
for iso := sql.LevelDefault; ; iso++ {
if iso > sql.LevelLinearizable {
return -1, fmt.Errorf("cannot provide an invalid sql isolation level: '%s'", str)
}
if str == "" || strings.EqualFold(iso.String(), strings.ReplaceAll(str, "_", " ")) {
return iso, nil
}
}
}
// sqlMode tries to match a string to a command mode: 'query' or 'exec' for now.
func sqlMode(str string) (string, error) {
switch str {
case "query", "exec":
return str, nil
default:
return "", fmt.Errorf("cannot provide invalid sql mode: '%s'", str)
}
}
// sqlRows scans through a SQL result set and returns an array of objects.
func sqlRows(rows *sql.Rows) ([]map[string]interface{}, error) {
columns, err := rows.Columns()
if err != nil {
return nil, errors.Wrap(err, "could not extract columns from result")
}
defer rows.Close()
types, err := rows.ColumnTypes()
if err != nil {
// Some drivers do not support type extraction, so fail silently and continue.
types = make([]*sql.ColumnType, len(columns))
}
values := make([]interface{}, len(columns))
pointers := make([]interface{}, len(columns))
var results []map[string]interface{}
for rows.Next() {
for i := range columns {
pointers[i] = &values[i]
}
rows.Scan(pointers...)
// Convert a row, an array of values, into an object where
// each key is the name of its respective column.
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = sqlValue(values[i], types[i])
}
results = append(results, entry)
}
return results, nil
}
// sqlValue handles special cases where sql.Rows does not return a "human-readable" object.
func sqlValue(val interface{}, col *sql.ColumnType) interface{} {
bytes, ok := val.([]byte)
if ok {
// Opportunistically check for embeded JSON and convert it to a first-class object.
var embeded interface{}
if json.Unmarshal(bytes, &embeded) == nil {
return embeded
}
// STOR-604: investigate a way to coerce PostgreSQL arrays '{a, b, ...}' into JSON.
// Although easy with strings, it becomes more difficult with special types like INET[].
return string(bytes)
}
return val
}
// sqlResult is a thin wrapper around sql.Result.
type sqlResult struct {
LastInsertId int64 `json:"last_insert_id"`
RowsAffected int64 `json:"rows_affected"`
}
// sqlResultFrom converts sql.Result into a JSON-marshable sqlResult.
func sqlResultFrom(res sql.Result) sqlResult {
insertID, errID := res.LastInsertId()
rowsAffected, errRows := res.RowsAffected()
// If an error occurs when extracting the result, it is because the
// driver does not support that specific field. Instead of passing this
// to the user, omit the field in the response.
if errID != nil {
insertID = -1
}
if errRows != nil {
rowsAffected = -1
}
return sqlResult{insertID, rowsAffected}
}

336
dbconnect/sql_test.go Normal file
View File

@ -0,0 +1,336 @@
package dbconnect
import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"testing"
"time"
"github.com/kshvakov/clickhouse"
"github.com/lib/pq"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
)
func TestNewSQLClient(t *testing.T) {
originURLs := []string{
"postgres://localhost",
"cockroachdb://localhost:1337",
"postgresql://user:pass@127.0.0.1",
"mysql://localhost",
"clickhouse://127.0.0.1:9000/?debug",
"sqlite3::memory:",
"file:test.db?cache=shared",
}
for _, originURL := range originURLs {
origin, _ := url.Parse(originURL)
_, err := NewSQLClient(context.Background(), origin)
assert.NoError(t, err, originURL)
}
originURLs = []string{
"",
"/",
"http://localhost",
"coolthing://user:pass@127.0.0.1",
}
for _, originURL := range originURLs {
origin, _ := url.Parse(originURL)
_, err := NewSQLClient(context.Background(), origin)
assert.Error(t, err, originURL)
}
}
func TestArgumentsSQL(t *testing.T) {
args := []Arguments{
Arguments{
Positional: []interface{}{
"val", 10, 3.14,
},
},
Arguments{
Named: map[string]interface{}{
"key": time.Unix(0, 0),
},
},
}
var nameType sql.NamedArg
for _, arg := range args {
arg.sql("")
for _, named := range arg.Positional {
assert.IsType(t, nameType, named)
}
}
}
func TestSQLArg(t *testing.T) {
tests := []struct {
key interface{}
val interface{}
dialect string
arg sql.NamedArg
}{
{"key", "val", "mssql", sql.Named("key", "val")},
{0, 1, "sqlite3", sql.Named("0", 1)},
{1, []string{"a", "b", "c"}, "postgres", sql.Named("1", pq.Array([]string{"a", "b", "c"}))},
{"in", []uint{0, 1}, "clickhouse", sql.Named("in", clickhouse.Array([]uint{0, 1}))},
{"", time.Unix(0, 0), "", sql.Named("", time.Unix(0, 0))},
}
for _, test := range tests {
arg := sqlArg(test.key, test.val, test.dialect)
assert.Equal(t, test.arg, arg, test.key)
}
}
func TestSQLIsolation(t *testing.T) {
tests := []struct {
str string
iso sql.IsolationLevel
}{
{"", sql.LevelDefault},
{"DEFAULT", sql.LevelDefault},
{"read_UNcommitted", sql.LevelReadUncommitted},
{"serializable", sql.LevelSerializable},
{"none", sql.IsolationLevel(-1)},
{"SNAP shot", -2},
{"blah", -2},
}
for _, test := range tests {
iso, err := sqlIsolation(test.str)
if test.iso < -1 {
assert.Error(t, err, test.str)
} else {
assert.NoError(t, err)
assert.Equal(t, test.iso, iso, test.str)
}
}
}
func TestSQLMode(t *testing.T) {
modes := []string{
"query",
"exec",
}
for _, mode := range modes {
actual, err := sqlMode(mode)
assert.NoError(t, err)
assert.Equal(t, strings.ToLower(mode), actual, mode)
}
modes = []string{
"",
"blah",
}
for _, mode := range modes {
_, err := sqlMode(mode)
assert.Error(t, err)
}
}
func helperRows(mockRows *sqlmock.Rows) *sql.Rows {
db, mock, _ := sqlmock.New()
mock.ExpectQuery("SELECT").WillReturnRows(mockRows)
rows, _ := db.Query("SELECT")
return rows
}
func TestSQLRows(t *testing.T) {
actual, err := sqlRows(helperRows(sqlmock.NewRows(
[]string{"name", "age", "dept"}).
AddRow("alice", 19, "prod")))
expected := []map[string]interface{}{map[string]interface{}{
"name": "alice",
"age": int64(19),
"dept": "prod"}}
assert.NoError(t, err)
assert.ElementsMatch(t, expected, actual)
}
func TestSQLValue(t *testing.T) {
tests := []struct {
input interface{}
output interface{}
}{
{"hello", "hello"},
{1, 1},
{false, false},
{[]byte("random"), "random"},
{[]byte("{\"json\":true}"), map[string]interface{}{"json": true}},
{[]byte("[]"), []interface{}{}},
}
for _, test := range tests {
assert.Equal(t, test.output, sqlValue(test.input, nil), test.input)
}
}
func TestSQLResultFrom(t *testing.T) {
res := sqlResultFrom(sqlmock.NewResult(1, 2))
assert.Equal(t, sqlResult{1, 2}, res)
res = sqlResultFrom(sqlmock.NewErrorResult(fmt.Errorf("")))
assert.Equal(t, sqlResult{-1, -1}, res)
}
func helperSQLite3(t *testing.T) (context.Context, Client) {
t.Helper()
ctx := context.Background()
url, _ := url.Parse("file::memory:?cache=shared")
sqlite3, err := NewSQLClient(ctx, url)
assert.NoError(t, err)
return ctx, sqlite3
}
func TestPing(t *testing.T) {
ctx, sqlite3 := helperSQLite3(t)
err := sqlite3.Ping(ctx)
assert.NoError(t, err)
}
func TestSubmit(t *testing.T) {
ctx, sqlite3 := helperSQLite3(t)
res, err := sqlite3.Submit(ctx, &Command{
Statement: "CREATE TABLE t (a INTEGER, b FLOAT, c TEXT, d BLOB);",
Mode: "exec",
})
assert.NoError(t, err)
assert.Equal(t, sqlResult{0, 0}, res)
res, err = sqlite3.Submit(ctx, &Command{
Statement: "SELECT * FROM t;",
Mode: "query",
})
assert.NoError(t, err)
assert.Empty(t, res)
res, err = sqlite3.Submit(ctx, &Command{
Statement: "INSERT INTO t VALUES (?, ?, ?, ?);",
Mode: "exec",
Arguments: Arguments{
Positional: []interface{}{
1,
3.14,
"text",
"blob",
},
},
})
assert.NoError(t, err)
assert.Equal(t, sqlResult{1, 1}, res)
res, err = sqlite3.Submit(ctx, &Command{
Statement: "UPDATE t SET c = NULL;",
Mode: "exec",
})
assert.NoError(t, err)
assert.Equal(t, sqlResult{1, 1}, res)
res, err = sqlite3.Submit(ctx, &Command{
Statement: "SELECT * FROM t WHERE a = ?;",
Mode: "query",
Arguments: Arguments{
Positional: []interface{}{1},
},
})
assert.NoError(t, err)
assert.Len(t, res, 1)
resf, ok := res.([]map[string]interface{})
assert.True(t, ok)
assert.EqualValues(t, map[string]interface{}{
"a": int64(1),
"b": 3.14,
"c": nil,
"d": "blob",
}, resf[0])
res, err = sqlite3.Submit(ctx, &Command{
Statement: "DROP TABLE t;",
Mode: "exec",
})
assert.NoError(t, err)
assert.Equal(t, sqlResult{1, 1}, res)
}
func TestSubmitTransaction(t *testing.T) {
ctx, sqlite3 := helperSQLite3(t)
res, err := sqlite3.Submit(ctx, &Command{
Statement: "BEGIN;",
Mode: "exec",
})
assert.Error(t, err)
assert.Empty(t, res)
res, err = sqlite3.Submit(ctx, &Command{
Statement: "BEGIN; CREATE TABLE tt (a INT); COMMIT;",
Mode: "exec",
Isolation: "none",
})
assert.NoError(t, err)
assert.Equal(t, sqlResult{0, 0}, res)
rows, err := sqlite3.Submit(ctx, &Command{
Statement: "SELECT * FROM tt;",
Mode: "query",
Isolation: "repeatable_read",
})
assert.NoError(t, err)
assert.Empty(t, rows)
}
func TestSubmitTimeout(t *testing.T) {
ctx, sqlite3 := helperSQLite3(t)
res, err := sqlite3.Submit(ctx, &Command{
Statement: "SELECT * FROM t;",
Mode: "query",
Timeout: 1 * time.Nanosecond,
})
assert.Error(t, err)
assert.Empty(t, res)
}
func TestSubmitMode(t *testing.T) {
ctx, sqlite3 := helperSQLite3(t)
res, err := sqlite3.Submit(ctx, &Command{
Statement: "SELECT * FROM t;",
Mode: "notanoption",
})
assert.Error(t, err)
assert.Empty(t, res)
}
func TestSubmitEmpty(t *testing.T) {
ctx, sqlite3 := helperSQLite3(t)
res, err := sqlite3.Submit(ctx, &Command{
Statement: "; ; ; ;",
Mode: "query",
})
assert.Error(t, err)
assert.Empty(t, res)
}

View File

@ -0,0 +1,78 @@
# docker-compose -f ./dbconnect_tests/dbconnect.yaml up --build --force-recreate --renew-anon-volumes --exit-code-from cloudflared
version: "2.3"
networks:
test-dbconnect-network:
driver: bridge
services:
cloudflared:
build:
context: ../
dockerfile: dev.Dockerfile
command: go test github.com/cloudflare/cloudflared/dbconnect_tests -v
depends_on:
postgres:
condition: service_healthy
mysql:
condition: service_healthy
mssql:
condition: service_healthy
clickhouse:
condition: service_healthy
environment:
DBCONNECT_INTEGRATION_TEST: "true"
POSTGRESQL_URL: postgres://postgres:secret@postgres/db?sslmode=disable
MYSQL_URL: mysql://root:secret@mysql/db?tls=false
MSSQL_URL: mssql://sa:secret12345!@mssql
CLICKHOUSE_URL: clickhouse://clickhouse:9000/db
networks:
- test-dbconnect-network
postgres:
image: postgres:11.4-alpine
environment:
POSTGRES_DB: db
POSTGRES_PASSWORD: secret
healthcheck:
test: ["CMD", "pg_isready", "-U", "postgres"]
start_period: 3s
interval: 1s
timeout: 3s
retries: 10
networks:
- test-dbconnect-network
mysql:
image: mysql:8.0
environment:
MYSQL_DATABASE: db
MYSQL_ROOT_PASSWORD: secret
healthcheck:
test: ["CMD", "mysqladmin", "ping"]
start_period: 3s
interval: 1s
timeout: 3s
retries: 10
networks:
- test-dbconnect-network
mssql:
image: mcr.microsoft.com/mssql/server:2017-CU8-ubuntu
environment:
ACCEPT_EULA: "Y"
SA_PASSWORD: secret12345!
healthcheck:
test: ["CMD", "/opt/mssql-tools/bin/sqlcmd", "-S", "localhost", "-U", "sa", "-P", "secret12345!", "-Q", "SELECT 1"]
start_period: 3s
interval: 1s
timeout: 3s
retries: 10
networks:
- test-dbconnect-network
clickhouse:
image: yandex/clickhouse-server:19.11
healthcheck:
test: ["CMD", "clickhouse-client", "--query", "SELECT 1"]
start_period: 3s
interval: 1s
timeout: 3s
retries: 10
networks:
- test-dbconnect-network

265
dbconnect_tests/sql_test.go Normal file
View File

@ -0,0 +1,265 @@
package dbconnect_tests
import (
"context"
"log"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/dbconnect"
)
func TestIntegrationPostgreSQL(t *testing.T) {
ctx, pq := helperNewSQLClient(t, "POSTGRESQL_URL")
err := pq.Ping(ctx)
assert.NoError(t, err)
_, err = pq.Submit(ctx, &dbconnect.Command{
Statement: "CREATE TABLE t (a TEXT, b UUID, c JSON, d INET[], e SERIAL);",
Mode: "exec",
})
assert.NoError(t, err)
_, err = pq.Submit(ctx, &dbconnect.Command{
Statement: "INSERT INTO t VALUES ($1, $2, $3, $4);",
Mode: "exec",
Arguments: dbconnect.Arguments{
Positional: []interface{}{
"text",
"6b8d686d-bd8e-43bc-b09a-cfcbbe702c10",
"{\"bool\":true,\"array\":[\"a\", 1, 3.14],\"embed\":{\"num\":21}}",
[]string{"1.1.1.1", "1.0.0.1"},
},
},
})
assert.NoError(t, err)
_, err = pq.Submit(ctx, &dbconnect.Command{
Statement: "UPDATE t SET b = NULL;",
Mode: "exec",
})
assert.NoError(t, err)
res, err := pq.Submit(ctx, &dbconnect.Command{
Statement: "SELECT * FROM t;",
Mode: "query",
})
assert.NoError(t, err)
assert.IsType(t, make([]map[string]interface{}, 0), res)
actual := res.([]map[string]interface{})[0]
expected := map[string]interface{}{
"a": "text",
"b": nil,
"c": map[string]interface{}{
"bool": true,
"array": []interface{}{"a", float64(1), 3.14},
"embed": map[string]interface{}{"num": float64(21)},
},
"d": "{1.1.1.1,1.0.0.1}",
"e": int64(1),
}
assert.EqualValues(t, expected, actual)
_, err = pq.Submit(ctx, &dbconnect.Command{
Statement: "DROP TABLE t;",
Mode: "exec",
})
assert.NoError(t, err)
}
func TestIntegrationMySQL(t *testing.T) {
ctx, my := helperNewSQLClient(t, "MYSQL_URL")
err := my.Ping(ctx)
assert.NoError(t, err)
_, err = my.Submit(ctx, &dbconnect.Command{
Statement: "CREATE TABLE t (a CHAR, b TINYINT, c FLOAT, d JSON, e YEAR);",
Mode: "exec",
})
assert.NoError(t, err)
_, err = my.Submit(ctx, &dbconnect.Command{
Statement: "INSERT INTO t VALUES (?, ?, ?, ?, ?);",
Mode: "exec",
Arguments: dbconnect.Arguments{
Positional: []interface{}{
"a",
10,
3.14,
"{\"bool\":true}",
2000,
},
},
})
assert.NoError(t, err)
_, err = my.Submit(ctx, &dbconnect.Command{
Statement: "ALTER TABLE t ADD COLUMN f GEOMETRY;",
Mode: "exec",
})
assert.NoError(t, err)
res, err := my.Submit(ctx, &dbconnect.Command{
Statement: "SELECT * FROM t;",
Mode: "query",
})
assert.NoError(t, err)
assert.IsType(t, make([]map[string]interface{}, 0), res)
actual := res.([]map[string]interface{})[0]
expected := map[string]interface{}{
"a": "a",
"b": float64(10),
"c": 3.14,
"d": map[string]interface{}{"bool": true},
"e": float64(2000),
"f": nil,
}
assert.EqualValues(t, expected, actual)
_, err = my.Submit(ctx, &dbconnect.Command{
Statement: "DROP TABLE t;",
Mode: "exec",
})
assert.NoError(t, err)
}
func TestIntegrationMSSQL(t *testing.T) {
ctx, ms := helperNewSQLClient(t, "MSSQL_URL")
err := ms.Ping(ctx)
assert.NoError(t, err)
_, err = ms.Submit(ctx, &dbconnect.Command{
Statement: "CREATE TABLE t (a BIT, b DECIMAL, c MONEY, d TEXT);",
Mode: "exec"})
assert.NoError(t, err)
_, err = ms.Submit(ctx, &dbconnect.Command{
Statement: "INSERT INTO t VALUES (?, ?, ?, ?);",
Mode: "exec",
Arguments: dbconnect.Arguments{
Positional: []interface{}{
0,
3,
"$0.99",
"text",
},
},
})
assert.NoError(t, err)
_, err = ms.Submit(ctx, &dbconnect.Command{
Statement: "UPDATE t SET d = NULL;",
Mode: "exec",
})
assert.NoError(t, err)
res, err := ms.Submit(ctx, &dbconnect.Command{
Statement: "SELECT * FROM t;",
Mode: "query",
})
assert.NoError(t, err)
assert.IsType(t, make([]map[string]interface{}, 0), res)
actual := res.([]map[string]interface{})[0]
expected := map[string]interface{}{
"a": false,
"b": float64(3),
"c": float64(0.99),
"d": nil,
}
assert.EqualValues(t, expected, actual)
_, err = ms.Submit(ctx, &dbconnect.Command{
Statement: "DROP TABLE t;",
Mode: "exec",
})
assert.NoError(t, err)
}
func TestIntegrationClickhouse(t *testing.T) {
ctx, ch := helperNewSQLClient(t, "CLICKHOUSE_URL")
err := ch.Ping(ctx)
assert.NoError(t, err)
_, err = ch.Submit(ctx, &dbconnect.Command{
Statement: "CREATE TABLE t (a UUID, b String, c Float64, d UInt32, e Int16, f Array(Enum8('a'=1, 'b'=2, 'c'=3))) engine=Memory;",
Mode: "exec",
})
assert.NoError(t, err)
_, err = ch.Submit(ctx, &dbconnect.Command{
Statement: "INSERT INTO t VALUES (?, ?, ?, ?, ?, ?);",
Mode: "exec",
Arguments: dbconnect.Arguments{
Positional: []interface{}{
"ec65f626-6f50-4c86-9628-6314ef1edacd",
"",
3.14,
314,
-144,
[]string{"a", "b", "c"},
},
},
})
assert.NoError(t, err)
res, err := ch.Submit(ctx, &dbconnect.Command{
Statement: "SELECT * FROM t;",
Mode: "query",
})
assert.NoError(t, err)
assert.IsType(t, make([]map[string]interface{}, 0), res)
actual := res.([]map[string]interface{})[0]
expected := map[string]interface{}{
"a": "ec65f626-6f50-4c86-9628-6314ef1edacd",
"b": "",
"c": float64(3.14),
"d": uint32(314),
"e": int16(-144),
"f": []string{"a", "b", "c"},
}
assert.EqualValues(t, expected, actual)
_, err = ch.Submit(ctx, &dbconnect.Command{
Statement: "DROP TABLE t;",
Mode: "exec",
})
assert.NoError(t, err)
}
func helperNewSQLClient(t *testing.T, env string) (context.Context, dbconnect.Client) {
_, ok := os.LookupEnv("DBCONNECT_INTEGRATION_TEST")
if ok {
t.Helper()
} else {
t.SkipNow()
}
val, ok := os.LookupEnv(env)
if !ok {
log.Fatalf("must provide database url as environment variable: %s", env)
}
parsed, err := url.Parse(val)
if err != nil {
log.Fatalf("cannot provide invalid database url: %s=%s", env, val)
}
ctx := context.Background()
client, err := dbconnect.NewSQLClient(ctx, parsed)
if err != nil {
log.Fatalf("could not start test client: %s", err)
}
return ctx, client
}

4
dev.Dockerfile Normal file
View File

@ -0,0 +1,4 @@
FROM golang:1.12 as builder
WORKDIR /go/src/github.com/cloudflare/cloudflared/
RUN apt-get update
COPY . .

73
go.mod Normal file
View File

@ -0,0 +1,73 @@
module github.com/cloudflare/cloudflared
go 1.12
require (
github.com/DATA-DOG/go-sqlmock v1.3.3
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239 // indirect
github.com/aws/aws-sdk-go v1.25.8
github.com/beorn7/perks v1.0.1 // indirect
github.com/certifi/gocertifi v0.0.0-20180118203423-deb3ae2ef261 // indirect
github.com/cloudflare/brotli-go v0.0.0-20191101163834-d34379f7ff93
github.com/cloudflare/golibs v0.0.0-20170913112048-333127dbecfc
github.com/coredns/coredns v1.2.0
github.com/coreos/go-oidc v0.0.0-20171002155002-a93f71fdfe73
github.com/coreos/go-systemd v0.0.0-20190620071333-e64a0ec8b42a
github.com/denisenkom/go-mssqldb v0.0.0-20191001013358-cfbb681360f0
github.com/equinox-io/equinox v1.2.0
github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51 // indirect
github.com/facebookgo/freeport v0.0.0-20150612182905-d4adf43b75b9 // indirect
github.com/facebookgo/grace v0.0.0-20180706040059-75cf19382434
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect
github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870 // indirect
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
github.com/getsentry/raven-go v0.0.0-20180517221441-ed7bcb39ff10
github.com/gliderlabs/ssh v0.0.0-20191009160644-63518b5243e0
github.com/go-sql-driver/mysql v1.4.1
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
github.com/google/uuid v1.1.1
github.com/gorilla/mux v1.7.3
github.com/gorilla/websocket v1.2.0
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect
github.com/jmoiron/sqlx v1.2.0
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/kshvakov/clickhouse v1.3.11
github.com/lib/pq v1.2.0
github.com/mattn/go-colorable v0.1.4
github.com/mattn/go-isatty v0.0.10 // indirect
github.com/mattn/go-sqlite3 v1.11.0
github.com/mholt/caddy v0.0.0-20180807230124-d3b731e9255b // indirect
github.com/miekg/dns v1.1.8
github.com/mitchellh/go-homedir v1.1.0
github.com/opentracing/opentracing-go v1.1.0 // indirect
github.com/philhofer/fwd v1.0.0 // indirect
github.com/pkg/errors v0.8.1
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/prometheus/client_golang v1.0.0
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 // indirect
github.com/prometheus/common v0.7.0 // indirect
github.com/prometheus/procfs v0.0.5 // indirect
github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.3.0
github.com/tinylib/msgp v1.1.0 // indirect
github.com/xo/dburl v0.0.0-20191005012637-293c3298d6c0
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
golang.org/x/net v0.0.0-20191007182048-72f939374954
golang.org/x/sync v0.0.0-20190423024810-112230192c58
golang.org/x/sys v0.0.0-20191008105621-543471e840be
golang.org/x/text v0.3.2 // indirect
google.golang.org/appengine v1.4.0 // indirect
google.golang.org/genproto v0.0.0-20191007204434-a023cd5227bd // indirect
google.golang.org/grpc v1.24.0 // indirect
gopkg.in/coreos/go-oidc.v2 v2.1.0
gopkg.in/square/go-jose.v2 v2.4.0 // indirect
gopkg.in/urfave/cli.v2 v2.0.0-20180128181224-d604b6ffeee8
gopkg.in/yaml.v2 v2.2.4 // indirect
zombiezen.com/go/capnproto2 v0.0.0-20180616160808-7cfd211c19c7
)
// ../../go/pkg/mod/github.com/coredns/coredns@v1.2.0/plugin/metrics/metrics.go:40:49: too many arguments in call to prometheus.NewProcessCollector
// have (int, string)
// want (prometheus.ProcessCollectorOpts)
replace github.com/prometheus/client_golang => github.com/prometheus/client_golang v0.9.0-pre1

210
go.sum Normal file
View File

@ -0,0 +1,210 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.3.3 h1:CWUqKXe0s8A2z6qCgkP4Kru7wC11YoAnoupUKFDnH08=
github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239 h1:kFOfPq6dUM1hTo4JG6LR5AXSUEsOjtdm0kw0FtQtMJA=
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c=
github.com/aws/aws-sdk-go v1.25.8 h1:n7I+HUUXjun2CsX7JK+1hpRIkZrlKhd3nayeb+Xmavs=
github.com/aws/aws-sdk-go v1.25.8/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bkaradzic/go-lz4 v1.0.0 h1:RXc4wYsyz985CkXXeX04y4VnZFGG8Rd43pRaHsOXAKk=
github.com/bkaradzic/go-lz4 v1.0.0/go.mod h1:0YdlkowM3VswSROI7qDxhRvJ3sLhlFrRRwjwegp5jy4=
github.com/certifi/gocertifi v0.0.0-20180118203423-deb3ae2ef261 h1:6/yVvBsKeAw05IUj4AzvrxaCnDjN4nUqKjW9+w5wixg=
github.com/certifi/gocertifi v0.0.0-20180118203423-deb3ae2ef261/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudflare/brotli-go v0.0.0-20191101163834-d34379f7ff93 h1:QrGfkZDnMxcWHaYDdB7CmqS9i26OAnUj/xcus/abYkY=
github.com/cloudflare/brotli-go v0.0.0-20191101163834-d34379f7ff93/go.mod h1:QiTe66jFdP7cUKMCCf/WrvDyYdtdmdZfVcdoLbzaKVY=
github.com/cloudflare/golibs v0.0.0-20170913112048-333127dbecfc h1:Dvk3ySBsOm5EviLx6VCyILnafPcQinXGP5jbTdHUJgE=
github.com/cloudflare/golibs v0.0.0-20170913112048-333127dbecfc/go.mod h1:HlgKKR8V5a1wroIDDIz3/A+T+9Janfq+7n1P5sEFdi0=
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg=
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80=
github.com/coredns/coredns v1.2.0 h1:YEI38K2BJYzL/SxO2tZFD727T/C68DqVWkBQjT0sWPU=
github.com/coredns/coredns v1.2.0/go.mod h1:zASH/MVDgR6XZTbxvOnsZfffS+31vg6Ackf/wo1+AM0=
github.com/coreos/go-oidc v0.0.0-20171002155002-a93f71fdfe73 h1:7CNPV0LWRCa1FNmqg700pbXhzvmoaXKyfxWRkjRym7Q=
github.com/coreos/go-oidc v0.0.0-20171002155002-a93f71fdfe73/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
github.com/coreos/go-systemd v0.0.0-20190620071333-e64a0ec8b42a h1:W8b4lQ4tFF21aspRGoBuCNV6V2fFJBF+pm1J6OY8Lys=
github.com/coreos/go-systemd v0.0.0-20190620071333-e64a0ec8b42a/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/denisenkom/go-mssqldb v0.0.0-20191001013358-cfbb681360f0 h1:epsH3lb7KVbXHYk7LYGN5EiE0MxcevHU85CKITJ0wUY=
github.com/denisenkom/go-mssqldb v0.0.0-20191001013358-cfbb681360f0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/equinox-io/equinox v1.2.0 h1:bBS7Ou+Y7Jwgmy8TWSYxEh85WctuFn7FPlgbUzX4DBA=
github.com/equinox-io/equinox v1.2.0/go.mod h1:6s3HJB0PYUNgs0mxmI8fHdfVl3TQ25ieA/PVfr+eyVo=
github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51 h1:0JZ+dUmQeA8IIVUMzysrX4/AKuQwWhV2dYQuPZdvdSQ=
github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51/go.mod h1:Yg+htXGokKKdzcwhuNDwVvN+uBxDGXJ7G/VN1d8fa64=
github.com/facebookgo/freeport v0.0.0-20150612182905-d4adf43b75b9 h1:wWke/RUCl7VRjQhwPlR/v0glZXNYzBHdNUzf/Am2Nmg=
github.com/facebookgo/freeport v0.0.0-20150612182905-d4adf43b75b9/go.mod h1:uPmAp6Sws4L7+Q/OokbWDAK1ibXYhB3PXFP1kol5hPg=
github.com/facebookgo/grace v0.0.0-20180706040059-75cf19382434 h1:mOp33BLbcbJ8fvTAmZacbBiOASfxN+MLcLxymZCIrGE=
github.com/facebookgo/grace v0.0.0-20180706040059-75cf19382434/go.mod h1:KigFdumBXUPSwzLDbeuzyt0elrL7+CP7TKuhrhT4bcU=
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 h1:JWuenKqqX8nojtoVVWjGfOF9635RETekkoH6Cc9SX0A=
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052/go.mod h1:UbMTZqLaRiH3MsBH8va0n7s1pQYcu3uTb8G4tygF4Zg=
github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870 h1:E2s37DuLxFhQDg5gKsWoLBOB0n+ZW8s599zru8FJ2/Y=
github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870/go.mod h1:5tD+neXqOorC30/tWg0LCSkrqj/AR6gu8yY8/fpw1q0=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/getsentry/raven-go v0.0.0-20180517221441-ed7bcb39ff10 h1:YO10pIIBftO/kkTFdWhctH96grJ7qiy7bMdiZcIvPKs=
github.com/getsentry/raven-go v0.0.0-20180517221441-ed7bcb39ff10/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ=
github.com/gliderlabs/ssh v0.0.0-20191009160644-63518b5243e0 h1:gF8ngtda767ddth2SH0YSAhswhz6qUkvyI9EZFYCWJA=
github.com/gliderlabs/ssh v0.0.0-20191009160644-63518b5243e0/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 h1:zN2lZNZRflqFyxVaTIU61KNKQ9C0055u9CAfpmqUvo4=
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3/go.mod h1:nPpo7qLxd6XL3hWJG/O60sR8ZKfMCiIoNap5GvD12KU=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/websocket v1.2.0 h1:VJtLvh6VQym50czpZzx07z/kw9EgAxI3x1ZB8taTMQQ=
github.com/gorilla/websocket v1.2.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 h1:MJG/KsmcqMwFAkh8mTnAwhyKoB+sTAnY4CACC110tbU=
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645/go.mod h1:6iZfnjpejD4L/4DwD7NryNaJyCQdzwWwH2MWhCA90Kw=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA=
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kshvakov/clickhouse v1.3.11 h1:dtzTJY0fCA+MWkLyuKZaNPkmSwdX4gh8+Klic9NB1Lw=
github.com/kshvakov/clickhouse v1.3.11/go.mod h1:/SVBAcqF3u7rxQ9sTWCZwf8jzzvxiZGeQvtmSF2BBEc=
github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 h1:MtvEpTB6LX3vkb4ax0b5D2DHbNAUsen0Gx5wZoq3lV4=
github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.10 h1:qxFzApOv4WsAL965uUPIsXzAKCZxN2p9UqdhFS4ZW10=
github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84=
github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q=
github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mholt/caddy v0.0.0-20180807230124-d3b731e9255b h1:/BbY4n99iMazlr2igipph+hj0MwlZIWpcsP8Iy+na+s=
github.com/mholt/caddy v0.0.0-20180807230124-d3b731e9255b/go.mod h1:Wb1PlT4DAYSqOEd03MsqkdkXnTxA8v9pKjdpxbqM1kY=
github.com/miekg/dns v1.1.8 h1:1QYRAKU3lN5cRfLCkPU08hwvLJFhvjP6MqNMmQz6ZVI=
github.com/miekg/dns v1.1.8/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU=
github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
github.com/philhofer/fwd v1.0.0 h1:UbZqGr5Y38ApvM/V/jEljVxwocdweyH+vmYvRPBnbqQ=
github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
github.com/pierrec/lz4 v2.0.5+incompatible h1:2xWsjqPFWcplujydGg4WmhC/6fZqK42wMM8aXeqhl0I=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
github.com/prometheus/client_golang v0.9.0-pre1 h1:AWTOhsOI9qxeirTuA0A4By/1Es1+y9EcCGY6bBZ2fhM=
github.com/prometheus/client_golang v0.9.0-pre1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/common v0.7.0 h1:L+1lyG48J1zAQXA3RBX/nG/B3gjlHq0zTt2tlbJLyCY=
github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA=
github.com/prometheus/procfs v0.0.5 h1:3+auTFlqw+ZaQYJARz6ArODtkaIwtvBTx3N2NehQlL8=
github.com/prometheus/procfs v0.0.5/go.mod h1:4A/X28fw3Fc593LaREMrKMqOKvUAntwMDaekg4FpcdQ=
github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 h1:mZHayPoR0lNmnHyvtYjDeq0zlVHn9K/ZXoy17ylucdo=
github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5/go.mod h1:GEXHk5HgEKCvEIIrSpFI3ozzG5xOKA2DVlEX/gGnewM=
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/tinylib/msgp v1.1.0 h1:9fQd+ICuRIu/ue4vxJZu6/LzxN0HwMds2nq/0cFvxHU=
github.com/tinylib/msgp v1.1.0/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/xo/dburl v0.0.0-20191005012637-293c3298d6c0 h1:6DtWz8hNS4qbq0OCRPhdBMG9E2qKTSDKlwnP3dmZvuA=
github.com/xo/dburl v0.0.0-20191005012637-293c3298d6c0/go.mod h1:A47W3pdWONaZmXuLZgfKLAVgUY0qvfTRM5vVDKS40S4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191007182048-72f939374954 h1:JGZucVF/L/TotR719NbujzadOZ2AgnYlqphQGHDCKaU=
golang.org/x/net v0.0.0-20191007182048-72f939374954/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be h1:vEDujvNQGv4jgYKudGeI/+DAX4Jffq6hpD55MmoEvKs=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191008105621-543471e840be h1:QAcqgptGM8IQBC9K/RC4o+O9YmqEm0diQn9QmZw/0mU=
golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20191007204434-a023cd5227bd h1:84VQPzup3IpKLxuIAZjHMhVjJ8fZ4/i3yUnj3k6fUdw=
google.golang.org/genproto v0.0.0-20191007204434-a023cd5227bd/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.24.0 h1:vb/1TCsVn3DcJlQ0Gs1yB1pKI6Do2/QNwxdKqmc/b0s=
google.golang.org/grpc v1.24.0/go.mod h1:XDChyiUovWa60DnaeDeZmSW86xtLtjtZbwvSiRnRtcA=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/coreos/go-oidc.v2 v2.1.0 h1:E8PjVFdj/SLDKB0hvb70KTbMbYVHjqztiQdSkIg8E+I=
gopkg.in/coreos/go-oidc.v2 v2.1.0/go.mod h1:fYaTe2FS96wZZwR17YTDHwG+Mw6fmyqJNxN2eNCGPCI=
gopkg.in/square/go-jose.v2 v2.4.0 h1:0kXPskUMGAXXWJlP05ktEMOV0vmzFQUWw6d+aZJQU8A=
gopkg.in/square/go-jose.v2 v2.4.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/urfave/cli.v2 v2.0.0-20180128181224-d604b6ffeee8 h1:/pLAskKF+d5SawboKd8GB8ew4ClHDbt2c3K9EBFeRGU=
gopkg.in/urfave/cli.v2 v2.0.0-20180128181224-d604b6ffeee8/go.mod h1:cKXr3E0k4aosgycml1b5z33BVV6hai1Kh7uDgFOkbcs=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
zombiezen.com/go/capnproto2 v0.0.0-20180616160808-7cfd211c19c7 h1:CZoOFlTPbKfAShKYrMuUfYbnXexFT1rYRUX1SPnrdE4=
zombiezen.com/go/capnproto2 v0.0.0-20180616160808-7cfd211c19c7/go.mod h1:TMGa8HWGJkXiq4nHe9Zu/JgRF5oUtg4XizFC+Vexbec=

View File

@ -3,6 +3,7 @@ package h2mux
import (
"sync"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/net/http2"
)
@ -12,23 +13,28 @@ type activeStreamMap struct {
sync.RWMutex
// streams tracks open streams.
streams map[uint32]*MuxedStream
// streamsEmpty is a chan that should be closed when no more streams are open.
streamsEmpty chan struct{}
// nextStreamID is the next ID to use on our side of the connection.
// This is odd for clients, even for servers.
nextStreamID uint32
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
maxPeerStreamID uint32
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
activeStreams prometheus.Gauge
// ignoreNewStreams is true when the connection is being shut down. New streams
// cannot be registered.
ignoreNewStreams bool
// streamsEmpty is a chan that will be closed when no more streams are open.
streamsEmptyChan chan struct{}
closeOnce sync.Once
}
func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap {
func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
m := &activeStreamMap{
streams: make(map[uint32]*MuxedStream),
streamsEmpty: make(chan struct{}),
nextStreamID: 1,
streams: make(map[uint32]*MuxedStream),
streamsEmptyChan: make(chan struct{}),
nextStreamID: 1,
activeStreams: activeStreams,
}
// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
if !useClientStreamNumbers {
@ -37,6 +43,12 @@ func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap {
return m
}
func (m *activeStreamMap) notifyStreamsEmpty() {
m.closeOnce.Do(func() {
close(m.streamsEmptyChan)
})
}
// Len returns the number of active streams.
func (m *activeStreamMap) Len() int {
m.RLock()
@ -63,6 +75,7 @@ func (m *activeStreamMap) Set(newStream *MuxedStream) bool {
return false
}
m.streams[newStream.streamID] = newStream
m.activeStreams.Inc()
return true
}
@ -70,31 +83,31 @@ func (m *activeStreamMap) Set(newStream *MuxedStream) bool {
func (m *activeStreamMap) Delete(streamID uint32) {
m.Lock()
defer m.Unlock()
delete(m.streams, streamID)
if len(m.streams) == 0 && m.streamsEmpty != nil {
close(m.streamsEmpty)
m.streamsEmpty = nil
if _, ok := m.streams[streamID]; ok {
delete(m.streams, streamID)
m.activeStreams.Dec()
}
if len(m.streams) == 0 {
m.notifyStreamsEmpty()
}
}
// Shutdown blocks new streams from being created. It returns a channel that receives an event
// once the last stream has closed, or nil if a shutdown is in progress.
func (m *activeStreamMap) Shutdown() <-chan struct{} {
// Shutdown blocks new streams from being created.
// It returns `done`, a channel that is closed once the last stream has closed
// and `progress`, whether a shutdown was already in progress
func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bool) {
m.Lock()
defer m.Unlock()
if m.ignoreNewStreams {
// already shutting down
return nil
return m.streamsEmptyChan, true
}
m.ignoreNewStreams = true
done := make(chan struct{})
if len(m.streams) == 0 {
// nothing to shut down
close(done)
return done
m.notifyStreamsEmpty()
}
m.streamsEmpty = done
return done
return m.streamsEmptyChan, false
}
// AcquireLocalID acquires a new stream ID for a stream you're opening.
@ -162,4 +175,5 @@ func (m *activeStreamMap) Abort() {
stream.Close()
}
m.ignoreNewStreams = true
m.notifyStreamsEmpty()
}

View File

@ -0,0 +1,134 @@
package h2mux
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestShutdown(t *testing.T) {
const numStreams = 1000
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
// Add all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
stream := &MuxedStream{streamID: uint32(streamID)}
ok := m.Set(stream)
assert.True(t, ok)
}(i)
}
wg.Wait()
}
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
shutdownChan, alreadyInProgress := m.Shutdown()
select {
case <-shutdownChan:
assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed")
default:
}
assert.False(t, alreadyInProgress)
shutdownChan2, alreadyInProgress2 := m.Shutdown()
assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel")
assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'")
// Delete all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
m.Delete(uint32(streamID))
}(i)
}
wg.Wait()
}
assert.Equal(t, 0, m.Len(), "All the streams should have been deleted")
select {
case <-shutdownChan:
default:
assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed")
}
}
type noopBuffer struct {
isClosed bool
}
func (t *noopBuffer) Read(p []byte) (n int, err error) { return len(p), nil }
func (t *noopBuffer) Write(p []byte) (n int, err error) { return len(p), nil }
func (t *noopBuffer) Reset() {}
func (t *noopBuffer) Len() int { return 0 }
func (t *noopBuffer) Close() error { t.isClosed = true; return nil }
func (t *noopBuffer) Closed() bool { return t.isClosed }
type noopReadyList struct{}
func (_ *noopReadyList) Signal(streamID uint32) {}
func TestAbort(t *testing.T) {
const numStreams = 1000
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
var openedStreams sync.Map
// Add all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
stream := &MuxedStream{
streamID: uint32(streamID),
readBuffer: &noopBuffer{},
writeBuffer: &noopBuffer{},
readyList: &noopReadyList{},
}
ok := m.Set(stream)
assert.True(t, ok)
openedStreams.Store(stream.streamID, stream)
}(i)
}
wg.Wait()
}
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
shutdownChan, alreadyInProgress := m.Shutdown()
select {
case <-shutdownChan:
assert.Fail(t, "before Abort(), shutdownChan shouldn't be closed")
default:
}
assert.False(t, alreadyInProgress)
m.Abort()
assert.Equal(t, numStreams, m.Len(), "Abort() shouldn't delete any streams")
openedStreams.Range(func(key interface{}, value interface{}) bool {
stream := value.(*MuxedStream)
readBuffer := stream.readBuffer.(*noopBuffer)
writeBuffer := stream.writeBuffer.(*noopBuffer)
return assert.True(t, readBuffer.isClosed && writeBuffer.isClosed, "Abort() should have closed all the streams")
})
select {
case <-shutdownChan:
default:
assert.Fail(t, "after Abort(), shutdownChan should have been closed")
}
// multiple aborts shouldn't cause any issues
m.Abort()
m.Abort()
m.Abort()
}

View File

@ -20,11 +20,12 @@ var (
ErrUnknownStream = MuxerProtocolError{"2002 unknown stream", http2.ErrCodeProtocol}
ErrInvalidStream = MuxerProtocolError{"2003 invalid stream", http2.ErrCodeProtocol}
ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"}
ErrConnectionClosed = MuxerApplicationError{"3001 connection closed"}
ErrConnectionDropped = MuxerApplicationError{"3002 connection dropped"}
ErrOpenStreamTimeout = MuxerApplicationError{"3003 open stream timeout"}
ErrResponseHeadersTimeout = MuxerApplicationError{"3004 timeout waiting for initial response headers"}
ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"}
ErrStreamRequestConnectionClosed = MuxerApplicationError{"3001 connection closed while opening stream"}
ErrConnectionDropped = MuxerApplicationError{"3002 connection dropped"}
ErrStreamRequestTimeout = MuxerApplicationError{"3003 open stream timeout"}
ErrResponseHeadersTimeout = MuxerApplicationError{"3004 timeout waiting for initial response headers"}
ErrResponseHeadersConnectionClosed = MuxerApplicationError{"3005 connection closed while waiting for initial response headers"}
ErrClosedStream = MuxerStreamError{"4000 stream closed", http2.ErrCodeStreamClosed}
)

View File

@ -542,7 +542,10 @@ func (w *h2DictWriter) Write(p []byte) (n int, err error) {
}
func (w *h2DictWriter) Close() error {
return w.comp.Close()
if w.comp != nil {
return w.comp.Close()
}
return nil
}
// From http2/hpack

View File

@ -1,13 +1,13 @@
package h2mux
import (
"bytes"
"context"
"io"
"strings"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
@ -20,7 +20,7 @@ const (
maxWindowSize uint32 = (1 << 31) - 1 // 2^31-1 = 2147483647, max window size in http2 spec
defaultTimeout time.Duration = 5 * time.Second
defaultRetries uint64 = 5
defaultWriteBufferMaxLen int = 1024 * 1024 * 512 // 500mb
defaultWriteBufferMaxLen int = 1024 * 1024 // 1mb
SettingMuxerMagic http2.SettingID = 0x42db
MuxerMagicOrigin uint32 = 0xa2e43c8b
@ -108,6 +108,7 @@ func Handshake(
w io.WriteCloser,
r io.ReadCloser,
config MuxerConfig,
activeStreamsMetrics prometheus.Gauge,
) (*Muxer, error) {
// Set default config values
if config.Timeout == 0 {
@ -131,7 +132,7 @@ func Handshake(
newStreamChan: make(chan MuxedStreamRequest),
abortChan: make(chan struct{}),
readyList: NewReadyList(),
streams: newActiveStreamMap(config.IsClient),
streams: newActiveStreamMap(config.IsClient, activeStreamsMetrics),
}
m.f.ReadMetaHeaders = hpack.NewDecoder(4096, func(hpack.HeaderField) {})
@ -352,9 +353,11 @@ func (m *Muxer) Serve(ctx context.Context) error {
}
// Shutdown is called to initiate the "happy path" of muxer termination.
func (m *Muxer) Shutdown() {
// It blocks new streams from being created.
// It returns a channel that is closed when the last stream has been closed.
func (m *Muxer) Shutdown() <-chan struct{} {
m.explicitShutdown.Fuse(true)
m.muxReader.Shutdown()
return m.muxReader.Shutdown()
}
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
@ -388,72 +391,58 @@ func isConnectionClosedError(err error) bool {
// OpenStream opens a new data stream with the given headers.
// Called by proxy server and tunnel
func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader) (*MuxedStream, error) {
stream := &MuxedStream{
responseHeadersReceived: make(chan struct{}),
readBuffer: NewSharedBuffer(),
writeBuffer: &bytes.Buffer{},
writeBufferMaxLen: m.config.StreamWriteBufferMaxLen,
writeBufferHasSpace: make(chan struct{}, 1),
receiveWindow: m.config.DefaultWindowSize,
receiveWindowCurrentMax: m.config.DefaultWindowSize,
receiveWindowMax: m.config.MaxWindowSize,
sendWindow: m.config.DefaultWindowSize,
readyList: m.readyList,
writeHeaders: headers,
dictionaries: m.muxReader.dictionaries,
stream := m.NewStream(headers)
if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, body)); err != nil {
return nil, err
}
select {
// Will be received by mux writer
case <-ctx.Done():
return nil, ErrOpenStreamTimeout
case <-m.abortChan:
return nil, ErrConnectionClosed
case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: body}:
}
select {
case <-ctx.Done():
return nil, ErrResponseHeadersTimeout
case <-m.abortChan:
return nil, ErrConnectionClosed
case <-stream.responseHeadersReceived:
return stream, nil
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
return nil, err
}
return stream, nil
}
func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) {
stream := &MuxedStream{
responseHeadersReceived: make(chan struct{}),
readBuffer: NewSharedBuffer(),
writeBuffer: &bytes.Buffer{},
writeBufferMaxLen: m.config.StreamWriteBufferMaxLen,
writeBufferHasSpace: make(chan struct{}, 1),
receiveWindow: m.config.DefaultWindowSize,
receiveWindowCurrentMax: m.config.DefaultWindowSize,
receiveWindowMax: m.config.MaxWindowSize,
sendWindow: m.config.DefaultWindowSize,
readyList: m.readyList,
writeHeaders: RPCHeaders(),
dictionaries: m.muxReader.dictionaries,
stream := m.NewStream(RPCHeaders())
if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, nil)); err != nil {
return nil, err
}
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
return nil, err
}
return stream, nil
}
func (m *Muxer) NewStream(headers []Header) *MuxedStream {
return NewStream(m.config, headers, m.readyList, m.muxReader.dictionaries)
}
func (m *Muxer) MakeMuxedStreamRequest(ctx context.Context, request MuxedStreamRequest) error {
select {
case <-ctx.Done():
return ErrStreamRequestTimeout
case <-m.abortChan:
return ErrStreamRequestConnectionClosed
// Will be received by mux writer
case <-ctx.Done():
return nil, ErrOpenStreamTimeout
case <-m.abortChan:
return nil, ErrConnectionClosed
case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: nil}:
case m.newStreamChan <- request:
return nil
}
}
func (m *Muxer) CloseStreamRead(stream *MuxedStream) {
stream.CloseRead()
if stream.WriteClosed() {
m.streams.Delete(stream.streamID)
}
}
func (m *Muxer) AwaitResponseHeaders(ctx context.Context, stream *MuxedStream) error {
select {
case <-ctx.Done():
return nil, ErrResponseHeadersTimeout
return ErrResponseHeadersTimeout
case <-m.abortChan:
return nil, ErrConnectionClosed
return ErrResponseHeadersConnectionClosed
case <-stream.responseHeadersReceived:
return stream, nil
return nil
}
}

View File

@ -43,7 +43,7 @@ type DefaultMuxerPair struct {
doneC chan struct{}
}
func NewDefaultMuxerPair(t assert.TestingT, f MuxedStreamFunc) *DefaultMuxerPair {
func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc) *DefaultMuxerPair {
origin, edge := net.Pipe()
p := &DefaultMuxerPair{
OriginMuxConfig: MuxerConfig{
@ -55,6 +55,8 @@ func NewDefaultMuxerPair(t assert.TestingT, f MuxedStreamFunc) *DefaultMuxerPair
DefaultWindowSize: (1 << 8) - 1,
MaxWindowSize: (1 << 15) - 1,
StreamWriteBufferMaxLen: 1024,
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
},
OriginConn: origin,
EdgeMuxConfig: MuxerConfig{
@ -65,15 +67,17 @@ func NewDefaultMuxerPair(t assert.TestingT, f MuxedStreamFunc) *DefaultMuxerPair
DefaultWindowSize: (1 << 8) - 1,
MaxWindowSize: (1 << 15) - 1,
StreamWriteBufferMaxLen: 1024,
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
},
EdgeConn: edge,
doneC: make(chan struct{}),
}
assert.NoError(t, p.Handshake())
assert.NoError(t, p.Handshake(testName))
return p
}
func NewCompressedMuxerPair(t assert.TestingT, quality CompressionSetting, f MuxedStreamFunc) *DefaultMuxerPair {
func NewCompressedMuxerPair(t assert.TestingT, testName string, quality CompressionSetting, f MuxedStreamFunc) *DefaultMuxerPair {
origin, edge := net.Pipe()
p := &DefaultMuxerPair{
OriginMuxConfig: MuxerConfig{
@ -83,6 +87,8 @@ func NewCompressedMuxerPair(t assert.TestingT, quality CompressionSetting, f Mux
Name: "origin",
CompressionQuality: quality,
Logger: log.NewEntry(log.New()),
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
},
OriginConn: origin,
EdgeMuxConfig: MuxerConfig{
@ -91,24 +97,26 @@ func NewCompressedMuxerPair(t assert.TestingT, quality CompressionSetting, f Mux
Name: "edge",
CompressionQuality: quality,
Logger: log.NewEntry(log.New()),
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
},
EdgeConn: edge,
doneC: make(chan struct{}),
}
assert.NoError(t, p.Handshake())
assert.NoError(t, p.Handshake(testName))
return p
}
func (p *DefaultMuxerPair) Handshake() error {
func (p *DefaultMuxerPair) Handshake(testName string) error {
ctx, cancel := context.WithTimeout(context.Background(), testHandshakeTimeout)
defer cancel()
errGroup, _ := errgroup.WithContext(ctx)
errGroup.Go(func() (err error) {
p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig)
p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig, NewActiveStreamsMetrics(testName, "edge"))
return errors.Wrap(err, "edge handshake failure")
})
errGroup.Go(func() (err error) {
p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig)
p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig, NewActiveStreamsMetrics(testName, "origin"))
return errors.Wrap(err, "origin handshake failure")
})
@ -161,7 +169,7 @@ func TestHandshake(t *testing.T) {
f := func(stream *MuxedStream) error {
return nil
}
muxPair := NewDefaultMuxerPair(t, f)
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
AssertIfPipeReadable(t, muxPair.OriginConn)
AssertIfPipeReadable(t, muxPair.EdgeConn)
}
@ -191,7 +199,7 @@ func TestSingleStream(t *testing.T) {
}
return nil
})
muxPair := NewDefaultMuxerPair(t, f)
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
muxPair.Serve(t)
stream, err := muxPair.OpenEdgeMuxStream(
@ -262,7 +270,7 @@ func TestSingleStreamLargeResponseBody(t *testing.T) {
return nil
})
muxPair := NewDefaultMuxerPair(t, f)
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
muxPair.Serve(t)
stream, err := muxPair.OpenEdgeMuxStream(
@ -309,7 +317,7 @@ func TestMultipleStreams(t *testing.T) {
log.Debugf("Wrote body for stream %s", stream.Headers[0].Value)
return nil
})
muxPair := NewDefaultMuxerPair(t, f)
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
muxPair.Serve(t)
maxStreams := 64
@ -402,7 +410,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) {
}
return nil
})
muxPair := NewDefaultMuxerPair(t, f)
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
muxPair.Serve(t)
errGroup, _ := errgroup.WithContext(context.Background())
@ -461,7 +469,7 @@ func TestGracefulShutdown(t *testing.T) {
log.Debugf("Handler ends")
return nil
})
muxPair := NewDefaultMuxerPair(t, f)
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
muxPair.Serve(t)
stream, err := muxPair.OpenEdgeMuxStream(
@ -516,7 +524,7 @@ func TestUnexpectedShutdown(t *testing.T) {
}
return nil
})
muxPair := NewDefaultMuxerPair(t, f)
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
muxPair.Serve(t)
stream, err := muxPair.OpenEdgeMuxStream(
@ -564,7 +572,7 @@ func EchoHandler(stream *MuxedStream) error {
func TestOpenAfterDisconnect(t *testing.T) {
for i := 0; i < 3; i++ {
muxPair := NewDefaultMuxerPair(t, EchoHandler)
muxPair := NewDefaultMuxerPair(t, fmt.Sprintf("%s_%d", t.Name(), i), EchoHandler)
muxPair.Serve(t)
switch i {
@ -584,14 +592,14 @@ func TestOpenAfterDisconnect(t *testing.T) {
[]Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != ErrConnectionClosed {
t.Fatalf("unexpected error in OpenStream: %s", err)
if err != ErrStreamRequestConnectionClosed && err != ErrResponseHeadersConnectionClosed {
t.Fatalf("case %v: unexpected error in OpenStream: %v", i, err)
}
}
}
func TestHPACK(t *testing.T) {
muxPair := NewDefaultMuxerPair(t, EchoHandler)
muxPair := NewDefaultMuxerPair(t, t.Name(), EchoHandler)
muxPair.Serve(t)
stream, err := muxPair.OpenEdgeMuxStream(
@ -724,7 +732,7 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) {
return nil
})
muxPair := NewCompressedMuxerPair(t, q, f)
muxPair := NewCompressedMuxerPair(t, fmt.Sprintf("%s_%d", t.Name(), q), q, f)
muxPair.Serve(t)
var wg sync.WaitGroup
@ -918,7 +926,7 @@ func TestSampleSiteWithDictionaries(t *testing.T) {
assert.NoError(t, err)
for q := CompressionNone; q <= CompressionMax; q++ {
muxPair := NewCompressedMuxerPair(t, q, sampleSiteHandler(files))
muxPair := NewCompressedMuxerPair(t, fmt.Sprintf("%s_%d", t.Name(), q), q, sampleSiteHandler(files))
muxPair.Serve(t)
var wg sync.WaitGroup
@ -957,7 +965,7 @@ func TestLongSiteWithDictionaries(t *testing.T) {
files, err := loadSampleFiles(paths)
assert.NoError(t, err)
for q := CompressionNone; q <= CompressionMedium; q++ {
muxPair := NewCompressedMuxerPair(t, q, sampleSiteHandler(files))
muxPair := NewCompressedMuxerPair(t, fmt.Sprintf("%s_%d", t.Name(), q), q, sampleSiteHandler(files))
muxPair.Serve(t)
rand.Seed(time.Now().Unix())
@ -998,7 +1006,7 @@ func BenchmarkOpenStream(b *testing.B) {
})
return nil
})
muxPair := NewDefaultMuxerPair(b, f)
muxPair := NewDefaultMuxerPair(b, fmt.Sprintf("%s_%d", b.Name(), i), f)
muxPair.Serve(b)
b.StartTimer()
openStreams(b, muxPair, streams)

View File

@ -17,6 +17,12 @@ type ReadWriteClosedCloser interface {
Closed() bool
}
// MuxedStreamDataSignaller is a write-only *ReadyList
type MuxedStreamDataSignaller interface {
// Non-blocking: call this when data is ready to be sent for the given stream ID.
Signal(ID uint32)
}
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
type MuxedStream struct {
streamID uint32
@ -55,8 +61,8 @@ type MuxedStream struct {
// This is the amount of bytes that are in the peer's receive window
// (how much data we can send from this stream).
sendWindow uint32
// Reference to the muxer's readyList; signal this for stream data to be sent.
readyList *ReadyList
// The muxer's readyList
readyList MuxedStreamDataSignaller
// The headers that should be sent, and a flag so we only send them once.
headersSent bool
writeHeaders []Header
@ -88,6 +94,23 @@ func (th TunnelHostname) IsSet() bool {
return th != ""
}
func NewStream(config MuxerConfig, writeHeaders []Header, readyList MuxedStreamDataSignaller, dictionaries h2Dictionaries) *MuxedStream {
return &MuxedStream{
responseHeadersReceived: make(chan struct{}),
readBuffer: NewSharedBuffer(),
writeBuffer: &bytes.Buffer{},
writeBufferMaxLen: config.StreamWriteBufferMaxLen,
writeBufferHasSpace: make(chan struct{}, 1),
receiveWindow: config.DefaultWindowSize,
receiveWindowCurrentMax: config.DefaultWindowSize,
receiveWindowMax: config.MaxWindowSize,
sendWindow: config.DefaultWindowSize,
readyList: readyList,
writeHeaders: writeHeaders,
dictionaries: dictionaries,
}
}
func (s *MuxedStream) Read(p []byte) (n int, err error) {
var readBuffer ReadWriteClosedCloser
if s.dictionaries.read != nil {
@ -120,9 +143,10 @@ func (s *MuxedStream) Write(p []byte) (int, error) {
// If the buffer is full, block till there is more room.
// Use a loop to recheck the buffer size after the lock is reacquired.
for s.writeBufferMaxLen <= s.writeBuffer.Len() {
s.writeLock.Unlock()
<-s.writeBufferHasSpace
s.writeLock.Lock()
s.awaitWriteBufferHasSpace()
if s.writeEOF {
return totalWritten, io.EOF
}
}
amountToWrite := len(p) - totalWritten
spaceAvailable := s.writeBufferMaxLen - s.writeBuffer.Len()
@ -171,10 +195,19 @@ func (s *MuxedStream) CloseWrite() error {
if c, ok := s.writeBuffer.(io.Closer); ok {
c.Close()
}
// Allow MuxedStream::Write() to terminate its loop with err=io.EOF, if needed
s.notifyWriteBufferHasSpace()
// We need to send something over the wire, even if it's an END_STREAM with no data
s.writeNotify()
return nil
}
func (s *MuxedStream) WriteClosed() bool {
s.writeLock.Lock()
defer s.writeLock.Unlock()
return s.writeEOF
}
func (s *MuxedStream) WriteHeaders(headers []Header) error {
s.writeLock.Lock()
defer s.writeLock.Unlock()
@ -215,6 +248,23 @@ func (s *MuxedStream) TunnelHostname() TunnelHostname {
return s.tunnelHostname
}
// Block until a value is sent on writeBufferHasSpace.
// Must be called while holding writeLock
func (s *MuxedStream) awaitWriteBufferHasSpace() {
s.writeLock.Unlock()
<-s.writeBufferHasSpace
s.writeLock.Lock()
}
// Send a value on writeBufferHasSpace without blocking.
// Must be called while holding writeLock
func (s *MuxedStream) notifyWriteBufferHasSpace() {
select {
case s.writeBufferHasSpace <- struct{}{}:
default:
}
}
func (s *MuxedStream) getReceiveWindow() uint32 {
s.writeLock.Lock()
defer s.writeLock.Unlock()
@ -334,17 +384,13 @@ func (s *MuxedStream) getChunk() *streamChunk {
sendData: !s.sentEOF,
eof: s.writeEOF && uint32(s.writeBuffer.Len()) <= s.sendWindow,
}
// Copy at most s.sendWindow bytes, adjust the sendWindow accordingly
writeLen, _ := io.CopyN(&chunk.buffer, s.writeBuffer, int64(s.sendWindow))
s.sendWindow -= uint32(writeLen)
// Non-blocking channel send. This will allow MuxedStream::Write() to continue, if needed
// Allow MuxedStream::Write() to continue, if needed
if s.writeBuffer.Len() < s.writeBufferMaxLen {
select {
case s.writeBufferHasSpace <- struct{}{}:
default:
}
s.notifyWriteBufferHasSpace()
}
// When we write the chunk, we'll write the WINDOW_UPDATE frame if needed

View File

@ -5,6 +5,7 @@ import (
"time"
"github.com/golang-collections/collections/queue"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
)
@ -299,3 +300,14 @@ func (r *rate) get() (curr, min, max uint64) {
defer r.lock.RUnlock()
return r.curr, r.min, r.max
}
func NewActiveStreamsMetrics(namespace, subsystem string) prometheus.Gauge {
activeStreams := prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "active_streams",
Help: "Number of active streams created by all muxers.",
})
prometheus.MustRegister(activeStreams)
return activeStreams
}

View File

@ -51,10 +51,12 @@ type MuxReader struct {
dictionaries h2Dictionaries
}
func (r *MuxReader) Shutdown() {
done := r.streams.Shutdown()
if done == nil {
return
// Shutdown blocks new streams from being created.
// It returns a channel that is closed once the last stream has closed.
func (r *MuxReader) Shutdown() <-chan struct{} {
done, alreadyInProgress := r.streams.Shutdown()
if alreadyInProgress {
return done
}
r.sendGoAway(http2.ErrCodeNo)
go func() {
@ -62,6 +64,7 @@ func (r *MuxReader) Shutdown() {
<-done
r.r.Close()
}()
return done
}
func (r *MuxReader) run(parentLogger *log.Entry) error {
@ -87,23 +90,28 @@ func (r *MuxReader) run(parentLogger *log.Entry) error {
for {
frame, err := r.f.ReadFrame()
if err != nil {
errLogger := logger.WithError(err)
if errorDetail := r.f.ErrorDetail(); errorDetail != nil {
errLogger = errLogger.WithField("errorDetail", errorDetail)
}
switch e := err.(type) {
case http2.StreamError:
logger.WithError(err).Warn("stream error")
errLogger.Warn("stream error")
r.streamError(e.StreamID, e.Code)
case http2.ConnectionError:
logger.WithError(err).Warn("connection error")
errLogger.Warn("connection error")
return r.connectionError(err)
default:
if isConnectionClosedError(err) {
if r.streams.Len() == 0 {
// don't log the error here -- that would just be extra noise
logger.Debug("shutting down")
return nil
}
logger.Warn("connection closed unexpectedly")
errLogger.Warn("connection closed unexpectedly")
return err
} else {
logger.WithError(err).Warn("frame read error")
errLogger.Warn("frame read error")
return r.connectionError(err)
}
}
@ -120,6 +128,9 @@ func (r *MuxReader) run(parentLogger *log.Entry) error {
if streamID == 0 {
return ErrInvalidStream
}
if stream, ok := r.streams.Get(streamID); ok {
stream.Close()
}
r.streams.Delete(streamID)
case *http2.PingFrame:
r.receivePingData(f)

View File

@ -59,7 +59,7 @@ func assertOpenStreamSucceed(t *testing.T, stream *MuxedStream, err error) {
func TestMissingHeaders(t *testing.T) {
originHandler := &mockOriginStreamHandler{}
muxPair := NewDefaultMuxerPair(t, originHandler.ServeStream)
muxPair := NewDefaultMuxerPair(t, t.Name(), originHandler.ServeStream)
muxPair.Serve(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
@ -83,7 +83,7 @@ func TestMissingHeaders(t *testing.T) {
func TestReceiveHeaderData(t *testing.T) {
originHandler := &mockOriginStreamHandler{}
muxPair := NewDefaultMuxerPair(t, originHandler.ServeStream)
muxPair := NewDefaultMuxerPair(t, t.Name(), originHandler.ServeStream)
muxPair.Serve(t)
reqHeaders := []Header{

View File

@ -54,6 +54,13 @@ type MuxedStreamRequest struct {
body io.Reader
}
func NewMuxedStreamRequest(stream *MuxedStream, body io.Reader) MuxedStreamRequest {
return MuxedStreamRequest{
stream: stream,
body: body,
}
}
func (r *MuxedStreamRequest) flushBody() {
io.Copy(r.stream, r.body)
r.stream.CloseWrite()

View File

@ -103,15 +103,16 @@ func StartHelloWorldServer(logger *logrus.Logger, listener net.Listener, shutdow
WriteBufferSize: 1024,
}
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
muxer := http.NewServeMux()
muxer.HandleFunc("/uptime", uptimeHandler(time.Now()))
muxer.HandleFunc("/ws", websocketHandler(logger, upgrader))
muxer.HandleFunc("/", rootHandler(serverName))
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer}
go func() {
<-shutdownC
httpServer.Close()
}()
http.HandleFunc("/uptime", uptimeHandler(time.Now()))
http.HandleFunc("/ws", websocketHandler(logger, upgrader))
http.HandleFunc("/", rootHandler(serverName))
err := httpServer.Serve(listener)
return err
}

View File

@ -2,6 +2,7 @@ package metrics
import (
"context"
"fmt"
"net"
"net/http"
_ "net/http/pprof"
@ -33,6 +34,9 @@ func ServeMetrics(l net.Listener, shutdownC <-chan struct{}, logger *logrus.Logg
}
http.Handle("/metrics", promhttp.Handler())
http.Handle("/healthcheck", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "OK\n")
}))
wg.Add(1)
go func() {

View File

@ -92,3 +92,8 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
}
return b.BaseTime
}
// Retries returns the number of retries consumed so far.
func (b *BackoffHandler) Retries() int {
return int(b.retries)
}

View File

@ -9,6 +9,12 @@ import (
"github.com/prometheus/client_golang/prometheus"
)
const (
metricsNamespace = "cloudflared"
tunnelSubsystem = "tunnel"
muxerSubsystem = "muxer"
)
type muxerMetrics struct {
rtt *prometheus.GaugeVec
rttMin *prometheus.GaugeVec
@ -32,6 +38,7 @@ type muxerMetrics struct {
type TunnelMetrics struct {
haConnections prometheus.Gauge
activeStreams prometheus.Gauge
totalRequests prometheus.Counter
requestsPerTunnel *prometheus.CounterVec
// concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests
@ -63,8 +70,10 @@ type TunnelMetrics struct {
func newMuxerMetrics() *muxerMetrics {
rtt := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "rtt",
Help: "Round-trip time in millisecond",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "rtt",
Help: "Round-trip time in millisecond",
},
[]string{"connection_id"},
)
@ -72,8 +81,10 @@ func newMuxerMetrics() *muxerMetrics {
rttMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "rtt_min",
Help: "Shortest round-trip time in millisecond",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "rtt_min",
Help: "Shortest round-trip time in millisecond",
},
[]string{"connection_id"},
)
@ -81,8 +92,10 @@ func newMuxerMetrics() *muxerMetrics {
rttMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "rtt_max",
Help: "Longest round-trip time in millisecond",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "rtt_max",
Help: "Longest round-trip time in millisecond",
},
[]string{"connection_id"},
)
@ -90,8 +103,10 @@ func newMuxerMetrics() *muxerMetrics {
receiveWindowAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "receive_window_ave",
Help: "Average receive window size in bytes",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "receive_window_ave",
Help: "Average receive window size in bytes",
},
[]string{"connection_id"},
)
@ -99,8 +114,10 @@ func newMuxerMetrics() *muxerMetrics {
sendWindowAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "send_window_ave",
Help: "Average send window size in bytes",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "send_window_ave",
Help: "Average send window size in bytes",
},
[]string{"connection_id"},
)
@ -108,8 +125,10 @@ func newMuxerMetrics() *muxerMetrics {
receiveWindowMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "receive_window_min",
Help: "Smallest receive window size in bytes",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "receive_window_min",
Help: "Smallest receive window size in bytes",
},
[]string{"connection_id"},
)
@ -117,8 +136,10 @@ func newMuxerMetrics() *muxerMetrics {
receiveWindowMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "receive_window_max",
Help: "Largest receive window size in bytes",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "receive_window_max",
Help: "Largest receive window size in bytes",
},
[]string{"connection_id"},
)
@ -126,8 +147,10 @@ func newMuxerMetrics() *muxerMetrics {
sendWindowMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "send_window_min",
Help: "Smallest send window size in bytes",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "send_window_min",
Help: "Smallest send window size in bytes",
},
[]string{"connection_id"},
)
@ -135,8 +158,10 @@ func newMuxerMetrics() *muxerMetrics {
sendWindowMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "send_window_max",
Help: "Largest send window size in bytes",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "send_window_max",
Help: "Largest send window size in bytes",
},
[]string{"connection_id"},
)
@ -144,8 +169,10 @@ func newMuxerMetrics() *muxerMetrics {
inBoundRateCurr := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "inbound_bytes_per_sec_curr",
Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "inbound_bytes_per_sec_curr",
Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
},
[]string{"connection_id"},
)
@ -153,8 +180,10 @@ func newMuxerMetrics() *muxerMetrics {
inBoundRateMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "inbound_bytes_per_sec_min",
Help: "Minimum non-zero inbounding bytes per second",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "inbound_bytes_per_sec_min",
Help: "Minimum non-zero inbounding bytes per second",
},
[]string{"connection_id"},
)
@ -162,8 +191,10 @@ func newMuxerMetrics() *muxerMetrics {
inBoundRateMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "inbound_bytes_per_sec_max",
Help: "Maximum inbounding bytes per second",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "inbound_bytes_per_sec_max",
Help: "Maximum inbounding bytes per second",
},
[]string{"connection_id"},
)
@ -171,8 +202,10 @@ func newMuxerMetrics() *muxerMetrics {
outBoundRateCurr := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "outbound_bytes_per_sec_curr",
Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "outbound_bytes_per_sec_curr",
Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
},
[]string{"connection_id"},
)
@ -180,8 +213,10 @@ func newMuxerMetrics() *muxerMetrics {
outBoundRateMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "outbound_bytes_per_sec_min",
Help: "Minimum non-zero outbounding bytes per second",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "outbound_bytes_per_sec_min",
Help: "Minimum non-zero outbounding bytes per second",
},
[]string{"connection_id"},
)
@ -189,8 +224,10 @@ func newMuxerMetrics() *muxerMetrics {
outBoundRateMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "outbound_bytes_per_sec_max",
Help: "Maximum outbounding bytes per second",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "outbound_bytes_per_sec_max",
Help: "Maximum outbounding bytes per second",
},
[]string{"connection_id"},
)
@ -198,8 +235,10 @@ func newMuxerMetrics() *muxerMetrics {
compBytesBefore := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "comp_bytes_before",
Help: "Bytes sent via cross-stream compression, pre compression",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "comp_bytes_before",
Help: "Bytes sent via cross-stream compression, pre compression",
},
[]string{"connection_id"},
)
@ -207,8 +246,10 @@ func newMuxerMetrics() *muxerMetrics {
compBytesAfter := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "comp_bytes_after",
Help: "Bytes sent via cross-stream compression, post compression",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "comp_bytes_after",
Help: "Bytes sent via cross-stream compression, post compression",
},
[]string{"connection_id"},
)
@ -216,8 +257,10 @@ func newMuxerMetrics() *muxerMetrics {
compRateAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "comp_rate_ave",
Help: "Average outbound cross-stream compression ratio",
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "comp_rate_ave",
Help: "Average outbound cross-stream compression ratio",
},
[]string{"connection_id"},
)
@ -274,22 +317,30 @@ func convertRTTMilliSec(t time.Duration) float64 {
func NewTunnelMetrics() *TunnelMetrics {
haConnections := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "ha_connections",
Help: "Number of active ha connections",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "ha_connections",
Help: "Number of active ha connections",
})
prometheus.MustRegister(haConnections)
activeStreams := h2mux.NewActiveStreamsMetrics(metricsNamespace, tunnelSubsystem)
totalRequests := prometheus.NewCounter(
prometheus.CounterOpts{
Name: "total_requests",
Help: "Amount of requests proxied through all the tunnels",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "total_requests",
Help: "Amount of requests proxied through all the tunnels",
})
prometheus.MustRegister(totalRequests)
requestsPerTunnel := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "requests_per_tunnel",
Help: "Amount of requests proxied through each tunnel",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "requests_per_tunnel",
Help: "Amount of requests proxied through each tunnel",
},
[]string{"connection_id"},
)
@ -297,8 +348,10 @@ func NewTunnelMetrics() *TunnelMetrics {
concurrentRequestsPerTunnel := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "concurrent_requests_per_tunnel",
Help: "Concurrent requests proxied through each tunnel",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "concurrent_requests_per_tunnel",
Help: "Concurrent requests proxied through each tunnel",
},
[]string{"connection_id"},
)
@ -306,8 +359,10 @@ func NewTunnelMetrics() *TunnelMetrics {
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "max_concurrent_requests_per_tunnel",
Help: "Largest number of concurrent requests proxied through each tunnel so far",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "max_concurrent_requests_per_tunnel",
Help: "Largest number of concurrent requests proxied through each tunnel so far",
},
[]string{"connection_id"},
)
@ -315,15 +370,19 @@ func NewTunnelMetrics() *TunnelMetrics {
timerRetries := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "timer_retries",
Help: "Unacknowledged heart beats count",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "timer_retries",
Help: "Unacknowledged heart beats count",
})
prometheus.MustRegister(timerRetries)
responseByCode := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "response_by_code",
Help: "Count of responses by HTTP status code",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "response_by_code",
Help: "Count of responses by HTTP status code",
},
[]string{"status_code"},
)
@ -331,8 +390,10 @@ func NewTunnelMetrics() *TunnelMetrics {
responseCodePerTunnel := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "response_code_per_tunnel",
Help: "Count of responses by HTTP status code fore each tunnel",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "response_code_per_tunnel",
Help: "Count of responses by HTTP status code fore each tunnel",
},
[]string{"connection_id", "status_code"},
)
@ -340,8 +401,10 @@ func NewTunnelMetrics() *TunnelMetrics {
serverLocations := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "server_locations",
Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "server_locations",
Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.",
},
[]string{"connection_id", "location"},
)
@ -349,8 +412,10 @@ func NewTunnelMetrics() *TunnelMetrics {
rpcFail := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "tunnel_rpc_fail",
Help: "Count of RPC connection errors by type",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "tunnel_rpc_fail",
Help: "Count of RPC connection errors by type",
},
[]string{"error"},
)
@ -358,8 +423,10 @@ func NewTunnelMetrics() *TunnelMetrics {
registerFail := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "tunnel_register_fail",
Help: "Count of tunnel registration errors by type",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "tunnel_register_fail",
Help: "Count of tunnel registration errors by type",
},
[]string{"error"},
)
@ -367,8 +434,10 @@ func NewTunnelMetrics() *TunnelMetrics {
userHostnamesCounts := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "user_hostnames_counts",
Help: "Which user hostnames cloudflared is serving",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "user_hostnames_counts",
Help: "Which user hostnames cloudflared is serving",
},
[]string{"userHostname"},
)
@ -376,13 +445,16 @@ func NewTunnelMetrics() *TunnelMetrics {
registerSuccess := prometheus.NewCounter(
prometheus.CounterOpts{
Name: "tunnel_register_success",
Help: "Count of successful tunnel registrations",
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "tunnel_register_success",
Help: "Count of successful tunnel registrations",
})
prometheus.MustRegister(registerSuccess)
return &TunnelMetrics{
haConnections: haConnections,
activeStreams: activeStreams,
totalRequests: totalRequests,
requestsPerTunnel: requestsPerTunnel,
concurrentRequestsPerTunnel: concurrentRequestsPerTunnel,

View File

@ -2,16 +2,20 @@ package origin
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"sync"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/signal"
"github.com/google/uuid"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
@ -21,11 +25,23 @@ const (
resolveTTL = time.Hour
// Interval between registering new tunnels
registrationInterval = time.Second
subsystemRefreshAuth = "refresh_auth"
// Maximum exponent for 'Authenticate' exponential backoff
refreshAuthMaxBackoff = 10
// Waiting time before retrying a failed 'Authenticate' connection
refreshAuthRetryDuration = time.Second * 10
)
var (
errJWTUnset = errors.New("JWT unset")
errEventDigestUnset = errors.New("event digest unset")
)
type Supervisor struct {
config *TunnelConfig
edgeIPs []*net.TCPAddr
cloudflaredUUID uuid.UUID
config *TunnelConfig
edgeIPs []*net.TCPAddr
// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
nextUnusedEdgeIP int
lastResolve time.Time
@ -38,6 +54,12 @@ type Supervisor struct {
nextConnectedSignal chan struct{}
logger *logrus.Entry
jwtLock *sync.RWMutex
jwt []byte
eventDigestLock *sync.RWMutex
eventDigest []byte
}
type resolveResult struct {
@ -50,18 +72,21 @@ type tunnelError struct {
err error
}
func NewSupervisor(config *TunnelConfig) *Supervisor {
func NewSupervisor(config *TunnelConfig, u uuid.UUID) *Supervisor {
return &Supervisor{
cloudflaredUUID: u,
config: config,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
logger: config.Logger.WithField("subsystem", "supervisor"),
jwtLock: &sync.RWMutex{},
eventDigestLock: &sync.RWMutex{},
}
}
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
logger := s.config.Logger
if err := s.initialize(ctx, connectedSignal, u); err != nil {
if err := s.initialize(ctx, connectedSignal); err != nil {
return err
}
var tunnelsWaiting []int
@ -69,6 +94,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
var backoffTimer <-chan time.Time
tunnelsActive := s.config.HAConnections
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
var refreshAuthBackoffTimer <-chan time.Time
if s.config.UseReconnectToken {
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
}
for {
select {
// Context cancelled
@ -104,10 +135,20 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
case <-backoffTimer:
backoffTimer = nil
for _, index := range tunnelsWaiting {
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u)
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
}
tunnelsActive += len(tunnelsWaiting)
tunnelsWaiting = nil
// Time to call Authenticate
case <-refreshAuthBackoffTimer:
newTimer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate)
if err != nil {
logger.WithError(err).Error("Authentication failed")
// Permanent failure. Leave the `select` without setting the
// channel to be non-null, so we'll never hit this case of the `select` again.
continue
}
refreshAuthBackoffTimer = newTimer
// Tunnel successfully connected
case <-s.nextConnectedSignal:
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
@ -128,7 +169,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
}
}
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
logger := s.logger
edgeIPs, err := s.resolveEdgeIPs()
@ -145,12 +186,12 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
s.lastResolve = time.Now()
// check entitlement and version too old error before attempting to register more tunnels
s.nextUnusedEdgeIP = s.config.HAConnections
go s.startFirstTunnel(ctx, connectedSignal, u)
go s.startFirstTunnel(ctx, connectedSignal)
select {
case <-ctx.Done():
<-s.tunnelErrors
// Error can't be nil. A nil error signals that initialization succeed
return fmt.Errorf("context was canceled")
return ctx.Err()
case tunnelError := <-s.tunnelErrors:
return tunnelError.err
case <-connectedSignal.Wait():
@ -158,7 +199,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ {
ch := signal.New(make(chan struct{}))
go s.startTunnel(ctx, i, ch, u)
go s.startTunnel(ctx, i, ch)
time.Sleep(registrationInterval)
}
return nil
@ -166,8 +207,8 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
defer func() {
s.tunnelErrors <- tunnelError{index: 0, err: err}
}()
@ -183,19 +224,19 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return
// try the next address if it was a dialError(network problem) or
// dupConnRegisterTunnelError
case dialError, dupConnRegisterTunnelError:
case connection.DialError, dupConnRegisterTunnelError:
s.replaceEdgeIP(0)
default:
return
}
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
}
}
// startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u)
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID)
s.tunnelErrors <- tunnelError{index: index, err: err}
}
@ -253,3 +294,109 @@ func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
s.nextUnusedEdgeIP++
}
func (s *Supervisor) ReconnectToken() ([]byte, error) {
s.jwtLock.RLock()
defer s.jwtLock.RUnlock()
if s.jwt == nil {
return nil, errJWTUnset
}
return s.jwt, nil
}
func (s *Supervisor) SetReconnectToken(jwt []byte) {
s.jwtLock.Lock()
defer s.jwtLock.Unlock()
s.jwt = jwt
}
func (s *Supervisor) EventDigest() ([]byte, error) {
s.eventDigestLock.RLock()
defer s.eventDigestLock.RUnlock()
if s.eventDigest == nil {
return nil, errEventDigestUnset
}
return s.eventDigest, nil
}
func (s *Supervisor) SetEventDigest(eventDigest []byte) {
s.eventDigestLock.Lock()
defer s.eventDigestLock.Unlock()
s.eventDigest = eventDigest
}
func (s *Supervisor) refreshAuth(
ctx context.Context,
backoff *BackoffHandler,
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
) (retryTimer <-chan time.Time, err error) {
logger := s.config.Logger.WithField("subsystem", subsystemRefreshAuth)
authOutcome, err := authenticate(ctx, backoff.Retries())
if err != nil {
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
logger.WithError(err).Warnf("Retrying in %v", duration)
return backoff.BackoffTimer(), nil
}
return nil, err
}
// clear backoff timer
backoff.SetGracePeriod()
switch outcome := authOutcome.(type) {
case tunnelpogs.AuthSuccess:
s.SetReconnectToken(outcome.JWT())
return timeAfter(outcome.RefreshAfter()), nil
case tunnelpogs.AuthUnknown:
return timeAfter(outcome.RefreshAfter()), nil
case tunnelpogs.AuthFail:
return nil, outcome
default:
return nil, fmt.Errorf("Unexpected outcome type %T", authOutcome)
}
}
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
arbitraryEdgeIP := s.getEdgeIP(rand.Int())
edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP)
if err != nil {
return nil, err
}
defer edgeConn.Close()
handler := h2mux.MuxedStreamFunc(func(*h2mux.MuxedStream) error {
// This callback is invoked by h2mux when the edge initiates a stream.
return nil // noop
})
muxerConfig := s.config.muxerConfig(handler)
muxerConfig.Logger = muxerConfig.Logger.WithField("subsystem", subsystemRefreshAuth)
muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams)
if err != nil {
return nil, err
}
go muxer.Serve(ctx)
defer func() {
// If we don't wait for the muxer shutdown here, edgeConn.Close() runs before the muxer connections are done,
// and the user sees log noise: "error writing data", "connection closed unexpectedly"
<-muxer.Shutdown()
}()
tunnelServer, err := connection.NewRPCClient(ctx, muxer, s.logger.WithField("subsystem", subsystemRefreshAuth), openStreamTimeout)
if err != nil {
return nil, err
}
defer tunnelServer.Close()
const arbitraryConnectionID = uint8(0)
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
authResponse, err := tunnelServer.Authenticate(
ctx,
s.config.OriginCert,
s.config.Hostname,
registrationOptions,
)
if err != nil {
return nil, err
}
return authResponse.Outcome(), nil
}

128
origin/supervisor_test.go Normal file
View File

@ -0,0 +1,128 @@
package origin
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
func TestRefreshAuthBackoff(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return nil, fmt.Errorf("authentication failure")
}
// authentication failures should consume the backoff
for i := uint(0); i < backoff.MaxRetries; i++ {
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err)
assert.NotNil(t, retryChan)
assert.Equal(t, (1<<i)*time.Second, wait)
}
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.Error(t, err)
assert.Nil(t, retryChan)
// now we actually make contact with the remote server
_, _ = s.refreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
})
// The backoff timer should have been reset. To confirm this, make timeNow
// return a value after the backoff timer's grace period
timeNow = func() time.Time {
expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries)
return time.Now().Add(expectedGracePeriod * 2)
}
_, ok := backoff.GetBackoffDuration(context.Background())
assert.True(t, ok)
}
func TestRefreshAuthSuccess(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
}
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err)
assert.NotNil(t, retryChan)
assert.Equal(t, 19*time.Hour, wait)
token, err := s.ReconnectToken()
assert.NoError(t, err)
assert.Equal(t, []byte("jwt"), token)
}
func TestRefreshAuthUnknown(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
}
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err)
assert.NotNil(t, retryChan)
assert.Equal(t, 19*time.Hour, wait)
token, err := s.ReconnectToken()
assert.Equal(t, errJWTUnset, err)
assert.Nil(t, token)
}
func TestRefreshAuthFail(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
}
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.Error(t, err)
assert.Nil(t, retryChan)
token, err := s.ReconnectToken()
assert.Equal(t, errJWTUnset, err)
assert.Nil(t, token)
}

View File

@ -14,7 +14,14 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/streamhandler"
@ -22,20 +29,12 @@ import (
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
"github.com/cloudflare/cloudflared/websocket"
raven "github.com/getsentry/raven-go"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
_ "github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
rpc "zombiezen.com/go/capnproto2/rpc"
)
const (
dialTimeout = 15 * time.Second
openStreamTimeout = 30 * time.Second
muxerTimeout = 5 * time.Second
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
TagHeaderNamePrefix = "Cf-Warp-Tag-"
DuplicateConnectionError = "EDUPCONN"
@ -53,6 +52,7 @@ type TunnelConfig struct {
HTTPTransport http.RoundTripper
HeartbeatInterval time.Duration
Hostname string
HTTPHostHeader string
IncidentLookup IncidentLookup
IsAutoupdated bool
IsFreeTunnel bool
@ -73,14 +73,9 @@ type TunnelConfig struct {
WSGI bool
// OriginUrl may not be used if a user specifies a unix socket.
OriginUrl string
}
type dialError struct {
cause error
}
func (e dialError) Error() string {
return e.cause.Error()
// feature-flag to use new edge reconnect tokens
UseReconnectToken bool
}
type dupConnRegisterTunnelError struct{}
@ -119,6 +114,18 @@ func (e clientRegisterTunnelError) Error() string {
return e.cause.Error()
}
func (c *TunnelConfig) muxerConfig(handler h2mux.MuxedStreamHandler) h2mux.MuxerConfig {
return h2mux.MuxerConfig{
Timeout: muxerTimeout,
Handler: handler,
IsClient: true,
HeartbeatInterval: c.HeartbeatInterval,
MaxHeartbeats: c.MaxHeartbeats,
Logger: c.TransportLogger.WithFields(log.Fields{}),
CompressionQuality: h2mux.CompressionSetting(c.CompressionQuality),
}
}
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
policy := tunnelrpc.ExistingTunnelPolicy_balance
if c.HAConnections <= 1 && c.LBPool == "" {
@ -141,7 +148,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
}
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
return NewSupervisor(config).Run(ctx, connectedSignal, cloudflaredID)
return NewSupervisor(config, cloudflaredID).Run(ctx, connectedSignal)
}
func ServeTunnelLoop(ctx context.Context,
@ -151,7 +158,7 @@ func ServeTunnelLoop(ctx context.Context,
connectedSignal *signal.Signal,
u uuid.UUID,
) error {
logger := config.Logger
connectionLogger := config.Logger.WithField("connectionID", connectionID)
config.Metrics.incrementHaConnections()
defer config.Metrics.decrementHaConnections()
backoff := BackoffHandler{MaxRetries: config.Retries}
@ -164,10 +171,18 @@ func ServeTunnelLoop(ctx context.Context,
// Ensure the above goroutine will terminate if we return without connecting
defer connectedFuse.Fuse(false)
for {
err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff, u)
err, recoverable := ServeTunnel(
ctx,
config,
connectionLogger,
addr, connectionID,
connectedFuse,
&backoff,
u,
)
if recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
logger.Infof("Retrying in %s seconds", duration)
connectionLogger.Infof("Retrying in %s seconds", duration)
backoff.Backoff(ctx)
continue
}
@ -179,6 +194,7 @@ func ServeTunnelLoop(ctx context.Context,
func ServeTunnel(
ctx context.Context,
config *TunnelConfig,
logger *log.Entry,
addr *net.TCPAddr,
connectionID uint8,
connectedFuse *h2mux.BooleanFuse,
@ -198,18 +214,17 @@ func ServeTunnel(
}()
connectionTag := uint8ToString(connectionID)
logger := config.Logger.WithField("connectionID", connectionTag)
// additional tags to send other than hostname which is set in cloudflared main package
tags := make(map[string]string)
tags["ha"] = connectionTag
// Returns error from parsing the origin URL or handshake errors
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID)
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID)
if err != nil {
errLog := config.Logger.WithError(err)
errLog := logger.WithError(err)
switch err.(type) {
case dialError:
case connection.DialError:
errLog.Error("Unable to dial edge")
case h2mux.MuxerHandshakeError:
errLog.Error("Handshake failed with edge server")
@ -223,7 +238,7 @@ func ServeTunnel(
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP, u)
err := RegisterTunnel(serveCtx, handler.muxer, config, logger, connectionID, originLocalIP, u)
if err == nil {
connectedFuse.Fuse(true)
backoff.SetGracePeriod()
@ -259,6 +274,8 @@ func ServeTunnel(
err = errGroup.Wait()
if err != nil {
_ = newClientRegisterTunnelError(err, config.Metrics.regFail)
switch castedErr := err.(type) {
case dupConnRegisterTunnelError:
logger.Info("Already connected to this server, selecting a different one")
@ -273,154 +290,108 @@ func ServeTunnel(
return castedErr.cause, !castedErr.permanent
case clientRegisterTunnelError:
logger.WithError(castedErr.cause).Error("Register tunnel error on client side")
raven.CaptureError(castedErr.cause, tags)
return err, true
case muxerShutdownError:
logger.Infof("Muxer shutdown")
return err, true
default:
logger.WithError(err).Error("Serve tunnel error")
raven.CaptureError(err, tags)
return err, true
}
}
return nil, true
}
func IsRPCStreamResponse(headers []h2mux.Header) bool {
if len(headers) != 1 {
return false
}
if headers[0].Name != ":status" || headers[0].Value != "200" {
return false
}
return true
}
func RegisterTunnel(
ctx context.Context,
muxer *h2mux.Muxer,
config *TunnelConfig,
logger *log.Entry,
connectionID uint8,
originLocalIP string,
uuid uuid.UUID,
) error {
config.TransportLogger.Debug("initiating RPC stream to register")
stream, err := openStream(ctx, muxer)
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger.WithField("subsystem", "rpc-register"), openStreamTimeout)
if err != nil {
// RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
}
if !IsRPCStreamResponse(stream.Headers) {
// stream response error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(config.TransportLogger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(config.TransportLogger.WithField("subsystem", "rpc-transport")),
)
defer conn.Close()
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
defer tunnelServer.Close()
// Request server info without blocking tunnel registration; must use capnp library directly.
tsClient := tunnelrpc.TunnelServer{Client: ts.Client}
serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil
})
registration, err := ts.RegisterTunnel(
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
registration := tunnelServer.RegisterTunnel(
ctx,
config.OriginCert,
config.Hostname,
config.RegistrationOptions(connectionID, originLocalIP, uuid),
)
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, config.Logger)
if err != nil {
if registrationErr := registration.DeserializeError(); registrationErr != nil {
// RegisterTunnel RPC failure
return newClientRegisterTunnelError(err, config.Metrics.regFail)
}
for _, logLine := range registration.LogLines {
config.Logger.Info(logLine)
return processRegisterTunnelError(registrationErr, config.Metrics)
}
if regErr := processRegisterTunnelError(registration.Err, registration.PermanentFailure, config.Metrics); regErr != nil {
return regErr
for _, logLine := range registration.LogLines {
logger.Info(logLine)
}
if registration.TunnelID != "" {
config.Metrics.tunnelsHA.AddTunnelID(connectionID, registration.TunnelID)
config.Logger.Infof("Each HA connection's tunnel IDs: %v", config.Metrics.tunnelsHA.String())
logger.Infof("Each HA connection's tunnel IDs: %v", config.Metrics.tunnelsHA.String())
}
// Print out the user's trial zone URL in a nice box (if they requested and got one)
if isTrialTunnel := config.Hostname == ""; isTrialTunnel {
if url, err := url.Parse(registration.Url); err == nil {
for _, line := range asciiBox(trialZoneMsg(url.String()), 2) {
config.Logger.Infoln(line)
logger.Infoln(line)
}
} else {
config.Logger.Errorln("Failed to connect tunnel, please try again.")
logger.Errorln("Failed to connect tunnel, please try again.")
return fmt.Errorf("empty URL in response from Cloudflare edge")
}
}
config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
config.Logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
config.Metrics.regSuccess.Inc()
return nil
}
func processRegisterTunnelError(err string, permanentFailure bool, metrics *TunnelMetrics) error {
if err == "" {
metrics.regSuccess.Inc()
return nil
}
metrics.regFail.WithLabelValues(err).Inc()
if err == DuplicateConnectionError {
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics) error {
if err.Error() == DuplicateConnectionError {
metrics.regFail.WithLabelValues("dup_edge_conn").Inc()
return dupConnRegisterTunnelError{}
}
metrics.regFail.WithLabelValues("server_error").Inc()
return serverRegisterTunnelError{
cause: fmt.Errorf("Server error: %s", err),
permanent: permanentFailure,
cause: fmt.Errorf("Server error: %s", err.Error()),
permanent: err.IsPermanent(),
}
}
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
logger.Debug("initiating RPC stream to unregister")
ctx := context.Background()
stream, err := openStream(ctx, muxer)
ts, err := connection.NewRPCClient(ctx, muxer, logger.WithField("subsystem", "rpc-unregister"), openStreamTimeout)
if err != nil {
// RPC stream open error
return err
}
if !IsRPCStreamResponse(stream.Headers) {
// stream response error
return err
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
)
defer conn.Close()
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
// gracePeriod is encoded in int64 using capnproto
return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
}
func openStream(ctx context.Context, muxer *h2mux.Muxer) (*h2mux.MuxedStream, error) {
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
defer cancel()
return muxer.OpenStream(openStreamCtx, []h2mux.Header{
{Name: ":method", Value: "RPC"},
{Name: ":scheme", Value: "capnp"},
{Name: ":path", Value: "*"},
}, nil)
}
func LogServerInfo(
promise tunnelrpc.ServerInfo_Promise,
connectionID uint8,
metrics *TunnelMetrics,
logger *log.Logger,
logger *log.Entry,
) {
serverInfoMessage, err := promise.Struct()
if err != nil {
@ -447,24 +418,25 @@ func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
}
type TunnelHandler struct {
originUrl string
muxer *h2mux.Muxer
httpClient http.RoundTripper
tlsConfig *tls.Config
tags []tunnelpogs.Tag
metrics *TunnelMetrics
originUrl string
httpHostHeader string
muxer *h2mux.Muxer
httpClient http.RoundTripper
tlsConfig *tls.Config
tags []tunnelpogs.Tag
metrics *TunnelMetrics
// connectionID is only used by metrics, and prometheus requires labels to be string
connectionID string
logger *log.Logger
noChunkedEncoding bool
}
var dialer = net.Dialer{DualStack: true}
var dialer = net.Dialer{}
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
func NewTunnelHandler(ctx context.Context,
config *TunnelConfig,
addr string,
addr *net.TCPAddr,
connectionID uint8,
) (*TunnelHandler, string, error) {
originURL, err := validation.ValidateUrl(config.OriginUrl)
@ -473,6 +445,7 @@ func NewTunnelHandler(ctx context.Context,
}
h := &TunnelHandler{
originUrl: originURL,
httpHostHeader: config.HTTPHostHeader,
httpClient: config.HTTPTransport,
tlsConfig: config.ClientTlsConfig,
tags: config.Tags,
@ -484,37 +457,18 @@ func NewTunnelHandler(ctx context.Context,
if h.httpClient == nil {
h.httpClient = http.DefaultTransport
}
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
// TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one)
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", addr)
dialCancel()
edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
if err != nil {
return nil, "", dialError{cause: errors.Wrap(err, "DialContext error")}
return nil, "", err
}
edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig)
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
err = edgeConn.Handshake()
if err != nil {
return nil, "", dialError{cause: errors.Wrap(err, "Handshake with edge error")}
}
// clear the deadline on the conn; h2mux has its own timeouts
edgeConn.SetDeadline(time.Time{})
// Establish a muxed connection with the edge
// Client mux handshake with agent server
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
Timeout: 5 * time.Second,
Handler: h,
IsClient: true,
HeartbeatInterval: config.HeartbeatInterval,
MaxHeartbeats: config.MaxHeartbeats,
Logger: config.TransportLogger.WithFields(log.Fields{}),
CompressionQuality: h2mux.CompressionSetting(config.CompressionQuality),
})
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, config.muxerConfig(h), h.metrics.activeStreams)
if err != nil {
return h, "", errors.New("TLS handshake error")
return nil, "", errors.Wrap(err, "Handshake with edge error")
}
return h, edgeConn.LocalAddr().String(), err
return h, edgeConn.LocalAddr().String(), nil
}
func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
@ -566,6 +520,11 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request,
}
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
if h.httpHostHeader != "" {
req.Header.Set("Host", h.httpHostHeader)
req.Host = h.httpHostHeader
}
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
if err != nil {
return nil, err
@ -594,6 +553,11 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request)
// Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive")
if h.httpHostHeader != "" {
req.Header.Set("Host", h.httpHostHeader)
req.Host = h.httpHostHeader
}
response, err := h.httpClient.RoundTrip(req)
if err != nil {
return nil, errors.Wrap(err, "Error proxying request to origin")

View File

@ -1,31 +0,0 @@
#!/bin/bash
FILENAME=$1
VERSION=$2
TAP_ROOT=$3
URL="https://developers.cloudflare.com/argo-tunnel/dl/cloudflared-${VERSION}-darwin-amd64.tgz"
SHA256=$(sha256sum -b "${FILENAME}" | cut -b1-64)
cd "${TAP_ROOT}" || exit 1
git checkout -f master
git reset --hard origin/master
tee cloudflared.rb <<EOF
class Cloudflared < Formula
desc 'Argo Tunnel'
homepage 'https://developers.cloudflare.com/argo-tunnel/'
url '${URL}'
sha256 '${SHA256}'
version '${VERSION}'
def install
bin.install 'cloudflared'
end
end
EOF
git add cloudflared.rb
git config user.name "cloudflare-warp-bot"
git config user.email "warp-bot@cloudflare.com"
git commit -m "Release Argo Tunnel ${VERSION}"
git version
GIT_SSH_COMMAND="ssh -o UserKnownHostsFile=../github_known_hosts" git push -v origin master

View File

@ -0,0 +1,14 @@
FROM python:3-buster
RUN wget https://bin.equinox.io/c/VdrWdbjqyF/cloudflared-stable-linux-amd64.deb \
&& dpkg -i cloudflared-stable-linux-amd64.deb
RUN pip install pexpect
COPY tests.py .
COPY ssh /root/.ssh
RUN chmod 600 /root/.ssh/id_rsa
ARG SSH_HOSTNAME
RUN bash -c 'sed -i "s/{{hostname}}/${SSH_HOSTNAME}/g" /root/.ssh/authorized_keys_config'
RUN bash -c 'sed -i "s/{{hostname}}/${SSH_HOSTNAME}/g" /root/.ssh/short_lived_cert_config'

View File

@ -0,0 +1,23 @@
# Cloudflared SSH server smoke tests
Runs several tests in a docker container against a server that is started out of band of these tests.
Cloudflared token also needs to be retrieved out of band.
SSH server hostname and user need to be configured in a docker environment file
## Running tests
* Build cloudflared:
make cloudflared
* Start server:
sudo ./cloudflared tunnel --hostname HOSTNAME --ssh-server
* Fetch token:
./cloudflared access login HOSTNAME
* Create docker env file:
echo "SSH_HOSTNAME=HOSTNAME\nSSH_USER=USERNAME\n" > ssh_server_tests/.env
* Run tests:
make test-ssh-server

View File

@ -0,0 +1,18 @@
version: "3.1"
services:
ssh_test:
build:
context: .
args:
- SSH_HOSTNAME=${SSH_HOSTNAME}
volumes:
- "~/.cloudflared/:/root/.cloudflared"
env_file:
- .env
environment:
- AUTHORIZED_KEYS_SSH_CONFIG=/root/.ssh/authorized_keys_config
- SHORT_LIVED_CERT_SSH_CONFIG=/root/.ssh/short_lived_cert_config
- REMOTE_SCP_FILENAME=scp_test.txt
- ROOT_ONLY_TEST_FILE_PATH=~/permission_test.txt
entrypoint: "python tests.py"

View File

@ -0,0 +1,5 @@
Host *
AddressFamily inet
Host {{hostname}}
ProxyCommand /usr/local/bin/cloudflared access ssh --hostname %h

View File

@ -0,0 +1,49 @@
-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAACFwAAAAdzc2gtcn
NhAAAAAwEAAQAAAgEAvi26NDQ8cYTTztqPe9ZgF5HR/rIo5FoDgL5NbbZKW6h0txP9Fd8s
id9Bgmo+aGkeM327tPVVMQ6UFmdRksOCIDWQDjNLF8b6S+Fu95tvMKSbGreRoR32OvgZKV
I6KmOsF4z4GIv9naPplZswtKEUhSSI+/gPdAs9wfwalqZ77e82QJ727bYMeC3lzuoT+KBI
dYufJ4OQhLtpHrqhB5sn7s6+oCv/u85GSln5SIC18Hi2t9lW4tgb5tH8P0kEDDWGfPS5ok
qGi4kFTvwBXOCS2r4dhi5hRkpP7PqG4np0OCfvK5IRRJ27fCnj0loc+puZJAxnPMbuJr64
vwxRx78PM/V0PDUsl0P6aR/vbe0XmF9FGqbWf2Tar1p4r6C9/bMzcDz8seYT8hzLIHP3+R
l1hdlsTLm+1EzhaExKId+tjXegKGG4nU24h6qHEnRxLQDMwEsdkfj4E1pVypZJXVyNj99D
o84vi0EUnu7R4HmQb/C+Pu7qMDtLT3Zk7O5Mg4LQ+cTz9V0noYEAyG46nAB4U/nJzBnV1J
+aAdpioHmUAYhLYlQ9Kiy7LCJi92g9Wqa4wxMKxBbO5ZeH++p2p2lUi/oQNqx/2dLYFmy0
wxvJHbZIhAaFbOeCvHg1ucIAQznli2jOr2qoB+yKRRPAp/3NXnZg1v7ce2CkwiAD52wjtC
kAAAdILMJUeyzCVHsAAAAHc3NoLXJzYQAAAgEAvi26NDQ8cYTTztqPe9ZgF5HR/rIo5FoD
gL5NbbZKW6h0txP9Fd8sid9Bgmo+aGkeM327tPVVMQ6UFmdRksOCIDWQDjNLF8b6S+Fu95
tvMKSbGreRoR32OvgZKVI6KmOsF4z4GIv9naPplZswtKEUhSSI+/gPdAs9wfwalqZ77e82
QJ727bYMeC3lzuoT+KBIdYufJ4OQhLtpHrqhB5sn7s6+oCv/u85GSln5SIC18Hi2t9lW4t
gb5tH8P0kEDDWGfPS5okqGi4kFTvwBXOCS2r4dhi5hRkpP7PqG4np0OCfvK5IRRJ27fCnj
0loc+puZJAxnPMbuJr64vwxRx78PM/V0PDUsl0P6aR/vbe0XmF9FGqbWf2Tar1p4r6C9/b
MzcDz8seYT8hzLIHP3+Rl1hdlsTLm+1EzhaExKId+tjXegKGG4nU24h6qHEnRxLQDMwEsd
kfj4E1pVypZJXVyNj99Do84vi0EUnu7R4HmQb/C+Pu7qMDtLT3Zk7O5Mg4LQ+cTz9V0noY
EAyG46nAB4U/nJzBnV1J+aAdpioHmUAYhLYlQ9Kiy7LCJi92g9Wqa4wxMKxBbO5ZeH++p2
p2lUi/oQNqx/2dLYFmy0wxvJHbZIhAaFbOeCvHg1ucIAQznli2jOr2qoB+yKRRPAp/3NXn
Zg1v7ce2CkwiAD52wjtCkAAAADAQABAAACAQCbnVsyAFQ9J00Rg/HIiUATyTQlzq57O9SF
8jH1RiZOHedzLx32WaleH5rBFiJ+2RTnWUjQ57aP77fpJR2wk93UcT+w/vPBPwXsNUjRvx
Qan3ZzRCYbyiKDWiNslmYV7X0RwD36CAK8jTVDP7t48h2SXLTiSLaMY+5i3uD6yLu7k/O2
qNyw4jgN1rCmwQ8acD0aQec3NAZ7NcbsaBX/3Uutsup0scwOZtlJWZoLY5Z8cKpCgcsAz4
j1NHnNZvey7dFgSffj/ktdvf7kBH0w/GnuJ4aNF0Jte70u0kiw5TZYBQVFh74tgUu6a6SJ
qUbxIYUL5EJNjxGsDn+phHEemw3aMv0CwZG6Tqaionlna7bLsl9Bg1HTGclczVWx8uqC+M
6agLmkhYCHG0rVj8h5smjXAQXtmvIDVYDOlJZZoF9VAOCj6QfmJUH1NAGpCs1HDHbeOxGA
OLCh4d3F4rScPqhGdtSt4W13VFIvXn2Qqoz9ufepZsee1SZqpcerxywx2wN9ZAzu+X8lTN
i+TA2B3vWpqqucOEsp4JwDN+VMKZqKUGUDWcm/eHSaG6wq0q734LUlgM85TjaIg8QsNtWV
giB1nWwsYIuH4rsFNFGEwURYdGBcw6idH0GZ7I4RaIB5F9oOza1d601E0APHYrtnx9yOiK
nOtJ+5ZmVZovaDRfu1aQAAAQBU/EFaNUzoVhO04pS2L6BlByt963bOIsSJhdlEzek5AAli
eaf1S/PD6xWCc0IGY+GZE0HPbhsKYanjqOpWldcA2T7fzf4oz4vFBfUkPYo/MLSlLCYsDd
IH3wBkCssnfR5EkzNgxnOvq646Nl64BMvxwSIXGPktdq9ZALxViwricSRzCFURnh5vLHWU
wBzSgAA0UlZ9E64GtAv066+AoZCp83GhTLRC4o0naE2e/K4op4BCFHLrZ8eXmDRK3NJj80
Vkn+uhrk+SHmbjIhmS57Pv9p8TWyRvemph/nMUuZGKBUu2X+JQxggck0KigIrXjsmciCsM
BIM3mYDDfjYbyVhTAAABAQDkV8O1bWUsAIqk7RU+iDZojN5kaO+zUvj1TafX8QX1sY6pu4
Z2cfSEka1532BaehM95bQm7BCPw4cYg56XidmCQTZ9WaWqxVrOo48EKXUtZMZx6nKFOKlq
MT2XTMnGT9n7kFCfEjSVkAjuJ9ZTFLOaoXAaVRnxeHQwOKaup5KKP9GSzNIw328U+96s3V
WKHeT4pMjHBccgW/qX/tRRidZw5in5uBC9Ew5y3UACFTkNOnhUwVfyUNbBZJ2W36msQ3KD
AN7nOrQHqhd3NFyCEy2ovIAKVBacr/VEX6EsRUshIehJzz8EY9f3kXL7WT2QDoz2giPeBJ
HJdEpXt43UpszjAAABAQDVNpqNdHUlCs9XnbIvc6ZRrNh79wt65YFfvh/QEuA33KnA6Ri6
EgnV5IdUWXS/UFaYcm2udydrBpVIVifSYl3sioHBylpri23BEy38PKwVXvghUtfpN6dWGn
NZUG25fQPtIzqi+lo953ZjIj+Adi17AeVv4P4NiLrZeM9lXfWf2pEPOecxXs1IwAf9IiDQ
WepAwRLsu42eEnHA+DSJPZUkSbISfM5X345k0g6EVATX/yLL3CsqClPzPtsqjh6rbEfFg3
2OfIMcWV77gOlGWGQ+bUHc8kV6xJqV9QVacLWzfLvIqHF0wQMf8WLOVHEzkfiq4VjwhVqr
/+FFvljm5nSDAAAAEW1pa2VAQzAyWTUwVEdKR0g4AQ==
-----END OPENSSH PRIVATE KEY-----

View File

@ -0,0 +1 @@
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQC+Lbo0NDxxhNPO2o971mAXkdH+sijkWgOAvk1ttkpbqHS3E/0V3yyJ30GCaj5oaR4zfbu09VUxDpQWZ1GSw4IgNZAOM0sXxvpL4W73m28wpJsat5GhHfY6+BkpUjoqY6wXjPgYi/2do+mVmzC0oRSFJIj7+A90Cz3B/BqWpnvt7zZAnvbttgx4LeXO6hP4oEh1i58ng5CEu2keuqEHmyfuzr6gK/+7zkZKWflIgLXweLa32Vbi2Bvm0fw/SQQMNYZ89LmiSoaLiQVO/AFc4JLavh2GLmFGSk/s+obienQ4J+8rkhFEnbt8KePSWhz6m5kkDGc8xu4mvri/DFHHvw8z9XQ8NSyXQ/ppH+9t7ReYX0UaptZ/ZNqvWnivoL39szNwPPyx5hPyHMsgc/f5GXWF2WxMub7UTOFoTEoh362Nd6AoYbidTbiHqocSdHEtAMzASx2R+PgTWlXKlkldXI2P30Ojzi+LQRSe7tHgeZBv8L4+7uowO0tPdmTs7kyDgtD5xPP1XSehgQDIbjqcAHhT+cnMGdXUn5oB2mKgeZQBiEtiVD0qLLssImL3aD1aprjDEwrEFs7ll4f76nanaVSL+hA2rH/Z0tgWbLTDG8kdtkiEBoVs54K8eDW5wgBDOeWLaM6vaqgH7IpFE8Cn/c1edmDW/tx7YKTCIAPnbCO0KQ== mike@C02Y50TGJGH8

View File

@ -0,0 +1,11 @@
Host *
AddressFamily inet
Host {{hostname}}
ProxyCommand bash -c '/usr/local/bin/cloudflared access ssh-gen --hostname %h; ssh -F /root/.ssh/short_lived_cert_config -tt %r@cfpipe-{{hostname}} >&2 <&1'
Host cfpipe-{{hostname}}
HostName {{hostname}}
ProxyCommand /usr/local/bin/cloudflared access ssh --hostname %h
IdentityFile ~/.cloudflared/{{hostname}}-cf_key
CertificateFile ~/.cloudflared/{{hostname}}-cf_key-cert.pub

195
ssh_server_tests/tests.py Normal file
View File

@ -0,0 +1,195 @@
"""
Cloudflared Integration tests
"""
import unittest
import subprocess
import os
import tempfile
from contextlib import contextmanager
from pexpect import pxssh
class TestSSHBase(unittest.TestCase):
"""
SSH test base class containing constants and helper funcs
"""
HOSTNAME = os.environ["SSH_HOSTNAME"]
SSH_USER = os.environ["SSH_USER"]
SSH_TARGET = f"{SSH_USER}@{HOSTNAME}"
AUTHORIZED_KEYS_SSH_CONFIG = os.environ["AUTHORIZED_KEYS_SSH_CONFIG"]
SHORT_LIVED_CERT_SSH_CONFIG = os.environ["SHORT_LIVED_CERT_SSH_CONFIG"]
SSH_OPTIONS = {"StrictHostKeyChecking": "no"}
@classmethod
def get_ssh_command(cls, pty=True):
"""
Return ssh command arg list. If pty is true, a PTY is forced for the session.
"""
cmd = [
"ssh",
"-o",
"StrictHostKeyChecking=no",
"-F",
cls.AUTHORIZED_KEYS_SSH_CONFIG,
cls.SSH_TARGET,
]
if not pty:
cmd += ["-T"]
else:
cmd += ["-tt"]
return cmd
@classmethod
@contextmanager
def ssh_session_manager(cls, *args, **kwargs):
"""
Context manager for interacting with a pxssh session.
Disables pty echo on the remote server and ensures session is terminated afterward.
"""
session = pxssh.pxssh(options=cls.SSH_OPTIONS)
session.login(
cls.HOSTNAME,
username=cls.SSH_USER,
original_prompt=r"[#@$]",
ssh_config=kwargs.get("ssh_config", cls.AUTHORIZED_KEYS_SSH_CONFIG),
ssh_tunnels=kwargs.get("ssh_tunnels", {}),
)
try:
session.sendline("stty -echo")
session.prompt()
yield session
finally:
session.logout()
@staticmethod
def get_command_output(session, cmd):
"""
Executes command on remote ssh server and waits for prompt.
Returns command output
"""
session.sendline(cmd)
session.prompt()
return session.before.decode().strip()
def exec_command(self, cmd, shell=False):
"""
Executes command locally. Raises Assertion error for non-zero return code.
Returns stdout and stderr
"""
proc = subprocess.Popen(
cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=shell
)
raw_out, raw_err = proc.communicate()
out = raw_out.decode()
err = raw_err.decode()
self.assertEqual(proc.returncode, 0, msg=f"stdout: {out} stderr: {err}")
return out.strip(), err.strip()
class TestSSHCommandExec(TestSSHBase):
"""
Tests inline ssh command exec
"""
# Name of file to be downloaded over SCP on remote server.
REMOTE_SCP_FILENAME = os.environ["REMOTE_SCP_FILENAME"]
@classmethod
def get_scp_base_command(cls):
return [
"scp",
"-o",
"StrictHostKeyChecking=no",
"-v",
"-F",
cls.AUTHORIZED_KEYS_SSH_CONFIG,
]
@unittest.skip(
"This creates files on the remote. Should be skipped until server is dockerized."
)
def test_verbose_scp_sink_mode(self):
with tempfile.NamedTemporaryFile() as fl:
self.exec_command(
self.get_scp_base_command() + [fl.name, f"{self.SSH_TARGET}:"]
)
def test_verbose_scp_source_mode(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.exec_command(
self.get_scp_base_command()
+ [f"{self.SSH_TARGET}:{self.REMOTE_SCP_FILENAME}", tmpdirname]
)
local_filename = os.path.join(tmpdirname, self.REMOTE_SCP_FILENAME)
self.assertTrue(os.path.exists(local_filename))
self.assertTrue(os.path.getsize(local_filename) > 0)
def test_pty_command(self):
base_cmd = self.get_ssh_command()
out, _ = self.exec_command(base_cmd + ["whoami"])
self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
out, _ = self.exec_command(base_cmd + ["tty"])
self.assertNotEqual(out, "not a tty")
def test_non_pty_command(self):
base_cmd = self.get_ssh_command(pty=False)
out, _ = self.exec_command(base_cmd + ["whoami"])
self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
out, _ = self.exec_command(base_cmd + ["tty"])
self.assertEqual(out, "not a tty")
class TestSSHShell(TestSSHBase):
"""
Tests interactive SSH shell
"""
# File path to a file on the remote server with root only read privileges.
ROOT_ONLY_TEST_FILE_PATH = os.environ["ROOT_ONLY_TEST_FILE_PATH"]
def test_ssh_pty(self):
with self.ssh_session_manager() as session:
# Test shell launched as correct user
username = self.get_command_output(session, "whoami")
self.assertEqual(username.lower(), self.SSH_USER.lower())
# Test USER env variable set
user_var = self.get_command_output(session, "echo $USER")
self.assertEqual(user_var.lower(), self.SSH_USER.lower())
# Test HOME env variable set to true user home.
home_env = self.get_command_output(session, "echo $HOME")
pwd = self.get_command_output(session, "pwd")
self.assertEqual(pwd, home_env)
# Test shell launched in correct user home dir.
self.assertIn(username, pwd)
# Ensure shell launched with correct user's permissions and privs.
# Cant read root owned 0700 files.
output = self.get_command_output(
session, f"cat {self.ROOT_ONLY_TEST_FILE_PATH}"
)
self.assertIn("Permission denied", output)
def test_short_lived_cert_auth(self):
with self.ssh_session_manager(
ssh_config=self.SHORT_LIVED_CERT_SSH_CONFIG
) as session:
username = self.get_command_output(session, "whoami")
self.assertEqual(username.lower(), self.SSH_USER.lower())
unittest.main()

View File

@ -8,7 +8,6 @@ import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"io/ioutil"
@ -20,6 +19,7 @@ import (
cfpath "github.com/cloudflare/cloudflared/cmd/cloudflared/path"
"github.com/coreos/go-oidc/jose"
homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh"
)
@ -73,48 +73,54 @@ func GenerateShortLivedCertificate(appURL *url.URL, token string) error {
// handleCertificateGeneration takes a JWT and uses it build a signPayload
// to send to the Sign endpoint with the public key from the keypair it generated
func handleCertificateGeneration(token, fullName string) (string, error) {
pub, err := generateKeyPair(fullName)
if err != nil {
return "", err
}
return SignCert(token, string(pub))
}
func SignCert(token, pubKey string) (string, error) {
if token == "" {
return "", errors.New("invalid token")
}
jwt, err := jose.ParseJWT(token)
if err != nil {
return "", err
return "", errors.Wrap(err, "failed to parse JWT")
}
claims, err := jwt.Claims()
if err != nil {
return "", err
return "", errors.Wrap(err, "failed to retrieve JWT claims")
}
issuer, _, err := claims.StringClaim("iss")
if err != nil {
return "", err
}
pub, err := generateKeyPair(fullName)
if err != nil {
return "", err
return "", errors.Wrap(err, "failed to retrieve JWT iss")
}
buf, err := json.Marshal(&signPayload{
PublicKey: string(pub),
PublicKey: pubKey,
JWT: token,
Issuer: issuer,
})
if err != nil {
return "", err
return "", errors.Wrap(err, "failed to marshal signPayload")
}
var res *http.Response
if mockRequest != nil {
res, err = mockRequest(issuer+signEndpoint, "application/json", bytes.NewBuffer(buf))
} else {
res, err = http.Post(issuer+signEndpoint, "application/json", bytes.NewBuffer(buf))
client := http.Client{
Timeout: 10 * time.Second,
}
res, err = client.Post(issuer+signEndpoint, "application/json", bytes.NewBuffer(buf))
}
if err != nil {
return "", err
return "", errors.Wrap(err, "failed to send request")
}
defer res.Body.Close()
@ -130,9 +136,9 @@ func handleCertificateGeneration(token, fullName string) (string, error) {
var signRes signResponse
if err := decoder.Decode(&signRes); err != nil {
return "", err
return "", errors.Wrap(err, "failed to decode HTTP response")
}
return signRes.Certificate, err
return signRes.Certificate, nil
}
// generateKeyPair creates a EC keypair (P256) and stores them in the homedir.

37
sshlog/empty_manager.go Normal file
View File

@ -0,0 +1,37 @@
package sshlog
import (
"io"
"github.com/sirupsen/logrus"
)
//empty manager implements the Manager but does nothing (for testing and to disable logging unless the logs are set)
type emptyManager struct {
}
type emptyWriteCloser struct {
}
// NewEmptyManager creates a new instance of a log empty log manager that does nothing
func NewEmptyManager() Manager {
return &emptyManager{}
}
func (m *emptyManager) NewLogger(name string, logger *logrus.Logger) (io.WriteCloser, error) {
return &emptyWriteCloser{}, nil
}
func (m *emptyManager) NewSessionLogger(name string, logger *logrus.Logger) (io.WriteCloser, error) {
return &emptyWriteCloser{}, nil
}
// emptyWriteCloser
func (w *emptyWriteCloser) Write(p []byte) (n int, err error) {
return len(p), nil
}
func (w *emptyWriteCloser) Close() error {
return nil
}

15
sshlog/go.capnp Normal file
View File

@ -0,0 +1,15 @@
# Generate go.capnp.out with:
# capnp compile -o- go.capnp > go.capnp.out
# Must run inside this directory to preserve paths.
@0xd12a1c51fedd6c88;
annotation package(file) :Text;
annotation import(file) :Text;
annotation doc(struct, field, enum) :Text;
annotation tag(enumerant) :Text;
annotation notag(enumerant) :Void;
annotation customtype(field) :Text;
annotation name(struct, field, union, enum, enumerant, interface, method, param, annotation, const, group) :Text;
$package("capnp");

167
sshlog/logger.go Normal file
View File

@ -0,0 +1,167 @@
package sshlog
import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/sirupsen/logrus"
)
const (
logTimeFormat = "2006-01-02T15-04-05.000"
megabyte = 1024 * 1024
defaultFileSizeLimit = 100 * megabyte
)
// Logger will buffer and write events to disk
type Logger struct {
sync.Mutex
filename string
file *os.File
writeBuffer *bufio.Writer
logger *logrus.Logger
flushInterval time.Duration
maxFileSize int64
done chan struct{}
once sync.Once
}
// NewLogger creates a Logger instance. A buffer is created that needs to be
// drained and closed when the caller is finished, so instances should call
// Close when finished with this Logger instance. Writes will be flushed to disk
// every second (fsync). filename is the name of the logfile to be created. The
// logger variable is a logrus that will log all i/o, filesystem error etc, that
// that shouldn't end execution of the logger, but are useful to report to the
// caller.
func NewLogger(filename string, logger *logrus.Logger, flushInterval time.Duration, maxFileSize int64) (*Logger, error) {
if logger == nil {
return nil, errors.New("logger can't be nil")
}
f, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(0600))
if err != nil {
return nil, err
}
l := &Logger{filename: filename,
file: f,
writeBuffer: bufio.NewWriter(f),
logger: logger,
flushInterval: flushInterval,
maxFileSize: maxFileSize,
done: make(chan struct{}),
}
go l.writer()
return l, nil
}
// Writes to a log buffer. Implements the io.Writer interface.
func (l *Logger) Write(p []byte) (n int, err error) {
l.Lock()
defer l.Unlock()
return l.writeBuffer.Write(p)
}
// Close drains anything left in the buffer and cleans up any resources still
// in use.
func (l *Logger) Close() error {
l.once.Do(func() {
close(l.done)
})
if err := l.write(); err != nil {
return err
}
return l.file.Close()
}
// writer is the run loop that handles draining the write buffer and syncing
// data to disk.
func (l *Logger) writer() {
ticker := time.NewTicker(l.flushInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := l.write(); err != nil {
l.logger.Errorln(err)
}
case <-l.done:
return
}
}
}
// write does the actual system write calls to disk and does a rotation if the
// file size limit has been reached. Since the rotation happens at the end,
// the rotation is a soft limit (aka the file can be bigger than the max limit
// because of the final buffer flush)
func (l *Logger) write() error {
l.Lock()
defer l.Unlock()
if l.writeBuffer.Buffered() <= 0 {
return nil
}
if err := l.writeBuffer.Flush(); err != nil {
return err
}
if err := l.file.Sync(); err != nil {
return err
}
if l.shouldRotate() {
return l.rotate()
}
return nil
}
// shouldRotate checks to see if the current file should be rotated to a new
// logfile.
func (l *Logger) shouldRotate() bool {
info, err := l.file.Stat()
if err != nil {
return false
}
return info.Size() >= l.maxFileSize
}
// rotate creates a new logfile with the existing filename and renames the
// existing file with a current timestamp.
func (l *Logger) rotate() error {
if err := l.file.Close(); err != nil {
return err
}
// move the existing file
newname := rotationName(l.filename)
if err := os.Rename(l.filename, newname); err != nil {
return fmt.Errorf("can't rename log file: %s", err)
}
f, err := os.OpenFile(l.filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(0600))
if err != nil {
return fmt.Errorf("failed to open new logfile %s", err)
}
l.file = f
l.writeBuffer = bufio.NewWriter(f)
return nil
}
// rotationName creates a new filename from the given name, inserting a timestamp
// between the filename and the extension.
func rotationName(name string) string {
dir := filepath.Dir(name)
filename := filepath.Base(name)
ext := filepath.Ext(filename)
prefix := filename[:len(filename)-len(ext)]
t := time.Now()
timestamp := t.Format(logTimeFormat)
return filepath.Join(dir, fmt.Sprintf("%s-%s%s", prefix, timestamp, ext))
}

89
sshlog/logger_test.go Normal file
View File

@ -0,0 +1,89 @@
package sshlog
import (
"log"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/sirupsen/logrus"
)
const logFileName = "test-logger.log"
func createLogger(t *testing.T) *Logger {
os.Remove(logFileName)
l := logrus.New()
logger, err := NewLogger(logFileName, l, time.Millisecond, 1024)
if err != nil {
t.Fatal("couldn't create the logger!", err)
}
return logger
}
// AUTH-2115 TODO: fix this test
//func TestWrite(t *testing.T) {
// testStr := "hi"
// logger := createLogger(t)
// defer func() {
// logger.Close()
// os.Remove(logFileName)
// }()
//
// logger.Write([]byte(testStr))
// time.Sleep(2 * time.Millisecond)
// data, err := ioutil.ReadFile(logFileName)
// if err != nil {
// t.Fatal("couldn't read the log file!", err)
// }
// checkStr := string(data)
// if checkStr != testStr {
// t.Fatal("file data doesn't match!")
// }
//}
func TestFilenameRotation(t *testing.T) {
newName := rotationName("dir/bob/acoolloggername.log")
dir := filepath.Dir(newName)
if dir != "dir/bob" {
t.Fatal("rotation name doesn't respect the directory filepath:", newName)
}
filename := filepath.Base(newName)
if !strings.HasPrefix(filename, "acoolloggername") {
t.Fatal("rotation filename is wrong:", filename)
}
ext := filepath.Ext(newName)
if ext != ".log" {
t.Fatal("rotation file extension is wrong:", ext)
}
}
func TestRotation(t *testing.T) {
logger := createLogger(t)
for i := 0; i < 2000; i++ {
logger.Write([]byte("a string for testing rotation\n"))
}
logger.Close()
count := 0
filepath.Walk(".", func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() {
return nil
}
if strings.HasPrefix(info.Name(), "test-logger") {
log.Println("deleting: ", path)
os.Remove(path)
count++
}
return nil
})
if count < 2 {
t.Fatal("rotation didn't roll files:", count)
}
}

34
sshlog/manager.go Normal file
View File

@ -0,0 +1,34 @@
package sshlog
import (
"io"
"path/filepath"
"time"
"github.com/sirupsen/logrus"
)
// Manager be managing logs bruh
type Manager interface {
NewLogger(string, *logrus.Logger) (io.WriteCloser, error)
NewSessionLogger(string, *logrus.Logger) (io.WriteCloser, error)
}
type manager struct {
baseDirectory string
}
// New creates a new instance of a log manager
func New(baseDirectory string) Manager {
return &manager{
baseDirectory: baseDirectory,
}
}
func (m *manager) NewLogger(name string, logger *logrus.Logger) (io.WriteCloser, error) {
return NewLogger(filepath.Join(m.baseDirectory, name), logger, time.Second, defaultFileSizeLimit)
}
func (m *manager) NewSessionLogger(name string, logger *logrus.Logger) (io.WriteCloser, error) {
return NewSessionLogger(filepath.Join(m.baseDirectory, name), logger, time.Second, defaultFileSizeLimit)
}

9
sshlog/session_log.capnp Normal file
View File

@ -0,0 +1,9 @@
using Go = import "go.capnp";
@0x8f43375162194466;
$Go.package("sshlog");
$Go.import("github.com/cloudflare/cloudflared/sshlog");
struct SessionLog {
timestamp @0 :Text;
content @1 :Data;
}

110
sshlog/session_log.capnp.go Normal file
View File

@ -0,0 +1,110 @@
// Code generated by capnpc-go. DO NOT EDIT.
package sshlog
import (
capnp "zombiezen.com/go/capnproto2"
text "zombiezen.com/go/capnproto2/encoding/text"
schemas "zombiezen.com/go/capnproto2/schemas"
)
type SessionLog struct{ capnp.Struct }
// SessionLog_TypeID is the unique identifier for the type SessionLog.
const SessionLog_TypeID = 0xa13a07c504a5ab64
func NewSessionLog(s *capnp.Segment) (SessionLog, error) {
st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2})
return SessionLog{st}, err
}
func NewRootSessionLog(s *capnp.Segment) (SessionLog, error) {
st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2})
return SessionLog{st}, err
}
func ReadRootSessionLog(msg *capnp.Message) (SessionLog, error) {
root, err := msg.RootPtr()
return SessionLog{root.Struct()}, err
}
func (s SessionLog) String() string {
str, _ := text.Marshal(0xa13a07c504a5ab64, s.Struct)
return str
}
func (s SessionLog) Timestamp() (string, error) {
p, err := s.Struct.Ptr(0)
return p.Text(), err
}
func (s SessionLog) HasTimestamp() bool {
p, err := s.Struct.Ptr(0)
return p.IsValid() || err != nil
}
func (s SessionLog) TimestampBytes() ([]byte, error) {
p, err := s.Struct.Ptr(0)
return p.TextBytes(), err
}
func (s SessionLog) SetTimestamp(v string) error {
return s.Struct.SetText(0, v)
}
func (s SessionLog) Content() ([]byte, error) {
p, err := s.Struct.Ptr(1)
return []byte(p.Data()), err
}
func (s SessionLog) HasContent() bool {
p, err := s.Struct.Ptr(1)
return p.IsValid() || err != nil
}
func (s SessionLog) SetContent(v []byte) error {
return s.Struct.SetData(1, v)
}
// SessionLog_List is a list of SessionLog.
type SessionLog_List struct{ capnp.List }
// NewSessionLog creates a new list of SessionLog.
func NewSessionLog_List(s *capnp.Segment, sz int32) (SessionLog_List, error) {
l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}, sz)
return SessionLog_List{l}, err
}
func (s SessionLog_List) At(i int) SessionLog { return SessionLog{s.List.Struct(i)} }
func (s SessionLog_List) Set(i int, v SessionLog) error { return s.List.SetStruct(i, v.Struct) }
func (s SessionLog_List) String() string {
str, _ := text.MarshalList(0xa13a07c504a5ab64, s.List)
return str
}
// SessionLog_Promise is a wrapper for a SessionLog promised by a client call.
type SessionLog_Promise struct{ *capnp.Pipeline }
func (p SessionLog_Promise) Struct() (SessionLog, error) {
s, err := p.Pipeline.Struct()
return SessionLog{s}, err
}
const schema_8f43375162194466 = "x\xda\x120q`\x12d\x8dg`\x08dae\xfb" +
"\x9f\xb2z)\xcbQv\xab\x85\x0c\x82B\x8c\xff\xd3\\" +
"$\x93\x02\xcd\x9d\xfb\x19X\x99\xd8\x19\x18\x04E_\x09" +
"*\x82h\xd9r\x06\xc6\xff\xc5\xa9\xc5\xc5\x99\xf9y\xf1" +
"L9\xf9\xe9z\xc9\x89\x05y\x05V\xc1`!\xfe<" +
"\x9f\xfc\xf4\x00F\xc6@\x0ef\x16\x06\x06\x16F\x06\x06" +
"A\xcd \x06\x86@\x0df\xc6@\x13&FAFF" +
"\x11F\x90\xa0\xa1\x13\x03C\xa0\x0e3c\xa0\x05\x13\xe3" +
"\xff\x92\xcc\xdc\xd4\xe2\x92\xc4\\\x06\xc6\x02F\x1e\x06&" +
"F\x1e\x06\xc6\xfa\xe4\xfc\xbc\x92\xd4\xbc\x12F^\x06&" +
"F^\x06F@\x00\x00\x00\xff\xff\xdaK$\x1a"
func init() {
schemas.Register(schema_8f43375162194466,
0xa13a07c504a5ab64)
}

71
sshlog/session_logger.go Normal file
View File

@ -0,0 +1,71 @@
package sshlog
import (
"time"
"github.com/sirupsen/logrus"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs"
)
// SessionLogger will buffer and write events to disk using capnp proto for session replay
type SessionLogger struct {
logger *Logger
encoder *capnp.Encoder
}
type sessionLogData struct {
Timestamp string // The UTC timestamp of when the log occurred
Content []byte // The shell output
}
// NewSessionLogger creates a new session logger by encapsulating a Logger object and writing capnp encoded messages to it
func NewSessionLogger(filename string, logger *logrus.Logger, flushInterval time.Duration, maxFileSize int64) (*SessionLogger, error) {
l, err := NewLogger(filename, logger, flushInterval, maxFileSize)
if err != nil {
return nil, err
}
sessionLogger := &SessionLogger{
logger: l,
encoder: capnp.NewEncoder(l),
}
return sessionLogger, nil
}
// Writes to a log buffer. Implements the io.Writer interface.
func (l *SessionLogger) Write(p []byte) (n int, err error) {
return l.writeSessionLog(&sessionLogData{
Timestamp: time.Now().UTC().Format(time.RFC3339),
Content: p,
})
}
// Close drains anything left in the buffer and cleans up any resources still
// in use.
func (l *SessionLogger) Close() error {
return l.logger.Close()
}
func (l *SessionLogger) writeSessionLog(p *sessionLogData) (int, error) {
msg, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
return 0, err
}
log, err := NewRootSessionLog(seg)
if err != nil {
return 0, err
}
log.SetTimestamp(p.Timestamp)
log.SetContent(p.Content)
if err := l.encoder.Encode(msg); err != nil {
return 0, err
}
return len(p.Content), nil
}
func unmarshalSessionLog(s SessionLog) (*sessionLogData, error) {
p := new(sessionLogData)
err := pogs.Extract(p, SessionLog_TypeID, s.Struct)
return p, err
}

View File

@ -0,0 +1,69 @@
package sshlog
import (
"os"
"testing"
"time"
"github.com/sirupsen/logrus"
capnp "zombiezen.com/go/capnproto2"
)
const sessionLogFileName = "test-session-logger.log"
func createSessionLogger(t *testing.T) *SessionLogger {
os.Remove(sessionLogFileName)
l := logrus.New()
logger, err := NewSessionLogger(sessionLogFileName, l, time.Millisecond, 1024)
if err != nil {
t.Fatal("couldn't create the logger!", err)
}
return logger
}
func TestSessionLogWrite(t *testing.T) {
testStr := "hi"
logger := createSessionLogger(t)
defer func() {
logger.Close()
os.Remove(sessionLogFileName)
}()
logger.Write([]byte(testStr))
time.Sleep(2 * time.Millisecond)
f, err := os.Open(sessionLogFileName)
if err != nil {
t.Fatal("couldn't read the log file!", err)
}
defer f.Close()
msg, err := capnp.NewDecoder(f).Decode()
if err != nil {
t.Fatal("couldn't read the capnp msg file!", err)
}
sessionLog, err := ReadRootSessionLog(msg)
if err != nil {
t.Fatal("couldn't read the session log from the msg!", err)
}
timeStr, err := sessionLog.Timestamp()
if err != nil {
t.Fatal("couldn't read the Timestamp field!", err)
}
_, terr := time.Parse(time.RFC3339, timeStr)
if terr != nil {
t.Fatal("couldn't parse the Timestamp into the expected RFC3339 format", terr)
}
data, err := sessionLog.Content()
if err != nil {
t.Fatal("couldn't read the Content field!", err)
}
checkStr := string(data)
if checkStr != testStr {
t.Fatal("file data doesn't match!")
}
}

114
sshserver/host_keys.go Normal file
View File

@ -0,0 +1,114 @@
//+build !windows
package sshserver
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
)
const (
rsaFilename = "ssh_host_rsa_key"
ecdsaFilename = "ssh_host_ecdsa_key"
)
var defaultHostKeyDir = filepath.Join(".cloudflared", "host_keys")
func (s *SSHProxy) configureHostKeys(hostKeyDir string) error {
if hostKeyDir == "" {
homeDir, err := os.UserHomeDir()
if err != nil {
return err
}
hostKeyDir = filepath.Join(homeDir, defaultHostKeyDir)
}
if _, err := os.Stat(hostKeyDir); os.IsNotExist(err) {
if err := os.MkdirAll(hostKeyDir, 0755); err != nil {
return errors.Wrap(err, fmt.Sprintf("Error creating %s directory", hostKeyDir))
}
}
if err := s.configureECDSAKey(hostKeyDir); err != nil {
return err
}
if err := s.configureRSAKey(hostKeyDir); err != nil {
return err
}
return nil
}
func (s *SSHProxy) configureRSAKey(basePath string) error {
keyPath := filepath.Join(basePath, rsaFilename)
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return errors.Wrap(err, "Error generating RSA host key")
}
privateKey := &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
}
if err = writePrivateKey(keyPath, privateKey); err != nil {
return err
}
s.logger.Debug("Created new RSA SSH host key: ", keyPath)
}
if err := s.SetOption(ssh.HostKeyFile(keyPath)); err != nil {
return errors.Wrap(err, "Could not set SSH RSA host key")
}
return nil
}
func (s *SSHProxy) configureECDSAKey(basePath string) error {
keyPath := filepath.Join(basePath, ecdsaFilename)
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return errors.Wrap(err, "Error generating ECDSA host key")
}
keyBytes, err := x509.MarshalECPrivateKey(key)
if err != nil {
return errors.Wrap(err, "Error marshalling ECDSA key")
}
privateKey := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: keyBytes,
}
if err = writePrivateKey(keyPath, privateKey); err != nil {
return err
}
s.logger.Debug("Created new ECDSA SSH host key: ", keyPath)
}
if err := s.SetOption(ssh.HostKeyFile(keyPath)); err != nil {
return errors.Wrap(err, "Could not set SSH ECDSA host key")
}
return nil
}
func writePrivateKey(keyPath string, privateKey *pem.Block) error {
if err := ioutil.WriteFile(keyPath, pem.EncodeToMemory(privateKey), 0600); err != nil {
return errors.Wrap(err, fmt.Sprintf("Error writing host key to %s", keyPath))
}
return nil
}

View File

@ -0,0 +1,29 @@
package sshserver
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestHasPort(t *testing.T) {
type testCase struct {
input string
expectedOutput string
}
tests := []testCase{
{"localhost", "localhost:22"},
{"other.addr:22", "other.addr:22"},
{"[2001:db8::1]:8080", "[2001:db8::1]:8080"},
{"[::1]", "[::1]:22"},
{"2001:0db8:3c4d:0015:0000:0000:1a2f:1234", "[2001:0db8:3c4d:0015:0000:0000:1a2f:1234]:22"},
{"::1", "[::1]:22"},
}
for _, test := range tests {
out, err := canonicalizeDest(test.input)
require.Nil(t, err)
assert.Equal(t, test.expectedOutput, out)
}
}

490
sshserver/sshserver_unix.go Normal file
View File

@ -0,0 +1,490 @@
//+build !windows
package sshserver
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net"
"runtime"
"strings"
"time"
"github.com/cloudflare/cloudflared/sshgen"
"github.com/cloudflare/cloudflared/sshlog"
"github.com/gliderlabs/ssh"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
gossh "golang.org/x/crypto/ssh"
)
const (
auditEventStart = "session_start"
auditEventStop = "session_stop"
auditEventExec = "exec"
auditEventScp = "scp"
auditEventResize = "resize"
auditEventShell = "shell"
sshContextSessionID = "sessionID"
sshContextEventLogger = "eventLogger"
sshContextPreamble = "sshPreamble"
sshContextSSHClient = "sshClient"
SSHPreambleLength = 2
defaultSSHPort = "22"
)
type auditEvent struct {
Event string `json:"event,omitempty"`
EventType string `json:"event_type,omitempty"`
SessionID string `json:"session_id,omitempty"`
User string `json:"user,omitempty"`
Login string `json:"login,omitempty"`
Datetime string `json:"datetime,omitempty"`
Hostname string `json:"hostname,omitempty"`
Destination string `json:"destination,omitempty"`
}
// sshConn wraps the incoming net.Conn and a cleanup function
// This is done to allow the outgoing SSH client to be retrieved and closed when the conn itself is closed.
type sshConn struct {
net.Conn
cleanupFunc func()
}
// close calls the cleanupFunc before closing the conn
func (c sshConn) Close() error {
c.cleanupFunc()
return c.Conn.Close()
}
type SSHProxy struct {
ssh.Server
hostname string
logger *logrus.Logger
shutdownC chan struct{}
caCert ssh.PublicKey
logManager sshlog.Manager
}
type SSHPreamble struct {
Destination string
JWT string
}
// New creates a new SSHProxy and configures its host keys and authentication by the data provided
func New(logManager sshlog.Manager, logger *logrus.Logger, version, localAddress, hostname, hostKeyDir string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHProxy, error) {
sshProxy := SSHProxy{
hostname: hostname,
logger: logger,
shutdownC: shutdownC,
logManager: logManager,
}
sshProxy.Server = ssh.Server{
Addr: localAddress,
MaxTimeout: maxTimeout,
IdleTimeout: idleTimeout,
Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS),
PublicKeyHandler: sshProxy.proxyAuthCallback,
ConnCallback: sshProxy.connCallback,
ChannelHandlers: map[string]ssh.ChannelHandler{
"default": sshProxy.channelHandler,
},
}
if err := sshProxy.configureHostKeys(hostKeyDir); err != nil {
return nil, err
}
return &sshProxy, nil
}
// Start the SSH proxy listener to start handling SSH connections from clients
func (s *SSHProxy) Start() error {
s.logger.Infof("Starting SSH server at %s", s.Addr)
go func() {
<-s.shutdownC
if err := s.Close(); err != nil {
s.logger.WithError(err).Error("Cannot close SSH server")
}
}()
return s.ListenAndServe()
}
// proxyAuthCallback attempts to connect to ultimate SSH destination. If successful, it allows the incoming connection
// to connect to the proxy and saves the outgoing SSH client to the context. Otherwise, no connection to the
// the proxy is allowed.
func (s *SSHProxy) proxyAuthCallback(ctx ssh.Context, key ssh.PublicKey) bool {
client, err := s.dialDestination(ctx)
if err != nil {
return false
}
ctx.SetValue(sshContextSSHClient, client)
return true
}
// connCallback reads the preamble sent from the proxy server and saves an audit event logger to the context.
// If any errors occur, the connection is terminated by returning nil from the callback.
func (s *SSHProxy) connCallback(ctx ssh.Context, conn net.Conn) net.Conn {
// AUTH-2050: This is a temporary workaround of a timing issue in the tunnel muxer to allow further testing.
// TODO: Remove this
time.Sleep(10 * time.Millisecond)
preamble, err := s.readPreamble(conn)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
s.logger.Warn("Could not establish session. Client likely does not have --destination set and is using old-style ssh config")
} else if err != io.EOF {
s.logger.WithError(err).Error("failed to read SSH preamble")
}
return nil
}
ctx.SetValue(sshContextPreamble, preamble)
logger, sessionID, err := s.auditLogger()
if err != nil {
s.logger.WithError(err).Error("failed to configure logger")
return nil
}
ctx.SetValue(sshContextEventLogger, logger)
ctx.SetValue(sshContextSessionID, sessionID)
// attempts to retrieve and close the outgoing ssh client when the incoming conn is closed.
// If no client exists, the conn is being closed before the PublicKeyCallback was called (where the client is created).
cleanupFunc := func() {
client, ok := ctx.Value(sshContextSSHClient).(*gossh.Client)
if ok && client != nil {
client.Close()
}
}
return sshConn{conn, cleanupFunc}
}
// channelHandler proxies incoming and outgoing SSH traffic back and forth over an SSH Channel
func (s *SSHProxy) channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
if newChan.ChannelType() != "session" && newChan.ChannelType() != "direct-tcpip" {
msg := fmt.Sprintf("channel type %s is not supported", newChan.ChannelType())
s.logger.Info(msg)
if err := newChan.Reject(gossh.UnknownChannelType, msg); err != nil {
s.logger.WithError(err).Error("Error rejecting SSH channel")
}
return
}
localChan, localChanReqs, err := newChan.Accept()
if err != nil {
s.logger.WithError(err).Error("Failed to accept session channel")
return
}
defer localChan.Close()
// client will be closed when the sshConn is closed
client, ok := ctx.Value(sshContextSSHClient).(*gossh.Client)
if !ok {
s.logger.Error("Could not retrieve client from context")
return
}
remoteChan, remoteChanReqs, err := client.OpenChannel(newChan.ChannelType(), newChan.ExtraData())
if err != nil {
s.logger.WithError(err).Error("Failed to open remote channel")
return
}
defer remoteChan.Close()
// Proxy ssh traffic back and forth between client and destination
s.proxyChannel(localChan, remoteChan, localChanReqs, remoteChanReqs, conn, ctx)
}
// proxyChannel couples two SSH channels and proxies SSH traffic and channel requests back and forth.
func (s *SSHProxy) proxyChannel(localChan, remoteChan gossh.Channel, localChanReqs, remoteChanReqs <-chan *gossh.Request, conn *gossh.ServerConn, ctx ssh.Context) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(localChan, remoteChan); err != nil {
s.logger.WithError(err).Error("remote to local copy error")
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(remoteChan, localChan); err != nil {
s.logger.WithError(err).Error("local to remote copy error")
}
done <- struct{}{}
}()
// stderr streams are used non-pty sessions since they have distinct IO streams.
remoteStderr := remoteChan.Stderr()
localStderr := localChan.Stderr()
go func() {
if _, err := io.Copy(remoteStderr, localStderr); err != nil {
s.logger.WithError(err).Error("stderr local to remote copy error")
}
}()
go func() {
if _, err := io.Copy(localStderr, remoteStderr); err != nil {
s.logger.WithError(err).Error("stderr remote to local copy error")
}
}()
s.logAuditEvent(conn, "", auditEventStart, ctx)
defer s.logAuditEvent(conn, "", auditEventStop, ctx)
// Proxy channel requests
for {
select {
case req := <-localChanReqs:
if req == nil {
return
}
if err := s.forwardChannelRequest(remoteChan, req); err != nil {
s.logger.WithError(err).Error("Failed to forward request")
return
}
s.logChannelRequest(req, conn, ctx)
case req := <-remoteChanReqs:
if req == nil {
return
}
if err := s.forwardChannelRequest(localChan, req); err != nil {
s.logger.WithError(err).Error("Failed to forward request")
return
}
case <-done:
return
}
}
}
// readPreamble reads a preamble from the SSH connection before any SSH traffic is sent.
// This preamble is a JSON encoded struct containing the users JWT and ultimate destination.
// The first 4 bytes contain the length of the preamble which follows immediately.
func (s *SSHProxy) readPreamble(conn net.Conn) (*SSHPreamble, error) {
// Set conn read deadline while reading preamble to prevent hangs if preamble wasnt sent.
if err := conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)); err != nil {
return nil, errors.Wrap(err, "failed to set conn deadline")
}
defer func() {
if err := conn.SetReadDeadline(time.Time{}); err != nil {
s.logger.WithError(err).Error("Failed to unset conn read deadline")
}
}()
size := make([]byte, SSHPreambleLength)
if _, err := io.ReadFull(conn, size); err != nil {
return nil, err
}
payloadLength := binary.BigEndian.Uint16(size)
payload := make([]byte, payloadLength)
if _, err := io.ReadFull(conn, payload); err != nil {
return nil, err
}
var preamble SSHPreamble
err := json.Unmarshal(payload, &preamble)
if err != nil {
return nil, err
}
preamble.Destination, err = canonicalizeDest(preamble.Destination)
if err != nil {
return nil, err
}
return &preamble, nil
}
// canonicalizeDest adds a default port if one doesnt exist
func canonicalizeDest(dest string) (string, error) {
_, _, err := net.SplitHostPort(dest)
// if host and port are split without error, a port exists.
if err != nil {
addrErr, ok := err.(*net.AddrError)
if !ok {
return "", err
}
// If the port is missing, append it.
if addrErr.Err == "missing port in address" {
return fmt.Sprintf("%s:%s", dest, defaultSSHPort), nil
}
// If there are too many colons and address is IPv6, wrap in brackets and append port. Otherwise invalid address
ip := net.ParseIP(dest)
if addrErr.Err == "too many colons in address" && ip != nil && ip.To4() == nil {
return fmt.Sprintf("[%s]:%s", dest, defaultSSHPort), nil
}
return "", addrErr
}
return dest, nil
}
// dialDestination creates a new SSH client and dials the destination server
func (s *SSHProxy) dialDestination(ctx ssh.Context) (*gossh.Client, error) {
preamble, ok := ctx.Value(sshContextPreamble).(*SSHPreamble)
if !ok {
msg := "failed to retrieve SSH preamble from context"
s.logger.Error(msg)
return nil, errors.New(msg)
}
signer, err := s.genSSHSigner(preamble.JWT)
if err != nil {
s.logger.WithError(err).Error("Failed to generate signed short lived cert")
return nil, err
}
s.logger.Debugf("Short lived certificate for %s connecting to %s:\n\n%s", ctx.User(), preamble.Destination, gossh.MarshalAuthorizedKey(signer.PublicKey()))
clientConfig := &gossh.ClientConfig{
User: ctx.User(),
// AUTH-2103 TODO: proper host key check
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Auth: []gossh.AuthMethod{gossh.PublicKeys(signer)},
ClientVersion: ctx.ServerVersion(),
}
client, err := gossh.Dial("tcp", preamble.Destination, clientConfig)
if err != nil {
s.logger.WithError(err).Info("Failed to connect to destination SSH server")
return nil, err
}
return client, nil
}
// Generates a key pair and sends public key to get signed by CA
func (s *SSHProxy) genSSHSigner(jwt string) (gossh.Signer, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, errors.Wrap(err, "failed to generate ecdsa key pair")
}
pub, err := gossh.NewPublicKey(&key.PublicKey)
if err != nil {
return nil, errors.Wrap(err, "failed to convert ecdsa public key to SSH public key")
}
pubBytes := gossh.MarshalAuthorizedKey(pub)
signedCertBytes, err := sshgen.SignCert(jwt, string(pubBytes))
if err != nil {
return nil, errors.Wrap(err, "failed to retrieve cert from SSHCAAPI")
}
signedPub, _, _, _, err := gossh.ParseAuthorizedKey([]byte(signedCertBytes))
if err != nil {
return nil, errors.Wrap(err, "failed to parse SSH public key")
}
cert, ok := signedPub.(*gossh.Certificate)
if !ok {
return nil, errors.Wrap(err, "failed to assert public key as certificate")
}
signer, err := gossh.NewSignerFromKey(key)
if err != nil {
return nil, errors.Wrap(err, "failed to create signer")
}
certSigner, err := gossh.NewCertSigner(cert, signer)
if err != nil {
return nil, errors.Wrap(err, "failed to create cert signer")
}
return certSigner, nil
}
// forwardChannelRequest sends request req to SSH channel sshChan, waits for reply, and sends the reply back.
func (s *SSHProxy) forwardChannelRequest(sshChan gossh.Channel, req *gossh.Request) error {
reply, err := sshChan.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
return errors.Wrap(err, "Failed to send request")
}
if err := req.Reply(reply, nil); err != nil {
return errors.Wrap(err, "Failed to reply to request")
}
return nil
}
// logChannelRequest creates an audit log for different types of channel requests
func (s *SSHProxy) logChannelRequest(req *gossh.Request, conn *gossh.ServerConn, ctx ssh.Context) {
var eventType string
var event string
switch req.Type {
case "exec":
var payload struct{ Value string }
if err := gossh.Unmarshal(req.Payload, &payload); err != nil {
s.logger.WithError(err).Errorf("Failed to unmarshal channel request payload: %s:%s", req.Type, req.Payload)
}
event = payload.Value
eventType = auditEventExec
if strings.HasPrefix(string(req.Payload), "scp") {
eventType = auditEventScp
}
case "shell":
eventType = auditEventShell
case "window-change":
eventType = auditEventResize
default:
return
}
s.logAuditEvent(conn, event, eventType, ctx)
}
func (s *SSHProxy) auditLogger() (io.WriteCloser, string, error) {
sessionUUID, err := uuid.NewRandom()
if err != nil {
return nil, "", errors.Wrap(err, "failed to create sessionID")
}
sessionID := sessionUUID.String()
writer, err := s.logManager.NewLogger(fmt.Sprintf("%s-event.log", sessionID), s.logger)
if err != nil {
return nil, "", errors.Wrap(err, "failed to create logger")
}
return writer, sessionID, nil
}
func (s *SSHProxy) logAuditEvent(conn *gossh.ServerConn, event, eventType string, ctx ssh.Context) {
sessionID, sessionIDOk := ctx.Value(sshContextSessionID).(string)
writer, writerOk := ctx.Value(sshContextEventLogger).(io.WriteCloser)
if !writerOk || !sessionIDOk {
s.logger.Error("Failed to retrieve audit logger from context")
return
}
var destination string
preamble, ok := ctx.Value(sshContextPreamble).(*SSHPreamble)
if ok {
destination = preamble.Destination
} else {
s.logger.Error("Failed to retrieve SSH preamble from context")
}
ae := auditEvent{
Event: event,
EventType: eventType,
SessionID: sessionID,
User: conn.User(),
Login: conn.User(),
Datetime: time.Now().UTC().Format(time.RFC3339),
Hostname: s.hostname,
Destination: destination,
}
data, err := json.Marshal(&ae)
if err != nil {
s.logger.WithError(err).Error("Failed to marshal audit event. malformed audit object")
return
}
line := string(data) + "\n"
if _, err := writer.Write([]byte(line)); err != nil {
s.logger.WithError(err).Error("Failed to write audit event.")
}
}

View File

@ -0,0 +1,29 @@
//+build windows
package sshserver
import (
"errors"
"time"
"github.com/cloudflare/cloudflared/sshlog"
"github.com/sirupsen/logrus"
)
const SSHPreambleLength = 2
type SSHServer struct{}
type SSHPreamble struct {
Destination string
JWT string
}
func New(_ sshlog.Manager, _ *logrus.Logger, _, _, _, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) {
return nil, errors.New("cloudflared ssh server is not supported on windows")
}
func (s *SSHServer) Start() error {
return errors.New("cloudflared ssh server is not supported on windows")
}

27
sshserver/testdata/ca vendored Normal file
View File

@ -0,0 +1,27 @@
-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn
NhAAAAAwEAAQAAAQEA0c6EklYvC9B041qEGWDNuot6G4tTVm9LCQC0vA+v2n25ru9CINV6
8IljmXBORXBwfG6PdLhg0SEabZUbsNX5WrIVbGovcghKS6GRsqI5+Quhm+o8eG042JE/hB
oYdZ19TcMEyPOGzHsx0U/BSN9ZJWVCxqN51iI6qyhz9f6jlX2LQBFEvXlhxgF3owBEf8UC
Zt/UvbZdmeeyKNQElPmiVLIJEAPCueECp7a2mjCiP3zqjDvSeeGk4CelB/1qZZ4V2n7fvb
HZjAB5JJs4KXs5o8KgvQnqgQMxiLFZ4PATt4+mxEzh4JymppbqJOo2rYwOA3TAIEWWtYRV
/ZKJ0AyhhQAAA8gciO8XHIjvFwAAAAdzc2gtcnNhAAABAQDRzoSSVi8L0HTjWoQZYM26i3
obi1NWb0sJALS8D6/afbmu70Ig1XrwiWOZcE5FcHB8bo90uGDRIRptlRuw1flashVsai9y
CEpLoZGyojn5C6Gb6jx4bTjYkT+EGhh1nX1NwwTI84bMezHRT8FI31klZULGo3nWIjqrKH
P1/qOVfYtAEUS9eWHGAXejAER/xQJm39S9tl2Z57Io1ASU+aJUsgkQA8K54QKntraaMKI/
fOqMO9J54aTgJ6UH/WplnhXaft+9sdmMAHkkmzgpezmjwqC9CeqBAzGIsVng8BO3j6bETO
HgnKamluok6jatjA4DdMAgRZa1hFX9konQDKGFAAAAAwEAAQAAAQEApVzGdKhk8ETevAst
rurze6JPHcKUbr3NQE1EJi2fBvCtF0oQrtxTx54h2GAB8Q0MO6bQfsiL1ojm0ZQCfUBJBs
jxxb9zoccS98Vilo7ybm5SdBcMjkZX1am1jCMdQCZfCpk4/kGi7yvyOe1IhG01UBodpX5X
mwTjhN+fdjW7LSiW6cKPClN49CZKgmtvI27FCt+/TtMzdCXOiJxJ4yZCzCRhSgssV0gWI1
0VJr/MHirKUvv/qCLAuOBxIr9UgdduRZUpNX+KS2rfhFEbjnUqc/57aAakpQmuPB5I+s9G
DnrF0HSHpq7u1XC1SvYlnFBN/0A7Hw/MX2SaBFH7mc9AAQAAAIAFuTHr6O8tCvWEawfxC0
qiAPQ+Yy1vthq5uewmuQujMutUnc9JAUl32PdU1DbS7APC1Dg9XL7SyAB6A+ZpRJRAKgCY
SneAKE6hOytH+yM206aekrz6VuZiSpBqpfEqDibVAaZIO8sv/9dtZd6kWemxNErPQoKJey
Z7/cuWUWQovAAAAIEA6ugIlVj1irPmElyCCt5YfPv2x8Dl54ELoP/WsffsrPHNQog64hFd
ahD7Wq63TA566bN85fkx8OVU5TbbEQmkHgOEV6nDRY2YsBSqIOblA/KehtfdUIqZB0iNBh
Gn6TV/z6HwnSR3gKv4b66Gveek6LfRAG3mbsLCgyRAbYgn6YUAAACBAOSlf+n1eh6yjtvF
Zecq3Zslj7O8cUs17PQx4vQ7sXNCFrIZdevWPIn9sVrt7/hsTrXunDz6eXCeclB35KZe3H
WPVjRoD+xnr5+sXx2qXOnKCR0LdFybso6IR5bXAI6DNSNfP7D9LPEQ+R73Jk0jPuLYzocS
iM89KZiuGpzr01gBAAAAEW1pa2VAQzAyWTUwVEdKR0g4AQ==
-----END OPENSSH PRIVATE KEY-----

1
sshserver/testdata/ca.pub vendored Normal file
View File

@ -0,0 +1 @@
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDRzoSSVi8L0HTjWoQZYM26i3obi1NWb0sJALS8D6/afbmu70Ig1XrwiWOZcE5FcHB8bo90uGDRIRptlRuw1flashVsai9yCEpLoZGyojn5C6Gb6jx4bTjYkT+EGhh1nX1NwwTI84bMezHRT8FI31klZULGo3nWIjqrKHP1/qOVfYtAEUS9eWHGAXejAER/xQJm39S9tl2Z57Io1ASU+aJUsgkQA8K54QKntraaMKI/fOqMO9J54aTgJ6UH/WplnhXaft+9sdmMAHkkmzgpezmjwqC9CeqBAzGIsVng8BO3j6bETOHgnKamluok6jatjA4DdMAgRZa1hFX9konQDKGF mike@C02Y50TGJGH8

49
sshserver/testdata/id_rsa vendored Normal file
View File

@ -0,0 +1,49 @@
-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAACFwAAAAdzc2gtcn
NhAAAAAwEAAQAAAgEA60Kneo87qPsh+zErWFl7vx93c7fyTxbZ9lUNqafgXy/BLOCc/nQS
McosVSLsQrbHlhYzfmZEhTiubmuYUrHchmsn1ml1HIqP8T5aDgtNbLqYnS4H5oO4Sj1+XH
lQtU7n7zHXgca9SnMWt1Fhkx1mvkeiOKs0eq7hV2TuIZxfmbYfIVvJGwrL0uWzbSEE1gvx
gTXZHxEChIQyrNviljgi4u2MD/cIi6KMeYUnaTL1FxO9G4GIFiy7ueHRwOZPIFHgYm+Vrt
X7XafSF0///zCrC63zzWt/6A06hFepOz2VXvm7SdckaR7qMXAb7kipsc0+dKk9ggU7Fqpx
ZY5cVeZo9RlRVhRXGDy7mABA/FMwvv+qYCgJ3nlZbdKbaiPLQu8ScTlJ9sMI06/ZiEY04b
meZ0ASM52gaDGjrFbbnuHNf5XV/oreEUhtCrryFnoIxmKgHznGjZ55q77FtTHnrAKFmKFP
11s3MLIX9o4RgtriOtl4KenkIfUumgtrwY/UGjOaOQUOrVH1am54wkUiVEF0Qd3AD8KCl/
l/xT5+t6cOspZ9GIhwa2NBmRjN/wVGp+Yrb08Re3kxPCX9bs5iLe+kHN0vuFr7RDo+eUoi
SPhWl6FUqx2W9NZqekmEgKn3oKrfbGaMH1VLkaKWlzQ4xJzP0iadQbIXGryLEYASydemZt
sAAAdQ/ovjxf6L48UAAAAHc3NoLXJzYQAAAgEA60Kneo87qPsh+zErWFl7vx93c7fyTxbZ
9lUNqafgXy/BLOCc/nQSMcosVSLsQrbHlhYzfmZEhTiubmuYUrHchmsn1ml1HIqP8T5aDg
tNbLqYnS4H5oO4Sj1+XHlQtU7n7zHXgca9SnMWt1Fhkx1mvkeiOKs0eq7hV2TuIZxfmbYf
IVvJGwrL0uWzbSEE1gvxgTXZHxEChIQyrNviljgi4u2MD/cIi6KMeYUnaTL1FxO9G4GIFi
y7ueHRwOZPIFHgYm+VrtX7XafSF0///zCrC63zzWt/6A06hFepOz2VXvm7SdckaR7qMXAb
7kipsc0+dKk9ggU7FqpxZY5cVeZo9RlRVhRXGDy7mABA/FMwvv+qYCgJ3nlZbdKbaiPLQu
8ScTlJ9sMI06/ZiEY04bmeZ0ASM52gaDGjrFbbnuHNf5XV/oreEUhtCrryFnoIxmKgHznG
jZ55q77FtTHnrAKFmKFP11s3MLIX9o4RgtriOtl4KenkIfUumgtrwY/UGjOaOQUOrVH1am
54wkUiVEF0Qd3AD8KCl/l/xT5+t6cOspZ9GIhwa2NBmRjN/wVGp+Yrb08Re3kxPCX9bs5i
Le+kHN0vuFr7RDo+eUoiSPhWl6FUqx2W9NZqekmEgKn3oKrfbGaMH1VLkaKWlzQ4xJzP0i
adQbIXGryLEYASydemZtsAAAADAQABAAACABUYzBYEhDAaHSj+dsmcdKll8/tPko4fGXqq
k+gT4t4GVUdl+Q4kcIFAhQs5b4BoDava39FE8H4V4CaMxYMc6g6vy0nB+TuO/Wt/0OmTf+
TxMsBdoV29kCgwLYWzZ1Zq9geQK6g6nzzu5ymXRa3ApDcKC3UTfUhHKHQC3AvtjvEk0NPX
/EfNhwuph5aQsHNVbNnOb2MGznf9tuGjckVQUWiSLs47s+t5rykylJ8tb6cbIQk3a3G5nz
gDFSE8Rfo6/Wk2YnDkRX9XjlKC3Q0QWzZX6hYQvs6baRT3G3jxg9SZhn8PqPc4S34VdJvA
rl8AbcpeZuKi/3J/5F1cD9GwMNcl4gM87piF20/r9mMvC4zBAEgyF8WBi4OjSu0+ccsEsb
GSpxKK04OPTB7p8mLJ8hQUiREg5OuPEEcAoDSuHgdliE7nDHzuImbpTcAZcWhkJaUdBWI6
qcnGPARzxAOmuzkY8Gq0MtcWge5QxnLWJyrfy43M984Cvxql/maLUij4eTbMDDwV7Qx30V
P2tJp5+hOnitRwB6cQIg5N7/cTQdJ6eiFYuw0v3IfHjYmaolY8F3u38Zv2PPk50CorPRDG
esx0a9Elm2UKPb145MtHGZtLH2mayRnDjnxr25iLwgokI06tCLCNvbkYLA7wVpJn81eKmZ
tQBtbfqBSiDiLjCrehAAABAQDh8vmgPR95Kx1DeBxzMSja7TStP5V58JcUvLP4dAN6zqtt
rMuKDfJRSIVYGkU8vXlME5E6phWP7R5lYY7+kLDbeZrYQg1AxkK8y4fUYkCLBuEcCjzWDK
oqZQNskk4urbCdBIP6AR7d/LMCHBb9rk2iOuUeos6JHRKbPGP1lvH3hLkbH9CA0F41sz86
JFg6u/XaRQ2CyhS7y7SQ8dmaANGz9LGdIRqIoZ8Hfht8t1VRbM9fzSb3xoxUItbHpk9R9g
GZsHSryi7AtRmHt0uBrWIv6RbIY0epCbjdCLvHflbkPgwBM7UndgkOSIwQ4SQF8Fs+e9/r
hV05h0Y81vd1RZvOAAABAQD5EgW3SpmYzeMmiP7MKkfIlbZtwVmRu4anTzWxlk5pJ9GXnC
QoInULCipWAOeJbuLIgRWLU4VzhOUbYLNKQPXECARfgoto2VXoXZZ2q2O4aXaCpeyU6nE8
VKbp4nU1jEg5hWB3PRwZ8Pzs4A93/9mrpVzLmCT+LW9Rlnp6tTpqcUKGugg8vr64SSgqnV
ZFyQgHgw+ZGOG9w714urS3U97WNTeHXAs0p2YBOu5XW3JQ3jkRo7YyZF3+TtBxbgfHRZfH
O2mFcMBD3Sn4t+LAbgnLye3S2/WZf/gQwdVB7BgrVqguzQ2hGoOxNiwadkIDsxb6r/u3n6
2lScpHFDS0WnpRAAABAQDxzkV52VX6wAWkQe/2KFH9wTG0XFANmZUnnTPR8wd+b9E7HIr0
Mdd8iAHOhLRvTy8mih53GGBptXK7GdABMZtkqDErbXhuC8xbi9uRLEHiRe/oBfWr8vYIZY
awiw3/EqxaTv0HBMicdr2S31Bs2/mjrVuJH0wAaI9ueQnZizzjgWuzeNZMWq1qk0akUUdm
PDVd58yBkt8lKlkOG0LJAn6JEG9oH9XiTFShHzu1dQmoC2bKVHdxL8WCcYFVtmyoMRcLZq
u6d4nyKha02cYZB5hM3VcizJI5HY/A+H3fBkRR0hXgkU5R89w+8x9VSJkNVx+JGC7ziK4a
kUjfOmR5WBdrAAAAE3Rlc3RAY2xvdWRmbGFyZS5jb20BAgMEBQYH
-----END OPENSSH PRIVATE KEY-----

1
sshserver/testdata/id_rsa-cert.pub vendored Normal file
View File

@ -0,0 +1 @@
ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgOsuFqKdzp/nC3wQfKVJBdHa8axtGryKplPkDjdSXT4kAAAADAQABAAACAQDrQqd6jzuo+yH7MStYWXu/H3dzt/JPFtn2VQ2pp+BfL8Es4Jz+dBIxyixVIuxCtseWFjN+ZkSFOK5ua5hSsdyGayfWaXUcio/xPloOC01supidLgfmg7hKPX5ceVC1TufvMdeBxr1Kcxa3UWGTHWa+R6I4qzR6ruFXZO4hnF+Zth8hW8kbCsvS5bNtIQTWC/GBNdkfEQKEhDKs2+KWOCLi7YwP9wiLoox5hSdpMvUXE70bgYgWLLu54dHA5k8gUeBib5Wu1ftdp9IXT///MKsLrfPNa3/oDTqEV6k7PZVe+btJ1yRpHuoxcBvuSKmxzT50qT2CBTsWqnFljlxV5mj1GVFWFFcYPLuYAED8UzC+/6pgKAneeVlt0ptqI8tC7xJxOUn2wwjTr9mIRjThuZ5nQBIznaBoMaOsVtue4c1/ldX+it4RSG0KuvIWegjGYqAfOcaNnnmrvsW1MeesAoWYoU/XWzcwshf2jhGC2uI62Xgp6eQh9S6aC2vBj9QaM5o5BQ6tUfVqbnjCRSJUQXRB3cAPwoKX+X/FPn63pw6yln0YiHBrY0GZGM3/BUan5itvTxF7eTE8Jf1uzmIt76Qc3S+4WvtEOj55SiJI+FaXoVSrHZb01mp6SYSAqfegqt9sZowfVUuRopaXNDjEnM/SJp1BshcavIsRgBLJ16Zm2wAAAAAAAAAAAAAAAQAAAA10ZXN0VXNlckB0ZXN0AAAADAAAAAh0ZXN0VXNlcgAAAAAAAAAA//////////8AAAAAAAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEA0c6EklYvC9B041qEGWDNuot6G4tTVm9LCQC0vA+v2n25ru9CINV68IljmXBORXBwfG6PdLhg0SEabZUbsNX5WrIVbGovcghKS6GRsqI5+Quhm+o8eG042JE/hBoYdZ19TcMEyPOGzHsx0U/BSN9ZJWVCxqN51iI6qyhz9f6jlX2LQBFEvXlhxgF3owBEf8UCZt/UvbZdmeeyKNQElPmiVLIJEAPCueECp7a2mjCiP3zqjDvSeeGk4CelB/1qZZ4V2n7fvbHZjAB5JJs4KXs5o8KgvQnqgQMxiLFZ4PATt4+mxEzh4JymppbqJOo2rYwOA3TAIEWWtYRV/ZKJ0AyhhQAAAQ8AAAAHc3NoLXJzYQAAAQC2lL+6JYTGOdz1zNnck6onrFcVpO2onCVAKP8HdLoCeH0/upIugaCocPKuzoURYEfiHQotviNeprE/2CyAroJ5VBdqWftEeHn3FFvBCQ1gwRQ7oci4C5n72t0vjWWE6WBylS0RqpJjr6EQ8a1vuwIqAQrEJPp2yNLjRH2WD7eicBh5f43VKOMr73DtyTh4xoF0C2sNBROudt58npTaYqRHQgoI25V/aCmuYBgM3wdAGcoEZGoSerMfhID7GcWkvemq2hF8mQsspG3zgnyQXk+ahagmefzxutDnr3KdrZ637La0/XwABvBZ9L4l5RiEilVI1Shl96F2qbBW2YZ64pUQ test@cloudflare.com

1
sshserver/testdata/id_rsa.pub vendored Normal file
View File

@ -0,0 +1 @@
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDrQqd6jzuo+yH7MStYWXu/H3dzt/JPFtn2VQ2pp+BfL8Es4Jz+dBIxyixVIuxCtseWFjN+ZkSFOK5ua5hSsdyGayfWaXUcio/xPloOC01supidLgfmg7hKPX5ceVC1TufvMdeBxr1Kcxa3UWGTHWa+R6I4qzR6ruFXZO4hnF+Zth8hW8kbCsvS5bNtIQTWC/GBNdkfEQKEhDKs2+KWOCLi7YwP9wiLoox5hSdpMvUXE70bgYgWLLu54dHA5k8gUeBib5Wu1ftdp9IXT///MKsLrfPNa3/oDTqEV6k7PZVe+btJ1yRpHuoxcBvuSKmxzT50qT2CBTsWqnFljlxV5mj1GVFWFFcYPLuYAED8UzC+/6pgKAneeVlt0ptqI8tC7xJxOUn2wwjTr9mIRjThuZ5nQBIznaBoMaOsVtue4c1/ldX+it4RSG0KuvIWegjGYqAfOcaNnnmrvsW1MeesAoWYoU/XWzcwshf2jhGC2uI62Xgp6eQh9S6aC2vBj9QaM5o5BQ6tUfVqbnjCRSJUQXRB3cAPwoKX+X/FPn63pw6yln0YiHBrY0GZGM3/BUan5itvTxF7eTE8Jf1uzmIt76Qc3S+4WvtEOj55SiJI+FaXoVSrHZb01mp6SYSAqfegqt9sZowfVUuRopaXNDjEnM/SJp1BshcavIsRgBLJ16Zm2w== test@cloudflare.com

27
sshserver/testdata/other_ca vendored Normal file
View File

@ -0,0 +1,27 @@
-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn
NhAAAAAwEAAQAAAQEAzBO7TXxbpk7sGQm/Wa29N/NFe5uuoEQGC5hxfihmcvVgeKeNKiSS
snxzCE1Y6SmNMoE4aQs92wtcn48GmxRwZSXbCqLq2CJrHfe9B2k3aPkJZpQkFMshcJGo7p
G0Vlo7dWAbYf99/YKddf290uLK7vxw9ty0pM1hXSXHNShv1b+bTQm/COMZ5jNsncjc1yBH
KGkFVHee9Dh4Z0xLlHipIyyNXXzI0RFYuHSNJz9GD310XQLIIroptr7+/7g6+sPPGsNlI+
95OScba1/PQ2b/qy+KyIwNIMSd9ziJy5xnO7Vo3LrqQrza1Pkn2i29PljUcbc/F0hhXNIq
ITdNWwVqsQAAA8iKllTIipZUyAAAAAdzc2gtcnNhAAABAQDME7tNfFumTuwZCb9Zrb0380
V7m66gRAYLmHF+KGZy9WB4p40qJJKyfHMITVjpKY0ygThpCz3bC1yfjwabFHBlJdsKourY
Imsd970HaTdo+QlmlCQUyyFwkajukbRWWjt1YBth/339gp11/b3S4sru/HD23LSkzWFdJc
c1KG/Vv5tNCb8I4xnmM2ydyNzXIEcoaQVUd570OHhnTEuUeKkjLI1dfMjREVi4dI0nP0YP
fXRdAsgiuim2vv7/uDr6w88aw2Uj73k5JxtrX89DZv+rL4rIjA0gxJ33OInLnGc7tWjcuu
pCvNrU+SfaLb0+WNRxtz8XSGFc0iohN01bBWqxAAAAAwEAAQAAAQAKEtNFEOVpQS4QUlXa
tGPJtj1wy4+EI7d0rRK1GoNsG0amzgZ+1Q1UuCXpe//uinmIy64gKUjlXhs1WRcHYqvlok
e8r6wN/Szybr8q9Xuht+FJ6fgZ+qjs6JPBKvoO5SdYNOVFIhpzABaLs3nCRiWkRFvDI8Pa
+rRap7m8mwFiOJtmdiIZYFxzw6xXwTsGCrWPKgTv3FKGZzXnCB9i7jC2vwT1MDYbcnzEH4
Ba4dxI8bp6WWEX0biRIXj3jCtLb5gisNTSxdZs254Syh75HEXunSh2YO+yVSWQtZj19ewW
6Rb1Z3x5rVfXcgSkg7gZd9EpbckIIg6+MFSH3wdGW6atAAAAgQDFXiMuNd4ZYwdyhjlM5n
nFqQDXGgnwyNdiIqAapoqTdF5aZwNnbTU0fCFaDMLCQAHgntcgCEsW9A4HzDzYhOABKElv
j973vXWF165wFiZwuKSfroq/6JH6CiIcjiqpszbnqSOzy1hq913RWILS6e9yMjxRv8PUjm
E+IkcnfcFUwAAAAIEA+jwI3ICe8PGEIezV2tvQFeQy2Z2wGslu1yvqfTYEztSmtygns3wn
ZBb+cBXCnpqUCtznG7hZhq7I4m1I47BYznULwwFiBTVtBASG5wNP7zeVKTVZ4SKprze+Fe
I/nUZDJ5Q26um7eDbhvZ/n95GY+fucMVHoSBfX1wE16XBfp88AAACBANDHcgC4qP2oyOw/
+p9HineMQd/ppG3fePe07jyZXLHLf0rByFveFgRAQ1m77O7FtP3fFKy3Y9nNy18LGq35ZK
Blsz2B23bO8NuffgAhchDG7KzKFXCo+AraIj5znp/znK5zIkaiiSOQaYywJ36EooYVpRtj
ep5ap6bBFDZ2e+V/AAAAEW1pa2VAQzAyWTUwVEdKR0g4AQ==
-----END OPENSSH PRIVATE KEY-----

1
sshserver/testdata/other_ca.pub vendored Normal file
View File

@ -0,0 +1 @@
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDME7tNfFumTuwZCb9Zrb0380V7m66gRAYLmHF+KGZy9WB4p40qJJKyfHMITVjpKY0ygThpCz3bC1yfjwabFHBlJdsKourYImsd970HaTdo+QlmlCQUyyFwkajukbRWWjt1YBth/339gp11/b3S4sru/HD23LSkzWFdJcc1KG/Vv5tNCb8I4xnmM2ydyNzXIEcoaQVUd570OHhnTEuUeKkjLI1dfMjREVi4dI0nP0YPfXRdAsgiuim2vv7/uDr6w88aw2Uj73k5JxtrX89DZv+rL4rIjA0gxJ33OInLnGc7tWjcuupCvNrU+SfaLb0+WNRxtz8XSGFc0iohN01bBWqx mike@C02Y50TGJGH8

View File

@ -35,26 +35,49 @@ func createRequest(stream *h2mux.MuxedStream, url *url.URL) (*http.Request, erro
return req, nil
}
// H2RequestHeadersToH1Request converts the HTTP/2 headers to an HTTP/1 Request
// object. This includes conversion of the pseudo-headers into their closest
// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
for _, header := range h2 {
switch header.Name {
case ":method":
h1.Method = header.Value
case ":scheme":
// noop - use the preexisting scheme from h1.URL
case ":authority":
// Otherwise the host header will be based on the origin URL
h1.Host = header.Value
case ":path":
u, err := url.Parse(header.Value)
// We don't want to be an "opinionated" proxy, so ideally we would use :path as-is.
// However, this HTTP/1 Request object belongs to the Go standard library,
// whose URL package makes some opinionated decisions about the encoding of
// URL characters: see the docs of https://godoc.org/net/url#URL,
// in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath,
// which is always used when computing url.URL.String(), whether we'd like it or not.
//
// Well, not *always*. We could circumvent this by using url.URL.Opaque. But
// that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque
// is treated differently when HTTP_PROXY is set!
// See https://github.com/golang/go/issues/5684#issuecomment-66080888
//
// This means we are subject to the behavior of net/url's function `shouldEscape`
// (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101
if header.Value == "*" {
h1.URL.Path = "*"
continue
}
// Due to the behavior of validation.ValidateUrl, h1.URL may
// already have a partial value, with or without a trailing slash.
base := h1.URL.String()
base = strings.TrimRight(base, "/")
// But we know :path begins with '/', because we handled '*' above - see RFC7540
url, err := url.Parse(base + header.Value)
if err != nil {
return fmt.Errorf("unparseable path")
return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value))
}
resolved := h1.URL.ResolveReference(u)
// prevent escaping base URL
if !strings.HasPrefix(resolved.String(), h1.URL.String()) {
return fmt.Errorf("invalid path")
}
h1.URL = resolved
h1.URL = url
case "content-length":
contentLength, err := strconv.ParseInt(header.Value, 10, 64)
if err != nil {

View File

@ -0,0 +1,441 @@
package streamhandler
import (
"fmt"
"math/rand"
"net/http"
"net/url"
"reflect"
"regexp"
"strings"
"testing"
"testing/quick"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{
h2mux.Header{
Name: "Mock header 1",
Value: "Mock value 1",
},
h2mux.Header{
Name: "Mock header 2",
Value: "Mock value 2",
},
},
request,
)
assert.Equal(t, http.Header{
"Mock header 1": []string{"Mock value 1"},
"Mock header 2": []string{"Mock value 2"},
}, request.Header)
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{},
request,
)
assert.Equal(t, http.Header{}, request.Header)
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{
h2mux.Header{
Name: ":path",
Value: "//bad_path/",
},
h2mux.Header{
Name: "Mock header",
Value: "Mock value",
},
},
request,
)
assert.Equal(t, http.Header{
"Mock header": []string{"Mock value"},
}, request.Header)
assert.Equal(t, "http://example.com//bad_path/", request.URL.String())
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
assert.NoError(t, err)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{
h2mux.Header{
Name: ":path",
Value: "/?query=mock%20value",
},
h2mux.Header{
Name: "Mock header",
Value: "Mock value",
},
},
request,
)
assert.Equal(t, http.Header{
"Mock header": []string{"Mock value"},
}, request.Header)
assert.Equal(t, "http://example.com/?query=mock%20value", request.URL.String())
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
assert.NoError(t, err)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{
h2mux.Header{
Name: ":path",
Value: "/mock%20path",
},
h2mux.Header{
Name: "Mock header",
Value: "Mock value",
},
},
request,
)
assert.Equal(t, http.Header{
"Mock header": []string{"Mock value"},
}, request.Header)
assert.Equal(t, "http://example.com/mock%20path", request.URL.String())
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) {
type testCase struct {
path string
want string
}
testCases := []testCase{
{
path: "",
want: "",
},
{
path: "/",
want: "/",
},
{
path: "//",
want: "//",
},
{
path: "/test",
want: "/test",
},
{
path: "//test",
want: "//test",
},
{
// https://github.com/cloudflare/cloudflared/issues/81
path: "//test/",
want: "//test/",
},
{
path: "/%2Ftest",
want: "/%2Ftest",
},
{
path: "//%20test",
want: "//%20test",
},
{
// https://github.com/cloudflare/cloudflared/issues/124
path: "/test?get=somthing%20a",
want: "/test?get=somthing%20a",
},
{
path: "/%20",
want: "/%20",
},
{
// stdlib's EscapedPath() will always percent-encode ' '
path: "/ ",
want: "/%20",
},
{
path: "/ a ",
want: "/%20a%20",
},
{
path: "/a%20b",
want: "/a%20b",
},
{
path: "/foo/bar;param?query#frag",
want: "/foo/bar;param?query#frag",
},
{
// stdlib's EscapedPath() will always percent-encode non-ASCII chars
path: "/a␠b",
want: "/a%E2%90%A0b",
},
{
path: "/a-umlaut-ä",
want: "/a-umlaut-%C3%A4",
},
{
path: "/a-umlaut-%C3%A4",
want: "/a-umlaut-%C3%A4",
},
{
path: "/a-umlaut-%c3%a4",
want: "/a-umlaut-%c3%a4",
},
{
// here the second '#' is treated as part of the fragment
path: "/a#b#c",
want: "/a#b%23c",
},
{
path: "/a#b␠c",
want: "/a#b%E2%90%A0c",
},
{
path: "/a#b%20c",
want: "/a#b%20c",
},
{
path: "/a#b c",
want: "/a#b%20c",
},
{
// stdlib's EscapedPath() will always percent-encode '\'
path: "/\\",
want: "/%5C",
},
{
path: "/a\\",
want: "/a%5C",
},
{
path: "/a,b.c.",
want: "/a,b.c.",
},
{
path: "/.",
want: "/.",
},
{
// stdlib's EscapedPath() will always percent-encode '`'
path: "/a`",
want: "/a%60",
},
{
path: "/a[0]",
want: "/a[0]",
},
{
path: "/?a[0]=5 &b[]=",
want: "/?a[0]=5 &b[]=",
},
{
path: "/?a=%22b%20%22",
want: "/?a=%22b%20%22",
},
}
for index, testCase := range testCases {
requestURL := "https://example.com"
request, err := http.NewRequest(http.MethodGet, requestURL, nil)
assert.NoError(t, err)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{
h2mux.Header{
Name: ":path",
Value: testCase.path,
},
h2mux.Header{
Name: "Mock header",
Value: "Mock value",
},
},
request,
)
assert.NoError(t, headersConversionErr)
assert.Equal(t,
http.Header{
"Mock header": []string{"Mock value"},
},
request.Header)
assert.Equal(t,
"https://example.com"+testCase.want,
request.URL.String(),
"Failed URL index: %v %#v", index, testCase)
}
}
func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) {
config := &quick.Config{
Values: func(args []reflect.Value, rand *rand.Rand) {
args[0] = reflect.ValueOf(randomHTTP2Path(t, rand))
},
}
type testOrigin struct {
url string
expectedScheme string
expectedBasePath string
}
testOrigins := []testOrigin{
{
url: "http://origin.hostname.example.com:8080",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080",
},
{
url: "http://origin.hostname.example.com:8080/",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080",
},
{
url: "http://origin.hostname.example.com:8080/api",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080/api",
},
{
url: "http://origin.hostname.example.com:8080/api/",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080/api",
},
{
url: "https://origin.hostname.example.com:8080/api",
expectedScheme: "https",
expectedBasePath: "https://origin.hostname.example.com:8080/api",
},
}
// use multiple schemes to demonstrate that the URL is based on the
// origin's scheme, not the :scheme header
for _, testScheme := range []string{"http", "https"} {
for _, testOrigin := range testOrigins {
assertion := func(testPath string) bool {
const expectedMethod = "POST"
const expectedHostname = "request.hostname.example.com"
h2 := []h2mux.Header{
h2mux.Header{Name: ":method", Value: expectedMethod},
h2mux.Header{Name: ":scheme", Value: testScheme},
h2mux.Header{Name: ":authority", Value: expectedHostname},
h2mux.Header{Name: ":path", Value: testPath},
}
h1, err := http.NewRequest("GET", testOrigin.url, nil)
require.NoError(t, err)
err = H2RequestHeadersToH1Request(h2, h1)
return assert.NoError(t, err) &&
assert.Equal(t, expectedMethod, h1.Method) &&
assert.Equal(t, expectedHostname, h1.Host) &&
assert.Equal(t, testOrigin.expectedScheme, h1.URL.Scheme) &&
assert.Equal(t, testOrigin.expectedBasePath+testPath, h1.URL.String())
}
err := quick.Check(assertion, config)
assert.NoError(t, err)
}
}
}
func randomASCIIPrintableChar(rand *rand.Rand) int {
// smallest printable ASCII char is 32, largest is 126
const startPrintable = 32
const endPrintable = 127
return startPrintable + rand.Intn(endPrintable-startPrintable)
}
// randomASCIIText generates an ASCII string, some of whose characters may be
// percent-encoded. Its "logical length" (ignoring percent-encoding) is
// between 1 and `maxLength`.
func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string {
length := minLength + rand.Intn(maxLength)
result := ""
for i := 0; i < length; i++ {
c := randomASCIIPrintableChar(rand)
// 1/4 chance of using percent encoding when not necessary
if c == '%' || rand.Intn(4) == 0 {
result += fmt.Sprintf("%%%02X", c)
} else {
result += string(c)
}
}
return result
}
// Calls `randomASCIIText` and ensures the result is a valid URL path,
// i.e. one that can pass unchanged through url.URL.String()
func randomHTTP1Path(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
text := randomASCIIText(rand, minLength, maxLength)
regexp, err := regexp.Compile("[^/;,]*")
require.NoError(t, err)
return "/" + regexp.ReplaceAllStringFunc(text, url.PathEscape)
}
// Calls `randomASCIIText` and ensures the result is a valid URL query,
// i.e. one that can pass unchanged through url.URL.String()
func randomHTTP1Query(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
text := randomASCIIText(rand, minLength, maxLength)
return "?" + strings.ReplaceAll(text, "#", "%23")
}
// Calls `randomASCIIText` and ensures the result is a valid URL fragment,
// i.e. one that can pass unchanged through url.URL.String()
func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
text := randomASCIIText(rand, minLength, maxLength)
url, err := url.Parse("#" + text)
require.NoError(t, err)
return url.String()
}
// Assemble a random :path pseudoheader that is legal by Go stdlib standards
// (i.e. all characters will satisfy "net/url".shouldEscape for their respective locations)
func randomHTTP2Path(t *testing.T, rand *rand.Rand) string {
result := randomHTTP1Path(t, rand, 1, 64)
if rand.Intn(2) == 1 {
result += randomHTTP1Query(t, rand, 1, 32)
}
if rand.Intn(2) == 1 {
result += randomHTTP1Fragment(t, rand, 1, 16)
}
return result
}

View File

@ -82,11 +82,18 @@ func (s *StreamHandler) UseConfiguration(ctx context.Context, config *pogs.Clien
// UpdateConfig replaces current originmapper mapping with mappings from newConfig
func (s *StreamHandler) UpdateConfig(newConfig []*pogs.ReverseProxyConfig) (failedConfigs []*pogs.FailedConfig) {
// TODO: TUN-1968: Gracefully apply new config
s.tunnelHostnameMapper.DeleteAll()
for _, tunnelConfig := range newConfig {
// Delete old configs that aren't in the `newConfig`
toRemove := s.tunnelHostnameMapper.ToRemove(newConfig)
for _, hostnameToRemove := range toRemove {
s.tunnelHostnameMapper.Delete(hostnameToRemove)
}
// Add new configs that weren't in the old mapper
toAdd := s.tunnelHostnameMapper.ToAdd(newConfig)
for _, tunnelConfig := range toAdd {
tunnelHostname := tunnelConfig.TunnelHostname
originSerice, err := tunnelConfig.OriginConfigJSONHandler.OriginConfig.Service()
originSerice, err := tunnelConfig.OriginConfig.Service()
if err != nil {
s.logger.WithField("tunnelHostname", tunnelHostname).WithError(err).Error("Invalid origin service config")
failedConfigs = append(failedConfigs, &pogs.FailedConfig{

View File

@ -49,10 +49,8 @@ func TestServeRequest(t *testing.T) {
reverseProxyConfigs := []*pogs.ReverseProxyConfig{
{
TunnelHostname: testTunnelHostname,
OriginConfigJSONHandler: &pogs.OriginConfigJSONHandler{
OriginConfig: &pogs.HTTPOriginConfig{
URLString: httpServer.URL,
},
OriginConfig: &pogs.HTTPOriginConfig{
URLString: httpServer.URL,
},
},
}
@ -99,10 +97,8 @@ func TestServeBadRequest(t *testing.T) {
reverseProxyConfigs := []*pogs.ReverseProxyConfig{
{
TunnelHostname: testTunnelHostname,
OriginConfigJSONHandler: &pogs.OriginConfigJSONHandler{
OriginConfig: &pogs.HTTPOriginConfig{
URLString: "",
},
OriginConfig: &pogs.HTTPOriginConfig{
URLString: "",
},
},
}
@ -145,7 +141,7 @@ type DefaultMuxerPair struct {
doneC chan struct{}
}
func NewDefaultMuxerPair(t assert.TestingT, h h2mux.MuxedStreamHandler) *DefaultMuxerPair {
func NewDefaultMuxerPair(t *testing.T, h h2mux.MuxedStreamHandler) *DefaultMuxerPair {
origin, edge := net.Pipe()
p := &DefaultMuxerPair{
OriginMuxConfig: h2mux.MuxerConfig{
@ -171,20 +167,20 @@ func NewDefaultMuxerPair(t assert.TestingT, h h2mux.MuxedStreamHandler) *Default
EdgeConn: edge,
doneC: make(chan struct{}),
}
assert.NoError(t, p.Handshake())
assert.NoError(t, p.Handshake(t.Name()))
return p
}
func (p *DefaultMuxerPair) Handshake() error {
func (p *DefaultMuxerPair) Handshake(testName string) error {
ctx, cancel := context.WithTimeout(context.Background(), testHandshakeTimeout)
defer cancel()
errGroup, _ := errgroup.WithContext(ctx)
errGroup.Go(func() (err error) {
p.EdgeMux, err = h2mux.Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig)
p.EdgeMux, err = h2mux.Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig, h2mux.NewActiveStreamsMetrics(testName, "edge"))
return errors.Wrap(err, "edge handshake failure")
})
errGroup.Go(func() (err error) {
p.OriginMux, err = h2mux.Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig)
p.OriginMux, err = h2mux.Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig, h2mux.NewActiveStreamsMetrics(testName, "origin"))
return errors.Wrap(err, "origin handshake failure")
})

View File

@ -16,6 +16,7 @@ import (
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
)
@ -28,6 +29,27 @@ type Supervisor struct {
useConfigResultChan chan<- *pogs.UseConfigurationResult
state *state
logger *logrus.Entry
metrics metrics
}
type metrics struct {
configVersion prometheus.Gauge
}
func newMetrics() metrics {
configVersion := prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "supervisor",
Subsystem: "supervisor",
Name: "config_version",
Help: "Latest configuration version received from Cloudflare",
},
)
prometheus.MustRegister(
configVersion,
)
return metrics{
configVersion: configVersion,
}
}
func NewSupervisor(
@ -70,6 +92,7 @@ func NewSupervisor(
useConfigResultChan: useConfigResultChan,
state: newState(defaultClientConfig),
logger: logger.WithField("subsystem", "supervisor"),
metrics: newMetrics(),
}, nil
}
@ -131,6 +154,7 @@ func (s *Supervisor) notifySubsystemsNewConfig(newConfig *pogs.ClientConfig) *po
Success: true,
}
}
s.metrics.configVersion.Set(float64(newConfig.Version))
s.state.updateConfig(newConfig)
var tunnelHostnames []h2mux.TunnelHostname

Some files were not shown because too many files have changed in this diff Show More