Add HMAC protection on /approval endpoint

Signed-off-by: Bob Callaway <bcallaway@google.com>
This commit is contained in:
Bob Callaway
2022-07-06 07:11:37 -04:00
parent 454122ca22
commit fcfbb1ecb0
19 changed files with 274 additions and 14 deletions

View File

@@ -31,6 +31,7 @@ func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error {
SetExpiry(authRequest.Expiry.UTC()).
SetConnectorID(authRequest.ConnectorID).
SetConnectorData(authRequest.ConnectorData).
SetHmacKey(authRequest.HMACKey).
Save(context.TODO())
if err != nil {
return convertDBError("create auth request: %w", err)
@@ -94,6 +95,7 @@ func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthReq
SetExpiry(newAuthRequest.Expiry.UTC()).
SetConnectorID(newAuthRequest.ConnectorID).
SetConnectorData(newAuthRequest.ConnectorData).
SetHmacKey(newAuthRequest.HMACKey).
Save(context.TODO())
if err != nil {
return rollback(tx, "update auth request uploading: %w", err)

View File

@@ -45,6 +45,7 @@ func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest {
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
},
HMACKey: a.HmacKey,
}
}

View File

@@ -55,6 +55,8 @@ type AuthRequest struct {
CodeChallenge string `json:"code_challenge,omitempty"`
// CodeChallengeMethod holds the value of the "code_challenge_method" field.
CodeChallengeMethod string `json:"code_challenge_method,omitempty"`
// HmacKey holds the value of the "hmac_key" field.
HmacKey []byte `json:"hmac_key,omitempty"`
}
// scanValues returns the types for scanning values from sql.Rows.
@@ -62,7 +64,7 @@ func (*AuthRequest) scanValues(columns []string) ([]interface{}, error) {
values := make([]interface{}, len(columns))
for i := range columns {
switch columns[i] {
case authrequest.FieldScopes, authrequest.FieldResponseTypes, authrequest.FieldClaimsGroups, authrequest.FieldConnectorData:
case authrequest.FieldScopes, authrequest.FieldResponseTypes, authrequest.FieldClaimsGroups, authrequest.FieldConnectorData, authrequest.FieldHmacKey:
values[i] = new([]byte)
case authrequest.FieldForceApprovalPrompt, authrequest.FieldLoggedIn, authrequest.FieldClaimsEmailVerified:
values[i] = new(sql.NullBool)
@@ -211,6 +213,12 @@ func (ar *AuthRequest) assignValues(columns []string, values []interface{}) erro
} else if value.Valid {
ar.CodeChallengeMethod = value.String
}
case authrequest.FieldHmacKey:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field hmac_key", values[i])
} else if value != nil {
ar.HmacKey = *value
}
}
}
return nil
@@ -279,6 +287,8 @@ func (ar *AuthRequest) String() string {
builder.WriteString(ar.CodeChallenge)
builder.WriteString(", code_challenge_method=")
builder.WriteString(ar.CodeChallengeMethod)
builder.WriteString(", hmac_key=")
builder.WriteString(fmt.Sprintf("%v", ar.HmacKey))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -45,6 +45,8 @@ const (
FieldCodeChallenge = "code_challenge"
// FieldCodeChallengeMethod holds the string denoting the code_challenge_method field in the database.
FieldCodeChallengeMethod = "code_challenge_method"
// FieldHmacKey holds the string denoting the hmac_key field in the database.
FieldHmacKey = "hmac_key"
// Table holds the table name of the authrequest in the database.
Table = "auth_requests"
)
@@ -71,6 +73,7 @@ var Columns = []string{
FieldExpiry,
FieldCodeChallenge,
FieldCodeChallengeMethod,
FieldHmacKey,
}
// ValidColumn reports if the column name is valid (part of the table columns).

View File

@@ -204,6 +204,13 @@ func CodeChallengeMethod(v string) predicate.AuthRequest {
})
}
// HmacKey applies equality check predicate on the "hmac_key" field. It's identical to HmacKeyEQ.
func HmacKey(v []byte) predicate.AuthRequest {
return predicate.AuthRequest(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldHmacKey), v))
})
}
// ClientIDEQ applies the EQ predicate on the "client_id" field.
func ClientIDEQ(v string) predicate.AuthRequest {
return predicate.AuthRequest(func(s *sql.Selector) {
@@ -1675,6 +1682,82 @@ func CodeChallengeMethodContainsFold(v string) predicate.AuthRequest {
})
}
// HmacKeyEQ applies the EQ predicate on the "hmac_key" field.
func HmacKeyEQ(v []byte) predicate.AuthRequest {
return predicate.AuthRequest(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldHmacKey), v))
})
}
// HmacKeyNEQ applies the NEQ predicate on the "hmac_key" field.
func HmacKeyNEQ(v []byte) predicate.AuthRequest {
return predicate.AuthRequest(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldHmacKey), v))
})
}
// HmacKeyIn applies the In predicate on the "hmac_key" field.
func HmacKeyIn(vs ...[]byte) predicate.AuthRequest {
v := make([]interface{}, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.AuthRequest(func(s *sql.Selector) {
// if not arguments were provided, append the FALSE constants,
// since we can't apply "IN ()". This will make this predicate falsy.
if len(v) == 0 {
s.Where(sql.False())
return
}
s.Where(sql.In(s.C(FieldHmacKey), v...))
})
}
// HmacKeyNotIn applies the NotIn predicate on the "hmac_key" field.
func HmacKeyNotIn(vs ...[]byte) predicate.AuthRequest {
v := make([]interface{}, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.AuthRequest(func(s *sql.Selector) {
// if not arguments were provided, append the FALSE constants,
// since we can't apply "IN ()". This will make this predicate falsy.
if len(v) == 0 {
s.Where(sql.False())
return
}
s.Where(sql.NotIn(s.C(FieldHmacKey), v...))
})
}
// HmacKeyGT applies the GT predicate on the "hmac_key" field.
func HmacKeyGT(v []byte) predicate.AuthRequest {
return predicate.AuthRequest(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldHmacKey), v))
})
}
// HmacKeyGTE applies the GTE predicate on the "hmac_key" field.
func HmacKeyGTE(v []byte) predicate.AuthRequest {
return predicate.AuthRequest(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldHmacKey), v))
})
}
// HmacKeyLT applies the LT predicate on the "hmac_key" field.
func HmacKeyLT(v []byte) predicate.AuthRequest {
return predicate.AuthRequest(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldHmacKey), v))
})
}
// HmacKeyLTE applies the LTE predicate on the "hmac_key" field.
func HmacKeyLTE(v []byte) predicate.AuthRequest {
return predicate.AuthRequest(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldHmacKey), v))
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.AuthRequest) predicate.AuthRequest {
return predicate.AuthRequest(func(s *sql.Selector) {

View File

@@ -158,6 +158,12 @@ func (arc *AuthRequestCreate) SetNillableCodeChallengeMethod(s *string) *AuthReq
return arc
}
// SetHmacKey sets the "hmac_key" field.
func (arc *AuthRequestCreate) SetHmacKey(b []byte) *AuthRequestCreate {
arc.mutation.SetHmacKey(b)
return arc
}
// SetID sets the "id" field.
func (arc *AuthRequestCreate) SetID(s string) *AuthRequestCreate {
arc.mutation.SetID(s)
@@ -296,6 +302,9 @@ func (arc *AuthRequestCreate) check() error {
if _, ok := arc.mutation.CodeChallengeMethod(); !ok {
return &ValidationError{Name: "code_challenge_method", err: errors.New(`db: missing required field "AuthRequest.code_challenge_method"`)}
}
if _, ok := arc.mutation.HmacKey(); !ok {
return &ValidationError{Name: "hmac_key", err: errors.New(`db: missing required field "AuthRequest.hmac_key"`)}
}
if v, ok := arc.mutation.ID(); ok {
if err := authrequest.IDValidator(v); err != nil {
return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "AuthRequest.id": %w`, err)}
@@ -489,6 +498,14 @@ func (arc *AuthRequestCreate) createSpec() (*AuthRequest, *sqlgraph.CreateSpec)
})
_node.CodeChallengeMethod = value
}
if value, ok := arc.mutation.HmacKey(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeBytes,
Value: value,
Column: authrequest.FieldHmacKey,
})
_node.HmacKey = value
}
return _node, _spec
}

View File

@@ -190,6 +190,12 @@ func (aru *AuthRequestUpdate) SetNillableCodeChallengeMethod(s *string) *AuthReq
return aru
}
// SetHmacKey sets the "hmac_key" field.
func (aru *AuthRequestUpdate) SetHmacKey(b []byte) *AuthRequestUpdate {
aru.mutation.SetHmacKey(b)
return aru
}
// Mutation returns the AuthRequestMutation object of the builder.
func (aru *AuthRequestUpdate) Mutation() *AuthRequestMutation {
return aru.mutation
@@ -424,6 +430,13 @@ func (aru *AuthRequestUpdate) sqlSave(ctx context.Context) (n int, err error) {
Column: authrequest.FieldCodeChallengeMethod,
})
}
if value, ok := aru.mutation.HmacKey(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeBytes,
Value: value,
Column: authrequest.FieldHmacKey,
})
}
if n, err = sqlgraph.UpdateNodes(ctx, aru.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{authrequest.Label}
@@ -605,6 +618,12 @@ func (aruo *AuthRequestUpdateOne) SetNillableCodeChallengeMethod(s *string) *Aut
return aruo
}
// SetHmacKey sets the "hmac_key" field.
func (aruo *AuthRequestUpdateOne) SetHmacKey(b []byte) *AuthRequestUpdateOne {
aruo.mutation.SetHmacKey(b)
return aruo
}
// Mutation returns the AuthRequestMutation object of the builder.
func (aruo *AuthRequestUpdateOne) Mutation() *AuthRequestMutation {
return aruo.mutation
@@ -863,6 +882,13 @@ func (aruo *AuthRequestUpdateOne) sqlSave(ctx context.Context) (_node *AuthReque
Column: authrequest.FieldCodeChallengeMethod,
})
}
if value, ok := aruo.mutation.HmacKey(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeBytes,
Value: value,
Column: authrequest.FieldHmacKey,
})
}
_node = &AuthRequest{config: aruo.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

View File

@@ -55,6 +55,7 @@ var (
{Name: "expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}},
{Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}},
{Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}},
{Name: "hmac_key", Type: field.TypeBytes},
}
// AuthRequestsTable holds the schema information for the "auth_requests" table.
AuthRequestsTable = &schema.Table{

View File

@@ -1205,6 +1205,7 @@ type AuthRequestMutation struct {
expiry *time.Time
code_challenge *string
code_challenge_method *string
hmac_key *[]byte
clearedFields map[string]struct{}
done bool
oldValue func(context.Context) (*AuthRequest, error)
@@ -2051,6 +2052,42 @@ func (m *AuthRequestMutation) ResetCodeChallengeMethod() {
m.code_challenge_method = nil
}
// SetHmacKey sets the "hmac_key" field.
func (m *AuthRequestMutation) SetHmacKey(b []byte) {
m.hmac_key = &b
}
// HmacKey returns the value of the "hmac_key" field in the mutation.
func (m *AuthRequestMutation) HmacKey() (r []byte, exists bool) {
v := m.hmac_key
if v == nil {
return
}
return *v, true
}
// OldHmacKey returns the old "hmac_key" field's value of the AuthRequest entity.
// If the AuthRequest object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *AuthRequestMutation) OldHmacKey(ctx context.Context) (v []byte, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldHmacKey is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldHmacKey requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldHmacKey: %w", err)
}
return oldValue.HmacKey, nil
}
// ResetHmacKey resets all changes to the "hmac_key" field.
func (m *AuthRequestMutation) ResetHmacKey() {
m.hmac_key = nil
}
// Where appends a list predicates to the AuthRequestMutation builder.
func (m *AuthRequestMutation) Where(ps ...predicate.AuthRequest) {
m.predicates = append(m.predicates, ps...)
@@ -2070,7 +2107,7 @@ func (m *AuthRequestMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *AuthRequestMutation) Fields() []string {
fields := make([]string, 0, 19)
fields := make([]string, 0, 20)
if m.client_id != nil {
fields = append(fields, authrequest.FieldClientID)
}
@@ -2128,6 +2165,9 @@ func (m *AuthRequestMutation) Fields() []string {
if m.code_challenge_method != nil {
fields = append(fields, authrequest.FieldCodeChallengeMethod)
}
if m.hmac_key != nil {
fields = append(fields, authrequest.FieldHmacKey)
}
return fields
}
@@ -2174,6 +2214,8 @@ func (m *AuthRequestMutation) Field(name string) (ent.Value, bool) {
return m.CodeChallenge()
case authrequest.FieldCodeChallengeMethod:
return m.CodeChallengeMethod()
case authrequest.FieldHmacKey:
return m.HmacKey()
}
return nil, false
}
@@ -2221,6 +2263,8 @@ func (m *AuthRequestMutation) OldField(ctx context.Context, name string) (ent.Va
return m.OldCodeChallenge(ctx)
case authrequest.FieldCodeChallengeMethod:
return m.OldCodeChallengeMethod(ctx)
case authrequest.FieldHmacKey:
return m.OldHmacKey(ctx)
}
return nil, fmt.Errorf("unknown AuthRequest field %s", name)
}
@@ -2363,6 +2407,13 @@ func (m *AuthRequestMutation) SetField(name string, value ent.Value) error {
}
m.SetCodeChallengeMethod(v)
return nil
case authrequest.FieldHmacKey:
v, ok := value.([]byte)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetHmacKey(v)
return nil
}
return fmt.Errorf("unknown AuthRequest field %s", name)
}
@@ -2496,6 +2547,9 @@ func (m *AuthRequestMutation) ResetField(name string) error {
case authrequest.FieldCodeChallengeMethod:
m.ResetCodeChallengeMethod()
return nil
case authrequest.FieldHmacKey:
m.ResetHmacKey()
return nil
}
return fmt.Errorf("unknown AuthRequest field %s", name)
}

View File

@@ -27,7 +27,8 @@ create table auth_request
expiry timestamp not null,
claims_preferred_username text default '' not null,
code_challenge text default '' not null,
code_challenge_method text default '' not null
code_challenge_method text default '' not null,
hmac_key blob
);
*/
@@ -86,6 +87,7 @@ func (AuthRequest) Fields() []ent.Field {
field.Text("code_challenge_method").
SchemaType(textSchema).
Default(""),
field.Bytes("hmac_key"),
}
}