Coverage for kwai/modules/news/stories/story_db_query.py: 91%
64 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module that implements a StoryQuery for a database."""
2from datetime import datetime
3from typing import AsyncIterator
5from sql_smith.functions import criteria, express, func, group, literal, on
7from kwai.core.db.database import Database
8from kwai.core.db.database_query import DatabaseQuery
9from kwai.core.db.rows import OwnersTable
10from kwai.core.domain.value_objects.unique_id import UniqueId
11from kwai.modules.news.stories.story import StoryIdentifier
12from kwai.modules.news.stories.story_query import StoryQuery
13from kwai.modules.news.stories.story_tables import (
14 ApplicationsTable,
15 StoriesTable,
16 StoryContentsTable,
17)
20class StoryDbQuery(StoryQuery, DatabaseQuery):
21 """A database query for news stories."""
23 def __init__(self, database: Database):
24 self._main_query = database.create_query_factory().select()
25 super().__init__(database)
27 def init(self):
28 # This query will be used as CTE, so only join the tables that are needed
29 # for counting and limiting the results.
30 self._query.from_(StoriesTable.table_name).join(
31 ApplicationsTable.table_name,
32 on(
33 ApplicationsTable.column("id"),
34 StoriesTable.column("application_id"),
35 ),
36 )
38 self._main_query = (
39 self._main_query.from_(StoriesTable.table_name)
40 .columns(*(self.columns + StoryContentsTable.aliases()))
41 .with_("limited", self._query)
42 .right_join("limited", on("limited.id", StoriesTable.column("id")))
43 .join(
44 ApplicationsTable.table_name,
45 on(
46 ApplicationsTable.column("id"),
47 StoriesTable.column("application_id"),
48 ),
49 )
50 .join(
51 StoryContentsTable.table_name,
52 on(StoryContentsTable.column("news_id"), StoriesTable.column("id")),
53 )
54 .join(
55 OwnersTable.table_name,
56 on(OwnersTable.column("id"), StoryContentsTable.column("user_id")),
57 )
58 )
60 @property
61 def columns(self):
62 return (
63 StoriesTable.aliases() + ApplicationsTable.aliases() + OwnersTable.aliases()
64 )
66 @property
67 def count_column(self) -> str:
68 return StoriesTable.column("id")
70 def filter_by_id(self, id_: StoryIdentifier) -> "StoryQuery":
71 self._query.and_where(StoriesTable.field("id").eq(id_.value))
72 return self
74 def filter_by_publication_date(
75 self, year: int, month: int | None = None
76 ) -> "StoryQuery":
77 condition = criteria(
78 "{} = {}", func("YEAR", StoriesTable.column("publish_date")), literal(year)
79 )
80 if month is not None:
81 condition.and_(
82 criteria(
83 "{} = {}",
84 func("MONTH", StoriesTable.column("publish_date")),
85 literal(month),
86 )
87 )
88 self._query.and_where(condition)
89 return self
91 def filter_by_promoted(self) -> "StoryQuery":
92 now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
93 condition = (
94 StoriesTable.field("promotion")
95 .gt(0)
96 .and_(
97 group(
98 StoriesTable.field("promotion_end_date")
99 .is_null()
100 .or_(StoriesTable.field("promotion_end_date").gt(now))
101 )
102 )
103 )
104 self._query.and_where(condition)
105 self._query.order_by(StoriesTable.column("promotion"))
106 return self
108 def filter_by_application(self, application: int | str) -> "StoryQuery":
109 if isinstance(application, str):
110 self._query.and_where(ApplicationsTable.field("name").eq(application))
111 else:
112 self._query.and_where(ApplicationsTable.field("id").eq(application))
114 return self
116 def filter_by_active(self) -> "StoryQuery":
117 now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
118 self._query.and_where(
119 group(
120 StoriesTable.field("enabled")
121 .eq(True)
122 .and_(StoriesTable.field("publish_date").lte(now))
123 .or_(
124 group(
125 StoriesTable.field("end_date")
126 .is_not_null()
127 .and_(StoriesTable.field("end_date").gt(now))
128 )
129 )
130 )
131 )
133 return self
135 def filter_by_user(self, user: int | UniqueId) -> "StoryQuery":
136 inner_select = (
137 self._database.create_query_factory()
138 .select(OwnersTable.column("id"))
139 .from_(OwnersTable.table_name)
140 )
141 if isinstance(user, UniqueId):
142 inner_select.where(OwnersTable.field("uuid").eq(str(user)))
143 else:
144 inner_select.where(OwnersTable.field("id").eq(user))
146 self._main_query.and_where(
147 group(StoryContentsTable.field("user_id").in_(express("%s", inner_select)))
148 )
149 return self
151 def order_by_publication_date(self) -> "StoryQuery":
152 self._main_query.order_by(StoriesTable.column("publish_date"), "DESC")
153 # Also add the order to the CTE
154 self._query.order_by(StoriesTable.column("publish_date"), "DESC")
155 return self
157 def fetch(
158 self, limit: int | None = None, offset: int | None = None
159 ) -> AsyncIterator[dict[str, any]]:
160 self._query.limit(limit)
161 self._query.offset(offset)
162 self._query.columns(StoriesTable.column("id"))
163 self._main_query.order_by(StoriesTable.column("id"))
165 return self._database.fetch(self._main_query)