Coverage for src/kwai/api/v1/auth/endpoints/login.py: 86%

104 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements all APIs for login.""" 

2 

3from typing import Annotated 

4 

5import jwt 

6 

7from fastapi import APIRouter, Cookie, Depends, Form, HTTPException, status 

8from fastapi.responses import Response 

9from fastapi.security import OAuth2PasswordRequestForm 

10from jwt import ExpiredSignatureError 

11from loguru import logger 

12 

13from kwai.api.dependencies import create_database, get_current_user, get_publisher 

14from kwai.core.db.database import Database 

15from kwai.core.db.uow import UnitOfWork 

16from kwai.core.domain.exceptions import UnprocessableException 

17from kwai.core.domain.value_objects.email_address import InvalidEmailException 

18from kwai.core.events.publisher import Publisher 

19from kwai.core.settings import Settings, get_settings 

20from kwai.modules.identity.authenticate_user import ( 

21 AuthenticateUser, 

22 AuthenticateUserCommand, 

23 AuthenticationException, 

24) 

25from kwai.modules.identity.exceptions import NotAllowedException 

26from kwai.modules.identity.logout import Logout, LogoutCommand 

27from kwai.modules.identity.recover_user import RecoverUser, RecoverUserCommand 

28from kwai.modules.identity.refresh_access_token import ( 

29 RefreshAccessToken, 

30 RefreshAccessTokenCommand, 

31) 

32from kwai.modules.identity.reset_password import ResetPassword, ResetPasswordCommand 

33from kwai.modules.identity.tokens.access_token_db_repository import ( 

34 AccessTokenDbRepository, 

35) 

36from kwai.modules.identity.tokens.refresh_token import RefreshTokenEntity 

37from kwai.modules.identity.tokens.refresh_token_db_repository import ( 

38 RefreshTokenDbRepository, 

39) 

40from kwai.modules.identity.tokens.refresh_token_repository import ( 

41 RefreshTokenNotFoundException, 

42) 

43from kwai.modules.identity.user_recoveries.user_recovery_db_repository import ( 

44 UserRecoveryDbRepository, 

45) 

46from kwai.modules.identity.user_recoveries.user_recovery_repository import ( 

47 UserRecoveryNotFoundException, 

48) 

49from kwai.modules.identity.users.user import UserEntity 

50from kwai.modules.identity.users.user_account_db_repository import ( 

51 UserAccountDbRepository, 

52) 

53from kwai.modules.identity.users.user_account_repository import ( 

54 UserAccountNotFoundException, 

55) 

56 

57 

58COOKIE_ACCESS_TOKEN = "access_token" 

59COOKIE_REFRESH_TOKEN = "refresh_token" 

60COOKIE_KWAI = "kwai" 

61 

62 

63router = APIRouter() 

64 

65 

66@router.post( 

67 "/login", 

68 summary="Create access and refresh token for a user.", 

69 responses={ 

70 200: {"description": "The user is logged in successfully."}, 

71 401: { 

72 "description": "The email is invalid, authentication failed or user is unknown." 

73 }, 

74 }, 

75) 

76async def login( 

77 settings: Annotated[Settings, Depends(get_settings)], 

78 db: Annotated[Database, Depends(create_database)], 

79 form_data: Annotated[OAuth2PasswordRequestForm, Depends()], 

80 response: Response, 

81): 

82 """Login a user. 

83 

84 This request expects a form (application/x-www-form-urlencoded). The form 

85 must contain a `username` and `password` field. The username is 

86 the email address of the user. 

87 

88 On success, a cookie for the access token and the refresh token will be returned. 

