Coverage for src/kwai/modules/teams/repositories/member_db_repository.py: 97%

65 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module for defining a team member repository for a database.""" 

2 

3from dataclasses import dataclass 

4from typing import AsyncGenerator, Self 

5 

6from sql_smith.functions import express, on 

7 

8from kwai.core.db.database import Database 

9from kwai.core.db.database_query import DatabaseQuery 

10from kwai.core.db.table_row import JoinedTableRow 

11from kwai.core.domain.value_objects.date import Date 

12from kwai.core.domain.value_objects.name import Name 

13from kwai.core.domain.value_objects.unique_id import UniqueId 

14from kwai.modules.club.domain.value_objects import Birthdate, Gender, License 

15from kwai.modules.teams.domain.team import TeamIdentifier 

16from kwai.modules.teams.domain.team_member import MemberEntity, MemberIdentifier 

17from kwai.modules.teams.repositories._tables import ( 

18 CountryRow, 

19 MemberPersonRow, 

20 MemberRow, 

21 TeamMemberRow, 

22) 

23from kwai.modules.teams.repositories.member_repository import ( 

24 MemberNotFoundException, 

25 MemberQuery, 

26 MemberRepository, 

27) 

28 

29 

30@dataclass(kw_only=True, frozen=True, slots=True) 

31class MemberQueryRow(JoinedTableRow): 

32 """A data transfer object for the member query.""" 

33 

34 member: MemberRow 

35 person: MemberPersonRow 

36 country: CountryRow 

37 

38 def create_entity(self) -> MemberEntity: 

39 """Create a team member entity from a row.""" 

40 return MemberEntity( 

41 id_=MemberIdentifier(self.member.id), 

42 uuid=UniqueId.create_from_string(self.member.uuid), 

43 name=Name(first_name=self.person.firstname, last_name=self.person.lastname), 

44 license=License( 

45 number=self.member.license, 

46 end_date=Date.create_from_date(self.member.license_end_date), 

47 ), 

48 birthdate=Birthdate(Date.create_from_date(self.person.birthdate)), 

49 gender=Gender(self.person.gender), 

50 nationality=self.country.create_country(), 

51 active_in_club=self.member.active == 1, 

52 ) 

53 

54 

55class MemberDbQuery(MemberQuery, DatabaseQuery): 

56 """A team member query for a database.""" 

57 

58 def __init__(self, database: Database): 

59 super().__init__(database) 

60 

61 def init(self): 

62 self._query.from_(MemberRow.__table_name__).inner_join( 

63 MemberPersonRow.__table_name__, 

64 on(MemberPersonRow.column("id"), MemberRow.column("person_id")), 

65 ).inner_join( 

66 CountryRow.__table_name__, 

67 on(CountryRow.column("id"), MemberPersonRow.column("nationality_id")), 

68 ) 

69 

70 @property 

71 def columns(self): 

72 return MemberQueryRow.get_aliases() 

73 

74 @property 

75 def count_column(self): 

76 return MemberRow.column("id") 

77 

78 def filter_by_id(self, id_: MemberIdentifier) -> Self: 

79 self._query.and_where(MemberRow.field("id").eq(id_.value)) 

80 return self 

81 

82 def filter_by_birthdate( 

83 self, start_date: Date, end_date: Date | None = None 

84 ) -> Self: 

85 if end_date is None: 

86 self._query.and_where(MemberPersonRow.field("birthdate").gte(start_date)) 

87 else: 

88 self._query.and_where( 

89 MemberPersonRow.field("birthdate").between(start_date, end_date) 

90 ) 

91 return self 

92 

93 def filter_by_uuid(self, uuid: UniqueId) -> Self: 

94 self._query.and_where(MemberRow.field("uuid").eq(str(uuid))) 

95 return self 

96 

97 def filter_by_team(self, team_id: TeamIdentifier, in_team: bool = True) -> Self: 

98 inner_select = ( 

99 self._database.create_query_factory() 

100 .select() 

101 .columns(TeamMemberRow.column("member_id")) 

102 .from_(TeamMemberRow.__table_name__) 

103 .where(TeamMemberRow.field("team_id").eq(team_id.value)) 

104 ) 

105 if in_team: 

106 condition = MemberRow.field("id").in_(express("{}", inner_select)) 

107 else: 

108 condition = MemberRow.field("id").not_in(express("{}", inner_select)) 

109 self._query.and_where(condition) 

110 return self 

111 

112 

113class MemberDbRepository(MemberRepository): 

114 """A member repository for a database.""" 

115 

116 def __init__(self, database: Database): 

117 self._database = database 

118 

119 def create_query(self) -> MemberQuery: 

120 return MemberDbQuery(self._database) 

121 

122 async def get(self, query: MemberQuery | None = None) -> MemberEntity: 

123 team_member_iterator = self.get_all(query) 

124 try: 

125 return await anext(team_member_iterator) 

126 except StopAsyncIteration: 

127 raise MemberNotFoundException("Member not found") from None 

128 

129 async def get_all( 

130 self, 

131 query: MemberQuery | None = None, 

132 limit: int | None = None, 

133 offset: int | None = None, 

134 ) -> AsyncGenerator[MemberEntity, None]: 

135 query = query or self.create_query() 

136 

137 async for row in query.fetch(limit, offset): 

138 yield MemberQueryRow.map(row).create_entity()