* 滚木

* feat: 重构核心架构,增强类型安全与插件管理

本次提交对核心模块进行了深度重构,引入 Pydantic 增强配置管理的类型安全性,并全面优化了插件管理系统。

主要变更详情:

1. 核心架构与配置
   - 重构配置加载模块:引入 Pydantic 模型 (`core/config_models.py`),提供严格的配置项类型检查、验证及默认值管理。
   - 统一模块结构:规范化模块导入路径,移除冗余的 `__init__.py` 文件,提升项目结构的清晰度。
   - 性能优化:集成 Redis 缓存支持 (`RedisManager`),有效降低高频 API 调用开销,提升响应速度。

2. 插件系统升级
   - 实现热重载机制:新增插件文件变更监听功能,支持开发过程中自动重载插件,提升开发效率。
   - 优化生命周期管理:改进插件加载与卸载逻辑,支持精确卸载指定插件及其关联的命令、事件处理器和定时任务。

3. 功能特性增强
   - 新增媒体 API:引入 `MediaAPI` 模块,封装图片、语音等富媒体资源的获取与处理接口。
   - 完善权限体系:重构权限管理系统,实现管理员与操作员的分级控制,支持更细粒度的命令权限校验。

4. 代码质量与稳定性
   - 全面类型修复:解决 `mypy` 静态类型检查发现的大量类型错误(包括 `CommandManager`、`EventFactory` 及 `Bot` API 签名不匹配问题)。
   - 增强错误处理:优化消息处理管道的异常捕获机制,完善关键路径的日志记录,提升系统运行稳定性。

* feat: 添加测试用例并优化代码结构

refactor(permission_manager): 调整初始化顺序和逻辑
fix(admin_manager): 修复初始化逻辑和目录创建问题
feat(ws): 优化Bot实例初始化条件
feat(message): 增强MessageSegment功能并添加测试
feat(events): 支持字符串格式的消息解析
test: 添加核心功能测试用例
refactor(plugin_manager): 改进插件路径处理
style: 清理无用导入和代码
chore: 更新依赖项

* refactor(handler): 移除TYPE_CHECKING并直接导入Bot类

简化类型注解,直接导入Bot类而非使用TYPE_CHECKING条件导入,提高代码可读性和维护性

* fix(command_manager): 修复插件卸载时元信息移除不精确的问题

修复 CommandManager 中 unload_plugin 方法移除插件元信息时使用 startswith 导致可能误删其他插件的问题,改为精确匹配
同时调整相关测试用例验证精确匹配行为

* refactor: 清理未使用的导入和更新文档结构

docs: 添加config_models.py到项目结构文档
docs: 调整数据目录位置到core/data下
docs: 更新权限管理器文档描述

* 文档更新

* 更新thpic插件 支持一次返回多张图

* feat: 添加测试覆盖率并修复相关问题

refactor(redis_manager): 移除冗余的ConnectionError处理
refactor(event_handler): 优化Bot类型注解
refactor(factory): 移除未使用的GroupCardNoticeEvent

test: 添加全面的单元测试覆盖
- 添加test_import.py测试模块导入
- 添加test_debug.py测试插件加载调试
- 添加test_plugin_error.py测试错误处理
- 添加test_config_loader.py测试配置加载
- 添加test_redis_manager.py测试Redis管理
- 添加test_bot.py测试Bot功能
- 扩展test_models.py测试消息模型
- 添加test_plugin_manager_coverage.py测试插件管理
- 添加test_executor.py测试代码执行器
- 添加test_ws.py测试WebSocket
- 添加test_api.py测试API接口
- 添加test_core_managers.py测试核心管理模块

fix(plugin_manager): 修复插件加载日志变量问题

覆盖率已到达86%(忽略插件)

* 更新/help指令,现在会发送图片

* feat(help): 重构帮助系统为图片渲染模式

添加浏览器管理器和图片管理器,用于通过 Playwright 渲染帮助菜单为图片
重构命令管理器以支持图片缓存和同步功能
添加 HTML 模板用于帮助菜单渲染

* build: 更新依赖文件 requirements.txt

* build: 更新依赖文件

* feat: 添加性能优化和架构文档,更新依赖和核心模块

refactor(browser_manager): 实现页面池机制以提升性能
refactor(image_manager): 添加模板缓存并集成页面池
refactor(bili_parser): 迁移到异步HTTP请求并实现会话复用
docs: 新增性能优化、架构设计和最佳实践文档
chore: 更新requirements.txt添加新依赖

* docs: 更新文档内容并优化语言风格

重构所有文档内容,使用更简洁直接的语言风格
更新架构、插件开发、部署等核心文档
优化代码示例和图表说明
统一术语和格式规范

* docs: 更新文档内容,简化语言并修正格式

- 简化插件开发指南中的描述,移除冗余内容
- 调整部署文档中的Python版本说明
- 优化最佳实践文档的措辞和格式
- 更新性能优化文档,删除不准确的数据
- 重构核心概念文档,使用更简洁的语言
- 修正README中的项目描述和技术栈说明
- 更新快速上手文档,简化安装步骤
- 调整事件流转文档的描述方式
- 简化架构文档内容
- 更新指令处理文档,添加参数注入示例
- 优化单例管理器文档的表述

* refactor(core): 优化权限管理和事件模型

- 重构 AdminManager 和 PermissionManager 以 Redis 为主要数据源
- 为所有事件模型添加 slots=True 提升性能
- 更新文档说明 Mypyc 编译注意事项
- 清理测试和调试文件
- 移动静态资源到 web_static 目录

* feat: 添加模块编译脚本和导出依赖功能

refactor(events): 移除数据类的slots参数以提升兼容性
build: 更新requirements.txt依赖列表

* docs: 更新性能优化文档并修复命令管理器帮助输出

更新性能优化相关文档,详细说明 Python 3.14 JIT 编译器的使用方法和原理,补充与 Mypyc 的互补策略。同时修复命令管理器中帮助信息的输出方式,移除图片发送仅保留文本输出。

调整部署文档结构,明确两种性能优化方案(AOT 和 JIT)的配置方法和适用场景。完善架构文档中关于 JIT 的原理和启用方式说明。

* feat(help): 重构帮助菜单界面并优化样式

refactor(bili_parser): 修复 API 响应 content-type 问题
fix(command_manager): 添加帮助图片获取的错误处理
docs(deployment): 简化部署文档并移除 JIT 相关内容

* feat: 新增自动同意请求插件和API文档

docs: 更新文档结构和内容

* refactor(scripts): 重构并优化脚本文件结构

feat(scripts): 添加Python环境检查脚本
feat(scripts): 增强依赖导出脚本功能
perf(plugins/bili_parser): 优化B站解析器性能和代码结构
style(plugins/bili_parser): 统一代码风格和常量命名

* fix(scripts): 修复编码问题并添加错误追踪

在compile_machine_code.py中添加utf-8编码设置以避免潜在编码问题
添加traceback.print_exc()以在编译失败时打印完整错误堆栈
更新.gitignore以忽略config.toml文件

* feat(性能分析): 实现性能分析工具模块并添加相关测试

添加性能分析工具模块,包括时间测量、内存分析和性能统计功能
添加测试文件和示例配置,完善性能分析工具的使用场景
在工具模块中实现单例装饰器并导出到__init__.py

* feat(douyin_parser): 新增抖音视频解析插件

refactor(performance): 移除未使用的asyncio导入并优化性能测试
style(compile_modules): 修正字符串引号格式
chore: 删除废弃的编译脚本和临时文件
fix(bili_parser): 增强B站链接解析的健壮性
refactor(singleton): 重构单例模式实现
docs: 更新配置文件和事件模型注释

* feat: 添加抖音视频解析插件并优化代码结构

添加抖音视频解析插件,支持自动解析抖音分享链接并提取视频信息。优化现有代码结构,包括:
- 重构单例模式实现
- 移除未使用的导入和文件
- 修复性能测试脚本中的异步调用
- 优化消息事件模型中的权限常量定义
- 改进编译脚本的错误处理
- 增强B站解析插件的稳定性

同时清理了多个废弃脚本和临时文件,提升代码可维护性。

* 1

* Delete core/data/temp/help_menu.png

* fix(权限管理): 增强权限检查的类型安全并修复权限引用

修复权限检查中可能传入非Permission类型导致的错误,将echo插件的权限引用从MessageEvent.ADMIN迁移到Permission.ADMIN

* redis取消tls

* feat(github_parser): 添加GitHub仓库信息查询功能

- 新增github_parser插件,支持通过命令或自动解析链接查询GitHub仓库信息
- 添加github_repo.html模板用于渲染仓库信息图片
- 优化图片管理器支持高质量截图和CSS缩放
- 重构消息事件类权限常量定义方式
- 更新帮助页面样式为三列布局并优化响应式设计

* feat(web_parser): 新增通用web链接解析插件框架

refactor: 重构B站、抖音、GitHub解析器为模块化结构

fix(executor): 增强docker容器错误处理和回调稳定性

style(templates): 优化帮助页面和代码执行结果的样式

perf(web_parser): 添加API缓存和消息去重机制

docs: 更新插件元信息和注释

chore: 移除旧的独立解析器插件文件

* refactor(managers): 重构单例管理器实现并优化代码结构

feat(ws_pool): 新增 WebSocket 连接池实现

perf(json): 使用 orjson 替代标准 json 库提升性能

style: 清理未使用的导入和冗余代码

docs: 更新架构文档和开发规范

test: 添加 WebSocket 连接池测试用例

fix(plugins): 修复自动审批插件 API 调用参数格式

---------

Co-authored-by: baby20162016 <2185823427@qq.com>
Co-authored-by: web vscode <youremail@example.com>
This commit is contained in:
镀铬酸钾
2026-01-22 16:25:13 +08:00
committed by GitHub
parent 8a6af1ea2a
commit 12d1eb3438
42 changed files with 1285 additions and 261 deletions

View File

