Coverage for src/kwai/modules/identity/tokens/refresh_token_db_repository.py: 82%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements a refresh token repository for a database.""" 

2 

3from typing import AsyncIterator 

4 

5from kwai.core.db.database import Database 

6from kwai.modules.identity.tokens.refresh_token import ( 

7 RefreshTokenEntity, 

8 RefreshTokenIdentifier, 

9) 

10from kwai.modules.identity.tokens.refresh_token_db_query import RefreshTokenDbQuery 

11from kwai.modules.identity.tokens.refresh_token_query import RefreshTokenQuery 

12from kwai.modules.identity.tokens.refresh_token_repository import ( 

13 RefreshTokenNotFoundException, 

14 RefreshTokenRepository, 

15) 

16from kwai.modules.identity.tokens.token_identifier import TokenIdentifier 

17from kwai.modules.identity.tokens.token_tables import ( 

18 AccessTokenRow, 

19 RefreshTokenRow, 

20) 

21from kwai.modules.identity.users.user_tables import UserAccountRow 

22 

23 

24def _create_entity(row) -> RefreshTokenEntity: 

25 """Create a refresh token entity from a row.""" 

26 return RefreshTokenRow.map(row).create_entity( 

27 AccessTokenRow.map(row).create_entity(UserAccountRow.map(row).create_entity()) 

28 ) 

29 

30 

31class RefreshTokenDbRepository(RefreshTokenRepository): 

32 """Database repository for the refresh token entity.""" 

33 

34 def __init__(self, database: Database): 

35 self._database = database 

36 

37 def create_query(self) -> RefreshTokenQuery: 

38 return RefreshTokenDbQuery(self._database) 

39 

40 async def get_by_token_identifier( 

41 self, identifier: TokenIdentifier 

42 ) -> RefreshTokenEntity: 

43 query = self.create_query() 

44 query.filter_by_token_identifier(identifier) 

45 

46 row = await query.fetch_one() 

47 if row: 

48 return _create_entity(row) 

49 

50 raise RefreshTokenNotFoundException() 

51 

52 async def get(self, id_: RefreshTokenIdentifier) -> RefreshTokenEntity: 

53 query = self.create_query() 

54 query.filter_by_id(id_.value) 

55 row = await query.fetch_one() 

56 if row: 

57 return _create_entity(row) 

58 

59 raise RefreshTokenNotFoundException() 

60 

61 async def get_all( 

62 self, 

63 query: RefreshTokenDbQuery | None = None, 

64 limit: int | None = None, 

65 offset: int | None = None, 

66 ) -> AsyncIterator[RefreshTokenEntity]: 

67 query = query or self.create_query() 

68 async for row in query.fetch(limit, offset): 

69 yield _create_entity(row) 

70 

71 async def create(self, refresh_token: RefreshTokenEntity) -> RefreshTokenEntity: 

72 new_id = await self._database.insert( 

73 RefreshTokenRow.__table_name__, RefreshTokenRow.persist(refresh_token) 

74 ) 

75 return refresh_token.set_id(RefreshTokenIdentifier(new_id)) 

76 

77 async def update(self, refresh_token: RefreshTokenEntity): 

78 await self._database.update( 

79 refresh_token.id.value, 

80 RefreshTokenRow.__table_name__, 

81 RefreshTokenRow.persist(refresh_token), 

82 )