feat: 添加多线程架构支持并优化性能
实现线程管理器以支持高并发场景,添加GIL-free模式提升Python 3.14下的多线程性能 新增B站API集成和本地文件服务器功能,改进镜像插件支持GIF处理 更新文档说明多线程架构和GIL-free模式的使用方法
This commit is contained in:
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
|
||||
import tomllib
|
||||
from pydantic import ValidationError
|
||||
from .config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel, ImageManagerModel, MySQLModel, ReverseWSModel
|
||||
from .config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel, ImageManagerModel, MySQLModel, ReverseWSModel, ThreadingModel, BilibiliModel, LocalFileServerModel
|
||||
from .utils.logger import ModuleLogger
|
||||
from .utils.exceptions import ConfigError, ConfigNotFoundError, ConfigValidationError
|
||||
|
||||
@@ -136,6 +136,27 @@ class Config:
|
||||
"""
|
||||
return self._model.reverse_ws
|
||||
|
||||
@property
|
||||
def threading(self) -> ThreadingModel:
|
||||
"""
|
||||
获取线程管理配置
|
||||
"""
|
||||
return self._model.threading
|
||||
|
||||
@property
|
||||
def bilibili(self) -> BilibiliModel:
|
||||
"""
|
||||
获取 Bilibili 配置
|
||||
"""
|
||||
return self._model.bilibili
|
||||
|
||||
@property
|
||||
def local_file_server(self) -> LocalFileServerModel:
|
||||
"""
|
||||
获取本地文件服务器配置
|
||||
"""
|
||||
return self._model.local_file_server
|
||||
|
||||
|
||||
|
||||
# 实例化全局配置对象
|
||||
|
||||
@@ -79,14 +79,32 @@ class ImageManagerModel(BaseModel):
|
||||
image_width: int = 1080
|
||||
|
||||
|
||||
class ReverseWSModel(BaseModel):
|
||||
class ThreadingModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[reverse_ws]` 配置块。
|
||||
对应 `config.toml` 中的 `[threading]` 配置块。
|
||||
"""
|
||||
enabled: bool = False
|
||||
max_workers: int = Field(default=10, ge=1, le=100)
|
||||
client_max_workers: int = Field(default=5, ge=1, le=50)
|
||||
thread_name_prefix: str = "NeoBot-Thread"
|
||||
|
||||
|
||||
class BilibiliModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[bilibili]` 配置块。
|
||||
"""
|
||||
sessdata: Optional[str] = None
|
||||
bili_jct: Optional[str] = None
|
||||
buvid3: Optional[str] = None
|
||||
dedeuserid: Optional[str] = None
|
||||
|
||||
|
||||
class LocalFileServerModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[local_file_server]` 配置块。
|
||||
"""
|
||||
enabled: bool = True
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 3002
|
||||
token: Optional[str] = None
|
||||
port: int = 3003
|
||||
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
@@ -100,5 +118,8 @@ class ConfigModel(BaseModel):
|
||||
docker: DockerModel
|
||||
image_manager: ImageManagerModel
|
||||
reverse_ws: ReverseWSModel
|
||||
threading: ThreadingModel = Field(default_factory=ThreadingModel)
|
||||
bilibili: BilibiliModel = Field(default_factory=BilibiliModel)
|
||||
local_file_server: LocalFileServerModel = Field(default_factory=LocalFileServerModel)
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from .mysql_manager import MySQLManager
|
||||
from .browser_manager import BrowserManager
|
||||
from .image_manager import ImageManager
|
||||
from .reverse_ws_manager import ReverseWSManager
|
||||
from .thread_manager import thread_manager
|
||||
|
||||
# --- 实例化所有单例管理器 ---
|
||||
|
||||
@@ -40,6 +41,9 @@ image_manager = ImageManager()
|
||||
# 反向 WebSocket 管理器
|
||||
reverse_ws_manager = ReverseWSManager()
|
||||
|
||||
# 线程管理器
|
||||
thread_manager.start()
|
||||
|
||||
__all__ = [
|
||||
"permission_manager",
|
||||
"command_manager",
|
||||
@@ -50,4 +54,5 @@ __all__ = [
|
||||
"browser_manager",
|
||||
"image_manager",
|
||||
"reverse_ws_manager",
|
||||
"thread_manager",
|
||||
]
|
||||
|
||||
@@ -11,15 +11,12 @@ from websockets.server import WebSocketServerProtocol
|
||||
from typing import Dict, Any, Optional, Set
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import random
|
||||
import threading
|
||||
|
||||
from ..config_loader import global_config
|
||||
from ..utils.logger import ModuleLogger
|
||||
from ..utils.exceptions import WebSocketError, WebSocketConnectionError
|
||||
from ..utils.error_codes import ErrorCode, create_error_response
|
||||
from .command_manager import matcher
|
||||
from models.events.factory import EventFactory
|
||||
from .redis_manager import redis_manager
|
||||
from ..bot import Bot
|
||||
from ..ws import ReverseWSClient as _ReverseWSClient
|
||||
|
||||
@@ -82,6 +79,18 @@ class ReverseWSManager:
|
||||
# 正在处理的事件ID集合(用于防止重复处理)
|
||||
self._processing_events: Dict[str, Set[str]] = {} # client_id: set of event_ids
|
||||
|
||||
# 线程安全锁
|
||||
self._clients_lock = threading.RLock()
|
||||
self._bots_lock = threading.RLock()
|
||||
self._pending_requests_lock = threading.RLock()
|
||||
self._load_lock = threading.RLock()
|
||||
self._health_lock = threading.RLock()
|
||||
self._processed_events_lock = threading.RLock()
|
||||
self._processed_messages_lock = threading.RLock()
|
||||
self._processing_events_lock = threading.RLock()
|
||||
self._message_locks_lock = threading.RLock()
|
||||
self._message_lock_times_lock = threading.RLock()
|
||||
|
||||
async def start(self, host: str = "0.0.0.0", port: int = 3002) -> None:
|
||||
"""
|
||||
启动反向 WebSocket 服务端。
|
||||
@@ -184,37 +193,41 @@ class ReverseWSManager:
|
||||
current_time = datetime.now()
|
||||
|
||||
# 清理过期的事件ID(按客户端)
|
||||
for client_id, events in list(self._processed_events.items()):
|
||||
expired_events = [
|
||||
event_id for event_id, timestamp in events.items()
|
||||
if (current_time - timestamp).total_seconds() > self._event_ttl
|
||||
]
|
||||
for event_id in expired_events:
|
||||
del events[event_id]
|
||||
if not events:
|
||||
del self._processed_events[client_id]
|
||||
|
||||
with self._processed_events_lock:
|
||||
for client_id, events in list(self._processed_events.items()):
|
||||
expired_events = [
|
||||
event_id for event_id, timestamp in events.items()
|
||||
if (current_time - timestamp).total_seconds() > self._event_ttl
|
||||
]
|
||||
for event_id in expired_events:
|
||||
del events[event_id]
|
||||
if not events:
|
||||
del self._processed_events[client_id]
|
||||
|
||||
# 清理过期的消息锁
|
||||
expired_locks = [
|
||||
lock_key for lock_key, timestamp in self._message_lock_times.items()
|
||||
if (current_time - timestamp).total_seconds() > self._lock_ttl
|
||||
]
|
||||
for lock_key in expired_locks:
|
||||
if lock_key in self._message_locks:
|
||||
del self._message_locks[lock_key]
|
||||
if lock_key in self._message_lock_times:
|
||||
del self._message_lock_times[lock_key]
|
||||
with self._message_lock_times_lock:
|
||||
expired_locks = [
|
||||
lock_key for lock_key, timestamp in self._message_lock_times.items()
|
||||
if (current_time - timestamp).total_seconds() > self._lock_ttl
|
||||
]
|
||||
for lock_key in expired_locks:
|
||||
with self._message_locks_lock:
|
||||
if lock_key in self._message_locks:
|
||||
del self._message_locks[lock_key]
|
||||
if lock_key in self._message_lock_times:
|
||||
del self._message_lock_times[lock_key]
|
||||
|
||||
# 清理过期的消息内容(按客户端)
|
||||
for client_id, messages in list(self._processed_messages.items()):
|
||||
expired_messages = [
|
||||
msg_key for msg_key, timestamp in messages.items()
|
||||
if (current_time - timestamp).total_seconds() > self._message_content_ttl
|
||||
]
|
||||
for msg_key in expired_messages:
|
||||
del messages[msg_key]
|
||||
if not messages:
|
||||
del self._processed_messages[client_id]
|
||||
with self._processed_messages_lock:
|
||||
for client_id, messages in list(self._processed_messages.items()):
|
||||
expired_messages = [
|
||||
msg_key for msg_key, timestamp in messages.items()
|
||||
if (current_time - timestamp).total_seconds() > self._message_content_ttl
|
||||
]
|
||||
for msg_key in expired_messages:
|
||||
del messages[msg_key]
|
||||
if not messages:
|
||||
del self._processed_messages[client_id]
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
@@ -228,24 +241,32 @@ class ReverseWSManager:
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
"""
|
||||
if client_id in self.clients:
|
||||
del self.clients[client_id]
|
||||
if client_id in self.client_self_ids:
|
||||
del self.client_self_ids[client_id]
|
||||
if client_id in self._client_load:
|
||||
del self._client_load[client_id]
|
||||
if client_id in self._client_health:
|
||||
del self._client_health[client_id]
|
||||
if client_id in self.bots:
|
||||
del self.bots[client_id]
|
||||
with self._clients_lock:
|
||||
if client_id in self.clients:
|
||||
del self.clients[client_id]
|
||||
with self._clients_lock:
|
||||
if client_id in self.client_self_ids:
|
||||
del self.client_self_ids[client_id]
|
||||
with self._load_lock:
|
||||
if client_id in self._client_load:
|
||||
del self._client_load[client_id]
|
||||
with self._health_lock:
|
||||
if client_id in self._client_health:
|
||||
del self._client_health[client_id]
|
||||
with self._bots_lock:
|
||||
if client_id in self.bots:
|
||||
del self.bots[client_id]
|
||||
|
||||
# 清理该客户端的防重复数据
|
||||
if client_id in self._processed_events:
|
||||
del self._processed_events[client_id]
|
||||
if client_id in self._processed_messages:
|
||||
del self._processed_messages[client_id]
|
||||
if client_id in self._processing_events:
|
||||
del self._processing_events[client_id]
|
||||
with self._processed_events_lock:
|
||||
if client_id in self._processed_events:
|
||||
del self._processed_events[client_id]
|
||||
with self._processed_messages_lock:
|
||||
if client_id in self._processed_messages:
|
||||
del self._processed_messages[client_id]
|
||||
with self._processing_events_lock:
|
||||
if client_id in self._processing_events:
|
||||
del self._processing_events[client_id]
|
||||
|
||||
self.logger.info(f"客户端已断开并清理: {client_id}")
|
||||
|
||||
@@ -266,41 +287,46 @@ class ReverseWSManager:
|
||||
event_key = f"{event_data.get('post_type')}:{event_id}"
|
||||
|
||||
# 检查客户端是否已连接
|
||||
if client_id not in self.clients:
|
||||
self.logger.debug(f"_on_event: 客户端已断开, client_id={client_id}")
|
||||
return
|
||||
|
||||
with self._clients_lock:
|
||||
if client_id not in self.clients:
|
||||
self.logger.debug(f"_on_event: 客户端已断开, client_id={client_id}")
|
||||
return
|
||||
|
||||
# 检查是否正在处理
|
||||
if client_id not in self._processing_events:
|
||||
self._processing_events[client_id] = set()
|
||||
with self._processing_events_lock:
|
||||
if client_id not in self._processing_events:
|
||||
self._processing_events[client_id] = set()
|
||||
|
||||
if event_key in self._processing_events[client_id]:
|
||||
self.logger.debug(f"_on_event: 事件正在处理中, client_id={client_id}, event_key={event_key}")
|
||||
return
|
||||
if event_key in self._processing_events[client_id]:
|
||||
self.logger.debug(f"_on_event: 事件正在处理中, client_id={client_id}, event_key={event_key}")
|
||||
return
|
||||
|
||||
# 标记为正在处理
|
||||
self._processing_events[client_id].add(event_key)
|
||||
# 标记为正在处理
|
||||
self._processing_events[client_id].add(event_key)
|
||||
|
||||
try:
|
||||
event = EventFactory.create_event(event_data)
|
||||
|
||||
if hasattr(event, 'self_id'):
|
||||
self.client_self_ids[client_id] = event.self_id
|
||||
with self._clients_lock:
|
||||
self.client_self_ids[client_id] = event.self_id
|
||||
|
||||
# 为事件注入Bot实例
|
||||
from ..ws import ReverseWSClient
|
||||
|
||||
# 为每个前端创建独立的Bot实例
|
||||
if client_id not in self.bots:
|
||||
# 使用 ReverseWSClient 代理
|
||||
temp_ws = ReverseWSClient(self, client_id)
|
||||
temp_ws.self_id = event.self_id if hasattr(event, 'self_id') else 0
|
||||
self.bots[client_id] = Bot(temp_ws)
|
||||
with self._bots_lock:
|
||||
if client_id not in self.bots:
|
||||
# 使用 ReverseWSClient 代理
|
||||
temp_ws = ReverseWSClient(self, client_id)
|
||||
temp_ws.self_id = event.self_id if hasattr(event, 'self_id') else 0
|
||||
self.bots[client_id] = Bot(temp_ws)
|
||||
|
||||
event.bot = self.bots[client_id]
|
||||
|
||||
# 记录客户端健康状态
|
||||
self._client_health[client_id] = datetime.now()
|
||||
with self._health_lock:
|
||||
self._client_health[client_id] = datetime.now()
|
||||
|
||||
# 检查是否为重复事件(按客户端)
|
||||
is_duplicate = self._is_duplicate_event(event_data, client_id)
|
||||
@@ -333,15 +359,18 @@ class ReverseWSManager:
|
||||
return
|
||||
|
||||
# 标记事件已处理(按客户端)
|
||||
self._mark_event_processed(event_data, client_id)
|
||||
with self._processed_events_lock:
|
||||
self._mark_event_processed(event_data, client_id)
|
||||
|
||||
# 更新客户端负载
|
||||
self._update_client_load(client_id)
|
||||
with self._load_lock:
|
||||
self._update_client_load(client_id)
|
||||
|
||||
await matcher.handle_event(event.bot, event)
|
||||
else:
|
||||
# 对于非消息事件,直接标记并处理
|
||||
self._mark_event_processed(event_data, client_id)
|
||||
with self._processed_events_lock:
|
||||
self._mark_event_processed(event_data, client_id)
|
||||
|
||||
if event.post_type == "notice":
|
||||
notice_type = getattr(event, "notice_type", "Unknown")
|
||||
@@ -362,12 +391,13 @@ class ReverseWSManager:
|
||||
self.logger.exception(f"事件处理异常: {str(e)}")
|
||||
finally:
|
||||
# 清理正在处理的事件
|
||||
if client_id in self._processing_events:
|
||||
if event_key in self._processing_events[client_id]:
|
||||
self._processing_events[client_id].discard(event_key)
|
||||
# 如果集合为空,删除该客户端的记录
|
||||
if not self._processing_events[client_id]:
|
||||
del self._processing_events[client_id]
|
||||
with self._processing_events_lock:
|
||||
if client_id in self._processing_events:
|
||||
if event_key in self._processing_events[client_id]:
|
||||
self._processing_events[client_id].discard(event_key)
|
||||
# 如果集合为空,删除该客户端的记录
|
||||
if not self._processing_events[client_id]:
|
||||
del self._processing_events[client_id]
|
||||
|
||||
async def call_api(
|
||||
self,
|
||||
@@ -404,28 +434,39 @@ class ReverseWSManager:
|
||||
# 选择负载最低的客户端
|
||||
client_id = self.get_client_with_least_load()
|
||||
if client_id is None and healthy_clients:
|
||||
client_id = list(healthy_clients.keys())[0]
|
||||
with self._clients_lock:
|
||||
client_id = list(healthy_clients.keys())[0]
|
||||
else:
|
||||
# 如果没有健康客户端,使用所有客户端中的一个
|
||||
client_id = list(self.clients.keys())[0]
|
||||
|
||||
with self._clients_lock:
|
||||
client_id = list(self.clients.keys())[0]
|
||||
|
||||
echo_id = str(uuid.uuid4())
|
||||
payload = {"action": action, "params": params or {}, "echo": echo_id}
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
self._pending_requests[echo_id] = future
|
||||
with self._pending_requests_lock:
|
||||
self._pending_requests[echo_id] = future
|
||||
|
||||
try:
|
||||
targets = [client_id] if client_id else list(self.clients.keys())
|
||||
targets = [client_id] if client_id else None
|
||||
clients_to_send = []
|
||||
|
||||
for cid in targets:
|
||||
if cid in self.clients:
|
||||
await self.clients[cid].send(orjson.dumps(payload))
|
||||
with self._clients_lock:
|
||||
if targets is None:
|
||||
targets = list(self.clients.keys())
|
||||
for cid in targets:
|
||||
if cid in self.clients:
|
||||
clients_to_send.append((cid, self.clients[cid]))
|
||||
|
||||
for cid, websocket in clients_to_send:
|
||||
await websocket.send(orjson.dumps(payload))
|
||||
|
||||
return await asyncio.wait_for(future, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
with self._pending_requests_lock:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.warning(f"API 调用超时: action={action}, params={params}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.TIMEOUT_ERROR,
|
||||
@@ -433,7 +474,8 @@ class ReverseWSManager:
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
with self._pending_requests_lock:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
@@ -448,7 +490,8 @@ class ReverseWSManager:
|
||||
Returns:
|
||||
客户端 ID 和 self_id 的映射字典
|
||||
"""
|
||||
return self.client_self_ids.copy()
|
||||
with self._clients_lock:
|
||||
return self.client_self_ids.copy()
|
||||
|
||||
def _is_duplicate_event(self, event_data: Dict[str, Any], client_id: str) -> bool:
|
||||
"""
|
||||
@@ -472,13 +515,14 @@ class ReverseWSManager:
|
||||
event_key = f"{event_data.get('post_type')}:{event_id}"
|
||||
|
||||
# 检查该客户端是否已处理过此事件
|
||||
if client_id not in self._processed_events:
|
||||
self.logger.debug(f"_is_duplicate_event: client_id={client_id}不在_processed_events中, event_key={event_key}, 返回False")
|
||||
return False
|
||||
with self._processed_events_lock:
|
||||
if client_id not in self._processed_events:
|
||||
self.logger.debug(f"_is_duplicate_event: client_id={client_id}不在_processed_events中, event_key={event_key}, 返回False")
|
||||
return False
|
||||
|
||||
is_duplicate = event_key in self._processed_events[client_id]
|
||||
self.logger.debug(f"_is_duplicate_event: client_id={client_id}, event_key={event_key}, in_processed={is_duplicate}, processed_events_count={len(self._processed_events[client_id])}")
|
||||
return is_duplicate
|
||||
is_duplicate = event_key in self._processed_events[client_id]
|
||||
self.logger.debug(f"_is_duplicate_event: client_id={client_id}, event_key={event_key}, in_processed={is_duplicate}, processed_events_count={len(self._processed_events[client_id])}")
|
||||
return is_duplicate
|
||||
|
||||
def _is_duplicate_message(self, event_data: Dict[str, Any], client_id: str) -> bool:
|
||||
"""
|
||||
@@ -507,10 +551,11 @@ class ReverseWSManager:
|
||||
content_key = f"content:{raw_message}:{user_id}:{group_id}"
|
||||
|
||||
# 检查该客户端是否已处理过此消息内容
|
||||
if client_id not in self._processed_messages:
|
||||
return False
|
||||
with self._processed_messages_lock:
|
||||
if client_id not in self._processed_messages:
|
||||
return False
|
||||
|
||||
return content_key in self._processed_messages[client_id]
|
||||
return content_key in self._processed_messages[client_id]
|
||||
|
||||
def _mark_event_processed(self, event_data: Dict[str, Any], client_id: str) -> None:
|
||||
"""
|
||||
@@ -532,10 +577,11 @@ class ReverseWSManager:
|
||||
event_key = f"{event_data.get('post_type')}:{event_id}"
|
||||
|
||||
# 为该客户端记录已处理的事件
|
||||
if client_id not in self._processed_events:
|
||||
self._processed_events[client_id] = {}
|
||||
self._processed_events[client_id][event_key] = datetime.now()
|
||||
self.logger.debug(f"_mark_event_processed: client_id={client_id}, event_key={event_key}, processed_events_count={len(self._processed_events[client_id])}")
|
||||
with self._processed_events_lock:
|
||||
if client_id not in self._processed_events:
|
||||
self._processed_events[client_id] = {}
|
||||
self._processed_events[client_id][event_key] = datetime.now()
|
||||
self.logger.debug(f"_mark_event_processed: client_id={client_id}, event_key={event_key}, processed_events_count={len(self._processed_events[client_id])}")
|
||||
|
||||
# 只对群聊消息标记内容已处理
|
||||
if event_data.get('post_type') == 'message' and event_data.get('message_type') == 'group':
|
||||
@@ -544,9 +590,10 @@ class ReverseWSManager:
|
||||
group_id = event_data.get('group_id', '0')
|
||||
content_key = f"content:{raw_message}:{user_id}:{group_id}"
|
||||
|
||||
if client_id not in self._processed_messages:
|
||||
self._processed_messages[client_id] = {}
|
||||
self._processed_messages[client_id][content_key] = datetime.now()
|
||||
with self._processed_messages_lock:
|
||||
if client_id not in self._processed_messages:
|
||||
self._processed_messages[client_id] = {}
|
||||
self._processed_messages[client_id][content_key] = datetime.now()
|
||||
|
||||
def _get_message_key(self, event_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
@@ -574,9 +621,11 @@ class ReverseWSManager:
|
||||
Returns:
|
||||
asyncio.Lock 实例
|
||||
"""
|
||||
if key not in self._message_locks:
|
||||
self._message_locks[key] = asyncio.Lock()
|
||||
self._message_lock_times[key] = datetime.now()
|
||||
with self._message_locks_lock:
|
||||
if key not in self._message_locks:
|
||||
self._message_locks[key] = asyncio.Lock()
|
||||
with self._message_lock_times_lock:
|
||||
self._message_lock_times[key] = datetime.now()
|
||||
return self._message_locks[key]
|
||||
|
||||
def _update_client_load(self, client_id: str) -> None:
|
||||
@@ -586,9 +635,10 @@ class ReverseWSManager:
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
"""
|
||||
if client_id not in self._client_load:
|
||||
self._client_load[client_id] = 0
|
||||
self._client_load[client_id] += 1
|
||||
with self._load_lock:
|
||||
if client_id not in self._client_load:
|
||||
self._client_load[client_id] = 0
|
||||
self._client_load[client_id] += 1
|
||||
|
||||
def get_client_with_least_load(self) -> Optional[str]:
|
||||
"""
|
||||
@@ -597,10 +647,11 @@ class ReverseWSManager:
|
||||
Returns:
|
||||
客户端 ID,如果没有客户端则返回 None
|
||||
"""
|
||||
if not self._client_load:
|
||||
return None
|
||||
with self._load_lock:
|
||||
if not self._client_load:
|
||||
return None
|
||||
|
||||
return min(self._client_load.keys(), key=lambda k: self._client_load[k])
|
||||
return min(self._client_load.keys(), key=lambda k: self._client_load[k])
|
||||
|
||||
def get_healthy_clients(self) -> Dict[str, int]:
|
||||
"""
|
||||
@@ -612,11 +663,13 @@ class ReverseWSManager:
|
||||
current_time = datetime.now()
|
||||
healthy = {}
|
||||
|
||||
for client_id, last_health in self._client_health.items():
|
||||
if (current_time - last_health).total_seconds() < 30:
|
||||
if client_id in self.client_self_ids:
|
||||
healthy[client_id] = self.client_self_ids[client_id]
|
||||
|
||||
with self._health_lock:
|
||||
with self._clients_lock:
|
||||
for client_id, last_health in self._client_health.items():
|
||||
if (current_time - last_health).total_seconds() < 30:
|
||||
if client_id in self.client_self_ids:
|
||||
healthy[client_id] = self.client_self_ids[client_id]
|
||||
|
||||
return healthy
|
||||
|
||||
|
||||
|
||||
379
core/managers/thread_manager.py
Normal file
379
core/managers/thread_manager.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
线程管理器模块
|
||||
|
||||
该模块提供了多线程支持,用于处理来自多个实现端的并发事件。
|
||||
每个 WebSocket 连接在独立的线程中运行,避免阻塞主事件循环。
|
||||
"""
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Dict, Optional, Callable, Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from ..utils.logger import ModuleLogger
|
||||
from ..config_loader import global_config
|
||||
|
||||
|
||||
class ThreadManager:
|
||||
"""
|
||||
线程管理器,负责管理多线程环境下的事件处理。
|
||||
|
||||
该管理器为每个 WebSocket 连接提供独立的线程池,
|
||||
确保多前端场景下的事件处理不会相互阻塞。
|
||||
"""
|
||||
|
||||
_instance: Optional['ThreadManager'] = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> 'ThreadManager':
|
||||
"""
|
||||
单例模式:确保全局只有一个线程管理器实例。
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
初始化线程管理器。
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.logger = ModuleLogger("ThreadManager")
|
||||
|
||||
# 线程池配置
|
||||
self._max_workers: int = global_config.threading.max_workers
|
||||
self._thread_name_prefix: str = global_config.threading.thread_name_prefix
|
||||
|
||||
# 线程池
|
||||
self._executor: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
# 每个客户端的线程池(用于反向 WebSocket)
|
||||
self._client_executors: Dict[str, ThreadPoolExecutor] = {}
|
||||
self._client_executor_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
# 线程安全的事件循环(用于跨线程调用)
|
||||
self._event_loops: Dict[str, asyncio.AbstractEventLoop] = {}
|
||||
self._event_loops_lock = threading.Lock()
|
||||
|
||||
# 统计信息
|
||||
self._stats: Dict[str, Any] = {
|
||||
'total_tasks': 0,
|
||||
'completed_tasks': 0,
|
||||
'failed_tasks': 0,
|
||||
'active_threads': 0,
|
||||
'client_tasks': {}
|
||||
}
|
||||
self._stats_lock = threading.Lock()
|
||||
|
||||
self._initialized = True
|
||||
self.logger.success("线程管理器初始化完成")
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
启动线程管理器,创建主线程池。
|
||||
"""
|
||||
if self._executor is None:
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=self._max_workers,
|
||||
thread_name_prefix=self._thread_name_prefix
|
||||
)
|
||||
self.logger.success(f"主 ThreadPool 已启动: max_workers={self._max_workers}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
关闭线程管理器,释放所有资源。
|
||||
"""
|
||||
self.logger.info("正在关闭线程管理器...")
|
||||
|
||||
# 关闭所有客户端线程池
|
||||
for client_id, executor in list(self._client_executors.items()):
|
||||
self._shutdown_client_executor(client_id)
|
||||
|
||||
# 关闭主执行器
|
||||
if self._executor is not None:
|
||||
self._executor.shutdown(wait=True)
|
||||
self._executor = None
|
||||
|
||||
self.logger.success("线程管理器已关闭")
|
||||
|
||||
def _shutdown_client_executor(self, client_id: str) -> None:
|
||||
"""
|
||||
关闭特定客户端的线程池。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
"""
|
||||
if client_id in self._client_executors:
|
||||
try:
|
||||
self._client_executors[client_id].shutdown(wait=True)
|
||||
del self._client_executors[client_id]
|
||||
self.logger.info(f"客户端 {client_id} 的线程池已关闭")
|
||||
except Exception as e:
|
||||
self.logger.error(f"关闭客户端 {client_id} 线程池失败: {e}")
|
||||
|
||||
def get_main_executor(self) -> ThreadPoolExecutor:
|
||||
"""
|
||||
获取主线程池。
|
||||
|
||||
Returns:
|
||||
ThreadPoolExecutor 实例
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果线程管理器未启动
|
||||
"""
|
||||
if self._executor is None:
|
||||
raise RuntimeError("线程管理器未启动,请先调用 start()")
|
||||
return self._executor
|
||||
|
||||
def get_client_executor(self, client_id: str) -> ThreadPoolExecutor:
|
||||
"""
|
||||
获取特定客户端的线程池(为反向 WebSocket 设计)。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
|
||||
Returns:
|
||||
ThreadPoolExecutor 实例
|
||||
"""
|
||||
if client_id not in self._client_executors:
|
||||
with threading.Lock():
|
||||
if client_id not in self._client_executors:
|
||||
executor = ThreadPoolExecutor(
|
||||
max_workers=global_config.threading.client_max_workers,
|
||||
thread_name_prefix=f"{self._thread_name_prefix}_{client_id[:8]}"
|
||||
)
|
||||
self._client_executors[client_id] = executor
|
||||
self._client_executor_locks[client_id] = threading.Lock()
|
||||
self.logger.info(f"为客户端 {client_id} 创建线程池")
|
||||
|
||||
return self._client_executors[client_id]
|
||||
|
||||
def submit_to_main_executor(
|
||||
self,
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
提交任务到主线程池(同步)。
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
函数执行结果
|
||||
"""
|
||||
executor = self.get_main_executor()
|
||||
future = executor.submit(func, *args, **kwargs)
|
||||
self._update_stats('total_tasks')
|
||||
try:
|
||||
result = future.result()
|
||||
self._update_stats('completed_tasks')
|
||||
return result
|
||||
except Exception as e:
|
||||
self._update_stats('failed_tasks')
|
||||
self.logger.error(f"主线程池任务执行失败: {e}")
|
||||
raise
|
||||
|
||||
async def submit_to_main_executor_async(
|
||||
self,
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
提交任务到主线程池(异步)。
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
函数执行结果
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
executor = self.get_main_executor()
|
||||
future = loop.run_in_executor(executor, lambda: func(*args, **kwargs))
|
||||
self._update_stats('total_tasks')
|
||||
try:
|
||||
result = await future
|
||||
self._update_stats('completed_tasks')
|
||||
return result
|
||||
except Exception as e:
|
||||
self._update_stats('failed_tasks')
|
||||
self.logger.error(f"异步主线程池任务执行失败: {e}")
|
||||
raise
|
||||
|
||||
def submit_to_client_executor(
|
||||
self,
|
||||
client_id: str,
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
提交任务到特定客户端的线程池。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
func: 要执行的函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
函数执行结果
|
||||
"""
|
||||
executor = self.get_client_executor(client_id)
|
||||
future = executor.submit(func, *args, **kwargs)
|
||||
self._update_client_stats(client_id, 'total_tasks')
|
||||
try:
|
||||
result = future.result()
|
||||
self._update_client_stats(client_id, 'completed_tasks')
|
||||
return result
|
||||
except Exception as e:
|
||||
self._update_client_stats(client_id, 'failed_tasks')
|
||||
self.logger.error(f"客户端 {client_id} 线程池任务执行失败: {e}")
|
||||
raise
|
||||
|
||||
async def submit_to_client_executor_async(
|
||||
self,
|
||||
client_id: str,
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
提交任务到特定客户端的线程池(异步)。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
func: 要执行的函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
函数执行结果
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
executor = self.get_client_executor(client_id)
|
||||
future = loop.run_in_executor(executor, lambda: func(*args, **kwargs))
|
||||
self._update_client_stats(client_id, 'total_tasks')
|
||||
try:
|
||||
result = await future
|
||||
self._update_client_stats(client_id, 'completed_tasks')
|
||||
return result
|
||||
except Exception as e:
|
||||
self._update_client_stats(client_id, 'failed_tasks')
|
||||
self.logger.error(f"客户端 {client_id} 异步线程池任务执行失败: {e}")
|
||||
raise
|
||||
|
||||
def run_coroutine_threadsafe(
|
||||
self,
|
||||
coro,
|
||||
client_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
在指定客户端的事件循环中运行协程(线程安全)。
|
||||
|
||||
Args:
|
||||
coro: 协程对象
|
||||
client_id: 客户端 ID,如果为 None 则使用主事件循环
|
||||
|
||||
Returns:
|
||||
协程执行结果
|
||||
"""
|
||||
if client_id is None:
|
||||
loop = asyncio.get_running_loop()
|
||||
else:
|
||||
with self._event_loops_lock:
|
||||
if client_id not in self._event_loops:
|
||||
self._event_loops[client_id] = asyncio.new_event_loop()
|
||||
threading.Thread(
|
||||
target=self._event_loop_thread,
|
||||
args=(client_id,),
|
||||
daemon=True
|
||||
).start()
|
||||
loop = self._event_loops[client_id]
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
return future.result()
|
||||
|
||||
def _event_loop_thread(self, client_id: str) -> None:
|
||||
"""
|
||||
事件循环线程(用于反向 WebSocket 客户端)。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
"""
|
||||
asyncio.set_event_loop(self._event_loops[client_id])
|
||||
self.logger.info(f"事件循环线程启动: client_id={client_id}")
|
||||
try:
|
||||
self._event_loops[client_id].run_forever()
|
||||
finally:
|
||||
self._event_loops[client_id].close()
|
||||
self.logger.info(f"事件循环线程停止: client_id={client_id}")
|
||||
|
||||
def _update_stats(self, key: str) -> None:
|
||||
"""
|
||||
更新全局统计信息。
|
||||
|
||||
Args:
|
||||
key: 统计项键名
|
||||
"""
|
||||
with self._stats_lock:
|
||||
self._stats[key] = self._stats.get(key, 0) + 1
|
||||
|
||||
def _update_client_stats(self, client_id: str, key: str) -> None:
|
||||
"""
|
||||
更新客户端统计信息。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
key: 统计项键名
|
||||
"""
|
||||
with self._stats_lock:
|
||||
if client_id not in self._stats['client_tasks']:
|
||||
self._stats['client_tasks'][client_id] = {
|
||||
'total_tasks': 0,
|
||||
'completed_tasks': 0,
|
||||
'failed_tasks': 0
|
||||
}
|
||||
self._stats['client_tasks'][client_id][key] += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息。
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
with self._stats_lock:
|
||||
stats = self._stats.copy()
|
||||
stats['client_tasks'] = stats.get('client_tasks', {}).copy()
|
||||
return stats
|
||||
|
||||
def get_active_threads_count(self) -> int:
|
||||
"""
|
||||
获取活动线程数量。
|
||||
|
||||
Returns:
|
||||
活动线程数量
|
||||
"""
|
||||
import threading
|
||||
return sum(
|
||||
1 for t in threading.enumerate()
|
||||
if t.name.startswith(self._thread_name_prefix)
|
||||
)
|
||||
|
||||
|
||||
# 全局线程管理器实例
|
||||
thread_manager = ThreadManager()
|
||||
217
core/services/local_file_server.py
Normal file
217
core/services/local_file_server.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
本地文件下载服务
|
||||
|
||||
该模块提供一个本地 HTTP 服务,用于下载远程文件到本地并提供本地访问。
|
||||
主要解决 NapCat 等第三方服务无法直接访问某些远程资源(如 B 站防盗链)的问题。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
import urllib.request
|
||||
|
||||
from core.utils.logger import logger
|
||||
from core.config_loader import global_config
|
||||
|
||||
|
||||
class LocalFileServer:
|
||||
"""
|
||||
本地文件下载服务
|
||||
|
||||
提供一个本地 HTTP 服务,用于下载远程文件到本地并提供本地访问。
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 3003):
|
||||
"""
|
||||
初始化本地文件下载服务
|
||||
|
||||
Args:
|
||||
host (str): 服务监听地址
|
||||
port (int): 服务监听端口
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.app = web.Application()
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self.download_dir = Path(tempfile.gettempdir()) / "neobot_downloads"
|
||||
self.download_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 注册路由
|
||||
self.app.router.add_get('/download', self.handle_download)
|
||||
self.app.router.add_get('/health', self.handle_health)
|
||||
|
||||
# 文件映射表:file_id -> file_path
|
||||
self.file_map: Dict[str, Path] = {}
|
||||
|
||||
logger.success(f"[LocalFileServer] 初始化完成: {self.host}:{self.port}")
|
||||
|
||||
async def start(self):
|
||||
"""启动服务"""
|
||||
self.runner = web.AppRunner(self.app)
|
||||
await self.runner.setup()
|
||||
self.site = web.TCPSite(self.runner, self.host, self.port)
|
||||
await self.site.start()
|
||||
logger.success(f"[LocalFileServer] 服务已启动: http://{self.host}:{self.port}")
|
||||
|
||||
async def stop(self):
|
||||
"""停止服务"""
|
||||
if self.runner:
|
||||
await self.runner.cleanup()
|
||||
logger.info("[LocalFileServer] 服务已停止")
|
||||
|
||||
def _generate_file_id(self, url: str) -> str:
|
||||
"""根据 URL 生成唯一的文件 ID"""
|
||||
url_hash = hashlib.md5(url.encode()).hexdigest()[:16]
|
||||
return f"file_{url_hash}"
|
||||
|
||||
async def download_file(self, url: str, timeout: int = 60) -> Optional[str]:
|
||||
"""
|
||||
下载远程文件到本地
|
||||
|
||||
Args:
|
||||
url (str): 远程文件 URL
|
||||
timeout (int): 下载超时时间(秒)
|
||||
|
||||
Returns:
|
||||
Optional[str]: 本地文件 ID,如果失败则返回 None
|
||||
"""
|
||||
try:
|
||||
file_id = self._generate_file_id(url)
|
||||
file_path = self.download_dir / f"{file_id}"
|
||||
|
||||
# 检查文件是否已存在
|
||||
if file_path.exists():
|
||||
logger.info(f"[LocalFileServer] 文件已存在: {file_id}")
|
||||
return file_id
|
||||
|
||||
logger.info(f"[LocalFileServer] 开始下载: {url}")
|
||||
|
||||
# 使用 aiohttp 下载文件
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, timeout=timeout) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"[LocalFileServer] 下载失败: HTTP {response.status}")
|
||||
return None
|
||||
|
||||
# 读取并保存文件
|
||||
with open(file_path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await response.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
|
||||
self.file_map[file_id] = file_path
|
||||
logger.success(f"[LocalFileServer] 下载完成: {file_id} ({file_path.stat().st_size} bytes)")
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LocalFileServer] 下载失败: {e}")
|
||||
return None
|
||||
|
||||
async def handle_download(self, request: web.Request) -> web.Response:
|
||||
"""处理文件下载请求"""
|
||||
file_id = request.query.get('id')
|
||||
|
||||
if not file_id or file_id not in self.file_map:
|
||||
return web.Response(
|
||||
status=404,
|
||||
text='File not found',
|
||||
content_type='text/plain'
|
||||
)
|
||||
|
||||
file_path = self.file_map[file_id]
|
||||
|
||||
if not file_path.exists():
|
||||
return web.Response(
|
||||
status=404,
|
||||
text='File not found',
|
||||
content_type='text/plain'
|
||||
)
|
||||
|
||||
# 获取文件大小
|
||||
file_size = file_path.stat().st_size
|
||||
|
||||
# 设置响应头
|
||||
headers = {
|
||||
'Content-Disposition': f'attachment; filename="{file_id}"',
|
||||
'Content-Length': str(file_size)
|
||||
}
|
||||
|
||||
return web.FileResponse(file_path, headers=headers)
|
||||
|
||||
async def handle_health(self, request: web.Request) -> web.Response:
|
||||
"""健康检查"""
|
||||
return web.json_response({
|
||||
'status': 'ok',
|
||||
'service': 'LocalFileServer',
|
||||
'download_dir': str(self.download_dir),
|
||||
'files_count': len(self.file_map)
|
||||
})
|
||||
|
||||
|
||||
# 全局实例
|
||||
_local_file_server: Optional[LocalFileServer] = None
|
||||
|
||||
|
||||
def get_local_file_server() -> Optional[LocalFileServer]:
|
||||
"""获取全局本地文件服务器实例"""
|
||||
global _local_file_server
|
||||
|
||||
if _local_file_server is None:
|
||||
try:
|
||||
server_config = global_config.local_file_server
|
||||
_local_file_server = LocalFileServer(
|
||||
host=server_config.host,
|
||||
port=server_config.port
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LocalFileServer] 初始化失败: {e}")
|
||||
return None
|
||||
|
||||
return _local_file_server
|
||||
|
||||
|
||||
async def start_local_file_server():
|
||||
"""启动全局本地文件服务器"""
|
||||
server = get_local_file_server()
|
||||
if server:
|
||||
await server.start()
|
||||
|
||||
|
||||
async def stop_local_file_server():
|
||||
"""停止全局本地文件服务器"""
|
||||
global _local_file_server
|
||||
if _local_file_server:
|
||||
await _local_file_server.stop()
|
||||
_local_file_server = None
|
||||
|
||||
|
||||
async def download_to_local(url: str, timeout: int = 60) -> Optional[str]:
|
||||
"""
|
||||
下载远程文件到本地并返回本地访问 URL
|
||||
|
||||
Args:
|
||||
url (str): 远程文件 URL
|
||||
timeout (int): 下载超时时间(秒)
|
||||
|
||||
Returns:
|
||||
Optional[str]: 本地访问 URL,如果失败则返回 None
|
||||
"""
|
||||
server = get_local_file_server()
|
||||
if not server:
|
||||
return None
|
||||
|
||||
file_id = await server.download_file(url, timeout)
|
||||
if not file_id:
|
||||
return None
|
||||
|
||||
return f"http://127.0.0.1:{server.port}/download?id={file_id}"
|
||||
33
core/ws.py
33
core/ws.py
@@ -15,6 +15,7 @@ import asyncio
|
||||
import orjson
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, cast
|
||||
import uuid
|
||||
import threading
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bot import Bot
|
||||
@@ -59,6 +60,9 @@ class WS:
|
||||
self.self_id: int | None = None
|
||||
self.code_executor = code_executor
|
||||
|
||||
# 线程安全锁
|
||||
self._pending_requests_lock = threading.RLock()
|
||||
|
||||
# 创建模块专用日志记录器
|
||||
self.logger = ModuleLogger("WebSocket")
|
||||
|
||||
@@ -123,9 +127,10 @@ class WS:
|
||||
# 如果消息中包含 echo 字段,说明是 API 调用的响应
|
||||
echo_id = data.get("echo")
|
||||
if echo_id and echo_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(echo_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
with self._pending_requests_lock:
|
||||
future = self._pending_requests.pop(echo_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
continue
|
||||
|
||||
# 2. 处理上报事件
|
||||
@@ -229,12 +234,13 @@ class WS:
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
|
||||
# 取消所有挂起的请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
with self._pending_requests_lock:
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
self.logger.success("WebSocket 客户端已关闭")
|
||||
|
||||
@@ -276,13 +282,15 @@ class WS:
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
self._pending_requests[echo_id] = future
|
||||
|
||||
with self._pending_requests_lock:
|
||||
self._pending_requests[echo_id] = future
|
||||
|
||||
try:
|
||||
await self.ws.send(orjson.dumps(payload))
|
||||
return await asyncio.wait_for(future, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
with self._pending_requests_lock:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.warning(f"API 调用超时: action={action}, params={params}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.TIMEOUT_ERROR,
|
||||
@@ -290,7 +298,8 @@ class WS:
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
with self._pending_requests_lock:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
|
||||
Reference in New Issue
Block a user