feat: 重构核心架构,增强类型安全与插件管理

本次提交对核心模块进行了深度重构,引入 Pydantic 增强配置管理的类型安全性,并全面优化了插件管理系统。

主要变更详情:

1. 核心架构与配置
   - 重构配置加载模块:引入 Pydantic 模型 (`core/config_models.py`),提供严格的配置项类型检查、验证及默认值管理。
   - 统一模块结构:规范化模块导入路径,移除冗余的 `__init__.py` 文件,提升项目结构的清晰度。
   - 性能优化:集成 Redis 缓存支持 (`RedisManager`),有效降低高频 API 调用开销,提升响应速度。

2. 插件系统升级
   - 实现热重载机制:新增插件文件变更监听功能,支持开发过程中自动重载插件,提升开发效率。
   - 优化生命周期管理:改进插件加载与卸载逻辑,支持精确卸载指定插件及其关联的命令、事件处理器和定时任务。

3. 功能特性增强
   - 新增媒体 API:引入 `MediaAPI` 模块,封装图片、语音等富媒体资源的获取与处理接口。
   - 完善权限体系:重构权限管理系统,实现管理员与操作员的分级控制,支持更细粒度的命令权限校验。

4. 代码质量与稳定性
   - 全面类型修复:解决 `mypy` 静态类型检查发现的大量类型错误(包括 `CommandManager`、`EventFactory` 及 `Bot` API 签名不匹配问题)。
   - 增强错误处理:优化消息处理管道的异常捕获机制,完善关键路径的日志记录,提升系统运行稳定性。
This commit is contained in:
2026-01-08 23:42:53 +08:00
parent c2de743098
commit 5d07a84283
35 changed files with 829 additions and 608 deletions

View File

@@ -1,6 +0,0 @@
from .managers.command_manager import matcher
from .config_loader import global_config
from .managers.plugin_manager import PluginDataManager
from .ws import WS
__all__ = ["WS", "matcher", "global_config", "PluginDataManager"]

View File

