From b59fd4b7d85d7faa389b3619e0b42934729ae359 Mon Sep 17 00:00:00 2001 From: Chung-Ting Huang Date: Thu, 15 Nov 2018 09:43:50 -0600 Subject: [PATCH] TUN-1196: Allow TLS config client CA and root CA to be constructed from multiple certificates --- cmd/cloudflared/tunnel/cmd.go | 2 +- cmd/cloudflared/tunnel/configuration.go | 90 ++++++- cmd/cloudflared/tunnel/configuration_test.go | 214 ++++++++++++++++ tlsconfig/certreloader.go | 29 +-- tlsconfig/cloudflare_ca.go | 28 ++- tlsconfig/testcert.pem | 13 + tlsconfig/testcert2.pem | 13 + tlsconfig/testkey.pem | 10 + tlsconfig/tlsconfig.go | 179 +++++--------- tlsconfig/tlsconfig_test.go | 246 +++++-------------- websocket/websocket_test.go | 13 +- 11 files changed, 491 insertions(+), 346 deletions(-) create mode 100644 cmd/cloudflared/tunnel/configuration_test.go create mode 100644 tlsconfig/testcert.pem create mode 100644 tlsconfig/testcert2.pem create mode 100644 tlsconfig/testkey.pem diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index b6eb3aef..f2b21df4 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -438,7 +438,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag { }), altsrc.NewStringFlag(&cli.StringFlag{ Name: "cacert", - Usage: "Certificate Authority authenticating the Cloudflare tunnel connection.", + Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.", EnvVars: []string{"TUNNEL_CACERT"}, Hidden: true, }), diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index c08a9e5d..6647a73f 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -198,12 +198,18 @@ func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, version st return nil, errors.Wrap(err, "unable to connect to the origin") } + toEdgeTLSConfig, err := createTunnelConfig(c) + if err != nil { + logger.WithError(err).Error("unable to create TLS config to connect with edge") + return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") + } + return &origin.TunnelConfig{ EdgeAddrs: c.StringSlice("edge"), OriginUrl: originURL, Hostname: hostname, OriginCert: originCert, - TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")), + TlsConfig: toEdgeTLSConfig, ClientTlsConfig: httpTransport.TLSClientConfig, Retries: c.Uint("retries"), HeartbeatInterval: c.Duration("heartbeat-interval"), @@ -240,7 +246,7 @@ func loadCertPool(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) } } - originCertPool, err := tlsconfig.LoadOriginCertPool(originCustomCAPool) + originCertPool, err := loadOriginCertPool(originCustomCAPool) if err != nil { return nil, errors.Wrap(err, "error loading the certificate pool") } @@ -253,6 +259,86 @@ func loadCertPool(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) return originCertPool, nil } +func loadOriginCertPool(originCAPoolPEM []byte) (*x509.CertPool, error) { + // Get the global pool + certPool, err := loadGlobalCertPool() + if err != nil { + return nil, err + } + + // Then, add any custom origin CA pool the user may have passed + if originCAPoolPEM != nil { + if !certPool.AppendCertsFromPEM(originCAPoolPEM) { + logger.Warn("could not append the provided origin CA to the cloudflared certificate pool") + } + } + + return certPool, nil +} + +func loadGlobalCertPool() (*x509.CertPool, error) { + // First, obtain the system certificate pool + certPool, err := x509.SystemCertPool() + if err != nil { + if runtime.GOOS != "windows" { + logger.WithError(err).Warn("error obtaining the system certificates") + } + certPool = x509.NewCertPool() + } + + // Next, append the Cloudflare CAs into the system pool + cfRootCA, err := tlsconfig.GetCloudflareRootCA() + if err != nil { + return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool") + } + for _, cert := range cfRootCA { + certPool.AddCert(cert) + } + + // Finally, add the Hello certificate into the pool (since it's self-signed) + helloCert, err := tlsconfig.GetHelloCertificateX509() + if err != nil { + return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool") + } + certPool.AddCert(helloCert) + + return certPool, nil +} + +func createTunnelConfig(c *cli.Context) (*tls.Config, error) { + var rootCAs []string + if c.String("cacert") != "" { + rootCAs = append(rootCAs, c.String("cacert")) + } + edgeAddrs := c.StringSlice("edge") + + userConfig := &tlsconfig.TLSParameters{RootCAs: rootCAs} + tlsConfig, err := tlsconfig.GetConfig(userConfig) + if err != nil { + return nil, err + } + if tlsConfig.RootCAs == nil { + rootCAPool := x509.NewCertPool() + cfRootCA, err := tlsconfig.GetCloudflareRootCA() + if err != nil { + return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool") + } + for _, cert := range cfRootCA { + rootCAPool.AddCert(cert) + } + tlsConfig.RootCAs = rootCAPool + tlsConfig.ServerName = "cftunnel.com" + } else if len(edgeAddrs) > 0 { + // Set for development environments and for testing specific origintunneld instances + tlsConfig.ServerName, _, _ = net.SplitHostPort(edgeAddrs[0]) + } + + if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify { + return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config") + } + return tlsConfig, nil +} + func isRunningFromTerminal() bool { return terminal.IsTerminal(int(os.Stdout.Fd())) } diff --git a/cmd/cloudflared/tunnel/configuration_test.go b/cmd/cloudflared/tunnel/configuration_test.go new file mode 100644 index 00000000..c18c6c09 --- /dev/null +++ b/cmd/cloudflared/tunnel/configuration_test.go @@ -0,0 +1,214 @@ +// +build ignore +// TODO: Remove the above build tag and include this test when we start compiling with Golang 1.10.0+ + +package tunnel + +import ( + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Generated using `openssl req -newkey rsa:512 -nodes -x509 -days 3650` +var samplePEM = []byte(` +-----BEGIN CERTIFICATE----- +MIIB4DCCAYoCCQCb/H0EUrdXEjANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv +dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE +AwwIVGVzdCBPbmUwHhcNMTgwNDI2MTYxMDUxWhcNMjgwNDIzMTYxMDUxWjB3MQsw +CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG +A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn +eTERMA8GA1UEAwwIVGVzdCBPbmUwXDANBgkqhkiG9w0BAQEFAANLADBIAkEAwVQD +K0SJ25UFLznm2pU3zhzMEvpDEofHVNnCjk4mlDrtVop7PkKZ8pDEmuQANltUrxC8 +yHBE2wXMv+GlH+bDtwIDAQABMA0GCSqGSIb3DQEBCwUAA0EAjVYQzozIFPkt/HRY +uUoZ8zEHIDICb0syFf5VAjm9AgTwIPzUmD+c5vl6LWDnxq7L45nLCzhhQ6YmiwDz +X7Wcyg== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIB4DCCAYoCCQDZfCdAJ+mwzDANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv +dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE +AwwIVGVzdCBUd28wHhcNMTgwNDI2MTYxMTIwWhcNMjgwNDIzMTYxMTIwWjB3MQsw +CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG +A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn +eTERMA8GA1UEAwwIVGVzdCBUd28wXDANBgkqhkiG9w0BAQEFAANLADBIAkEAoHKp +ROVK3zCSsH7ocYeyRAML4V7SFAbZcb4WIwDnE08oMBVRkQVcW5tqEkvG3RiClfzV +wZIJ3CfqKIeSNSDU9wIDAQABMA0GCSqGSIb3DQEBCwUAA0EAJw2gUbnPiq4C2p5b +iWzlA9Q7aKo+VQ4H7IZS7tTccr59nVjvH/TG3eWujpnocr4TOqW9M3CK1DF9mUGP +3pQ3Jg== +-----END CERTIFICATE----- +`) + +var systemCertPoolSubjects []*pkix.Name + +type certificateFixture struct { + ou string + cn string +} + +func TestMain(m *testing.M) { + systemCertPool, err := x509.SystemCertPool() + if isUnrecoverableError(err) { + os.Exit(1) + } + + if systemCertPool == nil { + // On Windows, let's just assume the system cert pool was empty + systemCertPool = x509.NewCertPool() + } + + systemCertPoolSubjects, err = getCertPoolSubjects(systemCertPool) + if err != nil { + os.Exit(1) + } + + os.Exit(m.Run()) +} + +func TestLoadOriginCertPoolJustSystemPool(t *testing.T) { + certPoolSubjects := loadCertPoolSubjects(t, nil) + extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects) + + // Remove extra subjects from the cert pool + var filteredSystemCertPoolSubjects []*pkix.Name + + t.Log(extraSubjects) + +OUTER: + for _, subject := range certPoolSubjects { + for _, extraSubject := range extraSubjects { + if subject == extraSubject { + t.Log(extraSubject) + continue OUTER + } + } + + filteredSystemCertPoolSubjects = append(filteredSystemCertPoolSubjects, subject) + } + + assert.Equal(t, len(filteredSystemCertPoolSubjects), len(systemCertPoolSubjects)) + + difference := subjectSubtract(systemCertPoolSubjects, filteredSystemCertPoolSubjects) + assert.Equal(t, 0, len(difference)) +} + +func TestLoadOriginCertPoolCFCertificates(t *testing.T) { + certPoolSubjects := loadCertPoolSubjects(t, nil) + + extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects) + + expected := []*certificateFixture{ + {ou: "CloudFlare Origin SSL ECC Certificate Authority"}, + {ou: "CloudFlare Origin SSL Certificate Authority"}, + {cn: "origin-pull.cloudflare.net"}, + {cn: "Argo Tunnel Sample Hello Server Certificate"}, + } + + assertFixturesMatchSubjects(t, expected, extraSubjects) +} + +func TestLoadOriginCertPoolWithExtraPEMs(t *testing.T) { + certPoolWithoutPEMSubjects := loadCertPoolSubjects(t, nil) + certPoolWithPEMSubjects := loadCertPoolSubjects(t, samplePEM) + + difference := subjectSubtract(certPoolWithoutPEMSubjects, certPoolWithPEMSubjects) + + assert.Equal(t, 2, len(difference)) + + expected := []*certificateFixture{ + {cn: "Test One"}, + {cn: "Test Two"}, + } + + assertFixturesMatchSubjects(t, expected, difference) +} + +func loadCertPoolSubjects(t *testing.T, originCAPoolPEM []byte) []*pkix.Name { + certPool, err := loadOriginCertPool(originCAPoolPEM) + if isUnrecoverableError(err) { + t.Fatal(err) + } + assert.NotEmpty(t, certPool.Subjects()) + certPoolSubjects, err := getCertPoolSubjects(certPool) + if err != nil { + t.Fatal(err) + } + + return certPoolSubjects +} + +func assertFixturesMatchSubjects(t *testing.T, fixtures []*certificateFixture, subjects []*pkix.Name) { + assert.Equal(t, len(fixtures), len(subjects)) + + for _, fixture := range fixtures { + found := false + for _, subject := range subjects { + found = found || fixtureMatchesSubjectPredicate(fixture, subject) + } + + if !found { + t.Fail() + } + } +} + +func fixtureMatchesSubjectPredicate(fixture *certificateFixture, subject *pkix.Name) bool { + cnMatch := true + if fixture.cn != "" { + cnMatch = fixture.cn == subject.CommonName + } + + ouMatch := true + if fixture.ou != "" { + ouMatch = len(subject.OrganizationalUnit) > 0 && fixture.ou == subject.OrganizationalUnit[0] + } + + return cnMatch && ouMatch +} + +func subjectSubtract(left []*pkix.Name, right []*pkix.Name) []*pkix.Name { + var difference []*pkix.Name + + var found bool + for _, r := range right { + found = false + for _, l := range left { + if (*l).String() == (*r).String() { + found = true + } + } + + if !found { + difference = append(difference, r) + } + } + + return difference +} + +func getCertPoolSubjects(certPool *x509.CertPool) ([]*pkix.Name, error) { + var subjects []*pkix.Name + + for _, subject := range certPool.Subjects() { + var sequence pkix.RDNSequence + _, err := asn1.Unmarshal(subject, &sequence) + if err != nil { + return nil, err + } + + name := pkix.Name{} + name.FillFromRDNSequence(&sequence) + + subjects = append(subjects, &name) + } + + return subjects, nil +} + +func isUnrecoverableError(err error) bool { + return err != nil && err.Error() != "crypto/x509: system root pool is not available on Windows" +} diff --git a/tlsconfig/certreloader.go b/tlsconfig/certreloader.go index 78e3d279..d359c2e6 100644 --- a/tlsconfig/certreloader.go +++ b/tlsconfig/certreloader.go @@ -2,14 +2,10 @@ package tlsconfig import ( "crypto/tls" - "errors" "fmt" "sync" - tunnellog "github.com/cloudflare/cloudflared/log" "github.com/getsentry/raven-go" - log "github.com/sirupsen/logrus" - "gopkg.in/urfave/cli.v2" ) // CertReloader can load and reload a TLS certificate from a particular filepath. @@ -21,18 +17,14 @@ type CertReloader struct { keyPath string } -// NewCertReloader makes a CertReloader, memorizing the filepaths in the context/flags. -func NewCertReloader(c *cli.Context, f CLIFlags) (*CertReloader, error) { - if !c.IsSet(f.Cert) { - return nil, errors.New("CertReloader: cert not provided") - } - if !c.IsSet(f.Key) { - return nil, errors.New("CertReloader: key not provided") - } +// NewCertReloader makes a CertReloader. It loads the cert during initialization to make sure certPath and keyPath are valid +func NewCertReloader(certPath, keyPath string) (*CertReloader, error) { cr := new(CertReloader) - cr.certPath = c.String(f.Cert) - cr.keyPath = c.String(f.Key) - cr.LoadCert() + cr.certPath = certPath + cr.keyPath = keyPath + if err := cr.LoadCert(); err != nil { + return nil, err + } return cr, nil } @@ -45,18 +37,17 @@ func (cr *CertReloader) Cert(clientHello *tls.ClientHelloInfo) (*tls.Certificate // LoadCert loads a TLS certificate from the CertReloader's specified filepath. // Call this after writing a new certificate to the disk (e.g. after renewing a certificate) -func (cr *CertReloader) LoadCert() { +func (cr *CertReloader) LoadCert() error { cr.Lock() defer cr.Unlock() - log.SetFormatter(&tunnellog.JSONFormatter{}) - log.Info("Reloading certificate") cert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath) // Keep the old certificate if there's a problem reading the new one. if err != nil { raven.CaptureError(fmt.Errorf("Error parsing X509 key pair: %v", err), nil) - return + return err } cr.certificate = &cert + return nil } diff --git a/tlsconfig/cloudflare_ca.go b/tlsconfig/cloudflare_ca.go index 202be9eb..e70cc503 100644 --- a/tlsconfig/cloudflare_ca.go +++ b/tlsconfig/cloudflare_ca.go @@ -2,6 +2,7 @@ package tlsconfig import ( "crypto/x509" + "encoding/pem" ) // TODO: remove the Origin CA root certs when migrated to Authenticated Origin Pull certs @@ -85,11 +86,26 @@ QzMmZpRpIBB321ZBlcnlxiTJvWxvbCPHKHj20VwwAz7LONF59s84ZsOqfoBv8gKM s0s5dsq5zpLeaw== -----END CERTIFICATE-----`) -func GetCloudflareRootCA() *x509.CertPool { - ca := x509.NewCertPool() - if !ca.AppendCertsFromPEM([]byte(cloudflareRootCA)) { - // should never happen - panic("failure loading Cloudflare origin CA pem") +func GetCloudflareRootCA() ([]*x509.Certificate, error) { + var certs []*x509.Certificate + pemBlocks := cloudflareRootCA + for len(pemBlocks) > 0 { + var block *pem.Block + block, pemBlocks = pem.Decode(pemBlocks) + if block == nil { + break + } + if block.Type != "CERTIFICATE" { + continue + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, err + } + + certs = append(certs, cert) } - return ca + + return certs, nil } diff --git a/tlsconfig/testcert.pem b/tlsconfig/testcert.pem new file mode 100644 index 00000000..96a8fd3e --- /dev/null +++ b/tlsconfig/testcert.pem @@ -0,0 +1,13 @@ +-----BEGIN CERTIFICATE----- +MIICBjCCAbCgAwIBAgIJAPKk4bYMrSFMMA0GCSqGSIb3DQEBCwUAMF0xCzAJBgNV +BAYTAlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQK +DBBDbG91ZGZsYXJlLCBJbmMuMRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcNMTgxMTE1 +MjA1NzU3WhcNMjgxMTEyMjA1NzU3WjBdMQswCQYDVQQGEwJVUzEOMAwGA1UECAwF +VGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xvdWRmbGFyZSwgSW5j +LjESMBAGA1UEAwwJbG9jYWxob3N0MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAOQN +pTRn5wLf8SSI5x2kpDvbdDy7lfhamJ2En4Q+wy1cSKp8bn8/oyhVF7QTsimDGTI4 +45pV9nDfNJPYB3IW0x0CAwEAAaNTMFEwHQYDVR0OBBYEFE4jIa97mIEiYFa02X++ +uu5mCEn+MB8GA1UdIwQYMBaAFE4jIa97mIEiYFa02X++uu5mCEn+MA8GA1UdEwEB +/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADQQAkE+pDee0o5cNcZRUszy8sTQzB1Wlp +J6ucfmo16crqRaK7uGvhkMyibIc4D8z2Cxw3aI3IMMFoIIlYoYKiUcbd +-----END CERTIFICATE----- \ No newline at end of file diff --git a/tlsconfig/testcert2.pem b/tlsconfig/testcert2.pem new file mode 100644 index 00000000..500c14e0 --- /dev/null +++ b/tlsconfig/testcert2.pem @@ -0,0 +1,13 @@ +-----BEGIN CERTIFICATE----- +MIICBjCCAbCgAwIBAgIJAN6cXRTbJtFnMA0GCSqGSIb3DQEBCwUAMF0xCzAJBgNV +BAYTAlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRgwFgYDVQQK +DA9DbG91ZGZsYXJlLCBJbmMxEzARBgNVBAMMCmxvY2FsaG9zdDIwHhcNMTgxMTE1 +MjExMTU4WhcNMjgxMTEyMjExMTU4WjBdMQswCQYDVQQGEwJVUzEOMAwGA1UECAwF +VGV4YXMxDzANBgNVBAcMBkF1c3RpbjEYMBYGA1UECgwPQ2xvdWRmbGFyZSwgSW5j +MRMwEQYDVQQDDApsb2NhbGhvc3QyMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAKQx +IMZ6QgXoul2ITF/7sly4fW2Ol+a/AYw42zCWhVqOXv8AhY21I0Q8lkRR6wOroQwZ +O7jKKOcE5TnR/NRcZr8CAwEAAaNTMFEwHQYDVR0OBBYEFONKxLZc2RUD0KTHkAz4 +8nrb5688MB8GA1UdIwQYMBaAFONKxLZc2RUD0KTHkAz48nrb5688MA8GA1UdEwEB +/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADQQA56pwhvGpNPjyLcWfJHu/vI3ZjdoLB +LnrkRaMjJmv0H0Beh4upJhoz8u6lhMACerKQrrdQhEPB2u+maFrEBtmN +-----END CERTIFICATE----- \ No newline at end of file diff --git a/tlsconfig/testkey.pem b/tlsconfig/testkey.pem new file mode 100644 index 00000000..a0d861d2 --- /dev/null +++ b/tlsconfig/testkey.pem @@ -0,0 +1,10 @@ +-----BEGIN PRIVATE KEY----- +MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEA5A2lNGfnAt/xJIjn +HaSkO9t0PLuV+FqYnYSfhD7DLVxIqnxufz+jKFUXtBOyKYMZMjjjmlX2cN80k9gH +chbTHQIDAQABAkAoeDtu91lJa1AxuZG58vOqI6GW/Xr5naojmdts7m5YaAhDa7DE +zJUp4d8SP5cGBf1/PB3x6Cu9UviFNQ16wmzJAiEA8gUm4UYpWZD4Ze2l/xb+BK8D +IglSUIy1VxW+X1G55wMCIQDxOfXiFzPqnv/e5avKGv6CU11Dhmbi1OpiyybZTjGz +XwIhAM3bE/cJdqJ4bNBGE6umIupY8pFA3IMnLBempwbsvPOBAiEAgzJ+5OSxu92W +VGidsmJUIhWtF9i1hJFAmVLcYjwBFAkCICtjP/vv0qOZWk4mAAn2zz9UVWp45DSR +p/FA8V77ohXD +-----END PRIVATE KEY----- diff --git a/tlsconfig/tlsconfig.go b/tlsconfig/tlsconfig.go index cb0cbe98..8b402fce 100644 --- a/tlsconfig/tlsconfig.go +++ b/tlsconfig/tlsconfig.go @@ -6,150 +6,81 @@ import ( "crypto/tls" "crypto/x509" "io/ioutil" - "net" - "github.com/cloudflare/cloudflared/log" "github.com/pkg/errors" - "gopkg.in/urfave/cli.v2" - "runtime" ) -var logger = log.CreateLogger() - -// CLIFlags names the flags used to configure TLS for a command or subsystem. -// The nil value for a field means the flag is ignored. -type CLIFlags struct { - Cert string - Key string - ClientCert string - RootCA string +// Config is the user provided parameters to create a tls.Config +type TLSParameters struct { + Cert string + Key string + GetCertificate *CertReloader + ClientCAs []string + RootCAs []string + ServerName string + CurvePreferences []tls.CurveID } -// GetConfig returns a TLS configuration according to the flags defined in f and -// set by the user. -func (f CLIFlags) GetConfig(c *cli.Context) *tls.Config { - config := &tls.Config{} - - if c.IsSet(f.Cert) && c.IsSet(f.Key) { - cert, err := tls.LoadX509KeyPair(c.String(f.Cert), c.String(f.Key)) +// GetConfig returns a TLS configuration according to the Config set by the user. +func GetConfig(p *TLSParameters) (*tls.Config, error) { + tlsconfig := &tls.Config{} + if p.GetCertificate != nil { + tlsconfig.GetCertificate = p.GetCertificate.Cert + tlsconfig.BuildNameToCertificate() + } else if p.Cert != "" && p.Key != "" { + cert, err := tls.LoadX509KeyPair(p.Cert, p.Key) if err != nil { - logger.WithError(err).Fatal("Error parsing X509 key pair") + return nil, errors.Wrap(err, "Error parsing X509 key pair") } - config.Certificates = []tls.Certificate{cert} - config.BuildNameToCertificate() + tlsconfig.Certificates = []tls.Certificate{cert} + tlsconfig.BuildNameToCertificate() } - return f.finishGettingConfig(c, config) -} -func (f CLIFlags) GetConfigReloadableCert(c *cli.Context, cr *CertReloader) *tls.Config { - config := &tls.Config{ - GetCertificate: cr.Cert, - } - config.BuildNameToCertificate() - return f.finishGettingConfig(c, config) -} - -func (f CLIFlags) finishGettingConfig(c *cli.Context, config *tls.Config) *tls.Config { - if c.IsSet(f.ClientCert) { + if len(p.ClientCAs) > 0 { // set of root certificate authorities that servers use if required to verify a client certificate // by the policy in ClientAuth - config.ClientCAs = LoadCert(c.String(f.ClientCert)) + clientCAs, err := LoadCert(p.ClientCAs) + if err != nil { + return nil, errors.Wrap(err, "Error loading client CAs") + } + tlsconfig.ClientCAs = clientCAs // server's policy for TLS Client Authentication. Default is no client cert - config.ClientAuth = tls.RequireAndVerifyClientCert + tlsconfig.ClientAuth = tls.RequireAndVerifyClientCert } - // set of root certificate authorities that clients use when verifying server certificates - if c.IsSet(f.RootCA) { - config.RootCAs = LoadCert(c.String(f.RootCA)) + + if len(p.RootCAs) > 0 { + rootCAs, err := LoadCert(p.RootCAs) + if err != nil { + return nil, errors.Wrap(err, "Error loading root CAs") + } + tlsconfig.RootCAs = rootCAs } - // we optimize CurveP256 - config.CurvePreferences = []tls.CurveID{tls.CurveP256} - return config + + if p.ServerName != "" { + tlsconfig.ServerName = p.ServerName + } + + if len(p.CurvePreferences) > 0 { + tlsconfig.CurvePreferences = p.CurvePreferences + } else { + // Cloudflare optimize CurveP256 + tlsconfig.CurvePreferences = []tls.CurveID{tls.CurveP256} + } + + return tlsconfig, nil } // LoadCert creates a CertPool containing all certificates in a PEM-format file. -func LoadCert(certPath string) *x509.CertPool { - caCert, err := ioutil.ReadFile(certPath) - if err != nil { - logger.WithError(err).Fatalf("Error reading certificate %s", certPath) - } +func LoadCert(certPaths []string) (*x509.CertPool, error) { ca := x509.NewCertPool() - if !ca.AppendCertsFromPEM(caCert) { - logger.WithError(err).Fatalf("Error parsing certificate %s", certPath) - } - return ca -} - -func LoadGlobalCertPool() (*x509.CertPool, error) { - success := false - - // First, obtain the system certificate pool - certPool, systemCertPoolErr := x509.SystemCertPool() - if systemCertPoolErr != nil { - if runtime.GOOS != "windows" { - logger.Warnf("error obtaining the system certificates: %s", systemCertPoolErr) + for _, certPath := range certPaths { + caCert, err := ioutil.ReadFile(certPath) + if err != nil { + return nil, errors.Wrapf(err, "Error reading certificate %s", certPath) } - certPool = x509.NewCertPool() - } else { - success = true - } - - // Next, append the Cloudflare CA pool into the system pool - if !certPool.AppendCertsFromPEM(cloudflareRootCA) { - logger.Warn("could not append the CF certificate to the cloudflared certificate pool") - } else { - success = true - } - - if success != true { // Obtaining any of the CAs has failed; this is a fatal error - return nil, errors.New("error loading any of the CAs into the global certificate pool") - } - - // Finally, add the Hello certificate into the pool (since it's self-signed) - helloCertificate, err := GetHelloCertificateX509() - if err != nil { - logger.Warn("error obtaining the Hello server certificate") - } - - certPool.AddCert(helloCertificate) - - return certPool, nil -} - -func LoadOriginCertPool(originCAPoolPEM []byte) (*x509.CertPool, error) { - success := false - - // Get the global pool - certPool, globalPoolErr := LoadGlobalCertPool() - if globalPoolErr != nil { - certPool = x509.NewCertPool() - } else { - success = true - } - - // Then, add any custom origin CA pool the user may have passed - if originCAPoolPEM != nil { - if !certPool.AppendCertsFromPEM(originCAPoolPEM) { - logger.Warn("could not append the provided origin CA to the cloudflared certificate pool") - } else { - success = true + if !ca.AppendCertsFromPEM(caCert) { + return nil, errors.Wrapf(err, "Error parsing certificate %s", certPath) } } - - if success != true { - return nil, errors.New("error loading any of the CAs into the origin certificate pool") - } - - return certPool, nil -} - -func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config { - tlsConfig := CLIFlags{RootCA: "cacert"}.GetConfig(c) - if tlsConfig.RootCAs == nil { - tlsConfig.RootCAs = GetCloudflareRootCA() - tlsConfig.ServerName = "cftunnel.com" - } else if len(addrs) > 0 { - // Set for development environments and for testing specific origintunneld instances - tlsConfig.ServerName, _, _ = net.SplitHostPort(addrs[0]) - } - return tlsConfig + return ca, nil } diff --git a/tlsconfig/tlsconfig_test.go b/tlsconfig/tlsconfig_test.go index b4baac0b..2f8067da 100644 --- a/tlsconfig/tlsconfig_test.go +++ b/tlsconfig/tlsconfig_test.go @@ -1,214 +1,84 @@ -// +build ignore // TODO: Remove the above build tag and include this test when we start compiling with Golang 1.10.0+ package tlsconfig import ( - "crypto/x509" - "crypto/x509/pkix" - "encoding/asn1" - "os" + "crypto/tls" "testing" "github.com/stretchr/testify/assert" ) -// Generated using `openssl req -newkey rsa:512 -nodes -x509 -days 3650` -var samplePEM = []byte(` ------BEGIN CERTIFICATE----- -MIIB4DCCAYoCCQCb/H0EUrdXEjANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV -UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv -dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE -AwwIVGVzdCBPbmUwHhcNMTgwNDI2MTYxMDUxWhcNMjgwNDIzMTYxMDUxWjB3MQsw -CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG -A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn -eTERMA8GA1UEAwwIVGVzdCBPbmUwXDANBgkqhkiG9w0BAQEFAANLADBIAkEAwVQD -K0SJ25UFLznm2pU3zhzMEvpDEofHVNnCjk4mlDrtVop7PkKZ8pDEmuQANltUrxC8 -yHBE2wXMv+GlH+bDtwIDAQABMA0GCSqGSIb3DQEBCwUAA0EAjVYQzozIFPkt/HRY -uUoZ8zEHIDICb0syFf5VAjm9AgTwIPzUmD+c5vl6LWDnxq7L45nLCzhhQ6YmiwDz -X7Wcyg== ------END CERTIFICATE----- ------BEGIN CERTIFICATE----- -MIIB4DCCAYoCCQDZfCdAJ+mwzDANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV -UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv -dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE -AwwIVGVzdCBUd28wHhcNMTgwNDI2MTYxMTIwWhcNMjgwNDIzMTYxMTIwWjB3MQsw -CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG -A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn -eTERMA8GA1UEAwwIVGVzdCBUd28wXDANBgkqhkiG9w0BAQEFAANLADBIAkEAoHKp -ROVK3zCSsH7ocYeyRAML4V7SFAbZcb4WIwDnE08oMBVRkQVcW5tqEkvG3RiClfzV -wZIJ3CfqKIeSNSDU9wIDAQABMA0GCSqGSIb3DQEBCwUAA0EAJw2gUbnPiq4C2p5b -iWzlA9Q7aKo+VQ4H7IZS7tTccr59nVjvH/TG3eWujpnocr4TOqW9M3CK1DF9mUGP -3pQ3Jg== ------END CERTIFICATE----- -`) +// testcert.pem and testcert2.pem are Generated using `openssl req -newkey rsa:512 -nodes -x509 -days 3650` +const ( + testcertCommonName = "localhost" +) -var systemCertPoolSubjects []*pkix.Name +func TestGetFromEmptyConfig(t *testing.T) { + c := &TLSParameters{} -type certificateFixture struct { - ou string - cn string + tlsConfig, err := GetConfig(c) + assert.NoError(t, err) + assert.Empty(t, tlsConfig.Certificates) + + assert.Empty(t, tlsConfig.NameToCertificate) + + assert.Nil(t, tlsConfig.ClientCAs) + assert.Equal(t, tls.NoClientCert, tlsConfig.ClientAuth) + + assert.Nil(t, tlsConfig.RootCAs) + + assert.Len(t, tlsConfig.CurvePreferences, 1) + assert.Equal(t, tls.CurveP256, tlsConfig.CurvePreferences[0]) } -func TestMain(m *testing.M) { - systemCertPool, err := x509.SystemCertPool() - if isUnrecoverableError(err) { - os.Exit(1) - } +func TestGetConfig(t *testing.T) { + cert, err := tls.LoadX509KeyPair("testcert.pem", "testkey.pem") + assert.NoError(t, err) - if systemCertPool == nil { - // On Windows, let's just assume the system cert pool was empty - systemCertPool = x509.NewCertPool() + c := &TLSParameters{ + Cert: "testcert.pem", + Key: "testkey.pem", + ClientCAs: []string{"testcert.pem", "testcert2.pem"}, + RootCAs: []string{"testcert.pem", "testcert2.pem"}, + ServerName: "test", + CurvePreferences: []tls.CurveID{tls.CurveP384}, } + tlsConfig, err := GetConfig(c) + assert.NoError(t, err) + assert.Len(t, tlsConfig.Certificates, 1) + assert.Equal(t, cert, tlsConfig.Certificates[0]) - systemCertPoolSubjects, err = getCertPoolSubjects(systemCertPool) - if err != nil { - os.Exit(1) - } + assert.Equal(t, cert, *tlsConfig.NameToCertificate[testcertCommonName]) - os.Exit(m.Run()) + assert.NotNil(t, tlsConfig.ClientCAs) + assert.Equal(t, tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) + + assert.NotNil(t, tlsConfig.RootCAs) + + assert.Len(t, tlsConfig.CurvePreferences, 1) + assert.Equal(t, tls.CurveP384, tlsConfig.CurvePreferences[0]) } -func TestLoadOriginCertPoolJustSystemPool(t *testing.T) { - certPoolSubjects := loadCertPoolSubjects(t, nil) - extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects) +func TestCertReloader(t *testing.T) { + expectedCert, err := tls.LoadX509KeyPair("testcert.pem", "testkey.pem") + assert.NoError(t, err) - // Remove extra subjects from the cert pool - var filteredSystemCertPoolSubjects []*pkix.Name + certReloader, err := NewCertReloader("testcert.pem", "testkey.pem") + assert.NoError(t, err) - t.Log(extraSubjects) + chi := &tls.ClientHelloInfo{ServerName: testcertCommonName} + cert, err := certReloader.Cert(chi) + assert.NoError(t, err) + assert.Equal(t, expectedCert, *cert) -OUTER: - for _, subject := range certPoolSubjects { - for _, extraSubject := range extraSubjects { - if subject == extraSubject { - t.Log(extraSubject) - continue OUTER - } - } - - filteredSystemCertPoolSubjects = append(filteredSystemCertPoolSubjects, subject) + c := &TLSParameters{ + GetCertificate: certReloader, } + tlsConfig, err := GetConfig(c) + assert.NoError(t, err) - assert.Equal(t, len(filteredSystemCertPoolSubjects), len(systemCertPoolSubjects)) - - difference := subjectSubtract(systemCertPoolSubjects, filteredSystemCertPoolSubjects) - assert.Equal(t, 0, len(difference)) -} - -func TestLoadOriginCertPoolCFCertificates(t *testing.T) { - certPoolSubjects := loadCertPoolSubjects(t, nil) - - extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects) - - expected := []*certificateFixture{ - {ou: "CloudFlare Origin SSL ECC Certificate Authority"}, - {ou: "CloudFlare Origin SSL Certificate Authority"}, - {cn: "origin-pull.cloudflare.net"}, - {cn: "Argo Tunnel Sample Hello Server Certificate"}, - } - - assertFixturesMatchSubjects(t, expected, extraSubjects) -} - -func TestLoadOriginCertPoolWithExtraPEMs(t *testing.T) { - certPoolWithoutPEMSubjects := loadCertPoolSubjects(t, nil) - certPoolWithPEMSubjects := loadCertPoolSubjects(t, samplePEM) - - difference := subjectSubtract(certPoolWithoutPEMSubjects, certPoolWithPEMSubjects) - - assert.Equal(t, 2, len(difference)) - - expected := []*certificateFixture{ - {cn: "Test One"}, - {cn: "Test Two"}, - } - - assertFixturesMatchSubjects(t, expected, difference) -} - -func loadCertPoolSubjects(t *testing.T, originCAPoolPEM []byte) []*pkix.Name { - certPool, err := LoadOriginCertPool(originCAPoolPEM) - if isUnrecoverableError(err) { - t.Fatal(err) - } - assert.NotEmpty(t, certPool.Subjects()) - certPoolSubjects, err := getCertPoolSubjects(certPool) - if err != nil { - t.Fatal(err) - } - - return certPoolSubjects -} - -func assertFixturesMatchSubjects(t *testing.T, fixtures []*certificateFixture, subjects []*pkix.Name) { - assert.Equal(t, len(fixtures), len(subjects)) - - for _, fixture := range fixtures { - found := false - for _, subject := range subjects { - found = found || fixtureMatchesSubjectPredicate(fixture, subject) - } - - if !found { - t.Fail() - } - } -} - -func fixtureMatchesSubjectPredicate(fixture *certificateFixture, subject *pkix.Name) bool { - cnMatch := true - if fixture.cn != "" { - cnMatch = fixture.cn == subject.CommonName - } - - ouMatch := true - if fixture.ou != "" { - ouMatch = len(subject.OrganizationalUnit) > 0 && fixture.ou == subject.OrganizationalUnit[0] - } - - return cnMatch && ouMatch -} - -func subjectSubtract(left []*pkix.Name, right []*pkix.Name) []*pkix.Name { - var difference []*pkix.Name - - var found bool - for _, r := range right { - found = false - for _, l := range left { - if (*l).String() == (*r).String() { - found = true - } - } - - if !found { - difference = append(difference, r) - } - } - - return difference -} - -func getCertPoolSubjects(certPool *x509.CertPool) ([]*pkix.Name, error) { - var subjects []*pkix.Name - - for _, subject := range certPool.Subjects() { - var sequence pkix.RDNSequence - _, err := asn1.Unmarshal(subject, &sequence) - if err != nil { - return nil, err - } - - name := pkix.Name{} - name.FillFromRDNSequence(&sequence) - - subjects = append(subjects, &name) - } - - return subjects, nil -} - -func isUnrecoverableError(err error) bool { - return err != nil && err.Error() != "crypto/x509: system root pool is not available on Windows" + cert, err = tlsConfig.GetCertificate(chi) + assert.NoError(t, err) + assert.Equal(t, expectedCert, *cert) } diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go index 8383a422..1b5eef48 100644 --- a/websocket/websocket_test.go +++ b/websocket/websocket_test.go @@ -2,18 +2,17 @@ package websocket import ( "crypto/tls" + "crypto/x509" "io" "math/rand" "net/http" "testing" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - - "golang.org/x/net/websocket" - "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/tlsconfig" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "golang.org/x/net/websocket" ) const ( @@ -40,8 +39,10 @@ func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request { } func websocketClientTLSConfig(t *testing.T) *tls.Config { - certPool, err := tlsconfig.LoadOriginCertPool(nil) + certPool := x509.NewCertPool() + helloCert, err := tlsconfig.GetHelloCertificateX509() assert.NoError(t, err) + certPool.AddCert(helloCert) assert.NotNil(t, certPool) return &tls.Config{RootCAs: certPool} }