Coverage for src/kwai/modules/teams/repositories/member_db_repository.py: 97%
65 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module for defining a team member repository for a database."""
3from dataclasses import dataclass
4from typing import AsyncGenerator, Self
6from sql_smith.functions import express, on
8from kwai.core.db.database import Database
9from kwai.core.db.database_query import DatabaseQuery
10from kwai.core.db.table_row import JoinedTableRow
11from kwai.core.domain.value_objects.date import Date
12from kwai.core.domain.value_objects.name import Name
13from kwai.core.domain.value_objects.unique_id import UniqueId
14from kwai.modules.club.domain.value_objects import Birthdate, Gender, License
15from kwai.modules.teams.domain.team import TeamIdentifier
16from kwai.modules.teams.domain.team_member import MemberEntity, MemberIdentifier
17from kwai.modules.teams.repositories._tables import (
18 CountryRow,
19 MemberPersonRow,
20 MemberRow,
21 TeamMemberRow,
22)
23from kwai.modules.teams.repositories.member_repository import (
24 MemberNotFoundException,
25 MemberQuery,
26 MemberRepository,
27)
30@dataclass(kw_only=True, frozen=True, slots=True)
31class MemberQueryRow(JoinedTableRow):
32 """A data transfer object for the member query."""
34 member: MemberRow
35 person: MemberPersonRow
36 country: CountryRow
38 def create_entity(self) -> MemberEntity:
39 """Create a team member entity from a row."""
40 return MemberEntity(
41 id_=MemberIdentifier(self.member.id),
42 uuid=UniqueId.create_from_string(self.member.uuid),
43 name=Name(first_name=self.person.firstname, last_name=self.person.lastname),
44 license=License(
45 number=self.member.license,
46 end_date=Date.create_from_date(self.member.license_end_date),
47 ),
48 birthdate=Birthdate(Date.create_from_date(self.person.birthdate)),
49 gender=Gender(self.person.gender),
50 nationality=self.country.create_country(),
51 active_in_club=self.member.active == 1,
52 )
55class MemberDbQuery(MemberQuery, DatabaseQuery):
56 """A team member query for a database."""
58 def __init__(self, database: Database):
59 super().__init__(database)
61 def init(self):
62 self._query.from_(MemberRow.__table_name__).inner_join(
63 MemberPersonRow.__table_name__,
64 on(MemberPersonRow.column("id"), MemberRow.column("person_id")),
65 ).inner_join(
66 CountryRow.__table_name__,
67 on(CountryRow.column("id"), MemberPersonRow.column("nationality_id")),
68 )
70 @property
71 def columns(self):
72 return MemberQueryRow.get_aliases()
74 @property
75 def count_column(self):
76 return MemberRow.column("id")
78 def filter_by_id(self, id_: MemberIdentifier) -> Self:
79 self._query.and_where(MemberRow.field("id").eq(id_.value))
80 return self
82 def filter_by_birthdate(
83 self, start_date: Date, end_date: Date | None = None
84 ) -> Self:
85 if end_date is None:
86 self._query.and_where(MemberPersonRow.field("birthdate").gte(start_date))
87 else:
88 self._query.and_where(
89 MemberPersonRow.field("birthdate").between(start_date, end_date)
90 )
91 return self
93 def filter_by_uuid(self, uuid: UniqueId) -> Self:
94 self._query.and_where(MemberRow.field("uuid").eq(str(uuid)))
95 return self
97 def filter_by_team(self, team_id: TeamIdentifier, in_team: bool = True) -> Self:
98 inner_select = (
99 self._database.create_query_factory()
100 .select()
101 .columns(TeamMemberRow.column("member_id"))
102 .from_(TeamMemberRow.__table_name__)
103 .where(TeamMemberRow.field("team_id").eq(team_id.value))
104 )
105 if in_team:
106 condition = MemberRow.field("id").in_(express("{}", inner_select))
107 else:
108 condition = MemberRow.field("id").not_in(express("{}", inner_select))
109 self._query.and_where(condition)
110 return self
113class MemberDbRepository(MemberRepository):
114 """A member repository for a database."""
116 def __init__(self, database: Database):
117 self._database = database
119 def create_query(self) -> MemberQuery:
120 return MemberDbQuery(self._database)
122 async def get(self, query: MemberQuery | None = None) -> MemberEntity:
123 team_member_iterator = self.get_all(query)
124 try:
125 return await anext(team_member_iterator)
126 except StopAsyncIteration:
127 raise MemberNotFoundException("Member not found") from None
129 async def get_all(
130 self,
131 query: MemberQuery | None = None,
132 limit: int | None = None,
133 offset: int | None = None,
134 ) -> AsyncGenerator[MemberEntity, None]:
135 query = query or self.create_query()
137 async for row in query.fetch(limit, offset):
138 yield MemberQueryRow.map(row).create_entity()