@@ -12,7 +12,7 @@ WebSocket 连接。它是整个机器人框架的底层通信基础。
- 提供 `call_api` 方法,用于异步发送 API 请求并等待响应。 - 提供 `call_api` 方法,用于异步发送 API 请求并等待响应。
""" """
import asyncio import asyncio
import json import orjson
from typing import Any, Dict, Optional, cast from typing import Any, Dict, Optional, cast
import uuid import uuid
@@ -25,11 +25,12 @@ from .bot import Bot
from .config_loader import global_config from .config_loader import global_config
from .managers.command_manager import matcher from .managers.command_manager import matcher
from .utils.executor import CodeExecutor from .utils.executor import CodeExecutor
from .utils.logger import logger, ModuleLogger from .utils.logger import ModuleLogger
from .utils.exceptions import ( from .utils.exceptions import (
WebSocketError, WebSocketConnectionError, WebSocketAuthenticationError WebSocketError, WebSocketConnectionError
) )
from .utils.error_codes import ErrorCode, create_error_response from .utils.error_codes import ErrorCode, create_error_response
from .ws_pool import WSConnectionPool
class WS: class WS:
@@ -37,11 +38,14 @@ class WS:
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。 WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
""" """
def __init__(self, code_executor: Optional[CodeExecutor] = None) -> None: def __init__(self, code_executor: Optional[CodeExecutor] = None, use_pool: bool = True) -> None:
""" """
初始化 WebSocket 客户端。 初始化 WebSocket 客户端。
从全局配置中读取 WebSocket URI、访问令牌Token和重连间隔。 从全局配置中读取 WebSocket URI、访问令牌Token和重连间隔。
:param code_executor: 代码执行器实例
:param use_pool: 是否使用连接池
""" """
# 读取参数 # 读取参数
cfg = global_config.napcat_ws cfg = global_config.napcat_ws
@@ -55,6 +59,8 @@ class WS:
self.bot: Bot | None = None self.bot: Bot | None = None
self.self_id: int | None = None self.self_id: int | None = None
self.code_executor = code_executor self.code_executor = code_executor
self.use_pool = use_pool
self.pool: Optional[WSConnectionPool] = None
# 创建模块专用日志记录器 # 创建模块专用日志记录器
self.logger = ModuleLogger("WebSocket") self.logger = ModuleLogger("WebSocket")
@@ -68,46 +74,112 @@ class WS:
""" """
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
if self.use_pool:
# 使用连接池模式
self.pool = WSConnectionPool(pool_size=3)
await self.pool.initialize()
self.logger.success("WebSocket 连接池初始化完成")
# 启动连接池监听循环
await self._pool_listen_loop()
else:
# 单连接模式
while True:
try:
self.logger.info(f"正在尝试连接至 NapCat: {self.url}")
async with websockets.connect(
self.url, additional_headers=headers
) as websocket_raw:
websocket = cast(WebSocketClientProtocol, websocket_raw)
self.ws = websocket
self.logger.success("连接成功!")
await self._listen_loop(websocket)
except (
websockets.exceptions.ConnectionClosed,
ConnectionRefusedError,
) as e:
conn_error = WebSocketConnectionError(
message=f"WebSocket连接失败: {str(e)}",
code=ErrorCode.WS_CONNECTION_FAILED,
original_error=e
)
self.logger.error(f"连接失败: {conn_error.message}")
self.logger.log_custom_exception(conn_error)
except Exception as e:
error = WebSocketError(
message=f"WebSocket运行异常: {str(e)}",
code=ErrorCode.WS_MESSAGE_ERROR,
original_error=e
)
self.logger.exception(f"运行异常: {error.message}")
self.logger.log_custom_exception(error)
self.logger.info(f"{self.reconnect_interval}秒后尝试重连...")
await asyncio.sleep(self.reconnect_interval)
async def _pool_listen_loop(self):
"""
连接池模式下的监听循环
"""
while True: while True:
try: try:
self.logger.info(f"正在尝试连接至 NapCat: {self.url}") # 从连接池获取一个连接
async with websockets.connect( conn = await self.pool.get_connection()
self.url, additional_headers=headers
) as websocket_raw: try:
websocket = cast(WebSocketClientProtocol, websocket_raw) # 监听连接上的消息
self.ws = websocket async for message in conn.conn:
self.logger.success("连接成功!") await self._handle_message(message, conn)
await self._listen_loop(websocket) except Exception as e:
self.logger.error(f"连接 {conn.conn_id} 监听异常: {e}")
except websockets.exceptions.AuthenticationError as e: finally:
error = WebSocketAuthenticationError( # 释放连接回连接池
message=f"WebSocket认证失败: {str(e)}", await self.pool.release_connection(conn)
code=ErrorCode.WS_AUTH_FAILED,
original_error=e
)
self.logger.error(f"连接失败: {error.message}")
self.logger.log_custom_exception(error)
except (
websockets.exceptions.ConnectionClosed,
ConnectionRefusedError,
) as e:
error = WebSocketConnectionError(
message=f"连接断开或服务器拒绝访问: {str(e)}",
code=ErrorCode.WS_CONNECTION_FAILED,
original_error=e
)
self.logger.warning(f"连接失败: {error.message}")
except Exception as e: except Exception as e:
error = WebSocketError( self.logger.error(f"连接池监听循环异常: {e}")
message=f"WebSocket运行异常: {str(e)}", await asyncio.sleep(self.reconnect_interval)
code=ErrorCode.WS_MESSAGE_ERROR,
original_error=e async def _handle_message(self, message: str, conn):
) """
self.logger.exception(f"运行异常: {error.message}") 处理从连接池获取的消息
self.logger.log_custom_exception(error) """
try:
data = orjson.loads(message)
self.logger.info(f"{self.reconnect_interval}秒后尝试重连...") # 1. 处理 API 响应
await asyncio.sleep(self.reconnect_interval) # 如果消息中包含 echo 字段,说明是 API 调用的响应
echo_id = data.get("echo")
if echo_id and echo_id in self._pending_requests:
future = self._pending_requests.pop(echo_id)
if not future.done():
future.set_result(data)
return
# 2. 处理上报事件
# 如果消息中包含 post_type 字段,说明是 OneBot 上报的事件
if "post_type" in data:
# 使用 create_task 异步执行,避免阻塞 WebSocket 接收循环
asyncio.create_task(self.on_event(data))
except orjson.JSONDecodeError as e:
error = WebSocketError(
message=f"JSON解析失败: {str(e)}",
code=ErrorCode.WS_MESSAGE_ERROR,
original_error=e
)
self.logger.error(f"解析消息异常: {error.message}")
# 如果message是bytes类型需要先解码
decoded_message = message.decode('utf-8') if isinstance(message, bytes) else message
self.logger.debug(f"原始消息: {decoded_message}")
except Exception as e:
error = WebSocketError(
message=f"处理消息异常: {str(e)}",
code=ErrorCode.WS_MESSAGE_ERROR,
original_error=e
)
self.logger.exception(f"解析消息异常: {error.message}")
self.logger.log_custom_exception(error)
async def _listen_loop(self, websocket_connection: WebSocketClientProtocol) -> None: async def _listen_loop(self, websocket_connection: WebSocketClientProtocol) -> None:
""" """
@@ -121,7 +193,7 @@ class WS:
""" """
async for message in websocket_connection: async for message in websocket_connection:
try: try:
data = json.loads(message) data = orjson.loads(message)
# 1. 处理 API 响应 # 1. 处理 API 响应
# 如果消息中包含 echo 字段,说明是 API 调用的响应 # 如果消息中包含 echo 字段,说明是 API 调用的响应
@@ -138,14 +210,16 @@ class WS:
# 使用 create_task 异步执行,避免阻塞 WebSocket 接收循环 # 使用 create_task 异步执行,避免阻塞 WebSocket 接收循环
asyncio.create_task(self.on_event(data)) asyncio.create_task(self.on_event(data))
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
error = WebSocketError( error = WebSocketError(
message=f"JSON解析失败: {str(e)}", message=f"JSON解析失败: {str(e)}",
code=ErrorCode.WS_MESSAGE_ERROR, code=ErrorCode.WS_MESSAGE_ERROR,
original_error=e original_error=e
) )
self.logger.error(f"解析消息异常: {error.message}") self.logger.error(f"解析消息异常: {error.message}")
self.logger.debug(f"原始消息: {message}") # 如果message是bytes类型需要先解码
decoded_message = message.decode('utf-8') if isinstance(message, bytes) else message
self.logger.debug(f"原始消息: {decoded_message}")
except Exception as e: except Exception as e:
error = WebSocketError( error = WebSocketError(
message=f"处理消息异常: {str(e)}", message=f"处理消息异常: {str(e)}",
@@ -236,48 +310,93 @@ class WS:
dict: OneBot API 的响应数据。如果超时或连接断开,则返回一个 dict: OneBot API 的响应数据。如果超时或连接断开,则返回一个
表示失败的字典。 表示失败的字典。
""" """
if not self.ws: if self.use_pool:
self.logger.error("调用 API 失败: WebSocket 未初始化") # 使用连接池模式
return create_error_response( if not self.pool:
code=ErrorCode.WS_DISCONNECTED, self.logger.error("调用 API 失败: WebSocket 连接池未初始化")
message="WebSocket未初始化", return create_error_response(
data={"action": action, "params": params} code=ErrorCode.WS_DISCONNECTED,
) message="WebSocket连接池未初始化",
data={"action": action, "params": params}
)
# 从连接池获取一个连接
conn = await self.pool.get_connection()
try:
echo_id = str(uuid.uuid4())
payload = {"action": action, "params": params or {}, "echo": echo_id}
from websockets.protocol import State loop = asyncio.get_running_loop()
future = loop.create_future()
self._pending_requests[echo_id] = future
if getattr(self.ws, "state", None) is not State.OPEN: try:
self.logger.error("调用 API 失败: WebSocket 连接未打开") await conn.send(orjson.dumps(payload))
return create_error_response( result = await asyncio.wait_for(future, timeout=30.0)
code=ErrorCode.WS_DISCONNECTED, return result
message="WebSocket连接未打开", except asyncio.TimeoutError:
data={"action": action, "params": params} self._pending_requests.pop(echo_id, None)
) self.logger.warning(f"API 调用超时: action={action}, params={params}")
return create_error_response(
code=ErrorCode.TIMEOUT_ERROR,
message="API调用超时",
data={"action": action, "params": params}
)
except Exception as e:
self._pending_requests.pop(echo_id, None)
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
return create_error_response(
code=ErrorCode.WS_MESSAGE_ERROR,
message=f"API调用异常: {str(e)}",
data={"action": action, "params": params}
)
finally:
# 释放连接回连接池
await self.pool.release_connection(conn)
else:
# 单连接模式
if not self.ws:
self.logger.error("调用 API 失败: WebSocket 未初始化")
return create_error_response(
code=ErrorCode.WS_DISCONNECTED,
message="WebSocket未初始化",
data={"action": action, "params": params}
)
echo_id = str(uuid.uuid4()) from websockets.protocol import State
payload = {"action": action, "params": params or {}, "echo": echo_id}
loop = asyncio.get_running_loop() if getattr(self.ws, "state", None) is not State.OPEN:
future = loop.create_future() self.logger.error("调用 API 失败: WebSocket 连接未打开")
self._pending_requests[echo_id] = future return create_error_response(
code=ErrorCode.WS_DISCONNECTED,
message="WebSocket连接未打开",
data={"action": action, "params": params}
)
try: echo_id = str(uuid.uuid4())
await self.ws.send(json.dumps(payload)) payload = {"action": action, "params": params or {}, "echo": echo_id}
return await asyncio.wait_for(future, timeout=30.0)
except asyncio.TimeoutError: loop = asyncio.get_running_loop()
self._pending_requests.pop(echo_id, None) future = loop.create_future()
self.logger.warning(f"API 调用超时: action={action}, params={params}") self._pending_requests[echo_id] = future
return create_error_response(
code=ErrorCode.TIMEOUT_ERROR, try:
message="API调用超时", await self.ws.send(orjson.dumps(payload))
data={"action": action, "params": params} return await asyncio.wait_for(future, timeout=30.0)
) except asyncio.TimeoutError:
except Exception as e: self._pending_requests.pop(echo_id, None)
self._pending_requests.pop(echo_id, None) self.logger.warning(f"API 调用超时: action={action}, params={params}")
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}") return create_error_response(
return create_error_response( code=ErrorCode.TIMEOUT_ERROR,
code=ErrorCode.WS_MESSAGE_ERROR, message="API调用超时",
message=f"API调用异常: {str(e)}", data={"action": action, "params": params}
data={"action": action, "params": params} )
) except Exception as e:
self._pending_requests.pop(echo_id, None)
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
return create_error_response(
code=ErrorCode.WS_MESSAGE_ERROR,
message=f"API调用异常: {str(e)}",
data={"action": action, "params": params}
)

View File

@@ -4,7 +4,7 @@
该模块定义了 `AccountAPI` Mixin 类,提供了所有与机器人自身账号信息、 该模块定义了 `AccountAPI` Mixin 类,提供了所有与机器人自身账号信息、
状态设置等相关的 OneBot v11 API 封装。 状态设置等相关的 OneBot v11 API 封装。
""" """
import json import orjson
from typing import Dict, Any from typing import Dict, Any
from .base import BaseAPI from .base import BaseAPI
from models.objects import LoginInfo, VersionInfo, Status from models.objects import LoginInfo, VersionInfo, Status
@@ -30,10 +30,10 @@ class AccountAPI(BaseAPI):
if not no_cache: if not no_cache:
cached_data = await redis_manager.get(cache_key) cached_data = await redis_manager.get(cache_key)
if cached_data: if cached_data:
return LoginInfo(**json.loads(cached_data)) return LoginInfo(**orjson.loads(cached_data))
res = await self.call_api("get_login_info") res = await self.call_api("get_login_info")
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时 await redis_manager.set(cache_key, orjson.dumps(res), ex=3600) # 缓存 1 小时
return LoginInfo(**res) return LoginInfo(**res)
async def get_version_info(self) -> VersionInfo: async def get_version_info(self) -> VersionInfo:
@@ -43,7 +43,7 @@ class AccountAPI(BaseAPI):
Returns: Returns:
VersionInfo: 包含 OneBot 实现版本信息的 `VersionInfo` 数据对象。 VersionInfo: 包含 OneBot 实现版本信息的 `VersionInfo` 数据对象。
""" """
res = await self.call_api("get_version_info") res = await self.call_api("get_friend_list")
return VersionInfo(**res) return VersionInfo(**res)
async def get_status(self) -> Status: async def get_status(self) -> Status:
@@ -189,10 +189,10 @@ class AccountAPI(BaseAPI):
if not no_cache: if not no_cache:
cached_data = await redis_manager.get(cache_key) cached_data = await redis_manager.get(cache_key)
if cached_data: if cached_data:
return json.loads(cached_data) return orjson.loads(cached_data)
res = await self.call_api("get_friend_list") res = await self.call_api("get_friend_list")
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时 await redis_manager.set(cache_key, orjson.dumps(res), ex=3600) # 缓存 1 小时
return res return res
async def get_group_list(self, no_cache: bool = False) -> list: async def get_group_list(self, no_cache: bool = False) -> list:
@@ -209,9 +209,9 @@ class AccountAPI(BaseAPI):
if not no_cache: if not no_cache:
cached_data = await redis_manager.get(cache_key) cached_data = await redis_manager.get(cache_key)
if cached_data: if cached_data:
return json.loads(cached_data) return orjson.loads(cached_data)
res = await self.call_api("get_group_list") res = await self.call_api("get_group_list")
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时 await redis_manager.set(cache_key, orjson.dumps(res), ex=3600) # 缓存 1 小时
return res return res

View File

