Coverage for src/kwai/modules/teams/repositories/team_db_repository.py: 96%
82 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that implements a team repository for a database."""
3from dataclasses import dataclass
4from typing import Any, AsyncGenerator, Self
6from sql_smith.functions import field, on
8from kwai.core.db.database import Database
9from kwai.core.db.database_query import DatabaseQuery
10from kwai.core.db.table_row import JoinedTableRow
11from kwai.core.domain.entity import Entity
12from kwai.core.domain.value_objects.date import Date
13from kwai.core.domain.value_objects.name import Name
14from kwai.core.domain.value_objects.unique_id import UniqueId
15from kwai.core.functions import async_groupby
16from kwai.modules.club.domain.value_objects import Birthdate, Gender, License
17from kwai.modules.teams.domain.team import TeamEntity, TeamIdentifier
18from kwai.modules.teams.domain.team_member import (
19 MemberEntity,
20 MemberIdentifier,
21 TeamMember,
22)
23from kwai.modules.teams.repositories._tables import (
24 CountryRow,
25 MemberPersonRow,
26 MemberRow,
27 TeamMemberRow,
28 TeamRow,
29)
30from kwai.modules.teams.repositories.team_repository import (
31 TeamNotFoundException,
32 TeamQuery,
33 TeamRepository,
34)
37@dataclass(kw_only=True, frozen=True, slots=True)
38class MemberPersonCountryMixin:
39 """Dataclass for a member related row."""
41 member: MemberRow
42 member_person: MemberPersonRow
43 country: CountryRow
45 def create_member_entity(self) -> MemberEntity:
46 """Create a member entity from a row."""
47 return MemberEntity(
48 id_=MemberIdentifier(self.member.id),
49 name=Name(
50 first_name=self.member_person.firstname,
51 last_name=self.member_person.lastname,
52 ),
53 license=License(
54 number=self.member.license,
55 end_date=Date.create_from_date(self.member.license_end_date),
56 ),
57 birthdate=Birthdate(
58 date=Date.create_from_date(self.member_person.birthdate)
59 ),
60 nationality=self.country.create_country(),
61 gender=Gender(self.member_person.gender),
62 uuid=UniqueId.create_from_string(self.member.uuid),
63 active_in_club=self.member.active == 1,
64 )
67@dataclass(kw_only=True, frozen=True, slots=True)
68class TeamQueryRow(MemberPersonCountryMixin, JoinedTableRow):
69 """A data transfer object for the team query."""
71 team: TeamRow
72 team_member: TeamMemberRow
74 @classmethod
75 def create_entity(cls, rows: list[dict[str, Any]]) -> TeamEntity:
76 """Create a team entity from a group of rows."""
77 team_query_row = cls.map(rows[0])
78 team_members = {}
79 for row in rows:
80 mapped_row = cls.map(row)
81 if mapped_row.member.id is None:
82 continue
84 member = mapped_row.create_member_entity()
85 team_members[member.uuid] = mapped_row.team_member.create_team_member(
86 member
87 )
88 return team_query_row.team.create_entity(team_members)
91class TeamDbQuery(TeamQuery, DatabaseQuery):
92 """A team query for a database."""
94 def __init__(self, database: Database):
95 super().__init__(database)
97 def init(self):
98 self._query.from_(TeamRow.__table_name__).left_join(
99 TeamMemberRow.__table_name__,
100 on(TeamRow.column("id"), TeamMemberRow.column("team_id")),
101 ).left_join(
102 MemberRow.__table_name__,
103 on(MemberRow.column("id"), TeamMemberRow.column("member_id")),
104 ).left_join(
105 MemberPersonRow.__table_name__,
106 on(MemberPersonRow.column("id"), MemberRow.column("person_id")),
107 ).left_join(
108 CountryRow.__table_name__,
109 on(CountryRow.column("id"), MemberPersonRow.column("nationality_id")),
110 )
112 @property
113 def columns(self):
114 return TeamQueryRow.get_aliases()
116 @property
117 def count_column(self) -> str:
118 return TeamRow.column("id")
120 def filter_by_id(self, id_: TeamIdentifier) -> Self:
121 self._query.and_where(TeamRow.field("id").eq(id_.value))
122 return self
125class TeamDbRepository(TeamRepository):
126 """A team repository for a database."""
128 def create_query(self) -> TeamQuery:
129 return TeamDbQuery(self._database)
131 async def get(self, query: TeamQuery | None = None) -> TeamEntity:
132 team_iterator = self.get_all(query)
133 try:
134 return await anext(team_iterator)
135 except StopAsyncIteration:
136 raise TeamNotFoundException("Team not found") from None
138 async def get_all(
139 self,
140 query: TeamQuery | None = None,
141 limit: int | None = None,
142 offset: int | None = None,
143 ) -> AsyncGenerator[TeamEntity, None]:
144 if query is None:
145 query = self.create_query()
147 group_by_column = "team_id"
148 row_iterator = query.fetch(limit=limit, offset=offset)
149 async for _, group in async_groupby(
150 row_iterator, key=lambda row: row[group_by_column]
151 ):
152 yield TeamQueryRow.create_entity(group)
154 def __init__(self, database: Database):
155 self._database = database
157 async def create(self, team: TeamEntity) -> TeamEntity:
158 new_team_id = await self._database.insert(
159 TeamRow.__table_name__, TeamRow.persist(team)
160 )
161 return Entity.replace(team, id_=TeamIdentifier(new_team_id))
163 async def delete(self, team: TeamEntity) -> None:
164 delete_team_members_query = (
165 self._database.create_query_factory()
166 .delete(TeamMemberRow.__table_name__)
167 .where(field("team_id").eq(team.id.value))
168 )
169 await self._database.execute(delete_team_members_query)
170 await self._database.delete(team.id.value, TeamRow.__table_name__)
172 async def update(self, team: TeamEntity):
173 await self._database.update(
174 team.id.value, TeamRow.__table_name__, TeamRow.persist(team)
175 )
177 async def add_team_member(self, team: TeamEntity, member: TeamMember):
178 team_member_row = TeamMemberRow.persist(team, member)
179 await self._database.insert(TeamMemberRow.__table_name__, team_member_row)