Coverage for kwai/modules/training/trainings/training_db_query.py: 100%
63 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module that implements a training query for a database."""
2from datetime import datetime
3from typing import AsyncIterator
5from sql_smith.functions import alias, criteria, express, func, group, literal, on
7from kwai.core.db.database import Database
8from kwai.core.db.database_query import DatabaseQuery
9from kwai.core.db.rows import OwnersTable
10from kwai.modules.training.coaches.coach import CoachEntity
11from kwai.modules.training.teams.team import TeamEntity
12from kwai.modules.training.trainings.training import TrainingIdentifier
13from kwai.modules.training.trainings.training_definition import TrainingDefinitionEntity
14from kwai.modules.training.trainings.training_query import TrainingQuery
15from kwai.modules.training.trainings.training_tables import (
16 TrainingCoachesTable,
17 TrainingContentsTable,
18 TrainingDefinitionsTable,
19 TrainingsTable,
20 TrainingTeamsTable,
21)
24class TrainingDbQuery(TrainingQuery, DatabaseQuery):
25 """A database query for trainings."""
27 def __init__(self, database: Database):
28 self._main_query = database.create_query_factory().select()
29 super().__init__(database)
31 def init(self):
32 # This query will be used as CTE, so only joins the tables that are needed
33 # for counting and limiting results.
34 self._query.from_(TrainingsTable.table_name).left_join(
35 TrainingDefinitionsTable.table_name,
36 on(
37 TrainingsTable.column("definition_id"),
38 TrainingDefinitionsTable.column("id"),
39 ),
40 )
41 self._main_query = (
42 self._main_query.from_(TrainingsTable.table_name)
43 .columns(
44 *(
45 self.columns
46 + OwnersTable.aliases("definition_owners")
47 + TrainingContentsTable.aliases()
48 + OwnersTable.aliases()
49 )
50 )
51 .with_("limited", self._query)
52 .right_join("limited", on("limited.id", TrainingsTable.column("id")))
53 .left_join(
54 TrainingDefinitionsTable.table_name,
55 on(
56 TrainingsTable.column("definition_id"),
57 TrainingDefinitionsTable.column("id"),
58 ),
59 )
60 .left_join(
61 alias(OwnersTable.table_name, "definition_owners"),
62 on(TrainingDefinitionsTable.column("user_id"), "definition_owners.id"),
63 )
64 .join(
65 TrainingContentsTable.table_name,
66 on(
67 TrainingContentsTable.column("training_id"),
68 TrainingsTable.column("id"),
69 ),
70 )
71 .join(
72 OwnersTable.table_name,
73 on(OwnersTable.column("id"), TrainingContentsTable.column("user_id")),
74 )
75 )
77 @property
78 def columns(self):
79 return TrainingsTable.aliases() + TrainingDefinitionsTable.aliases()
81 @property
82 def count_column(self) -> str:
83 return TrainingsTable.column("id")
85 def filter_by_id(self, id_: TrainingIdentifier) -> "TrainingQuery":
86 self._query.and_where(TrainingsTable.field("id").eq(id_.value))
87 return self
89 def filter_by_year_month(
90 self, year: int, month: int | None = None
91 ) -> "TrainingQuery":
92 condition = criteria(
93 "{} = {}", func("YEAR", TrainingsTable.column("start_date")), literal(year)
94 )
95 if month is not None:
96 condition = condition.and_(
97 criteria(
98 "{} = {}",
99 func("MONTH", TrainingsTable.column("start_date")),
100 literal(month),
101 )
102 )
103 self._query.and_where(group(condition))
104 return self
106 def filter_by_dates(self, start: datetime, end: datetime) -> "TrainingQuery":
107 self._query.and_where(TrainingsTable.field("start_date").between(start, end))
108 return self
110 def filter_by_coach(self, coach: CoachEntity) -> "TrainingQuery":
111 inner_select = (
112 self._database.create_query_factory()
113 .select()
114 .columns(TrainingCoachesTable.column("training_id"))
115 .from_(TrainingCoachesTable.table_name)
116 .where(TrainingCoachesTable.field("coach_id").eq(coach.id.value))
117 )
118 condition = TrainingsTable.field("id").in_(express("{}", inner_select))
119 self._query.and_where(group(condition))
120 return self
122 def filter_by_team(self, team: TeamEntity) -> "TrainingQuery":
123 inner_select = (
124 self._database.create_query_factory()
125 .select()
126 .columns(TrainingTeamsTable.column("training_id"))
127 .from_(TrainingTeamsTable.table_name)
128 .where(TrainingTeamsTable.field("team_id").eq(team.id.value))
129 )
130 condition = TrainingsTable.field("id").in_(express("{}", inner_select))
131 self._query.and_where(group(condition))
132 return self
134 def filter_by_definition(
135 self, definition: TrainingDefinitionEntity
136 ) -> "TrainingQuery":
137 self._query.and_where(
138 TrainingsTable.field("definition_id").eq(definition.id.value)
139 )
140 return self
142 def filter_active(self) -> "TrainingQuery":
143 self._query.and_where(TrainingsTable.field("active").eq(1))
144 return self
146 def fetch(
147 self, limit: int | None = None, offset: int | None = None
148 ) -> AsyncIterator[dict[str, any]]:
149 self._query.limit(limit)
150 self._query.offset(offset)
151 self._query.columns(TrainingsTable.column("id"))
152 self._main_query.order_by(TrainingsTable.column("id"))
154 return self._database.fetch(self._main_query)
156 def order_by_date(self) -> "TrainingQuery":
157 self._query.order_by(TrainingsTable.column("start_date"), "DESC")
158 self._main_query.order_by(TrainingsTable.column("start_date"), "DESC")
159 return self