@@ -4,7 +4,7 @@
该模块定义了 `FriendAPI` Mixin 类,提供了所有与好友、陌生人信息 该模块定义了 `FriendAPI` Mixin 类,提供了所有与好友、陌生人信息
等相关的 OneBot v11 API 封装。 等相关的 OneBot v11 API 封装。
""" """
import json import orjson
from typing import List, Dict, Any from typing import List, Dict, Any
from .base import BaseAPI from .base import BaseAPI
from models.objects import FriendInfo, StrangerInfo from models.objects import FriendInfo, StrangerInfo
@@ -44,10 +44,10 @@ class FriendAPI(BaseAPI):
if not no_cache: if not no_cache:
cached_data = await redis_manager.redis.get(cache_key) cached_data = await redis_manager.redis.get(cache_key)
if cached_data: if cached_data:
return StrangerInfo(**json.loads(cached_data)) return StrangerInfo(**orjson.loads(cached_data))
res = await self.call_api("get_stranger_info", {"user_id": user_id, "no_cache": no_cache}) res = await self.call_api("get_stranger_info", {"user_id": user_id, "no_cache": no_cache})
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时 await redis_manager.redis.set(cache_key, orjson.dumps(res), ex=3600) # 缓存 1 小时
return StrangerInfo(**res) return StrangerInfo(**res)
async def get_friend_list(self, no_cache: bool = False) -> List[FriendInfo]: async def get_friend_list(self, no_cache: bool = False) -> List[FriendInfo]:
@@ -64,10 +64,10 @@ class FriendAPI(BaseAPI):
if not no_cache: if not no_cache:
cached_data = await redis_manager.redis.get(cache_key) cached_data = await redis_manager.redis.get(cache_key)
if cached_data: if cached_data:
return [FriendInfo(**item) for item in json.loads(cached_data)] return [FriendInfo(**item) for item in orjson.loads(cached_data)]
res = await self.call_api("get_friend_list") res = await self.call_api("get_friend_list")
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时 await redis_manager.redis.set(cache_key, orjson.dumps(res), ex=3600) # 缓存 1 小时
return [FriendInfo(**item) for item in res] return [FriendInfo(**item) for item in res]
async def set_friend_add_request(self, flag: str, approve: bool = True, remark: str = "") -> Dict[str, Any]: async def set_friend_add_request(self, flag: str, approve: bool = True, remark: str = "") -> Dict[str, Any]:

View File

@@ -5,7 +5,7 @@
等相关的 OneBot v11 API 封装。 等相关的 OneBot v11 API 封装。
""" """
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
import json import orjson
from ..managers.redis_manager import redis_manager from ..managers.redis_manager import redis_manager
from .base import BaseAPI from .base import BaseAPI
from models.objects import GroupInfo, GroupMemberInfo, GroupHonorInfo from models.objects import GroupInfo, GroupMemberInfo, GroupHonorInfo
@@ -181,10 +181,10 @@ class GroupAPI(BaseAPI):
if not no_cache: if not no_cache:
cached_data = await redis_manager.redis.get(cache_key) cached_data = await redis_manager.redis.get(cache_key)
if cached_data: if cached_data:
return GroupInfo(**json.loads(cached_data)) return GroupInfo(**orjson.loads(cached_data))
res = await self.call_api("get_group_info", {"group_id": group_id}) res = await self.call_api("get_group_info", {"group_id": group_id})
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时 await redis_manager.redis.set(cache_key, orjson.dumps(res), ex=3600) # 缓存 1 小时
return GroupInfo(**res) return GroupInfo(**res)
async def get_group_list(self) -> Any: async def get_group_list(self) -> Any:
@@ -232,10 +232,10 @@ class GroupAPI(BaseAPI):
if not no_cache: if not no_cache:
cached_data = await redis_manager.redis.get(cache_key) cached_data = await redis_manager.redis.get(cache_key)
if cached_data: if cached_data:
return GroupMemberInfo(**json.loads(cached_data)) return GroupMemberInfo(**orjson.loads(cached_data))
res = await self.call_api("get_group_member_info", {"group_id": group_id, "user_id": user_id}) res = await self.call_api("get_group_member_info", {"group_id": group_id, "user_id": user_id})
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时 await redis_manager.redis.set(cache_key, orjson.dumps(res), ex=3600) # 缓存 1 小时
return GroupMemberInfo(**res) return GroupMemberInfo(**res)
async def get_group_member_list(self, group_id: int) -> List[GroupMemberInfo]: async def get_group_member_list(self, group_id: int) -> List[GroupMemberInfo]:

View File

@@ -8,9 +8,8 @@ from pathlib import Path
import tomllib import tomllib
from pydantic import ValidationError from pydantic import ValidationError
from .config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel from .config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel
from .utils.logger import logger, ModuleLogger from .utils.logger import ModuleLogger
from .utils.exceptions import ConfigError, ConfigNotFoundError, ConfigValidationError from .utils.exceptions import ConfigError, ConfigNotFoundError, ConfigValidationError
from .utils.error_codes import ErrorCode, create_error_response
class Config: class Config:

View File

@@ -4,7 +4,7 @@
该模块负责管理机器人的管理员列表。 该模块负责管理机器人的管理员列表。
它现在以 Redis 作为主要数据源,文件仅用作备份。 它现在以 Redis 作为主要数据源,文件仅用作备份。
""" """
import json import orjson
import os import os
from typing import Set from typing import Set
@@ -66,7 +66,7 @@ class AdminManager(Singleton):
try: try:
if os.path.exists(self.data_file): if os.path.exists(self.data_file):
with open(self.data_file, "r", encoding="utf-8") as f: with open(self.data_file, "r", encoding="utf-8") as f:
data = json.load(f) data = orjson.loads(f.read())
admins = data.get("admins", []) admins = data.get("admins", [])
admins_to_migrate = set(int(admin_id) for admin_id in admins) admins_to_migrate = set(int(admin_id) for admin_id in admins)
@@ -76,7 +76,7 @@ class AdminManager(Singleton):
else: else:
logger.info("admin.json 文件为空或不存在,无需迁移。") logger.info("admin.json 文件为空或不存在,无需迁移。")
except (json.JSONDecodeError, ValueError) as e: except ValueError as e:
logger.error(f"解析 admin.json 失败,无法迁移: {e}") logger.error(f"解析 admin.json 失败,无法迁移: {e}")
except Exception as e: except Exception as e:
logger.error(f"迁移管理员数据到 Redis 失败: {e}") logger.error(f"迁移管理员数据到 Redis 失败: {e}")
@@ -89,7 +89,7 @@ class AdminManager(Singleton):
admins = await self.get_all_admins() admins = await self.get_all_admins()
admin_list = [str(admin_id) for admin_id in admins] admin_list = [str(admin_id) for admin_id in admins]
with open(self.data_file, "w", encoding="utf-8") as f: with open(self.data_file, "w", encoding="utf-8") as f:
json.dump({"admins": admin_list}, f, indent=2, ensure_ascii=False) f.write(orjson.dumps({"admins": admin_list}, indent=2, ensure_ascii=False).decode('utf-8'))
logger.debug(f"管理员列表已备份到 {self.data_file}") logger.debug(f"管理员列表已备份到 {self.data_file}")
except Exception as e: except Exception as e:
logger.error(f"备份管理员列表到 admin.json 失败: {e}") logger.error(f"备份管理员列表到 admin.json 失败: {e}")

View File

@@ -7,21 +7,23 @@ import asyncio
from typing import Optional from typing import Optional
from playwright.async_api import async_playwright, Browser, Playwright, Page from playwright.async_api import async_playwright, Browser, Playwright, Page
from ..utils.logger import logger from ..utils.logger import logger
from ..utils.singleton import Singleton
class BrowserManager: class BrowserManager(Singleton):
""" """
浏览器管理器(异步单例) 浏览器管理器(异步单例)
""" """
_instance = None
_playwright: Optional[Playwright] = None _playwright: Optional[Playwright] = None
_browser: Optional[Browser] = None _browser: Optional[Browser] = None
_page_pool: Optional[asyncio.Queue] = None _page_pool: Optional[asyncio.Queue] = None
_pool_size: int = 3 _pool_size: int = 3
def __new__(cls): def __init__(self):
if cls._instance is None: """
cls._instance = super().__new__(cls) 初始化浏览器管理器
return cls._instance """
# 调用父类 __init__ 确保单例初始化
super().__init__()
async def initialize(self): async def initialize(self):
""" """

View File

@@ -7,12 +7,9 @@
""" """
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
import os
import base64
from models.events.message import MessageSegment from models.events.message import MessageSegment
from models.events.message import MessageSegment
from ..config_loader import global_config from ..config_loader import global_config
from ..handlers.event_handler import MessageHandler, NoticeHandler, RequestHandler from ..handlers.event_handler import MessageHandler, NoticeHandler, RequestHandler

View File

@@ -10,19 +10,21 @@ from jinja2 import Template
from .browser_manager import browser_manager from .browser_manager import browser_manager
from ..utils.logger import logger from ..utils.logger import logger
from ..utils.singleton import Singleton
class ImageManager: class ImageManager(Singleton):
""" """
图片生成管理器(单例) 图片生成管理器(单例)
""" """
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self): def __init__(self):
"""
初始化图片生成管理器
"""
# 检查是否已经初始化
if hasattr(self, 'template_dir'):
return
# 模板目录 # 模板目录
self.template_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "templates") self.template_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "templates")
# 临时文件目录 # 临时文件目录

View File

@@ -4,7 +4,7 @@
该模块负责管理用户权限,支持 admin、op、user 三个权限级别。 该模块负责管理用户权限,支持 admin、op、user 三个权限级别。
以 Redis Hash 作为主要数据源,文件仅用作备份和首次数据迁移。 以 Redis Hash 作为主要数据源,文件仅用作备份和首次数据迁移。
""" """
import json import orjson
import os import os
from typing import Dict from typing import Dict
@@ -71,7 +71,7 @@ class PermissionManager(Singleton):
try: try:
if os.path.exists(self.data_file): if os.path.exists(self.data_file):
with open(self.data_file, "r", encoding="utf-8") as f: with open(self.data_file, "r", encoding="utf-8") as f:
data = json.load(f) data = orjson.loads(f.read())
perms_to_migrate = data.get("users", {}) perms_to_migrate = data.get("users", {})
if perms_to_migrate: if perms_to_migrate:
@@ -84,7 +84,7 @@ class PermissionManager(Singleton):
else: else:
logger.info("permissions.json 文件为空或不存在,无需迁移。") logger.info("permissions.json 文件为空或不存在,无需迁移。")
except (json.JSONDecodeError, ValueError) as e: except ValueError as e:
logger.error(f"解析 permissions.json 失败,无法迁移: {e}") logger.error(f"解析 permissions.json 失败,无法迁移: {e}")
except Exception as e: except Exception as e:
logger.error(f"迁移权限数据到 Redis 失败: {e}") logger.error(f"迁移权限数据到 Redis 失败: {e}")
@@ -98,7 +98,7 @@ class PermissionManager(Singleton):
# Redis 返回的是 bytes需要解码 # Redis 返回的是 bytes需要解码
users_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in all_perms.items()} users_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in all_perms.items()}
with open(self.data_file, "w", encoding="utf-8") as f: with open(self.data_file, "w", encoding="utf-8") as f:
json.dump({"users": users_data}, f, indent=2, ensure_ascii=False) f.write(orjson.dumps({"users": users_data}, indent=2, ensure_ascii=False).decode('utf-8'))
logger.debug(f"权限数据已备份到 {self.data_file}") logger.debug(f"权限数据已备份到 {self.data_file}")
except Exception as e: except Exception as e:
logger.error(f"备份权限数据到 permissions.json 失败: {e}") logger.error(f"备份权限数据到 permissions.json 失败: {e}")

View File

@@ -10,28 +10,41 @@ import sys
from typing import Set from typing import Set
from .command_manager import CommandManager from .command_manager import CommandManager
from ..utils.exceptions import SyncHandlerError, PluginError, PluginLoadError, PluginReloadError, PluginNotFoundError from ..utils.exceptions import SyncHandlerError, PluginLoadError, PluginReloadError, PluginNotFoundError
from ..utils.logger import logger, ModuleLogger from ..utils.logger import logger, ModuleLogger
from ..utils.error_codes import ErrorCode, create_error_response from ..utils.singleton import Singleton
# 确保logger在模块级别可见 # 确保logger在模块级别可见
__all__ = ['PluginManager', 'logger'] __all__ = ['PluginManager', 'logger']
class PluginManager: class PluginManager(Singleton):
""" """
插件管理器类 插件管理器类
""" """
def __init__(self, command_manager: "CommandManager") -> None: def __init__(self, command_manager: "CommandManager" | None = None) -> None:
""" """
初始化插件管理器 初始化插件管理器
:param command_manager: CommandManager的实例 :param command_manager: CommandManager的实例
""" """
self.command_manager = command_manager # 检查是否已经初始化
self.loaded_plugins: Set[str] = set() if hasattr(self, '_command_manager'):
# 创建模块专用日志记录器 return
self.logger = ModuleLogger("PluginManager")
# 只有首次初始化时才执行
if command_manager:
self._command_manager = command_manager
self.loaded_plugins: Set[str] = set()
# 创建模块专用日志记录器
self.logger = ModuleLogger("PluginManager")
@property
def command_manager(self):
"""
获取命令管理器实例
"""
return self._command_manager
def load_all_plugins(self) -> None: def load_all_plugins(self) -> None:
""" """
@@ -99,12 +112,12 @@ class PluginManager:
self.logger.warning(f"尝试重载一个未被加载的插件: {full_module_name},将按首次加载处理。") self.logger.warning(f"尝试重载一个未被加载的插件: {full_module_name},将按首次加载处理。")
if full_module_name not in sys.modules: if full_module_name not in sys.modules:
error = PluginNotFoundError( reload_error = PluginNotFoundError(
plugin_name=full_module_name, plugin_name=full_module_name,
message="模块未在sys.modules中找到" message="模块未在sys.modules中找到"
) )
self.logger.error(f"重载失败: {error.message}") self.logger.error(f"重载失败: {reload_error.message}")
self.logger.log_custom_exception(error) self.logger.log_custom_exception(reload_error)
return return
try: try:

