Coverage for src/kwai/core/db/database_query.py: 100%
28 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that implements a query for a database."""
3from abc import abstractmethod
4from typing import Any, AsyncIterator
6from sql_smith.functions import alias, func
7from sql_smith.query import SelectQuery
9from kwai.core.db.database import Database
10from kwai.core.domain.repository.query import Query
13class DatabaseQuery(Query):
14 """Creates a query using a database."""
16 def __init__(self, database: Database):
17 self._database: Database = database
18 self._query: SelectQuery = Database.create_query_factory().select()
19 self.init()
21 @abstractmethod
22 def init(self):
23 """Override this method to create the base query."""
24 raise NotImplementedError
26 @property
27 @abstractmethod
28 def columns(self):
29 """Returns the columns used in the query."""
30 raise NotImplementedError
32 @property
33 def count_column(self) -> str:
34 """The column used to count records."""
35 return "id"
37 async def count(self) -> int:
38 """Execute the query and counts the number of records.
40 The `count_column` is used as column for a distinct count.
41 """
42 # Reset limit/offset to avoid a wrong result
43 self._query.limit(None)
44 self._query.offset(None)
46 self._query.columns(
47 alias(func("COUNT", func("DISTINCT", self.count_column)), "c")
48 )
49 result = await self._database.fetch_one(self._query)
50 return int(result["c"])
52 async def fetch_one(self) -> dict[str, Any] | None:
53 """Fetch only one record from this query."""
54 self._query.columns(*self.columns)
55 return await self._database.fetch_one(self._query)
57 def fetch(
58 self, limit: int | None = None, offset: int | None = None
59 ) -> AsyncIterator[dict[str, Any]]:
60 """Fetch all records from this query."""
61 self._query.limit(limit)
62 self._query.offset(offset)
63 self._query.columns(*self.columns)
65 return self._database.fetch(self._query)