Make /device/token deprecation warning more concise
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
parent
9ed5cc00cf
commit
3bd0e91a68
@ -151,9 +151,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleDeviceTokenGrant(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleDeviceTokenDeprecated(w http.ResponseWriter, r *http.Request) {
|
||||||
s.logger.Warn(`Request to the deprecated "/device/token" endpoint was received.`)
|
s.logger.Warn(`The deprecated "/device/token" endpoint was called. It will be removed, use "/token" instead.`)
|
||||||
s.logger.Warn(`The "/device/token" endpoint will be removed in a future release.`)
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
|
@ -321,7 +321,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
|||||||
handleFunc("/device/auth/verify_code", s.verifyUserCode)
|
handleFunc("/device/auth/verify_code", s.verifyUserCode)
|
||||||
handleFunc("/device/code", s.handleDeviceCode)
|
handleFunc("/device/code", s.handleDeviceCode)
|
||||||
// TODO(nabokihms): "/device/token" endpoint is deprecated, consider using /token endpoint instead
|
// TODO(nabokihms): "/device/token" endpoint is deprecated, consider using /token endpoint instead
|
||||||
handleFunc("/device/token", s.handleDeviceTokenGrant)
|
handleFunc("/device/token", s.handleDeviceTokenDeprecated)
|
||||||
handleFunc(deviceCallbackURI, s.handleDeviceCallback)
|
handleFunc(deviceCallbackURI, s.handleDeviceCallback)
|
||||||
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Strip the X-Remote-* headers to prevent security issues on
|
// Strip the X-Remote-* headers to prevent security issues on
|
||||||
|
@ -1497,143 +1497,164 @@ func TestOAuth2DeviceFlow(t *testing.T) {
|
|||||||
var conn *mock.Callback
|
var conn *mock.Callback
|
||||||
idTokensValidFor := time.Second * 30
|
idTokensValidFor := time.Second * 30
|
||||||
|
|
||||||
for _, tc := range makeOAuth2Tests(clientID, clientSecret, now).tests {
|
tests := makeOAuth2Tests(clientID, clientSecret, now)
|
||||||
func() {
|
testCases := []struct {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
name string
|
||||||
defer cancel()
|
tokenEndpoint string
|
||||||
|
oauth2Tests oauth2Tests
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Actual token endpoint for devices",
|
||||||
|
tokenEndpoint: "/token",
|
||||||
|
oauth2Tests: tests,
|
||||||
|
},
|
||||||
|
// TODO(nabokihms): delete temporary tests after removing the deprecated token endpoint support
|
||||||
|
{
|
||||||
|
name: "Deprecated token endpoint for devices",
|
||||||
|
tokenEndpoint: "/device/token",
|
||||||
|
oauth2Tests: tests,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// Setup a dex server.
|
for _, testCase := range testCases {
|
||||||
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
for _, tc := range testCase.oauth2Tests.tests {
|
||||||
c.Issuer += "/non-root-path"
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
c.Now = now
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
c.IDTokensValidFor = idTokensValidFor
|
defer cancel()
|
||||||
|
|
||||||
|
// Setup a dex server.
|
||||||
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||||
|
c.Issuer += "/non-root-path"
|
||||||
|
c.Now = now
|
||||||
|
c.IDTokensValidFor = idTokensValidFor
|
||||||
|
})
|
||||||
|
defer httpServer.Close()
|
||||||
|
|
||||||
|
mockConn := s.connectors["mock"]
|
||||||
|
conn = mockConn.Connector.(*mock.Callback)
|
||||||
|
|
||||||
|
p, err := oidc.NewProvider(ctx, httpServer.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get provider: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the Clients to the test server
|
||||||
|
client := storage.Client{
|
||||||
|
ID: clientID,
|
||||||
|
RedirectURIs: []string{deviceCallbackURI},
|
||||||
|
Public: true,
|
||||||
|
}
|
||||||
|
if err := s.storage.CreateClient(client); err != nil {
|
||||||
|
t.Fatalf("failed to create client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grab the issuer that we'll reuse for the different endpoints to hit
|
||||||
|
issuer, err := url.Parse(s.issuerURL.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could not parse issuer URL %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a new Device Request
|
||||||
|
codeURL, _ := url.Parse(issuer.String())
|
||||||
|
codeURL.Path = path.Join(codeURL.Path, "device/code")
|
||||||
|
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("client_id", clientID)
|
||||||
|
data.Add("scope", strings.Join(requestedScopes, " "))
|
||||||
|
resp, err := http.PostForm(codeURL.String(), data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could not request device code: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
responseBody, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could read device code response %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||||
|
}
|
||||||
|
if resp.Header.Get("Cache-Control") != "no-store" {
|
||||||
|
t.Errorf("Cache-Control header doesn't exist in Device Code Response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the code response
|
||||||
|
var deviceCode deviceCodeResponse
|
||||||
|
if err := json.Unmarshal(responseBody, &deviceCode); err != nil {
|
||||||
|
t.Errorf("Unexpected Device Code Response Format %v", string(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock the user hitting the verification URI and posting the form
|
||||||
|
verifyURL, _ := url.Parse(issuer.String())
|
||||||
|
verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code")
|
||||||
|
urlData := url.Values{}
|
||||||
|
urlData.Set("user_code", deviceCode.UserCode)
|
||||||
|
resp, err = http.PostForm(verifyURL.String(), urlData)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error Posting Form: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
responseBody, err = ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could read verification response %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hit the Token Endpoint, and try and get an access token
|
||||||
|
tokenURL, _ := url.Parse(issuer.String())
|
||||||
|
tokenURL.Path = path.Join(tokenURL.Path, testCase.tokenEndpoint)
|
||||||
|
v := url.Values{}
|
||||||
|
v.Add("grant_type", grantTypeDeviceCode)
|
||||||
|
v.Add("device_code", deviceCode.DeviceCode)
|
||||||
|
resp, err = http.PostForm(tokenURL.String(), v)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could not request device token: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
responseBody, err = ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could read device token response %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the response
|
||||||
|
var tokenRes accessTokenResponse
|
||||||
|
if err := json.Unmarshal(responseBody, &tokenRes); err != nil {
|
||||||
|
t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: tokenRes.AccessToken,
|
||||||
|
TokenType: tokenRes.TokenType,
|
||||||
|
RefreshToken: tokenRes.RefreshToken,
|
||||||
|
}
|
||||||
|
raw := make(map[string]interface{})
|
||||||
|
json.Unmarshal(responseBody, &raw) // no error checks for optional fields
|
||||||
|
token = token.WithExtra(raw)
|
||||||
|
if secs := tokenRes.ExpiresIn; secs > 0 {
|
||||||
|
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run token tests to validate info is correct
|
||||||
|
// Create the OAuth2 config.
|
||||||
|
oauth2Config := &oauth2.Config{
|
||||||
|
ClientID: client.ID,
|
||||||
|
ClientSecret: client.Secret,
|
||||||
|
Endpoint: p.Endpoint(),
|
||||||
|
Scopes: requestedScopes,
|
||||||
|
RedirectURL: deviceCallbackURI,
|
||||||
|
}
|
||||||
|
if len(tc.scopes) != 0 {
|
||||||
|
oauth2Config.Scopes = tc.scopes
|
||||||
|
}
|
||||||
|
err = tc.handleToken(ctx, p, oauth2Config, token, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%s: %v", tc.name, err)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
defer httpServer.Close()
|
}
|
||||||
|
|
||||||
mockConn := s.connectors["mock"]
|
|
||||||
conn = mockConn.Connector.(*mock.Callback)
|
|
||||||
|
|
||||||
p, err := oidc.NewProvider(ctx, httpServer.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get provider: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the Clients to the test server
|
|
||||||
client := storage.Client{
|
|
||||||
ID: clientID,
|
|
||||||
RedirectURIs: []string{deviceCallbackURI},
|
|
||||||
Public: true,
|
|
||||||
}
|
|
||||||
if err := s.storage.CreateClient(client); err != nil {
|
|
||||||
t.Fatalf("failed to create client: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Grab the issuer that we'll reuse for the different endpoints to hit
|
|
||||||
issuer, err := url.Parse(s.issuerURL.String())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Could not parse issuer URL %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send a new Device Request
|
|
||||||
codeURL, _ := url.Parse(issuer.String())
|
|
||||||
codeURL.Path = path.Join(codeURL.Path, "device/code")
|
|
||||||
|
|
||||||
data := url.Values{}
|
|
||||||
data.Set("client_id", clientID)
|
|
||||||
data.Add("scope", strings.Join(requestedScopes, " "))
|
|
||||||
resp, err := http.PostForm(codeURL.String(), data)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Could not request device code: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
responseBody, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Could read device code response %v", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
|
||||||
}
|
|
||||||
if resp.Header.Get("Cache-Control") != "no-store" {
|
|
||||||
t.Errorf("Cache-Control header doesn't exist in Device Code Response")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the code response
|
|
||||||
var deviceCode deviceCodeResponse
|
|
||||||
if err := json.Unmarshal(responseBody, &deviceCode); err != nil {
|
|
||||||
t.Errorf("Unexpected Device Code Response Format %v", string(responseBody))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mock the user hitting the verification URI and posting the form
|
|
||||||
verifyURL, _ := url.Parse(issuer.String())
|
|
||||||
verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code")
|
|
||||||
urlData := url.Values{}
|
|
||||||
urlData.Set("user_code", deviceCode.UserCode)
|
|
||||||
resp, err = http.PostForm(verifyURL.String(), urlData)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error Posting Form: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
responseBody, err = ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Could read verification response %v", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hit the Token Endpoint, and try and get an access token
|
|
||||||
tokenURL, _ := url.Parse(issuer.String())
|
|
||||||
tokenURL.Path = path.Join(tokenURL.Path, "/token")
|
|
||||||
v := url.Values{}
|
|
||||||
v.Add("grant_type", grantTypeDeviceCode)
|
|
||||||
v.Add("device_code", deviceCode.DeviceCode)
|
|
||||||
resp, err = http.PostForm(tokenURL.String(), v)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Could not request device token: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
responseBody, err = ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Could read device token response %v", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the response
|
|
||||||
var tokenRes accessTokenResponse
|
|
||||||
if err := json.Unmarshal(responseBody, &tokenRes); err != nil {
|
|
||||||
t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody))
|
|
||||||
}
|
|
||||||
|
|
||||||
token := &oauth2.Token{
|
|
||||||
AccessToken: tokenRes.AccessToken,
|
|
||||||
TokenType: tokenRes.TokenType,
|
|
||||||
RefreshToken: tokenRes.RefreshToken,
|
|
||||||
}
|
|
||||||
raw := make(map[string]interface{})
|
|
||||||
json.Unmarshal(responseBody, &raw) // no error checks for optional fields
|
|
||||||
token = token.WithExtra(raw)
|
|
||||||
if secs := tokenRes.ExpiresIn; secs > 0 {
|
|
||||||
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run token tests to validate info is correct
|
|
||||||
// Create the OAuth2 config.
|
|
||||||
oauth2Config := &oauth2.Config{
|
|
||||||
ClientID: client.ID,
|
|
||||||
ClientSecret: client.Secret,
|
|
||||||
Endpoint: p.Endpoint(),
|
|
||||||
Scopes: requestedScopes,
|
|
||||||
RedirectURL: deviceCallbackURI,
|
|
||||||
}
|
|
||||||
if len(tc.scopes) != 0 {
|
|
||||||
oauth2Config.Scopes = tc.scopes
|
|
||||||
}
|
|
||||||
err = tc.handleToken(ctx, p, oauth2Config, token, conn)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("%s: %v", tc.name, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user