Coverage for src/kwai/core/events/redis_bus.py: 68%
38 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Module for defining a publisher using Redis."""
3import asyncio
5from loguru import logger
6from redis.asyncio import Redis
8from kwai.core.events.consumer import RedisConsumer
9from kwai.core.events.event import Event, EventMeta
10from kwai.core.events.event_router import EventRouter
11from kwai.core.events.publisher import Publisher
12from kwai.core.events.stream import RedisMessage, RedisStream
13from kwai.core.events.subscriber import Subscriber
16class RedisBus(Publisher, Subscriber):
17 """An event bus using Redis streams."""
19 def __init__(self, redis: Redis):
20 self._redis = redis
21 self._consumers = []
23 async def publish(self, event: Event):
24 stream_name = self._get_stream_name(event.meta)
25 logger.info(f"Publishing event to {stream_name}")
26 stream = RedisStream(self._redis, stream_name)
27 await stream.add(RedisMessage(data=event.data))
29 def subscribe(self, event_router: EventRouter) -> None:
30 stream_name = self._get_stream_name(event_router.event.meta)
31 self._consumers.append(
32 RedisConsumer(
33 RedisStream(self._redis, stream_name),
34 event_router.callback.__qualname__,
35 RedisBus._create_event_trigger(event_router),
36 )
37 )
39 @classmethod
40 def _create_event_trigger(cls, event_router: EventRouter):
41 """Create an event trigger."""
43 async def trigger(message: RedisMessage) -> bool:
44 with logger.contextualize(
45 stream=RedisBus._get_stream_name(event_router.event.meta),
46 message_id=message.id,
47 ):
48 return await event_router.execute(message.data)
50 return trigger
52 async def run(self):
53 """Start all consumers.
55 For each stream a consumer will be started. This method will wait for all tasks
56 to end.
57 """
58 tasks = []
59 for index, consumer in enumerate(self._consumers):
60 # noinspection PyAsyncCall
61 tasks.append(asyncio.shield(consumer.consume(f"consumer-{index}")))
62 await asyncio.gather(*tasks)
64 async def cancel(self):
65 """Cancel all consumers."""
66 for task in self._consumers:
67 task.cancel()
69 @classmethod
70 def _get_stream_name(cls, meta: EventMeta) -> str:
71 """Get the stream name for the given event."""
72 return f"kwai.{meta.version}.{meta.module}.{meta.name}"