Coverage for kwai/api/v1/auth/endpoints/login.py: 84%

82 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module that implements all APIs for login.""" 

2 

3import jwt 

4from fastapi import APIRouter, Depends, Form, HTTPException, Response, status 

5from fastapi.security import OAuth2PasswordRequestForm 

6from loguru import logger 

7from pydantic import BaseModel 

8 

9from kwai.api.dependencies import deps, get_current_user 

10from kwai.core.db.database import Database 

11from kwai.core.domain.exceptions import UnprocessableException 

12from kwai.core.domain.value_objects.email_address import InvalidEmailException 

13from kwai.core.events.bus import Bus 

14from kwai.core.settings import SecuritySettings, Settings 

15from kwai.modules.identity.authenticate_user import ( 

16 AuthenticateUser, 

17 AuthenticateUserCommand, 

18 AuthenticationException, 

19) 

20from kwai.modules.identity.exceptions import NotAllowedException 

21from kwai.modules.identity.logout import Logout, LogoutCommand 

22from kwai.modules.identity.recover_user import RecoverUser, RecoverUserCommand 

23from kwai.modules.identity.refresh_access_token import ( 

24 RefreshAccessToken, 

25 RefreshAccessTokenCommand, 

26) 

27from kwai.modules.identity.reset_password import ResetPassword, ResetPasswordCommand 

28from kwai.modules.identity.tokens.access_token_db_repository import ( 

29 AccessTokenDbRepository, 

30) 

31from kwai.modules.identity.tokens.refresh_token import RefreshTokenEntity 

32from kwai.modules.identity.tokens.refresh_token_db_repository import ( 

33 RefreshTokenDbRepository, 

34) 

35from kwai.modules.identity.tokens.refresh_token_repository import ( 

36 RefreshTokenNotFoundException, 

37) 

38from kwai.modules.identity.user_recoveries.user_recovery_db_repository import ( 

39 UserRecoveryDbRepository, 

40) 

41from kwai.modules.identity.user_recoveries.user_recovery_repository import ( 

42 UserRecoveryNotFoundException, 

43) 

44from kwai.modules.identity.users.user import UserEntity 

45from kwai.modules.identity.users.user_account_db_repository import ( 

46 UserAccountDbRepository, 

47) 

48from kwai.modules.identity.users.user_account_repository import ( 

49 UserAccountNotFoundException, 

50) 

51 

52 

53class TokenSchema(BaseModel): 

54 """The response schema for an access/refresh token. 

55 

56 Attributes: 

57 access_token: 

58 refresh_token: 

59 expiration: Timestamp in format YYYY-MM-DD HH:MM:SS 

60 """ 

61 

62 access_token: str 

63 refresh_token: str 

64 expiration: str 

65 

66 

67router = APIRouter() 

68 

69 

70@router.post( 

71 "/login", 

72 response_model=TokenSchema, 

73 summary="Create access and refresh token for a user.", 

74) 

75async def login( 

76 settings=deps.depends(Settings), 

77 db=deps.depends(Database), 

78 form_data: OAuth2PasswordRequestForm = Depends(), 

79): 

80 """Login a user. 

81 

82 The response is a TokenSchema. 

83 

84 Note: 

85 This request expects a form (application/x-www-form-urlencoded). 

86 

87 Args: 

88 settings: Settings dependency 

89 db: Database dependency 

90 form_data: Form data that contains the username and password 

91 """ 

92 command = AuthenticateUserCommand( 

93 username=form_data.username, 

94 password=form_data.password, 

95 access_token_expiry_minutes=settings.security.access_token_expires_in, 

96 refresh_token_expiry_minutes=settings.security.refresh_token_expires_in, 

97 ) 

98 

99 try: 

100 refresh_token = await AuthenticateUser( 

101 UserAccountDbRepository(db), 

102 AccessTokenDbRepository(db), 

103 RefreshTokenDbRepository(db), 

104 ).execute(command) 

105 except InvalidEmailException as exc: 

106 raise HTTPException( 

107 status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email address" 

108 ) from exc 

109 except AuthenticationException as exc: 

110 raise HTTPException( 

111 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

112 ) from exc 

113 except UserAccountNotFoundException as exc: 

114 raise HTTPException( 

115 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

116 ) from exc 

117 

118 return _encode_token(refresh_token, settings.security) 

119 

120 

121@router.post("/logout", summary="Logout the current user") 

122async def logout( 

123 settings=deps.depends(Settings), 

124 db: Database = deps.depends(Database), 

125 user: UserEntity = Depends(get_current_user), 

126 refresh_token: str = Form(), 

127): 

128 """Log out the current user. 

129 

130 A user is logged out by revoking the refresh token. The associated access token 

131 will also be revoked. 

132 

133 Args: 

134 settings: Settings dependency 

135 db: Database dependency 

136 user: The currently logged-in user 

137 refresh_token: The active refresh token of the user 

138 

139 Returns: 

140 Http code 200 on success, 401 when the user is not logged in, 

141 404 when the refresh token is not found. 

