Revert "refactor(WS): 使用连接池上下文管理器简化连接管理"
This reverts commit c851b49db9.
This commit is contained in:
19
core/WS.py
19
core/WS.py
@@ -127,14 +127,17 @@ class WS:
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# 从连接池获取一个连接
|
# 从连接池获取一个连接
|
||||||
# 使用 connection 上下文管理器确保释放
|
conn = await self.pool.get_connection()
|
||||||
async with self.pool.connection() as conn:
|
|
||||||
try:
|
try:
|
||||||
# 监听连接上的消息
|
# 监听连接上的消息
|
||||||
async for message in conn.conn:
|
async for message in conn.conn:
|
||||||
await self._handle_message(message, conn)
|
await self._handle_message(message, conn)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"连接 {conn.conn_id} 监听异常: {e}")
|
self.logger.error(f"连接 {conn.conn_id} 监听异常: {e}")
|
||||||
|
finally:
|
||||||
|
# 释放连接回连接池
|
||||||
|
await self.pool.release_connection(conn)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"连接池监听循环异常: {e}")
|
self.logger.error(f"连接池监听循环异常: {e}")
|
||||||
await asyncio.sleep(self.reconnect_interval)
|
await asyncio.sleep(self.reconnect_interval)
|
||||||
@@ -321,8 +324,8 @@ class WS:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 从连接池获取一个连接
|
# 从连接池获取一个连接
|
||||||
|
conn = await self.pool.get_connection()
|
||||||
try:
|
try:
|
||||||
async with self.pool.connection() as conn:
|
|
||||||
echo_id = str(uuid.uuid4())
|
echo_id = str(uuid.uuid4())
|
||||||
payload = {"action": action, "params": params or {}, "echo": echo_id}
|
payload = {"action": action, "params": params or {}, "echo": echo_id}
|
||||||
|
|
||||||
@@ -333,15 +336,9 @@ class WS:
|
|||||||
async def wait_for_response():
|
async def wait_for_response():
|
||||||
async for message in conn.conn:
|
async for message in conn.conn:
|
||||||
data = orjson.loads(message)
|
data = orjson.loads(message)
|
||||||
|
|
||||||
# 检查是否是我们要的响应
|
|
||||||
if data.get("echo") == echo_id:
|
if data.get("echo") == echo_id:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
# 如果不是,可能是事件,需要分发
|
|
||||||
if "post_type" in data:
|
|
||||||
asyncio.create_task(self.on_event(data))
|
|
||||||
|
|
||||||
return await asyncio.wait_for(wait_for_response(), timeout=30.0)
|
return await asyncio.wait_for(wait_for_response(), timeout=30.0)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
@@ -363,6 +360,9 @@ class WS:
|
|||||||
message=f"API调用异常: {str(e)}",
|
message=f"API调用异常: {str(e)}",
|
||||||
data={"action": action, "params": params}
|
data={"action": action, "params": params}
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
# 释放连接回连接池
|
||||||
|
await self.pool.release_connection(conn)
|
||||||
else:
|
else:
|
||||||
# 单连接模式
|
# 单连接模式
|
||||||
if not self.ws:
|
if not self.ws:
|
||||||
@@ -409,3 +409,4 @@ class WS:
|
|||||||
message=f"API调用异常: {str(e)}",
|
message=f"API调用异常: {str(e)}",
|
||||||
data={"action": action, "params": params}
|
data={"action": action, "params": params}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
147
core/ws_pool.py
147
core/ws_pool.py
@@ -7,10 +7,9 @@ WebSocket 连接池模块
|
|||||||
import asyncio
|
import asyncio
|
||||||
import websockets
|
import websockets
|
||||||
from websockets.legacy.client import WebSocketClientProtocol
|
from websockets.legacy.client import WebSocketClientProtocol
|
||||||
from typing import Optional, Dict, Any, cast, Union, AsyncGenerator
|
from typing import Optional, Dict, Any, cast, Union
|
||||||
import uuid
|
import uuid
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import contextlib
|
|
||||||
|
|
||||||
from .config_loader import global_config
|
from .config_loader import global_config
|
||||||
from .utils.exceptions import WebSocketError, WebSocketConnectionError
|
from .utils.exceptions import WebSocketError, WebSocketConnectionError
|
||||||
@@ -65,11 +64,9 @@ class WSConnection:
|
|||||||
if not self.is_active:
|
if not self.is_active:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
# 使用 wait_for 包装 ping
|
await asyncio.wait_for(self.conn.ping(), timeout=timeout)
|
||||||
pong_waiter = await self.conn.ping()
|
|
||||||
await asyncio.wait_for(pong_waiter, timeout=timeout)
|
|
||||||
return True
|
return True
|
||||||
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed, Exception):
|
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
|
||||||
self.is_active = False
|
self.is_active = False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -79,10 +76,7 @@ class WSConnection:
|
|||||||
"""
|
"""
|
||||||
if self.is_active:
|
if self.is_active:
|
||||||
self.is_active = False
|
self.is_active = False
|
||||||
try:
|
|
||||||
await self.conn.close()
|
await self.conn.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class WSConnectionPool:
|
class WSConnectionPool:
|
||||||
@@ -103,8 +97,6 @@ class WSConnectionPool:
|
|||||||
self.pool: asyncio.Queue[WSConnection] = asyncio.Queue(maxsize=pool_size)
|
self.pool: asyncio.Queue[WSConnection] = asyncio.Queue(maxsize=pool_size)
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._cleanup_task: Optional[asyncio.Task] = None
|
self._cleanup_task: Optional[asyncio.Task] = None
|
||||||
self._current_size = 0 # 当前管理的连接数(包括池中和借出的)
|
|
||||||
self._lock = asyncio.Lock() # 用于保护 _current_size 的修改
|
|
||||||
|
|
||||||
# 从全局配置读取参数
|
# 从全局配置读取参数
|
||||||
self.url = global_config.napcat_ws.uri
|
self.url = global_config.napcat_ws.uri
|
||||||
@@ -123,17 +115,14 @@ class WSConnectionPool:
|
|||||||
# 启动连接清理任务
|
# 启动连接清理任务
|
||||||
self._cleanup_task = asyncio.create_task(self._cleanup_idle_connections())
|
self._cleanup_task = asyncio.create_task(self._cleanup_idle_connections())
|
||||||
|
|
||||||
# 预热连接池
|
# 创建初始连接
|
||||||
for _ in range(self.pool_size):
|
for _ in range(self.pool_size):
|
||||||
try:
|
try:
|
||||||
conn = await self._create_connection()
|
conn = await self._create_connection()
|
||||||
await self.pool.put(conn)
|
await self.pool.put(conn)
|
||||||
async with self._lock:
|
|
||||||
self._current_size += 1
|
|
||||||
logger.info(f"WebSocket 连接 {conn.conn_id} 已创建并加入连接池")
|
logger.info(f"WebSocket 连接 {conn.conn_id} 已创建并加入连接池")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建初始连接失败: {e}")
|
logger.error(f"创建初始连接失败: {e}")
|
||||||
# 初始连接失败不抛出异常,允许后续动态创建
|
|
||||||
|
|
||||||
async def _create_connection(self) -> WSConnection:
|
async def _create_connection(self) -> WSConnection:
|
||||||
"""
|
"""
|
||||||
@@ -154,17 +143,6 @@ class WSConnectionPool:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise WebSocketConnectionError(f"创建 WebSocket 连接失败: {e}")
|
raise WebSocketConnectionError(f"创建 WebSocket 连接失败: {e}")
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
|
||||||
async def connection(self) -> AsyncGenerator[WSConnection, None]:
|
|
||||||
"""
|
|
||||||
获取连接的上下文管理器
|
|
||||||
"""
|
|
||||||
conn = await self.get_connection()
|
|
||||||
try:
|
|
||||||
yield conn
|
|
||||||
finally:
|
|
||||||
await self.release_connection(conn)
|
|
||||||
|
|
||||||
async def get_connection(self) -> WSConnection:
|
async def get_connection(self) -> WSConnection:
|
||||||
"""
|
"""
|
||||||
从连接池获取一个健康的连接,包含健康检查。
|
从连接池获取一个健康的连接,包含健康检查。
|
||||||
@@ -172,64 +150,25 @@ class WSConnectionPool:
|
|||||||
if self._closed:
|
if self._closed:
|
||||||
raise WebSocketError("连接池已关闭")
|
raise WebSocketError("连接池已关闭")
|
||||||
|
|
||||||
start_time = asyncio.get_event_loop().time()
|
|
||||||
timeout = 10 # 获取连接的总超时时间
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if asyncio.get_event_loop().time() - start_time > timeout:
|
|
||||||
raise WebSocketError("获取连接超时")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 尝试从池中获取
|
# 尝试从连接池获取连接
|
||||||
conn = self.pool.get_nowait()
|
conn = await asyncio.wait_for(self.pool.get(), timeout=5)
|
||||||
|
|
||||||
# 健康检查
|
# 健康检查
|
||||||
if await conn.ping():
|
if await conn.ping():
|
||||||
logger.debug(f"连接 {conn.conn_id} 健康检查通过")
|
logger.debug(f"连接 {conn.conn_id} 健康检查通过")
|
||||||
return conn
|
return conn
|
||||||
else:
|
else:
|
||||||
logger.warning(f"连接 {conn.conn_id} 健康检查失败,丢弃")
|
logger.warning(f"连接 {conn.conn_id} 健康检查失败,丢弃并获取新连接")
|
||||||
await conn.close()
|
await conn.close()
|
||||||
async with self._lock:
|
return await self.get_connection() # 递归获取下一个
|
||||||
self._current_size -= 1
|
|
||||||
# 继续循环,尝试获取下一个或创建新的
|
|
||||||
continue
|
|
||||||
|
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
# 池为空,检查是否可以创建新连接
|
|
||||||
async with self._lock:
|
|
||||||
if self._current_size < self.pool_size:
|
|
||||||
# 有配额,创建新连接
|
|
||||||
self._current_size += 1 # 先占位
|
|
||||||
create_new = True
|
|
||||||
else:
|
|
||||||
create_new = False
|
|
||||||
|
|
||||||
if create_new:
|
|
||||||
try:
|
|
||||||
conn = await self._create_connection()
|
|
||||||
return conn
|
|
||||||
except Exception as e:
|
|
||||||
async with self._lock:
|
|
||||||
self._current_size -= 1 # 回滚占位
|
|
||||||
logger.error(f"创建新连接失败: {e}")
|
|
||||||
await asyncio.sleep(1) # 避免快速失败循环
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# 没有配额,等待池中有可用连接
|
|
||||||
try:
|
|
||||||
conn = await asyncio.wait_for(self.pool.get(), timeout=1.0)
|
|
||||||
# 获取到了,进行健康检查(在下一次循环中处理,或者这里直接处理)
|
|
||||||
# 为了代码复用,我们把 conn 放回去(或者直接用),这里直接用
|
|
||||||
if await conn.ping():
|
|
||||||
return conn
|
|
||||||
else:
|
|
||||||
await conn.close()
|
|
||||||
async with self._lock:
|
|
||||||
self._current_size -= 1
|
|
||||||
continue
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
# 连接池为空,创建新连接
|
||||||
|
logger.warning("连接池在5秒内无可用连接,创建新连接")
|
||||||
|
return await self._create_connection()
|
||||||
|
except Exception as e:
|
||||||
|
raise WebSocketError(f"获取连接时发生未知错误: {e}")
|
||||||
|
|
||||||
async def release_connection(self, conn: WSConnection):
|
async def release_connection(self, conn: WSConnection):
|
||||||
"""
|
"""
|
||||||
@@ -241,26 +180,19 @@ class WSConnectionPool:
|
|||||||
|
|
||||||
if not conn.is_active:
|
if not conn.is_active:
|
||||||
logger.warning(f"连接 {conn.conn_id} 已失效,不返回连接池")
|
logger.warning(f"连接 {conn.conn_id} 已失效,不返回连接池")
|
||||||
await conn.close()
|
|
||||||
async with self._lock:
|
|
||||||
self._current_size -= 1
|
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试放回池中
|
if self.pool.full():
|
||||||
self.pool.put_nowait(conn)
|
# 连接池已满,关闭该连接
|
||||||
logger.debug(f"连接 {conn.conn_id} 已返回连接池")
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
# 理论上不应该发生,除非 _current_size 逻辑有误
|
|
||||||
logger.warning(f"连接池已满,关闭多余连接 {conn.conn_id}")
|
|
||||||
await conn.close()
|
await conn.close()
|
||||||
async with self._lock:
|
logger.info(f"连接池已满,关闭连接 {conn.conn_id}")
|
||||||
self._current_size -= 1
|
else:
|
||||||
|
await self.pool.put(conn)
|
||||||
|
logger.debug(f"连接 {conn.conn_id} 已返回连接池")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"释放连接失败: {e}")
|
logger.error(f"释放连接失败: {e}")
|
||||||
await conn.close()
|
await conn.close()
|
||||||
async with self._lock:
|
|
||||||
self._current_size -= 1
|
|
||||||
|
|
||||||
async def _cleanup_idle_connections(self):
|
async def _cleanup_idle_connections(self):
|
||||||
"""
|
"""
|
||||||
@@ -270,33 +202,23 @@ class WSConnectionPool:
|
|||||||
await asyncio.sleep(60) # 每分钟检查一次
|
await asyncio.sleep(60) # 每分钟检查一次
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 我们不替换队列,而是取出检查再放回
|
# 检查连接池中的连接
|
||||||
# 这样比较安全,但可能会暂时清空池子
|
new_pool = asyncio.Queue(maxsize=self.pool_size)
|
||||||
# 更好的做法是只检查队头的连接
|
|
||||||
|
|
||||||
# 获取当前队列大小
|
|
||||||
qsize = self.pool.qsize()
|
|
||||||
for _ in range(qsize):
|
|
||||||
try:
|
|
||||||
conn = self.pool.get_nowait()
|
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
break
|
|
||||||
|
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
while not self.pool.empty():
|
||||||
|
conn = await self.pool.get()
|
||||||
|
|
||||||
if current_time - conn.last_used > self.max_idle_time:
|
if current_time - conn.last_used > self.max_idle_time:
|
||||||
|
# 连接空闲时间过长,关闭
|
||||||
|
await conn.close()
|
||||||
logger.info(f"清理空闲连接 {conn.conn_id}")
|
logger.info(f"清理空闲连接 {conn.conn_id}")
|
||||||
await conn.close()
|
|
||||||
async with self._lock:
|
|
||||||
self._current_size -= 1
|
|
||||||
else:
|
else:
|
||||||
# 还没过期,放回去
|
# 放回新队列
|
||||||
try:
|
await new_pool.put(conn)
|
||||||
self.pool.put_nowait(conn)
|
|
||||||
except asyncio.QueueFull:
|
# 替换原连接池
|
||||||
# 竞争条件下可能满了
|
self.pool = new_pool
|
||||||
await conn.close()
|
|
||||||
async with self._lock:
|
|
||||||
self._current_size -= 1
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"清理空闲连接失败: {e}")
|
logger.error(f"清理空闲连接失败: {e}")
|
||||||
|
|
||||||
@@ -319,10 +241,7 @@ class WSConnectionPool:
|
|||||||
|
|
||||||
# 关闭所有连接
|
# 关闭所有连接
|
||||||
while not self.pool.empty():
|
while not self.pool.empty():
|
||||||
try:
|
conn = await self.pool.get()
|
||||||
conn = self.pool.get_nowait()
|
|
||||||
await conn.close()
|
await conn.close()
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.info("WebSocket 连接池已关闭")
|
logger.info("WebSocket 连接池已关闭")
|
||||||
@@ -6,7 +6,6 @@ Bot 状态查询插件
|
|||||||
import os
|
import os
|
||||||
import psutil
|
import psutil
|
||||||
import time
|
import time
|
||||||
import asyncio
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from core.bot import Bot
|
from core.bot import Bot
|
||||||
@@ -33,7 +32,6 @@ def _get_system_info():
|
|||||||
"""
|
"""
|
||||||
同步函数:使用 psutil 获取系统信息,避免阻塞事件循环。
|
同步函数:使用 psutil 获取系统信息,避免阻塞事件循环。
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
# interval=1 会阻塞1秒,必须在线程池中运行
|
# interval=1 会阻塞1秒,必须在线程池中运行
|
||||||
cpu_percent = psutil.cpu_percent(interval=1)
|
cpu_percent = psutil.cpu_percent(interval=1)
|
||||||
mem_info = psutil.virtual_memory()
|
mem_info = psutil.virtual_memory()
|
||||||
@@ -43,13 +41,6 @@ def _get_system_info():
|
|||||||
"mem_percent": f"{mem_info.percent:.1f}",
|
"mem_percent": f"{mem_info.percent:.1f}",
|
||||||
"bot_mem_mb": f"{bot_mem_mb:.2f}",
|
"bot_mem_mb": f"{bot_mem_mb:.2f}",
|
||||||
}
|
}
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取系统信息失败: {e}")
|
|
||||||
return {
|
|
||||||
"cpu_percent": "N/A",
|
|
||||||
"mem_percent": "N/A",
|
|
||||||
"bot_mem_mb": "N/A",
|
|
||||||
}
|
|
||||||
|
|
||||||
@matcher.command("status", "状态")
|
@matcher.command("status", "状态")
|
||||||
async def handle_status(bot: Bot, event: MessageEvent, args: list[str]):
|
async def handle_status(bot: Bot, event: MessageEvent, args: list[str]):
|
||||||
@@ -102,7 +93,6 @@ async def handle_status(bot: Bot, event: MessageEvent, args: list[str]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 3. 获取统计数据
|
# 3. 获取统计数据
|
||||||
try:
|
|
||||||
msgs_recv = await redis_manager.get("neobot:stats:messages_received") or 0
|
msgs_recv = await redis_manager.get("neobot:stats:messages_received") or 0
|
||||||
msgs_sent = await redis_manager.get("neobot:stats:messages_sent") or 0
|
msgs_sent = await redis_manager.get("neobot:stats:messages_sent") or 0
|
||||||
command_stats_raw = await redis_manager.redis.hgetall("neobot:command_stats")
|
command_stats_raw = await redis_manager.redis.hgetall("neobot:command_stats")
|
||||||
@@ -120,33 +110,9 @@ async def handle_status(bot: Bot, event: MessageEvent, args: list[str]):
|
|||||||
key=lambda x: x["count"],
|
key=lambda x: x["count"],
|
||||||
reverse=True
|
reverse=True
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取Redis统计数据失败: {e}")
|
|
||||||
stats_data = {
|
|
||||||
"messages_received": 0,
|
|
||||||
"messages_sent": 0,
|
|
||||||
"total_commands": 0,
|
|
||||||
}
|
|
||||||
command_stats_data = []
|
|
||||||
|
|
||||||
# 4. 异步获取系统信息
|
# 4. 异步获取系统信息
|
||||||
# 设置超时,防止 psutil 阻塞过久
|
system_data = await run_in_thread_pool(_get_system_info)
|
||||||
try:
|
|
||||||
system_data = await asyncio.wait_for(run_in_thread_pool(_get_system_info), timeout=5.0)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error("获取系统信息超时")
|
|
||||||
system_data = {
|
|
||||||
"cpu_percent": "Timeout",
|
|
||||||
"mem_percent": "Timeout",
|
|
||||||
"bot_mem_mb": "Timeout",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取系统信息异常: {e}")
|
|
||||||
system_data = {
|
|
||||||
"cpu_percent": "Error",
|
|
||||||
"mem_percent": "Error",
|
|
||||||
"bot_mem_mb": "Error",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 5. 准备模板所需的所有数据
|
# 5. 准备模板所需的所有数据
|
||||||
template_data = {
|
template_data = {
|
||||||
@@ -159,7 +125,6 @@ async def handle_status(bot: Bot, event: MessageEvent, args: list[str]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 6. 渲染图片
|
# 6. 渲染图片
|
||||||
try:
|
|
||||||
base64_str = await image_manager.render_template_to_base64(
|
base64_str = await image_manager.render_template_to_base64(
|
||||||
template_name="status.html",
|
template_name="status.html",
|
||||||
data=template_data,
|
data=template_data,
|
||||||
@@ -172,9 +137,6 @@ async def handle_status(bot: Bot, event: MessageEvent, args: list[str]):
|
|||||||
else:
|
else:
|
||||||
# 如果渲染失败,image_manager 内部会记录错误,这里给用户一个通用提示
|
# 如果渲染失败,image_manager 内部会记录错误,这里给用户一个通用提示
|
||||||
await event.reply("状态图片生成失败,可能是渲染服务出现问题,请联系管理员。")
|
await event.reply("状态图片生成失败,可能是渲染服务出现问题,请联系管理员。")
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"渲染图片失败: {e}")
|
|
||||||
await event.reply("状态图片渲染过程中发生错误。")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"生成状态图时发生意外错误, 用户: {event.user_id}")
|
logger.exception(f"生成状态图时发生意外错误, 用户: {event.user_id}")
|
||||||
|
|||||||
Reference in New Issue
Block a user