refactor(reverse_ws): 重构反向WebSocket管理器的防重复处理逻辑

将防重复处理数据结构改为按客户端隔离,防止不同客户端间的事件冲突
添加事件处理中的状态跟踪,避免并发处理同一事件
优化群消息内容防重复检查,仅对群聊消息生效
增加详细的调试日志,便于问题排查
This commit is contained in:
2026-02-28 22:45:36 +08:00
parent 8e6f6cca0c
commit 311b1985dd
2 changed files with 189 additions and 64 deletions

View File

@@ -21,6 +21,23 @@ from .command_manager import matcher
from models.events.factory import EventFactory from models.events.factory import EventFactory
from .redis_manager import redis_manager from .redis_manager import redis_manager
from ..bot import Bot from ..bot import Bot
from ..ws import ReverseWSClient as _ReverseWSClient
class ReverseWSClient(_ReverseWSClient):
"""
反向 WebSocket 客户端代理,用于 Bot 实例调用 API。
"""
def __init__(self, manager: "ReverseWSManager", client_id: str):
super().__init__(manager, client_id)
self.manager = manager
self.client_id = client_id
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
"""
通过 ReverseWSManager 调用 API。
"""
return await self.manager.call_api(action, params, self.client_id)
class ReverseWSManager: class ReverseWSManager:
@@ -46,14 +63,14 @@ class ReverseWSManager:
self._client_health: Dict[str, datetime] = {} # 客户端健康检查时间 self._client_health: Dict[str, datetime] = {} # 客户端健康检查时间
# 防重复发送相关 # 防重复发送相关
self._processed_events: Dict[str, datetime] = {} # 已处理的事件ID和时间 self._processed_events: Dict[str, Dict[str, datetime]] = {} # 每个客户端已处理的事件ID和时间
self._event_ttl = 60 # 事件ID保留时间 self._event_ttl = 60 # 事件ID保留时间
self._message_locks: Dict[str, asyncio.Lock] = {} # 消息处理锁 self._message_locks: Dict[str, asyncio.Lock] = {} # 消息处理锁
self._message_lock_times: Dict[str, datetime] = {} # 消息锁创建时间 self._message_lock_times: Dict[str, datetime] = {} # 消息锁创建时间
self._lock_ttl = 300 # 锁保留时间(秒) self._lock_ttl = 300 # 锁保留时间(秒)
# 基于消息内容的防重复 # 基于消息内容的防重复(仅用于群聊)
self._processed_messages: Dict[str, datetime] = {} # 已处理的消息内容和时间 self._processed_messages: Dict[str, Dict[str, datetime]] = {} # 每个客户端已处理的消息内容和时间
self._message_content_ttl = 5 # 消息内容保留时间(秒) self._message_content_ttl = 5 # 消息内容保留时间(秒)
# 启动清理任务 # 启动清理任务
@@ -62,6 +79,9 @@ class ReverseWSManager:
# Bot实例字典每个前端独立的Bot实例 # Bot实例字典每个前端独立的Bot实例
self.bots: Dict[str, Bot] = {} self.bots: Dict[str, Bot] = {}
# 正在处理的事件ID集合用于防止重复处理
self._processing_events: Dict[str, Set[str]] = {} # client_id: set of event_ids
async def start(self, host: str = "0.0.0.0", port: int = 3002) -> None: async def start(self, host: str = "0.0.0.0", port: int = 3002) -> None:
""" """
启动反向 WebSocket 服务端。 启动反向 WebSocket 服务端。
@@ -137,6 +157,8 @@ class ReverseWSManager:
# 处理上报事件 # 处理上报事件
if "post_type" in data: if "post_type" in data:
event_id = data.get('id') or data.get('post_id') or data.get('message_id') or data.get('time')
self.logger.debug(f"收到事件: client_id={client_id}, event_id={event_id}, post_type={data.get('post_type')}")
asyncio.create_task(self._on_event(client_id, data)) asyncio.create_task(self._on_event(client_id, data))
except orjson.JSONDecodeError as e: except orjson.JSONDecodeError as e:
@@ -161,13 +183,16 @@ class ReverseWSManager:
current_time = datetime.now() current_time = datetime.now()
# 清理过期的事件ID # 清理过期的事件ID(按客户端)
for client_id, events in list(self._processed_events.items()):
expired_events = [ expired_events = [
event_id for event_id, timestamp in self._processed_events.items() event_id for event_id, timestamp in events.items()
if (current_time - timestamp).total_seconds() > self._event_ttl if (current_time - timestamp).total_seconds() > self._event_ttl
] ]
for event_id in expired_events: for event_id in expired_events:
del self._processed_events[event_id] del events[event_id]
if not events:
del self._processed_events[client_id]
# 清理过期的消息锁 # 清理过期的消息锁
expired_locks = [ expired_locks = [
@@ -180,13 +205,16 @@ class ReverseWSManager:
if lock_key in self._message_lock_times: if lock_key in self._message_lock_times:
del self._message_lock_times[lock_key] del self._message_lock_times[lock_key]
# 清理过期的消息内容 # 清理过期的消息内容(按客户端)
for client_id, messages in list(self._processed_messages.items()):
expired_messages = [ expired_messages = [
msg_key for msg_key, timestamp in self._processed_messages.items() msg_key for msg_key, timestamp in messages.items()
if (current_time - timestamp).total_seconds() > self._message_content_ttl if (current_time - timestamp).total_seconds() > self._message_content_ttl
] ]
for msg_key in expired_messages: for msg_key in expired_messages:
del self._processed_messages[msg_key] del messages[msg_key]
if not messages:
del self._processed_messages[client_id]
except asyncio.CancelledError: except asyncio.CancelledError:
break break
@@ -211,6 +239,14 @@ class ReverseWSManager:
if client_id in self.bots: if client_id in self.bots:
del self.bots[client_id] del self.bots[client_id]
# 清理该客户端的防重复数据
if client_id in self._processed_events:
del self._processed_events[client_id]
if client_id in self._processed_messages:
del self._processed_messages[client_id]
if client_id in self._processing_events:
del self._processing_events[client_id]
self.logger.info(f"客户端已断开并清理: {client_id}") self.logger.info(f"客户端已断开并清理: {client_id}")
async def _on_event(self, client_id: str, event_data: Dict[str, Any]) -> None: async def _on_event(self, client_id: str, event_data: Dict[str, Any]) -> None:
@@ -221,6 +257,30 @@ class ReverseWSManager:
client_id: 客户端 ID client_id: 客户端 ID
event_data: 事件数据 event_data: 事件数据
""" """
# 获取事件ID
event_id = event_data.get('id') or event_data.get('post_id') or event_data.get('message_id') or event_data.get('time')
if not event_id:
self.logger.debug(f"_on_event: 事件ID为空, client_id={client_id}")
return
event_key = f"{event_data.get('post_type')}:{event_id}"
# 检查客户端是否已连接
if client_id not in self.clients:
self.logger.debug(f"_on_event: 客户端已断开, client_id={client_id}")
return
# 检查是否正在处理
if client_id not in self._processing_events:
self._processing_events[client_id] = set()
if event_key in self._processing_events[client_id]:
self.logger.debug(f"_on_event: 事件正在处理中, client_id={client_id}, event_key={event_key}")
return
# 标记为正在处理
self._processing_events[client_id].add(event_key)
try: try:
event = EventFactory.create_event(event_data) event = EventFactory.create_event(event_data)
@@ -228,11 +288,12 @@ class ReverseWSManager:
self.client_self_ids[client_id] = event.self_id self.client_self_ids[client_id] = event.self_id
# 为事件注入Bot实例 # 为事件注入Bot实例
from ..ws import WS from ..ws import ReverseWSClient
# 为每个前端创建独立的Bot实例 # 为每个前端创建独立的Bot实例
if client_id not in self.bots: if client_id not in self.bots:
temp_ws = WS() # 使用 ReverseWSClient 代理
temp_ws = ReverseWSClient(self, client_id)
temp_ws.self_id = event.self_id if hasattr(event, 'self_id') else 0 temp_ws.self_id = event.self_id if hasattr(event, 'self_id') else 0
self.bots[client_id] = Bot(temp_ws) self.bots[client_id] = Bot(temp_ws)
@@ -241,14 +302,13 @@ class ReverseWSManager:
# 记录客户端健康状态 # 记录客户端健康状态
self._client_health[client_id] = datetime.now() self._client_health[client_id] = datetime.now()
# 检查是否为重复事件 # 检查是否为重复事件(按客户端)
if self._is_duplicate_event(event_data): is_duplicate = self._is_duplicate_event(event_data, client_id)
self.logger.debug(f"事件防重复检查: client_id={client_id}, event_id={event_data.get('message_id')}, is_duplicate={is_duplicate}")
if is_duplicate:
self.logger.debug(f"检测到重复事件,已忽略: {event_data.get('id')}") self.logger.debug(f"检测到重复事件,已忽略: {event_data.get('id')}")
return return
# 标记事件已处理
self._mark_event_processed(event_data)
# 处理消息事件 # 处理消息事件
if event.post_type == "message": if event.post_type == "message":
sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown" sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown"
@@ -260,40 +320,54 @@ class ReverseWSManager:
# 使用锁防止同一消息被多次处理 # 使用锁防止同一消息被多次处理
message_key = self._get_message_key(event_data) message_key = self._get_message_key(event_data)
async with self._get_message_lock(message_key): async with self._get_message_lock(message_key):
# 检查是否重复(基于事件ID # 再次检查是否重复(防止并发问题
if self._is_duplicate_event(event_data): if self._is_duplicate_event(event_data, client_id):
self.logger.debug(f"并发检测到重复消息事件ID已忽略: {message_key}") self.logger.debug(f"并发检测到重复消息事件ID已忽略: {message_key}")
return return
# 检查是否重复(基于消息内容) # 检查是否重复(基于消息内容,按客户端,仅群聊
if self._is_duplicate_message(event_data): is_duplicate_content = self._is_duplicate_message(event_data, client_id)
self.logger.debug(f"锁内内容检查: client_id={client_id}, is_duplicate={is_duplicate_content}")
if is_duplicate_content:
self.logger.debug(f"并发检测到重复消息(内容),已忽略: {message_key}") self.logger.debug(f"并发检测到重复消息(内容),已忽略: {message_key}")
return return
self._mark_event_processed(event_data) # 标记事件已处理(按客户端)
self._mark_event_processed(event_data, client_id)
# 更新客户端负载 # 更新客户端负载
self._update_client_load(client_id) self._update_client_load(client_id)
await matcher.handle_event(None, event) await matcher.handle_event(event.bot, event)
else:
# 对于非消息事件,直接标记并处理
self._mark_event_processed(event_data, client_id)
elif event.post_type == "notice": if event.post_type == "notice":
notice_type = getattr(event, "notice_type", "Unknown") notice_type = getattr(event, "notice_type", "Unknown")
self.logger.info(f"[通知] {notice_type}") self.logger.info(f"[通知] {notice_type}")
await matcher.handle_event(None, event) await matcher.handle_event(event.bot, event)
elif event.post_type == "request": elif event.post_type == "request":
request_type = getattr(event, "request_type", "Unknown") request_type = getattr(event, "request_type", "Unknown")
self.logger.info(f"[请求] {request_type}") self.logger.info(f"[请求] {request_type}")
await matcher.handle_event(None, event) await matcher.handle_event(event.bot, event)
elif event.post_type == "meta_event": elif event.post_type == "meta_event":
meta_event_type = getattr(event, "meta_event_type", "Unknown") meta_event_type = getattr(event, "meta_event_type", "Unknown")
self.logger.debug(f"[元事件] {meta_event_type}") self.logger.debug(f"[元事件] {meta_event_type}")
await matcher.handle_event(None, event) await matcher.handle_event(event.bot, event)
except Exception as e: except Exception as e:
self.logger.exception(f"事件处理异常: {str(e)}") self.logger.exception(f"事件处理异常: {str(e)}")
finally:
# 清理正在处理的事件
if client_id in self._processing_events:
if event_key in self._processing_events[client_id]:
self._processing_events[client_id].discard(event_key)
# 如果集合为空,删除该客户端的记录
if not self._processing_events[client_id]:
del self._processing_events[client_id]
async def call_api( async def call_api(
self, self,
@@ -376,29 +450,43 @@ class ReverseWSManager:
""" """
return self.client_self_ids.copy() return self.client_self_ids.copy()
def _is_duplicate_event(self, event_data: Dict[str, Any]) -> bool: def _is_duplicate_event(self, event_data: Dict[str, Any], client_id: str) -> bool:
""" """
检查是否为重复事件。 检查是否为重复事件。
Args: Args:
event_data: 事件数据 event_data: 事件数据
client_id: 客户端ID
Returns: Returns:
是否为重复事件 是否为重复事件
""" """
event_id = event_data.get('id') or event_data.get('post_id') or event_data.get('time') # 尝试多种可能的事件ID字段
event_id = (event_data.get('id') or
event_data.get('post_id') or
event_data.get('message_id') or
event_data.get('time'))
if not event_id: if not event_id:
return False return False
event_key = f"{event_data.get('post_type')}:{event_id}" event_key = f"{event_data.get('post_type')}:{event_id}"
return event_key in self._processed_events
def _is_duplicate_message(self, event_data: Dict[str, Any]) -> bool: # 检查该客户端是否已处理过此事件
if client_id not in self._processed_events:
self.logger.debug(f"_is_duplicate_event: client_id={client_id}不在_processed_events中, event_key={event_key}, 返回False")
return False
is_duplicate = event_key in self._processed_events[client_id]
self.logger.debug(f"_is_duplicate_event: client_id={client_id}, event_key={event_key}, in_processed={is_duplicate}, processed_events_count={len(self._processed_events[client_id])}")
return is_duplicate
def _is_duplicate_message(self, event_data: Dict[str, Any], client_id: str) -> bool:
""" """
检查是否为重复消息(基于消息内容)。 检查是否为重复消息(基于消息内容)。
Args: Args:
event_data: 事件数据 event_data: 事件数据
client_id: 客户端ID
Returns: Returns:
是否为重复消息 是否为重复消息
@@ -406,35 +494,59 @@ class ReverseWSManager:
if event_data.get('post_type') != 'message': if event_data.get('post_type') != 'message':
return False return False
# 只对群聊消息进行内容防重复
if event_data.get('message_type') != 'group':
return False
# 生成消息内容标识 # 生成消息内容标识
message_id = event_data.get('message_id')
user_id = event_data.get('user_id')
raw_message = event_data.get('raw_message', '') raw_message = event_data.get('raw_message', '')
user_id = event_data.get('user_id')
group_id = event_data.get('group_id', '0')
# 使用消息内容用户ID作为标识 # 使用消息内容用户ID和群组ID作为标识
content_key = f"content:{raw_message}:{user_id}" content_key = f"content:{raw_message}:{user_id}:{group_id}"
return content_key in self._processed_messages
def _mark_event_processed(self, event_data: Dict[str, Any]) -> None: # 检查该客户端是否已处理过此消息内容
if client_id not in self._processed_messages:
return False
return content_key in self._processed_messages[client_id]
def _mark_event_processed(self, event_data: Dict[str, Any], client_id: str) -> None:
""" """
标记事件已处理。 标记事件已处理。
Args: Args:
event_data: 事件数据 event_data: 事件数据
client_id: 客户端ID
""" """
event_id = event_data.get('id') or event_data.get('post_id') or event_data.get('time') # 尝试多种可能的事件ID字段
event_id = (event_data.get('id') or
event_data.get('post_id') or
event_data.get('message_id') or
event_data.get('time'))
if not event_id: if not event_id:
self.logger.debug(f"_mark_event_processed: event_id为空, event_data={event_data}")
return return
event_key = f"{event_data.get('post_type')}:{event_id}" event_key = f"{event_data.get('post_type')}:{event_id}"
self._processed_events[event_key] = datetime.now()
# 同时标记消息内容已处理 # 为该客户端记录已处理的事件
if event_data.get('post_type') == 'message': if client_id not in self._processed_events:
self._processed_events[client_id] = {}
self._processed_events[client_id][event_key] = datetime.now()
self.logger.debug(f"_mark_event_processed: client_id={client_id}, event_key={event_key}, processed_events_count={len(self._processed_events[client_id])}")
# 只对群聊消息标记内容已处理
if event_data.get('post_type') == 'message' and event_data.get('message_type') == 'group':
raw_message = event_data.get('raw_message', '') raw_message = event_data.get('raw_message', '')
user_id = event_data.get('user_id') user_id = event_data.get('user_id')
content_key = f"content:{raw_message}:{user_id}" group_id = event_data.get('group_id', '0')
self._processed_messages[content_key] = datetime.now() content_key = f"content:{raw_message}:{user_id}:{group_id}"
if client_id not in self._processed_messages:
self._processed_messages[client_id] = {}
self._processed_messages[client_id][content_key] = datetime.now()
def _get_message_key(self, event_data: Dict[str, Any]) -> str: def _get_message_key(self, event_data: Dict[str, Any]) -> str:
""" """

View File

@@ -25,7 +25,6 @@ from websockets.legacy.client import WebSocketClientProtocol
from models.events.factory import EventFactory from models.events.factory import EventFactory
from .config_loader import global_config from .config_loader import global_config
from .managers.command_manager import matcher
from .utils.executor import CodeExecutor from .utils.executor import CodeExecutor
from .utils.logger import ModuleLogger from .utils.logger import ModuleLogger
from .utils.exceptions import ( from .utils.exceptions import (
@@ -210,6 +209,7 @@ class WS:
self.logger.debug(f"[元事件] {meta_event_type}") self.logger.debug(f"[元事件] {meta_event_type}")
# 分发事件 # 分发事件
from .managers.command_manager import matcher
await matcher.handle_event(self.bot, event) await matcher.handle_event(self.bot, event)
except Exception as e: except Exception as e:
@@ -297,3 +297,16 @@ class WS:
message=f"API调用异常: {str(e)}", message=f"API调用异常: {str(e)}",
data={"action": action, "params": params} data={"action": action, "params": params}
) )
class ReverseWSClient(WS):
"""
反向 WebSocket 客户端代理,用于 Bot 实例调用 API。
"""
def __init__(self, manager: Any, client_id: str):
super().__init__()
self.manager = manager
self.client_id = client_id
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
return await self.manager.call_api(action, params, self.client_id)