This repository has been archived on 2023-08-14. You can view files and clone it, but cannot push or open issues or pull requests.
dex/server/rotation.go

250 lines
7.3 KiB
Go

package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/hex"
"errors"
"fmt"
"io"
"time"
"gopkg.in/square/go-jose.v2"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage"
)
var errAlreadyRotated = errors.New("keys already rotated by another server instance")
// rotationStrategy describes a strategy for generating cryptographic keys, how
// often to rotate them, and how long they can validate signatures after rotation.
type rotationStrategy struct {
// Time between rotations.
rotationFrequency time.Duration
// After being rotated how long should the key be kept around for validating
// signatures?
idTokenValidFor time.Duration
// Keys are always RSA keys. Though cryptopasta recommends ECDSA keys, not every
// client may support these (e.g. github.com/coreos/go-oidc/oidc).
key func() (*rsa.PrivateKey, error)
}
// staticRotationStrategy returns a strategy which never rotates keys.
func staticRotationStrategy(key *rsa.PrivateKey) rotationStrategy {
return rotationStrategy{
// Setting these values to 100 years is easier than having a flag indicating no rotation.
rotationFrequency: time.Hour * 8760 * 100,
idTokenValidFor: time.Hour * 8760 * 100,
key: func() (*rsa.PrivateKey, error) { return key, nil },
}
}
// defaultRotationStrategy returns a strategy which rotates keys every provided period,
// holding onto the public parts for some specified amount of time.
func defaultRotationStrategy(rotationFrequency, idTokenValidFor time.Duration) rotationStrategy {
return rotationStrategy{
rotationFrequency: rotationFrequency,
idTokenValidFor: idTokenValidFor,
key: func() (*rsa.PrivateKey, error) {
return rsa.GenerateKey(rand.Reader, 2048)
},
}
}
type keyRotator struct {
storage.Storage
strategy rotationStrategy
now func() time.Time
logger log.Logger
}
// startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled.
//
// The method blocks until after the first attempt to rotate keys has completed. That way
// healthy storages will return from this call with valid keys.
func (s *Server) startKeyRotation(ctx context.Context, strategy rotationStrategy, now func() time.Time) {
rotator := keyRotator{s.storage, strategy, now, s.logger}
// Try to rotate immediately so properly configured storages will have keys.
if err := rotator.rotate(); err != nil {
if err == errAlreadyRotated {
s.logger.Infof("Key rotation not needed: %v", err)
} else {
s.logger.Errorf("failed to rotate keys: %v", err)
}
}
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(time.Second * 30):
if err := rotator.rotate(); err != nil {
s.logger.Errorf("failed to rotate keys: %v", err)
}
}
}
}()
}
func (k keyRotator) rotate() error {
keys, err := k.GetKeys()
if err != nil && err != storage.ErrNotFound {
return fmt.Errorf("get keys: %v", err)
}
if k.now().Before(keys.NextRotation) {
return nil
}
k.logger.Infof("keys expired, rotating")
// Generate the key outside of a storage transaction.
key, err := k.strategy.key()
if err != nil {
return fmt.Errorf("generate key: %v", err)
}
b := make([]byte, 20)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
panic(err)
}
keyID := hex.EncodeToString(b)
priv := &jose.JSONWebKey{
Key: key,
KeyID: keyID,
Algorithm: "RS256",
Use: "sig",
}
pub := &jose.JSONWebKey{
Key: key.Public(),
KeyID: keyID,
Algorithm: "RS256",
Use: "sig",
}
var nextRotation time.Time
err = k.Storage.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
tNow := k.now()
// if you are running multiple instances of dex, another instance
// could have already rotated the keys.
if tNow.Before(keys.NextRotation) {
return storage.Keys{}, errAlreadyRotated
}
expired := func(key storage.VerificationKey) bool {
return tNow.After(key.Expiry)
}
// Remove any verification keys that have expired.
i := 0
for _, key := range keys.VerificationKeys {
if !expired(key) {
keys.VerificationKeys[i] = key
i++
}
}
keys.VerificationKeys = keys.VerificationKeys[:i]
if keys.SigningKeyPub != nil {
// Move current signing key to a verification only key, throwing
// away the private part.
verificationKey := storage.VerificationKey{
PublicKey: keys.SigningKeyPub,
// After demoting the signing key, keep the token around for at least
// the amount of time an ID Token is valid for. This ensures the
// verification key won't expire until all ID Tokens it's signed
// expired as well.
Expiry: tNow.Add(k.strategy.idTokenValidFor),
}
keys.VerificationKeys = append(keys.VerificationKeys, verificationKey)
}
nextRotation = k.now().Add(k.strategy.rotationFrequency)
keys.SigningKey = priv
keys.SigningKeyPub = pub
keys.NextRotation = nextRotation
return keys, nil
})
if err != nil {
return err
}
k.logger.Infof("keys rotated, next rotation: %s", nextRotation)
return nil
}
type RefreshTokenPolicy struct {
rotateRefreshTokens bool // enable rotation
absoluteLifetime time.Duration // interval from token creation to the end of its life
validIfNotUsedFor time.Duration // interval from last token update to the end of its life
reuseInterval time.Duration // interval within which old refresh token is allowed to be reused
now func() time.Time
logger log.Logger
}
func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) {
r := RefreshTokenPolicy{now: time.Now, logger: logger}
var err error
if validIfNotUsedFor != "" {
r.validIfNotUsedFor, err = time.ParseDuration(validIfNotUsedFor)
if err != nil {
return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err)
}
logger.Infof("config refresh tokens valid if not used for: %v", validIfNotUsedFor)
}
if absoluteLifetime != "" {
r.absoluteLifetime, err = time.ParseDuration(absoluteLifetime)
if err != nil {
return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err)
}
logger.Infof("config refresh tokens absolute lifetime: %v", absoluteLifetime)
}
if reuseInterval != "" {
r.reuseInterval, err = time.ParseDuration(reuseInterval)
if err != nil {
return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err)
}
logger.Infof("config refresh tokens reuse interval: %v", reuseInterval)
}
r.rotateRefreshTokens = !rotation
logger.Infof("config refresh tokens rotation enabled: %v", r.rotateRefreshTokens)
return &r, nil
}
func (r *RefreshTokenPolicy) RotationEnabled() bool {
return r.rotateRefreshTokens
}
func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool {
if r.absoluteLifetime == 0 {
return false // expiration disabled
}
return r.now().After(lastUsed.Add(r.absoluteLifetime))
}
func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool {
if r.validIfNotUsedFor == 0 {
return false // expiration disabled
}
return r.now().After(lastUsed.Add(r.validIfNotUsedFor))
}
func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool {
if r.reuseInterval == 0 {
return false // expiration disabled
}
return !r.now().After(lastUsed.Add(r.reuseInterval))
}