89 """ 

90 command = AuthenticateUserCommand( 

91 username=form_data.username, 

92 password=form_data.password, 

93 access_token_expiry_minutes=settings.security.access_token_expires_in, 

94 refresh_token_expiry_minutes=settings.security.refresh_token_expires_in, 

95 ) 

96 

97 try: 

98 async with UnitOfWork(db): 

99 refresh_token = await AuthenticateUser( 

100 UserAccountDbRepository(db), 

101 AccessTokenDbRepository(db), 

102 RefreshTokenDbRepository(db), 

103 ).execute(command) 

104 except InvalidEmailException as exc: 

105 raise HTTPException( 

106 status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email address" 

107 ) from exc 

108 except AuthenticationException as exc: 

109 raise HTTPException( 

110 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

111 ) from exc 

112 except UserAccountNotFoundException as exc: 

113 raise HTTPException( 

114 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

115 ) from exc 

116 

117 _create_cookie(response, refresh_token, settings) 

118 response.status_code = status.HTTP_200_OK 

119 

120 return response 

121 

122 

123@router.post( 

124 "/logout", 

125 summary="Logout the current user", 

126 responses={ 

127 200: {"description": "The user is logged out successfully."}, 

128 404: {"description": "The token is not found."}, 

129 }, 

130) 

131async def logout( 

132 settings: Annotated[Settings, Depends(get_settings)], 

133 db: Annotated[Database, Depends(create_database)], 

134 user: Annotated[UserEntity, Depends(get_current_user)], # noqa 

135 response: Response, 

136 refresh_token: Annotated[str | None, Cookie()] = None, 

137) -> None: 

138 """Log out the current user. 

139 

140 A user is logged out by revoking the refresh token. The associated access token 

141 will also be revoked. 

142 

143 This request expects a form (application/x-www-form-urlencoded). The form 

144 must contain a **refresh_token** field. 

145 """ 

146 if refresh_token: 

147 decoded_refresh_token = jwt.decode( 

148 refresh_token, 

149 key=settings.security.jwt_refresh_secret, 

150 algorithms=[settings.security.jwt_algorithm], 

151 ) 

152 command = LogoutCommand(identifier=decoded_refresh_token["jti"]) 

153 try: 

154 async with UnitOfWork(db): 

155 await Logout( 

156 refresh_token_repository=RefreshTokenDbRepository(db), 

157 access_token_repository=AccessTokenDbRepository(db), 

158 ).execute(command) 

159 except RefreshTokenNotFoundException as ex: 

160 raise HTTPException( 

161 status_code=status.HTTP_404_NOT_FOUND, detail=str(ex) 

162 ) from ex 

163 

164 response.delete_cookie(key=COOKIE_KWAI) 

165 response.delete_cookie(key=COOKIE_ACCESS_TOKEN) 

166 response.delete_cookie(key=COOKIE_REFRESH_TOKEN) 

167 response.status_code = status.HTTP_200_OK 

168 

169 

170@router.post( 

171 "/access_token", 

172 summary="Renew an access token using a refresh token.", 

173 responses={ 

174 200: {"description": "The access token is renewed."}, 

175 401: {"description": "The refresh token is expired."}, 

176 }, 

177) 

178async def renew_access_token( 

179 settings: Annotated[Settings, Depends(get_settings)], 

180 db: Annotated[Database, Depends(create_database)], 

181 refresh_token: Annotated[str, Cookie()], 

182 response: Response, 

183): 

184 """Refresh the access token. 

185 

186 On success, a new access token / refresh token cookie will be sent. 

187 

188 When the refresh token is expired, the user needs to log in again. 

189 """ 

190 try: 

191 decoded_refresh_token = jwt.decode( 

192 refresh_token, 

193 key=settings.security.jwt_refresh_secret, 

194 algorithms=[settings.security.jwt_algorithm], 

195 ) 

196 except ExpiredSignatureError as exc: 

197 raise HTTPException( 

198 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

199 ) from exc 

200 

201 command = RefreshAccessTokenCommand( 

202 identifier=decoded_refresh_token["jti"], 

203 access_token_expiry_minutes=settings.security.access_token_expires_in, 

204 refresh_token_expiry_minutes=settings.security.refresh_token_expires_in, 

205 ) 

206 

207 try: 

208 async with UnitOfWork(db): 

209 new_refresh_token = await RefreshAccessToken( 

210 RefreshTokenDbRepository(db), AccessTokenDbRepository(db) 

211 ).execute(command) 

212 except AuthenticationException as exc: 

213 raise HTTPException( 

214 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

215 ) from exc 

216 

217 _create_cookie(response, new_refresh_token, settings) 

218 response.status_code = status.HTTP_200_OK 

219 

220 

221@router.post( 

222 "/recover", 

223 summary="Initiate a password reset flow", 

224 responses={ 

225 200: {"description": "Ok."}, 

226 }, 

227) 

228async def recover_user( 

229 db: Annotated[Database, Depends(create_database)], 

230 publisher: Annotated[Publisher, Depends(get_publisher)], 

231 email: Annotated[str, Form()], 

232) -> None: 

233 """Start a recover password flow for the given email address. 

