feat: 添加测试用例并优化代码结构

refactor(permission_manager): 调整初始化顺序和逻辑
fix(admin_manager): 修复初始化逻辑和目录创建问题
feat(ws): 优化Bot实例初始化条件
feat(message): 增强MessageSegment功能并添加测试
feat(events): 支持字符串格式的消息解析
test: 添加核心功能测试用例
refactor(plugin_manager): 改进插件路径处理
style: 清理无用导入和代码
chore: 更新依赖项
This commit is contained in:
2026-01-09 00:20:30 +08:00
parent 5d07a84283
commit 77348113e3
18 changed files with 754 additions and 73 deletions

View File

@@ -1,3 +1,3 @@
{ {
"admins": [] "admins": [2221577113]
} }

View File

@@ -6,11 +6,12 @@
""" """
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
from ..bot import Bot if TYPE_CHECKING:
from ..bot import Bot
from ..config_loader import global_config from ..config_loader import global_config
from ..managers.permission_manager import Permission from ..permission import Permission
from ..utils.executor import run_in_thread_pool from ..utils.executor import run_in_thread_pool
@@ -22,7 +23,7 @@ class BaseHandler(ABC):
self.handlers: List[Dict[str, Any]] = [] self.handlers: List[Dict[str, Any]] = []
@abstractmethod @abstractmethod
async def handle(self, bot: Bot, event: Any): async def handle(self, bot: "Bot", event: Any):
""" """
处理事件 处理事件
""" """
@@ -31,7 +32,7 @@ class BaseHandler(ABC):
async def _run_handler( async def _run_handler(
self, self,
func: Callable, func: Callable,
bot: Bot, bot: "Bot",
event: Any, event: Any,
args: Optional[List[str]] = None, args: Optional[List[str]] = None,
permission_granted: Optional[bool] = None permission_granted: Optional[bool] = None
@@ -122,7 +123,7 @@ class MessageHandler(BaseHandler):
return func return func
return decorator return decorator
async def handle(self, bot: Bot, event: Any): async def handle(self, bot: "Bot", event: Any):
""" """
处理消息事件,分发给命令处理器或通用消息处理器 处理消息事件,分发给命令处理器或通用消息处理器
""" """

View File

@@ -26,8 +26,7 @@ class AdminManager(Singleton):
""" """
初始化 AdminManager 初始化 AdminManager
""" """
super().__init__() if hasattr(self, '_initialized') and self._initialized:
if not self._initialized:
return return
# 管理员数据文件路径 # 管理员数据文件路径
@@ -39,7 +38,12 @@ class AdminManager(Singleton):
) )
self._admins: Set[int] = set() self._admins: Set[int] = set()
# 确保数据目录存在
os.makedirs(os.path.dirname(self.data_file), exist_ok=True)
logger.info("管理员管理器初始化完成") logger.info("管理员管理器初始化完成")
super().__init__()
async def initialize(self): async def initialize(self):
""" """

View File

@@ -41,7 +41,6 @@ class PermissionManager(Singleton):
如果已经初始化过,则直接返回。 如果已经初始化过,则直接返回。
""" """
super().__init__()
if hasattr(self, '_initialized') and self._initialized: if hasattr(self, '_initialized') and self._initialized:
return return
@@ -64,7 +63,7 @@ class PermissionManager(Singleton):
self.load() self.load()
logger.info("权限管理器初始化完成") logger.info("权限管理器初始化完成")
self._initialized = True super().__init__()
def load(self) -> None: def load(self) -> None:
""" """

View File

@@ -30,12 +30,21 @@ class PluginManager:
""" """
扫描并加载 `plugins` 目录下的所有插件。 扫描并加载 `plugins` 目录下的所有插件。
""" """
plugin_dir = os.path.join( # 使用 pathlib 获取更可靠的路径
os.path.dirname(os.path.abspath(__file__)), "..", "..", "plugins" # 当前文件: core/managers/plugin_manager.py
) # 目标: plugins/
current_dir = os.path.dirname(os.path.abspath(__file__))
# 回退两级到项目根目录 (core/managers -> core -> root)
root_dir = os.path.dirname(os.path.dirname(current_dir))
plugin_dir = os.path.join(root_dir, "plugins")
package_name = "plugins" package_name = "plugins"
logger.info(f"正在从 {package_name} 加载插件...") if not os.path.exists(plugin_dir):
logger.error(f"插件目录不存在: {plugin_dir}")
return
logger.info(f"正在从 {package_name} 加载插件 (路径: {plugin_dir})...")
for _, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]): for _, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
full_module_name = f"{package_name}.{module_name}" full_module_name = f"{package_name}.{module_name}"

View File

