Coverage for src/kwai/api/dependencies.py: 80%

56 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2024-01-01 00:00 +0000

1"""Module that integrates the dependencies in FastAPI.""" 

2 

3from typing import Annotated, AsyncGenerator 

4 

5import jwt 

6 

7from fastapi import Cookie, Depends, HTTPException, status 

8from fastapi.security import OAuth2PasswordBearer 

9from fastapi.templating import Jinja2Templates 

10from jwt import ExpiredSignatureError 

11from redis.asyncio import Redis 

12 

13from kwai.core.db.database import Database 

14from kwai.core.events.publisher import Publisher 

15from kwai.core.events.redis_bus import RedisBus 

16from kwai.core.settings import SecuritySettings, Settings, get_settings 

17from kwai.core.template.jinja2_engine import Jinja2Engine 

18from kwai.modules.identity.tokens.access_token_db_repository import ( 

19 AccessTokenDbRepository, 

20) 

21from kwai.modules.identity.tokens.access_token_repository import ( 

22 AccessTokenNotFoundException, 

23) 

24from kwai.modules.identity.tokens.token_identifier import TokenIdentifier 

25from kwai.modules.identity.users.user import UserEntity 

26 

27 

28oauth = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") 

29 

30 

31async def create_database( 

32 settings=Depends(get_settings), 

33) -> AsyncGenerator[Database, None]: 

34 """Create the database dependency.""" 

35 database = Database(settings.db) 

36 try: 

37 yield database 

38 finally: 

39 await database.close() 

40 

41 

42async def create_templates(settings=Depends(get_settings)) -> Jinja2Templates: 

43 """Create the template engine dependency.""" 

44 return Jinja2Engine(website=settings.website).web_templates 

45 

46 

47async def get_current_user( 

48 settings: Annotated[Settings, Depends(get_settings)], 

49 db: Annotated[Database, Depends(create_database)], 

50 access_token: Annotated[str | None, Cookie()] = None, 

51) -> UserEntity: 

52 """Try to get the current user from the access token. 

53 

54 Not authorized will be raised when the access token is not found, expired, revoked 

55 or when the user is revoked. 

56 """ 

57 if not access_token: 

58 raise HTTPException( 

59 status.HTTP_401_UNAUTHORIZED, detail="Access token cookie missing" 

60 ) 

61 return await _get_user_from_token(access_token, settings.security, db) 

62 

63 

64optional_oauth = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False) 

65 

66 

67async def get_publisher( 

68 settings=Depends(get_settings), 

69) -> AsyncGenerator[Publisher, None]: 

70 """Get the publisher dependency.""" 

71 redis = Redis( 

72 host=settings.redis.host, 

73 port=settings.redis.port, 

74 password=settings.redis.password, 

75 ) 

76 bus = RedisBus(redis) 

77 yield bus 

78 

79 

80async def get_optional_user( 

81 settings: Annotated[Settings, Depends(get_settings)], 

82 db: Annotated[Database, Depends(create_database)], 

83 access_token: Annotated[str | None, Cookie()] = None, 

84) -> UserEntity | None: 

85 """Try to get the current user from an access token. 

86 

87 When no token is available in the request, None will be returned. 

88 

89 Not authorized will be raised when the access token is expired, revoked 

90 or when the user is revoked. 

91 """ 

92 if access_token is None: 

93 return None 

94 

95 return await _get_user_from_token(access_token, settings.security, db) 

96 

97 

98async def _get_user_from_token( 

99 token: str, security_settings: SecuritySettings, db: Database 

100) -> UserEntity: 

101 """Try to get the user from the token. 

102 

103 Returns: The user associated with the access token. 

104 """ 

105 try: 

106 payload = jwt.decode( 

107 token, 

108 security_settings.jwt_secret, 

109 algorithms=[security_settings.jwt_algorithm], 

110 ) 

111 except ExpiredSignatureError as exc: 

112 raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc 

113 

114 access_token_repo = AccessTokenDbRepository(db) 

115 try: 

116 access_token = await access_token_repo.get_by_identifier( 

117 TokenIdentifier(hex_string=payload["jti"]) 

118 ) 

119 except AccessTokenNotFoundException as exc: 

120 raise HTTPException( 

121 status.HTTP_401_UNAUTHORIZED, detail="The access token is unknown." 

122 ) from exc 

123 

124 # Check if the access token is assigned to the user we have in the subject of JWT. 

125 if not access_token.user_account.user.uuid == payload["sub"]: 

126 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

127 

128 if access_token.revoked: 

129 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

130 

131 if access_token.user_account.revoked: 

132 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

133 

134 if access_token.expired: 

135 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

136 

137 return access_token.user_account.user