Coverage for src/kwai/modules/portal/news/news_item_db_repository.py: 90%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements a news item repository for a database.""" 

2 

3from typing import Any, AsyncIterator 

4 

5from sql_smith.functions import field 

6 

7from kwai.core.db.database import Database 

8from kwai.core.db.rows import OwnersTable 

9from kwai.core.domain.entity import Entity 

10from kwai.modules.portal.applications.application_tables import ApplicationsTable 

11from kwai.modules.portal.news.news_item import NewsItemEntity, NewsItemIdentifier 

12from kwai.modules.portal.news.news_item_db_query import NewsItemDbQuery 

13from kwai.modules.portal.news.news_item_query import NewsItemQuery 

14from kwai.modules.portal.news.news_item_repository import ( 

15 NewsItemNotFoundException, 

16 NewsItemRepository, 

17) 

18from kwai.modules.portal.news.news_tables import ( 

19 NewsItemRow, 

20 NewsItemsTable, 

21 NewsItemTextRow, 

22 NewsItemTextsTable, 

23) 

24 

25 

26def _create_entity(rows: list[dict[str, Any]]) -> NewsItemEntity: 

27 """Create a news item entity from a group of rows.""" 

28 return NewsItemsTable(rows[0]).create_entity( 

29 ApplicationsTable(rows[0]).create_entity(), 

30 [ 

31 NewsItemTextsTable(row).create_text(author=OwnersTable(row).create_owner()) 

32 for row in rows 

33 ], 

34 ) 

35 

36 

37class NewsItemDbRepository(NewsItemRepository): 

38 """A news item database repository. 

39 

40 Attributes: 

41 _database: the database for the repository. 

42 """ 

43 

44 def __init__(self, database: Database): 

45 self._database = database 

46 

47 async def create(self, news_item: NewsItemEntity) -> NewsItemEntity: 

48 new_id = await self._database.insert( 

49 NewsItemsTable.table_name, NewsItemRow.persist(news_item) 

50 ) 

51 result = Entity.replace(news_item, id_=NewsItemIdentifier(new_id)) 

52 

53 content_rows = [ 

54 NewsItemTextRow.persist(result, content) for content in news_item.texts 

55 ] 

56 await self._database.insert(NewsItemTextsTable.table_name, *content_rows) 

57 

58 await self._database.commit() 

59 return result 

60 

61 async def update(self, news_item: NewsItemEntity): 

62 await self._database.update( 

63 news_item.id.value, 

64 NewsItemsTable.table_name, 

65 NewsItemRow.persist(news_item), 

66 ) 

67 

68 delete_contents_query = ( 

69 self._database.create_query_factory() 

70 .delete(NewsItemTextsTable.table_name) 

71 .where(field("news_id").eq(news_item.id.value)) 

72 ) 

73 await self._database.execute(delete_contents_query) 

74 

75 content_rows = [ 

76 NewsItemTextRow.persist(news_item, content) for content in news_item.texts 

77 ] 

78 await self._database.insert(NewsItemTextsTable.table_name, *content_rows) 

79 await self._database.commit() 

80 

81 async def delete(self, news_item: NewsItemEntity): 

82 delete_contents_query = ( 

83 self._database.create_query_factory() 

84 .delete(NewsItemTextsTable.table_name) 

85 .where(field("news_id").eq(news_item.id.value)) 

86 ) 

87 await self._database.execute(delete_contents_query) 

88 await self._database.delete(news_item.id.value, NewsItemsTable.table_name) 

89 await self._database.commit() 

90 

91 def create_query(self) -> NewsItemQuery: 

92 return NewsItemDbQuery(self._database) 

93 

94 async def get_by_id(self, id_: NewsItemIdentifier) -> NewsItemEntity: 

95 query = self.create_query() 

96 query.filter_by_id(id_) 

97 

98 entity = await anext(self.get_all(query, 1), None) 

99 if entity is None: 

100 raise NewsItemNotFoundException(f"News item with {id_} does not exist.") 

101 

102 return entity 

103 

104 async def get_all( 

105 self, 

106 query: NewsItemQuery | None = None, 

107 limit: int | None = None, 

108 offset: int | None = None, 

109 ) -> AsyncIterator[NewsItemEntity]: 

110 if query is None: 

111 query = self.create_query() 

112 

113 group_by_column = NewsItemsTable.alias_name("id") 

114 

115 row_it = query.fetch(limit, offset) 

116 

117 # Handle the first row 

118 try: 

119 row = await anext(row_it) 

120 except StopAsyncIteration: 

121 return 

122 

123 group = [row] 

124 current_key = row[group_by_column] 

125 

126 # Process all other rows 

127 async for row in row_it: 

128 new_key = row[group_by_column] 

129 if new_key != current_key: 

130 yield _create_entity(group) 

131 group = [row] 

132 current_key = new_key 

133 else: 

134 group.append(row) 

135 

136 yield _create_entity(group)