Coverage for src/kwai/modules/portal/news/news_item_db_query.py: 91%
65 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that implements a NewsItemQuery for a database."""
3from typing import AsyncIterator
5from sql_smith.functions import criteria, express, func, group, literal, on
7from kwai.core.db.database import Database
8from kwai.core.db.database_query import DatabaseQuery
9from kwai.core.db.rows import OwnersTable
10from kwai.core.domain.value_objects.timestamp import Timestamp
11from kwai.core.domain.value_objects.unique_id import UniqueId
12from kwai.modules.portal.applications.application_tables import ApplicationsTable
13from kwai.modules.portal.news.news_item import NewsItemIdentifier
14from kwai.modules.portal.news.news_item_query import NewsItemQuery
15from kwai.modules.portal.news.news_tables import (
16 NewsItemsTable,
17 NewsItemTextsTable,
18)
21class NewsItemDbQuery(NewsItemQuery, DatabaseQuery):
22 """A database query for news stories."""
24 def __init__(self, database: Database):
25 self._main_query = database.create_query_factory().select()
26 super().__init__(database)
28 def init(self):
29 # This query will be used as CTE, so only join the tables that are needed
30 # for counting and limiting the results.
31 self._query.from_(NewsItemsTable.table_name).join(
32 ApplicationsTable.table_name,
33 on(
34 ApplicationsTable.column("id"),
35 NewsItemsTable.column("application_id"),
36 ),
37 )
39 self._main_query = (
40 self._main_query.from_(NewsItemsTable.table_name)
41 .columns(*(self.columns + NewsItemTextsTable.aliases()))
42 .with_("limited", self._query)
43 .right_join("limited", on("limited.id", NewsItemsTable.column("id")))
44 .join(
45 ApplicationsTable.table_name,
46 on(
47 ApplicationsTable.column("id"),
48 NewsItemsTable.column("application_id"),
49 ),
50 )
51 .join(
52 NewsItemTextsTable.table_name,
53 on(NewsItemTextsTable.column("news_id"), NewsItemsTable.column("id")),
54 )
55 .join(
56 OwnersTable.table_name,
57 on(OwnersTable.column("id"), NewsItemTextsTable.column("user_id")),
58 )
59 )
61 @property
62 def columns(self):
63 return (
64 NewsItemsTable.aliases()
65 + ApplicationsTable.aliases()
66 + OwnersTable.aliases()
67 )
69 @property
70 def count_column(self) -> str:
71 return NewsItemsTable.column("id")
73 def filter_by_id(self, id_: NewsItemIdentifier) -> "NewsItemQuery":
74 self._query.and_where(NewsItemsTable.field("id").eq(id_.value))
75 return self
77 def filter_by_publication_date(
78 self, year: int, month: int | None = None
79 ) -> "NewsItemQuery":
80 condition = criteria(
81 "{} = {}",
82 func("YEAR", NewsItemsTable.column("publish_date")),
83 literal(year),
84 )
85 if month is not None:
86 condition.and_(
87 criteria(
88 "{} = {}",
89 func("MONTH", NewsItemsTable.column("publish_date")),
90 literal(month),
91 )
92 )
93 self._query.and_where(condition)
94 return self
96 def filter_by_promoted(self) -> "NewsItemQuery":
97 now = str(Timestamp.create_now())
98 condition = (
99 NewsItemsTable.field("promotion")
100 .gt(0)
101 .and_(
102 group(
103 NewsItemsTable.field("promotion_end_date")
104 .is_null()
105 .or_(NewsItemsTable.field("promotion_end_date").gt(now))
106 )
107 )
108 )
109 self._query.and_where(condition)
110 self._query.order_by(NewsItemsTable.column("promotion"))
111 return self
113 def filter_by_application(self, application: int | str) -> "NewsItemQuery":
114 if isinstance(application, str):
115 self._query.and_where(ApplicationsTable.field("name").eq(application))
116 else:
117 self._query.and_where(ApplicationsTable.field("id").eq(application))
119 return self
121 def filter_by_active(self) -> "NewsItemQuery":
122 now = str(Timestamp.create_now())
123 self._query.and_where(
124 group(
125 NewsItemsTable.field("enabled")
126 .eq(True)
127 .and_(NewsItemsTable.field("publish_date").lte(now))
128 .or_(
129 group(
130 NewsItemsTable.field("end_date")
131 .is_not_null()
132 .and_(NewsItemsTable.field("end_date").gt(now))
133 )
134 )
135 )
136 )
138 return self
140 def filter_by_user(self, user: int | UniqueId) -> "NewsItemQuery":
141 inner_select = (
142 self._database.create_query_factory()
143 .select(OwnersTable.column("id"))
144 .from_(OwnersTable.table_name)
145 )
146 if isinstance(user, UniqueId):
147 inner_select.where(OwnersTable.field("uuid").eq(str(user)))
148 else:
149 inner_select.where(OwnersTable.field("id").eq(user))
151 self._main_query.and_where(
152 group(NewsItemTextsTable.field("user_id").in_(express("%s", inner_select)))
153 )
154 return self
156 def order_by_publication_date(self) -> "NewsItemQuery":
157 self._main_query.order_by(NewsItemsTable.column("publish_date"), "DESC")
158 # Also add the order to the CTE
159 self._query.order_by(NewsItemsTable.column("publish_date"), "DESC")
160 return self
162 def fetch(
163 self, limit: int | None = None, offset: int | None = None
164 ) -> AsyncIterator[dict[str, any]]:
165 self._query.limit(limit)
166 self._query.offset(offset)
167 self._query.columns(NewsItemsTable.column("id"))
168 self._main_query.order_by(NewsItemsTable.column("id"))
170 return self._database.fetch(self._main_query)