* 滚木

* feat: 重构核心架构,增强类型安全与插件管理

本次提交对核心模块进行了深度重构,引入 Pydantic 增强配置管理的类型安全性,并全面优化了插件管理系统。

主要变更详情:

1. 核心架构与配置
   - 重构配置加载模块:引入 Pydantic 模型 (`core/config_models.py`),提供严格的配置项类型检查、验证及默认值管理。
   - 统一模块结构:规范化模块导入路径,移除冗余的 `__init__.py` 文件,提升项目结构的清晰度。
   - 性能优化:集成 Redis 缓存支持 (`RedisManager`),有效降低高频 API 调用开销,提升响应速度。

2. 插件系统升级
   - 实现热重载机制:新增插件文件变更监听功能,支持开发过程中自动重载插件,提升开发效率。
   - 优化生命周期管理:改进插件加载与卸载逻辑,支持精确卸载指定插件及其关联的命令、事件处理器和定时任务。

3. 功能特性增强
   - 新增媒体 API:引入 `MediaAPI` 模块,封装图片、语音等富媒体资源的获取与处理接口。
   - 完善权限体系:重构权限管理系统,实现管理员与操作员的分级控制,支持更细粒度的命令权限校验。

4. 代码质量与稳定性
   - 全面类型修复:解决 `mypy` 静态类型检查发现的大量类型错误(包括 `CommandManager`、`EventFactory` 及 `Bot` API 签名不匹配问题)。
   - 增强错误处理:优化消息处理管道的异常捕获机制,完善关键路径的日志记录,提升系统运行稳定性。

* feat: 添加测试用例并优化代码结构

refactor(permission_manager): 调整初始化顺序和逻辑
fix(admin_manager): 修复初始化逻辑和目录创建问题
feat(ws): 优化Bot实例初始化条件
feat(message): 增强MessageSegment功能并添加测试
feat(events): 支持字符串格式的消息解析
test: 添加核心功能测试用例
refactor(plugin_manager): 改进插件路径处理
style: 清理无用导入和代码
chore: 更新依赖项
This commit is contained in:
镀铬酸钾
2026-01-09 00:20:56 +08:00
committed by GitHub
parent 6d7dfc179d
commit fa81229f6f
42 changed files with 1461 additions and 697 deletions

View File

@@ -1,6 +0,0 @@
from .managers.command_manager import matcher
from .config_loader import global_config
from .managers.plugin_manager import PluginDataManager
from .ws import WS
__all__ = ["WS", "matcher", "global_config", "PluginDataManager"]

View File

