diff --git a/core/WS.py b/core/WS.py index 4cccba5..12a5413 100644 --- a/core/WS.py +++ b/core/WS.py @@ -127,17 +127,14 @@ class WS: while True: try: # 从连接池获取一个连接 - conn = await self.pool.get_connection() - - try: - # 监听连接上的消息 - async for message in conn.conn: - await self._handle_message(message, conn) - except Exception as e: - self.logger.error(f"连接 {conn.conn_id} 监听异常: {e}") - finally: - # 释放连接回连接池 - await self.pool.release_connection(conn) + # 使用 connection 上下文管理器确保释放 + async with self.pool.connection() as conn: + try: + # 监听连接上的消息 + async for message in conn.conn: + await self._handle_message(message, conn) + except Exception as e: + self.logger.error(f"连接 {conn.conn_id} 监听异常: {e}") except Exception as e: self.logger.error(f"连接池监听循环异常: {e}") await asyncio.sleep(self.reconnect_interval) @@ -324,27 +321,33 @@ class WS: ) # 从连接池获取一个连接 - conn = await self.pool.get_connection() try: - echo_id = str(uuid.uuid4()) - payload = {"action": action, "params": params or {}, "echo": echo_id} + async with self.pool.connection() as conn: + echo_id = str(uuid.uuid4()) + payload = {"action": action, "params": params or {}, "echo": echo_id} - await conn.send(orjson.dumps(payload)) + await conn.send(orjson.dumps(payload)) - # 在当前连接上等待特定 echo 的响应,并设置超时 - try: - async def wait_for_response(): - async for message in conn.conn: - data = orjson.loads(message) - if data.get("echo") == echo_id: - return data - - return await asyncio.wait_for(wait_for_response(), timeout=30.0) + # 在当前连接上等待特定 echo 的响应,并设置超时 + try: + async def wait_for_response(): + async for message in conn.conn: + data = orjson.loads(message) + + # 检查是否是我们要的响应 + if data.get("echo") == echo_id: + return data + + # 如果不是,可能是事件,需要分发 + if "post_type" in data: + asyncio.create_task(self.on_event(data)) + + return await asyncio.wait_for(wait_for_response(), timeout=30.0) - except asyncio.TimeoutError: - raise # 重新抛出超时异常 - except Exception as e: - raise WebSocketError(f"在等待API响应时连接出错: {e}") + except asyncio.TimeoutError: + raise # 重新抛出超时异常 + except Exception as e: + raise WebSocketError(f"在等待API响应时连接出错: {e}") except asyncio.TimeoutError: self.logger.warning(f"API 调用超时: action={action}, params={params}") @@ -360,9 +363,6 @@ class WS: message=f"API调用异常: {str(e)}", data={"action": action, "params": params} ) - finally: - # 释放连接回连接池 - await self.pool.release_connection(conn) else: # 单连接模式 if not self.ws: @@ -409,4 +409,3 @@ class WS: message=f"API调用异常: {str(e)}", data={"action": action, "params": params} ) - diff --git a/core/managers/command_manager.py b/core/managers/command_manager.py index ddfc846..95777ac 100644 --- a/core/managers/command_manager.py +++ b/core/managers/command_manager.py @@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, Optional, Tuple from models.events.message import MessageSegment + from ..config_loader import global_config from ..handlers.event_handler import MessageHandler, NoticeHandler, RequestHandler from .redis_manager import redis_manager diff --git a/core/ws_pool.py b/core/ws_pool.py index aea2578..80a96f7 100644 --- a/core/ws_pool.py +++ b/core/ws_pool.py @@ -7,9 +7,10 @@ WebSocket 连接池模块 import asyncio import websockets from websockets.legacy.client import WebSocketClientProtocol -from typing import Optional, Dict, Any, cast, Union +from typing import Optional, Dict, Any, cast, Union, AsyncGenerator import uuid from loguru import logger +import contextlib from .config_loader import global_config from .utils.exceptions import WebSocketError, WebSocketConnectionError @@ -64,9 +65,11 @@ class WSConnection: if not self.is_active: return False try: - await asyncio.wait_for(self.conn.ping(), timeout=timeout) + # 使用 wait_for 包装 ping + pong_waiter = await self.conn.ping() + await asyncio.wait_for(pong_waiter, timeout=timeout) return True - except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed): + except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed, Exception): self.is_active = False return False @@ -76,7 +79,10 @@ class WSConnection: """ if self.is_active: self.is_active = False - await self.conn.close() + try: + await self.conn.close() + except Exception: + pass class WSConnectionPool: @@ -97,6 +103,8 @@ class WSConnectionPool: self.pool: asyncio.Queue[WSConnection] = asyncio.Queue(maxsize=pool_size) self._closed = False self._cleanup_task: Optional[asyncio.Task] = None + self._current_size = 0 # 当前管理的连接数(包括池中和借出的) + self._lock = asyncio.Lock() # 用于保护 _current_size 的修改 # 从全局配置读取参数 self.url = global_config.napcat_ws.uri @@ -115,14 +123,17 @@ class WSConnectionPool: # 启动连接清理任务 self._cleanup_task = asyncio.create_task(self._cleanup_idle_connections()) - # 创建初始连接 + # 预热连接池 for _ in range(self.pool_size): try: conn = await self._create_connection() await self.pool.put(conn) + async with self._lock: + self._current_size += 1 logger.info(f"WebSocket 连接 {conn.conn_id} 已创建并加入连接池") except Exception as e: logger.error(f"创建初始连接失败: {e}") + # 初始连接失败不抛出异常,允许后续动态创建 async def _create_connection(self) -> WSConnection: """ @@ -143,6 +154,17 @@ class WSConnectionPool: except Exception as e: raise WebSocketConnectionError(f"创建 WebSocket 连接失败: {e}") + @contextlib.asynccontextmanager + async def connection(self) -> AsyncGenerator[WSConnection, None]: + """ + 获取连接的上下文管理器 + """ + conn = await self.get_connection() + try: + yield conn + finally: + await self.release_connection(conn) + async def get_connection(self) -> WSConnection: """ 从连接池获取一个健康的连接,包含健康检查。 @@ -150,25 +172,64 @@ class WSConnectionPool: if self._closed: raise WebSocketError("连接池已关闭") - try: - # 尝试从连接池获取连接 - conn = await asyncio.wait_for(self.pool.get(), timeout=5) - - # 健康检查 - if await conn.ping(): - logger.debug(f"连接 {conn.conn_id} 健康检查通过") - return conn - else: - logger.warning(f"连接 {conn.conn_id} 健康检查失败,丢弃并获取新连接") - await conn.close() - return await self.get_connection() # 递归获取下一个 + start_time = asyncio.get_event_loop().time() + timeout = 10 # 获取连接的总超时时间 - except asyncio.TimeoutError: - # 连接池为空,创建新连接 - logger.warning("连接池在5秒内无可用连接,创建新连接") - return await self._create_connection() - except Exception as e: - raise WebSocketError(f"获取连接时发生未知错误: {e}") + while True: + if asyncio.get_event_loop().time() - start_time > timeout: + raise WebSocketError("获取连接超时") + + try: + # 1. 尝试从池中获取 + conn = self.pool.get_nowait() + + # 健康检查 + if await conn.ping(): + logger.debug(f"连接 {conn.conn_id} 健康检查通过") + return conn + else: + logger.warning(f"连接 {conn.conn_id} 健康检查失败,丢弃") + await conn.close() + async with self._lock: + self._current_size -= 1 + # 继续循环,尝试获取下一个或创建新的 + continue + + except asyncio.QueueEmpty: + # 池为空,检查是否可以创建新连接 + async with self._lock: + if self._current_size < self.pool_size: + # 有配额,创建新连接 + self._current_size += 1 # 先占位 + create_new = True + else: + create_new = False + + if create_new: + try: + conn = await self._create_connection() + return conn + except Exception as e: + async with self._lock: + self._current_size -= 1 # 回滚占位 + logger.error(f"创建新连接失败: {e}") + await asyncio.sleep(1) # 避免快速失败循环 + continue + else: + # 没有配额,等待池中有可用连接 + try: + conn = await asyncio.wait_for(self.pool.get(), timeout=1.0) + # 获取到了,进行健康检查(在下一次循环中处理,或者这里直接处理) + # 为了代码复用,我们把 conn 放回去(或者直接用),这里直接用 + if await conn.ping(): + return conn + else: + await conn.close() + async with self._lock: + self._current_size -= 1 + continue + except asyncio.TimeoutError: + continue async def release_connection(self, conn: WSConnection): """ @@ -180,19 +241,26 @@ class WSConnectionPool: if not conn.is_active: logger.warning(f"连接 {conn.conn_id} 已失效,不返回连接池") + await conn.close() + async with self._lock: + self._current_size -= 1 return try: - if self.pool.full(): - # 连接池已满,关闭该连接 - await conn.close() - logger.info(f"连接池已满,关闭连接 {conn.conn_id}") - else: - await self.pool.put(conn) - logger.debug(f"连接 {conn.conn_id} 已返回连接池") + # 尝试放回池中 + self.pool.put_nowait(conn) + logger.debug(f"连接 {conn.conn_id} 已返回连接池") + except asyncio.QueueFull: + # 理论上不应该发生,除非 _current_size 逻辑有误 + logger.warning(f"连接池已满,关闭多余连接 {conn.conn_id}") + await conn.close() + async with self._lock: + self._current_size -= 1 except Exception as e: logger.error(f"释放连接失败: {e}") await conn.close() + async with self._lock: + self._current_size -= 1 async def _cleanup_idle_connections(self): """ @@ -202,23 +270,33 @@ class WSConnectionPool: await asyncio.sleep(60) # 每分钟检查一次 try: - # 检查连接池中的连接 - new_pool = asyncio.Queue(maxsize=self.pool_size) - current_time = asyncio.get_event_loop().time() + # 我们不替换队列,而是取出检查再放回 + # 这样比较安全,但可能会暂时清空池子 + # 更好的做法是只检查队头的连接 - while not self.pool.empty(): - conn = await self.pool.get() + # 获取当前队列大小 + qsize = self.pool.qsize() + for _ in range(qsize): + try: + conn = self.pool.get_nowait() + except asyncio.QueueEmpty: + break + current_time = asyncio.get_event_loop().time() if current_time - conn.last_used > self.max_idle_time: - # 连接空闲时间过长,关闭 - await conn.close() logger.info(f"清理空闲连接 {conn.conn_id}") + await conn.close() + async with self._lock: + self._current_size -= 1 else: - # 放回新队列 - await new_pool.put(conn) - - # 替换原连接池 - self.pool = new_pool + # 还没过期,放回去 + try: + self.pool.put_nowait(conn) + except asyncio.QueueFull: + # 竞争条件下可能满了 + await conn.close() + async with self._lock: + self._current_size -= 1 except Exception as e: logger.error(f"清理空闲连接失败: {e}") @@ -241,7 +319,10 @@ class WSConnectionPool: # 关闭所有连接 while not self.pool.empty(): - conn = await self.pool.get() - await conn.close() + try: + conn = self.pool.get_nowait() + await conn.close() + except asyncio.QueueEmpty: + break - logger.info("WebSocket 连接池已关闭") \ No newline at end of file + logger.info("WebSocket 连接池已关闭") diff --git a/docs/development-standards.md b/docs/development-standards.md index 56a9578..b46748b 100644 --- a/docs/development-standards.md +++ b/docs/development-standards.md @@ -12,8 +12,10 @@ - **应当**: 使用 `asyncio.sleep()`、异步库(如 `aiohttp`),并通过 `asyncio.to_thread` 或 `run_in_executor` 将同步代码移出主事件循环。 - **禁止**: 直接在异步函数中使用任何可能阻塞的同步调用。 -### 2. 资源管理 -**复用优于重建**。频繁创建和销毁资源(如网络连接、浏览器页面)会严重影响性能。 +### 1.1 异步优先原则 +- **绝对不要阻塞事件循环**:NeoBot 采用多线程异步架构,任何同步阻塞操作都会导致整个机器人卡死。 + - **禁止**:`time.sleep()`、同步 `requests`、密集 CPU 计算 + - **必须**:使用 `await asyncio.sleep()`、异步 HTTP 客户端、线程池执行同步任务 - **应当**: 通过框架提供的单例管理器(如 `redis_manager`, `browser_manager`)获取和管理资源。 - **禁止**: 自行实例化管理器或在插件中创建独立的资源实例(如 `aiohttp.ClientSession`)。 diff --git a/plugins/bot_status.py b/plugins/bot_status.py index e29a9a1..8828533 100644 --- a/plugins/bot_status.py +++ b/plugins/bot_status.py @@ -6,6 +6,7 @@ Bot 状态查询插件 import os import psutil import time +import asyncio from datetime import datetime, timedelta from core.bot import Bot @@ -32,15 +33,23 @@ def _get_system_info(): """ 同步函数:使用 psutil 获取系统信息,避免阻塞事件循环。 """ - # interval=1 会阻塞1秒,必须在线程池中运行 - cpu_percent = psutil.cpu_percent(interval=1) - mem_info = psutil.virtual_memory() - bot_mem_mb = PROCESS.memory_info().rss / (1024 * 1024) - return { - "cpu_percent": f"{cpu_percent:.1f}", - "mem_percent": f"{mem_info.percent:.1f}", - "bot_mem_mb": f"{bot_mem_mb:.2f}", - } + try: + # interval=1 会阻塞1秒,必须在线程池中运行 + cpu_percent = psutil.cpu_percent(interval=1) + mem_info = psutil.virtual_memory() + bot_mem_mb = PROCESS.memory_info().rss / (1024 * 1024) + return { + "cpu_percent": f"{cpu_percent:.1f}", + "mem_percent": f"{mem_info.percent:.1f}", + "bot_mem_mb": f"{bot_mem_mb:.2f}", + } + except Exception as e: + logger.error(f"获取系统信息失败: {e}") + return { + "cpu_percent": "N/A", + "mem_percent": "N/A", + "bot_mem_mb": "N/A", + } @matcher.command("status", "状态") async def handle_status(bot: Bot, event: MessageEvent, args: list[str]): @@ -93,26 +102,51 @@ async def handle_status(bot: Bot, event: MessageEvent, args: list[str]): } # 3. 获取统计数据 - msgs_recv = await redis_manager.get("neobot:stats:messages_received") or 0 - msgs_sent = await redis_manager.get("neobot:stats:messages_sent") or 0 - command_stats_raw = await redis_manager.redis.hgetall("neobot:command_stats") - - total_commands = sum(int(v) for v in command_stats_raw.values()) - - stats_data = { - "messages_received": int(msgs_recv), - "messages_sent": int(msgs_sent), - "total_commands": total_commands, - } + try: + msgs_recv = await redis_manager.get("neobot:stats:messages_received") or 0 + msgs_sent = await redis_manager.get("neobot:stats:messages_sent") or 0 + command_stats_raw = await redis_manager.redis.hgetall("neobot:command_stats") + + total_commands = sum(int(v) for v in command_stats_raw.values()) + + stats_data = { + "messages_received": int(msgs_recv), + "messages_sent": int(msgs_sent), + "total_commands": total_commands, + } - command_stats_data = sorted( - [{"name": k, "count": int(v)} for k, v in command_stats_raw.items()], - key=lambda x: x["count"], - reverse=True - ) + command_stats_data = sorted( + [{"name": k, "count": int(v)} for k, v in command_stats_raw.items()], + key=lambda x: x["count"], + reverse=True + ) + except Exception as e: + logger.error(f"获取Redis统计数据失败: {e}") + stats_data = { + "messages_received": 0, + "messages_sent": 0, + "total_commands": 0, + } + command_stats_data = [] # 4. 异步获取系统信息 - system_data = await run_in_thread_pool(_get_system_info) + # 设置超时,防止 psutil 阻塞过久 + try: + system_data = await asyncio.wait_for(run_in_thread_pool(_get_system_info), timeout=5.0) + except asyncio.TimeoutError: + logger.error("获取系统信息超时") + system_data = { + "cpu_percent": "Timeout", + "mem_percent": "Timeout", + "bot_mem_mb": "Timeout", + } + except Exception as e: + logger.error(f"获取系统信息异常: {e}") + system_data = { + "cpu_percent": "Error", + "mem_percent": "Error", + "bot_mem_mb": "Error", + } # 5. 准备模板所需的所有数据 template_data = { @@ -125,18 +159,22 @@ async def handle_status(bot: Bot, event: MessageEvent, args: list[str]): } # 6. 渲染图片 - base64_str = await image_manager.render_template_to_base64( - template_name="status.html", - data=template_data, - output_name="status.png", - image_type="png" - ) + try: + base64_str = await image_manager.render_template_to_base64( + template_name="status.html", + data=template_data, + output_name="status.png", + image_type="png" + ) - if base64_str: - await event.reply(MessageSegment.image(base64_str)) - else: - # 如果渲染失败,image_manager 内部会记录错误,这里给用户一个通用提示 - await event.reply("状态图片生成失败,可能是渲染服务出现问题,请联系管理员。") + if base64_str: + await event.reply(MessageSegment.image(base64_str)) + else: + # 如果渲染失败,image_manager 内部会记录错误,这里给用户一个通用提示 + await event.reply("状态图片生成失败,可能是渲染服务出现问题,请联系管理员。") + except Exception as e: + logger.error(f"渲染图片失败: {e}") + await event.reply("状态图片渲染过程中发生错误。") except Exception as e: logger.exception(f"生成状态图时发生意外错误, 用户: {event.user_id}")