Coverage for kwai/core/db/table.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module for the table decorator.""" 

2 

3from dataclasses import fields, is_dataclass 

4from typing import Any, Callable, Generic, TypeVar 

5 

6from sql_smith.functions import alias 

7from sql_smith.functions import field as sql_field 

8 

9T = TypeVar("T", bound=Callable) 

10 

11 

12class Table(Generic[T]): 

13 """Represent a table in the database. 

14 

15 With this class a table row can be transformed into a dataclass. It can also 

16 be used to generate columns or aliases for queries. 

17 """ 

18 

19 def __init__(self, table_name: str, data_class: T): 

20 assert is_dataclass(data_class) 

21 self._table_name: str = table_name 

22 self._data_class: T = data_class 

23 

24 @property 

25 def table_name(self) -> str: 

26 """Return the table name.""" 

27 return self._table_name 

28 

29 def __call__(self, row: dict[str, Any], table_name: str | None = None) -> T: 

30 """Shortcut for map_row.""" 

31 return self.map_row(row, table_name) 

32 

33 def alias_name(self, column_name: str, table_name: str | None = None): 

34 """Return an alias for a column. 

35 

36 The alias will be the name of the table delimited with an 

37 underscore: <table_name>_<column_name>. 

38 By default, the table name associated with the Table instance will be used. 

39 

40 Args: 

41 column_name: The name of the column 

42 table_name: To differ from the current table name, use this table name. 

43 

44 Returns: 

45 The alias for the given column. 

46 """ 

47 table_name = table_name or self._table_name 

48 return table_name + "_" + column_name 

49 

50 def aliases(self, table_name: str | None = None): 

51 """Return aliases for all fields of the dataclass.""" 

52 table_name = table_name or self._table_name 

53 return [ 

54 alias(table_name + "." + prop.name, self.alias_name(prop.name, table_name)) 

55 for prop in fields(self._data_class) 

56 ] 

57 

58 def column(self, column_name: str) -> str: 

59 """Return column as <table>.<column>.""" 

60 return self._table_name + "." + column_name 

61 

62 def field(self, column_name: str): 

63 """Call sql-smith field with the given column. 

64 

65 short-cut for: field(table.table_name + '.' + column_name) 

66 """ 

67 return sql_field(self.column(column_name)) 

68 

69 def map_row(self, row: dict[str, Any], table_name: str | None = None) -> T: 

70 """Map the data of a row to the dataclass. 

71 

72 Only the fields that have the alias prefix for this table will be selected. 

73 This makes it possible to pass it a row that contains data from multiple 

74 tables (which can be the case with a join). 

75 """ 

76 table_name = table_name or self._table_name 

77 table_alias = table_name + "_" 

78 # First, only select the values that belong to this table. 

79 filtered = { 

80 k.removeprefix(table_alias): v 

81 for (k, v) in row.items() 

82 if k.startswith(table_alias) 

83 } 

84 return self._data_class(**filtered)