add PKCE support to device code flow (#2575)
Signed-off-by: Bob Callaway <bobcallaway@users.noreply.github.com>
This commit is contained in:
		| @@ -73,6 +73,17 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { | ||||
| 		clientID := r.Form.Get("client_id") | ||||
| 		clientSecret := r.Form.Get("client_secret") | ||||
| 		scopes := strings.Fields(r.Form.Get("scope")) | ||||
| 		codeChallenge := r.Form.Get("code_challenge") | ||||
| 		codeChallengeMethod := r.Form.Get("code_challenge_method") | ||||
|  | ||||
| 		if codeChallengeMethod == "" { | ||||
| 			codeChallengeMethod = codeChallengeMethodPlain | ||||
| 		} | ||||
| 		if codeChallengeMethod != codeChallengeMethodS256 && codeChallengeMethod != codeChallengeMethodPlain { | ||||
| 			description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod) | ||||
| 			s.tokenErrHelper(w, errInvalidRequest, description, http.StatusBadRequest) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes) | ||||
|  | ||||
| @@ -108,6 +119,10 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { | ||||
| 			Expiry:              expireTime, | ||||
| 			LastRequestTime:     s.now(), | ||||
| 			PollIntervalSeconds: 0, | ||||
| 			PKCE: storage.PKCE{ | ||||
| 				CodeChallenge:       codeChallenge, | ||||
| 				CodeChallengeMethod: codeChallengeMethod, | ||||
| 			}, | ||||
| 		} | ||||
|  | ||||
| 		if err := s.storage.CreateDeviceToken(deviceToken); err != nil { | ||||
| @@ -236,6 +251,30 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { | ||||
| 			s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized) | ||||
| 		} | ||||
| 	case deviceTokenComplete: | ||||
| 		codeChallengeFromStorage := deviceToken.PKCE.CodeChallenge | ||||
| 		providedCodeVerifier := r.Form.Get("code_verifier") | ||||
|  | ||||
| 		switch { | ||||
| 		case providedCodeVerifier != "" && codeChallengeFromStorage != "": | ||||
| 			calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, deviceToken.PKCE.CodeChallengeMethod) | ||||
| 			if err != nil { | ||||
| 				s.logger.Error(err) | ||||
| 				s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) | ||||
| 				return | ||||
| 			} | ||||
| 			if codeChallengeFromStorage != calculatedCodeChallenge { | ||||
| 				s.tokenErrHelper(w, errInvalidGrant, "Invalid code_verifier.", http.StatusBadRequest) | ||||
| 				return | ||||
| 			} | ||||
| 		case providedCodeVerifier != "": | ||||
| 			// Received no code_challenge on /auth, but a code_verifier on /token | ||||
| 			s.tokenErrHelper(w, errInvalidRequest, "No PKCE flow started. Cannot check code_verifier.", http.StatusBadRequest) | ||||
| 			return | ||||
| 		case codeChallengeFromStorage != "": | ||||
| 			// Received PKCE request on /auth, but no code_verifier on /token | ||||
| 			s.tokenErrHelper(w, errInvalidGrant, "Expecting parameter code_verifier in PKCE flow.", http.StatusBadRequest) | ||||
| 			return | ||||
| 		} | ||||
| 		w.Write([]byte(deviceToken.Token)) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -49,6 +49,7 @@ func TestHandleDeviceCode(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		testName               string | ||||
| 		clientID               string | ||||
| 		codeChallengeMethod    string | ||||
| 		requestType            string | ||||
| 		scopes                 []string | ||||
| 		expectedResponseCode   int | ||||
| @@ -71,6 +72,24 @@ func TestHandleDeviceCode(t *testing.T) { | ||||
| 			expectedResponseCode: http.StatusBadRequest, | ||||
| 			expectedContentType:  "application/json", | ||||
| 		}, | ||||
| 		{ | ||||
| 			testName:             "New Code with valid PKCE", | ||||
| 			clientID:             "test", | ||||
| 			requestType:          "POST", | ||||
| 			scopes:               []string{"openid", "profile", "email"}, | ||||
| 			codeChallengeMethod:  "S256", | ||||
| 			expectedResponseCode: http.StatusOK, | ||||
| 			expectedContentType:  "application/json", | ||||
| 		}, | ||||
| 		{ | ||||
| 			testName:             "Invalid code challenge method", | ||||
| 			clientID:             "test", | ||||
| 			requestType:          "POST", | ||||
| 			codeChallengeMethod:  "invalid", | ||||
| 			scopes:               []string{"openid", "profile", "email"}, | ||||
| 			expectedResponseCode: http.StatusBadRequest, | ||||
| 			expectedContentType:  "application/json", | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tc := range tests { | ||||
| 		t.Run(tc.testName, func(t *testing.T) { | ||||
| @@ -92,6 +111,7 @@ func TestHandleDeviceCode(t *testing.T) { | ||||
|  | ||||
| 			data := url.Values{} | ||||
| 			data.Set("client_id", tc.clientID) | ||||
| 			data.Set("code_challenge_method", tc.codeChallengeMethod) | ||||
| 			for _, scope := range tc.scopes { | ||||
| 				data.Add("scope", scope) | ||||
| 			} | ||||
| @@ -401,6 +421,13 @@ func TestDeviceTokenResponse(t *testing.T) { | ||||
|  | ||||
| 	now := func() time.Time { return t0 } | ||||
|  | ||||
| 	// Base PKCE values | ||||
| 	// base64-urlencoded, sha256 digest of code_verifier | ||||
| 	codeChallenge := "L7ZqsT_zNwvrH8E7J0CqPHx1wgBaFiaE-fAZcKUUAbc" | ||||
| 	codeChallengeMethod := "S256" | ||||
| 	// "random" string between 43 & 128 ASCII characters | ||||
| 	codeVerifier := "66114650f56cc45dee7ee03c49f048ddf9aa53cbf5b09985832fa4f790ff2604" | ||||
|  | ||||
| 	baseDeviceRequest := storage.DeviceRequest{ | ||||
| 		UserCode:   "ABCD-WXYZ", | ||||
| 		DeviceCode: "foo", | ||||
| @@ -415,6 +442,7 @@ func TestDeviceTokenResponse(t *testing.T) { | ||||
| 		testDeviceToken        storage.DeviceToken | ||||
| 		testGrantType          string | ||||
| 		testDeviceCode         string | ||||
| 		testCodeVerifier       string | ||||
| 		expectedServerResponse string | ||||
| 		expectedResponseCode   int | ||||
| 	}{ | ||||
| @@ -524,6 +552,101 @@ func TestDeviceTokenResponse(t *testing.T) { | ||||
| 			expectedServerResponse: "{\"access_token\": \"foobar\"}", | ||||
| 			expectedResponseCode:   http.StatusOK, | ||||
| 		}, | ||||
| 		{ | ||||
| 			testName: "Successful Exchange with PKCE", | ||||
| 			testDeviceToken: storage.DeviceToken{ | ||||
| 				DeviceCode:          "foo", | ||||
| 				Status:              deviceTokenComplete, | ||||
| 				Token:               "{\"access_token\": \"foobar\"}", | ||||
| 				Expiry:              now().Add(5 * time.Minute), | ||||
| 				LastRequestTime:     time.Time{}, | ||||
| 				PollIntervalSeconds: 0, | ||||
| 				PKCE: storage.PKCE{ | ||||
| 					CodeChallenge:       codeChallenge, | ||||
| 					CodeChallengeMethod: codeChallengeMethod, | ||||
| 				}, | ||||
| 			}, | ||||
| 			testDeviceCode:         "foo", | ||||
| 			testCodeVerifier:       codeVerifier, | ||||
| 			testDeviceRequest:      baseDeviceRequest, | ||||
| 			expectedServerResponse: "{\"access_token\": \"foobar\"}", | ||||
| 			expectedResponseCode:   http.StatusOK, | ||||
| 		}, | ||||
| 		{ | ||||
| 			testName: "Test Exchange started with PKCE but without verifier provided", | ||||
| 			testDeviceToken: storage.DeviceToken{ | ||||
| 				DeviceCode:          "foo", | ||||
| 				Status:              deviceTokenComplete, | ||||
| 				Token:               "{\"access_token\": \"foobar\"}", | ||||
| 				Expiry:              now().Add(5 * time.Minute), | ||||
| 				LastRequestTime:     time.Time{}, | ||||
| 				PollIntervalSeconds: 0, | ||||
| 				PKCE: storage.PKCE{ | ||||
| 					CodeChallenge:       codeChallenge, | ||||
| 					CodeChallengeMethod: codeChallengeMethod, | ||||
| 				}, | ||||
| 			}, | ||||
| 			testDeviceCode:         "foo", | ||||
| 			testDeviceRequest:      baseDeviceRequest, | ||||
| 			expectedServerResponse: errInvalidGrant, | ||||
| 			expectedResponseCode:   http.StatusBadRequest, | ||||
| 		}, | ||||
| 		{ | ||||
| 			testName: "Test Exchange not started with PKCE but verifier provided", | ||||
| 			testDeviceToken: storage.DeviceToken{ | ||||
| 				DeviceCode:          "foo", | ||||
| 				Status:              deviceTokenComplete, | ||||
| 				Token:               "{\"access_token\": \"foobar\"}", | ||||
| 				Expiry:              now().Add(5 * time.Minute), | ||||
| 				LastRequestTime:     time.Time{}, | ||||
| 				PollIntervalSeconds: 0, | ||||
| 			}, | ||||
| 			testDeviceCode:         "foo", | ||||
| 			testCodeVerifier:       codeVerifier, | ||||
| 			testDeviceRequest:      baseDeviceRequest, | ||||
| 			expectedServerResponse: errInvalidRequest, | ||||
| 			expectedResponseCode:   http.StatusBadRequest, | ||||
| 		}, | ||||
| 		{ | ||||
| 			testName: "Test with PKCE but incorrect verifier provided", | ||||
| 			testDeviceToken: storage.DeviceToken{ | ||||
| 				DeviceCode:          "foo", | ||||
| 				Status:              deviceTokenComplete, | ||||
| 				Token:               "{\"access_token\": \"foobar\"}", | ||||
| 				Expiry:              now().Add(5 * time.Minute), | ||||
| 				LastRequestTime:     time.Time{}, | ||||
| 				PollIntervalSeconds: 0, | ||||
| 				PKCE: storage.PKCE{ | ||||
| 					CodeChallenge:       codeChallenge, | ||||
| 					CodeChallengeMethod: codeChallengeMethod, | ||||
| 				}, | ||||
| 			}, | ||||
| 			testDeviceCode:         "foo", | ||||
| 			testCodeVerifier:       "invalid", | ||||
| 			testDeviceRequest:      baseDeviceRequest, | ||||
| 			expectedServerResponse: errInvalidGrant, | ||||
| 			expectedResponseCode:   http.StatusBadRequest, | ||||
| 		}, | ||||
| 		{ | ||||
| 			testName: "Test with PKCE but incorrect challenge provided", | ||||
| 			testDeviceToken: storage.DeviceToken{ | ||||
| 				DeviceCode:          "foo", | ||||
| 				Status:              deviceTokenComplete, | ||||
| 				Token:               "{\"access_token\": \"foobar\"}", | ||||
| 				Expiry:              now().Add(5 * time.Minute), | ||||
| 				LastRequestTime:     time.Time{}, | ||||
| 				PollIntervalSeconds: 0, | ||||
| 				PKCE: storage.PKCE{ | ||||
| 					CodeChallenge:       "invalid", | ||||
| 					CodeChallengeMethod: codeChallengeMethod, | ||||
| 				}, | ||||
| 			}, | ||||
| 			testDeviceCode:         "foo", | ||||
| 			testCodeVerifier:       codeVerifier, | ||||
| 			testDeviceRequest:      baseDeviceRequest, | ||||
| 			expectedServerResponse: errInvalidGrant, | ||||
| 			expectedResponseCode:   http.StatusBadRequest, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tc := range tests { | ||||
| 		t.Run(tc.testName, func(t *testing.T) { | ||||
| @@ -558,6 +681,9 @@ func TestDeviceTokenResponse(t *testing.T) { | ||||
| 			} | ||||
| 			data.Set("grant_type", grantType) | ||||
| 			data.Set("device_code", tc.testDeviceCode) | ||||
| 			if tc.testCodeVerifier != "" { | ||||
| 				data.Set("code_verifier", tc.testCodeVerifier) | ||||
| 			} | ||||
| 			req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode())) | ||||
| 			req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") | ||||
|  | ||||
|   | ||||
| @@ -890,6 +890,10 @@ func testGC(t *testing.T, s storage.Storage) { | ||||
| 		Expiry:              expiry, | ||||
| 		LastRequestTime:     time.Now(), | ||||
| 		PollIntervalSeconds: 0, | ||||
| 		PKCE: storage.PKCE{ | ||||
| 			CodeChallenge:       "challenge", | ||||
| 			CodeChallengeMethod: "S256", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	if err := s.CreateDeviceToken(dt); err != nil { | ||||
| @@ -989,6 +993,11 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { | ||||
| } | ||||
|  | ||||
| func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { | ||||
| 	codeChallenge := storage.PKCE{ | ||||
| 		CodeChallenge:       "code_challenge_test", | ||||
| 		CodeChallengeMethod: "plain", | ||||
| 	} | ||||
|  | ||||
| 	// Create a Token | ||||
| 	d1 := storage.DeviceToken{ | ||||
| 		DeviceCode:          storage.NewID(), | ||||
| @@ -997,6 +1006,7 @@ func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { | ||||
| 		Expiry:              neverExpire, | ||||
| 		LastRequestTime:     time.Now(), | ||||
| 		PollIntervalSeconds: 0, | ||||
| 		PKCE:                codeChallenge, | ||||
| 	} | ||||
|  | ||||
| 	if err := s.CreateDeviceToken(d1); err != nil { | ||||
| @@ -1029,4 +1039,7 @@ func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { | ||||
| 	if got.Token != "token data" { | ||||
| 		t.Fatalf("update failed, wanted token %v got %v", "token data", got.Token) | ||||
| 	} | ||||
| 	if !reflect.DeepEqual(got.PKCE, codeChallenge) { | ||||
| 		t.Fatalf("storage does not support PKCE, wanted challenge=%#v got %#v", codeChallenge, got.PKCE) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -17,6 +17,8 @@ func (d *Database) CreateDeviceToken(token storage.DeviceToken) error { | ||||
| 		SetExpiry(token.Expiry.UTC()). | ||||
| 		SetLastRequest(token.LastRequestTime.UTC()). | ||||
| 		SetStatus(token.Status). | ||||
| 		SetCodeChallenge(token.PKCE.CodeChallenge). | ||||
| 		SetCodeChallengeMethod(token.PKCE.CodeChallengeMethod). | ||||
| 		Save(context.TODO()) | ||||
| 	if err != nil { | ||||
| 		return convertDBError("create device token: %w", err) | ||||
| @@ -63,6 +65,8 @@ func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage | ||||
| 		SetExpiry(newToken.Expiry.UTC()). | ||||
| 		SetLastRequest(newToken.LastRequestTime.UTC()). | ||||
| 		SetStatus(newToken.Status). | ||||
| 		SetCodeChallenge(newToken.PKCE.CodeChallenge). | ||||
| 		SetCodeChallengeMethod(newToken.PKCE.CodeChallengeMethod). | ||||
| 		Save(context.TODO()) | ||||
| 	if err != nil { | ||||
| 		return rollback(tx, "update device token uploading: %w", err) | ||||
|   | ||||
| @@ -164,5 +164,9 @@ func toStorageDeviceToken(t *db.DeviceToken) storage.DeviceToken { | ||||
| 		Expiry:              t.Expiry, | ||||
| 		LastRequestTime:     t.LastRequest, | ||||
| 		PollIntervalSeconds: t.PollInterval, | ||||
| 		PKCE: storage.PKCE{ | ||||
| 			CodeChallenge:       t.CodeChallenge, | ||||
| 			CodeChallengeMethod: t.CodeChallengeMethod, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -28,6 +28,10 @@ type DeviceToken struct { | ||||
| 	LastRequest time.Time `json:"last_request,omitempty"` | ||||
| 	// PollInterval holds the value of the "poll_interval" field. | ||||
| 	PollInterval int `json:"poll_interval,omitempty"` | ||||
| 	// CodeChallenge holds the value of the "code_challenge" field. | ||||
| 	CodeChallenge string `json:"code_challenge,omitempty"` | ||||
| 	// CodeChallengeMethod holds the value of the "code_challenge_method" field. | ||||
| 	CodeChallengeMethod string `json:"code_challenge_method,omitempty"` | ||||
| } | ||||
|  | ||||
| // scanValues returns the types for scanning values from sql.Rows. | ||||
| @@ -39,7 +43,7 @@ func (*DeviceToken) scanValues(columns []string) ([]interface{}, error) { | ||||
| 			values[i] = new([]byte) | ||||
| 		case devicetoken.FieldID, devicetoken.FieldPollInterval: | ||||
| 			values[i] = new(sql.NullInt64) | ||||
| 		case devicetoken.FieldDeviceCode, devicetoken.FieldStatus: | ||||
| 		case devicetoken.FieldDeviceCode, devicetoken.FieldStatus, devicetoken.FieldCodeChallenge, devicetoken.FieldCodeChallengeMethod: | ||||
| 			values[i] = new(sql.NullString) | ||||
| 		case devicetoken.FieldExpiry, devicetoken.FieldLastRequest: | ||||
| 			values[i] = new(sql.NullTime) | ||||
| @@ -100,6 +104,18 @@ func (dt *DeviceToken) assignValues(columns []string, values []interface{}) erro | ||||
| 			} else if value.Valid { | ||||
| 				dt.PollInterval = int(value.Int64) | ||||
| 			} | ||||
| 		case devicetoken.FieldCodeChallenge: | ||||
| 			if value, ok := values[i].(*sql.NullString); !ok { | ||||
| 				return fmt.Errorf("unexpected type %T for field code_challenge", values[i]) | ||||
| 			} else if value.Valid { | ||||
| 				dt.CodeChallenge = value.String | ||||
| 			} | ||||
| 		case devicetoken.FieldCodeChallengeMethod: | ||||
| 			if value, ok := values[i].(*sql.NullString); !ok { | ||||
| 				return fmt.Errorf("unexpected type %T for field code_challenge_method", values[i]) | ||||
| 			} else if value.Valid { | ||||
| 				dt.CodeChallengeMethod = value.String | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| @@ -142,6 +158,10 @@ func (dt *DeviceToken) String() string { | ||||
| 	builder.WriteString(dt.LastRequest.Format(time.ANSIC)) | ||||
| 	builder.WriteString(", poll_interval=") | ||||
| 	builder.WriteString(fmt.Sprintf("%v", dt.PollInterval)) | ||||
| 	builder.WriteString(", code_challenge=") | ||||
| 	builder.WriteString(dt.CodeChallenge) | ||||
| 	builder.WriteString(", code_challenge_method=") | ||||
| 	builder.WriteString(dt.CodeChallengeMethod) | ||||
| 	builder.WriteByte(')') | ||||
| 	return builder.String() | ||||
| } | ||||
|   | ||||
| @@ -19,6 +19,10 @@ const ( | ||||
| 	FieldLastRequest = "last_request" | ||||
| 	// FieldPollInterval holds the string denoting the poll_interval field in the database. | ||||
| 	FieldPollInterval = "poll_interval" | ||||
| 	// FieldCodeChallenge holds the string denoting the code_challenge field in the database. | ||||
| 	FieldCodeChallenge = "code_challenge" | ||||
| 	// FieldCodeChallengeMethod holds the string denoting the code_challenge_method field in the database. | ||||
| 	FieldCodeChallengeMethod = "code_challenge_method" | ||||
| 	// Table holds the table name of the devicetoken in the database. | ||||
| 	Table = "device_tokens" | ||||
| ) | ||||
| @@ -32,6 +36,8 @@ var Columns = []string{ | ||||
| 	FieldExpiry, | ||||
| 	FieldLastRequest, | ||||
| 	FieldPollInterval, | ||||
| 	FieldCodeChallenge, | ||||
| 	FieldCodeChallengeMethod, | ||||
| } | ||||
|  | ||||
| // ValidColumn reports if the column name is valid (part of the table columns). | ||||
| @@ -49,4 +55,8 @@ var ( | ||||
| 	DeviceCodeValidator func(string) error | ||||
| 	// StatusValidator is a validator for the "status" field. It is called by the builders before save. | ||||
| 	StatusValidator func(string) error | ||||
| 	// DefaultCodeChallenge holds the default value on creation for the "code_challenge" field. | ||||
| 	DefaultCodeChallenge string | ||||
| 	// DefaultCodeChallengeMethod holds the default value on creation for the "code_challenge_method" field. | ||||
| 	DefaultCodeChallengeMethod string | ||||
| ) | ||||
|   | ||||
| @@ -134,6 +134,20 @@ func PollInterval(v int) predicate.DeviceToken { | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallenge applies equality check predicate on the "code_challenge" field. It's identical to CodeChallengeEQ. | ||||
| func CodeChallenge(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.EQ(s.C(FieldCodeChallenge), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethod applies equality check predicate on the "code_challenge_method" field. It's identical to CodeChallengeMethodEQ. | ||||
| func CodeChallengeMethod(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.EQ(s.C(FieldCodeChallengeMethod), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // DeviceCodeEQ applies the EQ predicate on the "device_code" field. | ||||
| func DeviceCodeEQ(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| @@ -674,6 +688,228 @@ func PollIntervalLTE(v int) predicate.DeviceToken { | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeEQ applies the EQ predicate on the "code_challenge" field. | ||||
| func CodeChallengeEQ(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.EQ(s.C(FieldCodeChallenge), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeNEQ applies the NEQ predicate on the "code_challenge" field. | ||||
| func CodeChallengeNEQ(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.NEQ(s.C(FieldCodeChallenge), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeIn applies the In predicate on the "code_challenge" field. | ||||
| func CodeChallengeIn(vs ...string) predicate.DeviceToken { | ||||
| 	v := make([]interface{}, len(vs)) | ||||
| 	for i := range v { | ||||
| 		v[i] = vs[i] | ||||
| 	} | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		// if not arguments were provided, append the FALSE constants, | ||||
| 		// since we can't apply "IN ()". This will make this predicate falsy. | ||||
| 		if len(v) == 0 { | ||||
| 			s.Where(sql.False()) | ||||
| 			return | ||||
| 		} | ||||
| 		s.Where(sql.In(s.C(FieldCodeChallenge), v...)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeNotIn applies the NotIn predicate on the "code_challenge" field. | ||||
| func CodeChallengeNotIn(vs ...string) predicate.DeviceToken { | ||||
| 	v := make([]interface{}, len(vs)) | ||||
| 	for i := range v { | ||||
| 		v[i] = vs[i] | ||||
| 	} | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		// if not arguments were provided, append the FALSE constants, | ||||
| 		// since we can't apply "IN ()". This will make this predicate falsy. | ||||
| 		if len(v) == 0 { | ||||
| 			s.Where(sql.False()) | ||||
| 			return | ||||
| 		} | ||||
| 		s.Where(sql.NotIn(s.C(FieldCodeChallenge), v...)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeGT applies the GT predicate on the "code_challenge" field. | ||||
| func CodeChallengeGT(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.GT(s.C(FieldCodeChallenge), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeGTE applies the GTE predicate on the "code_challenge" field. | ||||
| func CodeChallengeGTE(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.GTE(s.C(FieldCodeChallenge), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeLT applies the LT predicate on the "code_challenge" field. | ||||
| func CodeChallengeLT(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.LT(s.C(FieldCodeChallenge), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeLTE applies the LTE predicate on the "code_challenge" field. | ||||
| func CodeChallengeLTE(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.LTE(s.C(FieldCodeChallenge), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeContains applies the Contains predicate on the "code_challenge" field. | ||||
| func CodeChallengeContains(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.Contains(s.C(FieldCodeChallenge), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeHasPrefix applies the HasPrefix predicate on the "code_challenge" field. | ||||
| func CodeChallengeHasPrefix(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.HasPrefix(s.C(FieldCodeChallenge), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeHasSuffix applies the HasSuffix predicate on the "code_challenge" field. | ||||
| func CodeChallengeHasSuffix(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.HasSuffix(s.C(FieldCodeChallenge), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeEqualFold applies the EqualFold predicate on the "code_challenge" field. | ||||
| func CodeChallengeEqualFold(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.EqualFold(s.C(FieldCodeChallenge), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeContainsFold applies the ContainsFold predicate on the "code_challenge" field. | ||||
| func CodeChallengeContainsFold(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.ContainsFold(s.C(FieldCodeChallenge), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodEQ applies the EQ predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodEQ(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.EQ(s.C(FieldCodeChallengeMethod), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodNEQ applies the NEQ predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodNEQ(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.NEQ(s.C(FieldCodeChallengeMethod), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodIn applies the In predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodIn(vs ...string) predicate.DeviceToken { | ||||
| 	v := make([]interface{}, len(vs)) | ||||
| 	for i := range v { | ||||
| 		v[i] = vs[i] | ||||
| 	} | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		// if not arguments were provided, append the FALSE constants, | ||||
| 		// since we can't apply "IN ()". This will make this predicate falsy. | ||||
| 		if len(v) == 0 { | ||||
| 			s.Where(sql.False()) | ||||
| 			return | ||||
| 		} | ||||
| 		s.Where(sql.In(s.C(FieldCodeChallengeMethod), v...)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodNotIn applies the NotIn predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodNotIn(vs ...string) predicate.DeviceToken { | ||||
| 	v := make([]interface{}, len(vs)) | ||||
| 	for i := range v { | ||||
| 		v[i] = vs[i] | ||||
| 	} | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		// if not arguments were provided, append the FALSE constants, | ||||
| 		// since we can't apply "IN ()". This will make this predicate falsy. | ||||
| 		if len(v) == 0 { | ||||
| 			s.Where(sql.False()) | ||||
| 			return | ||||
| 		} | ||||
| 		s.Where(sql.NotIn(s.C(FieldCodeChallengeMethod), v...)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodGT applies the GT predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodGT(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.GT(s.C(FieldCodeChallengeMethod), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodGTE applies the GTE predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodGTE(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.GTE(s.C(FieldCodeChallengeMethod), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodLT applies the LT predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodLT(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.LT(s.C(FieldCodeChallengeMethod), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodLTE applies the LTE predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodLTE(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.LTE(s.C(FieldCodeChallengeMethod), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodContains applies the Contains predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodContains(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.Contains(s.C(FieldCodeChallengeMethod), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodHasPrefix applies the HasPrefix predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodHasPrefix(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.HasPrefix(s.C(FieldCodeChallengeMethod), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodHasSuffix applies the HasSuffix predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodHasSuffix(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.HasSuffix(s.C(FieldCodeChallengeMethod), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodEqualFold applies the EqualFold predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodEqualFold(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.EqualFold(s.C(FieldCodeChallengeMethod), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethodContainsFold applies the ContainsFold predicate on the "code_challenge_method" field. | ||||
| func CodeChallengeMethodContainsFold(v string) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
| 		s.Where(sql.ContainsFold(s.C(FieldCodeChallengeMethod), v)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // And groups predicates with the AND operator between them. | ||||
| func And(predicates ...predicate.DeviceToken) predicate.DeviceToken { | ||||
| 	return predicate.DeviceToken(func(s *sql.Selector) { | ||||
|   | ||||
| @@ -56,6 +56,34 @@ func (dtc *DeviceTokenCreate) SetPollInterval(i int) *DeviceTokenCreate { | ||||
| 	return dtc | ||||
| } | ||||
|  | ||||
| // SetCodeChallenge sets the "code_challenge" field. | ||||
| func (dtc *DeviceTokenCreate) SetCodeChallenge(s string) *DeviceTokenCreate { | ||||
| 	dtc.mutation.SetCodeChallenge(s) | ||||
| 	return dtc | ||||
| } | ||||
|  | ||||
| // SetNillableCodeChallenge sets the "code_challenge" field if the given value is not nil. | ||||
| func (dtc *DeviceTokenCreate) SetNillableCodeChallenge(s *string) *DeviceTokenCreate { | ||||
| 	if s != nil { | ||||
| 		dtc.SetCodeChallenge(*s) | ||||
| 	} | ||||
| 	return dtc | ||||
| } | ||||
|  | ||||
| // SetCodeChallengeMethod sets the "code_challenge_method" field. | ||||
| func (dtc *DeviceTokenCreate) SetCodeChallengeMethod(s string) *DeviceTokenCreate { | ||||
| 	dtc.mutation.SetCodeChallengeMethod(s) | ||||
| 	return dtc | ||||
| } | ||||
|  | ||||
| // SetNillableCodeChallengeMethod sets the "code_challenge_method" field if the given value is not nil. | ||||
| func (dtc *DeviceTokenCreate) SetNillableCodeChallengeMethod(s *string) *DeviceTokenCreate { | ||||
| 	if s != nil { | ||||
| 		dtc.SetCodeChallengeMethod(*s) | ||||
| 	} | ||||
| 	return dtc | ||||
| } | ||||
|  | ||||
| // Mutation returns the DeviceTokenMutation object of the builder. | ||||
| func (dtc *DeviceTokenCreate) Mutation() *DeviceTokenMutation { | ||||
| 	return dtc.mutation | ||||
| @@ -67,6 +95,7 @@ func (dtc *DeviceTokenCreate) Save(ctx context.Context) (*DeviceToken, error) { | ||||
| 		err  error | ||||
| 		node *DeviceToken | ||||
| 	) | ||||
| 	dtc.defaults() | ||||
| 	if len(dtc.hooks) == 0 { | ||||
| 		if err = dtc.check(); err != nil { | ||||
| 			return nil, err | ||||
| @@ -124,6 +153,18 @@ func (dtc *DeviceTokenCreate) ExecX(ctx context.Context) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // defaults sets the default values of the builder before save. | ||||
| func (dtc *DeviceTokenCreate) defaults() { | ||||
| 	if _, ok := dtc.mutation.CodeChallenge(); !ok { | ||||
| 		v := devicetoken.DefaultCodeChallenge | ||||
| 		dtc.mutation.SetCodeChallenge(v) | ||||
| 	} | ||||
| 	if _, ok := dtc.mutation.CodeChallengeMethod(); !ok { | ||||
| 		v := devicetoken.DefaultCodeChallengeMethod | ||||
| 		dtc.mutation.SetCodeChallengeMethod(v) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // check runs all checks and user-defined validators on the builder. | ||||
| func (dtc *DeviceTokenCreate) check() error { | ||||
| 	if _, ok := dtc.mutation.DeviceCode(); !ok { | ||||
| @@ -151,6 +192,12 @@ func (dtc *DeviceTokenCreate) check() error { | ||||
| 	if _, ok := dtc.mutation.PollInterval(); !ok { | ||||
| 		return &ValidationError{Name: "poll_interval", err: errors.New(`db: missing required field "DeviceToken.poll_interval"`)} | ||||
| 	} | ||||
| 	if _, ok := dtc.mutation.CodeChallenge(); !ok { | ||||
| 		return &ValidationError{Name: "code_challenge", err: errors.New(`db: missing required field "DeviceToken.code_challenge"`)} | ||||
| 	} | ||||
| 	if _, ok := dtc.mutation.CodeChallengeMethod(); !ok { | ||||
| 		return &ValidationError{Name: "code_challenge_method", err: errors.New(`db: missing required field "DeviceToken.code_challenge_method"`)} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -226,6 +273,22 @@ func (dtc *DeviceTokenCreate) createSpec() (*DeviceToken, *sqlgraph.CreateSpec) | ||||
| 		}) | ||||
| 		_node.PollInterval = value | ||||
| 	} | ||||
| 	if value, ok := dtc.mutation.CodeChallenge(); ok { | ||||
| 		_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ | ||||
| 			Type:   field.TypeString, | ||||
| 			Value:  value, | ||||
| 			Column: devicetoken.FieldCodeChallenge, | ||||
| 		}) | ||||
| 		_node.CodeChallenge = value | ||||
| 	} | ||||
| 	if value, ok := dtc.mutation.CodeChallengeMethod(); ok { | ||||
| 		_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ | ||||
| 			Type:   field.TypeString, | ||||
| 			Value:  value, | ||||
| 			Column: devicetoken.FieldCodeChallengeMethod, | ||||
| 		}) | ||||
| 		_node.CodeChallengeMethod = value | ||||
| 	} | ||||
| 	return _node, _spec | ||||
| } | ||||
|  | ||||
| @@ -243,6 +306,7 @@ func (dtcb *DeviceTokenCreateBulk) Save(ctx context.Context) ([]*DeviceToken, er | ||||
| 	for i := range dtcb.builders { | ||||
| 		func(i int, root context.Context) { | ||||
| 			builder := dtcb.builders[i] | ||||
| 			builder.defaults() | ||||
| 			var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { | ||||
| 				mutation, ok := m.(*DeviceTokenMutation) | ||||
| 				if !ok { | ||||
|   | ||||
| @@ -77,6 +77,34 @@ func (dtu *DeviceTokenUpdate) AddPollInterval(i int) *DeviceTokenUpdate { | ||||
| 	return dtu | ||||
| } | ||||
|  | ||||
| // SetCodeChallenge sets the "code_challenge" field. | ||||
| func (dtu *DeviceTokenUpdate) SetCodeChallenge(s string) *DeviceTokenUpdate { | ||||
| 	dtu.mutation.SetCodeChallenge(s) | ||||
| 	return dtu | ||||
| } | ||||
|  | ||||
| // SetNillableCodeChallenge sets the "code_challenge" field if the given value is not nil. | ||||
| func (dtu *DeviceTokenUpdate) SetNillableCodeChallenge(s *string) *DeviceTokenUpdate { | ||||
| 	if s != nil { | ||||
| 		dtu.SetCodeChallenge(*s) | ||||
| 	} | ||||
| 	return dtu | ||||
| } | ||||
|  | ||||
| // SetCodeChallengeMethod sets the "code_challenge_method" field. | ||||
| func (dtu *DeviceTokenUpdate) SetCodeChallengeMethod(s string) *DeviceTokenUpdate { | ||||
| 	dtu.mutation.SetCodeChallengeMethod(s) | ||||
| 	return dtu | ||||
| } | ||||
|  | ||||
| // SetNillableCodeChallengeMethod sets the "code_challenge_method" field if the given value is not nil. | ||||
| func (dtu *DeviceTokenUpdate) SetNillableCodeChallengeMethod(s *string) *DeviceTokenUpdate { | ||||
| 	if s != nil { | ||||
| 		dtu.SetCodeChallengeMethod(*s) | ||||
| 	} | ||||
| 	return dtu | ||||
| } | ||||
|  | ||||
| // Mutation returns the DeviceTokenMutation object of the builder. | ||||
| func (dtu *DeviceTokenUpdate) Mutation() *DeviceTokenMutation { | ||||
| 	return dtu.mutation | ||||
| @@ -230,6 +258,20 @@ func (dtu *DeviceTokenUpdate) sqlSave(ctx context.Context) (n int, err error) { | ||||
| 			Column: devicetoken.FieldPollInterval, | ||||
| 		}) | ||||
| 	} | ||||
| 	if value, ok := dtu.mutation.CodeChallenge(); ok { | ||||
| 		_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ | ||||
| 			Type:   field.TypeString, | ||||
| 			Value:  value, | ||||
| 			Column: devicetoken.FieldCodeChallenge, | ||||
| 		}) | ||||
| 	} | ||||
| 	if value, ok := dtu.mutation.CodeChallengeMethod(); ok { | ||||
| 		_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ | ||||
| 			Type:   field.TypeString, | ||||
| 			Value:  value, | ||||
| 			Column: devicetoken.FieldCodeChallengeMethod, | ||||
| 		}) | ||||
| 	} | ||||
| 	if n, err = sqlgraph.UpdateNodes(ctx, dtu.driver, _spec); err != nil { | ||||
| 		if _, ok := err.(*sqlgraph.NotFoundError); ok { | ||||
| 			err = &NotFoundError{devicetoken.Label} | ||||
| @@ -298,6 +340,34 @@ func (dtuo *DeviceTokenUpdateOne) AddPollInterval(i int) *DeviceTokenUpdateOne { | ||||
| 	return dtuo | ||||
| } | ||||
|  | ||||
| // SetCodeChallenge sets the "code_challenge" field. | ||||
| func (dtuo *DeviceTokenUpdateOne) SetCodeChallenge(s string) *DeviceTokenUpdateOne { | ||||
| 	dtuo.mutation.SetCodeChallenge(s) | ||||
| 	return dtuo | ||||
| } | ||||
|  | ||||
| // SetNillableCodeChallenge sets the "code_challenge" field if the given value is not nil. | ||||
| func (dtuo *DeviceTokenUpdateOne) SetNillableCodeChallenge(s *string) *DeviceTokenUpdateOne { | ||||
| 	if s != nil { | ||||
| 		dtuo.SetCodeChallenge(*s) | ||||
| 	} | ||||
| 	return dtuo | ||||
| } | ||||
|  | ||||
| // SetCodeChallengeMethod sets the "code_challenge_method" field. | ||||
| func (dtuo *DeviceTokenUpdateOne) SetCodeChallengeMethod(s string) *DeviceTokenUpdateOne { | ||||
| 	dtuo.mutation.SetCodeChallengeMethod(s) | ||||
| 	return dtuo | ||||
| } | ||||
|  | ||||
| // SetNillableCodeChallengeMethod sets the "code_challenge_method" field if the given value is not nil. | ||||
| func (dtuo *DeviceTokenUpdateOne) SetNillableCodeChallengeMethod(s *string) *DeviceTokenUpdateOne { | ||||
| 	if s != nil { | ||||
| 		dtuo.SetCodeChallengeMethod(*s) | ||||
| 	} | ||||
| 	return dtuo | ||||
| } | ||||
|  | ||||
| // Mutation returns the DeviceTokenMutation object of the builder. | ||||
| func (dtuo *DeviceTokenUpdateOne) Mutation() *DeviceTokenMutation { | ||||
| 	return dtuo.mutation | ||||
| @@ -475,6 +545,20 @@ func (dtuo *DeviceTokenUpdateOne) sqlSave(ctx context.Context) (_node *DeviceTok | ||||
| 			Column: devicetoken.FieldPollInterval, | ||||
| 		}) | ||||
| 	} | ||||
| 	if value, ok := dtuo.mutation.CodeChallenge(); ok { | ||||
| 		_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ | ||||
| 			Type:   field.TypeString, | ||||
| 			Value:  value, | ||||
| 			Column: devicetoken.FieldCodeChallenge, | ||||
| 		}) | ||||
| 	} | ||||
| 	if value, ok := dtuo.mutation.CodeChallengeMethod(); ok { | ||||
| 		_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ | ||||
| 			Type:   field.TypeString, | ||||
| 			Value:  value, | ||||
| 			Column: devicetoken.FieldCodeChallengeMethod, | ||||
| 		}) | ||||
| 	} | ||||
| 	_node = &DeviceToken{config: dtuo.config} | ||||
| 	_spec.Assign = _node.assignValues | ||||
| 	_spec.ScanValues = _node.scanValues | ||||
|   | ||||
| @@ -101,6 +101,8 @@ var ( | ||||
| 		{Name: "expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, | ||||
| 		{Name: "last_request", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, | ||||
| 		{Name: "poll_interval", Type: field.TypeInt}, | ||||
| 		{Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, | ||||
| 		{Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, | ||||
| 	} | ||||
| 	// DeviceTokensTable holds the schema information for the "device_tokens" table. | ||||
| 	DeviceTokensTable = &schema.Table{ | ||||
|   | ||||
| @@ -3643,6 +3643,8 @@ type DeviceTokenMutation struct { | ||||
| 	last_request          *time.Time | ||||
| 	poll_interval         *int | ||||
| 	addpoll_interval      *int | ||||
| 	code_challenge        *string | ||||
| 	code_challenge_method *string | ||||
| 	clearedFields         map[string]struct{} | ||||
| 	done                  bool | ||||
| 	oldValue              func(context.Context) (*DeviceToken, error) | ||||
| @@ -3996,6 +3998,78 @@ func (m *DeviceTokenMutation) ResetPollInterval() { | ||||
| 	m.addpoll_interval = nil | ||||
| } | ||||
|  | ||||
| // SetCodeChallenge sets the "code_challenge" field. | ||||
| func (m *DeviceTokenMutation) SetCodeChallenge(s string) { | ||||
| 	m.code_challenge = &s | ||||
| } | ||||
|  | ||||
| // CodeChallenge returns the value of the "code_challenge" field in the mutation. | ||||
| func (m *DeviceTokenMutation) CodeChallenge() (r string, exists bool) { | ||||
| 	v := m.code_challenge | ||||
| 	if v == nil { | ||||
| 		return | ||||
| 	} | ||||
| 	return *v, true | ||||
| } | ||||
|  | ||||
| // OldCodeChallenge returns the old "code_challenge" field's value of the DeviceToken entity. | ||||
| // If the DeviceToken object wasn't provided to the builder, the object is fetched from the database. | ||||
| // An error is returned if the mutation operation is not UpdateOne, or the database query fails. | ||||
| func (m *DeviceTokenMutation) OldCodeChallenge(ctx context.Context) (v string, err error) { | ||||
| 	if !m.op.Is(OpUpdateOne) { | ||||
| 		return v, errors.New("OldCodeChallenge is only allowed on UpdateOne operations") | ||||
| 	} | ||||
| 	if m.id == nil || m.oldValue == nil { | ||||
| 		return v, errors.New("OldCodeChallenge requires an ID field in the mutation") | ||||
| 	} | ||||
| 	oldValue, err := m.oldValue(ctx) | ||||
| 	if err != nil { | ||||
| 		return v, fmt.Errorf("querying old value for OldCodeChallenge: %w", err) | ||||
| 	} | ||||
| 	return oldValue.CodeChallenge, nil | ||||
| } | ||||
|  | ||||
| // ResetCodeChallenge resets all changes to the "code_challenge" field. | ||||
| func (m *DeviceTokenMutation) ResetCodeChallenge() { | ||||
| 	m.code_challenge = nil | ||||
| } | ||||
|  | ||||
| // SetCodeChallengeMethod sets the "code_challenge_method" field. | ||||
| func (m *DeviceTokenMutation) SetCodeChallengeMethod(s string) { | ||||
| 	m.code_challenge_method = &s | ||||
| } | ||||
|  | ||||
| // CodeChallengeMethod returns the value of the "code_challenge_method" field in the mutation. | ||||
| func (m *DeviceTokenMutation) CodeChallengeMethod() (r string, exists bool) { | ||||
| 	v := m.code_challenge_method | ||||
| 	if v == nil { | ||||
| 		return | ||||
| 	} | ||||
| 	return *v, true | ||||
| } | ||||
|  | ||||
| // OldCodeChallengeMethod returns the old "code_challenge_method" field's value of the DeviceToken entity. | ||||
| // If the DeviceToken object wasn't provided to the builder, the object is fetched from the database. | ||||
| // An error is returned if the mutation operation is not UpdateOne, or the database query fails. | ||||
| func (m *DeviceTokenMutation) OldCodeChallengeMethod(ctx context.Context) (v string, err error) { | ||||
| 	if !m.op.Is(OpUpdateOne) { | ||||
| 		return v, errors.New("OldCodeChallengeMethod is only allowed on UpdateOne operations") | ||||
| 	} | ||||
| 	if m.id == nil || m.oldValue == nil { | ||||
| 		return v, errors.New("OldCodeChallengeMethod requires an ID field in the mutation") | ||||
| 	} | ||||
| 	oldValue, err := m.oldValue(ctx) | ||||
| 	if err != nil { | ||||
| 		return v, fmt.Errorf("querying old value for OldCodeChallengeMethod: %w", err) | ||||
| 	} | ||||
| 	return oldValue.CodeChallengeMethod, nil | ||||
| } | ||||
|  | ||||
| // ResetCodeChallengeMethod resets all changes to the "code_challenge_method" field. | ||||
| func (m *DeviceTokenMutation) ResetCodeChallengeMethod() { | ||||
| 	m.code_challenge_method = nil | ||||
| } | ||||
|  | ||||
| // Where appends a list predicates to the DeviceTokenMutation builder. | ||||
| func (m *DeviceTokenMutation) Where(ps ...predicate.DeviceToken) { | ||||
| 	m.predicates = append(m.predicates, ps...) | ||||
| @@ -4015,7 +4089,7 @@ func (m *DeviceTokenMutation) Type() string { | ||||
| // order to get all numeric fields that were incremented/decremented, call | ||||
| // AddedFields(). | ||||
| func (m *DeviceTokenMutation) Fields() []string { | ||||
| 	fields := make([]string, 0, 6) | ||||
| 	fields := make([]string, 0, 8) | ||||
| 	if m.device_code != nil { | ||||
| 		fields = append(fields, devicetoken.FieldDeviceCode) | ||||
| 	} | ||||
| @@ -4034,6 +4108,12 @@ func (m *DeviceTokenMutation) Fields() []string { | ||||
| 	if m.poll_interval != nil { | ||||
| 		fields = append(fields, devicetoken.FieldPollInterval) | ||||
| 	} | ||||
| 	if m.code_challenge != nil { | ||||
| 		fields = append(fields, devicetoken.FieldCodeChallenge) | ||||
| 	} | ||||
| 	if m.code_challenge_method != nil { | ||||
| 		fields = append(fields, devicetoken.FieldCodeChallengeMethod) | ||||
| 	} | ||||
| 	return fields | ||||
| } | ||||
|  | ||||
| @@ -4054,6 +4134,10 @@ func (m *DeviceTokenMutation) Field(name string) (ent.Value, bool) { | ||||
| 		return m.LastRequest() | ||||
| 	case devicetoken.FieldPollInterval: | ||||
| 		return m.PollInterval() | ||||
| 	case devicetoken.FieldCodeChallenge: | ||||
| 		return m.CodeChallenge() | ||||
| 	case devicetoken.FieldCodeChallengeMethod: | ||||
| 		return m.CodeChallengeMethod() | ||||
| 	} | ||||
| 	return nil, false | ||||
| } | ||||
| @@ -4075,6 +4159,10 @@ func (m *DeviceTokenMutation) OldField(ctx context.Context, name string) (ent.Va | ||||
| 		return m.OldLastRequest(ctx) | ||||
| 	case devicetoken.FieldPollInterval: | ||||
| 		return m.OldPollInterval(ctx) | ||||
| 	case devicetoken.FieldCodeChallenge: | ||||
| 		return m.OldCodeChallenge(ctx) | ||||
| 	case devicetoken.FieldCodeChallengeMethod: | ||||
| 		return m.OldCodeChallengeMethod(ctx) | ||||
| 	} | ||||
| 	return nil, fmt.Errorf("unknown DeviceToken field %s", name) | ||||
| } | ||||
| @@ -4126,6 +4214,20 @@ func (m *DeviceTokenMutation) SetField(name string, value ent.Value) error { | ||||
| 		} | ||||
| 		m.SetPollInterval(v) | ||||
| 		return nil | ||||
| 	case devicetoken.FieldCodeChallenge: | ||||
| 		v, ok := value.(string) | ||||
| 		if !ok { | ||||
| 			return fmt.Errorf("unexpected type %T for field %s", value, name) | ||||
| 		} | ||||
| 		m.SetCodeChallenge(v) | ||||
| 		return nil | ||||
| 	case devicetoken.FieldCodeChallengeMethod: | ||||
| 		v, ok := value.(string) | ||||
| 		if !ok { | ||||
| 			return fmt.Errorf("unexpected type %T for field %s", value, name) | ||||
| 		} | ||||
| 		m.SetCodeChallengeMethod(v) | ||||
| 		return nil | ||||
| 	} | ||||
| 	return fmt.Errorf("unknown DeviceToken field %s", name) | ||||
| } | ||||
| @@ -4217,6 +4319,12 @@ func (m *DeviceTokenMutation) ResetField(name string) error { | ||||
| 	case devicetoken.FieldPollInterval: | ||||
| 		m.ResetPollInterval() | ||||
| 		return nil | ||||
| 	case devicetoken.FieldCodeChallenge: | ||||
| 		m.ResetCodeChallenge() | ||||
| 		return nil | ||||
| 	case devicetoken.FieldCodeChallengeMethod: | ||||
| 		m.ResetCodeChallengeMethod() | ||||
| 		return nil | ||||
| 	} | ||||
| 	return fmt.Errorf("unknown DeviceToken field %s", name) | ||||
| } | ||||
|   | ||||
| @@ -142,6 +142,14 @@ func init() { | ||||
| 	devicetokenDescStatus := devicetokenFields[1].Descriptor() | ||||
| 	// devicetoken.StatusValidator is a validator for the "status" field. It is called by the builders before save. | ||||
| 	devicetoken.StatusValidator = devicetokenDescStatus.Validators[0].(func(string) error) | ||||
| 	// devicetokenDescCodeChallenge is the schema descriptor for code_challenge field. | ||||
| 	devicetokenDescCodeChallenge := devicetokenFields[6].Descriptor() | ||||
| 	// devicetoken.DefaultCodeChallenge holds the default value on creation for the code_challenge field. | ||||
| 	devicetoken.DefaultCodeChallenge = devicetokenDescCodeChallenge.Default.(string) | ||||
| 	// devicetokenDescCodeChallengeMethod is the schema descriptor for code_challenge_method field. | ||||
| 	devicetokenDescCodeChallengeMethod := devicetokenFields[7].Descriptor() | ||||
| 	// devicetoken.DefaultCodeChallengeMethod holds the default value on creation for the code_challenge_method field. | ||||
| 	devicetoken.DefaultCodeChallengeMethod = devicetokenDescCodeChallengeMethod.Default.(string) | ||||
| 	keysFields := schema.Keys{}.Fields() | ||||
| 	_ = keysFields | ||||
| 	// keysDescID is the schema descriptor for id field. | ||||
|   | ||||
| @@ -13,7 +13,9 @@ create table device_token | ||||
|     token                 blob, | ||||
|     expiry                timestamp       not null, | ||||
|     last_request          timestamp       not null, | ||||
|     poll_interval integer   not null | ||||
|     poll_interval         integer         not null, | ||||
|     code_challenge        text default '' not null, | ||||
|     code_challenge_method text default '' not null | ||||
| ); | ||||
| */ | ||||
|  | ||||
| @@ -38,6 +40,12 @@ func (DeviceToken) Fields() []ent.Field { | ||||
| 		field.Time("last_request"). | ||||
| 			SchemaType(timeSchema), | ||||
| 		field.Int("poll_interval"), | ||||
| 		field.Text("code_challenge"). | ||||
| 			SchemaType(textSchema). | ||||
| 			Default(""), | ||||
| 		field.Text("code_challenge_method"). | ||||
| 			SchemaType(textSchema). | ||||
| 			Default(""), | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -605,8 +605,11 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { | ||||
| func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) { | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) | ||||
| 	defer cancel() | ||||
| 	err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t) | ||||
| 	return t, err | ||||
| 	var dt DeviceToken | ||||
| 	if err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &dt); err == nil { | ||||
| 		t = toStorageDeviceToken(dt) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) { | ||||
|   | ||||
| @@ -281,6 +281,8 @@ type DeviceToken struct { | ||||
| 	Expiry              time.Time `json:"expiry"` | ||||
| 	LastRequestTime     time.Time `json:"last_request"` | ||||
| 	PollIntervalSeconds int       `json:"poll_interval"` | ||||
| 	CodeChallenge       string    `json:"code_challenge,omitempty"` | ||||
| 	CodeChallengeMethod string    `json:"code_challenge_method,omitempty"` | ||||
| } | ||||
|  | ||||
| func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { | ||||
| @@ -291,6 +293,8 @@ func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { | ||||
| 		Expiry:              t.Expiry, | ||||
| 		LastRequestTime:     t.LastRequestTime, | ||||
| 		PollIntervalSeconds: t.PollIntervalSeconds, | ||||
| 		CodeChallenge:       t.PKCE.CodeChallenge, | ||||
| 		CodeChallengeMethod: t.PKCE.CodeChallengeMethod, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -302,5 +306,9 @@ func toStorageDeviceToken(t DeviceToken) storage.DeviceToken { | ||||
| 		Expiry:              t.Expiry, | ||||
| 		LastRequestTime:     t.LastRequestTime, | ||||
| 		PollIntervalSeconds: t.PollIntervalSeconds, | ||||
| 		PKCE: storage.PKCE{ | ||||
| 			CodeChallenge:       t.CodeChallenge, | ||||
| 			CodeChallengeMethod: t.CodeChallengeMethod, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -802,6 +802,8 @@ type DeviceToken struct { | ||||
| 	Expiry              time.Time `json:"expiry"` | ||||
| 	LastRequestTime     time.Time `json:"last_request"` | ||||
| 	PollIntervalSeconds int       `json:"poll_interval"` | ||||
| 	CodeChallenge       string    `json:"code_challenge,omitempty"` | ||||
| 	CodeChallengeMethod string    `json:"code_challenge_method,omitempty"` | ||||
| } | ||||
|  | ||||
| // DeviceTokenList is a list of DeviceTokens. | ||||
| @@ -826,6 +828,8 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { | ||||
| 		Expiry:              t.Expiry, | ||||
| 		LastRequestTime:     t.LastRequestTime, | ||||
| 		PollIntervalSeconds: t.PollIntervalSeconds, | ||||
| 		CodeChallenge:       t.PKCE.CodeChallenge, | ||||
| 		CodeChallengeMethod: t.PKCE.CodeChallengeMethod, | ||||
| 	} | ||||
| 	return req | ||||
| } | ||||
| @@ -838,5 +842,9 @@ func toStorageDeviceToken(t DeviceToken) storage.DeviceToken { | ||||
| 		Expiry:              t.Expiry, | ||||
| 		LastRequestTime:     t.LastRequestTime, | ||||
| 		PollIntervalSeconds: t.PollIntervalSeconds, | ||||
| 		PKCE: storage.PKCE{ | ||||
| 			CodeChallenge:       t.CodeChallenge, | ||||
| 			CodeChallengeMethod: t.CodeChallengeMethod, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -927,12 +927,12 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { | ||||
| func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { | ||||
| 	_, err := c.Exec(` | ||||
| 		insert into device_token ( | ||||
| 			device_code, status, token, expiry, last_request, poll_interval | ||||
| 			device_code, status, token, expiry, last_request, poll_interval, code_challenge, code_challenge_method | ||||
| 		) | ||||
| 		values ( | ||||
| 			$1, $2, $3, $4, $5, $6 | ||||
| 			$1, $2, $3, $4, $5, $6, $7, $8 | ||||
| 		);`, | ||||
| 		t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds, | ||||
| 		t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds, t.PKCE.CodeChallenge, t.PKCE.CodeChallengeMethod, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		if c.alreadyExistsCheck(err) { | ||||
| @@ -972,10 +972,10 @@ func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { | ||||
| func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) { | ||||
| 	err = q.QueryRow(` | ||||
| 		select | ||||
|             status, token, expiry, last_request, poll_interval | ||||
|             status, token, expiry, last_request, poll_interval, code_challenge, code_challenge_method | ||||
| 		from device_token where device_code = $1; | ||||
| 	`, deviceCode).Scan( | ||||
| 		&a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds, | ||||
| 		&a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds, &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		if err == sql.ErrNoRows { | ||||
| @@ -1002,11 +1002,13 @@ func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.Dev | ||||
| 				status = $1,  | ||||
| 				token = $2, | ||||
| 				last_request = $3, | ||||
| 				poll_interval = $4 | ||||
| 				poll_interval = $4, | ||||
| 				code_challenge = $5, | ||||
| 				code_challenge_method = $6 | ||||
| 			where | ||||
| 				device_code = $5 | ||||
| 				device_code = $7 | ||||
| 		`, | ||||
| 			r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.DeviceCode, | ||||
| 			r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.PKCE.CodeChallenge, r.PKCE.CodeChallengeMethod, r.DeviceCode, | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("update device token: %v", err) | ||||
|   | ||||
| @@ -281,4 +281,14 @@ var migrations = []migration{ | ||||
| 				add column obsolete_token text default '';`, | ||||
| 		}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		stmts: []string{ | ||||
| 			` | ||||
| 			alter table device_token | ||||
| 				add column code_challenge text not null default '';`, | ||||
| 			` | ||||
| 			alter table device_token | ||||
| 				add column code_challenge_method text not null default '';`, | ||||
| 		}, | ||||
| 	}, | ||||
| } | ||||
|   | ||||
| @@ -427,4 +427,5 @@ type DeviceToken struct { | ||||
| 	Expiry              time.Time | ||||
| 	LastRequestTime     time.Time | ||||
| 	PollIntervalSeconds int | ||||
| 	PKCE                PKCE | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user