@@ -3,6 +3,7 @@ from .message import MessageAPI
from .group import GroupAPI from .group import GroupAPI
from .friend import FriendAPI from .friend import FriendAPI
from .account import AccountAPI from .account import AccountAPI
from .media import MediaAPI
__all__ = [ __all__ = [
"BaseAPI", "BaseAPI",
@@ -10,4 +11,5 @@ __all__ = [
"GroupAPI", "GroupAPI",
"FriendAPI", "FriendAPI",
"AccountAPI", "AccountAPI",
"MediaAPI",
] ]

View File

@@ -162,3 +162,56 @@ class AccountAPI(BaseAPI):
""" """
return await self.call_api("clean_cache") return await self.call_api("clean_cache")
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> Any:
"""
获取陌生人信息。
Args:
user_id (int): 目标用户的 QQ 号。
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
Returns:
Any: 包含陌生人信息的字典或对象。
"""
return await self.call_api("get_stranger_info", {"user_id": user_id, "no_cache": no_cache})
async def get_friend_list(self, no_cache: bool = False) -> list:
"""
获取好友列表。
Args:
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
Returns:
list: 好友列表。
"""
cache_key = f"neobot:cache:get_friend_list:{self.self_id}"
if not no_cache:
cached_data = await redis_manager.get(cache_key)
if cached_data:
return json.loads(cached_data)
res = await self.call_api("get_friend_list")
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
return res
async def get_group_list(self, no_cache: bool = False) -> list:
"""
获取群列表。
Args:
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
Returns:
list: 群列表。
"""
cache_key = f"neobot:cache:get_group_list:{self.self_id}"
if not no_cache:
cached_data = await redis_manager.get(cache_key)
if cached_data:
return json.loads(cached_data)
res = await self.call_api("get_group_list")
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
return res

View File

@@ -1,24 +1,50 @@
""" """
API 基础模块 API 基础模块
定义了 API 调用的基础接口。 定义了 API 调用的基础接口和统一处理逻辑
""" """
from abc import ABC, abstractmethod from typing import Any, Dict, Optional, TYPE_CHECKING
from typing import Any, Dict, Optional
from ..utils.logger import logger
if TYPE_CHECKING:
from ..ws import WS
class BaseAPI(ABC): class BaseAPI:
""" """
API 基础抽象 API 基础类,提供了统一的 `call_api` 方法,包含日志记录和异常处理。
""" """
_ws: "WS"
self_id: int
def __init__(self, ws_client: "WS", self_id: int):
self._ws = ws_client
self.self_id = self_id
@abstractmethod
async def call_api(self, action: str, params: Optional[Dict[str, Any]] = None) -> Any: async def call_api(self, action: str, params: Optional[Dict[str, Any]] = None) -> Any:
""" """
调用 API 调用 OneBot v11 API并提供统一的日志和异常处理。
:param action: API 动作名称 :param action: API 动作名称
:param params: API 参数 :param params: API 参数
:return: API 响应结果 :return: API 响应结果的数据部分
:raises Exception: 当 API 调用失败或发生网络错误时
""" """
raise NotImplementedError if params is None:
params = {}
try:
logger.debug(f"调用API -> action: {action}, params: {params}")
response = await self._ws.call_api(action, params)
logger.debug(f"API响应 <- {response}")
if response.get("status") == "failed":
logger.warning(f"API调用失败: {response}")
return response.get("data")
except Exception as e:
logger.error(f"API调用异常: action={action}, params={params}, error={e}")
raise

View File

@@ -4,7 +4,7 @@
该模块定义了 `GroupAPI` Mixin 类,提供了所有与群组管理、成员操作 该模块定义了 `GroupAPI` Mixin 类,提供了所有与群组管理、成员操作
等相关的 OneBot v11 API 封装。 等相关的 OneBot v11 API 封装。
""" """
from typing import List, Dict, Any from typing import List, Dict, Any, Optional
import json import json
from ..managers.redis_manager import redis_manager from ..managers.redis_manager import redis_manager
from .base import BaseAPI from .base import BaseAPI
@@ -46,7 +46,7 @@ class GroupAPI(BaseAPI):
""" """
return await self.call_api("set_group_ban", {"group_id": group_id, "user_id": user_id, "duration": duration}) return await self.call_api("set_group_ban", {"group_id": group_id, "user_id": user_id, "duration": duration})
async def set_group_anonymous_ban(self, group_id: int, anonymous: Dict[str, Any] = None, duration: int = 1800, flag: str = None) -> Dict[str, Any]: async def set_group_anonymous_ban(self, group_id: int, anonymous: Optional[Dict[str, Any]] = None, duration: int = 1800, flag: Optional[str] = None) -> Dict[str, Any]:
""" """
禁言群组中的匿名用户。 禁言群组中的匿名用户。
@@ -61,7 +61,7 @@ class GroupAPI(BaseAPI):
Returns: Returns:
Dict[str, Any]: OneBot API 的响应数据。 Dict[str, Any]: OneBot API 的响应数据。
""" """
params = {"group_id": group_id, "duration": duration} params: Dict[str, Any] = {"group_id": group_id, "duration": duration}
if anonymous: if anonymous:
params["anonymous"] = anonymous params["anonymous"] = anonymous
if flag: if flag:
@@ -187,17 +187,18 @@ class GroupAPI(BaseAPI):
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时 await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
return GroupInfo(**res) return GroupInfo(**res)
async def get_group_list(self) -> List[GroupInfo]: async def get_group_list(self) -> Any:
""" """
获取机器人加入的所有群组的列表。 获取机器人加入的所有群组的列表。
Returns: Returns:
List[GroupInfo]: 包含所有群组信息的 `GroupInfo` 对象列表。 Any: 包含所有群组信息的列表(可能是字典列表或对象列表
""" """
res = await self.call_api("get_group_list") res = await self.call_api("get_group_list")
# 增加日志记录 API 原始返回 # 增加日志记录 API 原始返回
logger.debug(f"OneBot API 'get_group_list' raw response: {res}") logger.debug(f"OneBot API 'get_group_list' raw response: {res}")
return res
# 健壮性处理:处理标准的 OneBot v11 响应格式 # 健壮性处理:处理标准的 OneBot v11 响应格式
if isinstance(res, dict) and res.get("status") == "ok": if isinstance(res, dict) and res.get("status") == "ok":

39
core/api/media.py Normal file
View File

@@ -0,0 +1,39 @@
"""
媒体API模块
封装了与图片、语音等媒体文件相关的API。
"""
from typing import Dict, Any
from .base import BaseAPI
class MediaAPI(BaseAPI):
"""
媒体相关API
"""
async def can_send_image(self) -> Dict[str, Any]:
"""
检查是否可以发送图片
:return: OneBot v11标准响应
"""
return await self.call_api(action="can_send_image")
async def can_send_record(self) -> Dict[str, Any]:
"""
检查是否可以发送语音
:return: OneBot v11标准响应
"""
return await self.call_api(action="can_send_record")
async def get_image(self, file: str) -> Dict[str, Any]:
"""
获取图片信息
:param file: 图片文件名或路径
:return: OneBot v11标准响应
"""
return await self.call_api(action="get_image", params={"file": file})

View File

@@ -8,7 +8,8 @@ from typing import Union, List, Dict, Any, TYPE_CHECKING
from .base import BaseAPI from .base import BaseAPI
if TYPE_CHECKING: if TYPE_CHECKING:
from models import MessageSegment, OneBotEvent from models.message import MessageSegment
from models.events.base import OneBotEvent
class MessageAPI(BaseAPI): class MessageAPI(BaseAPI):
@@ -156,24 +157,6 @@ class MessageAPI(BaseAPI):
""" """
return await self.call_api("send_private_forward_msg", {"user_id": user_id, "messages": messages}) return await self.call_api("send_private_forward_msg", {"user_id": user_id, "messages": messages})
async def can_send_image(self) -> Dict[str, Any]:
"""
检查当前机器人账号是否可以发送图片。
Returns:
Dict[str, Any]: OneBot API 的响应数据。
"""
return await self.call_api("can_send_image")
async def can_send_record(self) -> Dict[str, Any]:
"""
检查当前机器人账号是否可以发送语音。
Returns:
Dict[str, Any]: OneBot API 的响应数据。
"""
return await self.call_api("can_send_record")
def _process_message(self, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Union[str, List[Dict[str, Any]]]: def _process_message(self, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Union[str, List[Dict[str, Any]]]:
""" """
内部方法:将消息内容处理成 OneBot API 可接受的格式。 内部方法:将消息内容处理成 OneBot API 可接受的格式。
@@ -192,7 +175,7 @@ class MessageAPI(BaseAPI):
return message return message
# 避免循环导入,在运行时导入 # 避免循环导入,在运行时导入
from models import MessageSegment from models.message import MessageSegment
if isinstance(message, MessageSegment): if isinstance(message, MessageSegment):
return [self._segment_to_dict(message)] return [self._segment_to_dict(message)]

View File

@@ -13,14 +13,15 @@ Bot 核心抽象模块
from typing import TYPE_CHECKING, Dict, Any, List, Union from typing import TYPE_CHECKING, Dict, Any, List, Union
from models.events.base import OneBotEvent from models.events.base import OneBotEvent
from models.message import MessageSegment from models.message import MessageSegment
from models.objects import GroupInfo, StrangerInfo
if TYPE_CHECKING: if TYPE_CHECKING:
from .WS import WS from .ws import WS
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI
class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI): class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI):
""" """
机器人核心类,封装了所有与 OneBot API 的交互。 机器人核心类,封装了所有与 OneBot API 的交互。
@@ -35,22 +36,22 @@ class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI):
Args: Args:
ws_client (WS): WebSocket 客户端实例,负责底层的 API 请求和响应处理。 ws_client (WS): WebSocket 客户端实例,负责底层的 API 请求和响应处理。
""" """
self.ws = ws_client super().__init__(ws_client, ws_client.self_id or 0)
self.code_executor = None
async def call_api(self, action: str, params: Dict[str, Any] = None) -> Any: async def get_group_list(self, no_cache: bool = False) -> List[GroupInfo]:
""" # GroupAPI.get_group_list 不支持 no_cache 参数,这里忽略它
底层 API 调用方法。 result = await super().get_group_list()
# 确保结果是 GroupInfo 对象列表
return [GroupInfo(**group) if isinstance(group, dict) else group for group in result]
所有具体的 API 实现最终都会调用此方法,通过 WebSocket 发送请求。 async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> StrangerInfo:
result = await super().get_stranger_info(user_id=user_id, no_cache=no_cache)
# 确保结果是 StrangerInfo 对象
if isinstance(result, dict):
return StrangerInfo(**result)
return result
Args:
action (str): API 的动作名称,例如 "send_group_msg"
params (Dict[str, Any], optional): API 请求的参数字典。Defaults to None.
Returns:
Any: OneBot API 的响应数据。
"""
return await self.ws.call_api(action, params)
def build_forward_node(self, user_id: int, nickname: str, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Dict[str, Any]: def build_forward_node(self, user_id: int, nickname: str, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Dict[str, Any]:
""" """

View File

@@ -4,9 +4,11 @@
负责读取和解析 config.toml 配置文件,提供全局配置对象。 负责读取和解析 config.toml 配置文件,提供全局配置对象。
""" """
from pathlib import Path from pathlib import Path
from typing import Any, Dict
import tomllib import tomllib
from pydantic import ValidationError
from .config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel
from .utils.logger import logger
class Config: class Config:
@@ -21,73 +23,67 @@ class Config:
:param file_path: 配置文件路径,默认为 "config.toml" :param file_path: 配置文件路径,默认为 "config.toml"
""" """
self.path = Path(file_path) self.path = Path(file_path)
self._data: Dict[str, Any] = {} self._model: ConfigModel
self.load() self.load()
def load(self): def load(self):
""" """
加载配置文件 加载并验证配置文件
:raises FileNotFoundError: 如果配置文件不存在 :raises FileNotFoundError: 如果配置文件不存在
:raises ValidationError: 如果配置格式不正确
""" """
if not self.path.exists(): if not self.path.exists():
logger.error(f"配置文件 {self.path} 未找到!")
raise FileNotFoundError(f"配置文件 {self.path} 未找到!") raise FileNotFoundError(f"配置文件 {self.path} 未找到!")
with open(self.path, "rb") as f: try:
self._data = tomllib.load(f) logger.info(f"正在从 {self.path} 加载配置...")
with open(self.path, "rb") as f:
raw_config = tomllib.load(f)
self._model = ConfigModel(**raw_config)
logger.success("配置加载并验证成功!")
except ValidationError as e:
logger.error("配置验证失败,请检查 `config.toml` 文件中的以下错误:")
for error in e.errors():
field = " -> ".join(map(str, error["loc"]))
logger.error(f" - 字段 '{field}': {error['msg']}")
raise
except Exception as e:
logger.exception(f"加载配置文件时发生未知错误: {e}")
raise
# 通过属性访问配置 # 通过属性访问配置
@property @property
def napcat_ws(self) -> dict: def napcat_ws(self) -> NapCatWSModel:
""" """
获取 NapCat WebSocket 配置 获取 NapCat WebSocket 配置
:return: 配置字典
""" """
return self._data.get("napcat_ws", {}) return self._model.napcat_ws
@property @property
def bot(self) -> dict: def bot(self) -> BotModel:
""" """
获取 Bot 基础配置 获取 Bot 基础配置
:return: 配置字典
""" """
return self._data.get("bot", {}) return self._model.bot
@property @property
def features(self) -> dict: def redis(self) -> RedisModel:
"""
获取功能特性配置
:return: 配置字典
"""
return self._data.get("features", {})
@property
def redis(self) -> dict:
""" """
获取 Redis 配置 获取 Redis 配置
:return: 配置字典
""" """
return self._data.get("redis", {}) return self._model.redis
@property @property
def docker(self) -> dict: def docker(self) -> DockerModel:
""" """
获取 Docker 配置 获取 Docker 配置
:return: 配置字典
""" """
return self._data.get("docker", {}) return self._model.docker
# 实例化全局配置对象 # 实例化全局配置对象
global_config = Config() global_config = Config()
if __name__ == "__main__":
print(global_config.napcat_ws)
print(global_config.bot.get("command"))
print(type(global_config.bot.get("command")) is list)
print(global_config.features)

60
core/config_models.py Normal file
View File

@@ -0,0 +1,60 @@
"""
Pydantic 配置模型模块
该模块使用 Pydantic 定义了与 `config.toml` 文件结构完全对应的配置模型。
这使得配置的加载、校验和访问都变得类型安全和健壮。
"""
from typing import List, Optional
from pydantic import BaseModel, Field
class NapCatWSModel(BaseModel):
"""
对应 `config.toml` 中的 `[napcat_ws]` 配置块。
"""
uri: str
token: str
reconnect_interval: int = 5
class BotModel(BaseModel):
"""
对应 `config.toml` 中的 `[bot]` 配置块。
"""
command: List[str] = Field(default_factory=lambda: ["/"])
ignore_self_message: bool = True
permission_denied_message: str = "权限不足,需要 {permission_name} 权限"
class RedisModel(BaseModel):
"""
对应 `config.toml` 中的 `[redis]` 配置块。
"""
host: str
port: int
db: int
password: str
class DockerModel(BaseModel):
"""
对应 `config.toml` 中的 `[docker]` 配置块。
"""
base_url: Optional[str] = None
sandbox_image: str = "python-sandbox:latest"
timeout: int = 10
concurrency_limit: int = 5
tls_verify: bool = False
ca_cert_path: Optional[str] = None
client_cert_path: Optional[str] = None
client_key_path: Optional[str] = None
class ConfigModel(BaseModel):
"""
顶层配置模型,整合了所有子配置块。
"""
napcat_ws: NapCatWSModel
bot: BotModel
redis: RedisModel
docker: DockerModel

View File

@@ -1,3 +1,3 @@
{ {
"admins": [] "admins": [2221577113]
} }

View File

@@ -6,11 +6,12 @@
""" """
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
from ..bot import Bot if TYPE_CHECKING:
from ..bot import Bot
from ..config_loader import global_config from ..config_loader import global_config
from ..managers.permission_manager import Permission, permission_manager from ..permission import Permission
from ..utils.executor import run_in_thread_pool from ..utils.executor import run_in_thread_pool
@@ -22,7 +23,7 @@ class BaseHandler(ABC):
self.handlers: List[Dict[str, Any]] = [] self.handlers: List[Dict[str, Any]] = []
@abstractmethod @abstractmethod
async def handle(self, bot: Bot, event: Any): async def handle(self, bot: "Bot", event: Any):
""" """
处理事件 处理事件
""" """
@@ -31,7 +32,7 @@ class BaseHandler(ABC):
async def _run_handler( async def _run_handler(
self, self,
func: Callable, func: Callable,
bot: Bot, bot: "Bot",
event: Any, event: Any,
args: Optional[List[str]] = None, args: Optional[List[str]] = None,
permission_granted: Optional[bool] = None permission_granted: Optional[bool] = None
@@ -41,7 +42,7 @@ class BaseHandler(ABC):
""" """
sig = inspect.signature(func) sig = inspect.signature(func)
params = sig.parameters params = sig.parameters
kwargs = {} kwargs: Dict[str, Any] = {}
if "bot" in params: if "bot" in params:
kwargs["bot"] = bot kwargs["bot"] = bot
@@ -68,21 +69,41 @@ class MessageHandler(BaseHandler):
super().__init__() super().__init__()
self.prefixes = prefixes self.prefixes = prefixes
self.commands: Dict[str, Dict] = {} self.commands: Dict[str, Dict] = {}
self.message_handlers: List[Callable] = [] self.message_handlers: List[Dict[str, Any]] = []
def clear(self):
"""
清空所有已注册的消息和命令处理器
"""
self.commands.clear()
self.message_handlers.clear()
def unregister_by_plugin_name(self, plugin_name: str):
"""
根据插件名卸载相关的消息和命令处理器
"""
# 卸载命令
commands_to_remove = [name for name, info in self.commands.items() if info["plugin_name"] == plugin_name]
for name in commands_to_remove:
del self.commands[name]
# 卸载通用消息处理器
self.message_handlers = [h for h in self.message_handlers if h["plugin_name"] != plugin_name]
def on_message(self) -> Callable: def on_message(self) -> Callable:
""" """
注册通用消息处理器 注册通用消息处理器
""" """
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
self.message_handlers.append(func) module = inspect.getmodule(func)
plugin_name = module.__name__ if module else "Unknown"
self.message_handlers.append({"func": func, "plugin_name": plugin_name})
return func return func
return decorator return decorator
def command( def command(
self, self,
*names: str, *names: str,
*names: str,
permission: Optional[Permission] = None, permission: Optional[Permission] = None,
override_permission_check: bool = False override_permission_check: bool = False
) -> Callable: ) -> Callable:
@@ -90,21 +111,25 @@ class MessageHandler(BaseHandler):
注册命令处理器 注册命令处理器
""" """
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
module = inspect.getmodule(func)
plugin_name = module.__name__ if module else "Unknown"
for name in names: for name in names:
self.commands[name] = { self.commands[name] = {
"func": func, "func": func,
"permission": permission, "permission": permission,
"override_permission_check": override_permission_check, "override_permission_check": override_permission_check,
"plugin_name": plugin_name,
} }
return func return func
return decorator return decorator
async def handle(self, bot: Bot, event: Any): async def handle(self, bot: "Bot", event: Any):
""" """
处理消息事件,包括通用消息和命令 处理消息事件,分发给命令处理器或通用消息处理器
""" """
for handler in self.message_handlers: from ..managers import permission_manager
consumed = await self._run_handler(handler, bot, event) for handler_info in self.message_handlers:
consumed = await self._run_handler(handler_info["func"], bot, event)
if consumed: if consumed:
return return
@@ -136,7 +161,7 @@ class MessageHandler(BaseHandler):
if not permission_granted and not override_check: if not permission_granted and not override_check:
permission_name = permission.name if isinstance(permission, Permission) else permission permission_name = permission.name if isinstance(permission, Permission) else permission
message_template = global_config.bot.get("permission_denied_message", "权限不足,需要 {permission_name} 权限") message_template = global_config.bot.permission_denied_message
await bot.send(event, message_template.format(permission_name=permission_name)) await bot.send(event, message_template.format(permission_name=permission_name))
return return
@@ -153,12 +178,23 @@ class NoticeHandler(BaseHandler):
""" """
通知事件处理器 通知事件处理器
""" """
def clear(self):
self.handlers.clear()
def unregister_by_plugin_name(self, plugin_name: str):
"""
根据插件名卸载相关的通知处理器
"""
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
def register(self, notice_type: Optional[str] = None) -> Callable: def register(self, notice_type: Optional[str] = None) -> Callable:
""" """
注册通知处理器 注册通知处理器
""" """
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
self.handlers.append({"type": notice_type, "func": func}) module = inspect.getmodule(func)
plugin_name = module.__name__ if module else "Unknown"
self.handlers.append({"type": notice_type, "func": func, "plugin_name": plugin_name})
return func return func
return decorator return decorator
@@ -175,12 +211,23 @@ class RequestHandler(BaseHandler):
""" """
请求事件处理器 请求事件处理器
""" """
def clear(self):
self.handlers.clear()
def unregister_by_plugin_name(self, plugin_name: str):
"""
根据插件名卸载相关的请求处理器
"""
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
def register(self, request_type: Optional[str] = None) -> Callable: def register(self, request_type: Optional[str] = None) -> Callable:
""" """
注册请求处理器 注册请求处理器
""" """
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
self.handlers.append({"type": request_type, "func": func}) module = inspect.getmodule(func)
plugin_name = module.__name__ if module else "Unknown"
self.handlers.append({"type": request_type, "func": func, "plugin_name": plugin_name})
return func return func
return decorator return decorator

View File

@@ -0,0 +1,40 @@
"""
管理器包
这个包集中了机器人核心的单例管理器。
通过从这里导入,可以确保在整个应用中访问到的都是同一个实例。
"""
from ..config_loader import global_config
from .admin_manager import AdminManager
from .command_manager import CommandManager
from .permission_manager import PermissionManager
from .plugin_manager import PluginManager
from .redis_manager import RedisManager
# --- 实例化所有单例管理器 ---
# 管理员管理器
admin_manager = AdminManager()
# 权限管理器
permission_manager = PermissionManager()
# 命令与事件管理器 (别名 matcher)
command_manager = CommandManager(prefixes=tuple(global_config.bot.command))
matcher = command_manager
# 插件管理器
plugin_manager = PluginManager(command_manager)
plugin_manager.load_all_plugins()
# Redis 管理器
redis_manager = RedisManager()
__all__ = [
"admin_manager",
"permission_manager",
"command_manager",
"matcher",
"plugin_manager",
"redis_manager",
]

View File

@@ -26,8 +26,7 @@ class AdminManager(Singleton):
""" """
初始化 AdminManager 初始化 AdminManager
""" """
super().__init__() if hasattr(self, '_initialized') and self._initialized:
if not self._initialized:
return return
# 管理员数据文件路径 # 管理员数据文件路径
@@ -39,7 +38,12 @@ class AdminManager(Singleton):
) )
self._admins: Set[int] = set() self._admins: Set[int] = set()
# 确保数据目录存在
os.makedirs(os.path.dirname(self.data_file), exist_ok=True)
logger.info("管理员管理器初始化完成") logger.info("管理员管理器初始化完成")
super().__init__()
async def initialize(self): async def initialize(self):
""" """

View File

@@ -12,7 +12,16 @@ from ..handlers.event_handler import MessageHandler, NoticeHandler, RequestHandl
# 从配置中获取命令前缀 # 从配置中获取命令前缀
command_prefixes = global_config.bot.get("command", ("/",)) _config_prefixes = global_config.bot.command
# 确保前缀配置是元组格式
_final_prefixes: Tuple[str, ...]
if isinstance(_config_prefixes, list):
_final_prefixes = tuple(_config_prefixes)
elif isinstance(_config_prefixes, str):
_final_prefixes = (_config_prefixes,)
else:
_final_prefixes = tuple(_config_prefixes)
class CommandManager: class CommandManager:
@@ -59,6 +68,35 @@ class CommandManager:
"usage": "/help", "usage": "/help",
} }
def clear_all_handlers(self):
"""
清空所有已注册的事件处理器。
注意:这也会移除内置的 /help 命令,因此需要重新注册。
"""
self.message_handler.clear()
self.notice_handler.clear()
self.request_handler.clear()
self.plugins.clear()
# 清空后,需要重新注册内置命令
self._register_internal_commands()
def unload_plugin(self, plugin_name: str):
"""
卸载指定插件的所有处理器和命令。
Args:
plugin_name (str): 插件的模块名 (例如 'plugins.bili_parser')
"""
self.message_handler.unregister_by_plugin_name(plugin_name)
self.notice_handler.unregister_by_plugin_name(plugin_name)
self.request_handler.unregister_by_plugin_name(plugin_name)
# 移除插件元信息
plugins_to_remove = [name for name in self.plugins if name.startswith(plugin_name)]
for name in plugins_to_remove:
del self.plugins[name]
# --- 装饰器代理 --- # --- 装饰器代理 ---
def on_message(self) -> Callable: def on_message(self) -> Callable:
@@ -102,7 +140,7 @@ class CommandManager:
根据事件的 `post_type` 将其分发给对应的处理器。 根据事件的 `post_type` 将其分发给对应的处理器。
""" """
if event.post_type == 'message' and global_config.bot.get('ignore_self_message', False): if event.post_type == 'message' and global_config.bot.ignore_self_message:
if hasattr(event, 'user_id') and hasattr(event, 'self_id') and event.user_id == event.self_id: if hasattr(event, 'user_id') and hasattr(event, 'self_id') and event.user_id == event.self_id:
return return
@@ -130,14 +168,6 @@ class CommandManager:
await bot.send(event, help_text.strip()) await bot.send(event, help_text.strip())
# --- 全局单例 ---
# 确保前缀配置是元组格式
if isinstance(command_prefixes, list):
command_prefixes = tuple(command_prefixes)
elif isinstance(command_prefixes, str):
command_prefixes = (command_prefixes,)
# 实例化全局唯一的命令管理器 # 实例化全局唯一的命令管理器
matcher = CommandManager(prefixes=command_prefixes) matcher = CommandManager(prefixes=_final_prefixes)

View File

@@ -13,64 +13,17 @@
""" """
import json import json
import os import os
from functools import total_ordering from typing import Dict, Optional
from typing import Dict
from ..utils.logger import logger from ..utils.logger import logger
from ..utils.singleton import Singleton from ..utils.singleton import Singleton
from .admin_manager import admin_manager from .admin_manager import admin_manager
from ..permission import Permission
@total_ordering
class Permission:
"""
权限封装类
封装了权限的名称和等级,并提供了比较方法。
使用 @total_ordering 装饰器可以自动生成所有的比较运算符。
"""
def __init__(self, name: str, level: int):
"""
初始化权限对象
Args:
name (str): 权限名称 (e.g., "admin", "op")
level (int): 权限等级,数字越大权限越高
"""
self.name = name
self.level = level
def __eq__(self, other):
"""
判断权限是否相等
"""
if not isinstance(other, Permission):
return NotImplemented
return self.level == other.level
def __lt__(self, other):
"""
判断权限是否小于另一个权限
"""
if not isinstance(other, Permission):
return NotImplemented
return self.level < other.level
def __str__(self) -> str:
"""
返回权限的字符串表示(即权限名称)
"""
return self.name
# 定义全局权限常量
ADMIN = Permission("admin", 3)
OP = Permission("op", 2)
USER = Permission("user", 1)
# 用于从字符串名称查找权限对象的字典 # 用于从字符串名称查找权限对象的字典
_PERMISSIONS: Dict[str, Permission] = { _PERMISSIONS: Dict[str, Permission] = {
p.name: p for p in [ADMIN, OP, USER] p.value: p for p in Permission
} }
@@ -88,8 +41,7 @@ class PermissionManager(Singleton):
如果已经初始化过,则直接返回。 如果已经初始化过,则直接返回。
""" """
super().__init__() if hasattr(self, '_initialized') and self._initialized:
if not self._initialized:
return return
# 权限数据文件路径 # 权限数据文件路径
@@ -111,6 +63,7 @@ class PermissionManager(Singleton):
self.load() self.load()
logger.info("权限管理器初始化完成") logger.info("权限管理器初始化完成")
super().__init__()
def load(self) -> None: def load(self) -> None:
""" """
@@ -164,12 +117,12 @@ class PermissionManager(Singleton):
""" """
# 首先,通过 AdminManager 检查是否为管理员 # 首先,通过 AdminManager 检查是否为管理员
if await admin_manager.is_admin(user_id): if await admin_manager.is_admin(user_id):
return ADMIN return Permission.ADMIN
# 如果不是管理员,则从 permissions.json 中查找 # 如果不是管理员,则从 permissions.json 中查找
user_id_str = str(user_id) user_id_str = str(user_id)
level_name = self._data["users"].get(user_id_str, USER.name) level_name = self._data["users"].get(user_id_str, Permission.USER.value)
return _PERMISSIONS.get(level_name, USER) return _PERMISSIONS.get(level_name, Permission.USER)
def set_user_permission(self, user_id: int, permission: Permission) -> None: def set_user_permission(self, user_id: int, permission: Permission) -> None:
""" """
@@ -182,13 +135,13 @@ class PermissionManager(Singleton):
Raises: Raises:
ValueError: 如果权限对象无效 ValueError: 如果权限对象无效
""" """
if not isinstance(permission, Permission) or permission.name not in _PERMISSIONS: if not isinstance(permission, Permission):
raise ValueError(f"无效的权限对象: {permission}") raise ValueError(f"无效的权限对象: {permission}")
user_id_str = str(user_id) user_id_str = str(user_id)
self._data["users"][user_id_str] = permission.name self._data["users"][user_id_str] = permission.value
self.save() self.save()
logger.info(f"设置用户 {user_id} 的权限级别为 {permission.name}") logger.info(f"设置用户 {user_id} 的权限级别为 {permission.value}")
def remove_user(self, user_id: int) -> None: def remove_user(self, user_id: int) -> None:
""" """
@@ -214,17 +167,17 @@ class PermissionManager(Singleton):
Returns: Returns:
bool: 如果用户权限 >= 所需权限,返回 True否则返回 False bool: 如果用户权限 >= 所需权限,返回 True否则返回 False
""" """
# 如果传入的是字符串,先转换为 Permission 对象
if isinstance(required_permission, str):
required_permission = _PERMISSIONS.get(required_permission.lower())
if not required_permission:
# 如果是无效的权限字符串,默认拒绝
logger.warning(f"检测到无效的权限检查字符串: {required_permission}")
return False
user_permission = await self.get_user_permission(user_id) user_permission = await self.get_user_permission(user_id)
return user_permission >= required_permission return user_permission >= required_permission
def get_all_user_permissions(self) -> Dict[str, str]:
"""
获取所有已配置的用户权限
:return: 一个包含所有用户权限的字典
"""
return self._data["users"].copy()
def get_all_users(self) -> Dict[str, str]: def get_all_users(self) -> Dict[str, str]:
""" """
获取所有设置了权限的用户及其级别名称 获取所有设置了权限的用户及其级别名称
@@ -243,22 +196,22 @@ class PermissionManager(Singleton):
logger.info("已清空所有权限设置") logger.info("已清空所有权限设置")
# 全局权限管理器实例
permission_manager = PermissionManager()
def require_admin(func): def require_admin(func):
""" """
一个装饰器,用于限制命令只能由管理员执行。 一个装饰器,用于限制命令只能由管理员执行。
""" """
from functools import wraps from functools import wraps
from models.events.message import MessageEvent from models.events.message import MessageEvent
from core.managers import permission_manager
@wraps(func) @wraps(func)
async def wrapper(event: MessageEvent, *args, **kwargs): async def wrapper(event: MessageEvent, *args, **kwargs):
user_id = event.user_id user_id = event.user_id
if await permission_manager.check_permission(user_id, ADMIN): if await permission_manager.check_permission(user_id, Permission.ADMIN):
return await func(event, *args, **kwargs) return await func(event, *args, **kwargs)
else: else:
await event.reply("抱歉,您没有权限执行此命令。") # 假设 event 对象有 reply 方法
if hasattr(event, "reply"):
await event.reply("抱歉,您没有权限执行此命令。")
return None return None
return wrapper return wrapper

View File

@@ -1,126 +1,97 @@
""" """
插件管理器模块 插件管理器模块
负责扫描、加载和管理 `base_plugins` 目录下的所有插件。 负责扫描、加载和管理 `plugins` 目录下的所有插件。
""" """
import importlib import importlib
import json
import os import os
import pkgutil import pkgutil
import sys import sys
from typing import Set
from .command_manager import matcher
from ..utils.exceptions import SyncHandlerError from ..utils.exceptions import SyncHandlerError
from ..utils.logger import logger from ..utils.logger import logger
from ..utils.executor import run_in_thread_pool
def load_all_plugins(): class PluginManager:
""" """
扫描并加载 `plugins` 目录下的所有插件。 插件管理器类
该函数会遍历 `plugins` 目录下的所有模块:
1. 如果模块已加载,则执行 reload 操作(用于热重载)。
2. 如果模块未加载,则执行 import 操作。
加载过程中会提取插件元数据 `__plugin_meta__` 并注册到 CommandManager。
""" """
plugin_dir = os.path.join( def __init__(self, command_manager):
os.path.dirname(os.path.abspath(__file__)), "..", "plugins" """
) 初始化插件管理器
package_name = "plugins"
logger.info(f"正在从 {package_name} 加载插件...") :param command_manager: CommandManager的实例
"""
self.command_manager = command_manager
self.loaded_plugins: Set[str] = set()
for loader, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]): def load_all_plugins(self):
full_module_name = f"{package_name}.{module_name}" """
扫描并加载 `plugins` 目录下的所有插件。
"""
# 使用 pathlib 获取更可靠的路径
# 当前文件: core/managers/plugin_manager.py
# 目标: plugins/
current_dir = os.path.dirname(os.path.abspath(__file__))
# 回退两级到项目根目录 (core/managers -> core -> root)
root_dir = os.path.dirname(os.path.dirname(current_dir))
plugin_dir = os.path.join(root_dir, "plugins")
package_name = "plugins"
if not os.path.exists(plugin_dir):
logger.error(f"插件目录不存在: {plugin_dir}")
return
logger.info(f"正在从 {package_name} 加载插件 (路径: {plugin_dir})...")
for _, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
full_module_name = f"{package_name}.{module_name}"
try:
if full_module_name in self.loaded_plugins:
self.command_manager.unload_plugin(full_module_name)
module = importlib.reload(sys.modules[full_module_name])
action = "重载"
else:
module = importlib.import_module(full_module_name)
action = "加载"
if hasattr(module, "__plugin_meta__"):
meta = getattr(module, "__plugin_meta__")
self.command_manager.plugins[full_module_name] = meta
self.loaded_plugins.add(full_module_name)
type_str = "" if is_pkg else "文件"
logger.success(f" [{type_str}] 成功{action}: {module_name}")
except SyncHandlerError as e:
logger.error(f" 插件 {module_name} 加载失败: {e} (跳过此插件)")
except Exception as e:
logger.exception(
f" {action if 'action' in locals() else '加载'}插件 {module_name} 失败: {e}"
)
def reload_plugin(self, full_module_name: str):
"""
精确重载单个插件。
"""
if full_module_name not in self.loaded_plugins:
logger.warning(f"尝试重载一个未被加载的插件: {full_module_name},将按首次加载处理。")
if full_module_name not in sys.modules:
logger.error(f"重载失败: 模块 {full_module_name} 未在 sys.modules 中找到。")
return
try: try:
if full_module_name in sys.modules: self.command_manager.unload_plugin(full_module_name)
module = importlib.reload(sys.modules[full_module_name]) module = importlib.reload(sys.modules[full_module_name])
action = "重载"
else:
module = importlib.import_module(full_module_name)
action = "加载"
# 提取插件元数据
if hasattr(module, "__plugin_meta__"): if hasattr(module, "__plugin_meta__"):
meta = getattr(module, "__plugin_meta__") meta = getattr(module, "__plugin_meta__")
matcher.plugins[full_module_name] = meta self.command_manager.plugins[full_module_name] = meta
type_str = "" if is_pkg else "文件" logger.success(f"插件 {full_module_name} 已成功重载。")
logger.success(f" [{type_str}] 成功{action}: {module_name}")
except SyncHandlerError as e:
logger.error(f" 插件 {module_name} 加载失败: {e} (跳过此插件)")
except Exception as e: except Exception as e:
print( logger.exception(f"重载插件 {full_module_name} 时发生错误: {e}")
f" {action if 'action' in locals() else '加载'}插件 {module_name} 失败: {e}"
)
class PluginDataManager:
"""
用于管理插件产生的数据文件的类
"""
def __init__(self, plugin_name: str):
"""
初始化插件数据管理器
:param plugin_name: 插件名称
"""
self.plugin_name = plugin_name
self.data_file = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"plugins",
"data",
self.plugin_name + ".json",
)
self.data = {}
async def load(self):
"""读取配置文件"""
if not os.path.exists(self.data_file):
await self.set(self.plugin_name, [])
try:
with open(self.data_file, "r", encoding="utf-8") as f:
self.data = await run_in_thread_pool(json.load, f)
except json.JSONDecodeError:
self.data = {}
async def save(self):
"""保存配置到文件"""
with open(self.data_file, "w", encoding="utf-8") as f:
await run_in_thread_pool(json.dump, self.data, f, indent=2, ensure_ascii=False)
def get(self, key, default=None):
"""获取配置项"""
return self.data.get(key, default)
async def set(self, key, value):
"""设置配置项"""
self.data[key] = value
await self.save()
async def add(self, key, value):
"""添加配置项"""
if key not in self.data:
self.data[key] = []
self.data[key].append(value)
await self.save()
async def remove(self, key):
"""删除配置项"""
if key in self.data:
del self.data[key]
await self.save()
async def clear(self):
"""清空所有配置"""
self.data.clear()
await self.save()
def get_all(self):
return self.data.copy()

