diff --git a/tlsconfig/certreloader.go b/tlsconfig/certreloader.go index 7c0514b0..f2edce78 100644 --- a/tlsconfig/certreloader.go +++ b/tlsconfig/certreloader.go @@ -40,12 +40,21 @@ func NewCertReloader(certPath, keyPath string) (*CertReloader, error) { } // Cert returns the TLS certificate most recently read by the CertReloader. +// This method works as a direct utility method for tls.Config#Cert. func (cr *CertReloader) Cert(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { cr.Lock() defer cr.Unlock() return cr.certificate, nil } +// ClientCert returns the TLS certificate most recently read by the CertReloader. +// This method works as a direct utility method for tls.Config#ClientCert. +func (cr *CertReloader) ClientCert(certRequestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { + cr.Lock() + defer cr.Unlock() + return cr.certificate, nil +} + // LoadCert loads a TLS certificate from the CertReloader's specified filepath. // Call this after writing a new certificate to the disk (e.g. after renewing a certificate) func (cr *CertReloader) LoadCert() error { diff --git a/tlsconfig/tlsconfig.go b/tlsconfig/tlsconfig.go index fd255175..46490373 100644 --- a/tlsconfig/tlsconfig.go +++ b/tlsconfig/tlsconfig.go @@ -12,15 +12,16 @@ import ( // Config is the user provided parameters to create a tls.Config type TLSParameters struct { - Cert string - Key string - GetCertificate *CertReloader - ClientCAs []string - RootCAs []string - ServerName string - CurvePreferences []tls.CurveID - MinVersion uint16 // min tls version. If zero, TLS1.0 is defined as minimum. - MaxVersion uint16 // max tls version. If zero, last TLS version is used defined as limit (currently TLS1.3) + Cert string + Key string + GetCertificate *CertReloader + GetClientCertificate *CertReloader + ClientCAs []string + RootCAs []string + ServerName string + CurvePreferences []tls.CurveID + MinVersion uint16 // min tls version. If zero, TLS1.0 is defined as minimum. + MaxVersion uint16 // max tls version. If zero, last TLS version is used defined as limit (currently TLS1.3) } // GetConfig returns a TLS configuration according to the Config set by the user. @@ -43,6 +44,11 @@ func GetConfig(p *TLSParameters) (*tls.Config, error) { tlsconfig.GetCertificate = p.GetCertificate.Cert } + if p.GetClientCertificate != nil { + // GetClientCertificate is called when using an HTTP client library and mTLS is required. + tlsconfig.GetClientCertificate = p.GetClientCertificate.ClientCert + } + if len(p.ClientCAs) > 0 { // set of root certificate authorities that servers use if required to verify a client certificate // by the policy in ClientAuth