Coverage for kwai/core/db/table.py: 100%
30 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module for the table decorator."""
3from dataclasses import fields, is_dataclass
4from typing import Any, Callable, Generic, TypeVar
6from sql_smith.functions import alias
7from sql_smith.functions import field as sql_field
9T = TypeVar("T", bound=Callable)
12class Table(Generic[T]):
13 """Represent a table in the database.
15 With this class a table row can be transformed into a dataclass. It can also
16 be used to generate columns or aliases for queries.
17 """
19 def __init__(self, table_name: str, data_class: T):
20 assert is_dataclass(data_class)
21 self._table_name: str = table_name
22 self._data_class: T = data_class
24 @property
25 def table_name(self) -> str:
26 """Return the table name."""
27 return self._table_name
29 def __call__(self, row: dict[str, Any], table_name: str | None = None) -> T:
30 """Shortcut for map_row."""
31 return self.map_row(row, table_name)
33 def alias_name(self, column_name: str, table_name: str | None = None):
34 """Return an alias for a column.
36 The alias will be the name of the table delimited with an
37 underscore: <table_name>_<column_name>.
38 By default, the table name associated with the Table instance will be used.
40 Args:
41 column_name: The name of the column
42 table_name: To differ from the current table name, use this table name.
44 Returns:
45 The alias for the given column.
46 """
47 table_name = table_name or self._table_name
48 return table_name + "_" + column_name
50 def aliases(self, table_name: str | None = None):
51 """Return aliases for all fields of the dataclass."""
52 table_name = table_name or self._table_name
53 return [
54 alias(table_name + "." + prop.name, self.alias_name(prop.name, table_name))
55 for prop in fields(self._data_class)
56 ]
58 def column(self, column_name: str) -> str:
59 """Return column as <table>.<column>."""
60 return self._table_name + "." + column_name
62 def field(self, column_name: str):
63 """Call sql-smith field with the given column.
65 short-cut for: field(table.table_name + '.' + column_name)
66 """
67 return sql_field(self.column(column_name))
69 def map_row(self, row: dict[str, Any], table_name: str | None = None) -> T:
70 """Map the data of a row to the dataclass.
72 Only the fields that have the alias prefix for this table will be selected.
73 This makes it possible to pass it a row that contains data from multiple
74 tables (which can be the case with a join).
75 """
76 table_name = table_name or self._table_name
77 table_alias = table_name + "_"
78 # First, only select the values that belong to this table.
79 filtered = {
80 k.removeprefix(table_alias): v
81 for (k, v) in row.items()
82 if k.startswith(table_alias)
83 }
84 return self._data_class(**filtered)