Coverage for src/kwai/modules/identity/tokens/user_token_db_repository.py: 100%
15 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that implements a user token repository for a database."""
3from sql_smith.functions import express, field
5from kwai.core.db.database import Database
6from kwai.modules.identity.tokens.token_tables import AccessTokenRow, RefreshTokenRow
7from kwai.modules.identity.tokens.user_token_repository import UserTokenRepository
8from kwai.modules.identity.users.user_account import UserAccountEntity
11class UserTokenDbRepository(UserTokenRepository):
12 """Implements a user token repository for a database."""
14 def __init__(self, database: Database):
15 self._database = database
17 async def revoke(self, user_account: UserAccountEntity):
18 query_factory = Database.create_query_factory()
20 update_access_token_query = query_factory.update(
21 AccessTokenRow.__table_name__, {"revoked": 1}
22 ).where(field(AccessTokenRow.column("user_id")).eq(user_account.id.value))
23 await self._database.execute(update_access_token_query)
25 select_access_tokens = (
26 query_factory.select(AccessTokenRow.column("id"))
27 .from_(AccessTokenRow.__table_name__)
28 .where(AccessTokenRow.field("user_id").eq(user_account.id.value))
29 )
31 update_refresh_token_query = query_factory.update(
32 RefreshTokenRow.__table_name__, {"revoked": 1}
33 ).where(
34 field(RefreshTokenRow.column("access_token_id")).in_(
35 express("%s", select_access_tokens)
36 )
37 )
38 await self._database.execute(update_refresh_token_query)