Coverage for src/kwai/core/db/table_row.py: 100%
54 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that defines some dataclasses that can be used as data transfer objects."""
3from dataclasses import dataclass, fields
4from typing import ClassVar, Self
6from sql_smith.functions import alias
7from sql_smith.functions import field as sql_field
8from sql_smith.interfaces import ExpressionInterface
10from kwai.core.db.database import Record
13def _validate_dataclass(t):
14 """Check if all fields contains data with the correct type.
16 A ValueError will be raised when the data for a given field contains data with
17 an invalid type.
18 """
19 for k, v in t.__annotations__.items():
20 value = getattr(t, k)
21 if not isinstance(value, v):
22 raise ValueError(f"{k}({value}) of {t} should be of type {v}!")
25@dataclass(frozen=True, kw_only=True, slots=True)
26class TableRow:
27 """A data transfer object for a row of one table.
29 The derived class must be a dataclass.
31 Note:
32 The derived class is also the ideal place to act as builder for an entity.
33 """
35 __table_name__: ClassVar[str]
37 @classmethod
38 def get_column_alias(cls, name: str, prefix: str | None = None) -> str:
39 """Return the alias for a column."""
40 prefix = prefix or cls.__table_name__
41 return f"{prefix}_{name}"
43 @classmethod
44 def get_aliases(cls, prefix: str | None = None) -> list[ExpressionInterface]:
45 """Return aliases for all the fields of the dataclass."""
46 result = []
47 for field in fields(cls):
48 result.append(
49 alias(
50 f"{cls.__table_name__}.{field.name}",
51 cls.get_column_alias(field.name, prefix),
52 )
53 )
54 return result
56 @classmethod
57 def column(cls, column_name: str) -> str:
58 """Return the column prefixed with the table name."""
59 return f"{cls.__table_name__}.{column_name}"
61 @classmethod
62 def field(cls, column_name: str):
63 """Call sql-smith field with the given column.
65 short-cut for: field(table.table_name + '.' + column_name)
66 """
67 return sql_field(cls.column(column_name))
69 @classmethod
70 def map(cls, row: Record, prefix: str | None = None) -> Self:
71 """Map the data of a row to the dataclass.
73 A ValueError will be raised when a field contains data with the wrong type.
74 """
75 values = {}
76 for field in fields(cls):
77 column_alias = cls.get_column_alias(field.name, prefix)
78 values[field.name] = row.get(column_alias)
80 instance = cls(**values) # noqa
81 _validate_dataclass(instance)
83 return instance
86@dataclass(frozen=True, kw_only=True, slots=True)
87class JoinedTableRow:
88 """A data transfer object for data from multiple tables.
90 Each field of the dataclass will represent a table. The name of the field
91 will be used as prefix for creating an alias for each column of the associated
92 table.
94 The derived class must be a dataclass.
96 Note:
97 The derived class is also the ideal place to act as builder for an entity.
98 """
100 @classmethod
101 def get_aliases(cls) -> list[ExpressionInterface]:
102 """Return fields of all the TableRow dataclasses as aliases.
104 The name of the field will be used as prefix for the alias.
105 """
106 assert len(fields(cls)) > 0, "There are no fields. Is this a dataclass?"
108 aliases = []
109 for field in fields(cls):
110 aliases.extend(field.type.get_aliases(field.name))
111 return aliases
113 @classmethod
114 def map(cls, row: Record) -> Self:
115 """Map all fields of this dataclass to the TableRow dataclasses."""
116 tables = {}
117 for table_field in fields(cls):
118 tables[table_field.name] = table_field.type.map(row, table_field.name)
119 return cls(**tables) # noqa