Revert "refactor(WS): 使用连接池上下文管理器简化连接管理"

This reverts commit c851b49db9.
This commit is contained in:
2026-01-23 17:37:41 +08:00
parent 0e04829ac9
commit 57a04e436b
3 changed files with 114 additions and 232 deletions

View File

@@ -127,14 +127,17 @@ class WS:
while True: while True:
try: try:
# 从连接池获取一个连接 # 从连接池获取一个连接
# 使用 connection 上下文管理器确保释放 conn = await self.pool.get_connection()
async with self.pool.connection() as conn:
try: try:
# 监听连接上的消息 # 监听连接上的消息
async for message in conn.conn: async for message in conn.conn:
await self._handle_message(message, conn) await self._handle_message(message, conn)
except Exception as e: except Exception as e:
self.logger.error(f"连接 {conn.conn_id} 监听异常: {e}") self.logger.error(f"连接 {conn.conn_id} 监听异常: {e}")
finally:
# 释放连接回连接池
await self.pool.release_connection(conn)
except Exception as e: except Exception as e:
self.logger.error(f"连接池监听循环异常: {e}") self.logger.error(f"连接池监听循环异常: {e}")
await asyncio.sleep(self.reconnect_interval) await asyncio.sleep(self.reconnect_interval)
@@ -321,33 +324,27 @@ class WS:
) )
# 从连接池获取一个连接 # 从连接池获取一个连接
conn = await self.pool.get_connection()
try: try:
async with self.pool.connection() as conn: echo_id = str(uuid.uuid4())
echo_id = str(uuid.uuid4()) payload = {"action": action, "params": params or {}, "echo": echo_id}
payload = {"action": action, "params": params or {}, "echo": echo_id}
await conn.send(orjson.dumps(payload)) await conn.send(orjson.dumps(payload))
# 在当前连接上等待特定 echo 的响应,并设置超时 # 在当前连接上等待特定 echo 的响应,并设置超时
try: try:
async def wait_for_response(): async def wait_for_response():
async for message in conn.conn: async for message in conn.conn:
data = orjson.loads(message) data = orjson.loads(message)
if data.get("echo") == echo_id:
# 检查是否是我们要的响应 return data
if data.get("echo") == echo_id:
return data return await asyncio.wait_for(wait_for_response(), timeout=30.0)
# 如果不是,可能是事件,需要分发
if "post_type" in data:
asyncio.create_task(self.on_event(data))
return await asyncio.wait_for(wait_for_response(), timeout=30.0)
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise # 重新抛出超时异常 raise # 重新抛出超时异常
except Exception as e: except Exception as e:
raise WebSocketError(f"在等待API响应时连接出错: {e}") raise WebSocketError(f"在等待API响应时连接出错: {e}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.logger.warning(f"API 调用超时: action={action}, params={params}") self.logger.warning(f"API 调用超时: action={action}, params={params}")
@@ -363,6 +360,9 @@ class WS:
message=f"API调用异常: {str(e)}", message=f"API调用异常: {str(e)}",
data={"action": action, "params": params} data={"action": action, "params": params}
) )
finally:
# 释放连接回连接池
await self.pool.release_connection(conn)
else: else:
# 单连接模式 # 单连接模式
if not self.ws: if not self.ws:
@@ -409,3 +409,4 @@ class WS:
message=f"API调用异常: {str(e)}", message=f"API调用异常: {str(e)}",
data={"action": action, "params": params} data={"action": action, "params": params}
) )

View File