View File

@@ -1,18 +1,20 @@
import redis.asyncio as redis import redis.asyncio as redis
from ..config_loader import global_config as config from ..config_loader import global_config as config
from ..utils.logger import logger from ..utils.logger import logger
from ..utils.singleton import Singleton
class RedisManager: class RedisManager(Singleton):
""" """
Redis 连接管理器(异步单例) Redis 连接管理器(异步单例)
""" """
_instance = None
_redis = None _redis = None
def __new__(cls): def __init__(self):
if cls._instance is None: """
cls._instance = super().__new__(cls) 初始化 Redis 管理器
return cls._instance """
# 调用父类 __init__ 确保单例初始化
super().__init__()
async def initialize(self): async def initialize(self):
""" """

View File

@@ -6,7 +6,6 @@
# 导出核心工具 # 导出核心工具
from .logger import logger from .logger import logger
from .exceptions import * from .exceptions import *
from .json_utils import *
from .singleton import singleton from .singleton import singleton
from .executor import run_in_thread_pool, initialize_executor from .executor import run_in_thread_pool, initialize_executor
from .performance import ( from .performance import (

View File

@@ -3,6 +3,7 @@
该模块定义了项目中使用的错误码和统一的错误响应格式,确保所有模块返回一致的错误信息。 该模块定义了项目中使用的错误码和统一的错误响应格式,确保所有模块返回一致的错误信息。
""" """
from typing import Optional
# 错误码定义 # 错误码定义
class ErrorCode: class ErrorCode:
@@ -142,7 +143,7 @@ def get_error_message(code: int) -> str:
return ERROR_MESSAGES.get(code, ERROR_MESSAGES[ErrorCode.UNKNOWN_ERROR]) return ERROR_MESSAGES.get(code, ERROR_MESSAGES[ErrorCode.UNKNOWN_ERROR])
def create_error_response(code: int, message: str = None, data: dict = None, request_id: str = None) -> dict: def create_error_response(code: int, message: Optional[str] = None, data: Optional[dict] = None, request_id: Optional[str] = None) -> dict:
""" """
创建统一格式的错误响应 创建统一格式的错误响应
@@ -172,7 +173,7 @@ def create_error_response(code: int, message: str = None, data: dict = None, req
return response return response
def exception_to_error_response(exception: Exception, code: int = None, request_id: str = None) -> dict: def exception_to_error_response(exception: Exception, code: Optional[int] = None, request_id: Optional[str] = None) -> dict:
""" """
将异常对象转换为统一格式的错误响应 将异常对象转换为统一格式的错误响应

View File

@@ -1,34 +0,0 @@
"""
JSON 工具模块
统一使用高性能的 orjson 库进行 JSON 序列化和反序列化。
如果 orjson 不可用,则回退到标准库 json。
"""
from typing import Any, Union
import json
# 在模块加载时检查 orjson 是否可用
try:
import orjson
_orjson_available = True
except ImportError:
_orjson_available = False
def dumps(obj: Any) -> str:
"""
将对象序列化为 JSON 字符串。
"""
if _orjson_available:
# orjson.dumps 返回 bytes需要 decode
return orjson.dumps(obj).decode("utf-8")
else:
return json.dumps(obj, ensure_ascii=False)
def loads(json_str: Union[str, bytes]) -> Any:
"""
将 JSON 字符串反序列化为对象。
"""
if _orjson_available:
return orjson.loads(json_str)
else:
return json.loads(json_str)

View File

@@ -109,7 +109,7 @@ class PerformanceStats:
performance_stats = PerformanceStats() performance_stats = PerformanceStats()
def timeit(func: Callable = None, *, log_level: int = logging.INFO, collect_stats: bool = True): def timeit(func: Optional[Callable] = None, *, log_level: int = logging.INFO, collect_stats: bool = True):
""" """
函数执行时间分析装饰器(支持同步和异步) 函数执行时间分析装饰器(支持同步和异步)
@@ -261,7 +261,7 @@ class memory_profile:
logger.info(f"[内存分析] 使用内存: {memory_used:.2f} MB") logger.info(f"[内存分析] 使用内存: {memory_used:.2f} MB")
def memory_profile_decorator(func: Callable = None, *, interval: float = 0.1): def memory_profile_decorator(func: Optional[Callable] = None, *, interval: float = 0.1):
""" """
内存分析装饰器(支持同步函数) 内存分析装饰器(支持同步函数)
@@ -296,7 +296,7 @@ def memory_profile_decorator(func: Callable = None, *, interval: float = 0.1):
return decorator(func) return decorator(func)
def performance_monitor(func: Callable = None, *, threshold: float = 1.0): def performance_monitor(func: Optional[Callable] = None, *, threshold: float = 1.0):
""" """
性能监控装饰器 性能监控装饰器
仅当函数执行时间超过阈值时记录日志 仅当函数执行时间超过阈值时记录日志

View File

@@ -1,7 +1,7 @@
""" """
通用单例模式基类 通用单例模式基类
""" """
from typing import Any, Dict, Optional, Type, TypeVar from typing import Any, Dict, Optional, Type, TypeVar, cast
T = TypeVar('T') T = TypeVar('T')
@@ -29,9 +29,9 @@ class Singleton:
Returns: Returns:
T: 单例实例 T: 单例实例
""" """
# 使用全局字典存储实例,避免类型检查问题 # 使用全局字典存储实例,修复类型检查问题
if cls not in _instance_store: if cls not in _instance_store:
_instance_store[cls] = super().__new__(cls) _instance_store[cls] = super(Singleton, cls).__new__(cls)
return _instance_store[cls] return _instance_store[cls]
def __init__(self) -> None: def __init__(self) -> None:
@@ -67,7 +67,7 @@ def singleton(cls: Type[T]) -> Type[T]:
nonlocal class_instance nonlocal class_instance
if class_instance is None: if class_instance is None:
# 使用super()调用原始类的__new__方法 # 使用super()调用原始类的__new__方法
class_instance = cls(*args, **kwargs) class_instance = super(SingletonClass, cls).__new__(cls)
return class_instance return class_instance
# 复制类的元数据 # 复制类的元数据

231
core/ws_pool.py Normal file
View File

@@ -0,0 +1,231 @@
"""
WebSocket 连接池模块
该模块实现了 WebSocket 连接池功能,用于管理多个 WebSocket 连接,
提高并发处理能力和连接复用效率。
"""
import asyncio
import websockets
from websockets.legacy.client import WebSocketClientProtocol
from typing import Optional, Dict, Any, cast
import uuid
from loguru import logger
from .config_loader import global_config
from .utils.exceptions import WebSocketError, WebSocketConnectionError
class WSConnection:
"""
WebSocket 连接包装类
封装单个 WebSocket 连接的状态和操作
"""
def __init__(self, conn: WebSocketClientProtocol, conn_id: str):
self.conn = conn
self.conn_id = conn_id
self.last_used = asyncio.get_event_loop().time()
self.is_active = True
self._pending_requests: Dict[str, asyncio.Future] = {}
async def send(self, data: dict):
"""
发送数据到 WebSocket 连接
"""
if not self.is_active:
raise WebSocketError(f"连接 {self.conn_id} 已关闭")
try:
await self.conn.send(data)
self.last_used = asyncio.get_event_loop().time()
except Exception as e:
self.is_active = False
raise WebSocketError(f"发送数据失败: {e}")
async def recv(self):
"""
从 WebSocket 连接接收数据
"""
if not self.is_active:
raise WebSocketError(f"连接 {self.conn_id} 已关闭")
try:
data = await self.conn.recv()
self.last_used = asyncio.get_event_loop().time()
return data
except Exception as e:
self.is_active = False
raise WebSocketError(f"接收数据失败: {e}")
async def close(self):
"""
关闭 WebSocket 连接
"""
if self.is_active:
self.is_active = False
await self.conn.close()
class WSConnectionPool:
"""
WebSocket 连接池
管理多个 WebSocket 连接,提供连接的获取、释放和回收功能
"""
def __init__(self, pool_size: int = 3, max_idle_time: int = 300):
"""
初始化连接池
:param pool_size: 连接池大小
:param max_idle_time: 连接最大空闲时间(秒)
"""
self.pool_size = pool_size
self.max_idle_time = max_idle_time
self.pool: asyncio.Queue[WSConnection] = asyncio.Queue(maxsize=pool_size)
self._closed = False
self._cleanup_task: Optional[asyncio.Task] = None
# 从全局配置读取参数
self.url = global_config.napcat_ws.uri
self.token = global_config.napcat_ws.token
self.reconnect_interval = global_config.napcat_ws.reconnect_interval
logger.info(f"WebSocket 连接池初始化完成,大小: {pool_size}")
async def initialize(self):
"""
初始化连接池,创建初始连接
"""
if self._closed:
raise WebSocketError("连接池已关闭")
# 启动连接清理任务
self._cleanup_task = asyncio.create_task(self._cleanup_idle_connections())
# 创建初始连接
for _ in range(self.pool_size):
try:
conn = await self._create_connection()
await self.pool.put(conn)
logger.info(f"WebSocket 连接 {conn.conn_id} 已创建并加入连接池")
except Exception as e:
logger.error(f"创建初始连接失败: {e}")
async def _create_connection(self) -> WSConnection:
"""
创建新的 WebSocket 连接
"""
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
try:
conn_id = str(uuid.uuid4())
websocket_raw = await websockets.connect(
self.url, additional_headers=headers
)
websocket = cast(WebSocketClientProtocol, websocket_raw)
conn = WSConnection(websocket, conn_id)
logger.info(f"WebSocket 连接 {conn_id} 已建立")
return conn
except Exception as e:
raise WebSocketConnectionError(f"创建 WebSocket 连接失败: {e}")
async def get_connection(self) -> WSConnection:
"""
从连接池获取一个连接
"""
if self._closed:
raise WebSocketError("连接池已关闭")
try:
# 尝试从连接池获取连接
conn = await asyncio.wait_for(self.pool.get(), timeout=5)
# 检查连接是否活跃
if not conn.is_active:
logger.warning(f"连接 {conn.conn_id} 已失效,重新创建")
return await self._create_connection()
return conn
except asyncio.TimeoutError:
# 连接池为空,创建新连接
logger.warning("连接池为空,创建临时连接")
return await self._create_connection()
except Exception as e:
raise WebSocketError(f"获取连接失败: {e}")
async def release_connection(self, conn: WSConnection):
"""
释放连接回连接池
"""
if self._closed:
await conn.close()
return
if not conn.is_active:
logger.warning(f"连接 {conn.conn_id} 已失效,不返回连接池")
return
try:
if self.pool.full():
# 连接池已满,关闭该连接
await conn.close()
logger.info(f"连接池已满,关闭连接 {conn.conn_id}")
else:
await self.pool.put(conn)
logger.debug(f"连接 {conn.conn_id} 已返回连接池")
except Exception as e:
logger.error(f"释放连接失败: {e}")
await conn.close()
async def _cleanup_idle_connections(self):
"""
清理空闲连接任务
"""
while not self._closed:
await asyncio.sleep(60) # 每分钟检查一次
try:
# 检查连接池中的连接
new_pool = asyncio.Queue(maxsize=self.pool_size)
current_time = asyncio.get_event_loop().time()
while not self.pool.empty():
conn = await self.pool.get()
if current_time - conn.last_used > self.max_idle_time:
# 连接空闲时间过长,关闭
await conn.close()
logger.info(f"清理空闲连接 {conn.conn_id}")
else:
# 放回新队列
await new_pool.put(conn)
# 替换原连接池
self.pool = new_pool
except Exception as e:
logger.error(f"清理空闲连接失败: {e}")
async def close(self):
"""
关闭连接池
"""
if self._closed:
return
self._closed = True
# 停止清理任务
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
# 关闭所有连接
while not self.pool.empty():
conn = await self.pool.get()
await conn.close()
logger.info("WebSocket 连接池已关闭")

View File

@@ -54,3 +54,118 @@ graph LR
别几把开多个实例。。。 别几把开多个实例。。。
* **Browser Pool**: 浏览器页面提前开好,用完洗干净放回去 * **Browser Pool**: 浏览器页面提前开好,用完洗干净放回去
* **Connection Pool**: Redis 和 HTTP 请求都用连接池 * **Connection Pool**: Redis 和 HTTP 请求都用连接池
## 4. 技术栈全景
NEO Bot 的“骨架”是由一堆现代 Python 库和技术堆起来的。下面这张清单能让你一眼看清整个项目的技术选型。
### 编程语言与运行时
* **Python 3.14**: 镀铬酸钾创项目的时候用的 Python 3.14 3.14兼容JIT那就这样吧
* **JIT (Just-In-Time)**: 启动时加 `-X jit` 参数,运行时把热点代码编译成机器码
* **Mypyc (AOT)**: 核心模块(`core/ws.py`, `core/managers/*.py`编译成C扩展机器码运行
### 异步与网络
* **asyncio**: Python 原生异步框架,所有 IO 操作都是非阻塞的
* **uvloop (Linux)**: 替代 asyncio 默认事件循环,性能更高
* **IOCP (Windows)**: Windows 上的高性能 IO 完成端口
* **aiohttp**: 异步 HTTP 客户端/服务器,用于 API 请求和 WebSocket 通信
* **websockets**: 纯粹的 WebSocket 客户端/服务器库
* **Playwright**: 浏览器自动化工具,负责截图、页面渲染
### 数据与存储
* **Redis**: 内存数据库,用于缓存帮助图片、会话状态等
* **orjson**: Rust 编写的 JSON 序列化库,比标准 `json` 快很多
* **Pydantic**: 数据验证与设置管理配置文件、API 请求/响应都靠它
### 工具与工具链
* **Loguru**: 结构化日志记录,输出漂亮且支持文件轮转
* **Watchdog**: 文件系统监控,实现插件热重载
* **Jinja2**: 模板引擎,渲染 HTML 页面然后转为图片
* **Pillow**: 图像处理库,负责图片格式转换、尺寸调整
* **BeautifulSoup4**: HTML 解析B站、抖音等链接解析插件在用
* **httpx**: 异步 HTTP 客户端,某些插件用它发请求
### 测试与开发
* **Pytest**: 测试框架,写单元测试、集成测试
* **Docker**: 容器化,沙箱执行用户代码时可能用到
* **cryptography**: 加密解密,处理一些安全相关的操作
### 架构模式
* **Singleton (单例)**: 全局唯一实例,所有管理器都是单例
* **Connection Pool (连接池)**: Redis 连接、HTTP 会话都复用
* **Plugin System (插件系统)**: 动态导入、装饰器注册,一个 `.py` 文件就是一个插件
## 5. Python 动态语言特性运用
Python 是一门“动态”语言这意味着你可以在运行时做很多静态语言做不到的事情。NEO Bot 大量利用了这些特性,让框架变得灵活、易扩展。
### 装饰器 (Decorator)
* **何用**: 给函数“贴上标签”,告诉框架这个函数是干什么的
* **何处**:
* `@matcher.command("echo")` 注册一个消息指令
* `@matcher.on_message()` 注册一个通用消息处理器
* `@matcher.on_notice()` 注册一个通知事件处理器
* **何原理**: 装饰器本质上是一个高阶函数,它接收被装饰的函数,然后把它“注册”到某个管理器里
### 动态导入 (Dynamic Import)
* **何用**: 不需要在代码开头写死 `import`,运行时根据情况加载模块
* **何处**: `PluginManager.load_all_plugins()``importlib.import_module()` 扫描 `plugins/` 目录,找到 `.py` 文件就导入
* **何原理**: Python 的模块系统是完全动态的,`import` 语句实际上调用了 `__import__()` 函数
### 自省 (Introspection)
* **何用**: 让代码能“看到”自己的结构,比如函数属于哪个模块、有哪些参数
* **何处**:
* `inspect.getmodule(func)` 获取函数所在的模块名,用于记录插件来源
* `func.__name__`, `func.__module__` 获取函数名和模块名
* **何原理**: Python 把几乎所有元信息都存在对象的 `__dict__` 里,你可以随时翻看
### 鸭子类型 (Duck Typing)
* **何用**: “如果它走起来像鸭子,叫起来像鸭子,那它就是鸭子。”——不检查类型,只检查行为
* **何处**:
* 事件处理器不要求事件对象必须是某个类,只要它有 `post_type``user_id` 等属性就行
* 插件不需要继承某个基类,只要它有 `__plugin_meta__` 字典就行
* **何原理**: Python 的变量没有类型,类型是对象自己的事。只要对象有你需要的方法或属性,你就可以调用它
### 反射 (Reflection)
* **何用**: 在运行时检查、修改对象的结构
* **何处**:
* `getattr(module, "__plugin_meta__")` 获取插件的元数据字典
* `hasattr(event, "raw_message")` 检查事件对象是否有某个属性
* `setattr()` 动态设置属性(虽然用得少)
* **何原理**: Python 的对象本质上就是字典(`__dict__``getattr`/`setattr` 就是对这个字典的操作
### 元编程 (Metaprogramming)
* **何用**: 在代码运行时改变代码的行为
* **何处**:
* `Singleton` 基类重写 `__new__` 方法,控制实例创建,确保全局只有一个实例
* 装饰器在函数定义时修改函数,给它添加额外逻辑
* **何原理**: Python 的类也是对象(类型对象),你可以通过修改类来影响它所有实例的行为
### 上下文管理器 (Context Manager)
* **何用**: 安全地获取和释放资源,比如文件、网络连接、浏览器页面
* **何处**:
* `async with browser_manager.get_page() as page:` 从页面池获取一个页面,用完后自动放回
* `async with aiohttp.ClientSession() as session:` 发起 HTTP 请求后自动关闭会话
* **何原理**: `__enter__`/`__exit__`(同步)或 `__aenter__`/`__aexit__`(异步)协议
### 描述符 (Descriptor)
* **何用**: 控制属性访问的逻辑,比如把方法伪装成属性
* **何处**:
* `@property` 把方法变成只读属性,比如 `PluginManager.command_manager`
* `@property.setter` 给属性设置值时的自定义逻辑
* **何原理**: 描述符是一个实现了 `__get__``__set__``__delete__` 方法的类
### 猴子补丁 (Monkey Patching)
* **何用**: 在运行时修改模块、类或对象,通常用于测试或修复第三方库
* **何处**: 测试中可能会用 `unittest.mock.patch` 临时替换某个函数,模拟它的行为
* **何原理**: Python 的模块和类都是可变的,你可以直接给它们赋值新属性
### eval/exec
* **何用**: 执行字符串形式的 Python 代码
* **何处**: `code_py.py` 插件中,用户发送的代码片段会被 `exec()` 执行,实现代码沙箱功能
* **何原理**: Python 解释器本身就是一个运行时环境,`eval()` 用于表达式,`exec()` 用于语句
### 类型提示 (Type Hints)
* **何用**: 虽然 Python 是动态类型,但类型提示能让代码更清晰,工具(如 Mypy也能做静态检查
* **何处**: 几乎所有函数和方法的参数、返回值都加了类型提示,这让 Mypyc 编译成为可能
* **何原理**: 类型提示只是注解,运行时通常被忽略(除非你用 `typing` 模块做检查)

View File

@@ -0,0 +1,357 @@
# NEO Bot 开发规范与公约
写代码很简单,但写出**高性能、不炸裂、好维护**的代码需要遵守规矩。
本文档定义了 NEO Bot 项目的开发守则、编码公约、注意事项和代码规范。所有贡献者和插件开发者都**必须**遵循这些规范,确保机器人稳定运行、代码质量统一。
> 如果你觉得规范太麻烦,可以问问镀铬酸钾,他会给你一对一教学。。。但最好还是遵守规矩。
**补充阅读**
- [插件开发最佳实践](./plugin-development/best-practices.md) - 必读!写插件的基本规矩
- [项目结构](./project-structure.md) - 了解代码组织
- [核心概念](./core-concepts/architecture.md) - 理解框架设计
## 1. 开发守则(基本原则)
### 1.1 异步优先原则
- **绝对不要阻塞事件循环**NeoBot 采用单线程异步架构,任何同步阻塞操作都会导致整个机器人卡死。
- **禁止**`time.sleep()`、同步 `requests`、密集 CPU 计算
- **必须**:使用 `await asyncio.sleep()`、异步 HTTP 客户端、线程池执行同步任务
- **异步任务处理**:长时间运行的任务应使用 `run_in_thread_pool``asyncio.create_task` 执行,避免阻塞主循环。
### 1.2 资源管理原则
- **连接复用**:禁止重复创建连接和资源实例。
- HTTP 请求:使用全局 `aiohttp` session 或插件提供的 `get_session()`
- 浏览器操作:必须通过 `browser_manager.get_page()` 获取页面实例
- Redis 连接:通过 `redis_manager` 单例访问
- **资源池化**:浏览器页面、数据库连接等资源必须使用框架提供的池化机制。
### 1.3 性能优化原则
- **缓存策略**:频繁访问的外部数据必须添加缓存。
- 短期缓存(<1小时使用 Redis 或内存缓存
- 长期缓存:考虑持久化存储
- **懒加载**:大型资源或初始化成本高的组件应延迟加载。
### 1.4 错误处理原则
- **异常捕获**:所有插件代码都应妥善处理异常,避免插件崩溃影响机器人运行。
- **友好提示**:向用户返回清晰、友好的错误信息,避免暴露内部细节。
- **日志记录**:所有重要操作和错误都应记录日志,使用 `ModuleLogger` 进行结构化日志记录。
### 1.5 安全性原则
- **输入验证**:所有用户输入都必须验证和清理,防止注入攻击。
- **代码执行安全**:使用沙箱环境执行用户代码,隔离系统资源。
- **权限控制**:严格遵循权限管理系统,禁止越权操作。
### 1.6 跨平台兼容性原则
NEO Bot 需要在 **Windows 开发环境**和 **Linux 生产环境**中都能正常运行。
- **路径处理**
- 使用 `pathlib.Path` 处理文件路径,避免手动拼接字符串。
- 使用 `/` 作为路径分隔符Python 会自动转换)。
- 禁止使用硬编码的路径分隔符(如 `\\``/`)。
- **系统依赖**
- 避免使用平台特定的系统调用。
- 如果必须使用,通过 `sys.platform` 检测平台并提供备选方案。
- **环境变量**
- 通过 `global_config` 获取配置,而不是直接读取环境变量。
- 敏感信息(如 API 密钥)必须通过配置管理。
- **文件权限**
- 在 Linux 上注意文件权限设置,确保 Bot 有读写权限。
- 临时文件应放在系统临时目录(`tempfile.gettempdir()`)。
## 2. 公约(编码约定)
### 2.1 项目结构公约
- **插件位置**:所有插件必须放置在 `plugins/` 目录下,单个 `.py` 文件或包含 `__init__.py` 的目录。
- **模块导入**:遵循标准导入顺序:标准库 → 第三方库 → 本地模块。
- **配置访问**:通过 `global_config` 单例访问配置,禁止硬编码配置值。
### 2.2 单例管理器使用公约
NEO Bot 的核心是**单例管理器**`core/managers/` 目录下的类)。所有全局资源都必须通过管理器访问。
- **禁止重复创建**:严禁自己实例化管理器类,必须通过导入的单例对象访问。
-`from core.managers.redis_manager import redis_manager`
-`RedisManager()` (错误!会创建新实例)
- **资源池化**:浏览器页面、数据库连接等资源必须使用管理器提供的池化接口。
-`await browser_manager.get_page()`
-`playwright.chromium.launch()` (错误!会创建新浏览器进程)
- **数据一致性**:单例管理器确保全局数据一致性,不要绕过管理器直接操作底层资源。
### 2.2.1 单例模式实现机制
NEO Bot 提供了两种单例模式实现方式,位于 `core/utils/singleton.py`
#### 1. Singleton 基类(继承方式)
```python
from core.utils.singleton import Singleton
class MyManager(Singleton):
"""通过继承 Singleton 基类实现单例"""
def __init__(self, config: dict):
"""
初始化管理器
Args:
config: 配置字典
"""
# 调用父类 __init__ 确保单例初始化
super().__init__()
# 检查是否已经初始化(防止 __init__ 被多次调用)
if hasattr(self, '_my_initialized') and self._my_initialized:
return
# 执行一次性初始化逻辑
self.config = config
self.resource = None
self._initialize_resource()
# 标记为已初始化
self._my_initialized = True
def _initialize_resource(self):
"""初始化资源(只执行一次)"""
self.resource = initialize_resource(self.config)
async def cleanup(self):
"""清理资源(单例管理器应实现清理方法)"""
if self.resource:
await self.resource.close()
```
**特性**
- 通过重写 `__new__` 方法确保每个类只有一个实例
- 自动处理重复初始化问题,但建议子类添加额外的初始化检查
- 使用全局字典存储实例,避免类型检查问题
- 支持带参数的 `__init__` 方法
#### 2. @singleton 装饰器(装饰器方式)
```python
from core.utils.singleton import singleton
@singleton
class MyManager:
"""通过装饰器实现单例"""
def __init__(self, config):
self.config = config
self.resource = None
async def initialize(self):
self.resource = await load_resource()
```
**特性**
- 将普通类转换为单例类,无需修改类继承关系
- 保持原始类的元数据(名称、文档字符串等)
- 适用于无法修改基类的现有类
#### 3. 使用建议
- **新管理器类**:优先使用 **Singleton 基类继承方式**,结构更清晰
- **现有类转换**:使用 **@singleton 装饰器**,无需重构
- **线程安全**:两种方式都假设在单线程异步环境中使用,如需线程安全请自行加锁
- **导入方式**:单例类应该通过模块级别的实例变量导出,如:
```python
# redis_manager.py
class RedisManager(Singleton):
...
redis_manager = RedisManager() # 创建并导出单例实例
```
#### 4. 重要注意事项
- **避免循环导入**:单例类的导入应谨慎处理,避免循环依赖
- **初始化时机**:单例在第一次导入时创建,确保所需依赖已就绪
- **__init__ 调用语义**:虽然实例是单例,但 `__init__` 方法可能被多次调用(如重新导入时)。应添加额外检查确保一次性逻辑只执行一次。
- **资源清理**:单例管理器应在程序退出时清理资源,实现 `cleanup()` 方法
### 2.3 命名公约
- **文件命名**:使用小写字母和下划线,例如 `my_plugin.py`。
- **类命名**:使用 `PascalCase`,例如 `CommandManager`。
- **函数/方法命名**:使用 `snake_case`,例如 `handle_message`。
- **常量命名**:使用 `UPPER_SNAKE_CASE`,例如 `MAX_RETRY_COUNT`。
- **变量命名**:使用 `snake_case`,具有描述性,避免单字母变量(循环变量除外)。
### 2.4 类型提示公约
- **全面使用**:所有函数、方法、类属性都应提供类型提示。**这是强制要求**,因为框架开启了 Mypyc 编译。
- **性能优化**:类型提示不仅帮助发现 Bug还能让 Mypyc 生成更高效的机器码。
- **返回类型**:明确指定返回类型,包括 `None`。
- **复杂类型**:使用 `typing` 模块中的泛型,如 `List[str]`、`Dict[str, Any]`。
- **可选参数**:使用 `Optional[...]` 或默认值 `= None`。
**示例**
```python
# 好的写法
async def handle(event: MessageEvent, args: list[str]) -> None:
...
# 不好写法(会导致编译警告)
async def handle(event, args):
...
```
### 2.5 异常处理公约
- **自定义异常**:使用框架提供的自定义异常类,避免抛出通用的 `Exception`。
- **异常链**:保留原始异常信息,使用 `raise CustomError(...) from e`。
- **资源清理**:使用 `try...finally` 或上下文管理器确保资源释放。
### 2.6 日志记录公约
- **模块化日志**:每个模块使用 `ModuleLogger("ModuleName")` 创建专用日志记录器。
- **日志级别**
- `DEBUG`:调试信息,详细操作记录
- `INFO`:常规操作记录
- `WARNING`:预期内的异常或潜在问题
- `ERROR`:操作失败但可恢复的错误
- `CRITICAL`:系统级错误,需要立即关注
## 3. 注意事项(常见陷阱)
### 3.1 异步编程陷阱
- **忘记 await**:异步函数调用必须使用 `await`,否则任务不会执行。
- **阻塞循环**:在异步函数中执行同步阻塞操作会冻结整个事件循环。
- **任务泄漏**:创建的异步任务必须被妥善管理,避免内存泄漏。
### 3.2 资源管理陷阱
- **连接泄漏**:未关闭的 HTTP 连接、数据库连接会导致资源耗尽。
- **文件句柄泄漏**:打开的文件必须显式关闭或使用上下文管理器。
- **缓存雪崩**:大量缓存同时过期可能导致系统负载激增。
### 3.3 性能陷阱
- **N+1 查询**:避免在循环中执行数据库或 API 查询,使用批量操作。
- **内存泄漏**:大型数据结构长时间驻留内存,应定期清理。
- **重复计算**:相同的计算结果应缓存,避免重复计算。
### 3.4 安全性陷阱
- **SQL 注入**:使用参数化查询或 ORM禁止拼接 SQL 字符串。
- **XSS 攻击**:渲染用户输入时必须进行 HTML 转义。
- **路径遍历**:用户提供的文件路径必须进行规范化验证。
## 4. 代码规范(详细指南)
### 4.1 文档字符串规范(强制要求)
**所有代码必须包含完整的文档字符串**,这是项目质量保证的基础。缺少文档字符串的代码将在审查中被拒绝。
- **模块级文档**:每个模块顶部应有文档字符串,描述模块功能和主要接口。
- **类级文档**:每个类应有文档字符串,描述类的职责、使用方法和示例。
- **函数/方法级文档**:每个公共函数和方法必须有文档字符串,包含参数说明、返回值和异常信息。
**参数注释要求**
1. 每个参数都必须有类型提示和简要说明
2. 返回值必须明确说明类型和含义
3. 可能抛出的异常必须列出
4. 复杂的函数应提供使用示例
**标准格式示例:**
```python
def process_data(data: List[str], timeout: int = 30) -> Dict[str, Any]:
"""
处理数据并返回结果。
Args:
data: 待处理的数据列表
timeout: 操作超时时间,单位秒
Returns:
处理结果的字典,包含状态和详情
Raises:
TimeoutError: 处理超时时抛出
ValueError: 数据格式错误时抛出
Example:
>>> result = process_data(["item1", "item2"])
>>> print(result["status"])
"""
```
### 4.2 函数设计规范
- **单一职责**:每个函数只做一件事,保持功能简洁。
- **参数数量**:函数参数不宜过多(建议 ≤5过多时考虑使用 `dataclass` 或 `TypedDict`。
- **默认参数**:避免使用可变对象作为默认参数,使用 `None` 代替。
### 4.3 类设计规范
- **单一职责**:每个类应有明确的单一职责。
- **组合优于继承**:优先使用组合而非继承来复用功能。
- **属性访问控制**:使用 `@property` 装饰器控制属性访问,隐藏内部实现。
### 4.4 错误处理规范
- **错误码统一**:使用框架定义的 `ErrorCode` 枚举,避免自定义魔法数字。
- **错误响应格式**:使用 `exception_to_error_response` 生成统一错误响应。
- **用户友好消息**:错误消息应同时包含技术细节(日志)和用户友好提示(界面)。
### 4.5 测试规范
- **测试覆盖率**:核心功能应达到 80% 以上的测试覆盖率。
- **异步测试**:使用 `pytest-asyncio` 进行异步测试。
- **测试隔离**:测试用例之间应相互独立,避免依赖执行顺序。
## 5. 提交与协作规范
### 5.1 Git 提交规范
- **提交信息格式**:遵循 Conventional Commits 规范
```
<type>(<scope>): <subject>
<body>
<footer>
```
- **type**feat、fix、docs、style、refactor、test、chore
- **scope**:影响的模块或功能区域
- **subject**简洁的描述50字符以内
- **body**:详细说明(可选)
- **footer**Breaking Changes 或 Issue 引用
### 5.2 代码审查规范
- **审查重点**:功能正确性、代码规范、性能影响、安全性。
- **审查态度**:建设性反馈,避免人身攻击。
- **审查时效**24小时内响应审查请求。
### 5.3 分支管理规范
- **主分支**`main` 分支始终保持可部署状态。
- **功能分支**:从 `main` 创建,命名格式 `feature/简短描述`。
- **修复分支**:从 `main` 创建,命名格式 `fix/问题描述`。
### 5.4 发布规范
- **版本号**遵循语义化版本控制SemVer`主版本.次版本.修订版本`
- **更新日志**:每次发布都应更新 `CHANGELOG.md`。
- **向后兼容**:非主版本更新应保持 API 向后兼容。
## 6. 插件开发特别规范
### 6.1 插件元数据
每个插件必须在文件顶部定义 `__plugin_meta__` 字典:
```python
__plugin_meta__ = {
"name": "插件名称",
"description": "插件功能描述",
"usage": "使用说明,包括命令格式和示例",
"author": "作者名(可选)",
"version": "版本号(可选)",
}
```
### 6.2 命令注册
- **命令前缀**:使用配置中定义的前缀,不要硬编码。
- **权限控制**:使用 `Permission` 枚举指定命令权限级别。
- **参数解析**:利用框架的自动参数解析功能,避免手动解析。
### 6.3 插件生命周期
- **初始化**:避免在模块级别执行初始化操作,使用函数包装。
- **资源清理**:提供清理函数或使用上下文管理器管理资源。
- **错误恢复**:插件崩溃后应能优雅恢复,不影响其他插件。
## 7. 总结
遵循这些规范将确保 NeoBot 项目保持高质量、高性能和高可维护性。所有贡献者都应阅读并理解这些规范,并在代码审查中互相监督执行。
**记住:规范不是束缚,而是高效协作的基础。**

View File

@@ -198,7 +198,8 @@ async def main():
# 初始化代码执行器 # 初始化代码执行器
code_executor = initialize_executor(config) code_executor = initialize_executor(config)
websocket_client = WS(code_executor=code_executor) # 使用连接池模式初始化 WebSocket 客户端
websocket_client = WS(code_executor=code_executor, use_pool=True)
# 启动代码执行器的后台 worker # 启动代码执行器的后台 worker
logger.debug("[Main] 检查是否需要启动代码执行 Worker...") logger.debug("[Main] 检查是否需要启动代码执行 Worker...")

View File

@@ -68,12 +68,15 @@ class MessageEvent(OneBotEvent):
sender: Optional[Sender] = None sender: Optional[Sender] = None
"""发送者信息""" """发送者信息"""
# 权限级别常量,用于装饰器参数
ADMIN = Permission.ADMIN
OP = Permission.OP
USER = Permission.USER
@property @property
def post_type(self) -> str: def post_type(self) -> str:
return EventType.MESSAGE return EventType.MESSAGE
async def reply(self, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False): async def reply(self, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False):
""" """
回复消息(抽象方法,由子类实现) 回复消息(抽象方法,由子类实现)
@@ -84,12 +87,6 @@ class MessageEvent(OneBotEvent):
raise NotImplementedError("reply method must be implemented by subclasses") raise NotImplementedError("reply method must be implemented by subclasses")
# 在类定义之后添加权限常量作为类变量
MessageEvent.ADMIN = MESSAGE_EVENT_ADMIN
MessageEvent.OP = MESSAGE_EVENT_OP
MessageEvent.USER = MESSAGE_EVENT_USER
@dataclass(slots=True) @dataclass(slots=True)
class PrivateMessageEvent(MessageEvent): class PrivateMessageEvent(MessageEvent):
""" """

View File

@@ -41,7 +41,6 @@ def get_performance_config():
dict: 性能分析配置 dict: 性能分析配置
""" """
import os import os
import json
# 从环境变量加载配置 # 从环境变量加载配置
config = PERFORMANCE_CONFIG.copy() config = PERFORMANCE_CONFIG.copy()

View File

@@ -53,16 +53,16 @@ async def admin_management(event: MessageEvent, args: list[str]):
# 根据子命令分发 # 根据子命令分发
if subcommand == "add_admin": if subcommand == "add_admin":
permission_manager.set_user_permission(target_user_id, Permission.ADMIN) await permission_manager.set_user_permission(target_user_id, Permission.ADMIN)
await event.reply(f"已成功添加管理员:{target_user_id}") await event.reply(f"已成功添加管理员:{target_user_id}")
elif subcommand == "remove_admin": elif subcommand == "remove_admin":
permission_manager.set_user_permission(target_user_id, Permission.USER) await permission_manager.set_user_permission(target_user_id, Permission.USER)
await event.reply(f"已成功移除管理员:{target_user_id}") await event.reply(f"已成功移除管理员:{target_user_id}")
elif subcommand == "add_op": elif subcommand == "add_op":
permission_manager.set_user_permission(target_user_id, Permission.OP) await permission_manager.set_user_permission(target_user_id, Permission.OP)
await event.reply(f"已成功添加操作员:{target_user_id}") await event.reply(f"已成功添加操作员:{target_user_id}")
elif subcommand == "remove_op": elif subcommand == "remove_op":
permission_manager.set_user_permission(target_user_id, Permission.USER) await permission_manager.set_user_permission(target_user_id, Permission.USER)
await event.reply(f"已成功移除操作员:{target_user_id}") await event.reply(f"已成功移除操作员:{target_user_id}")
else: else:
await event.reply(f"未知的子命令 '{subcommand}'\n\n{__plugin_meta__['usage']}") await event.reply(f"未知的子命令 '{subcommand}'\n\n{__plugin_meta__['usage']}")

View File

@@ -25,8 +25,10 @@ async def handle_friend_request(bot: Bot, event: FriendRequestEvent):
# 自动同意好友请求 # 自动同意好友请求
await bot.call_api( await bot.call_api(
"set_friend_add_request", "set_friend_add_request",
flag=event.flag, params={
approve=True "flag": event.flag,
"approve": True
}
) )
print(f"[自动同意] 已同意用户 {event.user_id} 的好友请求") print(f"[自动同意] 已同意用户 {event.user_id} 的好友请求")
except Exception as e: except Exception as e:
@@ -44,9 +46,11 @@ async def handle_group_request(bot: Bot, event: GroupRequestEvent):
# 自动同意群聊邀请 # 自动同意群聊邀请
await bot.call_api( await bot.call_api(
"set_group_add_request", "set_group_add_request",
flag=event.flag, params={
sub_type=event.sub_type, "flag": event.flag,
approve=True "sub_type": event.sub_type,
"approve": True
}
) )
print(f"[自动同意] 已同意加入群聊 {event.group_id} (邀请人: {event.user_id})") print(f"[自动同意] 已同意加入群聊 {event.group_id} (邀请人: {event.user_id})")
except Exception as e: except Exception as e:

View File

@@ -1,13 +1,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re import re
import json import orjson
import abc import abc
import aiohttp import aiohttp
from typing import Optional, Dict, Any, List, Union from typing import Optional, Dict, Any, List, Union
from cachetools import TTLCache
from core.utils.logger import logger from core.utils.logger import logger
from models import MessageEvent, MessageSegment from models import MessageEvent
class BaseParser(metaclass=abc.ABCMeta): class BaseParser(metaclass=abc.ABCMeta):
@@ -38,6 +37,7 @@ class BaseParser(metaclass=abc.ABCMeta):
""" """
self.name = "Base Parser" self.name = "Base Parser"
self.url_pattern = re.compile(r"https?://[^\s]+") self.url_pattern = re.compile(r"https?://[^\s]+")
self.processed_messages = {} # 用于存储已处理的消息ID防止重复处理
@classmethod @classmethod
def get_session(cls) -> aiohttp.ClientSession: def get_session(cls) -> aiohttp.ClientSession:
@@ -105,12 +105,12 @@ class BaseParser(metaclass=abc.ABCMeta):
if segment.type == "json": if segment.type == "json":
logger.info(f"[{self.name}] 检测到JSON CQ码: {segment.data}") logger.info(f"[{self.name}] 检测到JSON CQ码: {segment.data}")
try: try:
json_data = json.loads(segment.data.get("data", "{}")) json_data = orjson.loads(segment.data.get("data", "{}"))
short_url = json_data.get("meta", {}).get("detail_1", {}).get("qqdocurl") short_url = json_data.get("meta", {}).get("detail_1", {}).get("qqdocurl")
if short_url: if short_url:
logger.success(f"[{self.name}] 成功从JSON卡片中提取到链接: {short_url}") logger.success(f"[{self.name}] 成功从JSON卡片中提取到链接: {short_url}")
return short_url return short_url
except (json.JSONDecodeError, KeyError) as e: except (orjson.JSONDecodeError, KeyError) as e:
logger.error(f"[{self.name}] 解析JSON失败: {e}") logger.error(f"[{self.name}] 解析JSON失败: {e}")
continue continue
return None return None

View File

@@ -1,14 +1,14 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re import re
import json import orjson
import aiohttp import aiohttp
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List, Union
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from core.utils.logger import logger from core.utils.logger import logger
from models import MessageEvent, MessageSegment from models import MessageEvent, MessageSegment
from ..base import BaseParser from ..base import BaseParser
from ..utils import format_duration, clean_url from ..utils import format_duration
from cachetools import TTLCache from cachetools import TTLCache
@@ -42,7 +42,7 @@ class BiliParser(BaseParser):
clean_url = clean_url.split('#/')[0] clean_url = clean_url.split('#/')[0]
session = self.get_session() session = self.get_session()
async with session.get(clean_url, headers=self.HEADERS, timeout=5) as response: async with session.get(clean_url, headers=self.HEADERS, timeout=aiohttp.ClientTimeout(total=5)) as response:
response.raise_for_status() response.raise_for_status()
text = await response.text() text = await response.text()
soup = BeautifulSoup(text, 'html.parser') soup = BeautifulSoup(text, 'html.parser')
@@ -93,14 +93,14 @@ class BiliParser(BaseParser):
json_str = json_str.strip().rstrip(';') json_str = json_str.strip().rstrip(';')
try: try:
data = json.loads(json_str) data = orjson.loads(json_str)
except json.JSONDecodeError: except ValueError:
# 如果直接解析失败尝试清理JSON字符串 # 如果直接解析失败尝试清理JSON字符串
# 移除可能的注释或无效字符 # 移除可能的注释或无效字符
cleaned_json = re.sub(r',\s*[}]', '}', json_str) # 移除末尾多余的逗号 cleaned_json = re.sub(r',\s*[}]', '}', json_str) # 移除末尾多余的逗号
cleaned_json = re.sub(r'/\*.*?\*/', '', cleaned_json) # 移除注释 cleaned_json = re.sub(r'/\*.*?\*/', '', cleaned_json) # 移除注释
cleaned_json = re.sub(r'//.*', '', cleaned_json) # 移除行注释 cleaned_json = re.sub(r'//.*', '', cleaned_json) # 移除行注释
data = json.loads(cleaned_json) data = orjson.loads(cleaned_json)
video_data = data.get('videoData', {}) video_data = data.get('videoData', {})
up_data = data.get('upData', {}) up_data = data.get('upData', {})
@@ -134,7 +134,7 @@ class BiliParser(BaseParser):
"followers": up_data.get('fans', 0), "followers": up_data.get('fans', 0),
} }
except (aiohttp.ClientError, KeyError, AttributeError, json.JSONDecodeError) as e: except (aiohttp.ClientError, KeyError, AttributeError, ValueError) as e:
logger.error(f"[{self.name}] 解析视频信息失败: {e}") logger.error(f"[{self.name}] 解析视频信息失败: {e}")
logger.debug(f"失败的URL: {url}") logger.debug(f"失败的URL: {url}")
except Exception as e: except Exception as e:
@@ -155,7 +155,7 @@ class BiliParser(BaseParser):
""" """
try: try:
session = self.get_session() session = self.get_session()
async with session.head(short_url, headers=self.HEADERS, allow_redirects=False, timeout=5) as response: async with session.head(short_url, headers=self.HEADERS, allow_redirects=False, timeout=aiohttp.ClientTimeout(total=5)) as response:
if response.status == 302: if response.status == 302:
return response.headers.get('Location') return response.headers.get('Location')
except Exception as e: except Exception as e:
@@ -175,13 +175,13 @@ class BiliParser(BaseParser):
api_url = f"https://api.mir6.com/api/bzjiexi?url={video_url}&type=json" api_url = f"https://api.mir6.com/api/bzjiexi?url={video_url}&type=json"
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(api_url, headers=self.HEADERS, timeout=10) as response: async with session.get(api_url, headers=self.HEADERS, timeout=aiohttp.ClientTimeout(total=10)) as response:
response.raise_for_status() response.raise_for_status()
# 使用 content_type=None 来忽略 Content-Type 检查 # 使用 content_type=None 来忽略 Content-Type 检查
data = await response.json(content_type=None) data = await response.json(content_type=None)
if data.get("code") == 200 and data.get("data"): if data.get("code") == 200 and data.get("data"):
return data["data"][0].get("video_url") return data["data"][0].get("video_url")
except (aiohttp.ClientError, json.JSONDecodeError, KeyError, IndexError) as e: except (aiohttp.ClientError, ValueError, KeyError, IndexError) as e:
logger.error(f"[{self.name}] 调用第三方API解析视频失败: {e}") logger.error(f"[{self.name}] 调用第三方API解析视频失败: {e}")
return None return None
@@ -197,6 +197,7 @@ class BiliParser(BaseParser):
List[Any]: 消息段列表 List[Any]: 消息段列表
""" """
# 检查视频时长 # 检查视频时长
video_message: Union[str, MessageSegment]
if data['duration'] > 1200: # 20分钟 = 1200秒 if data['duration'] > 1200: # 20分钟 = 1200秒
video_message = "视频时长超过20分钟不进行解析。" video_message = "视频时长超过20分钟不进行解析。"
else: else:

View File

@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re import re
import json
import aiohttp import aiohttp
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
@@ -40,7 +39,7 @@ class DouyinParser(BaseParser):
api_url = f"http://api.xhus.cn/api/douyin?url={url}" api_url = f"http://api.xhus.cn/api/douyin?url={url}"
session = self.get_session() session = self.get_session()
async with session.get(api_url, headers=self.HEADERS, timeout=10) as response: async with session.get(api_url, headers=self.HEADERS, timeout=aiohttp.ClientTimeout(total=10)) as response:
if response.status != 200: if response.status != 200:
logger.error(f"[{self.name}] API请求失败状态码: {response.status}") logger.error(f"[{self.name}] API请求失败状态码: {response.status}")
return None return None
@@ -75,7 +74,7 @@ class DouyinParser(BaseParser):
"music": data.get("music", {}), "music": data.get("music", {}),
} }
except (aiohttp.ClientError, KeyError, AttributeError, json.JSONDecodeError) as e: except (aiohttp.ClientError, KeyError, AttributeError, ValueError) as e:
logger.error(f"[{self.name}] 解析抖音视频信息失败: {e}") logger.error(f"[{self.name}] 解析抖音视频信息失败: {e}")
logger.debug(f"失败的URL: {url}") logger.debug(f"失败的URL: {url}")
except Exception as e: except Exception as e:
@@ -110,7 +109,7 @@ class DouyinParser(BaseParser):
'Referer': 'https://www.douyin.com/' 'Referer': 'https://www.douyin.com/'
}) })
async with session.get(short_url, headers=mobile_headers, allow_redirects=True, timeout=10) as response: async with session.get(short_url, headers=mobile_headers, allow_redirects=True, timeout=aiohttp.ClientTimeout(total=10)) as response:
redirected_url = str(response.url) redirected_url = str(response.url)
# 检查重定向后的URL是否包含视频ID # 检查重定向后的URL是否包含视频ID

