Coverage for kwai/modules/identity/tokens/refresh_token_db_repository.py: 83%

42 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module that implements a refresh token repository for a database.""" 

2from typing import AsyncIterator 

3 

4from kwai.core.db.database import Database 

5from kwai.core.domain.entity import Entity 

6from kwai.modules.identity.tokens.refresh_token import ( 

7 RefreshTokenEntity, 

8 RefreshTokenIdentifier, 

9) 

10from kwai.modules.identity.tokens.refresh_token_db_query import RefreshTokenDbQuery 

11from kwai.modules.identity.tokens.refresh_token_query import RefreshTokenQuery 

12from kwai.modules.identity.tokens.refresh_token_repository import ( 

13 RefreshTokenNotFoundException, 

14 RefreshTokenRepository, 

15) 

16from kwai.modules.identity.tokens.token_identifier import TokenIdentifier 

17from kwai.modules.identity.tokens.token_tables import ( 

18 AccessTokensTable, 

19 RefreshTokenRow, 

20 RefreshTokensTable, 

21) 

22from kwai.modules.identity.users.user_tables import ( 

23 UserAccountsTable, 

24) 

25 

26 

27def _create_entity(row) -> RefreshTokenEntity: 

28 """Create a refresh token entity from a row.""" 

29 return RefreshTokensTable(row).create_entity( 

30 AccessTokensTable(row).create_entity(UserAccountsTable(row).create_entity()) 

31 ) 

32 

33 

34class RefreshTokenDbRepository(RefreshTokenRepository): 

35 """Database repository for the refresh token entity.""" 

36 

37 def __init__(self, database: Database): 

38 self._database = database 

39 

40 def create_query(self) -> RefreshTokenQuery: 

41 return RefreshTokenDbQuery(self._database) 

42 

43 async def get_by_token_identifier( 

44 self, identifier: TokenIdentifier 

45 ) -> RefreshTokenEntity: 

46 query = self.create_query() 

47 query.filter_by_token_identifier(identifier) 

48 

49 row = await query.fetch_one() 

50 if row: 

51 return _create_entity(row) 

52 

53 raise RefreshTokenNotFoundException() 

54 

55 async def get(self, id_: RefreshTokenIdentifier) -> RefreshTokenEntity: 

56 query = self.create_query() 

57 query.filter_by_id(id_.value) 

58 row = await query.fetch_one() 

59 if row: 

60 return _create_entity(row) 

61 

62 raise RefreshTokenNotFoundException() 

63 

64 async def get_all( 

65 self, 

66 query: RefreshTokenDbQuery | None = None, 

67 limit: int | None = None, 

68 offset: int | None = None, 

69 ) -> AsyncIterator[RefreshTokenEntity]: 

70 query = query or self.create_query() 

71 async for row in query.fetch(limit, offset): 

72 yield _create_entity(row) 

73 

74 async def create(self, refresh_token: RefreshTokenEntity) -> RefreshTokenEntity: 

75 new_id = await self._database.insert( 

76 RefreshTokensTable.table_name, RefreshTokenRow.persist(refresh_token) 

77 ) 

78 await self._database.commit() 

79 return Entity.replace(refresh_token, id_=RefreshTokenIdentifier(new_id)) 

80 

81 async def update(self, refresh_token: RefreshTokenEntity): 

82 await self._database.update( 

83 refresh_token.id.value, 

84 RefreshTokensTable.table_name, 

85 RefreshTokenRow.persist(refresh_token), 

86 ) 

87 await self._database.commit()