Coverage for src/kwai/modules/portal/pages/page_db_query.py: 90%
48 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that implements a PageQuery for a database."""
3from typing import AsyncIterator
5from sql_smith.functions import express, group, on
7from kwai.core.db.database import Database
8from kwai.core.db.database_query import DatabaseQuery
9from kwai.core.db.rows import OwnersTable
10from kwai.core.domain.value_objects.unique_id import UniqueId
11from kwai.modules.portal.applications.application_tables import ApplicationsTable
12from kwai.modules.portal.pages.page import PageIdentifier
13from kwai.modules.portal.pages.page_query import PageQuery
14from kwai.modules.portal.pages.page_tables import (
15 PageContentsTable,
16 PagesTable,
17)
20class PageDbQuery(PageQuery, DatabaseQuery):
21 """A database query for pages."""
23 def __init__(self, database: Database):
24 self._main_query = database.create_query_factory().select()
25 super().__init__(database)
27 def init(self):
28 self._query.from_(PagesTable.table_name).join(
29 ApplicationsTable.table_name,
30 on(
31 ApplicationsTable.column("id"),
32 PagesTable.column("application_id"),
33 ),
34 )
35 self._main_query = (
36 self._main_query.from_(PagesTable.table_name)
37 .columns(*(self.columns + PageContentsTable.aliases()))
38 .with_("limited", self._query)
39 .right_join("limited", on("limited.id", PagesTable.column("id")))
40 .join(
41 ApplicationsTable.table_name,
42 on(
43 ApplicationsTable.column("id"),
44 PagesTable.column("application_id"),
45 ),
46 )
47 .join(
48 PageContentsTable.table_name,
49 on(PageContentsTable.column("page_id"), PagesTable.column("id")),
50 )
51 .join(
52 OwnersTable.table_name,
53 on(OwnersTable.column("id"), PageContentsTable.column("user_id")),
54 )
55 )
57 @property
58 def columns(self):
59 return (
60 PagesTable.aliases() + ApplicationsTable.aliases() + OwnersTable.aliases()
61 )
63 @property
64 def count_column(self) -> str:
65 return PagesTable.column("id")
67 def filter_by_id(self, id_: PageIdentifier) -> "PageQuery":
68 self._query.and_where(PagesTable.field("id").eq(id_.value))
69 return self
71 def filter_by_application(self, application: int | str) -> "PageQuery":
72 if isinstance(application, str):
73 self._query.and_where(ApplicationsTable.field("name").eq(application))
74 else:
75 self._query.and_where(ApplicationsTable.field("id").eq(application))
76 return self
78 def filter_by_active(self) -> "PageQuery":
79 self._query.and_where(PagesTable.field("enabled").eq(1))
80 return self
82 def filter_by_user(self, user: int | UniqueId) -> "PageQuery":
83 inner_select = (
84 self._database.create_query_factory()
85 .select(OwnersTable.column("id"))
86 .from_(OwnersTable.table_name)
87 )
88 if isinstance(user, UniqueId):
89 inner_select.where(OwnersTable.field("uuid").eq(str(user)))
90 else:
91 inner_select.where(OwnersTable.field("id").eq(user))
93 self._main_query.and_where(
94 group(PageContentsTable.field("user_id").in_(express("%s", inner_select)))
95 )
96 return self
98 def fetch(
99 self, limit: int | None = None, offset: int | None = None
100 ) -> AsyncIterator[dict[str, any]]:
101 self._query.limit(limit)
102 self._query.offset(offset)
103 self._query.columns(PagesTable.column("id"))
104 self._main_query.order_by(PagesTable.column("priority"))
105 self._main_query.order_by(PagesTable.column("id"))
107 return self._database.fetch(self._main_query)