Coverage for src/kwai/modules/club/repositories/contact_db_repository.py: 100%
31 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that defines a contact repository for a database."""
3from dataclasses import dataclass
5from sql_smith.functions import on
7from kwai.core.db.database import Database
8from kwai.core.db.table_row import JoinedTableRow
9from kwai.core.domain.entity import Entity
10from kwai.modules.club.domain.contact import ContactEntity, ContactIdentifier
11from kwai.modules.club.repositories._tables import ContactRow, CountryRow
12from kwai.modules.club.repositories.contact_repository import (
13 ContactNotFoundException,
14 ContactRepository,
15)
18@dataclass(kw_only=True, frozen=True, slots=True)
19class ContactQueryRow(JoinedTableRow):
20 """A data transfer object for a Contact query."""
22 contact: ContactRow
23 country: CountryRow
25 def create_entity(self) -> ContactEntity:
26 """Create a Contact entity from a row."""
27 return self.contact.create_entity(self.country.create_country())
30class ContactDbRepository(ContactRepository):
31 """A contact repository for a database."""
33 def __init__(self, database: Database):
34 self._database = database
36 async def create(self, contact: ContactEntity) -> ContactEntity:
37 new_contact_id = await self._database.insert(
38 ContactRow.__table_name__, ContactRow.persist(contact)
39 )
40 return Entity.replace(contact, id_=ContactIdentifier(new_contact_id))
42 async def delete(self, contact: ContactEntity):
43 await self._database.delete(contact.id.value, ContactRow.__table_name__)
45 async def update(self, contact: ContactEntity):
46 await self._database.update(
47 contact.id.value, ContactRow.__table_name__, ContactRow.persist(contact)
48 )
50 async def get(self, id_: ContactIdentifier) -> ContactEntity:
51 query = Database.create_query_factory().select()
52 query.from_(ContactRow.__table_name__).columns(
53 *ContactQueryRow.get_aliases()
54 ).inner_join(
55 CountryRow.__table_name__,
56 on(CountryRow.column("id"), ContactRow.column("country_id")),
57 ).where(ContactRow.field("id").eq(id_.value))
58 row = await self._database.fetch_one(query)
59 if row:
60 return ContactQueryRow.map(row).create_entity()
62 raise ContactNotFoundException(f"Contact with {id_} not found")