Coverage for src/kwai/core/events/stream.py: 76%
121 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-01-01 00:00 +0000
1"""Define a Redis stream."""
3import json
5from dataclasses import dataclass, field
6from json import JSONDecodeError
7from typing import Any
9import redis.exceptions
11from redis.asyncio import Redis
14@dataclass(kw_only=True, frozen=True, slots=True)
15class RedisStreamInfo:
16 """Dataclass with information about a redis stream."""
18 length: int
19 first_entry: str
20 last_entry: str
23@dataclass(kw_only=True, frozen=True, slots=True)
24class RedisGroupInfo:
25 """Dataclass with information about a redis stream group."""
27 name: str
28 consumers: int
29 pending: int
30 last_delivered_id: str
33class RedisMessageException(Exception):
34 """Exception raised when the message is not a RedisMessage."""
36 def __init__(self, stream_name: str, message_id: str, message: str):
37 self._stream_name = stream_name
38 self._message_id = message_id
39 super().__init__(message)
41 @property
42 def stream_name(self) -> str:
43 """Return the stream of the message."""
44 return self._stream_name
46 @property
47 def message_id(self) -> str:
48 """Return the message id of the message that raised this exception."""
49 return self._message_id
51 def __str__(self):
52 """Return a string representation of this exception."""
53 return f"({self._stream_name} - {self._message_id}) " + super().__str__()
56@dataclass(kw_only=True, frozen=True, slots=True)
57class RedisMessage:
58 """Dataclass for a message on a stream."""
60 stream: str | None = None
61 id: str = "*"
62 data: dict[str, Any] = field(default_factory=dict)
64 @classmethod
65 def create_from_redis(cls, messages: list) -> "RedisMessage":
66 """Create a RedisMessage from messages retrieved from a Redis stream."""
67 # A nested list is returned from Redis. For each stream (we only have one here),
68 # a list of entries read is returned. Because count was 1, this contains only 1
69 # element. An entry is a tuple with the message id and the message content.
70 message = messages[0] # we only have one stream, so use index 0
71 stream_name = message[0].decode("utf-8")
72 message = message[1] # This is a list with all returned tuple entries
73 message_id = message[0][0].decode("utf-8")
74 if b"data" in message[0][1]:
75 try:
76 json.loads(message[0][1][b"data"])
77 except JSONDecodeError as ex:
78 raise RedisMessageException(stream_name, message_id, str(ex)) from ex
79 return RedisMessage(
80 stream=stream_name,
81 id=message_id,
82 data=json.loads(message[0][1][b"data"]),
83 )
84 raise RedisMessageException(
85 stream_name, message_id, "No data key found in redis message"
86 )
89class RedisStream:
90 """A stream using Redis.
92 Attributes:
93 _redis: Redis connection.
94 _stream_name: Name of the Redis stream.
96 A stream will be created when a first group is created or when a first message is
97 added.
98 """
100 def __init__(self, redis_: Redis, stream_name: str):
101 self._redis = redis_
102 self._stream_name = stream_name
104 @property
105 def name(self) -> str:
106 """Return the name of the stream."""
107 return self._stream_name
109 async def ack(self, group_name: str, id_: str):
110 """Acknowledge the message with the given id for the given group.
112 Args:
113 group_name: The name of the group.
114 id_: The id of the message to acknowledge.
115 """
116 await self._redis.xack(self._stream_name, group_name, id_)
118 async def add(self, message: RedisMessage) -> RedisMessage:
119 """Add a new message to the stream.
121 Args:
122 message: The message to add to the stream.
124 Returns:
125 The original message. When the id of the message was a *, the id returned
126 from redis will be set.
128 The data will be serialized to JSON. The field 'data' will be used to store
129 this JSON.
130 """
131 message_id = await self._redis.xadd(
132 self._stream_name, {"data": json.dumps(message.data)}, id=message.id
133 )
134 return RedisMessage(id=message_id.decode("utf-8"), data=message.data)
136 async def consume(
137 self,
138 group_name: str,
139 consumer_name: str,
140 id_: str = ">",
141 block: int | None = None,
142 ) -> RedisMessage | None:
143 """Consume a message from a stream.
145 Args:
146 group_name: Name of the group.
147 consumer_name: Name of the consumer.
148 id_: The id to start from (default is >)
149 block: milliseconds to wait for an entry. Use None to not block.
150 """
151 messages = await self._redis.xreadgroup(
152 group_name, consumer_name, {self._stream_name: id_}, 1, block
153 )
154 if messages is None:
155 return
156 if len(messages) == 0:
157 return
159 # Check if there is a message returned for our stream.
160 _, stream_messages = messages[0]
161 if len(stream_messages) == 0:
162 return
164 return RedisMessage.create_from_redis(messages)
166 async def create_group(self, group_name: str, id_: str = "$") -> bool:
167 """Create a group (if it doesn't exist yet).
169 Args:
170 group_name: The name of the group
171 id_: The id used as starting id. Default is $, which means only
172 new messages.
174 Returns:
175 True, when the group is created, False when the group already exists.
177 When the stream does not exist yet, it will be created.
178 """
179 try:
180 await self._redis.xgroup_create(self._stream_name, group_name, id_, True)
181 return True
182 except redis.ResponseError:
183 return False
185 async def delete(self) -> bool:
186 """Delete the stream.
188 Returns:
189 True when the stream is deleted. False when the stream didn't exist or
190 isn't deleted.
191 """
192 result = await self._redis.delete(self._stream_name)
193 return result == 1
195 async def delete_group(self, group_name: str) -> None:
196 """Delete the group."""
197 await self._redis.xgroup_destroy(self._stream_name, group_name)
199 async def delete_entries(self, *ids) -> int:
200 """Delete entries from the stream.
202 Returns the number of deleted entries.
203 """
204 return await self._redis.xdel(self._stream_name, *ids)
206 async def get_group(self, group_name: str) -> RedisGroupInfo | None:
207 """Get the information about a group.
209 Returns:
210 RedisGroup when the group exist, otherwise None is returned.
211 """
212 groups = await self.get_groups()
213 return groups.get(group_name, None)
215 async def get_groups(self) -> dict[str, RedisGroupInfo]:
216 """Get all groups of the stream.
218 Returns:
219 A list of groups.
220 """
221 result = {}
222 groups = await self._redis.xinfo_groups(self._stream_name)
223 for group in groups:
224 group_name = group["name"].decode("utf-8")
225 result[group_name] = RedisGroupInfo(
226 name=group_name,
227 consumers=group["consumers"],
228 pending=group["pending"],
229 last_delivered_id=group["last-delivered-id"].decode("utf-8"),
230 )
232 return result
234 async def first_entry_id(self) -> str:
235 """Return the id of the first entry.
237 An empty string will be returned when there is no entry on the stream.
238 """
239 result = await self.info()
240 if result is None:
241 return ""
242 return result.first_entry
244 async def info(self) -> RedisStreamInfo | None:
245 """Return information about the stream.
247 Returns:
248 A tuple with length, first-entry-id and last-entry-id. None is returned
249 when the stream does not exist.
250 """
251 try:
252 result = await self._redis.xinfo_stream(self._stream_name)
253 return RedisStreamInfo(
254 length=result["length"],
255 first_entry=result["first-entry"][0].decode("utf-8"),
256 last_entry=result["last-entry"][0].decode("utf-8"),
257 )
258 except redis.exceptions.ResponseError:
259 return None
261 async def last_entry_id(self) -> str:
262 """Return the id of the last entry.
264 An empty string will be returned when there is no entry on the stream.
265 """
266 result = await self.info()
267 if result is None:
268 return ""
269 return result.last_entry
271 async def length(self):
272 """Return the number of entries on the stream.
274 0 will be returned when the stream does not exist.
275 """
276 result = await self.info()
277 if result is None:
278 return 0
279 return result.length
281 async def read(
282 self, last_id: str = "$", block: int | None = None
283 ) -> RedisMessage | None:
284 """Read an entry from the stream.
286 Args:
287 last_id: only entries with an id greater than this id will be returned.
288 The default is $, which means to return only new entries.
289 block: milliseconds to wait for an entry. Use None to not block.
291 Returns:
292 None when no entry is read. A tuple with the id and data of the entry
293 when an entry is read.
294 """
295 messages = await self._redis.xread({self._stream_name: last_id}, 1, block)
296 if messages is None: # No message read
297 return
298 if len(messages) == 0:
299 return
301 return RedisMessage.create_from_redis(messages)