Coverage for src/kwai/core/db/database.py: 91%
111 statements
« prev ^ index » next coverage.py v7.7.1, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2024-01-01 00:00 +0000
1"""Module for database classes/functions."""
3import dataclasses
5from collections import namedtuple
6from typing import Any, AsyncIterator, TypeAlias
8import asyncmy
10from loguru import logger
11from sql_smith import QueryFactory
12from sql_smith.engine import MysqlEngine
13from sql_smith.functions import field
14from sql_smith.query import AbstractQuery, SelectQuery
16from kwai.core.db.exceptions import DatabaseException, QueryException
17from kwai.core.settings import DatabaseSettings
20Record: TypeAlias = dict[str, Any]
21ExecuteResult = namedtuple("ExecuteResult", ("rowcount", "last_insert_id"))
24class Database:
25 """Class for communicating with a database.
27 Attributes:
28 _connection: A connection
29 _settings (DatabaseSettings): The settings for this database connection.
30 """
32 def __init__(self, settings: DatabaseSettings):
33 self._connection: asyncmy.Connection | None = None
34 self._settings = settings
36 async def setup(self):
37 """Set up the connection pool."""
38 try:
39 self._connection = await asyncmy.connect(
40 host=self._settings.host,
41 database=self._settings.name,
42 user=self._settings.user,
43 password=self._settings.password,
44 )
45 except Exception as exc:
46 raise DatabaseException(
47 f"Setting up connection for database {self._settings.name} "
48 f"failed: {exc}"
49 ) from exc
51 async def check_connection(self):
52 """Check if the connection is set, if not it will try to connect."""
53 if self._connection is None:
54 await self.setup()
56 async def close(self):
57 """Close the connection."""
58 if self._connection:
59 await self._connection.ensure_closed()
60 self._connection = None
62 @classmethod
63 def create_query_factory(cls) -> QueryFactory:
64 """Return a query factory for the current database engine.
66 The query factory is used to start creating a SELECT, INSERT, UPDATE or
67 DELETE query.
69 Returns:
70 (QueryFactory): The query factory from sql-smith.
71 Currently, it returns a query factory for the mysql engine. In the
72 future it can provide other engines.
73 """
74 return QueryFactory(MysqlEngine())
76 async def commit(self):
77 """Commit all changes."""
78 await self.check_connection()
79 await self._connection.commit()
81 async def execute(self, query: AbstractQuery) -> ExecuteResult:
82 """Execute a query.
84 The last rowid from the cursor is returned when the query executed
85 successfully. On insert, this can be used to determine the new id of a row.
87 Args:
88 query (AbstractQuery): The query to execute.
90 Returns:
91 (int): When the query is an insert query, it will return the last rowid.
92 (None): When there is no last rowid.
94 Raises:
95 (QueryException): Raised when the query contains an error.
96 """
97 compiled_query = query.compile()
99 await self.check_connection()
100 async with self._connection.cursor() as cursor:
101 try:
102 generated_sql = cursor.mogrify(
103 compiled_query.sql, compiled_query.params
104 )
105 self.log_query(generated_sql)
106 await cursor.execute(generated_sql)
107 return ExecuteResult(cursor.rowcount, cursor.lastrowid)
108 except Exception as exc:
109 raise QueryException(compiled_query.sql) from exc
111 async def fetch_one(self, query: SelectQuery) -> Record | None:
112 """Execute a query and return the first row.
114 Args:
115 query (SelectQuery): The query to execute.
117 Returns:
118 (Record): A row is a dictionary using the column names
119 as key and the column values as value.
120 (None): The query resulted in no rows found.
122 Raises:
123 (QueryException): Raised when the query contains an error.
124 """
125 compiled_query = query.compile()
127 await self.check_connection()
128 try:
129 async with self._connection.cursor() as cursor:
130 generated_sql = cursor.mogrify(
131 compiled_query.sql, compiled_query.params
132 )
133 self.log_query(generated_sql)
134 await cursor.execute(generated_sql)
135 column_names = [column[0] for column in cursor.description]
136 if row := await cursor.fetchone():
137 return {
138 column_name: column
139 for column, column_name in zip(row, column_names, strict=True)
140 }
141 except Exception as exc:
142 raise QueryException(compiled_query.sql) from exc
144 return None # Nothing found
146 async def fetch(self, query: SelectQuery) -> AsyncIterator[Record]:
147 """Execute a query and yields each row.
149 Args:
150 query (SelectQuery): The query to execute.
152 Yields:
153 (Record): A row is a dictionary using the column names
154 as key and the column values as value.
156 Raises:
157 (QueryException): Raised when the query contains an error.
158 """
159 compiled_query = query.compile()
160 self.log_query(compiled_query.sql)
162 await self.check_connection()
163 try:
164 async with self._connection.cursor() as cursor:
165 await cursor.execute(compiled_query.sql, compiled_query.params)
166 column_names = [column[0] for column in cursor.description]
167 while row := await cursor.fetchone():
168 yield {
169 column_name: column
170 for column, column_name in zip(row, column_names, strict=True)
171 }
172 except Exception as exc:
173 raise QueryException(compiled_query.sql) from exc
175 async def insert(
176 self, table_name: str, *table_data: Any, id_column: str = "id"
177 ) -> int:
178 """Insert one or more instances of a dataclass into the given table.
180 Args:
181 table_name: The name of the table
182 table_data: One or more instances of a dataclass containing the values
183 id_column: The name of the id column (default is 'id')
185 Returns:
186 (int): The last inserted id. When multiple inserts are performed, this will
187 be the id of the last executed insert.
189 Raises:
190 (QueryException): Raised when the query contains an error.
191 """
192 assert dataclasses.is_dataclass(table_data[0]), (
193 "table_data should be a dataclass"
194 )
196 record = dataclasses.asdict(table_data[0])
197 if id_column in record:
198 del record[id_column]
199 query = self.create_query_factory().insert(table_name).columns(*record.keys())
201 for data in table_data:
202 assert dataclasses.is_dataclass(data), "table_data should be a dataclass"
203 record = dataclasses.asdict(data)
204 if id_column in record:
205 del record[id_column]
206 query = query.values(*record.values())
208 execute_result = await self.execute(query)
209 return execute_result.last_insert_id
211 async def update(
212 self, id_: Any, table_name: str, table_data: Any, id_column: str = "id"
213 ) -> int:
214 """Update a dataclass in the given table.
216 Args:
217 id_: The id of the data to update.
218 table_name: The name of the table.
219 table_data: The dataclass containing the data.
220 id_column: The name of the id column (default is 'id').
222 Raises:
223 (QueryException): Raised when the query contains an error.
225 Returns:
226 The number of rows affected.
227 """
228 assert dataclasses.is_dataclass(table_data), "table_data should be a dataclass"
230 record = dataclasses.asdict(table_data)
231 del record[id_column]
232 query = (
233 self.create_query_factory()
234 .update(table_name)
235 .set(record)
236 .where(field(id_column).eq(id_))
237 )
238 execute_result = await self.execute(query)
239 return execute_result.rowcount
241 async def delete(self, id_: Any, table_name: str, id_column: str = "id"):
242 """Delete a row from the table using the id field.
244 Args:
245 id_ (Any): The id of the row to delete.
246 table_name (str): The name of the table.
247 id_column (str): The name of the id column (default is 'id')
249 Raises:
250 (QueryException): Raised when the query results in an error.
251 """
252 query = (
253 self.create_query_factory()
254 .delete(table_name)
255 .where(field(id_column).eq(id_))
256 )
257 await self.execute(query)
259 def log_query(self, query: str):
260 """Log a query.
262 Args:
263 query (str): The query to log.
264 """
265 db_logger = logger.bind(database=self._settings.name)
266 db_logger.info(
267 "DB: {database} - Query: {query}", database=self._settings.name, query=query
268 )
270 def log_affected_rows(self, rowcount: int):
271 """Log the number of affected rows of the last executed query.
273 Args:
274 rowcount: The number of affected rows.
275 """
276 db_logger = logger.bind(database=self._settings.name)
277 db_logger.info(
278 "DB: {database} - Affected rows: {rowcount}",
279 database=self._settings.name,
280 rowcount=rowcount,
281 )
283 @property
284 def settings(self) -> DatabaseSettings:
285 """Return the database settings.
287 This property is immutable: the returned value is a copy of the current
288 settings.
289 """
290 return self._settings.model_copy()
292 async def begin(self):
293 """Start a transaction."""
294 await self.check_connection()
295 await self._connection.begin()
297 async def rollback(self):
298 """Rollback a transaction."""
299 await self.check_connection()
300 await self._connection.rollback()