Coverage for kwai/modules/training/trainings/training_db_repository.py: 93%
107 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module for implementing a training repository for a database."""
2from typing import AsyncIterator
4from sql_smith.functions import field
6from kwai.core.db.database import Database, Record
7from kwai.core.db.rows import OwnersTable
8from kwai.core.domain.entity import Entity
9from kwai.modules.training.teams.team import TeamEntity
10from kwai.modules.training.trainings.training import TrainingEntity, TrainingIdentifier
11from kwai.modules.training.trainings.training_coach_db_query import TrainingCoachDbQuery
12from kwai.modules.training.trainings.training_db_query import TrainingDbQuery
13from kwai.modules.training.trainings.training_query import TrainingQuery
14from kwai.modules.training.trainings.training_repository import (
15 TrainingNotFoundException,
16 TrainingRepository,
17)
18from kwai.modules.training.trainings.training_tables import (
19 TrainingCoachesTable,
20 TrainingCoachRow,
21 TrainingContentRow,
22 TrainingContentsTable,
23 TrainingDefinitionsTable,
24 TrainingRow,
25 TrainingsTable,
26 TrainingTeamRow,
27 TrainingTeamsTable,
28)
29from kwai.modules.training.trainings.training_team_db_query import TrainingTeamDbQuery
30from kwai.modules.training.trainings.value_objects import TrainingCoach
33def _create_entity(rows: list[Record]) -> TrainingEntity:
34 """Create a training entity from a group of rows."""
35 if rows[0][TrainingDefinitionsTable.alias_name("id")] is None:
36 definition = None
37 else:
38 definition = TrainingDefinitionsTable(rows[0]).create_entity(
39 OwnersTable(rows[0], "definition_owners").create_owner()
40 )
41 return TrainingsTable(rows[0]).create_entity(
42 [
43 TrainingContentsTable(row).create_content(OwnersTable(row).create_owner())
44 for row in rows
45 ],
46 definition=definition,
47 )
50class TrainingDbRepository(TrainingRepository):
51 """A training repository for a database."""
53 def __init__(self, database: Database):
54 """Initialize the repository.
56 Args:
57 database: The database for this repository.
58 """
59 self._database = database
61 def create_query(self) -> TrainingQuery:
62 return TrainingDbQuery(self._database)
64 async def get_by_id(self, id: TrainingIdentifier) -> TrainingEntity:
65 query = self.create_query()
66 query.filter_by_id(id)
68 try:
69 row_iterator = self.get_all(query, 1)
70 entity = await anext(row_iterator)
71 except StopAsyncIteration:
72 raise TrainingNotFoundException(
73 f"Training with id {id} does not exist"
74 ) from None
75 return entity
77 async def get_all(
78 self,
79 training_query: TrainingQuery | None = None,
80 limit: int | None = None,
81 offset: int | None = None,
82 ) -> AsyncIterator[TrainingEntity] | None:
83 if training_query is None:
84 training_query = self.create_query()
86 trainings: dict[TrainingIdentifier, TrainingEntity] = {}
87 group_by_column = TrainingsTable.alias_name("id")
89 row_it = training_query.fetch(limit, offset)
90 # Handle the first row
91 try:
92 record = await anext(row_it)
93 except StopAsyncIteration:
94 return
96 group = [record]
97 current_key = record[group_by_column]
99 # Process all other rows
100 async for record in row_it:
101 new_key = record[group_by_column]
102 if new_key != current_key:
103 training = _create_entity(group)
104 trainings[training.id] = training
105 group = [record]
106 current_key = new_key
107 else:
108 group.append(record)
110 training = _create_entity(group)
111 trainings[training.id] = training
113 # Get the coaches of all the trainings.
114 training_query = TrainingCoachDbQuery(self._database).filter_by_trainings(
115 *trainings.keys()
116 )
117 coaches: dict[
118 TrainingIdentifier, list[TrainingCoach]
119 ] = await training_query.fetch_coaches()
121 # Get the teams of all trainings
122 team_query = TrainingTeamDbQuery(self._database).filter_by_trainings(
123 *trainings.keys()
124 )
125 teams: dict[
126 TrainingIdentifier, list[TeamEntity]
127 ] = await team_query.fetch_teams()
129 for training in trainings.values():
130 training_coaches = coaches.get(training.id, [])
131 training_teams = teams.get(training.id, [])
132 if len(training_coaches) > 0 or len(training_teams) > 0:
133 yield Entity.replace(
134 training, coaches=training_coaches, teams=training_teams
135 )
136 else:
137 yield training
139 async def create(self, training: TrainingEntity) -> TrainingEntity:
140 new_id = await self._database.insert(
141 TrainingsTable.table_name, TrainingRow.persist(training)
142 )
143 result = Entity.replace(training, id_=TrainingIdentifier(new_id))
145 content_rows = [
146 TrainingContentRow.persist(result, content) for content in training.content
147 ]
149 await self._database.insert(TrainingContentsTable.table_name, *content_rows)
150 await self._insert_coaches(result)
151 await self._insert_teams(result)
153 await self._database.commit()
155 return result
157 async def update(self, training: TrainingEntity) -> None:
158 # Update the training
159 await self._database.update(
160 training.id.value,
161 TrainingsTable.table_name,
162 TrainingRow.persist(training),
163 )
165 # Update the text, first delete, then insert again.
166 await self._delete_contents(training)
167 content_rows = [
168 TrainingContentRow.persist(training, content)
169 for content in training.content
170 ]
171 await self._database.insert(TrainingContentsTable.table_name, *content_rows)
173 # Update coaches, first delete, then insert again.
174 await self._delete_coaches(training)
175 await self._insert_coaches(training)
177 # Update teams, first delete, then insert again.
178 await self._delete_teams(training)
179 await self._insert_teams(training)
181 await self._database.commit()
183 async def _insert_coaches(self, training: TrainingEntity):
184 """Insert the related coaches."""
185 training_coach_rows = [
186 TrainingCoachRow.persist(training, training_coach)
187 for training_coach in training.coaches
188 ]
189 if training_coach_rows:
190 await self._database.insert(
191 TrainingCoachesTable.table_name, *training_coach_rows
192 )
194 async def _insert_teams(self, training: TrainingEntity):
195 """Insert the related teams."""
196 training_team_rows = [
197 TrainingTeamRow.persist(training, team) for team in training.teams
198 ]
199 if training_team_rows:
200 await self._database.insert(
201 TrainingTeamsTable.table_name, *training_team_rows
202 )
204 async def _delete_coaches(self, training: TrainingEntity):
205 """Delete coaches of the training."""
206 delete_coaches_query = (
207 self._database.create_query_factory()
208 .delete(TrainingCoachesTable.table_name)
209 .where(field("training_id").eq(training.id.value))
210 )
211 await self._database.execute(delete_coaches_query)
213 async def _delete_contents(self, training: TrainingEntity):
214 """Delete text contents of the training."""
215 delete_contents_query = (
216 self._database.create_query_factory()
217 .delete(TrainingContentsTable.table_name)
218 .where(field("training_id").eq(training.id.value))
219 )
220 await self._database.execute(delete_contents_query)
222 async def _delete_teams(self, training: TrainingEntity):
223 """Delete the teams of the training."""
224 delete_teams_query = (
225 self._database.create_query_factory()
226 .delete(TrainingTeamsTable.table_name)
227 .where(field("training_id").eq(training.id.value))
228 )
229 await self._database.execute(delete_teams_query)
231 async def delete(self, training: TrainingEntity) -> None:
232 await self._database.delete(training.id.value, TrainingsTable.table_name)
234 await self._delete_contents(training),
235 await self._delete_coaches(training),
236 await self._delete_teams(training),
238 await self._database.commit()