Merge remote-tracking branch 'upstream/master' into issue2289
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/dexidp/dex/api/v2"
|
||||
"github.com/dexidp/dex/pkg/log"
|
||||
@@ -41,7 +42,7 @@ func newAPI(s storage.Storage, logger log.Logger, t *testing.T) *apiClient {
|
||||
|
||||
// Dial will retry automatically if the serv.Serve() goroutine
|
||||
// hasn't started yet.
|
||||
conn, err := grpc.Dial(l.Addr().String(), grpc.WithInsecure())
|
||||
conn, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/pkg/log"
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
@@ -140,6 +141,10 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||
// https://tools.ietf.org/html/rfc8628#section-3.2
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
|
||||
// Response type should be application/json according to
|
||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetEscapeHTML(false)
|
||||
enc.SetIndent("", " ")
|
||||
@@ -152,7 +157,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceTokenDeprecated(w http.ResponseWriter, r *http.Request) {
|
||||
s.logger.Warn(`The deprecated "/device/token" endpoint was called. It will be removed, use "/token" instead.`)
|
||||
log.Deprecated(s.logger, `The /device/token endpoint was called. It will be removed, use /token instead.`)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.Method {
|
||||
|
@@ -52,6 +52,7 @@ func TestHandleDeviceCode(t *testing.T) {
|
||||
requestType string
|
||||
scopes []string
|
||||
expectedResponseCode int
|
||||
expectedContentType string
|
||||
expectedServerResponse string
|
||||
}{
|
||||
{
|
||||
@@ -60,6 +61,7 @@ func TestHandleDeviceCode(t *testing.T) {
|
||||
requestType: "POST",
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
expectedResponseCode: http.StatusOK,
|
||||
expectedContentType: "application/json",
|
||||
},
|
||||
{
|
||||
testName: "Invalid request Type (GET)",
|
||||
@@ -67,6 +69,7 @@ func TestHandleDeviceCode(t *testing.T) {
|
||||
requestType: "GET",
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
expectedContentType: "application/json",
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
@@ -101,6 +104,10 @@ func TestHandleDeviceCode(t *testing.T) {
|
||||
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
|
||||
}
|
||||
|
||||
if rr.Header().Get("content-type") != tc.expectedContentType {
|
||||
t.Errorf("Unexpected Response Content Type. Expected %v got %v", tc.expectedContentType, rr.Header().Get("content-type"))
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(rr.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Could read token response %v", err)
|
||||
|
14
server/handlers.go
Normal file → Executable file
14
server/handlers.go
Normal file → Executable file
@@ -95,7 +95,6 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
|
||||
UserInfo: s.absURL("/userinfo"),
|
||||
DeviceEndpoint: s.absURL("/device/code"),
|
||||
Subjects: []string{"public"},
|
||||
GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode},
|
||||
IDTokenAlgs: []string{string(jose.RS256)},
|
||||
CodeChallengeAlgs: []string{codeChallengeMethodS256, codeChallengeMethodPlain},
|
||||
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
|
||||
@@ -111,6 +110,8 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
|
||||
}
|
||||
sort.Strings(d.ResponseTypes)
|
||||
|
||||
d.GrantTypes = s.supportedGrantTypes
|
||||
|
||||
data, err := json.MarshalIndent(d, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal discovery data: %v", err)
|
||||
@@ -1122,10 +1123,17 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
|
||||
Groups: identity.Groups,
|
||||
}
|
||||
|
||||
accessToken := storage.NewID()
|
||||
accessToken, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("password grant failed to create new access token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, nonce, accessToken, "", connID)
|
||||
if err != nil {
|
||||
s.tokenErrHelper(w, errServerError, fmt.Sprintf("failed to create ID token: %v", err), http.StatusInternalServerError)
|
||||
s.logger.Errorf("password grant failed to create new ID token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
@@ -128,6 +128,7 @@ const (
|
||||
const (
|
||||
grantTypeAuthorizationCode = "authorization_code"
|
||||
grantTypeRefreshToken = "refresh_token"
|
||||
grantTypeImplicit = "implicit"
|
||||
grantTypePassword = "password"
|
||||
grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
)
|
||||
|
@@ -66,17 +66,18 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref
|
||||
|
||||
refresh, err := s.storage.GetRefresh(token.RefreshId)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to get refresh token: %v", err)
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get refresh token: %v", err)
|
||||
return nil, newInternalServerError()
|
||||
}
|
||||
|
||||
return nil, invalidErr
|
||||
}
|
||||
|
||||
if refresh.ClientID != clientID {
|
||||
s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID)
|
||||
return nil, invalidErr
|
||||
// According to https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 Dex should respond with an
|
||||
// invalid grant error if token has already been claimed by another client.
|
||||
return nil, &refreshError{msg: errInvalidGrant, desc: invalidErr.desc, code: http.StatusBadRequest}
|
||||
}
|
||||
|
||||
if refresh.Token != token.Token {
|
||||
@@ -227,16 +228,13 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora
|
||||
|
||||
lastUsed := s.now()
|
||||
|
||||
rerr := s.updateOfflineSession(refresh, ident, lastUsed)
|
||||
if rerr != nil {
|
||||
return nil, rerr
|
||||
}
|
||||
|
||||
refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
|
||||
if s.refreshTokenPolicy.RotationEnabled() {
|
||||
if old.Token != token.Token {
|
||||
if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == token.Token {
|
||||
newToken.Token = old.Token
|
||||
// Do not update last used time for offline session if token is allowed to be reused
|
||||
lastUsed = old.LastUsed
|
||||
return old, nil
|
||||
}
|
||||
return old, errors.New("refresh token claimed twice")
|
||||
@@ -268,6 +266,11 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora
|
||||
return nil, newInternalServerError()
|
||||
}
|
||||
|
||||
rerr := s.updateOfflineSession(refresh, ident, lastUsed)
|
||||
if rerr != nil {
|
||||
return nil, rerr
|
||||
}
|
||||
|
||||
return newToken, nil
|
||||
}
|
||||
|
||||
|
27
server/server.go
Normal file → Executable file
27
server/server.go
Normal file → Executable file
@@ -11,6 +11,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -37,6 +38,7 @@ import (
|
||||
"github.com/dexidp/dex/connector/linkedin"
|
||||
"github.com/dexidp/dex/connector/microsoft"
|
||||
"github.com/dexidp/dex/connector/mock"
|
||||
"github.com/dexidp/dex/connector/oauth"
|
||||
"github.com/dexidp/dex/connector/oidc"
|
||||
"github.com/dexidp/dex/connector/openshift"
|
||||
"github.com/dexidp/dex/connector/saml"
|
||||
@@ -169,6 +171,8 @@ type Server struct {
|
||||
|
||||
supportedResponseTypes map[string]bool
|
||||
|
||||
supportedGrantTypes []string
|
||||
|
||||
now func() time.Time
|
||||
|
||||
idTokensValidFor time.Duration
|
||||
@@ -209,16 +213,29 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||
c.SupportedResponseTypes = []string{responseTypeCode}
|
||||
}
|
||||
|
||||
supported := make(map[string]bool)
|
||||
supportedGrant := []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode} // default
|
||||
supportedRes := make(map[string]bool)
|
||||
|
||||
for _, respType := range c.SupportedResponseTypes {
|
||||
switch respType {
|
||||
case responseTypeCode, responseTypeIDToken, responseTypeToken:
|
||||
case responseTypeCode, responseTypeIDToken:
|
||||
// continue
|
||||
case responseTypeToken:
|
||||
// response_type=token is an implicit flow, let's add it to the discovery info
|
||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.1
|
||||
supportedGrant = append(supportedGrant, grantTypeImplicit)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported response_type %q", respType)
|
||||
}
|
||||
supported[respType] = true
|
||||
supportedRes[respType] = true
|
||||
}
|
||||
|
||||
if c.PasswordConnector != "" {
|
||||
supportedGrant = append(supportedGrant, grantTypePassword)
|
||||
}
|
||||
|
||||
sort.Strings(supportedGrant)
|
||||
|
||||
webFS := web.FS()
|
||||
if c.Web.Dir != "" {
|
||||
webFS = os.DirFS(c.Web.Dir)
|
||||
@@ -249,7 +266,8 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||
issuerURL: *issuerURL,
|
||||
connectors: make(map[string]Connector),
|
||||
storage: newKeyCacher(c.Storage, now),
|
||||
supportedResponseTypes: supported,
|
||||
supportedResponseTypes: supportedRes,
|
||||
supportedGrantTypes: supportedGrant,
|
||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
|
||||
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
|
||||
@@ -528,6 +546,7 @@ var ConnectorsConfig = map[string]func() ConnectorConfig{
|
||||
"gitlab": func() ConnectorConfig { return new(gitlab.Config) },
|
||||
"google": func() ConnectorConfig { return new(google.Config) },
|
||||
"oidc": func() ConnectorConfig { return new(oidc.Config) },
|
||||
"oauth": func() ConnectorConfig { return new(oauth.Config) },
|
||||
"saml": func() ConnectorConfig { return new(saml.Config) },
|
||||
"authproxy": func() ConnectorConfig { return new(authproxy.Config) },
|
||||
"linkedin": func() ConnectorConfig { return new(linkedin.Config) },
|
||||
|
@@ -481,6 +481,47 @@ func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "refresh with different client id",
|
||||
scopes: []string{"openid", "email"},
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||
v := url.Values{}
|
||||
v.Add("client_id", clientID)
|
||||
v.Add("client_secret", clientSecret)
|
||||
v.Add("grant_type", "refresh_token")
|
||||
v.Add("refresh_token", "existedrefrestoken")
|
||||
v.Add("scope", "oidc email")
|
||||
resp, err := http.PostForm(p.Endpoint().TokenURL, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
return fmt.Errorf("expected status code %d, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
|
||||
var respErr struct {
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description"`
|
||||
}
|
||||
|
||||
if err = json.NewDecoder(resp.Body).Decode(&respErr); err != nil {
|
||||
return fmt.Errorf("cannot decode token response: %v", err)
|
||||
}
|
||||
|
||||
if respErr.Error != errInvalidGrant {
|
||||
return fmt.Errorf("expected error %q, got %q", errInvalidGrant, respErr.Error)
|
||||
}
|
||||
|
||||
expectedMsg := "Refresh token is invalid or has already been claimed by another client."
|
||||
if respErr.Description != expectedMsg {
|
||||
return fmt.Errorf("expected error description %q, got %q", expectedMsg, respErr.Description)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
// This test ensures that the connector.RefreshConnector interface is being
|
||||
// used when clients request a refresh token.
|
||||
@@ -792,6 +833,13 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
if err := s.storage.CreateRefresh(storage.RefreshToken{
|
||||
ID: "existedrefrestoken",
|
||||
ClientID: "unexcistedclientid",
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to create existed refresh token: %v", err)
|
||||
}
|
||||
|
||||
// Create the OAuth2 config.
|
||||
oauth2Config = &oauth2.Config{
|
||||
ClientID: client.ID,
|
||||
@@ -1570,6 +1618,13 @@ func TestOAuth2DeviceFlow(t *testing.T) {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
if err := s.storage.CreateRefresh(storage.RefreshToken{
|
||||
ID: "existedrefrestoken",
|
||||
ClientID: "unexcistedclientid",
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to create existed refresh token: %v", err)
|
||||
}
|
||||
|
||||
// Grab the issuer that we'll reuse for the different endpoints to hit
|
||||
issuer, err := url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
@@ -1680,3 +1735,42 @@ func TestOAuth2DeviceFlow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerSupportedGrants(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config func(c *Config)
|
||||
resGrants []string
|
||||
}{
|
||||
{
|
||||
name: "Simple",
|
||||
config: func(c *Config) {},
|
||||
resGrants: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode},
|
||||
},
|
||||
{
|
||||
name: "With password connector",
|
||||
config: func(c *Config) { c.PasswordConnector = "local" },
|
||||
resGrants: []string{grantTypeAuthorizationCode, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode},
|
||||
},
|
||||
{
|
||||
name: "With token response",
|
||||
config: func(c *Config) { c.SupportedResponseTypes = append(c.SupportedResponseTypes, responseTypeToken) },
|
||||
resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypeRefreshToken, grantTypeDeviceCode},
|
||||
},
|
||||
{
|
||||
name: "All",
|
||||
config: func(c *Config) {
|
||||
c.PasswordConnector = "local"
|
||||
c.SupportedResponseTypes = append(c.SupportedResponseTypes, responseTypeToken)
|
||||
},
|
||||
resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, srv := newTestServer(context.TODO(), t, tc.config)
|
||||
require.Equal(t, srv.supportedGrantTypes, tc.resGrants)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -8,7 +8,6 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
@@ -109,7 +108,7 @@ func loadWebConfig(c webConfig) (http.Handler, http.Handler, *templates, error)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("read static dir: %v", err)
|
||||
}
|
||||
themeFiles, err := fs.Sub(c.webFS, filepath.Join("themes", c.theme))
|
||||
themeFiles, err := fs.Sub(c.webFS, path.Join("themes", c.theme))
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("read themes dir: %v", err)
|
||||
}
|
||||
@@ -133,7 +132,7 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
filenames = append(filenames, filepath.Join(templatesDir, file.Name()))
|
||||
filenames = append(filenames, path.Join(templatesDir, file.Name()))
|
||||
}
|
||||
if len(filenames) == 0 {
|
||||
return nil, fmt.Errorf("no files in template dir %q", templatesDir)
|
||||
@@ -239,6 +238,9 @@ var scopeDescriptions = map[string]string{
|
||||
"offline_access": "Have offline access",
|
||||
"profile": "View basic profile information",
|
||||
"email": "View your email address",
|
||||
// 'groups' is not a standard OIDC scope, and Dex only returns groups only if the upstream provider does too.
|
||||
// This warning is added for convenience to show that the user may expose some sensitive data to the application.
|
||||
"groups": "View your groups",
|
||||
}
|
||||
|
||||
type connectorInfo struct {
|
||||
|
Reference in New Issue
Block a user