142 """ 

143 decoded_refresh_token = jwt.decode( 

144 refresh_token, 

145 key=settings.security.jwt_refresh_secret, 

146 algorithms=[settings.security.jwt_algorithm], 

147 ) 

148 command = LogoutCommand(identifier=decoded_refresh_token["jti"]) 

149 try: 

150 await Logout( 

151 refresh_token_repository=RefreshTokenDbRepository(db), 

152 access_token_repository=AccessTokenDbRepository(db), 

153 ).execute(command) 

154 except RefreshTokenNotFoundException as ex: 

155 raise HTTPException( 

156 status_code=status.HTTP_404_NOT_FOUND, detail=str(ex) 

157 ) from ex 

158 

159 

160@router.post( 

161 "/access_token", 

162 response_model=TokenSchema, 

163 summary="Renew an access token using a refresh token.", 

164) 

165async def renew_access_token( 

166 settings=deps.depends(Settings), 

167 db=deps.depends(Database), 

168 refresh_token: str = Form(), 

169): 

170 """Refresh the access token. 

171 

172 Args: 

173 settings(Settings): Settings dependency 

174 db(Database): Database dependency 

175 refresh_token(str): The active refresh token of the user 

176 

177 Returns: 

178 TokenSchema: On success a new TokenSchema is returned. 

179 """ 

180 decoded_refresh_token = jwt.decode( 

181 refresh_token, 

182 key=settings.security.jwt_refresh_secret, 

183 algorithms=[settings.security.jwt_algorithm], 

184 ) 

185 

186 command = RefreshAccessTokenCommand( 

187 identifier=decoded_refresh_token["jti"], 

188 access_token_expiry_minutes=settings.security.access_token_expires_in, 

189 refresh_token_expiry_minutes=settings.security.refresh_token_expires_in, 

190 ) 

191 

192 try: 

193 new_refresh_token = await RefreshAccessToken( 

194 RefreshTokenDbRepository(db), AccessTokenDbRepository(db) 

195 ).execute(command) 

196 except AuthenticationException as exc: 

197 raise HTTPException( 

198 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

199 ) from exc 

200 

201 return _encode_token(new_refresh_token, settings.security) 

202 

203 

204@router.post( 

205 "/recover", 

206 summary="Initiate a password reset flow", 

207 status_code=status.HTTP_200_OK, 

208 response_class=Response, 

209) 

210async def recover_user( 

211 email: str = Form(), db=deps.depends(Database), bus=deps.depends(Bus) 

212) -> None: 

213 """Start a recover password flow for the given email address. 

214 

215 A mail with a unique id will be sent using the message bus. 

216 

217 Note: 

218 To avoid leaking information, this api will always respond with 200 

219 

220 Args: 

221 email(str): The email of the user that wants to reset the password. 

222 db(Database): Database dependency 

223 bus(Bus): A message bus used to publish the event 

224 """ 

225 command = RecoverUserCommand(email=email) 

226 try: 

227 await RecoverUser( 

228 UserAccountDbRepository(db), UserRecoveryDbRepository(db), bus 

229 ).execute(command) 

230 except UserAccountNotFoundException: 

231 logger.warning(f"Unknown email address used for a password recovery: {email}") 

232 except UnprocessableException as ex: 

233 logger.warning(f"User recovery could not be started: {ex}") 

234 

235 

236@router.post( 

237 "/reset", 

238 summary="Reset the password of a user.", 

239 status_code=status.HTTP_200_OK, 

240) 

241async def reset_password(uuid=Form(), password=Form(), db=deps.depends(Database)): 

242 """Reset the password of the user. 

243 

244 Args: 

245 uuid(str): The unique id of the password recovery. 

246 password(str): The new password 

247 db(Database): Database dependency 

248 

249 Returns: 

250 Http code 200 on success, 404 when the unique is invalid, 422 when the 

251 request can't be processed, 403 when the request is forbidden. 

252 """ 

253 command = ResetPasswordCommand(uuid=uuid, password=password) 

254 try: 

255 await ResetPassword( 

256 user_account_repo=UserAccountDbRepository(db), 

257 user_recovery_repo=UserRecoveryDbRepository(db), 

258 ).execute(command) 

259 except UserRecoveryNotFoundException as exc: 

260 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) from exc 

261 except UserAccountNotFoundException as exc: 

262 raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) from exc 

263 except NotAllowedException as exc: 

264 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) from exc 

265 

266 

267def _encode_token( 

268 refresh_token: RefreshTokenEntity, settings: SecuritySettings 

269) -> TokenSchema: 

270 """Encode the access and refresh token with JWT. 

271 

272 Args: 

273 refresh_token: The refresh token entity. 

274 settings: The security settings. 

275 

276 Returns: 

277 A dictionary with the access token, refresh token and expiration timestamp. 

278 """ 

279 return TokenSchema( 

280 access_token=jwt.encode( 

281 { 

282 "iat": refresh_token.access_token.traceable_time.created_at.timestamp, 

283 "exp": refresh_token.access_token.expiration, 

284 "jti": str(refresh_token.access_token.identifier), 

285 "sub": str(refresh_token.access_token.user_account.user.uuid), 

286 "scope": [], 

287 }, 

288 settings.jwt_secret, 

289 settings.jwt_algorithm, 

290 ), 

291 refresh_token=jwt.encode( 

292 { 

293 "iat": refresh_token.traceable_time.created_at.timestamp, 

294 "exp": refresh_token.expiration, 

295 "jti": str(refresh_token.identifier), 

296 }, 

297 settings.jwt_refresh_secret, 

298 settings.jwt_algorithm, 

299 ), 

300 expiration=refresh_token.access_token.expiration.isoformat(" ", "seconds"), 

301 )