package key import ( "errors" "log" "time" ptime "github.com/coreos/pkg/timeutil" "github.com/jonboulle/clockwork" ) var ( ErrorPrivateKeysExpired = errors.New("private keys have expired") ) func NewPrivateKeyRotator(repo PrivateKeySetRepo, ttl time.Duration) *PrivateKeyRotator { return &PrivateKeyRotator{ repo: repo, ttl: ttl, keep: 2, generateKey: GeneratePrivateKey, clock: clockwork.NewRealClock(), } } type PrivateKeyRotator struct { repo PrivateKeySetRepo generateKey GeneratePrivateKeyFunc clock clockwork.Clock keep int ttl time.Duration } func (r *PrivateKeyRotator) expiresAt() time.Time { return r.clock.Now().UTC().Add(r.ttl) } func (r *PrivateKeyRotator) Healthy() error { pks, err := r.privateKeySet() if err != nil { return err } if r.clock.Now().After(pks.ExpiresAt()) { return ErrorPrivateKeysExpired } return nil } func (r *PrivateKeyRotator) privateKeySet() (*PrivateKeySet, error) { ks, err := r.repo.Get() if err != nil { return nil, err } pks, ok := ks.(*PrivateKeySet) if !ok { return nil, errors.New("unable to cast to PrivateKeySet") } return pks, nil } func (r *PrivateKeyRotator) nextRotation() (time.Duration, error) { pks, err := r.privateKeySet() if err == ErrorNoKeys { return 0, nil } if err != nil { return 0, err } now := r.clock.Now() // Ideally, we want to rotate after half the TTL has elapsed. idealRotationTime := pks.ExpiresAt().Add(-r.ttl / 2) // If we are past the ideal rotation time, rotate immediatly. return max(0, idealRotationTime.Sub(now)), nil } func max(a, b time.Duration) time.Duration { if a > b { return a } return b } func (r *PrivateKeyRotator) Run() chan struct{} { attempt := func() { k, err := r.generateKey() if err != nil { log.Printf("go-oidc: failed generating signing key: %v", err) return } exp := r.expiresAt() if err := rotatePrivateKeys(r.repo, k, r.keep, exp); err != nil { log.Printf("go-oidc: key rotation failed: %v", err) return } } stop := make(chan struct{}) go func() { for { var nextRotation time.Duration var sleep time.Duration var err error for { if nextRotation, err = r.nextRotation(); err == nil { break } sleep = ptime.ExpBackoff(sleep, time.Minute) log.Printf("go-oidc: error getting nextRotation, retrying in %v: %v", sleep, err) time.Sleep(sleep) } select { case <-r.clock.After(nextRotation): attempt() case <-stop: return } } }() return stop } func rotatePrivateKeys(repo PrivateKeySetRepo, k *PrivateKey, keep int, exp time.Time) error { ks, err := repo.Get() if err != nil && err != ErrorNoKeys { return err } var keys []*PrivateKey if ks != nil { pks, ok := ks.(*PrivateKeySet) if !ok { return errors.New("unable to cast to PrivateKeySet") } keys = pks.Keys() } keys = append([]*PrivateKey{k}, keys...) if l := len(keys); l > keep { keys = keys[0:keep] } nks := PrivateKeySet{ keys: keys, ActiveKeyID: k.ID(), expiresAt: exp, } return repo.Set(KeySet(&nks)) }