Coverage for src/kwai/modules/training/trainings/training_db_query.py: 100%
64 statements
« prev ^ index » next coverage.py v7.7.1, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2024-01-01 00:00 +0000
1"""Module that implements a training query for a database."""
3from typing import AsyncIterator
5from sql_smith.functions import alias, criteria, express, func, group, literal, on
7from kwai.core.db.database import Database
8from kwai.core.db.database_query import DatabaseQuery
9from kwai.core.db.rows import OwnersTable
10from kwai.core.domain.value_objects.timestamp import Timestamp
11from kwai.modules.training.coaches.coach import CoachEntity
12from kwai.modules.training.teams.team import TeamEntity
13from kwai.modules.training.teams.team_tables import TeamsTable
14from kwai.modules.training.trainings.training import TrainingIdentifier
15from kwai.modules.training.trainings.training_definition import TrainingDefinitionEntity
16from kwai.modules.training.trainings.training_query import TrainingQuery
17from kwai.modules.training.trainings.training_tables import (
18 TrainingCoachRow,
19 TrainingContentsTable,
20 TrainingDefinitionsTable,
21 TrainingsTable,
22 TrainingTeamsTable,
23)
26class TrainingDbQuery(TrainingQuery, DatabaseQuery):
27 """A database query for trainings."""
29 def __init__(self, database: Database):
30 self._main_query = database.create_query_factory().select()
31 super().__init__(database)
33 def init(self):
34 # This query will be used as CTE, so only joins the tables that are needed
35 # for counting and limiting results.
36 self._query.from_(TrainingsTable.table_name).left_join(
37 TrainingDefinitionsTable.table_name,
38 on(
39 TrainingsTable.column("definition_id"),
40 TrainingDefinitionsTable.column("id"),
41 ),
42 )
43 self._main_query = (
44 self._main_query.from_(TrainingsTable.table_name)
45 .columns(
46 *(
47 self.columns
48 + TeamsTable.aliases()
49 + OwnersTable.aliases("definition_owners")
50 + TrainingContentsTable.aliases()
51 + OwnersTable.aliases()
52 )
53 )
54 .with_("limited", self._query)
55 .right_join("limited", on("limited.id", TrainingsTable.column("id")))
56 .left_join(
57 TrainingDefinitionsTable.table_name,
58 on(
59 TrainingsTable.column("definition_id"),
60 TrainingDefinitionsTable.column("id"),
61 ),
62 )
63 .left_join(
64 alias(OwnersTable.table_name, "definition_owners"),
65 on(TrainingDefinitionsTable.column("user_id"), "definition_owners.id"),
66 )
67 .left_join(
68 TeamsTable.table_name,
69 on(TeamsTable.column("id"), TrainingDefinitionsTable.column("team_id")),
70 )
71 .join(
72 TrainingContentsTable.table_name,
73 on(
74 TrainingContentsTable.column("training_id"),
75 TrainingsTable.column("id"),
76 ),
77 )
78 .join(
79 OwnersTable.table_name,
80 on(OwnersTable.column("id"), TrainingContentsTable.column("user_id")),
81 )
82 )
84 @property
85 def columns(self):
86 return TrainingsTable.aliases() + TrainingDefinitionsTable.aliases()
88 @property
89 def count_column(self) -> str:
90 return TrainingsTable.column("id")
92 def filter_by_id(self, id_: TrainingIdentifier) -> "TrainingQuery":
93 self._query.and_where(TrainingsTable.field("id").eq(id_.value))
94 return self
96 def filter_by_year_month(
97 self, year: int, month: int | None = None
98 ) -> "TrainingQuery":
99 condition = criteria(
100 "{} = {}", func("YEAR", TrainingsTable.column("start_date")), literal(year)
101 )
102 if month is not None:
103 condition = condition.and_(
104 criteria(
105 "{} = {}",
106 func("MONTH", TrainingsTable.column("start_date")),
107 literal(month),
108 )
109 )
110 self._query.and_where(group(condition))
111 return self
113 def filter_by_dates(self, start: Timestamp, end: Timestamp) -> "TrainingQuery":
114 self._query.and_where(
115 TrainingsTable.field("start_date").between(str(start), str(end))
116 )
117 return self
119 def filter_by_coach(self, coach: CoachEntity) -> "TrainingQuery":
120 inner_select = (
121 self._database.create_query_factory()
122 .select()
123 .columns(TrainingCoachRow.column("training_id"))
124 .from_(TrainingCoachRow.__table_name__)
125 .where(TrainingCoachRow.field("coach_id").eq(coach.id.value))
126 )
127 condition = TrainingsTable.field("id").in_(express("{}", inner_select))
128 self._query.and_where(group(condition))
129 return self
131 def filter_by_team(self, team: TeamEntity) -> "TrainingQuery":
132 inner_select = (
133 self._database.create_query_factory()
134 .select()
135 .columns(TrainingTeamsTable.column("training_id"))
136 .from_(TrainingTeamsTable.table_name)
137 .where(TrainingTeamsTable.field("team_id").eq(team.id.value))
138 )
139 condition = TrainingsTable.field("id").in_(express("{}", inner_select))
140 self._query.and_where(group(condition))
141 return self
143 def filter_by_definition(
144 self, definition: TrainingDefinitionEntity
145 ) -> "TrainingQuery":
146 self._query.and_where(
147 TrainingsTable.field("definition_id").eq(definition.id.value)
148 )
149 return self
151 def filter_active(self) -> "TrainingQuery":
152 self._query.and_where(TrainingsTable.field("active").eq(1))
153 return self
155 def fetch(
156 self, limit: int | None = None, offset: int | None = None
157 ) -> AsyncIterator[dict[str, any]]:
158 self._query.limit(limit)
159 self._query.offset(offset)
160 self._query.columns(TrainingsTable.column("id"))
161 self._main_query.order_by(TrainingsTable.column("id"))
163 return self._database.fetch(self._main_query)
165 def order_by_date(self) -> "TrainingQuery":
166 self._query.order_by(TrainingsTable.column("start_date"), "ASC")
167 self._main_query.order_by(TrainingsTable.column("start_date"), "ASC")
168 return self