Coverage for kwai/modules/identity/tokens/access_token_db_repository.py: 83%
42 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Module that implements an access token repository for a database."""
2from typing import Any, AsyncIterator
4from kwai.core.db.database import Database
5from kwai.core.domain.entity import Entity
6from kwai.modules.identity.tokens.access_token import (
7 AccessTokenEntity,
8 AccessTokenIdentifier,
9)
10from kwai.modules.identity.tokens.access_token_db_query import AccessTokenDbQuery
11from kwai.modules.identity.tokens.access_token_query import AccessTokenQuery
12from kwai.modules.identity.tokens.access_token_repository import (
13 AccessTokenNotFoundException,
14 AccessTokenRepository,
15)
16from kwai.modules.identity.tokens.token_identifier import TokenIdentifier
17from kwai.modules.identity.tokens.token_tables import (
18 AccessTokenRow,
19 AccessTokensTable,
20)
21from kwai.modules.identity.users.user_tables import UserAccountsTable
24def _create_entity(row: dict[str, Any]) -> AccessTokenEntity:
25 """Create an access token entity from a row."""
26 return AccessTokensTable(row).create_entity(UserAccountsTable(row).create_entity())
29class AccessTokenDbRepository(AccessTokenRepository):
30 """Database repository for the access token entity."""
32 def __init__(self, database: Database):
33 self._database = database
35 def create_query(self) -> AccessTokenQuery:
36 return AccessTokenDbQuery(self._database)
38 async def get(self, id_: AccessTokenIdentifier) -> AccessTokenEntity:
39 query = self.create_query()
40 query.filter_by_id(id_.value)
42 row = await query.fetch_one()
43 if row:
44 return _create_entity(row)
46 raise AccessTokenNotFoundException
48 async def get_by_identifier(self, identifier: TokenIdentifier) -> AccessTokenEntity:
49 query = self.create_query()
50 query.filter_by_token_identifier(identifier)
52 row = await query.fetch_one()
53 if row:
54 return _create_entity(row)
56 raise AccessTokenNotFoundException
58 async def get_all(
59 self,
60 query: AccessTokenQuery | None = None,
61 limit: int | None = None,
62 offset: int | None = None,
63 ) -> AsyncIterator[AccessTokenEntity]:
64 query = query or self.create_query()
65 async for row in query.fetch(limit, offset):
66 yield _create_entity(row)
68 async def create(self, access_token: AccessTokenEntity) -> AccessTokenEntity:
69 new_id = await self._database.insert(
70 AccessTokensTable.table_name, AccessTokenRow.persist(access_token)
71 )
72 await self._database.commit()
73 return Entity.replace(access_token, id_=AccessTokenIdentifier(new_id))
75 async def update(self, access_token: AccessTokenEntity):
76 await self._database.update(
77 access_token.id.value,
78 AccessTokensTable.table_name,
79 AccessTokenRow.persist(access_token),
80 )
81 await self._database.commit()