Coverage for src/kwai/modules/identity/tokens/refresh_token_db_repository.py: 85%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements a refresh token repository for a database.""" 

2 

3from typing import AsyncIterator 

4 

5from kwai.core.db.database import Database 

6from kwai.modules.identity.tokens.refresh_token import ( 

7 RefreshTokenEntity, 

8 RefreshTokenIdentifier, 

9) 

10from kwai.modules.identity.tokens.refresh_token_db_query import RefreshTokenDbQuery 

11from kwai.modules.identity.tokens.refresh_token_query import RefreshTokenQuery 

12from kwai.modules.identity.tokens.refresh_token_repository import ( 

13 RefreshTokenNotFoundException, 

14 RefreshTokenRepository, 

15) 

16from kwai.modules.identity.tokens.token_identifier import TokenIdentifier 

17from kwai.modules.identity.tokens.token_tables import ( 

18 AccessTokenRow, 

19 RefreshTokenRow, 

20) 

21from kwai.modules.identity.users.user_tables import UserAccountRow 

22 

23 

24def _create_entity(row) -> RefreshTokenEntity: 

25 """Create a refresh token entity from a row.""" 

26 return RefreshTokenRow.map(row).create_entity( 

27 AccessTokenRow.map(row).create_entity(UserAccountRow.map(row).create_entity()) 

28 ) 

29 

30 

31class RefreshTokenDbRepository(RefreshTokenRepository): 

32 """Database repository for the refresh token entity.""" 

33 

34 def __init__(self, database: Database): 

35 self._database = database 

36 

37 def create_query(self) -> RefreshTokenQuery: 

38 return RefreshTokenDbQuery(self._database) 

39 

40 async def get_by_token_identifier( 

41 self, identifier: TokenIdentifier 

42 ) -> RefreshTokenEntity: 

43 query = self.create_query() 

44 query.filter_by_token_identifier(identifier) 

45 

46 row = await query.fetch_one() 

47 if row: 

48 return _create_entity(row) 

49 

50 raise RefreshTokenNotFoundException( 

51 f"Token with identifier {identifier} not found" 

52 ) 

53 

54 async def get(self, id_: RefreshTokenIdentifier) -> RefreshTokenEntity: 

55 query = self.create_query() 

56 query.filter_by_id(id_) 

57 row = await query.fetch_one() 

58 if row: 

59 return _create_entity(row) 

60 

61 raise RefreshTokenNotFoundException(f"Token with id {id_} not found") 

62 

63 async def get_all( 

64 self, 

65 query: RefreshTokenDbQuery | None = None, 

66 limit: int | None = None, 

67 offset: int | None = None, 

68 ) -> AsyncIterator[RefreshTokenEntity]: 

69 query = query or self.create_query() 

70 async for row in query.fetch(limit, offset): 

71 yield _create_entity(row) 

72 

73 async def create(self, refresh_token: RefreshTokenEntity) -> RefreshTokenEntity: 

74 new_id = await self._database.insert( 

75 RefreshTokenRow.__table_name__, RefreshTokenRow.persist(refresh_token) 

76 ) 

77 return refresh_token.set_id(RefreshTokenIdentifier(new_id)) 

78 

79 async def update(self, refresh_token: RefreshTokenEntity): 

80 await self._database.update( 

81 refresh_token.id.value, 

82 RefreshTokenRow.__table_name__, 

83 RefreshTokenRow.persist(refresh_token), 

84 ) 

85 

86 async def delete(self, refresh_token: RefreshTokenEntity): 

87 await self._database.delete( 

88 refresh_token.id.value, RefreshTokenRow.__table_name__ 

89 )