Coverage for src/kwai/core/events/consumer.py: 76%
42 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module that implements a consumer for a redis stream."""
3import asyncio
4import inspect
6from asyncio import Event
7from typing import Awaitable, Callable
9from kwai.core.events.stream import RedisMessage, RedisStream
12class RedisConsumer:
13 """A consumer for a Redis stream.
15 Attributes:
16 _stream: The stream to consume.
17 _group_name: The name of the group.
18 _callback: The callback to call when a message is consumed.
19 _is_stopping: An event to stop the consumer.
20 """
22 def __init__(
23 self,
24 stream: RedisStream,
25 group_name: str,
26 callback: Callable[[RedisMessage], bool | Awaitable[bool]],
27 ):
28 self._stream = stream
29 self._group_name = group_name
30 self._callback = callback
31 self._is_stopping = Event()
33 async def consume(self, consumer_name: str, check_backlog: bool = True):
34 """Consume messages from a stream.
36 Args:
37 consumer_name: The name of the consumer.
38 check_backlog: When True, all pending messages will be processed first.
39 """
40 await self._stream.create_group(self._group_name)
42 while True:
43 if check_backlog:
44 id_ = "0-0"
45 else:
46 id_ = ">"
47 try:
48 message = await self._stream.consume(
49 self._group_name, consumer_name, id_
50 )
51 if message:
52 try:
53 await self._trigger_callback(message)
54 except Exception as ex:
55 print(f"Exception: {ex!r}")
56 # avoid a break of the loop
57 continue
58 else:
59 check_backlog = False
60 except asyncio.CancelledError:
61 # happens on shutdown, ignore
62 return
63 except Exception as ex:
64 print(f"Exception: {ex}")
65 continue
66 finally:
67 if self._is_stopping.is_set():
68 return # noqa
69 await asyncio.sleep(1)
71 def cancel(self):
72 """Cancel the consumer."""
73 self._is_stopping.set()
75 async def _trigger_callback(self, message: RedisMessage):
76 if inspect.iscoroutinefunction(self._callback):
77 result = await self._callback(message)
78 else:
79 result = self._callback(message)
80 if result:
81 await self._stream.ack(self._group_name, message.id)