Coverage for kwai/core/events/redis_bus.py: 39%
64 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Define a message bus using Redis."""
2import asyncio
3import inspect
4from typing import Any, Callable
6from loguru import logger
7from redis import Redis
9from kwai.core.events.bus import Bus
10from kwai.core.events.consumer import RedisConsumer
11from kwai.core.events.event import Event
12from kwai.core.events.stream import RedisMessage, RedisStream
15class RedisBus(Bus):
16 """A message bus using Redis streams.
18 The name of the event is mostly <module>.<entity>.<event>. Each module will have
19 its own stream.
20 """
22 def __init__(self, redis: Redis):
23 self._redis = redis
24 self._events: dict[str, list[Callable[[dict[str, Any]], Any]]] = {}
25 self._stream_names = set()
26 self._consumers: list[RedisConsumer] = []
28 async def publish(self, event: Event):
29 """Publish the event.
31 The event will be placed on the stream that belongs to the module.
32 """
33 stream_name = event.meta.name.split(".")[0]
34 stream = RedisStream(self._redis, f"kwai.{stream_name}")
35 await stream.add(RedisMessage(data=event.data))
37 def subscribe(self, event: type[Event], task: Callable[[dict[str, Any]], Any]):
38 """Subscribe a callback to an event.
40 When an event is retrieved from a stream, the callback will be executed. For
41 each stream, a consumer will be started when the bus is running.
42 """
43 if event.meta.name not in self._events:
44 self._events[event.meta.name] = []
45 stream_name = event.meta.name.split(".")[0]
46 self._stream_names.add(f"kwai.{stream_name}")
47 self._events[event.meta.name].append(task)
49 async def _trigger_event(self, message: RedisMessage) -> bool:
50 """Call all callbacks that are linked to the event."""
51 with logger.contextualize(stream=message.stream, message_id=message.id):
52 logger.info("An event received.")
53 if not self._is_valid_event(message):
54 return False
56 event_name = message.data["meta"]["name"]
57 callbacks = self._events.get(event_name, [])
58 if len(callbacks) == 0:
59 logger.warning(
60 f"Event ignored: No handlers found for event 'f{event_name}'."
61 f" Check the subscriptions."
62 )
63 for callback in callbacks:
64 logger.info(
65 f"Calling event handler "
66 f"'{callback.__module__}.{callback.__qualname__}'"
67 )
68 try:
69 if inspect.iscoroutinefunction(callback):
70 await callback(message.data)
71 else:
72 callback(message.data)
73 except Exception as ex:
74 logger.warning(f"The handler raised an exception: {ex}")
76 logger.info("All handlers are called.")
78 return True
80 @classmethod
81 def _is_valid_event(cls, message: RedisMessage):
82 """Check the event message."""
83 if "meta" not in message.data:
84 logger.warning("Event ignored: The event does not contain meta data.")
85 return False
86 if "name" not in message.data["meta"]:
87 logger.warning("Event ignored: The event meta does not contain a name.")
88 return False
89 return True
91 async def run(self):
92 """Start all consumers.
94 For each stream a consumer will be started. This method will wait for all tasks
95 to end.
96 """
97 tasks = []
98 self._consumers = []
99 for stream_name in self._stream_names:
100 event_consumer = RedisConsumer(
101 RedisStream(self._redis, stream_name),
102 f"{stream_name}.group",
103 self._trigger_event,
104 )
105 self._consumers.append(event_consumer)
106 # noinspection PyAsyncCall
107 tasks.append(
108 asyncio.shield(event_consumer.consume(f"{stream_name}.consumer"))
109 )
110 await asyncio.gather(*tasks)
112 async def cancel(self):
113 """Cancel all consumers."""
114 for task in self._consumers:
115 task.cancel()