Coverage for src/kwai/core/events/stream.py: 76%

121 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-01-01 00:00 +0000

1"""Define a Redis stream.""" 

2 

3import json 

4 

5from dataclasses import dataclass, field 

6from json import JSONDecodeError 

7from typing import Any 

8 

9import redis.exceptions 

10 

11from redis.asyncio import Redis 

12 

13 

14@dataclass(kw_only=True, frozen=True, slots=True) 

15class RedisStreamInfo: 

16 """Dataclass with information about a redis stream.""" 

17 

18 length: int 

19 first_entry: str 

20 last_entry: str 

21 

22 

23@dataclass(kw_only=True, frozen=True, slots=True) 

24class RedisGroupInfo: 

25 """Dataclass with information about a redis stream group.""" 

26 

27 name: str 

28 consumers: int 

29 pending: int 

30 last_delivered_id: str 

31 

32 

33class RedisMessageException(Exception): 

34 """Exception raised when the message is not a RedisMessage.""" 

35 

36 def __init__(self, stream_name: str, message_id: str, message: str): 

37 self._stream_name = stream_name 

38 self._message_id = message_id 

39 super().__init__(message) 

40 

41 @property 

42 def stream_name(self) -> str: 

43 """Return the stream of the message.""" 

44 return self._stream_name 

45 

46 @property 

47 def message_id(self) -> str: 

48 """Return the message id of the message that raised this exception.""" 

49 return self._message_id 

50 

51 def __str__(self): 

52 """Return a string representation of this exception.""" 

53 return f"({self._stream_name} - {self._message_id}) " + super().__str__() 

54 

55 

56@dataclass(kw_only=True, frozen=True, slots=True) 

57class RedisMessage: 

58 """Dataclass for a message on a stream.""" 

59 

60 stream: str | None = None 

61 id: str = "*" 

62 data: dict[str, Any] = field(default_factory=dict) 

63 

64 @classmethod 

65 def create_from_redis(cls, messages: list) -> "RedisMessage": 

66 """Create a RedisMessage from messages retrieved from a Redis stream.""" 

67 # A nested list is returned from Redis. For each stream (we only have one here), 

68 # a list of entries read is returned. Because count was 1, this contains only 1 

69 # element. An entry is a tuple with the message id and the message content. 

70 message = messages[0] # we only have one stream, so use index 0 

71 stream_name = message[0].decode("utf-8") 

72 message = message[1] # This is a list with all returned tuple entries 

73 message_id = message[0][0].decode("utf-8") 

74 if b"data" in message[0][1]: 

75 try: 

76 json.loads(message[0][1][b"data"]) 

77 except JSONDecodeError as ex: 

78 raise RedisMessageException(stream_name, message_id, str(ex)) from ex 

79 return RedisMessage( 

80 stream=stream_name, 

81 id=message_id, 

82 data=json.loads(message[0][1][b"data"]), 

83 ) 

84 raise RedisMessageException( 

85 stream_name, message_id, "No data key found in redis message" 

86 ) 

87 

88 

89class RedisStream: 

90 """A stream using Redis. 

91 

92 Attributes: 

93 _redis: Redis connection. 

94 _stream_name: Name of the Redis stream. 

95 

96 A stream will be created when a first group is created or when a first message is 

97 added. 

98 """ 

99 

100 def __init__(self, redis_: Redis, stream_name: str): 

101 self._redis = redis_ 

102 self._stream_name = stream_name 

103 

104 @property 

105 def name(self) -> str: 

106 """Return the name of the stream.""" 

107 return self._stream_name 

108 

109 async def ack(self, group_name: str, id_: str): 

110 """Acknowledge the message with the given id for the given group. 

111 

112 Args: 

113 group_name: The name of the group. 

114 id_: The id of the message to acknowledge. 

115 """ 

116 await self._redis.xack(self._stream_name, group_name, id_) 

117 

118 async def add(self, message: RedisMessage) -> RedisMessage: 

119 """Add a new message to the stream. 

120 

121 Args: 

122 message: The message to add to the stream. 

123 

124 Returns: 

125 The original message. When the id of the message was a *, the id returned 

126 from redis will be set. 

127 

128 The data will be serialized to JSON. The field 'data' will be used to store 

129 this JSON. 

130 """ 

131 message_id = await self._redis.xadd( 

132 self._stream_name, {"data": json.dumps(message.data)}, id=message.id 

133 ) 

134 return RedisMessage(id=message_id.decode("utf-8"), data=message.data) 

135 

136 async def consume( 

137 self, 

138 group_name: str, 

139 consumer_name: str, 

140 id_: str = ">", 

141 block: int | None = None, 

142 ) -> RedisMessage | None: 

143 """Consume a message from a stream. 

144 

145 Args: 

146 group_name: Name of the group. 

147 consumer_name: Name of the consumer. 

148 id_: The id to start from (default is >) 

149 block: milliseconds to wait for an entry. Use None to not block. 

150 """ 