@@ -33,13 +33,10 @@ class Permission(Enum):
return NotImplemented return NotImplemented
return self._level_map[self] < self._level_map[other] return self._level_map[self] < self._level_map[other]
def __eq__(self, other):
if not isinstance(other, Permission):
return NotImplemented
return self is other
def __ge__(self, other): def __ge__(self, other):
"""
比较当前权限是否大于等于另一个权限。
"""
if not isinstance(other, Permission): if not isinstance(other, Permission):
return NotImplemented return NotImplemented
return self._level_map[self] >= self._level_map[other] return self._level_map[self] >= self._level_map[other]

View File

@@ -60,7 +60,15 @@ class CodeExecutor:
将代码执行任务添加到队列中。 将代码执行任务添加到队列中。
:param code: 待执行的 Python 代码字符串。 :param code: 待执行的 Python 代码字符串。
:param callback: 执行完毕后用于回复结果的回调函数。 :param callback: 执行完毕后用于回复结果的回调函数。
:raises RuntimeError: 如果 Docker 客户端未初始化。
""" """
if not self.docker_client:
logger.warning("[CodeExecutor] 尝试添加任务,但 Docker 客户端未初始化。任务被拒绝。")
# 这里可以选择抛出异常,或者直接调用回调返回错误信息
# 为了用户体验,我们构造一个错误结果并直接调用回调(如果可能)
# 但由于 callback 返回 Future这里简单起见我们记录日志并抛出异常
raise RuntimeError("Docker环境未就绪无法执行代码。")
task = {"code": code, "callback": callback} task = {"code": code, "callback": callback}
await self.task_queue.put(task) await self.task_queue.put(task)
logger.info(f"[CodeExecutor] 新的代码执行任务已入队 (队列当前长度: {self.task_queue.qsize()})。") logger.info(f"[CodeExecutor] 新的代码执行任务已入队 (队列当前长度: {self.task_queue.qsize()})。")

View File

@@ -128,8 +128,9 @@ class WS:
# 使用工厂创建事件对象 # 使用工厂创建事件对象
event = EventFactory.create_event(event_data) event = EventFactory.create_event(event_data)
# 在收到第一个 meta_event 时,初始化 Bot 实例 # 尝试初始化 Bot 实例 (如果尚未初始化且事件包含 self_id)
if event.post_type == "meta_event" and self.bot is None: # 只要事件中包含 self_id我们就可以初始化 Bot不必非要等待 meta_event
if self.bot is None and hasattr(event, 'self_id'):
self.self_id = event.self_id self.self_id = event.self_id
self.bot = Bot(self) self.bot = Bot(self)
logger.success(f"Bot 实例初始化完成: self_id={self.self_id}") logger.success(f"Bot 实例初始化完成: self_id={self.self_id}")

View File

@@ -0,0 +1,23 @@
"""
Models 包
导出常用的模型类,方便插件导入。
"""
from .events.base import OneBotEvent
from .events.message import MessageEvent, GroupMessageEvent, PrivateMessageEvent
from .events.notice import NoticeEvent
from .events.request import RequestEvent
from .message import MessageSegment
from .sender import Sender
__all__ = [
"OneBotEvent",
"MessageEvent",
"GroupMessageEvent",
"PrivateMessageEvent",
"NoticeEvent",
"RequestEvent",
"MessageSegment",
"Sender",
]

View File

@@ -70,7 +70,11 @@ class EventFactory:
# 解析消息段 # 解析消息段
message_list = [] message_list = []
raw_message_list = data.get("message", []) raw_message_list = data.get("message", [])
if isinstance(raw_message_list, list):
if isinstance(raw_message_list, str):
# 如果消息是字符串,将其视为纯文本消息段
message_list.append(MessageSegment.text(raw_message_list))
elif isinstance(raw_message_list, list):
for item in raw_message_list: for item in raw_message_list:
if isinstance(item, dict): if isinstance(item, dict):
message_list.append(MessageSegment(type=item.get("type", ""), data=item.get("data", {}))) message_list.append(MessageSegment(type=item.get("type", ""), data=item.get("data", {})))

View File

