Coverage for src/kwai/modules/teams/repositories/team_db_repository.py: 96%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements a team repository for a database.""" 

2 

3from dataclasses import dataclass 

4from typing import Any, AsyncGenerator, Self 

5 

6from sql_smith.functions import field, on 

7 

8from kwai.core.db.database import Database 

9from kwai.core.db.database_query import DatabaseQuery 

10from kwai.core.db.table_row import JoinedTableRow 

11from kwai.core.domain.entity import Entity 

12from kwai.core.domain.value_objects.date import Date 

13from kwai.core.domain.value_objects.name import Name 

14from kwai.core.domain.value_objects.unique_id import UniqueId 

15from kwai.core.functions import async_groupby 

16from kwai.modules.club.domain.value_objects import Birthdate, Gender, License 

17from kwai.modules.teams.domain.team import TeamEntity, TeamIdentifier 

18from kwai.modules.teams.domain.team_member import ( 

19 MemberEntity, 

20 MemberIdentifier, 

21 TeamMember, 

22) 

23from kwai.modules.teams.repositories._tables import ( 

24 CountryRow, 

25 MemberPersonRow, 

26 MemberRow, 

27 TeamMemberRow, 

28 TeamRow, 

29) 

30from kwai.modules.teams.repositories.team_repository import ( 

31 TeamNotFoundException, 

32 TeamQuery, 

33 TeamRepository, 

34) 

35 

36 

37@dataclass(kw_only=True, frozen=True, slots=True) 

38class MemberPersonCountryMixin: 

39 """Dataclass for a member related row.""" 

40 

41 member: MemberRow 

42 member_person: MemberPersonRow 

43 country: CountryRow 

44 

45 def create_member_entity(self) -> MemberEntity: 

46 """Create a member entity from a row.""" 

47 return MemberEntity( 

48 id_=MemberIdentifier(self.member.id), 

49 name=Name( 

50 first_name=self.member_person.firstname, 

51 last_name=self.member_person.lastname, 

52 ), 

53 license=License( 

54 number=self.member.license, 

55 end_date=Date.create_from_date(self.member.license_end_date), 

56 ), 

57 birthdate=Birthdate( 

58 date=Date.create_from_date(self.member_person.birthdate) 

59 ), 

60 nationality=self.country.create_country(), 

61 gender=Gender(self.member_person.gender), 

62 uuid=UniqueId.create_from_string(self.member.uuid), 

63 active_in_club=self.member.active == 1, 

64 ) 

65 

66 

67@dataclass(kw_only=True, frozen=True, slots=True) 

68class TeamQueryRow(MemberPersonCountryMixin, JoinedTableRow): 

69 """A data transfer object for the team query.""" 

70 

71 team: TeamRow 

72 team_member: TeamMemberRow 

73 

74 @classmethod 

75 def create_entity(cls, rows: list[dict[str, Any]]) -> TeamEntity: 

76 """Create a team entity from a group of rows.""" 

77 team_query_row = cls.map(rows[0]) 

78 team_members = {} 

79 for row in rows: 

80 mapped_row = cls.map(row) 

81 if mapped_row.member.id is None: 

82 continue 

83 

84 member = mapped_row.create_member_entity() 

85 team_members[member.uuid] = mapped_row.team_member.create_team_member( 

86 member 

87 ) 

88 return team_query_row.team.create_entity(team_members) 

89 

90 

91class TeamDbQuery(TeamQuery, DatabaseQuery): 

92 """A team query for a database.""" 

93 

94 def __init__(self, database: Database): 

95 super().__init__(database) 

96 

97 def init(self): 

98 self._query.from_(TeamRow.__table_name__).left_join( 

99 TeamMemberRow.__table_name__, 

100 on(TeamRow.column("id"), TeamMemberRow.column("team_id")), 

101 ).left_join( 

102 MemberRow.__table_name__, 

103 on(MemberRow.column("id"), TeamMemberRow.column("member_id")), 

104 ).left_join( 

105 MemberPersonRow.__table_name__, 

106 on(MemberPersonRow.column("id"), MemberRow.column("person_id")), 

107 ).left_join( 

108 CountryRow.__table_name__, 

109 on(CountryRow.column("id"), MemberPersonRow.column("nationality_id")), 

110 ) 

111 

112 @property 

113 def columns(self): 

114 return TeamQueryRow.get_aliases() 

115 

116 @property 

117 def count_column(self) -> str: 

118 return TeamRow.column("id") 

119 

120 def filter_by_id(self, id_: TeamIdentifier) -> Self: 

121 self._query.and_where(TeamRow.field("id").eq(id_.value)) 

122 return self 

123 

124 

125class TeamDbRepository(TeamRepository): 

126 """A team repository for a database.""" 

127 

128 def create_query(self) -> TeamQuery: 

129 return TeamDbQuery(self._database) 

130 

131 async def get(self, query: TeamQuery | None = None) -> TeamEntity: 

132 team_iterator = self.get_all(query) 

133 try: 

134 return await anext(team_iterator) 

135 except StopAsyncIteration: 

136 raise TeamNotFoundException("Team not found") from None 

137 

138 async def get_all( 

139 self, 

140 query: TeamQuery | None = None, 

141 limit: int | None = None, 

142 offset: int | None = None, 

143 ) -> AsyncGenerator[TeamEntity, None]: 

144 if query is None: 

145 query = self.create_query() 

146 

147 group_by_column = "team_id" 

148 row_iterator = query.fetch(limit=limit, offset=offset) 

149 async for _, group in async_groupby( 

150 row_iterator, key=lambda row: row[group_by_column] 

151 ): 

152 yield TeamQueryRow.create_entity(group) 

153 

154 def __init__(self, database: Database): 

155 self._database = database 

156 

157 async def create(self, team: TeamEntity) -> TeamEntity: 

158 new_team_id = await self._database.insert( 

159 TeamRow.__table_name__, TeamRow.persist(team) 

160 ) 

161 return Entity.replace(team, id_=TeamIdentifier(new_team_id)) 

162 

163 async def delete(self, team: TeamEntity) -> None: 

164 delete_team_members_query = ( 

165 self._database.create_query_factory() 

166 .delete(TeamMemberRow.__table_name__) 

167 .where(field("team_id").eq(team.id.value)) 

168 ) 

169 await self._database.execute(delete_team_members_query) 

170 await self._database.delete(team.id.value, TeamRow.__table_name__) 

171 

172 async def update(self, team: TeamEntity): 

173 await self._database.update( 

174 team.id.value, TeamRow.__table_name__, TeamRow.persist(team) 

175 ) 

176 

177 async def add_team_member(self, team: TeamEntity, member: TeamMember): 

178 team_member_row = TeamMemberRow.persist(team, member) 

179 await self._database.insert(TeamMemberRow.__table_name__, team_member_row)