diff --git a/core/data/admin.json b/core/data/admin.json index b3c7949..577c240 100644 --- a/core/data/admin.json +++ b/core/data/admin.json @@ -1,3 +1,3 @@ { - "admins": [] + "admins": [2221577113] } \ No newline at end of file diff --git a/core/handlers/event_handler.py b/core/handlers/event_handler.py index e785349..b188eca 100644 --- a/core/handlers/event_handler.py +++ b/core/handlers/event_handler.py @@ -6,11 +6,12 @@ """ import inspect from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING -from ..bot import Bot +if TYPE_CHECKING: + from ..bot import Bot from ..config_loader import global_config -from ..managers.permission_manager import Permission +from ..permission import Permission from ..utils.executor import run_in_thread_pool @@ -22,7 +23,7 @@ class BaseHandler(ABC): self.handlers: List[Dict[str, Any]] = [] @abstractmethod - async def handle(self, bot: Bot, event: Any): + async def handle(self, bot: "Bot", event: Any): """ 处理事件 """ @@ -31,7 +32,7 @@ class BaseHandler(ABC): async def _run_handler( self, func: Callable, - bot: Bot, + bot: "Bot", event: Any, args: Optional[List[str]] = None, permission_granted: Optional[bool] = None @@ -122,7 +123,7 @@ class MessageHandler(BaseHandler): return func return decorator - async def handle(self, bot: Bot, event: Any): + async def handle(self, bot: "Bot", event: Any): """ 处理消息事件,分发给命令处理器或通用消息处理器 """ diff --git a/core/managers/admin_manager.py b/core/managers/admin_manager.py index 7e5f0d1..83b222f 100644 --- a/core/managers/admin_manager.py +++ b/core/managers/admin_manager.py @@ -26,8 +26,7 @@ class AdminManager(Singleton): """ 初始化 AdminManager """ - super().__init__() - if not self._initialized: + if hasattr(self, '_initialized') and self._initialized: return # 管理员数据文件路径 @@ -39,7 +38,12 @@ class AdminManager(Singleton): ) self._admins: Set[int] = set() + + # 确保数据目录存在 + os.makedirs(os.path.dirname(self.data_file), exist_ok=True) + logger.info("管理员管理器初始化完成") + super().__init__() async def initialize(self): """ diff --git a/core/managers/permission_manager.py b/core/managers/permission_manager.py index bf51990..de808c1 100644 --- a/core/managers/permission_manager.py +++ b/core/managers/permission_manager.py @@ -41,7 +41,6 @@ class PermissionManager(Singleton): 如果已经初始化过,则直接返回。 """ - super().__init__() if hasattr(self, '_initialized') and self._initialized: return @@ -64,7 +63,7 @@ class PermissionManager(Singleton): self.load() logger.info("权限管理器初始化完成") - self._initialized = True + super().__init__() def load(self) -> None: """ diff --git a/core/managers/plugin_manager.py b/core/managers/plugin_manager.py index 462def3..a287527 100644 --- a/core/managers/plugin_manager.py +++ b/core/managers/plugin_manager.py @@ -30,12 +30,21 @@ class PluginManager: """ 扫描并加载 `plugins` 目录下的所有插件。 """ - plugin_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "..", "..", "plugins" - ) + # 使用 pathlib 获取更可靠的路径 + # 当前文件: core/managers/plugin_manager.py + # 目标: plugins/ + current_dir = os.path.dirname(os.path.abspath(__file__)) + # 回退两级到项目根目录 (core/managers -> core -> root) + root_dir = os.path.dirname(os.path.dirname(current_dir)) + plugin_dir = os.path.join(root_dir, "plugins") + package_name = "plugins" - logger.info(f"正在从 {package_name} 加载插件...") + if not os.path.exists(plugin_dir): + logger.error(f"插件目录不存在: {plugin_dir}") + return + + logger.info(f"正在从 {package_name} 加载插件 (路径: {plugin_dir})...") for _, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]): full_module_name = f"{package_name}.{module_name}" diff --git a/core/permission.py b/core/permission.py index f574748..c66bd3b 100644 --- a/core/permission.py +++ b/core/permission.py @@ -33,13 +33,10 @@ class Permission(Enum): return NotImplemented return self._level_map[self] < self._level_map[other] - def __eq__(self, other): - if not isinstance(other, Permission): - return NotImplemented - return self is other - def __ge__(self, other): + """ + 比较当前权限是否大于等于另一个权限。 + """ if not isinstance(other, Permission): return NotImplemented return self._level_map[self] >= self._level_map[other] - diff --git a/core/utils/executor.py b/core/utils/executor.py index ca24514..79f2103 100644 --- a/core/utils/executor.py +++ b/core/utils/executor.py @@ -60,7 +60,15 @@ class CodeExecutor: 将代码执行任务添加到队列中。 :param code: 待执行的 Python 代码字符串。 :param callback: 执行完毕后用于回复结果的回调函数。 + :raises RuntimeError: 如果 Docker 客户端未初始化。 """ + if not self.docker_client: + logger.warning("[CodeExecutor] 尝试添加任务,但 Docker 客户端未初始化。任务被拒绝。") + # 这里可以选择抛出异常,或者直接调用回调返回错误信息 + # 为了用户体验,我们构造一个错误结果并直接调用回调(如果可能) + # 但由于 callback 返回 Future,这里简单起见,我们记录日志并抛出异常 + raise RuntimeError("Docker环境未就绪,无法执行代码。") + task = {"code": code, "callback": callback} await self.task_queue.put(task) logger.info(f"[CodeExecutor] 新的代码执行任务已入队 (队列当前长度: {self.task_queue.qsize()})。") diff --git a/core/ws.py b/core/ws.py index 544947c..8216cce 100644 --- a/core/ws.py +++ b/core/ws.py @@ -128,8 +128,9 @@ class WS: # 使用工厂创建事件对象 event = EventFactory.create_event(event_data) - # 在收到第一个 meta_event 时,初始化 Bot 实例 - if event.post_type == "meta_event" and self.bot is None: + # 尝试初始化 Bot 实例 (如果尚未初始化且事件包含 self_id) + # 只要事件中包含 self_id,我们就可以初始化 Bot,不必非要等待 meta_event + if self.bot is None and hasattr(event, 'self_id'): self.self_id = event.self_id self.bot = Bot(self) logger.success(f"Bot 实例初始化完成: self_id={self.self_id}") diff --git a/models/__init__.py b/models/__init__.py index e69de29..3418164 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -0,0 +1,23 @@ +""" +Models 包 + +导出常用的模型类,方便插件导入。 +""" + +from .events.base import OneBotEvent +from .events.message import MessageEvent, GroupMessageEvent, PrivateMessageEvent +from .events.notice import NoticeEvent +from .events.request import RequestEvent +from .message import MessageSegment +from .sender import Sender + +__all__ = [ + "OneBotEvent", + "MessageEvent", + "GroupMessageEvent", + "PrivateMessageEvent", + "NoticeEvent", + "RequestEvent", + "MessageSegment", + "Sender", +] diff --git a/models/events/factory.py b/models/events/factory.py index c1d93e6..7eb4e9f 100644 --- a/models/events/factory.py +++ b/models/events/factory.py @@ -70,7 +70,11 @@ class EventFactory: # 解析消息段 message_list = [] raw_message_list = data.get("message", []) - if isinstance(raw_message_list, list): + + if isinstance(raw_message_list, str): + # 如果消息是字符串,将其视为纯文本消息段 + message_list.append(MessageSegment.text(raw_message_list)) + elif isinstance(raw_message_list, list): for item in raw_message_list: if isinstance(item, dict): message_list.append(MessageSegment(type=item.get("type", ""), data=item.get("data", {}))) diff --git a/models/message.py b/models/message.py index 53e6435..2a8cafc 100644 --- a/models/message.py +++ b/models/message.py @@ -6,7 +6,7 @@ """ from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List @dataclass(slots=True) @@ -23,7 +23,7 @@ class MessageSegment: data: Dict[str, Any] @property - def text(self) -> str: + def plain_text(self) -> str: """ 当消息段类型为 'text' 时,快速获取其文本内容。 @@ -32,6 +32,19 @@ class MessageSegment: """ return self.data.get("text", "") if self.type == "text" else "" + @staticmethod + def text(text: str) -> "MessageSegment": + """ + 创建一个文本消息段。 + + Args: + text (str): 文本内容。 + + Returns: + MessageSegment: 一个类型为 'text' 的消息段对象。 + """ + return MessageSegment(type="text", data={"text": text}) + @property def image_url(self) -> str: """ @@ -93,12 +106,48 @@ class MessageSegment: return True return str(self.data.get("qq")) == str(user_id) + def __str__(self): + """ + 返回消息段的 CQ 码字符串表示。 + """ + if self.type == "text": + return self.data.get("text", "") + + params = ",".join([f"{k}={v}" for k, v in self.data.items()]) + if params: + return f"[CQ:{self.type},{params}]" + return f"[CQ:{self.type}]" + def __repr__(self): """ 返回消息段对象的字符串表示形式,便于调试。 """ return f"[MS:{self.type}:{self.data}]" + def __add__(self, other: Any) -> "List[MessageSegment]": + """ + 支持消息段相加,返回消息段列表。 + """ + if isinstance(other, MessageSegment): + return [self, other] + elif isinstance(other, str): + return [self, MessageSegment.text(other)] + elif isinstance(other, list): + return [self] + other + return NotImplemented + + def __radd__(self, other: Any) -> "List[MessageSegment]": + """ + 支持反向相加。 + """ + if isinstance(other, MessageSegment): + return [other, self] + elif isinstance(other, str): + return [MessageSegment.text(other), self] + elif isinstance(other, list): + return other + [self] + return NotImplemented + # --- 快捷构造方法 --- @staticmethod @@ -297,17 +346,17 @@ class MessageSegment: return MessageSegment(type="file", data={"file": file}) @staticmethod - def reply(message_id: str) -> "MessageSegment": + def reply(message_id: str | int) -> "MessageSegment": """ 创建一个回复消息段。 Args: - message_id (str): 被回复的消息 ID。 + message_id (str | int): 被回复的消息 ID。 Returns: MessageSegment: 一个类型为 'reply' 的消息段对象。 """ - return MessageSegment(type="reply", data={"id": message_id}) + return MessageSegment(type="reply", data={"id": str(message_id)}) @staticmethod def rps() -> "MessageSegment": diff --git a/plugins/bili_parser.py b/plugins/bili_parser.py index 57367ec..a4a8ac5 100644 --- a/plugins/bili_parser.py +++ b/plugins/bili_parser.py @@ -1,17 +1,17 @@ # -*- coding: utf-8 -*- import re import json -import httpx +import requests from bs4 import BeautifulSoup -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Union from cachetools import TTLCache from core.utils.logger import logger from core.managers.command_manager import matcher -from models.events.message import MessageEvent, MessageSegment +from models import MessageEvent, MessageSegment -# 创建一个TTL缓存,最大容量100,缓存时间60秒 -processed_messages: TTLCache[Any, bool] = TTLCache(maxsize=100, ttl=60) +# 创建一个TTL缓存,最大容量100,缓存时间10秒 +processed_messages: TTLCache[int, bool] = TTLCache(maxsize=100, ttl=10) __plugin_meta__ = { "name": "bili_parser", @@ -23,9 +23,6 @@ HEADERS = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' } -# 创建可复用的异步HTTP客户端 -async_client = httpx.AsyncClient(headers=HEADERS, follow_redirects=False, timeout=10) - def format_count(num: int) -> str: if not isinstance(num, int): @@ -43,32 +40,29 @@ def format_duration(seconds: int) -> str: return f"{minutes:02d}:{seconds:02d}" -async def get_real_url(short_url: str) -> Optional[str]: +def get_real_url(short_url: str) -> Optional[str]: try: - response = await async_client.head(short_url) + response = requests.head(short_url, headers=HEADERS, allow_redirects=False, timeout=5) if response.status_code == 302: return response.headers.get('Location') - except httpx.RequestError as e: - logger.error(f"获取真实URL失败: {e}") + except requests.RequestException as e: + print(f"获取真实URL失败: {e}") return None -async def parse_video_info(video_url: str) -> Optional[Dict[str, Any]]: +def parse_video_info(video_url: str) -> Optional[Dict[str, Any]]: try: - response = await async_client.get(video_url, follow_redirects=True) + response = requests.get(video_url, headers=HEADERS, timeout=5) response.raise_for_status() soup = BeautifulSoup(response.text, 'html.parser') script_tag = soup.find('script', text=re.compile('window.__INITIAL_STATE__')) - if not script_tag: + if not script_tag or not script_tag.string: return None - script_tag_content = script_tag.string - if not script_tag_content: - return None - - match = re.search(r'window\.__INITIAL_STATE__\s*=\s*(\{.*?\});', script_tag_content) + match = re.search(r'window\.__INITIAL_STATE__\s*=\s*(\{.*?\});', script_tag.string) if not match: return None + json_str = match.group(1) data = json.loads(json_str) @@ -104,12 +98,12 @@ async def parse_video_info(video_url: str) -> Optional[Dict[str, Any]]: "followers": up_data.get('fans', 0), } - except (httpx.RequestError, KeyError, AttributeError, json.JSONDecodeError) as e: - logger.error(f"解析视频信息失败: {e}") + except (requests.RequestException, KeyError, AttributeError, json.JSONDecodeError) as e: + print(f"解析视频信息失败: {e}") return None -async def get_direct_video_url(video_url: str) -> Optional[str]: +def get_direct_video_url(video_url: str) -> Optional[str]: """ 调用第三方API解析B站视频直链 :param video_url: B站视频的完整URL @@ -117,12 +111,12 @@ async def get_direct_video_url(video_url: str) -> Optional[str]: """ api_url = f"https://api.mir6.com/api/bzjiexi?url={video_url}&type=json" try: - response = await async_client.get(api_url) + response = requests.get(api_url, headers=HEADERS, timeout=10) response.raise_for_status() data = response.json() if data.get("code") == 200 and data.get("data"): return data["data"][0].get("video_url") - except (httpx.RequestError, json.JSONDecodeError, KeyError, IndexError) as e: + except (requests.RequestException, json.JSONDecodeError, KeyError, IndexError) as e: logger.error(f"[bili_parser] 调用第三方API解析视频失败: {e}") return None @@ -184,7 +178,7 @@ async def process_bili_link(event: MessageEvent, url: str): :param url: 待处理的B站链接 """ if "b23.tv" in url: - real_url = await get_real_url(url) + real_url = get_real_url(url) if not real_url: logger.error(f"[bili_parser] 无法从 {url} 获取真实URL。") await event.reply("无法解析B站短链接。") @@ -192,28 +186,59 @@ async def process_bili_link(event: MessageEvent, url: str): else: real_url = url.split('?')[0] - video_info = await parse_video_info(real_url) + video_info = parse_video_info(real_url) if not video_info: logger.error(f"[bili_parser] 无法从 {real_url} 解析视频信息。") await event.reply("无法获取视频信息,可能是B站接口变动或视频不存在。") return - title = video_info.get("title", "未知标题") - owner_name = video_info.get("owner_name", "未知UP主") - cover_url = video_info.get("cover_url") - bvid = video_info.get("bvid", "N/A") - play_count = format_count(video_info.get("play", 0)) - like_count = format_count(video_info.get("like", 0)) + # 检查视频时长 + video_message: Union[str, MessageSegment] + if video_info['duration'] > 300: # 5分钟 = 300秒 + video_message = "视频时长超过5分钟,不进行解析。" + else: + direct_url = get_direct_video_url(real_url) + if direct_url: + video_message = MessageSegment.video(direct_url) + else: + video_message = "视频解析失败,无法获取直链。" - text_part = ( - f"标题: {title}\n" - f"UP主: {owner_name}\n" - f"BV: {bvid} | ▶️ {play_count} | 👍 {like_count}" + text_message = ( + f"BiliBili 视频解析\n" + f"--------------------\n" + f" UP主: {video_info['owner_name']}\n" + f" 粉丝: {format_count(video_info['followers'])}\n" + f"--------------------\n" + f" 标题: {video_info['title']}\n" + f" BV号: {video_info['bvid']}\n" + f" 时长: {format_duration(video_info['duration'])}\n" + f"--------------------\n" + f" 数据:\n" + f" 播放: {format_count(video_info['play'])}\n" + f" 点赞: {format_count(video_info['like'])}\n" + f" 投币: {format_count(video_info['coin'])}\n" + f" 收藏: {format_count(video_info['favorite'])}\n" + f" 转发: {format_count(video_info['share'])}\n" + f" B站链接: {url}" ) - - reply_message = [MessageSegment.from_text(text_part)] - if cover_url: - reply_message.append(MessageSegment.image(cover_url)) - logger.success(f"[bili_parser] 成功解析视频信息并准备回复: {title}") - await event.reply(reply_message) + image_message_segment = [ + MessageSegment.text("B站封面:"), + MessageSegment.image(video_info['cover_url']) + ] + + up_info_segment = [ + MessageSegment.text("UP主头像:"), + MessageSegment.image(video_info['owner_avatar']) + ] + + nodes = [ + event.bot.build_forward_node(user_id=event.self_id, nickname="B站视频解析", message=text_message), + event.bot.build_forward_node(user_id=event.self_id, nickname="B站视频解析", message=image_message_segment), + event.bot.build_forward_node(user_id=event.self_id, nickname="B站视频解析", message=up_info_segment), + event.bot.build_forward_node(user_id=event.self_id, nickname="B站视频解析", message=video_message) + ] + + logger.success(f"[bili_parser] 成功解析视频信息并准备以聊天记录形式回复: {video_info['title']}") + # 使用更通用的 send_forwarded_messages 方法,自动判断私聊或群聊 + await event.bot.send_forwarded_messages(target=event, nodes=nodes) diff --git a/requirements.txt b/requirements.txt index 95b3fc5..fe0f977 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,6 @@ pipreqs==0.4.13 redis==5.0.7 requests==2.32.5 soupsieve==2.8.1 -toml==0.10.2 typing==3.7.4.3 typing_extensions==4.15.0 urllib3==2.6.2 @@ -25,6 +24,7 @@ docker pytest pytest-asyncio pytest-mock +pytest-cov httpx==0.27.0 # Dev Dependencies diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000..7dafa19 --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,37 @@ +import pytest +import sys +import os + +# 确保项目根目录在 sys.path 中 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +def test_import_core(): + """测试核心模块是否可以被导入""" + try: + import core + import core.bot + import core.ws + except ImportError as e: + pytest.fail(f"无法导入核心模块: {e}") + +def test_plugin_manager_path(): + """测试插件管理器路径逻辑是否正确""" + from core.managers.plugin_manager import PluginManager + # Mock command manager + pm = PluginManager(None) + + # 我们无法直接测试 load_all_plugins 的内部路径变量, + # 但我们可以检查它是否能找到 plugins 目录而不报错 + # 这里我们简单地断言 PluginManager 类存在且可以实例化 + assert pm is not None + +def test_config_loader_exists(): + """测试配置加载器是否存在""" + # 注意:导入 config_loader 会尝试读取 config.toml + # 如果 config.toml 不存在,这可能会失败。 + # 这是一个已知的设计问题,但在测试环境中我们假设 config.toml 存在或被 mock + if os.path.exists("config.toml"): + from core.config_loader import global_config + assert global_config is not None + else: + pytest.skip("config.toml 不存在,跳过配置加载测试") diff --git a/tests/test_command_manager.py b/tests/test_command_manager.py new file mode 100644 index 0000000..3743d99 --- /dev/null +++ b/tests/test_command_manager.py @@ -0,0 +1,114 @@ +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch +from core.managers.command_manager import CommandManager +from models.events.message import GroupMessageEvent +from models.message import MessageSegment + +@pytest.fixture +def mock_bot(): + bot = AsyncMock() + bot.self_id = 123456 + return bot + +@pytest.fixture +def command_manager(): + # 创建一个新的 CommandManager 实例用于测试,避免单例状态污染 + return CommandManager(prefixes=("/",)) + +@pytest.mark.asyncio +async def test_command_registration_and_execution(command_manager, mock_bot): + """测试命令注册和执行""" + + # 定义一个命令处理函数 + handler_mock = AsyncMock() + + # 注册命令 + @command_manager.command("test") + async def test_command(bot, event): + await handler_mock(bot, event) + + # 构造触发命令的事件 + event = MagicMock(spec=GroupMessageEvent) + event.post_type = "message" + event.message_type = "group" + event.raw_message = "/test" + event.message = [MessageSegment.text("/test")] + event.user_id = 111 + event.group_id = 222 + + # 处理事件 + await command_manager.handle_event(mock_bot, event) + + # 验证处理函数被调用 + handler_mock.assert_called_once_with(mock_bot, event) + +@pytest.mark.asyncio +async def test_command_prefix_match(command_manager, mock_bot): + """测试命令前缀匹配""" + handler_mock = AsyncMock() + + @command_manager.command("hello") + async def hello_command(bot, event): + await handler_mock(bot, event) + + # 1. 正确的前缀 + event1 = MagicMock(spec=GroupMessageEvent) + event1.post_type = "message" + event1.raw_message = "/hello" + event1.message = [MessageSegment.text("/hello")] + await command_manager.handle_event(mock_bot, event1) + handler_mock.assert_called_once() + handler_mock.reset_mock() + + # 2. 错误的前缀 (应该忽略) + event2 = MagicMock(spec=GroupMessageEvent) + event2.post_type = "message" + event2.raw_message = ".hello" # 假设前缀是 / + event2.message = [MessageSegment.text(".hello")] + await command_manager.handle_event(mock_bot, event2) + handler_mock.assert_not_called() + +@pytest.mark.asyncio +async def test_ignore_self_message(command_manager, mock_bot): + """测试忽略自身消息""" + # 模拟配置 + with patch("core.managers.command_manager.global_config") as mock_config: + mock_config.bot.ignore_self_message = True + + event = MagicMock(spec=GroupMessageEvent) + event.post_type = "message" + event.user_id = 123456 # 与 bot.self_id 相同 + event.self_id = 123456 + + # Mock handle 方法来检测是否被调用 + command_manager.message_handler.handle = AsyncMock() + + await command_manager.handle_event(mock_bot, event) + + # 应该直接返回,不调用 handler + command_manager.message_handler.handle.assert_not_called() + +@pytest.mark.asyncio +async def test_help_command(command_manager, mock_bot): + """测试内置 help 命令""" + # 注册一个测试插件信息 + command_manager.plugins["test_plugin"] = { + "name": "测试插件", + "description": "这是一个测试", + "usage": "/test" + } + + event = MagicMock(spec=GroupMessageEvent) + event.post_type = "message" + event.raw_message = "/help" + event.message = [MessageSegment.text("/help")] + + await command_manager.handle_event(mock_bot, event) + + # 验证 bot.send 被调用,且内容包含插件信息 + mock_bot.send.assert_called_once() + args, _ = mock_bot.send.call_args + sent_msg = args[1] + assert "测试插件" in sent_msg + assert "这是一个测试" in sent_msg diff --git a/tests/test_event_factory.py b/tests/test_event_factory.py new file mode 100644 index 0000000..1e038fd --- /dev/null +++ b/tests/test_event_factory.py @@ -0,0 +1,141 @@ +import pytest +from models.events.factory import EventFactory, EventType +from models.events.message import GroupMessageEvent, PrivateMessageEvent +from models.events.notice import GroupIncreaseNoticeEvent +from models.events.request import FriendRequestEvent +from models.events.meta import HeartbeatEvent +from models.message import MessageSegment + +class TestEventFactory: + def test_create_group_message_event_list(self): + """测试创建群消息事件 (message 为列表格式)""" + data = { + "post_type": "message", + "message_type": "group", + "time": 1600000000, + "self_id": 123456, + "sub_type": "normal", + "message_id": 1001, + "user_id": 111111, + "group_id": 222222, + "message": [ + {"type": "text", "data": {"text": "Hello"}} + ], + "raw_message": "Hello", + "font": 0, + "sender": { + "user_id": 111111, + "nickname": "User", + "role": "member" + } + } + event = EventFactory.create_event(data) + assert isinstance(event, GroupMessageEvent) + assert event.group_id == 222222 + assert len(event.message) == 1 + assert event.message[0].type == "text" + assert event.message[0].data["text"] == "Hello" + + def test_create_group_message_event_str(self): + """测试创建群消息事件 (message 为字符串格式)""" + data = { + "post_type": "message", + "message_type": "group", + "time": 1600000000, + "self_id": 123456, + "sub_type": "normal", + "message_id": 1002, + "user_id": 111111, + "group_id": 222222, + "message": "Hello World", + "raw_message": "Hello World", + "font": 0, + "sender": { + "user_id": 111111, + "nickname": "User" + } + } + event = EventFactory.create_event(data) + assert isinstance(event, GroupMessageEvent) + assert len(event.message) == 1 + assert event.message[0].type == "text" + assert event.message[0].data["text"] == "Hello World" + + def test_create_private_message_event(self): + """测试创建私聊消息事件""" + data = { + "post_type": "message", + "message_type": "private", + "time": 1600000000, + "self_id": 123456, + "sub_type": "friend", + "message_id": 2001, + "user_id": 333333, + "message": "Private Msg", + "raw_message": "Private Msg", + "font": 0, + "sender": { + "user_id": 333333, + "nickname": "Friend" + } + } + event = EventFactory.create_event(data) + assert isinstance(event, PrivateMessageEvent) + assert event.user_id == 333333 + + def test_create_notice_event(self): + """测试创建通知事件 (群成员增加)""" + data = { + "post_type": "notice", + "notice_type": "group_increase", + "sub_type": "approve", + "group_id": 222222, + "operator_id": 444444, + "user_id": 555555, + "time": 1600000000, + "self_id": 123456 + } + event = EventFactory.create_event(data) + assert isinstance(event, GroupIncreaseNoticeEvent) + assert event.group_id == 222222 + assert event.user_id == 555555 + + def test_create_request_event(self): + """测试创建请求事件 (加好友)""" + data = { + "post_type": "request", + "request_type": "friend", + "user_id": 666666, + "comment": "Add me", + "flag": "flag_123", + "time": 1600000000, + "self_id": 123456 + } + event = EventFactory.create_event(data) + assert isinstance(event, FriendRequestEvent) + assert event.user_id == 666666 + assert event.comment == "Add me" + + def test_create_meta_event(self): + """测试创建元事件 (心跳)""" + data = { + "post_type": "meta_event", + "meta_event_type": "heartbeat", + "time": 1600000000, + "self_id": 123456, + "status": {"online": True, "good": True}, + "interval": 5000 + } + event = EventFactory.create_event(data) + assert isinstance(event, HeartbeatEvent) + assert event.interval == 5000 + + def test_unknown_event_type(self): + """测试未知事件类型""" + data = { + "post_type": "unknown_type", + "time": 1600000000, + "self_id": 123456 + } + with pytest.raises(ValueError, match="Unknown event type"): + EventFactory.create_event(data) diff --git a/tests/test_event_handler.py b/tests/test_event_handler.py new file mode 100644 index 0000000..80af28f --- /dev/null +++ b/tests/test_event_handler.py @@ -0,0 +1,194 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from core.handlers.event_handler import MessageHandler, NoticeHandler, RequestHandler +from models.events.message import GroupMessageEvent +from models.events.notice import GroupIncreaseNoticeEvent +from models.events.request import FriendRequestEvent + +@pytest.fixture +def mock_bot(): + bot = AsyncMock() + return bot + +@pytest.mark.asyncio +async def test_message_handler_run_handler_injection(mock_bot): + """测试参数注入""" + handler = MessageHandler(prefixes=("/",)) + + # 1. 测试注入 bot 和 event + async def func1(bot, event): + assert bot == mock_bot + assert event.user_id == 123 + return True + + event = MagicMock(spec=GroupMessageEvent) + event.user_id = 123 + + result = await handler._run_handler(func1, mock_bot, event) + assert result is True + + # 2. 测试注入 args + async def func2(args): + assert args == ["arg1", "arg2"] + return True + + result = await handler._run_handler(func2, mock_bot, event, args=["arg1", "arg2"]) + assert result is True + +@pytest.mark.asyncio +async def test_message_handler_command_parsing(mock_bot): + """测试命令解析""" + handler = MessageHandler(prefixes=("/",)) + + mock_func = AsyncMock() + handler.commands["test"] = { + "func": mock_func, + "permission": None, + "override_permission_check": False, + "plugin_name": "test_plugin" + } + + event = MagicMock(spec=GroupMessageEvent) + event.raw_message = "/test arg1 arg2" + event.user_id = 123 + + # Mock permission manager + with patch("core.managers.permission_manager.PermissionManager.check_permission", new_callable=AsyncMock) as mock_perm: + mock_perm.return_value = True + + await handler.handle(mock_bot, event) + + mock_func.assert_called_once() + # 验证 args 参数是否正确传递 + call_args = mock_func.call_args + if "args" in call_args.kwargs: + assert call_args.kwargs["args"] == ["arg1", "arg2"] + +@pytest.mark.asyncio +async def test_notice_handler(mock_bot): + """测试通知事件分发""" + handler = NoticeHandler() + + mock_func = AsyncMock() + handler.handlers.append({ + "type": "group_increase", + "func": mock_func, + "plugin_name": "test_plugin" + }) + + event = MagicMock(spec=GroupIncreaseNoticeEvent) + event.notice_type = "group_increase" + + await handler.handle(mock_bot, event) + + mock_func.assert_called_once() + +@pytest.mark.asyncio +async def test_sync_handler_execution(mock_bot): + """测试同步处理函数的执行""" + handler = MessageHandler(prefixes=("/",)) + + def sync_func(event): + return True + + event = MagicMock(spec=GroupMessageEvent) + + # 同步函数应该在线程池中运行 + result = await handler._run_handler(sync_func, mock_bot, event) + assert result is True + +@pytest.mark.asyncio +async def test_message_handler_management(mock_bot): + """测试消息处理器的管理(注册、卸载、清空)""" + handler = MessageHandler(prefixes=("/",)) + + # 测试 on_message 装饰器 + @handler.on_message() + async def msg_handler(event): + pass + + assert len(handler.message_handlers) == 1 + + # 测试 command 装饰器 + @handler.command("cmd1", "cmd2") + async def cmd_handler(event): + pass + + assert len(handler.commands) == 2 + assert "cmd1" in handler.commands + assert "cmd2" in handler.commands + + # 测试 unregister_by_plugin_name + # 直接从已注册的处理器中获取 plugin_name + if handler.message_handlers: + plugin_name = handler.message_handlers[0]["plugin_name"] + handler.unregister_by_plugin_name(plugin_name) + + assert len(handler.message_handlers) == 0 + assert len(handler.commands) == 0 + + # 测试 clear + handler.commands["cmd"] = {} + handler.message_handlers.append({}) + handler.clear() + assert len(handler.commands) == 0 + assert len(handler.message_handlers) == 0 + +@pytest.mark.asyncio +async def test_request_handler(mock_bot): + """测试请求事件处理器""" + handler = RequestHandler() + + mock_func = AsyncMock() + + # 测试 register 装饰器 + @handler.register("friend") + async def req_handler(event): + await mock_func(event) + + assert len(handler.handlers) == 1 + + event = MagicMock(spec=FriendRequestEvent) + event.request_type = "friend" + + await handler.handle(mock_bot, event) + mock_func.assert_called_once() + + # 测试 unregister 和 clear + import inspect + module = inspect.getmodule(req_handler) + plugin_name = module.__name__ + + handler.unregister_by_plugin_name(plugin_name) + assert len(handler.handlers) == 0 + + handler.handlers.append({}) + handler.clear() + assert len(handler.handlers) == 0 + +@pytest.mark.asyncio +async def test_permission_denied(mock_bot): + """测试权限不足的情况""" + handler = MessageHandler(prefixes=("/",)) + + mock_func = AsyncMock() + handler.commands["admin_cmd"] = { + "func": mock_func, + "permission": "ADMIN", # 假设 Permission.ADMIN + "override_permission_check": False, + "plugin_name": "test_plugin" + } + + event = MagicMock(spec=GroupMessageEvent) + event.raw_message = "/admin_cmd" + event.user_id = 123 + + # Mock permission manager returning False + with patch("core.managers.permission_manager.PermissionManager.check_permission", new_callable=AsyncMock) as mock_perm: + mock_perm.return_value = False + + await handler.handle(mock_bot, event) + + mock_func.assert_not_called() + # 应该发送拒绝消息 + mock_bot.send.assert_called_once() diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..497581d --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,75 @@ +import pytest +from models.message import MessageSegment +from models.objects import GroupInfo, StrangerInfo + +class TestMessageSegment: + def test_text_segment(self): + seg = MessageSegment.text("Hello") + assert seg.type == "text" + assert seg.data["text"] == "Hello" + assert str(seg) == "Hello" + + def test_at_segment(self): + seg = MessageSegment.at(123456) + assert seg.type == "at" + assert seg.data["qq"] == "123456" + assert str(seg) == "[CQ:at,qq=123456]" + + def test_image_segment(self): + seg = MessageSegment.image("http://example.com/img.jpg", cache=False, proxy=False) + assert seg.type == "image" + assert seg.data["file"] == "http://example.com/img.jpg" + assert str(seg) == "[CQ:image,file=http://example.com/img.jpg,cache=0,proxy=0]" + + def test_face_segment(self): + seg = MessageSegment.face(123) + assert seg.type == "face" + assert seg.data["id"] == "123" + assert str(seg) == "[CQ:face,id=123]" + + def test_reply_segment(self): + seg = MessageSegment.reply(1001) + assert seg.type == "reply" + assert seg.data["id"] == "1001" + assert str(seg) == "[CQ:reply,id=1001]" + + def test_add_segments(self): + seg1 = MessageSegment.text("Hello ") + seg2 = MessageSegment.at(123) + combined = seg1 + seg2 + assert isinstance(combined, list) + assert len(combined) == 2 + assert combined[0] == seg1 + assert combined[1] == seg2 + + def test_add_segment_and_string(self): + seg = MessageSegment.at(123) + combined = seg + " Hello" + assert isinstance(combined, list) + assert len(combined) == 2 + assert combined[0] == seg + assert combined[1].type == "text" + assert combined[1].data["text"] == " Hello" + +class TestObjects: + def test_group_info(self): + data = { + "group_id": 123456, + "group_name": "Test Group", + "member_count": 10, + "max_member_count": 100 + } + group = GroupInfo(**data) + assert group.group_id == 123456 + assert group.group_name == "Test Group" + + def test_stranger_info(self): + data = { + "user_id": 111111, + "nickname": "Stranger", + "sex": "male", + "age": 18 + } + user = StrangerInfo(**data) + assert user.user_id == 111111 + assert user.nickname == "Stranger"