Coverage for src/kwai/core/events/redis_bus.py: 59%
39 statements
« prev ^ index » next coverage.py v7.7.1, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2024-01-01 00:00 +0000
1"""Module for defining a publisher using Redis."""
3import asyncio
5from loguru import logger
6from redis.asyncio import Redis
8from kwai.core.events.consumer import RedisConsumer
9from kwai.core.events.event import Event
10from kwai.core.events.event_router import EventRouter
11from kwai.core.events.publisher import Publisher
12from kwai.core.events.stream import RedisMessage, RedisStream
13from kwai.core.events.subscriber import Subscriber
16class RedisBus(Publisher, Subscriber):
17 """An event bus using Redis streams."""
19 def __init__(self, redis: Redis):
20 self._redis = redis
21 self._consumers: list[RedisConsumer] = []
23 async def publish(self, event: Event):
24 stream_name = event.meta.full_name
25 logger.info(f"Publishing event to {stream_name}")
26 stream = RedisStream(self._redis, stream_name)
27 await stream.add(RedisMessage(data=event.data))
29 def subscribe(self, event_router: EventRouter) -> None:
30 stream_name = event_router.event.meta.full_name
31 logger.info(f"Subscribing for {stream_name}")
32 self._consumers.append(
33 RedisConsumer(
34 RedisStream(self._redis, stream_name),
35 event_router.callback.__qualname__,
36 RedisBus._create_event_trigger(event_router),
37 )
38 )
40 @classmethod
41 def _create_event_trigger(cls, event_router: EventRouter):
42 """Create an event trigger."""
44 async def trigger(message: RedisMessage) -> bool:
45 with logger.contextualize(
46 stream=event_router.event.meta.full_name,
47 message_id=message.id,
48 ):
49 return await event_router.execute(message.data)
51 return trigger
53 async def run(self):
54 """Start all consumers.
56 For each stream a consumer will be started. This method will wait for all tasks
57 to end.
58 """
59 tasks = []
60 for index, consumer in enumerate(self._consumers):
61 # noinspection PyAsyncCall
62 tasks.append(asyncio.shield(consumer.consume(f"consumer-{index}")))
64 try:
65 await asyncio.gather(*tasks)
66 except asyncio.CancelledError:
67 logger.info("The bus has been cancelled.")
69 async def cancel(self):
70 """Cancel all consumers."""
71 for task in self._consumers:
72 task.cancel()