View File

@@ -20,10 +20,11 @@ class RedisManager:
""" """
if self._redis is None: if self._redis is None:
try: try:
host = config.redis['host'] redis_config = config.redis
port = config.redis['port'] host = redis_config.host
db = config.redis['db'] port = redis_config.port
password = config.redis.get('password') db = redis_config.db
password = redis_config.password
logger.info(f"正在尝试连接 Redis: {host}:{port}, DB: {db}") logger.info(f"正在尝试连接 Redis: {host}:{port}, DB: {db}")
@@ -54,5 +55,17 @@ class RedisManager:
raise ConnectionError("Redis 未初始化或连接失败,请先调用 initialize()") raise ConnectionError("Redis 未初始化或连接失败,请先调用 initialize()")
return self._redis return self._redis
async def get(self, name):
"""
获取指定键的值
"""
return await self.redis.get(name)
async def set(self, name, value, ex=None):
"""
设置指定键的值
"""
return await self.redis.set(name, value, ex=ex)
# 全局 Redis 管理器实例 # 全局 Redis 管理器实例
redis_manager = RedisManager() redis_manager = RedisManager()

42
core/permission.py Normal file
View File

@@ -0,0 +1,42 @@
from enum import Enum
from functools import total_ordering
@total_ordering
class Permission(Enum):
"""
定义用户权限等级的枚举类。
使用 @total_ordering 装饰器,只需定义 __lt__ 和 __eq__
即可自动实现所有比较运算符。
"""
USER = "user"
OP = "op"
ADMIN = "admin"
@property
def _level_map(self):
"""
内部属性,用于映射枚举成员到整数等级。
"""
return {
Permission.USER: 1,
Permission.OP: 2,
Permission.ADMIN: 3
}
def __lt__(self, other):
"""
比较当前权限是否小于另一个权限。
"""
if not isinstance(other, Permission):
return NotImplemented
return self._level_map[self] < self._level_map[other]
def __ge__(self, other):
"""
比较当前权限是否大于等于另一个权限。
"""
if not isinstance(other, Permission):
return NotImplemented
return self._level_map[self] >= self._level_map[other]

