Coverage for kwai/core/events/stream.py: 76%
119 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-05 17:55 +0000
1"""Define a Redis stream."""
2import json
3from dataclasses import dataclass, field
4from json import JSONDecodeError
5from typing import Any
7import redis.exceptions
8from redis.asyncio import Redis
11@dataclass(kw_only=True, frozen=True, slots=True)
12class RedisStreamInfo:
13 """Dataclass with information about a redis stream."""
15 length: int
16 first_entry: str
17 last_entry: str
20@dataclass(kw_only=True, frozen=True, slots=True)
21class RedisGroupInfo:
22 """Dataclass with information about a redis stream group."""
24 name: str
25 consumers: int
26 pending: int
27 last_delivered_id: str
30class RedisMessageException(Exception):
31 """Exception raised when the message is not a RedisMessage."""
33 def __init__(self, stream_name: str, message_id: str, message: str):
34 self._stream_name = stream_name
35 self._message_id = message_id
36 super().__init__(message)
38 @property
39 def stream_name(self) -> str:
40 """Return the stream of the message."""
41 return self._stream_name
43 @property
44 def message_id(self) -> str:
45 """Return the message id of the message that raised this exception."""
46 return self._message_id
48 def __str__(self):
49 """Return a string representation of this exception."""
50 return f"({self._stream_name} - {self._message_id}) " + super().__str__()
53@dataclass(kw_only=True, frozen=True, slots=True)
54class RedisMessage:
55 """Dataclass for a message on a stream."""
57 stream: str | None = None
58 id: str = "*"
59 data: dict[str, Any] = field(default_factory=dict)
61 @classmethod
62 def create_from_redis(cls, messages: list) -> "RedisMessage":
63 """Create a RedisMessage from messages retrieved from a Redis stream."""
64 # A nested list is returned from Redis. For each stream (we only have one here),
65 # a list of entries read is returned. Because count was 1, this contains only 1
66 # element. An entry is a tuple with the message id and the message content.
67 message = messages[0] # we only have one stream, so use index 0
68 stream_name = message[0].decode("utf-8")
69 message = message[1] # This is a list with all returned tuple entries
70 message_id = message[0][0].decode("utf-8")
71 if b"data" in message[0][1]:
72 try:
73 json.loads(message[0][1][b"data"])
74 except JSONDecodeError as ex:
75 raise RedisMessageException(stream_name, message_id, str(ex)) from ex
76 return RedisMessage(
77 stream=stream_name,
78 id=message_id,
79 data=json.loads(message[0][1][b"data"]),
80 )
81 raise RedisMessageException(
82 stream_name, message_id, "No data key found in redis message"
83 )
86class RedisStream:
87 """A stream using Redis.
89 Attributes:
90 _redis: Redis connection.
91 _stream_name: Name of the Redis stream.
93 A stream will be created when a first group is created or when a first message is
94 added.
95 """
97 def __init__(self, redis_: Redis, stream_name: str):
98 self._redis = redis_
99 self._stream_name = stream_name
101 @property
102 def name(self) -> str:
103 """Return the name of the stream."""
104 return self._stream_name
106 async def ack(self, group_name: str, id_: str):
107 """Acknowledge the message with the given id for the given group.
109 Args:
110 group_name: The name of the group.
111 id_: The id of the message to acknowledge.
112 """
113 await self._redis.xack(self._stream_name, group_name, id_)
115 async def add(self, message: RedisMessage) -> RedisMessage:
116 """Add a new message to the stream.
118 Args:
119 message: The message to add to the stream.
121 Returns:
122 The original message. When the id of the message was a *, the id returned
123 from redis will be set.
125 The data will be serialized to JSON. The field 'data' will be used to store
126 this JSON.
127 """
128 message_id = await self._redis.xadd(
129 self._stream_name, {"data": json.dumps(message.data)}, id=message.id
130 )
131 return RedisMessage(id=message_id.decode("utf-8"), data=message.data)
133 async def consume(
134 self, group_name: str, consumer_name: str, id_=">", block: int | None = None
135 ) -> RedisMessage | None:
136 """Consume a message from a stream.
138 Args:
139 group_name: Name of the group.
140 consumer_name: Name of the consumer.
141 id_: The id to start from (default is >)
142 block: milliseconds to wait for an entry. Use None to not block.
143 """
144 messages = await self._redis.xreadgroup(
145 group_name, consumer_name, {self._stream_name: id_}, 1, block
146 )
147 if messages is None:
148 return
149 if len(messages) == 0:
150 return
152 # Check if there is a message returned for our stream.
153 _, stream_messages = messages[0]
154 if len(stream_messages) == 0:
155 return
157 return RedisMessage.create_from_redis(messages)
159 async def create_group(self, group_name: str, id_="$") -> bool:
160 """Create a group (if it doesn't exist yet).
162 Args:
163 group_name: The name of the group
164 id_: The id used as starting id. Default is $, which means only
165 new messages.
167 Returns:
168 True, when the group is created, False when the group already exists.
170 When the stream does not exist yet, it will be created.
171 """
172 try:
173 await self._redis.xgroup_create(self._stream_name, group_name, id_, True)
174 return True
175 except redis.ResponseError:
176 return False
178 async def delete(self) -> bool:
179 """Delete the stream.
181 Returns:
182 True when the stream is deleted. False when the stream didn't exist or
183 isn't deleted.
184 """
185 result = await self._redis.delete(self._stream_name)
186 return result == 1
188 async def delete_entries(self, *ids) -> int:
189 """Delete entries from the stream.
191 Returns the number of deleted entries.
192 """
193 return await self._redis.xdel(self._stream_name, *ids)
195 async def get_group(self, group_name: str) -> RedisGroupInfo | None:
196 """Get the information about a group.
198 Returns:
199 RedisGroup when the group exist, otherwise None is returned.
200 """
201 groups = await self.get_groups()
202 return groups.get(group_name, None)
204 async def get_groups(self) -> dict[str, RedisGroupInfo]:
205 """Get all groups of the stream.
207 Returns:
208 A list of groups.
209 """
210 result = {}
211 groups = await self._redis.xinfo_groups(self._stream_name)
212 for group in groups:
213 group_name = group["name"].decode("utf-8")
214 result[group_name] = RedisGroupInfo(
215 name=group_name,
216 consumers=group["consumers"],
217 pending=group["pending"],
218 last_delivered_id=group["last-delivered-id"].decode("utf-8"),
219 )
221 return result
223 async def first_entry_id(self) -> str:
224 """Return the id of the first entry.
226 An empty string will be returned when there is no entry on the stream.
227 """
228 result = await self.info()
229 if result is None:
230 return ""
231 return result.first_entry
233 async def info(self) -> RedisStreamInfo | None:
234 """Return information about the stream.
236 Returns:
237 A tuple with length, first-entry-id and last-entry-id. None is returned
238 when the stream does not exist.
239 """
240 try:
241 result = await self._redis.xinfo_stream(self._stream_name)
242 return RedisStreamInfo(
243 length=result["length"],
244 first_entry=result["first-entry"][0].decode("utf-8"),
245 last_entry=result["last-entry"][0].decode("utf-8"),
246 )
247 except redis.exceptions.ResponseError:
248 return None
250 async def last_entry_id(self) -> str:
251 """Return the id of the last entry.
253 An empty string will be returned when there is no entry on the stream.
254 """
255 result = await self.info()
256 if result is None:
257 return ""
258 return result.last_entry
260 async def length(self):
261 """Return the number of entries on the stream.
263 0 will be returned when the stream does not exist.
264 """
265 result = await self.info()
266 if result is None:
267 return 0
268 return result.length
270 async def read(
271 self, last_id: str = "$", block: int | None = None
272 ) -> RedisMessage | None:
273 """Read an entry from the stream.
275 Args:
276 last_id: only entries with an id greater than this id will be returned.
277 The default is $, which means to return only new entries.
278 block: milliseconds to wait for an entry. Use None to not block.
280 Returns:
281 None when no entry is read. A tuple with the id and data of the entry
282 when an entry is read.
283 """
284 messages = await self._redis.xread({self._stream_name: last_id}, 1, block)
285 if messages is None: # No message read
286 return
287 if len(messages) == 0:
288 return
290 return RedisMessage.create_from_redis(messages)