@@ -3,6 +3,7 @@ from .message import MessageAPI
from .group import GroupAPI
from .friend import FriendAPI
from .account import AccountAPI
from .media import MediaAPI
__all__ = [
"BaseAPI",
@@ -10,4 +11,5 @@ __all__ = [
"GroupAPI",
"FriendAPI",
"AccountAPI",
"MediaAPI",
]

View File

@@ -162,3 +162,56 @@ class AccountAPI(BaseAPI):
"""
return await self.call_api("clean_cache")
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> Any:
"""
获取陌生人信息。
Args:
user_id (int): 目标用户的 QQ 号。
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
Returns:
Any: 包含陌生人信息的字典或对象。
"""
return await self.call_api("get_stranger_info", {"user_id": user_id, "no_cache": no_cache})
async def get_friend_list(self, no_cache: bool = False) -> list:
"""
获取好友列表。
Args:
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
Returns:
list: 好友列表。
"""
cache_key = f"neobot:cache:get_friend_list:{self.self_id}"
if not no_cache:
cached_data = await redis_manager.get(cache_key)
if cached_data:
return json.loads(cached_data)
res = await self.call_api("get_friend_list")
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
return res
async def get_group_list(self, no_cache: bool = False) -> list:
"""
获取群列表。
Args:
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
Returns:
list: 群列表。
"""
cache_key = f"neobot:cache:get_group_list:{self.self_id}"
if not no_cache:
cached_data = await redis_manager.get(cache_key)
if cached_data:
return json.loads(cached_data)
res = await self.call_api("get_group_list")
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
return res

View File

@@ -1,24 +1,50 @@
"""
API 基础模块
定义了 API 调用的基础接口。
定义了 API 调用的基础接口和统一处理逻辑
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, TYPE_CHECKING
from ..utils.logger import logger
if TYPE_CHECKING:
from ..ws import WS
class BaseAPI(ABC):
class BaseAPI:
"""
API 基础抽象
API 基础类,提供了统一的 `call_api` 方法,包含日志记录和异常处理。
"""
_ws: "WS"
self_id: int
def __init__(self, ws_client: "WS", self_id: int):
self._ws = ws_client
self.self_id = self_id
@abstractmethod
async def call_api(self, action: str, params: Optional[Dict[str, Any]] = None) -> Any:
"""
调用 API
调用 OneBot v11 API并提供统一的日志和异常处理。
:param action: API 动作名称
:param params: API 参数
:return: API 响应结果
:return: API 响应结果的数据部分
:raises Exception: 当 API 调用失败或发生网络错误时
"""
raise NotImplementedError
if params is None:
params = {}
try:
logger.debug(f"调用API -> action: {action}, params: {params}")
response = await self._ws.call_api(action, params)
logger.debug(f"API响应 <- {response}")
if response.get("status") == "failed":
logger.warning(f"API调用失败: {response}")
return response.get("data")
except Exception as e:
logger.error(f"API调用异常: action={action}, params={params}, error={e}")
raise

View File

@@ -4,7 +4,7 @@
该模块定义了 `GroupAPI` Mixin 类,提供了所有与群组管理、成员操作
等相关的 OneBot v11 API 封装。
"""
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional
import json
from ..managers.redis_manager import redis_manager
from .base import BaseAPI
@@ -46,7 +46,7 @@ class GroupAPI(BaseAPI):
"""
return await self.call_api("set_group_ban", {"group_id": group_id, "user_id": user_id, "duration": duration})
async def set_group_anonymous_ban(self, group_id: int, anonymous: Dict[str, Any] = None, duration: int = 1800, flag: str = None) -> Dict[str, Any]:
async def set_group_anonymous_ban(self, group_id: int, anonymous: Optional[Dict[str, Any]] = None, duration: int = 1800, flag: Optional[str] = None) -> Dict[str, Any]:
"""
禁言群组中的匿名用户。
@@ -61,7 +61,7 @@ class GroupAPI(BaseAPI):
Returns:
Dict[str, Any]: OneBot API 的响应数据。
"""
params = {"group_id": group_id, "duration": duration}
params: Dict[str, Any] = {"group_id": group_id, "duration": duration}
if anonymous:
params["anonymous"] = anonymous
if flag:
@@ -187,17 +187,18 @@ class GroupAPI(BaseAPI):
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
return GroupInfo(**res)
async def get_group_list(self) -> List[GroupInfo]:
async def get_group_list(self) -> Any:
"""
获取机器人加入的所有群组的列表。
Returns:
List[GroupInfo]: 包含所有群组信息的 `GroupInfo` 对象列表。
Any: 包含所有群组信息的列表(可能是字典列表或对象列表
"""
res = await self.call_api("get_group_list")
# 增加日志记录 API 原始返回
logger.debug(f"OneBot API 'get_group_list' raw response: {res}")
return res
# 健壮性处理:处理标准的 OneBot v11 响应格式
if isinstance(res, dict) and res.get("status") == "ok":

39
core/api/media.py Normal file
View File

@@ -0,0 +1,39 @@
"""
媒体API模块
封装了与图片、语音等媒体文件相关的API。
"""
from typing import Dict, Any
from .base import BaseAPI
class MediaAPI(BaseAPI):
"""
媒体相关API
"""
async def can_send_image(self) -> Dict[str, Any]:
"""
检查是否可以发送图片
:return: OneBot v11标准响应
"""
return await self.call_api(action="can_send_image")
async def can_send_record(self) -> Dict[str, Any]:
"""
检查是否可以发送语音
:return: OneBot v11标准响应
"""
return await self.call_api(action="can_send_record")
async def get_image(self, file: str) -> Dict[str, Any]:
"""
获取图片信息
:param file: 图片文件名或路径
:return: OneBot v11标准响应
"""
return await self.call_api(action="get_image", params={"file": file})

View File

@@ -8,7 +8,8 @@ from typing import Union, List, Dict, Any, TYPE_CHECKING
from .base import BaseAPI
if TYPE_CHECKING:
from models import MessageSegment, OneBotEvent
from models.message import MessageSegment
from models.events.base import OneBotEvent
class MessageAPI(BaseAPI):
@@ -156,24 +157,6 @@ class MessageAPI(BaseAPI):
"""
return await self.call_api("send_private_forward_msg", {"user_id": user_id, "messages": messages})
async def can_send_image(self) -> Dict[str, Any]:
"""
检查当前机器人账号是否可以发送图片。
Returns:
Dict[str, Any]: OneBot API 的响应数据。
"""
return await self.call_api("can_send_image")
async def can_send_record(self) -> Dict[str, Any]:
"""
检查当前机器人账号是否可以发送语音。
Returns:
Dict[str, Any]: OneBot API 的响应数据。
"""
return await self.call_api("can_send_record")
def _process_message(self, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Union[str, List[Dict[str, Any]]]:
"""
内部方法:将消息内容处理成 OneBot API 可接受的格式。
@@ -192,7 +175,7 @@ class MessageAPI(BaseAPI):
return message
# 避免循环导入,在运行时导入
from models import MessageSegment
from models.message import MessageSegment
if isinstance(message, MessageSegment):
return [self._segment_to_dict(message)]

View File

@@ -13,14 +13,15 @@ Bot 核心抽象模块
from typing import TYPE_CHECKING, Dict, Any, List, Union
from models.events.base import OneBotEvent
from models.message import MessageSegment
from models.objects import GroupInfo, StrangerInfo
if TYPE_CHECKING:
from .WS import WS
from .ws import WS
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI
class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI):
class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI):
"""
机器人核心类,封装了所有与 OneBot API 的交互。
@@ -35,22 +36,22 @@ class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI):
Args:
ws_client (WS): WebSocket 客户端实例,负责底层的 API 请求和响应处理。
"""
self.ws = ws_client
super().__init__(ws_client, ws_client.self_id or 0)
self.code_executor = None
async def call_api(self, action: str, params: Dict[str, Any] = None) -> Any:
"""
底层 API 调用方法。
async def get_group_list(self, no_cache: bool = False) -> List[GroupInfo]:
# GroupAPI.get_group_list 不支持 no_cache 参数,这里忽略它
result = await super().get_group_list()
# 确保结果是 GroupInfo 对象列表
return [GroupInfo(**group) if isinstance(group, dict) else group for group in result]
所有具体的 API 实现最终都会调用此方法,通过 WebSocket 发送请求。
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> StrangerInfo:
result = await super().get_stranger_info(user_id=user_id, no_cache=no_cache)
# 确保结果是 StrangerInfo 对象
if isinstance(result, dict):
return StrangerInfo(**result)
return result
Args:
action (str): API 的动作名称,例如 "send_group_msg"
params (Dict[str, Any], optional): API 请求的参数字典。Defaults to None.
Returns:
Any: OneBot API 的响应数据。
"""
return await self.ws.call_api(action, params)
def build_forward_node(self, user_id: int, nickname: str, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Dict[str, Any]:
"""

View File

@@ -4,9 +4,11 @@
负责读取和解析 config.toml 配置文件,提供全局配置对象。
"""
from pathlib import Path
from typing import Any, Dict
import tomllib
from pydantic import ValidationError
from .config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel
from .utils.logger import logger
class Config:
@@ -21,73 +23,67 @@ class Config:
:param file_path: 配置文件路径,默认为 "config.toml"
"""
self.path = Path(file_path)
self._data: Dict[str, Any] = {}
self._model: ConfigModel
self.load()
def load(self):
"""
加载配置文件
加载并验证配置文件
:raises FileNotFoundError: 如果配置文件不存在
:raises ValidationError: 如果配置格式不正确
"""
if not self.path.exists():
logger.error(f"配置文件 {self.path} 未找到!")
raise FileNotFoundError(f"配置文件 {self.path} 未找到!")
with open(self.path, "rb") as f:
self._data = tomllib.load(f)
try:
logger.info(f"正在从 {self.path} 加载配置...")
with open(self.path, "rb") as f:
raw_config = tomllib.load(f)
self._model = ConfigModel(**raw_config)
logger.success("配置加载并验证成功!")
except ValidationError as e:
logger.error("配置验证失败,请检查 `config.toml` 文件中的以下错误:")
for error in e.errors():
field = " -> ".join(map(str, error["loc"]))
logger.error(f" - 字段 '{field}': {error['msg']}")
raise
except Exception as e:
logger.exception(f"加载配置文件时发生未知错误: {e}")
raise
# 通过属性访问配置
@property
def napcat_ws(self) -> dict:
def napcat_ws(self) -> NapCatWSModel:
"""
获取 NapCat WebSocket 配置
:return: 配置字典
"""
return self._data.get("napcat_ws", {})
return self._model.napcat_ws
@property
def bot(self) -> dict:
def bot(self) -> BotModel:
"""
获取 Bot 基础配置
:return: 配置字典
"""
return self._data.get("bot", {})
return self._model.bot
@property
def features(self) -> dict:
"""
获取功能特性配置
:return: 配置字典
"""
return self._data.get("features", {})
@property
def redis(self) -> dict:
def redis(self) -> RedisModel:
"""
获取 Redis 配置
:return: 配置字典
"""
return self._data.get("redis", {})
return self._model.redis
@property
def docker(self) -> dict:
def docker(self) -> DockerModel:
"""
获取 Docker 配置
:return: 配置字典
"""
return self._data.get("docker", {})
return self._model.docker
# 实例化全局配置对象
global_config = Config()
if __name__ == "__main__":
print(global_config.napcat_ws)
print(global_config.bot.get("command"))
print(type(global_config.bot.get("command")) is list)
print(global_config.features)

60
core/config_models.py Normal file
View File

@@ -0,0 +1,60 @@
"""
Pydantic 配置模型模块
该模块使用 Pydantic 定义了与 `config.toml` 文件结构完全对应的配置模型。
这使得配置的加载、校验和访问都变得类型安全和健壮。
"""
from typing import List, Optional
from pydantic import BaseModel, Field
class NapCatWSModel(BaseModel):
"""
对应 `config.toml` 中的 `[napcat_ws]` 配置块。
"""
uri: str
token: str
reconnect_interval: int = 5
class BotModel(BaseModel):
"""
对应 `config.toml` 中的 `[bot]` 配置块。
"""
command: List[str] = Field(default_factory=lambda: ["/"])
ignore_self_message: bool = True
permission_denied_message: str = "权限不足,需要 {permission_name} 权限"
class RedisModel(BaseModel):
"""
对应 `config.toml` 中的 `[redis]` 配置块。
"""
host: str
port: int
db: int
password: str
class DockerModel(BaseModel):
"""
对应 `config.toml` 中的 `[docker]` 配置块。
"""
base_url: Optional[str] = None
sandbox_image: str = "python-sandbox:latest"
timeout: int = 10
concurrency_limit: int = 5
tls_verify: bool = False
ca_cert_path: Optional[str] = None
client_cert_path: Optional[str] = None
client_key_path: Optional[str] = None
class ConfigModel(BaseModel):
"""
顶层配置模型,整合了所有子配置块。
"""
napcat_ws: NapCatWSModel
bot: BotModel
redis: RedisModel
docker: DockerModel

View File

@@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
from ..bot import Bot
from ..config_loader import global_config
from ..managers.permission_manager import Permission, permission_manager
from ..managers.permission_manager import Permission
from ..utils.executor import run_in_thread_pool
@@ -41,7 +41,7 @@ class BaseHandler(ABC):
"""
sig = inspect.signature(func)
params = sig.parameters
kwargs = {}
kwargs: Dict[str, Any] = {}
if "bot" in params:
kwargs["bot"] = bot
@@ -68,14 +68,35 @@ class MessageHandler(BaseHandler):
super().__init__()
self.prefixes = prefixes
self.commands: Dict[str, Dict] = {}
self.message_handlers: List[Callable] = []
self.message_handlers: List[Dict[str, Any]] = []
def clear(self):
"""
清空所有已注册的消息和命令处理器
"""
self.commands.clear()
self.message_handlers.clear()
def unregister_by_plugin_name(self, plugin_name: str):
"""
根据插件名卸载相关的消息和命令处理器
"""
# 卸载命令
commands_to_remove = [name for name, info in self.commands.items() if info["plugin_name"] == plugin_name]
for name in commands_to_remove:
del self.commands[name]
# 卸载通用消息处理器
self.message_handlers = [h for h in self.message_handlers if h["plugin_name"] != plugin_name]
def on_message(self) -> Callable:
"""
注册通用消息处理器
"""
def decorator(func: Callable) -> Callable:
self.message_handlers.append(func)
module = inspect.getmodule(func)
plugin_name = module.__name__ if module else "Unknown"
self.message_handlers.append({"func": func, "plugin_name": plugin_name})
return func
return decorator
@@ -89,21 +110,25 @@ class MessageHandler(BaseHandler):
注册命令处理器
"""
def decorator(func: Callable) -> Callable:
module = inspect.getmodule(func)
plugin_name = module.__name__ if module else "Unknown"
for name in names:
self.commands[name] = {
"func": func,
"permission": permission,
"override_permission_check": override_permission_check,
"plugin_name": plugin_name,
}
return func
return decorator
async def handle(self, bot: Bot, event: Any):
"""
处理消息事件,包括通用消息和命令
处理消息事件,分发给命令处理器或通用消息处理器
"""
for handler in self.message_handlers:
consumed = await self._run_handler(handler, bot, event)
from ..managers import permission_manager
for handler_info in self.message_handlers:
consumed = await self._run_handler(handler_info["func"], bot, event)
if consumed:
return
@@ -135,7 +160,7 @@ class MessageHandler(BaseHandler):
if not permission_granted and not override_check:
permission_name = permission.name if isinstance(permission, Permission) else permission
message_template = global_config.bot.get("permission_denied_message", "权限不足,需要 {permission_name} 权限")
message_template = global_config.bot.permission_denied_message
await bot.send(event, message_template.format(permission_name=permission_name))
return
@@ -152,12 +177,23 @@ class NoticeHandler(BaseHandler):
"""
通知事件处理器
"""
def clear(self):
self.handlers.clear()
def unregister_by_plugin_name(self, plugin_name: str):
"""
根据插件名卸载相关的通知处理器
"""
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
def register(self, notice_type: Optional[str] = None) -> Callable:
"""
注册通知处理器
"""
def decorator(func: Callable) -> Callable:
self.handlers.append({"type": notice_type, "func": func})
module = inspect.getmodule(func)
plugin_name = module.__name__ if module else "Unknown"
self.handlers.append({"type": notice_type, "func": func, "plugin_name": plugin_name})
return func
return decorator
@@ -174,12 +210,23 @@ class RequestHandler(BaseHandler):
"""
请求事件处理器
"""
def clear(self):
self.handlers.clear()
def unregister_by_plugin_name(self, plugin_name: str):
"""
根据插件名卸载相关的请求处理器
"""
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
def register(self, request_type: Optional[str] = None) -> Callable:
"""
注册请求处理器
"""
def decorator(func: Callable) -> Callable:
self.handlers.append({"type": request_type, "func": func})
module = inspect.getmodule(func)
plugin_name = module.__name__ if module else "Unknown"
self.handlers.append({"type": request_type, "func": func, "plugin_name": plugin_name})
return func
return decorator

View File

@@ -0,0 +1,40 @@
"""
管理器包
这个包集中了机器人核心的单例管理器。
通过从这里导入,可以确保在整个应用中访问到的都是同一个实例。
"""
from ..config_loader import global_config
from .admin_manager import AdminManager
from .command_manager import CommandManager
from .permission_manager import PermissionManager
from .plugin_manager import PluginManager
from .redis_manager import RedisManager
# --- 实例化所有单例管理器 ---
# 管理员管理器
admin_manager = AdminManager()
# 权限管理器
permission_manager = PermissionManager()
# 命令与事件管理器 (别名 matcher)
command_manager = CommandManager(prefixes=tuple(global_config.bot.command))
matcher = command_manager
# 插件管理器
plugin_manager = PluginManager(command_manager)
plugin_manager.load_all_plugins()
# Redis 管理器
redis_manager = RedisManager()
__all__ = [
"admin_manager",
"permission_manager",
"command_manager",
"matcher",
"plugin_manager",
"redis_manager",
]

View File

@@ -12,7 +12,16 @@ from ..handlers.event_handler import MessageHandler, NoticeHandler, RequestHandl
# 从配置中获取命令前缀
command_prefixes = global_config.bot.get("command", ("/",))
_config_prefixes = global_config.bot.command
# 确保前缀配置是元组格式
_final_prefixes: Tuple[str, ...]
if isinstance(_config_prefixes, list):
_final_prefixes = tuple(_config_prefixes)
elif isinstance(_config_prefixes, str):
_final_prefixes = (_config_prefixes,)
else:
_final_prefixes = tuple(_config_prefixes)
class CommandManager:
@@ -59,6 +68,35 @@ class CommandManager:
"usage": "/help",
}
def clear_all_handlers(self):
"""
清空所有已注册的事件处理器。
注意:这也会移除内置的 /help 命令,因此需要重新注册。
"""
self.message_handler.clear()
self.notice_handler.clear()
self.request_handler.clear()
self.plugins.clear()
# 清空后,需要重新注册内置命令
self._register_internal_commands()
def unload_plugin(self, plugin_name: str):
"""
卸载指定插件的所有处理器和命令。
Args:
plugin_name (str): 插件的模块名 (例如 'plugins.bili_parser')
"""
self.message_handler.unregister_by_plugin_name(plugin_name)
self.notice_handler.unregister_by_plugin_name(plugin_name)
self.request_handler.unregister_by_plugin_name(plugin_name)
# 移除插件元信息
plugins_to_remove = [name for name in self.plugins if name.startswith(plugin_name)]
for name in plugins_to_remove:
del self.plugins[name]
# --- 装饰器代理 ---
def on_message(self) -> Callable:
@@ -102,7 +140,7 @@ class CommandManager:
根据事件的 `post_type` 将其分发给对应的处理器。
"""
if event.post_type == 'message' and global_config.bot.get('ignore_self_message', False):
if event.post_type == 'message' and global_config.bot.ignore_self_message:
if hasattr(event, 'user_id') and hasattr(event, 'self_id') and event.user_id == event.self_id:
return
@@ -130,14 +168,6 @@ class CommandManager:
await bot.send(event, help_text.strip())
# --- 全局单例 ---
# 确保前缀配置是元组格式
if isinstance(command_prefixes, list):
command_prefixes = tuple(command_prefixes)
elif isinstance(command_prefixes, str):
command_prefixes = (command_prefixes,)
# 实例化全局唯一的命令管理器
matcher = CommandManager(prefixes=command_prefixes)
matcher = CommandManager(prefixes=_final_prefixes)

View File

@@ -13,64 +13,17 @@
"""
import json
import os
from functools import total_ordering
from typing import Dict
from typing import Dict, Optional
from ..utils.logger import logger
from ..utils.singleton import Singleton
from .admin_manager import admin_manager
from ..permission import Permission
@total_ordering
class Permission:
"""
权限封装类
封装了权限的名称和等级,并提供了比较方法。
使用 @total_ordering 装饰器可以自动生成所有的比较运算符。
"""
def __init__(self, name: str, level: int):
"""
初始化权限对象
Args:
name (str): 权限名称 (e.g., "admin", "op")
level (int): 权限等级,数字越大权限越高
"""
self.name = name
self.level = level
def __eq__(self, other):
"""
判断权限是否相等
"""
if not isinstance(other, Permission):
return NotImplemented
return self.level == other.level
def __lt__(self, other):
"""
判断权限是否小于另一个权限
"""
if not isinstance(other, Permission):
return NotImplemented
return self.level < other.level
def __str__(self) -> str:
"""
返回权限的字符串表示(即权限名称)
"""
return self.name
# 定义全局权限常量
ADMIN = Permission("admin", 3)
OP = Permission("op", 2)
USER = Permission("user", 1)
# 用于从字符串名称查找权限对象的字典
_PERMISSIONS: Dict[str, Permission] = {
p.name: p for p in [ADMIN, OP, USER]
p.value: p for p in Permission
}
@@ -89,7 +42,7 @@ class PermissionManager(Singleton):
如果已经初始化过,则直接返回。
"""
super().__init__()
if not self._initialized:
if hasattr(self, '_initialized') and self._initialized:
return
# 权限数据文件路径
@@ -111,6 +64,7 @@ class PermissionManager(Singleton):
self.load()
logger.info("权限管理器初始化完成")
self._initialized = True
def load(self) -> None:
"""
@@ -164,12 +118,12 @@ class PermissionManager(Singleton):
"""
# 首先,通过 AdminManager 检查是否为管理员
if await admin_manager.is_admin(user_id):
return ADMIN
return Permission.ADMIN
# 如果不是管理员,则从 permissions.json 中查找
user_id_str = str(user_id)
level_name = self._data["users"].get(user_id_str, USER.name)
return _PERMISSIONS.get(level_name, USER)
level_name = self._data["users"].get(user_id_str, Permission.USER.value)
return _PERMISSIONS.get(level_name, Permission.USER)
def set_user_permission(self, user_id: int, permission: Permission) -> None:
"""
@@ -182,13 +136,13 @@ class PermissionManager(Singleton):
Raises:
ValueError: 如果权限对象无效
"""
if not isinstance(permission, Permission) or permission.name not in _PERMISSIONS:
if not isinstance(permission, Permission):
raise ValueError(f"无效的权限对象: {permission}")
user_id_str = str(user_id)
self._data["users"][user_id_str] = permission.name
self._data["users"][user_id_str] = permission.value
self.save()
logger.info(f"设置用户 {user_id} 的权限级别为 {permission.name}")
logger.info(f"设置用户 {user_id} 的权限级别为 {permission.value}")
def remove_user(self, user_id: int) -> None:
"""
@@ -214,17 +168,17 @@ class PermissionManager(Singleton):
Returns:
bool: 如果用户权限 >= 所需权限,返回 True否则返回 False
"""
# 如果传入的是字符串,先转换为 Permission 对象
if isinstance(required_permission, str):
required_permission = _PERMISSIONS.get(required_permission.lower())
if not required_permission:
# 如果是无效的权限字符串,默认拒绝
logger.warning(f"检测到无效的权限检查字符串: {required_permission}")
return False
user_permission = await self.get_user_permission(user_id)
return user_permission >= required_permission
def get_all_user_permissions(self) -> Dict[str, str]:
"""
获取所有已配置的用户权限
:return: 一个包含所有用户权限的字典
"""
return self._data["users"].copy()
def get_all_users(self) -> Dict[str, str]:
"""
获取所有设置了权限的用户及其级别名称
@@ -243,22 +197,22 @@ class PermissionManager(Singleton):
logger.info("已清空所有权限设置")
# 全局权限管理器实例
permission_manager = PermissionManager()
def require_admin(func):
"""
一个装饰器,用于限制命令只能由管理员执行。
"""
from functools import wraps
from models.events.message import MessageEvent
from core.managers import permission_manager
@wraps(func)
async def wrapper(event: MessageEvent, *args, **kwargs):
user_id = event.user_id
if await permission_manager.check_permission(user_id, ADMIN):
if await permission_manager.check_permission(user_id, Permission.ADMIN):
return await func(event, *args, **kwargs)
else:
await event.reply("抱歉,您没有权限执行此命令。")
# 假设 event 对象有 reply 方法
if hasattr(event, "reply"):
await event.reply("抱歉,您没有权限执行此命令。")
return None
return wrapper

View File

@@ -1,126 +1,88 @@
"""
插件管理器模块
负责扫描、加载和管理 `base_plugins` 目录下的所有插件。
负责扫描、加载和管理 `plugins` 目录下的所有插件。
"""
import importlib
import json
import os
import pkgutil
import sys
from typing import Set
from .command_manager import matcher
from ..utils.exceptions import SyncHandlerError
from ..utils.logger import logger
from ..utils.executor import run_in_thread_pool
def load_all_plugins():
class PluginManager:
"""
扫描并加载 `plugins` 目录下的所有插件。
该函数会遍历 `plugins` 目录下的所有模块:
1. 如果模块已加载,则执行 reload 操作(用于热重载)。
2. 如果模块未加载,则执行 import 操作。
加载过程中会提取插件元数据 `__plugin_meta__` 并注册到 CommandManager。
插件管理器类
"""
plugin_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "..", "plugins"
)
package_name = "plugins"
def __init__(self, command_manager):
"""
初始化插件管理器
logger.info(f"正在从 {package_name} 加载插件...")
:param command_manager: CommandManager的实例
"""
self.command_manager = command_manager
self.loaded_plugins: Set[str] = set()
for loader, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
full_module_name = f"{package_name}.{module_name}"
def load_all_plugins(self):
"""
扫描并加载 `plugins` 目录下的所有插件。
"""
plugin_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "..", "plugins"
)
package_name = "plugins"
logger.info(f"正在从 {package_name} 加载插件...")
for _, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
full_module_name = f"{package_name}.{module_name}"
try:
if full_module_name in self.loaded_plugins:
self.command_manager.unload_plugin(full_module_name)
module = importlib.reload(sys.modules[full_module_name])
action = "重载"
else:
module = importlib.import_module(full_module_name)
action = "加载"
if hasattr(module, "__plugin_meta__"):
meta = getattr(module, "__plugin_meta__")
self.command_manager.plugins[full_module_name] = meta
self.loaded_plugins.add(full_module_name)
type_str = "" if is_pkg else "文件"
logger.success(f" [{type_str}] 成功{action}: {module_name}")
except SyncHandlerError as e:
logger.error(f" 插件 {module_name} 加载失败: {e} (跳过此插件)")
except Exception as e:
logger.exception(
f" {action if 'action' in locals() else '加载'}插件 {module_name} 失败: {e}"
)
def reload_plugin(self, full_module_name: str):
"""
精确重载单个插件。
"""
if full_module_name not in self.loaded_plugins:
logger.warning(f"尝试重载一个未被加载的插件: {full_module_name},将按首次加载处理。")
if full_module_name not in sys.modules:
logger.error(f"重载失败: 模块 {full_module_name} 未在 sys.modules 中找到。")
return
try:
if full_module_name in sys.modules:
module = importlib.reload(sys.modules[full_module_name])
action = "重载"
else:
module = importlib.import_module(full_module_name)
action = "加载"
# 提取插件元数据
self.command_manager.unload_plugin(full_module_name)
module = importlib.reload(sys.modules[full_module_name])
if hasattr(module, "__plugin_meta__"):
meta = getattr(module, "__plugin_meta__")
matcher.plugins[full_module_name] = meta
type_str = "" if is_pkg else "文件"
logger.success(f" [{type_str}] 成功{action}: {module_name}")
except SyncHandlerError as e:
logger.error(f" 插件 {module_name} 加载失败: {e} (跳过此插件)")
self.command_manager.plugins[full_module_name] = meta
logger.success(f"插件 {full_module_name} 已成功重载。")
except Exception as e:
print(
f" {action if 'action' in locals() else '加载'}插件 {module_name} 失败: {e}"
)
class PluginDataManager:
"""
用于管理插件产生的数据文件的类
"""
def __init__(self, plugin_name: str):
"""
初始化插件数据管理器
:param plugin_name: 插件名称
"""
self.plugin_name = plugin_name
self.data_file = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"plugins",
"data",
self.plugin_name + ".json",
)
self.data = {}
async def load(self):
"""读取配置文件"""
if not os.path.exists(self.data_file):
await self.set(self.plugin_name, [])
try:
with open(self.data_file, "r", encoding="utf-8") as f:
self.data = await run_in_thread_pool(json.load, f)
except json.JSONDecodeError:
self.data = {}
async def save(self):
"""保存配置到文件"""
with open(self.data_file, "w", encoding="utf-8") as f:
await run_in_thread_pool(json.dump, self.data, f, indent=2, ensure_ascii=False)
def get(self, key, default=None):
"""获取配置项"""
return self.data.get(key, default)
async def set(self, key, value):
"""设置配置项"""
self.data[key] = value
await self.save()
async def add(self, key, value):
"""添加配置项"""
if key not in self.data:
self.data[key] = []
self.data[key].append(value)
await self.save()
async def remove(self, key):
"""删除配置项"""
if key in self.data:
del self.data[key]
await self.save()
async def clear(self):
"""清空所有配置"""
self.data.clear()
await self.save()
def get_all(self):
return self.data.copy()
logger.exception(f"重载插件 {full_module_name} 时发生错误: {e}")

View File

@@ -20,10 +20,11 @@ class RedisManager:
"""
if self._redis is None:
try:
host = config.redis['host']
port = config.redis['port']
db = config.redis['db']
password = config.redis.get('password')
redis_config = config.redis
host = redis_config.host
port = redis_config.port
db = redis_config.db
password = redis_config.password
logger.info(f"正在尝试连接 Redis: {host}:{port}, DB: {db}")
@@ -54,5 +55,17 @@ class RedisManager:
raise ConnectionError("Redis 未初始化或连接失败,请先调用 initialize()")
return self._redis
async def get(self, name):
"""
获取指定键的值
"""
return await self.redis.get(name)
async def set(self, name, value, ex=None):
"""
设置指定键的值
"""
return await self.redis.set(name, value, ex=ex)
# 全局 Redis 管理器实例
redis_manager = RedisManager()

45
core/permission.py Normal file
View File

@@ -0,0 +1,45 @@
from enum import Enum
from functools import total_ordering
@total_ordering
class Permission(Enum):
"""
定义用户权限等级的枚举类。
使用 @total_ordering 装饰器,只需定义 __lt__ 和 __eq__
即可自动实现所有比较运算符。
"""
USER = "user"
OP = "op"
ADMIN = "admin"
@property
def _level_map(self):
"""
内部属性,用于映射枚举成员到整数等级。
"""
return {
Permission.USER: 1,
Permission.OP: 2,
Permission.ADMIN: 3
}
def __lt__(self, other):
"""
比较当前权限是否小于另一个权限。
"""
if not isinstance(other, Permission):
return NotImplemented
return self._level_map[self] < self._level_map[other]
def __eq__(self, other):
if not isinstance(other, Permission):
return NotImplemented
return self is other
def __ge__(self, other):
if not isinstance(other, Permission):
return NotImplemented
return self._level_map[self] >= self._level_map[other]

View File

@@ -2,7 +2,8 @@
import asyncio
import docker
from docker.tls import TLSConfig
from typing import Dict, Any, Callable
from docker.types import LogConfig
from typing import Any, Callable
from core.utils.logger import logger
@@ -10,21 +11,20 @@ class CodeExecutor:
"""
代码执行引擎,负责管理一个异步任务队列和并发的 Docker 容器执行。
"""
def __init__(self, bot_instance, config: Dict[str, Any]):
def __init__(self, config: Any):
"""
初始化代码执行引擎。
:param bot_instance: Bot 实例,用于后续的消息回复
:param config: 从 config.toml 加载的配置字典。
:param config: 从 config_loader.py 加载的全局配置对象
"""
self.bot = bot_instance
self.task_queue = asyncio.Queue()
self.bot: Any = None # Bot 实例将在 WS 连接成功后动态注入
self.task_queue: asyncio.Queue = asyncio.Queue()
# 从传入的配置中读取 Docker 相关设置
docker_config = config.docker
self.docker_base_url = docker_config.get("base_url")
self.sandbox_image = docker_config.get("sandbox_image", "python-sandbox:latest")
self.timeout = docker_config.get("timeout", 10)
concurrency = docker_config.get("concurrency_limit", 5)
self.docker_base_url = docker_config.base_url
self.sandbox_image = docker_config.sandbox_image
self.timeout = docker_config.timeout
concurrency = docker_config.concurrency_limit
self.concurrency_limit = asyncio.Semaphore(concurrency)
self.docker_client = None
@@ -34,10 +34,10 @@ class CodeExecutor:
if self.docker_base_url:
# 如果配置了远程 Docker 地址,则使用 TLS 选项进行连接
tls_config = None
if docker_config.get("tls_verify", False):
if docker_config.tls_verify:
tls_config = TLSConfig(
ca_cert=docker_config.get("ca_cert_path"),
client_cert=(docker_config.get("client_cert_path"), docker_config.get("client_key_path")),
ca_cert=docker_config.ca_cert_path,
client_cert=(docker_config.client_cert_path, docker_config.client_key_path),
verify=True
)
self.docker_client = docker.DockerClient(base_url=self.docker_base_url, tls=tls_config)
@@ -125,6 +125,9 @@ class CodeExecutor:
同步函数:在 Docker 容器中运行代码。
此函数通过手动管理容器生命周期来提高稳定性。
"""
if self.docker_client is None:
raise docker.errors.DockerException("Docker client is not initialized.")
container = None
try:
# 1. 创建容器
@@ -134,7 +137,7 @@ class CodeExecutor:
mem_limit='128m',
cpu_shares=512,
network_disabled=True,
log_config={'type': 'json-file', 'config': {'max-size': '1m'}},
log_config=LogConfig(type='json-file', config={'max-size': '1m'}),
)
# 2. 启动容器
container.start()
@@ -150,7 +153,7 @@ class CodeExecutor:
# 5. 检查退出码,如果不为 0则手动抛出 ContainerError
if result.get('StatusCode', 0) != 0:
raise docker.errors.ContainerError(
container, result['StatusCode'], f"python -c '{code}'", self.sandbox_image, stderr
container, result['StatusCode'], f"python -c '{code}'", self.sandbox_image, stderr.decode('utf-8')
)
return stdout
@@ -166,11 +169,11 @@ class CodeExecutor:
except Exception as e:
logger.error(f"[CodeExecutor] 强制移除容器 {container.id} 时失败: {e}")
def initialize_executor(bot_instance, config: Dict[str, Any]):
def initialize_executor(config: Any):
"""
初始化并返回一个 CodeExecutor 实例。
"""
return CodeExecutor(bot_instance, config)
return CodeExecutor(config)
async def run_in_thread_pool(sync_func, *args, **kwargs):
"""

View File

@@ -13,11 +13,12 @@ WebSocket 连接。它是整个机器人框架的底层通信基础。
"""
import asyncio
import json
from typing import Any, Dict, Optional
import uuid
import websockets
from models import EventFactory
from models.events.factory import EventFactory
from .bot import Bot
from .config_loader import global_config
@@ -30,7 +31,7 @@ class WS:
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
"""
def __init__(self):
def __init__(self, code_executor=None):
"""
初始化 WebSocket 客户端。
@@ -38,13 +39,15 @@ class WS:
"""
# 读取参数
cfg = global_config.napcat_ws
self.url = cfg.get("uri")
self.token = cfg.get("token")
self.reconnect_interval = cfg.get("reconnect_interval", 5)
self.url = cfg.uri
self.token = cfg.token
self.reconnect_interval = cfg.reconnect_interval
self.ws = None
self._pending_requests = {}
self.bot = Bot(self)
self.bot: Bot | None = None
self.self_id: int | None = None
self.code_executor = code_executor
async def connect(self):
"""
@@ -124,18 +127,42 @@ class WS:
try:
# 使用工厂创建事件对象
event = EventFactory.create_event(event_data)
# 在收到第一个 meta_event 时,初始化 Bot 实例
if event.post_type == "meta_event" and self.bot is None:
self.self_id = event.self_id
self.bot = Bot(self)
logger.success(f"Bot 实例初始化完成: self_id={self.self_id}")
# 将代码执行器注入到 Bot 和执行器自身
if self.code_executor:
self.bot.code_executor = self.code_executor
self.code_executor.bot = self.bot
logger.info("代码执行器已成功注入 Bot 实例。")
# 如果 bot 尚未初始化,则不处理后续事件
if self.bot is None:
logger.warning("Bot 尚未初始化,跳过事件处理。")
return
event.bot = self.bot # 注入 Bot 实例
# 打印日志
if event.post_type == "message":
sender_name = event.sender.nickname if event.sender else "Unknown"
logger.info(f"[消息] {event.message_type} | {event.user_id}({sender_name}): {event.raw_message}")
sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown"
message_type = getattr(event, "message_type", "Unknown")
user_id = getattr(event, "user_id", "Unknown")
raw_message = getattr(event, "raw_message", "")
logger.info(f"[消息] {message_type} | {user_id}({sender_name}): {raw_message}")
elif event.post_type == "notice":
logger.info(f"[通知] {event.notice_type}")
notice_type = getattr(event, "notice_type", "Unknown")
logger.info(f"[通知] {notice_type}")
elif event.post_type == "request":
logger.info(f"[请求] {event.request_type}")
request_type = getattr(event, "request_type", "Unknown")
logger.info(f"[请求] {request_type}")
elif event.post_type == "meta_event":
logger.debug(f"[元事件] {event.meta_event_type}")
meta_event_type = getattr(event, "meta_event_type", "Unknown")
logger.debug(f"[元事件] {meta_event_type}")
# 分发事件
@@ -144,7 +171,7 @@ class WS:
except Exception as e:
logger.exception(f"事件处理异常: {e}")
async def call_api(self, action: str, params: dict = None):
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
"""
向 OneBot v11 实现端发送一个 API 请求。