diff --git a/tlsconfig/certreloader.go b/tlsconfig/certreloader.go index 00ebab39..5ce83934 100644 --- a/tlsconfig/certreloader.go +++ b/tlsconfig/certreloader.go @@ -89,16 +89,35 @@ func LoadOriginCA(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) return originCertPool, nil } -func LoadCustomCertPool(customCertFilename string) (*x509.CertPool, error) { - pool := x509.NewCertPool() - customCAPoolPEM, err := ioutil.ReadFile(customCertFilename) +func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) { + // First, obtain the system certificate pool + certPool, err := x509.SystemCertPool() if err != nil { - return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s", customCertFilename)) + certPool = x509.NewCertPool() } - if !pool.AppendCertsFromPEM(customCAPoolPEM) { + + // Next, append the Cloudflare CAs into the system pool + cfRootCA, err := GetCloudflareRootCA() + if err != nil { + return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool") + } + for _, cert := range cfRootCA { + certPool.AddCert(cert) + } + + if originCAFilename == "" { + return certPool, nil + } + + customOriginCA, err := ioutil.ReadFile(originCAFilename) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s", originCAFilename)) + } + + if !certPool.AppendCertsFromPEM(customOriginCA) { return nil, fmt.Errorf("error appending custom CA to cert pool") } - return pool, nil + return certPool, nil } func CreateTunnelConfig(c *cli.Context) (*tls.Config, error) { diff --git a/tunnelrpc/pogs/config.go b/tunnelrpc/pogs/config.go index f63c30e9..570a8ed2 100644 --- a/tunnelrpc/pogs/config.go +++ b/tunnelrpc/pogs/config.go @@ -180,7 +180,7 @@ func (up *UnixPath) Addr() string { } func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) { - rootCAs, err := tlsconfig.LoadCustomCertPool(hc.OriginCAPool) + rootCAs, err := tlsconfig.LoadCustomOriginCA(hc.OriginCAPool) if err != nil { return nil, err } @@ -220,7 +220,7 @@ type WebSocketOriginConfig struct { } func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error) { - rootCAs, err := tlsconfig.LoadCustomCertPool(wsc.OriginCAPool) + rootCAs, err := tlsconfig.LoadCustomOriginCA(wsc.OriginCAPool) if err != nil { return nil, err }