Coverage for src/kwai/modules/identity/tokens/refresh_token_db_repository.py: 82%
39 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that implements a refresh token repository for a database."""
3from typing import AsyncIterator
5from kwai.core.db.database import Database
6from kwai.modules.identity.tokens.refresh_token import (
7 RefreshTokenEntity,
8 RefreshTokenIdentifier,
9)
10from kwai.modules.identity.tokens.refresh_token_db_query import RefreshTokenDbQuery
11from kwai.modules.identity.tokens.refresh_token_query import RefreshTokenQuery
12from kwai.modules.identity.tokens.refresh_token_repository import (
13 RefreshTokenNotFoundException,
14 RefreshTokenRepository,
15)
16from kwai.modules.identity.tokens.token_identifier import TokenIdentifier
17from kwai.modules.identity.tokens.token_tables import (
18 AccessTokenRow,
19 RefreshTokenRow,
20)
21from kwai.modules.identity.users.user_tables import UserAccountRow
24def _create_entity(row) -> RefreshTokenEntity:
25 """Create a refresh token entity from a row."""
26 return RefreshTokenRow.map(row).create_entity(
27 AccessTokenRow.map(row).create_entity(UserAccountRow.map(row).create_entity())
28 )
31class RefreshTokenDbRepository(RefreshTokenRepository):
32 """Database repository for the refresh token entity."""
34 def __init__(self, database: Database):
35 self._database = database
37 def create_query(self) -> RefreshTokenQuery:
38 return RefreshTokenDbQuery(self._database)
40 async def get_by_token_identifier(
41 self, identifier: TokenIdentifier
42 ) -> RefreshTokenEntity:
43 query = self.create_query()
44 query.filter_by_token_identifier(identifier)
46 row = await query.fetch_one()
47 if row:
48 return _create_entity(row)
50 raise RefreshTokenNotFoundException()
52 async def get(self, id_: RefreshTokenIdentifier) -> RefreshTokenEntity:
53 query = self.create_query()
54 query.filter_by_id(id_.value)
55 row = await query.fetch_one()
56 if row:
57 return _create_entity(row)
59 raise RefreshTokenNotFoundException()
61 async def get_all(
62 self,
63 query: RefreshTokenDbQuery | None = None,
64 limit: int | None = None,
65 offset: int | None = None,
66 ) -> AsyncIterator[RefreshTokenEntity]:
67 query = query or self.create_query()
68 async for row in query.fetch(limit, offset):
69 yield _create_entity(row)
71 async def create(self, refresh_token: RefreshTokenEntity) -> RefreshTokenEntity:
72 new_id = await self._database.insert(
73 RefreshTokenRow.__table_name__, RefreshTokenRow.persist(refresh_token)
74 )
75 return refresh_token.set_id(RefreshTokenIdentifier(new_id))
77 async def update(self, refresh_token: RefreshTokenEntity):
78 await self._database.update(
79 refresh_token.id.value,
80 RefreshTokenRow.__table_name__,
81 RefreshTokenRow.persist(refresh_token),
82 )