Coverage for src/kwai/modules/identity/users/user_account_db_repository.py: 100%
38 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that implements a user account repository for a database."""
3from collections.abc import AsyncGenerator
4from dataclasses import replace
6from kwai.core.db.database import Database
7from kwai.core.domain.value_objects.email_address import EmailAddress
8from kwai.modules.identity.users.user import UserIdentifier
9from kwai.modules.identity.users.user_account import (
10 UserAccountEntity,
11 UserAccountIdentifier,
12)
13from kwai.modules.identity.users.user_account_db_query import UserAccountDbQuery
14from kwai.modules.identity.users.user_account_query import UserAccountQuery
15from kwai.modules.identity.users.user_account_repository import (
16 UserAccountNotFoundException,
17 UserAccountRepository,
18)
19from kwai.modules.identity.users.user_tables import (
20 UserAccountRow,
21)
24class UserAccountDbRepository(UserAccountRepository):
25 """User account repository for a database."""
27 def __init__(self, database: Database):
28 self._database = database
30 def create_query(self) -> UserAccountQuery:
31 return UserAccountDbQuery(self._database)
33 async def get_all(
34 self,
35 query: UserAccountQuery | None = None,
36 limit: int | None = None,
37 offset: int | None = None,
38 ) -> AsyncGenerator[UserAccountEntity, None]:
39 query = query or self.create_query()
40 async for row in query.fetch(limit, offset):
41 yield UserAccountRow.map(row).create_entity()
43 async def get_user_by_email(self, email: EmailAddress) -> UserAccountEntity:
44 query = self.create_query().filter_by_email(email)
45 if row := await query.fetch_one():
46 return UserAccountRow.map(row).create_entity()
48 raise UserAccountNotFoundException()
50 async def exists_with_email(self, email: EmailAddress) -> bool:
51 try:
52 await self.get_user_by_email(email)
53 except UserAccountNotFoundException:
54 return False
56 return True
58 async def create(self, user_account: UserAccountEntity) -> UserAccountEntity:
59 new_id = await self._database.insert(
60 UserAccountRow.__table_name__, UserAccountRow.persist(user_account)
61 )
62 user = user_account.user.set_id(UserIdentifier(new_id))
63 return replace(user_account, user=user).set_id(UserAccountIdentifier(new_id))
65 async def update(self, user_account: UserAccountEntity):
66 await self._database.update(
67 user_account.id.value,
68 UserAccountRow.__table_name__,
69 UserAccountRow.persist(user_account),
70 )
72 async def delete(self, user_account):
73 await self._database.delete(
74 user_account.id.value, UserAccountRow.__table_name__
75 )