Coverage for src/kwai/modules/identity/tokens/token_tables.py: 100%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that defines all tables for tokens.""" 

2 

3from dataclasses import dataclass 

4from datetime import datetime 

5 

6from kwai.core.db.table_row import TableRow 

7from kwai.core.domain.value_objects.timestamp import Timestamp 

8from kwai.core.domain.value_objects.traceable_time import TraceableTime 

9from kwai.modules.identity.tokens.access_token import ( 

10 AccessTokenEntity, 

11 AccessTokenIdentifier, 

12) 

13from kwai.modules.identity.tokens.refresh_token import ( 

14 RefreshTokenEntity, 

15 RefreshTokenIdentifier, 

16) 

17from kwai.modules.identity.tokens.token_identifier import TokenIdentifier 

18from kwai.modules.identity.users.user_account import UserAccountEntity 

19 

20 

21@dataclass(kw_only=True, frozen=True, slots=True) 

22class AccessTokenRow(TableRow): 

23 """Represent a table row in the access tokens table.""" 

24 

25 __table_name__ = "oauth_access_tokens" 

26 

27 id: int | None 

28 identifier: str 

29 expiration: datetime 

30 user_id: int 

31 revoked: int 

32 created_at: datetime 

33 updated_at: datetime | None 

34 

35 def create_entity(self, user_account: UserAccountEntity) -> AccessTokenEntity: 

36 """Create an entity from the table row.""" 

37 return AccessTokenEntity( 

38 id=AccessTokenIdentifier(self.id), 

39 identifier=TokenIdentifier(hex_string=self.identifier), 

40 expiration=Timestamp.create_utc(self.expiration), 

41 user_account=user_account, 

42 revoked=self.revoked == 1, 

43 traceable_time=TraceableTime( 

44 created_at=Timestamp.create_utc(self.created_at), 

45 updated_at=Timestamp.create_utc(self.updated_at), 

46 ), 

47 ) 

48 

49 @classmethod 

50 def persist(cls, access_token: AccessTokenEntity) -> "AccessTokenRow": 

51 """Persist an access token entity to a table record.""" 

52 return AccessTokenRow( 

53 id=access_token.id.value, 

54 identifier=str(access_token.identifier), 

55 expiration=access_token.expiration.timestamp, 

56 user_id=access_token.user_account.id.value, 

57 revoked=1 if access_token.revoked else 0, 

58 created_at=access_token.traceable_time.created_at.timestamp, 

59 updated_at=access_token.traceable_time.updated_at.timestamp, 

60 ) 

61 

62 

63@dataclass(kw_only=True, frozen=True, slots=True) 

64class RefreshTokenRow(TableRow): 

65 """Represent a table row in the refresh token table.""" 

66 

67 __table_name__ = "oauth_refresh_tokens" 

68 

69 id: int | None 

70 identifier: str 

71 access_token_id: int 

72 expiration: datetime 

73 revoked: int 

74 created_at: datetime 

75 updated_at: datetime | None 

76 

77 def create_entity(self, access_token: AccessTokenEntity) -> RefreshTokenEntity: 

78 """Create a refresh token entity from the table row.""" 

79 return RefreshTokenEntity( 

80 id=RefreshTokenIdentifier(self.id), 

81 identifier=TokenIdentifier(hex_string=self.identifier), 

82 access_token=access_token, 

83 expiration=Timestamp.create_utc(self.expiration), 

84 revoked=self.revoked == 1, 

85 traceable_time=TraceableTime( 

86 created_at=Timestamp.create_utc(self.created_at), 

87 updated_at=Timestamp.create_utc(self.updated_at), 

88 ), 

89 ) 

90 

91 @classmethod 

92 def persist(cls, refresh_token: RefreshTokenEntity) -> "RefreshTokenRow": 

93 """Transform a refresh token entity into a table record.""" 

94 return RefreshTokenRow( 

95 id=refresh_token.id.value, 

96 identifier=str(refresh_token.identifier), 

97 access_token_id=refresh_token.access_token.id.value, 

98 expiration=refresh_token.expiration.timestamp, 

99 revoked=1 if refresh_token.revoked else 0, 

100 created_at=refresh_token.traceable_time.created_at.timestamp, 

101 updated_at=refresh_token.traceable_time.updated_at.timestamp, 

102 )