Coverage for kwai/modules/training/trainings/training_db_query.py: 100%

63 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module that implements a training query for a database.""" 

2from datetime import datetime 

3from typing import AsyncIterator 

4 

5from sql_smith.functions import alias, criteria, express, func, group, literal, on 

6 

7from kwai.core.db.database import Database 

8from kwai.core.db.database_query import DatabaseQuery 

9from kwai.core.db.rows import OwnersTable 

10from kwai.modules.training.coaches.coach import CoachEntity 

11from kwai.modules.training.teams.team import TeamEntity 

12from kwai.modules.training.trainings.training import TrainingIdentifier 

13from kwai.modules.training.trainings.training_definition import TrainingDefinitionEntity 

14from kwai.modules.training.trainings.training_query import TrainingQuery 

15from kwai.modules.training.trainings.training_tables import ( 

16 TrainingCoachesTable, 

17 TrainingContentsTable, 

18 TrainingDefinitionsTable, 

19 TrainingsTable, 

20 TrainingTeamsTable, 

21) 

22 

23 

24class TrainingDbQuery(TrainingQuery, DatabaseQuery): 

25 """A database query for trainings.""" 

26 

27 def __init__(self, database: Database): 

28 self._main_query = database.create_query_factory().select() 

29 super().__init__(database) 

30 

31 def init(self): 

32 # This query will be used as CTE, so only joins the tables that are needed 

33 # for counting and limiting results. 

34 self._query.from_(TrainingsTable.table_name).left_join( 

35 TrainingDefinitionsTable.table_name, 

36 on( 

37 TrainingsTable.column("definition_id"), 

38 TrainingDefinitionsTable.column("id"), 

39 ), 

40 ) 

41 self._main_query = ( 

42 self._main_query.from_(TrainingsTable.table_name) 

43 .columns( 

44 *( 

45 self.columns 

46 + OwnersTable.aliases("definition_owners") 

47 + TrainingContentsTable.aliases() 

48 + OwnersTable.aliases() 

49 ) 

50 ) 

51 .with_("limited", self._query) 

52 .right_join("limited", on("limited.id", TrainingsTable.column("id"))) 

53 .left_join( 

54 TrainingDefinitionsTable.table_name, 

55 on( 

56 TrainingsTable.column("definition_id"), 

57 TrainingDefinitionsTable.column("id"), 

58 ), 

59 ) 

60 .left_join( 

61 alias(OwnersTable.table_name, "definition_owners"), 

62 on(TrainingDefinitionsTable.column("user_id"), "definition_owners.id"), 

63 ) 

64 .join( 

65 TrainingContentsTable.table_name, 

66 on( 

67 TrainingContentsTable.column("training_id"), 

68 TrainingsTable.column("id"), 

69 ), 

70 ) 

71 .join( 

72 OwnersTable.table_name, 

73 on(OwnersTable.column("id"), TrainingContentsTable.column("user_id")), 

74 ) 

75 ) 

76 

77 @property 

78 def columns(self): 

79 return TrainingsTable.aliases() + TrainingDefinitionsTable.aliases() 

80 

81 @property 

82 def count_column(self) -> str: 

83 return TrainingsTable.column("id") 

84 

85 def filter_by_id(self, id_: TrainingIdentifier) -> "TrainingQuery": 

86 self._query.and_where(TrainingsTable.field("id").eq(id_.value)) 

87 return self 

88 

89 def filter_by_year_month( 

90 self, year: int, month: int | None = None 

91 ) -> "TrainingQuery": 

92 condition = criteria( 

93 "{} = {}", func("YEAR", TrainingsTable.column("start_date")), literal(year) 

94 ) 

95 if month is not None: 

96 condition = condition.and_( 

97 criteria( 

98 "{} = {}", 

99 func("MONTH", TrainingsTable.column("start_date")), 

100 literal(month), 

101 ) 

102 ) 

103 self._query.and_where(group(condition)) 

104 return self 

105 

106 def filter_by_dates(self, start: datetime, end: datetime) -> "TrainingQuery": 

107 self._query.and_where(TrainingsTable.field("start_date").between(start, end)) 

108 return self 

109 

110 def filter_by_coach(self, coach: CoachEntity) -> "TrainingQuery": 

111 inner_select = ( 

112 self._database.create_query_factory() 

113 .select() 

114 .columns(TrainingCoachesTable.column("training_id")) 

115 .from_(TrainingCoachesTable.table_name) 

116 .where(TrainingCoachesTable.field("coach_id").eq(coach.id.value)) 

117 ) 

118 condition = TrainingsTable.field("id").in_(express("{}", inner_select)) 

119 self._query.and_where(group(condition)) 

120 return self 

121 

122 def filter_by_team(self, team: TeamEntity) -> "TrainingQuery": 

123 inner_select = ( 

124 self._database.create_query_factory() 

125 .select() 

126 .columns(TrainingTeamsTable.column("training_id")) 

127 .from_(TrainingTeamsTable.table_name) 

128 .where(TrainingTeamsTable.field("team_id").eq(team.id.value)) 

129 ) 

130 condition = TrainingsTable.field("id").in_(express("{}", inner_select)) 

131 self._query.and_where(group(condition)) 

132 return self 

133 

134 def filter_by_definition( 

135 self, definition: TrainingDefinitionEntity 

136 ) -> "TrainingQuery": 

137 self._query.and_where( 

138 TrainingsTable.field("definition_id").eq(definition.id.value) 

139 ) 

140 return self 

141 

142 def filter_active(self) -> "TrainingQuery": 

143 self._query.and_where(TrainingsTable.field("active").eq(1)) 

144 return self 

145 

146 def fetch( 

147 self, limit: int | None = None, offset: int | None = None 

148 ) -> AsyncIterator[dict[str, any]]: 

149 self._query.limit(limit) 

150 self._query.offset(offset) 

151 self._query.columns(TrainingsTable.column("id")) 

152 self._main_query.order_by(TrainingsTable.column("id")) 

153 

154 return self._database.fetch(self._main_query) 

155 

156 def order_by_date(self) -> "TrainingQuery": 

157 self._query.order_by(TrainingsTable.column("start_date"), "DESC") 

158 self._main_query.order_by(TrainingsTable.column("start_date"), "DESC") 

159 return self