Coverage for src/kwai/modules/portal/news/news_item_db_query.py: 91%

65 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements a NewsItemQuery for a database.""" 

2 

3from typing import AsyncIterator 

4 

5from sql_smith.functions import criteria, express, func, group, literal, on 

6 

7from kwai.core.db.database import Database 

8from kwai.core.db.database_query import DatabaseQuery 

9from kwai.core.db.rows import OwnersTable 

10from kwai.core.domain.value_objects.timestamp import Timestamp 

11from kwai.core.domain.value_objects.unique_id import UniqueId 

12from kwai.modules.portal.applications.application_tables import ApplicationsTable 

13from kwai.modules.portal.news.news_item import NewsItemIdentifier 

14from kwai.modules.portal.news.news_item_query import NewsItemQuery 

15from kwai.modules.portal.news.news_tables import ( 

16 NewsItemsTable, 

17 NewsItemTextsTable, 

18) 

19 

20 

21class NewsItemDbQuery(NewsItemQuery, DatabaseQuery): 

22 """A database query for news stories.""" 

23 

24 def __init__(self, database: Database): 

25 self._main_query = database.create_query_factory().select() 

26 super().__init__(database) 

27 

28 def init(self): 

29 # This query will be used as CTE, so only join the tables that are needed 

30 # for counting and limiting the results. 

31 self._query.from_(NewsItemsTable.table_name).join( 

32 ApplicationsTable.table_name, 

33 on( 

34 ApplicationsTable.column("id"), 

35 NewsItemsTable.column("application_id"), 

36 ), 

37 ) 

38 

39 self._main_query = ( 

40 self._main_query.from_(NewsItemsTable.table_name) 

41 .columns(*(self.columns + NewsItemTextsTable.aliases())) 

42 .with_("limited", self._query) 

43 .right_join("limited", on("limited.id", NewsItemsTable.column("id"))) 

44 .join( 

45 ApplicationsTable.table_name, 

46 on( 

47 ApplicationsTable.column("id"), 

48 NewsItemsTable.column("application_id"), 

49 ), 

50 ) 

51 .join( 

52 NewsItemTextsTable.table_name, 

53 on(NewsItemTextsTable.column("news_id"), NewsItemsTable.column("id")), 

54 ) 

55 .join( 

56 OwnersTable.table_name, 

57 on(OwnersTable.column("id"), NewsItemTextsTable.column("user_id")), 

58 ) 

59 ) 

60 

61 @property 

62 def columns(self): 

63 return ( 

64 NewsItemsTable.aliases() 

65 + ApplicationsTable.aliases() 

66 + OwnersTable.aliases() 

67 ) 

68 

69 @property 

70 def count_column(self) -> str: 

71 return NewsItemsTable.column("id") 

72 

73 def filter_by_id(self, id_: NewsItemIdentifier) -> "NewsItemQuery": 

74 self._query.and_where(NewsItemsTable.field("id").eq(id_.value)) 

75 return self 

76 

77 def filter_by_publication_date( 

78 self, year: int, month: int | None = None 

79 ) -> "NewsItemQuery": 

80 condition = criteria( 

81 "{} = {}", 

82 func("YEAR", NewsItemsTable.column("publish_date")), 

83 literal(year), 

84 ) 

85 if month is not None: 

86 condition.and_( 

87 criteria( 

88 "{} = {}", 

89 func("MONTH", NewsItemsTable.column("publish_date")), 

90 literal(month), 

91 ) 

92 ) 

93 self._query.and_where(condition) 

94 return self 

95 

96 def filter_by_promoted(self) -> "NewsItemQuery": 

97 now = str(Timestamp.create_now()) 

98 condition = ( 

99 NewsItemsTable.field("promotion") 

100 .gt(0) 

101 .and_( 

102 group( 

103 NewsItemsTable.field("promotion_end_date") 

104 .is_null() 

105 .or_(NewsItemsTable.field("promotion_end_date").gt(now)) 

106 ) 

107 ) 

108 ) 

109 self._query.and_where(condition) 

110 self._query.order_by(NewsItemsTable.column("promotion")) 

111 return self 

112 

113 def filter_by_application(self, application: int | str) -> "NewsItemQuery": 

114 if isinstance(application, str): 

115 self._query.and_where(ApplicationsTable.field("name").eq(application)) 

116 else: 

117 self._query.and_where(ApplicationsTable.field("id").eq(application)) 

118 

119 return self 

120 

121 def filter_by_active(self) -> "NewsItemQuery": 

122 now = str(Timestamp.create_now()) 

123 self._query.and_where( 

124 group( 

125 NewsItemsTable.field("enabled") 

126 .eq(True) 

127 .and_(NewsItemsTable.field("publish_date").lte(now)) 

128 .or_( 

129 group( 

130 NewsItemsTable.field("end_date") 

131 .is_not_null() 

132 .and_(NewsItemsTable.field("end_date").gt(now)) 

133 ) 

134 ) 

135 ) 

136 ) 

137 

138 return self 

139 

140 def filter_by_user(self, user: int | UniqueId) -> "NewsItemQuery": 

141 inner_select = ( 

142 self._database.create_query_factory() 

143 .select(OwnersTable.column("id")) 

144 .from_(OwnersTable.table_name) 

145 ) 

146 if isinstance(user, UniqueId): 

147 inner_select.where(OwnersTable.field("uuid").eq(str(user))) 

148 else: 

149 inner_select.where(OwnersTable.field("id").eq(user)) 

150 

151 self._main_query.and_where( 

152 group(NewsItemTextsTable.field("user_id").in_(express("%s", inner_select))) 

153 ) 

154 return self 

155 

156 def order_by_publication_date(self) -> "NewsItemQuery": 

157 self._main_query.order_by(NewsItemsTable.column("publish_date"), "DESC") 

158 # Also add the order to the CTE 

159 self._query.order_by(NewsItemsTable.column("publish_date"), "DESC") 

160 return self 

161 

162 def fetch( 

163 self, limit: int | None = None, offset: int | None = None 

164 ) -> AsyncIterator[dict[str, any]]: 

165 self._query.limit(limit) 

166 self._query.offset(offset) 

167 self._query.columns(NewsItemsTable.column("id")) 

168 self._main_query.order_by(NewsItemsTable.column("id")) 

169 

170 return self._database.fetch(self._main_query)