diff --git a/tlsconfig/certreloader.go b/tlsconfig/certreloader.go new file mode 100644 index 00000000..e9f36e9c --- /dev/null +++ b/tlsconfig/certreloader.go @@ -0,0 +1,62 @@ +package tlsconfig + +import ( + "crypto/tls" + "errors" + "fmt" + "sync" + + tunnellog "github.com/cloudflare/cloudflared/log" + "github.com/getsentry/raven-go" + log "github.com/sirupsen/logrus" + "gopkg.in/urfave/cli.v2" +) + +// CertReloader can load and reload a TLS certificate from a particular filepath. +// Hooks into tls.Config's GetCertificate to allow a TLS server to update its certificate without restarting. +type CertReloader struct { + sync.Mutex + certificate *tls.Certificate + certPath string + keyPath string +} + +// NewCertReloader makes a CertReloader, memorizing the filepaths in the context/flags. +func NewCertReloader(c *cli.Context, f CLIFlags) (*CertReloader, error) { + if !c.IsSet(f.Cert) { + return nil, errors.New("CertReloader: cert not provided") + } + if !c.IsSet(f.Key) { + return nil, errors.New("CertReloader: key not provided") + } + cr := new(CertReloader) + cr.certPath = c.String(f.Cert) + cr.keyPath = c.String(f.Key) + cr.LoadCert() + return cr, nil +} + +// Cert returns the TLS certificate most recently read by the CertReloader. +func (cr *CertReloader) Cert(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cr.Lock() + defer cr.Unlock() + return cr.certificate, nil +} + +// LoadCert loads a TLS certificate from the CertReloader's specified filepath. +// Call this after writing a new certificate to the disk (e.g. after renewing a certificate) +func (cr *CertReloader) LoadCert() { + cr.Lock() + defer cr.Unlock() + + log.SetFormatter(&tunnellog.JSONFormatter{}) + log.Info("Reloading certificate") + cert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath) + + // Keep the old certificate if there's a problem reading the new one. + if err != nil { + raven.CaptureErrorAndWait(fmt.Errorf("Error parsing X509 key pair: %v", err), nil) + return + } + cr.certificate = &cert +} diff --git a/tlsconfig/tlsconfig.go b/tlsconfig/tlsconfig.go index 2ecda972..61d87f87 100644 --- a/tlsconfig/tlsconfig.go +++ b/tlsconfig/tlsconfig.go @@ -37,6 +37,18 @@ func (f CLIFlags) GetConfig(c *cli.Context) *tls.Config { config.Certificates = []tls.Certificate{cert} config.BuildNameToCertificate() } + return f.finishGettingConfig(c, config) +} + +func (f CLIFlags) GetConfigReloadableCert(c *cli.Context, cr *CertReloader) *tls.Config { + config := &tls.Config{ + GetCertificate: cr.Cert, + } + config.BuildNameToCertificate() + return f.finishGettingConfig(c, config) +} + +func (f CLIFlags) finishGettingConfig(c *cli.Context, config *tls.Config) *tls.Config { if c.IsSet(f.ClientCert) { // set of root certificate authorities that servers use if required to verify a client certificate // by the policy in ClientAuth diff --git a/tunnelrpc/pogs/tunnelrpc.go b/tunnelrpc/pogs/tunnelrpc.go index 8d297a2f..a2ba5fb2 100644 --- a/tunnelrpc/pogs/tunnelrpc.go +++ b/tunnelrpc/pogs/tunnelrpc.go @@ -66,8 +66,8 @@ func UnmarshalRegistrationOptions(s tunnelrpc.RegistrationOptions) (*Registratio } type Tag struct { - Name string - Value string + Name string `json:"name"` + Value string `json:"value"` } type ServerInfo struct {