Coverage for src/kwai/core/db/database.py: 93%
108 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module for database classes/functions."""
3import dataclasses
5from typing import Any, AsyncIterator, TypeAlias
7import asyncmy
9from loguru import logger
10from sql_smith import QueryFactory
11from sql_smith.engine import MysqlEngine
12from sql_smith.functions import field
13from sql_smith.query import AbstractQuery, SelectQuery
15from kwai.core.db.exceptions import DatabaseException, QueryException
16from kwai.core.settings import DatabaseSettings
19Record: TypeAlias = dict[str, Any]
22class Database:
23 """Class for communicating with a database.
25 Attributes:
26 _connection: A connection
27 _settings (DatabaseSettings): The settings for this database connection.
28 """
30 def __init__(self, settings: DatabaseSettings):
31 self._connection: asyncmy.Connection | None = None
32 self._settings = settings
34 async def setup(self):
35 """Set up the connection pool."""
36 try:
37 self._connection = await asyncmy.connect(
38 host=self._settings.host,
39 database=self._settings.name,
40 user=self._settings.user,
41 password=self._settings.password,
42 )
43 except Exception as exc:
44 raise DatabaseException(
45 f"Setting up connection for database {self._settings.name} "
46 f"failed: {exc}"
47 ) from exc
49 async def check_connection(self):
50 """Check if the connection is set, if not it will try to connect."""
51 if self._connection is None:
52 await self.setup()
54 async def close(self):
55 """Close the connection."""
56 if self._connection:
57 await self._connection.ensure_closed()
58 self._connection = None
60 @classmethod
61 def create_query_factory(cls) -> QueryFactory:
62 """Return a query factory for the current database engine.
64 The query factory is used to start creating a SELECT, INSERT, UPDATE or
65 DELETE query.
67 Returns:
68 (QueryFactory): The query factory from sql-smith.
69 Currently, it returns a query factory for the mysql engine. In the
70 future it can provide other engines.
71 """
72 return QueryFactory(MysqlEngine())
74 async def commit(self):
75 """Commit all changes."""
76 await self.check_connection()
77 await self._connection.commit()
79 async def execute(self, query: AbstractQuery) -> int | None:
80 """Execute a query.
82 The last rowid from the cursor is returned when the query executed
83 successfully. On insert, this can be used to determine the new id of a row.
85 Args:
86 query (AbstractQuery): The query to execute.
88 Returns:
89 (int): When the query is an insert query, it will return the last rowid.
90 (None): When there is no last rowid.
92 Raises:
93 (QueryException): Raised when the query contains an error.
94 """
95 compiled_query = query.compile()
96 self.log_query(compiled_query.sql)
98 await self.check_connection()
99 async with self._connection.cursor() as cursor:
100 try:
101 await cursor.execute(compiled_query.sql, compiled_query.params)
102 if cursor.rowcount != -1:
103 self.log_affected_rows(cursor.rowcount)
104 return cursor.lastrowid
105 except Exception as exc:
106 raise QueryException(compiled_query.sql) from exc
108 async def fetch_one(self, query: SelectQuery) -> Record | None:
109 """Execute a query and return the first row.
111 Args:
112 query (SelectQuery): The query to execute.
114 Returns:
115 (Record): A row is a dictionary using the column names
116 as key and the column values as value.
117 (None): The query resulted in no rows found.
119 Raises:
120 (QueryException): Raised when the query contains an error.
121 """
122 compiled_query = query.compile()
123 self.log_query(compiled_query.sql)
125 await self.check_connection()
126 try:
127 async with self._connection.cursor() as cursor:
128 await cursor.execute(compiled_query.sql, compiled_query.params)
129 column_names = [column[0] for column in cursor.description]
130 if row := await cursor.fetchone():
131 return {
132 column_name: column
133 for column, column_name in zip(row, column_names, strict=True)
134 }
135 except Exception as exc:
136 raise QueryException(compiled_query.sql) from exc
138 return None # Nothing found
140 async def fetch(self, query: SelectQuery) -> AsyncIterator[Record]:
141 """Execute a query and yields each row.
143 Args:
144 query (SelectQuery): The query to execute.
146 Yields:
147 (Record): A row is a dictionary using the column names
148 as key and the column values as value.
150 Raises:
151 (QueryException): Raised when the query contains an error.
152 """
153 compiled_query = query.compile()
154 self.log_query(compiled_query.sql)
156 await self.check_connection()
157 try:
158 async with self._connection.cursor() as cursor:
159 await cursor.execute(compiled_query.sql, compiled_query.params)
160 column_names = [column[0] for column in cursor.description]
161 while row := await cursor.fetchone():
162 yield {
163 column_name: column
164 for column, column_name in zip(row, column_names, strict=True)
165 }
166 except Exception as exc:
167 raise QueryException(compiled_query.sql) from exc
169 async def insert(self, table_name: str, *table_data: Any) -> int:
170 """Insert one or more instances of a dataclass into the given table.
172 Args:
173 table_name (str): The name of the table
174 table_data (Any): One or more instances of a dataclass containing the values
176 Returns:
177 (int): The last inserted id. When multiple inserts are performed, this will
178 be the id of the last executed insert.
180 Raises:
181 (QueryException): Raised when the query contains an error.
182 """
183 assert dataclasses.is_dataclass(table_data[0]), (
184 "table_data should be a dataclass"
185 )
187 record = dataclasses.asdict(table_data[0])
188 if "id" in record:
189 del record["id"]
190 query = self.create_query_factory().insert(table_name).columns(*record.keys())
192 for data in table_data:
193 assert dataclasses.is_dataclass(data), "table_data should be a dataclass"
194 record = dataclasses.asdict(data)
195 if "id" in record:
196 del record["id"]
197 query = query.values(*record.values())
199 last_insert_id = await self.execute(query)
200 return last_insert_id
202 async def update(self, id_: Any, table_name: str, table_data: Any):
203 """Update a dataclass in the given table.
205 Args:
206 id_ (Any): The id of the data to update.
207 table_name: The name of the table.
208 table_data: The dataclass containing the data.
210 Raises:
211 (QueryException): Raised when the query contains an error.
212 """
213 assert dataclasses.is_dataclass(table_data), "table_data should be a dataclass"
215 record = dataclasses.asdict(table_data)
216 del record["id"]
217 query = (
218 self.create_query_factory()
219 .update(table_name)
220 .set(record)
221 .where(field("id").eq(id_))
222 )
223 await self.execute(query)
225 async def delete(self, id_: Any, table_name: str):
226 """Delete a row from the table using the id field.
228 Args:
229 id_ (Any): The id of the row to delete.
230 table_name (str): The name of the table.
232 Raises:
233 (QueryException): Raised when the query results in an error.
234 """
235 query = (
236 self.create_query_factory().delete(table_name).where(field("id").eq(id_))
237 )
238 await self.execute(query)
240 def log_query(self, query: str):
241 """Log a query.
243 Args:
244 query (str): The query to log.
245 """
246 db_logger = logger.bind(database=self._settings.name)
247 db_logger.info(
248 "DB: {database} - Query: {query}", database=self._settings.name, query=query
249 )
251 def log_affected_rows(self, rowcount: int):
252 """Log the number of affected rows of the last executed query.
254 Args:
255 rowcount: The number of affected rows.
256 """
257 db_logger = logger.bind(database=self._settings.name)
258 db_logger.info(
259 "DB: {database} - Affected rows: {rowcount}",
260 database=self._settings.name,
261 rowcount=rowcount,
262 )
264 @property
265 def settings(self) -> DatabaseSettings:
266 """Return the database settings.
268 This property is immutable: the returned value is a copy of the current
269 settings.
270 """
271 return self._settings.copy()
273 async def begin(self):
274 """Start a transaction."""
275 await self.check_connection()
276 await self._connection.begin()
278 async def rollback(self):
279 """Rollback a transaction."""
280 await self.check_connection()
281 await self._connection.rollback()