Coverage for src/kwai/core/events/redis_bus.py: 68%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module for defining a publisher using Redis.""" 

2 

3import asyncio 

4 

5from loguru import logger 

6from redis.asyncio import Redis 

7 

8from kwai.core.events.consumer import RedisConsumer 

9from kwai.core.events.event import Event, EventMeta 

10from kwai.core.events.event_router import EventRouter 

11from kwai.core.events.publisher import Publisher 

12from kwai.core.events.stream import RedisMessage, RedisStream 

13from kwai.core.events.subscriber import Subscriber 

14 

15 

16class RedisBus(Publisher, Subscriber): 

17 """An event bus using Redis streams.""" 

18 

19 def __init__(self, redis: Redis): 

20 self._redis = redis 

21 self._consumers = [] 

22 

23 async def publish(self, event: Event): 

24 stream_name = self._get_stream_name(event.meta) 

25 logger.info(f"Publishing event to {stream_name}") 

26 stream = RedisStream(self._redis, stream_name) 

27 await stream.add(RedisMessage(data=event.data)) 

28 

29 def subscribe(self, event_router: EventRouter) -> None: 

30 stream_name = self._get_stream_name(event_router.event.meta) 

31 self._consumers.append( 

32 RedisConsumer( 

33 RedisStream(self._redis, stream_name), 

34 event_router.callback.__qualname__, 

35 RedisBus._create_event_trigger(event_router), 

36 ) 

37 ) 

38 

39 @classmethod 

40 def _create_event_trigger(cls, event_router: EventRouter): 

41 """Create an event trigger.""" 

42 

43 async def trigger(message: RedisMessage) -> bool: 

44 with logger.contextualize( 

45 stream=RedisBus._get_stream_name(event_router.event.meta), 

46 message_id=message.id, 

47 ): 

48 return await event_router.execute(message.data) 

49 

50 return trigger 

51 

52 async def run(self): 

53 """Start all consumers. 

54 

55 For each stream a consumer will be started. This method will wait for all tasks 

56 to end. 

57 """ 

58 tasks = [] 

59 for index, consumer in enumerate(self._consumers): 

60 # noinspection PyAsyncCall 

61 tasks.append(asyncio.shield(consumer.consume(f"consumer-{index}"))) 

62 await asyncio.gather(*tasks) 

63 

64 async def cancel(self): 

65 """Cancel all consumers.""" 

66 for task in self._consumers: 

67 task.cancel() 

68 

69 @classmethod 

70 def _get_stream_name(cls, meta: EventMeta) -> str: 

71 """Get the stream name for the given event.""" 

72 return f"kwai.{meta.version}.{meta.module}.{meta.name}"