Coverage for src/kwai/modules/training/trainings/training_db_repository.py: 92%

123 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module for implementing a training repository for a database.""" 

2 

3from typing import AsyncIterator 

4 

5from sql_smith.functions import alias, express, field 

6 

7from kwai.core.db.database import Database, Record 

8from kwai.core.db.rows import OwnersTable 

9from kwai.core.domain.entity import Entity 

10from kwai.modules.training.teams.team import TeamEntity 

11from kwai.modules.training.teams.team_tables import TeamsTable 

12from kwai.modules.training.trainings.training import TrainingEntity, TrainingIdentifier 

13from kwai.modules.training.trainings.training_coach_db_query import TrainingCoachDbQuery 

14from kwai.modules.training.trainings.training_db_query import TrainingDbQuery 

15from kwai.modules.training.trainings.training_definition import TrainingDefinitionEntity 

16from kwai.modules.training.trainings.training_query import TrainingQuery 

17from kwai.modules.training.trainings.training_repository import ( 

18 TrainingNotFoundException, 

19 TrainingRepository, 

20) 

21from kwai.modules.training.trainings.training_tables import ( 

22 TrainingCoachRow, 

23 TrainingContentsTable, 

24 TrainingDefinitionsTable, 

25 TrainingRow, 

26 TrainingsTable, 

27 TrainingTeamRow, 

28 TrainingTeamsTable, 

29 TrainingTextRow, 

30) 

31from kwai.modules.training.trainings.training_team_db_query import TrainingTeamDbQuery 

32from kwai.modules.training.trainings.value_objects import TrainingCoach 

33 

34 

35def _create_entity(rows: list[Record]) -> TrainingEntity: 

36 """Create a training entity from a group of rows.""" 

37 if rows[0][TrainingDefinitionsTable.alias_name("id")] is None: 

38 definition = None 

39 else: 

40 definition = TrainingDefinitionsTable(rows[0]).create_entity( 

41 team=TeamsTable(rows[0]).create_entity(), 

42 owner=OwnersTable(rows[0], "definition_owners").create_owner(), 

43 ) 

44 return TrainingsTable(rows[0]).create_entity( 

45 [ 

46 TrainingContentsTable(row).create_text(OwnersTable(row).create_owner()) 

47 for row in rows 

48 ], 

49 definition=definition, 

50 ) 

51 

52 

53class TrainingDbRepository(TrainingRepository): 

54 """A training repository for a database.""" 

55 

56 def __init__(self, database: Database): 

57 """Initialize the repository. 

58 

59 Args: 

60 database: The database for this repository. 

