Coverage for src/kwai/modules/club/repositories/member_db_repository.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module for defining a member repository using a database.""" 

2 

3from typing import AsyncGenerator 

4 

5from sql_smith.functions import field 

6 

7from kwai.core.db.database import Database 

8from kwai.core.domain.entity import Entity 

9from kwai.modules.club.domain.file_upload import FileUploadEntity 

10from kwai.modules.club.domain.member import MemberEntity, MemberIdentifier 

11from kwai.modules.club.repositories._tables import MemberRow, MemberUploadRow 

12from kwai.modules.club.repositories.member_db_query import MemberDbQuery, MemberQueryRow 

13from kwai.modules.club.repositories.member_query import MemberQuery 

14from kwai.modules.club.repositories.member_repository import ( 

15 MemberNotFoundException, 

16 MemberRepository, 

17) 

18from kwai.modules.club.repositories.person_db_repository import PersonDbRepository 

19 

20 

21class MemberDbRepository(MemberRepository): 

22 """A member repository using a database.""" 

23 

24 def __init__(self, database: Database): 

25 """Initialize the repository. 

26 

27 Args: 

28 database: The database for this repository. 

29 """ 

30 self._database = database 

31 

32 def create_query(self) -> MemberQuery: 

33 return MemberDbQuery(self._database) 

34 

35 async def get_all( 

36 self, 

37 query: MemberQuery | None = None, 

38 limit: int | None = None, 

39 offset: int | None = None, 

40 ) -> AsyncGenerator[MemberEntity, None]: 

41 query = query or self.create_query() 

42 

43 async for row in query.fetch(limit, offset): 

44 yield MemberQueryRow.map(row).create_entity() 

45 

46 async def get(self, query: MemberQuery | None = None) -> MemberEntity: 

47 member_iterator = self.get_all(query) 

48 try: 

49 return await anext(member_iterator) 

50 except StopAsyncIteration: 

51 raise MemberNotFoundException("Member not found") from None 

52 

53 async def create(self, member: MemberEntity) -> MemberEntity: 

54 # When there is no person id, create it. 

55 if member.person.id.is_empty(): 

56 person = await PersonDbRepository(self._database).create(member.person) 

57 member = Entity.replace(member, person=person) 

58 

59 new_id = await self._database.insert( 

60 MemberRow.__table_name__, MemberRow.persist(member) 

61 ) 

62 

63 return Entity.replace(member, id_=MemberIdentifier(new_id)) 

64 

65 async def update(self, member: MemberEntity) -> None: 

66 # Update the member 

67 await self._database.update( 

68 member.id.value, MemberRow.__table_name__, MemberRow.persist(member) 

69 ) 

70 # Update person information 

71 await PersonDbRepository(self._database).update(member.person) 

72 

73 async def delete(self, member: MemberEntity) -> None: 

74 await PersonDbRepository(self._database).delete(member.person) 

75 await self._database.delete(member.id, MemberRow.__table_name__) 

76 

77 async def activate_members(self, upload: FileUploadEntity) -> None: 

78 member_upload_query = ( 

79 self._database.create_query_factory() 

80 .select("member_id") 

81 .from_(MemberUploadRow.__table_name__) 

82 .where(field("import_id").eq(upload.id.value)) 

83 ) 

84 update_query = ( 

85 self._database.create_query_factory() 

86 .update(MemberRow.__table_name__, {"active": 1}) 

87 .where(field("id").in_(member_upload_query)) 

88 ) 

89 await self._database.execute(update_query) 

90 

91 async def deactivate_members(self, upload: FileUploadEntity) -> None: 

92 member_upload_query = ( 

93 self._database.create_query_factory() 

94 .select("member_id") 

95 .from_(MemberUploadRow.__table_name__) 

96 .where(field("import_id").eq(upload.id.value)) 

97 ) 

98 update_query = ( 

99 self._database.create_query_factory() 

100 .update(MemberRow.__table_name__, {"active": 0}) 

101 .where(field("id").not_in(member_upload_query)) 

102 ) 

103 await self._database.execute(update_query)