@@ -6,7 +6,7 @@
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, List
@dataclass(slots=True) @dataclass(slots=True)
@@ -23,7 +23,7 @@ class MessageSegment:
data: Dict[str, Any] data: Dict[str, Any]
@property @property
def text(self) -> str: def plain_text(self) -> str:
""" """
当消息段类型为 'text' 时,快速获取其文本内容。 当消息段类型为 'text' 时,快速获取其文本内容。
@@ -32,6 +32,19 @@ class MessageSegment:
""" """
return self.data.get("text", "") if self.type == "text" else "" return self.data.get("text", "") if self.type == "text" else ""
@staticmethod
def text(text: str) -> "MessageSegment":
"""
创建一个文本消息段。
Args:
text (str): 文本内容。
Returns:
MessageSegment: 一个类型为 'text' 的消息段对象。
"""
return MessageSegment(type="text", data={"text": text})
@property @property
def image_url(self) -> str: def image_url(self) -> str:
""" """
@@ -93,12 +106,48 @@ class MessageSegment:
return True return True
return str(self.data.get("qq")) == str(user_id) return str(self.data.get("qq")) == str(user_id)
def __str__(self):
"""
返回消息段的 CQ 码字符串表示。
"""
if self.type == "text":
return self.data.get("text", "")
params = ",".join([f"{k}={v}" for k, v in self.data.items()])
if params:
return f"[CQ:{self.type},{params}]"
return f"[CQ:{self.type}]"
def __repr__(self): def __repr__(self):
""" """
返回消息段对象的字符串表示形式,便于调试。 返回消息段对象的字符串表示形式,便于调试。
""" """
return f"[MS:{self.type}:{self.data}]" return f"[MS:{self.type}:{self.data}]"
def __add__(self, other: Any) -> "List[MessageSegment]":
"""
支持消息段相加,返回消息段列表。
"""
if isinstance(other, MessageSegment):
return [self, other]
elif isinstance(other, str):
return [self, MessageSegment.text(other)]
elif isinstance(other, list):
return [self] + other
return NotImplemented
def __radd__(self, other: Any) -> "List[MessageSegment]":
"""
支持反向相加。
"""
if isinstance(other, MessageSegment):
return [other, self]
elif isinstance(other, str):
return [MessageSegment.text(other), self]
elif isinstance(other, list):
return other + [self]
return NotImplemented
# --- 快捷构造方法 --- # --- 快捷构造方法 ---
@staticmethod @staticmethod
@@ -297,17 +346,17 @@ class MessageSegment:
return MessageSegment(type="file", data={"file": file}) return MessageSegment(type="file", data={"file": file})
@staticmethod @staticmethod
def reply(message_id: str) -> "MessageSegment": def reply(message_id: str | int) -> "MessageSegment":
""" """
创建一个回复消息段。 创建一个回复消息段。
Args: Args:
message_id (str): 被回复的消息 ID。 message_id (str | int): 被回复的消息 ID。
Returns: Returns:
MessageSegment: 一个类型为 'reply' 的消息段对象。 MessageSegment: 一个类型为 'reply' 的消息段对象。
""" """
return MessageSegment(type="reply", data={"id": message_id}) return MessageSegment(type="reply", data={"id": str(message_id)})
@staticmethod @staticmethod
def rps() -> "MessageSegment": def rps() -> "MessageSegment":

View File