@@ -7,10 +7,9 @@ WebSocket 连接池模块
import asyncio import asyncio
import websockets import websockets
from websockets.legacy.client import WebSocketClientProtocol from websockets.legacy.client import WebSocketClientProtocol
from typing import Optional, Dict, Any, cast, Union, AsyncGenerator from typing import Optional, Dict, Any, cast, Union
import uuid import uuid
from loguru import logger from loguru import logger
import contextlib
from .config_loader import global_config from .config_loader import global_config
from .utils.exceptions import WebSocketError, WebSocketConnectionError from .utils.exceptions import WebSocketError, WebSocketConnectionError
@@ -65,11 +64,9 @@ class WSConnection:
if not self.is_active: if not self.is_active:
return False return False
try: try:
# 使用 wait_for 包装 ping await asyncio.wait_for(self.conn.ping(), timeout=timeout)
pong_waiter = await self.conn.ping()
await asyncio.wait_for(pong_waiter, timeout=timeout)
return True return True
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed, Exception): except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
self.is_active = False self.is_active = False
return False return False
@@ -79,10 +76,7 @@ class WSConnection:
""" """
if self.is_active: if self.is_active:
self.is_active = False self.is_active = False
try: await self.conn.close()
await self.conn.close()
except Exception:
pass
class WSConnectionPool: class WSConnectionPool:
@@ -103,8 +97,6 @@ class WSConnectionPool:
self.pool: asyncio.Queue[WSConnection] = asyncio.Queue(maxsize=pool_size) self.pool: asyncio.Queue[WSConnection] = asyncio.Queue(maxsize=pool_size)
self._closed = False self._closed = False
self._cleanup_task: Optional[asyncio.Task] = None self._cleanup_task: Optional[asyncio.Task] = None
self._current_size = 0 # 当前管理的连接数(包括池中和借出的)
self._lock = asyncio.Lock() # 用于保护 _current_size 的修改
# 从全局配置读取参数 # 从全局配置读取参数
self.url = global_config.napcat_ws.uri self.url = global_config.napcat_ws.uri
@@ -123,17 +115,14 @@ class WSConnectionPool:
# 启动连接清理任务 # 启动连接清理任务
self._cleanup_task = asyncio.create_task(self._cleanup_idle_connections()) self._cleanup_task = asyncio.create_task(self._cleanup_idle_connections())
# 预热连接 # 创建初始连接
for _ in range(self.pool_size): for _ in range(self.pool_size):
try: try:
conn = await self._create_connection() conn = await self._create_connection()
await self.pool.put(conn) await self.pool.put(conn)
async with self._lock:
self._current_size += 1
logger.info(f"WebSocket 连接 {conn.conn_id} 已创建并加入连接池") logger.info(f"WebSocket 连接 {conn.conn_id} 已创建并加入连接池")
except Exception as e: except Exception as e:
logger.error(f"创建初始连接失败: {e}") logger.error(f"创建初始连接失败: {e}")
# 初始连接失败不抛出异常,允许后续动态创建
async def _create_connection(self) -> WSConnection: async def _create_connection(self) -> WSConnection:
""" """
@@ -154,17 +143,6 @@ class WSConnectionPool:
except Exception as e: except Exception as e:
raise WebSocketConnectionError(f"创建 WebSocket 连接失败: {e}") raise WebSocketConnectionError(f"创建 WebSocket 连接失败: {e}")
@contextlib.asynccontextmanager
async def connection(self) -> AsyncGenerator[WSConnection, None]:
"""
获取连接的上下文管理器
"""
conn = await self.get_connection()
try:
yield conn
finally:
await self.release_connection(conn)
async def get_connection(self) -> WSConnection: async def get_connection(self) -> WSConnection:
""" """
从连接池获取一个健康的连接,包含健康检查。 从连接池获取一个健康的连接,包含健康检查。
@@ -172,64 +150,25 @@ class WSConnectionPool:
if self._closed: if self._closed:
raise WebSocketError("连接池已关闭") raise WebSocketError("连接池已关闭")
start_time = asyncio.get_event_loop().time() try:
timeout = 10 # 获取连接的总超时时间 # 尝试从连接池获取连接
conn = await asyncio.wait_for(self.pool.get(), timeout=5)
# 健康检查
if await conn.ping():
logger.debug(f"连接 {conn.conn_id} 健康检查通过")
return conn
else:
logger.warning(f"连接 {conn.conn_id} 健康检查失败,丢弃并获取新连接")
await conn.close()
return await self.get_connection() # 递归获取下一个
while True: except asyncio.TimeoutError:
if asyncio.get_event_loop().time() - start_time > timeout: # 连接池为空,创建新连接
raise WebSocketError("获取连接超时") logger.warning("连接池在5秒内无可用连接创建新连接")
return await self._create_connection()
try: except Exception as e:
# 1. 尝试从池中获取 raise WebSocketError(f"获取连接时发生未知错误: {e}")
conn = self.pool.get_nowait()
# 健康检查
if await conn.ping():
logger.debug(f"连接 {conn.conn_id} 健康检查通过")
return conn
else:
logger.warning(f"连接 {conn.conn_id} 健康检查失败,丢弃")
await conn.close()
async with self._lock:
self._current_size -= 1
# 继续循环,尝试获取下一个或创建新的
continue
except asyncio.QueueEmpty:
# 池为空,检查是否可以创建新连接
async with self._lock:
if self._current_size < self.pool_size:
# 有配额,创建新连接
self._current_size += 1 # 先占位
create_new = True
else:
create_new = False
if create_new:
try:
conn = await self._create_connection()
return conn
except Exception as e:
async with self._lock:
self._current_size -= 1 # 回滚占位
logger.error(f"创建新连接失败: {e}")
await asyncio.sleep(1) # 避免快速失败循环
continue
else:
# 没有配额,等待池中有可用连接
try:
conn = await asyncio.wait_for(self.pool.get(), timeout=1.0)
# 获取到了,进行健康检查(在下一次循环中处理,或者这里直接处理)
# 为了代码复用,我们把 conn 放回去(或者直接用),这里直接用
if await conn.ping():
return conn
else:
await conn.close()
async with self._lock:
self._current_size -= 1
continue
except asyncio.TimeoutError:
continue
async def release_connection(self, conn: WSConnection): async def release_connection(self, conn: WSConnection):
""" """
@@ -241,26 +180,19 @@ class WSConnectionPool:
if not conn.is_active: if not conn.is_active:
logger.warning(f"连接 {conn.conn_id} 已失效,不返回连接池") logger.warning(f"连接 {conn.conn_id} 已失效,不返回连接池")
await conn.close()
async with self._lock:
self._current_size -= 1
return return
try: try:
# 尝试放回池中 if self.pool.full():
self.pool.put_nowait(conn) # 连接池已满,关闭该连接
logger.debug(f"连接 {conn.conn_id} 已返回连接池") await conn.close()
except asyncio.QueueFull: logger.info(f"连接池已满,关闭连接 {conn.conn_id}")
# 理论上不应该发生,除非 _current_size 逻辑有误 else:
logger.warning(f"连接池已满,关闭多余连接 {conn.conn_id}") await self.pool.put(conn)
await conn.close() logger.debug(f"连接 {conn.conn_id} 已返回连接池")
async with self._lock:
self._current_size -= 1
except Exception as e: except Exception as e:
logger.error(f"释放连接失败: {e}") logger.error(f"释放连接失败: {e}")
await conn.close() await conn.close()
async with self._lock:
self._current_size -= 1
async def _cleanup_idle_connections(self): async def _cleanup_idle_connections(self):
""" """
@@ -270,33 +202,23 @@ class WSConnectionPool:
await asyncio.sleep(60) # 每分钟检查一次 await asyncio.sleep(60) # 每分钟检查一次
try: try:
# 我们不替换队列,而是取出检查再放回 # 检查连接池中的连接
# 这样比较安全,但可能会暂时清空池子 new_pool = asyncio.Queue(maxsize=self.pool_size)
# 更好的做法是只检查队头的连接 current_time = asyncio.get_event_loop().time()
# 获取当前队列大小 while not self.pool.empty():
qsize = self.pool.qsize() conn = await self.pool.get()
for _ in range(qsize):
try:
conn = self.pool.get_nowait()
except asyncio.QueueEmpty:
break
current_time = asyncio.get_event_loop().time()
if current_time - conn.last_used > self.max_idle_time: if current_time - conn.last_used > self.max_idle_time:
logger.info(f"清理空闲连接 {conn.conn_id}") # 连接空闲时间过长,关闭
await conn.close() await conn.close()
async with self._lock: logger.info(f"清理空闲连接 {conn.conn_id}")
self._current_size -= 1
else: else:
# 还没过期,放回去 # 放回新队列
try: await new_pool.put(conn)
self.pool.put_nowait(conn)
except asyncio.QueueFull: # 替换原连接池
# 竞争条件下可能满了 self.pool = new_pool
await conn.close()
async with self._lock:
self._current_size -= 1
except Exception as e: except Exception as e:
logger.error(f"清理空闲连接失败: {e}") logger.error(f"清理空闲连接失败: {e}")
@@ -319,10 +241,7 @@ class WSConnectionPool:
# 关闭所有连接 # 关闭所有连接
while not self.pool.empty(): while not self.pool.empty():
try: conn = await self.pool.get()
conn = self.pool.get_nowait() await conn.close()
await conn.close()
except asyncio.QueueEmpty:
break
logger.info("WebSocket 连接池已关闭") logger.info("WebSocket 连接池已关闭")

