Coverage for kwai/modules/news/stories/story_db_repository.py: 81%

62 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module that implements a story repository for a database.""" 

2from typing import Any, AsyncIterator 

3 

4from sql_smith.functions import field 

5 

6from kwai.core.db.database import Database 

7from kwai.core.db.rows import OwnersTable 

8from kwai.core.domain.entity import Entity 

9from kwai.modules.news.stories.story import StoryEntity, StoryIdentifier 

10from kwai.modules.news.stories.story_db_query import StoryDbQuery 

11from kwai.modules.news.stories.story_query import StoryQuery 

12from kwai.modules.news.stories.story_repository import ( 

13 StoryNotFoundException, 

14 StoryRepository, 

15) 

16from kwai.modules.news.stories.story_tables import ( 

17 ApplicationsTable, 

18 StoriesTable, 

19 StoryContentRow, 

20 StoryContentsTable, 

21 StoryRow, 

22) 

23 

24 

25def _create_entity(rows: list[dict[str, Any]]) -> StoryEntity: 

26 """Create a story entity from a group of rows.""" 

27 return StoriesTable(rows[0]).create_entity( 

28 ApplicationsTable(rows[0]).create_application(), 

29 [ 

30 StoryContentsTable(row).create_content( 

31 author=OwnersTable(row).create_owner() 

32 ) 

33 for row in rows 

34 ], 

35 ) 

36 

37 

38class StoryDbRepository(StoryRepository): 

39 """A story database repository. 

40 

41 Attributes: 

42 _database: the database for the repository. 

43 """ 

44 

45 def __init__(self, database: Database): 

46 self._database = database 

47 

48 async def create(self, story: StoryEntity) -> StoryEntity: 

49 new_id = await self._database.insert( 

50 StoriesTable.table_name, StoryRow.persist(story) 

51 ) 

52 result = Entity.replace(story, id_=StoryIdentifier(new_id)) 

53 

54 content_rows = [ 

55 StoryContentRow.persist(result, content) for content in story.content 

56 ] 

57 await self._database.insert(StoryContentsTable.table_name, *content_rows) 

58 

59 await self._database.commit() 

60 return result 

61 

62 async def update(self, story: StoryEntity): 

63 await self._database.update( 

64 story.id.value, StoriesTable.table_name, StoryRow.persist(story) 

65 ) 

66 

67 delete_contents_query = ( 

68 await self._database.create_query_factory() 

69 .delete(StoryContentsTable.table_name) 

70 .where(field("news_id").eq(story.id.value)) 

71 ) 

72 await self._database.execute(delete_contents_query) 

73 

74 content_rows = [ 

75 StoryContentRow.persist(story, content) for content in story.content 

76 ] 

77 await self._database.insert(StoryContentsTable.table_name, *content_rows) 

78 await self._database.commit() 

79 

80 async def delete(self, story: StoryEntity): 

81 delete_contents_query = ( 

82 self._database.create_query_factory() 

83 .delete(StoryContentsTable.table_name) 

84 .where(field("news_id").eq(story.id.value)) 

85 ) 

86 await self._database.execute(delete_contents_query) 

87 await self._database.delete(story.id.value, StoriesTable.table_name) 

88 await self._database.commit() 

89 

90 def create_query(self) -> StoryQuery: 

91 return StoryDbQuery(self._database) 

92 

93 async def get_by_id(self, id_: StoryIdentifier) -> StoryEntity: 

94 query = self.create_query() 

95 query.filter_by_id(id_) 

96 

97 entity = await anext(self.get_all(query, 1), None) 

98 if entity is None: 

99 raise StoryNotFoundException(f"Story with {id} does not exist.") 

100 

101 return entity 

102 

103 async def get_all( 

104 self, 

105 query: StoryQuery | None = None, 

106 limit: int | None = None, 

107 offset: int | None = None, 

108 ) -> AsyncIterator[StoryEntity]: 

109 if query is None: 

110 query = self.create_query() 

111 

112 group_by_column = StoriesTable.alias_name("id") 

113 

114 row_it = query.fetch(limit, offset) 

115 

116 # Handle the first row 

117 try: 

118 row = await anext(row_it) 

119 except StopAsyncIteration: 

120 return 

121 

122 group = [row] 

123 current_key = row[group_by_column] 

124 

125 # Process all other rows 

126 async for row in row_it: 

127 new_key = row[group_by_column] 

128 if new_key != current_key: 

129 yield _create_entity(group) 

130 group = [row] 

131 current_key = new_key 

132 else: 

133 group.append(row) 

134 

135 yield _create_entity(group)