Coverage for src/kwai/modules/portal/pages/page_db_query.py: 90%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements a PageQuery for a database.""" 

2 

3from typing import AsyncIterator 

4 

5from sql_smith.functions import express, group, on 

6 

7from kwai.core.db.database import Database 

8from kwai.core.db.database_query import DatabaseQuery 

9from kwai.core.db.rows import OwnersTable 

10from kwai.core.domain.value_objects.unique_id import UniqueId 

11from kwai.modules.portal.applications.application_tables import ApplicationsTable 

12from kwai.modules.portal.pages.page import PageIdentifier 

13from kwai.modules.portal.pages.page_query import PageQuery 

14from kwai.modules.portal.pages.page_tables import ( 

15 PageContentsTable, 

16 PagesTable, 

17) 

18 

19 

20class PageDbQuery(PageQuery, DatabaseQuery): 

21 """A database query for pages.""" 

22 

23 def __init__(self, database: Database): 

24 self._main_query = database.create_query_factory().select() 

25 super().__init__(database) 

26 

27 def init(self): 

28 self._query.from_(PagesTable.table_name).join( 

29 ApplicationsTable.table_name, 

30 on( 

31 ApplicationsTable.column("id"), 

32 PagesTable.column("application_id"), 

33 ), 

34 ) 

35 self._main_query = ( 

36 self._main_query.from_(PagesTable.table_name) 

37 .columns(*(self.columns + PageContentsTable.aliases())) 

38 .with_("limited", self._query) 

39 .right_join("limited", on("limited.id", PagesTable.column("id"))) 

40 .join( 

41 ApplicationsTable.table_name, 

42 on( 

43 ApplicationsTable.column("id"), 

44 PagesTable.column("application_id"), 

45 ), 

46 ) 

47 .join( 

48 PageContentsTable.table_name, 

49 on(PageContentsTable.column("page_id"), PagesTable.column("id")), 

50 ) 

51 .join( 

52 OwnersTable.table_name, 

53 on(OwnersTable.column("id"), PageContentsTable.column("user_id")), 

54 ) 

55 ) 

56 

57 @property 

58 def columns(self): 

59 return ( 

60 PagesTable.aliases() + ApplicationsTable.aliases() + OwnersTable.aliases() 

61 ) 

62 

63 @property 

64 def count_column(self) -> str: 

65 return PagesTable.column("id") 

66 

67 def filter_by_id(self, id_: PageIdentifier) -> "PageQuery": 

68 self._query.and_where(PagesTable.field("id").eq(id_.value)) 

69 return self 

70 

71 def filter_by_application(self, application: int | str) -> "PageQuery": 

72 if isinstance(application, str): 

73 self._query.and_where(ApplicationsTable.field("name").eq(application)) 

74 else: 

75 self._query.and_where(ApplicationsTable.field("id").eq(application)) 

76 return self 

77 

78 def filter_by_active(self) -> "PageQuery": 

79 self._query.and_where(PagesTable.field("enabled").eq(1)) 

80 return self 

81 

82 def filter_by_user(self, user: int | UniqueId) -> "PageQuery": 

83 inner_select = ( 

84 self._database.create_query_factory() 

85 .select(OwnersTable.column("id")) 

86 .from_(OwnersTable.table_name) 

87 ) 

88 if isinstance(user, UniqueId): 

89 inner_select.where(OwnersTable.field("uuid").eq(str(user))) 

90 else: 

91 inner_select.where(OwnersTable.field("id").eq(user)) 

92 

93 self._main_query.and_where( 

94 group(PageContentsTable.field("user_id").in_(express("%s", inner_select))) 

95 ) 

96 return self 

97 

98 def fetch( 

99 self, limit: int | None = None, offset: int | None = None 

100 ) -> AsyncIterator[dict[str, any]]: 

101 self._query.limit(limit) 

102 self._query.offset(offset) 

103 self._query.columns(PagesTable.column("id")) 

104 self._main_query.order_by(PagesTable.column("priority")) 

105 self._main_query.order_by(PagesTable.column("id")) 

106 

107 return self._database.fetch(self._main_query)