refactor(websocket): 移除连接池模式并改进资源清理

移除 WebSocket 连接池实现,改为单连接模式以简化代码结构
在 main 函数中添加资源清理逻辑,确保程序退出时正确关闭所有资源
改进 base64 数据处理逻辑,支持递归处理嵌套结构中的敏感数据

呵呵线程池加WS是神人
This commit is contained in:
2026-01-23 18:24:59 +08:00
parent 38bb10ccd9
commit cd5875be24
4 changed files with 129 additions and 460 deletions

View File

@@ -32,7 +32,6 @@ from .utils.exceptions import (
WebSocketError, WebSocketConnectionError WebSocketError, WebSocketConnectionError
) )
from .utils.error_codes import ErrorCode, create_error_response from .utils.error_codes import ErrorCode, create_error_response
from .ws_pool import WSConnectionPool
class WS: class WS:
@@ -40,14 +39,13 @@ class WS:
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。 WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
""" """
def __init__(self, code_executor: Optional[CodeExecutor] = None, use_pool: bool = True) -> None: def __init__(self, code_executor: Optional[CodeExecutor] = None) -> None:
""" """
初始化 WebSocket 客户端。 初始化 WebSocket 客户端。
从全局配置中读取 WebSocket URI、访问令牌Token和重连间隔。 从全局配置中读取 WebSocket URI、访问令牌Token和重连间隔。
:param code_executor: 代码执行器实例 :param code_executor: 代码执行器实例
:param use_pool: 是否使用连接池
""" """
# 读取参数 # 读取参数
cfg = global_config.napcat_ws cfg = global_config.napcat_ws
@@ -61,8 +59,6 @@ class WS:
self.bot: 'Bot' | None = None self.bot: 'Bot' | None = None
self.self_id: int | None = None self.self_id: int | None = None
self.code_executor = code_executor self.code_executor = code_executor
self.use_pool = use_pool
self.pool: Optional[WSConnectionPool] = None
# 创建模块专用日志记录器 # 创建模块专用日志记录器
self.logger = ModuleLogger("WebSocket") self.logger = ModuleLogger("WebSocket")
@@ -76,112 +72,39 @@ class WS:
""" """
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
if self.use_pool:
# 使用连接池模式
self.pool = WSConnectionPool(pool_size=3)
await self.pool.initialize()
self.logger.success("WebSocket 连接池初始化完成")
# 启动连接池监听循环
await self._pool_listen_loop()
else:
# 单连接模式
while True:
try:
self.logger.info(f"正在尝试连接至 NapCat: {self.url}")
async with websockets.connect(
self.url, additional_headers=headers
) as websocket_raw:
websocket = cast(WebSocketClientProtocol, websocket_raw)
self.ws = websocket
self.logger.success("连接成功!")
await self._listen_loop(websocket)
except (
websockets.exceptions.ConnectionClosed,
ConnectionRefusedError,
) as e:
conn_error = WebSocketConnectionError(
message=f"WebSocket连接失败: {str(e)}",
code=ErrorCode.WS_CONNECTION_FAILED,
original_error=e
)
self.logger.error(f"连接失败: {conn_error.message}")
self.logger.log_custom_exception(conn_error)
except Exception as e:
error = WebSocketError(
message=f"WebSocket运行异常: {str(e)}",
code=ErrorCode.WS_MESSAGE_ERROR,
original_error=e
)
self.logger.exception(f"运行异常: {error.message}")
self.logger.log_custom_exception(error)
self.logger.info(f"{self.reconnect_interval}秒后尝试重连...")
await asyncio.sleep(self.reconnect_interval)
async def _pool_listen_loop(self):
"""
连接池模式下的监听循环
"""
while True: while True:
try: try:
# 从连接池获取一个连接 self.logger.info(f"正在尝试连接至 NapCat: {self.url}")
conn = await self.pool.get_connection() async with websockets.connect(
self.url, additional_headers=headers
) as websocket_raw:
websocket = cast(WebSocketClientProtocol, websocket_raw)
self.ws = websocket
self.logger.success("连接成功!")
await self._listen_loop(websocket)
try: except (
# 监听连接上的消息 websockets.exceptions.ConnectionClosed,
async for message in conn.conn: ConnectionRefusedError,
await self._handle_message(message, conn) ) as e:
except Exception as e: conn_error = WebSocketConnectionError(
self.logger.error(f"连接 {conn.conn_id} 监听异常: {e}") message=f"WebSocket连接失败: {str(e)}",
finally: code=ErrorCode.WS_CONNECTION_FAILED,
# 释放连接回连接池 original_error=e
await self.pool.release_connection(conn) )
self.logger.error(f"连接失败: {conn_error.message}")
self.logger.log_custom_exception(conn_error)
except Exception as e: except Exception as e:
self.logger.error(f"连接池监听循环异常: {e}") error = WebSocketError(
await asyncio.sleep(self.reconnect_interval) message=f"WebSocket运行异常: {str(e)}",
code=ErrorCode.WS_MESSAGE_ERROR,
original_error=e
)
self.logger.exception(f"运行异常: {error.message}")
self.logger.log_custom_exception(error)
async def _handle_message(self, message: str, conn): self.logger.info(f"{self.reconnect_interval}秒后尝试重连...")
""" await asyncio.sleep(self.reconnect_interval)
处理从连接池获取的消息
"""
try:
data = orjson.loads(message)
# 1. 处理 API 响应
# 如果消息中包含 echo 字段,说明是 API 调用的响应
echo_id = data.get("echo")
if echo_id and echo_id in self._pending_requests:
future = self._pending_requests.pop(echo_id)
if not future.done():
future.set_result(data)
return
# 2. 处理上报事件
# 如果消息中包含 post_type 字段,说明是 OneBot 上报的事件
if "post_type" in data:
# 使用 create_task 异步执行,避免阻塞 WebSocket 接收循环
asyncio.create_task(self.on_event(data))
except orjson.JSONDecodeError as e:
error = WebSocketError(
message=f"JSON解析失败: {str(e)}",
code=ErrorCode.WS_MESSAGE_ERROR,
original_error=e
)
self.logger.error(f"解析消息异常: {error.message}")
# 如果message是bytes类型需要先解码
decoded_message = message.decode('utf-8') if isinstance(message, bytes) else message
self.logger.debug(f"原始消息: {decoded_message}")
except Exception as e:
error = WebSocketError(
message=f"处理消息异常: {str(e)}",
code=ErrorCode.WS_MESSAGE_ERROR,
original_error=e
)
self.logger.exception(f"解析消息异常: {error.message}")
self.logger.log_custom_exception(error)
async def _listen_loop(self, websocket_connection: WebSocketClientProtocol) -> None: async def _listen_loop(self, websocket_connection: WebSocketClientProtocol) -> None:
""" """
@@ -298,6 +221,23 @@ class WS:
) )
self.logger.log_custom_exception(error) self.logger.log_custom_exception(error)
async def close(self) -> None:
"""
关闭 WebSocket 客户端,释放资源。
"""
self.logger.info("正在关闭 WebSocket 客户端...")
if self.ws:
await self.ws.close()
# 取消所有挂起的请求
for future in self._pending_requests.values():
if not future.done():
future.cancel()
self._pending_requests.clear()
self.logger.success("WebSocket 客户端已关闭")
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]: async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
""" """
向 OneBot v11 实现端发送一个 API 请求。 向 OneBot v11 实现端发送一个 API 请求。
@@ -313,100 +253,47 @@ class WS:
dict: OneBot API 的响应数据。如果超时或连接断开,则返回一个 dict: OneBot API 的响应数据。如果超时或连接断开,则返回一个
表示失败的字典。 表示失败的字典。
""" """
if self.use_pool: if not self.ws:
# 使用连接池模式 self.logger.error("调用 API 失败: WebSocket 未初始化")
if not self.pool: return create_error_response(
self.logger.error("调用 API 失败: WebSocket 连接池未初始化") code=ErrorCode.WS_DISCONNECTED,
return create_error_response( message="WebSocket未初始化",
code=ErrorCode.WS_DISCONNECTED, data={"action": action, "params": params}
message="WebSocket连接池未初始化", )
data={"action": action, "params": params}
)
# 从连接池获取一个连接 from websockets.protocol import State
conn = await self.pool.get_connection()
try:
echo_id = str(uuid.uuid4())
payload = {"action": action, "params": params or {}, "echo": echo_id}
await conn.send(orjson.dumps(payload)) if getattr(self.ws, "state", None) is not State.OPEN:
self.logger.error("调用 API 失败: WebSocket 连接未打开")
return create_error_response(
code=ErrorCode.WS_DISCONNECTED,
message="WebSocket连接未打开",
data={"action": action, "params": params}
)
# 在当前连接上等待特定 echo 的响应,并设置超时 echo_id = str(uuid.uuid4())
try: payload = {"action": action, "params": params or {}, "echo": echo_id}
async def wait_for_response():
async for message in conn.conn:
data = orjson.loads(message)
if data.get("echo") == echo_id:
return data
return await asyncio.wait_for(wait_for_response(), timeout=30.0) loop = asyncio.get_running_loop()
future = loop.create_future()
except asyncio.TimeoutError: self._pending_requests[echo_id] = future
raise # 重新抛出超时异常
except Exception as e:
raise WebSocketError(f"在等待API响应时连接出错: {e}")
except asyncio.TimeoutError:
self.logger.warning(f"API 调用超时: action={action}, params={params}")
return create_error_response(
code=ErrorCode.TIMEOUT_ERROR,
message="API调用超时",
data={"action": action, "params": params}
)
except Exception as e:
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
return create_error_response(
code=ErrorCode.WS_MESSAGE_ERROR,
message=f"API调用异常: {str(e)}",
data={"action": action, "params": params}
)
finally:
# 释放连接回连接池
await self.pool.release_connection(conn)
else:
# 单连接模式
if not self.ws:
self.logger.error("调用 API 失败: WebSocket 未初始化")
return create_error_response(
code=ErrorCode.WS_DISCONNECTED,
message="WebSocket未初始化",
data={"action": action, "params": params}
)
from websockets.protocol import State
if getattr(self.ws, "state", None) is not State.OPEN:
self.logger.error("调用 API 失败: WebSocket 连接未打开")
return create_error_response(
code=ErrorCode.WS_DISCONNECTED,
message="WebSocket连接未打开",
data={"action": action, "params": params}
)
echo_id = str(uuid.uuid4())
payload = {"action": action, "params": params or {}, "echo": echo_id}
loop = asyncio.get_running_loop()
future = loop.create_future()
self._pending_requests[echo_id] = future
try:
await self.ws.send(orjson.dumps(payload))
return await asyncio.wait_for(future, timeout=30.0)
except asyncio.TimeoutError:
self._pending_requests.pop(echo_id, None)
self.logger.warning(f"API 调用超时: action={action}, params={params}")
return create_error_response(
code=ErrorCode.TIMEOUT_ERROR,
message="API调用超时",
data={"action": action, "params": params}
)
except Exception as e:
self._pending_requests.pop(echo_id, None)
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
return create_error_response(
code=ErrorCode.WS_MESSAGE_ERROR,
message=f"API调用异常: {str(e)}",
data={"action": action, "params": params}
)
try:
await self.ws.send(orjson.dumps(payload))
return await asyncio.wait_for(future, timeout=30.0)
except asyncio.TimeoutError:
self._pending_requests.pop(echo_id, None)
self.logger.warning(f"API 调用超时: action={action}, params={params}")
return create_error_response(
code=ErrorCode.TIMEOUT_ERROR,
message="API调用超时",
data={"action": action, "params": params}
)
except Exception as e:
self._pending_requests.pop(echo_id, None)
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
return create_error_response(
code=ErrorCode.WS_MESSAGE_ERROR,
message=f"API调用异常: {str(e)}",
data={"action": action, "params": params}
)

