TUN-1196: Allow TLS config client CA and root CA to be constructed from multiple certificates
This commit is contained in:
parent
c85c8526e8
commit
b59fd4b7d8
|
@ -438,7 +438,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
}),
|
}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "cacert",
|
Name: "cacert",
|
||||||
Usage: "Certificate Authority authenticating the Cloudflare tunnel connection.",
|
Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.",
|
||||||
EnvVars: []string{"TUNNEL_CACERT"},
|
EnvVars: []string{"TUNNEL_CACERT"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -198,12 +198,18 @@ func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, version st
|
||||||
return nil, errors.Wrap(err, "unable to connect to the origin")
|
return nil, errors.Wrap(err, "unable to connect to the origin")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toEdgeTLSConfig, err := createTunnelConfig(c)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Error("unable to create TLS config to connect with edge")
|
||||||
|
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
|
||||||
|
}
|
||||||
|
|
||||||
return &origin.TunnelConfig{
|
return &origin.TunnelConfig{
|
||||||
EdgeAddrs: c.StringSlice("edge"),
|
EdgeAddrs: c.StringSlice("edge"),
|
||||||
OriginUrl: originURL,
|
OriginUrl: originURL,
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
OriginCert: originCert,
|
OriginCert: originCert,
|
||||||
TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
|
TlsConfig: toEdgeTLSConfig,
|
||||||
ClientTlsConfig: httpTransport.TLSClientConfig,
|
ClientTlsConfig: httpTransport.TLSClientConfig,
|
||||||
Retries: c.Uint("retries"),
|
Retries: c.Uint("retries"),
|
||||||
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
||||||
|
@ -240,7 +246,7 @@ func loadCertPool(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
originCertPool, err := tlsconfig.LoadOriginCertPool(originCustomCAPool)
|
originCertPool, err := loadOriginCertPool(originCustomCAPool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error loading the certificate pool")
|
return nil, errors.Wrap(err, "error loading the certificate pool")
|
||||||
}
|
}
|
||||||
|
@ -253,6 +259,86 @@ func loadCertPool(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error)
|
||||||
return originCertPool, nil
|
return originCertPool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loadOriginCertPool(originCAPoolPEM []byte) (*x509.CertPool, error) {
|
||||||
|
// Get the global pool
|
||||||
|
certPool, err := loadGlobalCertPool()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then, add any custom origin CA pool the user may have passed
|
||||||
|
if originCAPoolPEM != nil {
|
||||||
|
if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
|
||||||
|
logger.Warn("could not append the provided origin CA to the cloudflared certificate pool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return certPool, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadGlobalCertPool() (*x509.CertPool, error) {
|
||||||
|
// First, obtain the system certificate pool
|
||||||
|
certPool, err := x509.SystemCertPool()
|
||||||
|
if err != nil {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
logger.WithError(err).Warn("error obtaining the system certificates")
|
||||||
|
}
|
||||||
|
certPool = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, append the Cloudflare CAs into the system pool
|
||||||
|
cfRootCA, err := tlsconfig.GetCloudflareRootCA()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
|
||||||
|
}
|
||||||
|
for _, cert := range cfRootCA {
|
||||||
|
certPool.AddCert(cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, add the Hello certificate into the pool (since it's self-signed)
|
||||||
|
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool")
|
||||||
|
}
|
||||||
|
certPool.AddCert(helloCert)
|
||||||
|
|
||||||
|
return certPool, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTunnelConfig(c *cli.Context) (*tls.Config, error) {
|
||||||
|
var rootCAs []string
|
||||||
|
if c.String("cacert") != "" {
|
||||||
|
rootCAs = append(rootCAs, c.String("cacert"))
|
||||||
|
}
|
||||||
|
edgeAddrs := c.StringSlice("edge")
|
||||||
|
|
||||||
|
userConfig := &tlsconfig.TLSParameters{RootCAs: rootCAs}
|
||||||
|
tlsConfig, err := tlsconfig.GetConfig(userConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tlsConfig.RootCAs == nil {
|
||||||
|
rootCAPool := x509.NewCertPool()
|
||||||
|
cfRootCA, err := tlsconfig.GetCloudflareRootCA()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
|
||||||
|
}
|
||||||
|
for _, cert := range cfRootCA {
|
||||||
|
rootCAPool.AddCert(cert)
|
||||||
|
}
|
||||||
|
tlsConfig.RootCAs = rootCAPool
|
||||||
|
tlsConfig.ServerName = "cftunnel.com"
|
||||||
|
} else if len(edgeAddrs) > 0 {
|
||||||
|
// Set for development environments and for testing specific origintunneld instances
|
||||||
|
tlsConfig.ServerName, _, _ = net.SplitHostPort(edgeAddrs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
|
||||||
|
return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config")
|
||||||
|
}
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
func isRunningFromTerminal() bool {
|
func isRunningFromTerminal() bool {
|
||||||
return terminal.IsTerminal(int(os.Stdout.Fd()))
|
return terminal.IsTerminal(int(os.Stdout.Fd()))
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,214 @@
|
||||||
|
// +build ignore
|
||||||
|
// TODO: Remove the above build tag and include this test when we start compiling with Golang 1.10.0+
|
||||||
|
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/asn1"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Generated using `openssl req -newkey rsa:512 -nodes -x509 -days 3650`
|
||||||
|
var samplePEM = []byte(`
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIB4DCCAYoCCQCb/H0EUrdXEjANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV
|
||||||
|
UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv
|
||||||
|
dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE
|
||||||
|
AwwIVGVzdCBPbmUwHhcNMTgwNDI2MTYxMDUxWhcNMjgwNDIzMTYxMDUxWjB3MQsw
|
||||||
|
CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG
|
||||||
|
A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn
|
||||||
|
eTERMA8GA1UEAwwIVGVzdCBPbmUwXDANBgkqhkiG9w0BAQEFAANLADBIAkEAwVQD
|
||||||
|
K0SJ25UFLznm2pU3zhzMEvpDEofHVNnCjk4mlDrtVop7PkKZ8pDEmuQANltUrxC8
|
||||||
|
yHBE2wXMv+GlH+bDtwIDAQABMA0GCSqGSIb3DQEBCwUAA0EAjVYQzozIFPkt/HRY
|
||||||
|
uUoZ8zEHIDICb0syFf5VAjm9AgTwIPzUmD+c5vl6LWDnxq7L45nLCzhhQ6YmiwDz
|
||||||
|
X7Wcyg==
|
||||||
|
-----END CERTIFICATE-----
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIB4DCCAYoCCQDZfCdAJ+mwzDANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV
|
||||||
|
UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv
|
||||||
|
dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE
|
||||||
|
AwwIVGVzdCBUd28wHhcNMTgwNDI2MTYxMTIwWhcNMjgwNDIzMTYxMTIwWjB3MQsw
|
||||||
|
CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG
|
||||||
|
A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn
|
||||||
|
eTERMA8GA1UEAwwIVGVzdCBUd28wXDANBgkqhkiG9w0BAQEFAANLADBIAkEAoHKp
|
||||||
|
ROVK3zCSsH7ocYeyRAML4V7SFAbZcb4WIwDnE08oMBVRkQVcW5tqEkvG3RiClfzV
|
||||||
|
wZIJ3CfqKIeSNSDU9wIDAQABMA0GCSqGSIb3DQEBCwUAA0EAJw2gUbnPiq4C2p5b
|
||||||
|
iWzlA9Q7aKo+VQ4H7IZS7tTccr59nVjvH/TG3eWujpnocr4TOqW9M3CK1DF9mUGP
|
||||||
|
3pQ3Jg==
|
||||||
|
-----END CERTIFICATE-----
|
||||||
|
`)
|
||||||
|
|
||||||
|
var systemCertPoolSubjects []*pkix.Name
|
||||||
|
|
||||||
|
type certificateFixture struct {
|
||||||
|
ou string
|
||||||
|
cn string
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
systemCertPool, err := x509.SystemCertPool()
|
||||||
|
if isUnrecoverableError(err) {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemCertPool == nil {
|
||||||
|
// On Windows, let's just assume the system cert pool was empty
|
||||||
|
systemCertPool = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
|
||||||
|
systemCertPoolSubjects, err = getCertPoolSubjects(systemCertPool)
|
||||||
|
if err != nil {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Exit(m.Run())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadOriginCertPoolJustSystemPool(t *testing.T) {
|
||||||
|
certPoolSubjects := loadCertPoolSubjects(t, nil)
|
||||||
|
extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects)
|
||||||
|
|
||||||
|
// Remove extra subjects from the cert pool
|
||||||
|
var filteredSystemCertPoolSubjects []*pkix.Name
|
||||||
|
|
||||||
|
t.Log(extraSubjects)
|
||||||
|
|
||||||
|
OUTER:
|
||||||
|
for _, subject := range certPoolSubjects {
|
||||||
|
for _, extraSubject := range extraSubjects {
|
||||||
|
if subject == extraSubject {
|
||||||
|
t.Log(extraSubject)
|
||||||
|
continue OUTER
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredSystemCertPoolSubjects = append(filteredSystemCertPoolSubjects, subject)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, len(filteredSystemCertPoolSubjects), len(systemCertPoolSubjects))
|
||||||
|
|
||||||
|
difference := subjectSubtract(systemCertPoolSubjects, filteredSystemCertPoolSubjects)
|
||||||
|
assert.Equal(t, 0, len(difference))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadOriginCertPoolCFCertificates(t *testing.T) {
|
||||||
|
certPoolSubjects := loadCertPoolSubjects(t, nil)
|
||||||
|
|
||||||
|
extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects)
|
||||||
|
|
||||||
|
expected := []*certificateFixture{
|
||||||
|
{ou: "CloudFlare Origin SSL ECC Certificate Authority"},
|
||||||
|
{ou: "CloudFlare Origin SSL Certificate Authority"},
|
||||||
|
{cn: "origin-pull.cloudflare.net"},
|
||||||
|
{cn: "Argo Tunnel Sample Hello Server Certificate"},
|
||||||
|
}
|
||||||
|
|
||||||
|
assertFixturesMatchSubjects(t, expected, extraSubjects)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadOriginCertPoolWithExtraPEMs(t *testing.T) {
|
||||||
|
certPoolWithoutPEMSubjects := loadCertPoolSubjects(t, nil)
|
||||||
|
certPoolWithPEMSubjects := loadCertPoolSubjects(t, samplePEM)
|
||||||
|
|
||||||
|
difference := subjectSubtract(certPoolWithoutPEMSubjects, certPoolWithPEMSubjects)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, len(difference))
|
||||||
|
|
||||||
|
expected := []*certificateFixture{
|
||||||
|
{cn: "Test One"},
|
||||||
|
{cn: "Test Two"},
|
||||||
|
}
|
||||||
|
|
||||||
|
assertFixturesMatchSubjects(t, expected, difference)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadCertPoolSubjects(t *testing.T, originCAPoolPEM []byte) []*pkix.Name {
|
||||||
|
certPool, err := loadOriginCertPool(originCAPoolPEM)
|
||||||
|
if isUnrecoverableError(err) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
assert.NotEmpty(t, certPool.Subjects())
|
||||||
|
certPoolSubjects, err := getCertPoolSubjects(certPool)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return certPoolSubjects
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertFixturesMatchSubjects(t *testing.T, fixtures []*certificateFixture, subjects []*pkix.Name) {
|
||||||
|
assert.Equal(t, len(fixtures), len(subjects))
|
||||||
|
|
||||||
|
for _, fixture := range fixtures {
|
||||||
|
found := false
|
||||||
|
for _, subject := range subjects {
|
||||||
|
found = found || fixtureMatchesSubjectPredicate(fixture, subject)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fixtureMatchesSubjectPredicate(fixture *certificateFixture, subject *pkix.Name) bool {
|
||||||
|
cnMatch := true
|
||||||
|
if fixture.cn != "" {
|
||||||
|
cnMatch = fixture.cn == subject.CommonName
|
||||||
|
}
|
||||||
|
|
||||||
|
ouMatch := true
|
||||||
|
if fixture.ou != "" {
|
||||||
|
ouMatch = len(subject.OrganizationalUnit) > 0 && fixture.ou == subject.OrganizationalUnit[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return cnMatch && ouMatch
|
||||||
|
}
|
||||||
|
|
||||||
|
func subjectSubtract(left []*pkix.Name, right []*pkix.Name) []*pkix.Name {
|
||||||
|
var difference []*pkix.Name
|
||||||
|
|
||||||
|
var found bool
|
||||||
|
for _, r := range right {
|
||||||
|
found = false
|
||||||
|
for _, l := range left {
|
||||||
|
if (*l).String() == (*r).String() {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
difference = append(difference, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return difference
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCertPoolSubjects(certPool *x509.CertPool) ([]*pkix.Name, error) {
|
||||||
|
var subjects []*pkix.Name
|
||||||
|
|
||||||
|
for _, subject := range certPool.Subjects() {
|
||||||
|
var sequence pkix.RDNSequence
|
||||||
|
_, err := asn1.Unmarshal(subject, &sequence)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
name := pkix.Name{}
|
||||||
|
name.FillFromRDNSequence(&sequence)
|
||||||
|
|
||||||
|
subjects = append(subjects, &name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return subjects, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isUnrecoverableError(err error) bool {
|
||||||
|
return err != nil && err.Error() != "crypto/x509: system root pool is not available on Windows"
|
||||||
|
}
|
|
@ -2,14 +2,10 @@ package tlsconfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
tunnellog "github.com/cloudflare/cloudflared/log"
|
|
||||||
"github.com/getsentry/raven-go"
|
"github.com/getsentry/raven-go"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"gopkg.in/urfave/cli.v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// CertReloader can load and reload a TLS certificate from a particular filepath.
|
// CertReloader can load and reload a TLS certificate from a particular filepath.
|
||||||
|
@ -21,18 +17,14 @@ type CertReloader struct {
|
||||||
keyPath string
|
keyPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCertReloader makes a CertReloader, memorizing the filepaths in the context/flags.
|
// NewCertReloader makes a CertReloader. It loads the cert during initialization to make sure certPath and keyPath are valid
|
||||||
func NewCertReloader(c *cli.Context, f CLIFlags) (*CertReloader, error) {
|
func NewCertReloader(certPath, keyPath string) (*CertReloader, error) {
|
||||||
if !c.IsSet(f.Cert) {
|
|
||||||
return nil, errors.New("CertReloader: cert not provided")
|
|
||||||
}
|
|
||||||
if !c.IsSet(f.Key) {
|
|
||||||
return nil, errors.New("CertReloader: key not provided")
|
|
||||||
}
|
|
||||||
cr := new(CertReloader)
|
cr := new(CertReloader)
|
||||||
cr.certPath = c.String(f.Cert)
|
cr.certPath = certPath
|
||||||
cr.keyPath = c.String(f.Key)
|
cr.keyPath = keyPath
|
||||||
cr.LoadCert()
|
if err := cr.LoadCert(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return cr, nil
|
return cr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,18 +37,17 @@ func (cr *CertReloader) Cert(clientHello *tls.ClientHelloInfo) (*tls.Certificate
|
||||||
|
|
||||||
// LoadCert loads a TLS certificate from the CertReloader's specified filepath.
|
// LoadCert loads a TLS certificate from the CertReloader's specified filepath.
|
||||||
// Call this after writing a new certificate to the disk (e.g. after renewing a certificate)
|
// Call this after writing a new certificate to the disk (e.g. after renewing a certificate)
|
||||||
func (cr *CertReloader) LoadCert() {
|
func (cr *CertReloader) LoadCert() error {
|
||||||
cr.Lock()
|
cr.Lock()
|
||||||
defer cr.Unlock()
|
defer cr.Unlock()
|
||||||
|
|
||||||
log.SetFormatter(&tunnellog.JSONFormatter{})
|
|
||||||
log.Info("Reloading certificate")
|
|
||||||
cert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
|
cert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
|
||||||
|
|
||||||
// Keep the old certificate if there's a problem reading the new one.
|
// Keep the old certificate if there's a problem reading the new one.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
raven.CaptureError(fmt.Errorf("Error parsing X509 key pair: %v", err), nil)
|
raven.CaptureError(fmt.Errorf("Error parsing X509 key pair: %v", err), nil)
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
cr.certificate = &cert
|
cr.certificate = &cert
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package tlsconfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: remove the Origin CA root certs when migrated to Authenticated Origin Pull certs
|
// TODO: remove the Origin CA root certs when migrated to Authenticated Origin Pull certs
|
||||||
|
@ -85,11 +86,26 @@ QzMmZpRpIBB321ZBlcnlxiTJvWxvbCPHKHj20VwwAz7LONF59s84ZsOqfoBv8gKM
|
||||||
s0s5dsq5zpLeaw==
|
s0s5dsq5zpLeaw==
|
||||||
-----END CERTIFICATE-----`)
|
-----END CERTIFICATE-----`)
|
||||||
|
|
||||||
func GetCloudflareRootCA() *x509.CertPool {
|
func GetCloudflareRootCA() ([]*x509.Certificate, error) {
|
||||||
ca := x509.NewCertPool()
|
var certs []*x509.Certificate
|
||||||
if !ca.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
|
pemBlocks := cloudflareRootCA
|
||||||
// should never happen
|
for len(pemBlocks) > 0 {
|
||||||
panic("failure loading Cloudflare origin CA pem")
|
var block *pem.Block
|
||||||
|
block, pemBlocks = pem.Decode(pemBlocks)
|
||||||
|
if block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if block.Type != "CERTIFICATE" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certs = append(certs, cert)
|
||||||
}
|
}
|
||||||
return ca
|
|
||||||
|
return certs, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIICBjCCAbCgAwIBAgIJAPKk4bYMrSFMMA0GCSqGSIb3DQEBCwUAMF0xCzAJBgNV
|
||||||
|
BAYTAlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQK
|
||||||
|
DBBDbG91ZGZsYXJlLCBJbmMuMRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcNMTgxMTE1
|
||||||
|
MjA1NzU3WhcNMjgxMTEyMjA1NzU3WjBdMQswCQYDVQQGEwJVUzEOMAwGA1UECAwF
|
||||||
|
VGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xvdWRmbGFyZSwgSW5j
|
||||||
|
LjESMBAGA1UEAwwJbG9jYWxob3N0MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAOQN
|
||||||
|
pTRn5wLf8SSI5x2kpDvbdDy7lfhamJ2En4Q+wy1cSKp8bn8/oyhVF7QTsimDGTI4
|
||||||
|
45pV9nDfNJPYB3IW0x0CAwEAAaNTMFEwHQYDVR0OBBYEFE4jIa97mIEiYFa02X++
|
||||||
|
uu5mCEn+MB8GA1UdIwQYMBaAFE4jIa97mIEiYFa02X++uu5mCEn+MA8GA1UdEwEB
|
||||||
|
/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADQQAkE+pDee0o5cNcZRUszy8sTQzB1Wlp
|
||||||
|
J6ucfmo16crqRaK7uGvhkMyibIc4D8z2Cxw3aI3IMMFoIIlYoYKiUcbd
|
||||||
|
-----END CERTIFICATE-----
|
|
@ -0,0 +1,13 @@
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIICBjCCAbCgAwIBAgIJAN6cXRTbJtFnMA0GCSqGSIb3DQEBCwUAMF0xCzAJBgNV
|
||||||
|
BAYTAlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRgwFgYDVQQK
|
||||||
|
DA9DbG91ZGZsYXJlLCBJbmMxEzARBgNVBAMMCmxvY2FsaG9zdDIwHhcNMTgxMTE1
|
||||||
|
MjExMTU4WhcNMjgxMTEyMjExMTU4WjBdMQswCQYDVQQGEwJVUzEOMAwGA1UECAwF
|
||||||
|
VGV4YXMxDzANBgNVBAcMBkF1c3RpbjEYMBYGA1UECgwPQ2xvdWRmbGFyZSwgSW5j
|
||||||
|
MRMwEQYDVQQDDApsb2NhbGhvc3QyMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAKQx
|
||||||
|
IMZ6QgXoul2ITF/7sly4fW2Ol+a/AYw42zCWhVqOXv8AhY21I0Q8lkRR6wOroQwZ
|
||||||
|
O7jKKOcE5TnR/NRcZr8CAwEAAaNTMFEwHQYDVR0OBBYEFONKxLZc2RUD0KTHkAz4
|
||||||
|
8nrb5688MB8GA1UdIwQYMBaAFONKxLZc2RUD0KTHkAz48nrb5688MA8GA1UdEwEB
|
||||||
|
/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADQQA56pwhvGpNPjyLcWfJHu/vI3ZjdoLB
|
||||||
|
LnrkRaMjJmv0H0Beh4upJhoz8u6lhMACerKQrrdQhEPB2u+maFrEBtmN
|
||||||
|
-----END CERTIFICATE-----
|
|
@ -0,0 +1,10 @@
|
||||||
|
-----BEGIN PRIVATE KEY-----
|
||||||
|
MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEA5A2lNGfnAt/xJIjn
|
||||||
|
HaSkO9t0PLuV+FqYnYSfhD7DLVxIqnxufz+jKFUXtBOyKYMZMjjjmlX2cN80k9gH
|
||||||
|
chbTHQIDAQABAkAoeDtu91lJa1AxuZG58vOqI6GW/Xr5naojmdts7m5YaAhDa7DE
|
||||||
|
zJUp4d8SP5cGBf1/PB3x6Cu9UviFNQ16wmzJAiEA8gUm4UYpWZD4Ze2l/xb+BK8D
|
||||||
|
IglSUIy1VxW+X1G55wMCIQDxOfXiFzPqnv/e5avKGv6CU11Dhmbi1OpiyybZTjGz
|
||||||
|
XwIhAM3bE/cJdqJ4bNBGE6umIupY8pFA3IMnLBempwbsvPOBAiEAgzJ+5OSxu92W
|
||||||
|
VGidsmJUIhWtF9i1hJFAmVLcYjwBFAkCICtjP/vv0qOZWk4mAAn2zz9UVWp45DSR
|
||||||
|
p/FA8V77ohXD
|
||||||
|
-----END PRIVATE KEY-----
|
|
@ -6,150 +6,81 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/log"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"gopkg.in/urfave/cli.v2"
|
|
||||||
"runtime"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.CreateLogger()
|
// Config is the user provided parameters to create a tls.Config
|
||||||
|
type TLSParameters struct {
|
||||||
// CLIFlags names the flags used to configure TLS for a command or subsystem.
|
Cert string
|
||||||
// The nil value for a field means the flag is ignored.
|
Key string
|
||||||
type CLIFlags struct {
|
GetCertificate *CertReloader
|
||||||
Cert string
|
ClientCAs []string
|
||||||
Key string
|
RootCAs []string
|
||||||
ClientCert string
|
ServerName string
|
||||||
RootCA string
|
CurvePreferences []tls.CurveID
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConfig returns a TLS configuration according to the flags defined in f and
|
// GetConfig returns a TLS configuration according to the Config set by the user.
|
||||||
// set by the user.
|
func GetConfig(p *TLSParameters) (*tls.Config, error) {
|
||||||
func (f CLIFlags) GetConfig(c *cli.Context) *tls.Config {
|
tlsconfig := &tls.Config{}
|
||||||
config := &tls.Config{}
|
if p.GetCertificate != nil {
|
||||||
|
tlsconfig.GetCertificate = p.GetCertificate.Cert
|
||||||
if c.IsSet(f.Cert) && c.IsSet(f.Key) {
|
tlsconfig.BuildNameToCertificate()
|
||||||
cert, err := tls.LoadX509KeyPair(c.String(f.Cert), c.String(f.Key))
|
} else if p.Cert != "" && p.Key != "" {
|
||||||
|
cert, err := tls.LoadX509KeyPair(p.Cert, p.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Fatal("Error parsing X509 key pair")
|
return nil, errors.Wrap(err, "Error parsing X509 key pair")
|
||||||
}
|
}
|
||||||
config.Certificates = []tls.Certificate{cert}
|
tlsconfig.Certificates = []tls.Certificate{cert}
|
||||||
config.BuildNameToCertificate()
|
tlsconfig.BuildNameToCertificate()
|
||||||
}
|
}
|
||||||
return f.finishGettingConfig(c, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f CLIFlags) GetConfigReloadableCert(c *cli.Context, cr *CertReloader) *tls.Config {
|
if len(p.ClientCAs) > 0 {
|
||||||
config := &tls.Config{
|
|
||||||
GetCertificate: cr.Cert,
|
|
||||||
}
|
|
||||||
config.BuildNameToCertificate()
|
|
||||||
return f.finishGettingConfig(c, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f CLIFlags) finishGettingConfig(c *cli.Context, config *tls.Config) *tls.Config {
|
|
||||||
if c.IsSet(f.ClientCert) {
|
|
||||||
// set of root certificate authorities that servers use if required to verify a client certificate
|
// set of root certificate authorities that servers use if required to verify a client certificate
|
||||||
// by the policy in ClientAuth
|
// by the policy in ClientAuth
|
||||||
config.ClientCAs = LoadCert(c.String(f.ClientCert))
|
clientCAs, err := LoadCert(p.ClientCAs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error loading client CAs")
|
||||||
|
}
|
||||||
|
tlsconfig.ClientCAs = clientCAs
|
||||||
// server's policy for TLS Client Authentication. Default is no client cert
|
// server's policy for TLS Client Authentication. Default is no client cert
|
||||||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
tlsconfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
}
|
}
|
||||||
// set of root certificate authorities that clients use when verifying server certificates
|
|
||||||
if c.IsSet(f.RootCA) {
|
if len(p.RootCAs) > 0 {
|
||||||
config.RootCAs = LoadCert(c.String(f.RootCA))
|
rootCAs, err := LoadCert(p.RootCAs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error loading root CAs")
|
||||||
|
}
|
||||||
|
tlsconfig.RootCAs = rootCAs
|
||||||
}
|
}
|
||||||
// we optimize CurveP256
|
|
||||||
config.CurvePreferences = []tls.CurveID{tls.CurveP256}
|
if p.ServerName != "" {
|
||||||
return config
|
tlsconfig.ServerName = p.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p.CurvePreferences) > 0 {
|
||||||
|
tlsconfig.CurvePreferences = p.CurvePreferences
|
||||||
|
} else {
|
||||||
|
// Cloudflare optimize CurveP256
|
||||||
|
tlsconfig.CurvePreferences = []tls.CurveID{tls.CurveP256}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsconfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadCert creates a CertPool containing all certificates in a PEM-format file.
|
// LoadCert creates a CertPool containing all certificates in a PEM-format file.
|
||||||
func LoadCert(certPath string) *x509.CertPool {
|
func LoadCert(certPaths []string) (*x509.CertPool, error) {
|
||||||
caCert, err := ioutil.ReadFile(certPath)
|
|
||||||
if err != nil {
|
|
||||||
logger.WithError(err).Fatalf("Error reading certificate %s", certPath)
|
|
||||||
}
|
|
||||||
ca := x509.NewCertPool()
|
ca := x509.NewCertPool()
|
||||||
if !ca.AppendCertsFromPEM(caCert) {
|
for _, certPath := range certPaths {
|
||||||
logger.WithError(err).Fatalf("Error parsing certificate %s", certPath)
|
caCert, err := ioutil.ReadFile(certPath)
|
||||||
}
|
if err != nil {
|
||||||
return ca
|
return nil, errors.Wrapf(err, "Error reading certificate %s", certPath)
|
||||||
}
|
|
||||||
|
|
||||||
func LoadGlobalCertPool() (*x509.CertPool, error) {
|
|
||||||
success := false
|
|
||||||
|
|
||||||
// First, obtain the system certificate pool
|
|
||||||
certPool, systemCertPoolErr := x509.SystemCertPool()
|
|
||||||
if systemCertPoolErr != nil {
|
|
||||||
if runtime.GOOS != "windows" {
|
|
||||||
logger.Warnf("error obtaining the system certificates: %s", systemCertPoolErr)
|
|
||||||
}
|
}
|
||||||
certPool = x509.NewCertPool()
|
if !ca.AppendCertsFromPEM(caCert) {
|
||||||
} else {
|
return nil, errors.Wrapf(err, "Error parsing certificate %s", certPath)
|
||||||
success = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, append the Cloudflare CA pool into the system pool
|
|
||||||
if !certPool.AppendCertsFromPEM(cloudflareRootCA) {
|
|
||||||
logger.Warn("could not append the CF certificate to the cloudflared certificate pool")
|
|
||||||
} else {
|
|
||||||
success = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if success != true { // Obtaining any of the CAs has failed; this is a fatal error
|
|
||||||
return nil, errors.New("error loading any of the CAs into the global certificate pool")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, add the Hello certificate into the pool (since it's self-signed)
|
|
||||||
helloCertificate, err := GetHelloCertificateX509()
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("error obtaining the Hello server certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
certPool.AddCert(helloCertificate)
|
|
||||||
|
|
||||||
return certPool, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadOriginCertPool(originCAPoolPEM []byte) (*x509.CertPool, error) {
|
|
||||||
success := false
|
|
||||||
|
|
||||||
// Get the global pool
|
|
||||||
certPool, globalPoolErr := LoadGlobalCertPool()
|
|
||||||
if globalPoolErr != nil {
|
|
||||||
certPool = x509.NewCertPool()
|
|
||||||
} else {
|
|
||||||
success = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then, add any custom origin CA pool the user may have passed
|
|
||||||
if originCAPoolPEM != nil {
|
|
||||||
if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
|
|
||||||
logger.Warn("could not append the provided origin CA to the cloudflared certificate pool")
|
|
||||||
} else {
|
|
||||||
success = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return ca, nil
|
||||||
if success != true {
|
|
||||||
return nil, errors.New("error loading any of the CAs into the origin certificate pool")
|
|
||||||
}
|
|
||||||
|
|
||||||
return certPool, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config {
|
|
||||||
tlsConfig := CLIFlags{RootCA: "cacert"}.GetConfig(c)
|
|
||||||
if tlsConfig.RootCAs == nil {
|
|
||||||
tlsConfig.RootCAs = GetCloudflareRootCA()
|
|
||||||
tlsConfig.ServerName = "cftunnel.com"
|
|
||||||
} else if len(addrs) > 0 {
|
|
||||||
// Set for development environments and for testing specific origintunneld instances
|
|
||||||
tlsConfig.ServerName, _, _ = net.SplitHostPort(addrs[0])
|
|
||||||
}
|
|
||||||
return tlsConfig
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,214 +1,84 @@
|
||||||
// +build ignore
|
|
||||||
// TODO: Remove the above build tag and include this test when we start compiling with Golang 1.10.0+
|
// TODO: Remove the above build tag and include this test when we start compiling with Golang 1.10.0+
|
||||||
|
|
||||||
package tlsconfig
|
package tlsconfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/tls"
|
||||||
"crypto/x509/pkix"
|
|
||||||
"encoding/asn1"
|
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Generated using `openssl req -newkey rsa:512 -nodes -x509 -days 3650`
|
// testcert.pem and testcert2.pem are Generated using `openssl req -newkey rsa:512 -nodes -x509 -days 3650`
|
||||||
var samplePEM = []byte(`
|
const (
|
||||||
-----BEGIN CERTIFICATE-----
|
testcertCommonName = "localhost"
|
||||||
MIIB4DCCAYoCCQCb/H0EUrdXEjANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV
|
)
|
||||||
UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv
|
|
||||||
dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE
|
|
||||||
AwwIVGVzdCBPbmUwHhcNMTgwNDI2MTYxMDUxWhcNMjgwNDIzMTYxMDUxWjB3MQsw
|
|
||||||
CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG
|
|
||||||
A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn
|
|
||||||
eTERMA8GA1UEAwwIVGVzdCBPbmUwXDANBgkqhkiG9w0BAQEFAANLADBIAkEAwVQD
|
|
||||||
K0SJ25UFLznm2pU3zhzMEvpDEofHVNnCjk4mlDrtVop7PkKZ8pDEmuQANltUrxC8
|
|
||||||
yHBE2wXMv+GlH+bDtwIDAQABMA0GCSqGSIb3DQEBCwUAA0EAjVYQzozIFPkt/HRY
|
|
||||||
uUoZ8zEHIDICb0syFf5VAjm9AgTwIPzUmD+c5vl6LWDnxq7L45nLCzhhQ6YmiwDz
|
|
||||||
X7Wcyg==
|
|
||||||
-----END CERTIFICATE-----
|
|
||||||
-----BEGIN CERTIFICATE-----
|
|
||||||
MIIB4DCCAYoCCQDZfCdAJ+mwzDANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV
|
|
||||||
UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv
|
|
||||||
dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE
|
|
||||||
AwwIVGVzdCBUd28wHhcNMTgwNDI2MTYxMTIwWhcNMjgwNDIzMTYxMTIwWjB3MQsw
|
|
||||||
CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG
|
|
||||||
A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn
|
|
||||||
eTERMA8GA1UEAwwIVGVzdCBUd28wXDANBgkqhkiG9w0BAQEFAANLADBIAkEAoHKp
|
|
||||||
ROVK3zCSsH7ocYeyRAML4V7SFAbZcb4WIwDnE08oMBVRkQVcW5tqEkvG3RiClfzV
|
|
||||||
wZIJ3CfqKIeSNSDU9wIDAQABMA0GCSqGSIb3DQEBCwUAA0EAJw2gUbnPiq4C2p5b
|
|
||||||
iWzlA9Q7aKo+VQ4H7IZS7tTccr59nVjvH/TG3eWujpnocr4TOqW9M3CK1DF9mUGP
|
|
||||||
3pQ3Jg==
|
|
||||||
-----END CERTIFICATE-----
|
|
||||||
`)
|
|
||||||
|
|
||||||
var systemCertPoolSubjects []*pkix.Name
|
func TestGetFromEmptyConfig(t *testing.T) {
|
||||||
|
c := &TLSParameters{}
|
||||||
|
|
||||||
type certificateFixture struct {
|
tlsConfig, err := GetConfig(c)
|
||||||
ou string
|
assert.NoError(t, err)
|
||||||
cn string
|
assert.Empty(t, tlsConfig.Certificates)
|
||||||
|
|
||||||
|
assert.Empty(t, tlsConfig.NameToCertificate)
|
||||||
|
|
||||||
|
assert.Nil(t, tlsConfig.ClientCAs)
|
||||||
|
assert.Equal(t, tls.NoClientCert, tlsConfig.ClientAuth)
|
||||||
|
|
||||||
|
assert.Nil(t, tlsConfig.RootCAs)
|
||||||
|
|
||||||
|
assert.Len(t, tlsConfig.CurvePreferences, 1)
|
||||||
|
assert.Equal(t, tls.CurveP256, tlsConfig.CurvePreferences[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestGetConfig(t *testing.T) {
|
||||||
systemCertPool, err := x509.SystemCertPool()
|
cert, err := tls.LoadX509KeyPair("testcert.pem", "testkey.pem")
|
||||||
if isUnrecoverableError(err) {
|
assert.NoError(t, err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if systemCertPool == nil {
|
c := &TLSParameters{
|
||||||
// On Windows, let's just assume the system cert pool was empty
|
Cert: "testcert.pem",
|
||||||
systemCertPool = x509.NewCertPool()
|
Key: "testkey.pem",
|
||||||
|
ClientCAs: []string{"testcert.pem", "testcert2.pem"},
|
||||||
|
RootCAs: []string{"testcert.pem", "testcert2.pem"},
|
||||||
|
ServerName: "test",
|
||||||
|
CurvePreferences: []tls.CurveID{tls.CurveP384},
|
||||||
}
|
}
|
||||||
|
tlsConfig, err := GetConfig(c)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, tlsConfig.Certificates, 1)
|
||||||
|
assert.Equal(t, cert, tlsConfig.Certificates[0])
|
||||||
|
|
||||||
systemCertPoolSubjects, err = getCertPoolSubjects(systemCertPool)
|
assert.Equal(t, cert, *tlsConfig.NameToCertificate[testcertCommonName])
|
||||||
if err != nil {
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
os.Exit(m.Run())
|
assert.NotNil(t, tlsConfig.ClientCAs)
|
||||||
|
assert.Equal(t, tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth)
|
||||||
|
|
||||||
|
assert.NotNil(t, tlsConfig.RootCAs)
|
||||||
|
|
||||||
|
assert.Len(t, tlsConfig.CurvePreferences, 1)
|
||||||
|
assert.Equal(t, tls.CurveP384, tlsConfig.CurvePreferences[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadOriginCertPoolJustSystemPool(t *testing.T) {
|
func TestCertReloader(t *testing.T) {
|
||||||
certPoolSubjects := loadCertPoolSubjects(t, nil)
|
expectedCert, err := tls.LoadX509KeyPair("testcert.pem", "testkey.pem")
|
||||||
extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Remove extra subjects from the cert pool
|
certReloader, err := NewCertReloader("testcert.pem", "testkey.pem")
|
||||||
var filteredSystemCertPoolSubjects []*pkix.Name
|
assert.NoError(t, err)
|
||||||
|
|
||||||
t.Log(extraSubjects)
|
chi := &tls.ClientHelloInfo{ServerName: testcertCommonName}
|
||||||
|
cert, err := certReloader.Cert(chi)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expectedCert, *cert)
|
||||||
|
|
||||||
OUTER:
|
c := &TLSParameters{
|
||||||
for _, subject := range certPoolSubjects {
|
GetCertificate: certReloader,
|
||||||
for _, extraSubject := range extraSubjects {
|
|
||||||
if subject == extraSubject {
|
|
||||||
t.Log(extraSubject)
|
|
||||||
continue OUTER
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
filteredSystemCertPoolSubjects = append(filteredSystemCertPoolSubjects, subject)
|
|
||||||
}
|
}
|
||||||
|
tlsConfig, err := GetConfig(c)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, len(filteredSystemCertPoolSubjects), len(systemCertPoolSubjects))
|
cert, err = tlsConfig.GetCertificate(chi)
|
||||||
|
assert.NoError(t, err)
|
||||||
difference := subjectSubtract(systemCertPoolSubjects, filteredSystemCertPoolSubjects)
|
assert.Equal(t, expectedCert, *cert)
|
||||||
assert.Equal(t, 0, len(difference))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadOriginCertPoolCFCertificates(t *testing.T) {
|
|
||||||
certPoolSubjects := loadCertPoolSubjects(t, nil)
|
|
||||||
|
|
||||||
extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects)
|
|
||||||
|
|
||||||
expected := []*certificateFixture{
|
|
||||||
{ou: "CloudFlare Origin SSL ECC Certificate Authority"},
|
|
||||||
{ou: "CloudFlare Origin SSL Certificate Authority"},
|
|
||||||
{cn: "origin-pull.cloudflare.net"},
|
|
||||||
{cn: "Argo Tunnel Sample Hello Server Certificate"},
|
|
||||||
}
|
|
||||||
|
|
||||||
assertFixturesMatchSubjects(t, expected, extraSubjects)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadOriginCertPoolWithExtraPEMs(t *testing.T) {
|
|
||||||
certPoolWithoutPEMSubjects := loadCertPoolSubjects(t, nil)
|
|
||||||
certPoolWithPEMSubjects := loadCertPoolSubjects(t, samplePEM)
|
|
||||||
|
|
||||||
difference := subjectSubtract(certPoolWithoutPEMSubjects, certPoolWithPEMSubjects)
|
|
||||||
|
|
||||||
assert.Equal(t, 2, len(difference))
|
|
||||||
|
|
||||||
expected := []*certificateFixture{
|
|
||||||
{cn: "Test One"},
|
|
||||||
{cn: "Test Two"},
|
|
||||||
}
|
|
||||||
|
|
||||||
assertFixturesMatchSubjects(t, expected, difference)
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadCertPoolSubjects(t *testing.T, originCAPoolPEM []byte) []*pkix.Name {
|
|
||||||
certPool, err := LoadOriginCertPool(originCAPoolPEM)
|
|
||||||
if isUnrecoverableError(err) {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assert.NotEmpty(t, certPool.Subjects())
|
|
||||||
certPoolSubjects, err := getCertPoolSubjects(certPool)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return certPoolSubjects
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertFixturesMatchSubjects(t *testing.T, fixtures []*certificateFixture, subjects []*pkix.Name) {
|
|
||||||
assert.Equal(t, len(fixtures), len(subjects))
|
|
||||||
|
|
||||||
for _, fixture := range fixtures {
|
|
||||||
found := false
|
|
||||||
for _, subject := range subjects {
|
|
||||||
found = found || fixtureMatchesSubjectPredicate(fixture, subject)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
t.Fail()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func fixtureMatchesSubjectPredicate(fixture *certificateFixture, subject *pkix.Name) bool {
|
|
||||||
cnMatch := true
|
|
||||||
if fixture.cn != "" {
|
|
||||||
cnMatch = fixture.cn == subject.CommonName
|
|
||||||
}
|
|
||||||
|
|
||||||
ouMatch := true
|
|
||||||
if fixture.ou != "" {
|
|
||||||
ouMatch = len(subject.OrganizationalUnit) > 0 && fixture.ou == subject.OrganizationalUnit[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
return cnMatch && ouMatch
|
|
||||||
}
|
|
||||||
|
|
||||||
func subjectSubtract(left []*pkix.Name, right []*pkix.Name) []*pkix.Name {
|
|
||||||
var difference []*pkix.Name
|
|
||||||
|
|
||||||
var found bool
|
|
||||||
for _, r := range right {
|
|
||||||
found = false
|
|
||||||
for _, l := range left {
|
|
||||||
if (*l).String() == (*r).String() {
|
|
||||||
found = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
difference = append(difference, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return difference
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCertPoolSubjects(certPool *x509.CertPool) ([]*pkix.Name, error) {
|
|
||||||
var subjects []*pkix.Name
|
|
||||||
|
|
||||||
for _, subject := range certPool.Subjects() {
|
|
||||||
var sequence pkix.RDNSequence
|
|
||||||
_, err := asn1.Unmarshal(subject, &sequence)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
name := pkix.Name{}
|
|
||||||
name.FillFromRDNSequence(&sequence)
|
|
||||||
|
|
||||||
subjects = append(subjects, &name)
|
|
||||||
}
|
|
||||||
|
|
||||||
return subjects, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isUnrecoverableError(err error) bool {
|
|
||||||
return err != nil && err.Error() != "crypto/x509: system root pool is not available on Windows"
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,18 +2,17 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"golang.org/x/net/websocket"
|
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/hello"
|
"github.com/cloudflare/cloudflared/hello"
|
||||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/net/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -40,8 +39,10 @@ func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
|
||||||
}
|
}
|
||||||
|
|
||||||
func websocketClientTLSConfig(t *testing.T) *tls.Config {
|
func websocketClientTLSConfig(t *testing.T) *tls.Config {
|
||||||
certPool, err := tlsconfig.LoadOriginCertPool(nil)
|
certPool := x509.NewCertPool()
|
||||||
|
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
certPool.AddCert(helloCert)
|
||||||
assert.NotNil(t, certPool)
|
assert.NotNil(t, certPool)
|
||||||
return &tls.Config{RootCAs: certPool}
|
return &tls.Config{RootCAs: certPool}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue