Coverage for kwai/modules/training/teams/team_db_repository.py: 100%

16 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module that defines a team repository for a database.""" 

2from typing import AsyncIterator 

3 

4from sql_smith.query import SelectQuery 

5 

6from kwai.core.db.database import Database 

7from kwai.modules.training.teams.team import TeamEntity, TeamIdentifier 

8from kwai.modules.training.teams.team_repository import TeamRepository 

9from kwai.modules.training.teams.team_tables import TeamsTable 

10 

11 

12class TeamDbRepository(TeamRepository): 

13 """A team repository for a database.""" 

14 

15 def __init__(self, database: Database): 

16 """Initialize the repository. 

17 

18 Args: 

19 database: The database for this repository. 

20 """ 

21 self._database = database 

22 

23 def _create_query(self) -> SelectQuery: 

24 """Create the base select query.""" 

25 return ( 

26 self._database.create_query_factory() 

27 .select() 

28 .from_(TeamsTable.table_name) 

29 .columns(*TeamsTable.aliases()) 

30 ) 

31 

32 async def get_by_ids(self, *ids: TeamIdentifier) -> AsyncIterator[TeamEntity]: 

33 unpacked_ids = tuple(i.value for i in ids) 

34 query = self._create_query().and_where( 

35 TeamsTable.field("id").in_(*unpacked_ids) 

36 ) 

37 

38 async for row in self._database.fetch(query): 

39 yield TeamsTable(row).create_entity()