Coverage for src/kwai/modules/club/repositories/contact_db_repository.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that defines a contact repository for a database.""" 

2 

3from dataclasses import dataclass 

4 

5from sql_smith.functions import on 

6 

7from kwai.core.db.database import Database 

8from kwai.core.db.table_row import JoinedTableRow 

9from kwai.core.domain.entity import Entity 

10from kwai.modules.club.domain.contact import ContactEntity, ContactIdentifier 

11from kwai.modules.club.repositories._tables import ContactRow, CountryRow 

12from kwai.modules.club.repositories.contact_repository import ( 

13 ContactNotFoundException, 

14 ContactRepository, 

15) 

16 

17 

18@dataclass(kw_only=True, frozen=True, slots=True) 

19class ContactQueryRow(JoinedTableRow): 

20 """A data transfer object for a Contact query.""" 

21 

22 contact: ContactRow 

23 country: CountryRow 

24 

25 def create_entity(self) -> ContactEntity: 

26 """Create a Contact entity from a row.""" 

27 return self.contact.create_entity(self.country.create_country()) 

28 

29 

30class ContactDbRepository(ContactRepository): 

31 """A contact repository for a database.""" 

32 

33 def __init__(self, database: Database): 

34 self._database = database 

35 

36 async def create(self, contact: ContactEntity) -> ContactEntity: 

37 new_contact_id = await self._database.insert( 

38 ContactRow.__table_name__, ContactRow.persist(contact) 

39 ) 

40 return Entity.replace(contact, id_=ContactIdentifier(new_contact_id)) 

41 

42 async def delete(self, contact: ContactEntity): 

43 await self._database.delete(contact.id.value, ContactRow.__table_name__) 

44 

45 async def update(self, contact: ContactEntity): 

46 await self._database.update( 

47 contact.id.value, ContactRow.__table_name__, ContactRow.persist(contact) 

48 ) 

49 

50 async def get(self, id_: ContactIdentifier) -> ContactEntity: 

51 query = Database.create_query_factory().select() 

52 query.from_(ContactRow.__table_name__).columns( 

53 *ContactQueryRow.get_aliases() 

54 ).inner_join( 

55 CountryRow.__table_name__, 

56 on(CountryRow.column("id"), ContactRow.column("country_id")), 

57 ).where(ContactRow.field("id").eq(id_.value)) 

58 row = await self._database.fetch_one(query) 

59 if row: 

60 return ContactQueryRow.map(row).create_entity() 

61 

62 raise ContactNotFoundException(f"Contact with {id_} not found")