Coverage for kwai/modules/training/teams/team_db_repository.py: 100%
16 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module that defines a team repository for a database."""
2from typing import AsyncIterator
4from sql_smith.query import SelectQuery
6from kwai.core.db.database import Database
7from kwai.modules.training.teams.team import TeamEntity, TeamIdentifier
8from kwai.modules.training.teams.team_repository import TeamRepository
9from kwai.modules.training.teams.team_tables import TeamsTable
12class TeamDbRepository(TeamRepository):
13 """A team repository for a database."""
15 def __init__(self, database: Database):
16 """Initialize the repository.
18 Args:
19 database: The database for this repository.
20 """
21 self._database = database
23 def _create_query(self) -> SelectQuery:
24 """Create the base select query."""
25 return (
26 self._database.create_query_factory()
27 .select()
28 .from_(TeamsTable.table_name)
29 .columns(*TeamsTable.aliases())
30 )
32 async def get_by_ids(self, *ids: TeamIdentifier) -> AsyncIterator[TeamEntity]:
33 unpacked_ids = tuple(i.value for i in ids)
34 query = self._create_query().and_where(
35 TeamsTable.field("id").in_(*unpacked_ids)
36 )
38 async for row in self._database.fetch(query):
39 yield TeamsTable(row).create_entity()