Dev (#28)
* 滚木 * feat: 重构核心架构,增强类型安全与插件管理 本次提交对核心模块进行了深度重构,引入 Pydantic 增强配置管理的类型安全性,并全面优化了插件管理系统。 主要变更详情: 1. 核心架构与配置 - 重构配置加载模块:引入 Pydantic 模型 (`core/config_models.py`),提供严格的配置项类型检查、验证及默认值管理。 - 统一模块结构:规范化模块导入路径,移除冗余的 `__init__.py` 文件,提升项目结构的清晰度。 - 性能优化:集成 Redis 缓存支持 (`RedisManager`),有效降低高频 API 调用开销,提升响应速度。 2. 插件系统升级 - 实现热重载机制:新增插件文件变更监听功能,支持开发过程中自动重载插件,提升开发效率。 - 优化生命周期管理:改进插件加载与卸载逻辑,支持精确卸载指定插件及其关联的命令、事件处理器和定时任务。 3. 功能特性增强 - 新增媒体 API:引入 `MediaAPI` 模块,封装图片、语音等富媒体资源的获取与处理接口。 - 完善权限体系:重构权限管理系统,实现管理员与操作员的分级控制,支持更细粒度的命令权限校验。 4. 代码质量与稳定性 - 全面类型修复:解决 `mypy` 静态类型检查发现的大量类型错误(包括 `CommandManager`、`EventFactory` 及 `Bot` API 签名不匹配问题)。 - 增强错误处理:优化消息处理管道的异常捕获机制,完善关键路径的日志记录,提升系统运行稳定性。 * feat: 添加测试用例并优化代码结构 refactor(permission_manager): 调整初始化顺序和逻辑 fix(admin_manager): 修复初始化逻辑和目录创建问题 feat(ws): 优化Bot实例初始化条件 feat(message): 增强MessageSegment功能并添加测试 feat(events): 支持字符串格式的消息解析 test: 添加核心功能测试用例 refactor(plugin_manager): 改进插件路径处理 style: 清理无用导入和代码 chore: 更新依赖项
This commit is contained in:
@@ -1,6 +0,0 @@
|
||||
from .managers.command_manager import matcher
|
||||
from .config_loader import global_config
|
||||
from .managers.plugin_manager import PluginDataManager
|
||||
from .ws import WS
|
||||
|
||||
__all__ = ["WS", "matcher", "global_config", "PluginDataManager"]
|
||||
|
||||
@@ -3,6 +3,7 @@ from .message import MessageAPI
|
||||
from .group import GroupAPI
|
||||
from .friend import FriendAPI
|
||||
from .account import AccountAPI
|
||||
from .media import MediaAPI
|
||||
|
||||
__all__ = [
|
||||
"BaseAPI",
|
||||
@@ -10,4 +11,5 @@ __all__ = [
|
||||
"GroupAPI",
|
||||
"FriendAPI",
|
||||
"AccountAPI",
|
||||
"MediaAPI",
|
||||
]
|
||||
|
||||
@@ -162,3 +162,56 @@ class AccountAPI(BaseAPI):
|
||||
"""
|
||||
return await self.call_api("clean_cache")
|
||||
|
||||
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> Any:
|
||||
"""
|
||||
获取陌生人信息。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Any: 包含陌生人信息的字典或对象。
|
||||
"""
|
||||
return await self.call_api("get_stranger_info", {"user_id": user_id, "no_cache": no_cache})
|
||||
|
||||
async def get_friend_list(self, no_cache: bool = False) -> list:
|
||||
"""
|
||||
获取好友列表。
|
||||
|
||||
Args:
|
||||
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||
|
||||
Returns:
|
||||
list: 好友列表。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_friend_list:{self.self_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.get(cache_key)
|
||||
if cached_data:
|
||||
return json.loads(cached_data)
|
||||
|
||||
res = await self.call_api("get_friend_list")
|
||||
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return res
|
||||
|
||||
async def get_group_list(self, no_cache: bool = False) -> list:
|
||||
"""
|
||||
获取群列表。
|
||||
|
||||
Args:
|
||||
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||
|
||||
Returns:
|
||||
list: 群列表。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_group_list:{self.self_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.get(cache_key)
|
||||
if cached_data:
|
||||
return json.loads(cached_data)
|
||||
|
||||
res = await self.call_api("get_group_list")
|
||||
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return res
|
||||
|
||||
|
||||
@@ -1,24 +1,50 @@
|
||||
"""
|
||||
API 基础模块
|
||||
|
||||
定义了 API 调用的基础接口。
|
||||
定义了 API 调用的基础接口和统一处理逻辑。
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from ..utils.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ws import WS
|
||||
|
||||
|
||||
class BaseAPI(ABC):
|
||||
class BaseAPI:
|
||||
"""
|
||||
API 基础抽象类
|
||||
API 基础类,提供了统一的 `call_api` 方法,包含日志记录和异常处理。
|
||||
"""
|
||||
_ws: "WS"
|
||||
self_id: int
|
||||
|
||||
def __init__(self, ws_client: "WS", self_id: int):
|
||||
self._ws = ws_client
|
||||
self.self_id = self_id
|
||||
|
||||
@abstractmethod
|
||||
async def call_api(self, action: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
调用 API
|
||||
调用 OneBot v11 API,并提供统一的日志和异常处理。
|
||||
|
||||
:param action: API 动作名称
|
||||
:param params: API 参数
|
||||
:return: API 响应结果
|
||||
:return: API 响应结果的数据部分
|
||||
:raises Exception: 当 API 调用失败或发生网络错误时
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
try:
|
||||
logger.debug(f"调用API -> action: {action}, params: {params}")
|
||||
response = await self._ws.call_api(action, params)
|
||||
logger.debug(f"API响应 <- {response}")
|
||||
|
||||
if response.get("status") == "failed":
|
||||
logger.warning(f"API调用失败: {response}")
|
||||
|
||||
return response.get("data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API调用异常: action={action}, params={params}, error={e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
该模块定义了 `GroupAPI` Mixin 类,提供了所有与群组管理、成员操作
|
||||
等相关的 OneBot v11 API 封装。
|
||||
"""
|
||||
from typing import List, Dict, Any
|
||||
from typing import List, Dict, Any, Optional
|
||||
import json
|
||||
from ..managers.redis_manager import redis_manager
|
||||
from .base import BaseAPI
|
||||
@@ -46,7 +46,7 @@ class GroupAPI(BaseAPI):
|
||||
"""
|
||||
return await self.call_api("set_group_ban", {"group_id": group_id, "user_id": user_id, "duration": duration})
|
||||
|
||||
async def set_group_anonymous_ban(self, group_id: int, anonymous: Dict[str, Any] = None, duration: int = 1800, flag: str = None) -> Dict[str, Any]:
|
||||
async def set_group_anonymous_ban(self, group_id: int, anonymous: Optional[Dict[str, Any]] = None, duration: int = 1800, flag: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
禁言群组中的匿名用户。
|
||||
|
||||
@@ -61,7 +61,7 @@ class GroupAPI(BaseAPI):
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
params = {"group_id": group_id, "duration": duration}
|
||||
params: Dict[str, Any] = {"group_id": group_id, "duration": duration}
|
||||
if anonymous:
|
||||
params["anonymous"] = anonymous
|
||||
if flag:
|
||||
@@ -187,17 +187,18 @@ class GroupAPI(BaseAPI):
|
||||
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return GroupInfo(**res)
|
||||
|
||||
async def get_group_list(self) -> List[GroupInfo]:
|
||||
async def get_group_list(self) -> Any:
|
||||
"""
|
||||
获取机器人加入的所有群组的列表。
|
||||
|
||||
Returns:
|
||||
List[GroupInfo]: 包含所有群组信息的 `GroupInfo` 对象列表。
|
||||
Any: 包含所有群组信息的列表(可能是字典列表或对象列表)。
|
||||
"""
|
||||
res = await self.call_api("get_group_list")
|
||||
|
||||
# 增加日志记录 API 原始返回
|
||||
logger.debug(f"OneBot API 'get_group_list' raw response: {res}")
|
||||
return res
|
||||
|
||||
# 健壮性处理:处理标准的 OneBot v11 响应格式
|
||||
if isinstance(res, dict) and res.get("status") == "ok":
|
||||
|
||||
39
core/api/media.py
Normal file
39
core/api/media.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
媒体API模块
|
||||
|
||||
封装了与图片、语音等媒体文件相关的API。
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
|
||||
from .base import BaseAPI
|
||||
|
||||
|
||||
class MediaAPI(BaseAPI):
|
||||
"""
|
||||
媒体相关API
|
||||
"""
|
||||
|
||||
async def can_send_image(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查是否可以发送图片
|
||||
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="can_send_image")
|
||||
|
||||
async def can_send_record(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查是否可以发送语音
|
||||
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="can_send_record")
|
||||
|
||||
async def get_image(self, file: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取图片信息
|
||||
|
||||
:param file: 图片文件名或路径
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="get_image", params={"file": file})
|
||||
@@ -8,7 +8,8 @@ from typing import Union, List, Dict, Any, TYPE_CHECKING
|
||||
from .base import BaseAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import MessageSegment, OneBotEvent
|
||||
from models.message import MessageSegment
|
||||
from models.events.base import OneBotEvent
|
||||
|
||||
|
||||
class MessageAPI(BaseAPI):
|
||||
@@ -156,24 +157,6 @@ class MessageAPI(BaseAPI):
|
||||
"""
|
||||
return await self.call_api("send_private_forward_msg", {"user_id": user_id, "messages": messages})
|
||||
|
||||
async def can_send_image(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查当前机器人账号是否可以发送图片。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("can_send_image")
|
||||
|
||||
async def can_send_record(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查当前机器人账号是否可以发送语音。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("can_send_record")
|
||||
|
||||
def _process_message(self, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Union[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
内部方法:将消息内容处理成 OneBot API 可接受的格式。
|
||||
@@ -192,7 +175,7 @@ class MessageAPI(BaseAPI):
|
||||
return message
|
||||
|
||||
# 避免循环导入,在运行时导入
|
||||
from models import MessageSegment
|
||||
from models.message import MessageSegment
|
||||
|
||||
if isinstance(message, MessageSegment):
|
||||
return [self._segment_to_dict(message)]
|
||||
|
||||
33
core/bot.py
33
core/bot.py
@@ -13,14 +13,15 @@ Bot 核心抽象模块
|
||||
from typing import TYPE_CHECKING, Dict, Any, List, Union
|
||||
from models.events.base import OneBotEvent
|
||||
from models.message import MessageSegment
|
||||
from models.objects import GroupInfo, StrangerInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .WS import WS
|
||||
from .ws import WS
|
||||
|
||||
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI
|
||||
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI
|
||||
|
||||
|
||||
class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI):
|
||||
class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI):
|
||||
"""
|
||||
机器人核心类,封装了所有与 OneBot API 的交互。
|
||||
|
||||
@@ -35,22 +36,22 @@ class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI):
|
||||
Args:
|
||||
ws_client (WS): WebSocket 客户端实例,负责底层的 API 请求和响应处理。
|
||||
"""
|
||||
self.ws = ws_client
|
||||
super().__init__(ws_client, ws_client.self_id or 0)
|
||||
self.code_executor = None
|
||||
|
||||
async def call_api(self, action: str, params: Dict[str, Any] = None) -> Any:
|
||||
"""
|
||||
底层 API 调用方法。
|
||||
async def get_group_list(self, no_cache: bool = False) -> List[GroupInfo]:
|
||||
# GroupAPI.get_group_list 不支持 no_cache 参数,这里忽略它
|
||||
result = await super().get_group_list()
|
||||
# 确保结果是 GroupInfo 对象列表
|
||||
return [GroupInfo(**group) if isinstance(group, dict) else group for group in result]
|
||||
|
||||
所有具体的 API 实现最终都会调用此方法,通过 WebSocket 发送请求。
|
||||
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> StrangerInfo:
|
||||
result = await super().get_stranger_info(user_id=user_id, no_cache=no_cache)
|
||||
# 确保结果是 StrangerInfo 对象
|
||||
if isinstance(result, dict):
|
||||
return StrangerInfo(**result)
|
||||
return result
|
||||
|
||||
Args:
|
||||
action (str): API 的动作名称,例如 "send_group_msg"。
|
||||
params (Dict[str, Any], optional): API 请求的参数字典。Defaults to None.
|
||||
|
||||
Returns:
|
||||
Any: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.ws.call_api(action, params)
|
||||
|
||||
def build_forward_node(self, user_id: int, nickname: str, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -4,9 +4,11 @@
|
||||
负责读取和解析 config.toml 配置文件,提供全局配置对象。
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import tomllib
|
||||
from pydantic import ValidationError
|
||||
from .config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel
|
||||
from .utils.logger import logger
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -21,73 +23,67 @@ class Config:
|
||||
:param file_path: 配置文件路径,默认为 "config.toml"
|
||||
"""
|
||||
self.path = Path(file_path)
|
||||
self._data: Dict[str, Any] = {}
|
||||
self._model: ConfigModel
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
加载配置文件
|
||||
加载并验证配置文件
|
||||
|
||||
:raises FileNotFoundError: 如果配置文件不存在
|
||||
:raises ValidationError: 如果配置格式不正确
|
||||
"""
|
||||
if not self.path.exists():
|
||||
logger.error(f"配置文件 {self.path} 未找到!")
|
||||
raise FileNotFoundError(f"配置文件 {self.path} 未找到!")
|
||||
|
||||
with open(self.path, "rb") as f:
|
||||
self._data = tomllib.load(f)
|
||||
try:
|
||||
logger.info(f"正在从 {self.path} 加载配置...")
|
||||
with open(self.path, "rb") as f:
|
||||
raw_config = tomllib.load(f)
|
||||
|
||||
self._model = ConfigModel(**raw_config)
|
||||
logger.success("配置加载并验证成功!")
|
||||
|
||||
except ValidationError as e:
|
||||
logger.error("配置验证失败,请检查 `config.toml` 文件中的以下错误:")
|
||||
for error in e.errors():
|
||||
field = " -> ".join(map(str, error["loc"]))
|
||||
logger.error(f" - 字段 '{field}': {error['msg']}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"加载配置文件时发生未知错误: {e}")
|
||||
raise
|
||||
|
||||
# 通过属性访问配置
|
||||
@property
|
||||
def napcat_ws(self) -> dict:
|
||||
def napcat_ws(self) -> NapCatWSModel:
|
||||
"""
|
||||
获取 NapCat WebSocket 配置
|
||||
|
||||
:return: 配置字典
|
||||
"""
|
||||
return self._data.get("napcat_ws", {})
|
||||
return self._model.napcat_ws
|
||||
|
||||
@property
|
||||
def bot(self) -> dict:
|
||||
def bot(self) -> BotModel:
|
||||
"""
|
||||
获取 Bot 基础配置
|
||||
|
||||
:return: 配置字典
|
||||
"""
|
||||
return self._data.get("bot", {})
|
||||
return self._model.bot
|
||||
|
||||
@property
|
||||
def features(self) -> dict:
|
||||
"""
|
||||
获取功能特性配置
|
||||
|
||||
:return: 配置字典
|
||||
"""
|
||||
return self._data.get("features", {})
|
||||
|
||||
@property
|
||||
def redis(self) -> dict:
|
||||
def redis(self) -> RedisModel:
|
||||
"""
|
||||
获取 Redis 配置
|
||||
|
||||
:return: 配置字典
|
||||
"""
|
||||
return self._data.get("redis", {})
|
||||
return self._model.redis
|
||||
|
||||
@property
|
||||
def docker(self) -> dict:
|
||||
def docker(self) -> DockerModel:
|
||||
"""
|
||||
获取 Docker 配置
|
||||
|
||||
:return: 配置字典
|
||||
"""
|
||||
return self._data.get("docker", {})
|
||||
return self._model.docker
|
||||
|
||||
|
||||
# 实例化全局配置对象
|
||||
global_config = Config()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(global_config.napcat_ws)
|
||||
print(global_config.bot.get("command"))
|
||||
print(type(global_config.bot.get("command")) is list)
|
||||
print(global_config.features)
|
||||
|
||||
60
core/config_models.py
Normal file
60
core/config_models.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Pydantic 配置模型模块
|
||||
|
||||
该模块使用 Pydantic 定义了与 `config.toml` 文件结构完全对应的配置模型。
|
||||
这使得配置的加载、校验和访问都变得类型安全和健壮。
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NapCatWSModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[napcat_ws]` 配置块。
|
||||
"""
|
||||
uri: str
|
||||
token: str
|
||||
reconnect_interval: int = 5
|
||||
|
||||
|
||||
class BotModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[bot]` 配置块。
|
||||
"""
|
||||
command: List[str] = Field(default_factory=lambda: ["/"])
|
||||
ignore_self_message: bool = True
|
||||
permission_denied_message: str = "权限不足,需要 {permission_name} 权限"
|
||||
|
||||
|
||||
class RedisModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[redis]` 配置块。
|
||||
"""
|
||||
host: str
|
||||
port: int
|
||||
db: int
|
||||
password: str
|
||||
|
||||
|
||||
class DockerModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[docker]` 配置块。
|
||||
"""
|
||||
base_url: Optional[str] = None
|
||||
sandbox_image: str = "python-sandbox:latest"
|
||||
timeout: int = 10
|
||||
concurrency_limit: int = 5
|
||||
tls_verify: bool = False
|
||||
ca_cert_path: Optional[str] = None
|
||||
client_cert_path: Optional[str] = None
|
||||
client_key_path: Optional[str] = None
|
||||
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
"""
|
||||
顶层配置模型,整合了所有子配置块。
|
||||
"""
|
||||
napcat_ws: NapCatWSModel
|
||||
bot: BotModel
|
||||
redis: RedisModel
|
||||
docker: DockerModel
|
||||
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"admins": []
|
||||
"admins": [2221577113]
|
||||
}
|
||||
@@ -6,11 +6,12 @@
|
||||
"""
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from ..bot import Bot
|
||||
if TYPE_CHECKING:
|
||||
from ..bot import Bot
|
||||
from ..config_loader import global_config
|
||||
from ..managers.permission_manager import Permission, permission_manager
|
||||
from ..permission import Permission
|
||||
from ..utils.executor import run_in_thread_pool
|
||||
|
||||
|
||||
@@ -22,7 +23,7 @@ class BaseHandler(ABC):
|
||||
self.handlers: List[Dict[str, Any]] = []
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, bot: Bot, event: Any):
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理事件
|
||||
"""
|
||||
@@ -31,7 +32,7 @@ class BaseHandler(ABC):
|
||||
async def _run_handler(
|
||||
self,
|
||||
func: Callable,
|
||||
bot: Bot,
|
||||
bot: "Bot",
|
||||
event: Any,
|
||||
args: Optional[List[str]] = None,
|
||||
permission_granted: Optional[bool] = None
|
||||
@@ -41,7 +42,7 @@ class BaseHandler(ABC):
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
params = sig.parameters
|
||||
kwargs = {}
|
||||
kwargs: Dict[str, Any] = {}
|
||||
|
||||
if "bot" in params:
|
||||
kwargs["bot"] = bot
|
||||
@@ -68,21 +69,41 @@ class MessageHandler(BaseHandler):
|
||||
super().__init__()
|
||||
self.prefixes = prefixes
|
||||
self.commands: Dict[str, Dict] = {}
|
||||
self.message_handlers: List[Callable] = []
|
||||
self.message_handlers: List[Dict[str, Any]] = []
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
清空所有已注册的消息和命令处理器
|
||||
"""
|
||||
self.commands.clear()
|
||||
self.message_handlers.clear()
|
||||
|
||||
def unregister_by_plugin_name(self, plugin_name: str):
|
||||
"""
|
||||
根据插件名卸载相关的消息和命令处理器
|
||||
"""
|
||||
# 卸载命令
|
||||
commands_to_remove = [name for name, info in self.commands.items() if info["plugin_name"] == plugin_name]
|
||||
for name in commands_to_remove:
|
||||
del self.commands[name]
|
||||
|
||||
# 卸载通用消息处理器
|
||||
self.message_handlers = [h for h in self.message_handlers if h["plugin_name"] != plugin_name]
|
||||
|
||||
def on_message(self) -> Callable:
|
||||
"""
|
||||
注册通用消息处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self.message_handlers.append(func)
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
self.message_handlers.append({"func": func, "plugin_name": plugin_name})
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def command(
|
||||
self,
|
||||
*names: str,
|
||||
*names: str,
|
||||
permission: Optional[Permission] = None,
|
||||
override_permission_check: bool = False
|
||||
) -> Callable:
|
||||
@@ -90,21 +111,25 @@ class MessageHandler(BaseHandler):
|
||||
注册命令处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
for name in names:
|
||||
self.commands[name] = {
|
||||
"func": func,
|
||||
"permission": permission,
|
||||
"override_permission_check": override_permission_check,
|
||||
"plugin_name": plugin_name,
|
||||
}
|
||||
return func
|
||||
return decorator
|
||||
|
||||
async def handle(self, bot: Bot, event: Any):
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理消息事件,包括通用消息和命令
|
||||
处理消息事件,分发给命令处理器或通用消息处理器
|
||||
"""
|
||||
for handler in self.message_handlers:
|
||||
consumed = await self._run_handler(handler, bot, event)
|
||||
from ..managers import permission_manager
|
||||
for handler_info in self.message_handlers:
|
||||
consumed = await self._run_handler(handler_info["func"], bot, event)
|
||||
if consumed:
|
||||
return
|
||||
|
||||
@@ -136,7 +161,7 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
if not permission_granted and not override_check:
|
||||
permission_name = permission.name if isinstance(permission, Permission) else permission
|
||||
message_template = global_config.bot.get("permission_denied_message", "权限不足,需要 {permission_name} 权限")
|
||||
message_template = global_config.bot.permission_denied_message
|
||||
await bot.send(event, message_template.format(permission_name=permission_name))
|
||||
return
|
||||
|
||||
@@ -153,12 +178,23 @@ class NoticeHandler(BaseHandler):
|
||||
"""
|
||||
通知事件处理器
|
||||
"""
|
||||
def clear(self):
|
||||
self.handlers.clear()
|
||||
|
||||
def unregister_by_plugin_name(self, plugin_name: str):
|
||||
"""
|
||||
根据插件名卸载相关的通知处理器
|
||||
"""
|
||||
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
|
||||
|
||||
def register(self, notice_type: Optional[str] = None) -> Callable:
|
||||
"""
|
||||
注册通知处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self.handlers.append({"type": notice_type, "func": func})
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
self.handlers.append({"type": notice_type, "func": func, "plugin_name": plugin_name})
|
||||
return func
|
||||
return decorator
|
||||
|
||||
@@ -175,12 +211,23 @@ class RequestHandler(BaseHandler):
|
||||
"""
|
||||
请求事件处理器
|
||||
"""
|
||||
def clear(self):
|
||||
self.handlers.clear()
|
||||
|
||||
def unregister_by_plugin_name(self, plugin_name: str):
|
||||
"""
|
||||
根据插件名卸载相关的请求处理器
|
||||
"""
|
||||
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
|
||||
|
||||
def register(self, request_type: Optional[str] = None) -> Callable:
|
||||
"""
|
||||
注册请求处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self.handlers.append({"type": request_type, "func": func})
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
self.handlers.append({"type": request_type, "func": func, "plugin_name": plugin_name})
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
管理器包
|
||||
|
||||
这个包集中了机器人核心的单例管理器。
|
||||
通过从这里导入,可以确保在整个应用中访问到的都是同一个实例。
|
||||
"""
|
||||
from ..config_loader import global_config
|
||||
from .admin_manager import AdminManager
|
||||
from .command_manager import CommandManager
|
||||
from .permission_manager import PermissionManager
|
||||
from .plugin_manager import PluginManager
|
||||
from .redis_manager import RedisManager
|
||||
|
||||
# --- 实例化所有单例管理器 ---
|
||||
|
||||
# 管理员管理器
|
||||
admin_manager = AdminManager()
|
||||
|
||||
# 权限管理器
|
||||
permission_manager = PermissionManager()
|
||||
|
||||
# 命令与事件管理器 (别名 matcher)
|
||||
command_manager = CommandManager(prefixes=tuple(global_config.bot.command))
|
||||
matcher = command_manager
|
||||
|
||||
# 插件管理器
|
||||
plugin_manager = PluginManager(command_manager)
|
||||
plugin_manager.load_all_plugins()
|
||||
|
||||
# Redis 管理器
|
||||
redis_manager = RedisManager()
|
||||
|
||||
__all__ = [
|
||||
"admin_manager",
|
||||
"permission_manager",
|
||||
"command_manager",
|
||||
"matcher",
|
||||
"plugin_manager",
|
||||
"redis_manager",
|
||||
]
|
||||
|
||||
@@ -26,8 +26,7 @@ class AdminManager(Singleton):
|
||||
"""
|
||||
初始化 AdminManager
|
||||
"""
|
||||
super().__init__()
|
||||
if not self._initialized:
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
|
||||
# 管理员数据文件路径
|
||||
@@ -39,7 +38,12 @@ class AdminManager(Singleton):
|
||||
)
|
||||
|
||||
self._admins: Set[int] = set()
|
||||
|
||||
# 确保数据目录存在
|
||||
os.makedirs(os.path.dirname(self.data_file), exist_ok=True)
|
||||
|
||||
logger.info("管理员管理器初始化完成")
|
||||
super().__init__()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
|
||||
@@ -12,7 +12,16 @@ from ..handlers.event_handler import MessageHandler, NoticeHandler, RequestHandl
|
||||
|
||||
|
||||
# 从配置中获取命令前缀
|
||||
command_prefixes = global_config.bot.get("command", ("/",))
|
||||
_config_prefixes = global_config.bot.command
|
||||
|
||||
# 确保前缀配置是元组格式
|
||||
_final_prefixes: Tuple[str, ...]
|
||||
if isinstance(_config_prefixes, list):
|
||||
_final_prefixes = tuple(_config_prefixes)
|
||||
elif isinstance(_config_prefixes, str):
|
||||
_final_prefixes = (_config_prefixes,)
|
||||
else:
|
||||
_final_prefixes = tuple(_config_prefixes)
|
||||
|
||||
|
||||
class CommandManager:
|
||||
@@ -59,6 +68,35 @@ class CommandManager:
|
||||
"usage": "/help",
|
||||
}
|
||||
|
||||
def clear_all_handlers(self):
|
||||
"""
|
||||
清空所有已注册的事件处理器。
|
||||
注意:这也会移除内置的 /help 命令,因此需要重新注册。
|
||||
"""
|
||||
self.message_handler.clear()
|
||||
self.notice_handler.clear()
|
||||
self.request_handler.clear()
|
||||
self.plugins.clear()
|
||||
|
||||
# 清空后,需要重新注册内置命令
|
||||
self._register_internal_commands()
|
||||
|
||||
def unload_plugin(self, plugin_name: str):
|
||||
"""
|
||||
卸载指定插件的所有处理器和命令。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件的模块名 (例如 'plugins.bili_parser')
|
||||
"""
|
||||
self.message_handler.unregister_by_plugin_name(plugin_name)
|
||||
self.notice_handler.unregister_by_plugin_name(plugin_name)
|
||||
self.request_handler.unregister_by_plugin_name(plugin_name)
|
||||
|
||||
# 移除插件元信息
|
||||
plugins_to_remove = [name for name in self.plugins if name.startswith(plugin_name)]
|
||||
for name in plugins_to_remove:
|
||||
del self.plugins[name]
|
||||
|
||||
# --- 装饰器代理 ---
|
||||
|
||||
def on_message(self) -> Callable:
|
||||
@@ -102,7 +140,7 @@ class CommandManager:
|
||||
|
||||
根据事件的 `post_type` 将其分发给对应的处理器。
|
||||
"""
|
||||
if event.post_type == 'message' and global_config.bot.get('ignore_self_message', False):
|
||||
if event.post_type == 'message' and global_config.bot.ignore_self_message:
|
||||
if hasattr(event, 'user_id') and hasattr(event, 'self_id') and event.user_id == event.self_id:
|
||||
return
|
||||
|
||||
@@ -130,14 +168,6 @@ class CommandManager:
|
||||
await bot.send(event, help_text.strip())
|
||||
|
||||
|
||||
# --- 全局单例 ---
|
||||
|
||||
# 确保前缀配置是元组格式
|
||||
if isinstance(command_prefixes, list):
|
||||
command_prefixes = tuple(command_prefixes)
|
||||
elif isinstance(command_prefixes, str):
|
||||
command_prefixes = (command_prefixes,)
|
||||
|
||||
# 实例化全局唯一的命令管理器
|
||||
matcher = CommandManager(prefixes=command_prefixes)
|
||||
matcher = CommandManager(prefixes=_final_prefixes)
|
||||
|
||||
|
||||
@@ -13,64 +13,17 @@
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from functools import total_ordering
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..utils.logger import logger
|
||||
from ..utils.singleton import Singleton
|
||||
from .admin_manager import admin_manager
|
||||
from ..permission import Permission
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Permission:
|
||||
"""
|
||||
权限封装类
|
||||
|
||||
封装了权限的名称和等级,并提供了比较方法。
|
||||
使用 @total_ordering 装饰器可以自动生成所有的比较运算符。
|
||||
"""
|
||||
def __init__(self, name: str, level: int):
|
||||
"""
|
||||
初始化权限对象
|
||||
|
||||
Args:
|
||||
name (str): 权限名称 (e.g., "admin", "op")
|
||||
level (int): 权限等级,数字越大权限越高
|
||||
"""
|
||||
self.name = name
|
||||
self.level = level
|
||||
|
||||
def __eq__(self, other):
|
||||
"""
|
||||
判断权限是否相等
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self.level == other.level
|
||||
|
||||
def __lt__(self, other):
|
||||
"""
|
||||
判断权限是否小于另一个权限
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self.level < other.level
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
返回权限的字符串表示(即权限名称)
|
||||
"""
|
||||
return self.name
|
||||
|
||||
|
||||
# 定义全局权限常量
|
||||
ADMIN = Permission("admin", 3)
|
||||
OP = Permission("op", 2)
|
||||
USER = Permission("user", 1)
|
||||
|
||||
# 用于从字符串名称查找权限对象的字典
|
||||
_PERMISSIONS: Dict[str, Permission] = {
|
||||
p.name: p for p in [ADMIN, OP, USER]
|
||||
p.value: p for p in Permission
|
||||
}
|
||||
|
||||
|
||||
@@ -88,8 +41,7 @@ class PermissionManager(Singleton):
|
||||
|
||||
如果已经初始化过,则直接返回。
|
||||
"""
|
||||
super().__init__()
|
||||
if not self._initialized:
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
|
||||
# 权限数据文件路径
|
||||
@@ -111,6 +63,7 @@ class PermissionManager(Singleton):
|
||||
self.load()
|
||||
|
||||
logger.info("权限管理器初始化完成")
|
||||
super().__init__()
|
||||
|
||||
def load(self) -> None:
|
||||
"""
|
||||
@@ -164,12 +117,12 @@ class PermissionManager(Singleton):
|
||||
"""
|
||||
# 首先,通过 AdminManager 检查是否为管理员
|
||||
if await admin_manager.is_admin(user_id):
|
||||
return ADMIN
|
||||
return Permission.ADMIN
|
||||
|
||||
# 如果不是管理员,则从 permissions.json 中查找
|
||||
user_id_str = str(user_id)
|
||||
level_name = self._data["users"].get(user_id_str, USER.name)
|
||||
return _PERMISSIONS.get(level_name, USER)
|
||||
level_name = self._data["users"].get(user_id_str, Permission.USER.value)
|
||||
return _PERMISSIONS.get(level_name, Permission.USER)
|
||||
|
||||
def set_user_permission(self, user_id: int, permission: Permission) -> None:
|
||||
"""
|
||||
@@ -182,13 +135,13 @@ class PermissionManager(Singleton):
|
||||
Raises:
|
||||
ValueError: 如果权限对象无效
|
||||
"""
|
||||
if not isinstance(permission, Permission) or permission.name not in _PERMISSIONS:
|
||||
if not isinstance(permission, Permission):
|
||||
raise ValueError(f"无效的权限对象: {permission}")
|
||||
|
||||
user_id_str = str(user_id)
|
||||
self._data["users"][user_id_str] = permission.name
|
||||
self._data["users"][user_id_str] = permission.value
|
||||
self.save()
|
||||
logger.info(f"设置用户 {user_id} 的权限级别为 {permission.name}")
|
||||
logger.info(f"设置用户 {user_id} 的权限级别为 {permission.value}")
|
||||
|
||||
def remove_user(self, user_id: int) -> None:
|
||||
"""
|
||||
@@ -214,17 +167,17 @@ class PermissionManager(Singleton):
|
||||
Returns:
|
||||
bool: 如果用户权限 >= 所需权限,返回 True,否则返回 False
|
||||
"""
|
||||
# 如果传入的是字符串,先转换为 Permission 对象
|
||||
if isinstance(required_permission, str):
|
||||
required_permission = _PERMISSIONS.get(required_permission.lower())
|
||||
if not required_permission:
|
||||
# 如果是无效的权限字符串,默认拒绝
|
||||
logger.warning(f"检测到无效的权限检查字符串: {required_permission}")
|
||||
return False
|
||||
|
||||
user_permission = await self.get_user_permission(user_id)
|
||||
return user_permission >= required_permission
|
||||
|
||||
def get_all_user_permissions(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取所有已配置的用户权限
|
||||
|
||||
:return: 一个包含所有用户权限的字典
|
||||
"""
|
||||
return self._data["users"].copy()
|
||||
|
||||
def get_all_users(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取所有设置了权限的用户及其级别名称
|
||||
@@ -243,22 +196,22 @@ class PermissionManager(Singleton):
|
||||
logger.info("已清空所有权限设置")
|
||||
|
||||
|
||||
# 全局权限管理器实例
|
||||
permission_manager = PermissionManager()
|
||||
|
||||
def require_admin(func):
|
||||
"""
|
||||
一个装饰器,用于限制命令只能由管理员执行。
|
||||
"""
|
||||
from functools import wraps
|
||||
from models.events.message import MessageEvent
|
||||
from core.managers import permission_manager
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(event: MessageEvent, *args, **kwargs):
|
||||
user_id = event.user_id
|
||||
if await permission_manager.check_permission(user_id, ADMIN):
|
||||
if await permission_manager.check_permission(user_id, Permission.ADMIN):
|
||||
return await func(event, *args, **kwargs)
|
||||
else:
|
||||
await event.reply("抱歉,您没有权限执行此命令。")
|
||||
# 假设 event 对象有 reply 方法
|
||||
if hasattr(event, "reply"):
|
||||
await event.reply("抱歉,您没有权限执行此命令。")
|
||||
return None
|
||||
return wrapper
|
||||
|
||||
@@ -1,126 +1,97 @@
|
||||
"""
|
||||
插件管理器模块
|
||||
|
||||
负责扫描、加载和管理 `base_plugins` 目录下的所有插件。
|
||||
负责扫描、加载和管理 `plugins` 目录下的所有插件。
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
from typing import Set
|
||||
|
||||
from .command_manager import matcher
|
||||
from ..utils.exceptions import SyncHandlerError
|
||||
from ..utils.logger import logger
|
||||
from ..utils.executor import run_in_thread_pool
|
||||
|
||||
|
||||
def load_all_plugins():
|
||||
class PluginManager:
|
||||
"""
|
||||
扫描并加载 `plugins` 目录下的所有插件。
|
||||
|
||||
该函数会遍历 `plugins` 目录下的所有模块:
|
||||
1. 如果模块已加载,则执行 reload 操作(用于热重载)。
|
||||
2. 如果模块未加载,则执行 import 操作。
|
||||
|
||||
加载过程中会提取插件元数据 `__plugin_meta__` 并注册到 CommandManager。
|
||||
插件管理器类
|
||||
"""
|
||||
plugin_dir = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "..", "plugins"
|
||||
)
|
||||
package_name = "plugins"
|
||||
def __init__(self, command_manager):
|
||||
"""
|
||||
初始化插件管理器
|
||||
|
||||
logger.info(f"正在从 {package_name} 加载插件...")
|
||||
:param command_manager: CommandManager的实例
|
||||
"""
|
||||
self.command_manager = command_manager
|
||||
self.loaded_plugins: Set[str] = set()
|
||||
|
||||
for loader, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
|
||||
full_module_name = f"{package_name}.{module_name}"
|
||||
def load_all_plugins(self):
|
||||
"""
|
||||
扫描并加载 `plugins` 目录下的所有插件。
|
||||
"""
|
||||
# 使用 pathlib 获取更可靠的路径
|
||||
# 当前文件: core/managers/plugin_manager.py
|
||||
# 目标: plugins/
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 回退两级到项目根目录 (core/managers -> core -> root)
|
||||
root_dir = os.path.dirname(os.path.dirname(current_dir))
|
||||
plugin_dir = os.path.join(root_dir, "plugins")
|
||||
|
||||
package_name = "plugins"
|
||||
|
||||
if not os.path.exists(plugin_dir):
|
||||
logger.error(f"插件目录不存在: {plugin_dir}")
|
||||
return
|
||||
|
||||
logger.info(f"正在从 {package_name} 加载插件 (路径: {plugin_dir})...")
|
||||
|
||||
for _, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
|
||||
full_module_name = f"{package_name}.{module_name}"
|
||||
|
||||
try:
|
||||
if full_module_name in self.loaded_plugins:
|
||||
self.command_manager.unload_plugin(full_module_name)
|
||||
module = importlib.reload(sys.modules[full_module_name])
|
||||
action = "重载"
|
||||
else:
|
||||
module = importlib.import_module(full_module_name)
|
||||
action = "加载"
|
||||
|
||||
if hasattr(module, "__plugin_meta__"):
|
||||
meta = getattr(module, "__plugin_meta__")
|
||||
self.command_manager.plugins[full_module_name] = meta
|
||||
|
||||
self.loaded_plugins.add(full_module_name)
|
||||
|
||||
type_str = "包" if is_pkg else "文件"
|
||||
logger.success(f" [{type_str}] 成功{action}: {module_name}")
|
||||
except SyncHandlerError as e:
|
||||
logger.error(f" 插件 {module_name} 加载失败: {e} (跳过此插件)")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f" {action if 'action' in locals() else '加载'}插件 {module_name} 失败: {e}"
|
||||
)
|
||||
|
||||
def reload_plugin(self, full_module_name: str):
|
||||
"""
|
||||
精确重载单个插件。
|
||||
"""
|
||||
if full_module_name not in self.loaded_plugins:
|
||||
logger.warning(f"尝试重载一个未被加载的插件: {full_module_name},将按首次加载处理。")
|
||||
|
||||
if full_module_name not in sys.modules:
|
||||
logger.error(f"重载失败: 模块 {full_module_name} 未在 sys.modules 中找到。")
|
||||
return
|
||||
|
||||
try:
|
||||
if full_module_name in sys.modules:
|
||||
module = importlib.reload(sys.modules[full_module_name])
|
||||
action = "重载"
|
||||
else:
|
||||
module = importlib.import_module(full_module_name)
|
||||
action = "加载"
|
||||
self.command_manager.unload_plugin(full_module_name)
|
||||
module = importlib.reload(sys.modules[full_module_name])
|
||||
|
||||
# 提取插件元数据
|
||||
if hasattr(module, "__plugin_meta__"):
|
||||
meta = getattr(module, "__plugin_meta__")
|
||||
matcher.plugins[full_module_name] = meta
|
||||
self.command_manager.plugins[full_module_name] = meta
|
||||
|
||||
type_str = "包" if is_pkg else "文件"
|
||||
logger.success(f" [{type_str}] 成功{action}: {module_name}")
|
||||
except SyncHandlerError as e:
|
||||
logger.error(f" 插件 {module_name} 加载失败: {e} (跳过此插件)")
|
||||
logger.success(f"插件 {full_module_name} 已成功重载。")
|
||||
except Exception as e:
|
||||
print(
|
||||
f" {action if 'action' in locals() else '加载'}插件 {module_name} 失败: {e}"
|
||||
)
|
||||
|
||||
|
||||
class PluginDataManager:
|
||||
"""
|
||||
用于管理插件产生的数据文件的类
|
||||
"""
|
||||
|
||||
def __init__(self, plugin_name: str):
|
||||
"""
|
||||
初始化插件数据管理器
|
||||
|
||||
:param plugin_name: 插件名称
|
||||
"""
|
||||
self.plugin_name = plugin_name
|
||||
self.data_file = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"..",
|
||||
"plugins",
|
||||
"data",
|
||||
self.plugin_name + ".json",
|
||||
)
|
||||
self.data = {}
|
||||
|
||||
async def load(self):
|
||||
"""读取配置文件"""
|
||||
if not os.path.exists(self.data_file):
|
||||
await self.set(self.plugin_name, [])
|
||||
try:
|
||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
||||
self.data = await run_in_thread_pool(json.load, f)
|
||||
except json.JSONDecodeError:
|
||||
self.data = {}
|
||||
|
||||
async def save(self):
|
||||
"""保存配置到文件"""
|
||||
with open(self.data_file, "w", encoding="utf-8") as f:
|
||||
await run_in_thread_pool(json.dump, self.data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""获取配置项"""
|
||||
return self.data.get(key, default)
|
||||
|
||||
async def set(self, key, value):
|
||||
"""设置配置项"""
|
||||
self.data[key] = value
|
||||
await self.save()
|
||||
|
||||
async def add(self, key, value):
|
||||
"""添加配置项"""
|
||||
if key not in self.data:
|
||||
self.data[key] = []
|
||||
self.data[key].append(value)
|
||||
await self.save()
|
||||
|
||||
async def remove(self, key):
|
||||
"""删除配置项"""
|
||||
if key in self.data:
|
||||
del self.data[key]
|
||||
await self.save()
|
||||
|
||||
async def clear(self):
|
||||
"""清空所有配置"""
|
||||
self.data.clear()
|
||||
await self.save()
|
||||
|
||||
def get_all(self):
|
||||
return self.data.copy()
|
||||
logger.exception(f"重载插件 {full_module_name} 时发生错误: {e}")
|
||||
|
||||
@@ -20,10 +20,11 @@ class RedisManager:
|
||||
"""
|
||||
if self._redis is None:
|
||||
try:
|
||||
host = config.redis['host']
|
||||
port = config.redis['port']
|
||||
db = config.redis['db']
|
||||
password = config.redis.get('password')
|
||||
redis_config = config.redis
|
||||
host = redis_config.host
|
||||
port = redis_config.port
|
||||
db = redis_config.db
|
||||
password = redis_config.password
|
||||
|
||||
logger.info(f"正在尝试连接 Redis: {host}:{port}, DB: {db}")
|
||||
|
||||
@@ -54,5 +55,17 @@ class RedisManager:
|
||||
raise ConnectionError("Redis 未初始化或连接失败,请先调用 initialize()")
|
||||
return self._redis
|
||||
|
||||
async def get(self, name):
|
||||
"""
|
||||
获取指定键的值
|
||||
"""
|
||||
return await self.redis.get(name)
|
||||
|
||||
async def set(self, name, value, ex=None):
|
||||
"""
|
||||
设置指定键的值
|
||||
"""
|
||||
return await self.redis.set(name, value, ex=ex)
|
||||
|
||||
# 全局 Redis 管理器实例
|
||||
redis_manager = RedisManager()
|
||||
|
||||
42
core/permission.py
Normal file
42
core/permission.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Permission(Enum):
|
||||
"""
|
||||
定义用户权限等级的枚举类。
|
||||
|
||||
使用 @total_ordering 装饰器,只需定义 __lt__ 和 __eq__,
|
||||
即可自动实现所有比较运算符。
|
||||
"""
|
||||
USER = "user"
|
||||
OP = "op"
|
||||
ADMIN = "admin"
|
||||
|
||||
@property
|
||||
def _level_map(self):
|
||||
"""
|
||||
内部属性,用于映射枚举成员到整数等级。
|
||||
"""
|
||||
return {
|
||||
Permission.USER: 1,
|
||||
Permission.OP: 2,
|
||||
Permission.ADMIN: 3
|
||||
}
|
||||
|
||||
def __lt__(self, other):
|
||||
"""
|
||||
比较当前权限是否小于另一个权限。
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self._level_map[self] < self._level_map[other]
|
||||
|
||||
def __ge__(self, other):
|
||||
"""
|
||||
比较当前权限是否大于等于另一个权限。
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self._level_map[self] >= self._level_map[other]
|
||||
@@ -2,7 +2,8 @@
|
||||
import asyncio
|
||||
import docker
|
||||
from docker.tls import TLSConfig
|
||||
from typing import Dict, Any, Callable
|
||||
from docker.types import LogConfig
|
||||
from typing import Any, Callable
|
||||
|
||||
from core.utils.logger import logger
|
||||
|
||||
@@ -10,21 +11,20 @@ class CodeExecutor:
|
||||
"""
|
||||
代码执行引擎,负责管理一个异步任务队列和并发的 Docker 容器执行。
|
||||
"""
|
||||
def __init__(self, bot_instance, config: Dict[str, Any]):
|
||||
def __init__(self, config: Any):
|
||||
"""
|
||||
初始化代码执行引擎。
|
||||
:param bot_instance: Bot 实例,用于后续的消息回复。
|
||||
:param config: 从 config.toml 加载的配置字典。
|
||||
:param config: 从 config_loader.py 加载的全局配置对象。
|
||||
"""
|
||||
self.bot = bot_instance
|
||||
self.task_queue = asyncio.Queue()
|
||||
self.bot: Any = None # Bot 实例将在 WS 连接成功后动态注入
|
||||
self.task_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
# 从传入的配置中读取 Docker 相关设置
|
||||
docker_config = config.docker
|
||||
self.docker_base_url = docker_config.get("base_url")
|
||||
self.sandbox_image = docker_config.get("sandbox_image", "python-sandbox:latest")
|
||||
self.timeout = docker_config.get("timeout", 10)
|
||||
concurrency = docker_config.get("concurrency_limit", 5)
|
||||
self.docker_base_url = docker_config.base_url
|
||||
self.sandbox_image = docker_config.sandbox_image
|
||||
self.timeout = docker_config.timeout
|
||||
concurrency = docker_config.concurrency_limit
|
||||
|
||||
self.concurrency_limit = asyncio.Semaphore(concurrency)
|
||||
self.docker_client = None
|
||||
@@ -34,10 +34,10 @@ class CodeExecutor:
|
||||
if self.docker_base_url:
|
||||
# 如果配置了远程 Docker 地址,则使用 TLS 选项进行连接
|
||||
tls_config = None
|
||||
if docker_config.get("tls_verify", False):
|
||||
if docker_config.tls_verify:
|
||||
tls_config = TLSConfig(
|
||||
ca_cert=docker_config.get("ca_cert_path"),
|
||||
client_cert=(docker_config.get("client_cert_path"), docker_config.get("client_key_path")),
|
||||
ca_cert=docker_config.ca_cert_path,
|
||||
client_cert=(docker_config.client_cert_path, docker_config.client_key_path),
|
||||
verify=True
|
||||
)
|
||||
self.docker_client = docker.DockerClient(base_url=self.docker_base_url, tls=tls_config)
|
||||
@@ -60,7 +60,15 @@ class CodeExecutor:
|
||||
将代码执行任务添加到队列中。
|
||||
:param code: 待执行的 Python 代码字符串。
|
||||
:param callback: 执行完毕后用于回复结果的回调函数。
|
||||
:raises RuntimeError: 如果 Docker 客户端未初始化。
|
||||
"""
|
||||
if not self.docker_client:
|
||||
logger.warning("[CodeExecutor] 尝试添加任务,但 Docker 客户端未初始化。任务被拒绝。")
|
||||
# 这里可以选择抛出异常,或者直接调用回调返回错误信息
|
||||
# 为了用户体验,我们构造一个错误结果并直接调用回调(如果可能)
|
||||
# 但由于 callback 返回 Future,这里简单起见,我们记录日志并抛出异常
|
||||
raise RuntimeError("Docker环境未就绪,无法执行代码。")
|
||||
|
||||
task = {"code": code, "callback": callback}
|
||||
await self.task_queue.put(task)
|
||||
logger.info(f"[CodeExecutor] 新的代码执行任务已入队 (队列当前长度: {self.task_queue.qsize()})。")
|
||||
@@ -125,6 +133,9 @@ class CodeExecutor:
|
||||
同步函数:在 Docker 容器中运行代码。
|
||||
此函数通过手动管理容器生命周期来提高稳定性。
|
||||
"""
|
||||
if self.docker_client is None:
|
||||
raise docker.errors.DockerException("Docker client is not initialized.")
|
||||
|
||||
container = None
|
||||
try:
|
||||
# 1. 创建容器
|
||||
@@ -134,7 +145,7 @@ class CodeExecutor:
|
||||
mem_limit='128m',
|
||||
cpu_shares=512,
|
||||
network_disabled=True,
|
||||
log_config={'type': 'json-file', 'config': {'max-size': '1m'}},
|
||||
log_config=LogConfig(type='json-file', config={'max-size': '1m'}),
|
||||
)
|
||||
# 2. 启动容器
|
||||
container.start()
|
||||
@@ -150,7 +161,7 @@ class CodeExecutor:
|
||||
# 5. 检查退出码,如果不为 0,则手动抛出 ContainerError
|
||||
if result.get('StatusCode', 0) != 0:
|
||||
raise docker.errors.ContainerError(
|
||||
container, result['StatusCode'], f"python -c '{code}'", self.sandbox_image, stderr
|
||||
container, result['StatusCode'], f"python -c '{code}'", self.sandbox_image, stderr.decode('utf-8')
|
||||
)
|
||||
|
||||
return stdout
|
||||
@@ -166,11 +177,11 @@ class CodeExecutor:
|
||||
except Exception as e:
|
||||
logger.error(f"[CodeExecutor] 强制移除容器 {container.id} 时失败: {e}")
|
||||
|
||||
def initialize_executor(bot_instance, config: Dict[str, Any]):
|
||||
def initialize_executor(config: Any):
|
||||
"""
|
||||
初始化并返回一个 CodeExecutor 实例。
|
||||
"""
|
||||
return CodeExecutor(bot_instance, config)
|
||||
return CodeExecutor(config)
|
||||
|
||||
async def run_in_thread_pool(sync_func, *args, **kwargs):
|
||||
"""
|
||||
|
||||
52
core/ws.py
52
core/ws.py
@@ -13,11 +13,12 @@ WebSocket 连接。它是整个机器人框架的底层通信基础。
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
import uuid
|
||||
|
||||
import websockets
|
||||
|
||||
from models import EventFactory
|
||||
from models.events.factory import EventFactory
|
||||
|
||||
from .bot import Bot
|
||||
from .config_loader import global_config
|
||||
@@ -30,7 +31,7 @@ class WS:
|
||||
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, code_executor=None):
|
||||
"""
|
||||
初始化 WebSocket 客户端。
|
||||
|
||||
@@ -38,13 +39,15 @@ class WS:
|
||||
"""
|
||||
# 读取参数
|
||||
cfg = global_config.napcat_ws
|
||||
self.url = cfg.get("uri")
|
||||
self.token = cfg.get("token")
|
||||
self.reconnect_interval = cfg.get("reconnect_interval", 5)
|
||||
self.url = cfg.uri
|
||||
self.token = cfg.token
|
||||
self.reconnect_interval = cfg.reconnect_interval
|
||||
|
||||
self.ws = None
|
||||
self._pending_requests = {}
|
||||
self.bot = Bot(self)
|
||||
self.bot: Bot | None = None
|
||||
self.self_id: int | None = None
|
||||
self.code_executor = code_executor
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
@@ -124,18 +127,43 @@ class WS:
|
||||
try:
|
||||
# 使用工厂创建事件对象
|
||||
event = EventFactory.create_event(event_data)
|
||||
|
||||
# 尝试初始化 Bot 实例 (如果尚未初始化且事件包含 self_id)
|
||||
# 只要事件中包含 self_id,我们就可以初始化 Bot,不必非要等待 meta_event
|
||||
if self.bot is None and hasattr(event, 'self_id'):
|
||||
self.self_id = event.self_id
|
||||
self.bot = Bot(self)
|
||||
logger.success(f"Bot 实例初始化完成: self_id={self.self_id}")
|
||||
|
||||
# 将代码执行器注入到 Bot 和执行器自身
|
||||
if self.code_executor:
|
||||
self.bot.code_executor = self.code_executor
|
||||
self.code_executor.bot = self.bot
|
||||
logger.info("代码执行器已成功注入 Bot 实例。")
|
||||
|
||||
# 如果 bot 尚未初始化,则不处理后续事件
|
||||
if self.bot is None:
|
||||
logger.warning("Bot 尚未初始化,跳过事件处理。")
|
||||
return
|
||||
|
||||
event.bot = self.bot # 注入 Bot 实例
|
||||
|
||||
# 打印日志
|
||||
if event.post_type == "message":
|
||||
sender_name = event.sender.nickname if event.sender else "Unknown"
|
||||
logger.info(f"[消息] {event.message_type} | {event.user_id}({sender_name}): {event.raw_message}")
|
||||
sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown"
|
||||
message_type = getattr(event, "message_type", "Unknown")
|
||||
user_id = getattr(event, "user_id", "Unknown")
|
||||
raw_message = getattr(event, "raw_message", "")
|
||||
logger.info(f"[消息] {message_type} | {user_id}({sender_name}): {raw_message}")
|
||||
elif event.post_type == "notice":
|
||||
logger.info(f"[通知] {event.notice_type}")
|
||||
notice_type = getattr(event, "notice_type", "Unknown")
|
||||
logger.info(f"[通知] {notice_type}")
|
||||
elif event.post_type == "request":
|
||||
logger.info(f"[请求] {event.request_type}")
|
||||
request_type = getattr(event, "request_type", "Unknown")
|
||||
logger.info(f"[请求] {request_type}")
|
||||
elif event.post_type == "meta_event":
|
||||
logger.debug(f"[元事件] {event.meta_event_type}")
|
||||
meta_event_type = getattr(event, "meta_event_type", "Unknown")
|
||||
logger.debug(f"[元事件] {meta_event_type}")
|
||||
|
||||
|
||||
# 分发事件
|
||||
@@ -144,7 +172,7 @@ class WS:
|
||||
except Exception as e:
|
||||
logger.exception(f"事件处理异常: {e}")
|
||||
|
||||
async def call_api(self, action: str, params: dict = None):
|
||||
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
|
||||
"""
|
||||
向 OneBot v11 实现端发送一个 API 请求。
|
||||
|
||||
|
||||
37
main.py
37
main.py
@@ -15,7 +15,7 @@ from core.utils.logger import logger
|
||||
|
||||
from core.managers.admin_manager import admin_manager
|
||||
from core.ws import WS
|
||||
from core.managers.plugin_manager import load_all_plugins
|
||||
from core.managers import plugin_manager
|
||||
from core.managers.redis_manager import redis_manager
|
||||
from core.utils.executor import run_in_thread_pool, initialize_executor
|
||||
from core.config_loader import global_config as config
|
||||
@@ -25,6 +25,8 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, ROOT_DIR)
|
||||
|
||||
|
||||
# 获取插件目录的绝对路径
|
||||
PLUGIN_DIR = os.path.join(ROOT_DIR, "plugins")
|
||||
|
||||
|
||||
class PluginReloadHandler(FileSystemEventHandler):
|
||||
@@ -32,7 +34,7 @@ class PluginReloadHandler(FileSystemEventHandler):
|
||||
文件变更处理器,用于热重载插件
|
||||
|
||||
继承自 watchdog.events.FileSystemEventHandler,
|
||||
监听 base_plugins 目录下的文件变化,并触发插件重载。
|
||||
监听 plugins 目录下的文件变化,并触发插件重载。
|
||||
"""
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
"""
|
||||
@@ -53,12 +55,14 @@ class PluginReloadHandler(FileSystemEventHandler):
|
||||
if file_system_event.is_directory:
|
||||
return
|
||||
|
||||
src_path = file_system_event.src_path
|
||||
|
||||
# 只监控 py 文件
|
||||
if not file_system_event.src_path.endswith(".py"):
|
||||
if not src_path.endswith(".py"):
|
||||
return
|
||||
|
||||
# 过滤掉一些临时文件
|
||||
if "__pycache__" in file_system_event.src_path:
|
||||
if "__pycache__" in src_path or not src_path.startswith(PLUGIN_DIR):
|
||||
return
|
||||
|
||||
# 简单的防抖动
|
||||
@@ -68,13 +72,18 @@ class PluginReloadHandler(FileSystemEventHandler):
|
||||
|
||||
self.last_reload_time = current_time
|
||||
|
||||
logger.info(f"检测到文件变更: {file_system_event.src_path}")
|
||||
logger.info("正在重载插件...")
|
||||
# 从文件路径解析出模块名
|
||||
# 例如: C:\path\to\project\plugins\bili_parser.py -> plugins.bili_parser
|
||||
relative_path = os.path.relpath(src_path, ROOT_DIR)
|
||||
module_name = os.path.splitext(relative_path.replace(os.sep, '.'))[0]
|
||||
|
||||
logger.info(f"检测到文件变更: {src_path}")
|
||||
logger.info(f"准备重载插件: {module_name}...")
|
||||
|
||||
try:
|
||||
# 使用线程安全的方式在主事件循环中运行异步的插件加载函数
|
||||
asyncio.run_coroutine_threadsafe(run_in_thread_pool(load_all_plugins), self.loop)
|
||||
logger.success("插件重载完成")
|
||||
# 使用线程安全的方式在主事件循环中运行异步的插件重载函数
|
||||
asyncio.run_coroutine_threadsafe(run_in_thread_pool(plugin_manager.reload_plugin, module_name), self.loop)
|
||||
logger.success(f"插件 {module_name} 重载任务已提交")
|
||||
except Exception as e:
|
||||
logger.exception(f"重载失败: {e}")
|
||||
|
||||
@@ -88,8 +97,7 @@ async def main():
|
||||
2. 初始化 WebSocket 客户端
|
||||
3. 建立连接并保持运行
|
||||
"""
|
||||
# 首次加载插件
|
||||
await run_in_thread_pool(load_all_plugins)
|
||||
# 插件加载已移至 core.managers.__init__.py 中自动执行
|
||||
|
||||
# 初始化 Redis 连接
|
||||
await redis_manager.initialize()
|
||||
@@ -114,11 +122,10 @@ async def main():
|
||||
logger.warning(f"插件目录不存在 {plugin_path}")
|
||||
|
||||
try:
|
||||
websocket_client = WS()
|
||||
|
||||
# 初始化代码执行器
|
||||
code_executor = initialize_executor(websocket_client, config)
|
||||
websocket_client.bot.code_executor = code_executor # 将执行器实例附加到 bot.bot 对象上
|
||||
code_executor = initialize_executor(config)
|
||||
|
||||
websocket_client = WS(code_executor=code_executor)
|
||||
|
||||
# 启动代码执行器的后台 worker
|
||||
logger.debug("[Main] 检查是否需要启动代码执行 Worker...")
|
||||
|
||||
@@ -1,97 +1,23 @@
|
||||
from .events.base import OneBotEvent
|
||||
from .events.factory import EventFactory
|
||||
from .events.message import (
|
||||
GroupMessageEvent,
|
||||
MessageEvent,
|
||||
MessageSegment,
|
||||
PrivateMessageEvent,
|
||||
)
|
||||
from .events.meta import HeartbeatEvent, HeartbeatStatus, LifeCycleEvent, MetaEvent
|
||||
from .events.notice import (
|
||||
ClientStatus,
|
||||
ClientStatusNoticeEvent,
|
||||
EssenceNoticeEvent,
|
||||
FriendAddNoticeEvent,
|
||||
FriendRecallNoticeEvent,
|
||||
GroupAdminNoticeEvent,
|
||||
GroupBanNoticeEvent,
|
||||
GroupCardNoticeEvent,
|
||||
GroupDecreaseNoticeEvent,
|
||||
GroupIncreaseNoticeEvent,
|
||||
GroupRecallNoticeEvent,
|
||||
GroupUploadFile,
|
||||
GroupUploadNoticeEvent,
|
||||
HonorNotifyEvent,
|
||||
LuckyKingNotifyEvent,
|
||||
NoticeEvent,
|
||||
NotifyNoticeEvent,
|
||||
OfflineFile,
|
||||
OfflineFileNoticeEvent,
|
||||
PokeNotifyEvent,
|
||||
)
|
||||
from .events.request import FriendRequestEvent, GroupRequestEvent, RequestEvent
|
||||
from .objects import (
|
||||
CurrentTalkative,
|
||||
EssenceMessage,
|
||||
FriendInfo,
|
||||
GroupHonorInfo,
|
||||
GroupInfo,
|
||||
GroupMemberInfo,
|
||||
HonorInfo,
|
||||
LoginInfo,
|
||||
Status,
|
||||
StrangerInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
"""
|
||||
Models 包
|
||||
|
||||
# Alias for backward compatibility
|
||||
Event = OneBotEvent
|
||||
导出常用的模型类,方便插件导入。
|
||||
"""
|
||||
|
||||
from .events.base import OneBotEvent
|
||||
from .events.message import MessageEvent, GroupMessageEvent, PrivateMessageEvent
|
||||
from .events.notice import NoticeEvent
|
||||
from .events.request import RequestEvent
|
||||
from .message import MessageSegment
|
||||
from .sender import Sender
|
||||
|
||||
__all__ = [
|
||||
"OneBotEvent",
|
||||
"MessageEvent",
|
||||
"GroupMessageEvent",
|
||||
"PrivateMessageEvent",
|
||||
"NoticeEvent",
|
||||
"RequestEvent",
|
||||
"MessageSegment",
|
||||
"Sender",
|
||||
"OneBotEvent",
|
||||
"Event",
|
||||
"MessageEvent",
|
||||
"PrivateMessageEvent",
|
||||
"GroupMessageEvent",
|
||||
"NoticeEvent",
|
||||
"FriendAddNoticeEvent",
|
||||
"FriendRecallNoticeEvent",
|
||||
"GroupRecallNoticeEvent",
|
||||
"GroupIncreaseNoticeEvent",
|
||||
"GroupDecreaseNoticeEvent",
|
||||
"GroupAdminNoticeEvent",
|
||||
"GroupBanNoticeEvent",
|
||||
"GroupUploadNoticeEvent",
|
||||
"GroupUploadFile",
|
||||
"NotifyNoticeEvent",
|
||||
"PokeNotifyEvent",
|
||||
"LuckyKingNotifyEvent",
|
||||
"HonorNotifyEvent",
|
||||
"GroupCardNoticeEvent",
|
||||
"OfflineFileNoticeEvent",
|
||||
"OfflineFile",
|
||||
"ClientStatusNoticeEvent",
|
||||
"ClientStatus",
|
||||
"EssenceNoticeEvent",
|
||||
"RequestEvent",
|
||||
"FriendRequestEvent",
|
||||
"GroupRequestEvent",
|
||||
"MetaEvent",
|
||||
"HeartbeatEvent",
|
||||
"LifeCycleEvent",
|
||||
"HeartbeatStatus",
|
||||
"EventFactory",
|
||||
"GroupInfo",
|
||||
"GroupMemberInfo",
|
||||
"FriendInfo",
|
||||
"StrangerInfo",
|
||||
"LoginInfo",
|
||||
"VersionInfo",
|
||||
"Status",
|
||||
"EssenceMessage",
|
||||
"GroupHonorInfo",
|
||||
"CurrentTalkative",
|
||||
"HonorInfo",
|
||||
]
|
||||
|
||||
@@ -70,7 +70,11 @@ class EventFactory:
|
||||
# 解析消息段
|
||||
message_list = []
|
||||
raw_message_list = data.get("message", [])
|
||||
if isinstance(raw_message_list, list):
|
||||
|
||||
if isinstance(raw_message_list, str):
|
||||
# 如果消息是字符串,将其视为纯文本消息段
|
||||
message_list.append(MessageSegment.text(raw_message_list))
|
||||
elif isinstance(raw_message_list, list):
|
||||
for item in raw_message_list:
|
||||
if isinstance(item, dict):
|
||||
message_list.append(MessageSegment(type=item.get("type", ""), data=item.get("data", {})))
|
||||
@@ -252,9 +256,18 @@ class EventFactory:
|
||||
card_new=data.get("card_new", ""),
|
||||
card_old=data.get("card_old", "")
|
||||
)
|
||||
elif notice_type == "group_card":
|
||||
return GroupCardNoticeEvent(
|
||||
**common_args,
|
||||
notice_type=notice_type,
|
||||
group_id=data.get("group_id", 0),
|
||||
user_id=data.get("user_id", 0),
|
||||
card_new=data.get("card_new", ""),
|
||||
card_old=data.get("card_old", "")
|
||||
)
|
||||
elif notice_type == "offline_file":
|
||||
file_data = data.get("file", {})
|
||||
file = OfflineFile(
|
||||
offline_file = OfflineFile(
|
||||
name=file_data.get("name", ""),
|
||||
size=file_data.get("size", 0),
|
||||
url=file_data.get("url", "")
|
||||
@@ -263,7 +276,7 @@ class EventFactory:
|
||||
**common_args,
|
||||
notice_type=notice_type,
|
||||
user_id=data.get("user_id", 0),
|
||||
file=file
|
||||
file=offline_file
|
||||
)
|
||||
elif notice_type == "client_status":
|
||||
client_data = data.get("client", {})
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
定义了消息相关的事件类,包括 MessageEvent, PrivateMessageEvent, GroupMessageEvent。
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from core.managers.permission_manager import ADMIN, OP, USER
|
||||
from core.permission import Permission
|
||||
from models.message import MessageSegment
|
||||
from models.sender import Sender
|
||||
from .base import OneBotEvent, EventType
|
||||
@@ -34,9 +34,9 @@ class MessageEvent(OneBotEvent):
|
||||
"""
|
||||
|
||||
# 权限级别常量,用于装饰器参数
|
||||
ADMIN = ADMIN
|
||||
OP = OP
|
||||
USER = USER
|
||||
ADMIN = Permission.ADMIN
|
||||
OP = Permission.OP
|
||||
USER = Permission.USER
|
||||
|
||||
message_type: str
|
||||
"""消息类型: private (私聊), group (群聊)"""
|
||||
@@ -70,7 +70,7 @@ class MessageEvent(OneBotEvent):
|
||||
def post_type(self) -> str:
|
||||
return EventType.MESSAGE
|
||||
|
||||
async def reply(self, message: str, auto_escape: bool = False):
|
||||
async def reply(self, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False):
|
||||
"""
|
||||
回复消息(抽象方法,由子类实现)
|
||||
|
||||
@@ -86,7 +86,7 @@ class PrivateMessageEvent(MessageEvent):
|
||||
私聊消息事件
|
||||
"""
|
||||
|
||||
async def reply(self, message: str, auto_escape: bool = False):
|
||||
async def reply(self, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False):
|
||||
"""
|
||||
回复私聊消息
|
||||
|
||||
@@ -110,7 +110,7 @@ class GroupMessageEvent(MessageEvent):
|
||||
anonymous: Optional[Anonymous] = None
|
||||
"""匿名信息"""
|
||||
|
||||
async def reply(self, message: str, auto_escape: bool = False):
|
||||
async def reply(self, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False):
|
||||
"""
|
||||
回复群聊消息
|
||||
|
||||
|
||||
@@ -63,5 +63,5 @@ class LifeCycleEvent(MetaEvent):
|
||||
meta_event_type: str = 'lifecycle'
|
||||
"""元事件类型:生命周期事件"""
|
||||
|
||||
sub_type: LifeCycleSubType = LifeCycleSubType.ENABLE
|
||||
sub_type: str = LifeCycleSubType.ENABLE
|
||||
"""子类型:启用、禁用、连接"""
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional, List
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -23,7 +23,7 @@ class MessageSegment:
|
||||
data: Dict[str, Any]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
def plain_text(self) -> str:
|
||||
"""
|
||||
当消息段类型为 'text' 时,快速获取其文本内容。
|
||||
|
||||
@@ -32,6 +32,19 @@ class MessageSegment:
|
||||
"""
|
||||
return self.data.get("text", "") if self.type == "text" else ""
|
||||
|
||||
@staticmethod
|
||||
def text(text: str) -> "MessageSegment":
|
||||
"""
|
||||
创建一个文本消息段。
|
||||
|
||||
Args:
|
||||
text (str): 文本内容。
|
||||
|
||||
Returns:
|
||||
MessageSegment: 一个类型为 'text' 的消息段对象。
|
||||
"""
|
||||
return MessageSegment(type="text", data={"text": text})
|
||||
|
||||
@property
|
||||
def image_url(self) -> str:
|
||||
"""
|
||||
@@ -76,7 +89,7 @@ class MessageSegment:
|
||||
return self.data.get("file", "")
|
||||
return ""
|
||||
|
||||
def is_at(self, user_id: int = None) -> bool:
|
||||
def is_at(self, user_id: Optional[int] = None) -> bool:
|
||||
"""
|
||||
检查当前消息段是否是一个 'at' (提及) 消息段。
|
||||
|
||||
@@ -93,16 +106,52 @@ class MessageSegment:
|
||||
return True
|
||||
return str(self.data.get("qq")) == str(user_id)
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
返回消息段的 CQ 码字符串表示。
|
||||
"""
|
||||
if self.type == "text":
|
||||
return self.data.get("text", "")
|
||||
|
||||
params = ",".join([f"{k}={v}" for k, v in self.data.items()])
|
||||
if params:
|
||||
return f"[CQ:{self.type},{params}]"
|
||||
return f"[CQ:{self.type}]"
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
返回消息段对象的字符串表示形式,便于调试。
|
||||
"""
|
||||
return f"[MS:{self.type}:{self.data}]"
|
||||
|
||||
def __add__(self, other: Any) -> "List[MessageSegment]":
|
||||
"""
|
||||
支持消息段相加,返回消息段列表。
|
||||
"""
|
||||
if isinstance(other, MessageSegment):
|
||||
return [self, other]
|
||||
elif isinstance(other, str):
|
||||
return [self, MessageSegment.text(other)]
|
||||
elif isinstance(other, list):
|
||||
return [self] + other
|
||||
return NotImplemented
|
||||
|
||||
def __radd__(self, other: Any) -> "List[MessageSegment]":
|
||||
"""
|
||||
支持反向相加。
|
||||
"""
|
||||
if isinstance(other, MessageSegment):
|
||||
return [other, self]
|
||||
elif isinstance(other, str):
|
||||
return [MessageSegment.text(other), self]
|
||||
elif isinstance(other, list):
|
||||
return other + [self]
|
||||
return NotImplemented
|
||||
|
||||
# --- 快捷构造方法 ---
|
||||
|
||||
@staticmethod
|
||||
def text(text: str) -> "MessageSegment": # noqa: F811
|
||||
def from_text(text: str) -> "MessageSegment":
|
||||
"""
|
||||
创建一个文本消息段。
|
||||
|
||||
@@ -115,7 +164,7 @@ class MessageSegment:
|
||||
return MessageSegment(type="text", data={"text": text})
|
||||
|
||||
@staticmethod
|
||||
def at(user_id: int | str, name: str = None) -> "MessageSegment":
|
||||
def at(user_id: int | str, name: Optional[str] = None) -> "MessageSegment":
|
||||
"""
|
||||
创建一个 @某人 的消息段。
|
||||
|
||||
@@ -132,7 +181,7 @@ class MessageSegment:
|
||||
return MessageSegment(type="at", data=data)
|
||||
|
||||
@staticmethod
|
||||
def image(file: str, image_type: str = None, cache: bool = True, proxy: bool = True, timeout: int = None, sub_type: int = None) -> "MessageSegment":
|
||||
def image(file: str, image_type: Optional[str] = None, cache: bool = True, proxy: bool = True, timeout: Optional[int] = None, sub_type: Optional[int] = None) -> "MessageSegment":
|
||||
"""
|
||||
创建一个图片消息段。
|
||||
|
||||
@@ -194,7 +243,7 @@ class MessageSegment:
|
||||
"""
|
||||
return MessageSegment(type="xml", data={"data": data})
|
||||
@staticmethod
|
||||
def share(url: str, title: str, content: str = None, image: str = None) -> "MessageSegment":
|
||||
def share(url: str, title: str, content: Optional[str] = None, image: Optional[str] = None) -> "MessageSegment":
|
||||
"""
|
||||
创建一个分享消息段。
|
||||
|
||||
@@ -227,7 +276,7 @@ class MessageSegment:
|
||||
"""
|
||||
return MessageSegment(type="music", data={"type": type, "id": id})
|
||||
@staticmethod
|
||||
def music_custom(url: str, audio: str, title: str, content: str = None, image: str = None) -> "MessageSegment":
|
||||
def music_custom(url: str, audio: str, title: str, content: Optional[str] = None, image: Optional[str] = None) -> "MessageSegment":
|
||||
"""
|
||||
创建一个自定义音乐消息段。
|
||||
|
||||
@@ -248,7 +297,7 @@ class MessageSegment:
|
||||
data["image"] = image
|
||||
return MessageSegment(type="music", data={"type": "custom", **data})
|
||||
@staticmethod
|
||||
def record(file: str, magic: bool = False, cache: bool = True, proxy: bool = True, timeout: int = None) -> "MessageSegment":
|
||||
def record(file: str, magic: bool = False, cache: bool = True, proxy: bool = True, timeout: Optional[int] = None) -> "MessageSegment":
|
||||
"""
|
||||
创建一个语音消息段。
|
||||
|
||||
@@ -267,7 +316,7 @@ class MessageSegment:
|
||||
data["timeout"] = str(timeout)
|
||||
return MessageSegment(type="record", data=data)
|
||||
@staticmethod
|
||||
def video(file: str, cover: str = None, c: int = 2) -> "MessageSegment":
|
||||
def video(file: str, cover: Optional[str] = None, c: int = 2) -> "MessageSegment":
|
||||
"""
|
||||
创建一个视频消息段。
|
||||
|
||||
@@ -297,17 +346,17 @@ class MessageSegment:
|
||||
return MessageSegment(type="file", data={"file": file})
|
||||
|
||||
@staticmethod
|
||||
def reply(message_id: str) -> "MessageSegment":
|
||||
def reply(message_id: str | int) -> "MessageSegment":
|
||||
"""
|
||||
创建一个回复消息段。
|
||||
|
||||
Args:
|
||||
message_id (str): 被回复的消息 ID。
|
||||
message_id (str | int): 被回复的消息 ID。
|
||||
|
||||
Returns:
|
||||
MessageSegment: 一个类型为 'reply' 的消息段对象。
|
||||
"""
|
||||
return MessageSegment(type="reply", data={"id": message_id})
|
||||
return MessageSegment(type="reply", data={"id": str(message_id)})
|
||||
|
||||
@staticmethod
|
||||
def rps() -> "MessageSegment":
|
||||
|
||||
0
plugins/__init__.py
Normal file
0
plugins/__init__.py
Normal file
130
plugins/admin.py
130
plugins/admin.py
@@ -1,74 +1,94 @@
|
||||
"""
|
||||
管理员管理插件
|
||||
|
||||
提供通过聊天指令动态添加或移除机器人管理员的功能。
|
||||
"""
|
||||
from core.bot import Bot
|
||||
from core.managers.command_manager import matcher
|
||||
from core.managers.admin_manager import admin_manager
|
||||
from core.handlers.event_handler import MessageHandler
|
||||
from core.managers import command_manager, permission_manager
|
||||
from core.permission import Permission
|
||||
from models.events.message import MessageEvent
|
||||
|
||||
# 更新插件元信息以包含OP管理
|
||||
__plugin_meta__ = {
|
||||
"name": "管理员管理",
|
||||
"description": "管理机器人的全局管理员",
|
||||
"name": "权限管理",
|
||||
"description": "管理机器人的管理员和操作员",
|
||||
"usage": (
|
||||
"/admin list - 列出所有管理员\n"
|
||||
"/admin add <QQ号> - 添加管理员\n"
|
||||
"/admin remove <QQ号> - 移除管理员"
|
||||
"/admin list - 列出所有管理员和操作员\n"
|
||||
"/admin add_admin <QQ号> - 添加管理员\n"
|
||||
"/admin remove_admin <QQ号> - 移除管理员\n"
|
||||
"/admin add_op <QQ号> - 添加操作员\n"
|
||||
"/admin remove_op <QQ号> - 移除操作员"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@matcher.command("admin", permission=MessageEvent.ADMIN)
|
||||
async def admin_command_handler(bot: Bot, event: MessageEvent, args: list[str]):
|
||||
@command_manager.command("admin", permission=Permission.ADMIN)
|
||||
async def admin_management(event: MessageEvent, args: str):
|
||||
"""
|
||||
处理 /admin 指令
|
||||
|
||||
:param bot: Bot 实例
|
||||
:param event: 消息事件实例
|
||||
:param args: 指令参数列表
|
||||
处理所有权限管理相关的命令。
|
||||
"""
|
||||
if not args:
|
||||
await event.reply(__plugin_meta__["usage"])
|
||||
parts = args.split()
|
||||
if not parts:
|
||||
await event.reply(f"用法不正确。\n\n{__plugin_meta__['usage']}")
|
||||
return
|
||||
|
||||
action = args[0].lower()
|
||||
subcommand = parts[0].lower()
|
||||
|
||||
if action == "list":
|
||||
admins = await admin_manager.get_all_admins()
|
||||
if not admins:
|
||||
await event.reply("当前没有设置任何管理员。")
|
||||
return
|
||||
|
||||
admin_list_text = "\n".join(str(admin_id) for admin_id in admins)
|
||||
await event.reply(f"当前管理员列表 ({len(admins)}):\n{admin_list_text}")
|
||||
if subcommand == "list":
|
||||
await list_permissions(event)
|
||||
return
|
||||
|
||||
if action in ("add", "remove"):
|
||||
if len(args) < 2 or not args[1].isdigit():
|
||||
await event.reply("参数错误,请提供一个有效的 QQ 号。\n示例: /admin add 123456")
|
||||
return
|
||||
# 处理需要QQ号的命令
|
||||
if len(parts) < 2 or not parts[1].isdigit():
|
||||
await event.reply(f"请提供有效的用户QQ号。\n用法: /admin {subcommand} <QQ号>")
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = int(args[1])
|
||||
except ValueError:
|
||||
await event.reply("无效的 QQ 号,请输入纯数字。")
|
||||
return
|
||||
try:
|
||||
target_user_id = int(parts[1])
|
||||
except ValueError:
|
||||
await event.reply("无效的QQ号。")
|
||||
return
|
||||
|
||||
if action == "add":
|
||||
success = await admin_manager.add_admin(user_id)
|
||||
if success:
|
||||
await event.reply(f"成功添加管理员: {user_id}")
|
||||
else:
|
||||
await event.reply(f"管理员 {user_id} 已存在,无需重复添加。")
|
||||
return
|
||||
# 安全检查
|
||||
if target_user_id == event.user_id:
|
||||
await event.reply("你不能操作自己的权限。")
|
||||
return
|
||||
if target_user_id == event.self_id:
|
||||
await event.reply("你不能操作机器人自身的权限。")
|
||||
return
|
||||
|
||||
elif action == "remove":
|
||||
success = await admin_manager.remove_admin(user_id)
|
||||
if success:
|
||||
await event.reply(f"成功移除管理员: {user_id}")
|
||||
else:
|
||||
await event.reply(f"管理员 {user_id} 不存在。")
|
||||
return
|
||||
# 根据子命令分发
|
||||
if subcommand == "add_admin":
|
||||
permission_manager.set_user_permission(target_user_id, Permission.ADMIN)
|
||||
await event.reply(f"已成功添加管理员:{target_user_id}")
|
||||
elif subcommand == "remove_admin":
|
||||
permission_manager.set_user_permission(target_user_id, Permission.USER)
|
||||
await event.reply(f"已成功移除管理员:{target_user_id}")
|
||||
elif subcommand == "add_op":
|
||||
permission_manager.set_user_permission(target_user_id, Permission.OP)
|
||||
await event.reply(f"已成功添加操作员:{target_user_id}")
|
||||
elif subcommand == "remove_op":
|
||||
permission_manager.set_user_permission(target_user_id, Permission.USER)
|
||||
await event.reply(f"已成功移除操作员:{target_user_id}")
|
||||
else:
|
||||
await event.reply(f"未知的子命令 '{subcommand}'。\n\n{__plugin_meta__['usage']}")
|
||||
|
||||
await event.reply(f"未知的指令: {action}\n\n{__plugin_meta__['usage']}")
|
||||
|
||||
async def list_permissions(event: MessageEvent):
|
||||
"""
|
||||
列出所有具有特殊权限(管理员和操作员)的用户。
|
||||
"""
|
||||
permissions = permission_manager.get_all_user_permissions()
|
||||
if not permissions:
|
||||
await event.reply("当前没有配置任何特殊权限的用户。")
|
||||
return
|
||||
|
||||
admins = {uid for uid, p in permissions.items() if p == 'admin'}
|
||||
ops = {uid for uid, p in permissions.items() if p == 'op'}
|
||||
|
||||
reply_msg = "当前权限列表:\n"
|
||||
if admins:
|
||||
reply_msg += "--- 管理员 ---\n"
|
||||
for user_id in admins:
|
||||
reply_msg += f"- {user_id}\n"
|
||||
if ops:
|
||||
reply_msg += "--- 操作员 ---\n"
|
||||
for user_id in ops:
|
||||
reply_msg += f"- {user_id}\n"
|
||||
|
||||
await event.reply(reply_msg.strip())
|
||||
|
||||
@@ -3,12 +3,16 @@ import re
|
||||
import json
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from cachetools import TTLCache
|
||||
|
||||
from core.utils.logger import logger
|
||||
from core.managers.command_manager import matcher
|
||||
from models import MessageEvent, MessageSegment
|
||||
|
||||
# 创建一个TTL缓存,最大容量100,缓存时间10秒
|
||||
processed_messages: TTLCache[int, bool] = TTLCache(maxsize=100, ttl=10)
|
||||
|
||||
__plugin_meta__ = {
|
||||
"name": "bili_parser",
|
||||
"description": "自动解析B站分享卡片,提取视频封面和播放量等信息。",
|
||||
@@ -52,10 +56,14 @@ def parse_video_info(video_url: str) -> Optional[Dict[str, Any]]:
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
|
||||
script_tag = soup.find('script', text=re.compile('window.__INITIAL_STATE__'))
|
||||
if not script_tag:
|
||||
if not script_tag or not script_tag.string:
|
||||
return None
|
||||
|
||||
json_str = re.search(r'window\.__INITIAL_STATE__\s*=\s*(\{.*?\});', script_tag.string).group(1)
|
||||
match = re.search(r'window\.__INITIAL_STATE__\s*=\s*(\{.*?\});', script_tag.string)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
json_str = match.group(1)
|
||||
data = json.loads(json_str)
|
||||
|
||||
video_data = data.get('videoData', {})
|
||||
@@ -121,6 +129,15 @@ async def handle_bili_share(event: MessageEvent):
|
||||
处理消息,检测B站分享链接(JSON卡片或文本链接)并进行解析。
|
||||
:param event: 消息事件对象
|
||||
"""
|
||||
# 消息去重
|
||||
if event.message_id in processed_messages:
|
||||
return
|
||||
processed_messages[event.message_id] = True
|
||||
|
||||
# 忽略机器人自己发送的消息,防止无限循环
|
||||
if event.user_id == event.self_id:
|
||||
return
|
||||
|
||||
url_to_process = None
|
||||
|
||||
# 1. 优先解析JSON卡片中的短链接
|
||||
@@ -176,6 +193,7 @@ async def process_bili_link(event: MessageEvent, url: str):
|
||||
return
|
||||
|
||||
# 检查视频时长
|
||||
video_message: Union[str, MessageSegment]
|
||||
if video_info['duration'] > 300: # 5分钟 = 300秒
|
||||
video_message = "视频时长超过5分钟,不进行解析。"
|
||||
else:
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
"""
|
||||
import asyncio
|
||||
from core.managers.command_manager import matcher
|
||||
from models import MessageEvent, PrivateMessageEvent
|
||||
from core.managers.permission_manager import ADMIN
|
||||
from models.events.message import MessageEvent, PrivateMessageEvent
|
||||
from core.permission import Permission
|
||||
from core.utils.logger import logger
|
||||
|
||||
# --- 会话状态管理 ---
|
||||
@@ -24,7 +24,7 @@ def cleanup_session(user_id: int):
|
||||
del broadcast_sessions[user_id]
|
||||
logger.info(f"[Broadcast] 会话 {user_id} 已超时,自动取消。")
|
||||
|
||||
@matcher.command("broadcast", "广播", permission=ADMIN)
|
||||
@matcher.command("broadcast", "广播", permission=Permission.ADMIN)
|
||||
async def broadcast_start(event: MessageEvent):
|
||||
"""
|
||||
广播指令的入口,启动一个等待用户消息的会话。
|
||||
@@ -92,7 +92,7 @@ async def handle_broadcast_content(event: MessageEvent):
|
||||
nodes_to_send = [
|
||||
bot.build_forward_node(
|
||||
user_id=event.user_id,
|
||||
nickname=event.sender.nickname,
|
||||
nickname=event.sender.nickname if event.sender else "未知用户",
|
||||
message=message_to_broadcast
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1,35 +1,24 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import html
|
||||
import textwrap
|
||||
# -*- coding: utf-8 -*-
|
||||
import html
|
||||
import textwrap
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
|
||||
from core.managers.command_manager import matcher
|
||||
from models import MessageEvent
|
||||
from core.managers.permission_manager import ADMIN
|
||||
from models.events.message import MessageEvent
|
||||
from core.permission import Permission
|
||||
from core.utils.logger import logger
|
||||
|
||||
__plugin_meta__ = {
|
||||
"name": "Python 代码执行",
|
||||
"description": "在安全的沙箱环境中执行 Python 代码片段,支持单行、多行和转发回复。",
|
||||
"usage": "/py <单行代码>\n/code_py <单行代码>\n/py (进入多行输入模式)",
|
||||
"name": "Python 代码执行",
|
||||
"description": "在安全的沙箱环境中执行 Python 代码片段,支持单行、多行和转发回复。",
|
||||
"usage": "/py <单行代码>\n/code_py <单行代码>\n/py (进入多行输入模式)",
|
||||
}
|
||||
|
||||
# --- 会话状态管理 ---
|
||||
# 结构: {(user_id, group_id): asyncio.TimerHandle}
|
||||
multi_line_sessions: Dict[tuple, asyncio.TimerHandle] = {}
|
||||
|
||||
async def reply_as_forward(event: MessageEvent, input_code: str, output_result: str):
|
||||
# --- 会话状态管理 ---
|
||||
# 结构: {(user_id, group_id): asyncio.TimerHandle}
|
||||
multi_line_sessions: Dict[tuple, asyncio.TimerHandle] = {}
|
||||
|
||||
async def reply_as_forward(event: MessageEvent, input_code: str, output_result: str):
|
||||
"""
|
||||
将输入和输出打包成转发消息进行回复。
|
||||
@@ -41,35 +30,7 @@ async def reply_as_forward(event: MessageEvent, input_code: str, output_result:
|
||||
nodes = [
|
||||
bot.build_forward_node(
|
||||
user_id=event.user_id,
|
||||
nickname=event.sender.nickname or str(event.user_id),
|
||||
message=f"--- Your Code ---\n{input_code}"
|
||||
),
|
||||
bot.build_forward_node(
|
||||
user_id=event.self_id,
|
||||
nickname="Code Executor",
|
||||
message=f"--- Execution Result ---\n{output_result}"
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
# 2. 发送合并转发消息
|
||||
await bot.send_forwarded_messages(event, nodes)
|
||||
except Exception as e:
|
||||
logger.error(f"[code_py] 发送转发消息失败: {e}")
|
||||
# 降级为普通消息回复
|
||||
await event.reply(f"--- 你的代码 ---\n{input_code}\n--- 执行结果 ---\n{output_result}")
|
||||
|
||||
async def execute_code(event: MessageEvent, code: str):
|
||||
将输入和输出打包成转发消息进行回复。
|
||||
参考 forward_test.py 的实现,兼容私聊和群聊。
|
||||
"""
|
||||
bot = event.bot
|
||||
|
||||
# 1. 构建消息节点列表
|
||||
nodes = [
|
||||
bot.build_forward_node(
|
||||
user_id=event.user_id,
|
||||
nickname=event.sender.nickname or str(event.user_id),
|
||||
nickname=event.sender.nickname if event.sender else str(event.user_id),
|
||||
message=f"--- Your Code ---\n{input_code}"
|
||||
),
|
||||
bot.build_forward_node(
|
||||
@@ -90,7 +51,6 @@ async def execute_code(event: MessageEvent, code: str):
|
||||
async def execute_code(event: MessageEvent, code: str):
|
||||
"""
|
||||
核心代码执行逻辑。
|
||||
核心代码执行逻辑。
|
||||
"""
|
||||
code_executor = getattr(event.bot, 'code_executor', None)
|
||||
if not code_executor or not code_executor.docker_client:
|
||||
@@ -137,74 +97,15 @@ def normalize_code(code: str) -> str:
|
||||
return code.strip()
|
||||
|
||||
|
||||
@matcher.command("py", "python", "code_py", permission=ADMIN)
|
||||
async def code_py_main(event: MessageEvent, args: list[str]):
|
||||
code_executor = getattr(event.bot, 'code_executor', None)
|
||||
if not code_executor or not code_executor.docker_client:
|
||||
await event.reply("代码执行服务当前不可用,请检查 Docker 连接配置。")
|
||||
return
|
||||
|
||||
# 修改 add_task,让它能直接接收回复函数
|
||||
await code_executor.add_task(
|
||||
code,
|
||||
lambda result: reply_as_forward(event, code, result)
|
||||
)
|
||||
await event.reply("代码已提交至沙箱执行队列,请稍候...")
|
||||
|
||||
def cleanup_session(session_key: tuple):
|
||||
"""
|
||||
清理超时的会话。
|
||||
"""
|
||||
if session_key in multi_line_sessions:
|
||||
del multi_line_sessions[session_key]
|
||||
logger.info(f"[code_py] 会话 {session_key} 已超时,自动取消。")
|
||||
|
||||
def normalize_code(code: str) -> str:
|
||||
"""
|
||||
规范化用户输入的 Python 代码字符串。
|
||||
|
||||
主要处理两个问题:
|
||||
1. 对消息中可能存在的 HTML 实体进行解码 (e.g., [ -> [)。
|
||||
2. 移除整个代码块的公共前导缩进,以修复因复制粘贴导致的多余缩进。
|
||||
|
||||
:param code: 原始代码字符串。
|
||||
:return: 规范化后的代码字符串。
|
||||
"""
|
||||
# 1. 解码 HTML 实体
|
||||
code = html.unescape(code)
|
||||
|
||||
# 2. 移除公共前导缩进
|
||||
try:
|
||||
code = textwrap.dedent(code)
|
||||
except Exception:
|
||||
# 在某些情况下(例如,不一致的缩进),dedent 可能会失败,
|
||||
# 但我们不希望因此中断流程,所以捕获异常并继续。
|
||||
pass
|
||||
|
||||
return code.strip()
|
||||
|
||||
|
||||
@matcher.command("py", "python", "code_py", permission=ADMIN)
|
||||
@matcher.command("py", "python", "code_py", permission=Permission.ADMIN)
|
||||
async def code_py_main(event: MessageEvent, args: list[str]):
|
||||
"""
|
||||
/py 命令的主入口。
|
||||
- 如果有参数,直接执行。
|
||||
- 如果没有参数,开启多行输入模式。
|
||||
/py 命令的主入口。
|
||||
- 如果有参数,直接执行。
|
||||
- 如果没有参数,开启多行输入模式。
|
||||
"""
|
||||
code_to_run = " ".join(args)
|
||||
|
||||
if code_to_run:
|
||||
# 单行模式,对代码进行规范化处理
|
||||
normalized_code = normalize_code(code_to_run)
|
||||
if not normalized_code:
|
||||
await event.reply("代码为空或格式错误,请输入有效的代码。")
|
||||
return
|
||||
await execute_code(event, normalized_code)
|
||||
code_to_run = " ".join(args)
|
||||
|
||||
if code_to_run:
|
||||
# 单行模式,对代码进行规范化处理
|
||||
normalized_code = normalize_code(code_to_run)
|
||||
@@ -231,24 +132,6 @@ async def code_py_main(event: MessageEvent, args: list[str]):
|
||||
session_key
|
||||
)
|
||||
multi_line_sessions[session_key] = timeout_handler
|
||||
# 多行模式
|
||||
# 使用 getattr 兼容私聊和群聊
|
||||
session_key = (event.user_id, getattr(event, 'group_id', 'private'))
|
||||
|
||||
# 如果上一个会话的超时任务还在,先取消它
|
||||
if session_key in multi_line_sessions:
|
||||
multi_line_sessions[session_key].cancel()
|
||||
|
||||
await event.reply("已进入多行代码输入模式,请直接发送你的代码。\n(60秒内无操作将自动取消)")
|
||||
|
||||
# 设置 60 秒超时
|
||||
loop = asyncio.get_running_loop()
|
||||
timeout_handler = loop.call_later(
|
||||
60,
|
||||
cleanup_session,
|
||||
session_key
|
||||
)
|
||||
multi_line_sessions[session_key] = timeout_handler
|
||||
|
||||
@matcher.on_message()
|
||||
async def handle_multi_line_code(event: MessageEvent):
|
||||
@@ -265,26 +148,6 @@ async def handle_multi_line_code(event: MessageEvent):
|
||||
# 对多行代码进行规范化处理
|
||||
normalized_code = normalize_code(event.raw_message)
|
||||
|
||||
if not normalized_code:
|
||||
await event.reply("捕获到的代码为空或格式错误,已取消输入。")
|
||||
return
|
||||
|
||||
await execute_code(event, normalized_code)
|
||||
return True # 消费事件,防止其他处理器响应
|
||||
async def handle_multi_line_code(event: MessageEvent):
|
||||
"""
|
||||
通用消息处理器,用于捕获多行模式下的代码输入。
|
||||
"""
|
||||
# 使用 getattr 兼容私聊和群聊
|
||||
session_key = (event.user_id, getattr(event, 'group_id', 'private'))
|
||||
if session_key in multi_line_sessions:
|
||||
# 取消超时任务
|
||||
multi_line_sessions[session_key].cancel()
|
||||
del multi_line_sessions[session_key]
|
||||
|
||||
# 对多行代码进行规范化处理
|
||||
normalized_code = normalize_code(event.raw_message)
|
||||
|
||||
if not normalized_code:
|
||||
await event.reply("捕获到的代码为空或格式错误,已取消输入。")
|
||||
return
|
||||
|
||||
@@ -5,7 +5,7 @@ Echo 与交互插件
|
||||
"""
|
||||
from core.managers.command_manager import matcher
|
||||
from core.bot import Bot
|
||||
from models import MessageEvent
|
||||
from models.events.message import MessageEvent
|
||||
|
||||
__plugin_meta__ = {
|
||||
"name": "echo",
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"""
|
||||
from core.managers.command_manager import matcher
|
||||
from core.bot import Bot
|
||||
from models import MessageEvent
|
||||
from models.events.message import MessageEvent
|
||||
from models.message import MessageSegment
|
||||
|
||||
__plugin_meta__ = {
|
||||
@@ -22,14 +22,15 @@ async def handle_forward_test(bot: Bot, event: MessageEvent, args: list[str]):
|
||||
:param args: 指令参数
|
||||
"""
|
||||
# 1. 构建消息节点列表
|
||||
nickname = event.sender.nickname if event.sender else "未知用户"
|
||||
nodes = [
|
||||
bot.build_forward_node(user_id=event.self_id, nickname="机器人", message="你要的furry来了"),
|
||||
bot.build_forward_node(user_id=event.user_id, nickname=event.sender.nickname, message="让我看看"),
|
||||
bot.build_forward_node(user_id=event.user_id, nickname=nickname, message="让我看看"),
|
||||
bot.build_forward_node(
|
||||
user_id=event.self_id,
|
||||
nickname="机器人",
|
||||
message=[
|
||||
MessageSegment.text("你要的福瑞图"),
|
||||
MessageSegment.from_text("你要的福瑞图"),
|
||||
MessageSegment.image("https://api.furry.ist/furry-img/")
|
||||
]
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from datetime import datetime
|
||||
from core.bot import Bot
|
||||
from core.managers.command_manager import matcher
|
||||
from core.utils.executor import run_in_thread_pool
|
||||
from models import MessageEvent, MessageSegment
|
||||
from models.events.message import MessageEvent, MessageSegment
|
||||
|
||||
__plugin_meta__ = {
|
||||
"name": "jrcd",
|
||||
@@ -79,14 +79,17 @@ async def handle_jrcd(bot: Bot, event: MessageEvent, args: list[str]):
|
||||
"""
|
||||
user_id = event.user_id
|
||||
jrcd = await run_in_thread_pool(get_jrcd, user_id)
|
||||
msg = [MessageSegment.at(user_id)]
|
||||
|
||||
msg_text = ""
|
||||
if jrcd <= 9:
|
||||
msg.append(MessageSegment.text(random.choice(JRCDMSG_1) % jrcd))
|
||||
msg_text = random.choice(JRCDMSG_1) % jrcd
|
||||
elif jrcd <= 19:
|
||||
msg.append(MessageSegment.text(random.choice(JRCDMSG_2) % jrcd))
|
||||
msg_text = random.choice(JRCDMSG_2) % jrcd
|
||||
else:
|
||||
msg.append(MessageSegment.text(random.choice(JRCDMSG_3) % jrcd))
|
||||
await event.reply(msg)
|
||||
msg_text = random.choice(JRCDMSG_3) % jrcd
|
||||
|
||||
reply_segments = [MessageSegment.at(user_id), MessageSegment.from_text(msg_text)]
|
||||
await event.reply(reply_segments)
|
||||
|
||||
|
||||
@matcher.command("bbcd")
|
||||
@@ -118,29 +121,31 @@ async def handle_bbcd(bot: Bot, event: MessageEvent, args: list[str]):
|
||||
|
||||
jrcz = jrcd1 - jrcd2
|
||||
|
||||
msg = [
|
||||
text_part = ""
|
||||
if jrcz == 0:
|
||||
text_part = f" 一样长。{random.choice(BBCDMSG7)}"
|
||||
elif jrcz > 0:
|
||||
text_part = f" 长{jrcz}cm。"
|
||||
if jrcz <= 9:
|
||||
text_part += random.choice(BBCDMSG1)
|
||||
elif jrcz <= 19:
|
||||
text_part += random.choice(BBCDMSG2)
|
||||
else:
|
||||
text_part += random.choice(BBCDMSG3)
|
||||
else: # jrcz < 0
|
||||
text_part = f" 短{abs(jrcz)}cm。"
|
||||
if jrcz >= -9:
|
||||
text_part += random.choice(BBCDMSG4)
|
||||
elif jrcz >= -19:
|
||||
text_part += random.choice(BBCDMSG5)
|
||||
else:
|
||||
text_part += random.choice(BBCDMSG6)
|
||||
|
||||
segments = [
|
||||
MessageSegment.at(user_id1),
|
||||
MessageSegment.text("你的长度比"),
|
||||
MessageSegment.from_text(" 你的长度比 "),
|
||||
MessageSegment.at(user_id2),
|
||||
MessageSegment.from_text(text_part),
|
||||
]
|
||||
|
||||
if jrcz == 0:
|
||||
msg.append(MessageSegment.text("一样长。"))
|
||||
msg.append(MessageSegment.text(random.choice(BBCDMSG7)))
|
||||
elif jrcz > 0:
|
||||
msg.append(MessageSegment.text("长" + str(jrcz) + "cm。"))
|
||||
if jrcz <= 9:
|
||||
msg.append(MessageSegment.text(random.choice(BBCDMSG1)))
|
||||
elif jrcz <= 19:
|
||||
msg.append(MessageSegment.text(random.choice(BBCDMSG2)))
|
||||
else:
|
||||
msg.append(MessageSegment.text(random.choice(BBCDMSG3)))
|
||||
elif jrcz < 0:
|
||||
msg.append(MessageSegment.text("短" + str(abs(jrcz)) + "cm。"))
|
||||
if jrcz >= -9:
|
||||
msg.append(MessageSegment.text(random.choice(BBCDMSG4)))
|
||||
elif jrcz >= -19:
|
||||
msg.append(MessageSegment.text(random.choice(BBCDMSG5)))
|
||||
else:
|
||||
msg.append(MessageSegment.text(random.choice(BBCDMSG6)))
|
||||
await event.reply(msg)
|
||||
await event.reply(segments)
|
||||
|
||||
@@ -7,7 +7,7 @@ thpic 插件
|
||||
|
||||
from core.bot import Bot
|
||||
from core.managers.command_manager import matcher
|
||||
from models import MessageEvent, MessageSegment
|
||||
from models.events.message import MessageEvent, MessageSegment
|
||||
|
||||
__plugin_meta__ = {
|
||||
"name": "thpic",
|
||||
@@ -26,6 +26,6 @@ async def handle_echo(bot: Bot, event: MessageEvent, args: list[str]):
|
||||
:param args: 指令参数列表(未使用)。
|
||||
"""
|
||||
try:
|
||||
await event.reply(MessageSegment.image("https://img.paulzzh.com/touhou/random"))
|
||||
await event.reply(str(MessageSegment.image("https://img.paulzzh.com/touhou/random")))
|
||||
except Exception as e:
|
||||
await event.reply("报错了。。。" + e)
|
||||
await event.reply(f"报错了。。。{e}")
|
||||
|
||||
@@ -11,7 +11,6 @@ pipreqs==0.4.13
|
||||
redis==5.0.7
|
||||
requests==2.32.5
|
||||
soupsieve==2.8.1
|
||||
toml==0.10.2
|
||||
typing==3.7.4.3
|
||||
typing_extensions==4.15.0
|
||||
urllib3==2.6.2
|
||||
@@ -19,7 +18,15 @@ watchdog==6.0.0
|
||||
websockets==15.0.1
|
||||
win32_setctime==1.2.0
|
||||
yarg==0.1.10
|
||||
cachetools
|
||||
pydantic
|
||||
docker
|
||||
pytest
|
||||
pytest-asyncio
|
||||
pytest-mock
|
||||
pytest-cov
|
||||
httpx==0.27.0
|
||||
|
||||
# Dev Dependencies
|
||||
mypy
|
||||
pydantic[mypy]
|
||||
|
||||
37
tests/test_basic.py
Normal file
37
tests/test_basic.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 确保项目根目录在 sys.path 中
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
def test_import_core():
|
||||
"""测试核心模块是否可以被导入"""
|
||||
try:
|
||||
import core
|
||||
import core.bot
|
||||
import core.ws
|
||||
except ImportError as e:
|
||||
pytest.fail(f"无法导入核心模块: {e}")
|
||||
|
||||
def test_plugin_manager_path():
|
||||
"""测试插件管理器路径逻辑是否正确"""
|
||||
from core.managers.plugin_manager import PluginManager
|
||||
# Mock command manager
|
||||
pm = PluginManager(None)
|
||||
|
||||
# 我们无法直接测试 load_all_plugins 的内部路径变量,
|
||||
# 但我们可以检查它是否能找到 plugins 目录而不报错
|
||||
# 这里我们简单地断言 PluginManager 类存在且可以实例化
|
||||
assert pm is not None
|
||||
|
||||
def test_config_loader_exists():
|
||||
"""测试配置加载器是否存在"""
|
||||
# 注意:导入 config_loader 会尝试读取 config.toml
|
||||
# 如果 config.toml 不存在,这可能会失败。
|
||||
# 这是一个已知的设计问题,但在测试环境中我们假设 config.toml 存在或被 mock
|
||||
if os.path.exists("config.toml"):
|
||||
from core.config_loader import global_config
|
||||
assert global_config is not None
|
||||
else:
|
||||
pytest.skip("config.toml 不存在,跳过配置加载测试")
|
||||
114
tests/test_command_manager.py
Normal file
114
tests/test_command_manager.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from core.managers.command_manager import CommandManager
|
||||
from models.events.message import GroupMessageEvent
|
||||
from models.message import MessageSegment
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot():
|
||||
bot = AsyncMock()
|
||||
bot.self_id = 123456
|
||||
return bot
|
||||
|
||||
@pytest.fixture
|
||||
def command_manager():
|
||||
# 创建一个新的 CommandManager 实例用于测试,避免单例状态污染
|
||||
return CommandManager(prefixes=("/",))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_registration_and_execution(command_manager, mock_bot):
|
||||
"""测试命令注册和执行"""
|
||||
|
||||
# 定义一个命令处理函数
|
||||
handler_mock = AsyncMock()
|
||||
|
||||
# 注册命令
|
||||
@command_manager.command("test")
|
||||
async def test_command(bot, event):
|
||||
await handler_mock(bot, event)
|
||||
|
||||
# 构造触发命令的事件
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
event.post_type = "message"
|
||||
event.message_type = "group"
|
||||
event.raw_message = "/test"
|
||||
event.message = [MessageSegment.text("/test")]
|
||||
event.user_id = 111
|
||||
event.group_id = 222
|
||||
|
||||
# 处理事件
|
||||
await command_manager.handle_event(mock_bot, event)
|
||||
|
||||
# 验证处理函数被调用
|
||||
handler_mock.assert_called_once_with(mock_bot, event)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_prefix_match(command_manager, mock_bot):
|
||||
"""测试命令前缀匹配"""
|
||||
handler_mock = AsyncMock()
|
||||
|
||||
@command_manager.command("hello")
|
||||
async def hello_command(bot, event):
|
||||
await handler_mock(bot, event)
|
||||
|
||||
# 1. 正确的前缀
|
||||
event1 = MagicMock(spec=GroupMessageEvent)
|
||||
event1.post_type = "message"
|
||||
event1.raw_message = "/hello"
|
||||
event1.message = [MessageSegment.text("/hello")]
|
||||
await command_manager.handle_event(mock_bot, event1)
|
||||
handler_mock.assert_called_once()
|
||||
handler_mock.reset_mock()
|
||||
|
||||
# 2. 错误的前缀 (应该忽略)
|
||||
event2 = MagicMock(spec=GroupMessageEvent)
|
||||
event2.post_type = "message"
|
||||
event2.raw_message = ".hello" # 假设前缀是 /
|
||||
event2.message = [MessageSegment.text(".hello")]
|
||||
await command_manager.handle_event(mock_bot, event2)
|
||||
handler_mock.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignore_self_message(command_manager, mock_bot):
|
||||
"""测试忽略自身消息"""
|
||||
# 模拟配置
|
||||
with patch("core.managers.command_manager.global_config") as mock_config:
|
||||
mock_config.bot.ignore_self_message = True
|
||||
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
event.post_type = "message"
|
||||
event.user_id = 123456 # 与 bot.self_id 相同
|
||||
event.self_id = 123456
|
||||
|
||||
# Mock handle 方法来检测是否被调用
|
||||
command_manager.message_handler.handle = AsyncMock()
|
||||
|
||||
await command_manager.handle_event(mock_bot, event)
|
||||
|
||||
# 应该直接返回,不调用 handler
|
||||
command_manager.message_handler.handle.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_command(command_manager, mock_bot):
|
||||
"""测试内置 help 命令"""
|
||||
# 注册一个测试插件信息
|
||||
command_manager.plugins["test_plugin"] = {
|
||||
"name": "测试插件",
|
||||
"description": "这是一个测试",
|
||||
"usage": "/test"
|
||||
}
|
||||
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
event.post_type = "message"
|
||||
event.raw_message = "/help"
|
||||
event.message = [MessageSegment.text("/help")]
|
||||
|
||||
await command_manager.handle_event(mock_bot, event)
|
||||
|
||||
# 验证 bot.send 被调用,且内容包含插件信息
|
||||
mock_bot.send.assert_called_once()
|
||||
args, _ = mock_bot.send.call_args
|
||||
sent_msg = args[1]
|
||||
assert "测试插件" in sent_msg
|
||||
assert "这是一个测试" in sent_msg
|
||||
141
tests/test_event_factory.py
Normal file
141
tests/test_event_factory.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
from models.events.factory import EventFactory, EventType
|
||||
from models.events.message import GroupMessageEvent, PrivateMessageEvent
|
||||
from models.events.notice import GroupIncreaseNoticeEvent
|
||||
from models.events.request import FriendRequestEvent
|
||||
from models.events.meta import HeartbeatEvent
|
||||
from models.message import MessageSegment
|
||||
|
||||
class TestEventFactory:
|
||||
def test_create_group_message_event_list(self):
|
||||
"""测试创建群消息事件 (message 为列表格式)"""
|
||||
data = {
|
||||
"post_type": "message",
|
||||
"message_type": "group",
|
||||
"time": 1600000000,
|
||||
"self_id": 123456,
|
||||
"sub_type": "normal",
|
||||
"message_id": 1001,
|
||||
"user_id": 111111,
|
||||
"group_id": 222222,
|
||||
"message": [
|
||||
{"type": "text", "data": {"text": "Hello"}}
|
||||
],
|
||||
"raw_message": "Hello",
|
||||
"font": 0,
|
||||
"sender": {
|
||||
"user_id": 111111,
|
||||
"nickname": "User",
|
||||
"role": "member"
|
||||
}
|
||||
}
|
||||
event = EventFactory.create_event(data)
|
||||
assert isinstance(event, GroupMessageEvent)
|
||||
assert event.group_id == 222222
|
||||
assert len(event.message) == 1
|
||||
assert event.message[0].type == "text"
|
||||
assert event.message[0].data["text"] == "Hello"
|
||||
|
||||
def test_create_group_message_event_str(self):
|
||||
"""测试创建群消息事件 (message 为字符串格式)"""
|
||||
data = {
|
||||
"post_type": "message",
|
||||
"message_type": "group",
|
||||
"time": 1600000000,
|
||||
"self_id": 123456,
|
||||
"sub_type": "normal",
|
||||
"message_id": 1002,
|
||||
"user_id": 111111,
|
||||
"group_id": 222222,
|
||||
"message": "Hello World",
|
||||
"raw_message": "Hello World",
|
||||
"font": 0,
|
||||
"sender": {
|
||||
"user_id": 111111,
|
||||
"nickname": "User"
|
||||
}
|
||||
}
|
||||
event = EventFactory.create_event(data)
|
||||
assert isinstance(event, GroupMessageEvent)
|
||||
assert len(event.message) == 1
|
||||
assert event.message[0].type == "text"
|
||||
assert event.message[0].data["text"] == "Hello World"
|
||||
|
||||
def test_create_private_message_event(self):
|
||||
"""测试创建私聊消息事件"""
|
||||
data = {
|
||||
"post_type": "message",
|
||||
"message_type": "private",
|
||||
"time": 1600000000,
|
||||
"self_id": 123456,
|
||||
"sub_type": "friend",
|
||||
"message_id": 2001,
|
||||
"user_id": 333333,
|
||||
"message": "Private Msg",
|
||||
"raw_message": "Private Msg",
|
||||
"font": 0,
|
||||
"sender": {
|
||||
"user_id": 333333,
|
||||
"nickname": "Friend"
|
||||
}
|
||||
}
|
||||
event = EventFactory.create_event(data)
|
||||
assert isinstance(event, PrivateMessageEvent)
|
||||
assert event.user_id == 333333
|
||||
|
||||
def test_create_notice_event(self):
|
||||
"""测试创建通知事件 (群成员增加)"""
|
||||
data = {
|
||||
"post_type": "notice",
|
||||
"notice_type": "group_increase",
|
||||
"sub_type": "approve",
|
||||
"group_id": 222222,
|
||||
"operator_id": 444444,
|
||||
"user_id": 555555,
|
||||
"time": 1600000000,
|
||||
"self_id": 123456
|
||||
}
|
||||
event = EventFactory.create_event(data)
|
||||
assert isinstance(event, GroupIncreaseNoticeEvent)
|
||||
assert event.group_id == 222222
|
||||
assert event.user_id == 555555
|
||||
|
||||
def test_create_request_event(self):
|
||||
"""测试创建请求事件 (加好友)"""
|
||||
data = {
|
||||
"post_type": "request",
|
||||
"request_type": "friend",
|
||||
"user_id": 666666,
|
||||
"comment": "Add me",
|
||||
"flag": "flag_123",
|
||||
"time": 1600000000,
|
||||
"self_id": 123456
|
||||
}
|
||||
event = EventFactory.create_event(data)
|
||||
assert isinstance(event, FriendRequestEvent)
|
||||
assert event.user_id == 666666
|
||||
assert event.comment == "Add me"
|
||||
|
||||
def test_create_meta_event(self):
|
||||
"""测试创建元事件 (心跳)"""
|
||||
data = {
|
||||
"post_type": "meta_event",
|
||||
"meta_event_type": "heartbeat",
|
||||
"time": 1600000000,
|
||||
"self_id": 123456,
|
||||
"status": {"online": True, "good": True},
|
||||
"interval": 5000
|
||||
}
|
||||
event = EventFactory.create_event(data)
|
||||
assert isinstance(event, HeartbeatEvent)
|
||||
assert event.interval == 5000
|
||||
|
||||
def test_unknown_event_type(self):
|
||||
"""测试未知事件类型"""
|
||||
data = {
|
||||
"post_type": "unknown_type",
|
||||
"time": 1600000000,
|
||||
"self_id": 123456
|
||||
}
|
||||
with pytest.raises(ValueError, match="Unknown event type"):
|
||||
EventFactory.create_event(data)
|
||||
194
tests/test_event_handler.py
Normal file
194
tests/test_event_handler.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from core.handlers.event_handler import MessageHandler, NoticeHandler, RequestHandler
|
||||
from models.events.message import GroupMessageEvent
|
||||
from models.events.notice import GroupIncreaseNoticeEvent
|
||||
from models.events.request import FriendRequestEvent
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot():
|
||||
bot = AsyncMock()
|
||||
return bot
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_handler_run_handler_injection(mock_bot):
|
||||
"""测试参数注入"""
|
||||
handler = MessageHandler(prefixes=("/",))
|
||||
|
||||
# 1. 测试注入 bot 和 event
|
||||
async def func1(bot, event):
|
||||
assert bot == mock_bot
|
||||
assert event.user_id == 123
|
||||
return True
|
||||
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
event.user_id = 123
|
||||
|
||||
result = await handler._run_handler(func1, mock_bot, event)
|
||||
assert result is True
|
||||
|
||||
# 2. 测试注入 args
|
||||
async def func2(args):
|
||||
assert args == ["arg1", "arg2"]
|
||||
return True
|
||||
|
||||
result = await handler._run_handler(func2, mock_bot, event, args=["arg1", "arg2"])
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_handler_command_parsing(mock_bot):
|
||||
"""测试命令解析"""
|
||||
handler = MessageHandler(prefixes=("/",))
|
||||
|
||||
mock_func = AsyncMock()
|
||||
handler.commands["test"] = {
|
||||
"func": mock_func,
|
||||
"permission": None,
|
||||
"override_permission_check": False,
|
||||
"plugin_name": "test_plugin"
|
||||
}
|
||||
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
event.raw_message = "/test arg1 arg2"
|
||||
event.user_id = 123
|
||||
|
||||
# Mock permission manager
|
||||
with patch("core.managers.permission_manager.PermissionManager.check_permission", new_callable=AsyncMock) as mock_perm:
|
||||
mock_perm.return_value = True
|
||||
|
||||
await handler.handle(mock_bot, event)
|
||||
|
||||
mock_func.assert_called_once()
|
||||
# 验证 args 参数是否正确传递
|
||||
call_args = mock_func.call_args
|
||||
if "args" in call_args.kwargs:
|
||||
assert call_args.kwargs["args"] == ["arg1", "arg2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notice_handler(mock_bot):
|
||||
"""测试通知事件分发"""
|
||||
handler = NoticeHandler()
|
||||
|
||||
mock_func = AsyncMock()
|
||||
handler.handlers.append({
|
||||
"type": "group_increase",
|
||||
"func": mock_func,
|
||||
"plugin_name": "test_plugin"
|
||||
})
|
||||
|
||||
event = MagicMock(spec=GroupIncreaseNoticeEvent)
|
||||
event.notice_type = "group_increase"
|
||||
|
||||
await handler.handle(mock_bot, event)
|
||||
|
||||
mock_func.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_handler_execution(mock_bot):
|
||||
"""测试同步处理函数的执行"""
|
||||
handler = MessageHandler(prefixes=("/",))
|
||||
|
||||
def sync_func(event):
|
||||
return True
|
||||
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
|
||||
# 同步函数应该在线程池中运行
|
||||
result = await handler._run_handler(sync_func, mock_bot, event)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_handler_management(mock_bot):
|
||||
"""测试消息处理器的管理(注册、卸载、清空)"""
|
||||
handler = MessageHandler(prefixes=("/",))
|
||||
|
||||
# 测试 on_message 装饰器
|
||||
@handler.on_message()
|
||||
async def msg_handler(event):
|
||||
pass
|
||||
|
||||
assert len(handler.message_handlers) == 1
|
||||
|
||||
# 测试 command 装饰器
|
||||
@handler.command("cmd1", "cmd2")
|
||||
async def cmd_handler(event):
|
||||
pass
|
||||
|
||||
assert len(handler.commands) == 2
|
||||
assert "cmd1" in handler.commands
|
||||
assert "cmd2" in handler.commands
|
||||
|
||||
# 测试 unregister_by_plugin_name
|
||||
# 直接从已注册的处理器中获取 plugin_name
|
||||
if handler.message_handlers:
|
||||
plugin_name = handler.message_handlers[0]["plugin_name"]
|
||||
handler.unregister_by_plugin_name(plugin_name)
|
||||
|
||||
assert len(handler.message_handlers) == 0
|
||||
assert len(handler.commands) == 0
|
||||
|
||||
# 测试 clear
|
||||
handler.commands["cmd"] = {}
|
||||
handler.message_handlers.append({})
|
||||
handler.clear()
|
||||
assert len(handler.commands) == 0
|
||||
assert len(handler.message_handlers) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_handler(mock_bot):
|
||||
"""测试请求事件处理器"""
|
||||
handler = RequestHandler()
|
||||
|
||||
mock_func = AsyncMock()
|
||||
|
||||
# 测试 register 装饰器
|
||||
@handler.register("friend")
|
||||
async def req_handler(event):
|
||||
await mock_func(event)
|
||||
|
||||
assert len(handler.handlers) == 1
|
||||
|
||||
event = MagicMock(spec=FriendRequestEvent)
|
||||
event.request_type = "friend"
|
||||
|
||||
await handler.handle(mock_bot, event)
|
||||
mock_func.assert_called_once()
|
||||
|
||||
# 测试 unregister 和 clear
|
||||
import inspect
|
||||
module = inspect.getmodule(req_handler)
|
||||
plugin_name = module.__name__
|
||||
|
||||
handler.unregister_by_plugin_name(plugin_name)
|
||||
assert len(handler.handlers) == 0
|
||||
|
||||
handler.handlers.append({})
|
||||
handler.clear()
|
||||
assert len(handler.handlers) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_denied(mock_bot):
|
||||
"""测试权限不足的情况"""
|
||||
handler = MessageHandler(prefixes=("/",))
|
||||
|
||||
mock_func = AsyncMock()
|
||||
handler.commands["admin_cmd"] = {
|
||||
"func": mock_func,
|
||||
"permission": "ADMIN", # 假设 Permission.ADMIN
|
||||
"override_permission_check": False,
|
||||
"plugin_name": "test_plugin"
|
||||
}
|
||||
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
event.raw_message = "/admin_cmd"
|
||||
event.user_id = 123
|
||||
|
||||
# Mock permission manager returning False
|
||||
with patch("core.managers.permission_manager.PermissionManager.check_permission", new_callable=AsyncMock) as mock_perm:
|
||||
mock_perm.return_value = False
|
||||
|
||||
await handler.handle(mock_bot, event)
|
||||
|
||||
mock_func.assert_not_called()
|
||||
# 应该发送拒绝消息
|
||||
mock_bot.send.assert_called_once()
|
||||
75
tests/test_models.py
Normal file
75
tests/test_models.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
from models.message import MessageSegment
|
||||
from models.objects import GroupInfo, StrangerInfo
|
||||
|
||||
class TestMessageSegment:
|
||||
def test_text_segment(self):
|
||||
seg = MessageSegment.text("Hello")
|
||||
assert seg.type == "text"
|
||||
assert seg.data["text"] == "Hello"
|
||||
assert str(seg) == "Hello"
|
||||
|
||||
def test_at_segment(self):
|
||||
seg = MessageSegment.at(123456)
|
||||
assert seg.type == "at"
|
||||
assert seg.data["qq"] == "123456"
|
||||
assert str(seg) == "[CQ:at,qq=123456]"
|
||||
|
||||
def test_image_segment(self):
|
||||
seg = MessageSegment.image("http://example.com/img.jpg", cache=False, proxy=False)
|
||||
assert seg.type == "image"
|
||||
assert seg.data["file"] == "http://example.com/img.jpg"
|
||||
assert str(seg) == "[CQ:image,file=http://example.com/img.jpg,cache=0,proxy=0]"
|
||||
|
||||
def test_face_segment(self):
|
||||
seg = MessageSegment.face(123)
|
||||
assert seg.type == "face"
|
||||
assert seg.data["id"] == "123"
|
||||
assert str(seg) == "[CQ:face,id=123]"
|
||||
|
||||
def test_reply_segment(self):
|
||||
seg = MessageSegment.reply(1001)
|
||||
assert seg.type == "reply"
|
||||
assert seg.data["id"] == "1001"
|
||||
assert str(seg) == "[CQ:reply,id=1001]"
|
||||
|
||||
def test_add_segments(self):
|
||||
seg1 = MessageSegment.text("Hello ")
|
||||
seg2 = MessageSegment.at(123)
|
||||
combined = seg1 + seg2
|
||||
assert isinstance(combined, list)
|
||||
assert len(combined) == 2
|
||||
assert combined[0] == seg1
|
||||
assert combined[1] == seg2
|
||||
|
||||
def test_add_segment_and_string(self):
|
||||
seg = MessageSegment.at(123)
|
||||
combined = seg + " Hello"
|
||||
assert isinstance(combined, list)
|
||||
assert len(combined) == 2
|
||||
assert combined[0] == seg
|
||||
assert combined[1].type == "text"
|
||||
assert combined[1].data["text"] == " Hello"
|
||||
|
||||
class TestObjects:
|
||||
def test_group_info(self):
|
||||
data = {
|
||||
"group_id": 123456,
|
||||
"group_name": "Test Group",
|
||||
"member_count": 10,
|
||||
"max_member_count": 100
|
||||
}
|
||||
group = GroupInfo(**data)
|
||||
assert group.group_id == 123456
|
||||
assert group.group_name == "Test Group"
|
||||
|
||||
def test_stranger_info(self):
|
||||
data = {
|
||||
"user_id": 111111,
|
||||
"nickname": "Stranger",
|
||||
"sex": "male",
|
||||
"age": 18
|
||||
}
|
||||
user = StrangerInfo(**data)
|
||||
assert user.user_id == 111111
|
||||
assert user.nickname == "Stranger"
|
||||
Reference in New Issue
Block a user