Coverage for src/kwai/core/events/consumer.py: 76%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Module that implements a consumer for a redis stream.""" 

2 

3import asyncio 

4import inspect 

5 

6from asyncio import Event 

7from typing import Awaitable, Callable 

8 

9from kwai.core.events.stream import RedisMessage, RedisStream 

10 

11 

12class RedisConsumer: 

13 """A consumer for a Redis stream. 

14 

15 Attributes: 

16 _stream: The stream to consume. 

17 _group_name: The name of the group. 

18 _callback: The callback to call when a message is consumed. 

19 _is_stopping: An event to stop the consumer. 

20 """ 

21 

22 def __init__( 

23 self, 

24 stream: RedisStream, 

25 group_name: str, 

26 callback: Callable[[RedisMessage], bool | Awaitable[bool]], 

27 ): 

28 self._stream = stream 

29 self._group_name = group_name 

30 self._callback = callback 

31 self._is_stopping = Event() 

32 

33 async def consume(self, consumer_name: str, check_backlog: bool = True): 

34 """Consume messages from a stream. 

35 

36 Args: 

37 consumer_name: The name of the consumer. 

38 check_backlog: When True, all pending messages will be processed first. 

39 """ 

40 await self._stream.create_group(self._group_name) 

41 

42 while True: 

43 if check_backlog: 

44 id_ = "0-0" 

45 else: 

46 id_ = ">" 

47 try: 

48 message = await self._stream.consume( 

49 self._group_name, consumer_name, id_ 

50 ) 

51 if message: 

52 try: 

53 await self._trigger_callback(message) 

54 except Exception as ex: 

55 print(f"Exception: {ex!r}") 

56 # avoid a break of the loop 

57 continue 

58 else: 

59 check_backlog = False 

60 except asyncio.CancelledError: 

61 # happens on shutdown, ignore 

62 return 

63 except Exception as ex: 

64 print(f"Exception: {ex}") 

65 continue 

66 finally: 

67 if self._is_stopping.is_set(): 

68 return # noqa 

69 await asyncio.sleep(1) 

70 

71 def cancel(self): 

72 """Cancel the consumer.""" 

73 self._is_stopping.set() 

74 

75 async def _trigger_callback(self, message: RedisMessage): 

76 if inspect.iscoroutinefunction(self._callback): 

77 result = await self._callback(message) 

78 else: 

79 result = self._callback(message) 

80 if result: 

81 await self._stream.ack(self._group_name, message.id)