Coverage for src/kwai/modules/club/repositories/member_db_repository.py: 100%
46 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module for defining a member repository using a database."""
3from typing import AsyncGenerator
5from sql_smith.functions import field
7from kwai.core.db.database import Database
8from kwai.core.domain.entity import Entity
9from kwai.modules.club.domain.file_upload import FileUploadEntity
10from kwai.modules.club.domain.member import MemberEntity, MemberIdentifier
11from kwai.modules.club.repositories._tables import MemberRow, MemberUploadRow
12from kwai.modules.club.repositories.member_db_query import MemberDbQuery, MemberQueryRow
13from kwai.modules.club.repositories.member_query import MemberQuery
14from kwai.modules.club.repositories.member_repository import (
15 MemberNotFoundException,
16 MemberRepository,
17)
18from kwai.modules.club.repositories.person_db_repository import PersonDbRepository
21class MemberDbRepository(MemberRepository):
22 """A member repository using a database."""
24 def __init__(self, database: Database):
25 """Initialize the repository.
27 Args:
28 database: The database for this repository.
29 """
30 self._database = database
32 def create_query(self) -> MemberQuery:
33 return MemberDbQuery(self._database)
35 async def get_all(
36 self,
37 query: MemberQuery | None = None,
38 limit: int | None = None,
39 offset: int | None = None,
40 ) -> AsyncGenerator[MemberEntity, None]:
41 query = query or self.create_query()
43 async for row in query.fetch(limit, offset):
44 yield MemberQueryRow.map(row).create_entity()
46 async def get(self, query: MemberQuery | None = None) -> MemberEntity:
47 member_iterator = self.get_all(query)
48 try:
49 return await anext(member_iterator)
50 except StopAsyncIteration:
51 raise MemberNotFoundException("Member not found") from None
53 async def create(self, member: MemberEntity) -> MemberEntity:
54 # When there is no person id, create it.
55 if member.person.id.is_empty():
56 person = await PersonDbRepository(self._database).create(member.person)
57 member = Entity.replace(member, person=person)
59 new_id = await self._database.insert(
60 MemberRow.__table_name__, MemberRow.persist(member)
61 )
63 return Entity.replace(member, id_=MemberIdentifier(new_id))
65 async def update(self, member: MemberEntity) -> None:
66 # Update the member
67 await self._database.update(
68 member.id.value, MemberRow.__table_name__, MemberRow.persist(member)
69 )
70 # Update person information
71 await PersonDbRepository(self._database).update(member.person)
73 async def delete(self, member: MemberEntity) -> None:
74 await PersonDbRepository(self._database).delete(member.person)
75 await self._database.delete(member.id, MemberRow.__table_name__)
77 async def activate_members(self, upload: FileUploadEntity) -> None:
78 member_upload_query = (
79 self._database.create_query_factory()
80 .select("member_id")
81 .from_(MemberUploadRow.__table_name__)
82 .where(field("import_id").eq(upload.id.value))
83 )
84 update_query = (
85 self._database.create_query_factory()
86 .update(MemberRow.__table_name__, {"active": 1})
87 .where(field("id").in_(member_upload_query))
88 )
89 await self._database.execute(update_query)
91 async def deactivate_members(self, upload: FileUploadEntity) -> None:
92 member_upload_query = (
93 self._database.create_query_factory()
94 .select("member_id")
95 .from_(MemberUploadRow.__table_name__)
96 .where(field("import_id").eq(upload.id.value))
97 )
98 update_query = (
99 self._database.create_query_factory()
100 .update(MemberRow.__table_name__, {"active": 0})
101 .where(field("id").not_in(member_upload_query))
102 )
103 await self._database.execute(update_query)