Coverage for kwai/modules/identity/tokens/access_token_db_repository.py: 83%

42 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module that implements an access token repository for a database.""" 

2from typing import Any, AsyncIterator 

3 

4from kwai.core.db.database import Database 

5from kwai.core.domain.entity import Entity 

6from kwai.modules.identity.tokens.access_token import ( 

7 AccessTokenEntity, 

8 AccessTokenIdentifier, 

9) 

10from kwai.modules.identity.tokens.access_token_db_query import AccessTokenDbQuery 

11from kwai.modules.identity.tokens.access_token_query import AccessTokenQuery 

12from kwai.modules.identity.tokens.access_token_repository import ( 

13 AccessTokenNotFoundException, 

14 AccessTokenRepository, 

15) 

16from kwai.modules.identity.tokens.token_identifier import TokenIdentifier 

17from kwai.modules.identity.tokens.token_tables import ( 

18 AccessTokenRow, 

19 AccessTokensTable, 

20) 

21from kwai.modules.identity.users.user_tables import UserAccountsTable 

22 

23 

24def _create_entity(row: dict[str, Any]) -> AccessTokenEntity: 

25 """Create an access token entity from a row.""" 

26 return AccessTokensTable(row).create_entity(UserAccountsTable(row).create_entity()) 

27 

28 

29class AccessTokenDbRepository(AccessTokenRepository): 

30 """Database repository for the access token entity.""" 

31 

32 def __init__(self, database: Database): 

33 self._database = database 

34 

35 def create_query(self) -> AccessTokenQuery: 

36 return AccessTokenDbQuery(self._database) 

37 

38 async def get(self, id_: AccessTokenIdentifier) -> AccessTokenEntity: 

39 query = self.create_query() 

40 query.filter_by_id(id_.value) 

41 

42 row = await query.fetch_one() 

43 if row: 

44 return _create_entity(row) 

45 

46 raise AccessTokenNotFoundException 

47 

48 async def get_by_identifier(self, identifier: TokenIdentifier) -> AccessTokenEntity: 

49 query = self.create_query() 

50 query.filter_by_token_identifier(identifier) 

51 

52 row = await query.fetch_one() 

53 if row: 

54 return _create_entity(row) 

55 

56 raise AccessTokenNotFoundException 

57 

58 async def get_all( 

59 self, 

60 query: AccessTokenQuery | None = None, 

61 limit: int | None = None, 

62 offset: int | None = None, 

63 ) -> AsyncIterator[AccessTokenEntity]: 

64 query = query or self.create_query() 

65 async for row in query.fetch(limit, offset): 

66 yield _create_entity(row) 

67 

68 async def create(self, access_token: AccessTokenEntity) -> AccessTokenEntity: 

69 new_id = await self._database.insert( 

70 AccessTokensTable.table_name, AccessTokenRow.persist(access_token) 

71 ) 

72 await self._database.commit() 

73 return Entity.replace(access_token, id_=AccessTokenIdentifier(new_id)) 

74 

75 async def update(self, access_token: AccessTokenEntity): 

76 await self._database.update( 

77 access_token.id.value, 

78 AccessTokensTable.table_name, 

79 AccessTokenRow.persist(access_token), 

80 ) 

81 await self._database.commit()