Coverage for kwai/core/db/database_query.py: 100%
29 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module that implements a query for a database."""
2from abc import abstractmethod
3from typing import Any, AsyncIterator
5from sql_smith.functions import alias, func
6from sql_smith.query import SelectQuery
8from kwai.core.db.database import Database
9from kwai.core.domain.repository.query import Query
12class DatabaseQuery(Query):
13 """Creates a query using a database."""
15 def __init__(self, database: Database):
16 self._database: Database = database
17 self._query: SelectQuery = Database.create_query_factory().select()
18 self.init()
20 @abstractmethod
21 def init(self):
22 """Override this method to create the base query."""
23 raise NotImplementedError
25 @property
26 @abstractmethod
27 def columns(self):
28 """Returns the columns used in the query."""
29 raise NotImplementedError
31 @property
32 def count_column(self) -> str:
33 """The column used to count records."""
34 return "id"
36 async def count(self) -> int:
37 """Execute the query and counts the number of records.
39 The `count_column` is used as column for a distinct count.
40 """
41 # Reset limit/offset to avoid a wrong result
42 self._query.limit(None)
43 self._query.offset(None)
45 self._query.columns(
46 alias(func("COUNT", func("DISTINCT", self.count_column)), "c")
47 )
48 result = await self._database.fetch_one(self._query)
49 return int(result["c"])
51 async def fetch_one(self) -> dict[str, Any] | None:
52 """Fetch only one record from this query."""
53 self._query.columns(*self.columns)
54 return await self._database.fetch_one(self._query)
56 def fetch(
57 self, limit: int | None = None, offset: int | None = None
58 ) -> AsyncIterator[dict[str, Any]]:
59 """Fetch all records from this query."""
60 self._query.limit(limit)
61 self._query.offset(offset)
62 self._query.columns(*self.columns)
64 return self._database.fetch(self._query)