@@ -1,17 +1,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re import re
import json import json
import httpx import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from typing import Optional, Dict, Any from typing import Optional, Dict, Any, Union
from cachetools import TTLCache from cachetools import TTLCache
from core.utils.logger import logger from core.utils.logger import logger
from core.managers.command_manager import matcher from core.managers.command_manager import matcher
from models.events.message import MessageEvent, MessageSegment from models import MessageEvent, MessageSegment
# 创建一个TTL缓存最大容量100缓存时间60秒 # 创建一个TTL缓存最大容量100缓存时间10秒
processed_messages: TTLCache[Any, bool] = TTLCache(maxsize=100, ttl=60) processed_messages: TTLCache[int, bool] = TTLCache(maxsize=100, ttl=10)
__plugin_meta__ = { __plugin_meta__ = {
"name": "bili_parser", "name": "bili_parser",
@@ -23,9 +23,6 @@ HEADERS = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
} }
# 创建可复用的异步HTTP客户端
async_client = httpx.AsyncClient(headers=HEADERS, follow_redirects=False, timeout=10)
def format_count(num: int) -> str: def format_count(num: int) -> str:
if not isinstance(num, int): if not isinstance(num, int):
@@ -43,32 +40,29 @@ def format_duration(seconds: int) -> str:
return f"{minutes:02d}:{seconds:02d}" return f"{minutes:02d}:{seconds:02d}"
async def get_real_url(short_url: str) -> Optional[str]: def get_real_url(short_url: str) -> Optional[str]:
try: try:
response = await async_client.head(short_url) response = requests.head(short_url, headers=HEADERS, allow_redirects=False, timeout=5)
if response.status_code == 302: if response.status_code == 302:
return response.headers.get('Location') return response.headers.get('Location')
except httpx.RequestError as e: except requests.RequestException as e:
logger.error(f"获取真实URL失败: {e}") print(f"获取真实URL失败: {e}")
return None return None
async def parse_video_info(video_url: str) -> Optional[Dict[str, Any]]: def parse_video_info(video_url: str) -> Optional[Dict[str, Any]]:
try: try:
response = await async_client.get(video_url, follow_redirects=True) response = requests.get(video_url, headers=HEADERS, timeout=5)
response.raise_for_status() response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser') soup = BeautifulSoup(response.text, 'html.parser')
script_tag = soup.find('script', text=re.compile('window.__INITIAL_STATE__')) script_tag = soup.find('script', text=re.compile('window.__INITIAL_STATE__'))
if not script_tag: if not script_tag or not script_tag.string:
return None return None
script_tag_content = script_tag.string match = re.search(r'window\.__INITIAL_STATE__\s*=\s*(\{.*?\});', script_tag.string)
if not script_tag_content:
return None
match = re.search(r'window\.__INITIAL_STATE__\s*=\s*(\{.*?\});', script_tag_content)
if not match: if not match:
return None return None
json_str = match.group(1) json_str = match.group(1)
data = json.loads(json_str) data = json.loads(json_str)
@@ -104,12 +98,12 @@ async def parse_video_info(video_url: str) -> Optional[Dict[str, Any]]:
"followers": up_data.get('fans', 0), "followers": up_data.get('fans', 0),
} }
except (httpx.RequestError, KeyError, AttributeError, json.JSONDecodeError) as e: except (requests.RequestException, KeyError, AttributeError, json.JSONDecodeError) as e:
logger.error(f"解析视频信息失败: {e}") print(f"解析视频信息失败: {e}")
return None return None
async def get_direct_video_url(video_url: str) -> Optional[str]: def get_direct_video_url(video_url: str) -> Optional[str]:
""" """
调用第三方API解析B站视频直链 调用第三方API解析B站视频直链
:param video_url: B站视频的完整URL :param video_url: B站视频的完整URL
@@ -117,12 +111,12 @@ async def get_direct_video_url(video_url: str) -> Optional[str]:
""" """
api_url = f"https://api.mir6.com/api/bzjiexi?url={video_url}&type=json" api_url = f"https://api.mir6.com/api/bzjiexi?url={video_url}&type=json"
try: try:
response = await async_client.get(api_url) response = requests.get(api_url, headers=HEADERS, timeout=10)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
if data.get("code") == 200 and data.get("data"): if data.get("code") == 200 and data.get("data"):
return data["data"][0].get("video_url") return data["data"][0].get("video_url")
except (httpx.RequestError, json.JSONDecodeError, KeyError, IndexError) as e: except (requests.RequestException, json.JSONDecodeError, KeyError, IndexError) as e:
logger.error(f"[bili_parser] 调用第三方API解析视频失败: {e}") logger.error(f"[bili_parser] 调用第三方API解析视频失败: {e}")
return None return None
@@ -184,7 +178,7 @@ async def process_bili_link(event: MessageEvent, url: str):
:param url: 待处理的B站链接 :param url: 待处理的B站链接
""" """
if "b23.tv" in url: if "b23.tv" in url:
real_url = await get_real_url(url) real_url = get_real_url(url)
if not real_url: if not real_url:
logger.error(f"[bili_parser] 无法从 {url} 获取真实URL。") logger.error(f"[bili_parser] 无法从 {url} 获取真实URL。")
await event.reply("无法解析B站短链接。") await event.reply("无法解析B站短链接。")
@@ -192,28 +186,59 @@ async def process_bili_link(event: MessageEvent, url: str):
else: else:
real_url = url.split('?')[0] real_url = url.split('?')[0]
video_info = await parse_video_info(real_url) video_info = parse_video_info(real_url)
if not video_info: if not video_info:
logger.error(f"[bili_parser] 无法从 {real_url} 解析视频信息。") logger.error(f"[bili_parser] 无法从 {real_url} 解析视频信息。")
await event.reply("无法获取视频信息可能是B站接口变动或视频不存在。") await event.reply("无法获取视频信息可能是B站接口变动或视频不存在。")
return return
title = video_info.get("title", "未知标题") # 检查视频时长
owner_name = video_info.get("owner_name", "未知UP主") video_message: Union[str, MessageSegment]
cover_url = video_info.get("cover_url") if video_info['duration'] > 300: # 5分钟 = 300秒
bvid = video_info.get("bvid", "N/A") video_message = "视频时长超过5分钟不进行解析。"
play_count = format_count(video_info.get("play", 0)) else:
like_count = format_count(video_info.get("like", 0)) direct_url = get_direct_video_url(real_url)
if direct_url:
video_message = MessageSegment.video(direct_url)
else:
video_message = "视频解析失败,无法获取直链。"
text_part = ( text_message = (
f"标题: {title}\n" f"BiliBili 视频解析\n"
f"UP主: {owner_name}\n" f"--------------------\n"
f"BV: {bvid} | ▶️ {play_count} | 👍 {like_count}" f" UP主: {video_info['owner_name']}\n"
f" 粉丝: {format_count(video_info['followers'])}\n"
f"--------------------\n"
f" 标题: {video_info['title']}\n"
f" BV号: {video_info['bvid']}\n"
f" 时长: {format_duration(video_info['duration'])}\n"
f"--------------------\n"
f" 数据:\n"
f" 播放: {format_count(video_info['play'])}\n"
f" 点赞: {format_count(video_info['like'])}\n"
f" 投币: {format_count(video_info['coin'])}\n"
f" 收藏: {format_count(video_info['favorite'])}\n"
f" 转发: {format_count(video_info['share'])}\n"
f" B站链接: {url}"
) )
reply_message = [MessageSegment.from_text(text_part)] image_message_segment = [
if cover_url: MessageSegment.text("B站封面"),
reply_message.append(MessageSegment.image(cover_url)) MessageSegment.image(video_info['cover_url'])
]
logger.success(f"[bili_parser] 成功解析视频信息并准备回复: {title}") up_info_segment = [
await event.reply(reply_message) MessageSegment.text("UP主头像"),
MessageSegment.image(video_info['owner_avatar'])
]
nodes = [
event.bot.build_forward_node(user_id=event.self_id, nickname="B站视频解析", message=text_message),
event.bot.build_forward_node(user_id=event.self_id, nickname="B站视频解析", message=image_message_segment),
event.bot.build_forward_node(user_id=event.self_id, nickname="B站视频解析", message=up_info_segment),
event.bot.build_forward_node(user_id=event.self_id, nickname="B站视频解析", message=video_message)
]
logger.success(f"[bili_parser] 成功解析视频信息并准备以聊天记录形式回复: {video_info['title']}")
# 使用更通用的 send_forwarded_messages 方法,自动判断私聊或群聊
await event.bot.send_forwarded_messages(target=event, nodes=nodes)

