storage: Add OfflineSession object to backend storage.

This commit is contained in:
rithu john
2017-01-31 16:11:59 -08:00
parent 49f446c1a7
commit d928ac0677
10 changed files with 580 additions and 32 deletions

View File

@@ -682,6 +682,75 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
// deleteToken determines if we need to delete the newly created refresh token
// due to a failure in updating/creating the OfflineSession object for the
// corresponding user.
var deleteToken bool
defer func() {
if deleteToken {
// Delete newly created refresh token from storage.
if err := s.storage.DeleteRefresh(refresh.ID); err != nil {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
}
}()
tokenRef := storage.RefreshTokenRef{
ID: refresh.ID,
ClientID: refresh.ClientID,
CreatedAt: refresh.CreatedAt,
LastUsed: refresh.LastUsed,
}
// Try to retrieve an existing OfflineSession object for the corresponding user.
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
offlineSessions := storage.OfflineSessions{
UserID: refresh.Claims.UserID,
ConnID: refresh.ConnectorID,
Refresh: make(map[string]*storage.RefreshTokenRef),
}
offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef
// Create a new OfflineSession object for the user and add a reference object for
// the newly recieved refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
} else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
}
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
}
}
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
}
@@ -815,6 +884,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return
}
lastUsed := s.now()
updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
if old.Token != refresh.Token {
return old, errors.New("refresh token claimed twice")
@@ -828,14 +898,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
old.Claims.EmailVerified = ident.EmailVerified
old.Claims.Groups = ident.Groups
old.ConnectorData = ident.ConnectorData
old.LastUsed = s.now()
old.LastUsed = lastUsed
return old, nil
}
// Update LastUsed time stamp in refresh token reference object
// in offline session for the user.
if err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if old.Refresh[refresh.ClientID].ID != refresh.ID {
return old, errors.New("refresh token invalid")
}
old.Refresh[refresh.ClientID].LastUsed = lastUsed
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
// Update refresh token in the storage.
if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil {
s.logger.Errorf("failed to update refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
}

View File

@@ -971,3 +971,108 @@ func TestKeyCacher(t *testing.T) {
}
}
}
type oauth2Client struct {
config *oauth2.Config
token *oauth2.Token
server *httptest.Server
}
// TestRefreshTokenFlow tests the refresh token code flow for oauth2. The test verifies
// that only valid refresh tokens can be used to refresh an expired token.
func TestRefreshTokenFlow(t *testing.T) {
state := "state"
now := func() time.Time { return time.Now() }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.Now = now
})
defer httpServer.Close()
p, err := oidc.NewProvider(ctx, httpServer.URL)
if err != nil {
t.Fatalf("failed to get provider: %v", err)
}
var oauth2Client oauth2Client
oauth2Client.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/callback" {
// User is visiting app first time. Redirect to dex.
http.Redirect(w, r, oauth2Client.config.AuthCodeURL(state), http.StatusSeeOther)
return
}
// User is at '/callback' so they were just redirected _from_ dex.
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
if desc := q.Get("error_description"); desc != "" {
t.Errorf("got error from server %s: %s", errType, desc)
} else {
t.Errorf("got error from server %s", errType)
}
w.WriteHeader(http.StatusInternalServerError)
return
}
// Grab code, exchange for token.
if code := q.Get("code"); code != "" {
token, err := oauth2Client.config.Exchange(ctx, code)
if err != nil {
t.Errorf("failed to exchange code for token: %v", err)
return
}
oauth2Client.token = token
}
// Ensure state matches.
if gotState := q.Get("state"); gotState != state {
t.Errorf("state did not match, want=%q got=%q", state, gotState)
}
w.WriteHeader(http.StatusOK)
return
}))
defer oauth2Client.server.Close()
// Register the client above with dex.
redirectURL := oauth2Client.server.URL + "/callback"
client := storage.Client{
ID: "testclient",
Secret: "testclientsecret",
RedirectURIs: []string{redirectURL},
}
if err := s.storage.CreateClient(client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
oauth2Client.config = &oauth2.Config{
ClientID: client.ID,
ClientSecret: client.Secret,
Endpoint: p.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "email", "offline_access"},
RedirectURL: redirectURL,
}
if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil {
t.Fatalf("get failed: %v", err)
}
tok := &oauth2.Token{
RefreshToken: oauth2Client.token.RefreshToken,
Expiry: time.Now().Add(-time.Hour),
}
// Login in again to recieve a new token.
if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil {
t.Fatalf("get failed: %v", err)
}
// try to refresh expired token with old refresh token.
newToken, err := oauth2Client.config.TokenSource(ctx, tok).Token()
if newToken != nil {
t.Errorf("Token refreshed with invalid refresh token.")
}
}