Coverage for kwai/modules/training/trainings/training_definition_db_repository.py: 100%
37 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module that implements a training definition repository for a database."""
4from typing import Any, AsyncIterator
6from kwai.core.db.database import Database
7from kwai.core.db.rows import OwnersTable
8from kwai.core.domain.entity import Entity
9from kwai.modules.training.trainings.training_definition import (
10 TrainingDefinitionEntity,
11 TrainingDefinitionIdentifier,
12)
13from kwai.modules.training.trainings.training_definition_db_query import (
14 TrainingDefinitionDbQuery,
15)
16from kwai.modules.training.trainings.training_definition_query import (
17 TrainingDefinitionQuery,
18)
19from kwai.modules.training.trainings.training_definition_repository import (
20 TrainingDefinitionNotFoundException,
21 TrainingDefinitionRepository,
22)
23from kwai.modules.training.trainings.training_tables import (
24 TrainingDefinitionRow,
25 TrainingDefinitionsTable,
26)
29def _create_entity(row: dict[str, Any]):
30 return TrainingDefinitionsTable(row).create_entity(
31 owner=OwnersTable(row).create_owner()
32 )
35class TrainingDefinitionDbRepository(TrainingDefinitionRepository):
36 """A training definition repository for a database."""
38 def __init__(self, database: Database) -> None:
39 """Initialize the repository.
41 Args:
42 database: The database for this repository
43 """
44 self._database = database
46 def create_query(self) -> TrainingDefinitionQuery: # noqa
47 return TrainingDefinitionDbQuery(self._database)
49 async def get_by_id(
50 self, id_: TrainingDefinitionIdentifier
51 ) -> TrainingDefinitionEntity:
52 query = self.create_query()
53 query.filter_by_id(id_)
55 if row := await query.fetch_one():
56 return _create_entity(row)
58 raise TrainingDefinitionNotFoundException(
59 f"Training definition with id {id_} does not exist."
60 )
62 async def get_all(
63 self,
64 query: TrainingDefinitionQuery | None = None,
65 limit: int | None = None,
66 offset: int | None = None,
67 ) -> AsyncIterator[TrainingDefinitionEntity]:
68 if query is None:
69 query = self.create_query()
70 async for row in query.fetch(limit, offset):
71 yield _create_entity(row)
73 async def create(
74 self, training_definition: TrainingDefinitionEntity
75 ) -> TrainingDefinitionEntity:
76 new_id = await self._database.insert(
77 TrainingDefinitionsTable.table_name,
78 TrainingDefinitionRow.persist(training_definition),
79 )
80 await self._database.commit()
81 return Entity.replace(
82 training_definition, id_=TrainingDefinitionIdentifier(new_id)
83 )
85 async def update(self, training_definition: TrainingDefinitionEntity):
86 await self._database.update(
87 training_definition.id.value,
88 TrainingDefinitionsTable.table_name,
89 TrainingDefinitionRow.persist(training_definition),
90 )
91 await self._database.commit()
93 async def delete(self, training_definition: TrainingDefinitionEntity):
94 await self._database.delete(
95 training_definition.id.value, TrainingDefinitionsTable.table_name
96 )
97 await self._database.commit()