diff --git a/server/handlers.go b/server/handlers.go index 5f8caf11..11dcdd07 100755 --- a/server/handlers.go +++ b/server/handlers.go @@ -1,6 +1,7 @@ package server import ( + "crypto/hmac" "crypto/sha256" "crypto/subtle" "encoding/base64" @@ -499,7 +500,15 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth s.logger.Infof("login successful: connector %q, username=%q, preferred_username=%q, email=%q, groups=%q", authReq.ConnectorID, claims.Username, claims.PreferredUsername, email, claims.Groups) - returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID + // TODO: if s.skipApproval or !authReq.ForceApprovalPrompt, we can skip the redirect to /approval and go ahead and send code + + // an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original + // flow would be unable to poll for the result at the /approval endpoint + h := hmac.New(sha256.New, authReq.HMACKey) + h.Write([]byte(authReq.ID)) + mac := h.Sum(nil) + + returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID + "&hmac=" + base64.RawURLEncoding.EncodeToString(mac) _, ok := conn.(connector.RefreshConnector) if !ok { return returnURL, nil @@ -544,6 +553,17 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth } func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { + macEncoded := r.FormValue("hmac") + if macEncoded == "" { + s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request") + return + } + mac, err := base64.RawURLEncoding.DecodeString(macEncoded) + if err != nil { + s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request") + return + } + authReq, err := s.storage.GetAuthRequest(r.FormValue("req")) if err != nil { s.logger.Errorf("Failed to get auth request: %v", err) @@ -556,6 +576,16 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { return } + // build expected hmac with secret key + h := hmac.New(sha256.New, authReq.HMACKey) + h.Write([]byte(authReq.ID)) + expectedMAC := h.Sum(nil) + // constant time comparison + if !hmac.Equal(mac, expectedMAC) { + s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request") + return + } + switch r.Method { case http.MethodGet: if s.skipApproval { diff --git a/server/oauth2.go b/server/oauth2.go index 998bf02a..67223fa1 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -2,6 +2,7 @@ package server import ( "context" + "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" @@ -576,6 +577,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques CodeChallenge: codeChallenge, CodeChallengeMethod: codeChallengeMethod, }, + HMACKey: storage.NewHMACKey(crypto.SHA256), }, nil } diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 9d9766eb..1b45b76c 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -104,7 +104,8 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { EmailVerified: true, Groups: []string{"a", "b"}, }, - PKCE: codeChallenge, + PKCE: codeChallenge, + HMACKey: []byte("hmac_key"), } identity := storage.Claims{Email: "foobar"} @@ -137,6 +138,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { EmailVerified: true, Groups: []string{"a"}, }, + HMACKey: []byte("hmac_key"), } if err := s.CreateAuthRequest(a2); err != nil { @@ -817,6 +819,7 @@ func testGC(t *testing.T, s storage.Storage) { EmailVerified: true, Groups: []string{"a", "b"}, }, + HMACKey: []byte("hmac_key"), } if err := s.CreateAuthRequest(a); err != nil { diff --git a/storage/ent/client/authrequest.go b/storage/ent/client/authrequest.go index bde37adc..d68fd438 100644 --- a/storage/ent/client/authrequest.go +++ b/storage/ent/client/authrequest.go @@ -31,6 +31,7 @@ func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error { SetExpiry(authRequest.Expiry.UTC()). SetConnectorID(authRequest.ConnectorID). SetConnectorData(authRequest.ConnectorData). + SetHmacKey(authRequest.HMACKey). Save(context.TODO()) if err != nil { return convertDBError("create auth request: %w", err) @@ -94,6 +95,7 @@ func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthReq SetExpiry(newAuthRequest.Expiry.UTC()). SetConnectorID(newAuthRequest.ConnectorID). SetConnectorData(newAuthRequest.ConnectorData). + SetHmacKey(newAuthRequest.HMACKey). Save(context.TODO()) if err != nil { return rollback(tx, "update auth request uploading: %w", err) diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go index 256bb73d..397d4d30 100644 --- a/storage/ent/client/types.go +++ b/storage/ent/client/types.go @@ -45,6 +45,7 @@ func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest { CodeChallenge: a.CodeChallenge, CodeChallengeMethod: a.CodeChallengeMethod, }, + HMACKey: a.HmacKey, } } diff --git a/storage/ent/db/authrequest.go b/storage/ent/db/authrequest.go index ecef32d7..095427ae 100644 --- a/storage/ent/db/authrequest.go +++ b/storage/ent/db/authrequest.go @@ -55,6 +55,8 @@ type AuthRequest struct { CodeChallenge string `json:"code_challenge,omitempty"` // CodeChallengeMethod holds the value of the "code_challenge_method" field. CodeChallengeMethod string `json:"code_challenge_method,omitempty"` + // HmacKey holds the value of the "hmac_key" field. + HmacKey []byte `json:"hmac_key,omitempty"` } // scanValues returns the types for scanning values from sql.Rows. @@ -62,7 +64,7 @@ func (*AuthRequest) scanValues(columns []string) ([]interface{}, error) { values := make([]interface{}, len(columns)) for i := range columns { switch columns[i] { - case authrequest.FieldScopes, authrequest.FieldResponseTypes, authrequest.FieldClaimsGroups, authrequest.FieldConnectorData: + case authrequest.FieldScopes, authrequest.FieldResponseTypes, authrequest.FieldClaimsGroups, authrequest.FieldConnectorData, authrequest.FieldHmacKey: values[i] = new([]byte) case authrequest.FieldForceApprovalPrompt, authrequest.FieldLoggedIn, authrequest.FieldClaimsEmailVerified: values[i] = new(sql.NullBool) @@ -211,6 +213,12 @@ func (ar *AuthRequest) assignValues(columns []string, values []interface{}) erro } else if value.Valid { ar.CodeChallengeMethod = value.String } + case authrequest.FieldHmacKey: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field hmac_key", values[i]) + } else if value != nil { + ar.HmacKey = *value + } } } return nil @@ -297,6 +305,8 @@ func (ar *AuthRequest) String() string { builder.WriteString(", ") builder.WriteString("code_challenge_method=") builder.WriteString(ar.CodeChallengeMethod) + builder.WriteString(", hmac_key=") + builder.WriteString(fmt.Sprintf("%v", ar.HmacKey)) builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/authrequest/authrequest.go b/storage/ent/db/authrequest/authrequest.go index 528d26b6..537f631e 100644 --- a/storage/ent/db/authrequest/authrequest.go +++ b/storage/ent/db/authrequest/authrequest.go @@ -45,6 +45,8 @@ const ( FieldCodeChallenge = "code_challenge" // FieldCodeChallengeMethod holds the string denoting the code_challenge_method field in the database. FieldCodeChallengeMethod = "code_challenge_method" + // FieldHmacKey holds the string denoting the hmac_key field in the database. + FieldHmacKey = "hmac_key" // Table holds the table name of the authrequest in the database. Table = "auth_requests" ) @@ -71,6 +73,7 @@ var Columns = []string{ FieldExpiry, FieldCodeChallenge, FieldCodeChallengeMethod, + FieldHmacKey, } // ValidColumn reports if the column name is valid (part of the table columns). diff --git a/storage/ent/db/authrequest/where.go b/storage/ent/db/authrequest/where.go index e31cdcee..1fd1d4e4 100644 --- a/storage/ent/db/authrequest/where.go +++ b/storage/ent/db/authrequest/where.go @@ -192,6 +192,13 @@ func CodeChallengeMethod(v string) predicate.AuthRequest { }) } +// HmacKey applies equality check predicate on the "hmac_key" field. It's identical to HmacKeyEQ. +func HmacKey(v []byte) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldHmacKey), v)) + }) +} + // ClientIDEQ applies the EQ predicate on the "client_id" field. func ClientIDEQ(v string) predicate.AuthRequest { return predicate.AuthRequest(func(s *sql.Selector) { @@ -1507,6 +1514,82 @@ func CodeChallengeMethodContainsFold(v string) predicate.AuthRequest { }) } +// HmacKeyEQ applies the EQ predicate on the "hmac_key" field. +func HmacKeyEQ(v []byte) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldHmacKey), v)) + }) +} + +// HmacKeyNEQ applies the NEQ predicate on the "hmac_key" field. +func HmacKeyNEQ(v []byte) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldHmacKey), v)) + }) +} + +// HmacKeyIn applies the In predicate on the "hmac_key" field. +func HmacKeyIn(vs ...[]byte) predicate.AuthRequest { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.AuthRequest(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldHmacKey), v...)) + }) +} + +// HmacKeyNotIn applies the NotIn predicate on the "hmac_key" field. +func HmacKeyNotIn(vs ...[]byte) predicate.AuthRequest { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.AuthRequest(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldHmacKey), v...)) + }) +} + +// HmacKeyGT applies the GT predicate on the "hmac_key" field. +func HmacKeyGT(v []byte) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldHmacKey), v)) + }) +} + +// HmacKeyGTE applies the GTE predicate on the "hmac_key" field. +func HmacKeyGTE(v []byte) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldHmacKey), v)) + }) +} + +// HmacKeyLT applies the LT predicate on the "hmac_key" field. +func HmacKeyLT(v []byte) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldHmacKey), v)) + }) +} + +// HmacKeyLTE applies the LTE predicate on the "hmac_key" field. +func HmacKeyLTE(v []byte) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldHmacKey), v)) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.AuthRequest) predicate.AuthRequest { return predicate.AuthRequest(func(s *sql.Selector) { diff --git a/storage/ent/db/authrequest_create.go b/storage/ent/db/authrequest_create.go index 63e0357d..c353c182 100644 --- a/storage/ent/db/authrequest_create.go +++ b/storage/ent/db/authrequest_create.go @@ -158,6 +158,12 @@ func (arc *AuthRequestCreate) SetNillableCodeChallengeMethod(s *string) *AuthReq return arc } +// SetHmacKey sets the "hmac_key" field. +func (arc *AuthRequestCreate) SetHmacKey(b []byte) *AuthRequestCreate { + arc.mutation.SetHmacKey(b) + return arc +} + // SetID sets the "id" field. func (arc *AuthRequestCreate) SetID(s string) *AuthRequestCreate { arc.mutation.SetID(s) @@ -302,6 +308,9 @@ func (arc *AuthRequestCreate) check() error { if _, ok := arc.mutation.CodeChallengeMethod(); !ok { return &ValidationError{Name: "code_challenge_method", err: errors.New(`db: missing required field "AuthRequest.code_challenge_method"`)} } + if _, ok := arc.mutation.HmacKey(); !ok { + return &ValidationError{Name: "hmac_key", err: errors.New(`db: missing required field "AuthRequest.hmac_key"`)} + } if v, ok := arc.mutation.ID(); ok { if err := authrequest.IDValidator(v); err != nil { return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "AuthRequest.id": %w`, err)} @@ -495,6 +504,14 @@ func (arc *AuthRequestCreate) createSpec() (*AuthRequest, *sqlgraph.CreateSpec) }) _node.CodeChallengeMethod = value } + if value, ok := arc.mutation.HmacKey(); ok { + _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ + Type: field.TypeBytes, + Value: value, + Column: authrequest.FieldHmacKey, + }) + _node.HmacKey = value + } return _node, _spec } diff --git a/storage/ent/db/authrequest_update.go b/storage/ent/db/authrequest_update.go index 22dbb24b..f0bf9b34 100644 --- a/storage/ent/db/authrequest_update.go +++ b/storage/ent/db/authrequest_update.go @@ -190,6 +190,12 @@ func (aru *AuthRequestUpdate) SetNillableCodeChallengeMethod(s *string) *AuthReq return aru } +// SetHmacKey sets the "hmac_key" field. +func (aru *AuthRequestUpdate) SetHmacKey(b []byte) *AuthRequestUpdate { + aru.mutation.SetHmacKey(b) + return aru +} + // Mutation returns the AuthRequestMutation object of the builder. func (aru *AuthRequestUpdate) Mutation() *AuthRequestMutation { return aru.mutation @@ -424,6 +430,13 @@ func (aru *AuthRequestUpdate) sqlSave(ctx context.Context) (n int, err error) { Column: authrequest.FieldCodeChallengeMethod, }) } + if value, ok := aru.mutation.HmacKey(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeBytes, + Value: value, + Column: authrequest.FieldHmacKey, + }) + } if n, err = sqlgraph.UpdateNodes(ctx, aru.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authrequest.Label} @@ -605,6 +618,12 @@ func (aruo *AuthRequestUpdateOne) SetNillableCodeChallengeMethod(s *string) *Aut return aruo } +// SetHmacKey sets the "hmac_key" field. +func (aruo *AuthRequestUpdateOne) SetHmacKey(b []byte) *AuthRequestUpdateOne { + aruo.mutation.SetHmacKey(b) + return aruo +} + // Mutation returns the AuthRequestMutation object of the builder. func (aruo *AuthRequestUpdateOne) Mutation() *AuthRequestMutation { return aruo.mutation @@ -869,6 +888,13 @@ func (aruo *AuthRequestUpdateOne) sqlSave(ctx context.Context) (_node *AuthReque Column: authrequest.FieldCodeChallengeMethod, }) } + if value, ok := aruo.mutation.HmacKey(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeBytes, + Value: value, + Column: authrequest.FieldHmacKey, + }) + } _node = &AuthRequest{config: aruo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index ced57b32..d3295a0c 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -55,6 +55,7 @@ var ( {Name: "expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "hmac_key", Type: field.TypeBytes}, } // AuthRequestsTable holds the schema information for the "auth_requests" table. AuthRequestsTable = &schema.Table{ diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index af12a242..85e1af72 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -1205,6 +1205,7 @@ type AuthRequestMutation struct { expiry *time.Time code_challenge *string code_challenge_method *string + hmac_key *[]byte clearedFields map[string]struct{} done bool oldValue func(context.Context) (*AuthRequest, error) @@ -2051,6 +2052,42 @@ func (m *AuthRequestMutation) ResetCodeChallengeMethod() { m.code_challenge_method = nil } +// SetHmacKey sets the "hmac_key" field. +func (m *AuthRequestMutation) SetHmacKey(b []byte) { + m.hmac_key = &b +} + +// HmacKey returns the value of the "hmac_key" field in the mutation. +func (m *AuthRequestMutation) HmacKey() (r []byte, exists bool) { + v := m.hmac_key + if v == nil { + return + } + return *v, true +} + +// OldHmacKey returns the old "hmac_key" field's value of the AuthRequest entity. +// If the AuthRequest object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthRequestMutation) OldHmacKey(ctx context.Context) (v []byte, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldHmacKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldHmacKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldHmacKey: %w", err) + } + return oldValue.HmacKey, nil +} + +// ResetHmacKey resets all changes to the "hmac_key" field. +func (m *AuthRequestMutation) ResetHmacKey() { + m.hmac_key = nil +} + // Where appends a list predicates to the AuthRequestMutation builder. func (m *AuthRequestMutation) Where(ps ...predicate.AuthRequest) { m.predicates = append(m.predicates, ps...) @@ -2070,7 +2107,7 @@ func (m *AuthRequestMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AuthRequestMutation) Fields() []string { - fields := make([]string, 0, 19) + fields := make([]string, 0, 20) if m.client_id != nil { fields = append(fields, authrequest.FieldClientID) } @@ -2128,6 +2165,9 @@ func (m *AuthRequestMutation) Fields() []string { if m.code_challenge_method != nil { fields = append(fields, authrequest.FieldCodeChallengeMethod) } + if m.hmac_key != nil { + fields = append(fields, authrequest.FieldHmacKey) + } return fields } @@ -2174,6 +2214,8 @@ func (m *AuthRequestMutation) Field(name string) (ent.Value, bool) { return m.CodeChallenge() case authrequest.FieldCodeChallengeMethod: return m.CodeChallengeMethod() + case authrequest.FieldHmacKey: + return m.HmacKey() } return nil, false } @@ -2221,6 +2263,8 @@ func (m *AuthRequestMutation) OldField(ctx context.Context, name string) (ent.Va return m.OldCodeChallenge(ctx) case authrequest.FieldCodeChallengeMethod: return m.OldCodeChallengeMethod(ctx) + case authrequest.FieldHmacKey: + return m.OldHmacKey(ctx) } return nil, fmt.Errorf("unknown AuthRequest field %s", name) } @@ -2363,6 +2407,13 @@ func (m *AuthRequestMutation) SetField(name string, value ent.Value) error { } m.SetCodeChallengeMethod(v) return nil + case authrequest.FieldHmacKey: + v, ok := value.([]byte) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetHmacKey(v) + return nil } return fmt.Errorf("unknown AuthRequest field %s", name) } @@ -2496,6 +2547,9 @@ func (m *AuthRequestMutation) ResetField(name string) error { case authrequest.FieldCodeChallengeMethod: m.ResetCodeChallengeMethod() return nil + case authrequest.FieldHmacKey: + m.ResetHmacKey() + return nil } return fmt.Errorf("unknown AuthRequest field %s", name) } diff --git a/storage/ent/schema/authrequest.go b/storage/ent/schema/authrequest.go index 7d41e830..2b75927b 100644 --- a/storage/ent/schema/authrequest.go +++ b/storage/ent/schema/authrequest.go @@ -27,7 +27,8 @@ create table auth_request expiry timestamp not null, claims_preferred_username text default '' not null, code_challenge text default '' not null, - code_challenge_method text default '' not null + code_challenge_method text default '' not null, + hmac_key blob ); */ @@ -86,6 +87,7 @@ func (AuthRequest) Fields() []ent.Field { field.Text("code_challenge_method"). SchemaType(textSchema). Default(""), + field.Bytes("hmac_key"), } } diff --git a/storage/etcd/types.go b/storage/etcd/types.go index 1174a2d2..91199ab6 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -84,6 +84,8 @@ type AuthRequest struct { CodeChallenge string `json:"code_challenge,omitempty"` CodeChallengeMethod string `json:"code_challenge_method,omitempty"` + + HMACKey []byte `json:"hmac_key"` } func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { @@ -103,6 +105,7 @@ func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { ConnectorData: a.ConnectorData, CodeChallenge: a.PKCE.CodeChallenge, CodeChallengeMethod: a.PKCE.CodeChallengeMethod, + HMACKey: a.HMACKey, } } @@ -125,6 +128,7 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest { CodeChallenge: a.CodeChallenge, CodeChallengeMethod: a.CodeChallengeMethod, }, + HMACKey: a.HMACKey, } } diff --git a/storage/health.go b/storage/health.go index 8a2f5a3d..1b6e22c6 100644 --- a/storage/health.go +++ b/storage/health.go @@ -2,6 +2,7 @@ package storage import ( "context" + "crypto" "fmt" "time" ) @@ -14,7 +15,8 @@ func NewCustomHealthCheckFunc(s Storage, now func() time.Time) func(context.Cont ClientID: NewID(), // Set a short expiry so if the delete fails this will be cleaned up quickly by garbage collection. - Expiry: now().Add(time.Minute), + Expiry: now().Add(time.Minute), + HMACKey: NewHMACKey(crypto.SHA256), } if err := s.CreateAuthRequest(a); err != nil { diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 5149e3ee..a5ec29af 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -356,6 +356,8 @@ type AuthRequest struct { CodeChallenge string `json:"code_challenge,omitempty"` CodeChallengeMethod string `json:"code_challenge_method,omitempty"` + + HMACKey []byte `json:"hmac_key"` } // AuthRequestList is a list of AuthRequests. @@ -384,6 +386,7 @@ func toStorageAuthRequest(req AuthRequest) storage.AuthRequest { CodeChallenge: req.CodeChallenge, CodeChallengeMethod: req.CodeChallengeMethod, }, + HMACKey: req.HMACKey, } return a } @@ -412,6 +415,7 @@ func (cli *client) fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { Claims: fromStorageClaims(a.Claims), CodeChallenge: a.PKCE.CodeChallenge, CodeChallengeMethod: a.PKCE.CodeChallengeMethod, + HMACKey: a.HMACKey, } return req } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index ac67bf28..1583c177 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -131,10 +131,11 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { claims_email, claims_email_verified, claims_groups, connector_id, connector_data, expiry, - code_challenge, code_challenge_method + code_challenge, code_challenge_method, + hmac_key ) values ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20 + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21 ); `, a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, @@ -144,6 +145,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { a.ConnectorID, a.ConnectorData, a.Expiry, a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, + a.HMACKey, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -175,8 +177,9 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) claims_groups = $14, connector_id = $15, connector_data = $16, expiry = $17, - code_challenge = $18, code_challenge_method = $19 - where id = $20; + code_challenge = $18, code_challenge_method = $19, + hmac_key = $20 + where id = $21; `, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, a.ForceApprovalPrompt, a.LoggedIn, @@ -185,7 +188,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, a.Expiry, - a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, + a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, a.HMACKey, r.ID, ) if err != nil { @@ -207,7 +210,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, expiry, - code_challenge, code_challenge_method + code_challenge, code_challenge_method, hmac_key from auth_request where id = $1; `, id).Scan( &a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State, @@ -216,7 +219,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups), &a.ConnectorID, &a.ConnectorData, &a.Expiry, - &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, + &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, &a.HMACKey, ) if err != nil { if err == sql.ErrNoRows { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 57720e17..83e9c20d 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -291,4 +291,11 @@ var migrations = []migration{ add column code_challenge_method text not null default '';`, }, }, + { + stmts: []string{ + ` + alter table auth_request + add column hmac_key bytea;`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index af39228a..198a70c8 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,6 +1,7 @@ package storage import ( + "crypto" "crypto/rand" "encoding/base32" "errors" @@ -47,6 +48,11 @@ func newSecureID(len int) string { return string(buff[0]%26+'a') + strings.TrimRight(encoding.EncodeToString(buff[1:]), "=") } +// NewHMACKey returns a random key which can be used in the computation of an HMAC +func NewHMACKey(h crypto.Hash) []byte { + return []byte(newSecureID(h.Size())) +} + // GCResult returns the number of objects deleted by garbage collection. type GCResult struct { AuthRequests int64 @@ -223,6 +229,9 @@ type AuthRequest struct { // PKCE CodeChallenge and CodeChallengeMethod PKCE PKCE + + // HMACKey is used when generating an AuthRequest-specific HMAC + HMACKey []byte } // AuthCode represents a code which can be exchanged for an OAuth2 token response.