Coverage for kwai/modules/news/stories/story_db_query.py: 91%

64 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module that implements a StoryQuery for a database.""" 

2from datetime import datetime 

3from typing import AsyncIterator 

4 

5from sql_smith.functions import criteria, express, func, group, literal, on 

6 

7from kwai.core.db.database import Database 

8from kwai.core.db.database_query import DatabaseQuery 

9from kwai.core.db.rows import OwnersTable 

10from kwai.core.domain.value_objects.unique_id import UniqueId 

11from kwai.modules.news.stories.story import StoryIdentifier 

12from kwai.modules.news.stories.story_query import StoryQuery 

13from kwai.modules.news.stories.story_tables import ( 

14 ApplicationsTable, 

15 StoriesTable, 

16 StoryContentsTable, 

17) 

18 

19 

20class StoryDbQuery(StoryQuery, DatabaseQuery): 

21 """A database query for news stories.""" 

22 

23 def __init__(self, database: Database): 

24 self._main_query = database.create_query_factory().select() 

25 super().__init__(database) 

26 

27 def init(self): 

28 # This query will be used as CTE, so only join the tables that are needed 

29 # for counting and limiting the results. 

30 self._query.from_(StoriesTable.table_name).join( 

31 ApplicationsTable.table_name, 

32 on( 

33 ApplicationsTable.column("id"), 

34 StoriesTable.column("application_id"), 

35 ), 

36 ) 

37 

38 self._main_query = ( 

39 self._main_query.from_(StoriesTable.table_name) 

40 .columns(*(self.columns + StoryContentsTable.aliases())) 

41 .with_("limited", self._query) 

42 .right_join("limited", on("limited.id", StoriesTable.column("id"))) 

43 .join( 

44 ApplicationsTable.table_name, 

45 on( 

46 ApplicationsTable.column("id"), 

47 StoriesTable.column("application_id"), 

48 ), 

49 ) 

50 .join( 

51 StoryContentsTable.table_name, 

52 on(StoryContentsTable.column("news_id"), StoriesTable.column("id")), 

53 ) 

54 .join( 

55 OwnersTable.table_name, 

56 on(OwnersTable.column("id"), StoryContentsTable.column("user_id")), 

57 ) 

58 ) 

59 

60 @property 

61 def columns(self): 

62 return ( 

63 StoriesTable.aliases() + ApplicationsTable.aliases() + OwnersTable.aliases() 

64 ) 

65 

66 @property 

67 def count_column(self) -> str: 

68 return StoriesTable.column("id") 

69 

70 def filter_by_id(self, id_: StoryIdentifier) -> "StoryQuery": 

71 self._query.and_where(StoriesTable.field("id").eq(id_.value)) 

72 return self 

73 

74 def filter_by_publication_date( 

75 self, year: int, month: int | None = None 

76 ) -> "StoryQuery": 

77 condition = criteria( 

78 "{} = {}", func("YEAR", StoriesTable.column("publish_date")), literal(year) 

79 ) 

80 if month is not None: 

81 condition.and_( 

82 criteria( 

83 "{} = {}", 

84 func("MONTH", StoriesTable.column("publish_date")), 

85 literal(month), 

86 ) 

87 ) 

88 self._query.and_where(condition) 

89 return self 

90 

91 def filter_by_promoted(self) -> "StoryQuery": 

92 now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") 

93 condition = ( 

94 StoriesTable.field("promotion") 

95 .gt(0) 

96 .and_( 

97 group( 

98 StoriesTable.field("promotion_end_date") 

99 .is_null() 

100 .or_(StoriesTable.field("promotion_end_date").gt(now)) 

101 ) 

102 ) 

103 ) 

104 self._query.and_where(condition) 

105 self._query.order_by(StoriesTable.column("promotion")) 

106 return self 

107 

108 def filter_by_application(self, application: int | str) -> "StoryQuery": 

109 if isinstance(application, str): 

110 self._query.and_where(ApplicationsTable.field("name").eq(application)) 

111 else: 

112 self._query.and_where(ApplicationsTable.field("id").eq(application)) 

113 

114 return self 

115 

116 def filter_by_active(self) -> "StoryQuery": 

117 now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") 

118 self._query.and_where( 

119 group( 

120 StoriesTable.field("enabled") 

121 .eq(True) 

122 .and_(StoriesTable.field("publish_date").lte(now)) 

123 .or_( 

124 group( 

125 StoriesTable.field("end_date") 

126 .is_not_null() 

127 .and_(StoriesTable.field("end_date").gt(now)) 

128 ) 

129 ) 

130 ) 

131 ) 

132 

133 return self 

134 

135 def filter_by_user(self, user: int | UniqueId) -> "StoryQuery": 

136 inner_select = ( 

137 self._database.create_query_factory() 

138 .select(OwnersTable.column("id")) 

139 .from_(OwnersTable.table_name) 

140 ) 

141 if isinstance(user, UniqueId): 

142 inner_select.where(OwnersTable.field("uuid").eq(str(user))) 

143 else: 

144 inner_select.where(OwnersTable.field("id").eq(user)) 

145 

146 self._main_query.and_where( 

147 group(StoryContentsTable.field("user_id").in_(express("%s", inner_select))) 

148 ) 

149 return self 

150 

151 def order_by_publication_date(self) -> "StoryQuery": 

152 self._main_query.order_by(StoriesTable.column("publish_date"), "DESC") 

153 # Also add the order to the CTE 

154 self._query.order_by(StoriesTable.column("publish_date"), "DESC") 

155 return self 

156 

157 def fetch( 

158 self, limit: int | None = None, offset: int | None = None 

159 ) -> AsyncIterator[dict[str, any]]: 

160 self._query.limit(limit) 

161 self._query.offset(offset) 

162 self._query.columns(StoriesTable.column("id")) 

163 self._main_query.order_by(StoriesTable.column("id")) 

164 

165 return self._database.fetch(self._main_query)