Coverage for src/kwai/core/db/table.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module for the table decorator.""" 

2 

3from dataclasses import fields, is_dataclass 

4from typing import Any, Callable 

5 

6from sql_smith.functions import alias 

7from sql_smith.functions import field as sql_field 

8 

9 

10class Table[T: Callable]: 

11 """Represent a table in the database. 

12 

13 With this class a table row can be transformed into a dataclass. It can also 

14 be used to generate columns or aliases for queries. 

15 """ 

16 

17 def __init__(self, table_name: str, data_class: T): 

18 assert is_dataclass(data_class) 

19 self._table_name: str = table_name 

20 self._data_class: T = data_class 

21 

22 @property 

23 def table_name(self) -> str: 

24 """Return the table name.""" 

25 return self._table_name 

26 

27 def __call__(self, row: dict[str, Any], table_name: str | None = None) -> T: 

28 """Shortcut for map_row.""" 

29 return self.map_row(row, table_name) 

30 

31 def alias_name(self, column_name: str, table_name: str | None = None) -> str: 

32 """Return an alias for a column. 

33 

34 The alias will be the name of the table delimited with an 

35 underscore: <table_name>_<column_name>. 

36 By default, the table name associated with the Table instance will be used. 

37 

38 Args: 

39 column_name: The name of the column 

40 table_name: To differ from the current table name, use this table name. 

41 

42 Returns: 

43 The alias for the given column. 

44 """ 

45 table_name = table_name or self._table_name 

46 return table_name + "_" + column_name 

47 

48 def aliases(self, table_name: str | None = None): 

49 """Return aliases for all fields of the dataclass.""" 

50 table_name = table_name or self._table_name 

51 return [ 

52 alias(table_name + "." + prop.name, self.alias_name(prop.name, table_name)) 

53 for prop in fields(self._data_class) 

54 ] 

55 

56 def column(self, column_name: str) -> str: 

57 """Return column as <table>.<column>.""" 

58 return self._table_name + "." + column_name 

59 

60 def field(self, column_name: str): 

61 """Call sql-smith field with the given column. 

62 

63 short-cut for: field(table.table_name + '.' + column_name) 

64 """ 

65 return sql_field(self.column(column_name)) 

66 

67 def map_row(self, row: dict[str, Any], table_name: str | None = None) -> T: 

68 """Map the data of a row to the dataclass. 

69 

70 Only the fields that have the alias prefix for this table will be selected. 

71 This makes it possible to pass it a row that contains data from multiple 

72 tables (which can be the case with a join). 

73 """ 

74 table_name = table_name or self._table_name 

75 table_alias = table_name + "_" 

76 # First, only select the values that belong to this table. 

77 filtered = { 

78 k.removeprefix(table_alias): v 

79 for (k, v) in row.items() 

80 if k.startswith(table_alias) 

81 } 

82 return self._data_class(**filtered)