storage: Surface "already exists" errors.

This commit is contained in:
rithu john
2017-02-21 15:00:22 -08:00
parent 7e9dc836eb
commit 3df1db1864
7 changed files with 119 additions and 18 deletions

View File

@@ -8,6 +8,13 @@ import (
"github.com/Sirupsen/logrus"
"github.com/coreos/dex/storage"
"github.com/lib/pq"
sqlite3 "github.com/mattn/go-sqlite3"
)
const (
// postgres error codes
pgErrUniqueViolation = "23505" // unique_violation
)
// SQLite3 options for creating an SQL db.
@@ -35,7 +42,16 @@ func (s *SQLite3) open(logger logrus.FieldLogger) (*conn, error) {
// doesn't support this, so limit the number of connections to 1.
db.SetMaxOpenConns(1)
}
c := &conn{db, flavorSQLite3, logger}
errCheck := func(err error) bool {
sqlErr, ok := err.(sqlite3.Error)
if !ok {
return false
}
return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey
}
c := &conn{db, flavorSQLite3, logger, errCheck}
if _, err := c.migrate(); err != nil {
return nil, fmt.Errorf("failed to perform migrations: %v", err)
}
@@ -114,7 +130,16 @@ func (p *Postgres) open(logger logrus.FieldLogger) (*conn, error) {
if err != nil {
return nil, err
}
c := &conn{db, flavorPostgres, logger}
errCheck := func(err error) bool {
sqlErr, ok := err.(*pq.Error)
if !ok {
return false
}
return sqlErr.Code == pgErrUniqueViolation
}
c := &conn{db, flavorPostgres, logger, errCheck}
if _, err := c.migrate(); err != nil {
return nil, fmt.Errorf("failed to perform migrations: %v", err)
}

View File

@@ -125,6 +125,9 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
a.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert auth request: %v", err)
}
return nil
@@ -212,7 +215,14 @@ func (c *conn) CreateAuthCode(a storage.AuthCode) error {
a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData, a.Expiry,
)
return err
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert auth code: %v", err)
}
return nil
}
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
@@ -256,6 +266,9 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
r.Token, r.CreatedAt, r.LastUsed,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert refresh_token: %v", err)
}
return nil
@@ -477,6 +490,9 @@ func (c *conn) CreateClient(cli storage.Client) error {
cli.Public, cli.Name, cli.LogoURL,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert client: %v", err)
}
return nil
@@ -544,6 +560,9 @@ func (c *conn) CreatePassword(p storage.Password) error {
p.Email, p.Hash, p.Username, p.UserID,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert password: %v", err)
}
return nil
@@ -636,6 +655,9 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
s.UserID, s.ConnID, encoder(s.Refresh),
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert offline session: %v", err)
}
return nil

View File

@@ -6,6 +6,7 @@ import (
"testing"
"github.com/Sirupsen/logrus"
sqlite3 "github.com/mattn/go-sqlite3"
)
func TestMigrate(t *testing.T) {
@@ -21,7 +22,15 @@ func TestMigrate(t *testing.T) {
Level: logrus.DebugLevel,
}
c := &conn{db, flavorSQLite3, logger}
errCheck := func(err error) bool {
sqlErr, ok := err.(sqlite3.Error)
if !ok {
return false
}
return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique
}
c := &conn{db, flavorSQLite3, logger, errCheck}
for _, want := range []int{len(migrations), 0} {
got, err := c.migrate()
if err != nil {

View File

@@ -131,9 +131,10 @@ func (c *conn) translateArgs(args []interface{}) []interface{} {
// conn is the main database connection.
type conn struct {
db *sql.DB
flavor flavor
logger logrus.FieldLogger
db *sql.DB
flavor flavor
logger logrus.FieldLogger
alreadyExistsCheck func(err error) bool
}
func (c *conn) Close() error {