parent
5feba7e3a9
commit
bdd70e798a
@ -1,7 +0,0 @@
|
||||
package http
|
||||
|
||||
import "net/http"
|
||||
|
||||
type Client interface {
|
||||
Do(*http.Request) (*http.Response, error)
|
||||
}
|
@ -1,2 +0,0 @@
|
||||
// Package http is DEPRECATED. Use net/http instead.
|
||||
package http
|
@ -1,161 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func WriteError(w http.ResponseWriter, code int, msg string) {
|
||||
e := struct {
|
||||
Error string `json:"error"`
|
||||
}{
|
||||
Error: msg,
|
||||
}
|
||||
b, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
log.Printf("go-oidc: failed to marshal %#v: %v", e, err)
|
||||
code = http.StatusInternalServerError
|
||||
b = []byte(`{"error":"server_error"}`)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
// BasicAuth parses a username and password from the request's
|
||||
// Authorization header. This was pulled from golang master:
|
||||
// https://codereview.appspot.com/76540043
|
||||
func BasicAuth(r *http.Request) (username, password string, ok bool) {
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(auth, "Basic ") {
|
||||
return
|
||||
}
|
||||
c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic "))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cs := string(c)
|
||||
s := strings.IndexByte(cs, ':')
|
||||
if s < 0 {
|
||||
return
|
||||
}
|
||||
return cs[:s], cs[s+1:], true
|
||||
}
|
||||
|
||||
func cacheControlMaxAge(hdr string) (time.Duration, bool, error) {
|
||||
for _, field := range strings.Split(hdr, ",") {
|
||||
parts := strings.SplitN(strings.TrimSpace(field), "=", 2)
|
||||
k := strings.ToLower(strings.TrimSpace(parts[0]))
|
||||
if k != "max-age" {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(parts) == 1 {
|
||||
return 0, false, errors.New("max-age has no value")
|
||||
}
|
||||
|
||||
v := strings.TrimSpace(parts[1])
|
||||
if v == "" {
|
||||
return 0, false, errors.New("max-age has empty value")
|
||||
}
|
||||
|
||||
age, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
if age <= 0 {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
return time.Duration(age) * time.Second, true, nil
|
||||
}
|
||||
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
func expires(date, expires string) (time.Duration, bool, error) {
|
||||
if date == "" || expires == "" {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
var te time.Time
|
||||
var err error
|
||||
if expires == "0" {
|
||||
return 0, false, nil
|
||||
}
|
||||
te, err = time.Parse(time.RFC1123, expires)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
td, err := time.Parse(time.RFC1123, date)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
ttl := te.Sub(td)
|
||||
|
||||
// headers indicate data already expired, caller should not
|
||||
// have to care about this case
|
||||
if ttl <= 0 {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
return ttl, true, nil
|
||||
}
|
||||
|
||||
func Cacheable(hdr http.Header) (time.Duration, bool, error) {
|
||||
ttl, ok, err := cacheControlMaxAge(hdr.Get("Cache-Control"))
|
||||
if err != nil || ok {
|
||||
return ttl, ok, err
|
||||
}
|
||||
|
||||
return expires(hdr.Get("Date"), hdr.Get("Expires"))
|
||||
}
|
||||
|
||||
// MergeQuery appends additional query values to an existing URL.
|
||||
func MergeQuery(u url.URL, q url.Values) url.URL {
|
||||
uv := u.Query()
|
||||
for k, vs := range q {
|
||||
for _, v := range vs {
|
||||
uv.Add(k, v)
|
||||
}
|
||||
}
|
||||
u.RawQuery = uv.Encode()
|
||||
return u
|
||||
}
|
||||
|
||||
// NewResourceLocation appends a resource id to the end of the requested URL path.
|
||||
func NewResourceLocation(reqURL *url.URL, id string) string {
|
||||
var u url.URL
|
||||
u = *reqURL
|
||||
u.Path = path.Join(u.Path, id)
|
||||
u.RawQuery = ""
|
||||
u.Fragment = ""
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// CopyRequest returns a clone of the provided *http.Request.
|
||||
// The returned object is a shallow copy of the struct and a
|
||||
// deep copy of its Header field.
|
||||
func CopyRequest(r *http.Request) *http.Request {
|
||||
r2 := *r
|
||||
r2.Header = make(http.Header)
|
||||
for k, s := range r.Header {
|
||||
r2.Header[k] = s
|
||||
}
|
||||
return &r2
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// ParseNonEmptyURL checks that a string is a parsable URL which is also not empty
|
||||
// since `url.Parse("")` does not return an error. Must contian a scheme and a host.
|
||||
func ParseNonEmptyURL(u string) (*url.URL, error) {
|
||||
if u == "" {
|
||||
return nil, errors.New("url is empty")
|
||||
}
|
||||
|
||||
ur, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ur.Scheme == "" {
|
||||
return nil, errors.New("url scheme is empty")
|
||||
}
|
||||
|
||||
if ur.Host == "" {
|
||||
return nil, errors.New("url host is empty")
|
||||
}
|
||||
|
||||
return ur, nil
|
||||
}
|
@ -1,2 +0,0 @@
|
||||
// Package key is DEPRECATED. Use github.com/coreos/go-oidc instead.
|
||||
package key
|
@ -1,153 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
func NewPublicKey(jwk jose.JWK) *PublicKey {
|
||||
return &PublicKey{jwk: jwk}
|
||||
}
|
||||
|
||||
type PublicKey struct {
|
||||
jwk jose.JWK
|
||||
}
|
||||
|
||||
func (k *PublicKey) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(&k.jwk)
|
||||
}
|
||||
|
||||
func (k *PublicKey) UnmarshalJSON(data []byte) error {
|
||||
var jwk jose.JWK
|
||||
if err := json.Unmarshal(data, &jwk); err != nil {
|
||||
return err
|
||||
}
|
||||
k.jwk = jwk
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k *PublicKey) ID() string {
|
||||
return k.jwk.ID
|
||||
}
|
||||
|
||||
func (k *PublicKey) Verifier() (jose.Verifier, error) {
|
||||
return jose.NewVerifierRSA(k.jwk)
|
||||
}
|
||||
|
||||
type PrivateKey struct {
|
||||
KeyID string
|
||||
PrivateKey *rsa.PrivateKey
|
||||
}
|
||||
|
||||
func (k *PrivateKey) ID() string {
|
||||
return k.KeyID
|
||||
}
|
||||
|
||||
func (k *PrivateKey) Signer() jose.Signer {
|
||||
return jose.NewSignerRSA(k.ID(), *k.PrivateKey)
|
||||
}
|
||||
|
||||
func (k *PrivateKey) JWK() jose.JWK {
|
||||
return jose.JWK{
|
||||
ID: k.KeyID,
|
||||
Type: "RSA",
|
||||
Alg: "RS256",
|
||||
Use: "sig",
|
||||
Exponent: k.PrivateKey.PublicKey.E,
|
||||
Modulus: k.PrivateKey.PublicKey.N,
|
||||
}
|
||||
}
|
||||
|
||||
type KeySet interface {
|
||||
ExpiresAt() time.Time
|
||||
}
|
||||
|
||||
type PublicKeySet struct {
|
||||
keys []PublicKey
|
||||
index map[string]*PublicKey
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func NewPublicKeySet(jwks []jose.JWK, exp time.Time) *PublicKeySet {
|
||||
keys := make([]PublicKey, len(jwks))
|
||||
index := make(map[string]*PublicKey)
|
||||
for i, jwk := range jwks {
|
||||
keys[i] = *NewPublicKey(jwk)
|
||||
index[keys[i].ID()] = &keys[i]
|
||||
}
|
||||
return &PublicKeySet{
|
||||
keys: keys,
|
||||
index: index,
|
||||
expiresAt: exp,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PublicKeySet) ExpiresAt() time.Time {
|
||||
return s.expiresAt
|
||||
}
|
||||
|
||||
func (s *PublicKeySet) Keys() []PublicKey {
|
||||
return s.keys
|
||||
}
|
||||
|
||||
func (s *PublicKeySet) Key(id string) *PublicKey {
|
||||
return s.index[id]
|
||||
}
|
||||
|
||||
type PrivateKeySet struct {
|
||||
keys []*PrivateKey
|
||||
ActiveKeyID string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func NewPrivateKeySet(keys []*PrivateKey, exp time.Time) *PrivateKeySet {
|
||||
return &PrivateKeySet{
|
||||
keys: keys,
|
||||
ActiveKeyID: keys[0].ID(),
|
||||
expiresAt: exp.UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PrivateKeySet) Keys() []*PrivateKey {
|
||||
return s.keys
|
||||
}
|
||||
|
||||
func (s *PrivateKeySet) ExpiresAt() time.Time {
|
||||
return s.expiresAt
|
||||
}
|
||||
|
||||
func (s *PrivateKeySet) Active() *PrivateKey {
|
||||
for i, k := range s.keys {
|
||||
if k.ID() == s.ActiveKeyID {
|
||||
return s.keys[i]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type GeneratePrivateKeyFunc func() (*PrivateKey, error)
|
||||
|
||||
func GeneratePrivateKey() (*PrivateKey, error) {
|
||||
pk, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyID := make([]byte, 20)
|
||||
if _, err := io.ReadFull(rand.Reader, keyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
k := PrivateKey{
|
||||
KeyID: hex.EncodeToString(keyID),
|
||||
PrivateKey: pk,
|
||||
}
|
||||
|
||||
return &k, nil
|
||||
}
|
@ -1,99 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/pkg/health"
|
||||
)
|
||||
|
||||
type PrivateKeyManager interface {
|
||||
ExpiresAt() time.Time
|
||||
Signer() (jose.Signer, error)
|
||||
JWKs() ([]jose.JWK, error)
|
||||
PublicKeys() ([]PublicKey, error)
|
||||
|
||||
WritableKeySetRepo
|
||||
health.Checkable
|
||||
}
|
||||
|
||||
func NewPrivateKeyManager() PrivateKeyManager {
|
||||
return &privateKeyManager{
|
||||
clock: clockwork.NewRealClock(),
|
||||
}
|
||||
}
|
||||
|
||||
type privateKeyManager struct {
|
||||
keySet *PrivateKeySet
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func (m *privateKeyManager) ExpiresAt() time.Time {
|
||||
if m.keySet == nil {
|
||||
return m.clock.Now().UTC()
|
||||
}
|
||||
|
||||
return m.keySet.ExpiresAt()
|
||||
}
|
||||
|
||||
func (m *privateKeyManager) Signer() (jose.Signer, error) {
|
||||
if err := m.Healthy(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.keySet.Active().Signer(), nil
|
||||
}
|
||||
|
||||
func (m *privateKeyManager) JWKs() ([]jose.JWK, error) {
|
||||
if err := m.Healthy(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := m.keySet.Keys()
|
||||
jwks := make([]jose.JWK, len(keys))
|
||||
for i, k := range keys {
|
||||
jwks[i] = k.JWK()
|
||||
}
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
func (m *privateKeyManager) PublicKeys() ([]PublicKey, error) {
|
||||
jwks, err := m.JWKs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys := make([]PublicKey, len(jwks))
|
||||
for i, jwk := range jwks {
|
||||
keys[i] = *NewPublicKey(jwk)
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (m *privateKeyManager) Healthy() error {
|
||||
if m.keySet == nil {
|
||||
return errors.New("private key manager uninitialized")
|
||||
}
|
||||
|
||||
if len(m.keySet.Keys()) == 0 {
|
||||
return errors.New("private key manager zero keys")
|
||||
}
|
||||
|
||||
if m.keySet.ExpiresAt().Before(m.clock.Now().UTC()) {
|
||||
return errors.New("private key manager keys expired")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *privateKeyManager) Set(keySet KeySet) error {
|
||||
privKeySet, ok := keySet.(*PrivateKeySet)
|
||||
if !ok {
|
||||
return errors.New("unable to cast to PrivateKeySet")
|
||||
}
|
||||
|
||||
m.keySet = privKeySet
|
||||
return nil
|
||||
}
|
@ -1,55 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var ErrorNoKeys = errors.New("no keys found")
|
||||
|
||||
type WritableKeySetRepo interface {
|
||||
Set(KeySet) error
|
||||
}
|
||||
|
||||
type ReadableKeySetRepo interface {
|
||||
Get() (KeySet, error)
|
||||
}
|
||||
|
||||
type PrivateKeySetRepo interface {
|
||||
WritableKeySetRepo
|
||||
ReadableKeySetRepo
|
||||
}
|
||||
|
||||
func NewPrivateKeySetRepo() PrivateKeySetRepo {
|
||||
return &memPrivateKeySetRepo{}
|
||||
}
|
||||
|
||||
type memPrivateKeySetRepo struct {
|
||||
mu sync.RWMutex
|
||||
pks PrivateKeySet
|
||||
}
|
||||
|
||||
func (r *memPrivateKeySetRepo) Set(ks KeySet) error {
|
||||
pks, ok := ks.(*PrivateKeySet)
|
||||
if !ok {
|
||||
return errors.New("unable to cast to PrivateKeySet")
|
||||
} else if pks == nil {
|
||||
return errors.New("nil KeySet")
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.pks = *pks
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memPrivateKeySetRepo) Get() (KeySet, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if r.pks.keys == nil {
|
||||
return nil, ErrorNoKeys
|
||||
}
|
||||
return KeySet(&r.pks), nil
|
||||
}
|
@ -1,159 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
ptime "github.com/coreos/pkg/timeutil"
|
||||
"github.com/jonboulle/clockwork"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorPrivateKeysExpired = errors.New("private keys have expired")
|
||||
)
|
||||
|
||||
func NewPrivateKeyRotator(repo PrivateKeySetRepo, ttl time.Duration) *PrivateKeyRotator {
|
||||
return &PrivateKeyRotator{
|
||||
repo: repo,
|
||||
ttl: ttl,
|
||||
|
||||
keep: 2,
|
||||
generateKey: GeneratePrivateKey,
|
||||
clock: clockwork.NewRealClock(),
|
||||
}
|
||||
}
|
||||
|
||||
type PrivateKeyRotator struct {
|
||||
repo PrivateKeySetRepo
|
||||
generateKey GeneratePrivateKeyFunc
|
||||
clock clockwork.Clock
|
||||
keep int
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func (r *PrivateKeyRotator) expiresAt() time.Time {
|
||||
return r.clock.Now().UTC().Add(r.ttl)
|
||||
}
|
||||
|
||||
func (r *PrivateKeyRotator) Healthy() error {
|
||||
pks, err := r.privateKeySet()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.clock.Now().After(pks.ExpiresAt()) {
|
||||
return ErrorPrivateKeysExpired
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *PrivateKeyRotator) privateKeySet() (*PrivateKeySet, error) {
|
||||
ks, err := r.repo.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pks, ok := ks.(*PrivateKeySet)
|
||||
if !ok {
|
||||
return nil, errors.New("unable to cast to PrivateKeySet")
|
||||
}
|
||||
return pks, nil
|
||||
}
|
||||
|
||||
func (r *PrivateKeyRotator) nextRotation() (time.Duration, error) {
|
||||
pks, err := r.privateKeySet()
|
||||
if err == ErrorNoKeys {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
now := r.clock.Now()
|
||||
|
||||
// Ideally, we want to rotate after half the TTL has elapsed.
|
||||
idealRotationTime := pks.ExpiresAt().Add(-r.ttl / 2)
|
||||
|
||||
// If we are past the ideal rotation time, rotate immediatly.
|
||||
return max(0, idealRotationTime.Sub(now)), nil
|
||||
}
|
||||
|
||||
func max(a, b time.Duration) time.Duration {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *PrivateKeyRotator) Run() chan struct{} {
|
||||
attempt := func() {
|
||||
k, err := r.generateKey()
|
||||
if err != nil {
|
||||
log.Printf("go-oidc: failed generating signing key: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
exp := r.expiresAt()
|
||||
if err := rotatePrivateKeys(r.repo, k, r.keep, exp); err != nil {
|
||||
log.Printf("go-oidc: key rotation failed: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
stop := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
var nextRotation time.Duration
|
||||
var sleep time.Duration
|
||||
var err error
|
||||
for {
|
||||
if nextRotation, err = r.nextRotation(); err == nil {
|
||||
break
|
||||
}
|
||||
sleep = ptime.ExpBackoff(sleep, time.Minute)
|
||||
log.Printf("go-oidc: error getting nextRotation, retrying in %v: %v", sleep, err)
|
||||
time.Sleep(sleep)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-r.clock.After(nextRotation):
|
||||
attempt()
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return stop
|
||||
}
|
||||
|
||||
func rotatePrivateKeys(repo PrivateKeySetRepo, k *PrivateKey, keep int, exp time.Time) error {
|
||||
ks, err := repo.Get()
|
||||
if err != nil && err != ErrorNoKeys {
|
||||
return err
|
||||
}
|
||||
|
||||
var keys []*PrivateKey
|
||||
if ks != nil {
|
||||
pks, ok := ks.(*PrivateKeySet)
|
||||
if !ok {
|
||||
return errors.New("unable to cast to PrivateKeySet")
|
||||
}
|
||||
keys = pks.Keys()
|
||||
}
|
||||
|
||||
keys = append([]*PrivateKey{k}, keys...)
|
||||
if l := len(keys); l > keep {
|
||||
keys = keys[0:keep]
|
||||
}
|
||||
|
||||
nks := PrivateKeySet{
|
||||
keys: keys,
|
||||
ActiveKeyID: k.ID(),
|
||||
expiresAt: exp,
|
||||
}
|
||||
|
||||
return repo.Set(KeySet(&nks))
|
||||
}
|
@ -1,91 +0,0 @@
|
||||