View File

@@ -11,7 +11,6 @@ pipreqs==0.4.13
redis==5.0.7 redis==5.0.7
requests==2.32.5 requests==2.32.5
soupsieve==2.8.1 soupsieve==2.8.1
toml==0.10.2
typing==3.7.4.3 typing==3.7.4.3
typing_extensions==4.15.0 typing_extensions==4.15.0
urllib3==2.6.2 urllib3==2.6.2
@@ -25,6 +24,7 @@ docker
pytest pytest
pytest-asyncio pytest-asyncio
pytest-mock pytest-mock
pytest-cov
httpx==0.27.0 httpx==0.27.0
# Dev Dependencies # Dev Dependencies

37
tests/test_basic.py Normal file
View File

@@ -0,0 +1,37 @@
import pytest
import sys
import os
# 确保项目根目录在 sys.path 中
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
def test_import_core():
"""测试核心模块是否可以被导入"""
try:
import core
import core.bot
import core.ws
except ImportError as e:
pytest.fail(f"无法导入核心模块: {e}")
def test_plugin_manager_path():
"""测试插件管理器路径逻辑是否正确"""
from core.managers.plugin_manager import PluginManager
# Mock command manager
pm = PluginManager(None)
# 我们无法直接测试 load_all_plugins 的内部路径变量,
# 但我们可以检查它是否能找到 plugins 目录而不报错
# 这里我们简单地断言 PluginManager 类存在且可以实例化
assert pm is not None
def test_config_loader_exists():
"""测试配置加载器是否存在"""
# 注意:导入 config_loader 会尝试读取 config.toml
# 如果 config.toml 不存在,这可能会失败。
# 这是一个已知的设计问题,但在测试环境中我们假设 config.toml 存在或被 mock
if os.path.exists("config.toml"):
from core.config_loader import global_config
assert global_config is not None
else:
pytest.skip("config.toml 不存在,跳过配置加载测试")

View File

