Merge branch 'main' of https://github.com/Fairy-Oracle-Sanctuary/NeoBot
This commit is contained in:
231
core/WS.py
231
core/WS.py
@@ -1,77 +1,125 @@
|
||||
"""
|
||||
WebSocket 核心模块
|
||||
WebSocket 核心通信模块
|
||||
|
||||
负责与 OneBot 实现端建立 WebSocket 连接,处理消息接收、事件分发和 API 调用。
|
||||
该模块定义了 `WS` 类,负责与 OneBot v11 实现(如 NapCat)建立和管理
|
||||
WebSocket 连接。它是整个机器人框架的底层通信基础。
|
||||
|
||||
主要职责包括:
|
||||
- 建立 WebSocket 连接并处理认证。
|
||||
- 实现断线自动重连机制。
|
||||
- 监听并接收来自 OneBot 的事件和 API 响应。
|
||||
- 分发事件给 `CommandManager` 进行处理。
|
||||
- 提供 `call_api` 方法,用于异步发送 API 请求并等待响应。
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
from typing import Any, Dict, Optional, cast
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import websockets
|
||||
from websockets.legacy.client import WebSocketClientProtocol
|
||||
|
||||
from models import EventFactory
|
||||
from models.events.factory import EventFactory
|
||||
|
||||
from .bot import Bot
|
||||
from .command_manager import matcher
|
||||
from .config_loader import global_config
|
||||
from .managers.command_manager import matcher
|
||||
from .utils.executor import CodeExecutor
|
||||
from .utils.logger import logger, ModuleLogger
|
||||
from .utils.exceptions import (
|
||||
WebSocketError, WebSocketConnectionError, WebSocketAuthenticationError
|
||||
)
|
||||
from .utils.error_codes import ErrorCode, create_error_response
|
||||
|
||||
|
||||
class WS:
|
||||
"""
|
||||
WebSocket 客户端类,负责与 OneBot 实现端建立连接并处理通信
|
||||
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, code_executor: Optional[CodeExecutor] = None) -> None:
|
||||
"""
|
||||
初始化 WebSocket 客户端
|
||||
初始化 WebSocket 客户端。
|
||||
|
||||
从全局配置中读取 WebSocket URI、访问令牌(Token)和重连间隔。
|
||||
"""
|
||||
# 读取参数
|
||||
cfg = global_config.napcat_ws
|
||||
self.url = cfg.get("uri")
|
||||
self.token = cfg.get("token")
|
||||
self.reconnect_interval = cfg.get("reconnect_interval", 5)
|
||||
self.url = cfg.uri
|
||||
self.token = cfg.token
|
||||
self.reconnect_interval = cfg.reconnect_interval
|
||||
|
||||
self.ws = None
|
||||
self._pending_requests = {}
|
||||
self.bot = Bot(self)
|
||||
# 初始化状态
|
||||
self.ws: Optional[WebSocketClientProtocol] = None
|
||||
self._pending_requests: Dict[str, asyncio.Future] = {} # echo: future
|
||||
self.bot: Bot | None = None
|
||||
self.self_id: int | None = None
|
||||
self.code_executor = code_executor
|
||||
|
||||
# 创建模块专用日志记录器
|
||||
self.logger = ModuleLogger("WebSocket")
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
主连接循环,负责建立连接和自动重连
|
||||
启动并管理 WebSocket 连接。
|
||||
|
||||
这是一个无限循环,负责建立连接。如果连接断开,它会根据配置的
|
||||
`reconnect_interval` 时间间隔后自动尝试重新连接。
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
||||
|
||||
while True:
|
||||
try:
|
||||
print(f" 正在尝试连接至 NapCat: {self.url}")
|
||||
self.logger.info(f"正在尝试连接至 NapCat: {self.url}")
|
||||
async with websockets.connect(
|
||||
self.url, additional_headers=headers
|
||||
) as websocket:
|
||||
) as websocket_raw:
|
||||
websocket = cast(WebSocketClientProtocol, websocket_raw)
|
||||
self.ws = websocket
|
||||
print(" 连接成功!")
|
||||
self.logger.success("连接成功!")
|
||||
await self._listen_loop(websocket)
|
||||
|
||||
except websockets.exceptions.AuthenticationError as e:
|
||||
error = WebSocketAuthenticationError(
|
||||
message=f"WebSocket认证失败: {str(e)}",
|
||||
code=ErrorCode.WS_AUTH_FAILED,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.error(f"连接失败: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
except (
|
||||
websockets.exceptions.ConnectionClosed,
|
||||
ConnectionRefusedError,
|
||||
) as e:
|
||||
print(f" 连接断开或服务器拒绝访问: {e}")
|
||||
error = WebSocketConnectionError(
|
||||
message=f"连接断开或服务器拒绝访问: {str(e)}",
|
||||
code=ErrorCode.WS_CONNECTION_FAILED,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.warning(f"连接失败: {error.message}")
|
||||
except Exception as e:
|
||||
print(f" 运行异常: {e}")
|
||||
traceback.print_exc()
|
||||
error = WebSocketError(
|
||||
message=f"WebSocket运行异常: {str(e)}",
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.exception(f"运行异常: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
|
||||
print(f" {self.reconnect_interval}秒后尝试重连...")
|
||||
self.logger.info(f"{self.reconnect_interval}秒后尝试重连...")
|
||||
await asyncio.sleep(self.reconnect_interval)
|
||||
|
||||
async def _listen_loop(self, websocket):
|
||||
async def _listen_loop(self, websocket_connection: WebSocketClientProtocol) -> None:
|
||||
"""
|
||||
核心监听循环,处理接收到的 WebSocket 消息
|
||||
核心监听循环,处理所有接收到的 WebSocket 消息。
|
||||
|
||||
:param websocket: WebSocket 连接对象
|
||||
此循环会持续从 WebSocket 连接中读取消息,并根据消息内容
|
||||
判断是 API 响应还是上报的事件,然后分发给相应的处理逻辑。
|
||||
|
||||
Args:
|
||||
websocket_connection: 当前活动的 WebSocket 连接对象。
|
||||
"""
|
||||
async for message in websocket:
|
||||
async for message in websocket_connection:
|
||||
try:
|
||||
data = json.loads(message)
|
||||
|
||||
@@ -90,53 +138,121 @@ class WS:
|
||||
# 使用 create_task 异步执行,避免阻塞 WebSocket 接收循环
|
||||
asyncio.create_task(self.on_event(data))
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
error = WebSocketError(
|
||||
message=f"JSON解析失败: {str(e)}",
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.error(f"解析消息异常: {error.message}")
|
||||
self.logger.debug(f"原始消息: {message}")
|
||||
except Exception as e:
|
||||
print(f" 解析消息异常: {e}")
|
||||
traceback.print_exc()
|
||||
error = WebSocketError(
|
||||
message=f"处理消息异常: {str(e)}",
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.exception(f"解析消息异常: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
|
||||
async def on_event(self, raw_data: dict):
|
||||
async def on_event(self, event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
事件分发层:根据 post_type 调用 matcher 对应的处理器
|
||||
事件处理和分发层。
|
||||
|
||||
:param raw_data: 原始事件数据字典
|
||||
当接收到一个 OneBot 事件时,此方法负责:
|
||||
1. 使用 `EventFactory` 将原始 JSON 数据解析成对应的事件对象。
|
||||
2. 为事件对象注入 `Bot` 实例,以便在插件中可以调用 API。
|
||||
3. 打印格式化的事件日志。
|
||||
4. 将事件对象传递给 `CommandManager` (`matcher`) 进行后续处理。
|
||||
|
||||
Args:
|
||||
event_data (dict): 从 WebSocket 接收到的原始事件字典。
|
||||
"""
|
||||
try:
|
||||
# 使用工厂创建事件对象
|
||||
event = EventFactory.create_event(raw_data)
|
||||
event = EventFactory.create_event(event_data)
|
||||
|
||||
# 尝试初始化 Bot 实例 (如果尚未初始化且事件包含 self_id)
|
||||
# 只要事件中包含 self_id,我们就可以初始化 Bot,不必非要等待 meta_event
|
||||
if self.bot is None and hasattr(event, 'self_id'):
|
||||
self.self_id = event.self_id
|
||||
self.bot = Bot(self)
|
||||
self.logger.success(f"Bot 实例初始化完成: self_id={self.self_id}")
|
||||
|
||||
# 将代码执行器注入到 Bot 和执行器自身
|
||||
if self.code_executor:
|
||||
self.bot.code_executor = self.code_executor
|
||||
self.code_executor.bot = self.bot
|
||||
self.logger.info("代码执行器已成功注入 Bot 实例。")
|
||||
|
||||
# 如果 bot 尚未初始化,则不处理后续事件
|
||||
if self.bot is None:
|
||||
self.logger.warning("Bot 尚未初始化,跳过事件处理。")
|
||||
return
|
||||
|
||||
event.bot = self.bot # 注入 Bot 实例
|
||||
|
||||
# 打印日志
|
||||
t = datetime.fromtimestamp(event.time).strftime("%H:%M:%S")
|
||||
if event.post_type == "message":
|
||||
sender_name = event.sender.nickname if event.sender else "Unknown"
|
||||
print(f" [{t}] [消息] {event.message_type} | {event.user_id}({sender_name}): {event.raw_message}")
|
||||
sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown"
|
||||
message_type = getattr(event, "message_type", "Unknown")
|
||||
user_id = getattr(event, "user_id", "Unknown")
|
||||
raw_message = getattr(event, "raw_message", "")
|
||||
self.logger.info(f"[消息] {message_type} | {user_id}({sender_name}): {raw_message}")
|
||||
elif event.post_type == "notice":
|
||||
print(f" [{t}] [通知] {event.notice_type}")
|
||||
notice_type = getattr(event, "notice_type", "Unknown")
|
||||
self.logger.info(f"[通知] {notice_type}")
|
||||
elif event.post_type == "request":
|
||||
print(f" [{t}] [请求] {event.request_type}")
|
||||
request_type = getattr(event, "request_type", "Unknown")
|
||||
self.logger.info(f"[请求] {request_type}")
|
||||
elif event.post_type == "meta_event":
|
||||
meta_event_type = getattr(event, "meta_event_type", "Unknown")
|
||||
self.logger.debug(f"[元事件] {meta_event_type}")
|
||||
|
||||
# 分发事件
|
||||
await matcher.handle_event(self.bot, event)
|
||||
|
||||
except Exception as e:
|
||||
print(f" 事件处理异常: {e}")
|
||||
traceback.print_exc()
|
||||
self.logger.exception(f"事件处理异常: {str(e)}")
|
||||
error = WebSocketError(
|
||||
message=f"事件处理异常: {str(e)}",
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.log_custom_exception(error)
|
||||
|
||||
async def call_api(self, action: str, params: dict = None):
|
||||
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
|
||||
"""
|
||||
调用 OneBot API
|
||||
向 OneBot v11 实现端发送一个 API 请求。
|
||||
|
||||
:param action: API 动作名称
|
||||
:param params: API 参数
|
||||
:return: API 响应结果
|
||||
该方法通过 WebSocket 发送请求,并使用 `echo` 字段来匹配对应的响应。
|
||||
它创建了一个 `Future` 对象来异步等待响应,并设置了超时机制。
|
||||
|
||||
Args:
|
||||
action (str): API 的动作名称,例如 "send_group_msg"。
|
||||
params (dict, optional): API 请求的参数字典。 Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: OneBot API 的响应数据。如果超时或连接断开,则返回一个
|
||||
表示失败的字典。
|
||||
"""
|
||||
if not self.ws:
|
||||
return {"status": "failed", "msg": "websocket not initialized"}
|
||||
self.logger.error("调用 API 失败: WebSocket 未初始化")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_DISCONNECTED,
|
||||
message="WebSocket未初始化",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
from websockets.protocol import State
|
||||
|
||||
if getattr(self.ws, "state", None) is not State.OPEN:
|
||||
return {"status": "failed", "msg": "websocket is not open"}
|
||||
self.logger.error("调用 API 失败: WebSocket 连接未打开")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_DISCONNECTED,
|
||||
message="WebSocket连接未打开",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
echo_id = str(uuid.uuid4())
|
||||
payload = {"action": action, "params": params or {}, "echo": echo_id}
|
||||
@@ -145,10 +261,23 @@ class WS:
|
||||
future = loop.create_future()
|
||||
self._pending_requests[echo_id] = future
|
||||
|
||||
await self.ws.send(json.dumps(payload))
|
||||
|
||||
try:
|
||||
await self.ws.send(json.dumps(payload))
|
||||
return await asyncio.wait_for(future, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
return {"status": "failed", "retcode": -1, "msg": "api timeout"}
|
||||
self.logger.warning(f"API 调用超时: action={action}, params={params}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.TIMEOUT_ERROR,
|
||||
message="API调用超时",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
message=f"API调用异常: {str(e)}",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
from .command_manager import matcher
|
||||
from .config_loader import global_config
|
||||
from .plugin_manager import PluginDataManager
|
||||
from .ws import WS
|
||||
|
||||
__all__ = ["WS", "matcher", "global_config", "PluginDataManager"]
|
||||
|
||||
@@ -3,6 +3,7 @@ from .message import MessageAPI
|
||||
from .group import GroupAPI
|
||||
from .friend import FriendAPI
|
||||
from .account import AccountAPI
|
||||
from .media import MediaAPI
|
||||
|
||||
__all__ = [
|
||||
"BaseAPI",
|
||||
@@ -10,4 +11,5 @@ __all__ = [
|
||||
"GroupAPI",
|
||||
"FriendAPI",
|
||||
"AccountAPI",
|
||||
"MediaAPI",
|
||||
]
|
||||
|
||||
@@ -1,78 +1,106 @@
|
||||
"""
|
||||
账号相关 API 模块
|
||||
账号与状态相关 API 模块
|
||||
|
||||
该模块定义了 `AccountAPI` Mixin 类,提供了所有与机器人自身账号信息、
|
||||
状态设置等相关的 OneBot v11 API 封装。
|
||||
"""
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from .base import BaseAPI
|
||||
from models.objects import LoginInfo, VersionInfo, Status
|
||||
from ..managers.redis_manager import redis_manager
|
||||
|
||||
|
||||
class AccountAPI(BaseAPI):
|
||||
"""
|
||||
账号相关 API Mixin
|
||||
`AccountAPI` Mixin 类,提供了所有与机器人账号、状态相关的 API 方法。
|
||||
"""
|
||||
|
||||
async def get_login_info(self) -> LoginInfo:
|
||||
async def get_login_info(self, no_cache: bool = False) -> LoginInfo:
|
||||
"""
|
||||
获取登录号信息
|
||||
获取当前登录的机器人账号信息。
|
||||
|
||||
:return: 登录信息对象
|
||||
Args:
|
||||
no_cache (bool, optional): 是否不使用缓存,直接从服务器获取最新信息。Defaults to False.
|
||||
|
||||
Returns:
|
||||
LoginInfo: 包含登录号 QQ 和昵称的 `LoginInfo` 数据对象。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_login_info:{self.self_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.get(cache_key)
|
||||
if cached_data:
|
||||
return LoginInfo(**json.loads(cached_data))
|
||||
|
||||
res = await self.call_api("get_login_info")
|
||||
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return LoginInfo(**res)
|
||||
|
||||
async def get_version_info(self) -> VersionInfo:
|
||||
"""
|
||||
获取版本信息
|
||||
获取 OneBot v11 实现的版本信息。
|
||||
|
||||
:return: 版本信息对象
|
||||
Returns:
|
||||
VersionInfo: 包含 OneBot 实现版本信息的 `VersionInfo` 数据对象。
|
||||
"""
|
||||
res = await self.call_api("get_version_info")
|
||||
return VersionInfo(**res)
|
||||
|
||||
async def get_status(self) -> Status:
|
||||
"""
|
||||
获取状态
|
||||
获取 OneBot v11 实现的状态信息。
|
||||
|
||||
:return: 状态对象
|
||||
Returns:
|
||||
Status: 包含 OneBot 状态信息的 `Status` 数据对象。
|
||||
"""
|
||||
res = await self.call_api("get_status")
|
||||
return Status(**res)
|
||||
|
||||
async def bot_exit(self) -> Dict[str, Any]:
|
||||
"""
|
||||
退出机器人
|
||||
让机器人进程退出(需要实现端支持)。
|
||||
|
||||
:return: API 响应结果
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("bot_exit")
|
||||
|
||||
async def set_self_longnick(self, long_nick: str) -> Dict[str, Any]:
|
||||
"""
|
||||
设置个性签名
|
||||
设置机器人账号的个性签名。
|
||||
|
||||
:param long_nick: 个性签名内容
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
long_nick (str): 要设置的个性签名内容。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_self_longnick", {"longNick": long_nick})
|
||||
|
||||
async def set_input_status(self, user_id: int, event_type: int) -> Dict[str, Any]:
|
||||
"""
|
||||
设置输入状态
|
||||
设置 "对方正在输入..." 状态提示。
|
||||
|
||||
:param user_id: 用户 ID
|
||||
:param event_type: 事件类型
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
event_type (int): 事件类型。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_input_status", {"user_id": user_id, "event_type": event_type})
|
||||
|
||||
async def set_diy_online_status(self, face_id: int, face_type: int, wording: str) -> Dict[str, Any]:
|
||||
"""
|
||||
设置自定义在线状态
|
||||
设置自定义的 "在线状态"。
|
||||
|
||||
:param face_id: 状态 ID
|
||||
:param face_type: 状态类型
|
||||
:param wording: 状态描述
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
face_id (int): 状态的表情 ID。
|
||||
face_type (int): 状态的表情类型。
|
||||
wording (str): 状态的描述文本。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_diy_online_status", {
|
||||
"face_id": face_id,
|
||||
@@ -82,43 +110,108 @@ class AccountAPI(BaseAPI):
|
||||
|
||||
async def set_online_status(self, status_code: int) -> Dict[str, Any]:
|
||||
"""
|
||||
设置在线状态
|
||||
设置在线状态(如在线、离开、摸鱼等)。
|
||||
|
||||
:param status_code: 状态码
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
status_code (int): 目标在线状态的状态码。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_online_status", {"status_code": status_code})
|
||||
|
||||
async def set_qq_profile(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
设置 QQ 资料
|
||||
设置机器人账号的个人资料。
|
||||
|
||||
:param kwargs: 个人资料相关参数
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
**kwargs: 个人资料的相关参数,具体字段请参考 OneBot v11 规范。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_qq_profile", kwargs)
|
||||
|
||||
async def set_qq_avatar(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
设置 QQ 头像
|
||||
设置机器人账号的头像。
|
||||
|
||||
:param kwargs: 头像相关参数
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
**kwargs: 头像的相关参数,具体字段请参考 OneBot v11 规范。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_qq_avatar", kwargs)
|
||||
|
||||
async def get_clientkey(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取客户端密钥
|
||||
获取客户端密钥(通常用于 QQ 登录相关操作)。
|
||||
|
||||
:return: API 响应结果
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("get_clientkey")
|
||||
|
||||
async def clean_cache(self) -> Dict[str, Any]:
|
||||
"""
|
||||
清理缓存
|
||||
清理 OneBot v11 实现端的缓存。
|
||||
|
||||
:return: API 响应结果
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("clean_cache")
|
||||
|
||||
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> Any:
|
||||
"""
|
||||
获取陌生人信息。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Any: 包含陌生人信息的字典或对象。
|
||||
"""
|
||||
return await self.call_api("get_stranger_info", {"user_id": user_id, "no_cache": no_cache})
|
||||
|
||||
async def get_friend_list(self, no_cache: bool = False) -> list:
|
||||
"""
|
||||
获取好友列表。
|
||||
|
||||
Args:
|
||||
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||
|
||||
Returns:
|
||||
list: 好友列表。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_friend_list:{self.self_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.get(cache_key)
|
||||
if cached_data:
|
||||
return json.loads(cached_data)
|
||||
|
||||
res = await self.call_api("get_friend_list")
|
||||
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return res
|
||||
|
||||
async def get_group_list(self, no_cache: bool = False) -> list:
|
||||
"""
|
||||
获取群列表。
|
||||
|
||||
Args:
|
||||
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||
|
||||
Returns:
|
||||
list: 群列表。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_group_list:{self.self_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.get(cache_key)
|
||||
if cached_data:
|
||||
return json.loads(cached_data)
|
||||
|
||||
res = await self.call_api("get_group_list")
|
||||
await redis_manager.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return res
|
||||
|
||||
|
||||
@@ -1,24 +1,50 @@
|
||||
"""
|
||||
API 基础模块
|
||||
|
||||
定义了 API 调用的基础接口。
|
||||
定义了 API 调用的基础接口和统一处理逻辑。
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from ..utils.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ws import WS
|
||||
|
||||
|
||||
class BaseAPI(ABC):
|
||||
class BaseAPI:
|
||||
"""
|
||||
API 基础抽象类
|
||||
API 基础类,提供了统一的 `call_api` 方法,包含日志记录和异常处理。
|
||||
"""
|
||||
_ws: "WS"
|
||||
self_id: int
|
||||
|
||||
def __init__(self, ws_client: "WS", self_id: int):
|
||||
self._ws = ws_client
|
||||
self.self_id = self_id
|
||||
|
||||
@abstractmethod
|
||||
async def call_api(self, action: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
调用 API
|
||||
调用 OneBot v11 API,并提供统一的日志和异常处理。
|
||||
|
||||
:param action: API 动作名称
|
||||
:param params: API 参数
|
||||
:return: API 响应结果
|
||||
:return: API 响应结果的数据部分
|
||||
:raises Exception: 当 API 调用失败或发生网络错误时
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
try:
|
||||
logger.debug(f"调用API -> action: {action}, params: {params}")
|
||||
response = await self._ws.call_api(action, params)
|
||||
logger.debug(f"API响应 <- {response}")
|
||||
|
||||
if response.get("status") == "failed":
|
||||
logger.warning(f"API调用失败: {response}")
|
||||
|
||||
return response.get("data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API调用异常: action={action}, params={params}, error={e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -1,53 +1,86 @@
|
||||
"""
|
||||
好友相关 API 模块
|
||||
好友与陌生人相关 API 模块
|
||||
|
||||
该模块定义了 `FriendAPI` Mixin 类,提供了所有与好友、陌生人信息
|
||||
等相关的 OneBot v11 API 封装。
|
||||
"""
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
from .base import BaseAPI
|
||||
from models.objects import FriendInfo, StrangerInfo
|
||||
from ..managers.redis_manager import redis_manager
|
||||
|
||||
|
||||
class FriendAPI(BaseAPI):
|
||||
"""
|
||||
好友相关 API Mixin
|
||||
`FriendAPI` Mixin 类,提供了所有与好友、陌生人操作相关的 API 方法。
|
||||
"""
|
||||
|
||||
async def send_like(self, user_id: int, times: int = 1) -> Dict[str, Any]:
|
||||
"""
|
||||
发送点赞
|
||||
向指定用户发送 "戳一戳" (点赞)。
|
||||
|
||||
:param user_id: 对方 QQ 号
|
||||
:param times: 点赞次数
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
times (int, optional): 点赞次数,建议不超过 10 次。Defaults to 1.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("send_like", {"user_id": user_id, "times": times})
|
||||
|
||||
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> StrangerInfo:
|
||||
"""
|
||||
获取陌生人信息
|
||||
获取陌生人的信息。
|
||||
|
||||
:param user_id: QQ 号
|
||||
:param no_cache: 是否不使用缓存
|
||||
:return: 陌生人信息对象
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
no_cache (bool, optional): 是否不使用缓存,直接从服务器获取。Defaults to False.
|
||||
|
||||
Returns:
|
||||
StrangerInfo: 包含陌生人信息的 `StrangerInfo` 数据对象。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_stranger_info:{user_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.redis.get(cache_key)
|
||||
if cached_data:
|
||||
return StrangerInfo(**json.loads(cached_data))
|
||||
|
||||
res = await self.call_api("get_stranger_info", {"user_id": user_id, "no_cache": no_cache})
|
||||
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return StrangerInfo(**res)
|
||||
|
||||
async def get_friend_list(self) -> List[FriendInfo]:
|
||||
async def get_friend_list(self, no_cache: bool = False) -> List[FriendInfo]:
|
||||
"""
|
||||
获取好友列表
|
||||
获取机器人账号的好友列表。
|
||||
|
||||
:return: 好友信息对象列表
|
||||
Args:
|
||||
no_cache (bool, optional): 是否不使用缓存,直接从服务器获取最新信息。Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[FriendInfo]: 包含所有好友信息的 `FriendInfo` 对象列表。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_friend_list:{self.self_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.redis.get(cache_key)
|
||||
if cached_data:
|
||||
return [FriendInfo(**item) for item in json.loads(cached_data)]
|
||||
|
||||
res = await self.call_api("get_friend_list")
|
||||
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return [FriendInfo(**item) for item in res]
|
||||
|
||||
async def set_friend_add_request(self, flag: str, approve: bool = True, remark: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
处理加好友请求
|
||||
处理收到的加好友请求。
|
||||
|
||||
:param flag: 加好友请求的 flag(需从上报的数据中获取)
|
||||
:param approve: 是否同意请求
|
||||
:param remark: 添加后的好友备注(仅在同意时有效)
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
flag (str): 请求的标识,需要从 `request` 事件中获取。
|
||||
approve (bool, optional): 是否同意该好友请求。Defaults to True.
|
||||
remark (str, optional): 在同意请求时,为该好友设置的备注。Defaults to "".
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_friend_add_request", {"flag": flag, "approve": approve, "remark": remark})
|
||||
|
||||
|
||||
@@ -1,49 +1,67 @@
|
||||
"""
|
||||
群组相关 API 模块
|
||||
|
||||
该模块定义了 `GroupAPI` Mixin 类,提供了所有与群组管理、成员操作
|
||||
等相关的 OneBot v11 API 封装。
|
||||
"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
import json
|
||||
from ..managers.redis_manager import redis_manager
|
||||
from .base import BaseAPI
|
||||
from models.objects import GroupInfo, GroupMemberInfo, GroupHonorInfo
|
||||
from ..utils.logger import logger
|
||||
|
||||
|
||||
class GroupAPI(BaseAPI):
|
||||
"""
|
||||
群组相关 API Mixin
|
||||
`GroupAPI` Mixin 类,提供了所有与群组操作相关的 API 方法。
|
||||
"""
|
||||
|
||||
async def set_group_kick(self, group_id: int, user_id: int, reject_add_request: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
群组踢人
|
||||
将指定成员踢出群组。
|
||||
|
||||
:param group_id: 群号
|
||||
:param user_id: 要踢的 QQ 号
|
||||
:param reject_add_request: 拒绝此人的加群请求
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 要踢出的成员的 QQ 号。
|
||||
reject_add_request (bool, optional): 是否拒绝该用户此后的加群请求。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_kick", {"group_id": group_id, "user_id": user_id, "reject_add_request": reject_add_request})
|
||||
|
||||
async def set_group_ban(self, group_id: int, user_id: int, duration: int = 30 * 60) -> Dict[str, Any]:
|
||||
async def set_group_ban(self, group_id: int, user_id: int, duration: int = 1800) -> Dict[str, Any]:
|
||||
"""
|
||||
群组单人禁言
|
||||
禁言群组中的指定成员。
|
||||
|
||||
:param group_id: 群号
|
||||
:param user_id: 要禁言的 QQ 号
|
||||
:param duration: 禁言时长(秒),0 表示解除禁言
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 要禁言的成员的 QQ 号。
|
||||
duration (int, optional): 禁言时长,单位为秒。设置为 0 表示解除禁言。
|
||||
Defaults to 1800 (30 分钟).
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_ban", {"group_id": group_id, "user_id": user_id, "duration": duration})
|
||||
|
||||
async def set_group_anonymous_ban(self, group_id: int, anonymous: Dict[str, Any] = None, duration: int = 30 * 60, flag: str = None) -> Dict[str, Any]:
|
||||
async def set_group_anonymous_ban(self, group_id: int, anonymous: Optional[Dict[str, Any]] = None, duration: int = 1800, flag: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
群组匿名禁言
|
||||
禁言群组中的匿名用户。
|
||||
|
||||
:param group_id: 群号
|
||||
:param anonymous: 可选,要禁言的匿名用户对象(群消息事件的 anonymous 字段)
|
||||
:param duration: 禁言时长(秒)
|
||||
:param flag: 可选,要禁言的匿名用户的 flag(需从群消息事件的 anonymous 字段中获取)
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
anonymous (Dict[str, Any], optional): 要禁言的匿名用户对象,
|
||||
可从群消息事件的 `anonymous` 字段中获取。Defaults to None.
|
||||
duration (int, optional): 禁言时长,单位为秒。Defaults to 1800.
|
||||
flag (str, optional): 要禁言的匿名用户的 flag 标识,
|
||||
可从群消息事件的 `anonymous` 字段中获取。Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
params = {"group_id": group_id, "duration": duration}
|
||||
params: Dict[str, Any] = {"group_id": group_id, "duration": duration}
|
||||
if anonymous:
|
||||
params["anonymous"] = anonymous
|
||||
if flag:
|
||||
@@ -52,139 +70,215 @@ class GroupAPI(BaseAPI):
|
||||
|
||||
async def set_group_whole_ban(self, group_id: int, enable: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
群组全员禁言
|
||||
开启或关闭群组全员禁言。
|
||||
|
||||
:param group_id: 群号
|
||||
:param enable: 是否开启
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
enable (bool, optional): True 表示开启全员禁言,False 表示关闭。Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_whole_ban", {"group_id": group_id, "enable": enable})
|
||||
|
||||
async def set_group_admin(self, group_id: int, user_id: int, enable: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
群组设置管理员
|
||||
设置或取消群组成员的管理员权限。
|
||||
|
||||
:param group_id: 群号
|
||||
:param user_id: 要设置的 QQ 号
|
||||
:param enable: True 为设置,False 为取消
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 目标成员的 QQ 号。
|
||||
enable (bool, optional): True 表示设为管理员,False 表示取消管理员。Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_admin", {"group_id": group_id, "user_id": user_id, "enable": enable})
|
||||
|
||||
async def set_group_anonymous(self, group_id: int, enable: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
群组匿名
|
||||
开启或关闭群组的匿名聊天功能。
|
||||
|
||||
:param group_id: 群号
|
||||
:param enable: 是否开启
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
enable (bool, optional): True 表示开启匿名,False 表示关闭。Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_anonymous", {"group_id": group_id, "enable": enable})
|
||||
|
||||
async def set_group_card(self, group_id: int, user_id: int, card: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
设置群名片(群备注)
|
||||
设置群组成员的群名片。
|
||||
|
||||
:param group_id: 群号
|
||||
:param user_id: 要设置的 QQ 号
|
||||
:param card: 群名片内容,不填或空字符串表示删除群名片
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 目标成员的 QQ 号。
|
||||
card (str, optional): 要设置的群名片内容。
|
||||
传入空字符串 `""` 或 `None` 表示删除该成员的群名片。Defaults to "".
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_card", {"group_id": group_id, "user_id": user_id, "card": card})
|
||||
|
||||
async def set_group_name(self, group_id: int, group_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
设置群名
|
||||
设置群组的名称。
|
||||
|
||||
:param group_id: 群号
|
||||
:param group_name: 新群名
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
group_name (str): 新的群组名称。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_name", {"group_id": group_id, "group_name": group_name})
|
||||
|
||||
async def set_group_leave(self, group_id: int, is_dismiss: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
退出群组
|
||||
退出或解散一个群组。
|
||||
|
||||
:param group_id: 群号
|
||||
:param is_dismiss: 是否解散,如果登录号是群主,则仅在此项为 True 时能够解散
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
is_dismiss (bool, optional): 是否解散群组。
|
||||
仅当机器人是群主时,此项设为 True 才能解散群。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_leave", {"group_id": group_id, "is_dismiss": is_dismiss})
|
||||
|
||||
async def set_group_special_title(self, group_id: int, user_id: int, special_title: str = "", duration: int = -1) -> Dict[str, Any]:
|
||||
"""
|
||||
设置群组专属头衔
|
||||
为群组成员设置专属头衔。
|
||||
|
||||
:param group_id: 群号
|
||||
:param user_id: 要设置的 QQ 号
|
||||
:param special_title: 专属头衔,不填或空字符串表示删除
|
||||
:param duration: 有效期(秒),-1 表示永久
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 目标成员的 QQ 号。
|
||||
special_title (str, optional): 专属头衔内容。
|
||||
传入空字符串 `""` 或 `None` 表示删除头衔。Defaults to "".
|
||||
duration (int, optional): 头衔有效期,单位为秒。-1 表示永久。Defaults to -1.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_special_title", {"group_id": group_id, "user_id": user_id, "special_title": special_title, "duration": duration})
|
||||
|
||||
async def get_group_info(self, group_id: int, no_cache: bool = False) -> GroupInfo:
|
||||
"""
|
||||
获取群信息
|
||||
获取群组的详细信息。
|
||||
|
||||
:param group_id: 群号
|
||||
:param no_cache: 是否不使用缓存
|
||||
:return: 群信息对象
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
no_cache (bool, optional): 是否不使用缓存,直接从服务器获取最新信息。Defaults to False.
|
||||
|
||||
Returns:
|
||||
GroupInfo: 包含群组信息的 `GroupInfo` 数据对象。
|
||||
"""
|
||||
res = await self.call_api("get_group_info", {"group_id": group_id, "no_cache": no_cache})
|
||||
cache_key = f"neobot:cache:get_group_info:{group_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.redis.get(cache_key)
|
||||
if cached_data:
|
||||
return GroupInfo(**json.loads(cached_data))
|
||||
|
||||
res = await self.call_api("get_group_info", {"group_id": group_id})
|
||||
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return GroupInfo(**res)
|
||||
|
||||
async def get_group_list(self) -> List[GroupInfo]:
|
||||
async def get_group_list(self) -> Any:
|
||||
"""
|
||||
获取群列表
|
||||
获取机器人加入的所有群组的列表。
|
||||
|
||||
:return: 群信息对象列表
|
||||
Returns:
|
||||
Any: 包含所有群组信息的列表(可能是字典列表或对象列表)。
|
||||
"""
|
||||
res = await self.call_api("get_group_list")
|
||||
return [GroupInfo(**item) for item in res]
|
||||
|
||||
# 增加日志记录 API 原始返回
|
||||
logger.debug(f"OneBot API 'get_group_list' raw response: {res}")
|
||||
return res
|
||||
|
||||
# 健壮性处理:处理标准的 OneBot v11 响应格式
|
||||
if isinstance(res, dict) and res.get("status") == "ok":
|
||||
group_data = res.get("data", [])
|
||||
if isinstance(group_data, list):
|
||||
return [GroupInfo(**item) for item in group_data]
|
||||
else:
|
||||
logger.error(f"The 'data' field in 'get_group_list' response is not a list: {group_data}")
|
||||
return []
|
||||
|
||||
# 兼容处理:如果返回的是列表(非标准但可能存在)
|
||||
if isinstance(res, list):
|
||||
return [GroupInfo(**item) for item in res]
|
||||
|
||||
logger.error(f"Unexpected response format from 'get_group_list': {res}")
|
||||
return []
|
||||
|
||||
async def get_group_member_info(self, group_id: int, user_id: int, no_cache: bool = False) -> GroupMemberInfo:
|
||||
"""
|
||||
获取群成员信息
|
||||
获取指定群组成员的详细信息。
|
||||
|
||||
:param group_id: 群号
|
||||
:param user_id: QQ 号
|
||||
:param no_cache: 是否不使用缓存
|
||||
:return: 群成员信息对象
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 目标成员的 QQ 号。
|
||||
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||
|
||||
Returns:
|
||||
GroupMemberInfo: 包含群成员信息的 `GroupMemberInfo` 数据对象。
|
||||
"""
|
||||
res = await self.call_api("get_group_member_info", {"group_id": group_id, "user_id": user_id, "no_cache": no_cache})
|
||||
cache_key = f"neobot:cache:get_group_member_info:{group_id}:{user_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.redis.get(cache_key)
|
||||
if cached_data:
|
||||
return GroupMemberInfo(**json.loads(cached_data))
|
||||
|
||||
res = await self.call_api("get_group_member_info", {"group_id": group_id, "user_id": user_id})
|
||||
await redis_manager.redis.set(cache_key, json.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return GroupMemberInfo(**res)
|
||||
|
||||
async def get_group_member_list(self, group_id: int) -> List[GroupMemberInfo]:
|
||||
"""
|
||||
获取群成员列表
|
||||
获取一个群组的所有成员列表。
|
||||
|
||||
:param group_id: 群号
|
||||
:return: 群成员信息对象列表
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
|
||||
Returns:
|
||||
List[GroupMemberInfo]: 包含所有群成员信息的 `GroupMemberInfo` 对象列表。
|
||||
"""
|
||||
res = await self.call_api("get_group_member_list", {"group_id": group_id})
|
||||
return [GroupMemberInfo(**item) for item in res]
|
||||
|
||||
async def get_group_honor_info(self, group_id: int, type: str) -> GroupHonorInfo:
|
||||
"""
|
||||
获取群荣誉信息
|
||||
获取群组的荣誉信息(如龙王、群聊之火等)。
|
||||
|
||||
:param group_id: 群号
|
||||
:param type: 要获取的群荣誉类型,可传入 talkative, performer, legend, strong_newbie, emotion 等
|
||||
:return: 群荣誉信息对象
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
type (str): 要获取的荣誉类型。
|
||||
可选值: "talkative", "performer", "legend", "strong_newbie", "emotion" 等。
|
||||
|
||||
Returns:
|
||||
GroupHonorInfo: 包含群荣誉信息的 `GroupHonorInfo` 数据对象。
|
||||
"""
|
||||
res = await self.call_api("get_group_honor_info", {"group_id": group_id, "type": type})
|
||||
return GroupHonorInfo(**res)
|
||||
|
||||
async def set_group_add_request(self, flag: str, sub_type: str, approve: bool = True, reason: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
处理加群请求/邀请
|
||||
处理加群请求或邀请。
|
||||
|
||||
:param flag: 加群请求的 flag(需从上报的数据中获取)
|
||||
:param sub_type: add 或 invite,请求类型(需要与上报消息中的 sub_type 字段相符)
|
||||
:param approve: 是否同意请求/邀请
|
||||
:param reason: 拒绝理由(仅在拒绝时有效)
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
flag (str): 请求的标识,需要从 `request` 事件中获取。
|
||||
sub_type (str): 请求的子类型,`add` 或 `invite`,
|
||||
需要与 `request` 事件中的 `sub_type` 字段相符。
|
||||
approve (bool, optional): 是否同意请求或邀请。Defaults to True.
|
||||
reason (str, optional): 拒绝加群的理由(仅在 `approve` 为 False 时有效)。Defaults to "".
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_add_request", {"flag": flag, "sub_type": sub_type, "approve": approve, "reason": reason})
|
||||
|
||||
|
||||
39
core/api/media.py
Normal file
39
core/api/media.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
媒体API模块
|
||||
|
||||
封装了与图片、语音等媒体文件相关的API。
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
|
||||
from .base import BaseAPI
|
||||
|
||||
|
||||
class MediaAPI(BaseAPI):
|
||||
"""
|
||||
媒体相关API
|
||||
"""
|
||||
|
||||
async def can_send_image(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查是否可以发送图片
|
||||
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="can_send_image")
|
||||
|
||||
async def can_send_record(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查是否可以发送语音
|
||||
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="can_send_record")
|
||||
|
||||
async def get_image(self, file: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取图片信息
|
||||
|
||||
:param file: 图片文件名或路径
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="get_image", params={"file": file})
|
||||
@@ -1,26 +1,35 @@
|
||||
"""
|
||||
消息相关 API 模块
|
||||
|
||||
该模块定义了 `MessageAPI` Mixin 类,提供了所有与消息发送、撤回、
|
||||
转发等相关的 OneBot v11 API 封装。
|
||||
"""
|
||||
from typing import Union, List, Dict, Any, TYPE_CHECKING
|
||||
from .base import BaseAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import MessageSegment, OneBotEvent
|
||||
from models.message import MessageSegment
|
||||
from models.events.base import OneBotEvent
|
||||
|
||||
|
||||
class MessageAPI(BaseAPI):
|
||||
"""
|
||||
消息相关 API Mixin
|
||||
`MessageAPI` Mixin 类,提供了所有与消息操作相关的 API 方法。
|
||||
"""
|
||||
|
||||
async def send_group_msg(self, group_id: int, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
发送群消息
|
||||
发送群消息。
|
||||
|
||||
:param group_id: 群号
|
||||
:param message: 消息内容,可以是字符串、MessageSegment 对象或 MessageSegment 列表
|
||||
:param auto_escape: 是否自动转义(仅当 message 为字符串时有效)
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
message (Union[str, MessageSegment, List[MessageSegment]]): 要发送的消息内容。
|
||||
可以是纯文本字符串、单个消息段对象或消息段列表。
|
||||
auto_escape (bool, optional): 仅当 `message` 为字符串时有效,
|
||||
是否对消息内容进行 CQ 码转义。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api(
|
||||
"send_group_msg", {"group_id": group_id, "message": self._process_message(message), "auto_escape": auto_escape}
|
||||
@@ -28,12 +37,15 @@ class MessageAPI(BaseAPI):
|
||||
|
||||
async def send_private_msg(self, user_id: int, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
发送私聊消息
|
||||
发送私聊消息。
|
||||
|
||||
:param user_id: 用户 QQ 号
|
||||
:param message: 消息内容,可以是字符串、MessageSegment 对象或 MessageSegment 列表
|
||||
:param auto_escape: 是否自动转义(仅当 message 为字符串时有效)
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
message (Union[str, MessageSegment, List[MessageSegment]]): 要发送的消息内容。
|
||||
auto_escape (bool, optional): 是否对消息内容进行 CQ 码转义。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api(
|
||||
"send_private_msg", {"user_id": user_id, "message": self._process_message(message), "auto_escape": auto_escape}
|
||||
@@ -41,12 +53,18 @@ class MessageAPI(BaseAPI):
|
||||
|
||||
async def send(self, event: "OneBotEvent", message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
智能发送消息,根据事件类型自动选择发送方式
|
||||
智能发送消息。
|
||||
|
||||
:param event: 触发事件对象
|
||||
:param message: 消息内容
|
||||
:param auto_escape: 是否自动转义
|
||||
:return: API 响应结果
|
||||
该方法会根据传入的事件对象 `event` 自动判断是私聊还是群聊,
|
||||
并调用相应的发送函数。如果事件是消息事件,则优先使用 `reply` 方法。
|
||||
|
||||
Args:
|
||||
event (OneBotEvent): 触发该发送行为的事件对象。
|
||||
message (Union[str, MessageSegment, List[MessageSegment]]): 要发送的消息内容。
|
||||
auto_escape (bool, optional): 是否对消息内容进行 CQ 码转义。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
# 如果是消息事件,直接调用 reply
|
||||
if hasattr(event, "reply"):
|
||||
@@ -66,59 +84,98 @@ class MessageAPI(BaseAPI):
|
||||
|
||||
async def delete_msg(self, message_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
撤回消息
|
||||
撤回一条消息。
|
||||
|
||||
:param message_id: 消息 ID
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
message_id (int): 要撤回的消息的 ID。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("delete_msg", {"message_id": message_id})
|
||||
|
||||
async def get_msg(self, message_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
获取消息
|
||||
获取一条消息的详细信息。
|
||||
|
||||
:param message_id: 消息 ID
|
||||
:return: API 响应结果
|
||||
Args:
|
||||
message_id (int): 要获取的消息的 ID。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据,包含消息详情。
|
||||
"""
|
||||
return await self.call_api("get_msg", {"message_id": message_id})
|
||||
|
||||
async def get_forward_msg(self, id: str) -> Dict[str, Any]:
|
||||
async def get_forward_msg(self, id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取合并转发消息
|
||||
获取合并转发消息的内容。
|
||||
|
||||
:param id: 合并转发 ID
|
||||
:return: API 响应结果
|
||||
"""
|
||||
return await self.call_api("get_forward_msg", {"id": id})
|
||||
Args:
|
||||
id (str): 合并转发消息的 ID。
|
||||
|
||||
async def can_send_image(self) -> Dict[str, Any]:
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 转发消息的节点列表。
|
||||
"""
|
||||
检查是否可以发送图片
|
||||
forward_data = await self.call_api("get_forward_msg", {"id": id})
|
||||
nodes = forward_data.get("data")
|
||||
|
||||
:return: API 响应结果
|
||||
"""
|
||||
return await self.call_api("can_send_image")
|
||||
if not isinstance(nodes, list):
|
||||
# 兼容某些实现可能将节点放在 'messages' 键下
|
||||
data = forward_data.get('data', {})
|
||||
if isinstance(data, dict):
|
||||
nodes = data.get('messages')
|
||||
|
||||
async def can_send_record(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查是否可以发送语音
|
||||
if not isinstance(nodes, list):
|
||||
raise ValueError("在 get_forward_msg 响应中找不到消息节点列表")
|
||||
|
||||
:return: API 响应结果
|
||||
return nodes
|
||||
|
||||
async def send_group_forward_msg(self, group_id: int, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
return await self.call_api("can_send_record")
|
||||
发送群聊合并转发消息。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
messages (List[Dict[str, Any]]): 消息节点列表。
|
||||
推荐使用 `bot.build_forward_node` 来构建节点。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("send_group_forward_msg", {"group_id": group_id, "messages": messages})
|
||||
|
||||
async def send_private_forward_msg(self, user_id: int, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
发送私聊合并转发消息。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
messages (List[Dict[str, Any]]): 消息节点列表。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("send_private_forward_msg", {"user_id": user_id, "messages": messages})
|
||||
|
||||
def _process_message(self, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Union[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
处理消息内容,将其转换为 API 可接受的格式
|
||||
内部方法:将消息内容处理成 OneBot API 可接受的格式。
|
||||
|
||||
:param message: 原始消息内容
|
||||
:return: 处理后的消息内容
|
||||
- `str` -> `str`
|
||||
- `MessageSegment` -> `List[Dict]`
|
||||
- `List[MessageSegment]` -> `List[Dict]`
|
||||
|
||||
Args:
|
||||
message: 原始消息内容。
|
||||
|
||||
Returns:
|
||||
处理后的消息内容。
|
||||
"""
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
|
||||
# 避免循环导入,在运行时导入
|
||||
from models import MessageSegment
|
||||
from models.message import MessageSegment
|
||||
|
||||
if isinstance(message, MessageSegment):
|
||||
return [self._segment_to_dict(message)]
|
||||
@@ -130,12 +187,16 @@ class MessageAPI(BaseAPI):
|
||||
|
||||
def _segment_to_dict(self, segment: "MessageSegment") -> Dict[str, Any]:
|
||||
"""
|
||||
将 MessageSegment 对象转换为字典
|
||||
内部方法:将 `MessageSegment` 对象转换为字典。
|
||||
|
||||
:param segment: MessageSegment 对象
|
||||
:return: 字典格式的消息段
|
||||
Args:
|
||||
segment (MessageSegment): 消息段对象。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 符合 OneBot 规范的消息段字典。
|
||||
"""
|
||||
return {
|
||||
"type": segment.type,
|
||||
"data": segment.data
|
||||
}
|
||||
|
||||
|
||||
114
core/bot.py
114
core/bot.py
@@ -1,36 +1,116 @@
|
||||
"""
|
||||
Bot 抽象模块
|
||||
Bot 核心抽象模块
|
||||
|
||||
定义了 Bot 类,封装了 OneBot API 的调用逻辑,提供了便捷的消息发送方法。
|
||||
该模块定义了 `Bot` 类,它是与 OneBot v11 API 进行交互的主要接口。
|
||||
`Bot` 类通过继承 `api` 目录下的各个 Mixin 类,将不同类别的 API 调用
|
||||
整合在一起,提供了一个统一、便捷的调用入口。
|
||||
|
||||
主要职责包括:
|
||||
- 封装 WebSocket 通信,提供 `call_api` 方法。
|
||||
- 提供高级消息发送功能,如 `send_forwarded_messages`。
|
||||
- 整合所有细分的 API 调用(消息、群组、好友等)。
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Dict, Any
|
||||
from typing import TYPE_CHECKING, Dict, Any, List, Union, Optional
|
||||
from models.events.base import OneBotEvent
|
||||
from models.message import MessageSegment
|
||||
from models.objects import GroupInfo, StrangerInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .ws import WS
|
||||
from .utils.executor import CodeExecutor
|
||||
|
||||
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI
|
||||
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI
|
||||
|
||||
|
||||
class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI):
|
||||
class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI):
|
||||
"""
|
||||
Bot 抽象类,封装 API 调用和常用操作
|
||||
继承各个 API Mixin 以提高代码的可维护性
|
||||
机器人核心类,封装了所有与 OneBot API 的交互。
|
||||
|
||||
通过 Mixin 模式继承了所有 API 功能,使得结构清晰且易于扩展。
|
||||
实例由 `WS` 客户端在连接成功后创建,并传递给所有事件处理器和插件。
|
||||
"""
|
||||
|
||||
def __init__(self, ws_client: "WS"):
|
||||
"""
|
||||
初始化 Bot 实例
|
||||
初始化 Bot 实例。
|
||||
|
||||
:param ws_client: WebSocket 客户端实例,用于底层通信
|
||||
Args:
|
||||
ws_client (WS): WebSocket 客户端实例,负责底层的 API 请求和响应处理。
|
||||
"""
|
||||
self.ws = ws_client
|
||||
super().__init__(ws_client, ws_client.self_id or 0)
|
||||
self.code_executor: Optional["CodeExecutor"] = None
|
||||
|
||||
async def call_api(self, action: str, params: Dict[str, Any] = None) -> Any:
|
||||
"""
|
||||
调用 OneBot API
|
||||
async def get_group_list(self, no_cache: bool = False) -> List[GroupInfo]:
|
||||
# GroupAPI.get_group_list 不支持 no_cache 参数,这里忽略它
|
||||
result = await super().get_group_list()
|
||||
# 确保结果是 GroupInfo 对象列表
|
||||
return [GroupInfo(**group) if isinstance(group, dict) else group for group in result]
|
||||
|
||||
:param action: API 动作名称
|
||||
:param params: API 参数
|
||||
:return: API 响应结果
|
||||
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> StrangerInfo:
|
||||
result = await super().get_stranger_info(user_id=user_id, no_cache=no_cache)
|
||||
# 确保结果是 StrangerInfo 对象
|
||||
if isinstance(result, dict):
|
||||
return StrangerInfo(**result)
|
||||
return result
|
||||
|
||||
|
||||
def build_forward_node(self, user_id: int, nickname: str, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Dict[str, Any]:
|
||||
"""
|
||||
return await self.ws.call_api(action, params)
|
||||
构建一个用于合并转发的消息节点 (Node)。
|
||||
|
||||
这是一个辅助方法,用于方便地创建符合 OneBot v11 规范的消息节点,
|
||||
以便在 `send_forwarded_messages` 中使用。
|
||||
|
||||
Args:
|
||||
user_id (int): 发送者的 QQ 号。
|
||||
nickname (str): 发送者在消息中显示的昵称。
|
||||
message (Union[str, MessageSegment, List[MessageSegment]]): 该节点的消息内容,
|
||||
可以是纯文本、单个消息段或消息段列表。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 构造好的消息节点字典。
|
||||
"""
|
||||
return {
|
||||
"type": "node",
|
||||
"data": {
|
||||
"uin": user_id,
|
||||
"name": nickname,
|
||||
"content": self._process_message(message)
|
||||
}
|
||||
}
|
||||
|
||||
async def send_forwarded_messages(self, target: Union[int, "OneBotEvent"], nodes: List[Dict[str, Any]]):
|
||||
"""
|
||||
发送合并转发消息。
|
||||
|
||||
该方法实现了智能判断,可以根据 `target` 的类型自动发送群聊合并转发
|
||||
或私聊合并转发消息。
|
||||
|
||||
Args:
|
||||
target (Union[int, OneBotEvent]): 发送目标。
|
||||
- 如果是 `OneBotEvent` 对象,则自动判断是群聊还是私聊。
|
||||
- 如果是 `int`,则默认为群号,发送群聊合并转发。
|
||||
nodes (List[Dict[str, Any]]): 消息节点列表。
|
||||
推荐使用 `build_forward_node` 方法来构建列表中的每个节点。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果事件对象中既没有 `group_id` 也没有 `user_id`。
|
||||
"""
|
||||
if isinstance(target, OneBotEvent):
|
||||
group_id = getattr(target, "group_id", None)
|
||||
user_id = getattr(target, "user_id", None)
|
||||
|
||||
if group_id:
|
||||
# 直接发送群聊合并转发
|
||||
await self.send_group_forward_msg(group_id, nodes)
|
||||
elif user_id:
|
||||
# 发送私聊合并转发
|
||||
await self.send_private_forward_msg(user_id, nodes)
|
||||
else:
|
||||
raise ValueError("Event has neither group_id nor user_id")
|
||||
|
||||
else:
|
||||
# 默认行为是发送到群聊
|
||||
group_id = target
|
||||
await self.send_group_forward_msg(group_id, nodes)
|
||||
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
"""
|
||||
命令管理器模块
|
||||
|
||||
提供装饰器用于注册消息指令、通知处理器和请求处理器,并负责事件的分发。
|
||||
"""
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
from .config_loader import global_config
|
||||
|
||||
# 从配置中获取命令前缀
|
||||
comm_prefixes = global_config.bot.get("command", ("/",))
|
||||
|
||||
|
||||
class CommandManager:
|
||||
"""
|
||||
命令管理器,负责注册和分发指令、通知和请求事件
|
||||
"""
|
||||
|
||||
def __init__(self, prefixes: Tuple[str, ...] = ("/",)):
|
||||
"""
|
||||
初始化命令管理器
|
||||
|
||||
:param prefixes: 命令前缀元组
|
||||
"""
|
||||
self.prefixes = prefixes
|
||||
self.commands: Dict[str, Callable] = {} # 存储消息指令
|
||||
self.notice_handlers: List[Dict] = [] # 存储通知处理器
|
||||
self.request_handlers: List[Dict] = [] # 存储请求处理器
|
||||
self.plugins: Dict[str, Dict[str, Any]] = {} # 存储插件元数据
|
||||
|
||||
# --- 内置 help 指令 ---
|
||||
self.commands["help"] = self._help_command
|
||||
self.plugins["core.help"] = {
|
||||
"name": "帮助",
|
||||
"description": "显示所有可用指令的帮助信息",
|
||||
"usage": "/help",
|
||||
}
|
||||
|
||||
async def _help_command(self, bot, event):
|
||||
"""
|
||||
内置的 /help 指令处理器
|
||||
|
||||
:param bot: Bot 实例
|
||||
:param event: 消息事件对象
|
||||
"""
|
||||
help_text = "--- 可用指令列表 ---\n"
|
||||
|
||||
for plugin_name, meta in self.plugins.items():
|
||||
name = meta.get("name", "未命名插件")
|
||||
description = meta.get("description", "暂无描述")
|
||||
usage = meta.get("usage", "暂无用法说明")
|
||||
|
||||
help_text += f"\n{name}:\n"
|
||||
help_text += f" 功能: {description}\n"
|
||||
help_text += f" 用法: {usage}\n"
|
||||
|
||||
await bot.send(event, help_text.strip())
|
||||
|
||||
# --- 1. 消息指令装饰器 ---
|
||||
def command(self, name: str):
|
||||
"""
|
||||
装饰器:注册消息指令
|
||||
|
||||
:param name: 指令名称(不含前缀)
|
||||
:return: 装饰器函数
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
self.commands[name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
# --- 2. 通知事件装饰器 ---
|
||||
def on_notice(self, notice_type: str = None):
|
||||
"""
|
||||
装饰器:注册通知处理器
|
||||
|
||||
:param notice_type: 通知类型,如果为 None 则处理所有通知
|
||||
:return: 装饰器函数
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
self.notice_handlers.append({"type": notice_type, "func": func})
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
# --- 3. 请求事件装饰器 ---
|
||||
def on_request(self, request_type: str = None):
|
||||
"""
|
||||
装饰器:注册请求处理器
|
||||
|
||||
:param request_type: 请求类型,如果为 None 则处理所有请求
|
||||
:return: 装饰器函数
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
self.request_handlers.append({"type": request_type, "func": func})
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
# --- 统一事件分发入口 ---
|
||||
async def handle_event(self, bot, event):
|
||||
"""
|
||||
统一事件分发入口
|
||||
|
||||
:param bot: Bot 实例
|
||||
:param event: 事件对象
|
||||
"""
|
||||
post_type = event.post_type
|
||||
|
||||
if post_type == 'message':
|
||||
await self.handle_message(bot, event)
|
||||
elif post_type == 'notice':
|
||||
await self.handle_notice(bot, event)
|
||||
elif post_type == 'request':
|
||||
await self.handle_request(bot, event)
|
||||
|
||||
# --- 消息分发逻辑 ---
|
||||
async def handle_message(self, bot, event):
|
||||
"""
|
||||
解析并分发消息指令
|
||||
|
||||
:param bot: Bot 实例
|
||||
:param event: 消息事件对象
|
||||
"""
|
||||
if not event.raw_message:
|
||||
return
|
||||
|
||||
raw_text = event.raw_message.strip()
|
||||
|
||||
# 1. 检查前缀
|
||||
prefix_found = None
|
||||
for p in self.prefixes:
|
||||
if raw_text.startswith(p):
|
||||
prefix_found = p
|
||||
break
|
||||
|
||||
if not prefix_found:
|
||||
return
|
||||
|
||||
# 2. 拆分指令和参数
|
||||
full_cmd = raw_text[len(prefix_found) :].split()
|
||||
if not full_cmd:
|
||||
return
|
||||
|
||||
cmd_name = full_cmd[0]
|
||||
args = full_cmd[1:]
|
||||
|
||||
# 3. 查找并执行
|
||||
if cmd_name in self.commands:
|
||||
func = self.commands[cmd_name]
|
||||
await self._run_handler(func, bot, event, args)
|
||||
|
||||
# --- 通知分发逻辑 ---
|
||||
async def handle_notice(self, bot, event):
|
||||
"""
|
||||
分发通知事件
|
||||
|
||||
:param bot: Bot 实例
|
||||
:param event: 通知事件对象
|
||||
"""
|
||||
for handler in self.notice_handlers:
|
||||
if handler["type"] is None or handler["type"] == event.notice_type:
|
||||
await self._run_handler(handler["func"], bot, event)
|
||||
|
||||
# --- 请求分发逻辑 ---
|
||||
async def handle_request(self, bot, event):
|
||||
"""
|
||||
分发请求事件
|
||||
|
||||
:param bot: Bot 实例
|
||||
:param event: 请求事件对象
|
||||
"""
|
||||
for handler in self.request_handlers:
|
||||
if handler["type"] is None or handler["type"] == event.request_type:
|
||||
await self._run_handler(handler["func"], bot, event)
|
||||
|
||||
# --- 通用执行器:自动注入参数 ---
|
||||
async def _run_handler(self, func, bot, event, args=None):
|
||||
"""
|
||||
根据函数签名自动注入 bot, event 或 args
|
||||
|
||||
:param func: 目标处理函数
|
||||
:param bot: Bot 实例
|
||||
:param event: 事件对象
|
||||
:param args: 指令参数(仅消息指令有效)
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
params = sig.parameters
|
||||
kwargs = {}
|
||||
|
||||
if "bot" in params:
|
||||
kwargs["bot"] = bot
|
||||
if "event" in params:
|
||||
kwargs["event"] = event
|
||||
if "args" in params and args is not None:
|
||||
kwargs["args"] = args
|
||||
|
||||
# 执行函数
|
||||
await func(**kwargs)
|
||||
|
||||
|
||||
# 确保前缀是元组格式
|
||||
if isinstance(comm_prefixes, list):
|
||||
comm_prefixes = tuple[Any, ...](comm_prefixes)
|
||||
elif isinstance(comm_prefixes, str):
|
||||
comm_prefixes = (comm_prefixes,)
|
||||
|
||||
# 实例化全局管理器
|
||||
matcher = CommandManager(prefixes=comm_prefixes)
|
||||
@@ -4,9 +4,13 @@
|
||||
负责读取和解析 config.toml 配置文件,提供全局配置对象。
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import tomllib
|
||||
from pydantic import ValidationError
|
||||
from .config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel
|
||||
from .utils.logger import logger, ModuleLogger
|
||||
from .utils.exceptions import ConfigError, ConfigNotFoundError, ConfigValidationError
|
||||
from .utils.error_codes import ErrorCode, create_error_response
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -21,55 +25,97 @@ class Config:
|
||||
:param file_path: 配置文件路径,默认为 "config.toml"
|
||||
"""
|
||||
self.path = Path(file_path)
|
||||
self._data: Dict[str, Any] = {}
|
||||
self._model: ConfigModel
|
||||
# 创建模块专用日志记录器
|
||||
self.logger = ModuleLogger("ConfigLoader")
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
加载配置文件
|
||||
加载并验证配置文件
|
||||
|
||||
:raises FileNotFoundError: 如果配置文件不存在
|
||||
:raises ConfigNotFoundError: 如果配置文件不存在
|
||||
:raises ConfigValidationError: 如果配置格式不正确
|
||||
:raises ConfigError: 如果加载配置时发生其他错误
|
||||
"""
|
||||
if not self.path.exists():
|
||||
raise FileNotFoundError(f"配置文件 {self.path} 未找到!")
|
||||
error = ConfigNotFoundError(message=f"配置文件 {self.path} 未找到!")
|
||||
self.logger.error(f"配置加载失败: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
raise error
|
||||
|
||||
with open(self.path, "rb") as f:
|
||||
self._data = tomllib.load(f)
|
||||
try:
|
||||
self.logger.info(f"正在从 {self.path} 加载配置...")
|
||||
with open(self.path, "rb") as f:
|
||||
raw_config = tomllib.load(f)
|
||||
|
||||
self._model = ConfigModel(**raw_config)
|
||||
self.logger.success("配置加载并验证成功!")
|
||||
|
||||
except ValidationError as e:
|
||||
error_details = []
|
||||
for error in e.errors():
|
||||
field = " -> ".join(map(str, error["loc"]))
|
||||
error_msg = f"字段 '{field}': {error['msg']}"
|
||||
error_details.append(error_msg)
|
||||
|
||||
validation_error = ConfigValidationError(
|
||||
message="配置验证失败",
|
||||
original_error=e
|
||||
)
|
||||
|
||||
self.logger.error("配置验证失败,请检查 `config.toml` 文件中的以下错误:")
|
||||
for detail in error_details:
|
||||
self.logger.error(f" - {detail}")
|
||||
|
||||
self.logger.log_custom_exception(validation_error)
|
||||
raise validation_error
|
||||
except tomllib.TOMLDecodeError as e:
|
||||
error = ConfigError(
|
||||
message=f"TOML解析错误: {str(e)}",
|
||||
original_error=e
|
||||
)
|
||||
self.logger.error(f"加载配置文件时发生TOML解析错误: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
raise error
|
||||
except Exception as e:
|
||||
error = ConfigError(
|
||||
message=f"加载配置文件时发生未知错误: {str(e)}",
|
||||
original_error=e
|
||||
)
|
||||
self.logger.exception(f"加载配置文件时发生未知错误: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
raise error
|
||||
|
||||
# 通过属性访问配置
|
||||
@property
|
||||
def napcat_ws(self) -> dict:
|
||||
def napcat_ws(self) -> NapCatWSModel:
|
||||
"""
|
||||
获取 NapCat WebSocket 配置
|
||||
|
||||
:return: 配置字典
|
||||
"""
|
||||
return self._data.get("napcat_ws", {})
|
||||
return self._model.napcat_ws
|
||||
|
||||
@property
|
||||
def bot(self) -> dict:
|
||||
def bot(self) -> BotModel:
|
||||
"""
|
||||
获取 Bot 基础配置
|
||||
|
||||
:return: 配置字典
|
||||
"""
|
||||
return self._data.get("bot", {})
|
||||
return self._model.bot
|
||||
|
||||
@property
|
||||
def features(self) -> dict:
|
||||
def redis(self) -> RedisModel:
|
||||
"""
|
||||
获取功能特性配置
|
||||
获取 Redis 配置
|
||||
"""
|
||||
return self._model.redis
|
||||
|
||||
:return: 配置字典
|
||||
@property
|
||||
def docker(self) -> DockerModel:
|
||||
"""
|
||||
return self._data.get("features", {})
|
||||
获取 Docker 配置
|
||||
"""
|
||||
return self._model.docker
|
||||
|
||||
|
||||
# 实例化全局配置对象
|
||||
global_config = Config()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(global_config.napcat_ws)
|
||||
print(global_config.bot.get("command"))
|
||||
print(type(global_config.bot.get("command")) is list)
|
||||
print(global_config.features)
|
||||
|
||||
60
core/config_models.py
Normal file
60
core/config_models.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Pydantic 配置模型模块
|
||||
|
||||
该模块使用 Pydantic 定义了与 `config.toml` 文件结构完全对应的配置模型。
|
||||
这使得配置的加载、校验和访问都变得类型安全和健壮。
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NapCatWSModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[napcat_ws]` 配置块。
|
||||
"""
|
||||
uri: str
|
||||
token: str
|
||||
reconnect_interval: int = 5
|
||||
|
||||
|
||||
class BotModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[bot]` 配置块。
|
||||
"""
|
||||
command: List[str] = Field(default_factory=lambda: ["/"])
|
||||
ignore_self_message: bool = True
|
||||
permission_denied_message: str = "权限不足,需要 {permission_name} 权限"
|
||||
|
||||
|
||||
class RedisModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[redis]` 配置块。
|
||||
"""
|
||||
host: str
|
||||
port: int
|
||||
db: int
|
||||
password: str
|
||||
|
||||
|
||||
class DockerModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[docker]` 配置块。
|
||||
"""
|
||||
base_url: Optional[str] = None
|
||||
sandbox_image: str = "python-sandbox:latest"
|
||||
timeout: int = 10
|
||||
concurrency_limit: int = 5
|
||||
tls_verify: bool = False
|
||||
ca_cert_path: Optional[str] = None
|
||||
client_cert_path: Optional[str] = None
|
||||
client_key_path: Optional[str] = None
|
||||
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
"""
|
||||
顶层配置模型,整合了所有子配置块。
|
||||
"""
|
||||
napcat_ws: NapCatWSModel
|
||||
bot: BotModel
|
||||
redis: RedisModel
|
||||
docker: DockerModel
|
||||
3
core/data/admin.json
Normal file
3
core/data/admin.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"admins": [2221577113]
|
||||
}
|
||||
3
core/data/permissions.json
Normal file
3
core/data/permissions.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"users": {}
|
||||
}
|
||||
0
core/handlers/__init__.py
Normal file
0
core/handlers/__init__.py
Normal file
240
core/handlers/event_handler.py
Normal file
240
core/handlers/event_handler.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
事件处理器模块
|
||||
|
||||
该模块定义了用于处理不同类型事件的处理器类。
|
||||
每个处理器都负责注册和分发特定类型的事件。
|
||||
"""
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bot import Bot
|
||||
from ..config_loader import global_config
|
||||
from ..permission import Permission
|
||||
from ..utils.executor import run_in_thread_pool
|
||||
|
||||
|
||||
class BaseHandler(ABC):
|
||||
"""
|
||||
事件处理器抽象基类
|
||||
"""
|
||||
def __init__(self):
|
||||
self.handlers: List[Dict[str, Any]] = []
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _run_handler(
|
||||
self,
|
||||
func: Callable,
|
||||
bot: "Bot",
|
||||
event: Any,
|
||||
args: Optional[List[str]] = None,
|
||||
permission_granted: Optional[bool] = None
|
||||
):
|
||||
"""
|
||||
智能执行事件处理器,并注入所需参数
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
params = sig.parameters
|
||||
kwargs: Dict[str, Any] = {}
|
||||
|
||||
if "bot" in params:
|
||||
kwargs["bot"] = bot
|
||||
if "event" in params:
|
||||
kwargs["event"] = event
|
||||
if "args" in params and args is not None:
|
||||
kwargs["args"] = args
|
||||
if "permission_granted" in params and permission_granted is not None:
|
||||
kwargs["permission_granted"] = permission_granted
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
result = await func(**kwargs)
|
||||
else:
|
||||
# 如果是同步函数,则放入线程池执行
|
||||
result = await run_in_thread_pool(func, **kwargs)
|
||||
return result is True
|
||||
|
||||
|
||||
class MessageHandler(BaseHandler):
|
||||
"""
|
||||
消息事件处理器
|
||||
"""
|
||||
def __init__(self, prefixes: Tuple[str, ...]):
|
||||
super().__init__()
|
||||
self.prefixes = prefixes
|
||||
self.commands: Dict[str, Dict] = {}
|
||||
self.message_handlers: List[Dict[str, Any]] = []
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
清空所有已注册的消息和命令处理器
|
||||
"""
|
||||
self.commands.clear()
|
||||
self.message_handlers.clear()
|
||||
|
||||
def unregister_by_plugin_name(self, plugin_name: str):
|
||||
"""
|
||||
根据插件名卸载相关的消息和命令处理器
|
||||
"""
|
||||
# 卸载命令
|
||||
commands_to_remove = [name for name, info in self.commands.items() if info["plugin_name"] == plugin_name]
|
||||
for name in commands_to_remove:
|
||||
del self.commands[name]
|
||||
|
||||
# 卸载通用消息处理器
|
||||
self.message_handlers = [h for h in self.message_handlers if h["plugin_name"] != plugin_name]
|
||||
|
||||
def on_message(self) -> Callable:
|
||||
"""
|
||||
注册通用消息处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
self.message_handlers.append({"func": func, "plugin_name": plugin_name})
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def command(
|
||||
self,
|
||||
*names: str,
|
||||
permission: Optional[Permission] = None,
|
||||
override_permission_check: bool = False
|
||||
) -> Callable:
|
||||
"""
|
||||
注册命令处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
for name in names:
|
||||
self.commands[name] = {
|
||||
"func": func,
|
||||
"permission": permission,
|
||||
"override_permission_check": override_permission_check,
|
||||
"plugin_name": plugin_name,
|
||||
}
|
||||
return func
|
||||
return decorator
|
||||
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理消息事件,分发给命令处理器或通用消息处理器
|
||||
"""
|
||||
from ..managers import permission_manager
|
||||
for handler_info in self.message_handlers:
|
||||
consumed = await self._run_handler(handler_info["func"], bot, event)
|
||||
if consumed:
|
||||
return
|
||||
|
||||
if not event.raw_message:
|
||||
return
|
||||
|
||||
raw_text = event.raw_message.strip()
|
||||
prefix_found = next((p for p in self.prefixes if raw_text.startswith(p)), None)
|
||||
|
||||
if not prefix_found:
|
||||
return
|
||||
|
||||
command_parts = raw_text[len(prefix_found):].split()
|
||||
if not command_parts:
|
||||
return
|
||||
|
||||
command_name = command_parts[0]
|
||||
args = command_parts[1:]
|
||||
|
||||
if command_name in self.commands:
|
||||
command_info = self.commands[command_name]
|
||||
func = command_info["func"]
|
||||
permission = command_info.get("permission")
|
||||
override_check = command_info.get("override_permission_check", False)
|
||||
|
||||
permission_granted = True
|
||||
if permission:
|
||||
permission_granted = await permission_manager.check_permission(event.user_id, permission)
|
||||
|
||||
if not permission_granted and not override_check:
|
||||
permission_name = permission.name if isinstance(permission, Permission) else permission
|
||||
message_template = global_config.bot.permission_denied_message
|
||||
await bot.send(event, message_template.format(permission_name=permission_name))
|
||||
return
|
||||
|
||||
await self._run_handler(
|
||||
func,
|
||||
bot,
|
||||
event,
|
||||
args=args,
|
||||
permission_granted=permission_granted
|
||||
)
|
||||
|
||||
|
||||
class NoticeHandler(BaseHandler):
|
||||
"""
|
||||
通知事件处理器
|
||||
"""
|
||||
def clear(self):
|
||||
self.handlers.clear()
|
||||
|
||||
def unregister_by_plugin_name(self, plugin_name: str):
|
||||
"""
|
||||
根据插件名卸载相关的通知处理器
|
||||
"""
|
||||
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
|
||||
|
||||
def register(self, notice_type: Optional[str] = None) -> Callable:
|
||||
"""
|
||||
注册通知处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
self.handlers.append({"type": notice_type, "func": func, "plugin_name": plugin_name})
|
||||
return func
|
||||
return decorator
|
||||
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理通知事件
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
if handler["type"] is None or handler["type"] == event.notice_type:
|
||||
await self._run_handler(handler["func"], bot, event)
|
||||
|
||||
|
||||
class RequestHandler(BaseHandler):
|
||||
"""
|
||||
请求事件处理器
|
||||
"""
|
||||
def clear(self):
|
||||
self.handlers.clear()
|
||||
|
||||
def unregister_by_plugin_name(self, plugin_name: str):
|
||||
"""
|
||||
根据插件名卸载相关的请求处理器
|
||||
"""
|
||||
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
|
||||
|
||||
def register(self, request_type: Optional[str] = None) -> Callable:
|
||||
"""
|
||||
注册请求处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
self.handlers.append({"type": request_type, "func": func, "plugin_name": plugin_name})
|
||||
return func
|
||||
return decorator
|
||||
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理请求事件
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
if handler["type"] is None or handler["type"] == event.request_type:
|
||||
await self._run_handler(handler["func"], bot, event)
|
||||
48
core/managers/__init__.py
Normal file
48
core/managers/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
管理器包
|
||||
|
||||
这个包集中了机器人核心的单例管理器。
|
||||
通过从这里导入,可以确保在整个应用中访问到的都是同一个实例。
|
||||
"""
|
||||
from .admin_manager import AdminManager
|
||||
from .command_manager import matcher as command_manager
|
||||
from .permission_manager import PermissionManager
|
||||
from .plugin_manager import PluginManager
|
||||
from .redis_manager import RedisManager
|
||||
from .browser_manager import BrowserManager
|
||||
from .image_manager import ImageManager
|
||||
|
||||
# --- 实例化所有单例管理器 ---
|
||||
|
||||
# 管理员管理器
|
||||
admin_manager = AdminManager()
|
||||
|
||||
# 权限管理器
|
||||
permission_manager = PermissionManager()
|
||||
|
||||
# 命令与事件管理器 (别名 matcher)
|
||||
matcher = command_manager
|
||||
|
||||
# 插件管理器
|
||||
plugin_manager = PluginManager(command_manager)
|
||||
plugin_manager.load_all_plugins()
|
||||
|
||||
# Redis 管理器
|
||||
redis_manager = RedisManager()
|
||||
|
||||
# 浏览器管理器
|
||||
browser_manager = BrowserManager()
|
||||
|
||||
# 图片管理器
|
||||
image_manager = ImageManager()
|
||||
|
||||
__all__ = [
|
||||
"admin_manager",
|
||||
"permission_manager",
|
||||
"command_manager",
|
||||
"matcher",
|
||||
"plugin_manager",
|
||||
"redis_manager",
|
||||
"browser_manager",
|
||||
"image_manager",
|
||||
]
|
||||
150
core/managers/admin_manager.py
Normal file
150
core/managers/admin_manager.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
管理员管理器模块
|
||||
|
||||
该模块负责管理机器人的管理员列表。
|
||||
它现在以 Redis 作为主要数据源,文件仅用作备份。
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from typing import Set
|
||||
|
||||
from ..utils.logger import logger
|
||||
from ..utils.singleton import Singleton
|
||||
from .redis_manager import redis_manager
|
||||
|
||||
|
||||
class AdminManager(Singleton):
|
||||
"""
|
||||
管理员管理器类
|
||||
|
||||
以 Redis Set 作为管理员列表的唯一真实来源,提供高速的读写能力。
|
||||
文件 (admin.json) 仅用于首次启动时的数据迁移和作为灾备。
|
||||
"""
|
||||
_REDIS_KEY = "neobot:admins" # 用于存储管理员集合的 Redis 键
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化 AdminManager
|
||||
"""
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
|
||||
# 管理员数据文件路径,主要用于备份和首次迁移
|
||||
self.data_file = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"..",
|
||||
"data",
|
||||
"admin.json"
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(self.data_file), exist_ok=True)
|
||||
logger.info("管理员管理器初始化完成")
|
||||
super().__init__()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
异步初始化,检查 Redis 数据,如果为空则尝试从文件迁移
|
||||
"""
|
||||
try:
|
||||
# 检查 Redis 中是否已存在数据
|
||||
if await redis_manager.redis.exists(self._REDIS_KEY):
|
||||
admin_count = await redis_manager.redis.scard(self._REDIS_KEY)
|
||||
logger.info(f"Redis 中已存在管理员数据,共 {admin_count} 位。")
|
||||
else:
|
||||
# Redis 为空,尝试从文件迁移
|
||||
logger.info("Redis 中未找到管理员数据,尝试从 admin.json 文件迁移...")
|
||||
await self._migrate_from_file_to_redis()
|
||||
except Exception as e:
|
||||
logger.error(f"初始化管理员数据时发生错误: {e}")
|
||||
|
||||
async def _migrate_from_file_to_redis(self):
|
||||
"""
|
||||
从 admin.json 加载管理员列表并存入 Redis
|
||||
这通常只在首次启动或 Redis 数据丢失时执行一次
|
||||
"""
|
||||
admins_to_migrate = set()
|
||||
try:
|
||||
if os.path.exists(self.data_file):
|
||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
admins = data.get("admins", [])
|
||||
admins_to_migrate = set(int(admin_id) for admin_id in admins)
|
||||
|
||||
if admins_to_migrate:
|
||||
await redis_manager.redis.sadd(self._REDIS_KEY, *admins_to_migrate)
|
||||
logger.success(f"成功从文件迁移 {len(admins_to_migrate)} 位管理员到 Redis。")
|
||||
else:
|
||||
logger.info("admin.json 文件为空或不存在,无需迁移。")
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"解析 admin.json 失败,无法迁移: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移管理员数据到 Redis 失败: {e}")
|
||||
|
||||
async def _save_to_file_backup(self):
|
||||
"""
|
||||
将 Redis 中的管理员列表备份到 admin.json
|
||||
"""
|
||||
try:
|
||||
admins = await self.get_all_admins()
|
||||
admin_list = [str(admin_id) for admin_id in admins]
|
||||
with open(self.data_file, "w", encoding="utf-8") as f:
|
||||
json.dump({"admins": admin_list}, f, indent=2, ensure_ascii=False)
|
||||
logger.debug(f"管理员列表已备份到 {self.data_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"备份管理员列表到 admin.json 失败: {e}")
|
||||
|
||||
async def is_admin(self, user_id: int) -> bool:
|
||||
"""
|
||||
检查用户是否为管理员(直接从 Redis 读取)
|
||||
"""
|
||||
try:
|
||||
return await redis_manager.redis.sismember(self._REDIS_KEY, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 检查管理员权限失败: {e}")
|
||||
return False
|
||||
|
||||
async def add_admin(self, user_id: int) -> bool:
|
||||
"""
|
||||
添加管理员到 Redis,并更新文件备份
|
||||
"""
|
||||
try:
|
||||
# sadd 返回成功添加的成员数量,1 表示成功,0 表示已存在
|
||||
if await redis_manager.redis.sadd(self._REDIS_KEY, user_id) == 1:
|
||||
logger.info(f"已添加新管理员 {user_id} 到 Redis")
|
||||
await self._save_to_file_backup() # 更新备份
|
||||
return True
|
||||
return False # 用户已经是管理员
|
||||
except Exception as e:
|
||||
logger.error(f"添加管理员 {user_id} 到 Redis 失败: {e}")
|
||||
return False
|
||||
|
||||
async def remove_admin(self, user_id: int) -> bool:
|
||||
"""
|
||||
从 Redis 移除管理员,并更新文件备份
|
||||
"""
|
||||
try:
|
||||
# srem 返回成功移除的成员数量,1 表示成功,0 表示不存在
|
||||
if await redis_manager.redis.srem(self._REDIS_KEY, user_id) == 1:
|
||||
logger.info(f"已从 Redis 移除管理员 {user_id}")
|
||||
await self._save_to_file_backup() # 更新备份
|
||||
return True
|
||||
return False # 用户不是管理员
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 移除管理员 {user_id} 失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_all_admins(self) -> Set[int]:
|
||||
"""
|
||||
从 Redis 获取所有管理员的集合
|
||||
"""
|
||||
try:
|
||||
admins = await redis_manager.redis.smembers(self._REDIS_KEY)
|
||||
return {int(admin_id) for admin_id in admins}
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 获取所有管理员失败: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
# 全局 AdminManager 实例
|
||||
admin_manager = AdminManager()
|
||||
151
core/managers/browser_manager.py
Normal file
151
core/managers/browser_manager.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
浏览器管理器模块
|
||||
|
||||
负责管理全局唯一的 Playwright 浏览器实例,避免频繁启动/关闭浏览器的开销。
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from playwright.async_api import async_playwright, Browser, Playwright, Page
|
||||
from ..utils.logger import logger
|
||||
|
||||
class BrowserManager:
|
||||
"""
|
||||
浏览器管理器(异步单例)
|
||||
"""
|
||||
_instance = None
|
||||
_playwright: Optional[Playwright] = None
|
||||
_browser: Optional[Browser] = None
|
||||
_page_pool: Optional[asyncio.Queue] = None
|
||||
_pool_size: int = 3
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化 Playwright 和 Browser
|
||||
"""
|
||||
if self._browser is None:
|
||||
try:
|
||||
logger.info("正在启动无头浏览器...")
|
||||
self._playwright = await async_playwright().start()
|
||||
# 启动 Chromium,headless=True 表示无头模式
|
||||
self._browser = await self._playwright.chromium.launch(headless=True)
|
||||
logger.success("无头浏览器启动成功!")
|
||||
except Exception as e:
|
||||
logger.exception(f"无头浏览器启动失败: {e}")
|
||||
self._browser = None
|
||||
|
||||
async def init_pool(self, size: int = 3):
|
||||
"""
|
||||
初始化页面池
|
||||
"""
|
||||
if not self._browser:
|
||||
await self.initialize()
|
||||
|
||||
if not self._browser:
|
||||
logger.error("浏览器初始化失败,无法创建页面池")
|
||||
return
|
||||
|
||||
self._pool_size = size
|
||||
self._page_pool = asyncio.Queue(maxsize=size)
|
||||
|
||||
logger.info(f"正在初始化页面池 (大小: {size})...")
|
||||
for i in range(size):
|
||||
try:
|
||||
page = await self._browser.new_page()
|
||||
await self._page_pool.put(page)
|
||||
except Exception as e:
|
||||
logger.error(f"创建页面池页面 {i+1} 失败: {e}")
|
||||
|
||||
logger.success(f"页面池初始化完成,当前可用页面: {self._page_pool.qsize()}")
|
||||
|
||||
async def get_page(self) -> Optional[Page]:
|
||||
"""
|
||||
从池中获取一个页面。如果池未初始化或为空,则尝试创建一个新页面(不入池)。
|
||||
"""
|
||||
if self._page_pool and not self._page_pool.empty():
|
||||
try:
|
||||
page = self._page_pool.get_nowait()
|
||||
# 简单的健康检查
|
||||
if page.is_closed():
|
||||
logger.warning("检测到池中页面已关闭,重新创建一个...")
|
||||
if self._browser:
|
||||
page = await self._browser.new_page()
|
||||
else:
|
||||
return None
|
||||
return page
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
|
||||
# 如果池空了或者没初始化,回退到临时创建
|
||||
logger.debug("页面池为空或未初始化,创建临时页面")
|
||||
return await self.get_new_page()
|
||||
|
||||
async def release_page(self, page: Page):
|
||||
"""
|
||||
归还页面到池中。如果池已满或未初始化,则关闭页面。
|
||||
"""
|
||||
if not page or page.is_closed():
|
||||
return
|
||||
|
||||
if self._page_pool:
|
||||
try:
|
||||
# 重置页面状态 (例如清空内容),防止数据污染
|
||||
# 注意: goto('about:blank') 比 close() 快得多
|
||||
await page.goto("about:blank")
|
||||
|
||||
self._page_pool.put_nowait(page)
|
||||
return
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
# 池满或未启用池,直接关闭
|
||||
await page.close()
|
||||
|
||||
async def get_new_page(self) -> Optional[Page]:
|
||||
"""
|
||||
获取一个新的页面 (Page)
|
||||
|
||||
使用完毕后,调用者应该负责关闭该页面 (await page.close())
|
||||
"""
|
||||
if self._browser is None:
|
||||
logger.warning("浏览器尚未初始化,尝试重新初始化...")
|
||||
await self.initialize()
|
||||
|
||||
if self._browser:
|
||||
try:
|
||||
return await self._browser.new_page()
|
||||
except Exception as e:
|
||||
logger.error(f"创建新页面失败: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
async def shutdown(self):
|
||||
"""
|
||||
关闭浏览器和 Playwright
|
||||
"""
|
||||
# 清空页面池
|
||||
if self._page_pool:
|
||||
while not self._page_pool.empty():
|
||||
try:
|
||||
page = self._page_pool.get_nowait()
|
||||
await page.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._page_pool = None
|
||||
|
||||
if self._browser:
|
||||
await self._browser.close()
|
||||
self._browser = None
|
||||
logger.info("浏览器已关闭")
|
||||
|
||||
if self._playwright:
|
||||
await self._playwright.stop()
|
||||
self._playwright = None
|
||||
logger.info("Playwright 已停止")
|
||||
|
||||
# 全局浏览器管理器实例
|
||||
browser_manager = BrowserManager()
|
||||
235
core/managers/command_manager.py
Normal file
235
core/managers/command_manager.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
命令与事件管理器模块
|
||||
|
||||
该模块定义了 `CommandManager` 类,它是整个机器人框架事件处理的核心。
|
||||
它通过装饰器模式,为插件提供了注册消息指令、通知事件处理器和
|
||||
请求事件处理器的能力。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
import os
|
||||
import base64
|
||||
|
||||
from models.events.message import MessageSegment
|
||||
|
||||
from models.events.message import MessageSegment
|
||||
|
||||
from ..config_loader import global_config
|
||||
from ..handlers.event_handler import MessageHandler, NoticeHandler, RequestHandler
|
||||
from .redis_manager import redis_manager
|
||||
from .image_manager import image_manager
|
||||
from ..utils.logger import logger
|
||||
|
||||
# 从配置中获取命令前缀
|
||||
_config_prefixes = global_config.bot.command
|
||||
|
||||
# 确保前缀配置是元组格式
|
||||
_final_prefixes: Tuple[str, ...]
|
||||
if isinstance(_config_prefixes, list):
|
||||
_final_prefixes = tuple(_config_prefixes)
|
||||
elif isinstance(_config_prefixes, str):
|
||||
_final_prefixes = (_config_prefixes,)
|
||||
else:
|
||||
_final_prefixes = tuple(_config_prefixes)
|
||||
|
||||
|
||||
class CommandManager:
|
||||
"""
|
||||
命令管理器,负责注册和分发所有类型的事件。
|
||||
|
||||
这是一个单例对象(`matcher`),在整个应用中共享。
|
||||
它将不同类型的事件处理委托给专门的处理器类。
|
||||
"""
|
||||
|
||||
def __init__(self, prefixes: Tuple[str, ...]):
|
||||
"""
|
||||
初始化命令管理器。
|
||||
|
||||
Args:
|
||||
prefixes (Tuple[str, ...]): 一个包含所有合法命令前缀的元组。
|
||||
"""
|
||||
self.plugins: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 初始化专门的事件处理器
|
||||
self.message_handler = MessageHandler(prefixes)
|
||||
self.notice_handler = NoticeHandler()
|
||||
self.request_handler = RequestHandler()
|
||||
|
||||
# 将处理器映射到事件类型
|
||||
self.handler_map = {
|
||||
"message": self.message_handler,
|
||||
"notice": self.notice_handler,
|
||||
"request": self.request_handler,
|
||||
}
|
||||
|
||||
# 注册内置的 /help 命令
|
||||
self._register_internal_commands()
|
||||
|
||||
async def sync_help_pic(self):
|
||||
"""
|
||||
启动时或插件重载时同步 help 图片到 Redis
|
||||
"""
|
||||
try:
|
||||
logger.info("正在生成帮助图片...")
|
||||
|
||||
# 1. 收集插件数据
|
||||
plugins_data = []
|
||||
for plugin_name, meta in self.plugins.items():
|
||||
plugins_data.append({
|
||||
"name": meta.get("name", plugin_name),
|
||||
"description": meta.get("description", "暂无描述"),
|
||||
"usage": meta.get("usage", "暂无用法")
|
||||
})
|
||||
|
||||
# 2. 渲染图片
|
||||
# 使用 png 格式以获得更好的文字清晰度
|
||||
base64_str = await image_manager.render_template_to_base64(
|
||||
template_name="help.html",
|
||||
data={"plugins": plugins_data},
|
||||
output_name="help_menu.png",
|
||||
image_type="png"
|
||||
)
|
||||
|
||||
if base64_str:
|
||||
await redis_manager.set("neobot:core:help_pic", base64_str)
|
||||
logger.success("帮助图片已更新并缓存到 Redis")
|
||||
else:
|
||||
logger.error("帮助图片生成失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步帮助图片失败: {e}")
|
||||
|
||||
def _register_internal_commands(self):
|
||||
"""
|
||||
注册框架内置的命令
|
||||
"""
|
||||
# Help 命令
|
||||
self.message_handler.command("help")(self._help_command)
|
||||
self.plugins["core.help"] = {
|
||||
"name": "帮助",
|
||||
"description": "显示所有可用指令的帮助信息",
|
||||
"usage": "/help",
|
||||
}
|
||||
|
||||
def clear_all_handlers(self):
|
||||
"""
|
||||
清空所有已注册的事件处理器。
|
||||
注意:这也会移除内置的 /help 命令,因此需要重新注册。
|
||||
"""
|
||||
self.message_handler.clear()
|
||||
self.notice_handler.clear()
|
||||
self.request_handler.clear()
|
||||
self.plugins.clear()
|
||||
|
||||
# 清空后,需要重新注册内置命令
|
||||
self._register_internal_commands()
|
||||
|
||||
def unload_plugin(self, plugin_name: str):
|
||||
"""
|
||||
卸载指定插件的所有处理器和命令。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件的模块名 (例如 'plugins.bili_parser')
|
||||
"""
|
||||
self.message_handler.unregister_by_plugin_name(plugin_name)
|
||||
self.notice_handler.unregister_by_plugin_name(plugin_name)
|
||||
self.request_handler.unregister_by_plugin_name(plugin_name)
|
||||
|
||||
# 移除插件元信息
|
||||
plugins_to_remove = [name for name in self.plugins if name == plugin_name]
|
||||
for name in plugins_to_remove:
|
||||
del self.plugins[name]
|
||||
|
||||
# --- 装饰器代理 ---
|
||||
|
||||
def on_message(self) -> Callable:
|
||||
"""
|
||||
装饰器:注册一个通用的消息处理器。
|
||||
"""
|
||||
return self.message_handler.on_message()
|
||||
|
||||
def command(
|
||||
self,
|
||||
*names: str,
|
||||
permission: Optional[Any] = None,
|
||||
override_permission_check: bool = False,
|
||||
) -> Callable:
|
||||
"""
|
||||
装饰器:注册一个消息指令处理器。
|
||||
"""
|
||||
return self.message_handler.command(
|
||||
*names,
|
||||
permission=permission,
|
||||
override_permission_check=override_permission_check,
|
||||
)
|
||||
|
||||
def on_notice(self, notice_type: Optional[str] = None) -> Callable:
|
||||
"""
|
||||
装饰器:注册一个通知事件处理器。
|
||||
"""
|
||||
return self.notice_handler.register(notice_type=notice_type)
|
||||
|
||||
def on_request(self, request_type: Optional[str] = None) -> Callable:
|
||||
"""
|
||||
装饰器:注册一个请求事件处理器。
|
||||
"""
|
||||
return self.request_handler.register(request_type=request_type)
|
||||
|
||||
# --- 事件处理 ---
|
||||
|
||||
async def handle_event(self, bot, event):
|
||||
"""
|
||||
统一的事件分发入口。
|
||||
|
||||
根据事件的 `post_type` 将其分发给对应的处理器。
|
||||
"""
|
||||
if event.post_type == "message" and global_config.bot.ignore_self_message:
|
||||
if (
|
||||
hasattr(event, "user_id")
|
||||
and hasattr(event, "self_id")
|
||||
and event.user_id == event.self_id
|
||||
):
|
||||
return
|
||||
|
||||
handler = self.handler_map.get(event.post_type)
|
||||
if handler:
|
||||
await handler.handle(bot, event)
|
||||
|
||||
# --- 内置命令实现 ---
|
||||
|
||||
async def _help_command(self, bot, event):
|
||||
"""
|
||||
内置的 `/help` 命令的实现。
|
||||
直接从 Redis 获取缓存的图片。
|
||||
"""
|
||||
try:
|
||||
# 1. 尝试从 Redis 获取
|
||||
help_pic = await redis_manager.get("neobot:core:help_pic")
|
||||
|
||||
if not help_pic:
|
||||
await bot.send(event, "帮助图片缓存缺失,正在重新生成...")
|
||||
await self.sync_help_pic()
|
||||
help_pic = await redis_manager.get("neobot:core:help_pic")
|
||||
|
||||
if help_pic:
|
||||
await bot.send(event, MessageSegment.image(help_pic))
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"获取或生成帮助图片失败: {e}")
|
||||
|
||||
# 2. 最后的兜底:发送纯文本
|
||||
help_text = "--- 可用指令列表 ---\n"
|
||||
for plugin_name, meta in self.plugins.items():
|
||||
name = meta.get("name", "未命名插件")
|
||||
description = meta.get("description", "暂无描述")
|
||||
usage = meta.get("usage", "暂无用法说明")
|
||||
|
||||
help_text += f"\n{name}:\n"
|
||||
help_text += f" 功能: {description}\n"
|
||||
help_text += f" 用法: {usage}\n"
|
||||
|
||||
await bot.send(event, help_text.strip())
|
||||
|
||||
|
||||
# 实例化全局唯一的命令管理器
|
||||
matcher = CommandManager(prefixes=_final_prefixes)
|
||||
1
core/managers/help_pic.py
Normal file
1
core/managers/help_pic.py
Normal file
File diff suppressed because one or more lines are too long
123
core/managers/image_manager.py
Normal file
123
core/managers/image_manager.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
图片生成管理器模块
|
||||
|
||||
负责管理图片生成相关的逻辑,支持多种渲染引擎(目前支持 Playwright)。
|
||||
"""
|
||||
import os
|
||||
import base64
|
||||
from typing import Dict, Any, Optional
|
||||
from jinja2 import Template
|
||||
|
||||
from .browser_manager import browser_manager
|
||||
from ..utils.logger import logger
|
||||
|
||||
class ImageManager:
|
||||
"""
|
||||
图片生成管理器(单例)
|
||||
"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# 模板目录
|
||||
self.template_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "templates")
|
||||
# 临时文件目录
|
||||
# core/managers/image_manager.py -> core/managers -> core -> core/data/temp
|
||||
self.temp_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "temp")
|
||||
os.makedirs(self.temp_dir, exist_ok=True)
|
||||
# 模板缓存
|
||||
self._template_cache: Dict[str, Template] = {}
|
||||
|
||||
async def render_template(self, template_name: str, data: Dict[str, Any], output_name: str = "output.png", quality: int = 80, image_type: str = "png") -> Optional[str]:
|
||||
"""
|
||||
使用 Playwright 渲染 Jinja2 模板并保存为图片文件
|
||||
|
||||
Args:
|
||||
template_name (str): 模板文件名 (例如 "help.html")
|
||||
data (Dict[str, Any]): 传递给模板的数据字典
|
||||
output_name (str, optional): 输出文件名. Defaults to "output.png".
|
||||
quality (int, optional): JPEG 质量 (0-100). 仅在 image_type 为 jpeg 时有效. Defaults to 80.
|
||||
image_type (str, optional): 图片类型 ('png' or 'jpeg'). Defaults to "png".
|
||||
|
||||
Returns:
|
||||
Optional[str]: 生成图片的绝对路径,如果失败则返回 None
|
||||
"""
|
||||
template_path = os.path.join(self.template_dir, template_name)
|
||||
if not os.path.exists(template_path):
|
||||
logger.error(f"模板文件未找到: {template_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1. 渲染 HTML (使用缓存)
|
||||
if template_name in self._template_cache:
|
||||
template = self._template_cache[template_name]
|
||||
else:
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
template_str = f.read()
|
||||
template = Template(template_str)
|
||||
self._template_cache[template_name] = template
|
||||
|
||||
html_content = template.render(**data)
|
||||
|
||||
# 2. 使用浏览器截图
|
||||
# 改为从池中获取页面
|
||||
page = await browser_manager.get_page()
|
||||
if not page:
|
||||
logger.error("无法获取浏览器页面")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 设置视口
|
||||
await page.set_viewport_size({"width": 650, "height": 100})
|
||||
|
||||
# 加载内容
|
||||
await page.set_content(html_content)
|
||||
await page.wait_for_selector("body")
|
||||
|
||||
# 截图
|
||||
screenshot_args = {'full_page': True, 'type': image_type}
|
||||
if image_type == 'jpeg':
|
||||
screenshot_args['quality'] = quality
|
||||
|
||||
screenshot_bytes = await page.screenshot(**screenshot_args) # type: ignore
|
||||
|
||||
finally:
|
||||
# 归还页面到池中,而不是直接关闭
|
||||
await browser_manager.release_page(page)
|
||||
|
||||
# 3. 保存文件
|
||||
output_path = os.path.join(self.temp_dir, output_name)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(screenshot_bytes)
|
||||
|
||||
logger.info(f"图片已生成: {output_path} ({len(screenshot_bytes)/1024:.2f} KB)")
|
||||
return os.path.abspath(output_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"渲染模板 {template_name} 失败: {e}")
|
||||
return None
|
||||
|
||||
async def render_template_to_base64(self, template_name: str, data: Dict[str, Any], output_name: str = "output.png", quality: int = 80, image_type: str = "png") -> Optional[str]:
|
||||
"""
|
||||
渲染模板并返回 Base64 编码的图片字符串
|
||||
"""
|
||||
file_path = await self.render_template(template_name, data, output_name, quality, image_type)
|
||||
if not file_path:
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
mime_type = "image/jpeg" if image_type == "jpeg" else "image/png"
|
||||
return f"data:{mime_type};base64," + base64.b64encode(content).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"读取图片文件失败: {e}")
|
||||
return None
|
||||
|
||||
# 全局图片管理器实例
|
||||
image_manager = ImageManager()
|
||||
209
core/managers/permission_manager.py
Normal file
209
core/managers/permission_manager.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
权限管理器模块
|
||||
|
||||
该模块负责管理用户权限,支持 admin、op、user 三个权限级别。
|
||||
以 Redis Hash 作为主要数据源,文件仅用作备份和首次数据迁移。
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from ..utils.logger import logger
|
||||
from ..utils.singleton import Singleton
|
||||
from .admin_manager import admin_manager
|
||||
from .redis_manager import redis_manager
|
||||
from ..permission import Permission
|
||||
|
||||
|
||||
# 用于从字符串名称查找权限对象的字典
|
||||
_PERMISSIONS: Dict[str, Permission] = {
|
||||
p.value: p for p in Permission
|
||||
}
|
||||
|
||||
|
||||
class PermissionManager(Singleton):
|
||||
"""
|
||||
权限管理器类
|
||||
|
||||
以 Redis Hash 作为权限数据的唯一真实来源,提供高速的读写能力。
|
||||
文件 (permissions.json) 仅用于首次启动时的数据迁移和作为灾备。
|
||||
"""
|
||||
_REDIS_KEY = "neobot:permissions" # 用于存储用户权限的 Redis Hash 键
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化权限管理器
|
||||
"""
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
|
||||
# 权限数据文件路径,主要用于备份和首次迁移
|
||||
self.data_file = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"..",
|
||||
"data",
|
||||
"permissions.json"
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(self.data_file), exist_ok=True)
|
||||
logger.info("权限管理器初始化完成")
|
||||
super().__init__()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
异步初始化,检查 Redis 数据,如果为空则尝试从文件迁移
|
||||
"""
|
||||
try:
|
||||
if not await redis_manager.redis.exists(self._REDIS_KEY):
|
||||
logger.info("Redis 中未找到权限数据,尝试从 permissions.json 文件迁移...")
|
||||
await self._migrate_from_file_to_redis()
|
||||
else:
|
||||
perm_count = await redis_manager.redis.hlen(self._REDIS_KEY)
|
||||
logger.info(f"Redis 中已存在权限数据,共 {perm_count} 条。")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化权限数据时发生错误: {e}")
|
||||
|
||||
async def _migrate_from_file_to_redis(self):
|
||||
"""
|
||||
从 permissions.json 加载权限数据并存入 Redis Hash
|
||||
"""
|
||||
perms_to_migrate = {}
|
||||
try:
|
||||
if os.path.exists(self.data_file):
|
||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
perms_to_migrate = data.get("users", {})
|
||||
|
||||
if perms_to_migrate:
|
||||
# 使用 pipeline 批量写入,提高效率
|
||||
async with redis_manager.redis.pipeline(transaction=True) as pipe:
|
||||
for user_id, level_name in perms_to_migrate.items():
|
||||
pipe.hset(self._REDIS_KEY, user_id, level_name)
|
||||
await pipe.execute()
|
||||
logger.success(f"成功从文件迁移 {len(perms_to_migrate)} 条权限数据到 Redis。")
|
||||
else:
|
||||
logger.info("permissions.json 文件为空或不存在,无需迁移。")
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"解析 permissions.json 失败,无法迁移: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移权限数据到 Redis 失败: {e}")
|
||||
|
||||
async def _save_to_file_backup(self):
|
||||
"""
|
||||
将 Redis 中的权限数据完整备份到 permissions.json
|
||||
"""
|
||||
try:
|
||||
all_perms = await redis_manager.redis.hgetall(self._REDIS_KEY)
|
||||
# Redis 返回的是 bytes,需要解码
|
||||
users_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in all_perms.items()}
|
||||
with open(self.data_file, "w", encoding="utf-8") as f:
|
||||
json.dump({"users": users_data}, f, indent=2, ensure_ascii=False)
|
||||
logger.debug(f"权限数据已备份到 {self.data_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"备份权限数据到 permissions.json 失败: {e}")
|
||||
|
||||
async def get_user_permission(self, user_id: int) -> Permission:
|
||||
"""
|
||||
获取指定用户的权限对象
|
||||
|
||||
优先检查是否为机器人管理员,然后从 Redis 查询。
|
||||
"""
|
||||
if await admin_manager.is_admin(user_id):
|
||||
return Permission.ADMIN
|
||||
|
||||
try:
|
||||
level_name_bytes = await redis_manager.redis.hget(self._REDIS_KEY, str(user_id))
|
||||
if level_name_bytes:
|
||||
level_name = level_name_bytes.decode('utf-8')
|
||||
return _PERMISSIONS.get(level_name, Permission.USER)
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 获取用户 {user_id} 权限失败: {e}")
|
||||
|
||||
return Permission.USER
|
||||
|
||||
async def set_user_permission(self, user_id: int, permission: Permission) -> None:
|
||||
"""
|
||||
在 Redis 中设置指定用户的权限级别,并更新文件备份
|
||||
"""
|
||||
if not isinstance(permission, Permission):
|
||||
raise ValueError(f"无效的权限对象: {permission}")
|
||||
|
||||
try:
|
||||
await redis_manager.redis.hset(self._REDIS_KEY, str(user_id), permission.value)
|
||||
await self._save_to_file_backup()
|
||||
logger.info(f"已在 Redis 中设置用户 {user_id} 的权限为 {permission.value}")
|
||||
except Exception as e:
|
||||
logger.error(f"在 Redis 中设置用户 {user_id} 权限失败: {e}")
|
||||
|
||||
async def remove_user(self, user_id: int) -> None:
|
||||
"""
|
||||
从 Redis 中移除指定用户的权限设置,并更新文件备份
|
||||
"""
|
||||
try:
|
||||
if await redis_manager.redis.hdel(self._REDIS_KEY, str(user_id)):
|
||||
await self._save_to_file_backup()
|
||||
logger.info(f"已从 Redis 中移除用户 {user_id} 的权限设置")
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 移除用户 {user_id} 权限失败: {e}")
|
||||
|
||||
async def check_permission(self, user_id: int, required_permission: Permission) -> bool:
|
||||
"""
|
||||
检查用户是否具有指定权限级别
|
||||
"""
|
||||
user_permission = await self.get_user_permission(user_id)
|
||||
return user_permission >= required_permission
|
||||
|
||||
async def get_all_user_permissions(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取所有已配置的用户权限(合并 Redis 和 AdminManager)
|
||||
"""
|
||||
permissions = {}
|
||||
try:
|
||||
# 从 Redis 获取基础权限
|
||||
all_perms = await redis_manager.redis.hgetall(self._REDIS_KEY)
|
||||
permissions = {k.decode('utf-8'): v.decode('utf-8') for k, v in all_perms.items()}
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 获取所有权限失败: {e}")
|
||||
|
||||
# 合并 AdminManager 中的管理员,ADMIN 权限覆盖一切
|
||||
try:
|
||||
admins = await admin_manager.get_all_admins()
|
||||
for admin_id in admins:
|
||||
permissions[str(admin_id)] = Permission.ADMIN.value
|
||||
except Exception as e:
|
||||
logger.error(f"获取管理员列表以合并权限时失败: {e}")
|
||||
|
||||
return permissions
|
||||
|
||||
async def clear_all(self) -> None:
|
||||
"""
|
||||
清空 Redis 中的所有权限设置,并更新备份文件
|
||||
"""
|
||||
try:
|
||||
await redis_manager.redis.delete(self._REDIS_KEY)
|
||||
await self._save_to_file_backup()
|
||||
logger.info("已清空 Redis 中的所有权限设置")
|
||||
except Exception as e:
|
||||
logger.error(f"清空 Redis 权限数据失败: {e}")
|
||||
|
||||
|
||||
def require_admin(func):
|
||||
"""
|
||||
一个装饰器,用于限制命令只能由管理员执行。
|
||||
"""
|
||||
from functools import wraps
|
||||
from models.events.message import MessageEvent
|
||||
from core.managers import permission_manager
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(event: MessageEvent, *args, **kwargs):
|
||||
user_id = event.user_id
|
||||
if await permission_manager.check_permission(user_id, Permission.ADMIN):
|
||||
return await func(event, *args, **kwargs)
|
||||
else:
|
||||
# 假设 event 对象有 reply 方法
|
||||
if hasattr(event, "reply"):
|
||||
await event.reply("抱歉,您没有权限执行此命令。")
|
||||
return None
|
||||
return wrapper
|
||||
134
core/managers/plugin_manager.py
Normal file
134
core/managers/plugin_manager.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
插件管理器模块
|
||||
|
||||
负责扫描、加载和管理 `plugins` 目录下的所有插件。
|
||||
"""
|
||||
import importlib
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
from typing import Set
|
||||
from .command_manager import CommandManager
|
||||
|
||||
from ..utils.exceptions import SyncHandlerError, PluginError, PluginLoadError, PluginReloadError, PluginNotFoundError
|
||||
from ..utils.logger import logger, ModuleLogger
|
||||
from ..utils.error_codes import ErrorCode, create_error_response
|
||||
|
||||
# 确保logger在模块级别可见
|
||||
__all__ = ['PluginManager', 'logger']
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""
|
||||
插件管理器类
|
||||
"""
|
||||
def __init__(self, command_manager: "CommandManager") -> None:
|
||||
"""
|
||||
初始化插件管理器
|
||||
|
||||
:param command_manager: CommandManager的实例
|
||||
"""
|
||||
self.command_manager = command_manager
|
||||
self.loaded_plugins: Set[str] = set()
|
||||
# 创建模块专用日志记录器
|
||||
self.logger = ModuleLogger("PluginManager")
|
||||
|
||||
def load_all_plugins(self) -> None:
|
||||
"""
|
||||
扫描并加载 `plugins` 目录下的所有插件。
|
||||
"""
|
||||
# 使用 pathlib 获取更可靠的路径
|
||||
# 当前文件: core/managers/plugin_manager.py
|
||||
# 目标: plugins/
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 回退两级到项目根目录 (core/managers -> core -> root)
|
||||
root_dir = os.path.dirname(os.path.dirname(current_dir))
|
||||
plugin_dir = os.path.join(root_dir, "plugins")
|
||||
|
||||
package_name = "plugins"
|
||||
|
||||
if not os.path.exists(plugin_dir):
|
||||
self.logger.error(f"插件目录不存在: {plugin_dir}")
|
||||
return
|
||||
|
||||
self.logger.info(f"正在从 {package_name} 加载插件 (路径: {plugin_dir})...")
|
||||
|
||||
for _, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
|
||||
full_module_name = f"{package_name}.{module_name}"
|
||||
|
||||
action = "加载" # 初始化默认值
|
||||
try:
|
||||
if full_module_name in self.loaded_plugins:
|
||||
self.command_manager.unload_plugin(full_module_name)
|
||||
module = importlib.reload(sys.modules[full_module_name])
|
||||
action = "重载"
|
||||
else:
|
||||
module = importlib.import_module(full_module_name)
|
||||
action = "加载"
|
||||
|
||||
if hasattr(module, "__plugin_meta__"):
|
||||
meta = getattr(module, "__plugin_meta__")
|
||||
self.command_manager.plugins[full_module_name] = meta
|
||||
|
||||
self.loaded_plugins.add(full_module_name)
|
||||
|
||||
type_str = "包" if is_pkg else "文件"
|
||||
self.logger.success(f" [{type_str}] 成功{action}: {module_name}")
|
||||
except SyncHandlerError as e:
|
||||
error = PluginLoadError(
|
||||
plugin_name=module_name,
|
||||
message=f"同步处理器错误: {str(e)}",
|
||||
original_error=e
|
||||
)
|
||||
self.logger.error(f" 插件 {module_name} 加载失败: {error.message} (跳过此插件)")
|
||||
self.logger.log_custom_exception(error)
|
||||
except Exception as e:
|
||||
error = PluginLoadError(
|
||||
plugin_name=module_name,
|
||||
message=f"未知错误: {str(e)}",
|
||||
original_error=e
|
||||
)
|
||||
self.logger.exception(f" 加载插件 {module_name} 失败: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
|
||||
def reload_plugin(self, full_module_name: str) -> None:
|
||||
"""
|
||||
精确重载单个插件。
|
||||
"""
|
||||
if full_module_name not in self.loaded_plugins:
|
||||
self.logger.warning(f"尝试重载一个未被加载的插件: {full_module_name},将按首次加载处理。")
|
||||
|
||||
if full_module_name not in sys.modules:
|
||||
error = PluginNotFoundError(
|
||||
plugin_name=full_module_name,
|
||||
message="模块未在sys.modules中找到"
|
||||
)
|
||||
self.logger.error(f"重载失败: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
return
|
||||
|
||||
try:
|
||||
self.command_manager.unload_plugin(full_module_name)
|
||||
module = importlib.reload(sys.modules[full_module_name])
|
||||
|
||||
if hasattr(module, "__plugin_meta__"):
|
||||
meta = getattr(module, "__plugin_meta__")
|
||||
self.command_manager.plugins[full_module_name] = meta
|
||||
|
||||
self.logger.success(f"插件 {full_module_name} 已成功重载。")
|
||||
except SyncHandlerError as e:
|
||||
error = PluginReloadError(
|
||||
plugin_name=full_module_name,
|
||||
message=f"同步处理器错误: {str(e)}",
|
||||
original_error=e
|
||||
)
|
||||
self.logger.error(f"重载插件 {full_module_name} 失败: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
except Exception as e:
|
||||
error = PluginReloadError(
|
||||
plugin_name=full_module_name,
|
||||
message=f"未知错误: {str(e)}",
|
||||
original_error=e
|
||||
)
|
||||
self.logger.exception(f"重载插件 {full_module_name} 时发生错误: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
68
core/managers/redis_manager.py
Normal file
68
core/managers/redis_manager.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import redis.asyncio as redis
|
||||
from ..config_loader import global_config as config
|
||||
from ..utils.logger import logger
|
||||
|
||||
class RedisManager:
|
||||
"""
|
||||
Redis 连接管理器(异步单例)
|
||||
"""
|
||||
_instance = None
|
||||
_redis = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
异步初始化 Redis 连接并进行健康检查
|
||||
"""
|
||||
if self._redis is None:
|
||||
try:
|
||||
redis_config = config.redis
|
||||
host = redis_config.host
|
||||
port = redis_config.port
|
||||
db = redis_config.db
|
||||
password = redis_config.password
|
||||
|
||||
logger.info(f"正在尝试连接 Redis: {host}:{port}, DB: {db}")
|
||||
|
||||
self._redis = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True
|
||||
)
|
||||
if await self._redis.ping():
|
||||
logger.success("Redis 连接成功!")
|
||||
else:
|
||||
logger.error("Redis 连接失败: PING 命令无响应")
|
||||
except Exception as e:
|
||||
logger.exception(f"Redis 初始化时发生未知错误: {e}")
|
||||
self._redis = None
|
||||
|
||||
@property
|
||||
def redis(self):
|
||||
"""
|
||||
获取 Redis 连接实例
|
||||
"""
|
||||
if self._redis is None:
|
||||
raise ConnectionError("Redis 未初始化或连接失败,请先调用 initialize()")
|
||||
return self._redis
|
||||
|
||||
async def get(self, name):
|
||||
"""
|
||||
获取指定键的值
|
||||
"""
|
||||
return await self.redis.get(name)
|
||||
|
||||
async def set(self, name, value, ex=None):
|
||||
"""
|
||||
设置指定键的值
|
||||
"""
|
||||
return await self.redis.set(name, value, ex=ex)
|
||||
|
||||
# 全局 Redis 管理器实例
|
||||
redis_manager = RedisManager()
|
||||
42
core/permission.py
Normal file
42
core/permission.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Permission(Enum):
|
||||
"""
|
||||
定义用户权限等级的枚举类。
|
||||
|
||||
使用 @total_ordering 装饰器,只需定义 __lt__ 和 __eq__,
|
||||
即可自动实现所有比较运算符。
|
||||
"""
|
||||
USER = "user"
|
||||
OP = "op"
|
||||
ADMIN = "admin"
|
||||
|
||||
@property
|
||||
def _level_map(self):
|
||||
"""
|
||||
内部属性,用于映射枚举成员到整数等级。
|
||||
"""
|
||||
return {
|
||||
Permission.USER: 1,
|
||||
Permission.OP: 2,
|
||||
Permission.ADMIN: 3
|
||||
}
|
||||
|
||||
def __lt__(self, other):
|
||||
"""
|
||||
比较当前权限是否小于另一个权限。
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self._level_map[self] < self._level_map[other]
|
||||
|
||||
def __ge__(self, other):
|
||||
"""
|
||||
比较当前权限是否大于等于另一个权限。
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self._level_map[self] >= self._level_map[other]
|
||||
@@ -1,123 +0,0 @@
|
||||
"""
|
||||
插件管理器模块
|
||||
|
||||
负责扫描、加载和管理 `base_plugins` 目录下的所有插件。
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
|
||||
from core.command_manager import matcher
|
||||
|
||||
|
||||
def load_all_plugins():
|
||||
"""
|
||||
扫描并加载 `plugins` 目录下的所有插件。
|
||||
|
||||
该函数会遍历 `plugins` 目录下的所有模块:
|
||||
1. 如果模块已加载,则执行 reload 操作(用于热重载)。
|
||||
2. 如果模块未加载,则执行 import 操作。
|
||||
|
||||
加载过程中会提取插件元数据 `__plugin_meta__` 并注册到 CommandManager。
|
||||
"""
|
||||
plugin_dir = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "..", "plugins"
|
||||
)
|
||||
package_name = "plugins"
|
||||
|
||||
print(f" 正在从 {package_name} 加载插件...")
|
||||
|
||||
for loader, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
|
||||
full_module_name = f"{package_name}.{module_name}"
|
||||
|
||||
try:
|
||||
if full_module_name in sys.modules:
|
||||
module = importlib.reload(sys.modules[full_module_name])
|
||||
action = "重载"
|
||||
else:
|
||||
module = importlib.import_module(full_module_name)
|
||||
action = "加载"
|
||||
|
||||
# 提取插件元数据
|
||||
if hasattr(module, "__plugin_meta__"):
|
||||
meta = getattr(module, "__plugin_meta__")
|
||||
matcher.plugins[full_module_name] = meta
|
||||
|
||||
type_str = "包" if is_pkg else "文件"
|
||||
print(f" [{type_str}] 成功{action}: {module_name}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f" {action if 'action' in locals() else '加载'}插件 {module_name} 失败: {e}"
|
||||
)
|
||||
|
||||
|
||||
class PluginDataManager:
|
||||
"""
|
||||
用于管理插件产生的数据文件的类
|
||||
"""
|
||||
|
||||
def __init__(self, plugin_name: str):
|
||||
"""
|
||||
初始化插件数据管理器
|
||||
|
||||
:param plugin_name: 插件名称
|
||||
"""
|
||||
self.plugin_name = plugin_name
|
||||
self.data_file = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"..",
|
||||
"plugins",
|
||||
"data",
|
||||
self.plugin_name + ".json",
|
||||
)
|
||||
self.data = {}
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
"""读取配置文件"""
|
||||
if not os.path.exists(self.data_file):
|
||||
with open(self.data_file, "w", encoding="utf-8") as f:
|
||||
self.set(self.plugin_name, [])
|
||||
try:
|
||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
||||
self.data = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
self.data = {}
|
||||
|
||||
def save(self):
|
||||
"""保存配置到文件"""
|
||||
with open(self.data_file, "w", encoding="utf-8") as f:
|
||||
json.dump(self.data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""获取配置项"""
|
||||
return self.data.get(key, default)
|
||||
|
||||
def set(self, key, value):
|
||||
"""设置配置项"""
|
||||
self.data[key] = value
|
||||
self.save()
|
||||
|
||||
def add(self, key, value):
|
||||
"""添加配置项"""
|
||||
if key not in self.data:
|
||||
self.data[key] = []
|
||||
self.data[key].append(value)
|
||||
self.save()
|
||||
|
||||
def remove(self, key):
|
||||
"""删除配置项"""
|
||||
if key in self.data:
|
||||
del self.data[key]
|
||||
self.save()
|
||||
|
||||
def clear(self):
|
||||
"""清空所有配置"""
|
||||
self.data.clear()
|
||||
self.save()
|
||||
|
||||
def get_all(self):
|
||||
return self.data.copy()
|
||||
45
core/utils/__init__.py
Normal file
45
core/utils/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
工具函数包
|
||||
"""
|
||||
|
||||
# 导出核心工具
|
||||
from .logger import logger, ModuleLogger, log_exception
|
||||
from .exceptions import *
|
||||
from .json_utils import *
|
||||
from .singleton import singleton
|
||||
from .executor import run_in_thread_pool, initialize_executor
|
||||
from .performance import (
|
||||
timeit,
|
||||
profile,
|
||||
aprofile,
|
||||
memory_profile,
|
||||
memory_profile_decorator,
|
||||
performance_monitor,
|
||||
PerformanceStats,
|
||||
performance_stats,
|
||||
global_stats
|
||||
)
|
||||
from .error_codes import ErrorCode, get_error_message, create_error_response, exception_to_error_response
|
||||
|
||||
__all__ = [
|
||||
'logger',
|
||||
'ModuleLogger',
|
||||
'log_exception',
|
||||
'timeit',
|
||||
'profile',
|
||||
'aprofile',
|
||||
'memory_profile',
|
||||
'memory_profile_decorator',
|
||||
'performance_monitor',
|
||||
'PerformanceStats',
|
||||
'performance_stats',
|
||||
'global_stats',
|
||||
'run_in_thread_pool',
|
||||
'initialize_executor',
|
||||
'singleton',
|
||||
'ErrorCode',
|
||||
'get_error_message',
|
||||
'create_error_response',
|
||||
'exception_to_error_response'
|
||||
]
|
||||
234
core/utils/error_codes.py
Normal file
234
core/utils/error_codes.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
错误码和统一响应格式模块
|
||||
|
||||
该模块定义了项目中使用的错误码和统一的错误响应格式,确保所有模块返回一致的错误信息。
|
||||
"""
|
||||
|
||||
# 错误码定义
|
||||
class ErrorCode:
|
||||
"""
|
||||
错误码枚举类,包含所有系统错误码的定义。
|
||||
|
||||
错误码规则:
|
||||
- 1xxx: 系统级错误
|
||||
- 2xxx: WebSocket相关错误
|
||||
- 3xxx: 插件相关错误
|
||||
- 4xxx: 配置相关错误
|
||||
- 5xxx: 权限相关错误
|
||||
- 6xxx: 命令相关错误
|
||||
- 7xxx: Redis相关错误
|
||||
- 8xxx: 浏览器管理器相关错误
|
||||
- 9xxx: 代码执行相关错误
|
||||
"""
|
||||
# 系统级错误
|
||||
SUCCESS = 0 # 成功
|
||||
UNKNOWN_ERROR = 1000 # 未知错误
|
||||
INVALID_PARAMETER = 1001 # 参数无效
|
||||
DATABASE_ERROR = 1002 # 数据库错误
|
||||
NETWORK_ERROR = 1003 # 网络错误
|
||||
TIMEOUT_ERROR = 1004 # 超时错误
|
||||
RESOURCE_EXHAUSTED = 1005 # 资源耗尽
|
||||
|
||||
# WebSocket相关错误
|
||||
WS_CONNECTION_FAILED = 2000 # WebSocket连接失败
|
||||
WS_AUTH_FAILED = 2001 # WebSocket认证失败
|
||||
WS_DISCONNECTED = 2002 # WebSocket已断开
|
||||
WS_MESSAGE_ERROR = 2003 # WebSocket消息错误
|
||||
|
||||
# 插件相关错误
|
||||
PLUGIN_LOAD_FAILED = 3000 # 插件加载失败
|
||||
PLUGIN_RELOAD_FAILED = 3001 # 插件重载失败
|
||||
PLUGIN_NOT_FOUND = 3002 # 插件未找到
|
||||
PLUGIN_INVALID = 3003 # 插件无效
|
||||
PLUGIN_DEPENDENCY_ERROR = 3004 # 插件依赖错误
|
||||
|
||||
# 配置相关错误
|
||||
CONFIG_NOT_FOUND = 4000 # 配置文件未找到
|
||||
CONFIG_PARSE_ERROR = 4001 # 配置解析错误
|
||||
CONFIG_VALIDATION_ERROR = 4002 # 配置验证错误
|
||||
CONFIG_KEY_NOT_FOUND = 4003 # 配置项未找到
|
||||
|
||||
# 权限相关错误
|
||||
PERMISSION_DENIED = 5000 # 权限不足
|
||||
NOT_ADMIN = 5001 # 不是管理员
|
||||
USER_BANNED = 5002 # 用户已被禁止
|
||||
|
||||
# 命令相关错误
|
||||
COMMAND_NOT_FOUND = 6000 # 命令未找到
|
||||
COMMAND_PARAM_ERROR = 6001 # 命令参数错误
|
||||
COMMAND_EXECUTE_ERROR = 6002 # 命令执行错误
|
||||
COMMAND_TIMEOUT = 6003 # 命令执行超时
|
||||
|
||||
# Redis相关错误
|
||||
REDIS_CONNECTION_FAILED = 7000 # Redis连接失败
|
||||
REDIS_OPERATION_ERROR = 7001 # Redis操作错误
|
||||
|
||||
# 浏览器管理器相关错误
|
||||
BROWSER_INIT_FAILED = 8000 # 浏览器初始化失败
|
||||
BROWSER_POOL_ERROR = 8001 # 浏览器池错误
|
||||
BROWSER_OPERATION_ERROR = 8002 # 浏览器操作错误
|
||||
|
||||
# 代码执行相关错误
|
||||
CODE_EXECUTE_ERROR = 9000 # 代码执行错误
|
||||
CODE_SECURITY_ERROR = 9001 # 代码安全错误
|
||||
|
||||
|
||||
# 错误码到错误消息的映射
|
||||
ERROR_MESSAGES = {
|
||||
# 系统级错误
|
||||
ErrorCode.SUCCESS: "操作成功",
|
||||
ErrorCode.UNKNOWN_ERROR: "未知错误",
|
||||
ErrorCode.INVALID_PARAMETER: "参数无效",
|
||||
ErrorCode.DATABASE_ERROR: "数据库错误",
|
||||
ErrorCode.NETWORK_ERROR: "网络错误",
|
||||
ErrorCode.TIMEOUT_ERROR: "操作超时",
|
||||
ErrorCode.RESOURCE_EXHAUSTED: "资源耗尽",
|
||||
|
||||
# WebSocket相关错误
|
||||
ErrorCode.WS_CONNECTION_FAILED: "WebSocket连接失败",
|
||||
ErrorCode.WS_AUTH_FAILED: "WebSocket认证失败",
|
||||
ErrorCode.WS_DISCONNECTED: "WebSocket已断开连接",
|
||||
ErrorCode.WS_MESSAGE_ERROR: "WebSocket消息格式错误",
|
||||
|
||||
# 插件相关错误
|
||||
ErrorCode.PLUGIN_LOAD_FAILED: "插件加载失败",
|
||||
ErrorCode.PLUGIN_RELOAD_FAILED: "插件重载失败",
|
||||
ErrorCode.PLUGIN_NOT_FOUND: "插件未找到",
|
||||
ErrorCode.PLUGIN_INVALID: "插件无效",
|
||||
ErrorCode.PLUGIN_DEPENDENCY_ERROR: "插件依赖错误",
|
||||
|
||||
# 配置相关错误
|
||||
ErrorCode.CONFIG_NOT_FOUND: "配置文件未找到",
|
||||
ErrorCode.CONFIG_PARSE_ERROR: "配置文件解析错误",
|
||||
ErrorCode.CONFIG_VALIDATION_ERROR: "配置验证失败",
|
||||
ErrorCode.CONFIG_KEY_NOT_FOUND: "配置项未找到",
|
||||
|
||||
# 权限相关错误
|
||||
ErrorCode.PERMISSION_DENIED: "权限不足",
|
||||
ErrorCode.NOT_ADMIN: "需要管理员权限",
|
||||
ErrorCode.USER_BANNED: "用户已被禁止操作",
|
||||
|
||||
# 命令相关错误
|
||||
ErrorCode.COMMAND_NOT_FOUND: "命令未找到",
|
||||
ErrorCode.COMMAND_PARAM_ERROR: "命令参数错误",
|
||||
ErrorCode.COMMAND_EXECUTE_ERROR: "命令执行错误",
|
||||
ErrorCode.COMMAND_TIMEOUT: "命令执行超时",
|
||||
|
||||
# Redis相关错误
|
||||
ErrorCode.REDIS_CONNECTION_FAILED: "Redis连接失败",
|
||||
ErrorCode.REDIS_OPERATION_ERROR: "Redis操作错误",
|
||||
|
||||
# 浏览器管理器相关错误
|
||||
ErrorCode.BROWSER_INIT_FAILED: "浏览器初始化失败",
|
||||
ErrorCode.BROWSER_POOL_ERROR: "浏览器池错误",
|
||||
ErrorCode.BROWSER_OPERATION_ERROR: "浏览器操作错误",
|
||||
|
||||
# 代码执行相关错误
|
||||
ErrorCode.CODE_EXECUTE_ERROR: "代码执行错误",
|
||||
ErrorCode.CODE_SECURITY_ERROR: "代码存在安全风险",
|
||||
}
|
||||
|
||||
|
||||
def get_error_message(code: int) -> str:
|
||||
"""
|
||||
根据错误码获取错误消息
|
||||
|
||||
Args:
|
||||
code: 错误码
|
||||
|
||||
Returns:
|
||||
str: 错误消息
|
||||
"""
|
||||
return ERROR_MESSAGES.get(code, ERROR_MESSAGES[ErrorCode.UNKNOWN_ERROR])
|
||||
|
||||
|
||||
def create_error_response(code: int, message: str = None, data: dict = None, request_id: str = None) -> dict:
|
||||
"""
|
||||
创建统一格式的错误响应
|
||||
|
||||
Args:
|
||||
code: 错误码
|
||||
message: 错误消息(可选,如果未提供则使用默认消息)
|
||||
data: 附加数据(可选)
|
||||
request_id: 请求ID(可选,用于追踪请求)
|
||||
|
||||
Returns:
|
||||
dict: 统一格式的错误响应
|
||||
"""
|
||||
error_message = message if message is not None else get_error_message(code)
|
||||
|
||||
response = {
|
||||
"code": code,
|
||||
"message": error_message,
|
||||
"success": code == ErrorCode.SUCCESS,
|
||||
}
|
||||
|
||||
if data is not None:
|
||||
response["data"] = data
|
||||
|
||||
if request_id is not None:
|
||||
response["request_id"] = request_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def exception_to_error_response(exception: Exception, code: int = None, request_id: str = None) -> dict:
|
||||
"""
|
||||
将异常对象转换为统一格式的错误响应
|
||||
|
||||
Args:
|
||||
exception: 异常对象
|
||||
code: 错误码(可选,如果未提供则根据异常类型自动推断)
|
||||
request_id: 请求ID(可选,用于追踪请求)
|
||||
|
||||
Returns:
|
||||
dict: 统一格式的错误响应
|
||||
"""
|
||||
# 从自定义异常类中提取错误码
|
||||
if hasattr(exception, "code") and exception.code is not None:
|
||||
code = exception.code
|
||||
|
||||
# 如果仍未找到错误码,则根据异常类型推断
|
||||
if code is None:
|
||||
from .exceptions import (
|
||||
WebSocketError, PluginError, ConfigError, PermissionError,
|
||||
CommandError, RedisError, BrowserManagerError, CodeExecutionError
|
||||
)
|
||||
|
||||
if isinstance(exception, WebSocketError):
|
||||
code = ErrorCode.WS_CONNECTION_FAILED
|
||||
elif isinstance(exception, PluginError):
|
||||
code = ErrorCode.PLUGIN_LOAD_FAILED
|
||||
elif isinstance(exception, ConfigError):
|
||||
code = ErrorCode.CONFIG_PARSE_ERROR
|
||||
elif isinstance(exception, PermissionError):
|
||||
code = ErrorCode.PERMISSION_DENIED
|
||||
elif isinstance(exception, CommandError):
|
||||
code = ErrorCode.COMMAND_EXECUTE_ERROR
|
||||
elif isinstance(exception, RedisError):
|
||||
code = ErrorCode.REDIS_OPERATION_ERROR
|
||||
elif isinstance(exception, BrowserManagerError):
|
||||
code = ErrorCode.BROWSER_OPERATION_ERROR
|
||||
elif isinstance(exception, CodeExecutionError):
|
||||
code = ErrorCode.CODE_EXECUTE_ERROR
|
||||
else:
|
||||
code = ErrorCode.UNKNOWN_ERROR
|
||||
|
||||
# 获取错误消息
|
||||
message = str(exception)
|
||||
|
||||
# 如果异常有原始错误,也包含在响应中
|
||||
data = None
|
||||
if hasattr(exception, "original_error") and exception.original_error is not None:
|
||||
data = {"original_error": str(exception.original_error)}
|
||||
|
||||
return create_error_response(code, message, data, request_id)
|
||||
|
||||
|
||||
# 将错误码导出以便其他模块使用
|
||||
__all__ = [
|
||||
"ErrorCode",
|
||||
"get_error_message",
|
||||
"create_error_response",
|
||||
"exception_to_error_response"
|
||||
]
|
||||
221
core/utils/exceptions.py
Normal file
221
core/utils/exceptions.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
自定义异常模块
|
||||
|
||||
该模块定义了项目中使用的各种自定义异常类,用于提供更精确、更友好的错误提示。
|
||||
"""
|
||||
|
||||
class SyncHandlerError(Exception):
|
||||
"""
|
||||
当尝试注册同步函数作为异步事件处理器时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketError(Exception):
|
||||
"""
|
||||
WebSocket相关错误的基类。
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
code: 错误代码(可选)
|
||||
original_error: 原始异常对象(可选)
|
||||
"""
|
||||
def __init__(self, message, code=None, original_error=None):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class WebSocketConnectionError(WebSocketError):
|
||||
"""
|
||||
WebSocket连接失败时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketAuthenticationError(WebSocketError):
|
||||
"""
|
||||
WebSocket认证失败时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PluginError(Exception):
|
||||
"""
|
||||
插件相关错误的基类。
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
message: 错误消息
|
||||
original_error: 原始异常对象(可选)
|
||||
"""
|
||||
def __init__(self, plugin_name, message, original_error=None):
|
||||
self.plugin_name = plugin_name
|
||||
self.message = message
|
||||
self.original_error = original_error
|
||||
super().__init__(f"插件 {plugin_name}: {message}")
|
||||
|
||||
|
||||
class PluginLoadError(PluginError):
|
||||
"""
|
||||
插件加载失败时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PluginReloadError(PluginError):
|
||||
"""
|
||||
插件重载失败时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PluginNotFoundError(PluginError):
|
||||
"""
|
||||
找不到指定插件时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigError(Exception):
|
||||
"""
|
||||
配置相关错误的基类。
|
||||
|
||||
Args:
|
||||
section: 配置部分名称
|
||||
key: 配置项名称
|
||||
message: 错误消息
|
||||
"""
|
||||
def __init__(self, section=None, key=None, message=None):
|
||||
self.section = section
|
||||
self.key = key
|
||||
self.message = message
|
||||
|
||||
if section and key and message:
|
||||
super().__init__(f"配置错误 [{section}.{key}]: {message}")
|
||||
elif section and message:
|
||||
super().__init__(f"配置错误 [{section}]: {message}")
|
||||
else:
|
||||
super().__init__(message or "配置错误")
|
||||
|
||||
|
||||
class ConfigNotFoundError(ConfigError):
|
||||
"""
|
||||
配置文件不存在时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigValidationError(ConfigError):
|
||||
"""
|
||||
配置验证失败时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PermissionError(Exception):
|
||||
"""
|
||||
权限相关错误的基类。
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
operation: 操作名称
|
||||
message: 错误消息
|
||||
"""
|
||||
def __init__(self, user_id=None, operation=None, message=None):
|
||||
self.user_id = user_id
|
||||
self.operation = operation
|
||||
self.message = message
|
||||
|
||||
if user_id and operation and message:
|
||||
super().__init__(f"权限错误 [用户 {user_id}]: 无权限执行操作 {operation} - {message}")
|
||||
elif user_id and operation:
|
||||
super().__init__(f"权限错误 [用户 {user_id}]: 无权限执行操作 {operation}")
|
||||
else:
|
||||
super().__init__(message or "权限错误")
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
"""
|
||||
命令处理相关错误的基类。
|
||||
|
||||
Args:
|
||||
command: 命令名称
|
||||
message: 错误消息
|
||||
original_error: 原始异常对象(可选)
|
||||
"""
|
||||
def __init__(self, command=None, message=None, original_error=None):
|
||||
self.command = command
|
||||
self.message = message
|
||||
self.original_error = original_error
|
||||
|
||||
if command and message:
|
||||
super().__init__(f"命令错误 [{command}]: {message}")
|
||||
else:
|
||||
super().__init__(message or "命令错误")
|
||||
|
||||
|
||||
class CommandNotFoundError(CommandError):
|
||||
"""
|
||||
找不到指定命令时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class CommandParameterError(CommandError):
|
||||
"""
|
||||
命令参数错误时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RedisError(Exception):
|
||||
"""
|
||||
Redis相关错误的基类。
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
original_error: 原始异常对象(可选)
|
||||
"""
|
||||
def __init__(self, message, original_error=None):
|
||||
self.message = message
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BrowserManagerError(Exception):
|
||||
"""
|
||||
浏览器管理器相关错误的基类。
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
original_error: 原始异常对象(可选)
|
||||
"""
|
||||
def __init__(self, message, original_error=None):
|
||||
self.message = message
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BrowserPoolError(BrowserManagerError):
|
||||
"""
|
||||
浏览器池相关错误时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class CodeExecutionError(Exception):
|
||||
"""
|
||||
代码执行相关错误的基类。
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
code: 执行的代码(可选)
|
||||
original_error: 原始异常对象(可选)
|
||||
"""
|
||||
def __init__(self, message, code=None, original_error=None):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
195
core/utils/executor.py
Normal file
195
core/utils/executor.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import docker
|
||||
from docker.tls import TLSConfig
|
||||
from docker.types import LogConfig
|
||||
from typing import Any, Callable
|
||||
|
||||
from core.utils.logger import logger
|
||||
|
||||
class CodeExecutor:
|
||||
"""
|
||||
代码执行引擎,负责管理一个异步任务队列和并发的 Docker 容器执行。
|
||||
"""
|
||||
def __init__(self, config: Any):
|
||||
"""
|
||||
初始化代码执行引擎。
|
||||
:param config: 从 config_loader.py 加载的全局配置对象。
|
||||
"""
|
||||
self.bot: Any = None # Bot 实例将在 WS 连接成功后动态注入
|
||||
self.task_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
# 从传入的配置中读取 Docker 相关设置
|
||||
docker_config = config.docker
|
||||
self.docker_base_url = docker_config.base_url
|
||||
self.sandbox_image = docker_config.sandbox_image
|
||||
self.timeout = docker_config.timeout
|
||||
concurrency = docker_config.concurrency_limit
|
||||
|
||||
self.concurrency_limit = asyncio.Semaphore(concurrency)
|
||||
self.docker_client = None
|
||||
|
||||
logger.info("[CodeExecutor] 初始化 Docker 客户端...")
|
||||
try:
|
||||
if self.docker_base_url:
|
||||
# 如果配置了远程 Docker 地址,则使用 TLS 选项进行连接
|
||||
tls_config = None
|
||||
if docker_config.tls_verify:
|
||||
tls_config = TLSConfig(
|
||||
ca_cert=docker_config.ca_cert_path,
|
||||
client_cert=(docker_config.client_cert_path, docker_config.client_key_path),
|
||||
verify=True
|
||||
)
|
||||
self.docker_client = docker.DockerClient(base_url=self.docker_base_url, tls=tls_config)
|
||||
else:
|
||||
# 否则,使用默认的本地环境连接
|
||||
self.docker_client = docker.from_env()
|
||||
|
||||
# 检查 Docker 服务是否可用
|
||||
self.docker_client.ping()
|
||||
logger.success("[CodeExecutor] Docker 客户端初始化成功,服务连接正常。")
|
||||
except docker.errors.DockerException as e:
|
||||
self.docker_client = None
|
||||
logger.error(f"无法连接到 Docker 服务,请检查 Docker 是否正在运行: {e}")
|
||||
except Exception as e:
|
||||
self.docker_client = None
|
||||
logger.error(f"初始化 Docker 客户端时发生未知错误: {e}")
|
||||
|
||||
async def add_task(self, code: str, callback: Callable[[str], asyncio.Future]):
|
||||
"""
|
||||
将代码执行任务添加到队列中。
|
||||
:param code: 待执行的 Python 代码字符串。
|
||||
:param callback: 执行完毕后用于回复结果的回调函数。
|
||||
:raises RuntimeError: 如果 Docker 客户端未初始化。
|
||||
"""
|
||||
if not self.docker_client:
|
||||
logger.warning("[CodeExecutor] 尝试添加任务,但 Docker 客户端未初始化。任务被拒绝。")
|
||||
# 这里可以选择抛出异常,或者直接调用回调返回错误信息
|
||||
# 为了用户体验,我们构造一个错误结果并直接调用回调(如果可能)
|
||||
# 但由于 callback 返回 Future,这里简单起见,我们记录日志并抛出异常
|
||||
raise RuntimeError("Docker环境未就绪,无法执行代码。")
|
||||
|
||||
task = {"code": code, "callback": callback}
|
||||
await self.task_queue.put(task)
|
||||
logger.info(f"[CodeExecutor] 新的代码执行任务已入队 (队列当前长度: {self.task_queue.qsize()})。")
|
||||
|
||||
async def worker(self):
|
||||
"""
|
||||
后台工作者,不断从队列中取出任务并执行。
|
||||
"""
|
||||
if not self.docker_client:
|
||||
logger.error("[CodeExecutor] Worker 无法启动,因为 Docker 客户端未初始化。")
|
||||
return
|
||||
|
||||
logger.info("[CodeExecutor] 代码执行 Worker 已启动,等待任务...")
|
||||
while True:
|
||||
task = await self.task_queue.get()
|
||||
|
||||
logger.info("[CodeExecutor] 开始处理代码执行任务。")
|
||||
|
||||
async with self.concurrency_limit:
|
||||
result_message = ""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# 使用 asyncio.wait_for 实现超时控制
|
||||
result_bytes = await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
None, # 使用默认线程池
|
||||
self._run_in_container,
|
||||
task['code']
|
||||
),
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
output = result_bytes.decode('utf-8').strip()
|
||||
result_message = output if output else "代码执行完毕,无输出。"
|
||||
logger.success("[CodeExecutor] 任务成功执行。")
|
||||
|
||||
except docker.errors.ImageNotFound:
|
||||
logger.error(f"[CodeExecutor] 镜像 '{self.sandbox_image}' 不存在!")
|
||||
result_message = f"执行失败:沙箱基础镜像 '{self.sandbox_image}' 不存在,请联系管理员构建。"
|
||||
except docker.errors.ContainerError as e:
|
||||
error_output = e.stderr.decode('utf-8').strip()
|
||||
result_message = f"代码执行出错:\n{error_output}"
|
||||
logger.warning(f"[CodeExecutor] 代码执行时发生错误: {error_output}")
|
||||
except docker.errors.APIError as e:
|
||||
logger.error(f"[CodeExecutor] Docker API 错误: {e}")
|
||||
result_message = "执行失败:与 Docker 服务通信时发生错误,请检查服务状态。"
|
||||
except asyncio.TimeoutError:
|
||||
result_message = f"执行超时 (超过 {self.timeout} 秒)。"
|
||||
logger.warning("[CodeExecutor] 任务执行超时。")
|
||||
except Exception as e:
|
||||
logger.exception(f"[CodeExecutor] 执行 Docker 任务时发生未知严重错误: {e}")
|
||||
result_message = "执行引擎发生内部错误,请联系管理员。"
|
||||
|
||||
# 调用回调函数回复结果
|
||||
await task['callback'](result_message)
|
||||
|
||||
self.task_queue.task_done()
|
||||
|
||||
def _run_in_container(self, code: str) -> bytes:
|
||||
"""
|
||||
同步函数:在 Docker 容器中运行代码。
|
||||
此函数通过手动管理容器生命周期来提高稳定性。
|
||||
"""
|
||||
if self.docker_client is None:
|
||||
raise docker.errors.DockerException("Docker client is not initialized.")
|
||||
|
||||
container = None
|
||||
try:
|
||||
# 1. 创建容器
|
||||
container = self.docker_client.containers.create(
|
||||
image=self.sandbox_image,
|
||||
command=["python", "-c", code],
|
||||
mem_limit='128m',
|
||||
cpu_shares=512,
|
||||
network_disabled=True,
|
||||
log_config=LogConfig(type='json-file', config={'max-size': '1m'}),
|
||||
)
|
||||
# 2. 启动容器
|
||||
container.start()
|
||||
|
||||
# 3. 等待容器执行完成
|
||||
# 主超时由 asyncio.wait_for 控制,这里的 timeout 是一个额外的保险
|
||||
result = container.wait(timeout=self.timeout + 5)
|
||||
|
||||
# 4. 获取日志
|
||||
stdout = container.logs(stdout=True, stderr=False)
|
||||
stderr = container.logs(stdout=False, stderr=True)
|
||||
|
||||
# 5. 检查退出码,如果不为 0,则手动抛出 ContainerError
|
||||
if result.get('StatusCode', 0) != 0:
|
||||
raise docker.errors.ContainerError(
|
||||
container, result['StatusCode'], f"python -c '{code}'", self.sandbox_image, stderr.decode('utf-8')
|
||||
)
|
||||
|
||||
return stdout
|
||||
|
||||
finally:
|
||||
# 6. 确保容器总是被移除
|
||||
if container:
|
||||
try:
|
||||
container.remove(force=True)
|
||||
except docker.errors.NotFound:
|
||||
# 如果容器因为某些原因已经消失,也沒关系
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"[CodeExecutor] 强制移除容器 {container.id} 时失败: {e}")
|
||||
|
||||
def initialize_executor(config: Any):
|
||||
"""
|
||||
初始化并返回一个 CodeExecutor 实例。
|
||||
"""
|
||||
return CodeExecutor(config)
|
||||
|
||||
async def run_in_thread_pool(sync_func, *args, **kwargs):
|
||||
"""
|
||||
在线程池中运行同步阻塞函数,以避免阻塞 asyncio 事件循环。
|
||||
:param sync_func: 同步函数
|
||||
:param args: 位置参数
|
||||
:param kwargs: 关键字参数
|
||||
:return: 同步函数的返回值
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, lambda: sync_func(*args, **kwargs))
|
||||
34
core/utils/json_utils.py
Normal file
34
core/utils/json_utils.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
JSON 工具模块
|
||||
|
||||
统一使用高性能的 orjson 库进行 JSON 序列化和反序列化。
|
||||
如果 orjson 不可用,则回退到标准库 json。
|
||||
"""
|
||||
from typing import Any, Union
|
||||
import json
|
||||
|
||||
# 在模块加载时检查 orjson 是否可用
|
||||
try:
|
||||
import orjson
|
||||
_orjson_available = True
|
||||
except ImportError:
|
||||
_orjson_available = False
|
||||
|
||||
def dumps(obj: Any) -> str:
|
||||
"""
|
||||
将对象序列化为 JSON 字符串。
|
||||
"""
|
||||
if _orjson_available:
|
||||
# orjson.dumps 返回 bytes,需要 decode
|
||||
return orjson.dumps(obj).decode("utf-8")
|
||||
else:
|
||||
return json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
def loads(json_str: Union[str, bytes]) -> Any:
|
||||
"""
|
||||
将 JSON 字符串反序列化为对象。
|
||||
"""
|
||||
if _orjson_available:
|
||||
return orjson.loads(json_str)
|
||||
else:
|
||||
return json.loads(json_str)
|
||||
137
core/utils/logger.py
Normal file
137
core/utils/logger.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
日志模块
|
||||
|
||||
该模块负责初始化和配置 loguru 日志记录器,为整个应用程序提供统一的日志记录接口。
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
# 定义日志格式,添加进程ID和线程ID作为上下文信息
|
||||
LOG_FORMAT = (
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<magenta>PID {process} TID {thread}</magenta> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||
"<level>{message}</level>"
|
||||
)
|
||||
|
||||
# 开发环境日志格式(更详细)
|
||||
DEBUG_LOG_FORMAT = (
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<magenta>PID {process} TID {thread}</magenta> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
|
||||
"<yellow>Module: {module}</yellow> | "
|
||||
"<level>{message}</level>"
|
||||
)
|
||||
|
||||
# 移除 loguru 默认的处理器
|
||||
logger.remove()
|
||||
|
||||
# 获取当前环境
|
||||
ENVIRONMENT = os.getenv("NEOBOT_ENV", "development")
|
||||
|
||||
# 添加控制台输出处理器
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
level="INFO" if ENVIRONMENT == "production" else "DEBUG",
|
||||
format=LOG_FORMAT if ENVIRONMENT == "production" else DEBUG_LOG_FORMAT,
|
||||
colorize=True,
|
||||
enqueue=True # 异步写入
|
||||
)
|
||||
|
||||
# 定义日志文件路径
|
||||
log_dir = Path("logs")
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
log_file_path = log_dir / "{time:YYYY-MM-DD}.log"
|
||||
|
||||
# 添加文件输出处理器
|
||||
logger.add(
|
||||
log_file_path,
|
||||
level="DEBUG",
|
||||
format=DEBUG_LOG_FORMAT,
|
||||
colorize=False,
|
||||
rotation="00:00", # 每天午夜创建新文件
|
||||
retention="7 days", # 保留最近 7 天的日志
|
||||
encoding="utf-8",
|
||||
enqueue=True, # 异步写入
|
||||
backtrace=True, # 记录完整的异常堆栈
|
||||
diagnose=True # 添加异常诊断信息
|
||||
)
|
||||
|
||||
# 为自定义异常添加专门的日志记录方法
|
||||
def log_exception(exc, module_name="unknown", level="error"):
|
||||
"""
|
||||
记录自定义异常的详细信息
|
||||
|
||||
Args:
|
||||
exc: 异常对象
|
||||
module_name: 模块名称(可选)
|
||||
level: 日志级别(可选,默认为 "error")
|
||||
"""
|
||||
log_func = getattr(logger, level)
|
||||
log_func(f"模块 {module_name} 发生异常: {exc}")
|
||||
|
||||
# 如果异常对象有原始异常,也记录原始异常信息
|
||||
if hasattr(exc, "original_error") and exc.original_error:
|
||||
log_func(f"原始异常: {exc.original_error}")
|
||||
|
||||
# 如果是配置错误,记录配置相关信息
|
||||
if hasattr(exc, "section") and hasattr(exc, "key"):
|
||||
log_func(f"配置信息: 部分={exc.section}, 键={exc.key}")
|
||||
|
||||
# 如果是插件错误,记录插件名称
|
||||
if hasattr(exc, "plugin_name"):
|
||||
log_func(f"插件名称: {exc.plugin_name}")
|
||||
|
||||
# 如果是命令错误,记录命令名称
|
||||
if hasattr(exc, "command"):
|
||||
log_func(f"命令名称: {exc.command}")
|
||||
|
||||
# 如果是权限错误,记录用户ID和操作
|
||||
if hasattr(exc, "user_id") and hasattr(exc, "operation"):
|
||||
log_func(f"权限信息: 用户ID={exc.user_id}, 操作={exc.operation}")
|
||||
|
||||
# 为不同模块提供日志工具
|
||||
class ModuleLogger:
|
||||
"""
|
||||
模块专用日志记录器
|
||||
|
||||
Args:
|
||||
module_name: 模块名称
|
||||
"""
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
|
||||
def debug(self, message):
|
||||
logger.debug(f"[{self.module_name}] {message}")
|
||||
|
||||
def info(self, message):
|
||||
logger.info(f"[{self.module_name}] {message}")
|
||||
|
||||
def success(self, message):
|
||||
logger.success(f"[{self.module_name}] {message}")
|
||||
|
||||
def warning(self, message):
|
||||
logger.warning(f"[{self.module_name}] {message}")
|
||||
|
||||
def error(self, message):
|
||||
logger.error(f"[{self.module_name}] {message}")
|
||||
|
||||
def exception(self, message, exc_info=True):
|
||||
logger.exception(f"[{self.module_name}] {message}", exc_info=exc_info)
|
||||
|
||||
def log_custom_exception(self, exc, level="error"):
|
||||
"""
|
||||
记录自定义异常
|
||||
|
||||
Args:
|
||||
exc: 异常对象
|
||||
level: 日志级别
|
||||
"""
|
||||
log_exception(exc, self.module_name, level)
|
||||
|
||||
# 导出配置好的 logger 和工具函数
|
||||
__all__ = ["logger", "log_exception", "ModuleLogger"]
|
||||
364
core/utils/performance.py
Normal file
364
core/utils/performance.py
Normal file
@@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
性能分析工具模块
|
||||
|
||||
提供同步和异步函数的性能分析装饰器、上下文管理器和统计工具。
|
||||
|
||||
主要功能:
|
||||
1. 函数执行时间分析(支持同步和异步)
|
||||
2. 内存使用分析
|
||||
3. 性能统计和报告生成
|
||||
4. 低开销的生产环境监控
|
||||
"""
|
||||
|
||||
import time
|
||||
import functools
|
||||
import logging
|
||||
from typing import Dict, Any, Callable, Optional
|
||||
import inspect
|
||||
|
||||
# 尝试导入性能分析库
|
||||
try:
|
||||
from pyinstrument import Profiler
|
||||
from pyinstrument.renderers import HTMLRenderer
|
||||
PYINSTRUMENT_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYINSTRUMENT_AVAILABLE = False
|
||||
|
||||
# 尝试导入内存分析库
|
||||
try:
|
||||
from memory_profiler import memory_usage
|
||||
MEMORY_PROFILER_AVAILABLE = True
|
||||
except ImportError:
|
||||
MEMORY_PROFILER_AVAILABLE = False
|
||||
|
||||
from .logger import logger
|
||||
|
||||
|
||||
class PerformanceStats:
|
||||
"""
|
||||
性能统计工具类
|
||||
用于收集和报告函数执行的性能指标
|
||||
"""
|
||||
def __init__(self):
|
||||
self.stats: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def record(self, func_name: str, duration: float, memory_used: Optional[float] = None):
|
||||
"""
|
||||
记录函数执行的性能数据
|
||||
|
||||
Args:
|
||||
func_name: 函数名称
|
||||
duration: 执行时间(秒)
|
||||
memory_used: 使用的内存(MB),可选
|
||||
"""
|
||||
if func_name not in self.stats:
|
||||
self.stats[func_name] = {
|
||||
"count": 0,
|
||||
"total_time": 0.0,
|
||||
"avg_time": 0.0,
|
||||
"min_time": float('inf'),
|
||||
"max_time": 0.0,
|
||||
"total_memory": 0.0,
|
||||
"avg_memory": 0.0
|
||||
}
|
||||
|
||||
stat = self.stats[func_name]
|
||||
stat["count"] += 1
|
||||
stat["total_time"] += duration
|
||||
stat["avg_time"] = stat["total_time"] / stat["count"]
|
||||
stat["min_time"] = min(stat["min_time"], duration)
|
||||
stat["max_time"] = max(stat["max_time"], duration)
|
||||
|
||||
if memory_used is not None:
|
||||
stat["total_memory"] += memory_used
|
||||
stat["avg_memory"] = stat["total_memory"] / stat["count"]
|
||||
|
||||
def report(self) -> str:
|
||||
"""
|
||||
生成性能统计报告
|
||||
|
||||
Returns:
|
||||
格式化的性能统计报告字符串
|
||||
"""
|
||||
if not self.stats:
|
||||
return "暂无性能统计数据"
|
||||
|
||||
report = ["\n=== 性能统计报告 ===\n"]
|
||||
report.append(f"{'函数名':<40} {'调用次数':<10} {'平均时间(ms)':<15} {'最长时间(ms)':<15} {'内存(MB)':<10}")
|
||||
report.append("-" * 100)
|
||||
|
||||
for func_name, stat in sorted(self.stats.items(), key=lambda x: x[1]["total_time"], reverse=True):
|
||||
memory_str = f"{stat['avg_memory']:.2f}" if stat['avg_memory'] > 0 else "-"
|
||||
report.append(
|
||||
f"{func_name:<40} {stat['count']:<10} {stat['avg_time']*1000:<15.2f} "
|
||||
f"{stat['max_time']*1000:<15.2f} {memory_str:<10}"
|
||||
)
|
||||
|
||||
report.append("=" * 100)
|
||||
return "\n".join(report)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
重置性能统计数据
|
||||
"""
|
||||
self.stats.clear()
|
||||
|
||||
|
||||
# 创建全局性能统计实例
|
||||
performance_stats = PerformanceStats()
|
||||
|
||||
|
||||
def timeit(func: Callable = None, *, log_level: int = logging.INFO, collect_stats: bool = True):
|
||||
"""
|
||||
函数执行时间分析装饰器(支持同步和异步)
|
||||
|
||||
Args:
|
||||
func: 要装饰的函数
|
||||
log_level: 日志级别
|
||||
collect_stats: 是否收集到全局统计中
|
||||
|
||||
Returns:
|
||||
装饰后的函数
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
func_name = func.__qualname__
|
||||
is_coroutine = inspect.iscoroutinefunction(func)
|
||||
|
||||
if is_coroutine:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if collect_stats:
|
||||
performance_stats.record(func_name, duration)
|
||||
|
||||
logger.log(log_level, f"[性能] {func_name} 执行时间: {duration*1000:.2f} ms")
|
||||
|
||||
return result
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if collect_stats:
|
||||
performance_stats.record(func_name, duration)
|
||||
|
||||
logger.log(log_level, f"[性能] {func_name} 执行时间: {duration*1000:.2f} ms")
|
||||
|
||||
return result
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
if func is None:
|
||||
return decorator
|
||||
return decorator(func)
|
||||
|
||||
|
||||
class profile:
|
||||
"""
|
||||
性能分析上下文管理器
|
||||
使用 pyinstrument 进行详细的性能分析
|
||||
"""
|
||||
def __init__(self, enabled: bool = True, output_file: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
enabled: 是否启用分析
|
||||
output_file: 分析结果输出文件路径(HTML格式)
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self.output_file = output_file
|
||||
self.profiler = None
|
||||
|
||||
def __enter__(self):
|
||||
if self.enabled and PYINSTRUMENT_AVAILABLE:
|
||||
self.profiler = Profiler()
|
||||
self.profiler.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.enabled and PYINSTRUMENT_AVAILABLE and self.profiler:
|
||||
self.profiler.stop()
|
||||
|
||||
# 输出到日志
|
||||
logger.info(f"[性能分析] {self.profiler.print()}")
|
||||
|
||||
# 如果指定了输出文件,保存为HTML
|
||||
if self.output_file:
|
||||
try:
|
||||
html = self.profiler.render(HTMLRenderer())
|
||||
with open(self.output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(html)
|
||||
logger.info(f"[性能分析] 报告已保存到: {self.output_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"[性能分析] 保存报告失败: {e}")
|
||||
|
||||
|
||||
async def aprofile(func: Callable, *args, **kwargs):
|
||||
"""
|
||||
异步函数性能分析
|
||||
|
||||
Args:
|
||||
func: 要分析的异步函数
|
||||
*args: 函数参数
|
||||
**kwargs: 函数关键字参数
|
||||
|
||||
Returns:
|
||||
函数执行结果
|
||||
"""
|
||||
if not PYINSTRUMENT_AVAILABLE:
|
||||
logger.warning("[性能分析] pyinstrument 未安装,无法进行详细分析")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
profiler = Profiler()
|
||||
profiler.start()
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
finally:
|
||||
profiler.stop()
|
||||
logger.info(f"[性能分析] {profiler.print()}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class memory_profile:
|
||||
"""
|
||||
内存分析上下文管理器
|
||||
"""
|
||||
def __init__(self, interval: float = 0.1, enabled: bool = True):
|
||||
"""
|
||||
Args:
|
||||
interval: 内存采样间隔(秒)
|
||||
enabled: 是否启用内存分析
|
||||
"""
|
||||
self.interval = interval
|
||||
self.enabled = enabled
|
||||
self.memory_start = 0.0
|
||||
self.memory_end = 0.0
|
||||
|
||||
def __enter__(self):
|
||||
if self.enabled and MEMORY_PROFILER_AVAILABLE:
|
||||
self.memory_start = memory_usage()[0]
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.enabled and MEMORY_PROFILER_AVAILABLE:
|
||||
self.memory_end = memory_usage()[0]
|
||||
memory_used = self.memory_end - self.memory_start
|
||||
logger.info(f"[内存分析] 使用内存: {memory_used:.2f} MB")
|
||||
|
||||
|
||||
def memory_profile_decorator(func: Callable = None, *, interval: float = 0.1):
|
||||
"""
|
||||
内存分析装饰器(支持同步函数)
|
||||
|
||||
Args:
|
||||
func: 要装饰的函数
|
||||
interval: 内存采样间隔
|
||||
|
||||
Returns:
|
||||
装饰后的函数
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not MEMORY_PROFILER_AVAILABLE:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
mem_usage = memory_usage(
|
||||
(func, args, kwargs),
|
||||
interval=interval,
|
||||
timeout=None,
|
||||
include_children=False
|
||||
)
|
||||
|
||||
max_memory = max(mem_usage)
|
||||
logger.info(f"[内存分析] {func.__qualname__} 最大内存使用: {max_memory:.2f} MB")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if func is None:
|
||||
return decorator
|
||||
return decorator(func)
|
||||
|
||||
|
||||
def performance_monitor(func: Callable = None, *, threshold: float = 1.0):
|
||||
"""
|
||||
性能监控装饰器
|
||||
仅当函数执行时间超过阈值时记录日志
|
||||
适合生产环境使用
|
||||
|
||||
Args:
|
||||
func: 要装饰的函数
|
||||
threshold: 时间阈值(秒)
|
||||
|
||||
Returns:
|
||||
装饰后的函数
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
func_name = func.__qualname__
|
||||
is_coroutine = inspect.iscoroutinefunction(func)
|
||||
|
||||
if is_coroutine:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
result = await func(*args, **kwargs)
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if duration > threshold:
|
||||
logger.warning(f"[性能监控] {func_name} 执行时间过长: {duration*1000:.2f} ms (阈值: {threshold*1000:.2f} ms)")
|
||||
|
||||
return result
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if duration > threshold:
|
||||
logger.warning(f"[性能监控] {func_name} 执行时间过长: {duration*1000:.2f} ms (阈值: {threshold*1000:.2f} ms)")
|
||||
|
||||
return result
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
if func is None:
|
||||
return decorator
|
||||
return decorator(func)
|
||||
|
||||
|
||||
# 全局实例
|
||||
global_stats = PerformanceStats()
|
||||
|
||||
|
||||
__all__ = [
|
||||
'timeit',
|
||||
'profile',
|
||||
'aprofile',
|
||||
'memory_profile',
|
||||
'memory_profile_decorator',
|
||||
'performance_monitor',
|
||||
'PerformanceStats',
|
||||
'performance_stats',
|
||||
'global_stats'
|
||||
]
|
||||
78
core/utils/singleton.py
Normal file
78
core/utils/singleton.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
通用单例模式基类
|
||||
"""
|
||||
from typing import Any, Dict, Optional, Type, TypeVar
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# 存储每个类的实例
|
||||
_instance_store: Dict[Type, Any] = {}
|
||||
|
||||
class Singleton:
|
||||
"""
|
||||
一个通用的单例基类
|
||||
|
||||
任何继承自该类的子类都将自动成为单例。
|
||||
它通过重写 __new__ 方法来确保每个类只有一个实例。
|
||||
同时,它处理了重复初始化的问题,确保 __init__ 方法只在第一次实例化时被调用。
|
||||
"""
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls: Type[T], *args: Any, **kwargs: Any) -> T:
|
||||
"""
|
||||
创建或返回现有的实例
|
||||
|
||||
Args:
|
||||
*args: 传递给构造函数的位置参数
|
||||
**kwargs: 传递给构造函数的关键字参数
|
||||
|
||||
Returns:
|
||||
T: 单例实例
|
||||
"""
|
||||
# 使用全局字典存储实例,避免类型检查问题
|
||||
if cls not in _instance_store:
|
||||
_instance_store[cls] = super().__new__(cls)
|
||||
return _instance_store[cls]
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
确保初始化逻辑只执行一次
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
|
||||
def singleton(cls: Type[T]) -> Type[T]:
|
||||
"""
|
||||
单例装饰器
|
||||
|
||||
将普通类转换为单例类,确保整个应用程序中只有一个实例。
|
||||
|
||||
Args:
|
||||
cls: 要转换为单例的类
|
||||
|
||||
Returns:
|
||||
Type[T]: 单例类
|
||||
"""
|
||||
# 为每个装饰的类创建一个实例存储
|
||||
class_instance: Optional[T] = None
|
||||
|
||||
# 创建一个新的类,继承自原始类
|
||||
class SingletonClass(cls):
|
||||
"""单例包装类"""
|
||||
|
||||
def __new__(cls: Type[T], *args: Any, **kwargs: Any) -> T:
|
||||
"""创建或返回现有的实例"""
|
||||
nonlocal class_instance
|
||||
if class_instance is None:
|
||||
# 使用super()调用原始类的__new__方法
|
||||
class_instance = cls(*args, **kwargs)
|
||||
return class_instance
|
||||
|
||||
# 复制类的元数据
|
||||
SingletonClass.__name__ = cls.__name__
|
||||
SingletonClass.__doc__ = cls.__doc__
|
||||
SingletonClass.__module__ = cls.__module__
|
||||
|
||||
return SingletonClass
|
||||
Reference in New Issue
Block a user