542 lines
15 KiB
Go
542 lines
15 KiB
Go
/*-
|
|
* Copyright 2014 Square Inc.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package jose
|
|
|
|
import (
|
|
"crypto/ecdsa"
|
|
"crypto/rsa"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"gopkg.in/square/go-jose.v2/json"
|
|
)
|
|
|
|
// Encrypter represents an encrypter which produces an encrypted JWE object.
|
|
type Encrypter interface {
|
|
Encrypt(plaintext []byte) (*JSONWebEncryption, error)
|
|
EncryptWithAuthData(plaintext []byte, aad []byte) (*JSONWebEncryption, error)
|
|
Options() EncrypterOptions
|
|
}
|
|
|
|
// A generic content cipher
|
|
type contentCipher interface {
|
|
keySize() int
|
|
encrypt(cek []byte, aad, plaintext []byte) (*aeadParts, error)
|
|
decrypt(cek []byte, aad []byte, parts *aeadParts) ([]byte, error)
|
|
}
|
|
|
|
// A key generator (for generating/getting a CEK)
|
|
type keyGenerator interface {
|
|
keySize() int
|
|
genKey() ([]byte, rawHeader, error)
|
|
}
|
|
|
|
// A generic key encrypter
|
|
type keyEncrypter interface {
|
|
encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) // Encrypt a key
|
|
}
|
|
|
|
// A generic key decrypter
|
|
type keyDecrypter interface {
|
|
decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) // Decrypt a key
|
|
}
|
|
|
|
// A generic encrypter based on the given key encrypter and content cipher.
|
|
type genericEncrypter struct {
|
|
contentAlg ContentEncryption
|
|
compressionAlg CompressionAlgorithm
|
|
cipher contentCipher
|
|
recipients []recipientKeyInfo
|
|
keyGenerator keyGenerator
|
|
extraHeaders map[HeaderKey]interface{}
|
|
}
|
|
|
|
type recipientKeyInfo struct {
|
|
keyID string
|
|
keyAlg KeyAlgorithm
|
|
keyEncrypter keyEncrypter
|
|
}
|
|
|
|
// EncrypterOptions represents options that can be set on new encrypters.
|
|
type EncrypterOptions struct {
|
|
Compression CompressionAlgorithm
|
|
|
|
// Optional map of additional keys to be inserted into the protected header
|
|
// of a JWS object. Some specifications which make use of JWS like to insert
|
|
// additional values here. All values must be JSON-serializable.
|
|
ExtraHeaders map[HeaderKey]interface{}
|
|
}
|
|
|
|
// WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it
|
|
// if necessary. It returns itself and so can be used in a fluent style.
|
|
func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions {
|
|
if eo.ExtraHeaders == nil {
|
|
eo.ExtraHeaders = map[HeaderKey]interface{}{}
|
|
}
|
|
eo.ExtraHeaders[k] = v
|
|
return eo
|
|
}
|
|
|
|
// WithContentType adds a content type ("cty") header and returns the updated
|
|
// EncrypterOptions.
|
|
func (eo *EncrypterOptions) WithContentType(contentType ContentType) *EncrypterOptions {
|
|
return eo.WithHeader(HeaderContentType, contentType)
|
|
}
|
|
|
|
// WithType adds a type ("typ") header and returns the updated EncrypterOptions.
|
|
func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions {
|
|
return eo.WithHeader(HeaderType, typ)
|
|
}
|
|
|
|
// Recipient represents an algorithm/key to encrypt messages to.
|
|
//
|
|
// PBES2Count and PBES2Salt correspond with the "p2c" and "p2s" headers used
|
|
// on the password-based encryption algorithms PBES2-HS256+A128KW,
|
|
// PBES2-HS384+A192KW, and PBES2-HS512+A256KW. If they are not provided a safe
|
|
// default of 100000 will be used for the count and a 128-bit random salt will
|
|
// be generated.
|
|
type Recipient struct {
|
|
Algorithm KeyAlgorithm
|
|
Key interface{}
|
|
KeyID string
|
|
PBES2Count int
|
|
PBES2Salt []byte
|
|
}
|
|
|
|
// NewEncrypter creates an appropriate encrypter based on the key type
|
|
func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions) (Encrypter, error) {
|
|
encrypter := &genericEncrypter{
|
|
contentAlg: enc,
|
|
recipients: []recipientKeyInfo{},
|
|
cipher: getContentCipher(enc),
|
|
}
|
|
if opts != nil {
|
|
encrypter.compressionAlg = opts.Compression
|
|
encrypter.extraHeaders = opts.ExtraHeaders
|
|
}
|
|
|
|
if encrypter.cipher == nil {
|
|
return nil, ErrUnsupportedAlgorithm
|
|
}
|
|
|
|
var keyID string
|
|
var rawKey interface{}
|
|
switch encryptionKey := rcpt.Key.(type) {
|
|
case JSONWebKey:
|
|
keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
|
|
case *JSONWebKey:
|
|
keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
|
|
case OpaqueKeyEncrypter:
|
|
keyID, rawKey = encryptionKey.KeyID(), encryptionKey
|
|
default:
|
|
rawKey = encryptionKey
|
|
}
|
|
|
|
switch rcpt.Algorithm {
|
|
case DIRECT:
|
|
// Direct encryption mode must be treated differently
|
|
if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) {
|
|
return nil, ErrUnsupportedKeyType
|
|
}
|
|
if encrypter.cipher.keySize() != len(rawKey.([]byte)) {
|
|
return nil, ErrInvalidKeySize
|
|
}
|
|
encrypter.keyGenerator = staticKeyGenerator{
|
|
key: rawKey.([]byte),
|
|
}
|
|
recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte))
|
|
recipientInfo.keyID = keyID
|
|
if rcpt.KeyID != "" {
|
|
recipientInfo.keyID = rcpt.KeyID
|
|
}
|
|
encrypter.recipients = []recipientKeyInfo{recipientInfo}
|
|
return encrypter, nil
|
|
case ECDH_ES:
|
|
// ECDH-ES (w/o key wrapping) is similar to DIRECT mode
|
|
typeOf := reflect.TypeOf(rawKey)
|
|
if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) {
|
|
return nil, ErrUnsupportedKeyType
|
|
}
|
|
encrypter.keyGenerator = ecKeyGenerator{
|
|
size: encrypter.cipher.keySize(),
|
|
algID: string(enc),
|
|
publicKey: rawKey.(*ecdsa.PublicKey),
|
|
}
|
|
recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey))
|
|
recipientInfo.keyID = keyID
|
|
if rcpt.KeyID != "" {
|
|
recipientInfo.keyID = rcpt.KeyID
|
|
}
|
|
encrypter.recipients = []recipientKeyInfo{recipientInfo}
|
|
return encrypter, nil
|
|
default:
|
|
// Can just add a standard recipient
|
|
encrypter.keyGenerator = randomKeyGenerator{
|
|
size: encrypter.cipher.keySize(),
|
|
}
|
|
err := encrypter.addRecipient(rcpt)
|
|
return encrypter, err
|
|
}
|
|
}
|
|
|
|
// NewMultiEncrypter creates a multi-encrypter based on the given parameters
|
|
func NewMultiEncrypter(enc ContentEncryption, rcpts []Recipient, opts *EncrypterOptions) (Encrypter, error) {
|
|
cipher := getContentCipher(enc)
|
|
|
|
if cipher == nil {
|
|
return nil, ErrUnsupportedAlgorithm
|
|
}
|
|
if rcpts == nil || len(rcpts) == 0 {
|
|
return nil, fmt.Errorf("square/go-jose: recipients is nil or empty")
|
|
}
|
|
|
|
encrypter := &genericEncrypter{
|
|
contentAlg: enc,
|
|
recipients: []recipientKeyInfo{},
|
|
cipher: cipher,
|
|
keyGenerator: randomKeyGenerator{
|
|
size: cipher.keySize(),
|
|
},
|
|
}
|
|
|
|
if opts != nil {
|
|
encrypter.compressionAlg = opts.Compression
|
|
}
|
|
|
|
for _, recipient := range rcpts {
|
|
err := encrypter.addRecipient(recipient)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return encrypter, nil
|
|
}
|
|
|
|
func (ctx *genericEncrypter) addRecipient(recipient Recipient) (err error) {
|
|
var recipientInfo recipientKeyInfo
|
|
|
|
switch recipient.Algorithm {
|
|
case DIRECT, ECDH_ES:
|
|
return fmt.Errorf("square/go-jose: key algorithm '%s' not supported in multi-recipient mode", recipient.Algorithm)
|
|
}
|
|
|
|
recipientInfo, err = makeJWERecipient(recipient.Algorithm, recipient.Key)
|
|
if recipient.KeyID != "" {
|
|
recipientInfo.keyID = recipient.KeyID
|
|
}
|
|
|
|
switch recipient.Algorithm {
|
|
case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
|
|
if sr, ok := recipientInfo.keyEncrypter.(*symmetricKeyCipher); ok {
|
|
sr.p2c = recipient.PBES2Count
|
|
sr.p2s = recipient.PBES2Salt
|
|
}
|
|
}
|
|
|
|
if err == nil {
|
|
ctx.recipients = append(ctx.recipients, recipientInfo)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKeyInfo, error) {
|
|
switch encryptionKey := encryptionKey.(type) {
|
|
case *rsa.PublicKey:
|
|
return newRSARecipient(alg, encryptionKey)
|
|
case *ecdsa.PublicKey:
|
|
return newECDHRecipient(alg, encryptionKey)
|
|
case []byte:
|
|
return newSymmetricRecipient(alg, encryptionKey)
|
|
case string:
|
|
return newSymmetricRecipient(alg, []byte(encryptionKey))
|
|
case *JSONWebKey:
|
|
recipient, err := makeJWERecipient(alg, encryptionKey.Key)
|
|
recipient.keyID = encryptionKey.KeyID
|
|
return recipient, err
|
|
}
|
|
if encrypter, ok := encryptionKey.(OpaqueKeyEncrypter); ok {
|
|
return newOpaqueKeyEncrypter(alg, encrypter)
|
|
}
|
|
return recipientKeyInfo{}, ErrUnsupportedKeyType
|
|
}
|
|
|
|
// newDecrypter creates an appropriate decrypter based on the key type
|
|
func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
|
|
switch decryptionKey := decryptionKey.(type) {
|
|
case *rsa.PrivateKey:
|
|
return &rsaDecrypterSigner{
|
|
privateKey: decryptionKey,
|
|
}, nil
|
|
case *ecdsa.PrivateKey:
|
|
return &ecDecrypterSigner{
|
|
privateKey: decryptionKey,
|
|
}, nil
|
|
case []byte:
|
|
return &symmetricKeyCipher{
|
|
key: decryptionKey,
|
|
}, nil
|
|
case string:
|
|
return &symmetricKeyCipher{
|
|
key: []byte(decryptionKey),
|
|
}, nil
|
|
case JSONWebKey:
|
|
return newDecrypter(decryptionKey.Key)
|
|
case *JSONWebKey:
|
|
return newDecrypter(decryptionKey.Key)
|
|
}
|
|
if okd, ok := decryptionKey.(OpaqueKeyDecrypter); ok {
|
|
return &opaqueKeyDecrypter{decrypter: okd}, nil
|
|
}
|
|
return nil, ErrUnsupportedKeyType
|
|
}
|
|
|
|
// Implementation of encrypt method producing a JWE object.
|
|
func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) {
|
|
return ctx.EncryptWithAuthData(plaintext, nil)
|
|
}
|
|
|
|
// Implementation of encrypt method producing a JWE object.
|
|
func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) {
|
|
obj := &JSONWebEncryption{}
|
|
obj.aad = aad
|
|
|
|
obj.protected = &rawHeader{}
|
|
err := obj.protected.set(headerEncryption, ctx.contentAlg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
obj.recipients = make([]recipientInfo, len(ctx.recipients))
|
|
|
|
if len(ctx.recipients) == 0 {
|
|
return nil, fmt.Errorf("square/go-jose: no recipients to encrypt to")
|
|
}
|
|
|
|
cek, headers, err := ctx.keyGenerator.genKey()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
obj.protected.merge(&headers)
|
|
|
|
for i, info := range ctx.recipients {
|
|
recipient, err := info.keyEncrypter.encryptKey(cek, info.keyAlg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = recipient.header.set(headerAlgorithm, info.keyAlg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if info.keyID != "" {
|
|
err = recipient.header.set(headerKeyID, info.keyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
obj.recipients[i] = recipient
|
|
}
|
|
|
|
if len(ctx.recipients) == 1 {
|
|
// Move per-recipient headers into main protected header if there's
|
|
// only a single recipient.
|
|
obj.protected.merge(obj.recipients[0].header)
|
|
obj.recipients[0].header = nil
|
|
}
|
|
|
|
if ctx.compressionAlg != NONE {
|
|
plaintext, err = compress(ctx.compressionAlg, plaintext)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = obj.protected.set(headerCompression, ctx.compressionAlg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
for k, v := range ctx.extraHeaders {
|
|
b, err := json.Marshal(v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
(*obj.protected)[k] = makeRawMessage(b)
|
|
}
|
|
|
|
authData := obj.computeAuthData()
|
|
parts, err := ctx.cipher.encrypt(cek, authData, plaintext)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
obj.iv = parts.iv
|
|
obj.ciphertext = parts.ciphertext
|
|
obj.tag = parts.tag
|
|
|
|
return obj, nil
|
|
}
|
|
|
|
func (ctx *genericEncrypter) Options() EncrypterOptions {
|
|
return EncrypterOptions{
|
|
Compression: ctx.compressionAlg,
|
|
ExtraHeaders: ctx.extraHeaders,
|
|
}
|
|
}
|
|
|
|
// Decrypt and validate the object and return the plaintext. Note that this
|
|
// function does not support multi-recipient, if you desire multi-recipient
|
|
// decryption use DecryptMulti instead.
|
|
func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
|
|
headers := obj.mergedHeaders(nil)
|
|
|
|
if len(obj.recipients) > 1 {
|
|
return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one")
|
|
}
|
|
|
|
critical, err := headers.getCritical()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("square/go-jose: invalid crit header")
|
|
}
|
|
|
|
if len(critical) > 0 {
|
|
return nil, fmt.Errorf("square/go-jose: unsupported crit header")
|
|
}
|
|
|
|
decrypter, err := newDecrypter(decryptionKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cipher := getContentCipher(headers.getEncryption())
|
|
if cipher == nil {
|
|
return nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(headers.getEncryption()))
|
|
}
|
|
|
|
generator := randomKeyGenerator{
|
|
size: cipher.keySize(),
|
|
}
|
|
|
|
parts := &aeadParts{
|
|
iv: obj.iv,
|
|
ciphertext: obj.ciphertext,
|
|
tag: obj.tag,
|
|
}
|
|
|
|
authData := obj.computeAuthData()
|
|
|
|
var plaintext []byte
|
|
recipient := obj.recipients[0]
|
|
recipientHeaders := obj.mergedHeaders(&recipient)
|
|
|
|
cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
|
|
if err == nil {
|
|
// Found a valid CEK -- let's try to decrypt.
|
|
plaintext, err = cipher.decrypt(cek, authData, parts)
|
|
}
|
|
|
|
if plaintext == nil {
|
|
return nil, ErrCryptoFailure
|
|
}
|
|
|
|
// The "zip" header parameter may only be present in the protected header.
|
|
if comp := obj.protected.getCompression(); comp != "" {
|
|
plaintext, err = decompress(comp, plaintext)
|
|
}
|
|
|
|
return plaintext, err
|
|
}
|
|
|
|
// DecryptMulti decrypts and validates the object and returns the plaintexts,
|
|
// with support for multiple recipients. It returns the index of the recipient
|
|
// for which the decryption was successful, the merged headers for that recipient,
|
|
// and the plaintext.
|
|
func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) {
|
|
globalHeaders := obj.mergedHeaders(nil)
|
|
|
|
critical, err := globalHeaders.getCritical()
|
|
if err != nil {
|
|
return -1, Header{}, nil, fmt.Errorf("square/go-jose: invalid crit header")
|
|
}
|
|
|
|
if len(critical) > 0 {
|
|
return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported crit header")
|
|
}
|
|
|
|
decrypter, err := newDecrypter(decryptionKey)
|
|
if err != nil {
|
|
return -1, Header{}, nil, err
|
|
}
|
|
|
|
encryption := globalHeaders.getEncryption()
|
|
cipher := getContentCipher(encryption)
|
|
if cipher == nil {
|
|
return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(encryption))
|
|
}
|
|
|
|
generator := randomKeyGenerator{
|
|
size: cipher.keySize(),
|
|
}
|
|
|
|
parts := &aeadParts{
|
|
iv: obj.iv,
|
|
ciphertext: obj.ciphertext,
|
|
tag: obj.tag,
|
|
}
|
|
|
|
authData := obj.computeAuthData()
|
|
|
|
index := -1
|
|
var plaintext []byte
|
|
var headers rawHeader
|
|
|
|
for i, recipient := range obj.recipients {
|
|
recipientHeaders := obj.mergedHeaders(&recipient)
|
|
|
|
cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
|
|
if err == nil {
|
|
// Found a valid CEK -- let's try to decrypt.
|
|
plaintext, err = cipher.decrypt(cek, authData, parts)
|
|
if err == nil {
|
|
index = i
|
|
headers = recipientHeaders
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if plaintext == nil || err != nil {
|
|
return -1, Header{}, nil, ErrCryptoFailure
|
|
}
|
|
|
|
// The "zip" header parameter may only be present in the protected header.
|
|
if comp := obj.protected.getCompression(); comp != "" {
|
|
plaintext, err = decompress(comp, plaintext)
|
|
}
|
|
|
|
sanitized, err := headers.sanitized()
|
|
if err != nil {
|
|
return -1, Header{}, nil, fmt.Errorf("square/go-jose: failed to sanitize header: %v", err)
|
|
}
|
|
|
|
return index, sanitized, plaintext, err
|
|
}
|