Coverage for src/kwai/modules/portal/pages/page_db_repository.py: 100%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements a page repository for a database.""" 

2 

3from typing import Any, AsyncIterator 

4 

5from sql_smith.functions import field 

6 

7from kwai.core.db.database import Database 

8from kwai.core.db.rows import OwnersTable 

9from kwai.core.domain.entity import Entity 

10from kwai.core.functions import async_groupby 

11from kwai.modules.portal.applications.application_tables import ApplicationsTable 

12from kwai.modules.portal.pages.page import PageEntity, PageIdentifier 

13from kwai.modules.portal.pages.page_db_query import PageDbQuery 

14from kwai.modules.portal.pages.page_query import PageQuery 

15from kwai.modules.portal.pages.page_repository import ( 

16 PageNotFoundException, 

17 PageRepository, 

18) 

19from kwai.modules.portal.pages.page_tables import ( 

20 PageContentsTable, 

21 PageRow, 

22 PagesTable, 

23 PageTextRow, 

24) 

25 

26 

27def _create_entity(rows: list[dict[str, Any]]) -> PageEntity: 

28 """Create a news item entity from a group of rows.""" 

29 return PagesTable(rows[0]).create_entity( 

30 ApplicationsTable(rows[0]).create_entity(), 

31 [ 

32 PageContentsTable(row).create_text(author=OwnersTable(row).create_owner()) 

33 for row in rows 

34 ], 

35 ) 

36 

37 

38class PageDbRepository(PageRepository): 

39 """Page repository for a database.""" 

40 

41 def __init__(self, database: Database): 

42 self._database = database 

43 

44 async def create(self, page: PageEntity) -> PageEntity: 

45 new_id = await self._database.insert( 

46 PagesTable.table_name, PageRow.persist(page) 

47 ) 

48 result = Entity.replace(page, id_=PageIdentifier(new_id)) 

49 

50 content_rows = [PageTextRow.persist(result, content) for content in page.texts] 

51 await self._database.insert(PageContentsTable.table_name, *content_rows) 

52 

53 await self._database.commit() 

54 return result 

55 

56 async def update(self, page: PageEntity): 

57 await self._database.update( 

58 page.id.value, PagesTable.table_name, PageRow.persist(page) 

59 ) 

60 

61 delete_contents_query = ( 

62 self._database.create_query_factory() 

63 .delete(PageContentsTable.table_name) 

64 .where(field("page_id").eq(page.id.value)) 

65 ) 

66 await self._database.execute(delete_contents_query) 

67 

68 content_rows = [PageTextRow.persist(page, content) for content in page.texts] 

69 await self._database.insert(PageContentsTable.table_name, *content_rows) 

70 await self._database.commit() 

71 

72 async def delete(self, page: PageEntity): 

73 delete_contents_query = ( 

74 self._database.create_query_factory() 

75 .delete(PageContentsTable.table_name) 

76 .where(field("page_id").eq(page.id.value)) 

77 ) 

78 await self._database.execute(delete_contents_query) 

79 await self._database.delete(page.id.value, PagesTable.table_name) 

80 await self._database.commit() 

81 

82 def create_query(self) -> PageQuery: 

83 return PageDbQuery(self._database) 

84 

85 async def get_by_id(self, id_: PageIdentifier) -> PageEntity: 

86 query = self.create_query() 

87 query.filter_by_id(id_) 

88 

89 entity = await anext(self.get_all(query, 1), None) 

90 if entity is None: 

91 raise PageNotFoundException(f"Page with {id_} does not exist.") 

92 

93 return entity 

94 

95 async def get_all( 

96 self, 

97 query: PageQuery | None = None, 

98 limit: int | None = None, 

99 offset: int | None = None, 

100 ) -> AsyncIterator[PageEntity]: 

101 if query is None: 

102 query = self.create_query() 

103 

104 group_by_column = PagesTable.alias_name("id") 

105 

106 row_iterator = query.fetch(limit, offset) 

107 async for _, group in async_groupby( 

108 row_iterator, key=lambda row: row[group_by_column] 

109 ): 

110 yield _create_entity(group)