Coverage for src/kwai/core/db/database.py: 93%

108 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module for database classes/functions.""" 

2 

3import dataclasses 

4 

5from typing import Any, AsyncIterator, TypeAlias 

6 

7import asyncmy 

8 

9from loguru import logger 

10from sql_smith import QueryFactory 

11from sql_smith.engine import MysqlEngine 

12from sql_smith.functions import field 

13from sql_smith.query import AbstractQuery, SelectQuery 

14 

15from kwai.core.db.exceptions import DatabaseException, QueryException 

16from kwai.core.settings import DatabaseSettings 

17 

18 

19Record: TypeAlias = dict[str, Any] 

20 

21 

22class Database: 

23 """Class for communicating with a database. 

24 

25 Attributes: 

26 _connection: A connection 

27 _settings (DatabaseSettings): The settings for this database connection. 

28 """ 

29 

30 def __init__(self, settings: DatabaseSettings): 

31 self._connection: asyncmy.Connection | None = None 

32 self._settings = settings 

33 

34 async def setup(self): 

35 """Set up the connection pool.""" 

36 try: 

37 self._connection = await asyncmy.connect( 

38 host=self._settings.host, 

39 database=self._settings.name, 

40 user=self._settings.user, 

41 password=self._settings.password, 

42 ) 

43 except Exception as exc: 

44 raise DatabaseException( 

45 f"Setting up connection for database {self._settings.name} " 

46 f"failed: {exc}" 

47 ) from exc 

48 

49 async def check_connection(self): 

50 """Check if the connection is set, if not it will try to connect.""" 

51 if self._connection is None: 

52 await self.setup() 

53 

54 async def close(self): 

55 """Close the connection.""" 

56 if self._connection: 

57 await self._connection.ensure_closed() 

58 self._connection = None 

59 

60 @classmethod 

61 def create_query_factory(cls) -> QueryFactory: 

62 """Return a query factory for the current database engine. 

63 

64 The query factory is used to start creating a SELECT, INSERT, UPDATE or 

65 DELETE query. 

66 

67 Returns: 

68 (QueryFactory): The query factory from sql-smith. 

69 Currently, it returns a query factory for the mysql engine. In the 

70 future it can provide other engines. 

71 """ 

72 return QueryFactory(MysqlEngine()) 

73 

74 async def commit(self): 

75 """Commit all changes.""" 

76 await self.check_connection() 

77 await self._connection.commit() 

78 

79 async def execute(self, query: AbstractQuery) -> int | None: 

80 """Execute a query. 

81 

82 The last rowid from the cursor is returned when the query executed 

83 successfully. On insert, this can be used to determine the new id of a row. 

84 

85 Args: 

86 query (AbstractQuery): The query to execute. 

87 

88 Returns: 

89 (int): When the query is an insert query, it will return the last rowid. 

90 (None): When there is no last rowid. 

91 

92 Raises: 

93 (QueryException): Raised when the query contains an error. 

94 """ 

95 compiled_query = query.compile() 

96 self.log_query(compiled_query.sql) 

97 

98 await self.check_connection() 

99 async with self._connection.cursor() as cursor: 

100 try: 

101 await cursor.execute(compiled_query.sql, compiled_query.params) 

102 if cursor.rowcount != -1: 

103 self.log_affected_rows(cursor.rowcount) 

104 return cursor.lastrowid 

105 except Exception as exc: 

106 raise QueryException(compiled_query.sql) from exc 

107 

108 async def fetch_one(self, query: SelectQuery) -> Record | None: 

109 """Execute a query and return the first row. 

110 

111 Args: 

112 query (SelectQuery): The query to execute. 

113 

114 Returns: 

115 (Record): A row is a dictionary using the column names 

116 as key and the column values as value. 

117 (None): The query resulted in no rows found. 

118 

119 Raises: 

120 (QueryException): Raised when the query contains an error. 

121 """ 

122 compiled_query = query.compile() 

123 self.log_query(compiled_query.sql) 

124 

125 await self.check_connection() 

126 try: 

127 async with self._connection.cursor() as cursor: 

128 await cursor.execute(compiled_query.sql, compiled_query.params) 

129 column_names = [column[0] for column in cursor.description] 

130 if row := await cursor.fetchone(): 

131 return { 

132 column_name: column 

133 for column, column_name in zip(row, column_names, strict=True) 

134 } 

135 except Exception as exc: 

136 raise QueryException(compiled_query.sql) from exc 

137 

138 return None # Nothing found 

139 

140 async def fetch(self, query: SelectQuery) -> AsyncIterator[Record]: 

141 """Execute a query and yields each row. 