View File

@@ -38,15 +38,26 @@ class BaseAPI:
try: try:
# 日志记录前,对敏感或过长的参数进行处理 # 日志记录前,对敏感或过长的参数进行处理
log_params = copy.deepcopy(params) log_params = copy.deepcopy(params)
if 'message' in log_params:
if isinstance(log_params['message'], list): # 处理各种可能包含base64数据的字段
for segment in log_params['message']: def truncate_base64_recursive(obj):
if segment.get('type') == 'image' and 'file' in segment.get('data', {}): """递归处理可能包含base64数据的对象"""
file_data = segment['data']['file'] if isinstance(obj, dict):
if file_data.startswith('data:image/'): for key, value in obj.items():
segment['data']['file'] = f"{file_data[:50]}... (base64 truncated)" if isinstance(value, str):
elif isinstance(log_params['message'], str) and log_params['message'].startswith('data:image/'): if value.startswith('data:image/') or value.startswith('data:video/') or value.startswith('data:audio/'):
log_params['message'] = f"{log_params['message'][:50]}... (base64 truncated)" obj[key] = f"{value[:50]}... (base64 truncated)"
elif len(value) > 100 and ('/' in value[:50] and '+' in value[:50] and '=' in value[-10:]):
# 检查是否是base64编码的字符串
obj[key] = f"{value[:50]}... (base64-like truncated)"
elif isinstance(value, (dict, list)):
truncate_base64_recursive(value)
elif isinstance(obj, list):
for item in obj:
if isinstance(item, (dict, list)):
truncate_base64_recursive(item)
truncate_base64_recursive(log_params)
# 如果是发送消息的动作,则原子化地增加发送消息总数 # 如果是发送消息的动作,则原子化地增加发送消息总数
if action in ["send_private_msg", "send_group_msg", "send_msg"]: if action in ["send_private_msg", "send_group_msg", "send_msg"]:
@@ -62,8 +73,13 @@ class BaseAPI:
logger.error(f"发送消息计数失败: {e}") logger.error(f"发送消息计数失败: {e}")
logger.debug(f"调用API -> action: {action}, params: {log_params}") logger.debug(f"调用API -> action: {action}, params: {log_params}")
response = await self._ws.call_api(action, params) response = await self._ws.call_api(action, params)
logger.debug(f"API响应 <- {response}")
# 对响应也做类似的处理
log_response = copy.deepcopy(response)
truncate_base64_recursive(log_response)
logger.debug(f"API响应 <- {log_response}")
if response.get("status") == "failed": if response.get("status") == "failed":
logger.warning(f"API调用失败: {response}") logger.warning(f"API调用失败: {response}")

