Coverage for src/kwai/modules/training/trainings/training_definition_db_repository.py: 98%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements a training definition repository for a database.""" 

2 

3from typing import Any, AsyncIterator 

4 

5from kwai.core.db.database import Database 

6from kwai.core.db.rows import OwnersTable 

7from kwai.core.domain.entity import Entity 

8from kwai.modules.training.teams.team_tables import TeamsTable 

9from kwai.modules.training.trainings.training_definition import ( 

10 TrainingDefinitionEntity, 

11 TrainingDefinitionIdentifier, 

12) 

13from kwai.modules.training.trainings.training_definition_db_query import ( 

14 TrainingDefinitionDbQuery, 

15) 

16from kwai.modules.training.trainings.training_definition_query import ( 

17 TrainingDefinitionQuery, 

18) 

19from kwai.modules.training.trainings.training_definition_repository import ( 

20 TrainingDefinitionNotFoundException, 

21 TrainingDefinitionRepository, 

22) 

23from kwai.modules.training.trainings.training_tables import ( 

24 TrainingDefinitionRow, 

25 TrainingDefinitionsTable, 

26) 

27 

28 

29def _create_entity(row: dict[str, Any]): 

30 if row[TrainingDefinitionsTable.alias_name("team_id")] is None: 

31 team = None 

32 else: 

33 team = TeamsTable(row).create_entity() 

34 return TrainingDefinitionsTable(row).create_entity( 

35 team=team, owner=OwnersTable(row).create_owner() 

36 ) 

37 

38 

39class TrainingDefinitionDbRepository(TrainingDefinitionRepository): 

40 """A training definition repository for a database.""" 

41 

42 def __init__(self, database: Database) -> None: 

43 """Initialize the repository. 

44 

45 Args: 

46 database: The database for this repository 

47 """ 

48 self._database = database 

49 

50 def create_query(self) -> TrainingDefinitionQuery: # noqa 

51 return TrainingDefinitionDbQuery(self._database) 

52 

53 async def get_by_id( 

54 self, id_: TrainingDefinitionIdentifier 

55 ) -> TrainingDefinitionEntity: 

56 query = self.create_query() 

57 query.filter_by_id(id_) 

58 

59 if row := await query.fetch_one(): 

60 return _create_entity(row) 

61 

62 raise TrainingDefinitionNotFoundException( 

63 f"Training definition with id {id_} does not exist." 

64 ) 

65 

66 async def get_all( 

67 self, 

68 query: TrainingDefinitionQuery | None = None, 

69 limit: int | None = None, 

70 offset: int | None = None, 

71 ) -> AsyncIterator[TrainingDefinitionEntity]: 

72 if query is None: 

73 query = self.create_query() 

74 async for row in query.fetch(limit, offset): 

75 yield _create_entity(row) 

76 

77 async def create( 

78 self, training_definition: TrainingDefinitionEntity 

79 ) -> TrainingDefinitionEntity: 

80 new_id = await self._database.insert( 

81 TrainingDefinitionsTable.table_name, 

82 TrainingDefinitionRow.persist(training_definition), 

83 ) 

84 await self._database.commit() 

85 return Entity.replace( 

86 training_definition, id_=TrainingDefinitionIdentifier(new_id) 

87 ) 

88 

89 async def update(self, training_definition: TrainingDefinitionEntity): 

90 await self._database.update( 

91 training_definition.id.value, 

92 TrainingDefinitionsTable.table_name, 

93 TrainingDefinitionRow.persist(training_definition), 

94 ) 

95 await self._database.commit() 

96 

97 async def delete(self, training_definition: TrainingDefinitionEntity): 

98 await self._database.delete( 

99 training_definition.id.value, TrainingDefinitionsTable.table_name 

100 ) 

101 await self._database.commit()