Coverage for kwai/core/db/database.py: 92%

97 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module for database classes/functions.""" 

2import dataclasses 

3from typing import Any, AsyncIterator, TypeAlias 

4 

5import asyncmy 

6from loguru import logger 

7from sql_smith import QueryFactory 

8from sql_smith.engine import MysqlEngine 

9from sql_smith.functions import field 

10from sql_smith.query import AbstractQuery, SelectQuery 

11 

12from kwai.core.db.exceptions import DatabaseException, QueryException 

13from kwai.core.settings import DatabaseSettings 

14 

15Record: TypeAlias = dict[str, Any] 

16 

17 

18class Database: 

19 """Class for communicating with a database. 

20 

21 Attributes: 

22 _connection: A connection 

23 _settings (DatabaseSettings): The settings for this database connection. 

24 """ 

25 

26 def __init__(self, settings: DatabaseSettings): 

27 self._connection: asyncmy.Connection | None = None 

28 self._settings = settings 

29 

30 async def setup(self): 

31 """Set up the connection pool.""" 

32 try: 

33 self._connection = await asyncmy.connect( 

34 host=self._settings.host, 

35 database=self._settings.name, 

36 user=self._settings.user, 

37 password=self._settings.password, 

38 ) 

39 except Exception as exc: 

40 raise DatabaseException( 

41 f"Setting up connection for database {self._settings.name} " 

42 f"failed: {exc}" 

43 ) from exc 

44 

45 async def check_connection(self): 

46 """Check if the connection is set, if not it will try to connect.""" 

47 if self._connection is None: 

48 await self.setup() 

49 

50 async def close(self): 

51 """Close the connection.""" 

52 if self._connection: 

53 await self._connection.ensure_closed() 

54 self._connection = None 

55 

56 @classmethod 

57 def create_query_factory(cls) -> QueryFactory: 

58 """Return a query factory for the current database engine. 

59 

60 The query factory is used to start creating a SELECT, INSERT, UPDATE or 

61 DELETE query. 

62 

63 Returns: 

64 (QueryFactory): The query factory from sql-smith. 

65 Currently, it returns a query factory for the mysql engine. In the 

66 future it can provide other engines. 

67 """ 

68 return QueryFactory(MysqlEngine()) 

69 

70 async def commit(self): 

71 """Commit all changes.""" 

72 await self.check_connection() 

73 await self._connection.commit() 

74 

75 async def execute(self, query: AbstractQuery) -> int | None: 

76 """Execute a query. 

77 

78 The last rowid from the cursor is returned when the query executed 

79 successfully. On insert, this can be used to determine the new id of a row. 

80 

81 Args: 

82 query (AbstractQuery): The query to execute. 

83 

84 Returns: 

85 (int): When the query is an insert query, it will return the last rowid. 

86 (None): When there is no last rowid. 

87 

88 Raises: 

89 (QueryException): Raised when the query contains an error. 

90 """ 

91 compiled_query = query.compile() 

92 self.log_query(compiled_query.sql) 

93 

94 await self.check_connection() 

95 async with self._connection.cursor() as cursor: 

96 try: 

97 await cursor.execute(compiled_query.sql, compiled_query.params) 

98 return cursor.lastrowid 

99 except Exception as exc: 

100 raise QueryException(compiled_query.sql) from exc 

101 

102 async def fetch_one(self, query: SelectQuery) -> Record | None: 

103 """Execute a query and return the first row. 

104 

105 Args: 

106 query (SelectQuery): The query to execute. 

107 

108 Returns: 

109 (Record): A row is a dictionary using the column names 

110 as key and the column values as value. 

111 (None): The query resulted in no rows found. 

112 

113 Raises: 

114 (QueryException): Raised when the query contains an error. 

115 """ 

116 compiled_query = query.compile() 

117 self.log_query(compiled_query.sql) 

118 

119 await self.check_connection() 

120 try: 

121 async with self._connection.cursor() as cursor: 

122 await cursor.execute(compiled_query.sql, compiled_query.params) 

123 column_names = [column[0] for column in cursor.description] 

124 if row := await cursor.fetchone(): 

125 return { 

126 column_name: column 

127 for column, column_name in zip(row, column_names, strict=True) 

128 } 

129 except Exception as exc: 

130 raise QueryException(compiled_query.sql) from exc 

131 

132 return None # Nothing found 

133 

134 async def fetch(self, query: SelectQuery) -> AsyncIterator[Record]: 

135 """Execute a query and yields each row. 

