Coverage for kwai/core/events/stream.py: 76%

119 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-05 17:55 +0000

1"""Define a Redis stream.""" 

2import json 

3from dataclasses import dataclass, field 

4from json import JSONDecodeError 

5from typing import Any 

6 

7import redis.exceptions 

8from redis.asyncio import Redis 

9 

10 

11@dataclass(kw_only=True, frozen=True, slots=True) 

12class RedisStreamInfo: 

13 """Dataclass with information about a redis stream.""" 

14 

15 length: int 

16 first_entry: str 

17 last_entry: str 

18 

19 

20@dataclass(kw_only=True, frozen=True, slots=True) 

21class RedisGroupInfo: 

22 """Dataclass with information about a redis stream group.""" 

23 

24 name: str 

25 consumers: int 

26 pending: int 

27 last_delivered_id: str 

28 

29 

30class RedisMessageException(Exception): 

31 """Exception raised when the message is not a RedisMessage.""" 

32 

33 def __init__(self, stream_name: str, message_id: str, message: str): 

34 self._stream_name = stream_name 

35 self._message_id = message_id 

36 super().__init__(message) 

37 

38 @property 

39 def stream_name(self) -> str: 

40 """Return the stream of the message.""" 

41 return self._stream_name 

42 

43 @property 

44 def message_id(self) -> str: 

45 """Return the message id of the message that raised this exception.""" 

46 return self._message_id 

47 

48 def __str__(self): 

49 """Return a string representation of this exception.""" 

50 return f"({self._stream_name} - {self._message_id}) " + super().__str__() 

51 

52 

53@dataclass(kw_only=True, frozen=True, slots=True) 

54class RedisMessage: 

55 """Dataclass for a message on a stream.""" 

56 

57 stream: str | None = None 

58 id: str = "*" 

59 data: dict[str, Any] = field(default_factory=dict) 

60 

61 @classmethod 

62 def create_from_redis(cls, messages: list) -> "RedisMessage": 

63 """Create a RedisMessage from messages retrieved from a Redis stream.""" 

64 # A nested list is returned from Redis. For each stream (we only have one here), 

65 # a list of entries read is returned. Because count was 1, this contains only 1 

66 # element. An entry is a tuple with the message id and the message content. 

67 message = messages[0] # we only have one stream, so use index 0 

68 stream_name = message[0].decode("utf-8") 

69 message = message[1] # This is a list with all returned tuple entries 

70 message_id = message[0][0].decode("utf-8") 

71 if b"data" in message[0][1]: 

72 try: 

73 json.loads(message[0][1][b"data"]) 

74 except JSONDecodeError as ex: 

75 raise RedisMessageException(stream_name, message_id, str(ex)) from ex 

76 return RedisMessage( 

77 stream=stream_name, 

78 id=message_id, 

79 data=json.loads(message[0][1][b"data"]), 

80 ) 

81 raise RedisMessageException( 

82 stream_name, message_id, "No data key found in redis message" 

83 ) 

84 

85 

86class RedisStream: 

87 """A stream using Redis. 

88 

89 Attributes: 

90 _redis: Redis connection. 

91 _stream_name: Name of the Redis stream. 

92 

93 A stream will be created when a first group is created or when a first message is 

94 added. 

95 """ 

96 

97 def __init__(self, redis_: Redis, stream_name: str): 

98 self._redis = redis_ 

99 self._stream_name = stream_name 

100 

101 @property 

102 def name(self) -> str: 

103 """Return the name of the stream.""" 

104 return self._stream_name 

105 

106 async def ack(self, group_name: str, id_: str): 

107 """Acknowledge the message with the given id for the given group. 

108 

109 Args: 

110 group_name: The name of the group. 

111 id_: The id of the message to acknowledge. 

112 """ 

113 await self._redis.xack(self._stream_name, group_name, id_) 

114 

115 async def add(self, message: RedisMessage) -> RedisMessage: 

116 """Add a new message to the stream. 

117 

118 Args: 

119 message: The message to add to the stream. 

120 

121 Returns: 

122 The original message. When the id of the message was a *, the id returned 

123 from redis will be set. 

124 

125 The data will be serialized to JSON. The field 'data' will be used to store 

126 this JSON. 

127 """ 

128 message_id = await self._redis.xadd( 

129 self._stream_name, {"data": json.dumps(message.data)}, id=message.id 

130 ) 

131 return RedisMessage(id=message_id.decode("utf-8"), data=message.data) 

132 

133 async def consume( 

134 self, group_name: str, consumer_name: str, id_=">", block: int | None = None 

135 ) -> RedisMessage | None: 

