Coverage for kwai/modules/identity/user_recoveries/user_recovery_db_repository.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module that implements a user recovery repository interface for a database.""" 

2from typing import Any 

3 

4from sql_smith.functions import on 

5 

6from kwai.core.db.database import Database 

7from kwai.core.domain.entity import Entity 

8from kwai.core.domain.value_objects.unique_id import UniqueId 

9from kwai.modules.identity.user_recoveries.user_recovery import ( 

10 UserRecoveryEntity, 

11 UserRecoveryIdentifier, 

12) 

13from kwai.modules.identity.user_recoveries.user_recovery_repository import ( 

14 UserRecoveryNotFoundException, 

15 UserRecoveryRepository, 

16) 

17from kwai.modules.identity.user_recoveries.user_recovery_tables import ( 

18 UserRecoveriesTable, 

19 UserRecoveryRow, 

20) 

21from kwai.modules.identity.users.user_tables import UsersTable 

22 

23 

24def _create_entity(row: dict[str, Any]) -> UserRecoveryEntity: 

25 """Map the user recovery record to an entity.""" 

26 return UserRecoveriesTable(row).create_entity(UsersTable(row).create_entity()) 

27 

28 

29class UserRecoveryDbRepository(UserRecoveryRepository): 

30 """A user recovery repository for a database.""" 

31 

32 def __init__(self, database: Database): 

33 self._database = database 

34 

35 async def create(self, user_recovery: UserRecoveryEntity) -> UserRecoveryEntity: 

36 new_id = await self._database.insert( 

37 UserRecoveriesTable.table_name, UserRecoveryRow.persist(user_recovery) 

38 ) 

39 await self._database.commit() 

40 return Entity.replace(user_recovery, id_=UserRecoveryIdentifier(new_id)) 

41 

42 async def update(self, user_recovery: UserRecoveryEntity): 

43 await self._database.update( 

44 user_recovery.id.value, 

45 UserRecoveriesTable.table_name, 

46 UserRecoveryRow.persist(user_recovery), 

47 ) 

48 await self._database.commit() 

49 

50 async def get_by_uuid(self, uuid: UniqueId) -> UserRecoveryEntity: 

51 query = ( 

52 self._database.create_query_factory() 

53 .select() 

54 .columns(*(UserRecoveriesTable.aliases() + UsersTable.aliases())) 

55 .from_(UserRecoveriesTable.table_name) 

56 .join( 

57 UsersTable.table_name, 

58 on(UserRecoveriesTable.column("user_id"), UsersTable.column("id")), 

59 ) 

60 .and_where(UserRecoveriesTable.field("uuid").eq(str(uuid))) 

61 ) 

62 if row := await self._database.fetch_one(query): 

63 return _create_entity(row) 

64 

65 raise UserRecoveryNotFoundException() 

66 

67 async def delete(self, user_recovery: UserRecoveryEntity): 

68 await self._database.delete( 

69 user_recovery.id.value, UserRecoveriesTable.table_name 

70 ) 

71 await self._database.commit()