Dev (#28)
* 滚木 * feat: 重构核心架构,增强类型安全与插件管理 本次提交对核心模块进行了深度重构,引入 Pydantic 增强配置管理的类型安全性,并全面优化了插件管理系统。 主要变更详情: 1. 核心架构与配置 - 重构配置加载模块:引入 Pydantic 模型 (`core/config_models.py`),提供严格的配置项类型检查、验证及默认值管理。 - 统一模块结构:规范化模块导入路径,移除冗余的 `__init__.py` 文件,提升项目结构的清晰度。 - 性能优化:集成 Redis 缓存支持 (`RedisManager`),有效降低高频 API 调用开销,提升响应速度。 2. 插件系统升级 - 实现热重载机制:新增插件文件变更监听功能,支持开发过程中自动重载插件,提升开发效率。 - 优化生命周期管理:改进插件加载与卸载逻辑,支持精确卸载指定插件及其关联的命令、事件处理器和定时任务。 3. 功能特性增强 - 新增媒体 API:引入 `MediaAPI` 模块,封装图片、语音等富媒体资源的获取与处理接口。 - 完善权限体系:重构权限管理系统,实现管理员与操作员的分级控制,支持更细粒度的命令权限校验。 4. 代码质量与稳定性 - 全面类型修复:解决 `mypy` 静态类型检查发现的大量类型错误(包括 `CommandManager`、`EventFactory` 及 `Bot` API 签名不匹配问题)。 - 增强错误处理:优化消息处理管道的异常捕获机制,完善关键路径的日志记录,提升系统运行稳定性。 * feat: 添加测试用例并优化代码结构 refactor(permission_manager): 调整初始化顺序和逻辑 fix(admin_manager): 修复初始化逻辑和目录创建问题 feat(ws): 优化Bot实例初始化条件 feat(message): 增强MessageSegment功能并添加测试 feat(events): 支持字符串格式的消息解析 test: 添加核心功能测试用例 refactor(plugin_manager): 改进插件路径处理 style: 清理无用导入和代码 chore: 更新依赖项
This commit is contained in:
@@ -1,6 +0,0 @@
|
||||
from .managers.command_manager import matcher
|
||||
from .config_loader import global_config
|
||||
from .managers.plugin_manager import PluginDataManager
|
||||
from .ws import WS
|
||||
|
||||
__all__ = ["WS", "matcher", "global_config", "PluginDataManager"]
|
||||
|
||||
@@ -3,6 +3,7 @@ from .message import MessageAPI
|
||||
from .group import GroupAPI
|
||||
from .friend import FriendAPI
|
||||
from .account import AccountAPI
|
||||
from .media import MediaAPI
|
||||
|
||||
__all__ = [
|
||||
"BaseAPI",
|
||||
@@ -10,4 +11,5 @@ __all__ = [
|
||||
"GroupAPI",
|
||||
"FriendAPI",
|
||||
"AccountAPI",
|
||||
"MediaAPI",
|
||||
]
|
||||
|
||||
@@ -162,3 +162,56 @@ class AccountAPI(BaseAPI):
|
||||
"""
|
||||
return await self.call_api("clean_cache")
|
||||
|
||||
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> Any:
|
||||
"""
|
||||
获取陌生人信息。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Any: 包含陌生人信息的字典或对象。
|
||||
"""
|
||||
return await self.call_api("get_stranger_info", {"user_id": user_id, "no_cache": no_cache})
|
||||
|
||||
async def get_friend_list(self, no_cache: bool = False) -> list:
|
||||
"""
|
||||
获取好友列表。
|
||||
|
||||
Args:
|
||||
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||
|
||||
Returns:
|
||||
list: 好友列表。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_friend_list:{self.self_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.get(cache_key)
|
||||
if cached_data:
|
||||
return json.loads(cached_data)
|
||||
|
||||
res = await self.call_api("get_friend_list")
|
||||
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return res
|
||||
|
||||
async def get_group_list(self, no_cache: bool = False) -> list:
|
||||
"""
|
||||
获取群列表。
|
||||
|
||||
Args:
|
||||
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||
|
||||
Returns:
|
||||
list: 群列表。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_group_list:{self.self_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.get(cache_key)
|
||||
if cached_data:
|
||||
return json.loads(cached_data)
|
||||
|
||||
res = await self.call_api("get_group_list")
|
||||
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return res
|
||||
|
||||
|
||||
@@ -1,24 +1,50 @@
|
||||
"""
|
||||
API 基础模块
|
||||
|
||||
定义了 API 调用的基础接口。
|
||||
定义了 API 调用的基础接口和统一处理逻辑。
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from ..utils.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ws import WS
|
||||
|
||||
|
||||
class BaseAPI(ABC):
|
||||
class BaseAPI:
|
||||
"""
|
||||
API 基础抽象类
|
||||
API 基础类,提供了统一的 `call_api` 方法,包含日志记录和异常处理。
|
||||
"""
|
||||
_ws: "WS"
|
||||
self_id: int
|
||||
|
||||
def __init__(self, ws_client: "WS", self_id: int):
|
||||
self._ws = ws_client
|
||||
self.self_id = self_id
|
||||
|
||||
@abstractmethod
|
||||
async def call_api(self, action: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
调用 API
|
||||
调用 OneBot v11 API,并提供统一的日志和异常处理。
|
||||
|
||||
:param action: API 动作名称
|
||||
:param params: API 参数
|
||||
:return: API 响应结果
|
||||
:return: API 响应结果的数据部分
|
||||
:raises Exception: 当 API 调用失败或发生网络错误时
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
try:
|
||||
logger.debug(f"调用API -> action: {action}, params: {params}")
|
||||
response = await self._ws.call_api(action, params)
|
||||
logger.debug(f"API响应 <- {response}")
|
||||
|
||||
if response.get("status") == "failed":
|
||||
logger.warning(f"API调用失败: {response}")
|
||||
|
||||
return response.get("data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API调用异常: action={action}, params={params}, error={e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
该模块定义了 `GroupAPI` Mixin 类,提供了所有与群组管理、成员操作
|
||||
等相关的 OneBot v11 API 封装。
|
||||
"""
|
||||
from typing import List, Dict, Any
|
||||
from typing import List, Dict, Any, Optional
|
||||
import json
|
||||
from ..managers.redis_manager import redis_manager
|
||||
from .base import BaseAPI
|
||||
@@ -46,7 +46,7 @@ class GroupAPI(BaseAPI):
|
||||
"""
|
||||
return await self.call_api("set_group_ban", {"group_id": group_id, "user_id": user_id, "duration": duration})
|
||||
|
||||
async def set_group_anonymous_ban(self, group_id: int, anonymous: Dict[str, Any] = None, duration: int = 1800, flag: str = None) -> Dict[str, Any]:
|
||||
async def set_group_anonymous_ban(self, group_id: int, anonymous: Optional[Dict[str, Any]] = None, duration: int = 1800, flag: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
禁言群组中的匿名用户。
|
||||
|
||||
@@ -61,7 +61,7 @@ class GroupAPI(BaseAPI):
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
params = {"group_id": group_id, "duration": duration}
|
||||
params: Dict[str, Any] = {"group_id": group_id, "duration": duration}
|
||||
if anonymous:
|
||||
params["anonymous"] = anonymous
|
||||
if flag:
|
||||
@@ -187,17 +187,18 @@ class GroupAPI(BaseAPI):
|
||||
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return GroupInfo(**res)
|
||||
|
||||
async def get_group_list(self) -> List[GroupInfo]:
|
||||
async def get_group_list(self) -> Any:
|
||||
"""
|
||||
获取机器人加入的所有群组的列表。
|
||||
|
||||
Returns:
|
||||
List[GroupInfo]: 包含所有群组信息的 `GroupInfo` 对象列表。
|
||||
Any: 包含所有群组信息的列表(可能是字典列表或对象列表)。
|
||||
"""
|
||||
res = await self.call_api("get_group_list")
|
||||
|
||||
# 增加日志记录 API 原始返回
|
||||
logger.debug(f"OneBot API 'get_group_list' raw response: {res}")
|
||||
return res
|
||||
|
||||
# 健壮性处理:处理标准的 OneBot v11 响应格式
|
||||
if isinstance(res, dict) and res.get("status") == "ok":
|
||||
|
||||
39
core/api/media.py
Normal file
39
core/api/media.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
媒体API模块
|
||||
|
||||
封装了与图片、语音等媒体文件相关的API。
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
|
||||
from .base import BaseAPI
|
||||
|
||||
|
||||
class MediaAPI(BaseAPI):
|
||||
"""
|
||||
媒体相关API
|
||||
"""
|
||||
|
||||
async def can_send_image(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查是否可以发送图片
|
||||
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="can_send_image")
|
||||
|
||||
async def can_send_record(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查是否可以发送语音
|
||||
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="can_send_record")
|
||||
|
||||
async def get_image(self, file: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取图片信息
|
||||
|
||||
:param file: 图片文件名或路径
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="get_image", params={"file": file})
|
||||
@@ -8,7 +8,8 @@ from typing import Union, List, Dict, Any, TYPE_CHECKING
|
||||
from .base import BaseAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import MessageSegment, OneBotEvent
|
||||
from models.message import MessageSegment
|
||||
from models.events.base import OneBotEvent
|
||||
|
||||
|
||||
class MessageAPI(BaseAPI):
|
||||
@@ -156,24 +157,6 @@ class MessageAPI(BaseAPI):
|
||||
"""
|
||||
return await self.call_api("send_private_forward_msg", {"user_id": user_id, "messages": messages})
|
||||
|
||||
async def can_send_image(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查当前机器人账号是否可以发送图片。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("can_send_image")
|
||||
|
||||
async def can_send_record(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查当前机器人账号是否可以发送语音。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("can_send_record")
|
||||
|
||||
def _process_message(self, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Union[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
内部方法:将消息内容处理成 OneBot API 可接受的格式。
|
||||
@@ -192,7 +175,7 @@ class MessageAPI(BaseAPI):
|
||||
return message
|
||||
|
||||
# 避免循环导入,在运行时导入
|
||||
from models import MessageSegment
|
||||
from models.message import MessageSegment
|
||||
|
||||
if isinstance(message, MessageSegment):
|
||||
return [self._segment_to_dict(message)]
|
||||
|
||||
33
core/bot.py
33
core/bot.py
@@ -13,14 +13,15 @@ Bot 核心抽象模块
|
||||
from typing import TYPE_CHECKING, Dict, Any, List, Union
|
||||
from models.events.base import OneBotEvent
|
||||
from models.message import MessageSegment
|
||||
from models.objects import GroupInfo, StrangerInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .WS import WS
|
||||
from .ws import WS
|
||||
|
||||
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI
|
||||
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI
|
||||
|
||||
|
||||
class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI):
|
||||
class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI):
|
||||
"""
|
||||
机器人核心类,封装了所有与 OneBot API 的交互。
|
||||
|
||||
@@ -35,22 +36,22 @@ class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI):
|
||||
Args:
|
||||
ws_client (WS): WebSocket 客户端实例,负责底层的 API 请求和响应处理。
|
||||
"""
|
||||
self.ws = ws_client
|
||||
super().__init__(ws_client, ws_client.self_id or 0)
|
||||
self.code_executor = None
|
||||
|
||||
async def call_api(self, action: str, params: Dict[str, Any] = None) -> Any:
|
||||
"""
|
||||
底层 API 调用方法。
|
||||
async def get_group_list(self, no_cache: bool = False) -> List[GroupInfo]:
|
||||
# GroupAPI.get_group_list 不支持 no_cache 参数,这里忽略它
|
||||
result = await super().get_group_list()
|
||||
# 确保结果是 GroupInfo 对象列表
|
||||
return [GroupInfo(**group) if isinstance(group, dict) else group for group in result]
|
||||
|
||||
所有具体的 API 实现最终都会调用此方法,通过 WebSocket 发送请求。
|
||||
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> StrangerInfo:
|
||||
result = await super().get_stranger_info(user_id=user_id, no_cache=no_cache)
|
||||
# 确保结果是 StrangerInfo 对象
|
||||
if isinstance(result, dict):
|
||||
return StrangerInfo(**result)
|
||||
return result
|
||||
|
||||
Args:
|
||||
action (str): API 的动作名称,例如 "send_group_msg"。
|
||||
params (Dict[str, Any], optional): API 请求的参数字典。Defaults to None.
|
||||
|
||||
Returns:
|
||||
Any: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.ws.call_api(action, params)
|
||||
|
||||
def build_forward_node(self, user_id: int, nickname: str, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -4,9 +4,11 @@
|
||||
负责读取和解析 config.toml 配置文件,提供全局配置对象。
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import tomllib
|
||||
from pydantic import ValidationError
|
||||
from .config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel
|
||||
from .utils.logger import logger
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -21,73 +23,67 @@ class Config:
|
||||
:param file_path: 配置文件路径,默认为 "config.toml"
|
||||
"""
|
||||
self.path = Path(file_path)
|
||||
self._data: Dict[str, Any] = {}
|
||||
self._model: ConfigModel
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
加载配置文件
|
||||
加载并验证配置文件
|
||||
|
||||
:raises FileNotFoundError: 如果配置文件不存在
|
||||
:raises ValidationError: 如果配置格式不正确
|
||||
"""
|
||||
if not self.path.exists():
|
||||
logger.error(f"配置文件 {self.path} 未找到!")
|
||||
raise FileNotFoundError(f"配置文件 {self.path} 未找到!")
|
||||
|
||||
with open(self.path, "rb") as f:
|
||||
self._data = tomllib.load(f)
|
||||
try:
|
||||
logger.info(f"正在从 {self.path} 加载配置...")
|
||||
with open(self.path, "rb") as f:
|
||||
raw_config = tomllib.load(f)
|
||||
|
||||
self._model = ConfigModel(**raw_config)
|
||||
logger.success("配置加载并验证成功!")
|
||||
|
||||
except ValidationError as e:
|
||||
logger.error("配置验证失败,请检查 `config.toml` 文件中的以下错误:")
|
||||
for error in e.errors():
|
||||
field = " -> ".join(map(str, error["loc"]))
|
||||
logger.error(f" - 字段 '{field}': {error['msg']}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"加载配置文件时发生未知错误: {e}")
|
||||
raise
|
||||
|
||||
# 通过属性访问配置
|
||||
@property
|
||||
def napcat_ws(self) -> dict:
|
||||
def napcat_ws(self) -> NapCatWSModel:
|
||||
"""
|
||||
获取 NapCat WebSocket 配置
|
||||
|
||||
:return: 配置字典
|
||||
"""
|
||||
return self._data.get("napcat_ws", {})
|
||||
return self._model.napcat_ws
|
||||
|
||||
@property
|
||||
def bot(self) -> dict:
|
||||
def bot(self) -> BotModel:
|
||||
"""
|
||||
获取 Bot 基础配置
|
||||
|
||||
:return: 配置字典
|
||||
"""
|
||||
return self._data.get("bot", {})
|
||||
return self._model.bot
|
||||
|
||||
@property
|
||||
def features(self) -> dict:
|
||||
"""
|
||||
获取功能特性配置
|
||||
|
||||
:return: 配置字典
|
||||
"""
|
||||
return self._data.get("features", {})
|
||||
|
||||
@property
|
||||
def redis(self) -> dict:
|
||||
def redis(self) -> RedisModel:
|
||||
"""
|
||||
获取 Redis 配置
|
||||
|
||||
:return: 配置字典
|
||||
"""
|
||||
return self._data.get("redis", {})
|
||||
return self._model.redis
|
||||
|
||||
@property
|
||||
def docker(self) -> dict:
|
||||
def docker(self) -> DockerModel:
|
||||
"""
|
||||
获取 Docker 配置
|
||||
|
||||
:return: 配置字典
|
||||
"""
|
||||
return self._data.get("docker", {})
|
||||
return self._model.docker
|
||||
|
||||
|
||||
# 实例化全局配置对象
|
||||
global_config = Config()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(global_config.napcat_ws)
|
||||
print(global_config.bot.get("command"))
|
||||
print(type(global_config.bot.get("command")) is list)
|
||||
print(global_config.features)
|
||||
|
||||
60
core/config_models.py
Normal file
60
core/config_models.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Pydantic 配置模型模块
|
||||
|
||||
该模块使用 Pydantic 定义了与 `config.toml` 文件结构完全对应的配置模型。
|
||||
这使得配置的加载、校验和访问都变得类型安全和健壮。
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NapCatWSModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[napcat_ws]` 配置块。
|
||||
"""
|
||||
uri: str
|
||||
token: str
|
||||
reconnect_interval: int = 5
|
||||
|
||||
|
||||
class BotModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[bot]` 配置块。
|
||||
"""
|
||||
command: List[str] = Field(default_factory=lambda: ["/"])
|
||||
ignore_self_message: bool = True
|
||||
permission_denied_message: str = "权限不足,需要 {permission_name} 权限"
|
||||
|
||||
|
||||
class RedisModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[redis]` 配置块。
|
||||
"""
|
||||
host: str
|
||||
port: int
|
||||
db: int
|
||||
password: str
|
||||
|
||||
|
||||
class DockerModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[docker]` 配置块。
|
||||
"""
|
||||
base_url: Optional[str] = None
|
||||
sandbox_image: str = "python-sandbox:latest"
|
||||
timeout: int = 10
|
||||
concurrency_limit: int = 5
|
||||
tls_verify: bool = False
|
||||
ca_cert_path: Optional[str] = None
|
||||
client_cert_path: Optional[str] = None
|
||||
client_key_path: Optional[str] = None
|
||||
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
"""
|
||||
顶层配置模型,整合了所有子配置块。
|
||||
"""
|
||||
napcat_ws: NapCatWSModel
|
||||
bot: BotModel
|
||||
redis: RedisModel
|
||||
docker: DockerModel
|
||||
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"admins": []
|
||||
"admins": [2221577113]
|
||||
}
|
||||
@@ -6,11 +6,12 @@
|
||||
"""
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from ..bot import Bot
|
||||
if TYPE_CHECKING:
|
||||
from ..bot import Bot
|
||||
from ..config_loader import global_config
|
||||
from ..managers.permission_manager import Permission, permission_manager
|
||||
from ..permission import Permission
|
||||
from ..utils.executor import run_in_thread_pool
|
||||
|
||||
|
||||
@@ -22,7 +23,7 @@ class BaseHandler(ABC):
|
||||
self.handlers: List[Dict[str, Any]] = []
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, bot: Bot, event: Any):
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理事件
|
||||
"""
|
||||
@@ -31,7 +32,7 @@ class BaseHandler(ABC):
|
||||
async def _run_handler(
|
||||
self,
|
||||
func: Callable,
|
||||
bot: Bot,
|
||||
bot: "Bot",
|
||||
event: Any,
|
||||
args: Optional[List[str]] = None,
|
||||
permission_granted: Optional[bool] = None
|
||||
@@ -41,7 +42,7 @@ class BaseHandler(ABC):
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
params = sig.parameters
|
||||
kwargs = {}
|
||||
kwargs: Dict[str, Any] = {}
|
||||
|
||||
if "bot" in params:
|
||||
kwargs["bot"] = bot
|
||||
@@ -68,21 +69,41 @@ class MessageHandler(BaseHandler):
|
||||
super().__init__()
|
||||
self.prefixes = prefixes
|
||||
self.commands: Dict[str, Dict] = {}
|
||||
self.message_handlers: List[Callable] = []
|
||||
self.message_handlers: List[Dict[str, Any]] = []
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
清空所有已注册的消息和命令处理器
|
||||
"""
|
||||
self.commands.clear()
|
||||
self.message_handlers.clear()
|
||||
|
||||
def unregister_by_plugin_name(self, plugin_name: str):
|
||||
"""
|
||||
根据插件名卸载相关的消息和命令处理器
|
||||
"""
|
||||
# 卸载命令
|
||||
commands_to_remove = [name for name, info in self.commands.items() if info["plugin_name"] == plugin_name]
|
||||
for name in commands_to_remove:
|
||||
del self.commands[name]
|
||||
|
||||
# 卸载通用消息处理器
|
||||
self.message_handlers = [h for h in self.message_handlers if h["plugin_name"] != plugin_name]
|
||||
|
||||
def on_message(self) -> Callable:
|
||||
"""
|
||||
注册通用消息处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self.message_handlers.append(func)
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
self.message_handlers.append({"func": func, "plugin_name": plugin_name})
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def command(
|
||||
self,
|
||||
*names: str,
|
||||
*names: str,
|
||||
permission: Optional[Permission] = None,
|
||||
override_permission_check: bool = False
|
||||
) -> Callable:
|
||||
@@ -90,21 +111,25 @@ class MessageHandler(BaseHandler):
|
||||
注册命令处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
for name in names:
|
||||
self.commands[name] = {
|
||||
"func": func,
|
||||
"permission": permission,
|
||||
"override_permission_check": override_permission_check,
|
||||
"plugin_name": plugin_name,
|
||||
}
|
||||
return func
|
||||
return decorator
|
||||
|
||||
async def handle(self, bot: Bot, event: Any):
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理消息事件,包括通用消息和命令
|
||||
处理消息事件,分发给命令处理器或通用消息处理器
|
||||
"""
|
||||
for handler in self.message_handlers:
|
||||
consumed = await self._run_handler(handler, bot, event)
|
||||
from ..managers import permission_manager
|
||||
for handler_info in self.message_handlers:
|
||||
consumed = await self._run_handler(handler_info["func"], bot, event)
|
||||
if consumed:
|
||||
return
|
||||
|
||||
@@ -136,7 +161,7 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
if not permission_granted and not override_check:
|
||||
permission_name = permission.name if isinstance(permission, Permission) else permission
|
||||
message_template = global_config.bot.get("permission_denied_message", "权限不足,需要 {permission_name} 权限")
|
||||
message_template = global_config.bot.permission_denied_message
|
||||
await bot.send(event, message_template.format(permission_name=permission_name))
|
||||
return
|
||||
|
||||
@@ -153,12 +178,23 @@ class NoticeHandler(BaseHandler):
|
||||
"""
|
||||
通知事件处理器
|
||||
"""
|
||||
def clear(self):
|
||||
self.handlers.clear()
|
||||
|
||||
def unregister_by_plugin_name(self, plugin_name: str):
|
||||
"""
|
||||
根据插件名卸载相关的通知处理器
|
||||
"""
|
||||
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
|
||||
|
||||
def register(self, notice_type: Optional[str] = None) -> Callable:
|
||||
"""
|
||||
注册通知处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self.handlers.append({"type": notice_type, "func": func})
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
self.handlers.append({"type": notice_type, "func": func, "plugin_name": plugin_name})
|
||||
return func
|
||||
return decorator
|
||||
|
||||
@@ -175,12 +211,23 @@ class RequestHandler(BaseHandler):
|
||||
"""
|
||||
请求事件处理器
|
||||
"""
|
||||
def clear(self):
|
||||
self.handlers.clear()
|
||||
|
||||
def unregister_by_plugin_name(self, plugin_name: str):
|
||||
"""
|
||||
根据插件名卸载相关的请求处理器
|
||||
"""
|
||||
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
|
||||
|
||||
def register(self, request_type: Optional[str] = None) -> Callable:
|
||||
"""
|
||||
注册请求处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self.handlers.append({"type": request_type, "func": func})
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
self.handlers.append({"type": request_type, "func": func, "plugin_name": plugin_name})
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
管理器包
|
||||
|
||||
这个包集中了机器人核心的单例管理器。
|
||||
通过从这里导入,可以确保在整个应用中访问到的都是同一个实例。
|
||||
"""
|
||||
from ..config_loader import global_config
|
||||
from .admin_manager import AdminManager
|
||||
from .command_manager import CommandManager
|
||||
from .permission_manager import PermissionManager
|
||||
from .plugin_manager import PluginManager
|
||||
from .redis_manager import RedisManager
|
||||
|
||||
# --- 实例化所有单例管理器 ---
|
||||
|
||||
# 管理员管理器
|
||||
admin_manager = AdminManager()
|
||||
|
||||
# 权限管理器
|
||||
permission_manager = PermissionManager()
|
||||
|
||||
# 命令与事件管理器 (别名 matcher)
|
||||
command_manager = CommandManager(prefixes=tuple(global_config.bot.command))
|
||||
matcher = command_manager
|
||||
|
||||
# 插件管理器
|
||||
plugin_manager = PluginManager(command_manager)
|
||||
plugin_manager.load_all_plugins()
|
||||
|
||||
# Redis 管理器
|
||||
redis_manager = RedisManager()
|
||||
|
||||
__all__ = [
|
||||
"admin_manager",
|
||||
"permission_manager",
|
||||
"command_manager",
|
||||
"matcher",
|
||||
"plugin_manager",
|
||||
"redis_manager",
|
||||
]
|
||||
|
||||
@@ -26,8 +26,7 @@ class AdminManager(Singleton):
|
||||
"""
|
||||
初始化 AdminManager
|
||||
"""
|
||||
super().__init__()
|
||||
if not self._initialized:
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
|
||||
# 管理员数据文件路径
|
||||
@@ -39,7 +38,12 @@ class AdminManager(Singleton):
|
||||
)
|
||||
|
||||
self._admins: Set[int] = set()
|
||||
|
||||
# 确保数据目录存在
|
||||
os.makedirs(os.path.dirname(self.data_file), exist_ok=True)
|
||||
|
||||
logger.info("管理员管理器初始化完成")
|
||||
super().__init__()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
|
||||
@@ -12,7 +12,16 @@ from ..handlers.event_handler import MessageHandler, NoticeHandler, RequestHandl
|
||||
|
||||
|
||||
# 从配置中获取命令前缀
|
||||
command_prefixes = global_config.bot.get("command", ("/",))
|
||||
_config_prefixes = global_config.bot.command
|
||||
|
||||
# 确保前缀配置是元组格式
|
||||
_final_prefixes: Tuple[str, ...]
|
||||
if isinstance(_config_prefixes, list):
|
||||
_final_prefixes = tuple(_config_prefixes)
|
||||
elif isinstance(_config_prefixes, str):
|
||||
_final_prefixes = (_config_prefixes,)
|
||||
else:
|
||||
_final_prefixes = tuple(_config_prefixes)
|
||||
|
||||
|
||||
class CommandManager:
|
||||
@@ -59,6 +68,35 @@ class CommandManager:
|
||||
"usage": "/help",
|
||||
}
|
||||
|
||||
def clear_all_handlers(self):
|
||||
"""
|
||||
清空所有已注册的事件处理器。
|
||||
注意:这也会移除内置的 /help 命令,因此需要重新注册。
|
||||
"""
|
||||
self.message_handler.clear()
|
||||
self.notice_handler.clear()
|
||||
self.request_handler.clear()
|
||||
self.plugins.clear()
|
||||
|
||||
# 清空后,需要重新注册内置命令
|
||||
self._register_internal_commands()
|
||||
|
||||
def unload_plugin(self, plugin_name: str):
|
||||
"""
|
||||
卸载指定插件的所有处理器和命令。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件的模块名 (例如 'plugins.bili_parser')
|
||||
"""
|
||||
self.message_handler.unregister_by_plugin_name(plugin_name)
|
||||
self.notice_handler.unregister_by_plugin_name(plugin_name)
|
||||
self.request_handler.unregister_by_plugin_name(plugin_name)
|
||||
|
||||
# 移除插件元信息
|
||||
plugins_to_remove = [name for name in self.plugins if name.startswith(plugin_name)]
|
||||
for name in plugins_to_remove:
|
||||
del self.plugins[name]
|
||||
|
||||
# --- 装饰器代理 ---
|
||||
|
||||
def on_message(self) -> Callable:
|
||||
@@ -102,7 +140,7 @@ class CommandManager:
|
||||
|
||||
根据事件的 `post_type` 将其分发给对应的处理器。
|
||||
"""
|
||||
if event.post_type == 'message' and global_config.bot.get('ignore_self_message', False):
|
||||
if event.post_type == 'message' and global_config.bot.ignore_self_message:
|
||||
if hasattr(event, 'user_id') and hasattr(event, 'self_id') and event.user_id == event.self_id:
|
||||
return
|
||||
|
||||
@@ -130,14 +168,6 @@ class CommandManager:
|
||||
await bot.send(event, help_text.strip())
|
||||
|
||||
|
||||
# --- 全局单例 ---
|
||||
|
||||
# 确保前缀配置是元组格式
|
||||
if isinstance(command_prefixes, list):
|
||||
command_prefixes = tuple(command_prefixes)
|
||||
elif isinstance(command_prefixes, str):
|
||||
command_prefixes = (command_prefixes,)
|
||||
|
||||
# 实例化全局唯一的命令管理器
|
||||
matcher = CommandManager(prefixes=command_prefixes)
|
||||
matcher = CommandManager(prefixes=_final_prefixes)
|
||||
|
||||
|
||||
@@ -13,64 +13,17 @@
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from functools import total_ordering
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..utils.logger import logger
|
||||
from ..utils.singleton import Singleton
|
||||
from .admin_manager import admin_manager
|
||||
from ..permission import Permission
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Permission:
|
||||
"""
|
||||
权限封装类
|
||||
|
||||
封装了权限的名称和等级,并提供了比较方法。
|
||||
使用 @total_ordering 装饰器可以自动生成所有的比较运算符。
|
||||
"""
|
||||
def __init__(self, name: str, level: int):
|
||||
"""
|
||||
初始化权限对象
|
||||
|
||||
Args:
|
||||
name (str): 权限名称 (e.g., "admin", "op")
|
||||
level (int): 权限等级,数字越大权限越高
|
||||
"""
|
||||
self.name = name
|
||||
self.level = level
|
||||
|
||||
def __eq__(self, other):
|
||||
"""
|
||||
判断权限是否相等
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self.level == other.level
|
||||
|
||||
def __lt__(self, other):
|
||||
"""
|
||||
判断权限是否小于另一个权限
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self.level < other.level
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
返回权限的字符串表示(即权限名称)
|
||||
"""
|
||||
return self.name
|
||||
|
||||
|
||||
# 定义全局权限常量
|
||||
ADMIN = Permission("admin", 3)
|
||||
OP = Permission("op", 2)
|
||||
USER = Permission("user", 1)
|
||||
|
||||
# 用于从字符串名称查找权限对象的字典
|
||||
_PERMISSIONS: Dict[str, Permission] = {
|
||||
p.name: p for p in [ADMIN, OP, USER]
|
||||
p.value: p for p in Permission
|
||||
}
|
||||
|
||||
|
||||
@@ -88,8 +41,7 @@ class PermissionManager(Singleton):
|
||||
|
||||
如果已经初始化过,则直接返回。
|
||||
"""
|
||||
super().__init__()
|
||||
if not self._initialized:
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
|
||||
# 权限数据文件路径
|
||||
@@ -111,6 +63,7 @@ class PermissionManager(Singleton):
|
||||
self.load()
|
||||
|
||||
logger.info("权限管理器初始化完成")
|
||||
super().__init__()
|
||||
|
||||
def load(self) -> None:
|
||||
"""
|
||||
@@ -164,12 +117,12 @@ class PermissionManager(Singleton):
|
||||
"""
|
||||
# 首先,通过 AdminManager 检查是否为管理员
|
||||
if await admin_manager.is_admin(user_id):
|
||||
return ADMIN
|
||||
return Permission.ADMIN
|
||||
|
||||
# 如果不是管理员,则从 permissions.json 中查找
|
||||
user_id_str = str(user_id)
|
||||
level_name = self._data["users"].get(user_id_str, USER.name)
|
||||
return _PERMISSIONS.get(level_name, USER)
|
||||
level_name = self._data["users"].get(user_id_str, Permission.USER.value)
|
||||
return _PERMISSIONS.get(level_name, Permission.USER)
|
||||
|
||||
def set_user_permission(self, user_id: int, permission: Permission) -> None:
|
||||
"""
|
||||
@@ -182,13 +135,13 @@ class PermissionManager(Singleton):
|
||||
Raises:
|
||||
ValueError: 如果权限对象无效
|
||||
"""
|
||||
if not isinstance(permission, Permission) or permission.name not in _PERMISSIONS:
|
||||
if not isinstance(permission, Permission):
|
||||
raise ValueError(f"无效的权限对象: {permission}")
|
||||
|
||||
user_id_str = str(user_id)
|
||||
self._data["users"][user_id_str] = permission.name
|
||||
self._data["users"][user_id_str] = permission.value
|
||||
self.save()
|
||||
logger.info(f"设置用户 {user_id} 的权限级别为 {permission.name}")
|
||||
logger.info(f"设置用户 {user_id} 的权限级别为 {permission.value}")
|
||||
|
||||
def remove_user(self, user_id: int) -> None:
|
||||
"""
|
||||
@@ -214,17 +167,17 @@ class PermissionManager(Singleton):
|
||||
Returns:
|
||||
bool: 如果用户权限 >= 所需权限,返回 True,否则返回 False
|
||||
"""
|
||||
# 如果传入的是字符串,先转换为 Permission 对象
|
||||
if isinstance(required_permission, str):
|
||||
required_permission = _PERMISSIONS.get(required_permission.lower())
|
||||
if not required_permission:
|
||||
# 如果是无效的权限字符串,默认拒绝
|
||||
logger.warning(f"检测到无效的权限检查字符串: {required_permission}")
|
||||
return False
|
||||
|
||||
user_permission = await self.get_user_permission(user_id)
|
||||
return user_permission >= required_permission
|
||||
|
||||
def get_all_user_permissions(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取所有已配置的用户权限
|
||||
|
||||
:return: 一个包含所有用户权限的字典
|
||||
"""
|
||||
return self._data["users"].copy()
|
||||
|
||||
def get_all_users(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取所有设置了权限的用户及其级别名称
|
||||
@@ -243,22 +196,22 @@ class PermissionManager(Singleton):
|
||||
logger.info("已清空所有权限设置")
|
||||
|
||||
|
||||
# 全局权限管理器实例
|
||||
permission_manager = PermissionManager()
|
||||
|
||||
def require_admin(func):
|
||||
"""
|
||||
一个装饰器,用于限制命令只能由管理员执行。
|
||||
"""
|
||||
from functools import wraps
|
||||
from models.events.message import MessageEvent
|
||||
from core.managers import permission_manager
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(event: MessageEvent, *args, **kwargs):
|
||||
user_id = event.user_id
|
||||
if await permission_manager.check_permission(user_id, ADMIN):
|
||||
if await permission_manager.check_permission(user_id, Permission.ADMIN):
|
||||
return await func(event, *args, **kwargs)
|
||||
else:
|
||||
await event.reply("抱歉,您没有权限执行此命令。")
|
||||
# 假设 event 对象有 reply 方法
|
||||
if hasattr(event, "reply"):
|
||||
await event.reply("抱歉,您没有权限执行此命令。")
|
||||
return None
|
||||
return wrapper
|
||||
|
||||
@@ -1,126 +1,97 @@
|
||||
"""
|
||||
插件管理器模块
|
||||
|
||||
负责扫描、加载和管理 `base_plugins` 目录下的所有插件。
|
||||
负责扫描、加载和管理 `plugins` 目录下的所有插件。
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
from typing import Set
|
||||
|
||||
from .command_manager import matcher
|
||||
from ..utils.exceptions import SyncHandlerError
|
||||
from ..utils.logger import logger
|
||||
from ..utils.executor import run_in_thread_pool
|
||||
|
||||
|
||||
def load_all_plugins():
|
||||
class PluginManager:
|
||||
"""
|
||||
扫描并加载 `plugins` 目录下的所有插件。
|
||||
|
||||
该函数会遍历 `plugins` 目录下的所有模块:
|
||||
1. 如果模块已加载,则执行 reload 操作(用于热重载)。
|
||||
2. 如果模块未加载,则执行 import 操作。
|
||||
|
||||
加载过程中会提取插件元数据 `__plugin_meta__` 并注册到 CommandManager。
|
||||
插件管理器类
|
||||
"""
|
||||
plugin_dir = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "..", "plugins"
|
||||
)
|
||||
package_name = "plugins"
|
||||
def __init__(self, command_manager):
|
||||
"""
|
||||
初始化插件管理器
|
||||
|
||||
logger.info(f"正在从 {package_name} 加载插件...")
|
||||
:param command_manager: CommandManager的实例
|
||||
"""
|
||||
self.command_manager = command_manager
|
||||
self.loaded_plugins: Set[str] = set()
|
||||
|
||||
for loader, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
|
||||
full_module_name = f"{package_name}.{module_name}"
|
||||
def load_all_plugins(self):
|
||||
"""
|
||||
扫描并加载 `plugins` 目录下的所有插件。
|
||||
"""
|
||||
# 使用 pathlib 获取更可靠的路径
|
||||
# 当前文件: core/managers/plugin_manager.py
|
||||
# 目标: plugins/
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 回退两级到项目根目录 (core/managers -> core -> root)
|
||||
root_dir = os.path.dirname(os.path.dirname(current_dir))
|
||||
plugin_dir = os.path.join(root_dir, "plugins")
|
||||
|
||||
package_name = "plugins"
|
||||
|
||||
if not os.path.exists(plugin_dir):
|
||||
logger.error(f"插件目录不存在: {plugin_dir}")
|
||||
return
|
||||
|
||||
logger.info(f"正在从 {package_name} 加载插件 (路径: {plugin_dir})...")
|
||||
|
||||
for _, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
|
||||
full_module_name = f"{package_name}.{module_name}"
|
||||
|
||||
try:
|
||||
if full_module_name in self.loaded_plugins:
|
||||
self.command_manager.unload_plugin(full_module_name)
|
||||
module = importlib.reload(sys.modules[full_module_name])
|
||||
action = "重载"
|
||||
else:
|
||||
module = importlib.import_module(full_module_name)
|
||||
action = "加载"
|
||||
|
||||
if hasattr(module, "__plugin_meta__"):
|
||||
meta = getattr(module, "__plugin_meta__")
|
||||
self.command_manager.plugins[full_module_name] = meta
|
||||
|
||||
self.loaded_plugins.add(full_module_name)
|
||||
|
||||
type_str = "包" if is_pkg else "文件"
|
||||
logger.success(f" [{type_str}] 成功{action}: {module_name}")
|
||||
except SyncHandlerError as e:
|
||||
logger.error(f" 插件 {module_name} 加载失败: {e} (跳过此插件)")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f" {action if 'action' in locals() else '加载'}插件 {module_name} 失败: {e}"
|
||||
)
|
||||
|
||||
def reload_plugin(self, full_module_name: str):
|
||||
"""
|
||||
精确重载单个插件。
|
||||
"""
|
||||
if full_module_name not in self.loaded_plugins:
|
||||
logger.warning(f"尝试重载一个未被加载的插件: {full_module_name},将按首次加载处理。")
|
||||
|
||||
if full_module_name not in sys.modules:
|
||||
logger.error(f"重载失败: 模块 {full_module_name} 未在 sys.modules 中找到。")
|
||||
return
|
||||
|
||||
try:
|
||||
if full_module_name in sys.modules:
|
||||
module = importlib.reload(sys.modules[full_module_name])
|
||||
action = "重载"
|
||||
else:
|
||||
module = importlib.import_module(full_module_name)
|
||||
action = "加载"
|
||||
|
||||
# 提取插件元数据
|
||||
self.command_manager.unload_plugin(full_module_name)
|
||||
module = importlib.reload(sys.modules[full_module_name])
|
||||
|
||||
if hasattr(module, "__plugin_meta__"):
|
||||
meta = getattr(module, "__plugin_meta__")
|
||||
matcher.plugins[full_module_name] = meta
|
||||
|
||||
type_str = "包" if is_pkg else "文件"
|
||||
logger.success(f" [{type_str}] 成功{action}: {module_name}")
|
||||
except SyncHandlerError as e:
|
||||
logger.error(f" 插件 {module_name} 加载失败: {e} (跳过此插件)")
|
||||
self.command_manager.plugins[full_module_name] = meta
|
||||
|
||||
logger.success(f"插件 {full_module_name} 已成功重载。")
|
||||
except Exception as e:
|
||||
print(
|
||||
f" {action if 'action' in locals() else '加载'}插件 {module_name} 失败: {e}"
|
||||
)
|
||||
|
||||
|
||||
class PluginDataManager:
|
||||
"""
|
||||
用于管理插件产生的数据文件的类
|
||||
"""
|
||||
|
||||
def __init__(self, plugin_name: str):
|
||||
"""
|
||||
初始化插件数据管理器
|
||||
|
||||
:param plugin_name: 插件名称
|
||||
"""
|
||||
self.plugin_name = plugin_name
|
||||
self.data_file = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"..",
|
||||
"plugins",
|
||||
"data",
|
||||
self.plugin_name + ".json",
|
||||
)
|
||||
self.data = {}
|
||||
|
||||
async def load(self):
|
||||
"""读取配置文件"""
|
||||
if not os.path.exists(self.data_file):
|
||||
await self.set(self.plugin_name, [])
|
||||
try:
|
||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
||||
self.data = await run_in_thread_pool(json.load, f)
|
||||
except json.JSONDecodeError:
|
||||
self.data = {}
|
||||
|
||||
async def save(self):
|
||||
"""保存配置到文件"""
|
||||
with open(self.data_file, "w", encoding="utf-8") as f:
|
||||
await run_in_thread_pool(json.dump, self.data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""获取配置项"""
|
||||
return self.data.get(key, default)
|
||||
|
||||
async def set(self, key, value):
|
||||
"""设置配置项"""
|
||||
self.data[key] = value
|
||||
await self.save()
|
||||
|
||||
async def add(self, key, value):
|
||||
"""添加配置项"""
|
||||
if key not in self.data:
|
||||
self.data[key] = []
|
||||
self.data[key].append(value)
|
||||
await self.save()
|
||||
|
||||
async def remove(self, key):
|
||||
"""删除配置项"""
|
||||
if key in self.data:
|
||||
del self.data[key]
|
||||
await self.save()
|
||||
|
||||
async def clear(self):
|
||||
"""清空所有配置"""
|
||||
self.data.clear()
|
||||
await self.save()
|
||||
|
||||
def get_all(self):
|
||||
return self.data.copy()
|
||||
logger.exception(f"重载插件 {full_module_name} 时发生错误: {e}")
|
||||
|
||||
@@ -20,10 +20,11 @@ class RedisManager:
|
||||
"""
|
||||
if self._redis is None:
|
||||
try:
|
||||
host = config.redis['host']
|
||||
port = config.redis['port']
|
||||
db = config.redis['db']
|
||||
password = config.redis.get('password')
|
||||
redis_config = config.redis
|
||||
host = redis_config.host
|
||||
port = redis_config.port
|
||||
db = redis_config.db
|
||||
password = redis_config.password
|
||||
|
||||
logger.info(f"正在尝试连接 Redis: {host}:{port}, DB: {db}")
|
||||
|
||||
@@ -54,5 +55,17 @@ class RedisManager:
|
||||
raise ConnectionError("Redis 未初始化或连接失败,请先调用 initialize()")
|
||||
return self._redis
|
||||
|
||||
async def get(self, name):
|
||||
"""
|
||||
获取指定键的值
|
||||
"""
|
||||
return await self.redis.get(name)
|
||||
|
||||
async def set(self, name, value, ex=None):
|
||||
"""
|
||||
设置指定键的值
|
||||
"""
|
||||
return await self.redis.set(name, value, ex=ex)
|
||||
|
||||
# 全局 Redis 管理器实例
|
||||
redis_manager = RedisManager()
|
||||
|
||||
42
core/permission.py
Normal file
42
core/permission.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Permission(Enum):
|
||||
"""
|
||||
定义用户权限等级的枚举类。
|
||||
|
||||
使用 @total_ordering 装饰器,只需定义 __lt__ 和 __eq__,
|
||||
即可自动实现所有比较运算符。
|
||||
"""
|
||||
USER = "user"
|
||||
OP = "op"
|
||||
ADMIN = "admin"
|
||||
|
||||
@property
|
||||
def _level_map(self):
|
||||
"""
|
||||
内部属性,用于映射枚举成员到整数等级。
|
||||
"""
|
||||
return {
|
||||
Permission.USER: 1,
|
||||
Permission.OP: 2,
|
||||
Permission.ADMIN: 3
|
||||
}
|
||||
|
||||
def __lt__(self, other):
|
||||
"""
|
||||
比较当前权限是否小于另一个权限。
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self._level_map[self] < self._level_map[other]
|
||||
|
||||
def __ge__(self, other):
|
||||
"""
|
||||
比较当前权限是否大于等于另一个权限。
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self._level_map[self] >= self._level_map[other]
|
||||
@@ -2,7 +2,8 @@
|
||||
import asyncio
|
||||
import docker
|
||||
from docker.tls import TLSConfig
|
||||
from typing import Dict, Any, Callable
|
||||
from docker.types import LogConfig
|
||||
from typing import Any, Callable
|
||||
|
||||
from core.utils.logger import logger
|
||||
|
||||
@@ -10,21 +11,20 @@ class CodeExecutor:
|
||||
"""
|
||||
代码执行引擎,负责管理一个异步任务队列和并发的 Docker 容器执行。
|
||||
"""
|
||||
def __init__(self, bot_instance, config: Dict[str, Any]):
|
||||
def __init__(self, config: Any):
|
||||
"""
|
||||
初始化代码执行引擎。
|
||||
:param bot_instance: Bot 实例,用于后续的消息回复。
|
||||
:param config: 从 config.toml 加载的配置字典。
|
||||
:param config: 从 config_loader.py 加载的全局配置对象。
|
||||
"""
|
||||
self.bot = bot_instance
|
||||
self.task_queue = asyncio.Queue()
|
||||
self.bot: Any = None # Bot 实例将在 WS 连接成功后动态注入
|
||||
self.task_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
# 从传入的配置中读取 Docker 相关设置
|
||||
docker_config = config.docker
|
||||
self.docker_base_url = docker_config.get("base_url")
|
||||
self.sandbox_image = docker_config.get("sandbox_image", "python-sandbox:latest")
|
||||
self.timeout = docker_config.get("timeout", 10)
|
||||
concurrency = docker_config.get("concurrency_limit", 5)
|
||||
self.docker_base_url = docker_config.base_url
|
||||
self.sandbox_image = docker_config.sandbox_image
|
||||
self.timeout = docker_config.timeout
|
||||
concurrency = docker_config.concurrency_limit
|
||||
|
||||
self.concurrency_limit = asyncio.Semaphore(concurrency)
|
||||
self.docker_client = None
|
||||
@@ -34,10 +34,10 @@ class CodeExecutor:
|
||||
if self.docker_base_url:
|
||||
# 如果配置了远程 Docker 地址,则使用 TLS 选项进行连接
|
||||
tls_config = None
|
||||
if docker_config.get("tls_verify", False):
|
||||
if docker_config.tls_verify:
|
||||
tls_config = TLSConfig(
|
||||
ca_cert=docker_config.get("ca_cert_path"),
|
||||
client_cert=(docker_config.get("client_cert_path"), docker_config.get("client_key_path")),
|
||||
ca_cert=docker_config.ca_cert_path,
|
||||
client_cert=(docker_config.client_cert_path, docker_config.client_key_path),
|
||||
verify=True
|
||||
)
|
||||
self.docker_client = docker.DockerClient(base_url=self.docker_base_url, tls=tls_config)
|
||||
@@ -60,7 +60,15 @@ class CodeExecutor:
|
||||
将代码执行任务添加到队列中。
|
||||
:param code: 待执行的 Python 代码字符串。
|
||||
:param callback: 执行完毕后用于回复结果的回调函数。
|
||||
:raises RuntimeError: 如果 Docker 客户端未初始化。
|
||||
"""
|
||||
if not self.docker_client:
|
||||
logger.warning("[CodeExecutor] 尝试添加任务,但 Docker 客户端未初始化。任务被拒绝。")
|
||||
# 这里可以选择抛出异常,或者直接调用回调返回错误信息
|
||||
# 为了用户体验,我们构造一个错误结果并直接调用回调(如果可能)
|
||||
# 但由于 callback 返回 Future,这里简单起见,我们记录日志并抛出异常
|
||||
raise RuntimeError("Docker环境未就绪,无法执行代码。")
|
||||
|
||||
task = {"code": code, "callback": callback}
|
||||
await self.task_queue.put(task)
|
||||
logger.info(f"[CodeExecutor] 新的代码执行任务已入队 (队列当前长度: {self.task_queue.qsize()})。")
|
||||
@@ -125,6 +133,9 @@ class CodeExecutor:
|
||||
同步函数:在 Docker 容器中运行代码。
|
||||
此函数通过手动管理容器生命周期来提高稳定性。
|
||||
"""
|
||||
if self.docker_client is None:
|
||||
raise docker.errors.DockerException("Docker client is not initialized.")
|
||||
|
||||
container = None
|
||||
try:
|
||||
# 1. 创建容器
|
||||
@@ -134,7 +145,7 @@ class CodeExecutor:
|
||||
mem_limit='128m',
|
||||
cpu_shares=512,
|
||||
network_disabled=True,
|
||||
log_config={'type': 'json-file', 'config': {'max-size': '1m'}},
|
||||
log_config=LogConfig(type='json-file', config={'max-size': '1m'}),
|
||||
)
|
||||
# 2. 启动容器
|
||||
container.start()
|
||||
@@ -150,7 +161,7 @@ class CodeExecutor:
|
||||
# 5. 检查退出码,如果不为 0,则手动抛出 ContainerError
|
||||
if result.get('StatusCode', 0) != 0:
|
||||
raise docker.errors.ContainerError(
|
||||
container, result['StatusCode'], f"python -c '{code}'", self.sandbox_image, stderr
|
||||
container, result['StatusCode'], f"python -c '{code}'", self.sandbox_image, stderr.decode('utf-8')
|
||||
)
|
||||
|
||||
return stdout
|
||||
@@ -166,11 +177,11 @@ class CodeExecutor:
|
||||
except Exception as e:
|
||||
logger.error(f"[CodeExecutor] 强制移除容器 {container.id} 时失败: {e}")
|
||||
|
||||
def initialize_executor(bot_instance, config: Dict[str, Any]):
|
||||
def initialize_executor(config: Any):
|
||||
"""
|
||||
初始化并返回一个 CodeExecutor 实例。
|
||||
"""
|
||||
return CodeExecutor(bot_instance, config)
|
||||
return CodeExecutor(config)
|
||||
|
||||
async def run_in_thread_pool(sync_func, *args, **kwargs):
|
||||
"""
|
||||
|
||||
52
core/ws.py
52
core/ws.py
@@ -13,11 +13,12 @@ WebSocket 连接。它是整个机器人框架的底层通信基础。
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
import uuid
|
||||
|
||||
import websockets
|
||||
|
||||
from models import EventFactory
|
||||
from models.events.factory import EventFactory
|
||||
|
||||
from .bot import Bot
|
||||
from .config_loader import global_config
|
||||
@@ -30,7 +31,7 @@ class WS:
|
||||
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, code_executor=None):
|
||||
"""
|
||||
初始化 WebSocket 客户端。
|
||||
|
||||
@@ -38,13 +39,15 @@ class WS:
|
||||
"""
|
||||
# 读取参数
|
||||
cfg = global_config.napcat_ws
|
||||
self.url = cfg.get("uri")
|
||||
self.token = cfg.get("token")
|
||||
self.reconnect_interval = cfg.get("reconnect_interval", 5)
|
||||
self.url = cfg.uri
|
||||
self.token = cfg.token
|
||||
self.reconnect_interval = cfg.reconnect_interval
|
||||
|
||||
self.ws = None
|
||||
self._pending_requests = {}
|
||||
self.bot = Bot(self)
|
||||
self.bot: Bot | None = None
|
||||
self.self_id: int | None = None
|
||||
self.code_executor = code_executor
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
@@ -124,18 +127,43 @@ class WS:
|
||||
try:
|
||||
# 使用工厂创建事件对象
|
||||
event = EventFactory.create_event(event_data)
|
||||
|
||||
# 尝试初始化 Bot 实例 (如果尚未初始化且事件包含 self_id)
|
||||
# 只要事件中包含 self_id,我们就可以初始化 Bot,不必非要等待 meta_event
|
||||
if self.bot is None and hasattr(event, 'self_id'):
|
||||
self.self_id = event.self_id
|
||||
self.bot = Bot(self)
|
||||
logger.success(f"Bot 实例初始化完成: self_id={self.self_id}")
|
||||
|
||||
# 将代码执行器注入到 Bot 和执行器自身
|
||||
if self.code_executor:
|
||||
self.bot.code_executor = self.code_executor
|
||||
self.code_executor.bot = self.bot
|
||||
logger.info("代码执行器已成功注入 Bot 实例。")
|
||||
|
||||
# 如果 bot 尚未初始化,则不处理后续事件
|
||||
if self.bot is None:
|
||||
logger.warning("Bot 尚未初始化,跳过事件处理。")
|
||||
return
|
||||
|
||||
event.bot = self.bot # 注入 Bot 实例
|
||||
|
||||
# 打印日志
|
||||
if event.post_type == "message":
|
||||
sender_name = event.sender.nickname if event.sender else "Unknown"
|
||||
logger.info(f"[消息] {event.message_type} | {event.user_id}({sender_name}): {event.raw_message}")
|
||||
sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown"
|
||||
message_type = getattr(event, "message_type", "Unknown")
|
||||
user_id = getattr(event, "user_id", "Unknown")
|
||||
raw_message = getattr(event, "raw_message", "")
|
||||
logger.info(f"[消息] {message_type} | {user_id}({sender_name}): {raw_message}")
|
||||
elif event.post_type == "notice":
|
||||
logger.info(f"[通知] {event.notice_type}")
|
||||
notice_type = getattr(event, "notice_type", "Unknown")
|
||||
logger.info(f"[通知] {notice_type}")
|
||||
elif event.post_type == "request":
|
||||
logger.info(f"[请求] {event.request_type}")
|
||||
request_type = getattr(event, "request_type", "Unknown")
|
||||
logger.info(f"[请求] {request_type}")
|
||||
elif event.post_type == "meta_event":
|
||||
logger.debug(f"[元事件] {event.meta_event_type}")
|
||||
meta_event_type = getattr(event, "meta_event_type", "Unknown")
|
||||
logger.debug(f"[元事件] {meta_event_type}")
|
||||
|
||||
|
||||
# 分发事件
|
||||
@@ -144,7 +172,7 @@ class WS:
|
||||
except Exception as e:
|
||||
logger.exception(f"事件处理异常: {e}")
|
||||
|
||||
async def call_api(self, action: str, params: dict = None):
|
||||
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
|
||||
"""
|
||||
向 OneBot v11 实现端发送一个 API 请求。
|
||||
|
||||
|
||||
Reference in New Issue
Block a user