Coverage for src/kwai/api/v1/auth/endpoints/login.py: 81%

100 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2024-01-01 00:00 +0000

1"""Module that implements all APIs for login.""" 

2 

3from typing import Annotated 

4 

5import jwt 

6 

7from fastapi import ( 

8 APIRouter, 

9 Cookie, 

10 Depends, 

11 Form, 

12 Header, 

13 HTTPException, 

14 Request, 

15 status, 

16) 

17from fastapi.responses import Response 

18from fastapi.security import OAuth2PasswordRequestForm 

19from loguru import logger 

20 

21from kwai.api.dependencies import create_database, get_publisher 

22from kwai.api.v1.auth.cookies import create_cookies, delete_cookies 

23from kwai.core.db.database import Database 

24from kwai.core.db.uow import UnitOfWork 

25from kwai.core.domain.exceptions import UnprocessableException 

26from kwai.core.domain.value_objects.email_address import InvalidEmailException 

27from kwai.core.events.publisher import Publisher 

28from kwai.core.settings import Settings, get_settings 

29from kwai.modules.identity.authenticate_user import ( 

30 AuthenticateUser, 

31 AuthenticateUserCommand, 

32 AuthenticationException, 

33) 

34from kwai.modules.identity.exceptions import NotAllowedException 

35from kwai.modules.identity.logout import Logout, LogoutCommand 

36from kwai.modules.identity.recover_user import RecoverUser, RecoverUserCommand 

37from kwai.modules.identity.refresh_access_token import ( 

38 RefreshAccessToken, 

39 RefreshAccessTokenCommand, 

40) 

41from kwai.modules.identity.reset_password import ( 

42 ResetPassword, 

43 ResetPasswordCommand, 

44 UserRecoveryConfirmedException, 

45) 

46from kwai.modules.identity.tokens.access_token_db_repository import ( 

47 AccessTokenDbRepository, 

48) 

49from kwai.modules.identity.tokens.log_user_login_db_service import LogUserLoginDbService 

50from kwai.modules.identity.tokens.refresh_token_db_repository import ( 

51 RefreshTokenDbRepository, 

52) 

53from kwai.modules.identity.tokens.refresh_token_repository import ( 

54 RefreshTokenNotFoundException, 

55) 

56from kwai.modules.identity.user_recoveries.user_recovery_db_repository import ( 

57 UserRecoveryDbRepository, 

58) 

59from kwai.modules.identity.user_recoveries.user_recovery_repository import ( 

60 UserRecoveryNotFoundException, 

61) 

62from kwai.modules.identity.users.user_account_db_repository import ( 

63 UserAccountDbRepository, 

64) 

65from kwai.modules.identity.users.user_account_repository import ( 

66 UserAccountNotFoundException, 

67) 

68 

69 

70router = APIRouter() 

71 

72 

73@router.post( 

74 "/login", 

75 summary="Create access and refresh token for a user.", 

76 responses={ 

77 200: {"description": "The user is logged in successfully."}, 

78 401: { 

79 "description": "The email is invalid, authentication failed or user is unknown." 

80 }, 

81 }, 

82) 

83async def login( 

84 request: Request, 

85 settings: Annotated[Settings, Depends(get_settings)], 

86 db: Annotated[Database, Depends(create_database)], 

87 form_data: Annotated[OAuth2PasswordRequestForm, Depends()], 

88 response: Response, 

89 x_forwarded_for: Annotated[str | None, Header()] = None, 

90 user_agent: Annotated[str | None, Header()] = "", 

91): 

92 """Login a user. 

93 

94 This request expects a form (application/x-www-form-urlencoded). The form 

95 must contain a `username` and `password` field. The username is 

96 the email address of the user. 

97 

98 On success, a cookie for the access token and the refresh token will be returned. 

99 """ 

100 command = AuthenticateUserCommand( 

101 username=form_data.username, 

102 password=form_data.password, 

103 access_token_expiry_minutes=settings.security.access_token_expires_in, 

104 refresh_token_expiry_minutes=settings.security.refresh_token_expires_in, 

105 ) 

