feat: 添加测试用例并优化代码结构
refactor(permission_manager): 调整初始化顺序和逻辑 fix(admin_manager): 修复初始化逻辑和目录创建问题 feat(ws): 优化Bot实例初始化条件 feat(message): 增强MessageSegment功能并添加测试 feat(events): 支持字符串格式的消息解析 test: 添加核心功能测试用例 refactor(plugin_manager): 改进插件路径处理 style: 清理无用导入和代码 chore: 更新依赖项
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"admins": []
|
||||
"admins": [2221577113]
|
||||
}
|
||||
@@ -6,11 +6,12 @@
|
||||
"""
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from ..bot import Bot
|
||||
if TYPE_CHECKING:
|
||||
from ..bot import Bot
|
||||
from ..config_loader import global_config
|
||||
from ..managers.permission_manager import Permission
|
||||
from ..permission import Permission
|
||||
from ..utils.executor import run_in_thread_pool
|
||||
|
||||
|
||||
@@ -22,7 +23,7 @@ class BaseHandler(ABC):
|
||||
self.handlers: List[Dict[str, Any]] = []
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, bot: Bot, event: Any):
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理事件
|
||||
"""
|
||||
@@ -31,7 +32,7 @@ class BaseHandler(ABC):
|
||||
async def _run_handler(
|
||||
self,
|
||||
func: Callable,
|
||||
bot: Bot,
|
||||
bot: "Bot",
|
||||
event: Any,
|
||||
args: Optional[List[str]] = None,
|
||||
permission_granted: Optional[bool] = None
|
||||
@@ -122,7 +123,7 @@ class MessageHandler(BaseHandler):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
async def handle(self, bot: Bot, event: Any):
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理消息事件,分发给命令处理器或通用消息处理器
|
||||
"""
|
||||
|
||||
@@ -26,8 +26,7 @@ class AdminManager(Singleton):
|
||||
"""
|
||||
初始化 AdminManager
|
||||
"""
|
||||
super().__init__()
|
||||
if not self._initialized:
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
|
||||
# 管理员数据文件路径
|
||||
@@ -39,7 +38,12 @@ class AdminManager(Singleton):
|
||||
)
|
||||
|
||||
self._admins: Set[int] = set()
|
||||
|
||||
# 确保数据目录存在
|
||||
os.makedirs(os.path.dirname(self.data_file), exist_ok=True)
|
||||
|
||||
logger.info("管理员管理器初始化完成")
|
||||
super().__init__()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
|
||||
@@ -41,7 +41,6 @@ class PermissionManager(Singleton):
|
||||
|
||||
如果已经初始化过,则直接返回。
|
||||
"""
|
||||
super().__init__()
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
|
||||
@@ -64,7 +63,7 @@ class PermissionManager(Singleton):
|
||||
self.load()
|
||||
|
||||
logger.info("权限管理器初始化完成")
|
||||
self._initialized = True
|
||||
super().__init__()
|
||||
|
||||
def load(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -30,12 +30,21 @@ class PluginManager:
|
||||
"""
|
||||
扫描并加载 `plugins` 目录下的所有插件。
|
||||
"""
|
||||
plugin_dir = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "..", "..", "plugins"
|
||||
)
|
||||
# 使用 pathlib 获取更可靠的路径
|
||||
# 当前文件: core/managers/plugin_manager.py
|
||||
# 目标: plugins/
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 回退两级到项目根目录 (core/managers -> core -> root)
|
||||
root_dir = os.path.dirname(os.path.dirname(current_dir))
|
||||
plugin_dir = os.path.join(root_dir, "plugins")
|
||||
|
||||
package_name = "plugins"
|
||||
|
||||
logger.info(f"正在从 {package_name} 加载插件...")
|
||||
if not os.path.exists(plugin_dir):
|
||||
logger.error(f"插件目录不存在: {plugin_dir}")
|
||||
return
|
||||
|
||||
logger.info(f"正在从 {package_name} 加载插件 (路径: {plugin_dir})...")
|
||||
|
||||
for _, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
|
||||
full_module_name = f"{package_name}.{module_name}"
|
||||
|
||||
@@ -33,13 +33,10 @@ class Permission(Enum):
|
||||
return NotImplemented
|
||||
return self._level_map[self] < self._level_map[other]
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self is other
|
||||
|
||||
def __ge__(self, other):
|
||||
"""
|
||||
比较当前权限是否大于等于另一个权限。
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self._level_map[self] >= self._level_map[other]
|
||||
|
||||
|
||||
@@ -60,7 +60,15 @@ class CodeExecutor:
|
||||
将代码执行任务添加到队列中。
|
||||
:param code: 待执行的 Python 代码字符串。
|
||||
:param callback: 执行完毕后用于回复结果的回调函数。
|
||||
:raises RuntimeError: 如果 Docker 客户端未初始化。
|
||||
"""
|
||||
if not self.docker_client:
|
||||
logger.warning("[CodeExecutor] 尝试添加任务,但 Docker 客户端未初始化。任务被拒绝。")
|
||||
# 这里可以选择抛出异常,或者直接调用回调返回错误信息
|
||||
# 为了用户体验,我们构造一个错误结果并直接调用回调(如果可能)
|
||||
# 但由于 callback 返回 Future,这里简单起见,我们记录日志并抛出异常
|
||||
raise RuntimeError("Docker环境未就绪,无法执行代码。")
|
||||
|
||||
task = {"code": code, "callback": callback}
|
||||
await self.task_queue.put(task)
|
||||
logger.info(f"[CodeExecutor] 新的代码执行任务已入队 (队列当前长度: {self.task_queue.qsize()})。")
|
||||
|
||||
@@ -128,8 +128,9 @@ class WS:
|
||||
# 使用工厂创建事件对象
|
||||
event = EventFactory.create_event(event_data)
|
||||
|
||||
# 在收到第一个 meta_event 时,初始化 Bot 实例
|
||||
if event.post_type == "meta_event" and self.bot is None:
|
||||
# 尝试初始化 Bot 实例 (如果尚未初始化且事件包含 self_id)
|
||||
# 只要事件中包含 self_id,我们就可以初始化 Bot,不必非要等待 meta_event
|
||||
if self.bot is None and hasattr(event, 'self_id'):
|
||||
self.self_id = event.self_id
|
||||
self.bot = Bot(self)
|
||||
logger.success(f"Bot 实例初始化完成: self_id={self.self_id}")
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Models 包
|
||||
|
||||
导出常用的模型类,方便插件导入。
|
||||
"""
|
||||
|
||||
from .events.base import OneBotEvent
|
||||
from .events.message import MessageEvent, GroupMessageEvent, PrivateMessageEvent
|
||||
from .events.notice import NoticeEvent
|
||||
from .events.request import RequestEvent
|
||||
from .message import MessageSegment
|
||||
from .sender import Sender
|
||||
|
||||
__all__ = [
|
||||
"OneBotEvent",
|
||||
"MessageEvent",
|
||||
"GroupMessageEvent",
|
||||
"PrivateMessageEvent",
|
||||
"NoticeEvent",
|
||||
"RequestEvent",
|
||||
"MessageSegment",
|
||||
"Sender",
|
||||
]
|
||||
|
||||
@@ -70,7 +70,11 @@ class EventFactory:
|
||||
# 解析消息段
|
||||
message_list = []
|
||||
raw_message_list = data.get("message", [])
|
||||
if isinstance(raw_message_list, list):
|
||||
|
||||
if isinstance(raw_message_list, str):
|
||||
# 如果消息是字符串,将其视为纯文本消息段
|
||||
message_list.append(MessageSegment.text(raw_message_list))
|
||||
elif isinstance(raw_message_list, list):
|
||||
for item in raw_message_list:
|
||||
if isinstance(item, dict):
|
||||
message_list.append(MessageSegment(type=item.get("type", ""), data=item.get("data", {})))
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, List
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -23,7 +23,7 @@ class MessageSegment:
|
||||
data: Dict[str, Any]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
def plain_text(self) -> str:
|
||||
"""
|
||||
当消息段类型为 'text' 时,快速获取其文本内容。
|
||||
|
||||
@@ -32,6 +32,19 @@ class MessageSegment:
|
||||
"""
|
||||
return self.data.get("text", "") if self.type == "text" else ""
|
||||
|
||||
@staticmethod
|
||||
def text(text: str) -> "MessageSegment":
|
||||
"""
|
||||
创建一个文本消息段。
|
||||
|
||||
Args:
|
||||
text (str): 文本内容。
|
||||
|
||||
Returns:
|
||||
MessageSegment: 一个类型为 'text' 的消息段对象。
|
||||
"""
|
||||
return MessageSegment(type="text", data={"text": text})
|
||||
|
||||
@property
|
||||
def image_url(self) -> str:
|
||||
"""
|
||||
@@ -93,12 +106,48 @@ class MessageSegment:
|
||||
return True
|
||||
return str(self.data.get("qq")) == str(user_id)
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
返回消息段的 CQ 码字符串表示。
|
||||
"""
|
||||
if self.type == "text":
|
||||
return self.data.get("text", "")
|
||||
|
||||
params = ",".join([f"{k}={v}" for k, v in self.data.items()])
|
||||
if params:
|
||||
return f"[CQ:{self.type},{params}]"
|
||||
return f"[CQ:{self.type}]"
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
返回消息段对象的字符串表示形式,便于调试。
|
||||
"""
|
||||
return f"[MS:{self.type}:{self.data}]"
|
||||
|
||||
def __add__(self, other: Any) -> "List[MessageSegment]":
|
||||
"""
|
||||
支持消息段相加,返回消息段列表。
|
||||
"""
|
||||
if isinstance(other, MessageSegment):
|
||||
return [self, other]
|
||||
elif isinstance(other, str):
|
||||
return [self, MessageSegment.text(other)]
|
||||
elif isinstance(other, list):
|
||||
return [self] + other
|
||||
return NotImplemented
|
||||
|
||||
def __radd__(self, other: Any) -> "List[MessageSegment]":
|
||||
"""
|
||||
支持反向相加。
|
||||
"""
|
||||
if isinstance(other, MessageSegment):
|
||||
return [other, self]
|
||||
elif isinstance(other, str):
|
||||
return [MessageSegment.text(other), self]
|
||||
elif isinstance(other, list):
|
||||
return other + [self]
|
||||
return NotImplemented
|
||||
|
||||
# --- 快捷构造方法 ---
|
||||
|
||||
@staticmethod
|
||||
@@ -297,17 +346,17 @@ class MessageSegment:
|
||||
return MessageSegment(type="file", data={"file": file})
|
||||
|
||||
@staticmethod
|
||||
def reply(message_id: str) -> "MessageSegment":
|
||||
def reply(message_id: str | int) -> "MessageSegment":
|
||||
"""
|
||||
创建一个回复消息段。
|
||||
|
||||
Args:
|
||||
message_id (str): 被回复的消息 ID。
|
||||
message_id (str | int): 被回复的消息 ID。
|
||||
|
||||
Returns:
|
||||
MessageSegment: 一个类型为 'reply' 的消息段对象。
|
||||
"""
|
||||
return MessageSegment(type="reply", data={"id": message_id})
|
||||
return MessageSegment(type="reply", data={"id": str(message_id)})
|
||||
|
||||
@staticmethod
|
||||
def rps() -> "MessageSegment":
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
import json
|
||||
import httpx
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from cachetools import TTLCache
|
||||
|
||||
from core.utils.logger import logger
|
||||
from core.managers.command_manager import matcher
|
||||
from models.events.message import MessageEvent, MessageSegment
|
||||
from models import MessageEvent, MessageSegment
|
||||
|
||||
# 创建一个TTL缓存,最大容量100,缓存时间60秒
|
||||
processed_messages: TTLCache[Any, bool] = TTLCache(maxsize=100, ttl=60)
|
||||
# 创建一个TTL缓存,最大容量100,缓存时间10秒
|
||||
processed_messages: TTLCache[int, bool] = TTLCache(maxsize=100, ttl=10)
|
||||
|
||||
__plugin_meta__ = {
|
||||
"name": "bili_parser",
|
||||
@@ -23,9 +23,6 @@ HEADERS = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||||
}
|
||||
|
||||
# 创建可复用的异步HTTP客户端
|
||||
async_client = httpx.AsyncClient(headers=HEADERS, follow_redirects=False, timeout=10)
|
||||
|
||||
|
||||
def format_count(num: int) -> str:
|
||||
if not isinstance(num, int):
|
||||
@@ -43,32 +40,29 @@ def format_duration(seconds: int) -> str:
|
||||
return f"{minutes:02d}:{seconds:02d}"
|
||||
|
||||
|
||||
async def get_real_url(short_url: str) -> Optional[str]:
|
||||
def get_real_url(short_url: str) -> Optional[str]:
|
||||
try:
|
||||
response = await async_client.head(short_url)
|
||||
response = requests.head(short_url, headers=HEADERS, allow_redirects=False, timeout=5)
|
||||
if response.status_code == 302:
|
||||
return response.headers.get('Location')
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"获取真实URL失败: {e}")
|
||||
except requests.RequestException as e:
|
||||
print(f"获取真实URL失败: {e}")
|
||||
return None
|
||||
|
||||
async def parse_video_info(video_url: str) -> Optional[Dict[str, Any]]:
|
||||
def parse_video_info(video_url: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
response = await async_client.get(video_url, follow_redirects=True)
|
||||
response = requests.get(video_url, headers=HEADERS, timeout=5)
|
||||
response.raise_for_status()
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
|
||||
script_tag = soup.find('script', text=re.compile('window.__INITIAL_STATE__'))
|
||||
if not script_tag:
|
||||
if not script_tag or not script_tag.string:
|
||||
return None
|
||||
|
||||
script_tag_content = script_tag.string
|
||||
if not script_tag_content:
|
||||
return None
|
||||
|
||||
match = re.search(r'window\.__INITIAL_STATE__\s*=\s*(\{.*?\});', script_tag_content)
|
||||
match = re.search(r'window\.__INITIAL_STATE__\s*=\s*(\{.*?\});', script_tag.string)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
json_str = match.group(1)
|
||||
data = json.loads(json_str)
|
||||
|
||||
@@ -104,12 +98,12 @@ async def parse_video_info(video_url: str) -> Optional[Dict[str, Any]]:
|
||||
"followers": up_data.get('fans', 0),
|
||||
}
|
||||
|
||||
except (httpx.RequestError, KeyError, AttributeError, json.JSONDecodeError) as e:
|
||||
logger.error(f"解析视频信息失败: {e}")
|
||||
except (requests.RequestException, KeyError, AttributeError, json.JSONDecodeError) as e:
|
||||
print(f"解析视频信息失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def get_direct_video_url(video_url: str) -> Optional[str]:
|
||||
def get_direct_video_url(video_url: str) -> Optional[str]:
|
||||
"""
|
||||
调用第三方API解析B站视频直链
|
||||
:param video_url: B站视频的完整URL
|
||||
@@ -117,12 +111,12 @@ async def get_direct_video_url(video_url: str) -> Optional[str]:
|
||||
"""
|
||||
api_url = f"https://api.mir6.com/api/bzjiexi?url={video_url}&type=json"
|
||||
try:
|
||||
response = await async_client.get(api_url)
|
||||
response = requests.get(api_url, headers=HEADERS, timeout=10)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if data.get("code") == 200 and data.get("data"):
|
||||
return data["data"][0].get("video_url")
|
||||
except (httpx.RequestError, json.JSONDecodeError, KeyError, IndexError) as e:
|
||||
except (requests.RequestException, json.JSONDecodeError, KeyError, IndexError) as e:
|
||||
logger.error(f"[bili_parser] 调用第三方API解析视频失败: {e}")
|
||||
return None
|
||||
|
||||
@@ -184,7 +178,7 @@ async def process_bili_link(event: MessageEvent, url: str):
|
||||
:param url: 待处理的B站链接
|
||||
"""
|
||||
if "b23.tv" in url:
|
||||
real_url = await get_real_url(url)
|
||||
real_url = get_real_url(url)
|
||||
if not real_url:
|
||||
logger.error(f"[bili_parser] 无法从 {url} 获取真实URL。")
|
||||
await event.reply("无法解析B站短链接。")
|
||||
@@ -192,28 +186,59 @@ async def process_bili_link(event: MessageEvent, url: str):
|
||||
else:
|
||||
real_url = url.split('?')[0]
|
||||
|
||||
video_info = await parse_video_info(real_url)
|
||||
video_info = parse_video_info(real_url)
|
||||
if not video_info:
|
||||
logger.error(f"[bili_parser] 无法从 {real_url} 解析视频信息。")
|
||||
await event.reply("无法获取视频信息,可能是B站接口变动或视频不存在。")
|
||||
return
|
||||
|
||||
title = video_info.get("title", "未知标题")
|
||||
owner_name = video_info.get("owner_name", "未知UP主")
|
||||
cover_url = video_info.get("cover_url")
|
||||
bvid = video_info.get("bvid", "N/A")
|
||||
play_count = format_count(video_info.get("play", 0))
|
||||
like_count = format_count(video_info.get("like", 0))
|
||||
# 检查视频时长
|
||||
video_message: Union[str, MessageSegment]
|
||||
if video_info['duration'] > 300: # 5分钟 = 300秒
|
||||
video_message = "视频时长超过5分钟,不进行解析。"
|
||||
else:
|
||||
direct_url = get_direct_video_url(real_url)
|
||||
if direct_url:
|
||||
video_message = MessageSegment.video(direct_url)
|
||||
else:
|
||||
video_message = "视频解析失败,无法获取直链。"
|
||||
|
||||
text_part = (
|
||||
f"标题: {title}\n"
|
||||
f"UP主: {owner_name}\n"
|
||||
f"BV: {bvid} | ▶️ {play_count} | 👍 {like_count}"
|
||||
text_message = (
|
||||
f"BiliBili 视频解析\n"
|
||||
f"--------------------\n"
|
||||
f" UP主: {video_info['owner_name']}\n"
|
||||
f" 粉丝: {format_count(video_info['followers'])}\n"
|
||||
f"--------------------\n"
|
||||
f" 标题: {video_info['title']}\n"
|
||||
f" BV号: {video_info['bvid']}\n"
|
||||
f" 时长: {format_duration(video_info['duration'])}\n"
|
||||
f"--------------------\n"
|
||||
f" 数据:\n"
|
||||
f" 播放: {format_count(video_info['play'])}\n"
|
||||
f" 点赞: {format_count(video_info['like'])}\n"
|
||||
f" 投币: {format_count(video_info['coin'])}\n"
|
||||
f" 收藏: {format_count(video_info['favorite'])}\n"
|
||||
f" 转发: {format_count(video_info['share'])}\n"
|
||||
f" B站链接: {url}"
|
||||
)
|
||||
|
||||
reply_message = [MessageSegment.from_text(text_part)]
|
||||
if cover_url:
|
||||
reply_message.append(MessageSegment.image(cover_url))
|
||||
|
||||
logger.success(f"[bili_parser] 成功解析视频信息并准备回复: {title}")
|
||||
await event.reply(reply_message)
|
||||
image_message_segment = [
|
||||
MessageSegment.text("B站封面:"),
|
||||
MessageSegment.image(video_info['cover_url'])
|
||||
]
|
||||
|
||||
up_info_segment = [
|
||||
MessageSegment.text("UP主头像:"),
|
||||
MessageSegment.image(video_info['owner_avatar'])
|
||||
]
|
||||
|
||||
nodes = [
|
||||
event.bot.build_forward_node(user_id=event.self_id, nickname="B站视频解析", message=text_message),
|
||||
event.bot.build_forward_node(user_id=event.self_id, nickname="B站视频解析", message=image_message_segment),
|
||||
event.bot.build_forward_node(user_id=event.self_id, nickname="B站视频解析", message=up_info_segment),
|
||||
event.bot.build_forward_node(user_id=event.self_id, nickname="B站视频解析", message=video_message)
|
||||
]
|
||||
|
||||
logger.success(f"[bili_parser] 成功解析视频信息并准备以聊天记录形式回复: {video_info['title']}")
|
||||
# 使用更通用的 send_forwarded_messages 方法,自动判断私聊或群聊
|
||||
await event.bot.send_forwarded_messages(target=event, nodes=nodes)
|
||||
|
||||
@@ -11,7 +11,6 @@ pipreqs==0.4.13
|
||||
redis==5.0.7
|
||||
requests==2.32.5
|
||||
soupsieve==2.8.1
|
||||
toml==0.10.2
|
||||
typing==3.7.4.3
|
||||
typing_extensions==4.15.0
|
||||
urllib3==2.6.2
|
||||
@@ -25,6 +24,7 @@ docker
|
||||
pytest
|
||||
pytest-asyncio
|
||||
pytest-mock
|
||||
pytest-cov
|
||||
httpx==0.27.0
|
||||
|
||||
# Dev Dependencies
|
||||
|
||||
37
tests/test_basic.py
Normal file
37
tests/test_basic.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 确保项目根目录在 sys.path 中
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
def test_import_core():
|
||||
"""测试核心模块是否可以被导入"""
|
||||
try:
|
||||
import core
|
||||
import core.bot
|
||||
import core.ws
|
||||
except ImportError as e:
|
||||
pytest.fail(f"无法导入核心模块: {e}")
|
||||
|
||||
def test_plugin_manager_path():
|
||||
"""测试插件管理器路径逻辑是否正确"""
|
||||
from core.managers.plugin_manager import PluginManager
|
||||
# Mock command manager
|
||||
pm = PluginManager(None)
|
||||
|
||||
# 我们无法直接测试 load_all_plugins 的内部路径变量,
|
||||
# 但我们可以检查它是否能找到 plugins 目录而不报错
|
||||
# 这里我们简单地断言 PluginManager 类存在且可以实例化
|
||||
assert pm is not None
|
||||
|
||||
def test_config_loader_exists():
|
||||
"""测试配置加载器是否存在"""
|
||||
# 注意:导入 config_loader 会尝试读取 config.toml
|
||||
# 如果 config.toml 不存在,这可能会失败。
|
||||
# 这是一个已知的设计问题,但在测试环境中我们假设 config.toml 存在或被 mock
|
||||
if os.path.exists("config.toml"):
|
||||
from core.config_loader import global_config
|
||||
assert global_config is not None
|
||||
else:
|
||||
pytest.skip("config.toml 不存在,跳过配置加载测试")
|
||||
114
tests/test_command_manager.py
Normal file
114
tests/test_command_manager.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from core.managers.command_manager import CommandManager
|
||||
from models.events.message import GroupMessageEvent
|
||||
from models.message import MessageSegment
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot():
|
||||
bot = AsyncMock()
|
||||
bot.self_id = 123456
|
||||
return bot
|
||||
|
||||
@pytest.fixture
|
||||
def command_manager():
|
||||
# 创建一个新的 CommandManager 实例用于测试,避免单例状态污染
|
||||
return CommandManager(prefixes=("/",))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_registration_and_execution(command_manager, mock_bot):
|
||||
"""测试命令注册和执行"""
|
||||
|
||||
# 定义一个命令处理函数
|
||||
handler_mock = AsyncMock()
|
||||
|
||||
# 注册命令
|
||||
@command_manager.command("test")
|
||||
async def test_command(bot, event):
|
||||
await handler_mock(bot, event)
|
||||
|
||||
# 构造触发命令的事件
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
event.post_type = "message"
|
||||
event.message_type = "group"
|
||||
event.raw_message = "/test"
|
||||
event.message = [MessageSegment.text("/test")]
|
||||
event.user_id = 111
|
||||
event.group_id = 222
|
||||
|
||||
# 处理事件
|
||||
await command_manager.handle_event(mock_bot, event)
|
||||
|
||||
# 验证处理函数被调用
|
||||
handler_mock.assert_called_once_with(mock_bot, event)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_prefix_match(command_manager, mock_bot):
|
||||
"""测试命令前缀匹配"""
|
||||
handler_mock = AsyncMock()
|
||||
|
||||
@command_manager.command("hello")
|
||||
async def hello_command(bot, event):
|
||||
await handler_mock(bot, event)
|
||||
|
||||
# 1. 正确的前缀
|
||||
event1 = MagicMock(spec=GroupMessageEvent)
|
||||
event1.post_type = "message"
|
||||
event1.raw_message = "/hello"
|
||||
event1.message = [MessageSegment.text("/hello")]
|
||||
await command_manager.handle_event(mock_bot, event1)
|
||||
handler_mock.assert_called_once()
|
||||
handler_mock.reset_mock()
|
||||
|
||||
# 2. 错误的前缀 (应该忽略)
|
||||
event2 = MagicMock(spec=GroupMessageEvent)
|
||||
event2.post_type = "message"
|
||||
event2.raw_message = ".hello" # 假设前缀是 /
|
||||
event2.message = [MessageSegment.text(".hello")]
|
||||
await command_manager.handle_event(mock_bot, event2)
|
||||
handler_mock.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignore_self_message(command_manager, mock_bot):
|
||||
"""测试忽略自身消息"""
|
||||
# 模拟配置
|
||||
with patch("core.managers.command_manager.global_config") as mock_config:
|
||||
mock_config.bot.ignore_self_message = True
|
||||
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
event.post_type = "message"
|
||||
event.user_id = 123456 # 与 bot.self_id 相同
|
||||
event.self_id = 123456
|
||||
|
||||
# Mock handle 方法来检测是否被调用
|
||||
command_manager.message_handler.handle = AsyncMock()
|
||||
|
||||
await command_manager.handle_event(mock_bot, event)
|
||||
|
||||
# 应该直接返回,不调用 handler
|
||||
command_manager.message_handler.handle.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_command(command_manager, mock_bot):
|
||||
"""测试内置 help 命令"""
|
||||
# 注册一个测试插件信息
|
||||
command_manager.plugins["test_plugin"] = {
|
||||
"name": "测试插件",
|
||||
"description": "这是一个测试",
|
||||
"usage": "/test"
|
||||
}
|
||||
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
event.post_type = "message"
|
||||
event.raw_message = "/help"
|
||||
event.message = [MessageSegment.text("/help")]
|
||||
|
||||
await command_manager.handle_event(mock_bot, event)
|
||||
|
||||
# 验证 bot.send 被调用,且内容包含插件信息
|
||||
mock_bot.send.assert_called_once()
|
||||
args, _ = mock_bot.send.call_args
|
||||
sent_msg = args[1]
|
||||
assert "测试插件" in sent_msg
|
||||
assert "这是一个测试" in sent_msg
|
||||
141
tests/test_event_factory.py
Normal file
141
tests/test_event_factory.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
from models.events.factory import EventFactory, EventType
|
||||
from models.events.message import GroupMessageEvent, PrivateMessageEvent
|
||||
from models.events.notice import GroupIncreaseNoticeEvent
|
||||
from models.events.request import FriendRequestEvent
|
||||
from models.events.meta import HeartbeatEvent
|
||||
from models.message import MessageSegment
|
||||
|
||||
class TestEventFactory:
|
||||
def test_create_group_message_event_list(self):
|
||||
"""测试创建群消息事件 (message 为列表格式)"""
|
||||
data = {
|
||||
"post_type": "message",
|
||||
"message_type": "group",
|
||||
"time": 1600000000,
|
||||
"self_id": 123456,
|
||||
"sub_type": "normal",
|
||||
"message_id": 1001,
|
||||
"user_id": 111111,
|
||||
"group_id": 222222,
|
||||
"message": [
|
||||
{"type": "text", "data": {"text": "Hello"}}
|
||||
],
|
||||
"raw_message": "Hello",
|
||||
"font": 0,
|
||||
"sender": {
|
||||
"user_id": 111111,
|
||||
"nickname": "User",
|
||||
"role": "member"
|
||||
}
|
||||
}
|
||||
event = EventFactory.create_event(data)
|
||||
assert isinstance(event, GroupMessageEvent)
|
||||
assert event.group_id == 222222
|
||||
assert len(event.message) == 1
|
||||
assert event.message[0].type == "text"
|
||||
assert event.message[0].data["text"] == "Hello"
|
||||
|
||||
def test_create_group_message_event_str(self):
|
||||
"""测试创建群消息事件 (message 为字符串格式)"""
|
||||
data = {
|
||||
"post_type": "message",
|
||||
"message_type": "group",
|
||||
"time": 1600000000,
|
||||
"self_id": 123456,
|
||||
"sub_type": "normal",
|
||||
"message_id": 1002,
|
||||
"user_id": 111111,
|
||||
"group_id": 222222,
|
||||
"message": "Hello World",
|
||||
"raw_message": "Hello World",
|
||||
"font": 0,
|
||||
"sender": {
|
||||
"user_id": 111111,
|
||||
"nickname": "User"
|
||||
}
|
||||
}
|
||||
event = EventFactory.create_event(data)
|
||||
assert isinstance(event, GroupMessageEvent)
|
||||
assert len(event.message) == 1
|
||||
assert event.message[0].type == "text"
|
||||
assert event.message[0].data["text"] == "Hello World"
|
||||
|
||||
def test_create_private_message_event(self):
|
||||
"""测试创建私聊消息事件"""
|
||||
data = {
|
||||
"post_type": "message",
|
||||
"message_type": "private",
|
||||
"time": 1600000000,
|
||||
"self_id": 123456,
|
||||
"sub_type": "friend",
|
||||
"message_id": 2001,
|
||||
"user_id": 333333,
|
||||
"message": "Private Msg",
|
||||
"raw_message": "Private Msg",
|
||||
"font": 0,
|
||||
"sender": {
|
||||
"user_id": 333333,
|
||||
"nickname": "Friend"
|
||||
}
|
||||
}
|
||||
event = EventFactory.create_event(data)
|
||||
assert isinstance(event, PrivateMessageEvent)
|
||||
assert event.user_id == 333333
|
||||
|
||||
def test_create_notice_event(self):
|
||||
"""测试创建通知事件 (群成员增加)"""
|
||||
data = {
|
||||
"post_type": "notice",
|
||||
"notice_type": "group_increase",
|
||||
"sub_type": "approve",
|
||||
"group_id": 222222,
|
||||
"operator_id": 444444,
|
||||
"user_id": 555555,
|
||||
"time": 1600000000,
|
||||
"self_id": 123456
|
||||
}
|
||||
event = EventFactory.create_event(data)
|
||||
assert isinstance(event, GroupIncreaseNoticeEvent)
|
||||
assert event.group_id == 222222
|
||||
assert event.user_id == 555555
|
||||
|
||||
def test_create_request_event(self):
|
||||
"""测试创建请求事件 (加好友)"""
|
||||
data = {
|
||||
"post_type": "request",
|
||||
"request_type": "friend",
|
||||
"user_id": 666666,
|
||||
"comment": "Add me",
|
||||
"flag": "flag_123",
|
||||
"time": 1600000000,
|
||||
"self_id": 123456
|
||||
}
|
||||
event = EventFactory.create_event(data)
|
||||
assert isinstance(event, FriendRequestEvent)
|
||||
assert event.user_id == 666666
|
||||
assert event.comment == "Add me"
|
||||
|
||||
def test_create_meta_event(self):
|
||||
"""测试创建元事件 (心跳)"""
|
||||
data = {
|
||||
"post_type": "meta_event",
|
||||
"meta_event_type": "heartbeat",
|
||||
"time": 1600000000,
|
||||
"self_id": 123456,
|
||||
"status": {"online": True, "good": True},
|
||||
"interval": 5000
|
||||
}
|
||||
event = EventFactory.create_event(data)
|
||||
assert isinstance(event, HeartbeatEvent)
|
||||
assert event.interval == 5000
|
||||
|
||||
def test_unknown_event_type(self):
|
||||
"""测试未知事件类型"""
|
||||
data = {
|
||||
"post_type": "unknown_type",
|
||||
"time": 1600000000,
|
||||
"self_id": 123456
|
||||
}
|
||||
with pytest.raises(ValueError, match="Unknown event type"):
|
||||
EventFactory.create_event(data)
|
||||
194
tests/test_event_handler.py
Normal file
194
tests/test_event_handler.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from core.handlers.event_handler import MessageHandler, NoticeHandler, RequestHandler
|
||||
from models.events.message import GroupMessageEvent
|
||||
from models.events.notice import GroupIncreaseNoticeEvent
|
||||
from models.events.request import FriendRequestEvent
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot():
|
||||
bot = AsyncMock()
|
||||
return bot
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_handler_run_handler_injection(mock_bot):
|
||||
"""测试参数注入"""
|
||||
handler = MessageHandler(prefixes=("/",))
|
||||
|
||||
# 1. 测试注入 bot 和 event
|
||||
async def func1(bot, event):
|
||||
assert bot == mock_bot
|
||||
assert event.user_id == 123
|
||||
return True
|
||||
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
event.user_id = 123
|
||||
|
||||
result = await handler._run_handler(func1, mock_bot, event)
|
||||
assert result is True
|
||||
|
||||
# 2. 测试注入 args
|
||||
async def func2(args):
|
||||
assert args == ["arg1", "arg2"]
|
||||
return True
|
||||
|
||||
result = await handler._run_handler(func2, mock_bot, event, args=["arg1", "arg2"])
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_handler_command_parsing(mock_bot):
|
||||
"""测试命令解析"""
|
||||
handler = MessageHandler(prefixes=("/",))
|
||||
|
||||
mock_func = AsyncMock()
|
||||
handler.commands["test"] = {
|
||||
"func": mock_func,
|
||||
"permission": None,
|
||||
"override_permission_check": False,
|
||||
"plugin_name": "test_plugin"
|
||||
}
|
||||
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
event.raw_message = "/test arg1 arg2"
|
||||
event.user_id = 123
|
||||
|
||||
# Mock permission manager
|
||||
with patch("core.managers.permission_manager.PermissionManager.check_permission", new_callable=AsyncMock) as mock_perm:
|
||||
mock_perm.return_value = True
|
||||
|
||||
await handler.handle(mock_bot, event)
|
||||
|
||||
mock_func.assert_called_once()
|
||||
# 验证 args 参数是否正确传递
|
||||
call_args = mock_func.call_args
|
||||
if "args" in call_args.kwargs:
|
||||
assert call_args.kwargs["args"] == ["arg1", "arg2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notice_handler(mock_bot):
|
||||
"""测试通知事件分发"""
|
||||
handler = NoticeHandler()
|
||||
|
||||
mock_func = AsyncMock()
|
||||
handler.handlers.append({
|
||||
"type": "group_increase",
|
||||
"func": mock_func,
|
||||
"plugin_name": "test_plugin"
|
||||
})
|
||||
|
||||
event = MagicMock(spec=GroupIncreaseNoticeEvent)
|
||||
event.notice_type = "group_increase"
|
||||
|
||||
await handler.handle(mock_bot, event)
|
||||
|
||||
mock_func.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_handler_execution(mock_bot):
|
||||
"""测试同步处理函数的执行"""
|
||||
handler = MessageHandler(prefixes=("/",))
|
||||
|
||||
def sync_func(event):
|
||||
return True
|
||||
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
|
||||
# 同步函数应该在线程池中运行
|
||||
result = await handler._run_handler(sync_func, mock_bot, event)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_handler_management(mock_bot):
|
||||
"""测试消息处理器的管理(注册、卸载、清空)"""
|
||||
handler = MessageHandler(prefixes=("/",))
|
||||
|
||||
# 测试 on_message 装饰器
|
||||
@handler.on_message()
|
||||
async def msg_handler(event):
|
||||
pass
|
||||
|
||||
assert len(handler.message_handlers) == 1
|
||||
|
||||
# 测试 command 装饰器
|
||||
@handler.command("cmd1", "cmd2")
|
||||
async def cmd_handler(event):
|
||||
pass
|
||||
|
||||
assert len(handler.commands) == 2
|
||||
assert "cmd1" in handler.commands
|
||||
assert "cmd2" in handler.commands
|
||||
|
||||
# 测试 unregister_by_plugin_name
|
||||
# 直接从已注册的处理器中获取 plugin_name
|
||||
if handler.message_handlers:
|
||||
plugin_name = handler.message_handlers[0]["plugin_name"]
|
||||
handler.unregister_by_plugin_name(plugin_name)
|
||||
|
||||
assert len(handler.message_handlers) == 0
|
||||
assert len(handler.commands) == 0
|
||||
|
||||
# 测试 clear
|
||||
handler.commands["cmd"] = {}
|
||||
handler.message_handlers.append({})
|
||||
handler.clear()
|
||||
assert len(handler.commands) == 0
|
||||
assert len(handler.message_handlers) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_handler(mock_bot):
|
||||
"""测试请求事件处理器"""
|
||||
handler = RequestHandler()
|
||||
|
||||
mock_func = AsyncMock()
|
||||
|
||||
# 测试 register 装饰器
|
||||
@handler.register("friend")
|
||||
async def req_handler(event):
|
||||
await mock_func(event)
|
||||
|
||||
assert len(handler.handlers) == 1
|
||||
|
||||
event = MagicMock(spec=FriendRequestEvent)
|
||||
event.request_type = "friend"
|
||||
|
||||
await handler.handle(mock_bot, event)
|
||||
mock_func.assert_called_once()
|
||||
|
||||
# 测试 unregister 和 clear
|
||||
import inspect
|
||||
module = inspect.getmodule(req_handler)
|
||||
plugin_name = module.__name__
|
||||
|
||||
handler.unregister_by_plugin_name(plugin_name)
|
||||
assert len(handler.handlers) == 0
|
||||
|
||||
handler.handlers.append({})
|
||||
handler.clear()
|
||||
assert len(handler.handlers) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_denied(mock_bot):
|
||||
"""测试权限不足的情况"""
|
||||
handler = MessageHandler(prefixes=("/",))
|
||||
|
||||
mock_func = AsyncMock()
|
||||
handler.commands["admin_cmd"] = {
|
||||
"func": mock_func,
|
||||
"permission": "ADMIN", # 假设 Permission.ADMIN
|
||||
"override_permission_check": False,
|
||||
"plugin_name": "test_plugin"
|
||||
}
|
||||
|
||||
event = MagicMock(spec=GroupMessageEvent)
|
||||
event.raw_message = "/admin_cmd"
|
||||
event.user_id = 123
|
||||
|
||||
# Mock permission manager returning False
|
||||
with patch("core.managers.permission_manager.PermissionManager.check_permission", new_callable=AsyncMock) as mock_perm:
|
||||
mock_perm.return_value = False
|
||||
|
||||
await handler.handle(mock_bot, event)
|
||||
|
||||
mock_func.assert_not_called()
|
||||
# 应该发送拒绝消息
|
||||
mock_bot.send.assert_called_once()
|
||||
75
tests/test_models.py
Normal file
75
tests/test_models.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
from models.message import MessageSegment
|
||||
from models.objects import GroupInfo, StrangerInfo
|
||||
|
||||
class TestMessageSegment:
|
||||
def test_text_segment(self):
|
||||
seg = MessageSegment.text("Hello")
|
||||
assert seg.type == "text"
|
||||
assert seg.data["text"] == "Hello"
|
||||
assert str(seg) == "Hello"
|
||||
|
||||
def test_at_segment(self):
|
||||
seg = MessageSegment.at(123456)
|
||||
assert seg.type == "at"
|
||||
assert seg.data["qq"] == "123456"
|
||||
assert str(seg) == "[CQ:at,qq=123456]"
|
||||
|
||||
def test_image_segment(self):
|
||||
seg = MessageSegment.image("http://example.com/img.jpg", cache=False, proxy=False)
|
||||
assert seg.type == "image"
|
||||
assert seg.data["file"] == "http://example.com/img.jpg"
|
||||
assert str(seg) == "[CQ:image,file=http://example.com/img.jpg,cache=0,proxy=0]"
|
||||
|
||||
def test_face_segment(self):
|
||||
seg = MessageSegment.face(123)
|
||||
assert seg.type == "face"
|
||||
assert seg.data["id"] == "123"
|
||||
assert str(seg) == "[CQ:face,id=123]"
|
||||
|
||||
def test_reply_segment(self):
|
||||
seg = MessageSegment.reply(1001)
|
||||
assert seg.type == "reply"
|
||||
assert seg.data["id"] == "1001"
|
||||
assert str(seg) == "[CQ:reply,id=1001]"
|
||||
|
||||
def test_add_segments(self):
|
||||
seg1 = MessageSegment.text("Hello ")
|
||||
seg2 = MessageSegment.at(123)
|
||||
combined = seg1 + seg2
|
||||
assert isinstance(combined, list)
|
||||
assert len(combined) == 2
|
||||
assert combined[0] == seg1
|
||||
assert combined[1] == seg2
|
||||
|
||||
def test_add_segment_and_string(self):
|
||||
seg = MessageSegment.at(123)
|
||||
combined = seg + " Hello"
|
||||
assert isinstance(combined, list)
|
||||
assert len(combined) == 2
|
||||
assert combined[0] == seg
|
||||
assert combined[1].type == "text"
|
||||
assert combined[1].data["text"] == " Hello"
|
||||
|
||||
class TestObjects:
|
||||
def test_group_info(self):
|
||||
data = {
|
||||
"group_id": 123456,
|
||||
"group_name": "Test Group",
|
||||
"member_count": 10,
|
||||
"max_member_count": 100
|
||||
}
|
||||
group = GroupInfo(**data)
|
||||
assert group.group_id == 123456
|
||||
assert group.group_name == "Test Group"
|
||||
|
||||
def test_stranger_info(self):
|
||||
data = {
|
||||
"user_id": 111111,
|
||||
"nickname": "Stranger",
|
||||
"sex": "male",
|
||||
"age": 18
|
||||
}
|
||||
user = StrangerInfo(**data)
|
||||
assert user.user_id == 111111
|
||||
assert user.nickname == "Stranger"
|
||||
Reference in New Issue
Block a user