Coverage for src/kwai/modules/identity/tokens/access_token_db_repository.py: 82%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements an access token repository for a database.""" 

2 

3from typing import Any, AsyncIterator 

4 

5from kwai.core.db.database import Database 

6from kwai.modules.identity.tokens.access_token import ( 

7 AccessTokenEntity, 

8 AccessTokenIdentifier, 

9) 

10from kwai.modules.identity.tokens.access_token_db_query import AccessTokenDbQuery 

11from kwai.modules.identity.tokens.access_token_query import AccessTokenQuery 

12from kwai.modules.identity.tokens.access_token_repository import ( 

13 AccessTokenNotFoundException, 

14 AccessTokenRepository, 

15) 

16from kwai.modules.identity.tokens.token_identifier import TokenIdentifier 

17from kwai.modules.identity.tokens.token_tables import AccessTokenRow 

18from kwai.modules.identity.users.user_tables import UserAccountRow 

19 

20 

21def _create_entity(row: dict[str, Any]) -> AccessTokenEntity: 

22 """Create an access token entity from a row.""" 

23 return AccessTokenRow.map(row).create_entity( 

24 UserAccountRow.map(row).create_entity() 

25 ) 

26 

27 

28class AccessTokenDbRepository(AccessTokenRepository): 

29 """Database repository for the access token entity.""" 

30 

31 def __init__(self, database: Database): 

32 self._database = database 

33 

34 def create_query(self) -> AccessTokenQuery: 

35 return AccessTokenDbQuery(self._database) 

36 

37 async def get(self, id_: AccessTokenIdentifier) -> AccessTokenEntity: 

38 query = self.create_query() 

39 query.filter_by_id(id_.value) 

40 

41 row = await query.fetch_one() 

42 if row: 

43 return _create_entity(row) 

44 

45 raise AccessTokenNotFoundException 

46 

47 async def get_by_identifier(self, identifier: TokenIdentifier) -> AccessTokenEntity: 

48 query = self.create_query() 

49 query.filter_by_token_identifier(identifier) 

50 

51 row = await query.fetch_one() 

52 if row: 

53 return _create_entity(row) 

54 

55 raise AccessTokenNotFoundException 

56 

57 async def get_all( 

58 self, 

59 query: AccessTokenQuery | None = None, 

60 limit: int | None = None, 

61 offset: int | None = None, 

62 ) -> AsyncIterator[AccessTokenEntity]: 

63 query = query or self.create_query() 

64 async for row in query.fetch(limit, offset): 

65 yield _create_entity(row) 

66 

67 async def create(self, access_token: AccessTokenEntity) -> AccessTokenEntity: 

68 new_id = await self._database.insert( 

69 AccessTokenRow.__table_name__, AccessTokenRow.persist(access_token) 

70 ) 

71 return access_token.set_id(AccessTokenIdentifier(new_id)) 

72 

73 async def update(self, access_token: AccessTokenEntity): 

74 await self._database.update( 

75 access_token.id.value, 

76 AccessTokenRow.__table_name__, 

77 AccessTokenRow.persist(access_token), 

78 )