@@ -0,0 +1,114 @@
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from core.managers.command_manager import CommandManager
from models.events.message import GroupMessageEvent
from models.message import MessageSegment
@pytest.fixture
def mock_bot():
bot = AsyncMock()
bot.self_id = 123456
return bot
@pytest.fixture
def command_manager():
# 创建一个新的 CommandManager 实例用于测试,避免单例状态污染
return CommandManager(prefixes=("/",))
@pytest.mark.asyncio
async def test_command_registration_and_execution(command_manager, mock_bot):
"""测试命令注册和执行"""
# 定义一个命令处理函数
handler_mock = AsyncMock()
# 注册命令
@command_manager.command("test")
async def test_command(bot, event):
await handler_mock(bot, event)
# 构造触发命令的事件
event = MagicMock(spec=GroupMessageEvent)
event.post_type = "message"
event.message_type = "group"
event.raw_message = "/test"
event.message = [MessageSegment.text("/test")]
event.user_id = 111
event.group_id = 222
# 处理事件
await command_manager.handle_event(mock_bot, event)
# 验证处理函数被调用
handler_mock.assert_called_once_with(mock_bot, event)
@pytest.mark.asyncio
async def test_command_prefix_match(command_manager, mock_bot):
"""测试命令前缀匹配"""
handler_mock = AsyncMock()
@command_manager.command("hello")
async def hello_command(bot, event):
await handler_mock(bot, event)
# 1. 正确的前缀
event1 = MagicMock(spec=GroupMessageEvent)
event1.post_type = "message"
event1.raw_message = "/hello"
event1.message = [MessageSegment.text("/hello")]
await command_manager.handle_event(mock_bot, event1)
handler_mock.assert_called_once()
handler_mock.reset_mock()
# 2. 错误的前缀 (应该忽略)
event2 = MagicMock(spec=GroupMessageEvent)
event2.post_type = "message"
event2.raw_message = ".hello" # 假设前缀是 /
event2.message = [MessageSegment.text(".hello")]
await command_manager.handle_event(mock_bot, event2)
handler_mock.assert_not_called()
@pytest.mark.asyncio
async def test_ignore_self_message(command_manager, mock_bot):
"""测试忽略自身消息"""
# 模拟配置
with patch("core.managers.command_manager.global_config") as mock_config:
mock_config.bot.ignore_self_message = True
event = MagicMock(spec=GroupMessageEvent)
event.post_type = "message"
event.user_id = 123456 # 与 bot.self_id 相同
event.self_id = 123456
# Mock handle 方法来检测是否被调用
command_manager.message_handler.handle = AsyncMock()
await command_manager.handle_event(mock_bot, event)
# 应该直接返回,不调用 handler
command_manager.message_handler.handle.assert_not_called()
@pytest.mark.asyncio
async def test_help_command(command_manager, mock_bot):
"""测试内置 help 命令"""
# 注册一个测试插件信息
command_manager.plugins["test_plugin"] = {
"name": "测试插件",
"description": "这是一个测试",
"usage": "/test"
}
event = MagicMock(spec=GroupMessageEvent)
event.post_type = "message"
event.raw_message = "/help"
event.message = [MessageSegment.text("/help")]
await command_manager.handle_event(mock_bot, event)
# 验证 bot.send 被调用,且内容包含插件信息
mock_bot.send.assert_called_once()
args, _ = mock_bot.send.call_args
sent_msg = args[1]
assert "测试插件" in sent_msg
assert "这是一个测试" in sent_msg

141
tests/test_event_factory.py Normal file
View File

@@ -0,0 +1,141 @@
import pytest
from models.events.factory import EventFactory, EventType
from models.events.message import GroupMessageEvent, PrivateMessageEvent
from models.events.notice import GroupIncreaseNoticeEvent
from models.events.request import FriendRequestEvent
from models.events.meta import HeartbeatEvent
from models.message import MessageSegment
class TestEventFactory:
def test_create_group_message_event_list(self):
"""测试创建群消息事件 (message 为列表格式)"""
data = {
"post_type": "message",
"message_type": "group",
"time": 1600000000,
"self_id": 123456,
"sub_type": "normal",
"message_id": 1001,
"user_id": 111111,
"group_id": 222222,
"message": [
{"type": "text", "data": {"text": "Hello"}}
],
"raw_message": "Hello",
"font": 0,
"sender": {
"user_id": 111111,
"nickname": "User",
"role": "member"
}
}
event = EventFactory.create_event(data)
assert isinstance(event, GroupMessageEvent)
assert event.group_id == 222222
assert len(event.message) == 1
assert event.message[0].type == "text"
assert event.message[0].data["text"] == "Hello"
def test_create_group_message_event_str(self):
"""测试创建群消息事件 (message 为字符串格式)"""
data = {
"post_type": "message",
"message_type": "group",
"time": 1600000000,
"self_id": 123456,
"sub_type": "normal",
"message_id": 1002,
"user_id": 111111,
"group_id": 222222,
"message": "Hello World",
"raw_message": "Hello World",
"font": 0,
"sender": {
"user_id": 111111,
"nickname": "User"
}
}
event = EventFactory.create_event(data)
assert isinstance(event, GroupMessageEvent)
assert len(event.message) == 1
assert event.message[0].type == "text"
assert event.message[0].data["text"] == "Hello World"
def test_create_private_message_event(self):
"""测试创建私聊消息事件"""
data = {
"post_type": "message",
"message_type": "private",
"time": 1600000000,
"self_id": 123456,
"sub_type": "friend",
"message_id": 2001,
"user_id": 333333,
"message": "Private Msg",
"raw_message": "Private Msg",
"font": 0,
"sender": {
"user_id": 333333,
"nickname": "Friend"
}
}
event = EventFactory.create_event(data)
assert isinstance(event, PrivateMessageEvent)
assert event.user_id == 333333
def test_create_notice_event(self):
"""测试创建通知事件 (群成员增加)"""
data = {
"post_type": "notice",
"notice_type": "group_increase",
"sub_type": "approve",
"group_id": 222222,
"operator_id": 444444,
"user_id": 555555,
"time": 1600000000,
"self_id": 123456
}
event = EventFactory.create_event(data)
assert isinstance(event, GroupIncreaseNoticeEvent)
assert event.group_id == 222222
assert event.user_id == 555555
def test_create_request_event(self):
"""测试创建请求事件 (加好友)"""
data = {
"post_type": "request",
"request_type": "friend",
"user_id": 666666,
"comment": "Add me",
"flag": "flag_123",
"time": 1600000000,
"self_id": 123456
}
event = EventFactory.create_event(data)
assert isinstance(event, FriendRequestEvent)
assert event.user_id == 666666
assert event.comment == "Add me"
def test_create_meta_event(self):
"""测试创建元事件 (心跳)"""
data = {
"post_type": "meta_event",
"meta_event_type": "heartbeat",
"time": 1600000000,
"self_id": 123456,
"status": {"online": True, "good": True},
"interval": 5000
}
event = EventFactory.create_event(data)
assert isinstance(event, HeartbeatEvent)
assert event.interval == 5000
def test_unknown_event_type(self):
"""测试未知事件类型"""
data = {
"post_type": "unknown_type",
"time": 1600000000,
"self_id": 123456
}
with pytest.raises(ValueError, match="Unknown event type"):
EventFactory.create_event(data)

