Coverage for src/kwai/modules/training/trainings/training_db_query.py: 100%

64 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2024-01-01 00:00 +0000

1"""Module that implements a training query for a database.""" 

2 

3from typing import AsyncIterator 

4 

5from sql_smith.functions import alias, criteria, express, func, group, literal, on 

6 

7from kwai.core.db.database import Database 

8from kwai.core.db.database_query import DatabaseQuery 

9from kwai.core.db.rows import OwnersTable 

10from kwai.core.domain.value_objects.timestamp import Timestamp 

11from kwai.modules.training.coaches.coach import CoachEntity 

12from kwai.modules.training.teams.team import TeamEntity 

13from kwai.modules.training.teams.team_tables import TeamsTable 

14from kwai.modules.training.trainings.training import TrainingIdentifier 

15from kwai.modules.training.trainings.training_definition import TrainingDefinitionEntity 

16from kwai.modules.training.trainings.training_query import TrainingQuery 

17from kwai.modules.training.trainings.training_tables import ( 

18 TrainingCoachRow, 

19 TrainingContentsTable, 

20 TrainingDefinitionsTable, 

21 TrainingsTable, 

22 TrainingTeamsTable, 

23) 

24 

25 

26class TrainingDbQuery(TrainingQuery, DatabaseQuery): 

27 """A database query for trainings.""" 

28 

29 def __init__(self, database: Database): 

30 self._main_query = database.create_query_factory().select() 

31 super().__init__(database) 

32 

33 def init(self): 

34 # This query will be used as CTE, so only joins the tables that are needed 

35 # for counting and limiting results. 

36 self._query.from_(TrainingsTable.table_name).left_join( 

37 TrainingDefinitionsTable.table_name, 

38 on( 

39 TrainingsTable.column("definition_id"), 

40 TrainingDefinitionsTable.column("id"), 

41 ), 

42 ) 

43 self._main_query = ( 

44 self._main_query.from_(TrainingsTable.table_name) 

45 .columns( 

46 *( 

47 self.columns 

48 + TeamsTable.aliases() 

49 + OwnersTable.aliases("definition_owners") 

50 + TrainingContentsTable.aliases() 

51 + OwnersTable.aliases() 

52 ) 

53 ) 

54 .with_("limited", self._query) 

55 .right_join("limited", on("limited.id", TrainingsTable.column("id"))) 

56 .left_join( 

57 TrainingDefinitionsTable.table_name, 

58 on( 

59 TrainingsTable.column("definition_id"), 

60 TrainingDefinitionsTable.column("id"), 

61 ), 

62 ) 

63 .left_join( 

64 alias(OwnersTable.table_name, "definition_owners"), 

65 on(TrainingDefinitionsTable.column("user_id"), "definition_owners.id"), 

66 ) 

67 .left_join( 

68 TeamsTable.table_name, 

69 on(TeamsTable.column("id"), TrainingDefinitionsTable.column("team_id")), 

70 ) 

71 .join( 

72 TrainingContentsTable.table_name, 

73 on( 

74 TrainingContentsTable.column("training_id"), 

75 TrainingsTable.column("id"), 

76 ), 

77 ) 

78 .join( 

79 OwnersTable.table_name, 

80 on(OwnersTable.column("id"), TrainingContentsTable.column("user_id")), 

81 ) 

82 ) 

83 

84 @property 

85 def columns(self): 

86 return TrainingsTable.aliases() + TrainingDefinitionsTable.aliases() 

87 

88 @property 

89 def count_column(self) -> str: 

90 return TrainingsTable.column("id") 

91 

92 def filter_by_id(self, id_: TrainingIdentifier) -> "TrainingQuery": 

93 self._query.and_where(TrainingsTable.field("id").eq(id_.value)) 

94 return self 

95 

96 def filter_by_year_month( 

97 self, year: int, month: int | None = None 

98 ) -> "TrainingQuery": 

99 condition = criteria( 

100 "{} = {}", func("YEAR", TrainingsTable.column("start_date")), literal(year) 

101 ) 

102 if month is not None: 

103 condition = condition.and_( 

104 criteria( 

105 "{} = {}", 

106 func("MONTH", TrainingsTable.column("start_date")), 

107 literal(month), 

108 ) 

109 ) 

110 self._query.and_where(group(condition)) 

111 return self 

112 

113 def filter_by_dates(self, start: Timestamp, end: Timestamp) -> "TrainingQuery": 

114 self._query.and_where( 

115 TrainingsTable.field("start_date").between(str(start), str(end)) 

116 ) 

117 return self 

118 

119 def filter_by_coach(self, coach: CoachEntity) -> "TrainingQuery": 

120 inner_select = ( 

121 self._database.create_query_factory() 

122 .select() 

123 .columns(TrainingCoachRow.column("training_id")) 

124 .from_(TrainingCoachRow.__table_name__) 

125 .where(TrainingCoachRow.field("coach_id").eq(coach.id.value)) 

126 ) 

127 condition = TrainingsTable.field("id").in_(express("{}", inner_select)) 

128 self._query.and_where(group(condition)) 

129 return self 

130 

131 def filter_by_team(self, team: TeamEntity) -> "TrainingQuery": 

132 inner_select = ( 

133 self._database.create_query_factory() 

134 .select() 

135 .columns(TrainingTeamsTable.column("training_id")) 

136 .from_(TrainingTeamsTable.table_name) 

137 .where(TrainingTeamsTable.field("team_id").eq(team.id.value)) 

138 ) 

139 condition = TrainingsTable.field("id").in_(express("{}", inner_select)) 

140 self._query.and_where(group(condition)) 

141 return self 

142 

143 def filter_by_definition( 

144 self, definition: TrainingDefinitionEntity 

145 ) -> "TrainingQuery": 

146 self._query.and_where( 

147 TrainingsTable.field("definition_id").eq(definition.id.value) 

148 ) 

149 return self 

150 

151 def filter_active(self) -> "TrainingQuery": 

152 self._query.and_where(TrainingsTable.field("active").eq(1)) 

153 return self 

154 

155 def fetch( 

156 self, limit: int | None = None, offset: int | None = None 

157 ) -> AsyncIterator[dict[str, any]]: 

158 self._query.limit(limit) 

159 self._query.offset(offset) 

160 self._query.columns(TrainingsTable.column("id")) 

161 self._main_query.order_by(TrainingsTable.column("id")) 

162 

163 return self._database.fetch(self._main_query) 

164 

165 def order_by_date(self) -> "TrainingQuery": 

166 self._query.order_by(TrainingsTable.column("start_date"), "ASC") 

167 self._main_query.order_by(TrainingsTable.column("start_date"), "ASC") 

168 return self