Coverage for kwai/modules/training/trainings/training_db_repository.py: 93%

107 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module for implementing a training repository for a database.""" 

2from typing import AsyncIterator 

3 

4from sql_smith.functions import field 

5 

6from kwai.core.db.database import Database, Record 

7from kwai.core.db.rows import OwnersTable 

8from kwai.core.domain.entity import Entity 

9from kwai.modules.training.teams.team import TeamEntity 

10from kwai.modules.training.trainings.training import TrainingEntity, TrainingIdentifier 

11from kwai.modules.training.trainings.training_coach_db_query import TrainingCoachDbQuery 

12from kwai.modules.training.trainings.training_db_query import TrainingDbQuery 

13from kwai.modules.training.trainings.training_query import TrainingQuery 

14from kwai.modules.training.trainings.training_repository import ( 

15 TrainingNotFoundException, 

16 TrainingRepository, 

17) 

18from kwai.modules.training.trainings.training_tables import ( 

19 TrainingCoachesTable, 

20 TrainingCoachRow, 

21 TrainingContentRow, 

22 TrainingContentsTable, 

23 TrainingDefinitionsTable, 

24 TrainingRow, 

25 TrainingsTable, 

26 TrainingTeamRow, 

27 TrainingTeamsTable, 

28) 

29from kwai.modules.training.trainings.training_team_db_query import TrainingTeamDbQuery 

30from kwai.modules.training.trainings.value_objects import TrainingCoach 

31 

32 

33def _create_entity(rows: list[Record]) -> TrainingEntity: 

34 """Create a training entity from a group of rows.""" 

35 if rows[0][TrainingDefinitionsTable.alias_name("id")] is None: 

36 definition = None 

37 else: 

38 definition = TrainingDefinitionsTable(rows[0]).create_entity( 

39 OwnersTable(rows[0], "definition_owners").create_owner() 

40 ) 

41 return TrainingsTable(rows[0]).create_entity( 

42 [ 

43 TrainingContentsTable(row).create_content(OwnersTable(row).create_owner()) 

44 for row in rows 

45 ], 

46 definition=definition, 

47 ) 

48 

49 

50class TrainingDbRepository(TrainingRepository): 

51 """A training repository for a database.""" 

52 

53 def __init__(self, database: Database): 

54 """Initialize the repository. 

55 

56 Args: 

57 database: The database for this repository. 

