refactor(websocket): 移除连接池模式并改进资源清理
移除 WebSocket 连接池实现,改为单连接模式以简化代码结构 在 main 函数中添加资源清理逻辑,确保程序退出时正确关闭所有资源 改进 base64 数据处理逻辑,支持递归处理嵌套结构中的敏感数据 呵呵线程池加WS是神人
This commit is contained in:
285
core/WS.py
285
core/WS.py
@@ -32,7 +32,6 @@ from .utils.exceptions import (
|
||||
WebSocketError, WebSocketConnectionError
|
||||
)
|
||||
from .utils.error_codes import ErrorCode, create_error_response
|
||||
from .ws_pool import WSConnectionPool
|
||||
|
||||
|
||||
class WS:
|
||||
@@ -40,14 +39,13 @@ class WS:
|
||||
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
|
||||
"""
|
||||
|
||||
def __init__(self, code_executor: Optional[CodeExecutor] = None, use_pool: bool = True) -> None:
|
||||
def __init__(self, code_executor: Optional[CodeExecutor] = None) -> None:
|
||||
"""
|
||||
初始化 WebSocket 客户端。
|
||||
|
||||
从全局配置中读取 WebSocket URI、访问令牌(Token)和重连间隔。
|
||||
|
||||
:param code_executor: 代码执行器实例
|
||||
:param use_pool: 是否使用连接池
|
||||
"""
|
||||
# 读取参数
|
||||
cfg = global_config.napcat_ws
|
||||
@@ -61,8 +59,6 @@ class WS:
|
||||
self.bot: 'Bot' | None = None
|
||||
self.self_id: int | None = None
|
||||
self.code_executor = code_executor
|
||||
self.use_pool = use_pool
|
||||
self.pool: Optional[WSConnectionPool] = None
|
||||
|
||||
# 创建模块专用日志记录器
|
||||
self.logger = ModuleLogger("WebSocket")
|
||||
@@ -76,112 +72,39 @@ class WS:
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
||||
|
||||
if self.use_pool:
|
||||
# 使用连接池模式
|
||||
self.pool = WSConnectionPool(pool_size=3)
|
||||
await self.pool.initialize()
|
||||
self.logger.success("WebSocket 连接池初始化完成")
|
||||
|
||||
# 启动连接池监听循环
|
||||
await self._pool_listen_loop()
|
||||
else:
|
||||
# 单连接模式
|
||||
while True:
|
||||
try:
|
||||
self.logger.info(f"正在尝试连接至 NapCat: {self.url}")
|
||||
async with websockets.connect(
|
||||
self.url, additional_headers=headers
|
||||
) as websocket_raw:
|
||||
websocket = cast(WebSocketClientProtocol, websocket_raw)
|
||||
self.ws = websocket
|
||||
self.logger.success("连接成功!")
|
||||
await self._listen_loop(websocket)
|
||||
|
||||
except (
|
||||
websockets.exceptions.ConnectionClosed,
|
||||
ConnectionRefusedError,
|
||||
) as e:
|
||||
conn_error = WebSocketConnectionError(
|
||||
message=f"WebSocket连接失败: {str(e)}",
|
||||
code=ErrorCode.WS_CONNECTION_FAILED,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.error(f"连接失败: {conn_error.message}")
|
||||
self.logger.log_custom_exception(conn_error)
|
||||
except Exception as e:
|
||||
error = WebSocketError(
|
||||
message=f"WebSocket运行异常: {str(e)}",
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.exception(f"运行异常: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
|
||||
self.logger.info(f"{self.reconnect_interval}秒后尝试重连...")
|
||||
await asyncio.sleep(self.reconnect_interval)
|
||||
|
||||
async def _pool_listen_loop(self):
|
||||
"""
|
||||
连接池模式下的监听循环
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# 从连接池获取一个连接
|
||||
conn = await self.pool.get_connection()
|
||||
|
||||
try:
|
||||
# 监听连接上的消息
|
||||
async for message in conn.conn:
|
||||
await self._handle_message(message, conn)
|
||||
except Exception as e:
|
||||
self.logger.error(f"连接 {conn.conn_id} 监听异常: {e}")
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
await self.pool.release_connection(conn)
|
||||
self.logger.info(f"正在尝试连接至 NapCat: {self.url}")
|
||||
async with websockets.connect(
|
||||
self.url, additional_headers=headers
|
||||
) as websocket_raw:
|
||||
websocket = cast(WebSocketClientProtocol, websocket_raw)
|
||||
self.ws = websocket
|
||||
self.logger.success("连接成功!")
|
||||
await self._listen_loop(websocket)
|
||||
|
||||
except (
|
||||
websockets.exceptions.ConnectionClosed,
|
||||
ConnectionRefusedError,
|
||||
) as e:
|
||||
conn_error = WebSocketConnectionError(
|
||||
message=f"WebSocket连接失败: {str(e)}",
|
||||
code=ErrorCode.WS_CONNECTION_FAILED,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.error(f"连接失败: {conn_error.message}")
|
||||
self.logger.log_custom_exception(conn_error)
|
||||
except Exception as e:
|
||||
self.logger.error(f"连接池监听循环异常: {e}")
|
||||
await asyncio.sleep(self.reconnect_interval)
|
||||
|
||||
async def _handle_message(self, message: str, conn):
|
||||
"""
|
||||
处理从连接池获取的消息
|
||||
"""
|
||||
try:
|
||||
data = orjson.loads(message)
|
||||
error = WebSocketError(
|
||||
message=f"WebSocket运行异常: {str(e)}",
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.exception(f"运行异常: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
|
||||
# 1. 处理 API 响应
|
||||
# 如果消息中包含 echo 字段,说明是 API 调用的响应
|
||||
echo_id = data.get("echo")
|
||||
if echo_id and echo_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(echo_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
return
|
||||
|
||||
# 2. 处理上报事件
|
||||
# 如果消息中包含 post_type 字段,说明是 OneBot 上报的事件
|
||||
if "post_type" in data:
|
||||
# 使用 create_task 异步执行,避免阻塞 WebSocket 接收循环
|
||||
asyncio.create_task(self.on_event(data))
|
||||
|
||||
except orjson.JSONDecodeError as e:
|
||||
error = WebSocketError(
|
||||
message=f"JSON解析失败: {str(e)}",
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.error(f"解析消息异常: {error.message}")
|
||||
# 如果message是bytes类型,需要先解码
|
||||
decoded_message = message.decode('utf-8') if isinstance(message, bytes) else message
|
||||
self.logger.debug(f"原始消息: {decoded_message}")
|
||||
except Exception as e:
|
||||
error = WebSocketError(
|
||||
message=f"处理消息异常: {str(e)}",
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.exception(f"解析消息异常: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
self.logger.info(f"{self.reconnect_interval}秒后尝试重连...")
|
||||
await asyncio.sleep(self.reconnect_interval)
|
||||
|
||||
async def _listen_loop(self, websocket_connection: WebSocketClientProtocol) -> None:
|
||||
"""
|
||||
@@ -298,6 +221,23 @@ class WS:
|
||||
)
|
||||
self.logger.log_custom_exception(error)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
关闭 WebSocket 客户端,释放资源。
|
||||
"""
|
||||
self.logger.info("正在关闭 WebSocket 客户端...")
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
# 取消所有挂起的请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
self.logger.success("WebSocket 客户端已关闭")
|
||||
|
||||
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
|
||||
"""
|
||||
向 OneBot v11 实现端发送一个 API 请求。
|
||||
@@ -313,100 +253,47 @@ class WS:
|
||||
dict: OneBot API 的响应数据。如果超时或连接断开,则返回一个
|
||||
表示失败的字典。
|
||||
"""
|
||||
if self.use_pool:
|
||||
# 使用连接池模式
|
||||
if not self.pool:
|
||||
self.logger.error("调用 API 失败: WebSocket 连接池未初始化")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_DISCONNECTED,
|
||||
message="WebSocket连接池未初始化",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
# 从连接池获取一个连接
|
||||
conn = await self.pool.get_connection()
|
||||
try:
|
||||
echo_id = str(uuid.uuid4())
|
||||
payload = {"action": action, "params": params or {}, "echo": echo_id}
|
||||
if not self.ws:
|
||||
self.logger.error("调用 API 失败: WebSocket 未初始化")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_DISCONNECTED,
|
||||
message="WebSocket未初始化",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
await conn.send(orjson.dumps(payload))
|
||||
from websockets.protocol import State
|
||||
|
||||
# 在当前连接上等待特定 echo 的响应,并设置超时
|
||||
try:
|
||||
async def wait_for_response():
|
||||
async for message in conn.conn:
|
||||
data = orjson.loads(message)
|
||||
if data.get("echo") == echo_id:
|
||||
return data
|
||||
|
||||
return await asyncio.wait_for(wait_for_response(), timeout=30.0)
|
||||
if getattr(self.ws, "state", None) is not State.OPEN:
|
||||
self.logger.error("调用 API 失败: WebSocket 连接未打开")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_DISCONNECTED,
|
||||
message="WebSocket连接未打开",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise # 重新抛出超时异常
|
||||
except Exception as e:
|
||||
raise WebSocketError(f"在等待API响应时连接出错: {e}")
|
||||
echo_id = str(uuid.uuid4())
|
||||
payload = {"action": action, "params": params or {}, "echo": echo_id}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning(f"API 调用超时: action={action}, params={params}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.TIMEOUT_ERROR,
|
||||
message="API调用超时",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
message=f"API调用异常: {str(e)}",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
await self.pool.release_connection(conn)
|
||||
else:
|
||||
# 单连接模式
|
||||
if not self.ws:
|
||||
self.logger.error("调用 API 失败: WebSocket 未初始化")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_DISCONNECTED,
|
||||
message="WebSocket未初始化",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
from websockets.protocol import State
|
||||
|
||||
if getattr(self.ws, "state", None) is not State.OPEN:
|
||||
self.logger.error("调用 API 失败: WebSocket 连接未打开")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_DISCONNECTED,
|
||||
message="WebSocket连接未打开",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
echo_id = str(uuid.uuid4())
|
||||
payload = {"action": action, "params": params or {}, "echo": echo_id}
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
self._pending_requests[echo_id] = future
|
||||
|
||||
try:
|
||||
await self.ws.send(orjson.dumps(payload))
|
||||
return await asyncio.wait_for(future, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.warning(f"API 调用超时: action={action}, params={params}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.TIMEOUT_ERROR,
|
||||
message="API调用超时",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
message=f"API调用异常: {str(e)}",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
self._pending_requests[echo_id] = future
|
||||
|
||||
try:
|
||||
await self.ws.send(orjson.dumps(payload))
|
||||
return await asyncio.wait_for(future, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.warning(f"API 调用超时: action={action}, params={params}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.TIMEOUT_ERROR,
|
||||
message="API调用超时",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
message=f"API调用异常: {str(e)}",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user