View File

@@ -2,7 +2,8 @@
import asyncio import asyncio
import docker import docker
from docker.tls import TLSConfig from docker.tls import TLSConfig
from typing import Dict, Any, Callable from docker.types import LogConfig
from typing import Any, Callable
from core.utils.logger import logger from core.utils.logger import logger
@@ -10,21 +11,20 @@ class CodeExecutor:
""" """
代码执行引擎,负责管理一个异步任务队列和并发的 Docker 容器执行。 代码执行引擎,负责管理一个异步任务队列和并发的 Docker 容器执行。
""" """
def __init__(self, bot_instance, config: Dict[str, Any]): def __init__(self, config: Any):
""" """
初始化代码执行引擎。 初始化代码执行引擎。
:param bot_instance: Bot 实例,用于后续的消息回复 :param config: 从 config_loader.py 加载的全局配置对象
:param config: 从 config.toml 加载的配置字典。
""" """
self.bot = bot_instance self.bot: Any = None # Bot 实例将在 WS 连接成功后动态注入
self.task_queue = asyncio.Queue() self.task_queue: asyncio.Queue = asyncio.Queue()
# 从传入的配置中读取 Docker 相关设置 # 从传入的配置中读取 Docker 相关设置
docker_config = config.docker docker_config = config.docker
self.docker_base_url = docker_config.get("base_url") self.docker_base_url = docker_config.base_url
self.sandbox_image = docker_config.get("sandbox_image", "python-sandbox:latest") self.sandbox_image = docker_config.sandbox_image
self.timeout = docker_config.get("timeout", 10) self.timeout = docker_config.timeout
concurrency = docker_config.get("concurrency_limit", 5) concurrency = docker_config.concurrency_limit
self.concurrency_limit = asyncio.Semaphore(concurrency) self.concurrency_limit = asyncio.Semaphore(concurrency)
self.docker_client = None self.docker_client = None
@@ -34,10 +34,10 @@ class CodeExecutor:
if self.docker_base_url: if self.docker_base_url:
# 如果配置了远程 Docker 地址,则使用 TLS 选项进行连接 # 如果配置了远程 Docker 地址,则使用 TLS 选项进行连接
tls_config = None tls_config = None
if docker_config.get("tls_verify", False): if docker_config.tls_verify:
tls_config = TLSConfig( tls_config = TLSConfig(
ca_cert=docker_config.get("ca_cert_path"), ca_cert=docker_config.ca_cert_path,
client_cert=(docker_config.get("client_cert_path"), docker_config.get("client_key_path")), client_cert=(docker_config.client_cert_path, docker_config.client_key_path),
verify=True verify=True
) )
self.docker_client = docker.DockerClient(base_url=self.docker_base_url, tls=tls_config) self.docker_client = docker.DockerClient(base_url=self.docker_base_url, tls=tls_config)
@@ -60,7 +60,15 @@ class CodeExecutor:
将代码执行任务添加到队列中。 将代码执行任务添加到队列中。
:param code: 待执行的 Python 代码字符串。 :param code: 待执行的 Python 代码字符串。
:param callback: 执行完毕后用于回复结果的回调函数。 :param callback: 执行完毕后用于回复结果的回调函数。
:raises RuntimeError: 如果 Docker 客户端未初始化。
""" """
if not self.docker_client:
logger.warning("[CodeExecutor] 尝试添加任务,但 Docker 客户端未初始化。任务被拒绝。")
# 这里可以选择抛出异常,或者直接调用回调返回错误信息
# 为了用户体验,我们构造一个错误结果并直接调用回调(如果可能)
# 但由于 callback 返回 Future这里简单起见我们记录日志并抛出异常
raise RuntimeError("Docker环境未就绪无法执行代码。")
task = {"code": code, "callback": callback} task = {"code": code, "callback": callback}
await self.task_queue.put(task) await self.task_queue.put(task)
logger.info(f"[CodeExecutor] 新的代码执行任务已入队 (队列当前长度: {self.task_queue.qsize()})。") logger.info(f"[CodeExecutor] 新的代码执行任务已入队 (队列当前长度: {self.task_queue.qsize()})。")
@@ -125,6 +133,9 @@ class CodeExecutor:
同步函数:在 Docker 容器中运行代码。 同步函数:在 Docker 容器中运行代码。
此函数通过手动管理容器生命周期来提高稳定性。 此函数通过手动管理容器生命周期来提高稳定性。
""" """
if self.docker_client is None:
raise docker.errors.DockerException("Docker client is not initialized.")
container = None container = None
try: try:
# 1. 创建容器 # 1. 创建容器
@@ -134,7 +145,7 @@ class CodeExecutor:
mem_limit='128m', mem_limit='128m',
cpu_shares=512, cpu_shares=512,
network_disabled=True, network_disabled=True,
log_config={'type': 'json-file', 'config': {'max-size': '1m'}}, log_config=LogConfig(type='json-file', config={'max-size': '1m'}),
) )
# 2. 启动容器 # 2. 启动容器
container.start() container.start()
@@ -150,7 +161,7 @@ class CodeExecutor:
# 5. 检查退出码,如果不为 0则手动抛出 ContainerError # 5. 检查退出码,如果不为 0则手动抛出 ContainerError
if result.get('StatusCode', 0) != 0: if result.get('StatusCode', 0) != 0:
raise docker.errors.ContainerError( raise docker.errors.ContainerError(
container, result['StatusCode'], f"python -c '{code}'", self.sandbox_image, stderr container, result['StatusCode'], f"python -c '{code}'", self.sandbox_image, stderr.decode('utf-8')
) )
return stdout return stdout
@@ -166,11 +177,11 @@ class CodeExecutor:
except Exception as e: except Exception as e:
logger.error(f"[CodeExecutor] 强制移除容器 {container.id} 时失败: {e}") logger.error(f"[CodeExecutor] 强制移除容器 {container.id} 时失败: {e}")
def initialize_executor(bot_instance, config: Dict[str, Any]): def initialize_executor(config: Any):
""" """
初始化并返回一个 CodeExecutor 实例。 初始化并返回一个 CodeExecutor 实例。
""" """
return CodeExecutor(bot_instance, config) return CodeExecutor(config)
async def run_in_thread_pool(sync_func, *args, **kwargs): async def run_in_thread_pool(sync_func, *args, **kwargs):
""" """

View File

@@ -13,11 +13,12 @@ WebSocket 连接。它是整个机器人框架的底层通信基础。
""" """
import asyncio import asyncio
import json import json
from typing import Any, Dict, Optional
import uuid import uuid
import websockets import websockets
from models import EventFactory from models.events.factory import EventFactory
from .bot import Bot from .bot import Bot
from .config_loader import global_config from .config_loader import global_config
@@ -30,7 +31,7 @@ class WS:
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。 WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
""" """
def __init__(self): def __init__(self, code_executor=None):
""" """
初始化 WebSocket 客户端。 初始化 WebSocket 客户端。
@@ -38,13 +39,15 @@ class WS:
""" """
# 读取参数 # 读取参数
cfg = global_config.napcat_ws cfg = global_config.napcat_ws
self.url = cfg.get("uri") self.url = cfg.uri
self.token = cfg.get("token") self.token = cfg.token
self.reconnect_interval = cfg.get("reconnect_interval", 5) self.reconnect_interval = cfg.reconnect_interval
self.ws = None self.ws = None
self._pending_requests = {} self._pending_requests = {}
self.bot = Bot(self) self.bot: Bot | None = None
self.self_id: int | None = None
self.code_executor = code_executor
async def connect(self): async def connect(self):
""" """
@@ -124,18 +127,43 @@ class WS:
try: try:
# 使用工厂创建事件对象 # 使用工厂创建事件对象
event = EventFactory.create_event(event_data) event = EventFactory.create_event(event_data)
# 尝试初始化 Bot 实例 (如果尚未初始化且事件包含 self_id)
# 只要事件中包含 self_id我们就可以初始化 Bot不必非要等待 meta_event
if self.bot is None and hasattr(event, 'self_id'):
self.self_id = event.self_id
self.bot = Bot(self)
logger.success(f"Bot 实例初始化完成: self_id={self.self_id}")
# 将代码执行器注入到 Bot 和执行器自身
if self.code_executor:
self.bot.code_executor = self.code_executor
self.code_executor.bot = self.bot
logger.info("代码执行器已成功注入 Bot 实例。")
# 如果 bot 尚未初始化,则不处理后续事件
if self.bot is None:
logger.warning("Bot 尚未初始化,跳过事件处理。")
return
event.bot = self.bot # 注入 Bot 实例 event.bot = self.bot # 注入 Bot 实例
# 打印日志 # 打印日志
if event.post_type == "message": if event.post_type == "message":
sender_name = event.sender.nickname if event.sender else "Unknown" sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown"
logger.info(f"[消息] {event.message_type} | {event.user_id}({sender_name}): {event.raw_message}") message_type = getattr(event, "message_type", "Unknown")
user_id = getattr(event, "user_id", "Unknown")
raw_message = getattr(event, "raw_message", "")
logger.info(f"[消息] {message_type} | {user_id}({sender_name}): {raw_message}")
elif event.post_type == "notice": elif event.post_type == "notice":
logger.info(f"[通知] {event.notice_type}") notice_type = getattr(event, "notice_type", "Unknown")
logger.info(f"[通知] {notice_type}")
elif event.post_type == "request": elif event.post_type == "request":
logger.info(f"[请求] {event.request_type}") request_type = getattr(event, "request_type", "Unknown")
logger.info(f"[请求] {request_type}")
elif event.post_type == "meta_event": elif event.post_type == "meta_event":
logger.debug(f"[元事件] {event.meta_event_type}") meta_event_type = getattr(event, "meta_event_type", "Unknown")
logger.debug(f"[元事件] {meta_event_type}")
# 分发事件 # 分发事件
@@ -144,7 +172,7 @@ class WS:
except Exception as e: except Exception as e:
logger.exception(f"事件处理异常: {e}") logger.exception(f"事件处理异常: {e}")
async def call_api(self, action: str, params: dict = None): async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
""" """
向 OneBot v11 实现端发送一个 API 请求。 向 OneBot v11 实现端发送一个 API 请求。

37
main.py
View File

@@ -15,7 +15,7 @@ from core.utils.logger import logger
from core.managers.admin_manager import admin_manager from core.managers.admin_manager import admin_manager
from core.ws import WS from core.ws import WS
from core.managers.plugin_manager import load_all_plugins from core.managers import plugin_manager
from core.managers.redis_manager import redis_manager from core.managers.redis_manager import redis_manager
from core.utils.executor import run_in_thread_pool, initialize_executor from core.utils.executor import run_in_thread_pool, initialize_executor
from core.config_loader import global_config as config from core.config_loader import global_config as config
@@ -25,6 +25,8 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, ROOT_DIR) sys.path.insert(0, ROOT_DIR)
# 获取插件目录的绝对路径
PLUGIN_DIR = os.path.join(ROOT_DIR, "plugins")
class PluginReloadHandler(FileSystemEventHandler): class PluginReloadHandler(FileSystemEventHandler):
@@ -32,7 +34,7 @@ class PluginReloadHandler(FileSystemEventHandler):
文件变更处理器,用于热重载插件 文件变更处理器,用于热重载插件
继承自 watchdog.events.FileSystemEventHandler 继承自 watchdog.events.FileSystemEventHandler
监听 base_plugins 目录下的文件变化,并触发插件重载。 监听 plugins 目录下的文件变化,并触发插件重载。
""" """
def __init__(self, loop: asyncio.AbstractEventLoop): def __init__(self, loop: asyncio.AbstractEventLoop):
""" """
@@ -53,12 +55,14 @@ class PluginReloadHandler(FileSystemEventHandler):
if file_system_event.is_directory: if file_system_event.is_directory:
return return
src_path = file_system_event.src_path
# 只监控 py 文件 # 只监控 py 文件
if not file_system_event.src_path.endswith(".py"): if not src_path.endswith(".py"):
return return
# 过滤掉一些临时文件 # 过滤掉一些临时文件
if "__pycache__" in file_system_event.src_path: if "__pycache__" in src_path or not src_path.startswith(PLUGIN_DIR):
return return
# 简单的防抖动 # 简单的防抖动
@@ -68,13 +72,18 @@ class PluginReloadHandler(FileSystemEventHandler):
self.last_reload_time = current_time self.last_reload_time = current_time
logger.info(f"检测到文件变更: {file_system_event.src_path}") # 从文件路径解析出模块名
logger.info("正在重载插件...") # 例如: C:\path\to\project\plugins\bili_parser.py -> plugins.bili_parser
relative_path = os.path.relpath(src_path, ROOT_DIR)
module_name = os.path.splitext(relative_path.replace(os.sep, '.'))[0]
logger.info(f"检测到文件变更: {src_path}")
logger.info(f"准备重载插件: {module_name}...")
try: try:
# 使用线程安全的方式在主事件循环中运行异步的插件载函数 # 使用线程安全的方式在主事件循环中运行异步的插件载函数
asyncio.run_coroutine_threadsafe(run_in_thread_pool(load_all_plugins), self.loop) asyncio.run_coroutine_threadsafe(run_in_thread_pool(plugin_manager.reload_plugin, module_name), self.loop)
logger.success("插件重载完成") logger.success(f"插件 {module_name} 重载任务已提交")
except Exception as e: except Exception as e:
logger.exception(f"重载失败: {e}") logger.exception(f"重载失败: {e}")
@@ -88,8 +97,7 @@ async def main():
2. 初始化 WebSocket 客户端 2. 初始化 WebSocket 客户端
3. 建立连接并保持运行 3. 建立连接并保持运行
""" """
# 首次加载插件 # 插件加载已移至 core.managers.__init__.py 中自动执行
await run_in_thread_pool(load_all_plugins)
# 初始化 Redis 连接 # 初始化 Redis 连接
await redis_manager.initialize() await redis_manager.initialize()
@@ -114,11 +122,10 @@ async def main():
logger.warning(f"插件目录不存在 {plugin_path}") logger.warning(f"插件目录不存在 {plugin_path}")
try: try:
websocket_client = WS()
# 初始化代码执行器 # 初始化代码执行器
code_executor = initialize_executor(websocket_client, config) code_executor = initialize_executor(config)
websocket_client.bot.code_executor = code_executor # 将执行器实例附加到 bot.bot 对象上
websocket_client = WS(code_executor=code_executor)
# 启动代码执行器的后台 worker # 启动代码执行器的后台 worker
logger.debug("[Main] 检查是否需要启动代码执行 Worker...") logger.debug("[Main] 检查是否需要启动代码执行 Worker...")

