Coverage for src/kwai/modules/training/coaches/coach_db_repository.py: 96%

28 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that defines a coach repository for a database.""" 

2 

3from typing import AsyncIterator 

4 

5from kwai.core.db.database import Database 

6from kwai.modules.training.coaches._tables import ( 

7 CoachRow, 

8 PersonRow, 

9) 

10from kwai.modules.training.coaches.coach import CoachEntity, CoachIdentifier 

11from kwai.modules.training.coaches.coach_db_query import CoachDbQuery, CoachQueryRow 

12from kwai.modules.training.coaches.coach_query import CoachQuery 

13from kwai.modules.training.coaches.coach_repository import ( 

14 CoachNotFoundException, 

15 CoachRepository, 

16) 

17 

18 

19def _create_entity(coach_row: CoachRow, person_row: PersonRow) -> CoachEntity: 

20 return coach_row.create_entity(person_row) 

21 

22 

23class CoachDbRepository(CoachRepository): 

24 """A coach repository for a database.""" 

25 

26 def __init__(self, database: Database): 

27 """Initialize the repository. 

28 

29 Args: 

30 database: The database for this repository. 

31 """ 

32 self._database = database 

33 

34 def create_query(self) -> CoachQuery: 

35 """Create the coach query.""" 

36 return CoachDbQuery(self._database) 

37 

38 async def get_by_id(self, id: CoachIdentifier) -> CoachEntity: 

39 query = self.create_query().filter_by_id(id) 

40 row = await query.fetch_one() 

41 

42 if not row: 

43 raise CoachNotFoundException(f"Coach with id {id} not found.") 

44 

45 return CoachQueryRow.map(row).create_entity() 

46 

47 async def get_by_ids(self, *ids: CoachIdentifier) -> AsyncIterator[CoachEntity]: 

48 query = self.create_query().filter_by_ids(*ids) 

49 

50 async for row in query.fetch(): 

51 yield CoachQueryRow.map(row).create_entity() 

52 

53 async def get_all( 

54 self, query: CoachQuery | None = None 

55 ) -> AsyncIterator[CoachEntity]: 

56 query = query or self.create_query() 

57 async for row in query.fetch(): 

58 yield CoachQueryRow.map(row).create_entity()