TUN-2110: Implement custom deserialization logic for OriginConfig

pull/126/head
Chung-Ting Huang 4 years ago
parent 5feba7e3a9
commit bdd70e798a

40
Gopkg.lock generated

@ -73,15 +73,9 @@
version = "v1.2.0"
[[projects]]
digest = "1:6a503e232df389d94ebb97dfb22d4ae463b6e2f351660613e11d9e42f57ab6df"
digest = "1:6f70106e7bc1c803e8a0a4519e09c12d154771acfa2559206e97b033bbd1dd38"
name = "github.com/coreos/go-oidc"
packages = [
"http",
"jose",
"key",
"oauth2",
"oidc",
]
packages = ["jose"]
pruneopts = "UT"
revision = "a93f71fdfe73d2c0f5413c0565eea0af6523a6df"
@ -93,18 +87,6 @@
revision = "95778dfbb74eb7e4dbaf43bf7d71809650ef8076"
version = "v19"
[[projects]]
digest = "1:6fda0d7f5e52b081e075775b1ecebf1ea0c923e7be33604ed0225ae078e701b5"
name = "github.com/coreos/pkg"
packages = [
"health",
"httputil",
"timeutil",
]
pruneopts = "UT"
revision = "97fdf19511ea361ae1c100dd393cc47f8dcfa1e1"
version = "v4"
[[projects]]
digest = "1:ffe9824d294da03b391f44e1ae8281281b4afc1bdaa9588c9097785e3af10cec"
name = "github.com/davecgh/go-spew"
@ -212,14 +194,6 @@
pruneopts = "UT"
revision = "8e809c8a86450a29b90dcc9efbf062d0fe6d9746"
[[projects]]
digest = "1:75ab90ae3f5d876167e60f493beadfe66f0ed861a710f283fb06c86437a09538"
name = "github.com/jonboulle/clockwork"
packages = ["."]
pruneopts = "UT"
revision = "2eee05ed794112d45db504eb05aa693efd2b8b09"
version = "v0.1.0"
[[projects]]
digest = "1:31e761d97c76151dde79e9d28964a812c46efc5baee4085b86f68f0c654450de"
name = "github.com/konsorten/go-windows-terminal-sequences"
@ -291,6 +265,14 @@
revision = "af06845cf3004701891bf4fdb884bfe4920b3727"
version = "v1.1.0"
[[projects]]
digest = "1:53bc4cd4914cd7cd52139990d5170d6dc99067ae31c56530621b18b35fc30318"
name = "github.com/mitchellh/mapstructure"
packages = ["."]
pruneopts = "UT"
revision = "3536a929edddb9a5b34bd6861dc4a9647cb459fe"
version = "v1.1.2"
[[projects]]
digest = "1:11e62d6050198055e6cd87ed57e5d8c669e84f839c16e16f192374d913d1a70d"
name = "github.com/opentracing/opentracing-go"
@ -577,7 +559,6 @@
"github.com/coredns/coredns/plugin/pkg/rcode",
"github.com/coredns/coredns/request",
"github.com/coreos/go-oidc/jose",
"github.com/coreos/go-oidc/oidc",
"github.com/coreos/go-systemd/daemon",
"github.com/elgs/gosqljson",
"github.com/equinox-io/equinox",
@ -591,6 +572,7 @@
"github.com/mattn/go-colorable",
"github.com/miekg/dns",
"github.com/mitchellh/go-homedir",
"github.com/mitchellh/mapstructure",
"github.com/pkg/errors",
"github.com/prometheus/client_golang/prometheus",
"github.com/prometheus/client_golang/prometheus/promhttp",

@ -85,3 +85,7 @@
[[constraint]]
name = "github.com/google/uuid"
version = "=1.1.1"
[[constraint]]
name = "github.com/mitchellh/mapstructure"
version = "1.1.2"

