Coverage for src/kwai/modules/training/coaches/coach_db_repository.py: 96%
28 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that defines a coach repository for a database."""
3from typing import AsyncIterator
5from kwai.core.db.database import Database
6from kwai.modules.training.coaches._tables import (
7 CoachRow,
8 PersonRow,
9)
10from kwai.modules.training.coaches.coach import CoachEntity, CoachIdentifier
11from kwai.modules.training.coaches.coach_db_query import CoachDbQuery, CoachQueryRow
12from kwai.modules.training.coaches.coach_query import CoachQuery
13from kwai.modules.training.coaches.coach_repository import (
14 CoachNotFoundException,
15 CoachRepository,
16)
19def _create_entity(coach_row: CoachRow, person_row: PersonRow) -> CoachEntity:
20 return coach_row.create_entity(person_row)
23class CoachDbRepository(CoachRepository):
24 """A coach repository for a database."""
26 def __init__(self, database: Database):
27 """Initialize the repository.
29 Args:
30 database: The database for this repository.
31 """
32 self._database = database
34 def create_query(self) -> CoachQuery:
35 """Create the coach query."""
36 return CoachDbQuery(self._database)
38 async def get_by_id(self, id: CoachIdentifier) -> CoachEntity:
39 query = self.create_query().filter_by_id(id)
40 row = await query.fetch_one()
42 if not row:
43 raise CoachNotFoundException(f"Coach with id {id} not found.")
45 return CoachQueryRow.map(row).create_entity()
47 async def get_by_ids(self, *ids: CoachIdentifier) -> AsyncIterator[CoachEntity]:
48 query = self.create_query().filter_by_ids(*ids)
50 async for row in query.fetch():
51 yield CoachQueryRow.map(row).create_entity()
53 async def get_all(
54 self, query: CoachQuery | None = None
55 ) -> AsyncIterator[CoachEntity]:
56 query = query or self.create_query()
57 async for row in query.fetch():
58 yield CoachQueryRow.map(row).create_entity()