View File

@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re import re
import json
import aiohttp import aiohttp
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from cachetools import TTLCache from cachetools import TTLCache
@@ -60,7 +59,7 @@ class GitHubParser(BaseParser):
""" """
try: try:
session = self.get_session() session = self.get_session()
async with session.head(short_url, headers=self.HEADERS, allow_redirects=False, timeout=5) as response: async with session.head(short_url, headers=self.HEADERS, allow_redirects=False, timeout=aiohttp.ClientTimeout(total=5)) as response:
if response.status == 302: if response.status == 302:
return response.headers.get('Location') return response.headers.get('Location')
except Exception as e: except Exception as e:
@@ -86,7 +85,7 @@ class GitHubParser(BaseParser):
api_url = f"https://api.github.com/repos/{owner}/{repo}" api_url = f"https://api.github.com/repos/{owner}/{repo}"
try: try:
session = self.get_session() session = self.get_session()
async with session.get(api_url, timeout=10) as response: async with session.get(api_url, timeout=aiohttp.ClientTimeout(total=10)) as response:
response.raise_for_status() response.raise_for_status()
repo_data = await response.json() repo_data = await response.json()
@@ -97,7 +96,7 @@ class GitHubParser(BaseParser):
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logger.error(f"[{self.name}] GitHub API请求失败: {e}") logger.error(f"[{self.name}] GitHub API请求失败: {e}")
except json.JSONDecodeError as e: except ValueError as e:
logger.error(f"[{self.name}] 解析GitHub API响应失败: {e}") logger.error(f"[{self.name}] 解析GitHub API响应失败: {e}")
except Exception as e: except Exception as e:
logger.error(f"[{self.name}] 获取仓库信息时发生未知错误: {e}") logger.error(f"[{self.name}] 获取仓库信息时发生未知错误: {e}")

