Merge pull request #64 from Fairy-Oracle-Sanctuary/dev
refactor(reverse_ws): 重构反向WebSocket管理器的防重复处理逻辑
This commit is contained in:
@@ -21,6 +21,23 @@ from .command_manager import matcher
|
||||
from models.events.factory import EventFactory
|
||||
from .redis_manager import redis_manager
|
||||
from ..bot import Bot
|
||||
from ..ws import ReverseWSClient as _ReverseWSClient
|
||||
|
||||
|
||||
class ReverseWSClient(_ReverseWSClient):
|
||||
"""
|
||||
反向 WebSocket 客户端代理,用于 Bot 实例调用 API。
|
||||
"""
|
||||
def __init__(self, manager: "ReverseWSManager", client_id: str):
|
||||
super().__init__(manager, client_id)
|
||||
self.manager = manager
|
||||
self.client_id = client_id
|
||||
|
||||
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
|
||||
"""
|
||||
通过 ReverseWSManager 调用 API。
|
||||
"""
|
||||
return await self.manager.call_api(action, params, self.client_id)
|
||||
|
||||
|
||||
class ReverseWSManager:
|
||||
@@ -46,14 +63,14 @@ class ReverseWSManager:
|
||||
self._client_health: Dict[str, datetime] = {} # 客户端健康检查时间
|
||||
|
||||
# 防重复发送相关
|
||||
self._processed_events: Dict[str, datetime] = {} # 已处理的事件ID和时间
|
||||
self._processed_events: Dict[str, Dict[str, datetime]] = {} # 每个客户端已处理的事件ID和时间
|
||||
self._event_ttl = 60 # 事件ID保留时间(秒)
|
||||
self._message_locks: Dict[str, asyncio.Lock] = {} # 消息处理锁
|
||||
self._message_lock_times: Dict[str, datetime] = {} # 消息锁创建时间
|
||||
self._lock_ttl = 300 # 锁保留时间(秒)
|
||||
|
||||
# 基于消息内容的防重复
|
||||
self._processed_messages: Dict[str, datetime] = {} # 已处理的消息内容和时间
|
||||
# 基于消息内容的防重复(仅用于群聊)
|
||||
self._processed_messages: Dict[str, Dict[str, datetime]] = {} # 每个客户端已处理的消息内容和时间
|
||||
self._message_content_ttl = 5 # 消息内容保留时间(秒)
|
||||
|
||||
# 启动清理任务
|
||||
@@ -62,6 +79,9 @@ class ReverseWSManager:
|
||||
# Bot实例字典(每个前端独立的Bot实例)
|
||||
self.bots: Dict[str, Bot] = {}
|
||||
|
||||
# 正在处理的事件ID集合(用于防止重复处理)
|
||||
self._processing_events: Dict[str, Set[str]] = {} # client_id: set of event_ids
|
||||
|
||||
async def start(self, host: str = "0.0.0.0", port: int = 3002) -> None:
|
||||
"""
|
||||
启动反向 WebSocket 服务端。
|
||||
@@ -137,6 +157,8 @@ class ReverseWSManager:
|
||||
|
||||
# 处理上报事件
|
||||
if "post_type" in data:
|
||||
event_id = data.get('id') or data.get('post_id') or data.get('message_id') or data.get('time')
|
||||
self.logger.debug(f"收到事件: client_id={client_id}, event_id={event_id}, post_type={data.get('post_type')}")
|
||||
asyncio.create_task(self._on_event(client_id, data))
|
||||
|
||||
except orjson.JSONDecodeError as e:
|
||||
@@ -161,13 +183,16 @@ class ReverseWSManager:
|
||||
|
||||
current_time = datetime.now()
|
||||
|
||||
# 清理过期的事件ID
|
||||
expired_events = [
|
||||
event_id for event_id, timestamp in self._processed_events.items()
|
||||
if (current_time - timestamp).total_seconds() > self._event_ttl
|
||||
]
|
||||
for event_id in expired_events:
|
||||
del self._processed_events[event_id]
|
||||
# 清理过期的事件ID(按客户端)
|
||||
for client_id, events in list(self._processed_events.items()):
|
||||
expired_events = [
|
||||
event_id for event_id, timestamp in events.items()
|
||||
if (current_time - timestamp).total_seconds() > self._event_ttl
|
||||
]
|
||||
for event_id in expired_events:
|
||||
del events[event_id]
|
||||
if not events:
|
||||
del self._processed_events[client_id]
|
||||
|
||||
# 清理过期的消息锁
|
||||
expired_locks = [
|
||||
@@ -180,13 +205,16 @@ class ReverseWSManager:
|
||||
if lock_key in self._message_lock_times:
|
||||
del self._message_lock_times[lock_key]
|
||||
|
||||
# 清理过期的消息内容
|
||||
expired_messages = [
|
||||
msg_key for msg_key, timestamp in self._processed_messages.items()
|
||||
if (current_time - timestamp).total_seconds() > self._message_content_ttl
|
||||
]
|
||||
for msg_key in expired_messages:
|
||||
del self._processed_messages[msg_key]
|
||||
# 清理过期的消息内容(按客户端)
|
||||
for client_id, messages in list(self._processed_messages.items()):
|
||||
expired_messages = [
|
||||
msg_key for msg_key, timestamp in messages.items()
|
||||
if (current_time - timestamp).total_seconds() > self._message_content_ttl
|
||||
]
|
||||
for msg_key in expired_messages:
|
||||
del messages[msg_key]
|
||||
if not messages:
|
||||
del self._processed_messages[client_id]
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
@@ -211,6 +239,14 @@ class ReverseWSManager:
|
||||
if client_id in self.bots:
|
||||
del self.bots[client_id]
|
||||
|
||||
# 清理该客户端的防重复数据
|
||||
if client_id in self._processed_events:
|
||||
del self._processed_events[client_id]
|
||||
if client_id in self._processed_messages:
|
||||
del self._processed_messages[client_id]
|
||||
if client_id in self._processing_events:
|
||||
del self._processing_events[client_id]
|
||||
|
||||
self.logger.info(f"客户端已断开并清理: {client_id}")
|
||||
|
||||
async def _on_event(self, client_id: str, event_data: Dict[str, Any]) -> None:
|
||||
@@ -221,6 +257,30 @@ class ReverseWSManager:
|
||||
client_id: 客户端 ID
|
||||
event_data: 事件数据
|
||||
"""
|
||||
# 获取事件ID
|
||||
event_id = event_data.get('id') or event_data.get('post_id') or event_data.get('message_id') or event_data.get('time')
|
||||
if not event_id:
|
||||
self.logger.debug(f"_on_event: 事件ID为空, client_id={client_id}")
|
||||
return
|
||||
|
||||
event_key = f"{event_data.get('post_type')}:{event_id}"
|
||||
|
||||
# 检查客户端是否已连接
|
||||
if client_id not in self.clients:
|
||||
self.logger.debug(f"_on_event: 客户端已断开, client_id={client_id}")
|
||||
return
|
||||
|
||||
# 检查是否正在处理
|
||||
if client_id not in self._processing_events:
|
||||
self._processing_events[client_id] = set()
|
||||
|
||||
if event_key in self._processing_events[client_id]:
|
||||
self.logger.debug(f"_on_event: 事件正在处理中, client_id={client_id}, event_key={event_key}")
|
||||
return
|
||||
|
||||
# 标记为正在处理
|
||||
self._processing_events[client_id].add(event_key)
|
||||
|
||||
try:
|
||||
event = EventFactory.create_event(event_data)
|
||||
|
||||
@@ -228,11 +288,12 @@ class ReverseWSManager:
|
||||
self.client_self_ids[client_id] = event.self_id
|
||||
|
||||
# 为事件注入Bot实例
|
||||
from ..ws import WS
|
||||
from ..ws import ReverseWSClient
|
||||
|
||||
# 为每个前端创建独立的Bot实例
|
||||
if client_id not in self.bots:
|
||||
temp_ws = WS()
|
||||
# 使用 ReverseWSClient 代理
|
||||
temp_ws = ReverseWSClient(self, client_id)
|
||||
temp_ws.self_id = event.self_id if hasattr(event, 'self_id') else 0
|
||||
self.bots[client_id] = Bot(temp_ws)
|
||||
|
||||
@@ -241,14 +302,13 @@ class ReverseWSManager:
|
||||
# 记录客户端健康状态
|
||||
self._client_health[client_id] = datetime.now()
|
||||
|
||||
# 检查是否为重复事件
|
||||
if self._is_duplicate_event(event_data):
|
||||
# 检查是否为重复事件(按客户端)
|
||||
is_duplicate = self._is_duplicate_event(event_data, client_id)
|
||||
self.logger.debug(f"事件防重复检查: client_id={client_id}, event_id={event_data.get('message_id')}, is_duplicate={is_duplicate}")
|
||||
if is_duplicate:
|
||||
self.logger.debug(f"检测到重复事件,已忽略: {event_data.get('id')}")
|
||||
return
|
||||
|
||||
# 标记事件已处理
|
||||
self._mark_event_processed(event_data)
|
||||
|
||||
# 处理消息事件
|
||||
if event.post_type == "message":
|
||||
sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown"
|
||||
@@ -260,40 +320,54 @@ class ReverseWSManager:
|
||||
# 使用锁防止同一消息被多次处理
|
||||
message_key = self._get_message_key(event_data)
|
||||
async with self._get_message_lock(message_key):
|
||||
# 检查是否重复(基于事件ID)
|
||||
if self._is_duplicate_event(event_data):
|
||||
# 再次检查是否重复(防止并发问题)
|
||||
if self._is_duplicate_event(event_data, client_id):
|
||||
self.logger.debug(f"并发检测到重复消息(事件ID),已忽略: {message_key}")
|
||||
return
|
||||
|
||||
# 检查是否重复(基于消息内容)
|
||||
if self._is_duplicate_message(event_data):
|
||||
# 检查是否重复(基于消息内容,按客户端,仅群聊)
|
||||
is_duplicate_content = self._is_duplicate_message(event_data, client_id)
|
||||
self.logger.debug(f"锁内内容检查: client_id={client_id}, is_duplicate={is_duplicate_content}")
|
||||
if is_duplicate_content:
|
||||
self.logger.debug(f"并发检测到重复消息(内容),已忽略: {message_key}")
|
||||
return
|
||||
|
||||
self._mark_event_processed(event_data)
|
||||
# 标记事件已处理(按客户端)
|
||||
self._mark_event_processed(event_data, client_id)
|
||||
|
||||
# 更新客户端负载
|
||||
self._update_client_load(client_id)
|
||||
|
||||
await matcher.handle_event(None, event)
|
||||
await matcher.handle_event(event.bot, event)
|
||||
else:
|
||||
# 对于非消息事件,直接标记并处理
|
||||
self._mark_event_processed(event_data, client_id)
|
||||
|
||||
elif event.post_type == "notice":
|
||||
notice_type = getattr(event, "notice_type", "Unknown")
|
||||
self.logger.info(f"[通知] {notice_type}")
|
||||
await matcher.handle_event(None, event)
|
||||
if event.post_type == "notice":
|
||||
notice_type = getattr(event, "notice_type", "Unknown")
|
||||
self.logger.info(f"[通知] {notice_type}")
|
||||
await matcher.handle_event(event.bot, event)
|
||||
|
||||
elif event.post_type == "request":
|
||||
request_type = getattr(event, "request_type", "Unknown")
|
||||
self.logger.info(f"[请求] {request_type}")
|
||||
await matcher.handle_event(None, event)
|
||||
elif event.post_type == "request":
|
||||
request_type = getattr(event, "request_type", "Unknown")
|
||||
self.logger.info(f"[请求] {request_type}")
|
||||
await matcher.handle_event(event.bot, event)
|
||||
|
||||
elif event.post_type == "meta_event":
|
||||
meta_event_type = getattr(event, "meta_event_type", "Unknown")
|
||||
self.logger.debug(f"[元事件] {meta_event_type}")
|
||||
await matcher.handle_event(None, event)
|
||||
elif event.post_type == "meta_event":
|
||||
meta_event_type = getattr(event, "meta_event_type", "Unknown")
|
||||
self.logger.debug(f"[元事件] {meta_event_type}")
|
||||
await matcher.handle_event(event.bot, event)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"事件处理异常: {str(e)}")
|
||||
finally:
|
||||
# 清理正在处理的事件
|
||||
if client_id in self._processing_events:
|
||||
if event_key in self._processing_events[client_id]:
|
||||
self._processing_events[client_id].discard(event_key)
|
||||
# 如果集合为空,删除该客户端的记录
|
||||
if not self._processing_events[client_id]:
|
||||
del self._processing_events[client_id]
|
||||
|
||||
async def call_api(
|
||||
self,
|
||||
@@ -376,29 +450,43 @@ class ReverseWSManager:
|
||||
"""
|
||||
return self.client_self_ids.copy()
|
||||
|
||||
def _is_duplicate_event(self, event_data: Dict[str, Any]) -> bool:
|
||||
def _is_duplicate_event(self, event_data: Dict[str, Any], client_id: str) -> bool:
|
||||
"""
|
||||
检查是否为重复事件。
|
||||
|
||||
Args:
|
||||
event_data: 事件数据
|
||||
client_id: 客户端ID
|
||||
|
||||
Returns:
|
||||
是否为重复事件
|
||||
"""
|
||||
event_id = event_data.get('id') or event_data.get('post_id') or event_data.get('time')
|
||||
# 尝试多种可能的事件ID字段
|
||||
event_id = (event_data.get('id') or
|
||||
event_data.get('post_id') or
|
||||
event_data.get('message_id') or
|
||||
event_data.get('time'))
|
||||
if not event_id:
|
||||
return False
|
||||
|
||||
event_key = f"{event_data.get('post_type')}:{event_id}"
|
||||
return event_key in self._processed_events
|
||||
|
||||
def _is_duplicate_message(self, event_data: Dict[str, Any]) -> bool:
|
||||
# 检查该客户端是否已处理过此事件
|
||||
if client_id not in self._processed_events:
|
||||
self.logger.debug(f"_is_duplicate_event: client_id={client_id}不在_processed_events中, event_key={event_key}, 返回False")
|
||||
return False
|
||||
|
||||
is_duplicate = event_key in self._processed_events[client_id]
|
||||
self.logger.debug(f"_is_duplicate_event: client_id={client_id}, event_key={event_key}, in_processed={is_duplicate}, processed_events_count={len(self._processed_events[client_id])}")
|
||||
return is_duplicate
|
||||
|
||||
def _is_duplicate_message(self, event_data: Dict[str, Any], client_id: str) -> bool:
|
||||
"""
|
||||
检查是否为重复消息(基于消息内容)。
|
||||
|
||||
Args:
|
||||
event_data: 事件数据
|
||||
client_id: 客户端ID
|
||||
|
||||
Returns:
|
||||
是否为重复消息
|
||||
@@ -406,35 +494,59 @@ class ReverseWSManager:
|
||||
if event_data.get('post_type') != 'message':
|
||||
return False
|
||||
|
||||
# 只对群聊消息进行内容防重复
|
||||
if event_data.get('message_type') != 'group':
|
||||
return False
|
||||
|
||||
# 生成消息内容标识
|
||||
message_id = event_data.get('message_id')
|
||||
user_id = event_data.get('user_id')
|
||||
raw_message = event_data.get('raw_message', '')
|
||||
user_id = event_data.get('user_id')
|
||||
group_id = event_data.get('group_id', '0')
|
||||
|
||||
# 使用消息内容和用户ID作为标识
|
||||
content_key = f"content:{raw_message}:{user_id}"
|
||||
return content_key in self._processed_messages
|
||||
# 使用消息内容、用户ID和群组ID作为标识
|
||||
content_key = f"content:{raw_message}:{user_id}:{group_id}"
|
||||
|
||||
def _mark_event_processed(self, event_data: Dict[str, Any]) -> None:
|
||||
# 检查该客户端是否已处理过此消息内容
|
||||
if client_id not in self._processed_messages:
|
||||
return False
|
||||
|
||||
return content_key in self._processed_messages[client_id]
|
||||
|
||||
def _mark_event_processed(self, event_data: Dict[str, Any], client_id: str) -> None:
|
||||
"""
|
||||
标记事件已处理。
|
||||
|
||||
Args:
|
||||
event_data: 事件数据
|
||||
client_id: 客户端ID
|
||||
"""
|
||||
event_id = event_data.get('id') or event_data.get('post_id') or event_data.get('time')
|
||||
# 尝试多种可能的事件ID字段
|
||||
event_id = (event_data.get('id') or
|
||||
event_data.get('post_id') or
|
||||
event_data.get('message_id') or
|
||||
event_data.get('time'))
|
||||
if not event_id:
|
||||
self.logger.debug(f"_mark_event_processed: event_id为空, event_data={event_data}")
|
||||
return
|
||||
|
||||
event_key = f"{event_data.get('post_type')}:{event_id}"
|
||||
self._processed_events[event_key] = datetime.now()
|
||||
|
||||
# 同时标记消息内容已处理
|
||||
if event_data.get('post_type') == 'message':
|
||||
# 为该客户端记录已处理的事件
|
||||
if client_id not in self._processed_events:
|
||||
self._processed_events[client_id] = {}
|
||||
self._processed_events[client_id][event_key] = datetime.now()
|
||||
self.logger.debug(f"_mark_event_processed: client_id={client_id}, event_key={event_key}, processed_events_count={len(self._processed_events[client_id])}")
|
||||
|
||||
# 只对群聊消息标记内容已处理
|
||||
if event_data.get('post_type') == 'message' and event_data.get('message_type') == 'group':
|
||||
raw_message = event_data.get('raw_message', '')
|
||||
user_id = event_data.get('user_id')
|
||||
content_key = f"content:{raw_message}:{user_id}"
|
||||
self._processed_messages[content_key] = datetime.now()
|
||||
group_id = event_data.get('group_id', '0')
|
||||
content_key = f"content:{raw_message}:{user_id}:{group_id}"
|
||||
|
||||
if client_id not in self._processed_messages:
|
||||
self._processed_messages[client_id] = {}
|
||||
self._processed_messages[client_id][content_key] = datetime.now()
|
||||
|
||||
def _get_message_key(self, event_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
|
||||
15
core/ws.py
15
core/ws.py
@@ -25,7 +25,6 @@ from websockets.legacy.client import WebSocketClientProtocol
|
||||
from models.events.factory import EventFactory
|
||||
|
||||
from .config_loader import global_config
|
||||
from .managers.command_manager import matcher
|
||||
from .utils.executor import CodeExecutor
|
||||
from .utils.logger import ModuleLogger
|
||||
from .utils.exceptions import (
|
||||
@@ -210,6 +209,7 @@ class WS:
|
||||
self.logger.debug(f"[元事件] {meta_event_type}")
|
||||
|
||||
# 分发事件
|
||||
from .managers.command_manager import matcher
|
||||
await matcher.handle_event(self.bot, event)
|
||||
|
||||
except Exception as e:
|
||||
@@ -297,3 +297,16 @@ class WS:
|
||||
message=f"API调用异常: {str(e)}",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
|
||||
class ReverseWSClient(WS):
|
||||
"""
|
||||
反向 WebSocket 客户端代理,用于 Bot 实例调用 API。
|
||||
"""
|
||||
def __init__(self, manager: Any, client_id: str):
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
self.client_id = client_id
|
||||
|
||||
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
|
||||
return await self.manager.call_api(action, params, self.client_id)
|
||||
|
||||
Reference in New Issue
Block a user