View File

@@ -6,7 +6,6 @@ Bot 状态查询插件
import os import os
import psutil import psutil
import time import time
import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from core.bot import Bot from core.bot import Bot
@@ -33,23 +32,15 @@ def _get_system_info():
""" """
同步函数:使用 psutil 获取系统信息,避免阻塞事件循环。 同步函数:使用 psutil 获取系统信息,避免阻塞事件循环。
""" """
try: # interval=1 会阻塞1秒必须在线程池中运行
# interval=1 会阻塞1秒必须在线程池中运行 cpu_percent = psutil.cpu_percent(interval=1)
cpu_percent = psutil.cpu_percent(interval=1) mem_info = psutil.virtual_memory()
mem_info = psutil.virtual_memory() bot_mem_mb = PROCESS.memory_info().rss / (1024 * 1024)
bot_mem_mb = PROCESS.memory_info().rss / (1024 * 1024) return {
return { "cpu_percent": f"{cpu_percent:.1f}",
"cpu_percent": f"{cpu_percent:.1f}", "mem_percent": f"{mem_info.percent:.1f}",
"mem_percent": f"{mem_info.percent:.1f}", "bot_mem_mb": f"{bot_mem_mb:.2f}",
"bot_mem_mb": f"{bot_mem_mb:.2f}", }
}
except Exception as e:
logger.error(f"获取系统信息失败: {e}")
return {
"cpu_percent": "N/A",
"mem_percent": "N/A",
"bot_mem_mb": "N/A",
}
@matcher.command("status", "状态") @matcher.command("status", "状态")
async def handle_status(bot: Bot, event: MessageEvent, args: list[str]): async def handle_status(bot: Bot, event: MessageEvent, args: list[str]):
@@ -102,51 +93,26 @@ async def handle_status(bot: Bot, event: MessageEvent, args: list[str]):
} }
# 3. 获取统计数据 # 3. 获取统计数据
try: msgs_recv = await redis_manager.get("neobot:stats:messages_received") or 0
msgs_recv = await redis_manager.get("neobot:stats:messages_received") or 0 msgs_sent = await redis_manager.get("neobot:stats:messages_sent") or 0
msgs_sent = await redis_manager.get("neobot:stats:messages_sent") or 0 command_stats_raw = await redis_manager.redis.hgetall("neobot:command_stats")
command_stats_raw = await redis_manager.redis.hgetall("neobot:command_stats")
total_commands = sum(int(v) for v in command_stats_raw.values())
total_commands = sum(int(v) for v in command_stats_raw.values())
stats_data = {
stats_data = { "messages_received": int(msgs_recv),
"messages_received": int(msgs_recv), "messages_sent": int(msgs_sent),
"messages_sent": int(msgs_sent), "total_commands": total_commands,
"total_commands": total_commands, }
}
command_stats_data = sorted( command_stats_data = sorted(
[{"name": k, "count": int(v)} for k, v in command_stats_raw.items()], [{"name": k, "count": int(v)} for k, v in command_stats_raw.items()],
key=lambda x: x["count"], key=lambda x: x["count"],
reverse=True reverse=True
) )
except Exception as e:
logger.error(f"获取Redis统计数据失败: {e}")
stats_data = {
"messages_received": 0,
"messages_sent": 0,
"total_commands": 0,
}
command_stats_data = []
# 4. 异步获取系统信息 # 4. 异步获取系统信息
# 设置超时,防止 psutil 阻塞过久 system_data = await run_in_thread_pool(_get_system_info)
try:
system_data = await asyncio.wait_for(run_in_thread_pool(_get_system_info), timeout=5.0)
except asyncio.TimeoutError:
logger.error("获取系统信息超时")
system_data = {
"cpu_percent": "Timeout",
"mem_percent": "Timeout",
"bot_mem_mb": "Timeout",
}
except Exception as e:
logger.error(f"获取系统信息异常: {e}")
system_data = {
"cpu_percent": "Error",
"mem_percent": "Error",
"bot_mem_mb": "Error",
}
# 5. 准备模板所需的所有数据 # 5. 准备模板所需的所有数据
template_data = { template_data = {
@@ -159,22 +125,18 @@ async def handle_status(bot: Bot, event: MessageEvent, args: list[str]):
} }
# 6. 渲染图片 # 6. 渲染图片
try: base64_str = await image_manager.render_template_to_base64(
base64_str = await image_manager.render_template_to_base64( template_name="status.html",
template_name="status.html", data=template_data,
data=template_data, output_name="status.png",
output_name="status.png", image_type="png"
image_type="png" )
)
if base64_str: if base64_str:
await event.reply(MessageSegment.image(base64_str)) await event.reply(MessageSegment.image(base64_str))
else: else:
# 如果渲染失败image_manager 内部会记录错误,这里给用户一个通用提示 # 如果渲染失败image_manager 内部会记录错误,这里给用户一个通用提示
await event.reply("状态图片生成失败,可能是渲染服务出现问题,请联系管理员。") await event.reply("状态图片生成失败,可能是渲染服务出现问题,请联系管理员。")
except Exception as e:
logger.error(f"渲染图片失败: {e}")
await event.reply("状态图片渲染过程中发生错误。")
except Exception as e: except Exception as e:
logger.exception(f"生成状态图时发生意外错误, 用户: {event.user_id}") logger.exception(f"生成状态图时发生意外错误, 用户: {event.user_id}")