refactor(websocket): 移除连接池模式并改进资源清理
移除 WebSocket 连接池实现,改为单连接模式以简化代码结构 在 main 函数中添加资源清理逻辑,确保程序退出时正确关闭所有资源 改进 base64 数据处理逻辑,支持递归处理嵌套结构中的敏感数据 呵呵线程池加WS是神人
This commit is contained in:
283
core/WS.py
283
core/WS.py
@@ -32,7 +32,6 @@ from .utils.exceptions import (
|
|||||||
WebSocketError, WebSocketConnectionError
|
WebSocketError, WebSocketConnectionError
|
||||||
)
|
)
|
||||||
from .utils.error_codes import ErrorCode, create_error_response
|
from .utils.error_codes import ErrorCode, create_error_response
|
||||||
from .ws_pool import WSConnectionPool
|
|
||||||
|
|
||||||
|
|
||||||
class WS:
|
class WS:
|
||||||
@@ -40,14 +39,13 @@ class WS:
|
|||||||
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
|
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, code_executor: Optional[CodeExecutor] = None, use_pool: bool = True) -> None:
|
def __init__(self, code_executor: Optional[CodeExecutor] = None) -> None:
|
||||||
"""
|
"""
|
||||||
初始化 WebSocket 客户端。
|
初始化 WebSocket 客户端。
|
||||||
|
|
||||||
从全局配置中读取 WebSocket URI、访问令牌(Token)和重连间隔。
|
从全局配置中读取 WebSocket URI、访问令牌(Token)和重连间隔。
|
||||||
|
|
||||||
:param code_executor: 代码执行器实例
|
:param code_executor: 代码执行器实例
|
||||||
:param use_pool: 是否使用连接池
|
|
||||||
"""
|
"""
|
||||||
# 读取参数
|
# 读取参数
|
||||||
cfg = global_config.napcat_ws
|
cfg = global_config.napcat_ws
|
||||||
@@ -61,8 +59,6 @@ class WS:
|
|||||||
self.bot: 'Bot' | None = None
|
self.bot: 'Bot' | None = None
|
||||||
self.self_id: int | None = None
|
self.self_id: int | None = None
|
||||||
self.code_executor = code_executor
|
self.code_executor = code_executor
|
||||||
self.use_pool = use_pool
|
|
||||||
self.pool: Optional[WSConnectionPool] = None
|
|
||||||
|
|
||||||
# 创建模块专用日志记录器
|
# 创建模块专用日志记录器
|
||||||
self.logger = ModuleLogger("WebSocket")
|
self.logger = ModuleLogger("WebSocket")
|
||||||
@@ -76,112 +72,39 @@ class WS:
|
|||||||
"""
|
"""
|
||||||
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
||||||
|
|
||||||
if self.use_pool:
|
|
||||||
# 使用连接池模式
|
|
||||||
self.pool = WSConnectionPool(pool_size=3)
|
|
||||||
await self.pool.initialize()
|
|
||||||
self.logger.success("WebSocket 连接池初始化完成")
|
|
||||||
|
|
||||||
# 启动连接池监听循环
|
|
||||||
await self._pool_listen_loop()
|
|
||||||
else:
|
|
||||||
# 单连接模式
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
self.logger.info(f"正在尝试连接至 NapCat: {self.url}")
|
|
||||||
async with websockets.connect(
|
|
||||||
self.url, additional_headers=headers
|
|
||||||
) as websocket_raw:
|
|
||||||
websocket = cast(WebSocketClientProtocol, websocket_raw)
|
|
||||||
self.ws = websocket
|
|
||||||
self.logger.success("连接成功!")
|
|
||||||
await self._listen_loop(websocket)
|
|
||||||
|
|
||||||
except (
|
|
||||||
websockets.exceptions.ConnectionClosed,
|
|
||||||
ConnectionRefusedError,
|
|
||||||
) as e:
|
|
||||||
conn_error = WebSocketConnectionError(
|
|
||||||
message=f"WebSocket连接失败: {str(e)}",
|
|
||||||
code=ErrorCode.WS_CONNECTION_FAILED,
|
|
||||||
original_error=e
|
|
||||||
)
|
|
||||||
self.logger.error(f"连接失败: {conn_error.message}")
|
|
||||||
self.logger.log_custom_exception(conn_error)
|
|
||||||
except Exception as e:
|
|
||||||
error = WebSocketError(
|
|
||||||
message=f"WebSocket运行异常: {str(e)}",
|
|
||||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
|
||||||
original_error=e
|
|
||||||
)
|
|
||||||
self.logger.exception(f"运行异常: {error.message}")
|
|
||||||
self.logger.log_custom_exception(error)
|
|
||||||
|
|
||||||
self.logger.info(f"{self.reconnect_interval}秒后尝试重连...")
|
|
||||||
await asyncio.sleep(self.reconnect_interval)
|
|
||||||
|
|
||||||
async def _pool_listen_loop(self):
|
|
||||||
"""
|
|
||||||
连接池模式下的监听循环
|
|
||||||
"""
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# 从连接池获取一个连接
|
self.logger.info(f"正在尝试连接至 NapCat: {self.url}")
|
||||||
conn = await self.pool.get_connection()
|
async with websockets.connect(
|
||||||
|
self.url, additional_headers=headers
|
||||||
|
) as websocket_raw:
|
||||||
|
websocket = cast(WebSocketClientProtocol, websocket_raw)
|
||||||
|
self.ws = websocket
|
||||||
|
self.logger.success("连接成功!")
|
||||||
|
await self._listen_loop(websocket)
|
||||||
|
|
||||||
try:
|
except (
|
||||||
# 监听连接上的消息
|
websockets.exceptions.ConnectionClosed,
|
||||||
async for message in conn.conn:
|
ConnectionRefusedError,
|
||||||
await self._handle_message(message, conn)
|
) as e:
|
||||||
except Exception as e:
|
conn_error = WebSocketConnectionError(
|
||||||
self.logger.error(f"连接 {conn.conn_id} 监听异常: {e}")
|
message=f"WebSocket连接失败: {str(e)}",
|
||||||
finally:
|
code=ErrorCode.WS_CONNECTION_FAILED,
|
||||||
# 释放连接回连接池
|
original_error=e
|
||||||
await self.pool.release_connection(conn)
|
)
|
||||||
|
self.logger.error(f"连接失败: {conn_error.message}")
|
||||||
|
self.logger.log_custom_exception(conn_error)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"连接池监听循环异常: {e}")
|
error = WebSocketError(
|
||||||
await asyncio.sleep(self.reconnect_interval)
|
message=f"WebSocket运行异常: {str(e)}",
|
||||||
|
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||||
|
original_error=e
|
||||||
|
)
|
||||||
|
self.logger.exception(f"运行异常: {error.message}")
|
||||||
|
self.logger.log_custom_exception(error)
|
||||||
|
|
||||||
async def _handle_message(self, message: str, conn):
|
self.logger.info(f"{self.reconnect_interval}秒后尝试重连...")
|
||||||
"""
|
await asyncio.sleep(self.reconnect_interval)
|
||||||
处理从连接池获取的消息
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
data = orjson.loads(message)
|
|
||||||
|
|
||||||
# 1. 处理 API 响应
|
|
||||||
# 如果消息中包含 echo 字段,说明是 API 调用的响应
|
|
||||||
echo_id = data.get("echo")
|
|
||||||
if echo_id and echo_id in self._pending_requests:
|
|
||||||
future = self._pending_requests.pop(echo_id)
|
|
||||||
if not future.done():
|
|
||||||
future.set_result(data)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 2. 处理上报事件
|
|
||||||
# 如果消息中包含 post_type 字段,说明是 OneBot 上报的事件
|
|
||||||
if "post_type" in data:
|
|
||||||
# 使用 create_task 异步执行,避免阻塞 WebSocket 接收循环
|
|
||||||
asyncio.create_task(self.on_event(data))
|
|
||||||
|
|
||||||
except orjson.JSONDecodeError as e:
|
|
||||||
error = WebSocketError(
|
|
||||||
message=f"JSON解析失败: {str(e)}",
|
|
||||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
|
||||||
original_error=e
|
|
||||||
)
|
|
||||||
self.logger.error(f"解析消息异常: {error.message}")
|
|
||||||
# 如果message是bytes类型,需要先解码
|
|
||||||
decoded_message = message.decode('utf-8') if isinstance(message, bytes) else message
|
|
||||||
self.logger.debug(f"原始消息: {decoded_message}")
|
|
||||||
except Exception as e:
|
|
||||||
error = WebSocketError(
|
|
||||||
message=f"处理消息异常: {str(e)}",
|
|
||||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
|
||||||
original_error=e
|
|
||||||
)
|
|
||||||
self.logger.exception(f"解析消息异常: {error.message}")
|
|
||||||
self.logger.log_custom_exception(error)
|
|
||||||
|
|
||||||
async def _listen_loop(self, websocket_connection: WebSocketClientProtocol) -> None:
|
async def _listen_loop(self, websocket_connection: WebSocketClientProtocol) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -298,6 +221,23 @@ class WS:
|
|||||||
)
|
)
|
||||||
self.logger.log_custom_exception(error)
|
self.logger.log_custom_exception(error)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""
|
||||||
|
关闭 WebSocket 客户端,释放资源。
|
||||||
|
"""
|
||||||
|
self.logger.info("正在关闭 WebSocket 客户端...")
|
||||||
|
|
||||||
|
if self.ws:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
# 取消所有挂起的请求
|
||||||
|
for future in self._pending_requests.values():
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
self._pending_requests.clear()
|
||||||
|
|
||||||
|
self.logger.success("WebSocket 客户端已关闭")
|
||||||
|
|
||||||
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
|
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
|
||||||
"""
|
"""
|
||||||
向 OneBot v11 实现端发送一个 API 请求。
|
向 OneBot v11 实现端发送一个 API 请求。
|
||||||
@@ -313,100 +253,47 @@ class WS:
|
|||||||
dict: OneBot API 的响应数据。如果超时或连接断开,则返回一个
|
dict: OneBot API 的响应数据。如果超时或连接断开,则返回一个
|
||||||
表示失败的字典。
|
表示失败的字典。
|
||||||
"""
|
"""
|
||||||
if self.use_pool:
|
if not self.ws:
|
||||||
# 使用连接池模式
|
self.logger.error("调用 API 失败: WebSocket 未初始化")
|
||||||
if not self.pool:
|
return create_error_response(
|
||||||
self.logger.error("调用 API 失败: WebSocket 连接池未初始化")
|
code=ErrorCode.WS_DISCONNECTED,
|
||||||
return create_error_response(
|
message="WebSocket未初始化",
|
||||||
code=ErrorCode.WS_DISCONNECTED,
|
data={"action": action, "params": params}
|
||||||
message="WebSocket连接池未初始化",
|
)
|
||||||
data={"action": action, "params": params}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 从连接池获取一个连接
|
from websockets.protocol import State
|
||||||
conn = await self.pool.get_connection()
|
|
||||||
try:
|
|
||||||
echo_id = str(uuid.uuid4())
|
|
||||||
payload = {"action": action, "params": params or {}, "echo": echo_id}
|
|
||||||
|
|
||||||
await conn.send(orjson.dumps(payload))
|
if getattr(self.ws, "state", None) is not State.OPEN:
|
||||||
|
self.logger.error("调用 API 失败: WebSocket 连接未打开")
|
||||||
|
return create_error_response(
|
||||||
|
code=ErrorCode.WS_DISCONNECTED,
|
||||||
|
message="WebSocket连接未打开",
|
||||||
|
data={"action": action, "params": params}
|
||||||
|
)
|
||||||
|
|
||||||
# 在当前连接上等待特定 echo 的响应,并设置超时
|
echo_id = str(uuid.uuid4())
|
||||||
try:
|
payload = {"action": action, "params": params or {}, "echo": echo_id}
|
||||||
async def wait_for_response():
|
|
||||||
async for message in conn.conn:
|
|
||||||
data = orjson.loads(message)
|
|
||||||
if data.get("echo") == echo_id:
|
|
||||||
return data
|
|
||||||
|
|
||||||
return await asyncio.wait_for(wait_for_response(), timeout=30.0)
|
loop = asyncio.get_running_loop()
|
||||||
|
future = loop.create_future()
|
||||||
except asyncio.TimeoutError:
|
self._pending_requests[echo_id] = future
|
||||||
raise # 重新抛出超时异常
|
|
||||||
except Exception as e:
|
|
||||||
raise WebSocketError(f"在等待API响应时连接出错: {e}")
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.logger.warning(f"API 调用超时: action={action}, params={params}")
|
|
||||||
return create_error_response(
|
|
||||||
code=ErrorCode.TIMEOUT_ERROR,
|
|
||||||
message="API调用超时",
|
|
||||||
data={"action": action, "params": params}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
|
|
||||||
return create_error_response(
|
|
||||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
|
||||||
message=f"API调用异常: {str(e)}",
|
|
||||||
data={"action": action, "params": params}
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
# 释放连接回连接池
|
|
||||||
await self.pool.release_connection(conn)
|
|
||||||
else:
|
|
||||||
# 单连接模式
|
|
||||||
if not self.ws:
|
|
||||||
self.logger.error("调用 API 失败: WebSocket 未初始化")
|
|
||||||
return create_error_response(
|
|
||||||
code=ErrorCode.WS_DISCONNECTED,
|
|
||||||
message="WebSocket未初始化",
|
|
||||||
data={"action": action, "params": params}
|
|
||||||
)
|
|
||||||
|
|
||||||
from websockets.protocol import State
|
|
||||||
|
|
||||||
if getattr(self.ws, "state", None) is not State.OPEN:
|
|
||||||
self.logger.error("调用 API 失败: WebSocket 连接未打开")
|
|
||||||
return create_error_response(
|
|
||||||
code=ErrorCode.WS_DISCONNECTED,
|
|
||||||
message="WebSocket连接未打开",
|
|
||||||
data={"action": action, "params": params}
|
|
||||||
)
|
|
||||||
|
|
||||||
echo_id = str(uuid.uuid4())
|
|
||||||
payload = {"action": action, "params": params or {}, "echo": echo_id}
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
future = loop.create_future()
|
|
||||||
self._pending_requests[echo_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.ws.send(orjson.dumps(payload))
|
|
||||||
return await asyncio.wait_for(future, timeout=30.0)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self._pending_requests.pop(echo_id, None)
|
|
||||||
self.logger.warning(f"API 调用超时: action={action}, params={params}")
|
|
||||||
return create_error_response(
|
|
||||||
code=ErrorCode.TIMEOUT_ERROR,
|
|
||||||
message="API调用超时",
|
|
||||||
data={"action": action, "params": params}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self._pending_requests.pop(echo_id, None)
|
|
||||||
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
|
|
||||||
return create_error_response(
|
|
||||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
|
||||||
message=f"API调用异常: {str(e)}",
|
|
||||||
data={"action": action, "params": params}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.ws.send(orjson.dumps(payload))
|
||||||
|
return await asyncio.wait_for(future, timeout=30.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._pending_requests.pop(echo_id, None)
|
||||||
|
self.logger.warning(f"API 调用超时: action={action}, params={params}")
|
||||||
|
return create_error_response(
|
||||||
|
code=ErrorCode.TIMEOUT_ERROR,
|
||||||
|
message="API调用超时",
|
||||||
|
data={"action": action, "params": params}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self._pending_requests.pop(echo_id, None)
|
||||||
|
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
|
||||||
|
return create_error_response(
|
||||||
|
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||||
|
message=f"API调用异常: {str(e)}",
|
||||||
|
data={"action": action, "params": params}
|
||||||
|
)
|
||||||
|
|||||||
@@ -38,15 +38,26 @@ class BaseAPI:
|
|||||||
try:
|
try:
|
||||||
# 日志记录前,对敏感或过长的参数进行处理
|
# 日志记录前,对敏感或过长的参数进行处理
|
||||||
log_params = copy.deepcopy(params)
|
log_params = copy.deepcopy(params)
|
||||||
if 'message' in log_params:
|
|
||||||
if isinstance(log_params['message'], list):
|
# 处理各种可能包含base64数据的字段
|
||||||
for segment in log_params['message']:
|
def truncate_base64_recursive(obj):
|
||||||
if segment.get('type') == 'image' and 'file' in segment.get('data', {}):
|
"""递归处理可能包含base64数据的对象"""
|
||||||
file_data = segment['data']['file']
|
if isinstance(obj, dict):
|
||||||
if file_data.startswith('data:image/'):
|
for key, value in obj.items():
|
||||||
segment['data']['file'] = f"{file_data[:50]}... (base64 truncated)"
|
if isinstance(value, str):
|
||||||
elif isinstance(log_params['message'], str) and log_params['message'].startswith('data:image/'):
|
if value.startswith('data:image/') or value.startswith('data:video/') or value.startswith('data:audio/'):
|
||||||
log_params['message'] = f"{log_params['message'][:50]}... (base64 truncated)"
|
obj[key] = f"{value[:50]}... (base64 truncated)"
|
||||||
|
elif len(value) > 100 and ('/' in value[:50] and '+' in value[:50] and '=' in value[-10:]):
|
||||||
|
# 检查是否是base64编码的字符串
|
||||||
|
obj[key] = f"{value[:50]}... (base64-like truncated)"
|
||||||
|
elif isinstance(value, (dict, list)):
|
||||||
|
truncate_base64_recursive(value)
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
for item in obj:
|
||||||
|
if isinstance(item, (dict, list)):
|
||||||
|
truncate_base64_recursive(item)
|
||||||
|
|
||||||
|
truncate_base64_recursive(log_params)
|
||||||
|
|
||||||
# 如果是发送消息的动作,则原子化地增加发送消息总数
|
# 如果是发送消息的动作,则原子化地增加发送消息总数
|
||||||
if action in ["send_private_msg", "send_group_msg", "send_msg"]:
|
if action in ["send_private_msg", "send_group_msg", "send_msg"]:
|
||||||
@@ -62,8 +73,13 @@ class BaseAPI:
|
|||||||
logger.error(f"发送消息计数失败: {e}")
|
logger.error(f"发送消息计数失败: {e}")
|
||||||
|
|
||||||
logger.debug(f"调用API -> action: {action}, params: {log_params}")
|
logger.debug(f"调用API -> action: {action}, params: {log_params}")
|
||||||
|
|
||||||
response = await self._ws.call_api(action, params)
|
response = await self._ws.call_api(action, params)
|
||||||
logger.debug(f"API响应 <- {response}")
|
|
||||||
|
# 对响应也做类似的处理
|
||||||
|
log_response = copy.deepcopy(response)
|
||||||
|
truncate_base64_recursive(log_response)
|
||||||
|
logger.debug(f"API响应 <- {log_response}")
|
||||||
|
|
||||||
if response.get("status") == "failed":
|
if response.get("status") == "failed":
|
||||||
logger.warning(f"API调用失败: {response}")
|
logger.warning(f"API调用失败: {response}")
|
||||||
|
|||||||
247
core/ws_pool.py
247
core/ws_pool.py
@@ -1,247 +0,0 @@
|
|||||||
"""
|
|
||||||
WebSocket 连接池模块
|
|
||||||
|
|
||||||
该模块实现了 WebSocket 连接池功能,用于管理多个 WebSocket 连接,
|
|
||||||
提高并发处理能力和连接复用效率。
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
import websockets
|
|
||||||
from websockets.legacy.client import WebSocketClientProtocol
|
|
||||||
from typing import Optional, Dict, Any, cast, Union
|
|
||||||
import uuid
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from .config_loader import global_config
|
|
||||||
from .utils.exceptions import WebSocketError, WebSocketConnectionError
|
|
||||||
|
|
||||||
|
|
||||||
class WSConnection:
|
|
||||||
"""
|
|
||||||
WebSocket 连接包装类
|
|
||||||
|
|
||||||
封装单个 WebSocket 连接的状态和操作
|
|
||||||
"""
|
|
||||||
def __init__(self, conn: WebSocketClientProtocol, conn_id: str):
|
|
||||||
self.conn = conn
|
|
||||||
self.conn_id = conn_id
|
|
||||||
self.last_used = asyncio.get_event_loop().time()
|
|
||||||
self.is_active = True
|
|
||||||
self._pending_requests: Dict[str, asyncio.Future] = {}
|
|
||||||
|
|
||||||
async def send(self, data: Union[Dict[Any, Any], bytes]):
|
|
||||||
"""
|
|
||||||
发送数据到 WebSocket 连接
|
|
||||||
"""
|
|
||||||
if not self.is_active:
|
|
||||||
raise WebSocketError(f"连接 {self.conn_id} 已关闭")
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.conn.send(data)
|
|
||||||
self.last_used = asyncio.get_event_loop().time()
|
|
||||||
except Exception as e:
|
|
||||||
self.is_active = False
|
|
||||||
raise WebSocketError(f"发送数据失败: {e}")
|
|
||||||
|
|
||||||
async def recv(self):
|
|
||||||
"""
|
|
||||||
从 WebSocket 连接接收数据
|
|
||||||
"""
|
|
||||||
if not self.is_active:
|
|
||||||
raise WebSocketError(f"连接 {self.conn_id} 已关闭")
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = await self.conn.recv()
|
|
||||||
self.last_used = asyncio.get_event_loop().time()
|
|
||||||
return data
|
|
||||||
except Exception as e:
|
|
||||||
self.is_active = False
|
|
||||||
raise WebSocketError(f"接收数据失败: {e}")
|
|
||||||
|
|
||||||
async def ping(self, timeout: int = 5) -> bool:
|
|
||||||
"""
|
|
||||||
对 WebSocket 连接执行 ping-pong 健康检查
|
|
||||||
"""
|
|
||||||
if not self.is_active:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self.conn.ping(), timeout=timeout)
|
|
||||||
return True
|
|
||||||
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
|
|
||||||
self.is_active = False
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""
|
|
||||||
关闭 WebSocket 连接
|
|
||||||
"""
|
|
||||||
if self.is_active:
|
|
||||||
self.is_active = False
|
|
||||||
await self.conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
class WSConnectionPool:
|
|
||||||
"""
|
|
||||||
WebSocket 连接池
|
|
||||||
|
|
||||||
管理多个 WebSocket 连接,提供连接的获取、释放和回收功能
|
|
||||||
"""
|
|
||||||
def __init__(self, pool_size: int = 3, max_idle_time: int = 300):
|
|
||||||
"""
|
|
||||||
初始化连接池
|
|
||||||
|
|
||||||
:param pool_size: 连接池大小
|
|
||||||
:param max_idle_time: 连接最大空闲时间(秒)
|
|
||||||
"""
|
|
||||||
self.pool_size = pool_size
|
|
||||||
self.max_idle_time = max_idle_time
|
|
||||||
self.pool: asyncio.Queue[WSConnection] = asyncio.Queue(maxsize=pool_size)
|
|
||||||
self._closed = False
|
|
||||||
self._cleanup_task: Optional[asyncio.Task] = None
|
|
||||||
|
|
||||||
# 从全局配置读取参数
|
|
||||||
self.url = global_config.napcat_ws.uri
|
|
||||||
self.token = global_config.napcat_ws.token
|
|
||||||
self.reconnect_interval = global_config.napcat_ws.reconnect_interval
|
|
||||||
|
|
||||||
logger.info(f"WebSocket 连接池初始化完成,大小: {pool_size}")
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
"""
|
|
||||||
初始化连接池,创建初始连接
|
|
||||||
"""
|
|
||||||
if self._closed:
|
|
||||||
raise WebSocketError("连接池已关闭")
|
|
||||||
|
|
||||||
# 启动连接清理任务
|
|
||||||
self._cleanup_task = asyncio.create_task(self._cleanup_idle_connections())
|
|
||||||
|
|
||||||
# 创建初始连接
|
|
||||||
for _ in range(self.pool_size):
|
|
||||||
try:
|
|
||||||
conn = await self._create_connection()
|
|
||||||
await self.pool.put(conn)
|
|
||||||
logger.info(f"WebSocket 连接 {conn.conn_id} 已创建并加入连接池")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建初始连接失败: {e}")
|
|
||||||
|
|
||||||
async def _create_connection(self) -> WSConnection:
|
|
||||||
"""
|
|
||||||
创建新的 WebSocket 连接
|
|
||||||
"""
|
|
||||||
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
conn_id = str(uuid.uuid4())
|
|
||||||
websocket_raw = await websockets.connect(
|
|
||||||
self.url, additional_headers=headers
|
|
||||||
)
|
|
||||||
websocket = cast(WebSocketClientProtocol, websocket_raw)
|
|
||||||
|
|
||||||
conn = WSConnection(websocket, conn_id)
|
|
||||||
logger.info(f"WebSocket 连接 {conn_id} 已建立")
|
|
||||||
return conn
|
|
||||||
except Exception as e:
|
|
||||||
raise WebSocketConnectionError(f"创建 WebSocket 连接失败: {e}")
|
|
||||||
|
|
||||||
async def get_connection(self) -> WSConnection:
|
|
||||||
"""
|
|
||||||
从连接池获取一个健康的连接,包含健康检查。
|
|
||||||
"""
|
|
||||||
if self._closed:
|
|
||||||
raise WebSocketError("连接池已关闭")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 尝试从连接池获取连接
|
|
||||||
conn = await asyncio.wait_for(self.pool.get(), timeout=5)
|
|
||||||
|
|
||||||
# 健康检查
|
|
||||||
if await conn.ping():
|
|
||||||
logger.debug(f"连接 {conn.conn_id} 健康检查通过")
|
|
||||||
return conn
|
|
||||||
else:
|
|
||||||
logger.warning(f"连接 {conn.conn_id} 健康检查失败,丢弃并获取新连接")
|
|
||||||
await conn.close()
|
|
||||||
return await self.get_connection() # 递归获取下一个
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# 连接池为空,创建新连接
|
|
||||||
logger.warning("连接池在5秒内无可用连接,创建新连接")
|
|
||||||
return await self._create_connection()
|
|
||||||
except Exception as e:
|
|
||||||
raise WebSocketError(f"获取连接时发生未知错误: {e}")
|
|
||||||
|
|
||||||
async def release_connection(self, conn: WSConnection):
|
|
||||||
"""
|
|
||||||
释放连接回连接池
|
|
||||||
"""
|
|
||||||
if self._closed:
|
|
||||||
await conn.close()
|
|
||||||
return
|
|
||||||
|
|
||||||
if not conn.is_active:
|
|
||||||
logger.warning(f"连接 {conn.conn_id} 已失效,不返回连接池")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
if self.pool.full():
|
|
||||||
# 连接池已满,关闭该连接
|
|
||||||
await conn.close()
|
|
||||||
logger.info(f"连接池已满,关闭连接 {conn.conn_id}")
|
|
||||||
else:
|
|
||||||
await self.pool.put(conn)
|
|
||||||
logger.debug(f"连接 {conn.conn_id} 已返回连接池")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"释放连接失败: {e}")
|
|
||||||
await conn.close()
|
|
||||||
|
|
||||||
async def _cleanup_idle_connections(self):
|
|
||||||
"""
|
|
||||||
清理空闲连接任务
|
|
||||||
"""
|
|
||||||
while not self._closed:
|
|
||||||
await asyncio.sleep(60) # 每分钟检查一次
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 检查连接池中的连接
|
|
||||||
new_pool = asyncio.Queue(maxsize=self.pool_size)
|
|
||||||
current_time = asyncio.get_event_loop().time()
|
|
||||||
|
|
||||||
while not self.pool.empty():
|
|
||||||
conn = await self.pool.get()
|
|
||||||
|
|
||||||
if current_time - conn.last_used > self.max_idle_time:
|
|
||||||
# 连接空闲时间过长,关闭
|
|
||||||
await conn.close()
|
|
||||||
logger.info(f"清理空闲连接 {conn.conn_id}")
|
|
||||||
else:
|
|
||||||
# 放回新队列
|
|
||||||
await new_pool.put(conn)
|
|
||||||
|
|
||||||
# 替换原连接池
|
|
||||||
self.pool = new_pool
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"清理空闲连接失败: {e}")
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""
|
|
||||||
关闭连接池
|
|
||||||
"""
|
|
||||||
if self._closed:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._closed = True
|
|
||||||
|
|
||||||
# 停止清理任务
|
|
||||||
if self._cleanup_task:
|
|
||||||
self._cleanup_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._cleanup_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 关闭所有连接
|
|
||||||
while not self.pool.empty():
|
|
||||||
conn = await self.pool.get()
|
|
||||||
await conn.close()
|
|
||||||
|
|
||||||
logger.info("WebSocket 连接池已关闭")
|
|
||||||
21
main.py
21
main.py
@@ -158,12 +158,13 @@ async def main():
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"插件目录不存在 {plugin_path}")
|
logger.warning(f"插件目录不存在 {plugin_path}")
|
||||||
|
|
||||||
|
websocket_client = None
|
||||||
try:
|
try:
|
||||||
# 初始化代码执行器
|
# 初始化代码执行器
|
||||||
code_executor = initialize_executor(config)
|
code_executor = initialize_executor(config)
|
||||||
|
|
||||||
# 使用连接池模式初始化 WebSocket 客户端
|
# 初始化 WebSocket 客户端
|
||||||
websocket_client = WS(code_executor=code_executor, use_pool=True)
|
websocket_client = WS(code_executor=code_executor)
|
||||||
|
|
||||||
# 启动代码执行器的后台 worker
|
# 启动代码执行器的后台 worker
|
||||||
logger.debug("[Main] 检查是否需要启动代码执行 Worker...")
|
logger.debug("[Main] 检查是否需要启动代码执行 Worker...")
|
||||||
@@ -174,11 +175,22 @@ async def main():
|
|||||||
logger.warning("[Main] 未启动代码执行 Worker,因为 Docker 客户端未初始化或连接失败。")
|
logger.warning("[Main] 未启动代码执行 Worker,因为 Docker 客户端未初始化或连接失败。")
|
||||||
|
|
||||||
await websocket_client.connect()
|
await websocket_client.connect()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("主任务被取消,正在停止...")
|
||||||
finally:
|
finally:
|
||||||
|
logger.info("正在清理资源...")
|
||||||
if observer.is_alive():
|
if observer.is_alive():
|
||||||
observer.stop()
|
observer.stop()
|
||||||
observer.join()
|
observer.join()
|
||||||
|
|
||||||
|
if websocket_client:
|
||||||
|
await websocket_client.close()
|
||||||
|
|
||||||
|
# 关闭浏览器管理器
|
||||||
|
await browser_manager.shutdown()
|
||||||
|
|
||||||
|
logger.success("资源清理完成,程序退出。")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
@@ -193,8 +205,9 @@ if __name__ == "__main__":
|
|||||||
try:
|
try:
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
main_logger.info("程序已被用户中断")
|
# 捕获 KeyboardInterrupt,不做任何操作,让 asyncio.run 正常结束
|
||||||
exit(0)
|
# 这样 main 函数中的 finally 块会被执行
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
main_logger.exception("程序发生未处理的全局异常")
|
main_logger.exception("程序发生未处理的全局异常")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user