1cc26fab2f
fix: prevent cross-site scripting for the device flow
705 lines
21 KiB
Go
705 lines
21 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"path"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/dexidp/dex/storage"
|
|
)
|
|
|
|
func TestDeviceVerificationURI(t *testing.T) {
|
|
t0 := time.Now()
|
|
|
|
now := func() time.Time { return t0 }
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
// Setup a dex server.
|
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
|
c.Issuer += "/non-root-path"
|
|
c.Now = now
|
|
})
|
|
defer httpServer.Close()
|
|
|
|
u, err := url.Parse(s.issuerURL.String())
|
|
if err != nil {
|
|
t.Fatalf("Could not parse issuer URL %v", err)
|
|
}
|
|
u.Path = path.Join(u.Path, "/device/auth/verify_code")
|
|
|
|
uri := s.getDeviceVerificationURI()
|
|
if uri != u.Path {
|
|
t.Errorf("Invalid verification URI. Expected %v got %v", u.Path, uri)
|
|
}
|
|
}
|
|
|
|
func TestHandleDeviceCode(t *testing.T) {
|
|
t0 := time.Now()
|
|
|
|
now := func() time.Time { return t0 }
|
|
|
|
tests := []struct {
|
|
testName string
|
|
clientID string
|
|
requestType string
|
|
scopes []string
|
|
expectedResponseCode int
|
|
expectedContentType string
|
|
expectedServerResponse string
|
|
}{
|
|
{
|
|
testName: "New Code",
|
|
clientID: "test",
|
|
requestType: "POST",
|
|
scopes: []string{"openid", "profile", "email"},
|
|
expectedResponseCode: http.StatusOK,
|
|
expectedContentType: "application/json",
|
|
},
|
|
{
|
|
testName: "Invalid request Type (GET)",
|
|
clientID: "test",
|
|
requestType: "GET",
|
|
scopes: []string{"openid", "profile", "email"},
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
expectedContentType: "application/json",
|
|
},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.testName, func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Setup a dex server.
|
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
|
c.Issuer += "/non-root-path"
|
|
c.Now = now
|
|
})
|
|
defer httpServer.Close()
|
|
|
|
u, err := url.Parse(s.issuerURL.String())
|
|
if err != nil {
|
|
t.Fatalf("Could not parse issuer URL %v", err)
|
|
}
|
|
u.Path = path.Join(u.Path, "device/code")
|
|
|
|
data := url.Values{}
|
|
data.Set("client_id", tc.clientID)
|
|
for _, scope := range tc.scopes {
|
|
data.Add("scope", scope)
|
|
}
|
|
req, _ := http.NewRequest(tc.requestType, u.String(), bytes.NewBufferString(data.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
|
|
|
rr := httptest.NewRecorder()
|
|
s.ServeHTTP(rr, req)
|
|
if rr.Code != tc.expectedResponseCode {
|
|
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
|
|
}
|
|
|
|
if rr.Header().Get("content-type") != tc.expectedContentType {
|
|
t.Errorf("Unexpected Response Content Type. Expected %v got %v", tc.expectedContentType, rr.Header().Get("content-type"))
|
|
}
|
|
|
|
body, err := io.ReadAll(rr.Body)
|
|
if err != nil {
|
|
t.Errorf("Could read token response %v", err)
|
|
}
|
|
if tc.expectedResponseCode == http.StatusOK {
|
|
var resp deviceCodeResponse
|
|
if err := json.Unmarshal(body, &resp); err != nil {
|
|
t.Errorf("Unexpected Device Code Response Format %v", string(body))
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDeviceCallback(t *testing.T) {
|
|
t0 := time.Now()
|
|
|
|
now := func() time.Time { return t0 }
|
|
|
|
type formValues struct {
|
|
state string
|
|
code string
|
|
error string
|
|
}
|
|
|
|
// Base "Control" test values
|
|
baseFormValues := formValues{
|
|
state: "XXXX-XXXX",
|
|
code: "somecode",
|
|
}
|
|
baseAuthCode := storage.AuthCode{
|
|
ID: "somecode",
|
|
ClientID: "testclient",
|
|
RedirectURI: deviceCallbackURI,
|
|
Nonce: "",
|
|
Scopes: []string{"openid", "profile", "email"},
|
|
ConnectorID: "mock",
|
|
ConnectorData: nil,
|
|
Claims: storage.Claims{},
|
|
Expiry: now().Add(5 * time.Minute),
|
|
}
|
|
baseDeviceRequest := storage.DeviceRequest{
|
|
UserCode: "XXXX-XXXX",
|
|
DeviceCode: "devicecode",
|
|
ClientID: "testclient",
|
|
ClientSecret: "",
|
|
Scopes: []string{"openid", "profile", "email"},
|
|
Expiry: now().Add(5 * time.Minute),
|
|
}
|
|
baseDeviceToken := storage.DeviceToken{
|
|
DeviceCode: "devicecode",
|
|
Status: deviceTokenPending,
|
|
Token: "",
|
|
Expiry: now().Add(5 * time.Minute),
|
|
LastRequestTime: time.Time{},
|
|
PollIntervalSeconds: 0,
|
|
}
|
|
|
|
tests := []struct {
|
|
testName string
|
|
expectedResponseCode int
|
|
expectedServerResponse string
|
|
values formValues
|
|
testAuthCode storage.AuthCode
|
|
testDeviceRequest storage.DeviceRequest
|
|
testDeviceToken storage.DeviceToken
|
|
}{
|
|
{
|
|
testName: "Missing State",
|
|
values: formValues{
|
|
state: "",
|
|
code: "somecode",
|
|
error: "",
|
|
},
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Missing Code",
|
|
values: formValues{
|
|
state: "XXXX-XXXX",
|
|
code: "",
|
|
error: "",
|
|
},
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Error During Authorization",
|
|
values: formValues{
|
|
state: "XXXX-XXXX",
|
|
code: "somecode",
|
|
error: "Error Condition",
|
|
},
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
expectedServerResponse: "Error Condition: \n",
|
|
},
|
|
{
|
|
testName: "Expired Auth Code",
|
|
values: baseFormValues,
|
|
testAuthCode: storage.AuthCode{
|
|
ID: "somecode",
|
|
ClientID: "testclient",
|
|
RedirectURI: deviceCallbackURI,
|
|
Nonce: "",
|
|
Scopes: []string{"openid", "profile", "email"},
|
|
ConnectorID: "pic",
|
|
ConnectorData: nil,
|
|
Claims: storage.Claims{},
|
|
Expiry: now().Add(-5 * time.Minute),
|
|
},
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Invalid Auth Code",
|
|
values: baseFormValues,
|
|
testAuthCode: storage.AuthCode{
|
|
ID: "somecode",
|
|
ClientID: "testclient",
|
|
RedirectURI: deviceCallbackURI,
|
|
Nonce: "",
|
|
Scopes: []string{"openid", "profile", "email"},
|
|
ConnectorID: "pic",
|
|
ConnectorData: nil,
|
|
Claims: storage.Claims{},
|
|
Expiry: now().Add(5 * time.Minute),
|
|
},
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Expired Device Request",
|
|
values: baseFormValues,
|
|
testAuthCode: baseAuthCode,
|
|
testDeviceRequest: storage.DeviceRequest{
|
|
UserCode: "XXXX-XXXX",
|
|
DeviceCode: "devicecode",
|
|
ClientID: "testclient",
|
|
Scopes: []string{"openid", "profile", "email"},
|
|
Expiry: now().Add(-5 * time.Minute),
|
|
},
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Non-Existent User Code",
|
|
values: baseFormValues,
|
|
testAuthCode: baseAuthCode,
|
|
testDeviceRequest: storage.DeviceRequest{
|
|
UserCode: "ZZZZ-ZZZZ",
|
|
DeviceCode: "devicecode",
|
|
Scopes: []string{"openid", "profile", "email"},
|
|
Expiry: now().Add(5 * time.Minute),
|
|
},
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Bad Device Request Client",
|
|
values: baseFormValues,
|
|
testAuthCode: baseAuthCode,
|
|
testDeviceRequest: storage.DeviceRequest{
|
|
UserCode: "XXXX-XXXX",
|
|
DeviceCode: "devicecode",
|
|
Scopes: []string{"openid", "profile", "email"},
|
|
Expiry: now().Add(5 * time.Minute),
|
|
},
|
|
expectedResponseCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
testName: "Bad Device Request Secret",
|
|
values: baseFormValues,
|
|
testAuthCode: baseAuthCode,
|
|
testDeviceRequest: storage.DeviceRequest{
|
|
UserCode: "XXXX-XXXX",
|
|
DeviceCode: "devicecode",
|
|
ClientSecret: "foobar",
|
|
Scopes: []string{"openid", "profile", "email"},
|
|
Expiry: now().Add(5 * time.Minute),
|
|
},
|
|
expectedResponseCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
testName: "Expired Device Token",
|
|
values: baseFormValues,
|
|
testAuthCode: baseAuthCode,
|
|
testDeviceRequest: baseDeviceRequest,
|
|
testDeviceToken: storage.DeviceToken{
|
|
DeviceCode: "devicecode",
|
|
Status: deviceTokenPending,
|
|
Token: "",
|
|
Expiry: now().Add(-5 * time.Minute),
|
|
LastRequestTime: time.Time{},
|
|
PollIntervalSeconds: 0,
|
|
},
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Device Code Already Redeemed",
|
|
values: baseFormValues,
|
|
testAuthCode: baseAuthCode,
|
|
testDeviceRequest: baseDeviceRequest,
|
|
testDeviceToken: storage.DeviceToken{
|
|
DeviceCode: "devicecode",
|
|
Status: deviceTokenComplete,
|
|
Token: "",
|
|
Expiry: now().Add(5 * time.Minute),
|
|
LastRequestTime: time.Time{},
|
|
PollIntervalSeconds: 0,
|
|
},
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Successful Exchange",
|
|
values: baseFormValues,
|
|
testAuthCode: baseAuthCode,
|
|
testDeviceRequest: baseDeviceRequest,
|
|
testDeviceToken: baseDeviceToken,
|
|
expectedResponseCode: http.StatusOK,
|
|
},
|
|
{
|
|
testName: "Prevent cross-site scripting",
|
|
values: formValues{
|
|
state: "XXXX-XXXX",
|
|
code: "somecode",
|
|
error: "<script>console.log(window);</script>",
|
|
},
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
expectedServerResponse: "<script>console.log(window);</script>: \n",
|
|
},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.testName, func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Setup a dex server.
|
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
|
// c.Issuer = c.Issuer + "/non-root-path"
|
|
c.Now = now
|
|
})
|
|
defer httpServer.Close()
|
|
|
|
if err := s.storage.CreateAuthCode(tc.testAuthCode); err != nil {
|
|
t.Fatalf("failed to create auth code: %v", err)
|
|
}
|
|
|
|
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
|
|
t.Fatalf("failed to create device request: %v", err)
|
|
}
|
|
|
|
if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
|
|
t.Fatalf("failed to create device token: %v", err)
|
|
}
|
|
|
|
client := storage.Client{
|
|
ID: "testclient",
|
|
Secret: "",
|
|
RedirectURIs: []string{deviceCallbackURI},
|
|
}
|
|
if err := s.storage.CreateClient(client); err != nil {
|
|
t.Fatalf("failed to create client: %v", err)
|
|
}
|
|
|
|
u, err := url.Parse(s.issuerURL.String())
|
|
if err != nil {
|
|
t.Fatalf("Could not parse issuer URL %v", err)
|
|
}
|
|
u.Path = path.Join(u.Path, "device/callback")
|
|
q := u.Query()
|
|
q.Set("state", tc.values.state)
|
|
q.Set("code", tc.values.code)
|
|
q.Set("error", tc.values.error)
|
|
u.RawQuery = q.Encode()
|
|
req, _ := http.NewRequest("GET", u.String(), nil)
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
|
|
|
rr := httptest.NewRecorder()
|
|
s.ServeHTTP(rr, req)
|
|
if rr.Code != tc.expectedResponseCode {
|
|
t.Errorf("%s: Unexpected Response Type. Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code)
|
|
}
|
|
|
|
if len(tc.expectedServerResponse) > 0 {
|
|
result, _ := io.ReadAll(rr.Body)
|
|
if string(result) != tc.expectedServerResponse {
|
|
t.Errorf("%s: Unexpected Response. Expected %q got %q", tc.testName, tc.expectedServerResponse, result)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDeviceTokenResponse(t *testing.T) {
|
|
t0 := time.Now()
|
|
|
|
now := func() time.Time { return t0 }
|
|
|
|
baseDeviceRequest := storage.DeviceRequest{
|
|
UserCode: "ABCD-WXYZ",
|
|
DeviceCode: "foo",
|
|
ClientID: "testclient",
|
|
Scopes: []string{"openid", "profile", "offline_access"},
|
|
Expiry: now().Add(5 * time.Minute),
|
|
}
|
|
|
|
tests := []struct {
|
|
testName string
|
|
testDeviceRequest storage.DeviceRequest
|
|
testDeviceToken storage.DeviceToken
|
|
testGrantType string
|
|
testDeviceCode string
|
|
expectedServerResponse string
|
|
expectedResponseCode int
|
|
}{
|
|
{
|
|
testName: "Valid but pending token",
|
|
testDeviceRequest: baseDeviceRequest,
|
|
testDeviceToken: storage.DeviceToken{
|
|
DeviceCode: "f00bar",
|
|
Status: deviceTokenPending,
|
|
Token: "",
|
|
Expiry: now().Add(5 * time.Minute),
|
|
LastRequestTime: time.Time{},
|
|
PollIntervalSeconds: 0,
|
|
},
|
|
testDeviceCode: "f00bar",
|
|
expectedServerResponse: deviceTokenPending,
|
|
expectedResponseCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
testName: "Invalid Grant Type",
|
|
testDeviceRequest: baseDeviceRequest,
|
|
testDeviceToken: storage.DeviceToken{
|
|
DeviceCode: "f00bar",
|
|
Status: deviceTokenPending,
|
|
Token: "",
|
|
Expiry: now().Add(5 * time.Minute),
|
|
LastRequestTime: time.Time{},
|
|
PollIntervalSeconds: 0,
|
|
},
|
|
testDeviceCode: "f00bar",
|
|
testGrantType: grantTypeAuthorizationCode,
|
|
expectedServerResponse: errInvalidGrant,
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Test Slow Down State",
|
|
testDeviceRequest: baseDeviceRequest,
|
|
testDeviceToken: storage.DeviceToken{
|
|
DeviceCode: "f00bar",
|
|
Status: deviceTokenPending,
|
|
Token: "",
|
|
Expiry: now().Add(5 * time.Minute),
|
|
LastRequestTime: now(),
|
|
PollIntervalSeconds: 10,
|
|
},
|
|
testDeviceCode: "f00bar",
|
|
expectedServerResponse: deviceTokenSlowDown,
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Test Expired Device Token",
|
|
testDeviceRequest: baseDeviceRequest,
|
|
testDeviceToken: storage.DeviceToken{
|
|
DeviceCode: "f00bar",
|
|
Status: deviceTokenPending,
|
|
Token: "",
|
|
Expiry: now().Add(-5 * time.Minute),
|
|
LastRequestTime: time.Time{},
|
|
PollIntervalSeconds: 0,
|
|
},
|
|
testDeviceCode: "f00bar",
|
|
expectedServerResponse: deviceTokenExpired,
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Test Non-existent Device Code",
|
|
testDeviceRequest: baseDeviceRequest,
|
|
testDeviceToken: storage.DeviceToken{
|
|
DeviceCode: "foo",
|
|
Status: deviceTokenPending,
|
|
Token: "",
|
|
Expiry: now().Add(-5 * time.Minute),
|
|
LastRequestTime: time.Time{},
|
|
PollIntervalSeconds: 0,
|
|
},
|
|
testDeviceCode: "bar",
|
|
expectedServerResponse: errInvalidRequest,
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Empty Device Code in Request",
|
|
testDeviceRequest: baseDeviceRequest,
|
|
testDeviceToken: storage.DeviceToken{
|
|
DeviceCode: "bar",
|
|
Status: deviceTokenPending,
|
|
Token: "",
|
|
Expiry: now().Add(-5 * time.Minute),
|
|
LastRequestTime: time.Time{},
|
|
PollIntervalSeconds: 0,
|
|
},
|
|
testDeviceCode: "",
|
|
expectedServerResponse: errInvalidRequest,
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
testName: "Claim validated token from Device Code",
|
|
testDeviceRequest: baseDeviceRequest,
|
|
testDeviceToken: storage.DeviceToken{
|
|
DeviceCode: "foo",
|
|
Status: deviceTokenComplete,
|
|
Token: "{\"access_token\": \"foobar\"}",
|
|
Expiry: now().Add(5 * time.Minute),
|
|
LastRequestTime: time.Time{},
|
|
PollIntervalSeconds: 0,
|
|
},
|
|
testDeviceCode: "foo",
|
|
expectedServerResponse: "{\"access_token\": \"foobar\"}",
|
|
expectedResponseCode: http.StatusOK,
|
|
},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.testName, func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Setup a dex server.
|
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
|
c.Issuer += "/non-root-path"
|
|
c.Now = now
|
|
})
|
|
defer httpServer.Close()
|
|
|
|
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
|
|
t.Fatalf("Failed to store device token %v", err)
|
|
}
|
|
|
|
if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
|
|
t.Fatalf("Failed to store device token %v", err)
|
|
}
|
|
|
|
u, err := url.Parse(s.issuerURL.String())
|
|
if err != nil {
|
|
t.Fatalf("Could not parse issuer URL %v", err)
|
|
}
|
|
u.Path = path.Join(u.Path, "device/token")
|
|
|
|
data := url.Values{}
|
|
grantType := grantTypeDeviceCode
|
|
if tc.testGrantType != "" {
|
|
grantType = tc.testGrantType
|
|
}
|
|
data.Set("grant_type", grantType)
|
|
data.Set("device_code", tc.testDeviceCode)
|
|
req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
|
|
|
rr := httptest.NewRecorder()
|
|
s.ServeHTTP(rr, req)
|
|
if rr.Code != tc.expectedResponseCode {
|
|
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
|
|
}
|
|
|
|
body, err := io.ReadAll(rr.Body)
|
|
if err != nil {
|
|
t.Errorf("Could read token response %v", err)
|
|
}
|
|
if tc.expectedResponseCode == http.StatusBadRequest || tc.expectedResponseCode == http.StatusUnauthorized {
|
|
expectJSONErrorResponse(tc.testName, body, tc.expectedServerResponse, t)
|
|
} else if string(body) != tc.expectedServerResponse {
|
|
t.Errorf("Unexpected Server Response. Expected %v got %v", tc.expectedServerResponse, string(body))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func expectJSONErrorResponse(testCase string, body []byte, expectedError string, t *testing.T) {
|
|
jsonMap := make(map[string]interface{})
|
|
err := json.Unmarshal(body, &jsonMap)
|
|
if err != nil {
|
|
t.Errorf("Unexpected error unmarshalling response: %v", err)
|
|
}
|
|
if jsonMap["error"] != expectedError {
|
|
t.Errorf("Test Case %s expected error %v, received %v", testCase, expectedError, jsonMap["error"])
|
|
}
|
|
}
|
|
|
|
func TestVerifyCodeResponse(t *testing.T) {
|
|
t0 := time.Now()
|
|
|
|
now := func() time.Time { return t0 }
|
|
|
|
tests := []struct {
|
|
testName string
|
|
testDeviceRequest storage.DeviceRequest
|
|
userCode string
|
|
expectedResponseCode int
|
|
expectedRedirectPath string
|
|
}{
|
|
{
|
|
testName: "Unknown user code",
|
|
testDeviceRequest: storage.DeviceRequest{
|
|
UserCode: "ABCD-WXYZ",
|
|
DeviceCode: "f00bar",
|
|
ClientID: "testclient",
|
|
Scopes: []string{"openid", "profile", "offline_access"},
|
|
Expiry: now().Add(5 * time.Minute),
|
|
},
|
|
userCode: "CODE-TEST",
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
expectedRedirectPath: "",
|
|
},
|
|
{
|
|
testName: "Expired user code",
|
|
testDeviceRequest: storage.DeviceRequest{
|
|
UserCode: "ABCD-WXYZ",
|
|
DeviceCode: "f00bar",
|
|
ClientID: "testclient",
|
|
Scopes: []string{"openid", "profile", "offline_access"},
|
|
Expiry: now().Add(-5 * time.Minute),
|
|
},
|
|
userCode: "ABCD-WXYZ",
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
expectedRedirectPath: "",
|
|
},
|
|
{
|
|
testName: "No user code",
|
|
testDeviceRequest: storage.DeviceRequest{
|
|
UserCode: "ABCD-WXYZ",
|
|
DeviceCode: "f00bar",
|
|
ClientID: "testclient",
|
|
Scopes: []string{"openid", "profile", "offline_access"},
|
|
Expiry: now().Add(-5 * time.Minute),
|
|
},
|
|
userCode: "",
|
|
expectedResponseCode: http.StatusBadRequest,
|
|
expectedRedirectPath: "",
|
|
},
|
|
{
|
|
testName: "Valid user code, expect redirect to auth endpoint",
|
|
testDeviceRequest: storage.DeviceRequest{
|
|
UserCode: "ABCD-WXYZ",
|
|
DeviceCode: "f00bar",
|
|
ClientID: "testclient",
|
|
Scopes: []string{"openid", "profile", "offline_access"},
|
|
Expiry: now().Add(5 * time.Minute),
|
|
},
|
|
userCode: "ABCD-WXYZ",
|
|
expectedResponseCode: http.StatusFound,
|
|
expectedRedirectPath: "/auth",
|
|
},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.testName, func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Setup a dex server.
|
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
|
c.Issuer += "/non-root-path"
|
|
c.Now = now
|
|
})
|
|
defer httpServer.Close()
|
|
|
|
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
|
|
t.Fatalf("Failed to store device token %v", err)
|
|
}
|
|
|
|
u, err := url.Parse(s.issuerURL.String())
|
|
if err != nil {
|
|
t.Fatalf("Could not parse issuer URL %v", err)
|
|
}
|
|
|
|
u.Path = path.Join(u.Path, "device/auth/verify_code")
|
|
data := url.Values{}
|
|
data.Set("user_code", tc.userCode)
|
|
req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
|
|
|
rr := httptest.NewRecorder()
|
|
s.ServeHTTP(rr, req)
|
|
if rr.Code != tc.expectedResponseCode {
|
|
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
|
|
}
|
|
|
|
u, err = url.Parse(s.issuerURL.String())
|
|
if err != nil {
|
|
t.Errorf("Could not parse issuer URL %v", err)
|
|
}
|
|
u.Path = path.Join(u.Path, tc.expectedRedirectPath)
|
|
|
|
location := rr.Header().Get("Location")
|
|
if rr.Code == http.StatusFound && !strings.HasPrefix(location, u.Path) {
|
|
t.Errorf("Invalid Redirect. Expected %v got %v", u.Path, location)
|
|
}
|
|
})
|
|
}
|
|
}
|