Coverage for src/kwai/core/db/table_row.py: 100%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that defines some dataclasses that can be used as data transfer objects.""" 

2 

3from dataclasses import dataclass, fields 

4from typing import ClassVar, Self 

5 

6from sql_smith.functions import alias 

7from sql_smith.functions import field as sql_field 

8from sql_smith.interfaces import ExpressionInterface 

9 

10from kwai.core.db.database import Record 

11 

12 

13def _validate_dataclass(t): 

14 """Check if all fields contains data with the correct type. 

15 

16 A ValueError will be raised when the data for a given field contains data with 

17 an invalid type. 

18 """ 

19 for k, v in t.__annotations__.items(): 

20 value = getattr(t, k) 

21 if not isinstance(value, v): 

22 raise ValueError(f"{k}({value}) of {t} should be of type {v}!") 

23 

24 

25@dataclass(frozen=True, kw_only=True, slots=True) 

26class TableRow: 

27 """A data transfer object for a row of one table. 

28 

29 The derived class must be a dataclass. 

30 

31 Note: 

32 The derived class is also the ideal place to act as builder for an entity. 

33 """ 

34 

35 __table_name__: ClassVar[str] 

36 

37 @classmethod 

38 def get_column_alias(cls, name: str, prefix: str | None = None) -> str: 

39 """Return the alias for a column.""" 

40 prefix = prefix or cls.__table_name__ 

41 return f"{prefix}_{name}" 

42 

43 @classmethod 

44 def get_aliases(cls, prefix: str | None = None) -> list[ExpressionInterface]: 

45 """Return aliases for all the fields of the dataclass.""" 

46 result = [] 

47 for field in fields(cls): 

48 result.append( 

49 alias( 

50 f"{cls.__table_name__}.{field.name}", 

51 cls.get_column_alias(field.name, prefix), 

52 ) 

53 ) 

54 return result 

55 

56 @classmethod 

57 def column(cls, column_name: str) -> str: 

58 """Return the column prefixed with the table name.""" 

59 return f"{cls.__table_name__}.{column_name}" 

60 

61 @classmethod 

62 def field(cls, column_name: str): 

63 """Call sql-smith field with the given column. 

64 

65 short-cut for: field(table.table_name + '.' + column_name) 

66 """ 

67 return sql_field(cls.column(column_name)) 

68 

69 @classmethod 

70 def map(cls, row: Record, prefix: str | None = None) -> Self: 

71 """Map the data of a row to the dataclass. 

72 

73 A ValueError will be raised when a field contains data with the wrong type. 

74 """ 

75 values = {} 

76 for field in fields(cls): 

77 column_alias = cls.get_column_alias(field.name, prefix) 

78 values[field.name] = row.get(column_alias) 

79 

80 instance = cls(**values) # noqa 

81 _validate_dataclass(instance) 

82 

83 return instance 

84 

85 

86@dataclass(frozen=True, kw_only=True, slots=True) 

87class JoinedTableRow: 

88 """A data transfer object for data from multiple tables. 

89 

90 Each field of the dataclass will represent a table. The name of the field 

91 will be used as prefix for creating an alias for each column of the associated 

92 table. 

93 

94 The derived class must be a dataclass. 

95 

96 Note: 

97 The derived class is also the ideal place to act as builder for an entity. 

98 """ 

99 

100 @classmethod 

101 def get_aliases(cls) -> list[ExpressionInterface]: 

102 """Return fields of all the TableRow dataclasses as aliases. 

103 

104 The name of the field will be used as prefix for the alias. 

105 """ 

106 assert len(fields(cls)) > 0, "There are no fields. Is this a dataclass?" 

107 

108 aliases = [] 

109 for field in fields(cls): 

110 aliases.extend(field.type.get_aliases(field.name)) 

111 return aliases 

112 

113 @classmethod 

114 def map(cls, row: Record) -> Self: 

115 """Map all fields of this dataclass to the TableRow dataclasses.""" 

116 tables = {} 

117 for table_field in fields(cls): 

118 tables[table_field.name] = table_field.type.map(row, table_field.name) 

119 return cls(**tables) # noqa