Coverage for kwai/modules/training/coaches/coach_db_repository.py: 100%

25 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module that defines a coach repository for a database.""" 

2from typing import AsyncIterator 

3 

4from sql_smith.functions import on 

5from sql_smith.query import SelectQuery 

6 

7from kwai.core.db.database import Database 

8from kwai.modules.training.coaches.coach import CoachEntity, CoachIdentifier 

9from kwai.modules.training.coaches.coach_repository import ( 

10 CoachNotFoundException, 

11 CoachRepository, 

12) 

13from kwai.modules.training.coaches.coach_tables import ( 

14 CoachesTable, 

15 CoachRow, 

16 PersonRow, 

17 PersonsTable, 

18) 

19 

20 

21def _create_entity(coach_row: CoachRow, person_row: PersonRow) -> CoachEntity: 

22 return coach_row.create_entity(person_row) 

23 

24 

25class CoachDbRepository(CoachRepository): 

26 """A coach repository for a database.""" 

27 

28 def __init__(self, database: Database): 

29 """Initialize the repository. 

30 

31 Args: 

32 database: The database for this repository. 

33 """ 

34 self._database = database 

35 

36 def _create_query(self) -> SelectQuery: 

37 """Create the base select query.""" 

38 return ( 

39 self._database.create_query_factory() 

40 .select() 

41 .from_(CoachesTable.table_name) 

42 .columns(*(CoachesTable.aliases() + PersonsTable.aliases())) 

43 .join( 

44 PersonsTable.table_name, 

45 on(CoachesTable.column("person_id"), PersonsTable.column("id")), 

46 ) 

47 ) 

48 

49 async def get_by_id(self, id: CoachIdentifier) -> CoachEntity: 

50 query = self._create_query().and_where(CoachesTable.field("id").eq(id.value)) 

51 row = await self._database.fetch_one(query) 

52 

53 if not row: 

54 raise CoachNotFoundException(f"Coach with id {id} not found.") 

55 

56 return _create_entity(CoachesTable(row), PersonsTable(row)) 

57 

58 async def get_by_ids(self, *ids: CoachIdentifier) -> AsyncIterator[CoachEntity]: 

59 unpacked_ids = tuple(i.value for i in ids) 

60 query = self._create_query().and_where( 

61 CoachesTable.field("id").in_(*unpacked_ids) 

62 ) 

63 

64 async for row in self._database.fetch(query): 

65 yield _create_entity(CoachesTable(row), PersonsTable(row))