106 

107 try: 

108 if x_forwarded_for: 

109 client_ip = x_forwarded_for 

110 else: 

111 client_ip = request.client.host if request.client else "" 

112 async with UnitOfWork(db, always_commit=True): 

113 refresh_token = await AuthenticateUser( 

114 UserAccountDbRepository(db), 

115 AccessTokenDbRepository(db), 

116 RefreshTokenDbRepository(db), 

117 LogUserLoginDbService( 

118 db, 

119 email=form_data.username, 

120 user_agent=user_agent or "", 

121 client_ip=client_ip, 

122 ), 

123 ).execute(command) 

124 except InvalidEmailException as exc: 

125 raise HTTPException( 

126 status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email address" 

127 ) from exc 

128 except AuthenticationException as exc: 

129 raise HTTPException( 

130 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

131 ) from exc 

132 except UserAccountNotFoundException as exc: 

133 raise HTTPException( 

134 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

135 ) from exc 

136 

137 create_cookies(response, refresh_token, settings) 

138 response.status_code = status.HTTP_200_OK 

139 

140 return response 

141 

142 

143@router.post( 

144 "/logout", 

145 summary="Logout the current user", 

146 responses={200: {"description": "The user is logged out successfully."}}, 

147) 

148async def logout( 

149 settings: Annotated[Settings, Depends(get_settings)], 

150 db: Annotated[Database, Depends(create_database)], 

151 response: Response, 

152 refresh_token: Annotated[str | None, Cookie()] = None, 

153) -> None: 

154 """Log out the current user. 

155 

156 A user is logged out by revoking the refresh token. The associated access token 

157 will also be revoked. 

158 

159 This request expects a form (application/x-www-form-urlencoded). The form 

160 must contain a **refresh_token** field. 

161 

162 Even when a token could not be found, the cookies will be deleted. 

163 """ 

164 if refresh_token: 

165 decoded_refresh_token = jwt.decode( 

166 refresh_token, 

167 key=settings.security.jwt_refresh_secret, 

168 algorithms=[settings.security.jwt_algorithm], 

169 ) 

170 command = LogoutCommand(identifier=decoded_refresh_token["jti"]) 

171 try: 

172 async with UnitOfWork(db): 

173 await Logout( 

174 refresh_token_repository=RefreshTokenDbRepository(db), 

175 access_token_repository=AccessTokenDbRepository(db), 

176 ).execute(command) 

177 except RefreshTokenNotFoundException: 

178 pass 

179 

180 delete_cookies(response) 

181 response.status_code = status.HTTP_200_OK 

182 

183 

184@router.post( 

185 "/access_token", 

186 summary="Renew an access token using a refresh token.", 

187 responses={ 

188 200: {"description": "The access token is renewed."}, 

189 401: {"description": "The refresh token is expired."}, 

190 }, 

191) 

192async def renew_access_token( 

193 request: Request, 

194 settings: Annotated[Settings, Depends(get_settings)], 

195 db: Annotated[Database, Depends(create_database)], 

196 refresh_token: Annotated[str, Cookie()], 

197 response: Response, 

198 x_forwarded_for: Annotated[str | None, Header()] = None, 

199 user_agent: Annotated[str | None, Header()] = "", 

200): 

201 """Refresh the access token. 

202 

203 On success, a new access token / refresh token cookie will be sent. 

204 

205 When the refresh token is expired, the user needs to log in again. 

206 """ 

207 try: 

208 decoded_refresh_token = jwt.decode( 

209 refresh_token, 

210 key=settings.security.jwt_refresh_secret, 

211 algorithms=[settings.security.jwt_algorithm], 

212 ) 

213 except jwt.ExpiredSignatureError as exc: 

214 raise HTTPException( 

215 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

216 ) from exc 

217 

