Coverage for src/kwai/core/db/table.py: 100%
29 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module for the table decorator."""
3from dataclasses import fields, is_dataclass
4from typing import Any, Callable
6from sql_smith.functions import alias
7from sql_smith.functions import field as sql_field
10class Table[T: Callable]:
11 """Represent a table in the database.
13 With this class a table row can be transformed into a dataclass. It can also
14 be used to generate columns or aliases for queries.
15 """
17 def __init__(self, table_name: str, data_class: T):
18 assert is_dataclass(data_class)
19 self._table_name: str = table_name
20 self._data_class: T = data_class
22 @property
23 def table_name(self) -> str:
24 """Return the table name."""
25 return self._table_name
27 def __call__(self, row: dict[str, Any], table_name: str | None = None) -> T:
28 """Shortcut for map_row."""
29 return self.map_row(row, table_name)
31 def alias_name(self, column_name: str, table_name: str | None = None) -> str:
32 """Return an alias for a column.
34 The alias will be the name of the table delimited with an
35 underscore: <table_name>_<column_name>.
36 By default, the table name associated with the Table instance will be used.
38 Args:
39 column_name: The name of the column
40 table_name: To differ from the current table name, use this table name.
42 Returns:
43 The alias for the given column.
44 """
45 table_name = table_name or self._table_name
46 return table_name + "_" + column_name
48 def aliases(self, table_name: str | None = None):
49 """Return aliases for all fields of the dataclass."""
50 table_name = table_name or self._table_name
51 return [
52 alias(table_name + "." + prop.name, self.alias_name(prop.name, table_name))
53 for prop in fields(self._data_class)
54 ]
56 def column(self, column_name: str) -> str:
57 """Return column as <table>.<column>."""
58 return self._table_name + "." + column_name
60 def field(self, column_name: str):
61 """Call sql-smith field with the given column.
63 short-cut for: field(table.table_name + '.' + column_name)
64 """
65 return sql_field(self.column(column_name))
67 def map_row(self, row: dict[str, Any], table_name: str | None = None) -> T:
68 """Map the data of a row to the dataclass.
70 Only the fields that have the alias prefix for this table will be selected.
71 This makes it possible to pass it a row that contains data from multiple
72 tables (which can be the case with a join).
73 """
74 table_name = table_name or self._table_name
75 table_alias = table_name + "_"
76 # First, only select the values that belong to this table.
77 filtered = {
78 k.removeprefix(table_alias): v
79 for (k, v) in row.items()
80 if k.startswith(table_alias)
81 }
82 return self._data_class(**filtered)