View File

@@ -1,97 +1,23 @@
from .events.base import OneBotEvent """
from .events.factory import EventFactory Models 包
from .events.message import (
GroupMessageEvent,
MessageEvent,
MessageSegment,
PrivateMessageEvent,
)
from .events.meta import HeartbeatEvent, HeartbeatStatus, LifeCycleEvent, MetaEvent
from .events.notice import (
ClientStatus,
ClientStatusNoticeEvent,
EssenceNoticeEvent,
FriendAddNoticeEvent,
FriendRecallNoticeEvent,
GroupAdminNoticeEvent,
GroupBanNoticeEvent,
GroupCardNoticeEvent,
GroupDecreaseNoticeEvent,
GroupIncreaseNoticeEvent,
GroupRecallNoticeEvent,
GroupUploadFile,
GroupUploadNoticeEvent,
HonorNotifyEvent,
LuckyKingNotifyEvent,
NoticeEvent,
NotifyNoticeEvent,
OfflineFile,
OfflineFileNoticeEvent,
PokeNotifyEvent,
)
from .events.request import FriendRequestEvent, GroupRequestEvent, RequestEvent
from .objects import (
CurrentTalkative,
EssenceMessage,
FriendInfo,
GroupHonorInfo,
GroupInfo,
GroupMemberInfo,
HonorInfo,
LoginInfo,
Status,
StrangerInfo,
VersionInfo,
)
# Alias for backward compatibility 导出常用的模型类,方便插件导入。
Event = OneBotEvent """
from .events.base import OneBotEvent
from .events.message import MessageEvent, GroupMessageEvent, PrivateMessageEvent
from .events.notice import NoticeEvent
from .events.request import RequestEvent
from .message import MessageSegment
from .sender import Sender
__all__ = [ __all__ = [
"OneBotEvent",
"MessageEvent",
"GroupMessageEvent",
"PrivateMessageEvent",
"NoticeEvent",
"RequestEvent",
"MessageSegment", "MessageSegment",
"Sender", "Sender",
"OneBotEvent",
"Event",
"MessageEvent",
"PrivateMessageEvent",
"GroupMessageEvent",
"NoticeEvent",
"FriendAddNoticeEvent",
"FriendRecallNoticeEvent",
"GroupRecallNoticeEvent",
"GroupIncreaseNoticeEvent",
"GroupDecreaseNoticeEvent",
"GroupAdminNoticeEvent",
"GroupBanNoticeEvent",
"GroupUploadNoticeEvent",
"GroupUploadFile",
"NotifyNoticeEvent",
"PokeNotifyEvent",
"LuckyKingNotifyEvent",
"HonorNotifyEvent",
"GroupCardNoticeEvent",
"OfflineFileNoticeEvent",
"OfflineFile",
"ClientStatusNoticeEvent",
"ClientStatus",
"EssenceNoticeEvent",
"RequestEvent",
"FriendRequestEvent",
"GroupRequestEvent",
"MetaEvent",
"HeartbeatEvent",
"LifeCycleEvent",
"HeartbeatStatus",
"EventFactory",
"GroupInfo",
"GroupMemberInfo",
"FriendInfo",
"StrangerInfo",
"LoginInfo",
"VersionInfo",
"Status",
"EssenceMessage",
"GroupHonorInfo",
"CurrentTalkative",
"HonorInfo",
] ]

View File

@@ -70,7 +70,11 @@ class EventFactory:
# 解析消息段 # 解析消息段
message_list = [] message_list = []
raw_message_list = data.get("message", []) raw_message_list = data.get("message", [])
if isinstance(raw_message_list, list):
if isinstance(raw_message_list, str):
# 如果消息是字符串,将其视为纯文本消息段
message_list.append(MessageSegment.text(raw_message_list))
elif isinstance(raw_message_list, list):
for item in raw_message_list: for item in raw_message_list:
if isinstance(item, dict): if isinstance(item, dict):
message_list.append(MessageSegment(type=item.get("type", ""), data=item.get("data", {}))) message_list.append(MessageSegment(type=item.get("type", ""), data=item.get("data", {})))
@@ -252,9 +256,18 @@ class EventFactory:
card_new=data.get("card_new", ""), card_new=data.get("card_new", ""),
card_old=data.get("card_old", "") card_old=data.get("card_old", "")
) )
elif notice_type == "group_card":
return GroupCardNoticeEvent(
**common_args,
notice_type=notice_type,
group_id=data.get("group_id", 0),
user_id=data.get("user_id", 0),
card_new=data.get("card_new", ""),
card_old=data.get("card_old", "")
)
elif notice_type == "offline_file": elif notice_type == "offline_file":
file_data = data.get("file", {}) file_data = data.get("file", {})
file = OfflineFile( offline_file = OfflineFile(
name=file_data.get("name", ""), name=file_data.get("name", ""),
size=file_data.get("size", 0), size=file_data.get("size", 0),
url=file_data.get("url", "") url=file_data.get("url", "")
@@ -263,7 +276,7 @@ class EventFactory:
**common_args, **common_args,
notice_type=notice_type, notice_type=notice_type,
user_id=data.get("user_id", 0), user_id=data.get("user_id", 0),
file=file file=offline_file
) )
elif notice_type == "client_status": elif notice_type == "client_status":
client_data = data.get("client", {}) client_data = data.get("client", {})

View File

@@ -4,9 +4,9 @@
定义了消息相关的事件类,包括 MessageEvent, PrivateMessageEvent, GroupMessageEvent。 定义了消息相关的事件类,包括 MessageEvent, PrivateMessageEvent, GroupMessageEvent。
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional from typing import List, Optional, Union
from core.managers.permission_manager import ADMIN, OP, USER from core.permission import Permission
from models.message import MessageSegment from models.message import MessageSegment
from models.sender import Sender from models.sender import Sender
from .base import OneBotEvent, EventType from .base import OneBotEvent, EventType
@@ -34,9 +34,9 @@ class MessageEvent(OneBotEvent):
""" """
# 权限级别常量,用于装饰器参数 # 权限级别常量,用于装饰器参数
ADMIN = ADMIN ADMIN = Permission.ADMIN
OP = OP OP = Permission.OP
USER = USER USER = Permission.USER
message_type: str message_type: str
"""消息类型: private (私聊), group (群聊)""" """消息类型: private (私聊), group (群聊)"""
@@ -70,7 +70,7 @@ class MessageEvent(OneBotEvent):
def post_type(self) -> str: def post_type(self) -> str:
return EventType.MESSAGE return EventType.MESSAGE
async def reply(self, message: str, auto_escape: bool = False): async def reply(self, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False):
""" """
回复消息(抽象方法,由子类实现) 回复消息(抽象方法,由子类实现)
@@ -86,7 +86,7 @@ class PrivateMessageEvent(MessageEvent):
私聊消息事件 私聊消息事件
""" """
async def reply(self, message: str, auto_escape: bool = False): async def reply(self, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False):
""" """
回复私聊消息 回复私聊消息
@@ -110,7 +110,7 @@ class GroupMessageEvent(MessageEvent):
anonymous: Optional[Anonymous] = None anonymous: Optional[Anonymous] = None
"""匿名信息""" """匿名信息"""
async def reply(self, message: str, auto_escape: bool = False): async def reply(self, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False):
""" """
回复群聊消息 回复群聊消息

View File

@@ -63,5 +63,5 @@ class LifeCycleEvent(MetaEvent):
meta_event_type: str = 'lifecycle' meta_event_type: str = 'lifecycle'
"""元事件类型:生命周期事件""" """元事件类型:生命周期事件"""
sub_type: LifeCycleSubType = LifeCycleSubType.ENABLE sub_type: str = LifeCycleSubType.ENABLE
"""子类型:启用、禁用、连接""" """子类型:启用、禁用、连接"""

View File

@@ -6,7 +6,7 @@
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict from typing import Any, Dict, Optional, List
@dataclass(slots=True) @dataclass(slots=True)
@@ -23,7 +23,7 @@ class MessageSegment:
data: Dict[str, Any] data: Dict[str, Any]
@property @property
def text(self) -> str: def plain_text(self) -> str:
""" """
当消息段类型为 'text' 时,快速获取其文本内容。 当消息段类型为 'text' 时,快速获取其文本内容。
@@ -32,6 +32,19 @@ class MessageSegment:
""" """
return self.data.get("text", "") if self.type == "text" else "" return self.data.get("text", "") if self.type == "text" else ""
@staticmethod
def text(text: str) -> "MessageSegment":
"""
创建一个文本消息段。
Args:
text (str): 文本内容。
Returns:
MessageSegment: 一个类型为 'text' 的消息段对象。
"""
return MessageSegment(type="text", data={"text": text})
@property @property
def image_url(self) -> str: def image_url(self) -> str:
""" """
@@ -76,7 +89,7 @@ class MessageSegment:
return self.data.get("file", "") return self.data.get("file", "")
return "" return ""
def is_at(self, user_id: int = None) -> bool: def is_at(self, user_id: Optional[int] = None) -> bool:
""" """
检查当前消息段是否是一个 'at' (提及) 消息段。 检查当前消息段是否是一个 'at' (提及) 消息段。
@@ -93,16 +106,52 @@ class MessageSegment:
return True return True
return str(self.data.get("qq")) == str(user_id) return str(self.data.get("qq")) == str(user_id)
def __str__(self):
"""
返回消息段的 CQ 码字符串表示。
"""
if self.type == "text":
return self.data.get("text", "")
params = ",".join([f"{k}={v}" for k, v in self.data.items()])
if params:
return f"[CQ:{self.type},{params}]"
return f"[CQ:{self.type}]"
def __repr__(self): def __repr__(self):
""" """
返回消息段对象的字符串表示形式,便于调试。 返回消息段对象的字符串表示形式,便于调试。
""" """
return f"[MS:{self.type}:{self.data}]" return f"[MS:{self.type}:{self.data}]"
def __add__(self, other: Any) -> "List[MessageSegment]":
"""
支持消息段相加,返回消息段列表。
"""
if isinstance(other, MessageSegment):
return [self, other]
elif isinstance(other, str):
return [self, MessageSegment.text(other)]
elif isinstance(other, list):
return [self] + other
return NotImplemented
def __radd__(self, other: Any) -> "List[MessageSegment]":
"""
支持反向相加。
"""
if isinstance(other, MessageSegment):
return [other, self]
elif isinstance(other, str):
return [MessageSegment.text(other), self]
elif isinstance(other, list):
return other + [self]
return NotImplemented
# --- 快捷构造方法 --- # --- 快捷构造方法 ---
@staticmethod @staticmethod
def text(text: str) -> "MessageSegment": # noqa: F811 def from_text(text: str) -> "MessageSegment":
""" """
创建一个文本消息段。 创建一个文本消息段。
@@ -115,7 +164,7 @@ class MessageSegment:
return MessageSegment(type="text", data={"text": text}) return MessageSegment(type="text", data={"text": text})
@staticmethod @staticmethod
def at(user_id: int | str, name: str = None) -> "MessageSegment": def at(user_id: int | str, name: Optional[str] = None) -> "MessageSegment":
""" """
创建一个 @某人 的消息段。 创建一个 @某人 的消息段。
@@ -132,7 +181,7 @@ class MessageSegment:
return MessageSegment(type="at", data=data) return MessageSegment(type="at", data=data)
@staticmethod @staticmethod
def image(file: str, image_type: str = None, cache: bool = True, proxy: bool = True, timeout: int = None, sub_type: int = None) -> "MessageSegment": def image(file: str, image_type: Optional[str] = None, cache: bool = True, proxy: bool = True, timeout: Optional[int] = None, sub_type: Optional[int] = None) -> "MessageSegment":
""" """
创建一个图片消息段。 创建一个图片消息段。
@@ -194,7 +243,7 @@ class MessageSegment:
""" """
return MessageSegment(type="xml", data={"data": data}) return MessageSegment(type="xml", data={"data": data})
@staticmethod @staticmethod
def share(url: str, title: str, content: str = None, image: str = None) -> "MessageSegment": def share(url: str, title: str, content: Optional[str] = None, image: Optional[str] = None) -> "MessageSegment":
""" """
创建一个分享消息段。 创建一个分享消息段。
@@ -227,7 +276,7 @@ class MessageSegment:
""" """
return MessageSegment(type="music", data={"type": type, "id": id}) return MessageSegment(type="music", data={"type": type, "id": id})
@staticmethod @staticmethod
def music_custom(url: str, audio: str, title: str, content: str = None, image: str = None) -> "MessageSegment": def music_custom(url: str, audio: str, title: str, content: Optional[str] = None, image: Optional[str] = None) -> "MessageSegment":
""" """
创建一个自定义音乐消息段。 创建一个自定义音乐消息段。
@@ -248,7 +297,7 @@ class MessageSegment:
data["image"] = image data["image"] = image
return MessageSegment(type="music", data={"type": "custom", **data}) return MessageSegment(type="music", data={"type": "custom", **data})
@staticmethod @staticmethod
def record(file: str, magic: bool = False, cache: bool = True, proxy: bool = True, timeout: int = None) -> "MessageSegment": def record(file: str, magic: bool = False, cache: bool = True, proxy: bool = True, timeout: Optional[int] = None) -> "MessageSegment":
""" """
创建一个语音消息段。 创建一个语音消息段。
@@ -267,7 +316,7 @@ class MessageSegment:
data["timeout"] = str(timeout) data["timeout"] = str(timeout)
return MessageSegment(type="record", data=data) return MessageSegment(type="record", data=data)
@staticmethod @staticmethod
def video(file: str, cover: str = None, c: int = 2) -> "MessageSegment": def video(file: str, cover: Optional[str] = None, c: int = 2) -> "MessageSegment":
""" """
创建一个视频消息段。 创建一个视频消息段。
@@ -297,17 +346,17 @@ class MessageSegment:
return MessageSegment(type="file", data={"file": file}) return MessageSegment(type="file", data={"file": file})
@staticmethod @staticmethod
def reply(message_id: str) -> "MessageSegment": def reply(message_id: str | int) -> "MessageSegment":
""" """
创建一个回复消息段。 创建一个回复消息段。
Args: Args:
message_id (str): 被回复的消息 ID。 message_id (str | int): 被回复的消息 ID。
Returns: Returns:
MessageSegment: 一个类型为 'reply' 的消息段对象。 MessageSegment: 一个类型为 'reply' 的消息段对象。
""" """
return MessageSegment(type="reply", data={"id": message_id}) return MessageSegment(type="reply", data={"id": str(message_id)})
@staticmethod @staticmethod
def rps() -> "MessageSegment": def rps() -> "MessageSegment":

