Coverage for kwai/modules/training/coaches/coach_db_repository.py: 100%
25 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module that defines a coach repository for a database."""
2from typing import AsyncIterator
4from sql_smith.functions import on
5from sql_smith.query import SelectQuery
7from kwai.core.db.database import Database
8from kwai.modules.training.coaches.coach import CoachEntity, CoachIdentifier
9from kwai.modules.training.coaches.coach_repository import (
10 CoachNotFoundException,
11 CoachRepository,
12)
13from kwai.modules.training.coaches.coach_tables import (
14 CoachesTable,
15 CoachRow,
16 PersonRow,
17 PersonsTable,
18)
21def _create_entity(coach_row: CoachRow, person_row: PersonRow) -> CoachEntity:
22 return coach_row.create_entity(person_row)
25class CoachDbRepository(CoachRepository):
26 """A coach repository for a database."""
28 def __init__(self, database: Database):
29 """Initialize the repository.
31 Args:
32 database: The database for this repository.
33 """
34 self._database = database
36 def _create_query(self) -> SelectQuery:
37 """Create the base select query."""
38 return (
39 self._database.create_query_factory()
40 .select()
41 .from_(CoachesTable.table_name)
42 .columns(*(CoachesTable.aliases() + PersonsTable.aliases()))
43 .join(
44 PersonsTable.table_name,
45 on(CoachesTable.column("person_id"), PersonsTable.column("id")),
46 )
47 )
49 async def get_by_id(self, id: CoachIdentifier) -> CoachEntity:
50 query = self._create_query().and_where(CoachesTable.field("id").eq(id.value))
51 row = await self._database.fetch_one(query)
53 if not row:
54 raise CoachNotFoundException(f"Coach with id {id} not found.")
56 return _create_entity(CoachesTable(row), PersonsTable(row))
58 async def get_by_ids(self, *ids: CoachIdentifier) -> AsyncIterator[CoachEntity]:
59 unpacked_ids = tuple(i.value for i in ids)
60 query = self._create_query().and_where(
61 CoachesTable.field("id").in_(*unpacked_ids)
62 )
64 async for row in self._database.fetch(query):
65 yield _create_entity(CoachesTable(row), PersonsTable(row))