218 command = RefreshAccessTokenCommand( 

219 identifier=decoded_refresh_token["jti"], 

220 access_token_expiry_minutes=settings.security.access_token_expires_in, 

221 refresh_token_expiry_minutes=settings.security.refresh_token_expires_in, 

222 ) 

223 

224 try: 

225 if x_forwarded_for: 

226 client_ip = x_forwarded_for 

227 else: 

228 client_ip = request.client.host if request.client else "" 

229 

230 async with UnitOfWork(db, always_commit=True): 

231 new_refresh_token = await RefreshAccessToken( 

232 RefreshTokenDbRepository(db), 

233 AccessTokenDbRepository(db), 

234 LogUserLoginDbService( 

235 db, 

236 email="", 

237 user_agent=user_agent or "", 

238 client_ip=client_ip, 

239 ), 

240 ).execute(command) 

241 except AuthenticationException as exc: 

242 raise HTTPException( 

243 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

244 ) from exc 

245 

246 create_cookies(response, new_refresh_token, settings) 

247 response.status_code = status.HTTP_200_OK 

248 

249 

250@router.post( 

251 "/recover", 

252 summary="Initiate a password reset flow", 

253 responses={ 

254 200: {"description": "Ok."}, 

255 }, 

256) 

257async def recover_user( 

258 db: Annotated[Database, Depends(create_database)], 

259 publisher: Annotated[Publisher, Depends(get_publisher)], 

260 email: Annotated[str, Form()], 

261) -> None: 

262 """Start a recover password flow for the given email address. 

263 

264 A mail with a unique id will be sent using the message bus. 

265 

266 This request expects a form (application/x-www-form-urlencoded). The form 

267 must contain an **email** field. 

268 

269 !!! Note 

270 To avoid leaking information, this api will always respond with 200 

271 """ 

272 command = RecoverUserCommand(email=email) 

273 try: 

274 async with UnitOfWork(db): 

275 await RecoverUser( 

276 UserAccountDbRepository(db), UserRecoveryDbRepository(db), publisher 

277 ).execute(command) 

278 except UserAccountNotFoundException: 

279 logger.warning(f"Unknown email address used for a password recovery: {email}") 

280 except UnprocessableException as ex: 

281 logger.warning(f"User recovery could not be started: {ex}") 

282 

283 

284@router.post( 

285 "/reset", 

286 summary="Reset the password of a user.", 

287 responses={ # noqa B006 

288 200: {"description": "The password is reset successfully."}, 

289 400: {"description": "The reset code was already used."}, 

290 403: {"description": "This request is forbidden."}, 

291 404: {"description": "The uniqued id of the recovery could not be found."}, 

292 422: {"description": "The user could not be found."}, 

293 }, 

294) 

295async def reset_password( 

296 uuid: Annotated[str, Form()], 

297 password: Annotated[str, Form()], 

298 db: Annotated[Database, Depends(create_database)], 

299): 

300 """Reset the password of the user. 

301 

302 Http code 200 on success, 404 when the unique id is invalid, 422 when the 

303 request can't be processed, 403 when the request is forbidden. 

304 

305 This request expects a form (application/x-www-form-urlencoded). The form 

306 must contain an **uuid** and **password** field. The unique id must be valid 

307 and is retrieved by [/api/v1/auth/recover][post_/recover]. 

308 """ 

309 command = ResetPasswordCommand(uuid=uuid, password=password) 

310 try: 

311 async with UnitOfWork(db): 

312 await ResetPassword( 

313 user_account_repo=UserAccountDbRepository(db), 

314 user_recovery_repo=UserRecoveryDbRepository(db), 

315 ).execute(command) 

316 except UserRecoveryNotFoundException as exc: 

317 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) from exc 

318 except UserAccountNotFoundException as exc: 

319 raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) from exc 

320 except UserRecoveryConfirmedException as exc: 

321 raise HTTPException( 

322 status_code=status.HTTP_400_BAD_REQUEST, 

323 detail="Reset code was already used.", 

324 ) from exc 

325 except NotAllowedException as exc: 

326 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) from exc