Coverage for src/kwai/core/db/database_query.py: 100%

28 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements a query for a database.""" 

2 

3from abc import abstractmethod 

4from typing import Any, AsyncIterator 

5 

6from sql_smith.functions import alias, func 

7from sql_smith.query import SelectQuery 

8 

9from kwai.core.db.database import Database 

10from kwai.core.domain.repository.query import Query 

11 

12 

13class DatabaseQuery(Query): 

14 """Creates a query using a database.""" 

15 

16 def __init__(self, database: Database): 

17 self._database: Database = database 

18 self._query: SelectQuery = Database.create_query_factory().select() 

19 self.init() 

20 

21 @abstractmethod 

22 def init(self): 

23 """Override this method to create the base query.""" 

24 raise NotImplementedError 

25 

26 @property 

27 @abstractmethod 

28 def columns(self): 

29 """Returns the columns used in the query.""" 

30 raise NotImplementedError 

31 

32 @property 

33 def count_column(self) -> str: 

34 """The column used to count records.""" 

35 return "id" 

36 

37 async def count(self) -> int: 

38 """Execute the query and counts the number of records. 

39 

40 The `count_column` is used as column for a distinct count. 

41 """ 

42 # Reset limit/offset to avoid a wrong result 

43 self._query.limit(None) 

44 self._query.offset(None) 

45 

46 self._query.columns( 

47 alias(func("COUNT", func("DISTINCT", self.count_column)), "c") 

48 ) 

49 result = await self._database.fetch_one(self._query) 

50 return int(result["c"]) 

51 

52 async def fetch_one(self) -> dict[str, Any] | None: 

53 """Fetch only one record from this query.""" 

54 self._query.columns(*self.columns) 

55 return await self._database.fetch_one(self._query) 

56 

57 def fetch( 

58 self, limit: int | None = None, offset: int | None = None 

59 ) -> AsyncIterator[dict[str, Any]]: 

60 """Fetch all records from this query.""" 

61 self._query.limit(limit) 

62 self._query.offset(offset) 

63 self._query.columns(*self.columns) 

64 

65 return self._database.fetch(self._query)