194
tests/test_event_handler.py Normal file
View File

@@ -0,0 +1,194 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from core.handlers.event_handler import MessageHandler, NoticeHandler, RequestHandler
from models.events.message import GroupMessageEvent
from models.events.notice import GroupIncreaseNoticeEvent
from models.events.request import FriendRequestEvent
@pytest.fixture
def mock_bot():
bot = AsyncMock()
return bot
@pytest.mark.asyncio
async def test_message_handler_run_handler_injection(mock_bot):
"""测试参数注入"""
handler = MessageHandler(prefixes=("/",))
# 1. 测试注入 bot 和 event
async def func1(bot, event):
assert bot == mock_bot
assert event.user_id == 123
return True
event = MagicMock(spec=GroupMessageEvent)
event.user_id = 123
result = await handler._run_handler(func1, mock_bot, event)
assert result is True
# 2. 测试注入 args
async def func2(args):
assert args == ["arg1", "arg2"]
return True
result = await handler._run_handler(func2, mock_bot, event, args=["arg1", "arg2"])
assert result is True
@pytest.mark.asyncio
async def test_message_handler_command_parsing(mock_bot):
"""测试命令解析"""
handler = MessageHandler(prefixes=("/",))
mock_func = AsyncMock()
handler.commands["test"] = {
"func": mock_func,
"permission": None,
"override_permission_check": False,
"plugin_name": "test_plugin"
}
event = MagicMock(spec=GroupMessageEvent)
event.raw_message = "/test arg1 arg2"
event.user_id = 123
# Mock permission manager
with patch("core.managers.permission_manager.PermissionManager.check_permission", new_callable=AsyncMock) as mock_perm:
mock_perm.return_value = True
await handler.handle(mock_bot, event)
mock_func.assert_called_once()
# 验证 args 参数是否正确传递
call_args = mock_func.call_args
if "args" in call_args.kwargs:
assert call_args.kwargs["args"] == ["arg1", "arg2"]
@pytest.mark.asyncio
async def test_notice_handler(mock_bot):
"""测试通知事件分发"""
handler = NoticeHandler()
mock_func = AsyncMock()
handler.handlers.append({
"type": "group_increase",
"func": mock_func,
"plugin_name": "test_plugin"
})
event = MagicMock(spec=GroupIncreaseNoticeEvent)
event.notice_type = "group_increase"
await handler.handle(mock_bot, event)
mock_func.assert_called_once()
@pytest.mark.asyncio
async def test_sync_handler_execution(mock_bot):
"""测试同步处理函数的执行"""
handler = MessageHandler(prefixes=("/",))
def sync_func(event):
return True
event = MagicMock(spec=GroupMessageEvent)
# 同步函数应该在线程池中运行
result = await handler._run_handler(sync_func, mock_bot, event)
assert result is True
@pytest.mark.asyncio
async def test_message_handler_management(mock_bot):
"""测试消息处理器的管理(注册、卸载、清空)"""
handler = MessageHandler(prefixes=("/",))
# 测试 on_message 装饰器
@handler.on_message()
async def msg_handler(event):
pass
assert len(handler.message_handlers) == 1
# 测试 command 装饰器
@handler.command("cmd1", "cmd2")
async def cmd_handler(event):
pass
assert len(handler.commands) == 2
assert "cmd1" in handler.commands
assert "cmd2" in handler.commands
# 测试 unregister_by_plugin_name
# 直接从已注册的处理器中获取 plugin_name
if handler.message_handlers:
plugin_name = handler.message_handlers[0]["plugin_name"]
handler.unregister_by_plugin_name(plugin_name)
assert len(handler.message_handlers) == 0
assert len(handler.commands) == 0
# 测试 clear
handler.commands["cmd"] = {}
handler.message_handlers.append({})
handler.clear()
assert len(handler.commands) == 0
assert len(handler.message_handlers) == 0
@pytest.mark.asyncio
async def test_request_handler(mock_bot):
"""测试请求事件处理器"""
handler = RequestHandler()
mock_func = AsyncMock()
# 测试 register 装饰器
@handler.register("friend")
async def req_handler(event):
await mock_func(event)
assert len(handler.handlers) == 1
event = MagicMock(spec=FriendRequestEvent)
event.request_type = "friend"
await handler.handle(mock_bot, event)
mock_func.assert_called_once()
# 测试 unregister 和 clear
import inspect
module = inspect.getmodule(req_handler)
plugin_name = module.__name__
handler.unregister_by_plugin_name(plugin_name)
assert len(handler.handlers) == 0
handler.handlers.append({})
handler.clear()
assert len(handler.handlers) == 0
@pytest.mark.asyncio
async def test_permission_denied(mock_bot):
"""测试权限不足的情况"""
handler = MessageHandler(prefixes=("/",))
mock_func = AsyncMock()
handler.commands["admin_cmd"] = {
"func": mock_func,
"permission": "ADMIN", # 假设 Permission.ADMIN
"override_permission_check": False,
"plugin_name": "test_plugin"
}
event = MagicMock(spec=GroupMessageEvent)
event.raw_message = "/admin_cmd"
event.user_id = 123
# Mock permission manager returning False
with patch("core.managers.permission_manager.PermissionManager.check_permission", new_callable=AsyncMock) as mock_perm:
mock_perm.return_value = False
await handler.handle(mock_bot, event)
mock_func.assert_not_called()
# 应该发送拒绝消息
mock_bot.send.assert_called_once()

