diff --git a/server/handlers.go b/server/handlers.go index 7b9f3a94..b8430ebc 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -155,14 +155,22 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { return } - authReq.Expiry = s.now().Add(time.Minute * 30) + // TODO(ericchiang): Create this authorization request later in the login flow + // so users don't hit "not found" database errors if they wait at the login + // screen too long. + // + // See: https://github.com/coreos/dex/issues/646 + authReq.Expiry = s.now().Add(24 * time.Hour) // Totally arbitrary value. if err := s.storage.CreateAuthRequest(authReq); err != nil { s.logger.Errorf("Failed to create authorization request: %v", err) s.renderError(w, http.StatusInternalServerError, "Failed to connect to the database.") return } + if len(s.connectors) == 1 { for id := range s.connectors { + // TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter + // on create the auth request. http.Redirect(w, r, s.absPath("/auth", id)+"?req="+authReq.ID, http.StatusFound) return } @@ -174,12 +182,14 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { connectorInfos[i] = connectorInfo{ ID: id, Name: conn.DisplayName, - URL: s.absPath("/auth", id), + // TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter + // on create the auth request. + URL: s.absPath("/auth", id) + "?req=" + authReq.ID, } i++ } - if err := s.templates.login(w, connectorInfos, authReq.ID); err != nil { + if err := s.templates.login(w, connectorInfos); err != nil { s.logger.Errorf("Server template error: %v", err) } } @@ -198,7 +208,11 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { authReq, err := s.storage.GetAuthRequest(authReqID) if err != nil { s.logger.Errorf("Failed to get auth request: %v", err) - s.renderError(w, http.StatusInternalServerError, "Database error.") + if err == storage.ErrNotFound { + s.renderError(w, http.StatusBadRequest, "Login session expired.") + } else { + s.renderError(w, http.StatusInternalServerError, "Database error.") + } return } scopes := parseScopes(authReq.Scopes) @@ -229,7 +243,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } http.Redirect(w, r, callbackURL, http.StatusFound) case connector.PasswordConnector: - if err := s.templates.password(w, authReqID, r.URL.String(), "", false); err != nil { + if err := s.templates.password(w, r.URL.String(), "", false); err != nil { s.logger.Errorf("Server template error: %v", err) } case connector.SAMLConnector: @@ -277,7 +291,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } if !ok { - if err := s.templates.password(w, authReqID, r.URL.String(), username, true); err != nil { + if err := s.templates.password(w, r.URL.String(), username, true); err != nil { s.logger.Errorf("Server template error: %v", err) } return diff --git a/server/templates.go b/server/templates.go index 6e1f1c8d..4c11e2c4 100644 --- a/server/templates.go +++ b/server/templates.go @@ -2,6 +2,7 @@ package server import ( "fmt" + "html/template" "io" "io/ioutil" "net/http" @@ -9,7 +10,6 @@ import ( "path/filepath" "sort" "strings" - "text/template" ) const ( @@ -181,23 +181,20 @@ func (n byName) Len() int { return len(n) } func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name } func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] } -func (t *templates) login(w http.ResponseWriter, connectors []connectorInfo, authReqID string) error { +func (t *templates) login(w http.ResponseWriter, connectors []connectorInfo) error { sort.Sort(byName(connectors)) - data := struct { Connectors []connectorInfo - AuthReqID string - }{connectors, authReqID} + }{connectors} return renderTemplate(w, t.loginTmpl, data) } -func (t *templates) password(w http.ResponseWriter, authReqID, callback, lastUsername string, lastWasInvalid bool) error { +func (t *templates) password(w http.ResponseWriter, postURL, lastUsername string, lastWasInvalid bool) error { data := struct { - AuthReqID string - PostURL string - Username string - Invalid bool - }{authReqID, string(callback), lastUsername, lastWasInvalid} + PostURL string + Username string + Invalid bool + }{postURL, lastUsername, lastWasInvalid} return renderTemplate(w, t.passwordTmpl, data) } diff --git a/web/templates/login.html b/web/templates/login.html index 10c5dbbb..56151a78 100644 --- a/web/templates/login.html +++ b/web/templates/login.html @@ -5,7 +5,7 @@