View File

@@ -1,247 +0,0 @@
"""
WebSocket 连接池模块
该模块实现了 WebSocket 连接池功能,用于管理多个 WebSocket 连接,
提高并发处理能力和连接复用效率。
"""
import asyncio
import websockets
from websockets.legacy.client import WebSocketClientProtocol
from typing import Optional, Dict, Any, cast, Union
import uuid
from loguru import logger
from .config_loader import global_config
from .utils.exceptions import WebSocketError, WebSocketConnectionError
class WSConnection:
"""
WebSocket 连接包装类
封装单个 WebSocket 连接的状态和操作
"""
def __init__(self, conn: WebSocketClientProtocol, conn_id: str):
self.conn = conn
self.conn_id = conn_id
self.last_used = asyncio.get_event_loop().time()
self.is_active = True
self._pending_requests: Dict[str, asyncio.Future] = {}
async def send(self, data: Union[Dict[Any, Any], bytes]):
"""
发送数据到 WebSocket 连接
"""
if not self.is_active:
raise WebSocketError(f"连接 {self.conn_id} 已关闭")
try:
await self.conn.send(data)
self.last_used = asyncio.get_event_loop().time()
except Exception as e:
self.is_active = False
raise WebSocketError(f"发送数据失败: {e}")
async def recv(self):
"""
从 WebSocket 连接接收数据
"""
if not self.is_active:
raise WebSocketError(f"连接 {self.conn_id} 已关闭")
try:
data = await self.conn.recv()
self.last_used = asyncio.get_event_loop().time()
return data
except Exception as e:
self.is_active = False
raise WebSocketError(f"接收数据失败: {e}")
async def ping(self, timeout: int = 5) -> bool:
"""
对 WebSocket 连接执行 ping-pong 健康检查
"""
if not self.is_active:
return False
try:
await asyncio.wait_for(self.conn.ping(), timeout=timeout)
return True
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
self.is_active = False
return False
async def close(self):
"""
关闭 WebSocket 连接
"""
if self.is_active:
self.is_active = False
await self.conn.close()
class WSConnectionPool:
"""
WebSocket 连接池
管理多个 WebSocket 连接,提供连接的获取、释放和回收功能
"""
def __init__(self, pool_size: int = 3, max_idle_time: int = 300):
"""
初始化连接池
:param pool_size: 连接池大小
:param max_idle_time: 连接最大空闲时间(秒)
"""
self.pool_size = pool_size
self.max_idle_time = max_idle_time
self.pool: asyncio.Queue[WSConnection] = asyncio.Queue(maxsize=pool_size)
self._closed = False
self._cleanup_task: Optional[asyncio.Task] = None
# 从全局配置读取参数
self.url = global_config.napcat_ws.uri
self.token = global_config.napcat_ws.token
self.reconnect_interval = global_config.napcat_ws.reconnect_interval
logger.info(f"WebSocket 连接池初始化完成,大小: {pool_size}")
async def initialize(self):
"""
初始化连接池,创建初始连接
"""
if self._closed:
raise WebSocketError("连接池已关闭")
# 启动连接清理任务
self._cleanup_task = asyncio.create_task(self._cleanup_idle_connections())
# 创建初始连接
for _ in range(self.pool_size):
try:
conn = await self._create_connection()
await self.pool.put(conn)
logger.info(f"WebSocket 连接 {conn.conn_id} 已创建并加入连接池")
except Exception as e:
logger.error(f"创建初始连接失败: {e}")
async def _create_connection(self) -> WSConnection:
"""
创建新的 WebSocket 连接
"""
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
try:
conn_id = str(uuid.uuid4())
websocket_raw = await websockets.connect(
self.url, additional_headers=headers
)
websocket = cast(WebSocketClientProtocol, websocket_raw)
conn = WSConnection(websocket, conn_id)
logger.info(f"WebSocket 连接 {conn_id} 已建立")
return conn
except Exception as e:
raise WebSocketConnectionError(f"创建 WebSocket 连接失败: {e}")
async def get_connection(self) -> WSConnection:
"""
从连接池获取一个健康的连接,包含健康检查。
"""
if self._closed:
raise WebSocketError("连接池已关闭")
try:
# 尝试从连接池获取连接
conn = await asyncio.wait_for(self.pool.get(), timeout=5)
# 健康检查
if await conn.ping():
logger.debug(f"连接 {conn.conn_id} 健康检查通过")
return conn
else:
logger.warning(f"连接 {conn.conn_id} 健康检查失败,丢弃并获取新连接")
await conn.close()
return await self.get_connection() # 递归获取下一个
except asyncio.TimeoutError:
# 连接池为空,创建新连接
logger.warning("连接池在5秒内无可用连接创建新连接")
return await self._create_connection()
except Exception as e:
raise WebSocketError(f"获取连接时发生未知错误: {e}")
async def release_connection(self, conn: WSConnection):
"""
释放连接回连接池
"""
if self._closed:
await conn.close()
return
if not conn.is_active:
logger.warning(f"连接 {conn.conn_id} 已失效,不返回连接池")
return
try:
if self.pool.full():
# 连接池已满,关闭该连接
await conn.close()
logger.info(f"连接池已满,关闭连接 {conn.conn_id}")
else:
await self.pool.put(conn)
logger.debug(f"连接 {conn.conn_id} 已返回连接池")
except Exception as e:
logger.error(f"释放连接失败: {e}")
await conn.close()
async def _cleanup_idle_connections(self):
"""
清理空闲连接任务
"""
while not self._closed:
await asyncio.sleep(60) # 每分钟检查一次
try:
# 检查连接池中的连接
new_pool = asyncio.Queue(maxsize=self.pool_size)
current_time = asyncio.get_event_loop().time()
while not self.pool.empty():
conn = await self.pool.get()
if current_time - conn.last_used > self.max_idle_time:
# 连接空闲时间过长,关闭
await conn.close()
logger.info(f"清理空闲连接 {conn.conn_id}")
else:
# 放回新队列
await new_pool.put(conn)
# 替换原连接池
self.pool = new_pool
except Exception as e:
logger.error(f"清理空闲连接失败: {e}")
async def close(self):
"""
关闭连接池
"""
if self._closed:
return
self._closed = True
# 停止清理任务
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
# 关闭所有连接
while not self.pool.empty():
conn = await self.pool.get()
await conn.close()
logger.info("WebSocket 连接池已关闭")