75
tests/test_models.py Normal file
View File

@@ -0,0 +1,75 @@
import pytest
from models.message import MessageSegment
from models.objects import GroupInfo, StrangerInfo
class TestMessageSegment:
def test_text_segment(self):
seg = MessageSegment.text("Hello")
assert seg.type == "text"
assert seg.data["text"] == "Hello"
assert str(seg) == "Hello"
def test_at_segment(self):
seg = MessageSegment.at(123456)
assert seg.type == "at"
assert seg.data["qq"] == "123456"
assert str(seg) == "[CQ:at,qq=123456]"
def test_image_segment(self):
seg = MessageSegment.image("http://example.com/img.jpg", cache=False, proxy=False)
assert seg.type == "image"
assert seg.data["file"] == "http://example.com/img.jpg"
assert str(seg) == "[CQ:image,file=http://example.com/img.jpg,cache=0,proxy=0]"
def test_face_segment(self):
seg = MessageSegment.face(123)
assert seg.type == "face"
assert seg.data["id"] == "123"
assert str(seg) == "[CQ:face,id=123]"
def test_reply_segment(self):
seg = MessageSegment.reply(1001)
assert seg.type == "reply"
assert seg.data["id"] == "1001"
assert str(seg) == "[CQ:reply,id=1001]"
def test_add_segments(self):
seg1 = MessageSegment.text("Hello ")
seg2 = MessageSegment.at(123)
combined = seg1 + seg2
assert isinstance(combined, list)
assert len(combined) == 2
assert combined[0] == seg1
assert combined[1] == seg2
def test_add_segment_and_string(self):
seg = MessageSegment.at(123)
combined = seg + " Hello"
assert isinstance(combined, list)
assert len(combined) == 2
assert combined[0] == seg
assert combined[1].type == "text"
assert combined[1].data["text"] == " Hello"
class TestObjects:
def test_group_info(self):
data = {
"group_id": 123456,
"group_name": "Test Group",
"member_count": 10,
"max_member_count": 100
}
group = GroupInfo(**data)
assert group.group_id == 123456
assert group.group_name == "Test Group"
def test_stranger_info(self):
data = {
"user_id": 111111,
"nickname": "Stranger",
"sex": "male",
"age": 18
}
user = StrangerInfo(**data)
assert user.user_id == 111111
assert user.nickname == "Stranger"