Coverage for src/kwai/modules/identity/tokens/token_tables.py: 98%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that defines all tables for tokens.""" 

2 

3from dataclasses import dataclass 

4from datetime import datetime 

5from typing import Self 

6 

7from kwai.core.db.table_row import TableRow, unwrap 

8from kwai.core.domain.value_objects.timestamp import Timestamp 

9from kwai.core.domain.value_objects.traceable_time import TraceableTime 

10from kwai.modules.identity.tokens.access_token import ( 

11 AccessTokenEntity, 

12 AccessTokenIdentifier, 

13) 

14from kwai.modules.identity.tokens.refresh_token import ( 

15 RefreshTokenEntity, 

16 RefreshTokenIdentifier, 

17) 

18from kwai.modules.identity.tokens.token_identifier import TokenIdentifier 

19from kwai.modules.identity.tokens.user_log import UserLogEntity, UserLogIdentifier 

20from kwai.modules.identity.tokens.value_objects import IpAddress, OpenId 

21from kwai.modules.identity.users.user_account import UserAccountEntity 

22 

23 

24@dataclass(kw_only=True, frozen=True, slots=True) 

25class AccessTokenRow(TableRow): 

26 """Represent a table row in the access tokens table.""" 

27 

28 __table_name__ = "oauth_access_tokens" 

29 

30 id: int | None 

31 identifier: str 

32 expiration: datetime 

33 user_id: int 

34 revoked: int 

35 created_at: datetime 

36 updated_at: datetime | None 

37 

38 def create_entity(self, user_account: UserAccountEntity) -> AccessTokenEntity: 

39 """Create an entity from the table row.""" 

40 return AccessTokenEntity( 

41 id=AccessTokenIdentifier(unwrap(self.id)), 

42 identifier=TokenIdentifier(hex_string=self.identifier), 

43 expiration=Timestamp.create_utc(self.expiration), 

44 user_account=user_account, 

45 revoked=self.revoked == 1, 

46 traceable_time=TraceableTime( 

47 created_at=Timestamp.create_utc(self.created_at), 

48 updated_at=Timestamp.create_utc(self.updated_at), 

49 ), 

50 ) 

51 

52 @classmethod 

53 def persist(cls, access_token: AccessTokenEntity) -> "AccessTokenRow": 

54 """Persist an access token entity to a table record.""" 

55 return AccessTokenRow( 

56 id=access_token.id.value, 

57 identifier=str(access_token.identifier), 

58 expiration=unwrap(access_token.expiration.timestamp), 

59 user_id=access_token.user_account.id.value, 

60 revoked=1 if access_token.revoked else 0, 

61 created_at=unwrap(access_token.traceable_time.created_at.timestamp), 

62 updated_at=access_token.traceable_time.updated_at.timestamp, 

63 ) 

64 

65 

66@dataclass(kw_only=True, frozen=True, slots=True) 

67class RefreshTokenRow(TableRow): 

68 """Represent a table row in the refresh token table.""" 

69 

70 __table_name__ = "oauth_refresh_tokens" 

71 

72 id: int | None = None 

73 identifier: str 

74 access_token_id: int 

75 expiration: datetime 

76 revoked: int 

77 created_at: datetime 

78 updated_at: datetime | None 

79 

80 def create_entity(self, access_token: AccessTokenEntity) -> RefreshTokenEntity: 

81 """Create a refresh token entity from the table row.""" 

82 return RefreshTokenEntity( 

83 id=RefreshTokenIdentifier(unwrap(self.id)), 

84 identifier=TokenIdentifier(hex_string=self.identifier), 

85 access_token=access_token, 

86 expiration=Timestamp.create_utc(self.expiration), 

87 revoked=self.revoked == 1, 

88 traceable_time=TraceableTime( 

89 created_at=unwrap(Timestamp.create_utc(self.created_at)), 

90 updated_at=Timestamp.create_utc(self.updated_at), 

91 ), 

92 ) 

93 

94 @classmethod 

95 def persist(cls, refresh_token: RefreshTokenEntity) -> "RefreshTokenRow": 

96 """Transform a refresh token entity into a table record.""" 

97 return RefreshTokenRow( 

98 id=refresh_token.id.value, 

99 identifier=str(refresh_token.identifier), 

100 access_token_id=refresh_token.access_token.id.value, 

101 expiration=unwrap(refresh_token.expiration.timestamp), 

102 revoked=1 if refresh_token.revoked else 0, 

103 created_at=unwrap(refresh_token.traceable_time.created_at.timestamp), 

104 updated_at=refresh_token.traceable_time.updated_at.timestamp, 

105 ) 

106 

107 

108@dataclass(kw_only=True, frozen=True, slots=True) 

109class UserLogRow(TableRow): 

110 """Represent a table row in the user logs table.""" 

111 

112 __table_name__ = "user_logs" 

113 

114 id: int 

115 success: int 

116 email: str 

117 user_id: int | None 

118 refresh_token_id: int | None 

119 client_ip: str 

120 user_agent: str 

121 openid_sub: str 

122 openid_provider: str 

123 remark: str 

124 created_at: datetime 

125 

126 def create_entity( 

127 self, 

128 user_account: UserAccountEntity | None, 

129 refresh_token: RefreshTokenEntity | None, 

130 ) -> UserLogEntity: 

131 """Create a User Log entity from the table row.""" 

132 return UserLogEntity( 

133 id=UserLogIdentifier(self.id), 

134 email=self.email, 

135 user_account=user_account, 

136 refresh_token=refresh_token, 

137 client_ip=IpAddress.create(self.client_ip), 

138 user_agent=self.user_agent, 

139 remark=self.remark, 

140 openid=OpenId(sub=self.openid_sub, provider=self.openid_provider), 

141 created_at=Timestamp.create_utc(self.created_at), 

142 ) 

143 

144 @classmethod 

145 def persist(cls, user_log: UserLogEntity) -> Self: 

146 """Transform a user log entity into a table record.""" 

147 return cls( 

148 id=user_log.id.value, 

149 success=1 if user_log.success else 0, 

150 email=user_log.email, 

151 user_id=None 

152 if user_log.user_account is None 

153 else user_log.user_account.id.value, 

154 refresh_token_id=None 

155 if user_log.refresh_token is None 

156 else user_log.refresh_token.id.value, 

157 client_ip=str(user_log.client_ip), 

158 user_agent=user_log.user_agent, 

159 openid_sub=user_log.openid.sub, 

160 openid_provider=user_log.openid.provider, 

161 remark=user_log.remark, 

162 created_at=unwrap(user_log.created_at.timestamp), 

163 )