58 """ 

59 self._database = database 

60 

61 def create_query(self) -> TrainingQuery: 

62 return TrainingDbQuery(self._database) 

63 

64 async def get_by_id(self, id: TrainingIdentifier) -> TrainingEntity: 

65 query = self.create_query() 

66 query.filter_by_id(id) 

67 

68 try: 

69 row_iterator = self.get_all(query, 1) 

70 entity = await anext(row_iterator) 

71 except StopAsyncIteration: 

72 raise TrainingNotFoundException( 

73 f"Training with id {id} does not exist" 

74 ) from None 

75 return entity 

76 

77 async def get_all( 

78 self, 

79 training_query: TrainingQuery | None = None, 

80 limit: int | None = None, 

81 offset: int | None = None, 

82 ) -> AsyncIterator[TrainingEntity] | None: 

83 if training_query is None: 

84 training_query = self.create_query() 

85 

86 trainings: dict[TrainingIdentifier, TrainingEntity] = {} 

87 group_by_column = TrainingsTable.alias_name("id") 

88 

89 row_it = training_query.fetch(limit, offset) 

90 # Handle the first row 

91 try: 

92 record = await anext(row_it) 

93 except StopAsyncIteration: 

94 return 

95 

96 group = [record] 

97 current_key = record[group_by_column] 

98 

99 # Process all other rows 

100 async for record in row_it: 

101 new_key = record[group_by_column] 

102 if new_key != current_key: 

103 training = _create_entity(group) 

104 trainings[training.id] = training 

105 group = [record] 

106 current_key = new_key 

107 else: 

108 group.append(record) 

109 

110 training = _create_entity(group) 

111 trainings[training.id] = training 

112 

113 # Get the coaches of all the trainings. 

114 training_query = TrainingCoachDbQuery(self._database).filter_by_trainings( 

115 *trainings.keys() 

116 ) 

117 coaches: dict[ 

118 TrainingIdentifier, list[TrainingCoach] 

119 ] = await training_query.fetch_coaches() 

120 

121 # Get the teams of all trainings 

122 team_query = TrainingTeamDbQuery(self._database).filter_by_trainings( 

123 *trainings.keys() 

124 ) 

125 teams: dict[ 

126 TrainingIdentifier, list[TeamEntity] 

127 ] = await team_query.fetch_teams() 

128 

129 for training in trainings.values(): 

130 training_coaches = coaches.get(training.id, []) 

131 training_teams = teams.get(training.id, []) 

132 if len(training_coaches) > 0 or len(training_teams) > 0: 

133 yield Entity.replace( 

134 training, coaches=training_coaches, teams=training_teams 

135 ) 

136 else: 

137 yield training 

138 

139 async def create(self, training: TrainingEntity) -> TrainingEntity: 

140 new_id = await self._database.insert( 

141 TrainingsTable.table_name, TrainingRow.persist(training) 

142 ) 

143 result = Entity.replace(training, id_=TrainingIdentifier(new_id)) 

144 

145 content_rows = [ 

146 TrainingContentRow.persist(result, content) for content in training.content 

147 ] 

148 

149 await self._database.insert(TrainingContentsTable.table_name, *content_rows) 

150 await self._insert_coaches(result) 

151 await self._insert_teams(result) 

152 

153 await self._database.commit() 

154 

155 return result 

156 

157 async def update(self, training: TrainingEntity) -> None: 

158 # Update the training 

159 await self._database.update( 

160 training.id.value, 

161 TrainingsTable.table_name, 

162 TrainingRow.persist(training), 

163 ) 

164 

165 # Update the text, first delete, then insert again. 

166 await self._delete_contents(training) 

167 content_rows = [ 

168 TrainingContentRow.persist(training, content) 

169 for content in training.content 

170 ] 

171 await self._database.insert(TrainingContentsTable.table_name, *content_rows) 

172 

173 # Update coaches, first delete, then insert again. 

174 await self._delete_coaches(training) 

175 await self._insert_coaches(training) 

176 

177 # Update teams, first delete, then insert again. 

178 await self._delete_teams(training) 

179 await self._insert_teams(training) 

180 

181 await self._database.commit() 

182 

183 async def _insert_coaches(self, training: TrainingEntity): 

184 """Insert the related coaches.""" 

185 training_coach_rows = [ 

186 TrainingCoachRow.persist(training, training_coach) 

187 for training_coach in training.coaches 

188 ] 

189 if training_coach_rows: 

190 await self._database.insert( 

191 TrainingCoachesTable.table_name, *training_coach_rows 

192 ) 

193 

194 async def _insert_teams(self, training: TrainingEntity): 

195 """Insert the related teams.""" 

196 training_team_rows = [ 

197 TrainingTeamRow.persist(training, team) for team in training.teams 

198 ] 

199 if training_team_rows: 

200 await self._database.insert( 

201 TrainingTeamsTable.table_name, *training_team_rows 

202 ) 

203 

204 async def _delete_coaches(self, training: TrainingEntity): 

205 """Delete coaches of the training.""" 

206 delete_coaches_query = ( 

207 self._database.create_query_factory() 

208 .delete(TrainingCoachesTable.table_name) 

209 .where(field("training_id").eq(training.id.value)) 

210 ) 

211 await self._database.execute(delete_coaches_query) 

212 

213 async def _delete_contents(self, training: TrainingEntity): 

214 """Delete text contents of the training.""" 

215 delete_contents_query = ( 

216 self._database.create_query_factory() 

217 .delete(TrainingContentsTable.table_name) 

218 .where(field("training_id").eq(training.id.value)) 

219 ) 

220 await self._database.execute(delete_contents_query) 

221 

222 async def _delete_teams(self, training: TrainingEntity): 

223 """Delete the teams of the training.""" 

224 delete_teams_query = ( 

225 self._database.create_query_factory() 

226 .delete(TrainingTeamsTable.table_name) 

227 .where(field("training_id").eq(training.id.value)) 

228 ) 

229 await self._database.execute(delete_teams_query) 

230 

231 async def delete(self, training: TrainingEntity) -> None: 

232 await self._database.delete(training.id.value, TrainingsTable.table_name) 

233 

234 await self._delete_contents(training), 

235 await self._delete_coaches(training), 

236 await self._delete_teams(training), 

237 

238 await self._database.commit()