Coverage for kwai/core/events/redis_bus.py: 39%

64 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Define a message bus using Redis.""" 

2import asyncio 

3import inspect 

4from typing import Any, Callable 

5 

6from loguru import logger 

7from redis import Redis 

8 

9from kwai.core.events.bus import Bus 

10from kwai.core.events.consumer import RedisConsumer 

11from kwai.core.events.event import Event 

12from kwai.core.events.stream import RedisMessage, RedisStream 

13 

14 

15class RedisBus(Bus): 

16 """A message bus using Redis streams. 

17 

18 The name of the event is mostly <module>.<entity>.<event>. Each module will have 

19 its own stream. 

20 """ 

21 

22 def __init__(self, redis: Redis): 

23 self._redis = redis 

24 self._events: dict[str, list[Callable[[dict[str, Any]], Any]]] = {} 

25 self._stream_names = set() 

26 self._consumers: list[RedisConsumer] = [] 

27 

28 async def publish(self, event: Event): 

29 """Publish the event. 

30 

31 The event will be placed on the stream that belongs to the module. 

32 """ 

33 stream_name = event.meta.name.split(".")[0] 

34 stream = RedisStream(self._redis, f"kwai.{stream_name}") 

35 await stream.add(RedisMessage(data=event.data)) 

36 

37 def subscribe(self, event: type[Event], task: Callable[[dict[str, Any]], Any]): 

38 """Subscribe a callback to an event. 

39 

40 When an event is retrieved from a stream, the callback will be executed. For 

41 each stream, a consumer will be started when the bus is running. 

42 """ 

43 if event.meta.name not in self._events: 

44 self._events[event.meta.name] = [] 

45 stream_name = event.meta.name.split(".")[0] 

46 self._stream_names.add(f"kwai.{stream_name}") 

47 self._events[event.meta.name].append(task) 

48 

49 async def _trigger_event(self, message: RedisMessage) -> bool: 

50 """Call all callbacks that are linked to the event.""" 

51 with logger.contextualize(stream=message.stream, message_id=message.id): 

52 logger.info("An event received.") 

53 if not self._is_valid_event(message): 

54 return False 

55 

56 event_name = message.data["meta"]["name"] 

57 callbacks = self._events.get(event_name, []) 

58 if len(callbacks) == 0: 

59 logger.warning( 

60 f"Event ignored: No handlers found for event 'f{event_name}'." 

61 f" Check the subscriptions." 

62 ) 

63 for callback in callbacks: 

64 logger.info( 

65 f"Calling event handler " 

66 f"'{callback.__module__}.{callback.__qualname__}'" 

67 ) 

68 try: 

69 if inspect.iscoroutinefunction(callback): 

70 await callback(message.data) 

71 else: 

72 callback(message.data) 

73 except Exception as ex: 

74 logger.warning(f"The handler raised an exception: {ex}") 

75 

76 logger.info("All handlers are called.") 

77 

78 return True 

79 

80 @classmethod 

81 def _is_valid_event(cls, message: RedisMessage): 

82 """Check the event message.""" 

83 if "meta" not in message.data: 

84 logger.warning("Event ignored: The event does not contain meta data.") 

85 return False 

86 if "name" not in message.data["meta"]: 

87 logger.warning("Event ignored: The event meta does not contain a name.") 

88 return False 

89 return True 

90 

91 async def run(self): 

92 """Start all consumers. 

93 

94 For each stream a consumer will be started. This method will wait for all tasks 

95 to end. 

96 """ 

97 tasks = [] 

98 self._consumers = [] 

99 for stream_name in self._stream_names: 

100 event_consumer = RedisConsumer( 

101 RedisStream(self._redis, stream_name), 

102 f"{stream_name}.group", 

103 self._trigger_event, 

104 ) 

105 self._consumers.append(event_consumer) 

106 # noinspection PyAsyncCall 

107 tasks.append( 

108 asyncio.shield(event_consumer.consume(f"{stream_name}.consumer")) 

109 ) 

110 await asyncio.gather(*tasks) 

111 

112 async def cancel(self): 

113 """Cancel all consumers.""" 

114 for task in self._consumers: 

115 task.cancel()