215 lines
5.8 KiB
Go
215 lines
5.8 KiB
Go
// +build ignore
|
|
// TODO: Remove the above build tag and include this test when we start compiling with Golang 1.10.0+
|
|
|
|
package tunnel
|
|
|
|
import (
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/asn1"
|
|
"os"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
// Generated using `openssl req -newkey rsa:512 -nodes -x509 -days 3650`
|
|
var samplePEM = []byte(`
|
|
-----BEGIN CERTIFICATE-----
|
|
MIIB4DCCAYoCCQCb/H0EUrdXEjANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV
|
|
UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv
|
|
dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE
|
|
AwwIVGVzdCBPbmUwHhcNMTgwNDI2MTYxMDUxWhcNMjgwNDIzMTYxMDUxWjB3MQsw
|
|
CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG
|
|
A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn
|
|
eTERMA8GA1UEAwwIVGVzdCBPbmUwXDANBgkqhkiG9w0BAQEFAANLADBIAkEAwVQD
|
|
K0SJ25UFLznm2pU3zhzMEvpDEofHVNnCjk4mlDrtVop7PkKZ8pDEmuQANltUrxC8
|
|
yHBE2wXMv+GlH+bDtwIDAQABMA0GCSqGSIb3DQEBCwUAA0EAjVYQzozIFPkt/HRY
|
|
uUoZ8zEHIDICb0syFf5VAjm9AgTwIPzUmD+c5vl6LWDnxq7L45nLCzhhQ6YmiwDz
|
|
X7Wcyg==
|
|
-----END CERTIFICATE-----
|
|
-----BEGIN CERTIFICATE-----
|
|
MIIB4DCCAYoCCQDZfCdAJ+mwzDANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV
|
|
UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv
|
|
dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE
|
|
AwwIVGVzdCBUd28wHhcNMTgwNDI2MTYxMTIwWhcNMjgwNDIzMTYxMTIwWjB3MQsw
|
|
CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG
|
|
A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn
|
|
eTERMA8GA1UEAwwIVGVzdCBUd28wXDANBgkqhkiG9w0BAQEFAANLADBIAkEAoHKp
|
|
ROVK3zCSsH7ocYeyRAML4V7SFAbZcb4WIwDnE08oMBVRkQVcW5tqEkvG3RiClfzV
|
|
wZIJ3CfqKIeSNSDU9wIDAQABMA0GCSqGSIb3DQEBCwUAA0EAJw2gUbnPiq4C2p5b
|
|
iWzlA9Q7aKo+VQ4H7IZS7tTccr59nVjvH/TG3eWujpnocr4TOqW9M3CK1DF9mUGP
|
|
3pQ3Jg==
|
|
-----END CERTIFICATE-----
|
|
`)
|
|
|
|
var systemCertPoolSubjects []*pkix.Name
|
|
|
|
type certificateFixture struct {
|
|
ou string
|
|
cn string
|
|
}
|
|
|
|
func TestMain(m *testing.M) {
|
|
systemCertPool, err := x509.SystemCertPool()
|
|
if isUnrecoverableError(err) {
|
|
os.Exit(1)
|
|
}
|
|
|
|
if systemCertPool == nil {
|
|
// On Windows, let's just assume the system cert pool was empty
|
|
systemCertPool = x509.NewCertPool()
|
|
}
|
|
|
|
systemCertPoolSubjects, err = getCertPoolSubjects(systemCertPool)
|
|
if err != nil {
|
|
os.Exit(1)
|
|
}
|
|
|
|
os.Exit(m.Run())
|
|
}
|
|
|
|
func TestLoadOriginCertPoolJustSystemPool(t *testing.T) {
|
|
certPoolSubjects := loadCertPoolSubjects(t, nil)
|
|
extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects)
|
|
|
|
// Remove extra subjects from the cert pool
|
|
var filteredSystemCertPoolSubjects []*pkix.Name
|
|
|
|
t.Log(extraSubjects)
|
|
|
|
OUTER:
|
|
for _, subject := range certPoolSubjects {
|
|
for _, extraSubject := range extraSubjects {
|
|
if subject == extraSubject {
|
|
t.Log(extraSubject)
|
|
continue OUTER
|
|
}
|
|
}
|
|
|
|
filteredSystemCertPoolSubjects = append(filteredSystemCertPoolSubjects, subject)
|
|
}
|
|
|
|
assert.Equal(t, len(filteredSystemCertPoolSubjects), len(systemCertPoolSubjects))
|
|
|
|
difference := subjectSubtract(systemCertPoolSubjects, filteredSystemCertPoolSubjects)
|
|
assert.Equal(t, 0, len(difference))
|
|
}
|
|
|
|
func TestLoadOriginCertPoolCFCertificates(t *testing.T) {
|
|
certPoolSubjects := loadCertPoolSubjects(t, nil)
|
|
|
|
extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects)
|
|
|
|
expected := []*certificateFixture{
|
|
{ou: "CloudFlare Origin SSL ECC Certificate Authority"},
|
|
{ou: "CloudFlare Origin SSL Certificate Authority"},
|
|
{cn: "origin-pull.cloudflare.net"},
|
|
{cn: "Argo Tunnel Sample Hello Server Certificate"},
|
|
}
|
|
|
|
assertFixturesMatchSubjects(t, expected, extraSubjects)
|
|
}
|
|
|
|
func TestLoadOriginCertPoolWithExtraPEMs(t *testing.T) {
|
|
certPoolWithoutPEMSubjects := loadCertPoolSubjects(t, nil)
|
|
certPoolWithPEMSubjects := loadCertPoolSubjects(t, samplePEM)
|
|
|
|
difference := subjectSubtract(certPoolWithoutPEMSubjects, certPoolWithPEMSubjects)
|
|
|
|
assert.Equal(t, 2, len(difference))
|
|
|
|
expected := []*certificateFixture{
|
|
{cn: "Test One"},
|
|
{cn: "Test Two"},
|
|
}
|
|
|
|
assertFixturesMatchSubjects(t, expected, difference)
|
|
}
|
|
|
|
func loadCertPoolSubjects(t *testing.T, originCAPoolPEM []byte) []*pkix.Name {
|
|
certPool, err := loadOriginCertPool(originCAPoolPEM)
|
|
if isUnrecoverableError(err) {
|
|
t.Fatal(err)
|
|
}
|
|
assert.NotEmpty(t, certPool.Subjects())
|
|
certPoolSubjects, err := getCertPoolSubjects(certPool)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return certPoolSubjects
|
|
}
|
|
|
|
func assertFixturesMatchSubjects(t *testing.T, fixtures []*certificateFixture, subjects []*pkix.Name) {
|
|
assert.Equal(t, len(fixtures), len(subjects))
|
|
|
|
for _, fixture := range fixtures {
|
|
found := false
|
|
for _, subject := range subjects {
|
|
found = found || fixtureMatchesSubjectPredicate(fixture, subject)
|
|
}
|
|
|
|
if !found {
|
|
t.Fail()
|
|
}
|
|
}
|
|
}
|
|
|
|
func fixtureMatchesSubjectPredicate(fixture *certificateFixture, subject *pkix.Name) bool {
|
|
cnMatch := true
|
|
if fixture.cn != "" {
|
|
cnMatch = fixture.cn == subject.CommonName
|
|
}
|
|
|
|
ouMatch := true
|
|
if fixture.ou != "" {
|
|
ouMatch = len(subject.OrganizationalUnit) > 0 && fixture.ou == subject.OrganizationalUnit[0]
|
|
}
|
|
|
|
return cnMatch && ouMatch
|
|
}
|
|
|
|
func subjectSubtract(left []*pkix.Name, right []*pkix.Name) []*pkix.Name {
|
|
var difference []*pkix.Name
|
|
|
|
var found bool
|
|
for _, r := range right {
|
|
found = false
|
|
for _, l := range left {
|
|
if (*l).String() == (*r).String() {
|
|
found = true
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
difference = append(difference, r)
|
|
}
|
|
}
|
|
|
|
return difference
|
|
}
|
|
|
|
func getCertPoolSubjects(certPool *x509.CertPool) ([]*pkix.Name, error) {
|
|
var subjects []*pkix.Name
|
|
|
|
for _, subject := range certPool.Subjects() {
|
|
var sequence pkix.RDNSequence
|
|
_, err := asn1.Unmarshal(subject, &sequence)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
name := pkix.Name{}
|
|
name.FillFromRDNSequence(&sequence)
|
|
|
|
subjects = append(subjects, &name)
|
|
}
|
|
|
|
return subjects, nil
|
|
}
|
|
|
|
func isUnrecoverableError(err error) bool {
|
|
return err != nil && err.Error() != "crypto/x509: system root pool is not available on Windows"
|
|
}
|