From 312ca7491e16107f23d0235be709c3e11473716e Mon Sep 17 00:00:00 2001
From: Eric Chiang
Date: Thu, 22 Dec 2016 15:56:09 -0800
Subject: [PATCH 1/4] storage: add extra fields to refresh token and update
method
---
storage/conformance/conformance.go | 26 +++++++++++--
storage/kubernetes/storage.go | 53 ++++++++++++-------------
storage/kubernetes/types.go | 42 ++++++++++++++++++++
storage/memory/memory.go | 18 ++++++++-
storage/sql/crud.go | 62 ++++++++++++++++++++++++++----
storage/sql/migrate.go | 10 +++++
storage/storage.go | 12 +++++-
7 files changed, 180 insertions(+), 43 deletions(-)
diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go
index 8cb911aa..0a6fe1c9 100644
--- a/storage/conformance/conformance.go
+++ b/storage/conformance/conformance.go
@@ -208,10 +208,14 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
id := storage.NewID()
refresh := storage.RefreshToken{
- RefreshToken: id,
- ClientID: "client_id",
- ConnectorID: "client_secret",
- Scopes: []string{"openid", "email", "profile"},
+ ID: id,
+ Token: "bar",
+ Nonce: "foo",
+ ClientID: "client_id",
+ ConnectorID: "client_secret",
+ Scopes: []string{"openid", "email", "profile"},
+ CreatedAt: time.Now().UTC().Round(time.Millisecond),
+ LastUsed: time.Now().UTC().Round(time.Millisecond),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
@@ -238,6 +242,20 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
getAndCompare(id, refresh)
+ updatedAt := time.Now().UTC().Round(time.Millisecond)
+
+ updater := func(r storage.RefreshToken) (storage.RefreshToken, error) {
+ r.Token = "spam"
+ r.LastUsed = updatedAt
+ return r, nil
+ }
+ if err := s.UpdateRefreshToken(id, updater); err != nil {
+ t.Errorf("failed to udpate refresh token: %v", err)
+ }
+ refresh.Token = "spam"
+ refresh.LastUsed = updatedAt
+ getAndCompare(id, refresh)
+
if err := s.DeleteRefresh(id); err != nil {
t.Fatalf("failed to delete refresh request: %v", err)
}
diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go
index e744ab2d..102a7494 100644
--- a/storage/kubernetes/storage.go
+++ b/storage/kubernetes/storage.go
@@ -153,23 +153,7 @@ func (cli *client) CreatePassword(p storage.Password) error {
}
func (cli *client) CreateRefresh(r storage.RefreshToken) error {
- refresh := RefreshToken{
- TypeMeta: k8sapi.TypeMeta{
- Kind: kindRefreshToken,
- APIVersion: cli.apiVersion,
- },
- ObjectMeta: k8sapi.ObjectMeta{
- Name: r.RefreshToken,
- Namespace: cli.namespace,
- },
- ClientID: r.ClientID,
- ConnectorID: r.ConnectorID,
- Scopes: r.Scopes,
- Nonce: r.Nonce,
- Claims: fromStorageClaims(r.Claims),
- ConnectorData: r.ConnectorData,
- }
- return cli.post(resourceRefreshToken, refresh)
+ return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r))
}
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
@@ -239,19 +223,16 @@ func (cli *client) GetKeys() (storage.Keys, error) {
}
func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) {
- var r RefreshToken
- if err := cli.get(resourceRefreshToken, id, &r); err != nil {
+ r, err := cli.getRefreshToken(id)
+ if err != nil {
return storage.RefreshToken{}, err
}
- return storage.RefreshToken{
- RefreshToken: r.ObjectMeta.Name,
- ClientID: r.ClientID,
- ConnectorID: r.ConnectorID,
- Scopes: r.Scopes,
- Nonce: r.Nonce,
- Claims: toStorageClaims(r.Claims),
- ConnectorData: r.ConnectorData,
- }, nil
+ return toStorageRefreshToken(r), nil
+}
+
+func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) {
+ err = cli.get(resourceRefreshToken, id, &r)
+ return
}
func (cli *client) ListClients() ([]storage.Client, error) {
@@ -311,6 +292,22 @@ func (cli *client) DeletePassword(email string) error {
return cli.delete(resourcePassword, p.ObjectMeta.Name)
}
+func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
+ r, err := cli.getRefreshToken(id)
+ if err != nil {
+ return err
+ }
+ updated, err := updater(toStorageRefreshToken(r))
+ if err != nil {
+ return err
+ }
+ updated.ID = id
+
+ newToken := cli.fromStorageRefreshToken(updated)
+ newToken.ObjectMeta = r.ObjectMeta
+ return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken)
+}
+
func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
c, err := cli.getClient(id)
if err != nil {
diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go
index 9009c800..660f86d8 100644
--- a/storage/kubernetes/types.go
+++ b/storage/kubernetes/types.go
@@ -362,9 +362,14 @@ type RefreshToken struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
+ CreatedAt time.Time
+ LastUsed time.Time
+
ClientID string `json:"clientID"`
Scopes []string `json:"scopes,omitempty"`
+ Token string `json:"token,omitempty"`
+
Nonce string `json:"nonce,omitempty"`
Claims Claims `json:"claims,omitempty"`
@@ -379,6 +384,43 @@ type RefreshList struct {
RefreshTokens []RefreshToken `json:"items"`
}
+func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
+ return storage.RefreshToken{
+ ID: r.ObjectMeta.Name,
+ Token: r.Token,
+ CreatedAt: r.CreatedAt,
+ LastUsed: r.LastUsed,
+ ClientID: r.ClientID,
+ ConnectorID: r.ConnectorID,
+ ConnectorData: r.ConnectorData,
+ Scopes: r.Scopes,
+ Nonce: r.Nonce,
+ Claims: toStorageClaims(r.Claims),
+ }
+}
+
+func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
+ return RefreshToken{
+ TypeMeta: k8sapi.TypeMeta{
+ Kind: kindRefreshToken,
+ APIVersion: cli.apiVersion,
+ },
+ ObjectMeta: k8sapi.ObjectMeta{
+ Name: r.ID,
+ Namespace: cli.namespace,
+ },
+ Token: r.Token,
+ CreatedAt: r.CreatedAt,
+ LastUsed: r.LastUsed,
+ ClientID: r.ClientID,
+ ConnectorID: r.ConnectorID,
+ ConnectorData: r.ConnectorData,
+ Scopes: r.Scopes,
+ Nonce: r.Nonce,
+ Claims: fromStorageClaims(r.Claims),
+ }
+}
+
// Keys is a mirrored struct from storage with JSON struct tags and Kubernetes
// type metadata.
type Keys struct {
diff --git a/storage/memory/memory.go b/storage/memory/memory.go
index 6d609717..8bfbdce2 100644
--- a/storage/memory/memory.go
+++ b/storage/memory/memory.go
@@ -98,10 +98,10 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) {
func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) {
s.tx(func() {
- if _, ok := s.refreshTokens[r.RefreshToken]; ok {
+ if _, ok := s.refreshTokens[r.ID]; ok {
err = storage.ErrAlreadyExists
} else {
- s.refreshTokens[r.RefreshToken] = r
+ s.refreshTokens[r.ID] = r
}
})
return
@@ -324,3 +324,17 @@ func (s *memStorage) UpdatePassword(email string, updater func(p storage.Passwor
})
return
}
+
+func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.RefreshToken) (storage.RefreshToken, error)) (err error) {
+ s.tx(func() {
+ r, ok := s.refreshTokens[id]
+ if !ok {
+ err = storage.ErrNotFound
+ return
+ }
+ if r, err = updater(r); err == nil {
+ s.refreshTokens[id] = r
+ }
+ })
+ return
+}
diff --git a/storage/sql/crud.go b/storage/sql/crud.go
index e3270363..494f1c20 100644
--- a/storage/sql/crud.go
+++ b/storage/sql/crud.go
@@ -244,14 +244,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
- connector_id, connector_data
+ connector_id, connector_data,
+ token, created_at, last_used
)
- values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11);
+ values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14);
`,
- r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce,
+ r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData,
+ r.Token, r.CreatedAt, r.LastUsed,
)
if err != nil {
return fmt.Errorf("insert refresh_token: %v", err)
@@ -259,13 +261,57 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
return nil
}
+func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
+ return c.ExecTx(func(tx *trans) error {
+ r, err := getRefresh(tx, id)
+ if err != nil {
+ return err
+ }
+ if r, err = updater(r); err != nil {
+ return err
+ }
+ _, err = tx.Exec(`
+ update refresh_token
+ set
+ client_id = $1,
+ scopes = $2,
+ nonce = $3,
+ claims_user_id = $4,
+ claims_username = $5,
+ claims_email = $6,
+ claims_email_verified = $7,
+ claims_groups = $8,
+ connector_id = $9,
+ connector_data = $10,
+ token = $11,
+ created_at = $12,
+ last_used = $13
+ `,
+ r.ClientID, encoder(r.Scopes), r.Nonce,
+ r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
+ encoder(r.Claims.Groups),
+ r.ConnectorID, r.ConnectorData,
+ r.Token, r.CreatedAt, r.LastUsed,
+ )
+ if err != nil {
+ return fmt.Errorf("update refresh token: %v", err)
+ }
+ return nil
+ })
+}
+
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
- return scanRefresh(c.QueryRow(`
+ return getRefresh(c, id)
+}
+
+func getRefresh(q querier, id string) (storage.RefreshToken, error) {
+ return scanRefresh(q.QueryRow(`
select
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
- connector_id, connector_data
+ connector_id, connector_data,
+ token, created_at, last_used
from refresh_token where id = $1;
`, id))
}
@@ -276,7 +322,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
- connector_id, connector_data
+ connector_id, connector_data,
+ token, created_at, last_used
from refresh_token;
`)
if err != nil {
@@ -298,10 +345,11 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
err = s.Scan(
- &r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
+ &r.ID, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified,
decoder(&r.Claims.Groups),
&r.ConnectorID, &r.ConnectorData,
+ &r.Token, &r.CreatedAt, &r.LastUsed,
)
if err != nil {
if err == sql.ErrNoRows {
diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go
index 3bb410aa..b2b66d39 100644
--- a/storage/sql/migrate.go
+++ b/storage/sql/migrate.go
@@ -155,4 +155,14 @@ var migrations = []migration{
);
`,
},
+ {
+ stmt: `
+ alter table refresh_token
+ add column token text not null default '';
+ alter table refresh_token
+ add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';
+ alter table refresh_token
+ add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
+ `,
+ },
}
diff --git a/storage/storage.go b/storage/storage.go
index 22a9ea50..47f5dcc6 100644
--- a/storage/storage.go
+++ b/storage/storage.go
@@ -94,6 +94,7 @@ type Storage interface {
UpdateClient(id string, updater func(old Client) (Client, error)) error
UpdateKeys(updater func(old Keys) (Keys, error)) error
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
+ UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error
UpdatePassword(email string, updater func(p Password) (Password, error)) error
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
@@ -216,8 +217,15 @@ type AuthCode struct {
// RefreshToken is an OAuth2 refresh token which allows a client to request new
// tokens on the end user's behalf.
type RefreshToken struct {
- // The actual refresh token.
- RefreshToken string
+ ID string
+
+ // A single token that's rotated every time the refresh token is refreshed.
+ //
+ // May be empty.
+ Token string
+
+ CreatedAt time.Time
+ LastUsed time.Time
// Client this refresh token is valid for.
ClientID string
From f778b2d33be336926be112b9b203c49b914492bf Mon Sep 17 00:00:00 2001
From: Eric Chiang
Date: Thu, 22 Dec 2016 16:41:30 -0800
Subject: [PATCH 2/4] server: update refresh tokens instead of deleting and
creating another
The server implements a strategy called "Refresh Token Rotation" to
ensure refresh tokens can only be claimed once.
ref: https://tools.ietf.org/html/rfc6819#section-5.2.2.3
Previously "refresh_token" values in token responses where just the
ID of the internal refresh object. To implement rotation, when a
client redeemed a refresh token, the object would be deleted, a new
one created, and the new ID returned as the new "refresh_token".
However, this means there was no consistent ID for refresh tokens
internally, making things like foreign keys very hard to implement.
This is problematic for revocation features like showing all the
refresh tokens a user or client has out.
This PR updates the "refresh_token" to be an encoded protobuf
message, which holds the internal ID and a nonce. When a refresh
token is used, the nonce is updated to prevent reuse, but the ID
remains the same. Additionally it adds the timestamp of each
token's last use.
---
Makefile | 9 ++-
server/handlers.go | 124 ++++++++++++++++++++++++++----------
server/internal/codec.go | 25 ++++++++
server/internal/types.proto | 10 +++
server/server_test.go | 4 ++
5 files changed, 134 insertions(+), 38 deletions(-)
create mode 100644 server/internal/codec.go
create mode 100644 server/internal/types.proto
diff --git a/Makefile b/Makefile
index 8006982a..9519a186 100644
--- a/Makefile
+++ b/Makefile
@@ -55,7 +55,7 @@ fmt:
@go fmt $(shell go list ./... | grep -v '/vendor/')
lint:
- @for package in $(shell go list ./... | grep -v '/vendor/' | grep -v '/api'); do \
+ @for package in $(shell go list ./... | grep -v '/vendor/' | grep -v '/api' | grep -v '/server/internal'); do \
golint -set_exit_status $$package $$i || exit 1; \
done
@@ -81,12 +81,15 @@ aci: clean-release _output/bin/dex _output/images/library-alpine-3.4.aci
docker-image: clean-release _output/bin/dex
@sudo docker build -t $(DOCKER_IMAGE) .
-.PHONY: grpc
-grpc: api/api.pb.go
+.PHONY: proto
+proto: api/api.pb.go server/internal/types.pb.go
api/api.pb.go: api/api.proto bin/protoc bin/protoc-gen-go
@protoc --go_out=plugins=grpc:. api/*.proto
+server/internal/types.pb.go: server/internal/types.proto bin/protoc bin/protoc-gen-go
+ @protoc --go_out=. server/internal/*.proto
+
bin/protoc: scripts/get-protoc
@./scripts/get-protoc bin/protoc
diff --git a/server/handlers.go b/server/handlers.go
index 808f8031..ef264dfe 100644
--- a/server/handlers.go
+++ b/server/handlers.go
@@ -2,6 +2,7 @@ package server
import (
"encoding/json"
+ "errors"
"fmt"
"net/http"
"net/url"
@@ -16,6 +17,7 @@ import (
jose "gopkg.in/square/go-jose.v2"
"github.com/coreos/dex/connector"
+ "github.com/coreos/dex/server/internal"
"github.com/coreos/dex/storage"
)
@@ -645,20 +647,32 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
var refreshToken string
if reqRefresh {
refresh := storage.RefreshToken{
- RefreshToken: storage.NewID(),
+ ID: storage.NewID(),
+ Token: storage.NewID(),
ClientID: authCode.ClientID,
ConnectorID: authCode.ConnectorID,
Scopes: authCode.Scopes,
Claims: authCode.Claims,
Nonce: authCode.Nonce,
ConnectorData: authCode.ConnectorData,
+ CreatedAt: s.now(),
+ LastUsed: s.now(),
}
+ token := &internal.RefreshToken{
+ RefreshId: refresh.ID,
+ Token: refresh.Token,
+ }
+ if refreshToken, err = internal.Marshal(token); err != nil {
+ s.logger.Errorf("failed to marshal refresh token: %v", err)
+ s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
+ return
+ }
+
if err := s.storage.CreateRefresh(refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
- refreshToken = refresh.RefreshToken
}
s.writeAccessToken(w, idToken, refreshToken, expiry)
}
@@ -672,16 +686,37 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return
}
- refresh, err := s.storage.GetRefresh(code)
- if err != nil || refresh.ClientID != client.ID {
- if err != storage.ErrNotFound {
- s.logger.Errorf("failed to get auth code: %v", err)
- s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
- } else {
+ token := new(internal.RefreshToken)
+ if err := internal.Unmarshal(code, token); err != nil {
+ // For backward compatibility, assume the refresh_token is a raw refresh token ID
+ // if it fails to decode.
+ //
+ // Because refresh_token values that aren't unmarshable were generated by servers
+ // that don't have a Token value, we'll still reject any attempts to claim a
+ // refresh_token twice.
+ token = &internal.RefreshToken{RefreshId: code, Token: ""}
+ }
+
+ refresh, err := s.storage.GetRefresh(token.RefreshId)
+ if err != nil {
+ s.logger.Errorf("failed to get refresh token: %v", err)
+ if err == storage.ErrNotFound {
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
+ } else {
+ s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
}
return
}
+ if refresh.ClientID != client.ID {
+ s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID)
+ s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
+ return
+ }
+ if refresh.Token != token.Token {
+ s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
+ s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
+ return
+ }
// Per the OAuth2 spec, if the client has omitted the scopes, default to the original
// authorized scopes.
@@ -720,6 +755,14 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
+ ident := connector.Identity{
+ UserID: refresh.Claims.UserID,
+ Username: refresh.Claims.Username,
+ Email: refresh.Claims.Email,
+ EmailVerified: refresh.Claims.EmailVerified,
+ Groups: refresh.Claims.Groups,
+ ConnectorData: refresh.ConnectorData,
+ }
// Can the connector refresh the identity? If so, attempt to refresh the data
// in the connector.
@@ -727,52 +770,63 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
// TODO(ericchiang): We may want a strict mode where connectors that don't implement
// this interface can't perform refreshing.
if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
- ident := connector.Identity{
- UserID: refresh.Claims.UserID,
- Username: refresh.Claims.Username,
- Email: refresh.Claims.Email,
- EmailVerified: refresh.Claims.EmailVerified,
- Groups: refresh.Claims.Groups,
- ConnectorData: refresh.ConnectorData,
- }
- ident, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident)
+ newIdent, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident)
if err != nil {
s.logger.Errorf("failed to refresh identity: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
-
- // Update the claims of the refresh token.
- //
- // UserID intentionally ignored for now.
- refresh.Claims.Username = ident.Username
- refresh.Claims.Email = ident.Email
- refresh.Claims.EmailVerified = ident.EmailVerified
- refresh.Claims.Groups = ident.Groups
- refresh.ConnectorData = ident.ConnectorData
+ ident = newIdent
}
- idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce)
+ claims := storage.Claims{
+ UserID: ident.UserID,
+ Username: ident.Username,
+ Email: ident.Email,
+ EmailVerified: ident.EmailVerified,
+ Groups: ident.Groups,
+ }
+
+ idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce)
if err != nil {
s.logger.Errorf("failed to create ID token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
- // Refresh tokens are claimed exactly once. Delete the current token and
- // create a new one.
- if err := s.storage.DeleteRefresh(code); err != nil {
- s.logger.Errorf("failed to delete auth code: %v", err)
+ newToken := &internal.RefreshToken{
+ RefreshId: refresh.ID,
+ Token: storage.NewID(),
+ }
+ rawNewToken, err := internal.Marshal(newToken)
+ if err != nil {
+ s.logger.Errorf("failed to marshal refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
- refresh.RefreshToken = storage.NewID()
- if err := s.storage.CreateRefresh(refresh); err != nil {
- s.logger.Errorf("failed to create refresh token: %v", err)
+
+ updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
+ if old.Token != refresh.Token {
+ return old, errors.New("refresh token claimed twice")
+ }
+ old.Token = newToken.Token
+ // Update the claims of the refresh token.
+ //
+ // UserID intentionally ignored for now.
+ old.Claims.Username = ident.Username
+ old.Claims.Email = ident.Email
+ old.Claims.EmailVerified = ident.EmailVerified
+ old.Claims.Groups = ident.Groups
+ old.ConnectorData = ident.ConnectorData
+ old.LastUsed = s.now()
+ return old, nil
+ }
+ if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil {
+ s.logger.Errorf("failed to update refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
- s.writeAccessToken(w, idToken, refresh.RefreshToken, expiry)
+ s.writeAccessToken(w, idToken, rawNewToken, expiry)
}
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken string, expiry time.Time) {
diff --git a/server/internal/codec.go b/server/internal/codec.go
new file mode 100644
index 00000000..a92c26f9
--- /dev/null
+++ b/server/internal/codec.go
@@ -0,0 +1,25 @@
+package internal
+
+import (
+ "encoding/base64"
+
+ "github.com/golang/protobuf/proto"
+)
+
+// Marshal converts a protobuf message to a URL legal string.
+func Marshal(message proto.Message) (string, error) {
+ data, err := proto.Marshal(message)
+ if err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(data), nil
+}
+
+// Unmarshal decodes a protobuf message.
+func Unmarshal(s string, message proto.Message) error {
+ data, err := base64.RawURLEncoding.DecodeString(s)
+ if err != nil {
+ return err
+ }
+ return proto.Unmarshal(data, message)
+}
diff --git a/server/internal/types.proto b/server/internal/types.proto
new file mode 100644
index 00000000..442dbd95
--- /dev/null
+++ b/server/internal/types.proto
@@ -0,0 +1,10 @@
+syntax = "proto3";
+
+// Package internal holds protobuf types used by the server
+package internal;
+
+// RefreshToken is a message that holds refresh token data used by dex.
+message RefreshToken {
+ string refresh_id = 1;
+ string token = 2;
+}
diff --git a/server/server_test.go b/server/server_test.go
index 7c499c15..d848076f 100644
--- a/server/server_test.go
+++ b/server/server_test.go
@@ -237,6 +237,10 @@ func TestOAuth2CodeFlow(t *testing.T) {
if token.RefreshToken == newToken.RefreshToken {
return fmt.Errorf("old refresh token was the same as the new token %q", token.RefreshToken)
}
+
+ if _, err := config.TokenSource(ctx, token).Token(); err == nil {
+ return errors.New("was able to redeem the same refresh token twice")
+ }
return nil
},
},
From f4bbab50561ea899417dc6e74972efdeeaa15935 Mon Sep 17 00:00:00 2001
From: Eric Chiang
Date: Thu, 22 Dec 2016 16:50:31 -0800
Subject: [PATCH 3/4] server/internal: generate protobuf types
---
server/internal/types.pb.go | 59 +++++++++++++++++++++++++++++++++++++
1 file changed, 59 insertions(+)
create mode 100644 server/internal/types.pb.go
diff --git a/server/internal/types.pb.go b/server/internal/types.pb.go
new file mode 100644
index 00000000..791944f5
--- /dev/null
+++ b/server/internal/types.pb.go
@@ -0,0 +1,59 @@
+// Code generated by protoc-gen-go.
+// source: server/internal/types.proto
+// DO NOT EDIT!
+
+/*
+Package internal is a generated protocol buffer package.
+
+Package internal holds protobuf types used by the server
+
+It is generated from these files:
+ server/internal/types.proto
+
+It has these top-level messages:
+ RefreshToken
+*/
+package internal
+
+import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
+import math "math"
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
+
+// RefreshToken is a message that holds refresh token data used by dex.
+type RefreshToken struct {
+ RefreshId string `protobuf:"bytes,1,opt,name=refresh_id,json=refreshId" json:"refresh_id,omitempty"`
+ Token string `protobuf:"bytes,2,opt,name=token" json:"token,omitempty"`
+}
+
+func (m *RefreshToken) Reset() { *m = RefreshToken{} }
+func (m *RefreshToken) String() string { return proto.CompactTextString(m) }
+func (*RefreshToken) ProtoMessage() {}
+func (*RefreshToken) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
+
+func init() {
+ proto.RegisterType((*RefreshToken)(nil), "internal.RefreshToken")
+}
+
+func init() { proto.RegisterFile("server/internal/types.proto", fileDescriptor0) }
+
+var fileDescriptor0 = []byte{
+ // 112 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0x92, 0x2e, 0x4e, 0x2d, 0x2a,
+ 0x4b, 0x2d, 0xd2, 0xcf, 0xcc, 0x2b, 0x49, 0x2d, 0xca, 0x4b, 0xcc, 0xd1, 0x2f, 0xa9, 0x2c, 0x48,
+ 0x2d, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x80, 0x89, 0x2a, 0x39, 0x73, 0xf1, 0x04,
+ 0xa5, 0xa6, 0x15, 0xa5, 0x16, 0x67, 0x84, 0xe4, 0x67, 0xa7, 0xe6, 0x09, 0xc9, 0x72, 0x71, 0x15,
+ 0x41, 0xf8, 0xf1, 0x99, 0x29, 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x9c, 0x50, 0x11, 0xcf,
+ 0x14, 0x21, 0x11, 0x2e, 0xd6, 0x12, 0x90, 0x3a, 0x09, 0x26, 0xb0, 0x0c, 0x84, 0x93, 0xc4, 0x06,
+ 0x36, 0xd5, 0x18, 0x10, 0x00, 0x00, 0xff, 0xff, 0x9b, 0xd0, 0x5a, 0x1d, 0x74, 0x00, 0x00, 0x00,
+}
From ed20fee2b9691705c65aeedb9f17182a104b8328 Mon Sep 17 00:00:00 2001
From: Eric Chiang
Date: Thu, 22 Dec 2016 16:58:21 -0800
Subject: [PATCH 4/4] cmd/example-app: fix refreshing
---
cmd/example-app/main.go | 42 +++++++++++++++++++++---------------
cmd/example-app/templates.go | 9 ++++++--
2 files changed, 32 insertions(+), 19 deletions(-)
diff --git a/cmd/example-app/main.go b/cmd/example-app/main.go
index ffa21c29..3ec34e38 100644
--- a/cmd/example-app/main.go
+++ b/cmd/example-app/main.go
@@ -241,7 +241,7 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) {
authCodeURL := ""
scopes = append(scopes, "openid", "profile", "email")
- if r.FormValue("offline_acecss") != "yes" {
+ if r.FormValue("offline_access") != "yes" {
authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState)
} else if a.offlineAsScope {
scopes = append(scopes, "offline_access")
@@ -254,34 +254,42 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) {
}
func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
- if errMsg := r.FormValue("error"); errMsg != "" {
- http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
- return
- }
-
- if state := r.FormValue("state"); state != exampleAppState {
- http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest)
- return
- }
-
- code := r.FormValue("code")
- refresh := r.FormValue("refresh_token")
var (
err error
token *oauth2.Token
)
oauth2Config := a.oauth2Config(nil)
- switch {
- case code != "":
+ switch r.Method {
+ case "GET":
+ // Authorization redirect callback from OAuth2 auth flow.
+ if errMsg := r.FormValue("error"); errMsg != "" {
+ http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
+ return
+ }
+ code := r.FormValue("code")
+ if code == "" {
+ http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest)
+ return
+ }
+ if state := r.FormValue("state"); state != exampleAppState {
+ http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest)
+ return
+ }
token, err = oauth2Config.Exchange(a.ctx, code)
- case refresh != "":
+ case "POST":
+ // Form request from frontend to refresh a token.
+ refresh := r.FormValue("refresh_token")
+ if refresh == "" {
+ http.Error(w, fmt.Sprintf("no refresh_token in request: %q", r.Form), http.StatusBadRequest)
+ return
+ }
t := &oauth2.Token{
RefreshToken: refresh,
Expiry: time.Now().Add(-time.Hour),
}
token, err = oauth2Config.TokenSource(r.Context(), t).Token()
default:
- http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest)
+ http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
return
}
diff --git a/cmd/example-app/templates.go b/cmd/example-app/templates.go
index c0f9dfbd..a870d0f0 100644
--- a/cmd/example-app/templates.go
+++ b/cmd/example-app/templates.go
@@ -8,7 +8,7 @@ import (
var indexTmpl = template.Must(template.New("index.html").Parse(`
-
Claims:
{{ .Claims }}
+ {{ if .RefreshToken }}
Refresh Token:
{{ .RefreshToken }}
- Redeem refresh token
+
+ {{ end }}
`))