Coverage for src/kwai/core/db/database.py: 91%
111 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module for database classes/functions."""
3import dataclasses
5from collections import namedtuple
6from typing import Any, AsyncIterator, TypeAlias
8import asyncmy
10from loguru import logger
11from sql_smith import QueryFactory
12from sql_smith.engine import MysqlEngine
13from sql_smith.functions import field
14from sql_smith.query import AbstractQuery, SelectQuery
16from kwai.core.db.exceptions import DatabaseException, QueryException
17from kwai.core.settings import DatabaseSettings
20Record: TypeAlias = dict[str, Any]
21ExecuteResult = namedtuple("ExecuteResult", ("rowcount", "last_insert_id"))
24class Database:
25 """Class for communicating with a database.
27 Attributes:
28 _connection: A connection
29 _settings (DatabaseSettings): The settings for this database connection.
30 """
32 def __init__(self, settings: DatabaseSettings):
33 self._connection: asyncmy.Connection | None = None
34 self._settings = settings
36 async def setup(self):
37 """Set up the connection pool."""
38 try:
39 self._connection = await asyncmy.connect(
40 host=self._settings.host,
41 database=self._settings.name,
42 user=self._settings.user,
43 password=self._settings.password,
44 )
45 except Exception as exc:
46 raise DatabaseException(
47 f"Setting up connection for database {self._settings.name} "
48 f"failed: {exc}"
49 ) from exc
51 async def check_connection(self):
52 """Check if the connection is set, if not it will try to connect."""
53 if self._connection is None:
54 await self.setup()
56 async def close(self):
57 """Close the connection."""
58 if self._connection:
59 await self._connection.ensure_closed()
60 self._connection = None
62 @classmethod
63 def create_query_factory(cls) -> QueryFactory:
64 """Return a query factory for the current database engine.
66 The query factory is used to start creating a SELECT, INSERT, UPDATE or
67 DELETE query.
69 Returns:
70 (QueryFactory): The query factory from sql-smith.
71 Currently, it returns a query factory for the mysql engine. In the
72 future it can provide other engines.
73 """
74 return QueryFactory(MysqlEngine())
76 async def commit(self):
77 """Commit all changes."""
78 await self.check_connection()
79 await self._connection.commit()
81 async def execute(self, query: AbstractQuery) -> ExecuteResult:
82 """Execute a query.
84 The last rowid from the cursor is returned when the query executed
85 successfully. On insert, this can be used to determine the new id of a row.
87 Args:
88 query (AbstractQuery): The query to execute.
90 Returns:
91 (int): When the query is an insert query, it will return the last rowid.
92 (None): When there is no last rowid.
94 Raises:
95 (QueryException): Raised when the query contains an error.
96 """
97 compiled_query = query.compile()
99 await self.check_connection()
100 async with self._connection.cursor() as cursor:
101 try:
102 generated_sql = cursor.mogrify(
103 compiled_query.sql, compiled_query.params
104 )
105 self.log_query(generated_sql)
106 await cursor.execute(generated_sql)
107 return ExecuteResult(cursor.rowcount, cursor.lastrowid)
108 except Exception as exc:
109 raise QueryException(compiled_query.sql) from exc
111 async def fetch_one(self, query: SelectQuery) -> Record | None:
112 """Execute a query and return the first row.
114 Args:
115 query (SelectQuery): The query to execute.
117 Returns:
118 (Record): A row is a dictionary using the column names
119 as key and the column values as value.
120 (None): The query resulted in no rows found.
122 Raises:
123 (QueryException): Raised when the query contains an error.
124 """
125 compiled_query = query.compile()
127 await self.check_connection()
128 try:
129 async with self._connection.cursor() as cursor:
130 generated_sql = cursor.mogrify(
131 compiled_query.sql, compiled_query.params
132 )
133 self.log_query(generated_sql)
134 await cursor.execute(generated_sql)
135 column_names = [column[0] for column in cursor.description]
136 if row := await cursor.fetchone():
137 return {
138 column_name: column
139 for column, column_name in zip(row, column_names, strict=True)
140 }
141 except Exception as exc:
142 raise QueryException(compiled_query.sql) from exc
144 return None # Nothing found
146 async def fetch(self, query: SelectQuery) -> AsyncIterator[Record]:
147 """Execute a query and yields each row.
149 Args:
150 query (SelectQuery): The query to execute.
152 Yields:
153 (Record): A row is a dictionary using the column names
154 as key and the column values as value.
156 Raises:
157 (QueryException): Raised when the query contains an error.
158 """
159 compiled_query = query.compile()
160 self.log_query(compiled_query.sql)
162 await self.check_connection()
163 try:
164 async with self._connection.cursor() as cursor:
165 await cursor.execute(compiled_query.sql, compiled_query.params)
166 column_names = [column[0] for column in cursor.description]
167 while row := await cursor.fetchone():
168 yield {
169 column_name: column
170 for column, column_name in zip(row, column_names, strict=True)
171 }
172 except Exception as exc:
173 raise QueryException(compiled_query.sql) from exc
175 async def insert(self, table_name: str, *table_data: Any) -> int:
176 """Insert one or more instances of a dataclass into the given table.
178 Args:
179 table_name (str): The name of the table
180 table_data (Any): One or more instances of a dataclass containing the values
182 Returns:
183 (int): The last inserted id. When multiple inserts are performed, this will
184 be the id of the last executed insert.
186 Raises:
187 (QueryException): Raised when the query contains an error.
188 """
189 assert dataclasses.is_dataclass(table_data[0]), (
190 "table_data should be a dataclass"
191 )
193 record = dataclasses.asdict(table_data[0])
194 if "id" in record:
195 del record["id"]
196 query = self.create_query_factory().insert(table_name).columns(*record.keys())
198 for data in table_data:
199 assert dataclasses.is_dataclass(data), "table_data should be a dataclass"
200 record = dataclasses.asdict(data)
201 if "id" in record:
202 del record["id"]
203 query = query.values(*record.values())
205 execute_result = await self.execute(query)
206 return execute_result.last_insert_id
208 async def update(self, id_: Any, table_name: str, table_data: Any) -> int:
209 """Update a dataclass in the given table.
211 Args:
212 id_ (Any): The id of the data to update.
213 table_name: The name of the table.
214 table_data: The dataclass containing the data.
216 Raises:
217 (QueryException): Raised when the query contains an error.
219 Returns:
220 The number of rows affected.
221 """
222 assert dataclasses.is_dataclass(table_data), "table_data should be a dataclass"
224 record = dataclasses.asdict(table_data)
225 del record["id"]
226 query = (
227 self.create_query_factory()
228 .update(table_name)
229 .set(record)
230 .where(field("id").eq(id_))
231 )
232 execute_result = await self.execute(query)
233 return execute_result.rowcount
235 async def delete(self, id_: Any, table_name: str):
236 """Delete a row from the table using the id field.
238 Args:
239 id_ (Any): The id of the row to delete.
240 table_name (str): The name of the table.
242 Raises:
243 (QueryException): Raised when the query results in an error.
244 """
245 query = (
246 self.create_query_factory().delete(table_name).where(field("id").eq(id_))
247 )
248 await self.execute(query)
250 def log_query(self, query: str):
251 """Log a query.
253 Args:
254 query (str): The query to log.
255 """
256 db_logger = logger.bind(database=self._settings.name)
257 db_logger.info(
258 "DB: {database} - Query: {query}", database=self._settings.name, query=query
259 )
261 def log_affected_rows(self, rowcount: int):
262 """Log the number of affected rows of the last executed query.
264 Args:
265 rowcount: The number of affected rows.
266 """
267 db_logger = logger.bind(database=self._settings.name)
268 db_logger.info(
269 "DB: {database} - Affected rows: {rowcount}",
270 database=self._settings.name,
271 rowcount=rowcount,
272 )
274 @property
275 def settings(self) -> DatabaseSettings:
276 """Return the database settings.
278 This property is immutable: the returned value is a copy of the current
279 settings.
280 """
281 return self._settings.model_copy()
283 async def begin(self):
284 """Start a transaction."""
285 await self.check_connection()
286 await self._connection.begin()
288 async def rollback(self):
289 """Rollback a transaction."""
290 await self.check_connection()
291 await self._connection.rollback()