Coverage for src/kwai/modules/training/trainings/training_db_repository.py: 92%
123 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module for implementing a training repository for a database."""
3from typing import AsyncIterator
5from sql_smith.functions import alias, express, field
7from kwai.core.db.database import Database, Record
8from kwai.core.db.rows import OwnersTable
9from kwai.core.domain.entity import Entity
10from kwai.modules.training.teams.team import TeamEntity
11from kwai.modules.training.teams.team_tables import TeamsTable
12from kwai.modules.training.trainings.training import TrainingEntity, TrainingIdentifier
13from kwai.modules.training.trainings.training_coach_db_query import TrainingCoachDbQuery
14from kwai.modules.training.trainings.training_db_query import TrainingDbQuery
15from kwai.modules.training.trainings.training_definition import TrainingDefinitionEntity
16from kwai.modules.training.trainings.training_query import TrainingQuery
17from kwai.modules.training.trainings.training_repository import (
18 TrainingNotFoundException,
19 TrainingRepository,
20)
21from kwai.modules.training.trainings.training_tables import (
22 TrainingCoachRow,
23 TrainingContentsTable,
24 TrainingDefinitionsTable,
25 TrainingRow,
26 TrainingsTable,
27 TrainingTeamRow,
28 TrainingTeamsTable,
29 TrainingTextRow,
30)
31from kwai.modules.training.trainings.training_team_db_query import TrainingTeamDbQuery
32from kwai.modules.training.trainings.value_objects import TrainingCoach
35def _create_entity(rows: list[Record]) -> TrainingEntity:
36 """Create a training entity from a group of rows."""
37 if rows[0][TrainingDefinitionsTable.alias_name("id")] is None:
38 definition = None
39 else:
40 definition = TrainingDefinitionsTable(rows[0]).create_entity(
41 team=TeamsTable(rows[0]).create_entity(),
42 owner=OwnersTable(rows[0], "definition_owners").create_owner(),
43 )
44 return TrainingsTable(rows[0]).create_entity(
45 [
46 TrainingContentsTable(row).create_text(OwnersTable(row).create_owner())
47 for row in rows
48 ],
49 definition=definition,
50 )
53class TrainingDbRepository(TrainingRepository):
54 """A training repository for a database."""
56 def __init__(self, database: Database):
57 """Initialize the repository.
59 Args:
60 database: The database for this repository.
61 """
62 self._database = database
64 def create_query(self) -> TrainingQuery:
65 return TrainingDbQuery(self._database)
67 async def get_by_id(self, id: TrainingIdentifier) -> TrainingEntity:
68 query = self.create_query()
69 query.filter_by_id(id)
71 try:
72 row_iterator = self.get_all(query, 1)
73 entity = await anext(row_iterator)
74 except StopAsyncIteration:
75 raise TrainingNotFoundException(
76 f"Training with id {id} does not exist"
77 ) from None
78 return entity
80 async def get_all(
81 self,
82 training_query: TrainingQuery | None = None,
83 limit: int | None = None,
84 offset: int | None = None,
85 ) -> AsyncIterator[TrainingEntity] | None:
86 if training_query is None:
87 training_query = self.create_query()
89 trainings: dict[TrainingIdentifier, TrainingEntity] = {}
90 group_by_column = TrainingsTable.alias_name("id")
92 row_it = training_query.fetch(limit, offset)
93 # Handle the first row
94 try:
95 record = await anext(row_it)
96 except StopAsyncIteration:
97 return
99 group = [record]
100 current_key = record[group_by_column]
102 # Process all other rows
103 async for record in row_it:
104 new_key = record[group_by_column]
105 if new_key != current_key:
106 training = _create_entity(group)
107 trainings[training.id] = training
108 group = [record]
109 current_key = new_key
110 else:
111 group.append(record)
113 training = _create_entity(group)
114 trainings[training.id] = training
116 # Get the coaches of all the trainings.
117 training_query = TrainingCoachDbQuery(self._database).filter_by_trainings(
118 *trainings.keys()
119 )
120 coaches: dict[
121 TrainingIdentifier, list[TrainingCoach]
122 ] = await training_query.fetch_coaches()
124 # Get the teams of all trainings
125 team_query = TrainingTeamDbQuery(self._database).filter_by_trainings(
126 *trainings.keys()
127 )
128 teams: dict[
129 TrainingIdentifier, list[TeamEntity]
130 ] = await team_query.fetch_teams()
132 for training in trainings.values():
133 training_coaches = coaches.get(training.id, [])
134 training_teams = teams.get(training.id, [])
135 if len(training_coaches) > 0 or len(training_teams) > 0:
136 yield Entity.replace(
137 training, coaches=training_coaches, teams=training_teams
138 )
139 else:
140 yield training
142 async def create(self, training: TrainingEntity) -> TrainingEntity:
143 new_id = await self._database.insert(
144 TrainingsTable.table_name, TrainingRow.persist(training)
145 )
146 result = Entity.replace(training, id_=TrainingIdentifier(new_id))
148 content_rows = [
149 TrainingTextRow.persist(result, content) for content in training.texts
150 ]
152 await self._database.insert(TrainingContentsTable.table_name, *content_rows)
153 await self._insert_coaches(result)
154 await self._insert_teams(result)
156 await self._database.commit()
158 return result
160 async def update(self, training: TrainingEntity) -> None:
161 # Update the training
162 await self._database.update(
163 training.id.value,
164 TrainingsTable.table_name,
165 TrainingRow.persist(training),
166 )
168 # Update the text, first delete, then insert again.
169 await self._delete_contents(training)
170 content_rows = [
171 TrainingTextRow.persist(training, content) for content in training.texts
172 ]
173 await self._database.insert(TrainingContentsTable.table_name, *content_rows)
175 # Update coaches, first delete, then insert again.
176 await self._delete_coaches(training)
177 await self._insert_coaches(training)
179 # Update teams, first delete, then insert again.
180 await self._delete_teams(training)
181 await self._insert_teams(training)
183 await self._database.commit()
185 async def _insert_coaches(self, training: TrainingEntity):
186 """Insert the related coaches."""
187 training_coach_rows = [
188 TrainingCoachRow.persist(training, training_coach)
189 for training_coach in training.coaches
190 ]
191 if training_coach_rows:
192 await self._database.insert(
193 TrainingCoachRow.__table_name__, *training_coach_rows
194 )
196 async def _insert_teams(self, training: TrainingEntity):
197 """Insert the related teams."""
198 training_team_rows = [
199 TrainingTeamRow.persist(training, team) for team in training.teams
200 ]
201 if training_team_rows:
202 await self._database.insert(
203 TrainingTeamsTable.table_name, *training_team_rows
204 )
206 async def _delete_coaches(self, training: TrainingEntity):
207 """Delete coaches of the training."""
208 delete_coaches_query = (
209 self._database.create_query_factory()
210 .delete(TrainingCoachRow.__table_name__)
211 .where(field("training_id").eq(training.id.value))
212 )
213 await self._database.execute(delete_coaches_query)
215 async def _delete_contents(self, training: TrainingEntity):
216 """Delete text contents of the training."""
217 delete_contents_query = (
218 self._database.create_query_factory()
219 .delete(TrainingContentsTable.table_name)
220 .where(field("training_id").eq(training.id.value))
221 )
222 await self._database.execute(delete_contents_query)
224 async def _delete_teams(self, training: TrainingEntity):
225 """Delete the teams of the training."""
226 delete_teams_query = (
227 self._database.create_query_factory()
228 .delete(TrainingTeamsTable.table_name)
229 .where(field("training_id").eq(training.id.value))
230 )
231 await self._database.execute(delete_teams_query)
233 async def delete(self, training: TrainingEntity) -> None:
234 await self._database.delete(training.id.value, TrainingsTable.table_name)
236 (await self._delete_contents(training),)
237 (await self._delete_coaches(training),)
238 (await self._delete_teams(training),)
240 await self._database.commit()
242 async def reset_definition(
243 self, training_definition: TrainingDefinitionEntity, delete: bool = False
244 ) -> None:
245 trainings_query = (
246 self._database.create_query_factory()
247 .select(TrainingsTable.column("id"))
248 .from_(TrainingsTable.table_name)
249 .and_where(field("definition_id").eq(training_definition.id.value))
250 )
251 if delete:
252 delete_teams = (
253 self._database.create_query_factory()
254 .delete(TrainingTeamsTable.table_name)
255 .and_where(TrainingTeamsTable.field("training_id").in_(trainings_query))
256 )
257 await self._database.execute(delete_teams)
259 delete_coaches = (
260 self._database.create_query_factory()
261 .delete(TrainingCoachRow.__table_name__)
262 .and_where(TrainingCoachRow.field("training_id").in_(trainings_query))
263 )
264 await self._database.execute(delete_coaches)
266 delete_contents = (
267 self._database.create_query_factory()
268 .delete(TrainingContentsTable.table_name)
269 .and_where(
270 TrainingContentsTable.field("training_id").in_(trainings_query)
271 )
272 )
273 await self._database.execute(delete_contents)
274 await self._database.commit()
275 else:
276 # Because it is not allowed to update the table that is used
277 # in a sub query, we need to create a copy.
278 copy_trainings_query = (
279 self._database.create_query_factory()
280 .select("t.id")
281 .from_(alias(express("({})", trainings_query), "t"))
282 )
283 update_trainings = (
284 self._database.create_query_factory()
285 .update(TrainingsTable.table_name, {"definition_id": None})
286 .where(TrainingsTable.field("id").in_(copy_trainings_query))
287 )
288 await self._database.execute(update_trainings)
289 await self._database.commit()