View File

@@ -1,10 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re import re
import json from typing import Dict, Any, List
from typing import Optional, Dict, Any, Union, List
from core.utils.logger import logger from models import MessageEvent
from models import MessageEvent, MessageSegment
def format_duration(seconds: int) -> str: def format_duration(seconds: int) -> str:

View File

@@ -42,7 +42,6 @@ os.environ['PERFORMANCE_THRESHOLD'] = str(args.threshold)
os.environ['PERFORMANCE_STATS'] = '1' if args.stats else '0' os.environ['PERFORMANCE_STATS'] = '1' if args.stats else '0'
# 导入并运行主程序 # 导入并运行主程序
from core.utils.performance import profile, aprofile
from main import main from main import main
import asyncio import asyncio

View File

@@ -369,21 +369,17 @@ if __name__ == '__main__':
""" """
import os import os
import sys import sys
import glob
import subprocess
import shutil
import argparse
# 检测当前平台 # 检测当前平台
PLATFORM = sys.platform PLATFORM = sys.platform
if PLATFORM.startswith('win'): if PLATFORM.startswith('win'):
EXTENSION = '.pyd' EXTENSION = '.pyd'
BUILD_PREFIX = 'cp314-win_amd64' BUILD_PREFIX = 'cp314-win_amd64'
BUILD_PATH = os.path.join('build', f'lib.win-amd64-cpython-314') BUILD_PATH = os.path.join('build', 'lib.win-amd64-cpython-314')
elif PLATFORM.startswith('linux'): elif PLATFORM.startswith('linux'):
EXTENSION = '.so' EXTENSION = '.so'
BUILD_PREFIX = 'cp314-x86_64-linux-gnu' BUILD_PREFIX = 'cp314-x86_64-linux-gnu'
BUILD_PATH = os.path.join('build', f'lib.linux-x86_64-cpython-314') BUILD_PATH = os.path.join('build', 'lib.linux-x86_64-cpython-314')
else: else:
print(f"不支持的平台: {PLATFORM}") print(f"不支持的平台: {PLATFORM}")
sys.exit(1) sys.exit(1)

