Coverage for kwai/modules/news/stories/story_db_repository.py: 81%
62 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module that implements a story repository for a database."""
2from typing import Any, AsyncIterator
4from sql_smith.functions import field
6from kwai.core.db.database import Database
7from kwai.core.db.rows import OwnersTable
8from kwai.core.domain.entity import Entity
9from kwai.modules.news.stories.story import StoryEntity, StoryIdentifier
10from kwai.modules.news.stories.story_db_query import StoryDbQuery
11from kwai.modules.news.stories.story_query import StoryQuery
12from kwai.modules.news.stories.story_repository import (
13 StoryNotFoundException,
14 StoryRepository,
15)
16from kwai.modules.news.stories.story_tables import (
17 ApplicationsTable,
18 StoriesTable,
19 StoryContentRow,
20 StoryContentsTable,
21 StoryRow,
22)
25def _create_entity(rows: list[dict[str, Any]]) -> StoryEntity:
26 """Create a story entity from a group of rows."""
27 return StoriesTable(rows[0]).create_entity(
28 ApplicationsTable(rows[0]).create_application(),
29 [
30 StoryContentsTable(row).create_content(
31 author=OwnersTable(row).create_owner()
32 )
33 for row in rows
34 ],
35 )
38class StoryDbRepository(StoryRepository):
39 """A story database repository.
41 Attributes:
42 _database: the database for the repository.
43 """
45 def __init__(self, database: Database):
46 self._database = database
48 async def create(self, story: StoryEntity) -> StoryEntity:
49 new_id = await self._database.insert(
50 StoriesTable.table_name, StoryRow.persist(story)
51 )
52 result = Entity.replace(story, id_=StoryIdentifier(new_id))
54 content_rows = [
55 StoryContentRow.persist(result, content) for content in story.content
56 ]
57 await self._database.insert(StoryContentsTable.table_name, *content_rows)
59 await self._database.commit()
60 return result
62 async def update(self, story: StoryEntity):
63 await self._database.update(
64 story.id.value, StoriesTable.table_name, StoryRow.persist(story)
65 )
67 delete_contents_query = (
68 await self._database.create_query_factory()
69 .delete(StoryContentsTable.table_name)
70 .where(field("news_id").eq(story.id.value))
71 )
72 await self._database.execute(delete_contents_query)
74 content_rows = [
75 StoryContentRow.persist(story, content) for content in story.content
76 ]
77 await self._database.insert(StoryContentsTable.table_name, *content_rows)
78 await self._database.commit()
80 async def delete(self, story: StoryEntity):
81 delete_contents_query = (
82 self._database.create_query_factory()
83 .delete(StoryContentsTable.table_name)
84 .where(field("news_id").eq(story.id.value))
85 )
86 await self._database.execute(delete_contents_query)
87 await self._database.delete(story.id.value, StoriesTable.table_name)
88 await self._database.commit()
90 def create_query(self) -> StoryQuery:
91 return StoryDbQuery(self._database)
93 async def get_by_id(self, id_: StoryIdentifier) -> StoryEntity:
94 query = self.create_query()
95 query.filter_by_id(id_)
97 entity = await anext(self.get_all(query, 1), None)
98 if entity is None:
99 raise StoryNotFoundException(f"Story with {id} does not exist.")
101 return entity
103 async def get_all(
104 self,
105 query: StoryQuery | None = None,
106 limit: int | None = None,
107 offset: int | None = None,
108 ) -> AsyncIterator[StoryEntity]:
109 if query is None:
110 query = self.create_query()
112 group_by_column = StoriesTable.alias_name("id")
114 row_it = query.fetch(limit, offset)
116 # Handle the first row
117 try:
118 row = await anext(row_it)
119 except StopAsyncIteration:
120 return
122 group = [row]
123 current_key = row[group_by_column]
125 # Process all other rows
126 async for row in row_it:
127 new_key = row[group_by_column]
128 if new_key != current_key:
129 yield _create_entity(group)
130 group = [row]
131 current_key = new_key
132 else:
133 group.append(row)
135 yield _create_entity(group)