Coverage for src/kwai/modules/training/teams/team_db_repository.py: 100%
26 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that defines a team repository for a database."""
3from typing import AsyncIterator
5from kwai.core.db.database import Database
6from kwai.modules.training.teams.team import TeamEntity, TeamIdentifier
7from kwai.modules.training.teams.team_db_query import TeamDbQuery
8from kwai.modules.training.teams.team_query import TeamQuery
9from kwai.modules.training.teams.team_repository import TeamRepository
10from kwai.modules.training.teams.team_tables import TeamsTable
13class TeamDbRepository(TeamRepository):
14 """A team repository for a database."""
16 def __init__(self, database: Database):
17 """Initialize the repository.
19 Args:
20 database: The database for this repository.
21 """
22 self._database = database
24 def create_query(self) -> TeamQuery:
25 """Create the team query."""
26 return TeamDbQuery(self._database)
28 async def get_by_id(self, id: TeamIdentifier) -> TeamEntity:
29 query = self.create_query()
30 query.filter_by_id(id)
32 row = await query.fetch_one()
34 return TeamsTable(row).create_entity()
36 async def get_all(self) -> AsyncIterator[TeamEntity]:
37 query = self.create_query()
38 async for row in query.fetch():
39 yield TeamsTable(row).create_entity()
41 async def get_by_ids(self, *ids: TeamIdentifier) -> AsyncIterator[TeamEntity]:
42 query = self.create_query()
43 query.filter_by_ids(*ids)
45 async for row in query.fetch():
46 yield TeamsTable(row).create_entity()