142 

143 Args: 

144 query (SelectQuery): The query to execute. 

145 

146 Yields: 

147 (Record): A row is a dictionary using the column names 

148 as key and the column values as value. 

149 

150 Raises: 

151 (QueryException): Raised when the query contains an error. 

152 """ 

153 compiled_query = query.compile() 

154 self.log_query(compiled_query.sql) 

155 

156 await self.check_connection() 

157 try: 

158 async with self._connection.cursor() as cursor: 

159 await cursor.execute(compiled_query.sql, compiled_query.params) 

160 column_names = [column[0] for column in cursor.description] 

161 while row := await cursor.fetchone(): 

162 yield { 

163 column_name: column 

164 for column, column_name in zip(row, column_names, strict=True) 

165 } 

166 except Exception as exc: 

167 raise QueryException(compiled_query.sql) from exc 

168 

169 async def insert(self, table_name: str, *table_data: Any) -> int: 

170 """Insert one or more instances of a dataclass into the given table. 

171 

172 Args: 

173 table_name (str): The name of the table 

174 table_data (Any): One or more instances of a dataclass containing the values 

175 

176 Returns: 

177 (int): The last inserted id. When multiple inserts are performed, this will 

178 be the id of the last executed insert. 

179 

180 Raises: 

181 (QueryException): Raised when the query contains an error. 

182 """ 

183 assert dataclasses.is_dataclass(table_data[0]), ( 

184 "table_data should be a dataclass" 

185 ) 

186 

187 record = dataclasses.asdict(table_data[0]) 

188 if "id" in record: 

189 del record["id"] 

190 query = self.create_query_factory().insert(table_name).columns(*record.keys()) 

191 

192 for data in table_data: 

193 assert dataclasses.is_dataclass(data), "table_data should be a dataclass" 

194 record = dataclasses.asdict(data) 

195 if "id" in record: 

196 del record["id"] 

197 query = query.values(*record.values()) 

198 

199 last_insert_id = await self.execute(query) 

200 return last_insert_id 

201 

202 async def update(self, id_: Any, table_name: str, table_data: Any): 

203 """Update a dataclass in the given table. 

204 

205 Args: 

206 id_ (Any): The id of the data to update. 

207 table_name: The name of the table. 

208 table_data: The dataclass containing the data. 

209 

210 Raises: 

211 (QueryException): Raised when the query contains an error. 

212 """ 

213 assert dataclasses.is_dataclass(table_data), "table_data should be a dataclass" 

214 

215 record = dataclasses.asdict(table_data) 

216 del record["id"] 

217 query = ( 

218 self.create_query_factory() 

219 .update(table_name) 

220 .set(record) 

221 .where(field("id").eq(id_)) 

222 ) 

223 await self.execute(query) 

224 

225 async def delete(self, id_: Any, table_name: str): 

226 """Delete a row from the table using the id field. 

227 

228 Args: 

229 id_ (Any): The id of the row to delete. 

230 table_name (str): The name of the table. 

231 

232 Raises: 

233 (QueryException): Raised when the query results in an error. 

234 """ 

235 query = ( 

236 self.create_query_factory().delete(table_name).where(field("id").eq(id_)) 

237 ) 

238 await self.execute(query) 

239 

240 def log_query(self, query: str): 

241 """Log a query. 

242 

243 Args: 

244 query (str): The query to log. 

245 """ 

246 db_logger = logger.bind(database=self._settings.name) 

247 db_logger.info( 

248 "DB: {database} - Query: {query}", database=self._settings.name, query=query 

249 ) 

250 

251 def log_affected_rows(self, rowcount: int): 

252 """Log the number of affected rows of the last executed query. 

253 

254 Args: 

255 rowcount: The number of affected rows. 

256 """ 

257 db_logger = logger.bind(database=self._settings.name) 

258 db_logger.info( 

259 "DB: {database} - Affected rows: {rowcount}", 

260 database=self._settings.name, 

261 rowcount=rowcount, 

262 ) 

263 

264 @property 

265 def settings(self) -> DatabaseSettings: 

266 """Return the database settings. 

267 

268 This property is immutable: the returned value is a copy of the current 

269 settings. 

270 """ 

271 return self._settings.copy() 

272 

273 async def begin(self): 

274 """Start a transaction.""" 

275 await self.check_connection() 

276 await self._connection.begin() 

277 

278 async def rollback(self): 

279 """Rollback a transaction.""" 

280 await self.check_connection() 

281 await self._connection.rollback()