diff --git a/server/handlers.go b/server/handlers.go index 5756f652..b059b98f 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "net/url" "path" @@ -1438,9 +1437,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { case http.MethodPost: err := r.ParseForm() if err != nil { - message := "Could not parse Device Request body" - s.logger.Errorf("%s : %v", message, err) - respondWithError(w, message, err) + s.logger.Errorf("Could not parse Device Request body: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound) return } @@ -1454,7 +1452,11 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { deviceCode := storage.NewDeviceCode() //make user code - userCode := storage.NewUserCode() + userCode, err := storage.NewUserCode() + if err != nil { + s.logger.Errorf("Error generating user code: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) + } //make a pkce verification code pkceCode := storage.NewID() @@ -1473,24 +1475,21 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { } if err := s.storage.CreateDeviceRequest(deviceReq); err != nil { - message := fmt.Sprintf("Failed to store device request %v", err) - s.logger.Errorf(message) - respondWithError(w, message, err) + s.logger.Errorf("Failed to store device request; %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) return } //Store the device token deviceToken := storage.DeviceToken{ DeviceCode: deviceCode, - Status: "pending", - Token: "", + Status: deviceTokenPending, Expiry: expireTime, } if err := s.storage.CreateDeviceToken(deviceToken); err != nil { - message := fmt.Sprintf("Failed to store device token %v", err) - s.logger.Errorf(message) - respondWithError(w, message, err) + s.logger.Errorf("Failed to store device token %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) return } @@ -1506,21 +1505,54 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { enc.SetIndent("", " ") enc.Encode(code) + default: + s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type") + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) + } +} + +func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.Method { + case http.MethodPost: + err := r.ParseForm() + if err != nil { + message := "Could not parse Device Token Request body" + s.logger.Warnf("%s : %v", message, err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) + return + } + + deviceCode := r.Form.Get("device_code") + if deviceCode == "" { + message := "No device code received" + s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest) + return + } + + grantType := r.PostFormValue("grant_type") + if grantType != grantTypeDeviceCode { + s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) + return + } + + //Grab the device token from the db + deviceToken, err := s.storage.GetDeviceToken(deviceCode) + if err != nil || s.now().After(deviceToken.Expiry) { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get device code: %v", err) + } + s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest) + return + } + + switch deviceToken.Status { + case deviceTokenPending: + s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized) + case deviceTokenComplete: + w.Write([]byte(deviceToken.Token)) + } default: s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") } } - -func respondWithError(w io.Writer, errorMessage string, err error) { - resp := struct { - Error string `json:"error"` - ErrorMessage string `json:"message"` - }{ - Error: err.Error(), - ErrorMessage: errorMessage, - } - - enc := json.NewEncoder(w) - enc.SetIndent("", " ") - enc.Encode(resp) -} diff --git a/server/oauth2.go b/server/oauth2.go index 05dd25d2..ddeffc3f 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -122,6 +122,7 @@ const ( grantTypeAuthorizationCode = "authorization_code" grantTypeRefreshToken = "refresh_token" grantTypePassword = "password" + grantTypeDeviceCode = "device_code" ) const ( @@ -130,6 +131,11 @@ const ( responseTypeIDToken = "id_token" // ID Token in url fragment ) +const ( + deviceTokenPending = "authorization_pending" + deviceTokenComplete = "complete" +) + func parseScopes(scopes []string) connector.Scopes { var s connector.Scopes for _, scope := range scopes { diff --git a/server/server.go b/server/server.go index 95f5359b..b86dac04 100644 --- a/server/server.go +++ b/server/server.go @@ -303,6 +303,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) handleFunc("/device/code", s.handleDeviceCode) + handleFunc("/device/token", s.handleDeviceToken) r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { // Strip the X-Remote-* headers to prevent security issues on // misconfigured authproxy connector setups. diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index c1bd318f..944d8a78 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -837,8 +837,13 @@ func testGC(t *testing.T, s storage.Storage) { t.Errorf("expected storage.ErrNotFound, got %v", err) } + userCode, err := storage.NewUserCode() + if err != nil { + t.Errorf("Unexpected Error: %v", err) + } + d := storage.DeviceRequest{ - UserCode: storage.NewUserCode(), + UserCode: userCode, DeviceCode: storage.NewID(), ClientID: "client1", Scopes: []string{"openid", "email"}, @@ -896,9 +901,9 @@ func testGC(t *testing.T, s storage.Storage) { t.Errorf("expected no device token garbage collection results, got %#v", result) } } - //if _, err := s.GetDeviceRequest(d.UserCode); err != nil { - // t.Errorf("expected to be able to get auth request after GC: %v", err) - //} + if _, err := s.GetDeviceToken(dt.DeviceCode); err != nil { + t.Errorf("expected to be able to get device token after GC: %v", err) + } } if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) @@ -906,12 +911,11 @@ func testGC(t *testing.T, s storage.Storage) { t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens) } - //TODO add this code back once Getters are written for device tokens - //if _, err := s.GetDeviceRequest(d.UserCode); err == nil { - // t.Errorf("expected device request to be GC'd") - //} else if err != storage.ErrNotFound { - // t.Errorf("expected storage.ErrNotFound, got %v", err) - //} + if _, err := s.GetDeviceToken(dt.DeviceCode); err == nil { + t.Errorf("expected device token to be GC'd") + } else if err != storage.ErrNotFound { + t.Errorf("expected storage.ErrNotFound, got %v", err) + } } // testTimezones tests that backends either fully support timezones or @@ -961,8 +965,12 @@ func testTimezones(t *testing.T, s storage.Storage) { } func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { + userCode, err := storage.NewUserCode() + if err != nil { + panic(err) + } d1 := storage.DeviceRequest{ - UserCode: storage.NewUserCode(), + UserCode: userCode, DeviceCode: storage.NewID(), ClientID: "client1", Scopes: []string{"openid", "email"}, @@ -975,7 +983,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { } // Attempt to create same DeviceRequest twice. - err := s.CreateDeviceRequest(d1) + err = s.CreateDeviceRequest(d1) mustBeErrAlreadyExists(t, "device request", err) //No manual deletes for device requests, will be handled by garbage collection routines diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index 27e337a4..bbb86651 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -591,6 +591,13 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t)) } +func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t) + return t, err +} + func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) { res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix()) if err != nil { diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index e87b9c01..20f9daac 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -641,3 +641,11 @@ func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error { func (cli *client) CreateDeviceToken(t storage.DeviceToken) error { return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t)) } + +func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { + var token DeviceToken + if err := cli.get(resourceDeviceToken, deviceCode, &token); err != nil { + return storage.DeviceToken{}, err + } + return toStorageDeviceToken(token), nil +} diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 5a61b92e..66fe5780 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -739,3 +739,12 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { } return req } + +func toStorageDeviceToken(t DeviceToken) storage.DeviceToken { + return storage.DeviceToken{ + DeviceCode: t.ObjectMeta.Name, + Status: t.Status, + Token: t.Token, + Expiry: t.Expiry, + } +} diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 29d4af27..32cfd415 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -503,3 +503,14 @@ func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) { }) return } + +func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) { + s.tx(func() { + var ok bool + if t, ok = s.deviceTokens[deviceCode]; !ok { + err = storage.ErrNotFound + return + } + }) + return +} diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 989d2db0..c52e67cf 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -922,3 +922,25 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { } return nil } + +func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { + return getDeviceToken(c, deviceCode) +} + +func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) { + err = q.QueryRow(` + select + status, token, expiry + from device_token where device_code = $1; + `, deviceCode).Scan( + &a.Status, &a.Token, &a.Expiry, + ) + if err != nil { + if err == sql.ErrNoRows { + return a, storage.ErrNotFound + } + return a, fmt.Errorf("select device token: %v", err) + } + a.DeviceCode = deviceCode + return a, nil +} diff --git a/storage/storage.go b/storage/storage.go index 7078ccf5..88ab71cd 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -5,7 +5,7 @@ import ( "encoding/base32" "errors" "io" - mrand "math/rand" + "math/big" "strings" "time" @@ -25,6 +25,9 @@ var ( // TODO(ericchiang): refactor ID creation onto the storage. var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567") +//Valid characters for user codes +const validUserCharacters = "BCDFGHJKLMNPQRSTVWXZ" + // NewDeviceCode returns a 32 char alphanumeric cryptographically secure string func NewDeviceCode() string { return newSecureID(32) @@ -79,6 +82,7 @@ type Storage interface { GetPassword(email string) (Password, error) GetOfflineSessions(userID string, connID string) (OfflineSessions, error) GetConnector(id string) (Connector, error) + GetDeviceToken(deviceCode string) (DeviceToken, error) ListClients() ([]Client, error) ListRefreshTokens() ([]RefreshToken, error) @@ -357,18 +361,24 @@ type Keys struct { NextRotation time.Time } -func NewUserCode() string { - mrand.Seed(time.Now().UnixNano()) - return randomString(4) + "-" + randomString(4) +// NewUserCode returns a randomized 8 character user code for the device flow. +// No vowels are included to prevent accidental generation of words +func NewUserCode() (string, error) { + code, err := randomString(8) + if err != nil { + return "", err + } + return code[:4] + "-" + code[4:], nil } -func randomString(n int) string { - var letter = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ") - b := make([]rune, n) - for i := range b { - b[i] = letter[mrand.Intn(len(letter))] +func randomString(n int) (string, error) { + v := big.NewInt(int64(len(validUserCharacters))) + bytes := make([]byte, n) + for i := 0; i < n; i++ { + c, _ := rand.Int(rand.Reader, v) + bytes[i] = validUserCharacters[c.Int64()] } - return string(b) + return string(bytes), nil } //DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user