151 messages = await self._redis.xreadgroup( 

152 group_name, consumer_name, {self._stream_name: id_}, 1, block 

153 ) 

154 if messages is None: 

155 return 

156 if len(messages) == 0: 

157 return 

158 

159 # Check if there is a message returned for our stream. 

160 _, stream_messages = messages[0] 

161 if len(stream_messages) == 0: 

162 return 

163 

164 return RedisMessage.create_from_redis(messages) 

165 

166 async def create_group(self, group_name: str, id_: str = "$") -> bool: 

167 """Create a group (if it doesn't exist yet). 

168 

169 Args: 

170 group_name: The name of the group 

171 id_: The id used as starting id. Default is $, which means only 

172 new messages. 

173 

174 Returns: 

175 True, when the group is created, False when the group already exists. 

176 

177 When the stream does not exist yet, it will be created. 

178 """ 

179 try: 

180 await self._redis.xgroup_create(self._stream_name, group_name, id_, True) 

181 return True 

182 except redis.ResponseError: 

183 return False 

184 

185 async def delete(self) -> bool: 

186 """Delete the stream. 

187 

188 Returns: 

189 True when the stream is deleted. False when the stream didn't exist or 

190 isn't deleted. 

191 """ 

192 result = await self._redis.delete(self._stream_name) 

193 return result == 1 

194 

195 async def delete_group(self, group_name: str) -> None: 

196 """Delete the group.""" 

197 await self._redis.xgroup_destroy(self._stream_name, group_name) 

198 

199 async def delete_entries(self, *ids) -> int: 

200 """Delete entries from the stream. 

201 

202 Returns the number of deleted entries. 

203 """ 

204 return await self._redis.xdel(self._stream_name, *ids) 

205 

206 async def get_group(self, group_name: str) -> RedisGroupInfo | None: 

207 """Get the information about a group. 

208 

209 Returns: 

210 RedisGroup when the group exist, otherwise None is returned. 

211 """ 

212 groups = await self.get_groups() 

213 return groups.get(group_name, None) 

214 

215 async def get_groups(self) -> dict[str, RedisGroupInfo]: 

216 """Get all groups of the stream. 

217 

218 Returns: 

219 A list of groups. 

220 """ 

221 result = {} 

222 groups = await self._redis.xinfo_groups(self._stream_name) 

223 for group in groups: 

224 group_name = group["name"].decode("utf-8") 

225 result[group_name] = RedisGroupInfo( 

226 name=group_name, 

227 consumers=group["consumers"], 

228 pending=group["pending"], 

229 last_delivered_id=group["last-delivered-id"].decode("utf-8"), 

230 ) 

231 

232 return result 

233 

234 async def first_entry_id(self) -> str: 

235 """Return the id of the first entry. 

236 

237 An empty string will be returned when there is no entry on the stream. 

238 """ 

239 result = await self.info() 

240 if result is None: 

241 return "" 

242 return result.first_entry 

243 

244 async def info(self) -> RedisStreamInfo | None: 

245 """Return information about the stream. 

246 

247 Returns: 

248 A tuple with length, first-entry-id and last-entry-id. None is returned 

249 when the stream does not exist. 

250 """ 

251 try: 

252 result = await self._redis.xinfo_stream(self._stream_name) 

253 return RedisStreamInfo( 

254 length=result["length"], 

255 first_entry=result["first-entry"][0].decode("utf-8"), 

256 last_entry=result["last-entry"][0].decode("utf-8"), 

257 ) 

258 except redis.exceptions.ResponseError: 

259 return None 

260 

261 async def last_entry_id(self) -> str: 

262 """Return the id of the last entry. 

263 

264 An empty string will be returned when there is no entry on the stream. 

265 """ 

266 result = await self.info() 

267 if result is None: 

268 return "" 

269 return result.last_entry 

270 

271 async def length(self): 

272 """Return the number of entries on the stream. 

273 

274 0 will be returned when the stream does not exist. 

275 """ 

276 result = await self.info() 

277 if result is None: 

278 return 0 

279 return result.length 

280 

281 async def read( 

282 self, last_id: str = "$", block: int | None = None 

283 ) -> RedisMessage | None: 

284 """Read an entry from the stream. 

285 

286 Args: 

287 last_id: only entries with an id greater than this id will be returned. 

288 The default is $, which means to return only new entries. 

289 block: milliseconds to wait for an entry. Use None to not block. 

290 

291 Returns: 

292 None when no entry is read. A tuple with the id and data of the entry 

293 when an entry is read. 

294 """ 

295 messages = await self._redis.xread({self._stream_name: last_id}, 1, block) 

296 if messages is None: # No message read 

297 return 

298 if len(messages) == 0: 

299 return 

300 

301 return RedisMessage.create_from_redis(messages)