Coverage for kwai/core/events/consumer.py: 76%

42 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Module that implements a consumer for a redis stream.""" 

2import asyncio 

3import inspect 

4from asyncio import Event 

5from typing import Awaitable, Callable 

6 

7from kwai.core.events.stream import RedisMessage, RedisStream 

8 

9 

10class RedisConsumer: 

11 """A consumer for a Redis stream. 

12 

13 Attributes: 

14 _stream: The stream to consume. 

15 _group_name: The name of the group. 

16 _callback: The callback to call when a message is consumed. 

17 _is_stopping: An event to stop the consumer. 

18 """ 

19 

20 def __init__( 

21 self, 

22 stream: RedisStream, 

23 group_name: str, 

24 callback: Callable[[RedisMessage], bool | Awaitable[bool]], 

25 ): 

26 self._stream = stream 

27 self._group_name = group_name 

28 self._callback = callback 

29 self._is_stopping = Event() 

30 

31 async def consume(self, consumer_name: str, check_backlog: bool = True): 

32 """Consume messages from a stream. 

33 

34 Args: 

35 consumer_name: The name of the consumer. 

36 check_backlog: When True, all pending messages will be processed first. 

37 """ 

38 await self._stream.create_group(self._group_name) 

39 

40 while True: 

41 if check_backlog: 

42 id_ = "0-0" 

43 else: 

44 id_ = ">" 

45 try: 

46 message = await self._stream.consume( 

47 self._group_name, consumer_name, id_ 

48 ) 

49 if message: 

50 try: 

51 await self._trigger_callback(message) 

52 except Exception as ex: 

53 print(f"Exception: {ex}") 

54 # avoid a break of the loop 

55 continue 

56 else: 

57 check_backlog = False 

58 except asyncio.CancelledError: 

59 # happens on shutdown, ignore 

60 return 

61 except Exception as ex: 

62 print(f"Exception: {ex}") 

63 continue 

64 finally: 

65 if self._is_stopping.is_set(): 

66 return # noqa 

67 await asyncio.sleep(1) 

68 

69 def cancel(self): 

70 """Cancel the consumer.""" 

71 self._is_stopping.set() 

72 

73 async def _trigger_callback(self, message: RedisMessage): 

74 if inspect.iscoroutinefunction(self._callback): 

75 result = await self._callback(message) 

76 else: 

77 result = self._callback(message) 

78 if result: 

79 await self._stream.ack(self._group_name, message.id)