Coverage for kwai/modules/identity/user_recoveries/user_recovery_db_repository.py: 100%
29 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module that implements a user recovery repository interface for a database."""
2from typing import Any
4from sql_smith.functions import on
6from kwai.core.db.database import Database
7from kwai.core.domain.entity import Entity
8from kwai.core.domain.value_objects.unique_id import UniqueId
9from kwai.modules.identity.user_recoveries.user_recovery import (
10 UserRecoveryEntity,
11 UserRecoveryIdentifier,
12)
13from kwai.modules.identity.user_recoveries.user_recovery_repository import (
14 UserRecoveryNotFoundException,
15 UserRecoveryRepository,
16)
17from kwai.modules.identity.user_recoveries.user_recovery_tables import (
18 UserRecoveriesTable,
19 UserRecoveryRow,
20)
21from kwai.modules.identity.users.user_tables import UsersTable
24def _create_entity(row: dict[str, Any]) -> UserRecoveryEntity:
25 """Map the user recovery record to an entity."""
26 return UserRecoveriesTable(row).create_entity(UsersTable(row).create_entity())
29class UserRecoveryDbRepository(UserRecoveryRepository):
30 """A user recovery repository for a database."""
32 def __init__(self, database: Database):
33 self._database = database
35 async def create(self, user_recovery: UserRecoveryEntity) -> UserRecoveryEntity:
36 new_id = await self._database.insert(
37 UserRecoveriesTable.table_name, UserRecoveryRow.persist(user_recovery)
38 )
39 await self._database.commit()
40 return Entity.replace(user_recovery, id_=UserRecoveryIdentifier(new_id))
42 async def update(self, user_recovery: UserRecoveryEntity):
43 await self._database.update(
44 user_recovery.id.value,
45 UserRecoveriesTable.table_name,
46 UserRecoveryRow.persist(user_recovery),
47 )
48 await self._database.commit()
50 async def get_by_uuid(self, uuid: UniqueId) -> UserRecoveryEntity:
51 query = (
52 self._database.create_query_factory()
53 .select()
54 .columns(*(UserRecoveriesTable.aliases() + UsersTable.aliases()))
55 .from_(UserRecoveriesTable.table_name)
56 .join(
57 UsersTable.table_name,
58 on(UserRecoveriesTable.column("user_id"), UsersTable.column("id")),
59 )
60 .and_where(UserRecoveriesTable.field("uuid").eq(str(uuid)))
61 )
62 if row := await self._database.fetch_one(query):
63 return _create_entity(row)
65 raise UserRecoveryNotFoundException()
67 async def delete(self, user_recovery: UserRecoveryEntity):
68 await self._database.delete(
69 user_recovery.id.value, UserRecoveriesTable.table_name
70 )
71 await self._database.commit()