storage: Surface "already exists" errors.
This commit is contained in:
@@ -8,6 +8,13 @@ import (
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/coreos/dex/storage"
|
||||
"github.com/lib/pq"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
const (
|
||||
// postgres error codes
|
||||
pgErrUniqueViolation = "23505" // unique_violation
|
||||
)
|
||||
|
||||
// SQLite3 options for creating an SQL db.
|
||||
@@ -35,7 +42,16 @@ func (s *SQLite3) open(logger logrus.FieldLogger) (*conn, error) {
|
||||
// doesn't support this, so limit the number of connections to 1.
|
||||
db.SetMaxOpenConns(1)
|
||||
}
|
||||
c := &conn{db, flavorSQLite3, logger}
|
||||
|
||||
errCheck := func(err error) bool {
|
||||
sqlErr, ok := err.(sqlite3.Error)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey
|
||||
}
|
||||
|
||||
c := &conn{db, flavorSQLite3, logger, errCheck}
|
||||
if _, err := c.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
||||
}
|
||||
@@ -114,7 +130,16 @@ func (p *Postgres) open(logger logrus.FieldLogger) (*conn, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &conn{db, flavorPostgres, logger}
|
||||
|
||||
errCheck := func(err error) bool {
|
||||
sqlErr, ok := err.(*pq.Error)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return sqlErr.Code == pgErrUniqueViolation
|
||||
}
|
||||
|
||||
c := &conn{db, flavorPostgres, logger, errCheck}
|
||||
if _, err := c.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
||||
}
|
||||
|
@@ -125,6 +125,9 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
||||
a.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert auth request: %v", err)
|
||||
}
|
||||
return nil
|
||||
@@ -212,7 +215,14 @@ func (c *conn) CreateAuthCode(a storage.AuthCode) error {
|
||||
a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
|
||||
a.ConnectorID, a.ConnectorData, a.Expiry,
|
||||
)
|
||||
return err
|
||||
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert auth code: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
|
||||
@@ -256,6 +266,9 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
||||
r.Token, r.CreatedAt, r.LastUsed,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert refresh_token: %v", err)
|
||||
}
|
||||
return nil
|
||||
@@ -477,6 +490,9 @@ func (c *conn) CreateClient(cli storage.Client) error {
|
||||
cli.Public, cli.Name, cli.LogoURL,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert client: %v", err)
|
||||
}
|
||||
return nil
|
||||
@@ -544,6 +560,9 @@ func (c *conn) CreatePassword(p storage.Password) error {
|
||||
p.Email, p.Hash, p.Username, p.UserID,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert password: %v", err)
|
||||
}
|
||||
return nil
|
||||
@@ -636,6 +655,9 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
||||
s.UserID, s.ConnID, encoder(s.Refresh),
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert offline session: %v", err)
|
||||
}
|
||||
return nil
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
@@ -21,7 +22,15 @@ func TestMigrate(t *testing.T) {
|
||||
Level: logrus.DebugLevel,
|
||||
}
|
||||
|
||||
c := &conn{db, flavorSQLite3, logger}
|
||||
errCheck := func(err error) bool {
|
||||
sqlErr, ok := err.(sqlite3.Error)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique
|
||||
}
|
||||
|
||||
c := &conn{db, flavorSQLite3, logger, errCheck}
|
||||
for _, want := range []int{len(migrations), 0} {
|
||||
got, err := c.migrate()
|
||||
if err != nil {
|
||||
|
@@ -131,9 +131,10 @@ func (c *conn) translateArgs(args []interface{}) []interface{} {
|
||||
|
||||
// conn is the main database connection.
|
||||
type conn struct {
|
||||
db *sql.DB
|
||||
flavor flavor
|
||||
logger logrus.FieldLogger
|
||||
db *sql.DB
|
||||
flavor flavor
|
||||
logger logrus.FieldLogger
|
||||
alreadyExistsCheck func(err error) bool
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
|
Reference in New Issue
Block a user