Coverage for src/kwai/modules/training/teams/team_db_repository.py: 100%

26 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that defines a team repository for a database.""" 

2 

3from typing import AsyncIterator 

4 

5from kwai.core.db.database import Database 

6from kwai.modules.training.teams.team import TeamEntity, TeamIdentifier 

7from kwai.modules.training.teams.team_db_query import TeamDbQuery 

8from kwai.modules.training.teams.team_query import TeamQuery 

9from kwai.modules.training.teams.team_repository import TeamRepository 

10from kwai.modules.training.teams.team_tables import TeamsTable 

11 

12 

13class TeamDbRepository(TeamRepository): 

14 """A team repository for a database.""" 

15 

16 def __init__(self, database: Database): 

17 """Initialize the repository. 

18 

19 Args: 

20 database: The database for this repository. 

21 """ 

22 self._database = database 

23 

24 def create_query(self) -> TeamQuery: 

25 """Create the team query.""" 

26 return TeamDbQuery(self._database) 

27 

28 async def get_by_id(self, id: TeamIdentifier) -> TeamEntity: 

29 query = self.create_query() 

30 query.filter_by_id(id) 

31 

32 row = await query.fetch_one() 

33 

34 return TeamsTable(row).create_entity() 

35 

36 async def get_all(self) -> AsyncIterator[TeamEntity]: 

37 query = self.create_query() 

38 async for row in query.fetch(): 

39 yield TeamsTable(row).create_entity() 

40 

41 async def get_by_ids(self, *ids: TeamIdentifier) -> AsyncIterator[TeamEntity]: 

42 query = self.create_query() 

43 query.filter_by_ids(*ids) 

44 

45 async for row in query.fetch(): 

46 yield TeamsTable(row).create_entity()