Coverage for src/kwai/modules/identity/tokens/access_token_db_repository.py: 82%
39 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that implements an access token repository for a database."""
3from typing import Any, AsyncIterator
5from kwai.core.db.database import Database
6from kwai.modules.identity.tokens.access_token import (
7 AccessTokenEntity,
8 AccessTokenIdentifier,
9)
10from kwai.modules.identity.tokens.access_token_db_query import AccessTokenDbQuery
11from kwai.modules.identity.tokens.access_token_query import AccessTokenQuery
12from kwai.modules.identity.tokens.access_token_repository import (
13 AccessTokenNotFoundException,
14 AccessTokenRepository,
15)
16from kwai.modules.identity.tokens.token_identifier import TokenIdentifier
17from kwai.modules.identity.tokens.token_tables import AccessTokenRow
18from kwai.modules.identity.users.user_tables import UserAccountRow
21def _create_entity(row: dict[str, Any]) -> AccessTokenEntity:
22 """Create an access token entity from a row."""
23 return AccessTokenRow.map(row).create_entity(
24 UserAccountRow.map(row).create_entity()
25 )
28class AccessTokenDbRepository(AccessTokenRepository):
29 """Database repository for the access token entity."""
31 def __init__(self, database: Database):
32 self._database = database
34 def create_query(self) -> AccessTokenQuery:
35 return AccessTokenDbQuery(self._database)
37 async def get(self, id_: AccessTokenIdentifier) -> AccessTokenEntity:
38 query = self.create_query()
39 query.filter_by_id(id_.value)
41 row = await query.fetch_one()
42 if row:
43 return _create_entity(row)
45 raise AccessTokenNotFoundException
47 async def get_by_identifier(self, identifier: TokenIdentifier) -> AccessTokenEntity:
48 query = self.create_query()
49 query.filter_by_token_identifier(identifier)
51 row = await query.fetch_one()
52 if row:
53 return _create_entity(row)
55 raise AccessTokenNotFoundException
57 async def get_all(
58 self,
59 query: AccessTokenQuery | None = None,
60 limit: int | None = None,
61 offset: int | None = None,
62 ) -> AsyncIterator[AccessTokenEntity]:
63 query = query or self.create_query()
64 async for row in query.fetch(limit, offset):
65 yield _create_entity(row)
67 async def create(self, access_token: AccessTokenEntity) -> AccessTokenEntity:
68 new_id = await self._database.insert(
69 AccessTokenRow.__table_name__, AccessTokenRow.persist(access_token)
70 )
71 return access_token.set_id(AccessTokenIdentifier(new_id))
73 async def update(self, access_token: AccessTokenEntity):
74 await self._database.update(
75 access_token.id.value,
76 AccessTokenRow.__table_name__,
77 AccessTokenRow.persist(access_token),
78 )