Coverage for src/kwai/api/dependencies.py: 81%

57 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that integrates the dependencies in FastAPI.""" 

2 

3from typing import Annotated, AsyncGenerator 

4 

5import jwt 

6 

7from fastapi import Cookie, Depends, HTTPException, status 

8from fastapi.security import OAuth2PasswordBearer 

9from fastapi.templating import Jinja2Templates 

10from faststream.redis import RedisBroker 

11from faststream.security import SASLPlaintext 

12from jwt import ExpiredSignatureError 

13 

14from kwai.core.db.database import Database 

15from kwai.core.events.fast_stream_publisher import FastStreamPublisher 

16from kwai.core.events.publisher import Publisher 

17from kwai.core.settings import SecuritySettings, Settings, get_settings 

18from kwai.core.template.jinja2_engine import Jinja2Engine 

19from kwai.modules.identity.tokens.access_token_db_repository import ( 

20 AccessTokenDbRepository, 

21) 

22from kwai.modules.identity.tokens.access_token_repository import ( 

23 AccessTokenNotFoundException, 

24) 

25from kwai.modules.identity.tokens.token_identifier import TokenIdentifier 

26from kwai.modules.identity.users.user import UserEntity 

27 

28 

29oauth = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") 

30 

31 

32async def create_database( 

33 settings=Depends(get_settings), 

34) -> AsyncGenerator[Database, None]: 

35 """Create the database dependency.""" 

36 database = Database(settings.db) 

37 try: 

38 yield database 

39 finally: 

40 await database.close() 

41 

42 

43async def create_templates(settings=Depends(get_settings)) -> Jinja2Templates: 

44 """Create the template engine dependency.""" 

45 return Jinja2Engine(website=settings.website).web_templates 

46 

47 

48async def get_current_user( 

49 settings: Annotated[Settings, Depends(get_settings)], 

50 db: Annotated[Database, Depends(create_database)], 

51 access_token: Annotated[str | None, Cookie()] = None, 

52) -> UserEntity: 

53 """Try to get the current user from the access token. 

54 

55 Not authorized will be raised when the access token is not found, expired, revoked 

56 or when the user is revoked. 

57 """ 

58 if not access_token: 

59 raise HTTPException( 

60 status.HTTP_401_UNAUTHORIZED, detail="Access token cookie missing" 

61 ) 

62 return await _get_user_from_token(access_token, settings.security, db) 

63 

64 

65optional_oauth = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False) 

66 

67 

68async def get_publisher( 

69 settings=Depends(get_settings), 

70) -> AsyncGenerator[Publisher, None]: 

71 """Get the publisher dependency.""" 

72 broker = RedisBroker( 

73 url=f"redis://{settings.redis.host}:{settings.redis.port}", 

74 # middlewares=[LoggerMiddleware], 

75 security=SASLPlaintext( 

76 username="", 

77 password=settings.redis.password, 

78 ), 

79 ) 

80 await broker.start() 

81 yield FastStreamPublisher(broker) 

82 

83 

84async def get_optional_user( 

85 settings: Annotated[Settings, Depends(get_settings)], 

86 db: Annotated[Database, Depends(create_database)], 

87 access_token: Annotated[str | None, Cookie()] = None, 

88) -> UserEntity | None: 

89 """Try to get the current user from an access token. 

90 

91 When no token is available in the request, None will be returned. 

92 

93 Not authorized will be raised when the access token is expired, revoked 

94 or when the user is revoked. 

95 """ 

96 if access_token is None: 

97 return None 

98 

99 return await _get_user_from_token(access_token, settings.security, db) 

100 

101 

102async def _get_user_from_token( 

103 token: str, security_settings: SecuritySettings, db: Database 

104) -> UserEntity: 

105 """Try to get the user from the token. 

106 

107 Returns: The user associated with the access token. 

108 """ 

109 try: 

110 payload = jwt.decode( 

111 token, 

112 security_settings.jwt_secret, 

113 algorithms=[security_settings.jwt_algorithm], 

114 ) 

115 except ExpiredSignatureError as exc: 

116 raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc 

117 

118 access_token_repo = AccessTokenDbRepository(db) 

119 try: 

120 access_token = await access_token_repo.get_by_identifier( 

121 TokenIdentifier(hex_string=payload["jti"]) 

122 ) 

123 except AccessTokenNotFoundException as exc: 

124 raise HTTPException( 

125 status.HTTP_401_UNAUTHORIZED, detail="The access token is unknown." 

126 ) from exc 

127 

128 # Check if the access token is assigned to the user we have in the subject of JWT. 

129 if not access_token.user_account.user.uuid == payload["sub"]: 

130 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

131 

132 if access_token.revoked: 

133 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

134 

135 if access_token.user_account.revoked: 

136 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

137 

138 if access_token.expired: 

139 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

140 

141 return access_token.user_account.user