Coverage for kwai/modules/identity/tokens/refresh_token_db_repository.py: 83%
42 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module that implements a refresh token repository for a database."""
2from typing import AsyncIterator
4from kwai.core.db.database import Database
5from kwai.core.domain.entity import Entity
6from kwai.modules.identity.tokens.refresh_token import (
7 RefreshTokenEntity,
8 RefreshTokenIdentifier,
9)
10from kwai.modules.identity.tokens.refresh_token_db_query import RefreshTokenDbQuery
11from kwai.modules.identity.tokens.refresh_token_query import RefreshTokenQuery
12from kwai.modules.identity.tokens.refresh_token_repository import (
13 RefreshTokenNotFoundException,
14 RefreshTokenRepository,
15)
16from kwai.modules.identity.tokens.token_identifier import TokenIdentifier
17from kwai.modules.identity.tokens.token_tables import (
18 AccessTokensTable,
19 RefreshTokenRow,
20 RefreshTokensTable,
21)
22from kwai.modules.identity.users.user_tables import (
23 UserAccountsTable,
24)
27def _create_entity(row) -> RefreshTokenEntity:
28 """Create a refresh token entity from a row."""
29 return RefreshTokensTable(row).create_entity(
30 AccessTokensTable(row).create_entity(UserAccountsTable(row).create_entity())
31 )
34class RefreshTokenDbRepository(RefreshTokenRepository):
35 """Database repository for the refresh token entity."""
37 def __init__(self, database: Database):
38 self._database = database
40 def create_query(self) -> RefreshTokenQuery:
41 return RefreshTokenDbQuery(self._database)
43 async def get_by_token_identifier(
44 self, identifier: TokenIdentifier
45 ) -> RefreshTokenEntity:
46 query = self.create_query()
47 query.filter_by_token_identifier(identifier)
49 row = await query.fetch_one()
50 if row:
51 return _create_entity(row)
53 raise RefreshTokenNotFoundException()
55 async def get(self, id_: RefreshTokenIdentifier) -> RefreshTokenEntity:
56 query = self.create_query()
57 query.filter_by_id(id_.value)
58 row = await query.fetch_one()
59 if row:
60 return _create_entity(row)
62 raise RefreshTokenNotFoundException()
64 async def get_all(
65 self,
66 query: RefreshTokenDbQuery | None = None,
67 limit: int | None = None,
68 offset: int | None = None,
69 ) -> AsyncIterator[RefreshTokenEntity]:
70 query = query or self.create_query()
71 async for row in query.fetch(limit, offset):
72 yield _create_entity(row)
74 async def create(self, refresh_token: RefreshTokenEntity) -> RefreshTokenEntity:
75 new_id = await self._database.insert(
76 RefreshTokensTable.table_name, RefreshTokenRow.persist(refresh_token)
77 )
78 await self._database.commit()
79 return Entity.replace(refresh_token, id_=RefreshTokenIdentifier(new_id))
81 async def update(self, refresh_token: RefreshTokenEntity):
82 await self._database.update(
83 refresh_token.id.value,
84 RefreshTokensTable.table_name,
85 RefreshTokenRow.persist(refresh_token),
86 )
87 await self._database.commit()