Coverage for kwai/core/db/database.py: 92%
97 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module for database classes/functions."""
2import dataclasses
3from typing import Any, AsyncIterator, TypeAlias
5import asyncmy
6from loguru import logger
7from sql_smith import QueryFactory
8from sql_smith.engine import MysqlEngine
9from sql_smith.functions import field
10from sql_smith.query import AbstractQuery, SelectQuery
12from kwai.core.db.exceptions import DatabaseException, QueryException
13from kwai.core.settings import DatabaseSettings
15Record: TypeAlias = dict[str, Any]
18class Database:
19 """Class for communicating with a database.
21 Attributes:
22 _connection: A connection
23 _settings (DatabaseSettings): The settings for this database connection.
24 """
26 def __init__(self, settings: DatabaseSettings):
27 self._connection: asyncmy.Connection | None = None
28 self._settings = settings
30 async def setup(self):
31 """Set up the connection pool."""
32 try:
33 self._connection = await asyncmy.connect(
34 host=self._settings.host,
35 database=self._settings.name,
36 user=self._settings.user,
37 password=self._settings.password,
38 )
39 except Exception as exc:
40 raise DatabaseException(
41 f"Setting up connection for database {self._settings.name} "
42 f"failed: {exc}"
43 ) from exc
45 async def check_connection(self):
46 """Check if the connection is set, if not it will try to connect."""
47 if self._connection is None:
48 await self.setup()
50 async def close(self):
51 """Close the connection."""
52 if self._connection:
53 await self._connection.ensure_closed()
54 self._connection = None
56 @classmethod
57 def create_query_factory(cls) -> QueryFactory:
58 """Return a query factory for the current database engine.
60 The query factory is used to start creating a SELECT, INSERT, UPDATE or
61 DELETE query.
63 Returns:
64 (QueryFactory): The query factory from sql-smith.
65 Currently, it returns a query factory for the mysql engine. In the
66 future it can provide other engines.
67 """
68 return QueryFactory(MysqlEngine())
70 async def commit(self):
71 """Commit all changes."""
72 await self.check_connection()
73 await self._connection.commit()
75 async def execute(self, query: AbstractQuery) -> int | None:
76 """Execute a query.
78 The last rowid from the cursor is returned when the query executed
79 successfully. On insert, this can be used to determine the new id of a row.
81 Args:
82 query (AbstractQuery): The query to execute.
84 Returns:
85 (int): When the query is an insert query, it will return the last rowid.
86 (None): When there is no last rowid.
88 Raises:
89 (QueryException): Raised when the query contains an error.
90 """
91 compiled_query = query.compile()
92 self.log_query(compiled_query.sql)
94 await self.check_connection()
95 async with self._connection.cursor() as cursor:
96 try:
97 await cursor.execute(compiled_query.sql, compiled_query.params)
98 return cursor.lastrowid
99 except Exception as exc:
100 raise QueryException(compiled_query.sql) from exc
102 async def fetch_one(self, query: SelectQuery) -> Record | None:
103 """Execute a query and return the first row.
105 Args:
106 query (SelectQuery): The query to execute.
108 Returns:
109 (Record): A row is a dictionary using the column names
110 as key and the column values as value.
111 (None): The query resulted in no rows found.
113 Raises:
114 (QueryException): Raised when the query contains an error.
115 """
116 compiled_query = query.compile()
117 self.log_query(compiled_query.sql)
119 await self.check_connection()
120 try:
121 async with self._connection.cursor() as cursor:
122 await cursor.execute(compiled_query.sql, compiled_query.params)
123 column_names = [column[0] for column in cursor.description]
124 if row := await cursor.fetchone():
125 return {
126 column_name: column
127 for column, column_name in zip(row, column_names, strict=True)
128 }
129 except Exception as exc:
130 raise QueryException(compiled_query.sql) from exc
132 return None # Nothing found
134 async def fetch(self, query: SelectQuery) -> AsyncIterator[Record]:
135 """Execute a query and yields each row.
137 Args:
138 query (SelectQuery): The query to execute.
140 Yields:
141 (Record): A row is a dictionary using the column names
142 as key and the column values as value.
144 Raises:
145 (QueryException): Raised when the query contains an error.
146 """
147 compiled_query = query.compile()
148 self.log_query(compiled_query.sql)
150 await self.check_connection()
151 try:
152 async with self._connection.cursor() as cursor:
153 await cursor.execute(compiled_query.sql, compiled_query.params)
154 column_names = [column[0] for column in cursor.description]
155 while row := await cursor.fetchone():
156 yield {
157 column_name: column
158 for column, column_name in zip(row, column_names, strict=True)
159 }
160 except Exception as exc:
161 raise QueryException(compiled_query.sql) from exc
163 async def insert(self, table_name: str, *table_data: Any) -> int:
164 """Insert one or more instances of a dataclass into the given table.
166 Args:
167 table_name (str): The name of the table
168 table_data (Any): One or more instances of a dataclass containing the values
170 Returns:
171 (int): The last inserted id. When multiple inserts are performed, this will
172 be the id of the last executed insert.
174 Raises:
175 (QueryException): Raised when the query contains an error.
176 """
177 assert dataclasses.is_dataclass(
178 table_data[0]
179 ), "table_data should be a dataclass"
181 record = dataclasses.asdict(table_data[0])
182 if "id" in record:
183 del record["id"]
184 query = self.create_query_factory().insert(table_name).columns(*record.keys())
186 for data in table_data:
187 assert dataclasses.is_dataclass(data), "table_data should be a dataclass"
188 record = dataclasses.asdict(data)
189 if "id" in record:
190 del record["id"]
191 query = query.values(*record.values())
193 last_insert_id = await self.execute(query)
194 return last_insert_id
196 async def update(self, id_: Any, table_name: str, table_data: Any):
197 """Update a dataclass in the given table.
199 Args:
200 id_ (Any): The id of the data to update.
201 table_name: The name of the table.
202 table_data: The dataclass containing the data.
204 Raises:
205 (QueryException): Raised when the query contains an error.
206 """
207 assert dataclasses.is_dataclass(table_data), "table_data should be a dataclass"
209 record = dataclasses.asdict(table_data)
210 del record["id"]
211 query = (
212 self.create_query_factory()
213 .update(table_name)
214 .set(record)
215 .where(field("id").eq(id_))
216 )
217 await self.execute(query)
219 async def delete(self, id_: Any, table_name: str):
220 """Delete a row from the table using the id field.
222 Args:
223 id_ (Any): The id of the row to delete.
224 table_name (str): The name of the table.
226 Raises:
227 (QueryException): Raised when the query results in an error.
228 """
229 query = (
230 self.create_query_factory().delete(table_name).where(field("id").eq(id_))
231 )
232 await self.execute(query)
234 def log_query(self, query: str):
235 """Log a query.
237 Args:
238 query (str): The query to log.
239 """
240 db_logger = logger.bind(database=self._settings.name)
241 db_logger.info(
242 "DB: {database} - Query: {query}", database=self._settings.name, query=query
243 )
245 @property
246 def settings(self) -> DatabaseSettings:
247 """Return the database settings.
249 This property is immutable: the returned value is a copy of the current
250 settings.
251 """
252 return self._settings.copy()