Coverage for kwai/modules/training/trainings/training_definition_db_repository.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module that implements a training definition repository for a database.""" 

2 

3 

4from typing import Any, AsyncIterator 

5 

6from kwai.core.db.database import Database 

7from kwai.core.db.rows import OwnersTable 

8from kwai.core.domain.entity import Entity 

9from kwai.modules.training.trainings.training_definition import ( 

10 TrainingDefinitionEntity, 

11 TrainingDefinitionIdentifier, 

12) 

13from kwai.modules.training.trainings.training_definition_db_query import ( 

14 TrainingDefinitionDbQuery, 

15) 

16from kwai.modules.training.trainings.training_definition_query import ( 

17 TrainingDefinitionQuery, 

18) 

19from kwai.modules.training.trainings.training_definition_repository import ( 

20 TrainingDefinitionNotFoundException, 

21 TrainingDefinitionRepository, 

22) 

23from kwai.modules.training.trainings.training_tables import ( 

24 TrainingDefinitionRow, 

25 TrainingDefinitionsTable, 

26) 

27 

28 

29def _create_entity(row: dict[str, Any]): 

30 return TrainingDefinitionsTable(row).create_entity( 

31 owner=OwnersTable(row).create_owner() 

32 ) 

33 

34 

35class TrainingDefinitionDbRepository(TrainingDefinitionRepository): 

36 """A training definition repository for a database.""" 

37 

38 def __init__(self, database: Database) -> None: 

39 """Initialize the repository. 

40 

41 Args: 

42 database: The database for this repository 

43 """ 

44 self._database = database 

45 

46 def create_query(self) -> TrainingDefinitionQuery: # noqa 

47 return TrainingDefinitionDbQuery(self._database) 

48 

49 async def get_by_id( 

50 self, id_: TrainingDefinitionIdentifier 

51 ) -> TrainingDefinitionEntity: 

52 query = self.create_query() 

53 query.filter_by_id(id_) 

54 

55 if row := await query.fetch_one(): 

56 return _create_entity(row) 

57 

58 raise TrainingDefinitionNotFoundException( 

59 f"Training definition with id {id_} does not exist." 

60 ) 

61 

62 async def get_all( 

63 self, 

64 query: TrainingDefinitionQuery | None = None, 

65 limit: int | None = None, 

66 offset: int | None = None, 

67 ) -> AsyncIterator[TrainingDefinitionEntity]: 

68 if query is None: 

69 query = self.create_query() 

70 async for row in query.fetch(limit, offset): 

71 yield _create_entity(row) 

72 

73 async def create( 

74 self, training_definition: TrainingDefinitionEntity 

75 ) -> TrainingDefinitionEntity: 

76 new_id = await self._database.insert( 

77 TrainingDefinitionsTable.table_name, 

78 TrainingDefinitionRow.persist(training_definition), 

79 ) 

80 await self._database.commit() 

81 return Entity.replace( 

82 training_definition, id_=TrainingDefinitionIdentifier(new_id) 

83 ) 

84 

85 async def update(self, training_definition: TrainingDefinitionEntity): 

86 await self._database.update( 

87 training_definition.id.value, 

88 TrainingDefinitionsTable.table_name, 

89 TrainingDefinitionRow.persist(training_definition), 

90 ) 

91 await self._database.commit() 

92 

93 async def delete(self, training_definition: TrainingDefinitionEntity): 

94 await self._database.delete( 

95 training_definition.id.value, TrainingDefinitionsTable.table_name 

96 ) 

97 await self._database.commit()