0
plugins/__init__.py Normal file
View File

View File

@@ -1,74 +1,94 @@
""" from core.handlers.event_handler import MessageHandler
管理员管理插件 from core.managers import command_manager, permission_manager
from core.permission import Permission
提供通过聊天指令动态添加或移除机器人管理员的功能。
"""
from core.bot import Bot
from core.managers.command_manager import matcher
from core.managers.admin_manager import admin_manager
from models.events.message import MessageEvent from models.events.message import MessageEvent
# 更新插件元信息以包含OP管理
__plugin_meta__ = { __plugin_meta__ = {
"name": "管理员管理", "name": "权限管理",
"description": "管理机器人的全局管理员", "description": "管理机器人的管理员和操作",
"usage": ( "usage": (
"/admin list - 列出所有管理员\n" "/admin list - 列出所有管理员和操作员\n"
"/admin add <QQ号> - 添加管理员\n" "/admin add_admin <QQ号> - 添加管理员\n"
"/admin remove <QQ号> - 移除管理员" "/admin remove_admin <QQ号> - 移除管理员\n"
"/admin add_op <QQ号> - 添加操作员\n"
"/admin remove_op <QQ号> - 移除操作员"
), ),
} }
@matcher.command("admin", permission=MessageEvent.ADMIN) @command_manager.command("admin", permission=Permission.ADMIN)
async def admin_command_handler(bot: Bot, event: MessageEvent, args: list[str]): async def admin_management(event: MessageEvent, args: str):
""" """
处理 /admin 指令 处理所有权限管理相关的命令。
:param bot: Bot 实例
:param event: 消息事件实例
:param args: 指令参数列表
""" """
if not args: parts = args.split()
await event.reply(__plugin_meta__["usage"]) if not parts:
await event.reply(f"用法不正确。\n\n{__plugin_meta__['usage']}")
return return
action = args[0].lower() subcommand = parts[0].lower()
if action == "list": if subcommand == "list":
admins = await admin_manager.get_all_admins() await list_permissions(event)
if not admins:
await event.reply("当前没有设置任何管理员。")
return
admin_list_text = "\n".join(str(admin_id) for admin_id in admins)
await event.reply(f"当前管理员列表 ({len(admins)}):\n{admin_list_text}")
return return
if action in ("add", "remove"): # 处理需要QQ号的命令
if len(args) < 2 or not args[1].isdigit(): if len(parts) < 2 or not parts[1].isdigit():
await event.reply("参数错误,请提供一个有效的 QQ 号。\n示例: /admin add 123456") await event.reply(f"请提供有效的用户QQ号。\n用法: /admin {subcommand} <QQ号>")
return return
try: try:
user_id = int(args[1]) target_user_id = int(parts[1])
except ValueError: except ValueError:
await event.reply("无效的 QQ 号,请输入纯数字") await event.reply("无效的QQ")
return return
if action == "add": # 安全检查
success = await admin_manager.add_admin(user_id) if target_user_id == event.user_id:
if success: await event.reply("你不能操作自己的权限。")
await event.reply(f"成功添加管理员: {user_id}") return
else: if target_user_id == event.self_id:
await event.reply(f"管理员 {user_id} 已存在,无需重复添加") await event.reply("你不能操作机器人自身的权限")
return return
elif action == "remove":
success = await admin_manager.remove_admin(user_id)
if success:
await event.reply(f"成功移除管理员: {user_id}")
else:
await event.reply(f"管理员 {user_id} 不存在。")
return
await event.reply(f"未知的指令: {action}\n\n{__plugin_meta__['usage']}") # 根据子命令分发
if subcommand == "add_admin":
permission_manager.set_user_permission(target_user_id, Permission.ADMIN)
await event.reply(f"已成功添加管理员:{target_user_id}")
elif subcommand == "remove_admin":
permission_manager.set_user_permission(target_user_id, Permission.USER)
await event.reply(f"已成功移除管理员:{target_user_id}")
elif subcommand == "add_op":
permission_manager.set_user_permission(target_user_id, Permission.OP)
await event.reply(f"已成功添加操作员:{target_user_id}")
elif subcommand == "remove_op":
permission_manager.set_user_permission(target_user_id, Permission.USER)
await event.reply(f"已成功移除操作员:{target_user_id}")
else:
await event.reply(f"未知的子命令 '{subcommand}'\n\n{__plugin_meta__['usage']}")
async def list_permissions(event: MessageEvent):
"""
列出所有具有特殊权限(管理员和操作员)的用户。
"""
permissions = permission_manager.get_all_user_permissions()
if not permissions:
await event.reply("当前没有配置任何特殊权限的用户。")
return
admins = {uid for uid, p in permissions.items() if p == 'admin'}
ops = {uid for uid, p in permissions.items() if p == 'op'}
reply_msg = "当前权限列表:\n"
if admins:
reply_msg += "--- 管理员 ---\n"
for user_id in admins:
reply_msg += f"- {user_id}\n"
if ops:
reply_msg += "--- 操作员 ---\n"
for user_id in ops:
reply_msg += f"- {user_id}\n"
await event.reply(reply_msg.strip())

View File