View File

@@ -10,8 +10,7 @@ from core.api.group import GroupAPI
from core.api.media import MediaAPI from core.api.media import MediaAPI
from core.api.message import MessageAPI from core.api.message import MessageAPI
from models.objects import ( from models.objects import (
LoginInfo, VersionInfo, Status, StrangerInfo, FriendInfo, LoginInfo, VersionInfo, Status
GroupInfo, GroupMemberInfo, GroupHonorInfo
) )
from models.message import MessageSegment from models.message import MessageSegment

View File

@@ -1,5 +1,4 @@
import pytest import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from core.managers.command_manager import CommandManager from core.managers.command_manager import CommandManager
from models.events.message import GroupMessageEvent from models.events.message import GroupMessageEvent

View File

@@ -1,6 +1,4 @@
import pytest import pytest
import tomllib
from pathlib import Path
from core.config_loader import Config from core.config_loader import Config
from core.config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel from core.config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel

View File

@@ -2,7 +2,7 @@ import asyncio
import pytest import pytest
from unittest.mock import MagicMock, patch, AsyncMock from unittest.mock import MagicMock, patch, AsyncMock
import docker import docker
from core.utils.executor import CodeExecutor, initialize_executor from core.utils.executor import CodeExecutor
# Mock 配置对象 # Mock 配置对象
@pytest.fixture @pytest.fixture
@@ -136,8 +136,10 @@ async def test_worker_docker_errors(executor):
await executor.task_queue.join() await executor.task_queue.join()
callback.assert_called_with(f"执行失败:沙箱基础镜像 '{executor.sandbox_image}' 不存在,请联系管理员构建。") callback.assert_called_with(f"执行失败:沙箱基础镜像 '{executor.sandbox_image}' 不存在,请联系管理员构建。")
worker_task.cancel() worker_task.cancel()
try: await worker_task try:
except: pass await worker_task
except Exception:
pass
# ContainerError # ContainerError
executor._run_in_container = MagicMock(side_effect=docker.errors.ContainerError( executor._run_in_container = MagicMock(side_effect=docker.errors.ContainerError(
@@ -150,8 +152,10 @@ async def test_worker_docker_errors(executor):
await executor.task_queue.join() await executor.task_queue.join()
callback.assert_called_with("代码执行出错:\nError output") callback.assert_called_with("代码执行出错:\nError output")
worker_task.cancel() worker_task.cancel()
try: await worker_task try:
except: pass await worker_task
except Exception:
pass
def test_run_in_container_success(executor): def test_run_in_container_success(executor):
"""测试 _run_in_container 成功""" """测试 _run_in_container 成功"""

View File

@@ -1,4 +1,3 @@
import pytest
from models.message import MessageSegment from models.message import MessageSegment
from models.objects import GroupInfo, StrangerInfo from models.objects import GroupInfo, StrangerInfo

View File

@@ -8,7 +8,6 @@
import asyncio import asyncio
import time import time
import pytest import pytest
from typing import Optional
# 导入性能分析工具 # 导入性能分析工具
from core.utils.performance import ( from core.utils.performance import (
@@ -66,7 +65,6 @@ class TestProfileContextManager:
"""测试同步代码的性能分析""" """测试同步代码的性能分析"""
# 捕获标准输出 # 捕获标准输出
import io import io
import sys
from contextlib import redirect_stdout from contextlib import redirect_stdout
f = io.StringIO() f = io.StringIO()
@@ -254,7 +252,7 @@ if __name__ == "__main__":
async_result = asyncio.run(test_async()) async_result = asyncio.run(test_async())
slow_result = slow_func() slow_result = slow_func()
print(f"\n测试结果:") print("\n测试结果:")
print(f"sync_result: {sync_result}") print(f"sync_result: {sync_result}")
print(f"async_result: {async_result}") print(f"async_result: {async_result}")
print(f"slow_result: {slow_result}") print(f"slow_result: {slow_result}")

View File

@@ -2,7 +2,6 @@
import sys import sys
import pytest import pytest
from unittest.mock import MagicMock, patch, call from unittest.mock import MagicMock, patch, call
import core.managers.plugin_manager as pm_module
from core.managers.plugin_manager import PluginManager from core.managers.plugin_manager import PluginManager
from core.managers.command_manager import CommandManager from core.managers.command_manager import CommandManager

View File

@@ -1,5 +1,5 @@
import pytest import pytest
from unittest.mock import MagicMock, patch, AsyncMock from unittest.mock import patch, AsyncMock
from core.managers.redis_manager import RedisManager from core.managers.redis_manager import RedisManager

View File

@@ -1,9 +1,7 @@
import pytest import pytest
import asyncio
from unittest.mock import MagicMock, AsyncMock, patch from unittest.mock import MagicMock, AsyncMock, patch
from core.ws import WS from core.ws import WS
from core.bot import Bot from core.bot import Bot
from models.objects import GroupInfo, StrangerInfo
class TestWS: class TestWS:
@@ -38,7 +36,7 @@ class TestWS:
# 测试 WebSocket 未初始化的情况 # 测试 WebSocket 未初始化的情况
result = await ws.call_api("send_group_msg", {"group_id": 123456, "message": "test"}) result = await ws.call_api("send_group_msg", {"group_id": 123456, "message": "test"})
assert result["code"] == 2002 # WS_DISCONNECTED assert result["code"] == 2002 # WS_DISCONNECTED
assert result["success"] == False assert not result["success"]
assert "WebSocket未初始化" in result["message"] assert "WebSocket未初始化" in result["message"]
# 测试 WebSocket 已初始化但未连接的情况 # 测试 WebSocket 已初始化但未连接的情况
@@ -47,7 +45,7 @@ class TestWS:
ws.ws = mock_ws ws.ws = mock_ws
result = await ws.call_api("send_group_msg", {"group_id": 123456, "message": "test"}) result = await ws.call_api("send_group_msg", {"group_id": 123456, "message": "test"})
assert result["code"] == 2002 # WS_DISCONNECTED assert result["code"] == 2002 # WS_DISCONNECTED
assert result["success"] == False assert not result["success"]
assert "WebSocket连接未打开" in result["message"] assert "WebSocket连接未打开" in result["message"]
@pytest.mark.asyncio @pytest.mark.asyncio

234
tests/test_ws_pool.py Normal file
View File

@@ -0,0 +1,234 @@
"""
WebSocket 连接池测试模块
该模块包含对 WebSocket 连接池的单元测试和集成测试。
"""
import pytest
import asyncio
from unittest.mock import Mock, patch, MagicMock
from core.ws_pool import WSConnection, WSConnectionPool
from core.utils.exceptions import WebSocketError, WebSocketConnectionError
class TestWSConnection:
"""
WebSocket 连接包装类测试
"""
def test_connection_initialization(self):
"""测试连接初始化"""
mock_conn = Mock()
conn_id = "test-connection-id"
conn = WSConnection(mock_conn, conn_id)
assert conn.conn == mock_conn
assert conn.conn_id == conn_id
assert conn.is_active
assert conn._pending_requests == {}
assert isinstance(conn.last_used, float)
@pytest.mark.asyncio
async def test_send_data(self):
"""测试发送数据"""
mock_conn = Mock()
mock_conn.send = Mock(return_value=asyncio.coroutine(lambda x: None)())
conn = WSConnection(mock_conn, "test-id")
data = {"action": "test", "params": {}}
await conn.send(data)
mock_conn.send.assert_called_once()
assert conn.last_used > 0
@pytest.mark.asyncio
async def test_send_data_inactive_connection(self):
"""测试向已关闭的连接发送数据"""
mock_conn = Mock()
conn = WSConnection(mock_conn, "test-id")
conn.is_active = False
with pytest.raises(WebSocketError):
await conn.send({"action": "test"})
@pytest.mark.asyncio
async def test_recv_data(self):
"""测试接收数据"""
mock_conn = Mock()
mock_conn.recv = Mock(return_value=asyncio.coroutine(lambda: "test-data")())
conn = WSConnection(mock_conn, "test-id")
result = await conn.recv()
assert result == "test-data"
mock_conn.recv.assert_called_once()
@pytest.mark.asyncio
async def test_close_connection(self):
"""测试关闭连接"""
mock_conn = Mock()
mock_conn.close = Mock(return_value=asyncio.coroutine(lambda: None)())
conn = WSConnection(mock_conn, "test-id")
await conn.close()
assert not conn.is_active
mock_conn.close.assert_called_once()
class TestWSConnectionPool:
"""
WebSocket 连接池测试
"""
@pytest.mark.asyncio
async def test_pool_initialization(self):
"""测试连接池初始化"""
pool = WSConnectionPool(pool_size=2, max_idle_time=300)
assert pool.pool_size == 2
assert pool.max_idle_time == 300
assert not pool._closed
assert pool.pool is not None
@pytest.mark.asyncio
@patch('websockets.connect')
async def test_create_connection(self, mock_connect):
"""测试创建新连接"""
mock_websocket = Mock()
mock_connect.return_value = asyncio.coroutine(lambda: mock_websocket)()
pool = WSConnectionPool(pool_size=1)
conn = await pool._create_connection()
assert isinstance(conn, WSConnection)
assert conn.is_active
mock_connect.assert_called_once()
@pytest.mark.asyncio
@patch('websockets.connect')
async def test_pool_initialize(self, mock_connect):
"""测试连接池初始化"""
mock_websocket = Mock()
mock_connect.return_value = asyncio.coroutine(lambda: mock_websocket)()
pool = WSConnectionPool(pool_size=2)
await pool.initialize()
assert pool.pool.qsize() == 2
mock_connect.assert_called()
@pytest.mark.asyncio
@patch('websockets.connect')
async def test_get_connection(self, mock_connect):
"""测试从连接池获取连接"""
mock_websocket = Mock()
mock_connect.return_value = asyncio.coroutine(lambda: mock_websocket)()
pool = WSConnectionPool(pool_size=1)
await pool.initialize()
conn = await pool.get_connection()
assert isinstance(conn, WSConnection)
assert conn.is_active
assert pool.pool.qsize() == 0
@pytest.mark.asyncio
@patch('websockets.connect')
async def test_release_connection(self, mock_connect):
"""测试释放连接回连接池"""
mock_websocket = Mock()
mock_connect.return_value = asyncio.coroutine(lambda: mock_websocket)()
pool = WSConnectionPool(pool_size=1)
await pool.initialize()
conn = await pool.get_connection()
await pool.release_connection(conn)
assert pool.pool.qsize() == 1
@pytest.mark.asyncio
@patch('websockets.connect')
async def test_release_inactive_connection(self, mock_connect):
"""测试释放已关闭的连接"""
mock_websocket = Mock()
mock_connect.return_value = asyncio.coroutine(lambda: mock_websocket)()
pool = WSConnectionPool(pool_size=1)
await pool.initialize()
conn = await pool.get_connection()
conn.is_active = False
await pool.release_connection(conn)
assert pool.pool.qsize() == 0
@pytest.mark.asyncio
@patch('websockets.connect')
async def test_cleanup_idle_connections(self, mock_connect):
"""测试清理空闲连接"""
mock_websocket = Mock()
mock_connect.return_value = asyncio.coroutine(lambda: mock_websocket)()
pool = WSConnectionPool(pool_size=2, max_idle_time=0.1)
await pool.initialize()
# 等待清理任务执行
await asyncio.sleep(0.2)
# 检查连接池是否为空
assert pool.pool.qsize() == 0
@pytest.mark.asyncio
@patch('websockets.connect')
async def test_pool_close(self, mock_connect):
"""测试关闭连接池"""
mock_websocket = Mock()
mock_websocket.close = Mock(return_value=asyncio.coroutine(lambda: None)())
mock_connect.return_value = asyncio.coroutine(lambda: mock_websocket)()
pool = WSConnectionPool(pool_size=2)
await pool.initialize()
await pool.close()
assert pool._closed
assert pool.pool.qsize() == 0
mock_websocket.close.assert_called()
@pytest.mark.asyncio
async def test_get_connection_from_closed_pool(self):
"""测试从已关闭的连接池获取连接"""
pool = WSConnectionPool(pool_size=1)
pool._closed = True
with pytest.raises(WebSocketError):
await pool.get_connection()
@pytest.mark.asyncio
@patch('websockets.connect')
async def test_pool_with_max_size(self, mock_connect):
"""测试连接池大小限制"""
mock_websocket = Mock()
mock_connect.return_value = asyncio.coroutine(lambda: mock_websocket)()
pool = WSConnectionPool(pool_size=2)
await pool.initialize()
# 获取两个连接
conn1 = await pool.get_connection()
conn2 = await pool.get_connection()
# 第三个连接会创建临时连接
conn3 = await pool.get_connection()
# 释放所有连接
await pool.release_connection(conn1)
await pool.release_connection(conn2)
await pool.release_connection(conn3)
# 连接池应保持最大大小
assert pool.pool.qsize() == 2
if __name__ == "__main__":
pytest.main([__file__])