@ -12,7 +12,6 @@ import (
"syscall"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/connection"
@ -386,6 +385,17 @@ func startDeclarativeTunnel(ctx context.Context,
logger.WithError(err)
return err
}
reverseProxyConfig, err := pogs.NewReverseProxyConfig(
c.String("hostname"),
reverseProxyOrigin,
c.Uint64("retries"),
c.Duration("proxy-connection-timeout"),
c.Uint64("compression-quality"),
)
if err != nil {
logger.WithError(err).Error("Cannot initialize default client config because reverse proxy config is invalid")
return err
}
defaultClientConfig := &pogs.ClientConfig{
Version: pogs.InitVersion(),
SupervisorConfig: &pogs.SupervisorConfig{
@ -399,13 +409,8 @@ func startDeclarativeTunnel(ctx context.Context,
Timeout: c.Duration("dial-edge-timeout"),
MaxFailedHeartbeats: c.Uint64("heartbeat-count"),
},
DoHProxyConfigs: []*pogs.DoHProxyConfig{},
ReverseProxyConfigs: []*pogs.ReverseProxyConfig{
{
TunnelHostname: h2mux.TunnelHostname(c.String("hostname")),
Origin: reverseProxyOrigin,
},
},
DoHProxyConfigs: []*pogs.DoHProxyConfig{},
ReverseProxyConfigs: []*pogs.ReverseProxyConfig{reverseProxyConfig},
}
autoupdater := updater.NewAutoUpdater(defaultClientConfig.SupervisorConfig.AutoUpdateFrequency, listeners)

@ -86,7 +86,7 @@ func (s *StreamHandler) UpdateConfig(newConfig []*pogs.ReverseProxyConfig) (fail
s.tunnelHostnameMapper.DeleteAll()
for _, tunnelConfig := range newConfig {
tunnelHostname := tunnelConfig.TunnelHostname
originSerice, err := tunnelConfig.Origin.Service()
originSerice, err := tunnelConfig.OriginConfigUnmarshaler.OriginConfig.Service()
if err != nil {
s.logger.WithField("tunnelHostname", tunnelHostname).WithError(err).Error("Invalid origin service config")
failedConfigs = append(failedConfigs, &pogs.FailedConfig{

@ -49,8 +49,10 @@ func TestServeRequest(t *testing.T) {
reverseProxyConfigs := []*pogs.ReverseProxyConfig{
{
TunnelHostname: testTunnelHostname,
Origin: &pogs.HTTPOriginConfig{
URLString: httpServer.URL,
OriginConfigUnmarshaler: &pogs.OriginConfigUnmarshaler{
OriginConfig: &pogs.HTTPOriginConfig{
URLString: httpServer.URL,
},
},
},
}
@ -97,8 +99,10 @@ func TestServeBadRequest(t *testing.T) {
reverseProxyConfigs := []*pogs.ReverseProxyConfig{
{
TunnelHostname: testTunnelHostname,
Origin: &pogs.HTTPOriginConfig{
URLString: "",
OriginConfigUnmarshaler: &pogs.OriginConfigUnmarshaler{
OriginConfig: &pogs.HTTPOriginConfig{
URLString: "",
},
},
},
}

@ -14,6 +14,7 @@ import (
"github.com/cloudflare/cloudflared/originservice"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/pkg/errors"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs"
@ -27,11 +28,11 @@ import (
// ClientConfig is a collection of FallibleConfig that determines how cloudflared should function
type ClientConfig struct {
Version Version
SupervisorConfig *SupervisorConfig
EdgeConnectionConfig *EdgeConnectionConfig
DoHProxyConfigs []*DoHProxyConfig
ReverseProxyConfigs []*ReverseProxyConfig
Version Version `json:"version"`
SupervisorConfig *SupervisorConfig `json:"supervisor_config"`
EdgeConnectionConfig *EdgeConnectionConfig `json:"edge_connection_config"`
DoHProxyConfigs []*DoHProxyConfig `json:"doh_proxy_configs"`
ReverseProxyConfigs []*ReverseProxyConfig `json:"reverse_proxy_configs"`
}
// Version type models the version of a ClientConfig
@ -56,9 +57,9 @@ type FallibleConfig interface {
// SupervisorConfig specifies config of components managed by Supervisor other than ConnectionManager
type SupervisorConfig struct {
AutoUpdateFrequency time.Duration
MetricsUpdateFrequency time.Duration
GracePeriod time.Duration
AutoUpdateFrequency time.Duration `json:"auto_update_frequency"`
MetricsUpdateFrequency time.Duration `json:"metrics_update_frequency"`
GracePeriod time.Duration `json:"grace_period"`
}
// FailReason impelents FallibleConfig interface for SupervisorConfig
@ -68,11 +69,11 @@ func (sc *SupervisorConfig) FailReason(err error) string {
// EdgeConnectionConfig specifies what parameters and how may connections should ConnectionManager establish with edge
type EdgeConnectionConfig struct {
NumHAConnections uint8
HeartbeatInterval time.Duration
Timeout time.Duration
MaxFailedHeartbeats uint64
UserCredentialPath string
NumHAConnections uint8 `json:"num_ha_connections"`
HeartbeatInterval time.Duration `json:"heartbeat_interval"`
Timeout time.Duration `json:"timeout"`
MaxFailedHeartbeats uint64 `json:"max_failed_heartbeats"`
UserCredentialPath string `json:"user_credential_path"`
}
// FailReason impelents FallibleConfig interface for EdgeConnectionConfig
@ -82,9 +83,9 @@ func (cmc *EdgeConnectionConfig) FailReason(err error) string {
// DoHProxyConfig is configuration for DNS over HTTPS service
type DoHProxyConfig struct {
ListenHost string
ListenPort uint16
Upstreams []string
ListenHost string `json:"listen_host"`
ListenPort uint16 `json:"listen_port"`
Upstreams []string `json:"upstreams"`
}
// FailReason impelents FallibleConfig interface for DoHProxyConfig
@ -94,11 +95,11 @@ func (dpc *DoHProxyConfig) FailReason(err error) string {
// ReverseProxyConfig how and for what hostnames can this cloudflared proxy
type ReverseProxyConfig struct {
TunnelHostname h2mux.TunnelHostname
Origin OriginConfig
Retries uint64
ConnectionTimeout time.Duration
CompressionQuality uint64
TunnelHostname h2mux.TunnelHostname `json:"tunnel_hostname"`
OriginConfigUnmarshaler *OriginConfigUnmarshaler `json:"origin_config"`
Retries uint64 `json:"retries"`
ConnectionTimeout time.Duration `json:"connection_timeout"`
CompressionQuality uint64 `json:"compression_quality"`
}
func NewReverseProxyConfig(
@ -109,14 +110,14 @@ func NewReverseProxyConfig(
compressionQuality uint64,
) (*ReverseProxyConfig, error) {
if originConfig == nil {
return nil, fmt.Errorf("NewReverseProxyConfig: originConfig was null")
return nil, fmt.Errorf("NewReverseProxyConfig: originConfigUnmarshaler was null")
}
return &ReverseProxyConfig{
TunnelHostname: h2mux.TunnelHostname(tunnelHostname),
Origin: originConfig,
Retries: retries,
ConnectionTimeout: connectionTimeout,
CompressionQuality: compressionQuality,
TunnelHostname: h2mux.TunnelHostname(tunnelHostname),
OriginConfigUnmarshaler: &OriginConfigUnmarshaler{originConfig},
Retries: retries,
ConnectionTimeout: connectionTimeout,
CompressionQuality: compressionQuality,
}, nil
}
@ -133,19 +134,40 @@ type OriginConfig interface {
isOriginConfig()
}
type originType int
const (
httpType originType = iota
wsType
helloWorldType
)
func (ot originType) String() string {
switch ot {
case httpType:
return "Http"
case wsType:
return "WebSocket"
case helloWorldType:
return "HelloWorld"
default:
return "unknown"
}
}
type HTTPOriginConfig struct {
URLString string `capnp:"urlString"`
TCPKeepAlive time.Duration `capnp:"tcpKeepAlive"`
DialDualStack bool
TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout"`
TLSVerify bool `capnp:"tlsVerify"`
OriginCAPool string
OriginServerName string
MaxIdleConnections uint64
IdleConnectionTimeout time.Duration
ProxyConnectionTimeout time.Duration
ExpectContinueTimeout time.Duration
ChunkedEncoding bool
URLString string `capnp:"urlString" mapstructure:"url_string"`
TCPKeepAlive time.Duration `capnp:"tcpKeepAlive" mapstructure:"tcp_keep_alive"`
DialDualStack bool `mapstructure:"dial_dual_stack"`
TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout" mapstructure:"tls_handshake_timeout"`
TLSVerify bool `capnp:"tlsVerify" mapstructure:"tls_verify"`
OriginCAPool string `mapstructure:"origin_ca_pool"`
OriginServerName string `mapstructure:"origin_server_name"`
MaxIdleConnections uint64 `mapstructure:"max_idle_connections"`
IdleConnectionTimeout time.Duration `mapstructure:"idle_connection_timeout"`
ProxyConnectionTimeout time.Duration `mapstructure:"proxy_connection_timeout"`
ExpectContinueTimeout time.Duration `mapstructure:"expect_continue_timeout"`
ChunkedEncoding bool `mapstructure:"chunked_encoding"`
}
func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
@ -187,10 +209,10 @@ func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
func (_ *HTTPOriginConfig) isOriginConfig() {}
type WebSocketOriginConfig struct {
URLString string `capnp:"urlString"`
TLSVerify bool `capnp:"tlsVerify"`
OriginCAPool string
OriginServerName string
URLString string `capnp:"urlString" mapstructure:"url_string"`
TLSVerify bool `capnp:"tlsVerify" mapstructure:"tls_verify"`
OriginCAPool string `mapstructure:"origin_ca_pool"`
OriginServerName string `mapstructure:"origin_server_name"`
}
func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error) {
@ -449,7 +471,7 @@ func UnmarshalDoHProxyConfig(s tunnelrpc.DoHProxyConfig) (*DoHProxyConfig, error
func MarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig, p *ReverseProxyConfig) error {
s.SetTunnelHostname(p.TunnelHostname.String())
switch config := p.Origin.(type) {
switch config := p.OriginConfigUnmarshaler.OriginConfig.(type) {
case *HTTPOriginConfig:
ss, err := s.Origin().NewHttp()
if err != nil {
@ -500,7 +522,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
if err != nil {
return nil, err
}
p.Origin = config
p.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{config}
case tunnelrpc.ReverseProxyConfig_origin_Which_websocket:
ss, err := s.Origin().Websocket()
if err != nil {
@ -510,7 +532,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
if err != nil {
return nil, err
}
p.Origin = config
p.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{config}
case tunnelrpc.ReverseProxyConfig_origin_Which_helloWorld:
ss, err := s.Origin().HelloWorld()
if err != nil {
@ -520,7 +542,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
if err != nil {
return nil, err
}
p.Origin = config
p.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{config}
}
p.Retries = s.Retries()
p.ConnectionTimeout = time.Duration(s.ConnectionTimeout())

@ -41,13 +41,13 @@ func TestClientConfig(t *testing.T) {
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleHTTPOriginConfig()
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfig()}
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleHTTPOriginConfig(unixPathOverride)
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfigUnixPath()}
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleWebSocketOriginConfig()
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleWebSocketOriginConfig()}
}),
}
}
@ -142,13 +142,13 @@ func TestReverseProxyConfig(t *testing.T) {
testCases := []*ReverseProxyConfig{
sampleReverseProxyConfig(),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleHTTPOriginConfig()
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfig()}
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleHTTPOriginConfig(unixPathOverride)
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfigUnixPath()}
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleWebSocketOriginConfig()
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleWebSocketOriginConfig()}
}),
}
for i, testCase := range testCases {
@ -285,11 +285,11 @@ func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig {
// applies any number of overrides to it, and returns it.
func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReverseProxyConfig {
sample := &ReverseProxyConfig{
TunnelHostname: "hijk.example.com",
Origin: &HelloWorldOriginConfig{},
Retries: 18,
ConnectionTimeout: 5 * time.Second,
CompressionQuality: 4,
TunnelHostname: "hijk.example.com",
OriginConfigUnmarshaler: &OriginConfigUnmarshaler{&HelloWorldOriginConfig{}},
Retries: 18,
ConnectionTimeout: 5 * time.Second,
CompressionQuality: 4,
}
sample.ensureNoZeroFields()
for _, f := range overrides {
@ -298,11 +298,9 @@ func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReversePr
return sample
}
// sampleHTTPOriginConfig initializes a new HTTPOriginConfig literal,
// applies any number of overrides to it, and returns it.
func sampleHTTPOriginConfig(overrides ...func(*HTTPOriginConfig)) *HTTPOriginConfig {
sample := &HTTPOriginConfig{
URLString: "https://example.com",
URLString: "https.example.com",
TCPKeepAlive: 7 * time.Second,
DialDualStack: true,
TLSHandshakeTimeout: 11 * time.Second,
@ -322,14 +320,28 @@ func sampleHTTPOriginConfig(overrides ...func(*HTTPOriginConfig)) *HTTPOriginCon
return sample
}
// unixPathOverride sets the URLString of the given HTTPOriginConfig to be a
// Unix socket (i.e. `unix:` scheme plus a file path)
func unixPathOverride(sample *HTTPOriginConfig) {
sample.URLString = "unix:/var/lib/file.sock"
func sampleHTTPOriginConfigUnixPath(overrides ...func(*HTTPOriginConfig)) *HTTPOriginConfig {
sample := &HTTPOriginConfig{
URLString: "unix:/var/lib/file.sock",
TCPKeepAlive: 7 * time.Second,
DialDualStack: true,
TLSHandshakeTimeout: 11 * time.Second,
TLSVerify: true,
OriginCAPool: "/etc/cert.pem",
OriginServerName: "secure.example.com",
MaxIdleConnections: 19,
IdleConnectionTimeout: 17 * time.Second,
ProxyConnectionTimeout: 15 * time.Second,
ExpectContinueTimeout: 21 * time.Second,
ChunkedEncoding: true,
}
sample.ensureNoZeroFields()
for _, f := range overrides {
f(sample)
}
return sample
}
// sampleWebSocketOriginConfig initializes a new WebSocketOriginConfig
// struct, applies any number of overrides to it, and returns it.
func sampleWebSocketOriginConfig(overrides ...func(*WebSocketOriginConfig)) *WebSocketOriginConfig {
sample := &WebSocketOriginConfig{
URLString: "ssh://example.com",

@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
)
@ -39,3 +40,43 @@ func (su *ScopeUnmarshaler) UnmarshalJSON(b []byte) error {
return fmt.Errorf("JSON should have been an object with one root key, either 'system_name' or 'group'")
}
type OriginConfigUnmarshaler struct {
OriginConfig OriginConfig
}
func (ocu *OriginConfigUnmarshaler) UnmarshalJSON(b []byte) error {
var originJSON map[string]interface{}
if err := json.Unmarshal(b, &originJSON); err != nil {
return errors.Wrapf(err, "cannot unmarshal %s into originJSON", string(b))
}
if originConfig, ok := originJSON[httpType.String()]; ok {
httpOriginConfig := &HTTPOriginConfig{}
if err := mapstructure.Decode(originConfig, httpOriginConfig); err != nil {
return errors.Wrapf(err, "cannot decode %+v into HTTPOriginConfig", originConfig)
}
ocu.OriginConfig = httpOriginConfig
return nil
}
if originConfig, ok := originJSON[wsType.String()]; ok {
wsOriginConfig := &WebSocketOriginConfig{}
if err := mapstructure.Decode(originConfig, wsOriginConfig); err != nil {
return errors.Wrapf(err, "cannot decode %+v into WebSocketOriginConfig", originConfig)
}
ocu.OriginConfig = wsOriginConfig
return nil
}
if originConfig, ok := originJSON[helloWorldType.String()]; ok {
helloWorldOriginConfig := &HelloWorldOriginConfig{}
if err := mapstructure.Decode(originConfig, helloWorldOriginConfig); err != nil {
return errors.Wrapf(err, "cannot decode %+v into HelloWorldOriginConfig", originConfig)
}
ocu.OriginConfig = helloWorldOriginConfig
return nil
}
return fmt.Errorf("cannot unmarshal %s into OriginConfig", string(b))
}

@ -1,6 +1,13 @@
package pogs
import "testing"
import (
"encoding/json"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestScopeUnmarshaler_UnmarshalJSON(t *testing.T) {
type fields struct {
@ -55,6 +62,180 @@ func TestScopeUnmarshaler_UnmarshalJSON(t *testing.T) {
}
}
func TestUnmarshalOrigin(t *testing.T) {
tests := []struct {
jsonLiteral string
exceptedOriginConfig OriginConfig
}{
{
jsonLiteral: `{
"Http":{
"url_string":"https://127.0.0.1:8080",
"tcp_keep_alive":30000000000,
"dial_dual_stack":true,
"tls_handshake_timeout":10000000000,
"tls_verify":true,
"origin_ca_pool":"",
"origin_server_name":"",
"max_idle_connections":100,
"idle_connection_timeout":90000000000,
"proxy_connection_timeout":90000000000,
"expect_continue_timeout":90000000000,
"chunked_encoding":true
}
}`,
exceptedOriginConfig: &HTTPOriginConfig{
URLString: "https://127.0.0.1:8080",
TCPKeepAlive: time.Second * 30,
DialDualStack: true,
TLSHandshakeTimeout: time.Second * 10,
TLSVerify: true,
OriginCAPool: "",
OriginServerName: "",
MaxIdleConnections: 100,
IdleConnectionTimeout: time.Second * 90,
ProxyConnectionTimeout: time.Second * 90,
ExpectContinueTimeout: time.Second * 90,
ChunkedEncoding: true,
},
},
{
jsonLiteral: `{
"WebSocket":{
"url_string":"https://127.0.0.1:9090",
"tls_verify":true,
"origin_ca_pool":"",
"origin_server_name":"ws.example.com"
}
}`,
exceptedOriginConfig: &WebSocketOriginConfig{
URLString: "https://127.0.0.1:9090",
TLSVerify: true,
OriginCAPool: "",
OriginServerName: "ws.example.com",
},
},
{
jsonLiteral: `{
"HelloWorld": {}
}`,
exceptedOriginConfig: &HelloWorldOriginConfig{},
},
}
for _, test := range tests {
originConfigJSON := strings.ReplaceAll(strings.ReplaceAll(test.jsonLiteral, "\n", ""), "\t", "")
var OriginConfigUnmarshaler OriginConfigUnmarshaler
err := json.Unmarshal([]byte(originConfigJSON), &OriginConfigUnmarshaler)
assert.NoError(t, err)
assert.Equal(t, test.exceptedOriginConfig, OriginConfigUnmarshaler.OriginConfig)
}
}
func TestUnmarshalClientConfig(t *testing.T) {
prettyClientConfigJSON := `{
"version":10,
"supervisor_config":{
"auto_update_frequency":86400000000000,
"metrics_update_frequency":300000000000,
"grace_period":30000000000
},
"edge_connection_config":{
"num_ha_connections":4,
"heartbeat_interval":5000000000,
"timeout":30000000000,
"max_failed_heartbeats":5,
"user_credential_path":"~/.cloudflared/cert.pem"
},
"doh_proxy_configs":[{
"listen_host": "localhost",
"listen_port": 53,
"upstreams": ["https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"]
}],
"reverse_proxy_configs":[{
"tunnel_hostname":"sdfjadk33.cftunnel.com",
"origin_config":{
"Http":{
"url_string":"https://127.0.0.1:8080",
"tcp_keep_alive":30000000000,
"dial_dual_stack":true,
"tls_handshake_timeout":10000000000,
"tls_verify":true,
"origin_ca_pool":"",
"origin_server_name":"",
"max_idle_connections":100,
"idle_connection_timeout":90000000000,
"proxy_connection_timeout":90000000000,
"expect_continue_timeout":90000000000,
"chunked_encoding":true
}
},
"retries":5,
"connection_timeout":30,
"compression_quality":0
}]
}`
// replace new line and tab
clientConfigJSON := strings.ReplaceAll(strings.ReplaceAll(prettyClientConfigJSON, "\n", ""), "\t", "")
var clientConfig ClientConfig
err := json.Unmarshal([]byte(clientConfigJSON), &clientConfig)
assert.NoError(t, err)
assert.Equal(t, Version(10), clientConfig.Version)
supervisorConfig := SupervisorConfig{
AutoUpdateFrequency: time.Hour * 24,
MetricsUpdateFrequency: time.Second * 300,
GracePeriod: time.Second * 30,
}
assert.Equal(t, supervisorConfig, *clientConfig.SupervisorConfig)
edgeConnectionConfig := EdgeConnectionConfig{
NumHAConnections: 4,
HeartbeatInterval: time.Second * 5,
Timeout: time.Second * 30,
MaxFailedHeartbeats: 5,
UserCredentialPath: "~/.cloudflared/cert.pem",
}
assert.Equal(t, edgeConnectionConfig, *clientConfig.EdgeConnectionConfig)
dohProxyConfig := DoHProxyConfig{
ListenHost: "localhost",
ListenPort: 53,
Upstreams: []string{"https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"},
}
assert.Len(t, clientConfig.DoHProxyConfigs, 1)
assert.Equal(t, dohProxyConfig, *clientConfig.DoHProxyConfigs[0])
reverseProxyConfig := ReverseProxyConfig{
TunnelHostname: "sdfjadk33.cftunnel.com",
OriginConfigUnmarshaler: &OriginConfigUnmarshaler{
OriginConfig: &HTTPOriginConfig{
URLString: "https://127.0.0.1:8080",
TCPKeepAlive: time.Second * 30,
DialDualStack: true,
TLSHandshakeTimeout: time.Second * 10,
TLSVerify: true,
OriginCAPool: "",
OriginServerName: "",
MaxIdleConnections: 100,
IdleConnectionTimeout: time.Second * 90,
ProxyConnectionTimeout: time.Second * 90,
ExpectContinueTimeout: time.Second * 90,
ChunkedEncoding: true,
},
},
Retries: 5,
ConnectionTimeout: 30,
CompressionQuality: 0,
}
assert.Len(t, clientConfig.ReverseProxyConfigs, 1)
assert.Equal(t, reverseProxyConfig, *clientConfig.ReverseProxyConfigs[0])
}
func eqScope(s1, s2 Scope) bool {
return s1.Value() == s2.Value() && s1.PostgresType() == s2.PostgresType()
}

@ -1,7 +0,0 @@
package http
import "net/http"
type Client interface {
Do(*http.Request) (*http.Response, error)
}

@ -1,2 +0,0 @@
// Package http is DEPRECATED. Use net/http instead.
package http

@ -1,161 +0,0 @@
package http
import (
"encoding/base64"
"encoding/json"
"errors"
"log"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
)
func WriteError(w http.ResponseWriter, code int, msg string) {
e := struct {
Error string `json:"error"`
}{
Error: msg,
}
b, err := json.Marshal(e)
if err != nil {
log.Printf("go-oidc: failed to marshal %#v: %v", e, err)
code = http.StatusInternalServerError
b = []byte(`{"error":"server_error"}`)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
w.Write(b)
}
// BasicAuth parses a username and password from the request's
// Authorization header. This was pulled from golang master:
// https://codereview.appspot.com/76540043
func BasicAuth(r *http.Request) (username, password string, ok bool) {
auth := r.Header.Get("Authorization")
if auth == "" {
return
}
if !strings.HasPrefix(auth, "Basic ") {
return
}
c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic "))
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return cs[:s], cs[s+1:], true
}
func cacheControlMaxAge(hdr string) (time.Duration, bool, error) {
for _, field := range strings.Split(hdr, ",") {
parts := strings.SplitN(strings.TrimSpace(field), "=", 2)
k := strings.ToLower(strings.TrimSpace(parts[0]))
if k != "max-age" {
continue
}
if len(parts) == 1 {
return 0, false, errors.New("max-age has no value")
}
v := strings.TrimSpace(parts[1])
if v == "" {
return 0, false, errors.New("max-age has empty value")
}
age, err := strconv.Atoi(v)
if err != nil {
return 0, false, err
}
if age <= 0 {
return 0, false, nil
}
return time.Duration(age) * time.Second, true, nil
}
return 0, false, nil
}
func expires(date, expires string) (time.Duration, bool, error) {
if date == "" || expires == "" {
return 0, false, nil
}
var te time.Time
var err error
if expires == "0" {
return 0, false, nil
}
te, err = time.Parse(time.RFC1123, expires)
if err != nil {
return 0, false, err
}
td, err := time.Parse(time.RFC1123, date)
if err != nil {
return 0, false, err
}
ttl := te.Sub(td)
// headers indicate data already expired, caller should not
// have to care about this case
if ttl <= 0 {
return 0, false, nil
}
return ttl, true, nil
}
func Cacheable(hdr http.Header) (time.Duration, bool, error) {
ttl, ok, err := cacheControlMaxAge(hdr.Get("Cache-Control"))
if err != nil || ok {
return ttl, ok, err
}
return expires(hdr.Get("Date"), hdr.Get("Expires"))
}
// MergeQuery appends additional query values to an existing URL.
func MergeQuery(u url.URL, q url.Values) url.URL {
uv := u.Query()
for k, vs := range q {
for _, v := range vs {
uv.Add(k, v)
}
}
u.RawQuery = uv.Encode()
return u
}
// NewResourceLocation appends a resource id to the end of the requested URL path.
func NewResourceLocation(reqURL *url.URL, id string) string {
var u url.URL
u = *reqURL
u.Path = path.Join(u.Path, id)
u.RawQuery = ""
u.Fragment = ""
return u.String()
}
// CopyRequest returns a clone of the provided *http.Request.
// The returned object is a shallow copy of the struct and a
// deep copy of its Header field.
func CopyRequest(r *http.Request) *http.Request {
r2 := *r
r2.Header = make(http.Header)
for k, s := range r.Header {
r2.Header[k] = s
}
return &r2
}

@ -1,29 +0,0 @@
package http
import (
"errors"
"net/url"
)
// ParseNonEmptyURL checks that a string is a parsable URL which is also not empty
// since `url.Parse("")` does not return an error. Must contian a scheme and a host.
func ParseNonEmptyURL(u string) (*url.URL, error) {
if u == "" {
return nil, errors.New("url is empty")
}
ur, err := url.Parse(u)
if err != nil {
return nil, err
}
if ur.Scheme == "" {
return nil, errors.New("url scheme is empty")
}
if ur.Host == "" {
return nil, errors.New("url host is empty")
}
return ur, nil
}

@ -1,2 +0,0 @@
// Package key is DEPRECATED. Use github.com/coreos/go-oidc instead.
package key

@ -1,153 +0,0 @@
package key
import (
"crypto/rand"
"crypto/rsa"
"encoding/hex"
"encoding/json"
"io"
"time"
"github.com/coreos/go-oidc/jose"
)
func NewPublicKey(jwk jose.JWK) *PublicKey {
return &PublicKey{jwk: jwk}
}
type PublicKey struct {
jwk jose.JWK
}
func (k *PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(&k.jwk)
}
func (k *PublicKey) UnmarshalJSON(data []byte) error {
var jwk jose.JWK
if err := json.Unmarshal(data, &jwk); err != nil {
return err
}
k.jwk = jwk
return nil
}
func (k *PublicKey) ID() string {
return k.jwk.ID
}
func (k *PublicKey) Verifier() (jose.Verifier, error) {
return jose.NewVerifierRSA(k.jwk)
}
type PrivateKey struct {
KeyID string
PrivateKey *rsa.PrivateKey
}
func (k *PrivateKey) ID() string {
return k.KeyID
}
func (k *PrivateKey) Signer() jose.Signer {
return jose.NewSignerRSA(k.ID(), *k.PrivateKey)
}
func (k *PrivateKey) JWK() jose.JWK {
return jose.JWK{
ID: k.KeyID,
Type: "RSA",
Alg: "RS256",
Use: "sig",
Exponent: k.PrivateKey.PublicKey.E,
Modulus: k.PrivateKey.PublicKey.N,
}
}
type KeySet interface {
ExpiresAt() time.Time
}
type PublicKeySet struct {
keys []PublicKey
index map[string]*PublicKey
expiresAt time.Time
}
func NewPublicKeySet(jwks []jose.JWK, exp time.Time) *PublicKeySet {
keys := make([]PublicKey, len(jwks))
index := make(map[string]*PublicKey)
for i, jwk := range jwks {
keys[i] = *NewPublicKey(jwk)
index[keys[i].ID()] = &keys[i]
}
return &PublicKeySet{
keys: keys,
index: index,
expiresAt: exp,
}
}
func (s *PublicKeySet) ExpiresAt() time.Time {
return s.expiresAt
}
func (s *PublicKeySet) Keys() []PublicKey {
return s.keys
}
func (s *PublicKeySet) Key(id string) *PublicKey {
return s.index[id]
}
type PrivateKeySet struct {
keys []*PrivateKey
ActiveKeyID string
expiresAt time.Time
}
func NewPrivateKeySet(keys []*PrivateKey, exp time.Time) *PrivateKeySet {
return &PrivateKeySet{
keys: keys,
ActiveKeyID: keys[0].ID(),
expiresAt: exp.UTC(),
}
}
func (s *PrivateKeySet) Keys() []*PrivateKey {
return s.keys
}
func (s *PrivateKeySet) ExpiresAt() time.Time {
return s.expiresAt
}
func (s *PrivateKeySet) Active() *PrivateKey {
for i, k := range s.keys {
if k.ID() == s.ActiveKeyID {
return s.keys[i]
}
}
return nil
}
type GeneratePrivateKeyFunc func() (*PrivateKey, error)
func GeneratePrivateKey() (*PrivateKey, error) {
pk, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
keyID := make([]byte, 20)
if _, err := io.ReadFull(rand.Reader, keyID); err != nil {
return nil, err
}
k := PrivateKey{
KeyID: hex.EncodeToString(keyID),
PrivateKey: pk,
}
return &k, nil
}

@ -1,99 +0,0 @@
package key
import (
"errors"
"time"
"github.com/jonboulle/clockwork"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/pkg/health"
)
type PrivateKeyManager interface {
ExpiresAt() time.Time
Signer() (jose.Signer, error)
JWKs() ([]jose.JWK, error)
PublicKeys() ([]PublicKey, error)
WritableKeySetRepo
health.Checkable
}
func NewPrivateKeyManager() PrivateKeyManager {
return &privateKeyManager{
clock: clockwork.NewRealClock(),
}
}
type privateKeyManager struct {
keySet *PrivateKeySet
clock clockwork.Clock
}
func (m *privateKeyManager) ExpiresAt() time.Time {
if m.keySet == nil {
return m.clock.Now().UTC()
}
return m.keySet.ExpiresAt()
}
func (m *privateKeyManager) Signer() (jose.Signer, error) {
if err := m.Healthy(); err != nil {
return nil, err
}
return m.keySet.Active().Signer(), nil
}
func (m *privateKeyManager) JWKs() ([]jose.JWK, error) {
if err := m.Healthy(); err != nil {
return nil, err
}
keys := m.keySet.Keys()
jwks := make([]jose.JWK, len(keys))
for i, k := range keys {
jwks[i] = k.JWK()
}
return jwks, nil
}
func (m *privateKeyManager) PublicKeys() ([]PublicKey, error) {
jwks, err := m.JWKs()
if err != nil {
return nil, err
}
keys := make([]PublicKey, len(jwks))
for i, jwk := range jwks {
keys[i] = *NewPublicKey(jwk)
}
return keys, nil
}
func (m *privateKeyManager) Healthy() error {
if m.keySet == nil {
return errors.New("private key manager uninitialized")
}
if len(m.keySet.Keys()) == 0 {
return errors.New("private key manager zero keys")
}
if m.keySet.ExpiresAt().Before(m.clock.Now().UTC()) {
return errors.New("private key manager keys expired")
}
return nil
}
func (m *privateKeyManager) Set(keySet KeySet) error {
privKeySet, ok := keySet.(*PrivateKeySet)
if !ok {
return errors.New("unable to cast to PrivateKeySet")
}
m.keySet = privKeySet
return nil
}

@ -1,55 +0,0 @@
package key
import (
"errors"
"sync"
)
var ErrorNoKeys = errors.New("no keys found")
type WritableKeySetRepo interface {
Set(KeySet) error
}
type ReadableKeySetRepo interface {
Get() (KeySet, error)
}
type PrivateKeySetRepo interface {
WritableKeySetRepo
ReadableKeySetRepo
}
func NewPrivateKeySetRepo() PrivateKeySetRepo {
return &memPrivateKeySetRepo{}
}
type memPrivateKeySetRepo struct {
mu sync.RWMutex
pks PrivateKeySet
}
func (r *memPrivateKeySetRepo) Set(ks KeySet) error {
pks, ok := ks.(*PrivateKeySet)
if !ok {
return errors.New("unable to cast to PrivateKeySet")
} else if pks == nil {
return errors.New("nil KeySet")
}
r.mu.Lock()
defer r.mu.Unlock()
r.pks = *pks
return nil
}
func (r *memPrivateKeySetRepo) Get() (KeySet, error) {
r.mu.RLock()
defer r.mu.RUnlock()
if r.pks.keys == nil {
return nil, ErrorNoKeys
}
return KeySet(&r.pks), nil
}

@ -1,159 +0,0 @@
package key
import (
"errors"
"log"
"time"
ptime "github.com/coreos/pkg/timeutil"
"github.com/jonboulle/clockwork"
)
var (
ErrorPrivateKeysExpired = errors.New("private keys have expired")
)
func NewPrivateKeyRotator(repo PrivateKeySetRepo, ttl time.Duration) *PrivateKeyRotator {
return &PrivateKeyRotator{
repo: repo,
ttl: ttl,
keep: 2,
generateKey: GeneratePrivateKey,
clock: clockwork.NewRealClock(),
}
}
type PrivateKeyRotator struct {
repo PrivateKeySetRepo
generateKey GeneratePrivateKeyFunc
clock clockwork.Clock
keep int
ttl time.Duration
}
func (r *PrivateKeyRotator) expiresAt() time.Time {
return r.clock.Now().UTC().Add(r.ttl)
}
func (r *PrivateKeyRotator) Healthy() error {
pks, err := r.privateKeySet()
if err != nil {
return err
}
if r.clock.Now().After(pks.ExpiresAt()) {
return ErrorPrivateKeysExpired
}
return nil
}
func (r *PrivateKeyRotator) privateKeySet() (*PrivateKeySet, error) {
ks, err := r.repo.Get()
if err != nil {
return nil, err
}
pks, ok := ks.(*PrivateKeySet)
if !ok {
return nil, errors.New("unable to cast to PrivateKeySet")
}
return pks, nil
}
func (r *PrivateKeyRotator) nextRotation() (time.Duration, error) {
pks, err := r.privateKeySet()
if err == ErrorNoKeys {
return 0, nil
}
if err != nil {
return 0, err
}
now := r.clock.Now()
// Ideally, we want to rotate after half the TTL has elapsed.
idealRotationTime := pks.ExpiresAt().Add(-r.ttl / 2)
// If we are past the ideal rotation time, rotate immediatly.
return max(0, idealRotationTime.Sub(now)), nil
}
func max(a, b time.Duration) time.Duration {
if a > b {
return a
}
return b
}
func (r *PrivateKeyRotator) Run() chan struct{} {
attempt := func() {
k, err := r.generateKey()
if err != nil {
log.Printf("go-oidc: failed generating signing key: %v", err)
return
}
exp := r.expiresAt()
if err := rotatePrivateKeys(r.repo, k, r.keep, exp); err != nil {
log.Printf("go-oidc: key rotation failed: %v", err)
return
}
}
stop := make(chan struct{})
go func() {
for {
var nextRotation time.Duration
var sleep time.Duration
var err error
for {
if nextRotation, err = r.nextRotation(); err == nil {
break
}
sleep = ptime.ExpBackoff(sleep, time.Minute)
log.Printf("go-oidc: error getting nextRotation, retrying in %v: %v", sleep, err)
time.Sleep(sleep)
}
select {
case <-r.clock.After(nextRotation):
attempt()
case <-stop:
return
}
}
}()
return stop
}
func rotatePrivateKeys(repo PrivateKeySetRepo, k *PrivateKey, keep int, exp time.Time) error {
ks, err := repo.Get()
if err != nil && err != ErrorNoKeys {
return err
}
var keys []*PrivateKey
if ks != nil {
pks, ok := ks.(*PrivateKeySet)
if !ok {
return errors.New("unable to cast to PrivateKeySet")
}
keys = pks.Keys()
}
keys = append([]*PrivateKey{k}, keys...)
if l := len(keys); l > keep {
keys = keys[0:keep]
}
nks := PrivateKeySet{
keys: keys,
ActiveKeyID: k.ID(),
expiresAt: exp,
}
return repo.Set(KeySet(&nks))
}