Coverage for src/kwai/modules/identity/logout.py: 100%

16 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements the logout use case.""" 

2 

3from dataclasses import dataclass 

4 

5from kwai.modules.identity.tokens.access_token_repository import AccessTokenRepository 

6from kwai.modules.identity.tokens.refresh_token_repository import RefreshTokenRepository 

7from kwai.modules.identity.tokens.token_identifier import TokenIdentifier 

8 

9 

10@dataclass(frozen=True, kw_only=True) 

11class LogoutCommand: 

12 """Command for the logout use case. 

13 

14 Attributes: 

15 identifier(str): The refresh token to revoke 

16 """ 

17 

18 identifier: str 

19 

20 

21class Logout: 

22 """Use case: logout a user. 

23 

24 A user is logged out by revoking the refresh token. The access token that is 

25 related to this refresh token will also be revoked. 

26 

27 Attributes: 

28 _refresh_token_repository (RefreshTokenRepository): The repository to 

29 get and update the refresh token. 

30 _access_token_repository (AccessTokenRepository): The repository to 

31 get and update the access token. 

32 """ 

33 

34 def __init__( 

35 self, 

36 refresh_token_repository: RefreshTokenRepository, 

37 access_token_repository: AccessTokenRepository, 

38 ): 

39 self._refresh_token_repository = refresh_token_repository 

40 self._access_token_repository = access_token_repository 

41 

42 async def execute(self, command: LogoutCommand): 

43 """Execute the use case. 

44 

45 Args: 

46 command: The input for this use case. 

47 

48 Raises: 

49 RefreshTokenNotFoundException: The refresh token with the identifier 

50 could not be found. 

51 """ 

52 refresh_token = await self._refresh_token_repository.get_by_token_identifier( 

53 TokenIdentifier(hex_string=command.identifier) 

54 ) 

55 refresh_token.revoke() 

56 

57 await self._refresh_token_repository.update(refresh_token) 

58 await self._access_token_repository.update(refresh_token.access_token)