Coverage for src/kwai/modules/identity/user_recoveries/user_recovery_db_repository.py: 100%

25 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements a user recovery repository interface for a database.""" 

2 

3from typing import Any 

4 

5from sql_smith.functions import on 

6 

7from kwai.core.db.database import Database 

8from kwai.core.domain.value_objects.unique_id import UniqueId 

9from kwai.modules.identity.user_recoveries.user_recovery import ( 

10 UserRecoveryEntity, 

11 UserRecoveryIdentifier, 

12) 

13from kwai.modules.identity.user_recoveries.user_recovery_repository import ( 

14 UserRecoveryNotFoundException, 

15 UserRecoveryRepository, 

16) 

17from kwai.modules.identity.user_recoveries.user_recovery_tables import ( 

18 UserRecoveryRow, 

19) 

20from kwai.modules.identity.users.user_tables import UserRow 

21 

22 

23def _create_entity(row: dict[str, Any]) -> UserRecoveryEntity: 

24 """Map the user recovery record to an entity.""" 

25 return UserRecoveryRow.map(row).create_entity(UserRow.map(row).create_entity()) 

26 

27 

28class UserRecoveryDbRepository(UserRecoveryRepository): 

29 """A user recovery repository for a database.""" 

30 

31 def __init__(self, database: Database): 

32 self._database = database 

33 

34 async def create(self, user_recovery: UserRecoveryEntity) -> UserRecoveryEntity: 

35 new_id = await self._database.insert( 

36 UserRecoveryRow.__table_name__, UserRecoveryRow.persist(user_recovery) 

37 ) 

38 return user_recovery.set_id(UserRecoveryIdentifier(new_id)) 

39 

40 async def update(self, user_recovery: UserRecoveryEntity): 

41 await self._database.update( 

42 user_recovery.id.value, 

43 UserRecoveryRow.__table_name__, 

44 UserRecoveryRow.persist(user_recovery), 

45 ) 

46 

47 async def get_by_uuid(self, uuid: UniqueId) -> UserRecoveryEntity: 

48 query = ( 

49 self._database.create_query_factory() 

50 .select() 

51 .columns(*(UserRecoveryRow.get_aliases() + UserRow.get_aliases())) 

52 .from_(UserRecoveryRow.__table_name__) 

53 .join( 

54 UserRow.__table_name__, 

55 on(UserRecoveryRow.column("user_id"), UserRow.column("id")), 

56 ) 

57 .and_where(UserRecoveryRow.field("uuid").eq(str(uuid))) 

58 ) 

59 if row := await self._database.fetch_one(query): 

60 return _create_entity(row) 

61 

62 raise UserRecoveryNotFoundException() 

63 

64 async def delete(self, user_recovery: UserRecoveryEntity): 

65 await self._database.delete( 

66 user_recovery.id.value, UserRecoveryRow.__table_name__ 

67 )