136 """Consume a message from a stream. 

137 

138 Args: 

139 group_name: Name of the group. 

140 consumer_name: Name of the consumer. 

141 id_: The id to start from (default is >) 

142 block: milliseconds to wait for an entry. Use None to not block. 

143 """ 

144 messages = await self._redis.xreadgroup( 

145 group_name, consumer_name, {self._stream_name: id_}, 1, block 

146 ) 

147 if messages is None: 

148 return 

149 if len(messages) == 0: 

150 return 

151 

152 # Check if there is a message returned for our stream. 

153 _, stream_messages = messages[0] 

154 if len(stream_messages) == 0: 

155 return 

156 

157 return RedisMessage.create_from_redis(messages) 

158 

159 async def create_group(self, group_name: str, id_="$") -> bool: 

160 """Create a group (if it doesn't exist yet). 

161 

162 Args: 

163 group_name: The name of the group 

164 id_: The id used as starting id. Default is $, which means only 

165 new messages. 

166 

167 Returns: 

168 True, when the group is created, False when the group already exists. 

169 

170 When the stream does not exist yet, it will be created. 

171 """ 

172 try: 

173 await self._redis.xgroup_create(self._stream_name, group_name, id_, True) 

174 return True 

175 except redis.ResponseError: 

176 return False 

177 

178 async def delete(self) -> bool: 

179 """Delete the stream. 

180 

181 Returns: 

182 True when the stream is deleted. False when the stream didn't exist or 

183 isn't deleted. 

184 """ 

185 result = await self._redis.delete(self._stream_name) 

186 return result == 1 

187 

188 async def delete_entries(self, *ids) -> int: 

189 """Delete entries from the stream. 

190 

191 Returns the number of deleted entries. 

192 """ 

193 return await self._redis.xdel(self._stream_name, *ids) 

194 

195 async def get_group(self, group_name: str) -> RedisGroupInfo | None: 

196 """Get the information about a group. 

197 

198 Returns: 

199 RedisGroup when the group exist, otherwise None is returned. 

200 """ 

201 groups = await self.get_groups() 

202 return groups.get(group_name, None) 

203 

204 async def get_groups(self) -> dict[str, RedisGroupInfo]: 

205 """Get all groups of the stream. 

206 

207 Returns: 

208 A list of groups. 

209 """ 

210 result = {} 

211 groups = await self._redis.xinfo_groups(self._stream_name) 

212 for group in groups: 

213 group_name = group["name"].decode("utf-8") 

214 result[group_name] = RedisGroupInfo( 

215 name=group_name, 

216 consumers=group["consumers"], 

217 pending=group["pending"], 

218 last_delivered_id=group["last-delivered-id"].decode("utf-8"), 

219 ) 

220 

221 return result 

222 

223 async def first_entry_id(self) -> str: 

224 """Return the id of the first entry. 

225 

226 An empty string will be returned when there is no entry on the stream. 

227 """ 

228 result = await self.info() 

229 if result is None: 

230 return "" 

231 return result.first_entry 

232 

233 async def info(self) -> RedisStreamInfo | None: 

234 """Return information about the stream. 

235 

236 Returns: 

237 A tuple with length, first-entry-id and last-entry-id. None is returned 

238 when the stream does not exist. 

239 """ 

240 try: 

241 result = await self._redis.xinfo_stream(self._stream_name) 

242 return RedisStreamInfo( 

243 length=result["length"], 

244 first_entry=result["first-entry"][0].decode("utf-8"), 

245 last_entry=result["last-entry"][0].decode("utf-8"), 

246 ) 

247 except redis.exceptions.ResponseError: 

248 return None 

249 

250 async def last_entry_id(self) -> str: 

251 """Return the id of the last entry. 

252 

253 An empty string will be returned when there is no entry on the stream. 

254 """ 

255 result = await self.info() 

256 if result is None: 

257 return "" 

258 return result.last_entry 

259 

260 async def length(self): 

261 """Return the number of entries on the stream. 

262 

263 0 will be returned when the stream does not exist. 

264 """ 

265 result = await self.info() 

266 if result is None: 

267 return 0 

268 return result.length 

269 

270 async def read( 

271 self, last_id: str = "$", block: int | None = None 

272 ) -> RedisMessage | None: 

273 """Read an entry from the stream. 

274 

275 Args: 

276 last_id: only entries with an id greater than this id will be returned. 

277 The default is $, which means to return only new entries. 

278 block: milliseconds to wait for an entry. Use None to not block. 

279 

280 Returns: 

281 None when no entry is read. A tuple with the id and data of the entry 

282 when an entry is read. 

283 """ 

284 messages = await self._redis.xread({self._stream_name: last_id}, 1, block) 

285 if messages is None: # No message read 

286 return 

287 if len(messages) == 0: 

288 return 

289 

290 return RedisMessage.create_from_redis(messages)