136 

137 Args: 

138 query (SelectQuery): The query to execute. 

139 

140 Yields: 

141 (Record): A row is a dictionary using the column names 

142 as key and the column values as value. 

143 

144 Raises: 

145 (QueryException): Raised when the query contains an error. 

146 """ 

147 compiled_query = query.compile() 

148 self.log_query(compiled_query.sql) 

149 

150 await self.check_connection() 

151 try: 

152 async with self._connection.cursor() as cursor: 

153 await cursor.execute(compiled_query.sql, compiled_query.params) 

154 column_names = [column[0] for column in cursor.description] 

155 while row := await cursor.fetchone(): 

156 yield { 

157 column_name: column 

158 for column, column_name in zip(row, column_names, strict=True) 

159 } 

160 except Exception as exc: 

161 raise QueryException(compiled_query.sql) from exc 

162 

163 async def insert(self, table_name: str, *table_data: Any) -> int: 

164 """Insert one or more instances of a dataclass into the given table. 

165 

166 Args: 

167 table_name (str): The name of the table 

168 table_data (Any): One or more instances of a dataclass containing the values 

169 

170 Returns: 

171 (int): The last inserted id. When multiple inserts are performed, this will 

172 be the id of the last executed insert. 

173 

174 Raises: 

175 (QueryException): Raised when the query contains an error. 

176 """ 

177 assert dataclasses.is_dataclass( 

178 table_data[0] 

179 ), "table_data should be a dataclass" 

180 

181 record = dataclasses.asdict(table_data[0]) 

182 if "id" in record: 

183 del record["id"] 

184 query = self.create_query_factory().insert(table_name).columns(*record.keys()) 

185 

186 for data in table_data: 

187 assert dataclasses.is_dataclass(data), "table_data should be a dataclass" 

188 record = dataclasses.asdict(data) 

189 if "id" in record: 

190 del record["id"] 

191 query = query.values(*record.values()) 

192 

193 last_insert_id = await self.execute(query) 

194 return last_insert_id 

195 

196 async def update(self, id_: Any, table_name: str, table_data: Any): 

197 """Update a dataclass in the given table. 

198 

199 Args: 

200 id_ (Any): The id of the data to update. 

201 table_name: The name of the table. 

202 table_data: The dataclass containing the data. 

203 

204 Raises: 

205 (QueryException): Raised when the query contains an error. 

206 """ 

207 assert dataclasses.is_dataclass(table_data), "table_data should be a dataclass" 

208 

209 record = dataclasses.asdict(table_data) 

210 del record["id"] 

211 query = ( 

212 self.create_query_factory() 

213 .update(table_name) 

214 .set(record) 

215 .where(field("id").eq(id_)) 

216 ) 

217 await self.execute(query) 

218 

219 async def delete(self, id_: Any, table_name: str): 

220 """Delete a row from the table using the id field. 

221 

222 Args: 

223 id_ (Any): The id of the row to delete. 

224 table_name (str): The name of the table. 

225 

226 Raises: 

227 (QueryException): Raised when the query results in an error. 

228 """ 

229 query = ( 

230 self.create_query_factory().delete(table_name).where(field("id").eq(id_)) 

231 ) 

232 await self.execute(query) 

233 

234 def log_query(self, query: str): 

235 """Log a query. 

236 

237 Args: 

238 query (str): The query to log. 

239 """ 

240 db_logger = logger.bind(database=self._settings.name) 

241 db_logger.info( 

242 "DB: {database} - Query: {query}", database=self._settings.name, query=query 

243 ) 

244 

245 @property 

246 def settings(self) -> DatabaseSettings: 

247 """Return the database settings. 

248 

249 This property is immutable: the returned value is a copy of the current 

250 settings. 

251 """ 

252 return self._settings.copy()