diff --git a/core/managers/reverse_ws_manager.py b/core/managers/reverse_ws_manager.py index b64b39d..2848778 100644 --- a/core/managers/reverse_ws_manager.py +++ b/core/managers/reverse_ws_manager.py @@ -21,6 +21,23 @@ from .command_manager import matcher from models.events.factory import EventFactory from .redis_manager import redis_manager from ..bot import Bot +from ..ws import ReverseWSClient as _ReverseWSClient + + +class ReverseWSClient(_ReverseWSClient): + """ + 反向 WebSocket 客户端代理,用于 Bot 实例调用 API。 + """ + def __init__(self, manager: "ReverseWSManager", client_id: str): + super().__init__(manager, client_id) + self.manager = manager + self.client_id = client_id + + async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]: + """ + 通过 ReverseWSManager 调用 API。 + """ + return await self.manager.call_api(action, params, self.client_id) class ReverseWSManager: @@ -46,14 +63,14 @@ class ReverseWSManager: self._client_health: Dict[str, datetime] = {} # 客户端健康检查时间 # 防重复发送相关 - self._processed_events: Dict[str, datetime] = {} # 已处理的事件ID和时间 + self._processed_events: Dict[str, Dict[str, datetime]] = {} # 每个客户端已处理的事件ID和时间 self._event_ttl = 60 # 事件ID保留时间(秒) self._message_locks: Dict[str, asyncio.Lock] = {} # 消息处理锁 self._message_lock_times: Dict[str, datetime] = {} # 消息锁创建时间 self._lock_ttl = 300 # 锁保留时间(秒) - # 基于消息内容的防重复 - self._processed_messages: Dict[str, datetime] = {} # 已处理的消息内容和时间 + # 基于消息内容的防重复(仅用于群聊) + self._processed_messages: Dict[str, Dict[str, datetime]] = {} # 每个客户端已处理的消息内容和时间 self._message_content_ttl = 5 # 消息内容保留时间(秒) # 启动清理任务 @@ -62,6 +79,9 @@ class ReverseWSManager: # Bot实例字典(每个前端独立的Bot实例) self.bots: Dict[str, Bot] = {} + # 正在处理的事件ID集合(用于防止重复处理) + self._processing_events: Dict[str, Set[str]] = {} # client_id: set of event_ids + async def start(self, host: str = "0.0.0.0", port: int = 3002) -> None: """ 启动反向 WebSocket 服务端。 @@ -137,6 +157,8 @@ class ReverseWSManager: # 处理上报事件 if "post_type" in data: + event_id = data.get('id') or data.get('post_id') or data.get('message_id') or data.get('time') + self.logger.debug(f"收到事件: client_id={client_id}, event_id={event_id}, post_type={data.get('post_type')}") asyncio.create_task(self._on_event(client_id, data)) except orjson.JSONDecodeError as e: @@ -161,14 +183,17 @@ class ReverseWSManager: current_time = datetime.now() - # 清理过期的事件ID - expired_events = [ - event_id for event_id, timestamp in self._processed_events.items() - if (current_time - timestamp).total_seconds() > self._event_ttl - ] - for event_id in expired_events: - del self._processed_events[event_id] - + # 清理过期的事件ID(按客户端) + for client_id, events in list(self._processed_events.items()): + expired_events = [ + event_id for event_id, timestamp in events.items() + if (current_time - timestamp).total_seconds() > self._event_ttl + ] + for event_id in expired_events: + del events[event_id] + if not events: + del self._processed_events[client_id] + # 清理过期的消息锁 expired_locks = [ lock_key for lock_key, timestamp in self._message_lock_times.items() @@ -180,14 +205,17 @@ class ReverseWSManager: if lock_key in self._message_lock_times: del self._message_lock_times[lock_key] - # 清理过期的消息内容 - expired_messages = [ - msg_key for msg_key, timestamp in self._processed_messages.items() - if (current_time - timestamp).total_seconds() > self._message_content_ttl - ] - for msg_key in expired_messages: - del self._processed_messages[msg_key] - + # 清理过期的消息内容(按客户端) + for client_id, messages in list(self._processed_messages.items()): + expired_messages = [ + msg_key for msg_key, timestamp in messages.items() + if (current_time - timestamp).total_seconds() > self._message_content_ttl + ] + for msg_key in expired_messages: + del messages[msg_key] + if not messages: + del self._processed_messages[client_id] + except asyncio.CancelledError: break except Exception as e: @@ -211,6 +239,14 @@ class ReverseWSManager: if client_id in self.bots: del self.bots[client_id] + # 清理该客户端的防重复数据 + if client_id in self._processed_events: + del self._processed_events[client_id] + if client_id in self._processed_messages: + del self._processed_messages[client_id] + if client_id in self._processing_events: + del self._processing_events[client_id] + self.logger.info(f"客户端已断开并清理: {client_id}") async def _on_event(self, client_id: str, event_data: Dict[str, Any]) -> None: @@ -221,6 +257,30 @@ class ReverseWSManager: client_id: 客户端 ID event_data: 事件数据 """ + # 获取事件ID + event_id = event_data.get('id') or event_data.get('post_id') or event_data.get('message_id') or event_data.get('time') + if not event_id: + self.logger.debug(f"_on_event: 事件ID为空, client_id={client_id}") + return + + event_key = f"{event_data.get('post_type')}:{event_id}" + + # 检查客户端是否已连接 + if client_id not in self.clients: + self.logger.debug(f"_on_event: 客户端已断开, client_id={client_id}") + return + + # 检查是否正在处理 + if client_id not in self._processing_events: + self._processing_events[client_id] = set() + + if event_key in self._processing_events[client_id]: + self.logger.debug(f"_on_event: 事件正在处理中, client_id={client_id}, event_key={event_key}") + return + + # 标记为正在处理 + self._processing_events[client_id].add(event_key) + try: event = EventFactory.create_event(event_data) @@ -228,11 +288,12 @@ class ReverseWSManager: self.client_self_ids[client_id] = event.self_id # 为事件注入Bot实例 - from ..ws import WS + from ..ws import ReverseWSClient # 为每个前端创建独立的Bot实例 if client_id not in self.bots: - temp_ws = WS() + # 使用 ReverseWSClient 代理 + temp_ws = ReverseWSClient(self, client_id) temp_ws.self_id = event.self_id if hasattr(event, 'self_id') else 0 self.bots[client_id] = Bot(temp_ws) @@ -241,14 +302,13 @@ class ReverseWSManager: # 记录客户端健康状态 self._client_health[client_id] = datetime.now() - # 检查是否为重复事件 - if self._is_duplicate_event(event_data): + # 检查是否为重复事件(按客户端) + is_duplicate = self._is_duplicate_event(event_data, client_id) + self.logger.debug(f"事件防重复检查: client_id={client_id}, event_id={event_data.get('message_id')}, is_duplicate={is_duplicate}") + if is_duplicate: self.logger.debug(f"检测到重复事件,已忽略: {event_data.get('id')}") return - - # 标记事件已处理 - self._mark_event_processed(event_data) - + # 处理消息事件 if event.post_type == "message": sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown" @@ -260,40 +320,54 @@ class ReverseWSManager: # 使用锁防止同一消息被多次处理 message_key = self._get_message_key(event_data) async with self._get_message_lock(message_key): - # 检查是否重复(基于事件ID) - if self._is_duplicate_event(event_data): + # 再次检查是否重复(防止并发问题) + if self._is_duplicate_event(event_data, client_id): self.logger.debug(f"并发检测到重复消息(事件ID),已忽略: {message_key}") return - # 检查是否重复(基于消息内容) - if self._is_duplicate_message(event_data): + # 检查是否重复(基于消息内容,按客户端,仅群聊) + is_duplicate_content = self._is_duplicate_message(event_data, client_id) + self.logger.debug(f"锁内内容检查: client_id={client_id}, is_duplicate={is_duplicate_content}") + if is_duplicate_content: self.logger.debug(f"并发检测到重复消息(内容),已忽略: {message_key}") return - self._mark_event_processed(event_data) + # 标记事件已处理(按客户端) + self._mark_event_processed(event_data, client_id) # 更新客户端负载 self._update_client_load(client_id) - await matcher.handle_event(None, event) + await matcher.handle_event(event.bot, event) + else: + # 对于非消息事件,直接标记并处理 + self._mark_event_processed(event_data, client_id) + + if event.post_type == "notice": + notice_type = getattr(event, "notice_type", "Unknown") + self.logger.info(f"[通知] {notice_type}") + await matcher.handle_event(event.bot, event) - elif event.post_type == "notice": - notice_type = getattr(event, "notice_type", "Unknown") - self.logger.info(f"[通知] {notice_type}") - await matcher.handle_event(None, event) - - elif event.post_type == "request": - request_type = getattr(event, "request_type", "Unknown") - self.logger.info(f"[请求] {request_type}") - await matcher.handle_event(None, event) - - elif event.post_type == "meta_event": - meta_event_type = getattr(event, "meta_event_type", "Unknown") - self.logger.debug(f"[元事件] {meta_event_type}") - await matcher.handle_event(None, event) + elif event.post_type == "request": + request_type = getattr(event, "request_type", "Unknown") + self.logger.info(f"[请求] {request_type}") + await matcher.handle_event(event.bot, event) + + elif event.post_type == "meta_event": + meta_event_type = getattr(event, "meta_event_type", "Unknown") + self.logger.debug(f"[元事件] {meta_event_type}") + await matcher.handle_event(event.bot, event) except Exception as e: self.logger.exception(f"事件处理异常: {str(e)}") + finally: + # 清理正在处理的事件 + if client_id in self._processing_events: + if event_key in self._processing_events[client_id]: + self._processing_events[client_id].discard(event_key) + # 如果集合为空,删除该客户端的记录 + if not self._processing_events[client_id]: + del self._processing_events[client_id] async def call_api( self, @@ -376,29 +450,43 @@ class ReverseWSManager: """ return self.client_self_ids.copy() - def _is_duplicate_event(self, event_data: Dict[str, Any]) -> bool: + def _is_duplicate_event(self, event_data: Dict[str, Any], client_id: str) -> bool: """ 检查是否为重复事件。 Args: event_data: 事件数据 + client_id: 客户端ID Returns: 是否为重复事件 """ - event_id = event_data.get('id') or event_data.get('post_id') or event_data.get('time') + # 尝试多种可能的事件ID字段 + event_id = (event_data.get('id') or + event_data.get('post_id') or + event_data.get('message_id') or + event_data.get('time')) if not event_id: return False event_key = f"{event_data.get('post_type')}:{event_id}" - return event_key in self._processed_events - def _is_duplicate_message(self, event_data: Dict[str, Any]) -> bool: + # 检查该客户端是否已处理过此事件 + if client_id not in self._processed_events: + self.logger.debug(f"_is_duplicate_event: client_id={client_id}不在_processed_events中, event_key={event_key}, 返回False") + return False + + is_duplicate = event_key in self._processed_events[client_id] + self.logger.debug(f"_is_duplicate_event: client_id={client_id}, event_key={event_key}, in_processed={is_duplicate}, processed_events_count={len(self._processed_events[client_id])}") + return is_duplicate + + def _is_duplicate_message(self, event_data: Dict[str, Any], client_id: str) -> bool: """ 检查是否为重复消息(基于消息内容)。 Args: event_data: 事件数据 + client_id: 客户端ID Returns: 是否为重复消息 @@ -406,35 +494,59 @@ class ReverseWSManager: if event_data.get('post_type') != 'message': return False + # 只对群聊消息进行内容防重复 + if event_data.get('message_type') != 'group': + return False + # 生成消息内容标识 - message_id = event_data.get('message_id') - user_id = event_data.get('user_id') raw_message = event_data.get('raw_message', '') + user_id = event_data.get('user_id') + group_id = event_data.get('group_id', '0') - # 使用消息内容和用户ID作为标识 - content_key = f"content:{raw_message}:{user_id}" - return content_key in self._processed_messages + # 使用消息内容、用户ID和群组ID作为标识 + content_key = f"content:{raw_message}:{user_id}:{group_id}" - def _mark_event_processed(self, event_data: Dict[str, Any]) -> None: + # 检查该客户端是否已处理过此消息内容 + if client_id not in self._processed_messages: + return False + + return content_key in self._processed_messages[client_id] + + def _mark_event_processed(self, event_data: Dict[str, Any], client_id: str) -> None: """ 标记事件已处理。 Args: event_data: 事件数据 + client_id: 客户端ID """ - event_id = event_data.get('id') or event_data.get('post_id') or event_data.get('time') + # 尝试多种可能的事件ID字段 + event_id = (event_data.get('id') or + event_data.get('post_id') or + event_data.get('message_id') or + event_data.get('time')) if not event_id: + self.logger.debug(f"_mark_event_processed: event_id为空, event_data={event_data}") return event_key = f"{event_data.get('post_type')}:{event_id}" - self._processed_events[event_key] = datetime.now() - # 同时标记消息内容已处理 - if event_data.get('post_type') == 'message': + # 为该客户端记录已处理的事件 + if client_id not in self._processed_events: + self._processed_events[client_id] = {} + self._processed_events[client_id][event_key] = datetime.now() + self.logger.debug(f"_mark_event_processed: client_id={client_id}, event_key={event_key}, processed_events_count={len(self._processed_events[client_id])}") + + # 只对群聊消息标记内容已处理 + if event_data.get('post_type') == 'message' and event_data.get('message_type') == 'group': raw_message = event_data.get('raw_message', '') user_id = event_data.get('user_id') - content_key = f"content:{raw_message}:{user_id}" - self._processed_messages[content_key] = datetime.now() + group_id = event_data.get('group_id', '0') + content_key = f"content:{raw_message}:{user_id}:{group_id}" + + if client_id not in self._processed_messages: + self._processed_messages[client_id] = {} + self._processed_messages[client_id][content_key] = datetime.now() def _get_message_key(self, event_data: Dict[str, Any]) -> str: """ diff --git a/core/ws.py b/core/ws.py index 187db30..6030e9f 100644 --- a/core/ws.py +++ b/core/ws.py @@ -25,7 +25,6 @@ from websockets.legacy.client import WebSocketClientProtocol from models.events.factory import EventFactory from .config_loader import global_config -from .managers.command_manager import matcher from .utils.executor import CodeExecutor from .utils.logger import ModuleLogger from .utils.exceptions import ( @@ -210,6 +209,7 @@ class WS: self.logger.debug(f"[元事件] {meta_event_type}") # 分发事件 + from .managers.command_manager import matcher await matcher.handle_event(self.bot, event) except Exception as e: @@ -297,3 +297,16 @@ class WS: message=f"API调用异常: {str(e)}", data={"action": action, "params": params} ) + + +class ReverseWSClient(WS): + """ + 反向 WebSocket 客户端代理,用于 Bot 实例调用 API。 + """ + def __init__(self, manager: Any, client_id: str): + super().__init__() + self.manager = manager + self.client_id = client_id + + async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]: + return await self.manager.call_api(action, params, self.client_id)