diff --git a/server/handlers.go b/server/handlers.go index 5512d87f..8df7aa15 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -661,7 +661,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe return } - idToken, idTokenExpiry, err = s.newIDToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, authReq.ConnectorID) + idToken, idTokenExpiry, err = s.newIDToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -791,7 +791,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s return } - idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID) + idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1060,7 +1060,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID) + idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1243,7 +1243,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli } accessToken := storage.NewID() - idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, nonce, accessToken, connID) + idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, nonce, accessToken, "", connID) if err != nil { s.tokenErrHelper(w, errServerError, fmt.Sprintf("failed to create ID token: %v", err), http.StatusInternalServerError) return diff --git a/server/oauth2.go b/server/oauth2.go index 05dd25d2..1a1a56d1 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -217,11 +217,11 @@ func accessTokenHash(alg jose.SignatureAlgorithm, accessToken string) (string, e return "", fmt.Errorf("unsupported signature algorithm: %s", alg) } - hash := newHash() - if _, err := io.WriteString(hash, accessToken); err != nil { + hashFunc := newHash() + if _, err := io.WriteString(hashFunc, accessToken); err != nil { return "", fmt.Errorf("computing hash: %v", err) } - sum := hash.Sum(nil) + sum := hashFunc.Sum(nil) return base64.RawURLEncoding.EncodeToString(sum[:len(sum)/2]), nil } @@ -253,6 +253,7 @@ type idTokenClaims struct { Nonce string `json:"nonce,omitempty"` AccessTokenHash string `json:"at_hash,omitempty"` + CodeHash string `json:"c_hash,omitempty"` Email string `json:"email,omitempty"` EmailVerified *bool `json:"email_verified,omitempty"` @@ -271,11 +272,11 @@ type federatedIDClaims struct { } func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, err error) { - idToken, _, err := s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), connID) + idToken, _, err := s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), "", connID) return idToken, err } -func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, connID string) (idToken string, expiry time.Time, err error) { +func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) { keys, err := s.storage.GetKeys() if err != nil { s.logger.Errorf("Failed to get keys: %v", err) @@ -322,6 +323,15 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str tok.AccessTokenHash = atHash } + if code != "" { + cHash, err := accessTokenHash(signingAlg, code) + if err != nil { + s.logger.Errorf("error computing c_hash: %v", err) + return "", expiry, fmt.Errorf("error computing c_hash: #{err}") + } + tok.CodeHash = cHash + } + for _, scope := range scopes { switch { case scope == scopeEmail: