Coverage for src/kwai/modules/identity/tokens/user_token_db_repository.py: 100%

15 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements a user token repository for a database.""" 

2 

3from sql_smith.functions import express, field 

4 

5from kwai.core.db.database import Database 

6from kwai.modules.identity.tokens.token_tables import AccessTokenRow, RefreshTokenRow 

7from kwai.modules.identity.tokens.user_token_repository import UserTokenRepository 

8from kwai.modules.identity.users.user_account import UserAccountEntity 

9 

10 

11class UserTokenDbRepository(UserTokenRepository): 

12 """Implements a user token repository for a database.""" 

13 

14 def __init__(self, database: Database): 

15 self._database = database 

16 

17 async def revoke(self, user_account: UserAccountEntity): 

18 query_factory = Database.create_query_factory() 

19 

20 update_access_token_query = query_factory.update( 

21 AccessTokenRow.__table_name__, {"revoked": 1} 

22 ).where(field(AccessTokenRow.column("user_id")).eq(user_account.id.value)) 

23 await self._database.execute(update_access_token_query) 

24 

25 select_access_tokens = ( 

26 query_factory.select(AccessTokenRow.column("id")) 

27 .from_(AccessTokenRow.__table_name__) 

28 .where(AccessTokenRow.field("user_id").eq(user_account.id.value)) 

29 ) 

30 

31 update_refresh_token_query = query_factory.update( 

32 RefreshTokenRow.__table_name__, {"revoked": 1} 

33 ).where( 

34 field(RefreshTokenRow.column("access_token_id")).in_( 

35 express("%s", select_access_tokens) 

36 ) 

37 ) 

38 await self._database.execute(update_refresh_token_query)