package tlsconfig

import (
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io/ioutil"
	"runtime"
	"sync"

	"github.com/getsentry/sentry-go"
	"github.com/pkg/errors"
	"github.com/rs/zerolog"
	"github.com/urfave/cli/v2"
)

const (
	OriginCAPoolFlag = "origin-ca-pool"
	CaCertFlag       = "cacert"
)

// CertReloader can load and reload a TLS certificate from a particular filepath.
// Hooks into tls.Config's GetCertificate to allow a TLS server to update its certificate without restarting.
type CertReloader struct {
	sync.Mutex
	certificate *tls.Certificate
	certPath    string
	keyPath     string
}

// NewCertReloader makes a CertReloader. It loads the cert during initialization to make sure certPath and keyPath are valid
func NewCertReloader(certPath, keyPath string) (*CertReloader, error) {
	cr := new(CertReloader)
	cr.certPath = certPath
	cr.keyPath = keyPath
	if err := cr.LoadCert(); err != nil {
		return nil, err
	}
	return cr, nil
}

// Cert returns the TLS certificate most recently read by the CertReloader.
// This method works as a direct utility method for tls.Config#Cert.
func (cr *CertReloader) Cert(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
	cr.Lock()
	defer cr.Unlock()
	return cr.certificate, nil
}

// ClientCert returns the TLS certificate most recently read by the CertReloader.
// This method works as a direct utility method for tls.Config#ClientCert.
func (cr *CertReloader) ClientCert(certRequestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
	cr.Lock()
	defer cr.Unlock()
	return cr.certificate, nil
}

// LoadCert loads a TLS certificate from the CertReloader's specified filepath.
// Call this after writing a new certificate to the disk (e.g. after renewing a certificate)
func (cr *CertReloader) LoadCert() error {
	cr.Lock()
	defer cr.Unlock()

	cert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)

	// Keep the old certificate if there's a problem reading the new one.
	if err != nil {
		sentry.CaptureException(fmt.Errorf("Error parsing X509 key pair: %v", err))
		return err
	}
	cr.certificate = &cert
	return nil
}

func LoadOriginCA(originCAPoolFilename string, log *zerolog.Logger) (*x509.CertPool, error) {
	var originCustomCAPool []byte

	if originCAPoolFilename != "" {
		var err error
		originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename)
		if err != nil {
			return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, OriginCAPoolFlag))
		}
	}

	originCertPool, err := loadOriginCertPool(originCustomCAPool, log)
	if err != nil {
		return nil, errors.Wrap(err, "error loading the certificate pool")
	}

	// Windows users should be notified that they can use the flag
	if runtime.GOOS == "windows" && originCAPoolFilename == "" {
		log.Info().Msgf("cloudflared does not support loading the system root certificate pool on Windows. Please use --%s <PATH> to specify the path to the certificate pool", OriginCAPoolFlag)
	}

	return originCertPool, nil
}

func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) {
	// First, obtain the system certificate pool
	certPool, err := x509.SystemCertPool()
	if err != nil {
		certPool = x509.NewCertPool()
	}

	// Next, append the Cloudflare CAs into the system pool
	cfRootCA, err := GetCloudflareRootCA()
	if err != nil {
		return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
	}
	for _, cert := range cfRootCA {
		certPool.AddCert(cert)
	}

	if originCAFilename == "" {
		return certPool, nil
	}

	customOriginCA, err := ioutil.ReadFile(originCAFilename)
	if err != nil {
		return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s", originCAFilename))
	}

	if !certPool.AppendCertsFromPEM(customOriginCA) {
		return nil, fmt.Errorf("error appending custom CA to cert pool")
	}
	return certPool, nil
}

func CreateTunnelConfig(c *cli.Context, serverName string) (*tls.Config, error) {
	var rootCAs []string
	if c.String(CaCertFlag) != "" {
		rootCAs = append(rootCAs, c.String(CaCertFlag))
	}

	userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName}
	tlsConfig, err := GetConfig(userConfig)
	if err != nil {
		return nil, err
	}

	if tlsConfig.RootCAs == nil {
		rootCAPool, err := x509.SystemCertPool()
		if err != nil {
			return nil, errors.Wrap(err, "unable to get x509 system cert pool")
		}
		cfRootCA, err := GetCloudflareRootCA()
		if err != nil {
			return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
		}
		for _, cert := range cfRootCA {
			rootCAPool.AddCert(cert)
		}
		tlsConfig.RootCAs = rootCAPool
	}

	if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
		return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config")
	}
	return tlsConfig, nil
}

func loadOriginCertPool(originCAPoolPEM []byte, log *zerolog.Logger) (*x509.CertPool, error) {
	// Get the global pool
	certPool, err := loadGlobalCertPool(log)
	if err != nil {
		return nil, err
	}

	// Then, add any custom origin CA pool the user may have passed
	if originCAPoolPEM != nil {
		if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
			log.Info().Msg("could not append the provided origin CA to the cloudflared certificate pool")
		}
	}

	return certPool, nil
}

func loadGlobalCertPool(log *zerolog.Logger) (*x509.CertPool, error) {
	// First, obtain the system certificate pool
	certPool, err := x509.SystemCertPool()
	if err != nil {
		if runtime.GOOS != "windows" { // See https://github.com/golang/go/issues/16736
			log.Err(err).Msg("error obtaining the system certificates")
		}
		certPool = x509.NewCertPool()
	}

	// Next, append the Cloudflare CAs into the system pool
	cfRootCA, err := GetCloudflareRootCA()
	if err != nil {
		return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
	}
	for _, cert := range cfRootCA {
		certPool.AddCert(cert)
	}

	// Finally, add the Hello certificate into the pool (since it's self-signed)
	helloCert, err := GetHelloCertificateX509()
	if err != nil {
		return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool")
	}
	certPool.AddCert(helloCert)

	return certPool, nil
}