234 

235 A mail with a unique id will be sent using the message bus. 

236 

237 This request expects a form (application/x-www-form-urlencoded). The form 

238 must contain an **email** field. 

239 

240 !!! Note 

241 To avoid leaking information, this api will always respond with 200 

242 """ 

243 command = RecoverUserCommand(email=email) 

244 try: 

245 async with UnitOfWork(db): 

246 await RecoverUser( 

247 UserAccountDbRepository(db), UserRecoveryDbRepository(db), publisher 

248 ).execute(command) 

249 except UserAccountNotFoundException: 

250 logger.warning(f"Unknown email address used for a password recovery: {email}") 

251 except UnprocessableException as ex: 

252 logger.warning(f"User recovery could not be started: {ex}") 

253 

254 

255@router.post( 

256 "/reset", 

257 summary="Reset the password of a user.", 

258 responses={ # noqa B006 

259 200: {"description": "The password is reset successfully."}, 

260 403: {"description": "This request is forbidden."}, 

261 404: {"description": "The uniqued id of the recovery could not be found."}, 

262 422: {"description": "The user could not be found."}, 

263 }, 

264) 

265async def reset_password( 

266 uuid: Annotated[str, Form()], 

267 password: Annotated[str, Form()], 

268 db: Annotated[Database, Depends(create_database)], 

269): 

270 """Reset the password of the user. 

271 

272 Http code 200 on success, 404 when the unique id is invalid, 422 when the 

273 request can't be processed, 403 when the request is forbidden. 

274 

275 This request expects a form (application/x-www-form-urlencoded). The form 

276 must contain an **uuid** and **password** field. The unique id must be valid 

277 and is retrieved by [/api/v1/auth/recover][post_/recover]. 

278 """ 

279 command = ResetPasswordCommand(uuid=uuid, password=password) 

280 try: 

281 async with UnitOfWork(db): 

282 await ResetPassword( 

283 user_account_repo=UserAccountDbRepository(db), 

284 user_recovery_repo=UserRecoveryDbRepository(db), 

285 ).execute(command) 

286 except UserRecoveryNotFoundException as exc: 

287 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) from exc 

288 except UserAccountNotFoundException as exc: 

289 raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) from exc 

290 except NotAllowedException as exc: 

291 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) from exc 

292 

293 

294def _create_cookie( 

295 response: Response, refresh_token: RefreshTokenEntity, settings: Settings 

296) -> None: 

297 """Create cookies for access en refresh token.""" 

298 encoded_access_token = jwt.encode( 

299 { 

300 "iat": refresh_token.access_token.traceable_time.created_at.timestamp, 

301 "exp": refresh_token.access_token.expiration.timestamp, 

302 "jti": str(refresh_token.access_token.identifier), 

303 "sub": str(refresh_token.access_token.user_account.user.uuid), 

304 "scope": [], 

305 }, 

306 settings.security.jwt_secret, 

307 settings.security.jwt_algorithm, 

308 ) 

309 encoded_refresh_token = jwt.encode( 

310 { 

311 "iat": refresh_token.traceable_time.created_at.timestamp, 

312 "exp": refresh_token.expiration.timestamp, 

313 "jti": str(refresh_token.identifier), 

314 }, 

315 settings.security.jwt_refresh_secret, 

316 settings.security.jwt_algorithm, 

317 ) 

318 response.set_cookie( 

319 key=COOKIE_KWAI, 

320 value="Y", 

321 expires=refresh_token.expiration.timestamp, 

322 secure=settings.frontend.test, 

323 ) 

324 response.set_cookie( 

325 key=COOKIE_ACCESS_TOKEN, 

326 value=encoded_access_token, 

327 expires=refresh_token.access_token.expiration.timestamp, 

328 httponly=True, 

329 secure=not settings.frontend.test, 

330 ) 

331 response.set_cookie( 

332 key=COOKIE_REFRESH_TOKEN, 

333 value=encoded_refresh_token, 

334 expires=refresh_token.expiration.timestamp, 

335 httponly=True, 

336 secure=not settings.frontend.test, 

337 )