21
main.py
View File

@@ -158,12 +158,13 @@ async def main():
else: else:
logger.warning(f"插件目录不存在 {plugin_path}") logger.warning(f"插件目录不存在 {plugin_path}")
websocket_client = None
try: try:
# 初始化代码执行器 # 初始化代码执行器
code_executor = initialize_executor(config) code_executor = initialize_executor(config)
# 使用连接池模式初始化 WebSocket 客户端 # 初始化 WebSocket 客户端
websocket_client = WS(code_executor=code_executor, use_pool=True) websocket_client = WS(code_executor=code_executor)
# 启动代码执行器的后台 worker # 启动代码执行器的后台 worker
logger.debug("[Main] 检查是否需要启动代码执行 Worker...") logger.debug("[Main] 检查是否需要启动代码执行 Worker...")
@@ -174,11 +175,22 @@ async def main():
logger.warning("[Main] 未启动代码执行 Worker因为 Docker 客户端未初始化或连接失败。") logger.warning("[Main] 未启动代码执行 Worker因为 Docker 客户端未初始化或连接失败。")
await websocket_client.connect() await websocket_client.connect()
except asyncio.CancelledError:
logger.info("主任务被取消,正在停止...")
finally: finally:
logger.info("正在清理资源...")
if observer.is_alive(): if observer.is_alive():
observer.stop() observer.stop()
observer.join() observer.join()
if websocket_client:
await websocket_client.close()
# 关闭浏览器管理器
await browser_manager.shutdown()
logger.success("资源清理完成,程序退出。")
if __name__ == "__main__": if __name__ == "__main__":
""" """
@@ -193,8 +205,9 @@ if __name__ == "__main__":
try: try:
asyncio.run(main()) asyncio.run(main())
except KeyboardInterrupt: except KeyboardInterrupt:
main_logger.info("程序已被用户中断") # 捕获 KeyboardInterrupt不做任何操作让 asyncio.run 正常结束
exit(0) # 这样 main 函数中的 finally 块会被执行
pass
except Exception as e: except Exception as e:
main_logger.exception("程序发生未处理的全局异常") main_logger.exception("程序发生未处理的全局异常")