Dev (#28)
* 滚木 * feat: 重构核心架构,增强类型安全与插件管理 本次提交对核心模块进行了深度重构,引入 Pydantic 增强配置管理的类型安全性,并全面优化了插件管理系统。 主要变更详情: 1. 核心架构与配置 - 重构配置加载模块:引入 Pydantic 模型 (`core/config_models.py`),提供严格的配置项类型检查、验证及默认值管理。 - 统一模块结构:规范化模块导入路径,移除冗余的 `__init__.py` 文件,提升项目结构的清晰度。 - 性能优化:集成 Redis 缓存支持 (`RedisManager`),有效降低高频 API 调用开销,提升响应速度。 2. 插件系统升级 - 实现热重载机制:新增插件文件变更监听功能,支持开发过程中自动重载插件,提升开发效率。 - 优化生命周期管理:改进插件加载与卸载逻辑,支持精确卸载指定插件及其关联的命令、事件处理器和定时任务。 3. 功能特性增强 - 新增媒体 API:引入 `MediaAPI` 模块,封装图片、语音等富媒体资源的获取与处理接口。 - 完善权限体系:重构权限管理系统,实现管理员与操作员的分级控制,支持更细粒度的命令权限校验。 4. 代码质量与稳定性 - 全面类型修复:解决 `mypy` 静态类型检查发现的大量类型错误(包括 `CommandManager`、`EventFactory` 及 `Bot` API 签名不匹配问题)。 - 增强错误处理:优化消息处理管道的异常捕获机制,完善关键路径的日志记录,提升系统运行稳定性。 * feat: 添加测试用例并优化代码结构 refactor(permission_manager): 调整初始化顺序和逻辑 fix(admin_manager): 修复初始化逻辑和目录创建问题 feat(ws): 优化Bot实例初始化条件 feat(message): 增强MessageSegment功能并添加测试 feat(events): 支持字符串格式的消息解析 test: 添加核心功能测试用例 refactor(plugin_manager): 改进插件路径处理 style: 清理无用导入和代码 chore: 更新依赖项
This commit is contained in:
@@ -1,6 +0,0 @@
|
|||||||
from .managers.command_manager import matcher
|
|
||||||
from .config_loader import global_config
|
|
||||||
from .managers.plugin_manager import PluginDataManager
|
|
||||||
from .ws import WS
|
|
||||||
|
|
||||||
__all__ = ["WS", "matcher", "global_config", "PluginDataManager"]
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from .message import MessageAPI
|
|||||||
from .group import GroupAPI
|
from .group import GroupAPI
|
||||||
from .friend import FriendAPI
|
from .friend import FriendAPI
|
||||||
from .account import AccountAPI
|
from .account import AccountAPI
|
||||||
|
from .media import MediaAPI
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseAPI",
|
"BaseAPI",
|
||||||
@@ -10,4 +11,5 @@ __all__ = [
|
|||||||
"GroupAPI",
|
"GroupAPI",
|
||||||
"FriendAPI",
|
"FriendAPI",
|
||||||
"AccountAPI",
|
"AccountAPI",
|
||||||
|
"MediaAPI",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -162,3 +162,56 @@ class AccountAPI(BaseAPI):
|
|||||||
"""
|
"""
|
||||||
return await self.call_api("clean_cache")
|
return await self.call_api("clean_cache")
|
||||||
|
|
||||||
|
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> Any:
|
||||||
|
"""
|
||||||
|
获取陌生人信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (int): 目标用户的 QQ 号。
|
||||||
|
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: 包含陌生人信息的字典或对象。
|
||||||
|
"""
|
||||||
|
return await self.call_api("get_stranger_info", {"user_id": user_id, "no_cache": no_cache})
|
||||||
|
|
||||||
|
async def get_friend_list(self, no_cache: bool = False) -> list:
|
||||||
|
"""
|
||||||
|
获取好友列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 好友列表。
|
||||||
|
"""
|
||||||
|
cache_key = f"neobot:cache:get_friend_list:{self.self_id}"
|
||||||
|
if not no_cache:
|
||||||
|
cached_data = await redis_manager.get(cache_key)
|
||||||
|
if cached_data:
|
||||||
|
return json.loads(cached_data)
|
||||||
|
|
||||||
|
res = await self.call_api("get_friend_list")
|
||||||
|
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||||
|
return res
|
||||||
|
|
||||||
|
async def get_group_list(self, no_cache: bool = False) -> list:
|
||||||
|
"""
|
||||||
|
获取群列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 群列表。
|
||||||
|
"""
|
||||||
|
cache_key = f"neobot:cache:get_group_list:{self.self_id}"
|
||||||
|
if not no_cache:
|
||||||
|
cached_data = await redis_manager.get(cache_key)
|
||||||
|
if cached_data:
|
||||||
|
return json.loads(cached_data)
|
||||||
|
|
||||||
|
res = await self.call_api("get_group_list")
|
||||||
|
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||||
|
return res
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +1,50 @@
|
|||||||
"""
|
"""
|
||||||
API 基础模块
|
API 基础模块
|
||||||
|
|
||||||
定义了 API 调用的基础接口。
|
定义了 API 调用的基础接口和统一处理逻辑。
|
||||||
"""
|
"""
|
||||||
from abc import ABC, abstractmethod
|
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
from ..utils.logger import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..ws import WS
|
||||||
|
|
||||||
|
|
||||||
class BaseAPI(ABC):
|
class BaseAPI:
|
||||||
"""
|
"""
|
||||||
API 基础抽象类
|
API 基础类,提供了统一的 `call_api` 方法,包含日志记录和异常处理。
|
||||||
"""
|
"""
|
||||||
|
_ws: "WS"
|
||||||
|
self_id: int
|
||||||
|
|
||||||
|
def __init__(self, ws_client: "WS", self_id: int):
|
||||||
|
self._ws = ws_client
|
||||||
|
self.self_id = self_id
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def call_api(self, action: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
async def call_api(self, action: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||||
"""
|
"""
|
||||||
调用 API
|
调用 OneBot v11 API,并提供统一的日志和异常处理。
|
||||||
|
|
||||||
:param action: API 动作名称
|
:param action: API 动作名称
|
||||||
:param params: API 参数
|
:param params: API 参数
|
||||||
:return: API 响应结果
|
:return: API 响应结果的数据部分
|
||||||
|
:raises Exception: 当 API 调用失败或发生网络错误时
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug(f"调用API -> action: {action}, params: {params}")
|
||||||
|
response = await self._ws.call_api(action, params)
|
||||||
|
logger.debug(f"API响应 <- {response}")
|
||||||
|
|
||||||
|
if response.get("status") == "failed":
|
||||||
|
logger.warning(f"API调用失败: {response}")
|
||||||
|
|
||||||
|
return response.get("data")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"API调用异常: action={action}, params={params}, error={e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
该模块定义了 `GroupAPI` Mixin 类,提供了所有与群组管理、成员操作
|
该模块定义了 `GroupAPI` Mixin 类,提供了所有与群组管理、成员操作
|
||||||
等相关的 OneBot v11 API 封装。
|
等相关的 OneBot v11 API 封装。
|
||||||
"""
|
"""
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Optional
|
||||||
import json
|
import json
|
||||||
from ..managers.redis_manager import redis_manager
|
from ..managers.redis_manager import redis_manager
|
||||||
from .base import BaseAPI
|
from .base import BaseAPI
|
||||||
@@ -46,7 +46,7 @@ class GroupAPI(BaseAPI):
|
|||||||
"""
|
"""
|
||||||
return await self.call_api("set_group_ban", {"group_id": group_id, "user_id": user_id, "duration": duration})
|
return await self.call_api("set_group_ban", {"group_id": group_id, "user_id": user_id, "duration": duration})
|
||||||
|
|
||||||
async def set_group_anonymous_ban(self, group_id: int, anonymous: Dict[str, Any] = None, duration: int = 1800, flag: str = None) -> Dict[str, Any]:
|
async def set_group_anonymous_ban(self, group_id: int, anonymous: Optional[Dict[str, Any]] = None, duration: int = 1800, flag: Optional[str] = None) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
禁言群组中的匿名用户。
|
禁言群组中的匿名用户。
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ class GroupAPI(BaseAPI):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: OneBot API 的响应数据。
|
Dict[str, Any]: OneBot API 的响应数据。
|
||||||
"""
|
"""
|
||||||
params = {"group_id": group_id, "duration": duration}
|
params: Dict[str, Any] = {"group_id": group_id, "duration": duration}
|
||||||
if anonymous:
|
if anonymous:
|
||||||
params["anonymous"] = anonymous
|
params["anonymous"] = anonymous
|
||||||
if flag:
|
if flag:
|
||||||
@@ -187,17 +187,18 @@ class GroupAPI(BaseAPI):
|
|||||||
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||||
return GroupInfo(**res)
|
return GroupInfo(**res)
|
||||||
|
|
||||||
async def get_group_list(self) -> List[GroupInfo]:
|
async def get_group_list(self) -> Any:
|
||||||
"""
|
"""
|
||||||
获取机器人加入的所有群组的列表。
|
获取机器人加入的所有群组的列表。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[GroupInfo]: 包含所有群组信息的 `GroupInfo` 对象列表。
|
Any: 包含所有群组信息的列表(可能是字典列表或对象列表)。
|
||||||
"""
|
"""
|
||||||
res = await self.call_api("get_group_list")
|
res = await self.call_api("get_group_list")
|
||||||
|
|
||||||
# 增加日志记录 API 原始返回
|
# 增加日志记录 API 原始返回
|
||||||
logger.debug(f"OneBot API 'get_group_list' raw response: {res}")
|
logger.debug(f"OneBot API 'get_group_list' raw response: {res}")
|
||||||
|
return res
|
||||||
|
|
||||||
# 健壮性处理:处理标准的 OneBot v11 响应格式
|
# 健壮性处理:处理标准的 OneBot v11 响应格式
|
||||||
if isinstance(res, dict) and res.get("status") == "ok":
|
if isinstance(res, dict) and res.get("status") == "ok":
|
||||||
|
|||||||
39
core/api/media.py
Normal file
39
core/api/media.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""
|
||||||
|
媒体API模块
|
||||||
|
|
||||||
|
封装了与图片、语音等媒体文件相关的API。
|
||||||
|
"""
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from .base import BaseAPI
|
||||||
|
|
||||||
|
|
||||||
|
class MediaAPI(BaseAPI):
|
||||||
|
"""
|
||||||
|
媒体相关API
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def can_send_image(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
检查是否可以发送图片
|
||||||
|
|
||||||
|
:return: OneBot v11标准响应
|
||||||
|
"""
|
||||||
|
return await self.call_api(action="can_send_image")
|
||||||
|
|
||||||
|
async def can_send_record(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
检查是否可以发送语音
|
||||||
|
|
||||||
|
:return: OneBot v11标准响应
|
||||||
|
"""
|
||||||
|
return await self.call_api(action="can_send_record")
|
||||||
|
|
||||||
|
async def get_image(self, file: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取图片信息
|
||||||
|
|
||||||
|
:param file: 图片文件名或路径
|
||||||
|
:return: OneBot v11标准响应
|
||||||
|
"""
|
||||||
|
return await self.call_api(action="get_image", params={"file": file})
|
||||||
@@ -8,7 +8,8 @@ from typing import Union, List, Dict, Any, TYPE_CHECKING
|
|||||||
from .base import BaseAPI
|
from .base import BaseAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from models import MessageSegment, OneBotEvent
|
from models.message import MessageSegment
|
||||||
|
from models.events.base import OneBotEvent
|
||||||
|
|
||||||
|
|
||||||
class MessageAPI(BaseAPI):
|
class MessageAPI(BaseAPI):
|
||||||
@@ -156,24 +157,6 @@ class MessageAPI(BaseAPI):
|
|||||||
"""
|
"""
|
||||||
return await self.call_api("send_private_forward_msg", {"user_id": user_id, "messages": messages})
|
return await self.call_api("send_private_forward_msg", {"user_id": user_id, "messages": messages})
|
||||||
|
|
||||||
async def can_send_image(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
检查当前机器人账号是否可以发送图片。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: OneBot API 的响应数据。
|
|
||||||
"""
|
|
||||||
return await self.call_api("can_send_image")
|
|
||||||
|
|
||||||
async def can_send_record(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
检查当前机器人账号是否可以发送语音。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: OneBot API 的响应数据。
|
|
||||||
"""
|
|
||||||
return await self.call_api("can_send_record")
|
|
||||||
|
|
||||||
def _process_message(self, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Union[str, List[Dict[str, Any]]]:
|
def _process_message(self, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Union[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
内部方法:将消息内容处理成 OneBot API 可接受的格式。
|
内部方法:将消息内容处理成 OneBot API 可接受的格式。
|
||||||
@@ -192,7 +175,7 @@ class MessageAPI(BaseAPI):
|
|||||||
return message
|
return message
|
||||||
|
|
||||||
# 避免循环导入,在运行时导入
|
# 避免循环导入,在运行时导入
|
||||||
from models import MessageSegment
|
from models.message import MessageSegment
|
||||||
|
|
||||||
if isinstance(message, MessageSegment):
|
if isinstance(message, MessageSegment):
|
||||||
return [self._segment_to_dict(message)]
|
return [self._segment_to_dict(message)]
|
||||||
|
|||||||
33
core/bot.py
33
core/bot.py
@@ -13,14 +13,15 @@ Bot 核心抽象模块
|
|||||||
from typing import TYPE_CHECKING, Dict, Any, List, Union
|
from typing import TYPE_CHECKING, Dict, Any, List, Union
|
||||||
from models.events.base import OneBotEvent
|
from models.events.base import OneBotEvent
|
||||||
from models.message import MessageSegment
|
from models.message import MessageSegment
|
||||||
|
from models.objects import GroupInfo, StrangerInfo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .WS import WS
|
from .ws import WS
|
||||||
|
|
||||||
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI
|
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI
|
||||||
|
|
||||||
|
|
||||||
class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI):
|
class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI):
|
||||||
"""
|
"""
|
||||||
机器人核心类,封装了所有与 OneBot API 的交互。
|
机器人核心类,封装了所有与 OneBot API 的交互。
|
||||||
|
|
||||||
@@ -35,22 +36,22 @@ class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI):
|
|||||||
Args:
|
Args:
|
||||||
ws_client (WS): WebSocket 客户端实例,负责底层的 API 请求和响应处理。
|
ws_client (WS): WebSocket 客户端实例,负责底层的 API 请求和响应处理。
|
||||||
"""
|
"""
|
||||||
self.ws = ws_client
|
super().__init__(ws_client, ws_client.self_id or 0)
|
||||||
|
self.code_executor = None
|
||||||
|
|
||||||
async def call_api(self, action: str, params: Dict[str, Any] = None) -> Any:
|
async def get_group_list(self, no_cache: bool = False) -> List[GroupInfo]:
|
||||||
"""
|
# GroupAPI.get_group_list 不支持 no_cache 参数,这里忽略它
|
||||||
底层 API 调用方法。
|
result = await super().get_group_list()
|
||||||
|
# 确保结果是 GroupInfo 对象列表
|
||||||
|
return [GroupInfo(**group) if isinstance(group, dict) else group for group in result]
|
||||||
|
|
||||||
所有具体的 API 实现最终都会调用此方法,通过 WebSocket 发送请求。
|
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> StrangerInfo:
|
||||||
|
result = await super().get_stranger_info(user_id=user_id, no_cache=no_cache)
|
||||||
|
# 确保结果是 StrangerInfo 对象
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return StrangerInfo(**result)
|
||||||
|
return result
|
||||||
|
|
||||||
Args:
|
|
||||||
action (str): API 的动作名称,例如 "send_group_msg"。
|
|
||||||
params (Dict[str, Any], optional): API 请求的参数字典。Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: OneBot API 的响应数据。
|
|
||||||
"""
|
|
||||||
return await self.ws.call_api(action, params)
|
|
||||||
|
|
||||||
def build_forward_node(self, user_id: int, nickname: str, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Dict[str, Any]:
|
def build_forward_node(self, user_id: int, nickname: str, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,9 +4,11 @@
|
|||||||
负责读取和解析 config.toml 配置文件,提供全局配置对象。
|
负责读取和解析 config.toml 配置文件,提供全局配置对象。
|
||||||
"""
|
"""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import tomllib
|
import tomllib
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from .config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel
|
||||||
|
from .utils.logger import logger
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@@ -21,73 +23,67 @@ class Config:
|
|||||||
:param file_path: 配置文件路径,默认为 "config.toml"
|
:param file_path: 配置文件路径,默认为 "config.toml"
|
||||||
"""
|
"""
|
||||||
self.path = Path(file_path)
|
self.path = Path(file_path)
|
||||||
self._data: Dict[str, Any] = {}
|
self._model: ConfigModel
|
||||||
self.load()
|
self.load()
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
"""
|
"""
|
||||||
加载配置文件
|
加载并验证配置文件
|
||||||
|
|
||||||
:raises FileNotFoundError: 如果配置文件不存在
|
:raises FileNotFoundError: 如果配置文件不存在
|
||||||
|
:raises ValidationError: 如果配置格式不正确
|
||||||
"""
|
"""
|
||||||
if not self.path.exists():
|
if not self.path.exists():
|
||||||
|
logger.error(f"配置文件 {self.path} 未找到!")
|
||||||
raise FileNotFoundError(f"配置文件 {self.path} 未找到!")
|
raise FileNotFoundError(f"配置文件 {self.path} 未找到!")
|
||||||
|
|
||||||
with open(self.path, "rb") as f:
|
try:
|
||||||
self._data = tomllib.load(f)
|
logger.info(f"正在从 {self.path} 加载配置...")
|
||||||
|
with open(self.path, "rb") as f:
|
||||||
|
raw_config = tomllib.load(f)
|
||||||
|
|
||||||
|
self._model = ConfigModel(**raw_config)
|
||||||
|
logger.success("配置加载并验证成功!")
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.error("配置验证失败,请检查 `config.toml` 文件中的以下错误:")
|
||||||
|
for error in e.errors():
|
||||||
|
field = " -> ".join(map(str, error["loc"]))
|
||||||
|
logger.error(f" - 字段 '{field}': {error['msg']}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"加载配置文件时发生未知错误: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
# 通过属性访问配置
|
# 通过属性访问配置
|
||||||
@property
|
@property
|
||||||
def napcat_ws(self) -> dict:
|
def napcat_ws(self) -> NapCatWSModel:
|
||||||
"""
|
"""
|
||||||
获取 NapCat WebSocket 配置
|
获取 NapCat WebSocket 配置
|
||||||
|
|
||||||
:return: 配置字典
|
|
||||||
"""
|
"""
|
||||||
return self._data.get("napcat_ws", {})
|
return self._model.napcat_ws
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bot(self) -> dict:
|
def bot(self) -> BotModel:
|
||||||
"""
|
"""
|
||||||
获取 Bot 基础配置
|
获取 Bot 基础配置
|
||||||
|
|
||||||
:return: 配置字典
|
|
||||||
"""
|
"""
|
||||||
return self._data.get("bot", {})
|
return self._model.bot
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def features(self) -> dict:
|
def redis(self) -> RedisModel:
|
||||||
"""
|
|
||||||
获取功能特性配置
|
|
||||||
|
|
||||||
:return: 配置字典
|
|
||||||
"""
|
|
||||||
return self._data.get("features", {})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def redis(self) -> dict:
|
|
||||||
"""
|
"""
|
||||||
获取 Redis 配置
|
获取 Redis 配置
|
||||||
|
|
||||||
:return: 配置字典
|
|
||||||
"""
|
"""
|
||||||
return self._data.get("redis", {})
|
return self._model.redis
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def docker(self) -> dict:
|
def docker(self) -> DockerModel:
|
||||||
"""
|
"""
|
||||||
获取 Docker 配置
|
获取 Docker 配置
|
||||||
|
|
||||||
:return: 配置字典
|
|
||||||
"""
|
"""
|
||||||
return self._data.get("docker", {})
|
return self._model.docker
|
||||||
|
|
||||||
|
|
||||||
# 实例化全局配置对象
|
# 实例化全局配置对象
|
||||||
global_config = Config()
|
global_config = Config()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print(global_config.napcat_ws)
|
|
||||||
print(global_config.bot.get("command"))
|
|
||||||
print(type(global_config.bot.get("command")) is list)
|
|
||||||
print(global_config.features)
|
|
||||||
|
|||||||
60
core/config_models.py
Normal file
60
core/config_models.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""
|
||||||
|
Pydantic 配置模型模块
|
||||||
|
|
||||||
|
该模块使用 Pydantic 定义了与 `config.toml` 文件结构完全对应的配置模型。
|
||||||
|
这使得配置的加载、校验和访问都变得类型安全和健壮。
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class NapCatWSModel(BaseModel):
|
||||||
|
"""
|
||||||
|
对应 `config.toml` 中的 `[napcat_ws]` 配置块。
|
||||||
|
"""
|
||||||
|
uri: str
|
||||||
|
token: str
|
||||||
|
reconnect_interval: int = 5
|
||||||
|
|
||||||
|
|
||||||
|
class BotModel(BaseModel):
|
||||||
|
"""
|
||||||
|
对应 `config.toml` 中的 `[bot]` 配置块。
|
||||||
|
"""
|
||||||
|
command: List[str] = Field(default_factory=lambda: ["/"])
|
||||||
|
ignore_self_message: bool = True
|
||||||
|
permission_denied_message: str = "权限不足,需要 {permission_name} 权限"
|
||||||
|
|
||||||
|
|
||||||
|
class RedisModel(BaseModel):
|
||||||
|
"""
|
||||||
|
对应 `config.toml` 中的 `[redis]` 配置块。
|
||||||
|
"""
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
db: int
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class DockerModel(BaseModel):
|
||||||
|
"""
|
||||||
|
对应 `config.toml` 中的 `[docker]` 配置块。
|
||||||
|
"""
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
sandbox_image: str = "python-sandbox:latest"
|
||||||
|
timeout: int = 10
|
||||||
|
concurrency_limit: int = 5
|
||||||
|
tls_verify: bool = False
|
||||||
|
ca_cert_path: Optional[str] = None
|
||||||
|
client_cert_path: Optional[str] = None
|
||||||
|
client_key_path: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigModel(BaseModel):
|
||||||
|
"""
|
||||||
|
顶层配置模型,整合了所有子配置块。
|
||||||
|
"""
|
||||||
|
napcat_ws: NapCatWSModel
|
||||||
|
bot: BotModel
|
||||||
|
redis: RedisModel
|
||||||
|
docker: DockerModel
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
{
|
{
|
||||||
"admins": []
|
"admins": [2221577113]
|
||||||
}
|
}
|
||||||
@@ -6,11 +6,12 @@
|
|||||||
"""
|
"""
|
||||||
import inspect
|
import inspect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
from ..bot import Bot
|
if TYPE_CHECKING:
|
||||||
|
from ..bot import Bot
|
||||||
from ..config_loader import global_config
|
from ..config_loader import global_config
|
||||||
from ..managers.permission_manager import Permission, permission_manager
|
from ..permission import Permission
|
||||||
from ..utils.executor import run_in_thread_pool
|
from ..utils.executor import run_in_thread_pool
|
||||||
|
|
||||||
|
|
||||||
@@ -22,7 +23,7 @@ class BaseHandler(ABC):
|
|||||||
self.handlers: List[Dict[str, Any]] = []
|
self.handlers: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def handle(self, bot: Bot, event: Any):
|
async def handle(self, bot: "Bot", event: Any):
|
||||||
"""
|
"""
|
||||||
处理事件
|
处理事件
|
||||||
"""
|
"""
|
||||||
@@ -31,7 +32,7 @@ class BaseHandler(ABC):
|
|||||||
async def _run_handler(
|
async def _run_handler(
|
||||||
self,
|
self,
|
||||||
func: Callable,
|
func: Callable,
|
||||||
bot: Bot,
|
bot: "Bot",
|
||||||
event: Any,
|
event: Any,
|
||||||
args: Optional[List[str]] = None,
|
args: Optional[List[str]] = None,
|
||||||
permission_granted: Optional[bool] = None
|
permission_granted: Optional[bool] = None
|
||||||
@@ -41,7 +42,7 @@ class BaseHandler(ABC):
|
|||||||
"""
|
"""
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
params = sig.parameters
|
params = sig.parameters
|
||||||
kwargs = {}
|
kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
if "bot" in params:
|
if "bot" in params:
|
||||||
kwargs["bot"] = bot
|
kwargs["bot"] = bot
|
||||||
@@ -68,21 +69,41 @@ class MessageHandler(BaseHandler):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.prefixes = prefixes
|
self.prefixes = prefixes
|
||||||
self.commands: Dict[str, Dict] = {}
|
self.commands: Dict[str, Dict] = {}
|
||||||
self.message_handlers: List[Callable] = []
|
self.message_handlers: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""
|
||||||
|
清空所有已注册的消息和命令处理器
|
||||||
|
"""
|
||||||
|
self.commands.clear()
|
||||||
|
self.message_handlers.clear()
|
||||||
|
|
||||||
|
def unregister_by_plugin_name(self, plugin_name: str):
|
||||||
|
"""
|
||||||
|
根据插件名卸载相关的消息和命令处理器
|
||||||
|
"""
|
||||||
|
# 卸载命令
|
||||||
|
commands_to_remove = [name for name, info in self.commands.items() if info["plugin_name"] == plugin_name]
|
||||||
|
for name in commands_to_remove:
|
||||||
|
del self.commands[name]
|
||||||
|
|
||||||
|
# 卸载通用消息处理器
|
||||||
|
self.message_handlers = [h for h in self.message_handlers if h["plugin_name"] != plugin_name]
|
||||||
|
|
||||||
def on_message(self) -> Callable:
|
def on_message(self) -> Callable:
|
||||||
"""
|
"""
|
||||||
注册通用消息处理器
|
注册通用消息处理器
|
||||||
"""
|
"""
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
self.message_handlers.append(func)
|
module = inspect.getmodule(func)
|
||||||
|
plugin_name = module.__name__ if module else "Unknown"
|
||||||
|
self.message_handlers.append({"func": func, "plugin_name": plugin_name})
|
||||||
return func
|
return func
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def command(
|
def command(
|
||||||
self,
|
self,
|
||||||
*names: str,
|
*names: str,
|
||||||
*names: str,
|
|
||||||
permission: Optional[Permission] = None,
|
permission: Optional[Permission] = None,
|
||||||
override_permission_check: bool = False
|
override_permission_check: bool = False
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
@@ -90,21 +111,25 @@ class MessageHandler(BaseHandler):
|
|||||||
注册命令处理器
|
注册命令处理器
|
||||||
"""
|
"""
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
|
module = inspect.getmodule(func)
|
||||||
|
plugin_name = module.__name__ if module else "Unknown"
|
||||||
for name in names:
|
for name in names:
|
||||||
self.commands[name] = {
|
self.commands[name] = {
|
||||||
"func": func,
|
"func": func,
|
||||||
"permission": permission,
|
"permission": permission,
|
||||||
"override_permission_check": override_permission_check,
|
"override_permission_check": override_permission_check,
|
||||||
|
"plugin_name": plugin_name,
|
||||||
}
|
}
|
||||||
return func
|
return func
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
async def handle(self, bot: Bot, event: Any):
|
async def handle(self, bot: "Bot", event: Any):
|
||||||
"""
|
"""
|
||||||
处理消息事件,包括通用消息和命令
|
处理消息事件,分发给命令处理器或通用消息处理器
|
||||||
"""
|
"""
|
||||||
for handler in self.message_handlers:
|
from ..managers import permission_manager
|
||||||
consumed = await self._run_handler(handler, bot, event)
|
for handler_info in self.message_handlers:
|
||||||
|
consumed = await self._run_handler(handler_info["func"], bot, event)
|
||||||
if consumed:
|
if consumed:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -136,7 +161,7 @@ class MessageHandler(BaseHandler):
|
|||||||
|
|
||||||
if not permission_granted and not override_check:
|
if not permission_granted and not override_check:
|
||||||
permission_name = permission.name if isinstance(permission, Permission) else permission
|
permission_name = permission.name if isinstance(permission, Permission) else permission
|
||||||
message_template = global_config.bot.get("permission_denied_message", "权限不足,需要 {permission_name} 权限")
|
message_template = global_config.bot.permission_denied_message
|
||||||
await bot.send(event, message_template.format(permission_name=permission_name))
|
await bot.send(event, message_template.format(permission_name=permission_name))
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -153,12 +178,23 @@ class NoticeHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
通知事件处理器
|
通知事件处理器
|
||||||
"""
|
"""
|
||||||
|
def clear(self):
|
||||||
|
self.handlers.clear()
|
||||||
|
|
||||||
|
def unregister_by_plugin_name(self, plugin_name: str):
|
||||||
|
"""
|
||||||
|
根据插件名卸载相关的通知处理器
|
||||||
|
"""
|
||||||
|
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
|
||||||
|
|
||||||
def register(self, notice_type: Optional[str] = None) -> Callable:
|
def register(self, notice_type: Optional[str] = None) -> Callable:
|
||||||
"""
|
"""
|
||||||
注册通知处理器
|
注册通知处理器
|
||||||
"""
|
"""
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
self.handlers.append({"type": notice_type, "func": func})
|
module = inspect.getmodule(func)
|
||||||
|
plugin_name = module.__name__ if module else "Unknown"
|
||||||
|
self.handlers.append({"type": notice_type, "func": func, "plugin_name": plugin_name})
|
||||||
return func
|
return func
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@@ -175,12 +211,23 @@ class RequestHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
请求事件处理器
|
请求事件处理器
|
||||||
"""
|
"""
|
||||||
|
def clear(self):
|
||||||
|
self.handlers.clear()
|
||||||
|
|
||||||
|
def unregister_by_plugin_name(self, plugin_name: str):
|
||||||
|
"""
|
||||||
|
根据插件名卸载相关的请求处理器
|
||||||
|
"""
|
||||||
|
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
|
||||||
|
|
||||||
def register(self, request_type: Optional[str] = None) -> Callable:
|
def register(self, request_type: Optional[str] = None) -> Callable:
|
||||||
"""
|
"""
|
||||||
注册请求处理器
|
注册请求处理器
|
||||||
"""
|
"""
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
self.handlers.append({"type": request_type, "func": func})
|
module = inspect.getmodule(func)
|
||||||
|
plugin_name = module.__name__ if module else "Unknown"
|
||||||
|
self.handlers.append({"type": request_type, "func": func, "plugin_name": plugin_name})
|
||||||
return func
|
return func
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
管理器包
|
||||||
|
|
||||||
|
这个包集中了机器人核心的单例管理器。
|
||||||
|
通过从这里导入,可以确保在整个应用中访问到的都是同一个实例。
|
||||||
|
"""
|
||||||
|
from ..config_loader import global_config
|
||||||
|
from .admin_manager import AdminManager
|
||||||
|
from .command_manager import CommandManager
|
||||||
|
from .permission_manager import PermissionManager
|
||||||
|
from .plugin_manager import PluginManager
|
||||||
|
from .redis_manager import RedisManager
|
||||||
|
|
||||||
|
# --- 实例化所有单例管理器 ---
|
||||||
|
|
||||||
|
# 管理员管理器
|
||||||
|
admin_manager = AdminManager()
|
||||||
|
|
||||||
|
# 权限管理器
|
||||||
|
permission_manager = PermissionManager()
|
||||||
|
|
||||||
|
# 命令与事件管理器 (别名 matcher)
|
||||||
|
command_manager = CommandManager(prefixes=tuple(global_config.bot.command))
|
||||||
|
matcher = command_manager
|
||||||
|
|
||||||
|
# 插件管理器
|
||||||
|
plugin_manager = PluginManager(command_manager)
|
||||||
|
plugin_manager.load_all_plugins()
|
||||||
|
|
||||||
|
# Redis 管理器
|
||||||
|
redis_manager = RedisManager()
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"admin_manager",
|
||||||
|
"permission_manager",
|
||||||
|
"command_manager",
|
||||||
|
"matcher",
|
||||||
|
"plugin_manager",
|
||||||
|
"redis_manager",
|
||||||
|
]
|
||||||
|
|||||||
@@ -26,8 +26,7 @@ class AdminManager(Singleton):
|
|||||||
"""
|
"""
|
||||||
初始化 AdminManager
|
初始化 AdminManager
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
if hasattr(self, '_initialized') and self._initialized:
|
||||||
if not self._initialized:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# 管理员数据文件路径
|
# 管理员数据文件路径
|
||||||
@@ -39,7 +38,12 @@ class AdminManager(Singleton):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._admins: Set[int] = set()
|
self._admins: Set[int] = set()
|
||||||
|
|
||||||
|
# 确保数据目录存在
|
||||||
|
os.makedirs(os.path.dirname(self.data_file), exist_ok=True)
|
||||||
|
|
||||||
logger.info("管理员管理器初始化完成")
|
logger.info("管理员管理器初始化完成")
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -12,7 +12,16 @@ from ..handlers.event_handler import MessageHandler, NoticeHandler, RequestHandl
|
|||||||
|
|
||||||
|
|
||||||
# 从配置中获取命令前缀
|
# 从配置中获取命令前缀
|
||||||
command_prefixes = global_config.bot.get("command", ("/",))
|
_config_prefixes = global_config.bot.command
|
||||||
|
|
||||||
|
# 确保前缀配置是元组格式
|
||||||
|
_final_prefixes: Tuple[str, ...]
|
||||||
|
if isinstance(_config_prefixes, list):
|
||||||
|
_final_prefixes = tuple(_config_prefixes)
|
||||||
|
elif isinstance(_config_prefixes, str):
|
||||||
|
_final_prefixes = (_config_prefixes,)
|
||||||
|
else:
|
||||||
|
_final_prefixes = tuple(_config_prefixes)
|
||||||
|
|
||||||
|
|
||||||
class CommandManager:
|
class CommandManager:
|
||||||
@@ -59,6 +68,35 @@ class CommandManager:
|
|||||||
"usage": "/help",
|
"usage": "/help",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def clear_all_handlers(self):
|
||||||
|
"""
|
||||||
|
清空所有已注册的事件处理器。
|
||||||
|
注意:这也会移除内置的 /help 命令,因此需要重新注册。
|
||||||
|
"""
|
||||||
|
self.message_handler.clear()
|
||||||
|
self.notice_handler.clear()
|
||||||
|
self.request_handler.clear()
|
||||||
|
self.plugins.clear()
|
||||||
|
|
||||||
|
# 清空后,需要重新注册内置命令
|
||||||
|
self._register_internal_commands()
|
||||||
|
|
||||||
|
def unload_plugin(self, plugin_name: str):
|
||||||
|
"""
|
||||||
|
卸载指定插件的所有处理器和命令。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name (str): 插件的模块名 (例如 'plugins.bili_parser')
|
||||||
|
"""
|
||||||
|
self.message_handler.unregister_by_plugin_name(plugin_name)
|
||||||
|
self.notice_handler.unregister_by_plugin_name(plugin_name)
|
||||||
|
self.request_handler.unregister_by_plugin_name(plugin_name)
|
||||||
|
|
||||||
|
# 移除插件元信息
|
||||||
|
plugins_to_remove = [name for name in self.plugins if name.startswith(plugin_name)]
|
||||||
|
for name in plugins_to_remove:
|
||||||
|
del self.plugins[name]
|
||||||
|
|
||||||
# --- 装饰器代理 ---
|
# --- 装饰器代理 ---
|
||||||
|
|
||||||
def on_message(self) -> Callable:
|
def on_message(self) -> Callable:
|
||||||
@@ -102,7 +140,7 @@ class CommandManager:
|
|||||||
|
|
||||||
根据事件的 `post_type` 将其分发给对应的处理器。
|
根据事件的 `post_type` 将其分发给对应的处理器。
|
||||||
"""
|
"""
|
||||||
if event.post_type == 'message' and global_config.bot.get('ignore_self_message', False):
|
if event.post_type == 'message' and global_config.bot.ignore_self_message:
|
||||||
if hasattr(event, 'user_id') and hasattr(event, 'self_id') and event.user_id == event.self_id:
|
if hasattr(event, 'user_id') and hasattr(event, 'self_id') and event.user_id == event.self_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -130,14 +168,6 @@ class CommandManager:
|
|||||||
await bot.send(event, help_text.strip())
|
await bot.send(event, help_text.strip())
|
||||||
|
|
||||||
|
|
||||||
# --- 全局单例 ---
|
|
||||||
|
|
||||||
# 确保前缀配置是元组格式
|
|
||||||
if isinstance(command_prefixes, list):
|
|
||||||
command_prefixes = tuple(command_prefixes)
|
|
||||||
elif isinstance(command_prefixes, str):
|
|
||||||
command_prefixes = (command_prefixes,)
|
|
||||||
|
|
||||||
# 实例化全局唯一的命令管理器
|
# 实例化全局唯一的命令管理器
|
||||||
matcher = CommandManager(prefixes=command_prefixes)
|
matcher = CommandManager(prefixes=_final_prefixes)
|
||||||
|
|
||||||
|
|||||||
@@ -13,64 +13,17 @@
|
|||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from functools import total_ordering
|
from typing import Dict, Optional
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from ..utils.logger import logger
|
from ..utils.logger import logger
|
||||||
from ..utils.singleton import Singleton
|
from ..utils.singleton import Singleton
|
||||||
from .admin_manager import admin_manager
|
from .admin_manager import admin_manager
|
||||||
|
from ..permission import Permission
|
||||||
|
|
||||||
|
|
||||||
@total_ordering
|
|
||||||
class Permission:
|
|
||||||
"""
|
|
||||||
权限封装类
|
|
||||||
|
|
||||||
封装了权限的名称和等级,并提供了比较方法。
|
|
||||||
使用 @total_ordering 装饰器可以自动生成所有的比较运算符。
|
|
||||||
"""
|
|
||||||
def __init__(self, name: str, level: int):
|
|
||||||
"""
|
|
||||||
初始化权限对象
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): 权限名称 (e.g., "admin", "op")
|
|
||||||
level (int): 权限等级,数字越大权限越高
|
|
||||||
"""
|
|
||||||
self.name = name
|
|
||||||
self.level = level
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
"""
|
|
||||||
判断权限是否相等
|
|
||||||
"""
|
|
||||||
if not isinstance(other, Permission):
|
|
||||||
return NotImplemented
|
|
||||||
return self.level == other.level
|
|
||||||
|
|
||||||
def __lt__(self, other):
|
|
||||||
"""
|
|
||||||
判断权限是否小于另一个权限
|
|
||||||
"""
|
|
||||||
if not isinstance(other, Permission):
|
|
||||||
return NotImplemented
|
|
||||||
return self.level < other.level
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
"""
|
|
||||||
返回权限的字符串表示(即权限名称)
|
|
||||||
"""
|
|
||||||
return self.name
|
|
||||||
|
|
||||||
|
|
||||||
# 定义全局权限常量
|
|
||||||
ADMIN = Permission("admin", 3)
|
|
||||||
OP = Permission("op", 2)
|
|
||||||
USER = Permission("user", 1)
|
|
||||||
|
|
||||||
# 用于从字符串名称查找权限对象的字典
|
# 用于从字符串名称查找权限对象的字典
|
||||||
_PERMISSIONS: Dict[str, Permission] = {
|
_PERMISSIONS: Dict[str, Permission] = {
|
||||||
p.name: p for p in [ADMIN, OP, USER]
|
p.value: p for p in Permission
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -88,8 +41,7 @@ class PermissionManager(Singleton):
|
|||||||
|
|
||||||
如果已经初始化过,则直接返回。
|
如果已经初始化过,则直接返回。
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
if hasattr(self, '_initialized') and self._initialized:
|
||||||
if not self._initialized:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# 权限数据文件路径
|
# 权限数据文件路径
|
||||||
@@ -111,6 +63,7 @@ class PermissionManager(Singleton):
|
|||||||
self.load()
|
self.load()
|
||||||
|
|
||||||
logger.info("权限管理器初始化完成")
|
logger.info("权限管理器初始化完成")
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -164,12 +117,12 @@ class PermissionManager(Singleton):
|
|||||||
"""
|
"""
|
||||||
# 首先,通过 AdminManager 检查是否为管理员
|
# 首先,通过 AdminManager 检查是否为管理员
|
||||||
if await admin_manager.is_admin(user_id):
|
if await admin_manager.is_admin(user_id):
|
||||||
return ADMIN
|
return Permission.ADMIN
|
||||||
|
|
||||||
# 如果不是管理员,则从 permissions.json 中查找
|
# 如果不是管理员,则从 permissions.json 中查找
|
||||||
user_id_str = str(user_id)
|
user_id_str = str(user_id)
|
||||||
level_name = self._data["users"].get(user_id_str, USER.name)
|
level_name = self._data["users"].get(user_id_str, Permission.USER.value)
|
||||||
return _PERMISSIONS.get(level_name, USER)
|
return _PERMISSIONS.get(level_name, Permission.USER)
|
||||||
|
|
||||||
def set_user_permission(self, user_id: int, permission: Permission) -> None:
|
def set_user_permission(self, user_id: int, permission: Permission) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -182,13 +135,13 @@ class PermissionManager(Singleton):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: 如果权限对象无效
|
ValueError: 如果权限对象无效
|
||||||
"""
|
"""
|
||||||
if not isinstance(permission, Permission) or permission.name not in _PERMISSIONS:
|
if not isinstance(permission, Permission):
|
||||||
raise ValueError(f"无效的权限对象: {permission}")
|
raise ValueError(f"无效的权限对象: {permission}")
|
||||||
|
|
||||||
user_id_str = str(user_id)
|
user_id_str = str(user_id)
|
||||||
self._data["users"][user_id_str] = permission.name
|
self._data["users"][user_id_str] = permission.value
|
||||||
self.save()
|
self.save()
|
||||||
logger.info(f"设置用户 {user_id} 的权限级别为 {permission.name}")
|
logger.info(f"设置用户 {user_id} 的权限级别为 {permission.value}")
|
||||||
|
|
||||||
def remove_user(self, user_id: int) -> None:
|
def remove_user(self, user_id: int) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -214,17 +167,17 @@ class PermissionManager(Singleton):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 如果用户权限 >= 所需权限,返回 True,否则返回 False
|
bool: 如果用户权限 >= 所需权限,返回 True,否则返回 False
|
||||||
"""
|
"""
|
||||||
# 如果传入的是字符串,先转换为 Permission 对象
|
|
||||||
if isinstance(required_permission, str):
|
|
||||||
required_permission = _PERMISSIONS.get(required_permission.lower())
|
|
||||||
if not required_permission:
|
|
||||||
# 如果是无效的权限字符串,默认拒绝
|
|
||||||
logger.warning(f"检测到无效的权限检查字符串: {required_permission}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
user_permission = await self.get_user_permission(user_id)
|
user_permission = await self.get_user_permission(user_id)
|
||||||
return user_permission >= required_permission
|
return user_permission >= required_permission
|
||||||
|
|
||||||
|
def get_all_user_permissions(self) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
获取所有已配置的用户权限
|
||||||
|
|
||||||
|
:return: 一个包含所有用户权限的字典
|
||||||
|
"""
|
||||||
|
return self._data["users"].copy()
|
||||||
|
|
||||||
def get_all_users(self) -> Dict[str, str]:
|
def get_all_users(self) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
获取所有设置了权限的用户及其级别名称
|
获取所有设置了权限的用户及其级别名称
|
||||||
@@ -243,22 +196,22 @@ class PermissionManager(Singleton):
|
|||||||
logger.info("已清空所有权限设置")
|
logger.info("已清空所有权限设置")
|
||||||
|
|
||||||
|
|
||||||
# 全局权限管理器实例
|
|
||||||
permission_manager = PermissionManager()
|
|
||||||
|
|
||||||
def require_admin(func):
|
def require_admin(func):
|
||||||
"""
|
"""
|
||||||
一个装饰器,用于限制命令只能由管理员执行。
|
一个装饰器,用于限制命令只能由管理员执行。
|
||||||
"""
|
"""
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from models.events.message import MessageEvent
|
from models.events.message import MessageEvent
|
||||||
|
from core.managers import permission_manager
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def wrapper(event: MessageEvent, *args, **kwargs):
|
async def wrapper(event: MessageEvent, *args, **kwargs):
|
||||||
user_id = event.user_id
|
user_id = event.user_id
|
||||||
if await permission_manager.check_permission(user_id, ADMIN):
|
if await permission_manager.check_permission(user_id, Permission.ADMIN):
|
||||||
return await func(event, *args, **kwargs)
|
return await func(event, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
await event.reply("抱歉,您没有权限执行此命令。")
|
# 假设 event 对象有 reply 方法
|
||||||
|
if hasattr(event, "reply"):
|
||||||
|
await event.reply("抱歉,您没有权限执行此命令。")
|
||||||
return None
|
return None
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@@ -1,126 +1,97 @@
|
|||||||
"""
|
"""
|
||||||
插件管理器模块
|
插件管理器模块
|
||||||
|
|
||||||
负责扫描、加载和管理 `base_plugins` 目录下的所有插件。
|
负责扫描、加载和管理 `plugins` 目录下的所有插件。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
from .command_manager import matcher
|
|
||||||
from ..utils.exceptions import SyncHandlerError
|
from ..utils.exceptions import SyncHandlerError
|
||||||
from ..utils.logger import logger
|
from ..utils.logger import logger
|
||||||
from ..utils.executor import run_in_thread_pool
|
|
||||||
|
|
||||||
|
|
||||||
def load_all_plugins():
|
class PluginManager:
|
||||||
"""
|
"""
|
||||||
扫描并加载 `plugins` 目录下的所有插件。
|
插件管理器类
|
||||||
|
|
||||||
该函数会遍历 `plugins` 目录下的所有模块:
|
|
||||||
1. 如果模块已加载,则执行 reload 操作(用于热重载)。
|
|
||||||
2. 如果模块未加载,则执行 import 操作。
|
|
||||||
|
|
||||||
加载过程中会提取插件元数据 `__plugin_meta__` 并注册到 CommandManager。
|
|
||||||
"""
|
"""
|
||||||
plugin_dir = os.path.join(
|
def __init__(self, command_manager):
|
||||||
os.path.dirname(os.path.abspath(__file__)), "..", "plugins"
|
"""
|
||||||
)
|
初始化插件管理器
|
||||||
package_name = "plugins"
|
|
||||||
|
|
||||||
logger.info(f"正在从 {package_name} 加载插件...")
|
:param command_manager: CommandManager的实例
|
||||||
|
"""
|
||||||
|
self.command_manager = command_manager
|
||||||
|
self.loaded_plugins: Set[str] = set()
|
||||||
|
|
||||||
for loader, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
|
def load_all_plugins(self):
|
||||||
full_module_name = f"{package_name}.{module_name}"
|
"""
|
||||||
|
扫描并加载 `plugins` 目录下的所有插件。
|
||||||
|
"""
|
||||||
|
# 使用 pathlib 获取更可靠的路径
|
||||||
|
# 当前文件: core/managers/plugin_manager.py
|
||||||
|
# 目标: plugins/
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
# 回退两级到项目根目录 (core/managers -> core -> root)
|
||||||
|
root_dir = os.path.dirname(os.path.dirname(current_dir))
|
||||||
|
plugin_dir = os.path.join(root_dir, "plugins")
|
||||||
|
|
||||||
|
package_name = "plugins"
|
||||||
|
|
||||||
|
if not os.path.exists(plugin_dir):
|
||||||
|
logger.error(f"插件目录不存在: {plugin_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"正在从 {package_name} 加载插件 (路径: {plugin_dir})...")
|
||||||
|
|
||||||
|
for _, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
|
||||||
|
full_module_name = f"{package_name}.{module_name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
if full_module_name in self.loaded_plugins:
|
||||||
|
self.command_manager.unload_plugin(full_module_name)
|
||||||
|
module = importlib.reload(sys.modules[full_module_name])
|
||||||
|
action = "重载"
|
||||||
|
else:
|
||||||
|
module = importlib.import_module(full_module_name)
|
||||||
|
action = "加载"
|
||||||
|
|
||||||
|
if hasattr(module, "__plugin_meta__"):
|
||||||
|
meta = getattr(module, "__plugin_meta__")
|
||||||
|
self.command_manager.plugins[full_module_name] = meta
|
||||||
|
|
||||||
|
self.loaded_plugins.add(full_module_name)
|
||||||
|
|
||||||
|
type_str = "包" if is_pkg else "文件"
|
||||||
|
logger.success(f" [{type_str}] 成功{action}: {module_name}")
|
||||||
|
except SyncHandlerError as e:
|
||||||
|
logger.error(f" 插件 {module_name} 加载失败: {e} (跳过此插件)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f" {action if 'action' in locals() else '加载'}插件 {module_name} 失败: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def reload_plugin(self, full_module_name: str):
|
||||||
|
"""
|
||||||
|
精确重载单个插件。
|
||||||
|
"""
|
||||||
|
if full_module_name not in self.loaded_plugins:
|
||||||
|
logger.warning(f"尝试重载一个未被加载的插件: {full_module_name},将按首次加载处理。")
|
||||||
|
|
||||||
|
if full_module_name not in sys.modules:
|
||||||
|
logger.error(f"重载失败: 模块 {full_module_name} 未在 sys.modules 中找到。")
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if full_module_name in sys.modules:
|
self.command_manager.unload_plugin(full_module_name)
|
||||||
module = importlib.reload(sys.modules[full_module_name])
|
module = importlib.reload(sys.modules[full_module_name])
|
||||||
action = "重载"
|
|
||||||
else:
|
|
||||||
module = importlib.import_module(full_module_name)
|
|
||||||
action = "加载"
|
|
||||||
|
|
||||||
# 提取插件元数据
|
|
||||||
if hasattr(module, "__plugin_meta__"):
|
if hasattr(module, "__plugin_meta__"):
|
||||||
meta = getattr(module, "__plugin_meta__")
|
meta = getattr(module, "__plugin_meta__")
|
||||||
matcher.plugins[full_module_name] = meta
|
self.command_manager.plugins[full_module_name] = meta
|
||||||
|
|
||||||
type_str = "包" if is_pkg else "文件"
|
logger.success(f"插件 {full_module_name} 已成功重载。")
|
||||||
logger.success(f" [{type_str}] 成功{action}: {module_name}")
|
|
||||||
except SyncHandlerError as e:
|
|
||||||
logger.error(f" 插件 {module_name} 加载失败: {e} (跳过此插件)")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
logger.exception(f"重载插件 {full_module_name} 时发生错误: {e}")
|
||||||
f" {action if 'action' in locals() else '加载'}插件 {module_name} 失败: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PluginDataManager:
|
|
||||||
"""
|
|
||||||
用于管理插件产生的数据文件的类
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, plugin_name: str):
|
|
||||||
"""
|
|
||||||
初始化插件数据管理器
|
|
||||||
|
|
||||||
:param plugin_name: 插件名称
|
|
||||||
"""
|
|
||||||
self.plugin_name = plugin_name
|
|
||||||
self.data_file = os.path.join(
|
|
||||||
os.path.dirname(os.path.abspath(__file__)),
|
|
||||||
"..",
|
|
||||||
"plugins",
|
|
||||||
"data",
|
|
||||||
self.plugin_name + ".json",
|
|
||||||
)
|
|
||||||
self.data = {}
|
|
||||||
|
|
||||||
async def load(self):
|
|
||||||
"""读取配置文件"""
|
|
||||||
if not os.path.exists(self.data_file):
|
|
||||||
await self.set(self.plugin_name, [])
|
|
||||||
try:
|
|
||||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
|
||||||
self.data = await run_in_thread_pool(json.load, f)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
self.data = {}
|
|
||||||
|
|
||||||
async def save(self):
|
|
||||||
"""保存配置到文件"""
|
|
||||||
with open(self.data_file, "w", encoding="utf-8") as f:
|
|
||||||
await run_in_thread_pool(json.dump, self.data, f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
||||||
def get(self, key, default=None):
|
|
||||||
"""获取配置项"""
|
|
||||||
return self.data.get(key, default)
|
|
||||||
|
|
||||||
async def set(self, key, value):
|
|
||||||
"""设置配置项"""
|
|
||||||
self.data[key] = value
|
|
||||||
await self.save()
|
|
||||||
|
|
||||||
async def add(self, key, value):
|
|
||||||
"""添加配置项"""
|
|
||||||
if key not in self.data:
|
|
||||||
self.data[key] = []
|
|
||||||
self.data[key].append(value)
|
|
||||||
await self.save()
|
|
||||||
|
|
||||||
async def remove(self, key):
|
|
||||||
"""删除配置项"""
|
|
||||||
if key in self.data:
|
|
||||||
del self.data[key]
|
|
||||||
await self.save()
|
|
||||||
|
|
||||||
async def clear(self):
|
|
||||||
"""清空所有配置"""
|
|
||||||
self.data.clear()
|
|
||||||
await self.save()
|
|
||||||
|
|
||||||
def get_all(self):
|
|
||||||
return self.data.copy()
|
|
||||||
|
|||||||
@@ -20,10 +20,11 @@ class RedisManager:
|
|||||||
"""
|
"""
|
||||||
if self._redis is None:
|
if self._redis is None:
|
||||||
try:
|
try:
|
||||||
host = config.redis['host']
|
redis_config = config.redis
|
||||||
port = config.redis['port']
|
host = redis_config.host
|
||||||
db = config.redis['db']
|
port = redis_config.port
|
||||||
password = config.redis.get('password')
|
db = redis_config.db
|
||||||
|
password = redis_config.password
|
||||||
|
|
||||||
logger.info(f"正在尝试连接 Redis: {host}:{port}, DB: {db}")
|
logger.info(f"正在尝试连接 Redis: {host}:{port}, DB: {db}")
|
||||||
|
|
||||||
@@ -54,5 +55,17 @@ class RedisManager:
|
|||||||
raise ConnectionError("Redis 未初始化或连接失败,请先调用 initialize()")
|
raise ConnectionError("Redis 未初始化或连接失败,请先调用 initialize()")
|
||||||
return self._redis
|
return self._redis
|
||||||
|
|
||||||
|
async def get(self, name):
|
||||||
|
"""
|
||||||
|
获取指定键的值
|
||||||
|
"""
|
||||||
|
return await self.redis.get(name)
|
||||||
|
|
||||||
|
async def set(self, name, value, ex=None):
|
||||||
|
"""
|
||||||
|
设置指定键的值
|
||||||
|
"""
|
||||||
|
return await self.redis.set(name, value, ex=ex)
|
||||||
|
|
||||||
# 全局 Redis 管理器实例
|
# 全局 Redis 管理器实例
|
||||||
redis_manager = RedisManager()
|
redis_manager = RedisManager()
|
||||||
|
|||||||
42
core/permission.py
Normal file
42
core/permission.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from functools import total_ordering
|
||||||
|
|
||||||
|
|
||||||
|
@total_ordering
|
||||||
|
class Permission(Enum):
|
||||||
|
"""
|
||||||
|
定义用户权限等级的枚举类。
|
||||||
|
|
||||||
|
使用 @total_ordering 装饰器,只需定义 __lt__ 和 __eq__,
|
||||||
|
即可自动实现所有比较运算符。
|
||||||
|
"""
|
||||||
|
USER = "user"
|
||||||
|
OP = "op"
|
||||||
|
ADMIN = "admin"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _level_map(self):
|
||||||
|
"""
|
||||||
|
内部属性,用于映射枚举成员到整数等级。
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
Permission.USER: 1,
|
||||||
|
Permission.OP: 2,
|
||||||
|
Permission.ADMIN: 3
|
||||||
|
}
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
"""
|
||||||
|
比较当前权限是否小于另一个权限。
|
||||||
|
"""
|
||||||
|
if not isinstance(other, Permission):
|
||||||
|
return NotImplemented
|
||||||
|
return self._level_map[self] < self._level_map[other]
|
||||||
|
|
||||||
|
def __ge__(self, other):
|
||||||
|
"""
|
||||||
|
比较当前权限是否大于等于另一个权限。
|
||||||
|
"""
|
||||||
|
if not isinstance(other, Permission):
|
||||||
|
return NotImplemented
|
||||||
|
return self._level_map[self] >= self._level_map[other]
|
||||||
@@ -2,7 +2,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import docker
|
import docker
|
||||||
from docker.tls import TLSConfig
|
from docker.tls import TLSConfig
|
||||||
from typing import Dict, Any, Callable
|
from docker.types import LogConfig
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
from core.utils.logger import logger
|
from core.utils.logger import logger
|
||||||
|
|
||||||
@@ -10,21 +11,20 @@ class CodeExecutor:
|
|||||||
"""
|
"""
|
||||||
代码执行引擎,负责管理一个异步任务队列和并发的 Docker 容器执行。
|
代码执行引擎,负责管理一个异步任务队列和并发的 Docker 容器执行。
|
||||||
"""
|
"""
|
||||||
def __init__(self, bot_instance, config: Dict[str, Any]):
|
def __init__(self, config: Any):
|
||||||
"""
|
"""
|
||||||
初始化代码执行引擎。
|
初始化代码执行引擎。
|
||||||
:param bot_instance: Bot 实例,用于后续的消息回复。
|
:param config: 从 config_loader.py 加载的全局配置对象。
|
||||||
:param config: 从 config.toml 加载的配置字典。
|
|
||||||
"""
|
"""
|
||||||
self.bot = bot_instance
|
self.bot: Any = None # Bot 实例将在 WS 连接成功后动态注入
|
||||||
self.task_queue = asyncio.Queue()
|
self.task_queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
# 从传入的配置中读取 Docker 相关设置
|
# 从传入的配置中读取 Docker 相关设置
|
||||||
docker_config = config.docker
|
docker_config = config.docker
|
||||||
self.docker_base_url = docker_config.get("base_url")
|
self.docker_base_url = docker_config.base_url
|
||||||
self.sandbox_image = docker_config.get("sandbox_image", "python-sandbox:latest")
|
self.sandbox_image = docker_config.sandbox_image
|
||||||
self.timeout = docker_config.get("timeout", 10)
|
self.timeout = docker_config.timeout
|
||||||
concurrency = docker_config.get("concurrency_limit", 5)
|
concurrency = docker_config.concurrency_limit
|
||||||
|
|
||||||
self.concurrency_limit = asyncio.Semaphore(concurrency)
|
self.concurrency_limit = asyncio.Semaphore(concurrency)
|
||||||
self.docker_client = None
|
self.docker_client = None
|
||||||
@@ -34,10 +34,10 @@ class CodeExecutor:
|
|||||||
if self.docker_base_url:
|
if self.docker_base_url:
|
||||||
# 如果配置了远程 Docker 地址,则使用 TLS 选项进行连接
|
# 如果配置了远程 Docker 地址,则使用 TLS 选项进行连接
|
||||||
tls_config = None
|
tls_config = None
|
||||||
if docker_config.get("tls_verify", False):
|
if docker_config.tls_verify:
|
||||||
tls_config = TLSConfig(
|
tls_config = TLSConfig(
|
||||||
ca_cert=docker_config.get("ca_cert_path"),
|
ca_cert=docker_config.ca_cert_path,
|
||||||
client_cert=(docker_config.get("client_cert_path"), docker_config.get("client_key_path")),
|
client_cert=(docker_config.client_cert_path, docker_config.client_key_path),
|
||||||
verify=True
|
verify=True
|
||||||
)
|
)
|
||||||
self.docker_client = docker.DockerClient(base_url=self.docker_base_url, tls=tls_config)
|
self.docker_client = docker.DockerClient(base_url=self.docker_base_url, tls=tls_config)
|
||||||
@@ -60,7 +60,15 @@ class CodeExecutor:
|
|||||||
将代码执行任务添加到队列中。
|
将代码执行任务添加到队列中。
|
||||||
:param code: 待执行的 Python 代码字符串。
|
:param code: 待执行的 Python 代码字符串。
|
||||||
:param callback: 执行完毕后用于回复结果的回调函数。
|
:param callback: 执行完毕后用于回复结果的回调函数。
|
||||||
|
:raises RuntimeError: 如果 Docker 客户端未初始化。
|
||||||
"""
|
"""
|
||||||
|
if not self.docker_client:
|
||||||
|
logger.warning("[CodeExecutor] 尝试添加任务,但 Docker 客户端未初始化。任务被拒绝。")
|
||||||
|
# 这里可以选择抛出异常,或者直接调用回调返回错误信息
|
||||||
|
# 为了用户体验,我们构造一个错误结果并直接调用回调(如果可能)
|
||||||
|
# 但由于 callback 返回 Future,这里简单起见,我们记录日志并抛出异常
|
||||||
|
raise RuntimeError("Docker环境未就绪,无法执行代码。")
|
||||||
|
|
||||||
task = {"code": code, "callback": callback}
|
task = {"code": code, "callback": callback}
|
||||||
await self.task_queue.put(task)
|
await self.task_queue.put(task)
|
||||||
logger.info(f"[CodeExecutor] 新的代码执行任务已入队 (队列当前长度: {self.task_queue.qsize()})。")
|
logger.info(f"[CodeExecutor] 新的代码执行任务已入队 (队列当前长度: {self.task_queue.qsize()})。")
|
||||||
@@ -125,6 +133,9 @@ class CodeExecutor:
|
|||||||
同步函数:在 Docker 容器中运行代码。
|
同步函数:在 Docker 容器中运行代码。
|
||||||
此函数通过手动管理容器生命周期来提高稳定性。
|
此函数通过手动管理容器生命周期来提高稳定性。
|
||||||
"""
|
"""
|
||||||
|
if self.docker_client is None:
|
||||||
|
raise docker.errors.DockerException("Docker client is not initialized.")
|
||||||
|
|
||||||
container = None
|
container = None
|
||||||
try:
|
try:
|
||||||
# 1. 创建容器
|
# 1. 创建容器
|
||||||
@@ -134,7 +145,7 @@ class CodeExecutor:
|
|||||||
mem_limit='128m',
|
mem_limit='128m',
|
||||||
cpu_shares=512,
|
cpu_shares=512,
|
||||||
network_disabled=True,
|
network_disabled=True,
|
||||||
log_config={'type': 'json-file', 'config': {'max-size': '1m'}},
|
log_config=LogConfig(type='json-file', config={'max-size': '1m'}),
|
||||||
)
|
)
|
||||||
# 2. 启动容器
|
# 2. 启动容器
|
||||||
container.start()
|
container.start()
|
||||||
@@ -150,7 +161,7 @@ class CodeExecutor:
|
|||||||
# 5. 检查退出码,如果不为 0,则手动抛出 ContainerError
|
# 5. 检查退出码,如果不为 0,则手动抛出 ContainerError
|
||||||
if result.get('StatusCode', 0) != 0:
|
if result.get('StatusCode', 0) != 0:
|
||||||
raise docker.errors.ContainerError(
|
raise docker.errors.ContainerError(
|
||||||
container, result['StatusCode'], f"python -c '{code}'", self.sandbox_image, stderr
|
container, result['StatusCode'], f"python -c '{code}'", self.sandbox_image, stderr.decode('utf-8')
|
||||||
)
|
)
|
||||||
|
|
||||||
return stdout
|
return stdout
|
||||||
@@ -166,11 +177,11 @@ class CodeExecutor:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[CodeExecutor] 强制移除容器 {container.id} 时失败: {e}")
|
logger.error(f"[CodeExecutor] 强制移除容器 {container.id} 时失败: {e}")
|
||||||
|
|
||||||
def initialize_executor(bot_instance, config: Dict[str, Any]):
|
def initialize_executor(config: Any):
|
||||||
"""
|
"""
|
||||||
初始化并返回一个 CodeExecutor 实例。
|
初始化并返回一个 CodeExecutor 实例。
|
||||||
"""
|
"""
|
||||||
return CodeExecutor(bot_instance, config)
|
return CodeExecutor(config)
|
||||||
|
|
||||||
async def run_in_thread_pool(sync_func, *args, **kwargs):
|
async def run_in_thread_pool(sync_func, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
52
core/ws.py
52
core/ws.py
@@ -13,11 +13,12 @@ WebSocket 连接。它是整个机器人框架的底层通信基础。
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
from models import EventFactory
|
from models.events.factory import EventFactory
|
||||||
|
|
||||||
from .bot import Bot
|
from .bot import Bot
|
||||||
from .config_loader import global_config
|
from .config_loader import global_config
|
||||||
@@ -30,7 +31,7 @@ class WS:
|
|||||||
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
|
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, code_executor=None):
|
||||||
"""
|
"""
|
||||||
初始化 WebSocket 客户端。
|
初始化 WebSocket 客户端。
|
||||||
|
|
||||||
@@ -38,13 +39,15 @@ class WS:
|
|||||||
"""
|
"""
|
||||||
# 读取参数
|
# 读取参数
|
||||||
cfg = global_config.napcat_ws
|
cfg = global_config.napcat_ws
|
||||||
self.url = cfg.get("uri")
|
self.url = cfg.uri
|
||||||
self.token = cfg.get("token")
|
self.token = cfg.token
|
||||||
self.reconnect_interval = cfg.get("reconnect_interval", 5)
|
self.reconnect_interval = cfg.reconnect_interval
|
||||||
|
|
||||||
self.ws = None
|
self.ws = None
|
||||||
self._pending_requests = {}
|
self._pending_requests = {}
|
||||||
self.bot = Bot(self)
|
self.bot: Bot | None = None
|
||||||
|
self.self_id: int | None = None
|
||||||
|
self.code_executor = code_executor
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
"""
|
"""
|
||||||
@@ -124,18 +127,43 @@ class WS:
|
|||||||
try:
|
try:
|
||||||
# 使用工厂创建事件对象
|
# 使用工厂创建事件对象
|
||||||
event = EventFactory.create_event(event_data)
|
event = EventFactory.create_event(event_data)
|
||||||
|
|
||||||
|
# 尝试初始化 Bot 实例 (如果尚未初始化且事件包含 self_id)
|
||||||
|
# 只要事件中包含 self_id,我们就可以初始化 Bot,不必非要等待 meta_event
|
||||||
|
if self.bot is None and hasattr(event, 'self_id'):
|
||||||
|
self.self_id = event.self_id
|
||||||
|
self.bot = Bot(self)
|
||||||
|
logger.success(f"Bot 实例初始化完成: self_id={self.self_id}")
|
||||||
|
|
||||||
|
# 将代码执行器注入到 Bot 和执行器自身
|
||||||
|
if self.code_executor:
|
||||||
|
self.bot.code_executor = self.code_executor
|
||||||
|
self.code_executor.bot = self.bot
|
||||||
|
logger.info("代码执行器已成功注入 Bot 实例。")
|
||||||
|
|
||||||
|
# 如果 bot 尚未初始化,则不处理后续事件
|
||||||
|
if self.bot is None:
|
||||||
|
logger.warning("Bot 尚未初始化,跳过事件处理。")
|
||||||
|
return
|
||||||
|
|
||||||
event.bot = self.bot # 注入 Bot 实例
|
event.bot = self.bot # 注入 Bot 实例
|
||||||
|
|
||||||
# 打印日志
|
# 打印日志
|
||||||
if event.post_type == "message":
|
if event.post_type == "message":
|
||||||
sender_name = event.sender.nickname if event.sender else "Unknown"
|
sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown"
|
||||||
logger.info(f"[消息] {event.message_type} | {event.user_id}({sender_name}): {event.raw_message}")
|
message_type = getattr(event, "message_type", "Unknown")
|
||||||
|
user_id = getattr(event, "user_id", "Unknown")
|
||||||
|
raw_message = getattr(event, "raw_message", "")
|
||||||
|
logger.info(f"[消息] {message_type} | {user_id}({sender_name}): {raw_message}")
|
||||||
elif event.post_type == "notice":
|
elif event.post_type == "notice":
|
||||||
logger.info(f"[通知] {event.notice_type}")
|
notice_type = getattr(event, "notice_type", "Unknown")
|
||||||
|
logger.info(f"[通知] {notice_type}")
|
||||||
elif event.post_type == "request":
|
elif event.post_type == "request":
|
||||||
logger.info(f"[请求] {event.request_type}")
|
request_type = getattr(event, "request_type", "Unknown")
|
||||||
|
logger.info(f"[请求] {request_type}")
|
||||||
elif event.post_type == "meta_event":
|
elif event.post_type == "meta_event":
|
||||||
logger.debug(f"[元事件] {event.meta_event_type}")
|
meta_event_type = getattr(event, "meta_event_type", "Unknown")
|
||||||
|
logger.debug(f"[元事件] {meta_event_type}")
|
||||||
|
|
||||||
|
|
||||||
# 分发事件
|
# 分发事件
|
||||||
@@ -144,7 +172,7 @@ class WS:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"事件处理异常: {e}")
|
logger.exception(f"事件处理异常: {e}")
|
||||||
|
|
||||||
async def call_api(self, action: str, params: dict = None):
|
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
|
||||||
"""
|
"""
|
||||||
向 OneBot v11 实现端发送一个 API 请求。
|
向 OneBot v11 实现端发送一个 API 请求。
|
||||||
|
|
||||||
|
|||||||
37
main.py
37
main.py
@@ -15,7 +15,7 @@ from core.utils.logger import logger
|
|||||||
|
|
||||||
from core.managers.admin_manager import admin_manager
|
from core.managers.admin_manager import admin_manager
|
||||||
from core.ws import WS
|
from core.ws import WS
|
||||||
from core.managers.plugin_manager import load_all_plugins
|
from core.managers import plugin_manager
|
||||||
from core.managers.redis_manager import redis_manager
|
from core.managers.redis_manager import redis_manager
|
||||||
from core.utils.executor import run_in_thread_pool, initialize_executor
|
from core.utils.executor import run_in_thread_pool, initialize_executor
|
||||||
from core.config_loader import global_config as config
|
from core.config_loader import global_config as config
|
||||||
@@ -25,6 +25,8 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|||||||
sys.path.insert(0, ROOT_DIR)
|
sys.path.insert(0, ROOT_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
# 获取插件目录的绝对路径
|
||||||
|
PLUGIN_DIR = os.path.join(ROOT_DIR, "plugins")
|
||||||
|
|
||||||
|
|
||||||
class PluginReloadHandler(FileSystemEventHandler):
|
class PluginReloadHandler(FileSystemEventHandler):
|
||||||
@@ -32,7 +34,7 @@ class PluginReloadHandler(FileSystemEventHandler):
|
|||||||
文件变更处理器,用于热重载插件
|
文件变更处理器,用于热重载插件
|
||||||
|
|
||||||
继承自 watchdog.events.FileSystemEventHandler,
|
继承自 watchdog.events.FileSystemEventHandler,
|
||||||
监听 base_plugins 目录下的文件变化,并触发插件重载。
|
监听 plugins 目录下的文件变化,并触发插件重载。
|
||||||
"""
|
"""
|
||||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||||
"""
|
"""
|
||||||
@@ -53,12 +55,14 @@ class PluginReloadHandler(FileSystemEventHandler):
|
|||||||
if file_system_event.is_directory:
|
if file_system_event.is_directory:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
src_path = file_system_event.src_path
|
||||||
|
|
||||||
# 只监控 py 文件
|
# 只监控 py 文件
|
||||||
if not file_system_event.src_path.endswith(".py"):
|
if not src_path.endswith(".py"):
|
||||||
return
|
return
|
||||||
|
|
||||||
# 过滤掉一些临时文件
|
# 过滤掉一些临时文件
|
||||||
if "__pycache__" in file_system_event.src_path:
|
if "__pycache__" in src_path or not src_path.startswith(PLUGIN_DIR):
|
||||||
return
|
return
|
||||||
|
|
||||||
# 简单的防抖动
|
# 简单的防抖动
|
||||||
@@ -68,13 +72,18 @@ class PluginReloadHandler(FileSystemEventHandler):
|
|||||||
|
|
||||||
self.last_reload_time = current_time
|
self.last_reload_time = current_time
|
||||||
|
|
||||||
logger.info(f"检测到文件变更: {file_system_event.src_path}")
|
# 从文件路径解析出模块名
|
||||||
logger.info("正在重载插件...")
|
# 例如: C:\path\to\project\plugins\bili_parser.py -> plugins.bili_parser
|
||||||
|
relative_path = os.path.relpath(src_path, ROOT_DIR)
|
||||||
|
module_name = os.path.splitext(relative_path.replace(os.sep, '.'))[0]
|
||||||
|
|
||||||
|
logger.info(f"检测到文件变更: {src_path}")
|
||||||
|
logger.info(f"准备重载插件: {module_name}...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用线程安全的方式在主事件循环中运行异步的插件加载函数
|
# 使用线程安全的方式在主事件循环中运行异步的插件重载函数
|
||||||
asyncio.run_coroutine_threadsafe(run_in_thread_pool(load_all_plugins), self.loop)
|
asyncio.run_coroutine_threadsafe(run_in_thread_pool(plugin_manager.reload_plugin, module_name), self.loop)
|
||||||
logger.success("插件重载完成")
|
logger.success(f"插件 {module_name} 重载任务已提交")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"重载失败: {e}")
|
logger.exception(f"重载失败: {e}")
|
||||||
|
|
||||||
@@ -88,8 +97,7 @@ async def main():
|
|||||||
2. 初始化 WebSocket 客户端
|
2. 初始化 WebSocket 客户端
|
||||||
3. 建立连接并保持运行
|
3. 建立连接并保持运行
|
||||||
"""
|
"""
|
||||||
# 首次加载插件
|
# 插件加载已移至 core.managers.__init__.py 中自动执行
|
||||||
await run_in_thread_pool(load_all_plugins)
|
|
||||||
|
|
||||||
# 初始化 Redis 连接
|
# 初始化 Redis 连接
|
||||||
await redis_manager.initialize()
|
await redis_manager.initialize()
|
||||||
@@ -114,11 +122,10 @@ async def main():
|
|||||||
logger.warning(f"插件目录不存在 {plugin_path}")
|
logger.warning(f"插件目录不存在 {plugin_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
websocket_client = WS()
|
|
||||||
|
|
||||||
# 初始化代码执行器
|
# 初始化代码执行器
|
||||||
code_executor = initialize_executor(websocket_client, config)
|
code_executor = initialize_executor(config)
|
||||||
websocket_client.bot.code_executor = code_executor # 将执行器实例附加到 bot.bot 对象上
|
|
||||||
|
websocket_client = WS(code_executor=code_executor)
|
||||||
|
|
||||||
# 启动代码执行器的后台 worker
|
# 启动代码执行器的后台 worker
|
||||||
logger.debug("[Main] 检查是否需要启动代码执行 Worker...")
|
logger.debug("[Main] 检查是否需要启动代码执行 Worker...")
|
||||||
|
|||||||
@@ -1,97 +1,23 @@
|
|||||||
from .events.base import OneBotEvent
|
"""
|
||||||
from .events.factory import EventFactory
|
Models 包
|
||||||
from .events.message import (
|
|
||||||
GroupMessageEvent,
|
|
||||||
MessageEvent,
|
|
||||||
MessageSegment,
|
|
||||||
PrivateMessageEvent,
|
|
||||||
)
|
|
||||||
from .events.meta import HeartbeatEvent, HeartbeatStatus, LifeCycleEvent, MetaEvent
|
|
||||||
from .events.notice import (
|
|
||||||
ClientStatus,
|
|
||||||
ClientStatusNoticeEvent,
|
|
||||||
EssenceNoticeEvent,
|
|
||||||
FriendAddNoticeEvent,
|
|
||||||
FriendRecallNoticeEvent,
|
|
||||||
GroupAdminNoticeEvent,
|
|
||||||
GroupBanNoticeEvent,
|
|
||||||
GroupCardNoticeEvent,
|
|
||||||
GroupDecreaseNoticeEvent,
|
|
||||||
GroupIncreaseNoticeEvent,
|
|
||||||
GroupRecallNoticeEvent,
|
|
||||||
GroupUploadFile,
|
|
||||||
GroupUploadNoticeEvent,
|
|
||||||
HonorNotifyEvent,
|
|
||||||
LuckyKingNotifyEvent,
|
|
||||||
NoticeEvent,
|
|
||||||
NotifyNoticeEvent,
|
|
||||||
OfflineFile,
|
|
||||||
OfflineFileNoticeEvent,
|
|
||||||
PokeNotifyEvent,
|
|
||||||
)
|
|
||||||
from .events.request import FriendRequestEvent, GroupRequestEvent, RequestEvent
|
|
||||||
from .objects import (
|
|
||||||
CurrentTalkative,
|
|
||||||
EssenceMessage,
|
|
||||||
FriendInfo,
|
|
||||||
GroupHonorInfo,
|
|
||||||
GroupInfo,
|
|
||||||
GroupMemberInfo,
|
|
||||||
HonorInfo,
|
|
||||||
LoginInfo,
|
|
||||||
Status,
|
|
||||||
StrangerInfo,
|
|
||||||
VersionInfo,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Alias for backward compatibility
|
导出常用的模型类,方便插件导入。
|
||||||
Event = OneBotEvent
|
"""
|
||||||
|
|
||||||
|
from .events.base import OneBotEvent
|
||||||
|
from .events.message import MessageEvent, GroupMessageEvent, PrivateMessageEvent
|
||||||
|
from .events.notice import NoticeEvent
|
||||||
|
from .events.request import RequestEvent
|
||||||
|
from .message import MessageSegment
|
||||||
|
from .sender import Sender
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"OneBotEvent",
|
||||||
|
"MessageEvent",
|
||||||
|
"GroupMessageEvent",
|
||||||
|
"PrivateMessageEvent",
|
||||||
|
"NoticeEvent",
|
||||||
|
"RequestEvent",
|
||||||
"MessageSegment",
|
"MessageSegment",
|
||||||
"Sender",
|
"Sender",
|
||||||
"OneBotEvent",
|
|
||||||
"Event",
|
|
||||||
"MessageEvent",
|
|
||||||
"PrivateMessageEvent",
|
|
||||||
"GroupMessageEvent",
|
|
||||||
"NoticeEvent",
|
|
||||||
"FriendAddNoticeEvent",
|
|
||||||
"FriendRecallNoticeEvent",
|
|
||||||
"GroupRecallNoticeEvent",
|
|
||||||
"GroupIncreaseNoticeEvent",
|
|
||||||
"GroupDecreaseNoticeEvent",
|
|
||||||
"GroupAdminNoticeEvent",
|
|
||||||
"GroupBanNoticeEvent",
|
|
||||||
"GroupUploadNoticeEvent",
|
|
||||||
"GroupUploadFile",
|
|
||||||
"NotifyNoticeEvent",
|
|
||||||
"PokeNotifyEvent",
|
|
||||||
"LuckyKingNotifyEvent",
|
|
||||||
"HonorNotifyEvent",
|
|
||||||
"GroupCardNoticeEvent",
|
|
||||||
"OfflineFileNoticeEvent",
|
|
||||||
"OfflineFile",
|
|
||||||
"ClientStatusNoticeEvent",
|
|
||||||
"ClientStatus",
|
|
||||||
"EssenceNoticeEvent",
|
|
||||||
"RequestEvent",
|
|
||||||
"FriendRequestEvent",
|
|
||||||
"GroupRequestEvent",
|
|
||||||
"MetaEvent",
|
|
||||||
"HeartbeatEvent",
|
|
||||||
"LifeCycleEvent",
|
|
||||||
"HeartbeatStatus",
|
|
||||||
"EventFactory",
|
|
||||||
"GroupInfo",
|
|
||||||
"GroupMemberInfo",
|
|
||||||
"FriendInfo",
|
|
||||||
"StrangerInfo",
|
|
||||||
"LoginInfo",
|
|
||||||
"VersionInfo",
|
|
||||||
"Status",
|
|
||||||
"EssenceMessage",
|
|
||||||
"GroupHonorInfo",
|
|
||||||
"CurrentTalkative",
|
|
||||||
"HonorInfo",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -70,7 +70,11 @@ class EventFactory:
|
|||||||
# 解析消息段
|
# 解析消息段
|
||||||
message_list = []
|
message_list = []
|
||||||
raw_message_list = data.get("message", [])
|
raw_message_list = data.get("message", [])
|
||||||
if isinstance(raw_message_list, list):
|
|
||||||
|
if isinstance(raw_message_list, str):
|
||||||
|
# 如果消息是字符串,将其视为纯文本消息段
|
||||||
|
message_list.append(MessageSegment.text(raw_message_list))
|
||||||
|
elif isinstance(raw_message_list, list):
|
||||||
for item in raw_message_list:
|
for item in raw_message_list:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
message_list.append(MessageSegment(type=item.get("type", ""), data=item.get("data", {})))
|
message_list.append(MessageSegment(type=item.get("type", ""), data=item.get("data", {})))
|
||||||
@@ -252,9 +256,18 @@ class EventFactory:
|
|||||||
card_new=data.get("card_new", ""),
|
card_new=data.get("card_new", ""),
|
||||||
card_old=data.get("card_old", "")
|
card_old=data.get("card_old", "")
|
||||||
)
|
)
|
||||||
|
elif notice_type == "group_card":
|
||||||
|
return GroupCardNoticeEvent(
|
||||||
|
**common_args,
|
||||||
|
notice_type=notice_type,
|
||||||
|
group_id=data.get("group_id", 0),
|
||||||
|
user_id=data.get("user_id", 0),
|
||||||
|
card_new=data.get("card_new", ""),
|
||||||
|
card_old=data.get("card_old", "")
|
||||||
|
)
|
||||||
elif notice_type == "offline_file":
|
elif notice_type == "offline_file":
|
||||||
file_data = data.get("file", {})
|
file_data = data.get("file", {})
|
||||||
file = OfflineFile(
|
offline_file = OfflineFile(
|
||||||
name=file_data.get("name", ""),
|
name=file_data.get("name", ""),
|
||||||
size=file_data.get("size", 0),
|
size=file_data.get("size", 0),
|
||||||
url=file_data.get("url", "")
|
url=file_data.get("url", "")
|
||||||
@@ -263,7 +276,7 @@ class EventFactory:
|
|||||||
**common_args,
|
**common_args,
|
||||||
notice_type=notice_type,
|
notice_type=notice_type,
|
||||||
user_id=data.get("user_id", 0),
|
user_id=data.get("user_id", 0),
|
||||||
file=file
|
file=offline_file
|
||||||
)
|
)
|
||||||
elif notice_type == "client_status":
|
elif notice_type == "client_status":
|
||||||
client_data = data.get("client", {})
|
client_data = data.get("client", {})
|
||||||
|
|||||||
@@ -4,9 +4,9 @@
|
|||||||
定义了消息相关的事件类,包括 MessageEvent, PrivateMessageEvent, GroupMessageEvent。
|
定义了消息相关的事件类,包括 MessageEvent, PrivateMessageEvent, GroupMessageEvent。
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from core.managers.permission_manager import ADMIN, OP, USER
|
from core.permission import Permission
|
||||||
from models.message import MessageSegment
|
from models.message import MessageSegment
|
||||||
from models.sender import Sender
|
from models.sender import Sender
|
||||||
from .base import OneBotEvent, EventType
|
from .base import OneBotEvent, EventType
|
||||||
@@ -34,9 +34,9 @@ class MessageEvent(OneBotEvent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 权限级别常量,用于装饰器参数
|
# 权限级别常量,用于装饰器参数
|
||||||
ADMIN = ADMIN
|
ADMIN = Permission.ADMIN
|
||||||
OP = OP
|
OP = Permission.OP
|
||||||
USER = USER
|
USER = Permission.USER
|
||||||
|
|
||||||
message_type: str
|
message_type: str
|
||||||
"""消息类型: private (私聊), group (群聊)"""
|
"""消息类型: private (私聊), group (群聊)"""
|
||||||
@@ -70,7 +70,7 @@ class MessageEvent(OneBotEvent):
|
|||||||
def post_type(self) -> str:
|
def post_type(self) -> str:
|
||||||
return EventType.MESSAGE
|
return EventType.MESSAGE
|
||||||
|
|
||||||
async def reply(self, message: str, auto_escape: bool = False):
|
async def reply(self, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False):
|
||||||
"""
|
"""
|
||||||
回复消息(抽象方法,由子类实现)
|
回复消息(抽象方法,由子类实现)
|
||||||
|
|
||||||
@@ -86,7 +86,7 @@ class PrivateMessageEvent(MessageEvent):
|
|||||||
私聊消息事件
|
私聊消息事件
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def reply(self, message: str, auto_escape: bool = False):
|
async def reply(self, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False):
|
||||||
"""
|
"""
|
||||||
回复私聊消息
|
回复私聊消息
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ class GroupMessageEvent(MessageEvent):
|
|||||||
anonymous: Optional[Anonymous] = None
|
anonymous: Optional[Anonymous] = None
|
||||||
"""匿名信息"""
|
"""匿名信息"""
|
||||||
|
|
||||||
async def reply(self, message: str, auto_escape: bool = False):
|
async def reply(self, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False):
|
||||||
"""
|
"""
|
||||||
回复群聊消息
|
回复群聊消息
|
||||||
|
|
||||||
|
|||||||
@@ -63,5 +63,5 @@ class LifeCycleEvent(MetaEvent):
|
|||||||
meta_event_type: str = 'lifecycle'
|
meta_event_type: str = 'lifecycle'
|
||||||
"""元事件类型:生命周期事件"""
|
"""元事件类型:生命周期事件"""
|
||||||
|
|
||||||
sub_type: LifeCycleSubType = LifeCycleSubType.ENABLE
|
sub_type: str = LifeCycleSubType.ENABLE
|
||||||
"""子类型:启用、禁用、连接"""
|
"""子类型:启用、禁用、连接"""
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional, List
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@@ -23,7 +23,7 @@ class MessageSegment:
|
|||||||
data: Dict[str, Any]
|
data: Dict[str, Any]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def plain_text(self) -> str:
|
||||||
"""
|
"""
|
||||||
当消息段类型为 'text' 时,快速获取其文本内容。
|
当消息段类型为 'text' 时,快速获取其文本内容。
|
||||||
|
|
||||||
@@ -32,6 +32,19 @@ class MessageSegment:
|
|||||||
"""
|
"""
|
||||||
return self.data.get("text", "") if self.type == "text" else ""
|
return self.data.get("text", "") if self.type == "text" else ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def text(text: str) -> "MessageSegment":
|
||||||
|
"""
|
||||||
|
创建一个文本消息段。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): 文本内容。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MessageSegment: 一个类型为 'text' 的消息段对象。
|
||||||
|
"""
|
||||||
|
return MessageSegment(type="text", data={"text": text})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_url(self) -> str:
|
def image_url(self) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -76,7 +89,7 @@ class MessageSegment:
|
|||||||
return self.data.get("file", "")
|
return self.data.get("file", "")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def is_at(self, user_id: int = None) -> bool:
|
def is_at(self, user_id: Optional[int] = None) -> bool:
|
||||||
"""
|
"""
|
||||||
检查当前消息段是否是一个 'at' (提及) 消息段。
|
检查当前消息段是否是一个 'at' (提及) 消息段。
|
||||||
|
|
||||||
@@ -93,16 +106,52 @@ class MessageSegment:
|
|||||||
return True
|
return True
|
||||||
return str(self.data.get("qq")) == str(user_id)
|
return str(self.data.get("qq")) == str(user_id)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""
|
||||||
|
返回消息段的 CQ 码字符串表示。
|
||||||
|
"""
|
||||||
|
if self.type == "text":
|
||||||
|
return self.data.get("text", "")
|
||||||
|
|
||||||
|
params = ",".join([f"{k}={v}" for k, v in self.data.items()])
|
||||||
|
if params:
|
||||||
|
return f"[CQ:{self.type},{params}]"
|
||||||
|
return f"[CQ:{self.type}]"
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""
|
"""
|
||||||
返回消息段对象的字符串表示形式,便于调试。
|
返回消息段对象的字符串表示形式,便于调试。
|
||||||
"""
|
"""
|
||||||
return f"[MS:{self.type}:{self.data}]"
|
return f"[MS:{self.type}:{self.data}]"
|
||||||
|
|
||||||
|
def __add__(self, other: Any) -> "List[MessageSegment]":
|
||||||
|
"""
|
||||||
|
支持消息段相加,返回消息段列表。
|
||||||
|
"""
|
||||||
|
if isinstance(other, MessageSegment):
|
||||||
|
return [self, other]
|
||||||
|
elif isinstance(other, str):
|
||||||
|
return [self, MessageSegment.text(other)]
|
||||||
|
elif isinstance(other, list):
|
||||||
|
return [self] + other
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __radd__(self, other: Any) -> "List[MessageSegment]":
|
||||||
|
"""
|
||||||
|
支持反向相加。
|
||||||
|
"""
|
||||||
|
if isinstance(other, MessageSegment):
|
||||||
|
return [other, self]
|
||||||
|
elif isinstance(other, str):
|
||||||
|
return [MessageSegment.text(other), self]
|
||||||
|
elif isinstance(other, list):
|
||||||
|
return other + [self]
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
# --- 快捷构造方法 ---
|
# --- 快捷构造方法 ---
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def text(text: str) -> "MessageSegment": # noqa: F811
|
def from_text(text: str) -> "MessageSegment":
|
||||||
"""
|
"""
|
||||||
创建一个文本消息段。
|
创建一个文本消息段。
|
||||||
|
|
||||||
@@ -115,7 +164,7 @@ class MessageSegment:
|
|||||||
return MessageSegment(type="text", data={"text": text})
|
return MessageSegment(type="text", data={"text": text})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def at(user_id: int | str, name: str = None) -> "MessageSegment":
|
def at(user_id: int | str, name: Optional[str] = None) -> "MessageSegment":
|
||||||
"""
|
"""
|
||||||
创建一个 @某人 的消息段。
|
创建一个 @某人 的消息段。
|
||||||
|
|
||||||
@@ -132,7 +181,7 @@ class MessageSegment:
|
|||||||
return MessageSegment(type="at", data=data)
|
return MessageSegment(type="at", data=data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def image(file: str, image_type: str = None, cache: bool = True, proxy: bool = True, timeout: int = None, sub_type: int = None) -> "MessageSegment":
|
def image(file: str, image_type: Optional[str] = None, cache: bool = True, proxy: bool = True, timeout: Optional[int] = None, sub_type: Optional[int] = None) -> "MessageSegment":
|
||||||
"""
|
"""
|
||||||
创建一个图片消息段。
|
创建一个图片消息段。
|
||||||
|
|
||||||
@@ -194,7 +243,7 @@ class MessageSegment:
|
|||||||
"""
|
"""
|
||||||
return MessageSegment(type="xml", data={"data": data})
|
return MessageSegment(type="xml", data={"data": data})
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def share(url: str, title: str, content: str = None, image: str = None) -> "MessageSegment":
|
def share(url: str, title: str, content: Optional[str] = None, image: Optional[str] = None) -> "MessageSegment":
|
||||||
"""
|
"""
|
||||||
创建一个分享消息段。
|
创建一个分享消息段。
|
||||||
|
|
||||||
@@ -227,7 +276,7 @@ class MessageSegment:
|
|||||||
"""
|
"""
|
||||||
return MessageSegment(type="music", data={"type": type, "id": id})
|
return MessageSegment(type="music", data={"type": type, "id": id})
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def music_custom(url: str, audio: str, title: str, content: str = None, image: str = None) -> "MessageSegment":
|
def music_custom(url: str, audio: str, title: str, content: Optional[str] = None, image: Optional[str] = None) -> "MessageSegment":
|
||||||
"""
|
"""
|
||||||
创建一个自定义音乐消息段。
|
创建一个自定义音乐消息段。
|
||||||
|
|
||||||
@@ -248,7 +297,7 @@ class MessageSegment:
|
|||||||
data["image"] = image
|
data["image"] = image
|
||||||
return MessageSegment(type="music", data={"type": "custom", **data})
|
return MessageSegment(type="music", data={"type": "custom", **data})
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def record(file: str, magic: bool = False, cache: bool = True, proxy: bool = True, timeout: int = None) -> "MessageSegment":
|
def record(file: str, magic: bool = False, cache: bool = True, proxy: bool = True, timeout: Optional[int] = None) -> "MessageSegment":
|
||||||
"""
|
"""
|
||||||
创建一个语音消息段。
|
创建一个语音消息段。
|
||||||
|
|
||||||
@@ -267,7 +316,7 @@ class MessageSegment:
|
|||||||
data["timeout"] = str(timeout)
|
data["timeout"] = str(timeout)
|
||||||
return MessageSegment(type="record", data=data)
|
return MessageSegment(type="record", data=data)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def video(file: str, cover: str = None, c: int = 2) -> "MessageSegment":
|
def video(file: str, cover: Optional[str] = None, c: int = 2) -> "MessageSegment":
|
||||||
"""
|
"""
|
||||||
创建一个视频消息段。
|
创建一个视频消息段。
|
||||||
|
|
||||||
@@ -297,17 +346,17 @@ class MessageSegment:
|
|||||||
return MessageSegment(type="file", data={"file": file})
|
return MessageSegment(type="file", data={"file": file})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def reply(message_id: str) -> "MessageSegment":
|
def reply(message_id: str | int) -> "MessageSegment":
|
||||||
"""
|
"""
|
||||||
创建一个回复消息段。
|
创建一个回复消息段。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_id (str): 被回复的消息 ID。
|
message_id (str | int): 被回复的消息 ID。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MessageSegment: 一个类型为 'reply' 的消息段对象。
|
MessageSegment: 一个类型为 'reply' 的消息段对象。
|
||||||
"""
|
"""
|
||||||
return MessageSegment(type="reply", data={"id": message_id})
|
return MessageSegment(type="reply", data={"id": str(message_id)})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def rps() -> "MessageSegment":
|
def rps() -> "MessageSegment":
|
||||||
|
|||||||
0
plugins/__init__.py
Normal file
0
plugins/__init__.py
Normal file
130
plugins/admin.py
130
plugins/admin.py
@@ -1,74 +1,94 @@
|
|||||||
"""
|
from core.handlers.event_handler import MessageHandler
|
||||||
管理员管理插件
|
from core.managers import command_manager, permission_manager
|
||||||
|
from core.permission import Permission
|
||||||
提供通过聊天指令动态添加或移除机器人管理员的功能。
|
|
||||||
"""
|
|
||||||
from core.bot import Bot
|
|
||||||
from core.managers.command_manager import matcher
|
|
||||||
from core.managers.admin_manager import admin_manager
|
|
||||||
from models.events.message import MessageEvent
|
from models.events.message import MessageEvent
|
||||||
|
|
||||||
|
# 更新插件元信息以包含OP管理
|
||||||
__plugin_meta__ = {
|
__plugin_meta__ = {
|
||||||
"name": "管理员管理",
|
"name": "权限管理",
|
||||||
"description": "管理机器人的全局管理员",
|
"description": "管理机器人的管理员和操作员",
|
||||||
"usage": (
|
"usage": (
|
||||||
"/admin list - 列出所有管理员\n"
|
"/admin list - 列出所有管理员和操作员\n"
|
||||||
"/admin add <QQ号> - 添加管理员\n"
|
"/admin add_admin <QQ号> - 添加管理员\n"
|
||||||
"/admin remove <QQ号> - 移除管理员"
|
"/admin remove_admin <QQ号> - 移除管理员\n"
|
||||||
|
"/admin add_op <QQ号> - 添加操作员\n"
|
||||||
|
"/admin remove_op <QQ号> - 移除操作员"
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@matcher.command("admin", permission=MessageEvent.ADMIN)
|
@command_manager.command("admin", permission=Permission.ADMIN)
|
||||||
async def admin_command_handler(bot: Bot, event: MessageEvent, args: list[str]):
|
async def admin_management(event: MessageEvent, args: str):
|
||||||
"""
|
"""
|
||||||
处理 /admin 指令
|
处理所有权限管理相关的命令。
|
||||||
|
|
||||||
:param bot: Bot 实例
|
|
||||||
:param event: 消息事件实例
|
|
||||||
:param args: 指令参数列表
|
|
||||||
"""
|
"""
|
||||||
if not args:
|
parts = args.split()
|
||||||
await event.reply(__plugin_meta__["usage"])
|
if not parts:
|
||||||
|
await event.reply(f"用法不正确。\n\n{__plugin_meta__['usage']}")
|
||||||
return
|
return
|
||||||
|
|
||||||
action = args[0].lower()
|
subcommand = parts[0].lower()
|
||||||
|
|
||||||
if action == "list":
|
if subcommand == "list":
|
||||||
admins = await admin_manager.get_all_admins()
|
await list_permissions(event)
|
||||||
if not admins:
|
|
||||||
await event.reply("当前没有设置任何管理员。")
|
|
||||||
return
|
|
||||||
|
|
||||||
admin_list_text = "\n".join(str(admin_id) for admin_id in admins)
|
|
||||||
await event.reply(f"当前管理员列表 ({len(admins)}):\n{admin_list_text}")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if action in ("add", "remove"):
|
# 处理需要QQ号的命令
|
||||||
if len(args) < 2 or not args[1].isdigit():
|
if len(parts) < 2 or not parts[1].isdigit():
|
||||||
await event.reply("参数错误,请提供一个有效的 QQ 号。\n示例: /admin add 123456")
|
await event.reply(f"请提供有效的用户QQ号。\n用法: /admin {subcommand} <QQ号>")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = int(args[1])
|
target_user_id = int(parts[1])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
await event.reply("无效的 QQ 号,请输入纯数字。")
|
await event.reply("无效的QQ号。")
|
||||||
return
|
return
|
||||||
|
|
||||||
if action == "add":
|
# 安全检查
|
||||||
success = await admin_manager.add_admin(user_id)
|
if target_user_id == event.user_id:
|
||||||
if success:
|
await event.reply("你不能操作自己的权限。")
|
||||||
await event.reply(f"成功添加管理员: {user_id}")
|
return
|
||||||
else:
|
if target_user_id == event.self_id:
|
||||||
await event.reply(f"管理员 {user_id} 已存在,无需重复添加。")
|
await event.reply("你不能操作机器人自身的权限。")
|
||||||
return
|
return
|
||||||
|
|
||||||
elif action == "remove":
|
# 根据子命令分发
|
||||||
success = await admin_manager.remove_admin(user_id)
|
if subcommand == "add_admin":
|
||||||
if success:
|
permission_manager.set_user_permission(target_user_id, Permission.ADMIN)
|
||||||
await event.reply(f"成功移除管理员: {user_id}")
|
await event.reply(f"已成功添加管理员:{target_user_id}")
|
||||||
else:
|
elif subcommand == "remove_admin":
|
||||||
await event.reply(f"管理员 {user_id} 不存在。")
|
permission_manager.set_user_permission(target_user_id, Permission.USER)
|
||||||
return
|
await event.reply(f"已成功移除管理员:{target_user_id}")
|
||||||
|
elif subcommand == "add_op":
|
||||||
|
permission_manager.set_user_permission(target_user_id, Permission.OP)
|
||||||
|
await event.reply(f"已成功添加操作员:{target_user_id}")
|
||||||
|
elif subcommand == "remove_op":
|
||||||
|
permission_manager.set_user_permission(target_user_id, Permission.USER)
|
||||||
|
await event.reply(f"已成功移除操作员:{target_user_id}")
|
||||||
|
else:
|
||||||
|
await event.reply(f"未知的子命令 '{subcommand}'。\n\n{__plugin_meta__['usage']}")
|
||||||
|
|
||||||
await event.reply(f"未知的指令: {action}\n\n{__plugin_meta__['usage']}")
|
|
||||||
|
async def list_permissions(event: MessageEvent):
|
||||||
|
"""
|
||||||
|
列出所有具有特殊权限(管理员和操作员)的用户。
|
||||||
|
"""
|
||||||
|
permissions = permission_manager.get_all_user_permissions()
|
||||||
|
if not permissions:
|
||||||
|
await event.reply("当前没有配置任何特殊权限的用户。")
|
||||||
|
return
|
||||||
|
|
||||||
|
admins = {uid for uid, p in permissions.items() if p == 'admin'}
|
||||||
|
ops = {uid for uid, p in permissions.items() if p == 'op'}
|
||||||
|
|
||||||
|
reply_msg = "当前权限列表:\n"
|
||||||
|
if admins:
|
||||||
|
reply_msg += "--- 管理员 ---\n"
|
||||||
|
for user_id in admins:
|
||||||
|
reply_msg += f"- {user_id}\n"
|
||||||
|
if ops:
|
||||||
|
reply_msg += "--- 操作员 ---\n"
|
||||||
|
for user_id in ops:
|
||||||
|
reply_msg += f"- {user_id}\n"
|
||||||
|
|
||||||
|
await event.reply(reply_msg.strip())
|
||||||
|
|||||||
@@ -3,12 +3,16 @@ import re
|
|||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any, Union
|
||||||
|
from cachetools import TTLCache
|
||||||
|
|
||||||
from core.utils.logger import logger
|
from core.utils.logger import logger
|
||||||
from core.managers.command_manager import matcher
|
from core.managers.command_manager import matcher
|
||||||
from models import MessageEvent, MessageSegment
|
from models import MessageEvent, MessageSegment
|
||||||
|
|
||||||
|
# 创建一个TTL缓存,最大容量100,缓存时间10秒
|
||||||
|
processed_messages: TTLCache[int, bool] = TTLCache(maxsize=100, ttl=10)
|
||||||
|
|
||||||
__plugin_meta__ = {
|
__plugin_meta__ = {
|
||||||
"name": "bili_parser",
|
"name": "bili_parser",
|
||||||
"description": "自动解析B站分享卡片,提取视频封面和播放量等信息。",
|
"description": "自动解析B站分享卡片,提取视频封面和播放量等信息。",
|
||||||
@@ -52,10 +56,14 @@ def parse_video_info(video_url: str) -> Optional[Dict[str, Any]]:
|
|||||||
soup = BeautifulSoup(response.text, 'html.parser')
|
soup = BeautifulSoup(response.text, 'html.parser')
|
||||||
|
|
||||||
script_tag = soup.find('script', text=re.compile('window.__INITIAL_STATE__'))
|
script_tag = soup.find('script', text=re.compile('window.__INITIAL_STATE__'))
|
||||||
if not script_tag:
|
if not script_tag or not script_tag.string:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
json_str = re.search(r'window\.__INITIAL_STATE__\s*=\s*(\{.*?\});', script_tag.string).group(1)
|
match = re.search(r'window\.__INITIAL_STATE__\s*=\s*(\{.*?\});', script_tag.string)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
|
||||||
|
json_str = match.group(1)
|
||||||
data = json.loads(json_str)
|
data = json.loads(json_str)
|
||||||
|
|
||||||
video_data = data.get('videoData', {})
|
video_data = data.get('videoData', {})
|
||||||
@@ -121,6 +129,15 @@ async def handle_bili_share(event: MessageEvent):
|
|||||||
处理消息,检测B站分享链接(JSON卡片或文本链接)并进行解析。
|
处理消息,检测B站分享链接(JSON卡片或文本链接)并进行解析。
|
||||||
:param event: 消息事件对象
|
:param event: 消息事件对象
|
||||||
"""
|
"""
|
||||||
|
# 消息去重
|
||||||
|
if event.message_id in processed_messages:
|
||||||
|
return
|
||||||
|
processed_messages[event.message_id] = True
|
||||||
|
|
||||||
|
# 忽略机器人自己发送的消息,防止无限循环
|
||||||
|
if event.user_id == event.self_id:
|
||||||
|
return
|
||||||
|
|
||||||
url_to_process = None
|
url_to_process = None
|
||||||
|
|
||||||
# 1. 优先解析JSON卡片中的短链接
|
# 1. 优先解析JSON卡片中的短链接
|
||||||
@@ -176,6 +193,7 @@ async def process_bili_link(event: MessageEvent, url: str):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 检查视频时长
|
# 检查视频时长
|
||||||
|
video_message: Union[str, MessageSegment]
|
||||||
if video_info['duration'] > 300: # 5分钟 = 300秒
|
if video_info['duration'] > 300: # 5分钟 = 300秒
|
||||||
video_message = "视频时长超过5分钟,不进行解析。"
|
video_message = "视频时长超过5分钟,不进行解析。"
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -8,8 +8,8 @@
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
from core.managers.command_manager import matcher
|
from core.managers.command_manager import matcher
|
||||||
from models import MessageEvent, PrivateMessageEvent
|
from models.events.message import MessageEvent, PrivateMessageEvent
|
||||||
from core.managers.permission_manager import ADMIN
|
from core.permission import Permission
|
||||||
from core.utils.logger import logger
|
from core.utils.logger import logger
|
||||||
|
|
||||||
# --- 会话状态管理 ---
|
# --- 会话状态管理 ---
|
||||||
@@ -24,7 +24,7 @@ def cleanup_session(user_id: int):
|
|||||||
del broadcast_sessions[user_id]
|
del broadcast_sessions[user_id]
|
||||||
logger.info(f"[Broadcast] 会话 {user_id} 已超时,自动取消。")
|
logger.info(f"[Broadcast] 会话 {user_id} 已超时,自动取消。")
|
||||||
|
|
||||||
@matcher.command("broadcast", "广播", permission=ADMIN)
|
@matcher.command("broadcast", "广播", permission=Permission.ADMIN)
|
||||||
async def broadcast_start(event: MessageEvent):
|
async def broadcast_start(event: MessageEvent):
|
||||||
"""
|
"""
|
||||||
广播指令的入口,启动一个等待用户消息的会话。
|
广播指令的入口,启动一个等待用户消息的会话。
|
||||||
@@ -92,7 +92,7 @@ async def handle_broadcast_content(event: MessageEvent):
|
|||||||
nodes_to_send = [
|
nodes_to_send = [
|
||||||
bot.build_forward_node(
|
bot.build_forward_node(
|
||||||
user_id=event.user_id,
|
user_id=event.user_id,
|
||||||
nickname=event.sender.nickname,
|
nickname=event.sender.nickname if event.sender else "未知用户",
|
||||||
message=message_to_broadcast
|
message=message_to_broadcast
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,35 +1,24 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import html
|
import html
|
||||||
import textwrap
|
import textwrap
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import html
|
|
||||||
import textwrap
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from core.managers.command_manager import matcher
|
from core.managers.command_manager import matcher
|
||||||
from models import MessageEvent
|
from models.events.message import MessageEvent
|
||||||
from core.managers.permission_manager import ADMIN
|
from core.permission import Permission
|
||||||
from core.utils.logger import logger
|
from core.utils.logger import logger
|
||||||
|
|
||||||
__plugin_meta__ = {
|
__plugin_meta__ = {
|
||||||
"name": "Python 代码执行",
|
"name": "Python 代码执行",
|
||||||
"description": "在安全的沙箱环境中执行 Python 代码片段,支持单行、多行和转发回复。",
|
"description": "在安全的沙箱环境中执行 Python 代码片段,支持单行、多行和转发回复。",
|
||||||
"usage": "/py <单行代码>\n/code_py <单行代码>\n/py (进入多行输入模式)",
|
"usage": "/py <单行代码>\n/code_py <单行代码>\n/py (进入多行输入模式)",
|
||||||
"name": "Python 代码执行",
|
|
||||||
"description": "在安全的沙箱环境中执行 Python 代码片段,支持单行、多行和转发回复。",
|
|
||||||
"usage": "/py <单行代码>\n/code_py <单行代码>\n/py (进入多行输入模式)",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# --- 会话状态管理 ---
|
# --- 会话状态管理 ---
|
||||||
# 结构: {(user_id, group_id): asyncio.TimerHandle}
|
# 结构: {(user_id, group_id): asyncio.TimerHandle}
|
||||||
multi_line_sessions: Dict[tuple, asyncio.TimerHandle] = {}
|
multi_line_sessions: Dict[tuple, asyncio.TimerHandle] = {}
|
||||||
|
|
||||||
async def reply_as_forward(event: MessageEvent, input_code: str, output_result: str):
|
|
||||||
# --- 会话状态管理 ---
|
|
||||||
# 结构: {(user_id, group_id): asyncio.TimerHandle}
|
|
||||||
multi_line_sessions: Dict[tuple, asyncio.TimerHandle] = {}
|
|
||||||
|
|
||||||
async def reply_as_forward(event: MessageEvent, input_code: str, output_result: str):
|
async def reply_as_forward(event: MessageEvent, input_code: str, output_result: str):
|
||||||
"""
|
"""
|
||||||
将输入和输出打包成转发消息进行回复。
|
将输入和输出打包成转发消息进行回复。
|
||||||
@@ -41,35 +30,7 @@ async def reply_as_forward(event: MessageEvent, input_code: str, output_result:
|
|||||||
nodes = [
|
nodes = [
|
||||||
bot.build_forward_node(
|
bot.build_forward_node(
|
||||||
user_id=event.user_id,
|
user_id=event.user_id,
|
||||||
nickname=event.sender.nickname or str(event.user_id),
|
nickname=event.sender.nickname if event.sender else str(event.user_id),
|
||||||
message=f"--- Your Code ---\n{input_code}"
|
|
||||||
),
|
|
||||||
bot.build_forward_node(
|
|
||||||
user_id=event.self_id,
|
|
||||||
nickname="Code Executor",
|
|
||||||
message=f"--- Execution Result ---\n{output_result}"
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 2. 发送合并转发消息
|
|
||||||
await bot.send_forwarded_messages(event, nodes)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[code_py] 发送转发消息失败: {e}")
|
|
||||||
# 降级为普通消息回复
|
|
||||||
await event.reply(f"--- 你的代码 ---\n{input_code}\n--- 执行结果 ---\n{output_result}")
|
|
||||||
|
|
||||||
async def execute_code(event: MessageEvent, code: str):
|
|
||||||
将输入和输出打包成转发消息进行回复。
|
|
||||||
参考 forward_test.py 的实现,兼容私聊和群聊。
|
|
||||||
"""
|
|
||||||
bot = event.bot
|
|
||||||
|
|
||||||
# 1. 构建消息节点列表
|
|
||||||
nodes = [
|
|
||||||
bot.build_forward_node(
|
|
||||||
user_id=event.user_id,
|
|
||||||
nickname=event.sender.nickname or str(event.user_id),
|
|
||||||
message=f"--- Your Code ---\n{input_code}"
|
message=f"--- Your Code ---\n{input_code}"
|
||||||
),
|
),
|
||||||
bot.build_forward_node(
|
bot.build_forward_node(
|
||||||
@@ -90,7 +51,6 @@ async def execute_code(event: MessageEvent, code: str):
|
|||||||
async def execute_code(event: MessageEvent, code: str):
|
async def execute_code(event: MessageEvent, code: str):
|
||||||
"""
|
"""
|
||||||
核心代码执行逻辑。
|
核心代码执行逻辑。
|
||||||
核心代码执行逻辑。
|
|
||||||
"""
|
"""
|
||||||
code_executor = getattr(event.bot, 'code_executor', None)
|
code_executor = getattr(event.bot, 'code_executor', None)
|
||||||
if not code_executor or not code_executor.docker_client:
|
if not code_executor or not code_executor.docker_client:
|
||||||
@@ -137,74 +97,15 @@ def normalize_code(code: str) -> str:
|
|||||||
return code.strip()
|
return code.strip()
|
||||||
|
|
||||||
|
|
||||||
@matcher.command("py", "python", "code_py", permission=ADMIN)
|
@matcher.command("py", "python", "code_py", permission=Permission.ADMIN)
|
||||||
async def code_py_main(event: MessageEvent, args: list[str]):
|
|
||||||
code_executor = getattr(event.bot, 'code_executor', None)
|
|
||||||
if not code_executor or not code_executor.docker_client:
|
|
||||||
await event.reply("代码执行服务当前不可用,请检查 Docker 连接配置。")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 修改 add_task,让它能直接接收回复函数
|
|
||||||
await code_executor.add_task(
|
|
||||||
code,
|
|
||||||
lambda result: reply_as_forward(event, code, result)
|
|
||||||
)
|
|
||||||
await event.reply("代码已提交至沙箱执行队列,请稍候...")
|
|
||||||
|
|
||||||
def cleanup_session(session_key: tuple):
|
|
||||||
"""
|
|
||||||
清理超时的会话。
|
|
||||||
"""
|
|
||||||
if session_key in multi_line_sessions:
|
|
||||||
del multi_line_sessions[session_key]
|
|
||||||
logger.info(f"[code_py] 会话 {session_key} 已超时,自动取消。")
|
|
||||||
|
|
||||||
def normalize_code(code: str) -> str:
|
|
||||||
"""
|
|
||||||
规范化用户输入的 Python 代码字符串。
|
|
||||||
|
|
||||||
主要处理两个问题:
|
|
||||||
1. 对消息中可能存在的 HTML 实体进行解码 (e.g., [ -> [)。
|
|
||||||
2. 移除整个代码块的公共前导缩进,以修复因复制粘贴导致的多余缩进。
|
|
||||||
|
|
||||||
:param code: 原始代码字符串。
|
|
||||||
:return: 规范化后的代码字符串。
|
|
||||||
"""
|
|
||||||
# 1. 解码 HTML 实体
|
|
||||||
code = html.unescape(code)
|
|
||||||
|
|
||||||
# 2. 移除公共前导缩进
|
|
||||||
try:
|
|
||||||
code = textwrap.dedent(code)
|
|
||||||
except Exception:
|
|
||||||
# 在某些情况下(例如,不一致的缩进),dedent 可能会失败,
|
|
||||||
# 但我们不希望因此中断流程,所以捕获异常并继续。
|
|
||||||
pass
|
|
||||||
|
|
||||||
return code.strip()
|
|
||||||
|
|
||||||
|
|
||||||
@matcher.command("py", "python", "code_py", permission=ADMIN)
|
|
||||||
async def code_py_main(event: MessageEvent, args: list[str]):
|
async def code_py_main(event: MessageEvent, args: list[str]):
|
||||||
"""
|
"""
|
||||||
/py 命令的主入口。
|
/py 命令的主入口。
|
||||||
- 如果有参数,直接执行。
|
- 如果有参数,直接执行。
|
||||||
- 如果没有参数,开启多行输入模式。
|
- 如果没有参数,开启多行输入模式。
|
||||||
/py 命令的主入口。
|
|
||||||
- 如果有参数,直接执行。
|
|
||||||
- 如果没有参数,开启多行输入模式。
|
|
||||||
"""
|
"""
|
||||||
code_to_run = " ".join(args)
|
code_to_run = " ".join(args)
|
||||||
|
|
||||||
if code_to_run:
|
|
||||||
# 单行模式,对代码进行规范化处理
|
|
||||||
normalized_code = normalize_code(code_to_run)
|
|
||||||
if not normalized_code:
|
|
||||||
await event.reply("代码为空或格式错误,请输入有效的代码。")
|
|
||||||
return
|
|
||||||
await execute_code(event, normalized_code)
|
|
||||||
code_to_run = " ".join(args)
|
|
||||||
|
|
||||||
if code_to_run:
|
if code_to_run:
|
||||||
# 单行模式,对代码进行规范化处理
|
# 单行模式,对代码进行规范化处理
|
||||||
normalized_code = normalize_code(code_to_run)
|
normalized_code = normalize_code(code_to_run)
|
||||||
@@ -231,24 +132,6 @@ async def code_py_main(event: MessageEvent, args: list[str]):
|
|||||||
session_key
|
session_key
|
||||||
)
|
)
|
||||||
multi_line_sessions[session_key] = timeout_handler
|
multi_line_sessions[session_key] = timeout_handler
|
||||||
# 多行模式
|
|
||||||
# 使用 getattr 兼容私聊和群聊
|
|
||||||
session_key = (event.user_id, getattr(event, 'group_id', 'private'))
|
|
||||||
|
|
||||||
# 如果上一个会话的超时任务还在,先取消它
|
|
||||||
if session_key in multi_line_sessions:
|
|
||||||
multi_line_sessions[session_key].cancel()
|
|
||||||
|
|
||||||
await event.reply("已进入多行代码输入模式,请直接发送你的代码。\n(60秒内无操作将自动取消)")
|
|
||||||
|
|
||||||
# 设置 60 秒超时
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
timeout_handler = loop.call_later(
|
|
||||||
60,
|
|
||||||
cleanup_session,
|
|
||||||
session_key
|
|
||||||
)
|
|
||||||
multi_line_sessions[session_key] = timeout_handler
|
|
||||||
|
|
||||||
@matcher.on_message()
|
@matcher.on_message()
|
||||||
async def handle_multi_line_code(event: MessageEvent):
|
async def handle_multi_line_code(event: MessageEvent):
|
||||||
@@ -265,26 +148,6 @@ async def handle_multi_line_code(event: MessageEvent):
|
|||||||
# 对多行代码进行规范化处理
|
# 对多行代码进行规范化处理
|
||||||
normalized_code = normalize_code(event.raw_message)
|
normalized_code = normalize_code(event.raw_message)
|
||||||
|
|
||||||
if not normalized_code:
|
|
||||||
await event.reply("捕获到的代码为空或格式错误,已取消输入。")
|
|
||||||
return
|
|
||||||
|
|
||||||
await execute_code(event, normalized_code)
|
|
||||||
return True # 消费事件,防止其他处理器响应
|
|
||||||
async def handle_multi_line_code(event: MessageEvent):
|
|
||||||
"""
|
|
||||||
通用消息处理器,用于捕获多行模式下的代码输入。
|
|
||||||
"""
|
|
||||||
# 使用 getattr 兼容私聊和群聊
|
|
||||||
session_key = (event.user_id, getattr(event, 'group_id', 'private'))
|
|
||||||
if session_key in multi_line_sessions:
|
|
||||||
# 取消超时任务
|
|
||||||
multi_line_sessions[session_key].cancel()
|
|
||||||
del multi_line_sessions[session_key]
|
|
||||||
|
|
||||||
# 对多行代码进行规范化处理
|
|
||||||
normalized_code = normalize_code(event.raw_message)
|
|
||||||
|
|
||||||
if not normalized_code:
|
if not normalized_code:
|
||||||
await event.reply("捕获到的代码为空或格式错误,已取消输入。")
|
await event.reply("捕获到的代码为空或格式错误,已取消输入。")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ Echo 与交互插件
|
|||||||
"""
|
"""
|
||||||
from core.managers.command_manager import matcher
|
from core.managers.command_manager import matcher
|
||||||
from core.bot import Bot
|
from core.bot import Bot
|
||||||
from models import MessageEvent
|
from models.events.message import MessageEvent
|
||||||
|
|
||||||
__plugin_meta__ = {
|
__plugin_meta__ = {
|
||||||
"name": "echo",
|
"name": "echo",
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
"""
|
"""
|
||||||
from core.managers.command_manager import matcher
|
from core.managers.command_manager import matcher
|
||||||
from core.bot import Bot
|
from core.bot import Bot
|
||||||
from models import MessageEvent
|
from models.events.message import MessageEvent
|
||||||
from models.message import MessageSegment
|
from models.message import MessageSegment
|
||||||
|
|
||||||
__plugin_meta__ = {
|
__plugin_meta__ = {
|
||||||
@@ -22,14 +22,15 @@ async def handle_forward_test(bot: Bot, event: MessageEvent, args: list[str]):
|
|||||||
:param args: 指令参数
|
:param args: 指令参数
|
||||||
"""
|
"""
|
||||||
# 1. 构建消息节点列表
|
# 1. 构建消息节点列表
|
||||||
|
nickname = event.sender.nickname if event.sender else "未知用户"
|
||||||
nodes = [
|
nodes = [
|
||||||
bot.build_forward_node(user_id=event.self_id, nickname="机器人", message="你要的furry来了"),
|
bot.build_forward_node(user_id=event.self_id, nickname="机器人", message="你要的furry来了"),
|
||||||
bot.build_forward_node(user_id=event.user_id, nickname=event.sender.nickname, message="让我看看"),
|
bot.build_forward_node(user_id=event.user_id, nickname=nickname, message="让我看看"),
|
||||||
bot.build_forward_node(
|
bot.build_forward_node(
|
||||||
user_id=event.self_id,
|
user_id=event.self_id,
|
||||||
nickname="机器人",
|
nickname="机器人",
|
||||||
message=[
|
message=[
|
||||||
MessageSegment.text("你要的福瑞图"),
|
MessageSegment.from_text("你要的福瑞图"),
|
||||||
MessageSegment.image("https://api.furry.ist/furry-img/")
|
MessageSegment.image("https://api.furry.ist/furry-img/")
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from datetime import datetime
|
|||||||
from core.bot import Bot
|
from core.bot import Bot
|
||||||
from core.managers.command_manager import matcher
|
from core.managers.command_manager import matcher
|
||||||
from core.utils.executor import run_in_thread_pool
|
from core.utils.executor import run_in_thread_pool
|
||||||
from models import MessageEvent, MessageSegment
|
from models.events.message import MessageEvent, MessageSegment
|
||||||
|
|
||||||
__plugin_meta__ = {
|
__plugin_meta__ = {
|
||||||
"name": "jrcd",
|
"name": "jrcd",
|
||||||
@@ -79,14 +79,17 @@ async def handle_jrcd(bot: Bot, event: MessageEvent, args: list[str]):
|
|||||||
"""
|
"""
|
||||||
user_id = event.user_id
|
user_id = event.user_id
|
||||||
jrcd = await run_in_thread_pool(get_jrcd, user_id)
|
jrcd = await run_in_thread_pool(get_jrcd, user_id)
|
||||||
msg = [MessageSegment.at(user_id)]
|
|
||||||
|
msg_text = ""
|
||||||
if jrcd <= 9:
|
if jrcd <= 9:
|
||||||
msg.append(MessageSegment.text(random.choice(JRCDMSG_1) % jrcd))
|
msg_text = random.choice(JRCDMSG_1) % jrcd
|
||||||
elif jrcd <= 19:
|
elif jrcd <= 19:
|
||||||
msg.append(MessageSegment.text(random.choice(JRCDMSG_2) % jrcd))
|
msg_text = random.choice(JRCDMSG_2) % jrcd
|
||||||
else:
|
else:
|
||||||
msg.append(MessageSegment.text(random.choice(JRCDMSG_3) % jrcd))
|
msg_text = random.choice(JRCDMSG_3) % jrcd
|
||||||
await event.reply(msg)
|
|
||||||
|
reply_segments = [MessageSegment.at(user_id), MessageSegment.from_text(msg_text)]
|
||||||
|
await event.reply(reply_segments)
|
||||||
|
|
||||||
|
|
||||||
@matcher.command("bbcd")
|
@matcher.command("bbcd")
|
||||||
@@ -118,29 +121,31 @@ async def handle_bbcd(bot: Bot, event: MessageEvent, args: list[str]):
|
|||||||
|
|
||||||
jrcz = jrcd1 - jrcd2
|
jrcz = jrcd1 - jrcd2
|
||||||
|
|
||||||
msg = [
|
text_part = ""
|
||||||
|
if jrcz == 0:
|
||||||
|
text_part = f" 一样长。{random.choice(BBCDMSG7)}"
|
||||||
|
elif jrcz > 0:
|
||||||
|
text_part = f" 长{jrcz}cm。"
|
||||||
|
if jrcz <= 9:
|
||||||
|
text_part += random.choice(BBCDMSG1)
|
||||||
|
elif jrcz <= 19:
|
||||||
|
text_part += random.choice(BBCDMSG2)
|
||||||
|
else:
|
||||||
|
text_part += random.choice(BBCDMSG3)
|
||||||
|
else: # jrcz < 0
|
||||||
|
text_part = f" 短{abs(jrcz)}cm。"
|
||||||
|
if jrcz >= -9:
|
||||||
|
text_part += random.choice(BBCDMSG4)
|
||||||
|
elif jrcz >= -19:
|
||||||
|
text_part += random.choice(BBCDMSG5)
|
||||||
|
else:
|
||||||
|
text_part += random.choice(BBCDMSG6)
|
||||||
|
|
||||||
|
segments = [
|
||||||
MessageSegment.at(user_id1),
|
MessageSegment.at(user_id1),
|
||||||
MessageSegment.text("你的长度比"),
|
MessageSegment.from_text(" 你的长度比 "),
|
||||||
MessageSegment.at(user_id2),
|
MessageSegment.at(user_id2),
|
||||||
|
MessageSegment.from_text(text_part),
|
||||||
]
|
]
|
||||||
|
|
||||||
if jrcz == 0:
|
await event.reply(segments)
|
||||||
msg.append(MessageSegment.text("一样长。"))
|
|
||||||
msg.append(MessageSegment.text(random.choice(BBCDMSG7)))
|
|
||||||
elif jrcz > 0:
|
|
||||||
msg.append(MessageSegment.text("长" + str(jrcz) + "cm。"))
|
|
||||||
if jrcz <= 9:
|
|
||||||
msg.append(MessageSegment.text(random.choice(BBCDMSG1)))
|
|
||||||
elif jrcz <= 19:
|
|
||||||
msg.append(MessageSegment.text(random.choice(BBCDMSG2)))
|
|
||||||
else:
|
|
||||||
msg.append(MessageSegment.text(random.choice(BBCDMSG3)))
|
|
||||||
elif jrcz < 0:
|
|
||||||
msg.append(MessageSegment.text("短" + str(abs(jrcz)) + "cm。"))
|
|
||||||
if jrcz >= -9:
|
|
||||||
msg.append(MessageSegment.text(random.choice(BBCDMSG4)))
|
|
||||||
elif jrcz >= -19:
|
|
||||||
msg.append(MessageSegment.text(random.choice(BBCDMSG5)))
|
|
||||||
else:
|
|
||||||
msg.append(MessageSegment.text(random.choice(BBCDMSG6)))
|
|
||||||
await event.reply(msg)
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ thpic 插件
|
|||||||
|
|
||||||
from core.bot import Bot
|
from core.bot import Bot
|
||||||
from core.managers.command_manager import matcher
|
from core.managers.command_manager import matcher
|
||||||
from models import MessageEvent, MessageSegment
|
from models.events.message import MessageEvent, MessageSegment
|
||||||
|
|
||||||
__plugin_meta__ = {
|
__plugin_meta__ = {
|
||||||
"name": "thpic",
|
"name": "thpic",
|
||||||
@@ -26,6 +26,6 @@ async def handle_echo(bot: Bot, event: MessageEvent, args: list[str]):
|
|||||||
:param args: 指令参数列表(未使用)。
|
:param args: 指令参数列表(未使用)。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
await event.reply(MessageSegment.image("https://img.paulzzh.com/touhou/random"))
|
await event.reply(str(MessageSegment.image("https://img.paulzzh.com/touhou/random")))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await event.reply("报错了。。。" + e)
|
await event.reply(f"报错了。。。{e}")
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ pipreqs==0.4.13
|
|||||||
redis==5.0.7
|
redis==5.0.7
|
||||||
requests==2.32.5
|
requests==2.32.5
|
||||||
soupsieve==2.8.1
|
soupsieve==2.8.1
|
||||||
toml==0.10.2
|
|
||||||
typing==3.7.4.3
|
typing==3.7.4.3
|
||||||
typing_extensions==4.15.0
|
typing_extensions==4.15.0
|
||||||
urllib3==2.6.2
|
urllib3==2.6.2
|
||||||
@@ -19,7 +18,15 @@ watchdog==6.0.0
|
|||||||
websockets==15.0.1
|
websockets==15.0.1
|
||||||
win32_setctime==1.2.0
|
win32_setctime==1.2.0
|
||||||
yarg==0.1.10
|
yarg==0.1.10
|
||||||
|
cachetools
|
||||||
|
pydantic
|
||||||
docker
|
docker
|
||||||
pytest
|
pytest
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
pytest-mock
|
pytest-mock
|
||||||
|
pytest-cov
|
||||||
|
httpx==0.27.0
|
||||||
|
|
||||||
|
# Dev Dependencies
|
||||||
|
mypy
|
||||||
|
pydantic[mypy]
|
||||||
|
|||||||
37
tests/test_basic.py
Normal file
37
tests/test_basic.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 确保项目根目录在 sys.path 中
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
def test_import_core():
|
||||||
|
"""测试核心模块是否可以被导入"""
|
||||||
|
try:
|
||||||
|
import core
|
||||||
|
import core.bot
|
||||||
|
import core.ws
|
||||||
|
except ImportError as e:
|
||||||
|
pytest.fail(f"无法导入核心模块: {e}")
|
||||||
|
|
||||||
|
def test_plugin_manager_path():
|
||||||
|
"""测试插件管理器路径逻辑是否正确"""
|
||||||
|
from core.managers.plugin_manager import PluginManager
|
||||||
|
# Mock command manager
|
||||||
|
pm = PluginManager(None)
|
||||||
|
|
||||||
|
# 我们无法直接测试 load_all_plugins 的内部路径变量,
|
||||||
|
# 但我们可以检查它是否能找到 plugins 目录而不报错
|
||||||
|
# 这里我们简单地断言 PluginManager 类存在且可以实例化
|
||||||
|
assert pm is not None
|
||||||
|
|
||||||
|
def test_config_loader_exists():
|
||||||
|
"""测试配置加载器是否存在"""
|
||||||
|
# 注意:导入 config_loader 会尝试读取 config.toml
|
||||||
|
# 如果 config.toml 不存在,这可能会失败。
|
||||||
|
# 这是一个已知的设计问题,但在测试环境中我们假设 config.toml 存在或被 mock
|
||||||
|
if os.path.exists("config.toml"):
|
||||||
|
from core.config_loader import global_config
|
||||||
|
assert global_config is not None
|
||||||
|
else:
|
||||||
|
pytest.skip("config.toml 不存在,跳过配置加载测试")
|
||||||
114
tests/test_command_manager.py
Normal file
114
tests/test_command_manager.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from core.managers.command_manager import CommandManager
|
||||||
|
from models.events.message import GroupMessageEvent
|
||||||
|
from models.message import MessageSegment
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_bot():
|
||||||
|
bot = AsyncMock()
|
||||||
|
bot.self_id = 123456
|
||||||
|
return bot
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def command_manager():
|
||||||
|
# 创建一个新的 CommandManager 实例用于测试,避免单例状态污染
|
||||||
|
return CommandManager(prefixes=("/",))
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_command_registration_and_execution(command_manager, mock_bot):
|
||||||
|
"""测试命令注册和执行"""
|
||||||
|
|
||||||
|
# 定义一个命令处理函数
|
||||||
|
handler_mock = AsyncMock()
|
||||||
|
|
||||||
|
# 注册命令
|
||||||
|
@command_manager.command("test")
|
||||||
|
async def test_command(bot, event):
|
||||||
|
await handler_mock(bot, event)
|
||||||
|
|
||||||
|
# 构造触发命令的事件
|
||||||
|
event = MagicMock(spec=GroupMessageEvent)
|
||||||
|
event.post_type = "message"
|
||||||
|
event.message_type = "group"
|
||||||
|
event.raw_message = "/test"
|
||||||
|
event.message = [MessageSegment.text("/test")]
|
||||||
|
event.user_id = 111
|
||||||
|
event.group_id = 222
|
||||||
|
|
||||||
|
# 处理事件
|
||||||
|
await command_manager.handle_event(mock_bot, event)
|
||||||
|
|
||||||
|
# 验证处理函数被调用
|
||||||
|
handler_mock.assert_called_once_with(mock_bot, event)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_command_prefix_match(command_manager, mock_bot):
|
||||||
|
"""测试命令前缀匹配"""
|
||||||
|
handler_mock = AsyncMock()
|
||||||
|
|
||||||
|
@command_manager.command("hello")
|
||||||
|
async def hello_command(bot, event):
|
||||||
|
await handler_mock(bot, event)
|
||||||
|
|
||||||
|
# 1. 正确的前缀
|
||||||
|
event1 = MagicMock(spec=GroupMessageEvent)
|
||||||
|
event1.post_type = "message"
|
||||||
|
event1.raw_message = "/hello"
|
||||||
|
event1.message = [MessageSegment.text("/hello")]
|
||||||
|
await command_manager.handle_event(mock_bot, event1)
|
||||||
|
handler_mock.assert_called_once()
|
||||||
|
handler_mock.reset_mock()
|
||||||
|
|
||||||
|
# 2. 错误的前缀 (应该忽略)
|
||||||
|
event2 = MagicMock(spec=GroupMessageEvent)
|
||||||
|
event2.post_type = "message"
|
||||||
|
event2.raw_message = ".hello" # 假设前缀是 /
|
||||||
|
event2.message = [MessageSegment.text(".hello")]
|
||||||
|
await command_manager.handle_event(mock_bot, event2)
|
||||||
|
handler_mock.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ignore_self_message(command_manager, mock_bot):
|
||||||
|
"""测试忽略自身消息"""
|
||||||
|
# 模拟配置
|
||||||
|
with patch("core.managers.command_manager.global_config") as mock_config:
|
||||||
|
mock_config.bot.ignore_self_message = True
|
||||||
|
|
||||||
|
event = MagicMock(spec=GroupMessageEvent)
|
||||||
|
event.post_type = "message"
|
||||||
|
event.user_id = 123456 # 与 bot.self_id 相同
|
||||||
|
event.self_id = 123456
|
||||||
|
|
||||||
|
# Mock handle 方法来检测是否被调用
|
||||||
|
command_manager.message_handler.handle = AsyncMock()
|
||||||
|
|
||||||
|
await command_manager.handle_event(mock_bot, event)
|
||||||
|
|
||||||
|
# 应该直接返回,不调用 handler
|
||||||
|
command_manager.message_handler.handle.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_help_command(command_manager, mock_bot):
|
||||||
|
"""测试内置 help 命令"""
|
||||||
|
# 注册一个测试插件信息
|
||||||
|
command_manager.plugins["test_plugin"] = {
|
||||||
|
"name": "测试插件",
|
||||||
|
"description": "这是一个测试",
|
||||||
|
"usage": "/test"
|
||||||
|
}
|
||||||
|
|
||||||
|
event = MagicMock(spec=GroupMessageEvent)
|
||||||
|
event.post_type = "message"
|
||||||
|
event.raw_message = "/help"
|
||||||
|
event.message = [MessageSegment.text("/help")]
|
||||||
|
|
||||||
|
await command_manager.handle_event(mock_bot, event)
|
||||||
|
|
||||||
|
# 验证 bot.send 被调用,且内容包含插件信息
|
||||||
|
mock_bot.send.assert_called_once()
|
||||||
|
args, _ = mock_bot.send.call_args
|
||||||
|
sent_msg = args[1]
|
||||||
|
assert "测试插件" in sent_msg
|
||||||
|
assert "这是一个测试" in sent_msg
|
||||||
141
tests/test_event_factory.py
Normal file
141
tests/test_event_factory.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
import pytest
|
||||||
|
from models.events.factory import EventFactory, EventType
|
||||||
|
from models.events.message import GroupMessageEvent, PrivateMessageEvent
|
||||||
|
from models.events.notice import GroupIncreaseNoticeEvent
|
||||||
|
from models.events.request import FriendRequestEvent
|
||||||
|
from models.events.meta import HeartbeatEvent
|
||||||
|
from models.message import MessageSegment
|
||||||
|
|
||||||
|
class TestEventFactory:
|
||||||
|
def test_create_group_message_event_list(self):
|
||||||
|
"""测试创建群消息事件 (message 为列表格式)"""
|
||||||
|
data = {
|
||||||
|
"post_type": "message",
|
||||||
|
"message_type": "group",
|
||||||
|
"time": 1600000000,
|
||||||
|
"self_id": 123456,
|
||||||
|
"sub_type": "normal",
|
||||||
|
"message_id": 1001,
|
||||||
|
"user_id": 111111,
|
||||||
|
"group_id": 222222,
|
||||||
|
"message": [
|
||||||
|
{"type": "text", "data": {"text": "Hello"}}
|
||||||
|
],
|
||||||
|
"raw_message": "Hello",
|
||||||
|
"font": 0,
|
||||||
|
"sender": {
|
||||||
|
"user_id": 111111,
|
||||||
|
"nickname": "User",
|
||||||
|
"role": "member"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
event = EventFactory.create_event(data)
|
||||||
|
assert isinstance(event, GroupMessageEvent)
|
||||||
|
assert event.group_id == 222222
|
||||||
|
assert len(event.message) == 1
|
||||||
|
assert event.message[0].type == "text"
|
||||||
|
assert event.message[0].data["text"] == "Hello"
|
||||||
|
|
||||||
|
def test_create_group_message_event_str(self):
|
||||||
|
"""测试创建群消息事件 (message 为字符串格式)"""
|
||||||
|
data = {
|
||||||
|
"post_type": "message",
|
||||||
|
"message_type": "group",
|
||||||
|
"time": 1600000000,
|
||||||
|
"self_id": 123456,
|
||||||
|
"sub_type": "normal",
|
||||||
|
"message_id": 1002,
|
||||||
|
"user_id": 111111,
|
||||||
|
"group_id": 222222,
|
||||||
|
"message": "Hello World",
|
||||||
|
"raw_message": "Hello World",
|
||||||
|
"font": 0,
|
||||||
|
"sender": {
|
||||||
|
"user_id": 111111,
|
||||||
|
"nickname": "User"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
event = EventFactory.create_event(data)
|
||||||
|
assert isinstance(event, GroupMessageEvent)
|
||||||
|
assert len(event.message) == 1
|
||||||
|
assert event.message[0].type == "text"
|
||||||
|
assert event.message[0].data["text"] == "Hello World"
|
||||||
|
|
||||||
|
def test_create_private_message_event(self):
|
||||||
|
"""测试创建私聊消息事件"""
|
||||||
|
data = {
|
||||||
|
"post_type": "message",
|
||||||
|
"message_type": "private",
|
||||||
|
"time": 1600000000,
|
||||||
|
"self_id": 123456,
|
||||||
|
"sub_type": "friend",
|
||||||
|
"message_id": 2001,
|
||||||
|
"user_id": 333333,
|
||||||
|
"message": "Private Msg",
|
||||||
|
"raw_message": "Private Msg",
|
||||||
|
"font": 0,
|
||||||
|
"sender": {
|
||||||
|
"user_id": 333333,
|
||||||
|
"nickname": "Friend"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
event = EventFactory.create_event(data)
|
||||||
|
assert isinstance(event, PrivateMessageEvent)
|
||||||
|
assert event.user_id == 333333
|
||||||
|
|
||||||
|
def test_create_notice_event(self):
|
||||||
|
"""测试创建通知事件 (群成员增加)"""
|
||||||
|
data = {
|
||||||
|
"post_type": "notice",
|
||||||
|
"notice_type": "group_increase",
|
||||||
|
"sub_type": "approve",
|
||||||
|
"group_id": 222222,
|
||||||
|
"operator_id": 444444,
|
||||||
|
"user_id": 555555,
|
||||||
|
"time": 1600000000,
|
||||||
|
"self_id": 123456
|
||||||
|
}
|
||||||
|
event = EventFactory.create_event(data)
|
||||||
|
assert isinstance(event, GroupIncreaseNoticeEvent)
|
||||||
|
assert event.group_id == 222222
|
||||||
|
assert event.user_id == 555555
|
||||||
|
|
||||||
|
def test_create_request_event(self):
|
||||||
|
"""测试创建请求事件 (加好友)"""
|
||||||
|
data = {
|
||||||
|
"post_type": "request",
|
||||||
|
"request_type": "friend",
|
||||||
|
"user_id": 666666,
|
||||||
|
"comment": "Add me",
|
||||||
|
"flag": "flag_123",
|
||||||
|
"time": 1600000000,
|
||||||
|
"self_id": 123456
|
||||||
|
}
|
||||||
|
event = EventFactory.create_event(data)
|
||||||
|
assert isinstance(event, FriendRequestEvent)
|
||||||
|
assert event.user_id == 666666
|
||||||
|
assert event.comment == "Add me"
|
||||||
|
|
||||||
|
def test_create_meta_event(self):
|
||||||
|
"""测试创建元事件 (心跳)"""
|
||||||
|
data = {
|
||||||
|
"post_type": "meta_event",
|
||||||
|
"meta_event_type": "heartbeat",
|
||||||
|
"time": 1600000000,
|
||||||
|
"self_id": 123456,
|
||||||
|
"status": {"online": True, "good": True},
|
||||||
|
"interval": 5000
|
||||||
|
}
|
||||||
|
event = EventFactory.create_event(data)
|
||||||
|
assert isinstance(event, HeartbeatEvent)
|
||||||
|
assert event.interval == 5000
|
||||||
|
|
||||||
|
def test_unknown_event_type(self):
|
||||||
|
"""测试未知事件类型"""
|
||||||
|
data = {
|
||||||
|
"post_type": "unknown_type",
|
||||||
|
"time": 1600000000,
|
||||||
|
"self_id": 123456
|
||||||
|
}
|
||||||
|
with pytest.raises(ValueError, match="Unknown event type"):
|
||||||
|
EventFactory.create_event(data)
|
||||||
194
tests/test_event_handler.py
Normal file
194
tests/test_event_handler.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from core.handlers.event_handler import MessageHandler, NoticeHandler, RequestHandler
|
||||||
|
from models.events.message import GroupMessageEvent
|
||||||
|
from models.events.notice import GroupIncreaseNoticeEvent
|
||||||
|
from models.events.request import FriendRequestEvent
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_bot():
|
||||||
|
bot = AsyncMock()
|
||||||
|
return bot
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_message_handler_run_handler_injection(mock_bot):
|
||||||
|
"""测试参数注入"""
|
||||||
|
handler = MessageHandler(prefixes=("/",))
|
||||||
|
|
||||||
|
# 1. 测试注入 bot 和 event
|
||||||
|
async def func1(bot, event):
|
||||||
|
assert bot == mock_bot
|
||||||
|
assert event.user_id == 123
|
||||||
|
return True
|
||||||
|
|
||||||
|
event = MagicMock(spec=GroupMessageEvent)
|
||||||
|
event.user_id = 123
|
||||||
|
|
||||||
|
result = await handler._run_handler(func1, mock_bot, event)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
# 2. 测试注入 args
|
||||||
|
async def func2(args):
|
||||||
|
assert args == ["arg1", "arg2"]
|
||||||
|
return True
|
||||||
|
|
||||||
|
result = await handler._run_handler(func2, mock_bot, event, args=["arg1", "arg2"])
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_message_handler_command_parsing(mock_bot):
|
||||||
|
"""测试命令解析"""
|
||||||
|
handler = MessageHandler(prefixes=("/",))
|
||||||
|
|
||||||
|
mock_func = AsyncMock()
|
||||||
|
handler.commands["test"] = {
|
||||||
|
"func": mock_func,
|
||||||
|
"permission": None,
|
||||||
|
"override_permission_check": False,
|
||||||
|
"plugin_name": "test_plugin"
|
||||||
|
}
|
||||||
|
|
||||||
|
event = MagicMock(spec=GroupMessageEvent)
|
||||||
|
event.raw_message = "/test arg1 arg2"
|
||||||
|
event.user_id = 123
|
||||||
|
|
||||||
|
# Mock permission manager
|
||||||
|
with patch("core.managers.permission_manager.PermissionManager.check_permission", new_callable=AsyncMock) as mock_perm:
|
||||||
|
mock_perm.return_value = True
|
||||||
|
|
||||||
|
await handler.handle(mock_bot, event)
|
||||||
|
|
||||||
|
mock_func.assert_called_once()
|
||||||
|
# 验证 args 参数是否正确传递
|
||||||
|
call_args = mock_func.call_args
|
||||||
|
if "args" in call_args.kwargs:
|
||||||
|
assert call_args.kwargs["args"] == ["arg1", "arg2"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notice_handler(mock_bot):
|
||||||
|
"""测试通知事件分发"""
|
||||||
|
handler = NoticeHandler()
|
||||||
|
|
||||||
|
mock_func = AsyncMock()
|
||||||
|
handler.handlers.append({
|
||||||
|
"type": "group_increase",
|
||||||
|
"func": mock_func,
|
||||||
|
"plugin_name": "test_plugin"
|
||||||
|
})
|
||||||
|
|
||||||
|
event = MagicMock(spec=GroupIncreaseNoticeEvent)
|
||||||
|
event.notice_type = "group_increase"
|
||||||
|
|
||||||
|
await handler.handle(mock_bot, event)
|
||||||
|
|
||||||
|
mock_func.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sync_handler_execution(mock_bot):
|
||||||
|
"""测试同步处理函数的执行"""
|
||||||
|
handler = MessageHandler(prefixes=("/",))
|
||||||
|
|
||||||
|
def sync_func(event):
|
||||||
|
return True
|
||||||
|
|
||||||
|
event = MagicMock(spec=GroupMessageEvent)
|
||||||
|
|
||||||
|
# 同步函数应该在线程池中运行
|
||||||
|
result = await handler._run_handler(sync_func, mock_bot, event)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_message_handler_management(mock_bot):
|
||||||
|
"""测试消息处理器的管理(注册、卸载、清空)"""
|
||||||
|
handler = MessageHandler(prefixes=("/",))
|
||||||
|
|
||||||
|
# 测试 on_message 装饰器
|
||||||
|
@handler.on_message()
|
||||||
|
async def msg_handler(event):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert len(handler.message_handlers) == 1
|
||||||
|
|
||||||
|
# 测试 command 装饰器
|
||||||
|
@handler.command("cmd1", "cmd2")
|
||||||
|
async def cmd_handler(event):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert len(handler.commands) == 2
|
||||||
|
assert "cmd1" in handler.commands
|
||||||
|
assert "cmd2" in handler.commands
|
||||||
|
|
||||||
|
# 测试 unregister_by_plugin_name
|
||||||
|
# 直接从已注册的处理器中获取 plugin_name
|
||||||
|
if handler.message_handlers:
|
||||||
|
plugin_name = handler.message_handlers[0]["plugin_name"]
|
||||||
|
handler.unregister_by_plugin_name(plugin_name)
|
||||||
|
|
||||||
|
assert len(handler.message_handlers) == 0
|
||||||
|
assert len(handler.commands) == 0
|
||||||
|
|
||||||
|
# 测试 clear
|
||||||
|
handler.commands["cmd"] = {}
|
||||||
|
handler.message_handlers.append({})
|
||||||
|
handler.clear()
|
||||||
|
assert len(handler.commands) == 0
|
||||||
|
assert len(handler.message_handlers) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_handler(mock_bot):
|
||||||
|
"""测试请求事件处理器"""
|
||||||
|
handler = RequestHandler()
|
||||||
|
|
||||||
|
mock_func = AsyncMock()
|
||||||
|
|
||||||
|
# 测试 register 装饰器
|
||||||
|
@handler.register("friend")
|
||||||
|
async def req_handler(event):
|
||||||
|
await mock_func(event)
|
||||||
|
|
||||||
|
assert len(handler.handlers) == 1
|
||||||
|
|
||||||
|
event = MagicMock(spec=FriendRequestEvent)
|
||||||
|
event.request_type = "friend"
|
||||||
|
|
||||||
|
await handler.handle(mock_bot, event)
|
||||||
|
mock_func.assert_called_once()
|
||||||
|
|
||||||
|
# 测试 unregister 和 clear
|
||||||
|
import inspect
|
||||||
|
module = inspect.getmodule(req_handler)
|
||||||
|
plugin_name = module.__name__
|
||||||
|
|
||||||
|
handler.unregister_by_plugin_name(plugin_name)
|
||||||
|
assert len(handler.handlers) == 0
|
||||||
|
|
||||||
|
handler.handlers.append({})
|
||||||
|
handler.clear()
|
||||||
|
assert len(handler.handlers) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_permission_denied(mock_bot):
|
||||||
|
"""测试权限不足的情况"""
|
||||||
|
handler = MessageHandler(prefixes=("/",))
|
||||||
|
|
||||||
|
mock_func = AsyncMock()
|
||||||
|
handler.commands["admin_cmd"] = {
|
||||||
|
"func": mock_func,
|
||||||
|
"permission": "ADMIN", # 假设 Permission.ADMIN
|
||||||
|
"override_permission_check": False,
|
||||||
|
"plugin_name": "test_plugin"
|
||||||
|
}
|
||||||
|
|
||||||
|
event = MagicMock(spec=GroupMessageEvent)
|
||||||
|
event.raw_message = "/admin_cmd"
|
||||||
|
event.user_id = 123
|
||||||
|
|
||||||
|
# Mock permission manager returning False
|
||||||
|
with patch("core.managers.permission_manager.PermissionManager.check_permission", new_callable=AsyncMock) as mock_perm:
|
||||||
|
mock_perm.return_value = False
|
||||||
|
|
||||||
|
await handler.handle(mock_bot, event)
|
||||||
|
|
||||||
|
mock_func.assert_not_called()
|
||||||
|
# 应该发送拒绝消息
|
||||||
|
mock_bot.send.assert_called_once()
|
||||||
75
tests/test_models.py
Normal file
75
tests/test_models.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import pytest
|
||||||
|
from models.message import MessageSegment
|
||||||
|
from models.objects import GroupInfo, StrangerInfo
|
||||||
|
|
||||||
|
class TestMessageSegment:
|
||||||
|
def test_text_segment(self):
|
||||||
|
seg = MessageSegment.text("Hello")
|
||||||
|
assert seg.type == "text"
|
||||||
|
assert seg.data["text"] == "Hello"
|
||||||
|
assert str(seg) == "Hello"
|
||||||
|
|
||||||
|
def test_at_segment(self):
|
||||||
|
seg = MessageSegment.at(123456)
|
||||||
|
assert seg.type == "at"
|
||||||
|
assert seg.data["qq"] == "123456"
|
||||||
|
assert str(seg) == "[CQ:at,qq=123456]"
|
||||||
|
|
||||||
|
def test_image_segment(self):
|
||||||
|
seg = MessageSegment.image("http://example.com/img.jpg", cache=False, proxy=False)
|
||||||
|
assert seg.type == "image"
|
||||||
|
assert seg.data["file"] == "http://example.com/img.jpg"
|
||||||
|
assert str(seg) == "[CQ:image,file=http://example.com/img.jpg,cache=0,proxy=0]"
|
||||||
|
|
||||||
|
def test_face_segment(self):
|
||||||
|
seg = MessageSegment.face(123)
|
||||||
|
assert seg.type == "face"
|
||||||
|
assert seg.data["id"] == "123"
|
||||||
|
assert str(seg) == "[CQ:face,id=123]"
|
||||||
|
|
||||||
|
def test_reply_segment(self):
|
||||||
|
seg = MessageSegment.reply(1001)
|
||||||
|
assert seg.type == "reply"
|
||||||
|
assert seg.data["id"] == "1001"
|
||||||
|
assert str(seg) == "[CQ:reply,id=1001]"
|
||||||
|
|
||||||
|
def test_add_segments(self):
|
||||||
|
seg1 = MessageSegment.text("Hello ")
|
||||||
|
seg2 = MessageSegment.at(123)
|
||||||
|
combined = seg1 + seg2
|
||||||
|
assert isinstance(combined, list)
|
||||||
|
assert len(combined) == 2
|
||||||
|
assert combined[0] == seg1
|
||||||
|
assert combined[1] == seg2
|
||||||
|
|
||||||
|
def test_add_segment_and_string(self):
|
||||||
|
seg = MessageSegment.at(123)
|
||||||
|
combined = seg + " Hello"
|
||||||
|
assert isinstance(combined, list)
|
||||||
|
assert len(combined) == 2
|
||||||
|
assert combined[0] == seg
|
||||||
|
assert combined[1].type == "text"
|
||||||
|
assert combined[1].data["text"] == " Hello"
|
||||||
|
|
||||||
|
class TestObjects:
|
||||||
|
def test_group_info(self):
|
||||||
|
data = {
|
||||||
|
"group_id": 123456,
|
||||||
|
"group_name": "Test Group",
|
||||||
|
"member_count": 10,
|
||||||
|
"max_member_count": 100
|
||||||
|
}
|
||||||
|
group = GroupInfo(**data)
|
||||||
|
assert group.group_id == 123456
|
||||||
|
assert group.group_name == "Test Group"
|
||||||
|
|
||||||
|
def test_stranger_info(self):
|
||||||
|
data = {
|
||||||
|
"user_id": 111111,
|
||||||
|
"nickname": "Stranger",
|
||||||
|
"sex": "male",
|
||||||
|
"age": 18
|
||||||
|
}
|
||||||
|
user = StrangerInfo(**data)
|
||||||
|
assert user.user_id == 111111
|
||||||
|
assert user.nickname == "Stranger"
|
||||||
Reference in New Issue
Block a user