Coverage for kwai/core/db/database_query.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module that implements a query for a database.""" 

2from abc import abstractmethod 

3from typing import Any, AsyncIterator 

4 

5from sql_smith.functions import alias, func 

6from sql_smith.query import SelectQuery 

7 

8from kwai.core.db.database import Database 

9from kwai.core.domain.repository.query import Query 

10 

11 

12class DatabaseQuery(Query): 

13 """Creates a query using a database.""" 

14 

15 def __init__(self, database: Database): 

16 self._database: Database = database 

17 self._query: SelectQuery = Database.create_query_factory().select() 

18 self.init() 

19 

20 @abstractmethod 

21 def init(self): 

22 """Override this method to create the base query.""" 

23 raise NotImplementedError 

24 

25 @property 

26 @abstractmethod 

27 def columns(self): 

28 """Returns the columns used in the query.""" 

29 raise NotImplementedError 

30 

31 @property 

32 def count_column(self) -> str: 

33 """The column used to count records.""" 

34 return "id" 

35 

36 async def count(self) -> int: 

37 """Execute the query and counts the number of records. 

38 

39 The `count_column` is used as column for a distinct count. 

40 """ 

41 # Reset limit/offset to avoid a wrong result 

42 self._query.limit(None) 

43 self._query.offset(None) 

44 

45 self._query.columns( 

46 alias(func("COUNT", func("DISTINCT", self.count_column)), "c") 

47 ) 

48 result = await self._database.fetch_one(self._query) 

49 return int(result["c"]) 

50 

51 async def fetch_one(self) -> dict[str, Any] | None: 

52 """Fetch only one record from this query.""" 

53 self._query.columns(*self.columns) 

54 return await self._database.fetch_one(self._query) 

55 

56 def fetch( 

57 self, limit: int | None = None, offset: int | None = None 

58 ) -> AsyncIterator[dict[str, Any]]: 

59 """Fetch all records from this query.""" 

60 self._query.limit(limit) 

61 self._query.offset(offset) 

62 self._query.columns(*self.columns) 

63 

64 return self._database.fetch(self._query)