61 """ 

62 self._database = database 

63 

64 def create_query(self) -> TrainingQuery: 

65 return TrainingDbQuery(self._database) 

66 

67 async def get_by_id(self, id: TrainingIdentifier) -> TrainingEntity: 

68 query = self.create_query() 

69 query.filter_by_id(id) 

70 

71 try: 

72 row_iterator = self.get_all(query, 1) 

73 entity = await anext(row_iterator) 

74 except StopAsyncIteration: 

75 raise TrainingNotFoundException( 

76 f"Training with id {id} does not exist" 

77 ) from None 

78 return entity 

79 

80 async def get_all( 

81 self, 

82 training_query: TrainingQuery | None = None, 

83 limit: int | None = None, 

84 offset: int | None = None, 

85 ) -> AsyncIterator[TrainingEntity] | None: 

86 if training_query is None: 

87 training_query = self.create_query() 

88 

89 trainings: dict[TrainingIdentifier, TrainingEntity] = {} 

90 group_by_column = TrainingsTable.alias_name("id") 

91 

92 row_it = training_query.fetch(limit, offset) 

93 # Handle the first row 

94 try: 

95 record = await anext(row_it) 

96 except StopAsyncIteration: 

97 return 

98 

99 group = [record] 

100 current_key = record[group_by_column] 

101 

102 # Process all other rows 

103 async for record in row_it: 

104 new_key = record[group_by_column] 

105 if new_key != current_key: 

106 training = _create_entity(group) 

107 trainings[training.id] = training 

108 group = [record] 

109 current_key = new_key 

110 else: 

111 group.append(record) 

112 

113 training = _create_entity(group) 

114 trainings[training.id] = training 

115 

116 # Get the coaches of all the trainings. 

117 training_query = TrainingCoachDbQuery(self._database).filter_by_trainings( 

118 *trainings.keys() 

119 ) 

120 coaches: dict[ 

121 TrainingIdentifier, list[TrainingCoach] 

122 ] = await training_query.fetch_coaches() 

123 

124 # Get the teams of all trainings 

125 team_query = TrainingTeamDbQuery(self._database).filter_by_trainings( 

126 *trainings.keys() 

127 ) 

128 teams: dict[ 

129 TrainingIdentifier, list[TeamEntity] 

130 ] = await team_query.fetch_teams() 

131 

132 for training in trainings.values(): 

133 training_coaches = coaches.get(training.id, []) 

134 training_teams = teams.get(training.id, []) 

135 if len(training_coaches) > 0 or len(training_teams) > 0: 

136 yield Entity.replace( 

137 training, coaches=training_coaches, teams=training_teams 

138 ) 

139 else: 

140 yield training 

141 

142 async def create(self, training: TrainingEntity) -> TrainingEntity: 

143 new_id = await self._database.insert( 

144 TrainingsTable.table_name, TrainingRow.persist(training) 

145 ) 

146 result = Entity.replace(training, id_=TrainingIdentifier(new_id)) 

147 

148 content_rows = [ 

149 TrainingTextRow.persist(result, content) for content in training.texts 

150 ] 

151 

152 await self._database.insert(TrainingContentsTable.table_name, *content_rows) 

153 await self._insert_coaches(result) 

154 await self._insert_teams(result) 

155 

156 await self._database.commit() 

157 

158 return result 

159 

160 async def update(self, training: TrainingEntity) -> None: 

161 # Update the training 

162 await self._database.update( 

163 training.id.value, 

164 TrainingsTable.table_name, 

165 TrainingRow.persist(training), 

166 ) 

167 

168 # Update the text, first delete, then insert again. 

169 await self._delete_contents(training) 

170 content_rows = [ 

171 TrainingTextRow.persist(training, content) for content in training.texts 

172 ] 

173 await self._database.insert(TrainingContentsTable.table_name, *content_rows) 

174 

175 # Update coaches, first delete, then insert again. 

176 await self._delete_coaches(training) 

177 await self._insert_coaches(training) 

178 

179 # Update teams, first delete, then insert again. 

180 await self._delete_teams(training) 

181 await self._insert_teams(training) 

182 

183 await self._database.commit() 

184 

185 async def _insert_coaches(self, training: TrainingEntity): 

186 """Insert the related coaches.""" 

187 training_coach_rows = [ 

188 TrainingCoachRow.persist(training, training_coach) 

189 for training_coach in training.coaches 

190 ] 

191 if training_coach_rows: 

192 await self._database.insert( 

193 TrainingCoachRow.__table_name__, *training_coach_rows 

194 ) 

195 

196 async def _insert_teams(self, training: TrainingEntity): 

197 """Insert the related teams.""" 

198 training_team_rows = [ 

199 TrainingTeamRow.persist(training, team) for team in training.teams 

200 ] 

201 if training_team_rows: 

202 await self._database.insert( 

203 TrainingTeamsTable.table_name, *training_team_rows 

204 ) 

205 

206 async def _delete_coaches(self, training: TrainingEntity): 

207 """Delete coaches of the training.""" 

208 delete_coaches_query = ( 

209 self._database.create_query_factory() 

210 .delete(TrainingCoachRow.__table_name__) 

211 .where(field("training_id").eq(training.id.value)) 

212 ) 

213 await self._database.execute(delete_coaches_query) 

214 

215 async def _delete_contents(self, training: TrainingEntity): 

216 """Delete text contents of the training.""" 

217 delete_contents_query = ( 

218 self._database.create_query_factory() 

219 .delete(TrainingContentsTable.table_name) 

220 .where(field("training_id").eq(training.id.value)) 

221 ) 

222 await self._database.execute(delete_contents_query) 

223 

224 async def _delete_teams(self, training: TrainingEntity): 

225 """Delete the teams of the training.""" 

226 delete_teams_query = ( 

227 self._database.create_query_factory() 

228 .delete(TrainingTeamsTable.table_name) 

229 .where(field("training_id").eq(training.id.value)) 

230 ) 

231 await self._database.execute(delete_teams_query) 

232 

233 async def delete(self, training: TrainingEntity) -> None: 

234 await self._database.delete(training.id.value, TrainingsTable.table_name) 

235 

236 (await self._delete_contents(training),) 

237 (await self._delete_coaches(training),) 

238 (await self._delete_teams(training),) 

239 

240 await self._database.commit() 

241 

242 async def reset_definition( 

243 self, training_definition: TrainingDefinitionEntity, delete: bool = False 

244 ) -> None: 

245 trainings_query = ( 

246 self._database.create_query_factory() 

247 .select(TrainingsTable.column("id")) 

248 .from_(TrainingsTable.table_name) 

249 .and_where(field("definition_id").eq(training_definition.id.value)) 

250 ) 

251 if delete: 

252 delete_teams = ( 

253 self._database.create_query_factory() 

254 .delete(TrainingTeamsTable.table_name) 

255 .and_where(TrainingTeamsTable.field("training_id").in_(trainings_query)) 

256 ) 

257 await self._database.execute(delete_teams) 

258 

259 delete_coaches = ( 

260 self._database.create_query_factory() 

261 .delete(TrainingCoachRow.__table_name__) 

262 .and_where(TrainingCoachRow.field("training_id").in_(trainings_query)) 

263 ) 

264 await self._database.execute(delete_coaches) 

265 

266 delete_contents = ( 

267 self._database.create_query_factory() 

268 .delete(TrainingContentsTable.table_name) 

269 .and_where( 

270 TrainingContentsTable.field("training_id").in_(trainings_query) 

271 ) 

272 ) 

273 await self._database.execute(delete_contents) 

274 await self._database.commit() 

275 else: 

276 # Because it is not allowed to update the table that is used 

277 # in a sub query, we need to create a copy. 

278 copy_trainings_query = ( 

279 self._database.create_query_factory() 

280 .select("t.id") 

281 .from_(alias(express("({})", trainings_query), "t")) 

282 ) 

283 update_trainings = ( 

284 self._database.create_query_factory() 

285 .update(TrainingsTable.table_name, {"definition_id": None}) 

286 .where(TrainingsTable.field("id").in_(copy_trainings_query)) 

287 ) 

288 await self._database.execute(update_trainings) 

289 await self._database.commit()