@@ -3,12 +3,16 @@ import re
import json import json
import requests import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from typing import Optional, Dict, Any from typing import Optional, Dict, Any, Union
from cachetools import TTLCache
from core.utils.logger import logger from core.utils.logger import logger
from core.managers.command_manager import matcher from core.managers.command_manager import matcher
from models import MessageEvent, MessageSegment from models import MessageEvent, MessageSegment
# 创建一个TTL缓存最大容量100缓存时间10秒
processed_messages: TTLCache[int, bool] = TTLCache(maxsize=100, ttl=10)
__plugin_meta__ = { __plugin_meta__ = {
"name": "bili_parser", "name": "bili_parser",
"description": "自动解析B站分享卡片提取视频封面和播放量等信息。", "description": "自动解析B站分享卡片提取视频封面和播放量等信息。",
@@ -52,10 +56,14 @@ def parse_video_info(video_url: str) -> Optional[Dict[str, Any]]:
soup = BeautifulSoup(response.text, 'html.parser') soup = BeautifulSoup(response.text, 'html.parser')
script_tag = soup.find('script', text=re.compile('window.__INITIAL_STATE__')) script_tag = soup.find('script', text=re.compile('window.__INITIAL_STATE__'))
if not script_tag: if not script_tag or not script_tag.string:
return None return None
json_str = re.search(r'window\.__INITIAL_STATE__\s*=\s*(\{.*?\});', script_tag.string).group(1) match = re.search(r'window\.__INITIAL_STATE__\s*=\s*(\{.*?\});', script_tag.string)
if not match:
return None
json_str = match.group(1)
data = json.loads(json_str) data = json.loads(json_str)
video_data = data.get('videoData', {}) video_data = data.get('videoData', {})
@@ -121,6 +129,15 @@ async def handle_bili_share(event: MessageEvent):
处理消息检测B站分享链接JSON卡片或文本链接并进行解析。 处理消息检测B站分享链接JSON卡片或文本链接并进行解析。
:param event: 消息事件对象 :param event: 消息事件对象
""" """
# 消息去重
if event.message_id in processed_messages:
return
processed_messages[event.message_id] = True
# 忽略机器人自己发送的消息,防止无限循环
if event.user_id == event.self_id:
return
url_to_process = None url_to_process = None
# 1. 优先解析JSON卡片中的短链接 # 1. 优先解析JSON卡片中的短链接
@@ -176,6 +193,7 @@ async def process_bili_link(event: MessageEvent, url: str):
return return
# 检查视频时长 # 检查视频时长
video_message: Union[str, MessageSegment]
if video_info['duration'] > 300: # 5分钟 = 300秒 if video_info['duration'] > 300: # 5分钟 = 300秒
video_message = "视频时长超过5分钟不进行解析。" video_message = "视频时长超过5分钟不进行解析。"
else: else:

View File

@@ -8,8 +8,8 @@
""" """
import asyncio import asyncio
from core.managers.command_manager import matcher from core.managers.command_manager import matcher
from models import MessageEvent, PrivateMessageEvent from models.events.message import MessageEvent, PrivateMessageEvent
from core.managers.permission_manager import ADMIN from core.permission import Permission
from core.utils.logger import logger from core.utils.logger import logger
# --- 会话状态管理 --- # --- 会话状态管理 ---
@@ -24,7 +24,7 @@ def cleanup_session(user_id: int):
del broadcast_sessions[user_id] del broadcast_sessions[user_id]
logger.info(f"[Broadcast] 会话 {user_id} 已超时,自动取消。") logger.info(f"[Broadcast] 会话 {user_id} 已超时,自动取消。")
@matcher.command("broadcast", "广播", permission=ADMIN) @matcher.command("broadcast", "广播", permission=Permission.ADMIN)
async def broadcast_start(event: MessageEvent): async def broadcast_start(event: MessageEvent):
""" """
广播指令的入口,启动一个等待用户消息的会话。 广播指令的入口,启动一个等待用户消息的会话。
@@ -92,7 +92,7 @@ async def handle_broadcast_content(event: MessageEvent):
nodes_to_send = [ nodes_to_send = [
bot.build_forward_node( bot.build_forward_node(
user_id=event.user_id, user_id=event.user_id,
nickname=event.sender.nickname, nickname=event.sender.nickname if event.sender else "未知用户",
message=message_to_broadcast message=message_to_broadcast
) )
] ]

View File

@@ -1,35 +1,24 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import html import html
import textwrap import textwrap
# -*- coding: utf-8 -*-
import html
import textwrap
import asyncio import asyncio
from typing import Dict from typing import Dict
from core.managers.command_manager import matcher from core.managers.command_manager import matcher
from models import MessageEvent from models.events.message import MessageEvent
from core.managers.permission_manager import ADMIN from core.permission import Permission
from core.utils.logger import logger from core.utils.logger import logger
__plugin_meta__ = { __plugin_meta__ = {
"name": "Python 代码执行", "name": "Python 代码执行",
"description": "在安全的沙箱环境中执行 Python 代码片段,支持单行、多行和转发回复。", "description": "在安全的沙箱环境中执行 Python 代码片段,支持单行、多行和转发回复。",
"usage": "/py <单行代码>\n/code_py <单行代码>\n/py (进入多行输入模式)", "usage": "/py <单行代码>\n/code_py <单行代码>\n/py (进入多行输入模式)",
"name": "Python 代码执行",
"description": "在安全的沙箱环境中执行 Python 代码片段,支持单行、多行和转发回复。",
"usage": "/py <单行代码>\n/code_py <单行代码>\n/py (进入多行输入模式)",
} }
# --- 会话状态管理 --- # --- 会话状态管理 ---
# 结构: {(user_id, group_id): asyncio.TimerHandle} # 结构: {(user_id, group_id): asyncio.TimerHandle}
multi_line_sessions: Dict[tuple, asyncio.TimerHandle] = {} multi_line_sessions: Dict[tuple, asyncio.TimerHandle] = {}
async def reply_as_forward(event: MessageEvent, input_code: str, output_result: str):
# --- 会话状态管理 ---
# 结构: {(user_id, group_id): asyncio.TimerHandle}
multi_line_sessions: Dict[tuple, asyncio.TimerHandle] = {}
async def reply_as_forward(event: MessageEvent, input_code: str, output_result: str): async def reply_as_forward(event: MessageEvent, input_code: str, output_result: str):
""" """
将输入和输出打包成转发消息进行回复。 将输入和输出打包成转发消息进行回复。
@@ -41,35 +30,7 @@ async def reply_as_forward(event: MessageEvent, input_code: str, output_result:
nodes = [ nodes = [
bot.build_forward_node( bot.build_forward_node(
user_id=event.user_id, user_id=event.user_id,
nickname=event.sender.nickname or str(event.user_id), nickname=event.sender.nickname if event.sender else str(event.user_id),
message=f"--- Your Code ---\n{input_code}"
),
bot.build_forward_node(
user_id=event.self_id,
nickname="Code Executor",
message=f"--- Execution Result ---\n{output_result}"
)
]
try:
# 2. 发送合并转发消息
await bot.send_forwarded_messages(event, nodes)
except Exception as e:
logger.error(f"[code_py] 发送转发消息失败: {e}")
# 降级为普通消息回复
await event.reply(f"--- 你的代码 ---\n{input_code}\n--- 执行结果 ---\n{output_result}")
async def execute_code(event: MessageEvent, code: str):
将输入和输出打包成转发消息进行回复
参考 forward_test.py 的实现兼容私聊和群聊
"""
bot = event.bot
# 1. 构建消息节点列表
nodes = [
bot.build_forward_node(
user_id=event.user_id,
nickname=event.sender.nickname or str(event.user_id),
message=f"--- Your Code ---\n{input_code}" message=f"--- Your Code ---\n{input_code}"
), ),
bot.build_forward_node( bot.build_forward_node(
@@ -90,7 +51,6 @@ async def execute_code(event: MessageEvent, code: str):
async def execute_code(event: MessageEvent, code: str): async def execute_code(event: MessageEvent, code: str):
""" """
核心代码执行逻辑。 核心代码执行逻辑。
核心代码执行逻辑
""" """
code_executor = getattr(event.bot, 'code_executor', None) code_executor = getattr(event.bot, 'code_executor', None)
if not code_executor or not code_executor.docker_client: if not code_executor or not code_executor.docker_client:
@@ -137,74 +97,15 @@ def normalize_code(code: str) -> str:
return code.strip() return code.strip()
@matcher.command("py", "python", "code_py", permission=ADMIN) @matcher.command("py", "python", "code_py", permission=Permission.ADMIN)
async def code_py_main(event: MessageEvent, args: list[str]):
code_executor = getattr(event.bot, 'code_executor', None)
if not code_executor or not code_executor.docker_client:
await event.reply("代码执行服务当前不可用,请检查 Docker 连接配置。")
return
# 修改 add_task让它能直接接收回复函数
await code_executor.add_task(
code,
lambda result: reply_as_forward(event, code, result)
)
await event.reply("代码已提交至沙箱执行队列,请稍候...")
def cleanup_session(session_key: tuple):
"""
清理超时的会话
"""
if session_key in multi_line_sessions:
del multi_line_sessions[session_key]
logger.info(f"[code_py] 会话 {session_key} 已超时,自动取消。")
def normalize_code(code: str) -> str:
"""
规范化用户输入的 Python 代码字符串
主要处理两个问题
1. 对消息中可能存在的 HTML 实体进行解码 (e.g., &#91; -> [)。
2. 移除整个代码块的公共前导缩进以修复因复制粘贴导致的多余缩进
:param code: 原始代码字符串
:return: 规范化后的代码字符串
"""
# 1. 解码 HTML 实体
code = html.unescape(code)
# 2. 移除公共前导缩进
try:
code = textwrap.dedent(code)
except Exception:
# 在某些情况下例如不一致的缩进dedent 可能会失败,
# 但我们不希望因此中断流程,所以捕获异常并继续。
pass
return code.strip()
@matcher.command("py", "python", "code_py", permission=ADMIN)
async def code_py_main(event: MessageEvent, args: list[str]): async def code_py_main(event: MessageEvent, args: list[str]):
""" """
/py 命令的主入口。 /py 命令的主入口。
- 如果有参数,直接执行。 - 如果有参数,直接执行。
- 如果没有参数,开启多行输入模式。 - 如果没有参数,开启多行输入模式。
/py 命令的主入口
- 如果有参数直接执行
- 如果没有参数开启多行输入模式
""" """
code_to_run = " ".join(args) code_to_run = " ".join(args)
if code_to_run:
# 单行模式,对代码进行规范化处理
normalized_code = normalize_code(code_to_run)
if not normalized_code:
await event.reply("代码为空或格式错误,请输入有效的代码。")
return
await execute_code(event, normalized_code)
code_to_run = " ".join(args)
if code_to_run: if code_to_run:
# 单行模式,对代码进行规范化处理 # 单行模式,对代码进行规范化处理
normalized_code = normalize_code(code_to_run) normalized_code = normalize_code(code_to_run)
@@ -231,24 +132,6 @@ async def code_py_main(event: MessageEvent, args: list[str]):
session_key session_key
) )
multi_line_sessions[session_key] = timeout_handler multi_line_sessions[session_key] = timeout_handler
# 多行模式
# 使用 getattr 兼容私聊和群聊
session_key = (event.user_id, getattr(event, 'group_id', 'private'))
# 如果上一个会话的超时任务还在,先取消它
if session_key in multi_line_sessions:
multi_line_sessions[session_key].cancel()
await event.reply("已进入多行代码输入模式,请直接发送你的代码。\n(60秒内无操作将自动取消)")
# 设置 60 秒超时
loop = asyncio.get_running_loop()
timeout_handler = loop.call_later(
60,
cleanup_session,
session_key
)
multi_line_sessions[session_key] = timeout_handler
@matcher.on_message() @matcher.on_message()
async def handle_multi_line_code(event: MessageEvent): async def handle_multi_line_code(event: MessageEvent):
@@ -265,26 +148,6 @@ async def handle_multi_line_code(event: MessageEvent):
# 对多行代码进行规范化处理 # 对多行代码进行规范化处理
normalized_code = normalize_code(event.raw_message) normalized_code = normalize_code(event.raw_message)
if not normalized_code:
await event.reply("捕获到的代码为空或格式错误,已取消输入。")
return
await execute_code(event, normalized_code)
return True # 消费事件,防止其他处理器响应
async def handle_multi_line_code(event: MessageEvent):
"""
通用消息处理器用于捕获多行模式下的代码输入
"""
# 使用 getattr 兼容私聊和群聊
session_key = (event.user_id, getattr(event, 'group_id', 'private'))
if session_key in multi_line_sessions:
# 取消超时任务
multi_line_sessions[session_key].cancel()
del multi_line_sessions[session_key]
# 对多行代码进行规范化处理
normalized_code = normalize_code(event.raw_message)
if not normalized_code: if not normalized_code:
await event.reply("捕获到的代码为空或格式错误,已取消输入。") await event.reply("捕获到的代码为空或格式错误,已取消输入。")
return return

View File

@@ -5,7 +5,7 @@ Echo 与交互插件
""" """
from core.managers.command_manager import matcher from core.managers.command_manager import matcher
from core.bot import Bot from core.bot import Bot
from models import MessageEvent from models.events.message import MessageEvent
__plugin_meta__ = { __plugin_meta__ = {
"name": "echo", "name": "echo",

View File

@@ -3,7 +3,7 @@
""" """
from core.managers.command_manager import matcher from core.managers.command_manager import matcher
from core.bot import Bot from core.bot import Bot
from models import MessageEvent from models.events.message import MessageEvent
from models.message import MessageSegment from models.message import MessageSegment
__plugin_meta__ = { __plugin_meta__ = {
@@ -22,14 +22,15 @@ async def handle_forward_test(bot: Bot, event: MessageEvent, args: list[str]):
:param args: 指令参数 :param args: 指令参数
""" """
# 1. 构建消息节点列表 # 1. 构建消息节点列表
nickname = event.sender.nickname if event.sender else "未知用户"
nodes = [ nodes = [
bot.build_forward_node(user_id=event.self_id, nickname="机器人", message="你要的furry来了"), bot.build_forward_node(user_id=event.self_id, nickname="机器人", message="你要的furry来了"),
bot.build_forward_node(user_id=event.user_id, nickname=event.sender.nickname, message="让我看看"), bot.build_forward_node(user_id=event.user_id, nickname=nickname, message="让我看看"),
bot.build_forward_node( bot.build_forward_node(
user_id=event.self_id, user_id=event.self_id,
nickname="机器人", nickname="机器人",
message=[ message=[
MessageSegment.text("你要的福瑞图"), MessageSegment.from_text("你要的福瑞图"),
MessageSegment.image("https://api.furry.ist/furry-img/") MessageSegment.image("https://api.furry.ist/furry-img/")
] ]
) )

View File

@@ -10,7 +10,7 @@ from datetime import datetime
from core.bot import Bot from core.bot import Bot
from core.managers.command_manager import matcher from core.managers.command_manager import matcher
from core.utils.executor import run_in_thread_pool from core.utils.executor import run_in_thread_pool
from models import MessageEvent, MessageSegment from models.events.message import MessageEvent, MessageSegment
__plugin_meta__ = { __plugin_meta__ = {
"name": "jrcd", "name": "jrcd",
@@ -79,14 +79,17 @@ async def handle_jrcd(bot: Bot, event: MessageEvent, args: list[str]):
""" """
user_id = event.user_id user_id = event.user_id
jrcd = await run_in_thread_pool(get_jrcd, user_id) jrcd = await run_in_thread_pool(get_jrcd, user_id)
msg = [MessageSegment.at(user_id)]
msg_text = ""
if jrcd <= 9: if jrcd <= 9:
msg.append(MessageSegment.text(random.choice(JRCDMSG_1) % jrcd)) msg_text = random.choice(JRCDMSG_1) % jrcd
elif jrcd <= 19: elif jrcd <= 19:
msg.append(MessageSegment.text(random.choice(JRCDMSG_2) % jrcd)) msg_text = random.choice(JRCDMSG_2) % jrcd
else: else:
msg.append(MessageSegment.text(random.choice(JRCDMSG_3) % jrcd)) msg_text = random.choice(JRCDMSG_3) % jrcd
await event.reply(msg)
reply_segments = [MessageSegment.at(user_id), MessageSegment.from_text(msg_text)]
await event.reply(reply_segments)
@matcher.command("bbcd") @matcher.command("bbcd")
@@ -118,29 +121,31 @@ async def handle_bbcd(bot: Bot, event: MessageEvent, args: list[str]):
jrcz = jrcd1 - jrcd2 jrcz = jrcd1 - jrcd2
msg = [ text_part = ""
MessageSegment.at(user_id1),
MessageSegment.text("你的长度比"),
MessageSegment.at(user_id2),
]
if jrcz == 0: if jrcz == 0:
msg.append(MessageSegment.text("一样长。")) text_part = f" 一样长。{random.choice(BBCDMSG7)}"
msg.append(MessageSegment.text(random.choice(BBCDMSG7)))
elif jrcz > 0: elif jrcz > 0:
msg.append(MessageSegment.text("" + str(jrcz) + "cm。")) text_part = f"{jrcz}cm。"
if jrcz <= 9: if jrcz <= 9:
msg.append(MessageSegment.text(random.choice(BBCDMSG1))) text_part += random.choice(BBCDMSG1)
elif jrcz <= 19: elif jrcz <= 19:
msg.append(MessageSegment.text(random.choice(BBCDMSG2))) text_part += random.choice(BBCDMSG2)
else: else:
msg.append(MessageSegment.text(random.choice(BBCDMSG3))) text_part += random.choice(BBCDMSG3)
elif jrcz < 0: else: # jrcz < 0
msg.append(MessageSegment.text("" + str(abs(jrcz)) + "cm。")) text_part = f"{abs(jrcz)}cm。"
if jrcz >= -9: if jrcz >= -9:
msg.append(MessageSegment.text(random.choice(BBCDMSG4))) text_part += random.choice(BBCDMSG4)
elif jrcz >= -19: elif jrcz >= -19:
msg.append(MessageSegment.text(random.choice(BBCDMSG5))) text_part += random.choice(BBCDMSG5)
else: else:
msg.append(MessageSegment.text(random.choice(BBCDMSG6))) text_part += random.choice(BBCDMSG6)
await event.reply(msg)
segments = [
MessageSegment.at(user_id1),
MessageSegment.from_text(" 你的长度比 "),
MessageSegment.at(user_id2),
MessageSegment.from_text(text_part),
]
await event.reply(segments)

View File

@@ -7,7 +7,7 @@ thpic 插件
from core.bot import Bot from core.bot import Bot
from core.managers.command_manager import matcher from core.managers.command_manager import matcher
from models import MessageEvent, MessageSegment from models.events.message import MessageEvent, MessageSegment
__plugin_meta__ = { __plugin_meta__ = {
"name": "thpic", "name": "thpic",
@@ -26,6 +26,6 @@ async def handle_echo(bot: Bot, event: MessageEvent, args: list[str]):
:param args: 指令参数列表(未使用)。 :param args: 指令参数列表(未使用)。
""" """
try: try:
await event.reply(MessageSegment.image("https://img.paulzzh.com/touhou/random")) await event.reply(str(MessageSegment.image("https://img.paulzzh.com/touhou/random")))
except Exception as e: except Exception as e:
await event.reply("报错了。。。" + e) await event.reply(f"报错了。。。{e}")

View File

@@ -11,7 +11,6 @@ pipreqs==0.4.13
redis==5.0.7 redis==5.0.7
requests==2.32.5 requests==2.32.5
soupsieve==2.8.1 soupsieve==2.8.1
toml==0.10.2
typing==3.7.4.3 typing==3.7.4.3
typing_extensions==4.15.0 typing_extensions==4.15.0
urllib3==2.6.2 urllib3==2.6.2
@@ -19,7 +18,15 @@ watchdog==6.0.0
websockets==15.0.1 websockets==15.0.1
win32_setctime==1.2.0 win32_setctime==1.2.0
yarg==0.1.10 yarg==0.1.10
cachetools
pydantic
docker docker
pytest pytest
pytest-asyncio pytest-asyncio
pytest-mock pytest-mock
pytest-cov
httpx==0.27.0
# Dev Dependencies
mypy
pydantic[mypy]

37
tests/test_basic.py Normal file
View File

@@ -0,0 +1,37 @@
import pytest
import sys
import os
# 确保项目根目录在 sys.path 中
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
def test_import_core():
"""测试核心模块是否可以被导入"""
try:
import core
import core.bot
import core.ws
except ImportError as e:
pytest.fail(f"无法导入核心模块: {e}")
def test_plugin_manager_path():
"""测试插件管理器路径逻辑是否正确"""
from core.managers.plugin_manager import PluginManager
# Mock command manager
pm = PluginManager(None)
# 我们无法直接测试 load_all_plugins 的内部路径变量,
# 但我们可以检查它是否能找到 plugins 目录而不报错
# 这里我们简单地断言 PluginManager 类存在且可以实例化
assert pm is not None
def test_config_loader_exists():
"""测试配置加载器是否存在"""
# 注意:导入 config_loader 会尝试读取 config.toml
# 如果 config.toml 不存在,这可能会失败。
# 这是一个已知的设计问题,但在测试环境中我们假设 config.toml 存在或被 mock
if os.path.exists("config.toml"):
from core.config_loader import global_config
assert global_config is not None
else:
pytest.skip("config.toml 不存在,跳过配置加载测试")

View File

@@ -0,0 +1,114 @@
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from core.managers.command_manager import CommandManager
from models.events.message import GroupMessageEvent
from models.message import MessageSegment
@pytest.fixture
def mock_bot():
bot = AsyncMock()
bot.self_id = 123456
return bot
@pytest.fixture
def command_manager():
# 创建一个新的 CommandManager 实例用于测试,避免单例状态污染
return CommandManager(prefixes=("/",))
@pytest.mark.asyncio
async def test_command_registration_and_execution(command_manager, mock_bot):
"""测试命令注册和执行"""
# 定义一个命令处理函数
handler_mock = AsyncMock()
# 注册命令
@command_manager.command("test")
async def test_command(bot, event):
await handler_mock(bot, event)
# 构造触发命令的事件
event = MagicMock(spec=GroupMessageEvent)
event.post_type = "message"
event.message_type = "group"
event.raw_message = "/test"
event.message = [MessageSegment.text("/test")]
event.user_id = 111
event.group_id = 222
# 处理事件
await command_manager.handle_event(mock_bot, event)
# 验证处理函数被调用
handler_mock.assert_called_once_with(mock_bot, event)
@pytest.mark.asyncio
async def test_command_prefix_match(command_manager, mock_bot):
"""测试命令前缀匹配"""
handler_mock = AsyncMock()
@command_manager.command("hello")
async def hello_command(bot, event):
await handler_mock(bot, event)
# 1. 正确的前缀
event1 = MagicMock(spec=GroupMessageEvent)
event1.post_type = "message"
event1.raw_message = "/hello"
event1.message = [MessageSegment.text("/hello")]
await command_manager.handle_event(mock_bot, event1)
handler_mock.assert_called_once()
handler_mock.reset_mock()
# 2. 错误的前缀 (应该忽略)
event2 = MagicMock(spec=GroupMessageEvent)
event2.post_type = "message"
event2.raw_message = ".hello" # 假设前缀是 /
event2.message = [MessageSegment.text(".hello")]
await command_manager.handle_event(mock_bot, event2)
handler_mock.assert_not_called()
@pytest.mark.asyncio
async def test_ignore_self_message(command_manager, mock_bot):
"""测试忽略自身消息"""
# 模拟配置
with patch("core.managers.command_manager.global_config") as mock_config:
mock_config.bot.ignore_self_message = True
event = MagicMock(spec=GroupMessageEvent)
event.post_type = "message"
event.user_id = 123456 # 与 bot.self_id 相同
event.self_id = 123456
# Mock handle 方法来检测是否被调用
command_manager.message_handler.handle = AsyncMock()
await command_manager.handle_event(mock_bot, event)
# 应该直接返回,不调用 handler
command_manager.message_handler.handle.assert_not_called()
@pytest.mark.asyncio
async def test_help_command(command_manager, mock_bot):
"""测试内置 help 命令"""
# 注册一个测试插件信息
command_manager.plugins["test_plugin"] = {
"name": "测试插件",
"description": "这是一个测试",
"usage": "/test"
}
event = MagicMock(spec=GroupMessageEvent)
event.post_type = "message"
event.raw_message = "/help"
event.message = [MessageSegment.text("/help")]
await command_manager.handle_event(mock_bot, event)
# 验证 bot.send 被调用,且内容包含插件信息
mock_bot.send.assert_called_once()
args, _ = mock_bot.send.call_args
sent_msg = args[1]
assert "测试插件" in sent_msg
assert "这是一个测试" in sent_msg

141
tests/test_event_factory.py Normal file
View File

@@ -0,0 +1,141 @@
import pytest
from models.events.factory import EventFactory, EventType
from models.events.message import GroupMessageEvent, PrivateMessageEvent
from models.events.notice import GroupIncreaseNoticeEvent
from models.events.request import FriendRequestEvent
from models.events.meta import HeartbeatEvent
from models.message import MessageSegment
class TestEventFactory:
def test_create_group_message_event_list(self):
"""测试创建群消息事件 (message 为列表格式)"""
data = {
"post_type": "message",
"message_type": "group",
"time": 1600000000,
"self_id": 123456,
"sub_type": "normal",
"message_id": 1001,
"user_id": 111111,
"group_id": 222222,
"message": [
{"type": "text", "data": {"text": "Hello"}}
],
"raw_message": "Hello",
"font": 0,
"sender": {
"user_id": 111111,
"nickname": "User",
"role": "member"
}
}
event = EventFactory.create_event(data)
assert isinstance(event, GroupMessageEvent)
assert event.group_id == 222222
assert len(event.message) == 1
assert event.message[0].type == "text"
assert event.message[0].data["text"] == "Hello"
def test_create_group_message_event_str(self):
"""测试创建群消息事件 (message 为字符串格式)"""
data = {
"post_type": "message",
"message_type": "group",
"time": 1600000000,
"self_id": 123456,
"sub_type": "normal",
"message_id": 1002,
"user_id": 111111,
"group_id": 222222,
"message": "Hello World",
"raw_message": "Hello World",
"font": 0,
"sender": {
"user_id": 111111,
"nickname": "User"
}
}
event = EventFactory.create_event(data)
assert isinstance(event, GroupMessageEvent)
assert len(event.message) == 1
assert event.message[0].type == "text"
assert event.message[0].data["text"] == "Hello World"
def test_create_private_message_event(self):
"""测试创建私聊消息事件"""
data = {
"post_type": "message",
"message_type": "private",
"time": 1600000000,
"self_id": 123456,
"sub_type": "friend",
"message_id": 2001,
"user_id": 333333,
"message": "Private Msg",
"raw_message": "Private Msg",
"font": 0,
"sender": {
"user_id": 333333,
"nickname": "Friend"
}
}
event = EventFactory.create_event(data)
assert isinstance(event, PrivateMessageEvent)
assert event.user_id == 333333
def test_create_notice_event(self):
"""测试创建通知事件 (群成员增加)"""
data = {
"post_type": "notice",
"notice_type": "group_increase",
"sub_type": "approve",
"group_id": 222222,
"operator_id": 444444,
"user_id": 555555,
"time": 1600000000,
"self_id": 123456
}
event = EventFactory.create_event(data)
assert isinstance(event, GroupIncreaseNoticeEvent)
assert event.group_id == 222222
assert event.user_id == 555555
def test_create_request_event(self):
"""测试创建请求事件 (加好友)"""
data = {
"post_type": "request",
"request_type": "friend",
"user_id": 666666,
"comment": "Add me",
"flag": "flag_123",
"time": 1600000000,
"self_id": 123456
}
event = EventFactory.create_event(data)
assert isinstance(event, FriendRequestEvent)
assert event.user_id == 666666
assert event.comment == "Add me"
def test_create_meta_event(self):
"""测试创建元事件 (心跳)"""
data = {
"post_type": "meta_event",
"meta_event_type": "heartbeat",
"time": 1600000000,
"self_id": 123456,
"status": {"online": True, "good": True},
"interval": 5000
}
event = EventFactory.create_event(data)
assert isinstance(event, HeartbeatEvent)
assert event.interval == 5000
def test_unknown_event_type(self):
"""测试未知事件类型"""
data = {
"post_type": "unknown_type",
"time": 1600000000,
"self_id": 123456
}
with pytest.raises(ValueError, match="Unknown event type"):
EventFactory.create_event(data)

194
tests/test_event_handler.py Normal file
View File

@@ -0,0 +1,194 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from core.handlers.event_handler import MessageHandler, NoticeHandler, RequestHandler
from models.events.message import GroupMessageEvent
from models.events.notice import GroupIncreaseNoticeEvent
from models.events.request import FriendRequestEvent
@pytest.fixture
def mock_bot():
bot = AsyncMock()
return bot
@pytest.mark.asyncio
async def test_message_handler_run_handler_injection(mock_bot):
"""测试参数注入"""
handler = MessageHandler(prefixes=("/",))
# 1. 测试注入 bot 和 event
async def func1(bot, event):
assert bot == mock_bot
assert event.user_id == 123
return True
event = MagicMock(spec=GroupMessageEvent)
event.user_id = 123
result = await handler._run_handler(func1, mock_bot, event)
assert result is True
# 2. 测试注入 args
async def func2(args):
assert args == ["arg1", "arg2"]
return True
result = await handler._run_handler(func2, mock_bot, event, args=["arg1", "arg2"])
assert result is True
@pytest.mark.asyncio
async def test_message_handler_command_parsing(mock_bot):
"""测试命令解析"""
handler = MessageHandler(prefixes=("/",))
mock_func = AsyncMock()
handler.commands["test"] = {
"func": mock_func,
"permission": None,
"override_permission_check": False,
"plugin_name": "test_plugin"
}
event = MagicMock(spec=GroupMessageEvent)
event.raw_message = "/test arg1 arg2"
event.user_id = 123
# Mock permission manager
with patch("core.managers.permission_manager.PermissionManager.check_permission", new_callable=AsyncMock) as mock_perm:
mock_perm.return_value = True
await handler.handle(mock_bot, event)
mock_func.assert_called_once()
# 验证 args 参数是否正确传递
call_args = mock_func.call_args
if "args" in call_args.kwargs:
assert call_args.kwargs["args"] == ["arg1", "arg2"]
@pytest.mark.asyncio
async def test_notice_handler(mock_bot):
"""测试通知事件分发"""
handler = NoticeHandler()
mock_func = AsyncMock()
handler.handlers.append({
"type": "group_increase",
"func": mock_func,
"plugin_name": "test_plugin"
})
event = MagicMock(spec=GroupIncreaseNoticeEvent)
event.notice_type = "group_increase"
await handler.handle(mock_bot, event)
mock_func.assert_called_once()
@pytest.mark.asyncio
async def test_sync_handler_execution(mock_bot):
"""测试同步处理函数的执行"""
handler = MessageHandler(prefixes=("/",))
def sync_func(event):
return True
event = MagicMock(spec=GroupMessageEvent)
# 同步函数应该在线程池中运行
result = await handler._run_handler(sync_func, mock_bot, event)
assert result is True
@pytest.mark.asyncio
async def test_message_handler_management(mock_bot):
"""测试消息处理器的管理(注册、卸载、清空)"""
handler = MessageHandler(prefixes=("/",))
# 测试 on_message 装饰器
@handler.on_message()
async def msg_handler(event):
pass
assert len(handler.message_handlers) == 1
# 测试 command 装饰器
@handler.command("cmd1", "cmd2")
async def cmd_handler(event):
pass
assert len(handler.commands) == 2
assert "cmd1" in handler.commands
assert "cmd2" in handler.commands
# 测试 unregister_by_plugin_name
# 直接从已注册的处理器中获取 plugin_name
if handler.message_handlers:
plugin_name = handler.message_handlers[0]["plugin_name"]
handler.unregister_by_plugin_name(plugin_name)
assert len(handler.message_handlers) == 0
assert len(handler.commands) == 0
# 测试 clear
handler.commands["cmd"] = {}
handler.message_handlers.append({})
handler.clear()
assert len(handler.commands) == 0
assert len(handler.message_handlers) == 0
@pytest.mark.asyncio
async def test_request_handler(mock_bot):
"""测试请求事件处理器"""
handler = RequestHandler()
mock_func = AsyncMock()
# 测试 register 装饰器
@handler.register("friend")
async def req_handler(event):
await mock_func(event)
assert len(handler.handlers) == 1
event = MagicMock(spec=FriendRequestEvent)
event.request_type = "friend"
await handler.handle(mock_bot, event)
mock_func.assert_called_once()
# 测试 unregister 和 clear
import inspect
module = inspect.getmodule(req_handler)
plugin_name = module.__name__
handler.unregister_by_plugin_name(plugin_name)
assert len(handler.handlers) == 0
handler.handlers.append({})
handler.clear()
assert len(handler.handlers) == 0
@pytest.mark.asyncio
async def test_permission_denied(mock_bot):
"""测试权限不足的情况"""
handler = MessageHandler(prefixes=("/",))
mock_func = AsyncMock()
handler.commands["admin_cmd"] = {
"func": mock_func,
"permission": "ADMIN", # 假设 Permission.ADMIN
"override_permission_check": False,
"plugin_name": "test_plugin"
}
event = MagicMock(spec=GroupMessageEvent)
event.raw_message = "/admin_cmd"
event.user_id = 123
# Mock permission manager returning False
with patch("core.managers.permission_manager.PermissionManager.check_permission", new_callable=AsyncMock) as mock_perm:
mock_perm.return_value = False
await handler.handle(mock_bot, event)
mock_func.assert_not_called()
# 应该发送拒绝消息
mock_bot.send.assert_called_once()

75
tests/test_models.py Normal file
View File

@@ -0,0 +1,75 @@
import pytest
from models.message import MessageSegment
from models.objects import GroupInfo, StrangerInfo
class TestMessageSegment:
def test_text_segment(self):
seg = MessageSegment.text("Hello")
assert seg.type == "text"
assert seg.data["text"] == "Hello"
assert str(seg) == "Hello"
def test_at_segment(self):
seg = MessageSegment.at(123456)
assert seg.type == "at"
assert seg.data["qq"] == "123456"
assert str(seg) == "[CQ:at,qq=123456]"
def test_image_segment(self):
seg = MessageSegment.image("http://example.com/img.jpg", cache=False, proxy=False)
assert seg.type == "image"
assert seg.data["file"] == "http://example.com/img.jpg"
assert str(seg) == "[CQ:image,file=http://example.com/img.jpg,cache=0,proxy=0]"
def test_face_segment(self):
seg = MessageSegment.face(123)
assert seg.type == "face"
assert seg.data["id"] == "123"
assert str(seg) == "[CQ:face,id=123]"
def test_reply_segment(self):
seg = MessageSegment.reply(1001)
assert seg.type == "reply"
assert seg.data["id"] == "1001"
assert str(seg) == "[CQ:reply,id=1001]"
def test_add_segments(self):
seg1 = MessageSegment.text("Hello ")
seg2 = MessageSegment.at(123)
combined = seg1 + seg2
assert isinstance(combined, list)
assert len(combined) == 2
assert combined[0] == seg1
assert combined[1] == seg2
def test_add_segment_and_string(self):
seg = MessageSegment.at(123)
combined = seg + " Hello"
assert isinstance(combined, list)
assert len(combined) == 2
assert combined[0] == seg
assert combined[1].type == "text"
assert combined[1].data["text"] == " Hello"
class TestObjects:
def test_group_info(self):
data = {
"group_id": 123456,
"group_name": "Test Group",
"member_count": 10,
"max_member_count": 100
}
group = GroupInfo(**data)
assert group.group_id == 123456
assert group.group_name == "Test Group"
def test_stranger_info(self):
data = {
"user_id": 111111,
"nickname": "Stranger",
"sex": "male",
"age": 18
}
user = StrangerInfo(**data)
assert user.user_id == 111111
assert user.nickname == "Stranger"