Coverage for src/kwai/modules/training/trainings/training_definition_db_repository.py: 98%
41 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that implements a training definition repository for a database."""
3from typing import Any, AsyncIterator
5from kwai.core.db.database import Database
6from kwai.core.db.rows import OwnersTable
7from kwai.core.domain.entity import Entity
8from kwai.modules.training.teams.team_tables import TeamsTable
9from kwai.modules.training.trainings.training_definition import (
10 TrainingDefinitionEntity,
11 TrainingDefinitionIdentifier,
12)
13from kwai.modules.training.trainings.training_definition_db_query import (
14 TrainingDefinitionDbQuery,
15)
16from kwai.modules.training.trainings.training_definition_query import (
17 TrainingDefinitionQuery,
18)
19from kwai.modules.training.trainings.training_definition_repository import (
20 TrainingDefinitionNotFoundException,
21 TrainingDefinitionRepository,
22)
23from kwai.modules.training.trainings.training_tables import (
24 TrainingDefinitionRow,
25 TrainingDefinitionsTable,
26)
29def _create_entity(row: dict[str, Any]):
30 if row[TrainingDefinitionsTable.alias_name("team_id")] is None:
31 team = None
32 else:
33 team = TeamsTable(row).create_entity()
34 return TrainingDefinitionsTable(row).create_entity(
35 team=team, owner=OwnersTable(row).create_owner()
36 )
39class TrainingDefinitionDbRepository(TrainingDefinitionRepository):
40 """A training definition repository for a database."""
42 def __init__(self, database: Database) -> None:
43 """Initialize the repository.
45 Args:
46 database: The database for this repository
47 """
48 self._database = database
50 def create_query(self) -> TrainingDefinitionQuery: # noqa
51 return TrainingDefinitionDbQuery(self._database)
53 async def get_by_id(
54 self, id_: TrainingDefinitionIdentifier
55 ) -> TrainingDefinitionEntity:
56 query = self.create_query()
57 query.filter_by_id(id_)
59 if row := await query.fetch_one():
60 return _create_entity(row)
62 raise TrainingDefinitionNotFoundException(
63 f"Training definition with id {id_} does not exist."
64 )
66 async def get_all(
67 self,
68 query: TrainingDefinitionQuery | None = None,
69 limit: int | None = None,
70 offset: int | None = None,
71 ) -> AsyncIterator[TrainingDefinitionEntity]:
72 if query is None:
73 query = self.create_query()
74 async for row in query.fetch(limit, offset):
75 yield _create_entity(row)
77 async def create(
78 self, training_definition: TrainingDefinitionEntity
79 ) -> TrainingDefinitionEntity:
80 new_id = await self._database.insert(
81 TrainingDefinitionsTable.table_name,
82 TrainingDefinitionRow.persist(training_definition),
83 )
84 await self._database.commit()
85 return Entity.replace(
86 training_definition, id_=TrainingDefinitionIdentifier(new_id)
87 )
89 async def update(self, training_definition: TrainingDefinitionEntity):
90 await self._database.update(
91 training_definition.id.value,
92 TrainingDefinitionsTable.table_name,
93 TrainingDefinitionRow.persist(training_definition),
94 )
95 await self._database.commit()
97 async def delete(self, training_definition: TrainingDefinitionEntity):
98 await self._database.delete(
99 training_definition.id.value, TrainingDefinitionsTable.table_name
100 )
101 await self._database.commit()