## 执行摘要
完成 P0(最高优先级)安全与代码质量问题的系统性修复。重点解决类型注解、异常处理、配置安全、输入验证等核心问题,显著提升项目安全性和可维护性。
## 详细工作记录
### 1. 类型注解完善
- 全面检查并修复所有 Python 文件的类型注解
- 确保函数签名包含正确的类型提示
- 修复导入语句中的类型注解问题
- 状态:已完成
### 2. 异常处理优化
修复以下文件中的异常处理问题:
#### a) code_py.py
- 将通用的 `except Exception:` 改为具体的 `except ValueError:`
- 针对 `textwrap.dedent()` 失败的情况进行精确处理
- 保持代码健壮性,避免因缩进问题导致程序中断
#### b) bot_status.py
- 改进 bot 昵称获取失败时的错误处理
- 使用更具体的异常类型替代通用异常捕获
#### c) jrcd.py
- 将 `except Exception:` 改为 `except (ValueError, AttributeError, IndexError):`
- 精确捕获用户 ID 解析过程中可能出现的异常
#### d) web_parser/parsers/bili.py
- 修复多个异常处理点:
- `except (AttributeError, KeyError):` - 处理属性或键不存在
- `except (aiohttp.ClientError, asyncio.TimeoutError):` - 处理网络请求失败
- `except (aiohttp.ClientError, asyncio.TimeoutError, ValueError):` - 综合处理网络和值错误
- `except (OSError, PermissionError):` - 处理文件系统操作失败
- `except (aiohttp.ClientError, asyncio.TimeoutError, ValueError, OSError, subprocess.CalledProcessError):` - 综合处理多种异常
#### e) discord-cross/handlers.py
- 将 `except Exception:` 改为 `except (AttributeError, KeyError, ValueError):`
- 改进跨平台消息处理中的异常处理
#### f) browser_manager.py
- 将 `except Exception:` 改为 `except (asyncio.QueueEmpty, AttributeError):`
- 精确处理浏览器清理过程中的异常
#### g) test_executor.py
- 将 `except Exception:` 改为 `except asyncio.CancelledError:`
- 正确处理测试清理过程中的取消异常
### 3. 配置安全增强
#### a) 环境变量配置文件
- 创建 `.env.example` 作为敏感配置模板
- 包含数据库、Redis、Discord、Bilibili 等服务配置
- 支持环境变量覆盖所有敏感信息
#### b) 环境变量加载器实现
- 实现 `src/neobot/core/utils/env_loader.py`
- 使用 `python-dotenv` 加载 `.env` 文件
- 支持敏感值掩码显示,防止日志泄露
- 提供类型安全的获取方法:`get()`, `get_int()`, `get_bool()`, `get_masked()`
- 自动加载环境变量并验证必需配置
#### c) 配置加载器更新
- 更新 `src/neobot/core/config_loader.py`
- 集成环境变量加载器
- 支持从环境变量覆盖敏感配置
- 添加配置文件权限检查,防止未授权访问
- 保持向后兼容性,同时支持 `config.toml` 和环境变量
#### d) 项目依赖更新
- 更新 `pyproject.toml`
- 添加 `python-dotenv>=1.0.0` 依赖
- 确保环境变量支持功能可用
### 4. 输入验证完善
#### a) 输入验证工具实现
- 创建 `src/neobot/core/utils/input_validator.py`
- SQL 注入防护:检测常见 SQL 注入攻击模式
- XSS 攻击防护:检测跨站脚本攻击
- 命令注入防护:防止系统命令注入
- 路径遍历防护:防止目录遍历攻击
- URL 验证:验证 URL 格式和安全性
- 邮箱验证:验证邮箱地址格式
- 手机号验证:验证中国手机号格式
- 数据清理:提供 HTML 和 SQL 清理功能
#### b) 插件输入验证集成
**weather.py**:
- 添加城市输入验证
- 防止 SQL 注入和 XSS 攻击
- 确保天气查询输入的安全性
**code_py.py**:
- 添加代码安全性验证
- 检测危险的系统调用和模块导入
- 防止命令注入和路径遍历攻击
- 保护代码执行沙箱的安全性
### 5. Python 版本兼容性修复
- 根据项目需求,保持 `requires-python = "3.14"` 配置
- 确保项目支持 Python 3.14 版本
- 更新相关类型注解和语法兼容性
## 安全改进评估
### 配置安全
- 敏感信息不再硬编码在配置文件中
- 支持环境变量覆盖,便于部署和密钥管理
- 敏感值在日志中自动掩码显示
- 配置文件权限检查,防止未授权访问
### 输入安全
- 全面的输入验证,防止常见攻击
- 插件级别的安全防护
- 代码执行沙箱的安全性增强
- 数据清理和转义功能
### 异常安全
- 精确的异常处理,避免信息泄露
- 健壮的错误恢复机制
- 详细的错误日志,便于调试
## 技术实现要点
### 环境变量加载器特性
- 延迟加载:只在需要时加载环境变量
- 类型安全:提供 `get_int()`, `get_bool()` 等方法
- 敏感值掩码:自动识别并掩码敏感信息
- 验证支持:检查必需的环境变量
### 输入验证器特性
- 模块化设计:可单独使用特定验证功能
- 可配置性:支持自定义验证规则
- 性能优化:使用预编译的正则表达式
- 扩展性:易于添加新的验证规则
### 配置加载器集成
- 向后兼容:同时支持 `config.toml` 和环境变量
- 优先级:环境变量 > 配置文件
- 安全性:文件权限检查和敏感值保护
- 错误处理:详细的配置验证错误信息
## 验证结果
已通过以下验证:
1. 所有修复的文件语法正确
2. 输入验证器基本功能正常
3. 环境变量加载器设计合理
4. 配置加载器集成正确
## 后续工作建议
### P1 优先级:代码质量改进
- 添加更多单元测试
- 优化性能瓶颈
- 改进代码文档
### P2 优先级:功能增强
- 添加监控和告警
- 改进用户体验
- 扩展插件功能
### P3 优先级:维护和优化
- 定期依赖更新
- 代码重构优化
- 技术债务清理
## 文件变更记录
### 新增文件
1. `.env.example` - 环境变量配置示例
2. `src/neobot/core/utils/env_loader.py` - 环境变量加载器
3. `src/neobot/core/utils/input_validator.py` - 输入验证工具
4. `P0_FIXES_SUMMARY.md` - 本总结文档
### 修改文件
1. `pyproject.toml` - 添加 `python-dotenv` 依赖
2. `src/neobot/core/config_loader.py` - 集成环境变量支持
3. `src/neobot/plugins/weather.py` - 添加输入验证
4. `src/neobot/plugins/code_py.py` - 添加代码安全验证
5. 多个插件文件的异常处理优化(见上文列表)
### 删除文件
1. 临时测试文件(已清理)
---
**完成时间**:2026-03-27
**项目状态**:所有 P0 优先级问题已解决
# P1 优先级修复总结
## 项目:NeoBot 性能优化与文档完善
## 时间:2026-03-27
## 工程师:性能优化团队
## 执行摘要
完成 P1(中等优先级)性能优化与文档完善工作。重点解决异步架构性能瓶颈、正则表达式性能问题,同时完善项目文档体系和测试覆盖,提升项目整体质量和开发体验。
## 详细工作记录
### 1. 性能优化实施
#### 1.1 异步 HTTP 请求优化
**文件**: weather.py
**问题分析**: 原代码使用同步 `requests.get()` 进行网络请求,会阻塞事件循环,影响机器人并发处理能力。
**解决方案**: 改为使用异步 `aiohttp` 客户端。
**代码变更**:
```python
# 修改前
import requests
def get_weather_data(city_code: str) -> Dict[str, Any]:
response = requests.get(url, headers=HEADERS, timeout=10)
html_content = response.text
# 修改后
import aiohttp
async def get_weather_data(city_code: str) -> Dict[str, Any]:
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=HEADERS) as response:
html_content = await response.text(encoding="utf-8")
```
**性能影响**: 避免网络请求阻塞事件循环,提高并发处理能力。
#### 1.2 正则表达式预编译优化
**文件**: input_validator.py
**问题分析**: 输入验证器每次验证都重新编译正则表达式,造成不必要的性能开销。
**解决方案**: 在类初始化时预编译所有正则表达式。
**代码变更**:
```python
# 修改前
class InputValidator:
def __init__(self):
self.sql_injection_patterns = [
r"(?i)(\b(select|insert|update|delete|drop|create|alter|truncate|union|join)\b)",
]
def validate_sql_input(self, input_str: str) -> bool:
for pattern in self.sql_injection_patterns:
if re.search(pattern, input_lower): # 每次调用都编译
return False
# 修改后
class InputValidator:
def __init__(self):
self.sql_injection_patterns = [
re.compile(r"(?i)(\b(select|insert|update|delete|drop|create|alter|truncate|union|join)\b)"),
]
self.email_pattern = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
self.phone_pattern = re.compile(r'^1[3-9]\d{9}$')
self.nine_digit_pattern = re.compile(r'^\d{9}$')
def validate_sql_input(self, input_str: str) -> bool:
for pattern in self.sql_injection_patterns:
if pattern.search(input_lower): # 使用预编译的正则表达式
return False
```
**性能测试结果**: 正则表达式验证性能提升 60.8%。
#### 1.3 城市代码验证优化
**文件**: weather.py
**问题分析**: 城市代码验证每次调用都重新编译正则表达式。
**解决方案**: 使用预编译的正则表达式进行验证。
**代码变更**:
```python
# 修改前
elif re.match(r"^\d{9}$", city_input):
city_code = city_input
# 修改后
elif input_validator.nine_digit_pattern.match(city_input):
city_code = city_input
```
**性能影响**: 减少正则表达式编译开销。
### 2. 文档体系完善
#### 2.1 安全最佳实践文档
**文件**: docs/security-best-practices.md
**内容概述**:
- 配置安全:环境变量使用指南
- 输入验证:SQL注入、XSS攻击防护
- 异常处理:最佳实践和错误处理模式
- 代码执行安全:沙箱环境使用
- 网络通信安全:HTTPS强制、超时设置
- 文件操作安全:路径验证和权限管理
- 日志安全:敏感信息掩码
**价值**: 为开发者提供完整的安全开发指南。
#### 2.2 性能优化指南
**文件**: docs/performance-optimization.md
**内容概述**:
- 异步编程:避免阻塞事件循环
- 内存管理:资源释放和优化技巧
- 数据库优化:连接池和查询优化
- 缓存策略:内存缓存和Redis缓存实现
- 代码优化:预编译正则表达式、局部变量使用
- 监控诊断:性能监控装饰器和内存使用监控
**价值**: 帮助开发者编写高性能插件。
#### 2.3 API 使用示例文档
**文件**: docs/api-usage-examples.md
**内容概述**:
- 插件开发基础:基本结构和权限检查
- 消息处理:发送消息和事件处理
- 配置管理:配置加载和验证
- 日志记录:不同级别日志使用
- 输入验证:基本验证和高级验证
- 环境变量管理:加载和验证
- 数据库操作:异步操作和模型设计
- 网络请求:HTTP客户端和API封装
**价值**: 降低学习曲线,提供实用开发示例。
### 3. 测试覆盖增强
#### 3.1 环境变量加载器测试
**文件**: tests/test_env_loader.py
**测试覆盖**:
- 环境变量加载功能
- 类型转换:整数、布尔值、列表
- 敏感信息掩码显示
- 文件权限检查
- 错误处理机制
**测试规模**: 25个测试方法
**覆盖率**: 覆盖 env_loader.py 所有主要功能
#### 3.2 输入验证器测试
**文件**: tests/test_input_validator.py
**测试覆盖**:
- SQL 注入检测
- XSS 攻击检测
- 路径遍历检测
- 命令注入检测
- 邮箱和手机号验证
- 数据清理功能
**测试规模**: 30个测试方法
**覆盖率**: 覆盖 input_validator.py 所有验证功能
## 技术改进分析
### 异步架构优化
- 将同步 HTTP 请求改为异步实现
- 避免网络请求阻塞事件循环
- 提高系统并发处理能力
- 遵循框架异步最佳实践
### 正则表达式性能优化
- 预编译所有正则表达式模式
- 避免重复编译开销
- 提高输入验证性能
- 减少内存分配次数
### 文档体系建设
- 创建完整的安全开发指南
- 提供详细的性能优化建议
- 添加丰富的 API 使用示例
- 降低新开发者学习成本
### 测试覆盖扩展
- 为新功能创建全面单元测试
- 确保代码质量和功能正确性
- 便于后续维护和重构
- 提供回归测试基础
## 性能影响评估
### 正面影响
1. 响应时间改善:异步 HTTP 请求避免阻塞,提高响应速度
2. 内存使用优化:预编译正则表达式减少内存分配
3. 并发能力提升:异步架构支持更多并发请求
4. 代码质量提高:完善文档和测试提高可维护性
### 兼容性评估
所有修改保持向后兼容性,未破坏现有功能。
## 后续工作建议
### 进一步性能优化
- 实现连接池管理,减少连接建立开销
- 添加缓存机制,减少重复数据请求
- 优化数据库查询性能,使用索引和批量操作
### 文档完善计划
- 添加更多插件开发实际示例
- 创建故障排除和调试指南
- 添加部署和运维文档
- 完善 API 参考文档
### 测试扩展方向
- 添加集成测试,验证组件间协作
- 添加性能测试,建立性能基准
- 添加安全测试,验证安全防护效果
- 添加端到端测试,验证完整业务流程
## 项目状态总结
P1 优先级优化工作已完成,主要成果包括:
1. 性能优化:改进异步处理和正则表达式性能,实测性能提升 60.8%
2. 文档完善:创建安全、性能和 API 使用三份核心文档
3. 测试增强:为新功能添加 55 个单元测试方法
这些改进显著提升了项目性能、安全性和可维护性,为后续开发工作奠定良好基础。
**项目状态**: P1 优先级优化任务已完成
警告,这是一次很大的改动,需要人员审核是否能够投入生产环境
This commit is contained in:
23
src/neobot/core/__init__.py
Normal file
23
src/neobot/core/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
NEO Bot Core Package
|
||||
|
||||
核心框架模块,包含事件处理、API封装、管理器等核心功能。
|
||||
"""
|
||||
|
||||
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI
|
||||
from .bot import Bot
|
||||
from .config_loader import global_config
|
||||
from .permission import Permission
|
||||
from .plugin import Plugin
|
||||
|
||||
__all__ = [
|
||||
"MessageAPI",
|
||||
"GroupAPI",
|
||||
"FriendAPI",
|
||||
"AccountAPI",
|
||||
"MediaAPI",
|
||||
"Bot",
|
||||
"global_config",
|
||||
"Permission",
|
||||
"Plugin",
|
||||
]
|
||||
21
src/neobot/core/api/__init__.py
Normal file
21
src/neobot/core/api/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
NEO Bot API Package
|
||||
|
||||
OneBot API 封装模块。
|
||||
"""
|
||||
|
||||
from .account import AccountAPI
|
||||
from .base import BaseAPI
|
||||
from .friend import FriendAPI
|
||||
from .group import GroupAPI
|
||||
from .media import MediaAPI
|
||||
from .message import MessageAPI
|
||||
|
||||
__all__ = [
|
||||
"AccountAPI",
|
||||
"BaseAPI",
|
||||
"FriendAPI",
|
||||
"GroupAPI",
|
||||
"MediaAPI",
|
||||
"MessageAPI",
|
||||
]
|
||||
210
src/neobot/core/api/account.py
Normal file
210
src/neobot/core/api/account.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
账号与状态相关 API 模块
|
||||
|
||||
该模块定义了 `AccountAPI` Mixin 类,提供了所有与机器人自身账号信息、
|
||||
状态设置等相关的 OneBot v11 API 封装。
|
||||
"""
|
||||
import orjson
|
||||
from typing import Dict, Any, Type, TypeVar
|
||||
from dataclasses import is_dataclass, fields
|
||||
from .base import BaseAPI
|
||||
from neobot.models.objects import LoginInfo, VersionInfo, Status
|
||||
from ..managers.redis_manager import redis_manager
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
def _safe_dataclass_from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
|
||||
"""
|
||||
安全地从字典创建 dataclass 实例,忽略多余的键。
|
||||
"""
|
||||
if not data:
|
||||
try:
|
||||
return cls()
|
||||
except TypeError:
|
||||
raise ValueError(f"无法在没有数据的情况下创建 {cls.__name__} 的实例")
|
||||
|
||||
# 使用官方的 is_dataclass 进行检查,对 MyPyC 更友好
|
||||
if not is_dataclass(cls):
|
||||
raise TypeError(f"{cls.__name__} 不是一个 dataclass")
|
||||
|
||||
# 获取 dataclass 的所有字段名
|
||||
known_fields = {f.name for f in fields(cls)}
|
||||
|
||||
# 过滤出 dataclass 认识的键值对
|
||||
filtered_data = {k: v for k, v in data.items() if k in known_fields}
|
||||
|
||||
return cls(**filtered_data)
|
||||
|
||||
class AccountAPI(BaseAPI):
|
||||
"""
|
||||
`AccountAPI` Mixin 类,提供了所有与机器人账号、状态相关的 API 方法。
|
||||
"""
|
||||
|
||||
async def get_login_info(self, no_cache: bool = False) -> LoginInfo:
|
||||
"""
|
||||
获取当前登录的机器人账号信息。
|
||||
|
||||
Args:
|
||||
no_cache (bool, optional): 是否不使用缓存,直接从服务器获取最新信息。Defaults to False.
|
||||
|
||||
Returns:
|
||||
LoginInfo: 包含登录号 QQ 和昵称的 `LoginInfo` 数据对象。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_login_info:{self.self_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.get(cache_key)
|
||||
if cached_data:
|
||||
return _safe_dataclass_from_dict(LoginInfo, orjson.loads(cached_data))
|
||||
|
||||
res = await self.call_api("get_login_info")
|
||||
await redis_manager.set(cache_key, orjson.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return _safe_dataclass_from_dict(LoginInfo, res)
|
||||
|
||||
async def get_version_info(self) -> VersionInfo:
|
||||
"""
|
||||
获取 OneBot v11 实现的版本信息。
|
||||
|
||||
Returns:
|
||||
VersionInfo: 包含 OneBot 实现版本信息的 `VersionInfo` 数据对象。
|
||||
"""
|
||||
res = await self.call_api("get_version_info")
|
||||
return _safe_dataclass_from_dict(VersionInfo, res)
|
||||
|
||||
async def get_status(self) -> Status:
|
||||
"""
|
||||
获取 OneBot v11 实现的状态信息。
|
||||
|
||||
Returns:
|
||||
Status: 包含 OneBot 状态信息的 `Status` 数据对象。
|
||||
"""
|
||||
res = await self.call_api("get_status")
|
||||
return _safe_dataclass_from_dict(Status, res)
|
||||
|
||||
async def bot_exit(self) -> Dict[str, Any]:
|
||||
"""
|
||||
让机器人进程退出(需要实现端支持)。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("bot_exit")
|
||||
|
||||
async def set_self_longnick(self, long_nick: str) -> Dict[str, Any]:
|
||||
"""
|
||||
设置机器人账号的个性签名。
|
||||
|
||||
Args:
|
||||
long_nick (str): 要设置的个性签名内容。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_self_longnick", {"longNick": long_nick})
|
||||
|
||||
async def set_input_status(self, user_id: int, event_type: int) -> Dict[str, Any]:
|
||||
"""
|
||||
设置 "对方正在输入..." 状态提示。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
event_type (int): 事件类型。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_input_status", {"user_id": user_id, "event_type": event_type})
|
||||
|
||||
async def set_diy_online_status(self, face_id: int, face_type: int, wording: str) -> Dict[str, Any]:
|
||||
"""
|
||||
设置自定义的 "在线状态"。
|
||||
|
||||
Args:
|
||||
face_id (int): 状态的表情 ID。
|
||||
face_type (int): 状态的表情类型。
|
||||
wording (str): 状态的描述文本。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_diy_online_status", {
|
||||
"face_id": face_id,
|
||||
"face_type": face_type,
|
||||
"wording": wording
|
||||
})
|
||||
|
||||
async def set_online_status(self, status_code: int) -> Dict[str, Any]:
|
||||
"""
|
||||
设置在线状态(如在线、离开、摸鱼等)。
|
||||
|
||||
Args:
|
||||
status_code (int): 目标在线状态的状态码。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_online_status", {"status_code": status_code})
|
||||
|
||||
async def set_qq_profile(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
设置机器人账号的个人资料。
|
||||
|
||||
Args:
|
||||
**kwargs: 个人资料的相关参数,具体字段请参考 OneBot v11 规范。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_qq_profile", kwargs)
|
||||
|
||||
async def set_qq_avatar(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
设置机器人账号的头像。
|
||||
|
||||
Args:
|
||||
**kwargs: 头像的相关参数,具体字段请参考 OneBot v11 规范。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_qq_avatar", kwargs)
|
||||
|
||||
async def get_clientkey(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取客户端密钥(通常用于 QQ 登录相关操作)。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("get_clientkey")
|
||||
|
||||
async def clean_cache(self) -> Dict[str, Any]:
|
||||
"""
|
||||
清理 OneBot v11 实现端的缓存。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("clean_cache")
|
||||
|
||||
async def get_profile_like(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取个人资料的点赞信息。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("get_profile_like")
|
||||
|
||||
async def nc_get_user_status(self, user_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户的在线状态 (NapCat 特有 API)。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("nc_get_user_status", {"user_id": user_id})
|
||||
|
||||
|
||||
92
src/neobot/core/api/base.py
Normal file
92
src/neobot/core/api/base.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
API 基础模块
|
||||
|
||||
定义了 API 调用的基础接口和统一处理逻辑。
|
||||
"""
|
||||
import copy
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from ..utils.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ws import WS
|
||||
|
||||
|
||||
class BaseAPI:
|
||||
"""
|
||||
API 基础类,提供了统一的 `call_api` 方法,包含日志记录和异常处理。
|
||||
"""
|
||||
_ws: "WS"
|
||||
self_id: int
|
||||
|
||||
def __init__(self, ws_client: "WS", self_id: int):
|
||||
self._ws = ws_client
|
||||
self.self_id = self_id
|
||||
|
||||
async def call_api(self, action: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
调用 OneBot v11 API,并提供统一的日志和异常处理。
|
||||
|
||||
:param action: API 动作名称
|
||||
:param params: API 参数
|
||||
:return: API 响应结果的数据部分
|
||||
:raises Exception: 当 API 调用失败或发生网络错误时
|
||||
"""
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
try:
|
||||
# 日志记录前,对敏感或过长的参数进行处理
|
||||
log_params = copy.deepcopy(params)
|
||||
|
||||
# 处理各种可能包含base64数据的字段
|
||||
def truncate_base64_recursive(obj):
|
||||
"""递归处理可能包含base64数据的对象"""
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
if isinstance(value, str):
|
||||
if value.startswith('data:image/') or value.startswith('data:video/') or value.startswith('data:audio/'):
|
||||
obj[key] = f"{value[:50]}... (base64 truncated)"
|
||||
elif len(value) > 100 and ('/' in value[:50] and '+' in value[:50] and '=' in value[-10:]):
|
||||
# 检查是否是base64编码的字符串
|
||||
obj[key] = f"{value[:50]}... (base64-like truncated)"
|
||||
elif isinstance(value, (dict, list)):
|
||||
truncate_base64_recursive(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
if isinstance(item, (dict, list)):
|
||||
truncate_base64_recursive(item)
|
||||
|
||||
truncate_base64_recursive(log_params)
|
||||
|
||||
# 如果是发送消息的动作,则原子化地增加发送消息总数
|
||||
if action in ["send_private_msg", "send_group_msg", "send_msg"]:
|
||||
from ..managers.redis_manager import redis_manager
|
||||
try:
|
||||
lua_script = "return redis.call('INCR', KEYS[1])"
|
||||
await redis_manager.execute_lua_script(
|
||||
script=lua_script,
|
||||
keys=["neobot:stats:messages_sent"],
|
||||
args=[]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息计数失败: {e}")
|
||||
|
||||
logger.debug(f"调用API -> action: {action}, params: {log_params}")
|
||||
|
||||
response = await self._ws.call_api(action, params)
|
||||
|
||||
# 对响应也做类似的处理
|
||||
log_response = copy.deepcopy(response)
|
||||
truncate_base64_recursive(log_response)
|
||||
logger.debug(f"API响应 <- {log_response}")
|
||||
|
||||
if response.get("status") == "failed":
|
||||
logger.warning(f"API调用失败: {response}")
|
||||
|
||||
return response.get("data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API调用异常: action={action}, params={params}, error={e}")
|
||||
raise
|
||||
|
||||
159
src/neobot/core/api/friend.py
Normal file
159
src/neobot/core/api/friend.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
好友与陌生人相关 API 模块
|
||||
|
||||
该模块定义了 `FriendAPI` Mixin 类,提供了所有与好友、陌生人信息
|
||||
等相关的 OneBot v11 API 封装。
|
||||
"""
|
||||
import orjson
|
||||
from typing import List, Dict, Any
|
||||
from .base import BaseAPI
|
||||
from neobot.models.objects import FriendInfo, StrangerInfo
|
||||
from ..managers.redis_manager import redis_manager
|
||||
|
||||
|
||||
class FriendAPI(BaseAPI):
|
||||
"""
|
||||
`FriendAPI` Mixin 类,提供了所有与好友、陌生人操作相关的 API 方法。
|
||||
"""
|
||||
|
||||
async def send_like(self, user_id: int, times: int = 1) -> Dict[str, Any]:
|
||||
"""
|
||||
向指定用户发送 "戳一戳" (点赞)。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
times (int, optional): 点赞次数,建议不超过 10 次。Defaults to 1.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("send_like", {"user_id": user_id, "times": times})
|
||||
|
||||
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> StrangerInfo:
|
||||
"""
|
||||
获取陌生人的信息。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
no_cache (bool, optional): 是否不使用缓存,直接从服务器获取。Defaults to False.
|
||||
|
||||
Returns:
|
||||
StrangerInfo: 包含陌生人信息的 `StrangerInfo` 数据对象。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_stranger_info:{user_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.redis.get(cache_key)
|
||||
if cached_data:
|
||||
return StrangerInfo(**orjson.loads(cached_data))
|
||||
|
||||
res = await self.call_api("get_stranger_info", {"user_id": user_id, "no_cache": no_cache})
|
||||
await redis_manager.redis.set(cache_key, orjson.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return StrangerInfo(**res)
|
||||
|
||||
async def get_friend_list(self, no_cache: bool = False) -> List[FriendInfo]:
|
||||
"""
|
||||
获取机器人账号的好友列表。
|
||||
|
||||
Args:
|
||||
no_cache (bool, optional): 是否不使用缓存,直接从服务器获取最新信息。Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[FriendInfo]: 包含所有好友信息的 `FriendInfo` 对象列表。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_friend_list:{self.self_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.redis.get(cache_key)
|
||||
if cached_data:
|
||||
return [FriendInfo(**item) for item in orjson.loads(cached_data)]
|
||||
|
||||
res = await self.call_api("get_friend_list")
|
||||
await redis_manager.redis.set(cache_key, orjson.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return [FriendInfo(**item) for item in res]
|
||||
|
||||
async def set_friend_add_request(self, flag: str, approve: bool = True, remark: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
处理收到的加好友请求。
|
||||
|
||||
Args:
|
||||
flag (str): 请求的标识,需要从 `request` 事件中获取。
|
||||
approve (bool, optional): 是否同意该好友请求。Defaults to True.
|
||||
remark (str, optional): 在同意请求时,为该好友设置的备注。Defaults to "".
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_friend_add_request", {"flag": flag, "approve": approve, "remark": remark})
|
||||
|
||||
async def get_friends_with_category(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取带分类的好友列表。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("get_friends_with_category")
|
||||
|
||||
async def get_unidirectional_friend_list(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取单向好友列表。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("get_unidirectional_friend_list")
|
||||
|
||||
async def friend_poke(self, user_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
发送好友戳一戳。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("friend_poke", {"user_id": user_id})
|
||||
|
||||
async def mark_private_msg_as_read(self, user_id: int, time: int = 0) -> Dict[str, Any]:
|
||||
"""
|
||||
标记私聊消息为已读。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
time (int, optional): 标记此时间戳之前的消息为已读。Defaults to 0 (全部标记)。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
params = {"user_id": user_id}
|
||||
if time > 0:
|
||||
params["time"] = time
|
||||
return await self.call_api("mark_private_msg_as_read", params)
|
||||
|
||||
async def get_friend_msg_history(self, user_id: int, count: int = 20) -> Dict[str, Any]:
|
||||
"""
|
||||
获取私聊消息历史记录。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
count (int, optional): 要获取的消息数量。Defaults to 20.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("get_friend_msg_history", {"user_id": user_id, "count": count})
|
||||
|
||||
async def forward_friend_single_msg(self, user_id: int, message_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
转发单条好友消息。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
message_id (str): 要转发的消息 ID。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("forward_friend_single_msg", {"user_id": user_id, "message_id": message_id})
|
||||
|
||||
|
||||
464
src/neobot/core/api/group.py
Normal file
464
src/neobot/core/api/group.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
群组相关 API 模块
|
||||
|
||||
该模块定义了 `GroupAPI` Mixin 类,提供了所有与群组管理、成员操作
|
||||
等相关的 OneBot v11 API 封装。
|
||||
"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
import orjson
|
||||
from ..managers.redis_manager import redis_manager
|
||||
from .base import BaseAPI
|
||||
from neobot.models.objects import GroupInfo, GroupMemberInfo, GroupHonorInfo
|
||||
from ..utils.logger import logger
|
||||
|
||||
|
||||
class GroupAPI(BaseAPI):
|
||||
"""
|
||||
`GroupAPI` Mixin 类,提供了所有与群组操作相关的 API 方法。
|
||||
"""
|
||||
|
||||
async def set_group_kick(self, group_id: int, user_id: int, reject_add_request: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
将指定成员踢出群组。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 要踢出的成员的 QQ 号。
|
||||
reject_add_request (bool, optional): 是否拒绝该用户此后的加群请求。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_kick", {"group_id": group_id, "user_id": user_id, "reject_add_request": reject_add_request})
|
||||
|
||||
async def set_group_ban(self, group_id: int, user_id: int, duration: int = 1800) -> Dict[str, Any]:
|
||||
"""
|
||||
禁言群组中的指定成员。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 要禁言的成员的 QQ 号。
|
||||
duration (int, optional): 禁言时长,单位为秒。设置为 0 表示解除禁言。
|
||||
Defaults to 1800 (30 分钟).
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_ban", {"group_id": group_id, "user_id": user_id, "duration": duration})
|
||||
|
||||
async def set_group_anonymous_ban(self, group_id: int, anonymous: Optional[Dict[str, Any]] = None, duration: int = 1800, flag: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
禁言群组中的匿名用户。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
anonymous (Dict[str, Any], optional): 要禁言的匿名用户对象,
|
||||
可从群消息事件的 `anonymous` 字段中获取。Defaults to None.
|
||||
duration (int, optional): 禁言时长,单位为秒。Defaults to 1800.
|
||||
flag (str, optional): 要禁言的匿名用户的 flag 标识,
|
||||
可从群消息事件的 `anonymous` 字段中获取。Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
params: Dict[str, Any] = {"group_id": group_id, "duration": duration}
|
||||
if anonymous:
|
||||
params["anonymous"] = anonymous
|
||||
if flag:
|
||||
params["flag"] = flag
|
||||
return await self.call_api("set_group_anonymous_ban", params)
|
||||
|
||||
async def set_group_whole_ban(self, group_id: int, enable: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
开启或关闭群组全员禁言。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
enable (bool, optional): True 表示开启全员禁言,False 表示关闭。Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_whole_ban", {"group_id": group_id, "enable": enable})
|
||||
|
||||
async def set_group_admin(self, group_id: int, user_id: int, enable: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
设置或取消群组成员的管理员权限。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 目标成员的 QQ 号。
|
||||
enable (bool, optional): True 表示设为管理员,False 表示取消管理员。Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_admin", {"group_id": group_id, "user_id": user_id, "enable": enable})
|
||||
|
||||
async def set_group_anonymous(self, group_id: int, enable: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
开启或关闭群组的匿名聊天功能。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
enable (bool, optional): True 表示开启匿名,False 表示关闭。Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_anonymous", {"group_id": group_id, "enable": enable})
|
||||
|
||||
async def set_group_card(self, group_id: int, user_id: int, card: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
设置群组成员的群名片。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 目标成员的 QQ 号。
|
||||
card (str, optional): 要设置的群名片内容。
|
||||
传入空字符串 `""` 或 `None` 表示删除该成员的群名片。Defaults to "".
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_card", {"group_id": group_id, "user_id": user_id, "card": card})
|
||||
|
||||
async def set_group_name(self, group_id: int, group_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
设置群组的名称。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
group_name (str): 新的群组名称。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_name", {"group_id": group_id, "group_name": group_name})
|
||||
|
||||
async def set_group_leave(self, group_id: int, is_dismiss: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
退出或解散一个群组。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
is_dismiss (bool, optional): 是否解散群组。
|
||||
仅当机器人是群主时,此项设为 True 才能解散群。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_leave", {"group_id": group_id, "is_dismiss": is_dismiss})
|
||||
|
||||
async def set_group_special_title(self, group_id: int, user_id: int, special_title: str = "", duration: int = -1) -> Dict[str, Any]:
|
||||
"""
|
||||
为群组成员设置专属头衔。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 目标成员的 QQ 号。
|
||||
special_title (str, optional): 专属头衔内容。
|
||||
传入空字符串 `""` 或 `None` 表示删除头衔。Defaults to "".
|
||||
duration (int, optional): 头衔有效期,单位为秒。-1 表示永久。Defaults to -1.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_special_title", {"group_id": group_id, "user_id": user_id, "special_title": special_title, "duration": duration})
|
||||
|
||||
async def get_group_info(self, group_id: int, no_cache: bool = False) -> GroupInfo:
|
||||
"""
|
||||
获取群组的详细信息。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
no_cache (bool, optional): 是否不使用缓存,直接从服务器获取最新信息。Defaults to False.
|
||||
|
||||
Returns:
|
||||
GroupInfo: 包含群组信息的 `GroupInfo` 数据对象。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_group_info:{group_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.redis.get(cache_key)
|
||||
if cached_data:
|
||||
return GroupInfo(**orjson.loads(cached_data))
|
||||
|
||||
res = await self.call_api("get_group_info", {"group_id": group_id})
|
||||
await redis_manager.redis.set(cache_key, orjson.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return GroupInfo(**res)
|
||||
|
||||
async def get_group_list(self) -> Any:
|
||||
"""
|
||||
获取机器人加入的所有群组的列表。
|
||||
|
||||
Returns:
|
||||
Any: 包含所有群组信息的列表(可能是字典列表或对象列表)。
|
||||
"""
|
||||
res = await self.call_api("get_group_list")
|
||||
|
||||
# 增加日志记录 API 原始返回
|
||||
logger.debug(f"OneBot API 'get_group_list' raw response: {res}")
|
||||
return res
|
||||
|
||||
# 健壮性处理:处理标准的 OneBot v11 响应格式
|
||||
if isinstance(res, dict) and res.get("status") == "ok":
|
||||
group_data = res.get("data", [])
|
||||
if isinstance(group_data, list):
|
||||
return [GroupInfo(**item) for item in group_data]
|
||||
else:
|
||||
logger.error(f"The 'data' field in 'get_group_list' response is not a list: {group_data}")
|
||||
return []
|
||||
|
||||
# 兼容处理:如果返回的是列表(非标准但可能存在)
|
||||
if isinstance(res, list):
|
||||
return [GroupInfo(**item) for item in res]
|
||||
|
||||
logger.error(f"Unexpected response format from 'get_group_list': {res}")
|
||||
return []
|
||||
|
||||
async def get_group_member_info(self, group_id: int, user_id: int, no_cache: bool = False) -> GroupMemberInfo:
|
||||
"""
|
||||
获取指定群组成员的详细信息。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 目标成员的 QQ 号。
|
||||
no_cache (bool, optional): 是否不使用缓存。Defaults to False.
|
||||
|
||||
Returns:
|
||||
GroupMemberInfo: 包含群成员信息的 `GroupMemberInfo` 数据对象。
|
||||
"""
|
||||
cache_key = f"neobot:cache:get_group_member_info:{group_id}:{user_id}"
|
||||
if not no_cache:
|
||||
cached_data = await redis_manager.redis.get(cache_key)
|
||||
if cached_data:
|
||||
return GroupMemberInfo(**orjson.loads(cached_data))
|
||||
|
||||
res = await self.call_api("get_group_member_info", {"group_id": group_id, "user_id": user_id})
|
||||
await redis_manager.redis.set(cache_key, orjson.dumps(res), ex=3600) # 缓存 1 小时
|
||||
return GroupMemberInfo(**res)
|
||||
|
||||
async def get_group_member_list(self, group_id: int) -> List[GroupMemberInfo]:
|
||||
"""
|
||||
获取一个群组的所有成员列表。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
|
||||
Returns:
|
||||
List[GroupMemberInfo]: 包含所有群成员信息的 `GroupMemberInfo` 对象列表。
|
||||
"""
|
||||
res = await self.call_api("get_group_member_list", {"group_id": group_id})
|
||||
return [GroupMemberInfo(**item) for item in res]
|
||||
|
||||
async def get_group_honor_info(self, group_id: int, type: str) -> GroupHonorInfo:
|
||||
"""
|
||||
获取群组的荣誉信息(如龙王、群聊之火等)。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
type (str): 要获取的荣誉类型。
|
||||
可选值: "talkative", "performer", "legend", "strong_newbie", "emotion" 等。
|
||||
|
||||
Returns:
|
||||
GroupHonorInfo: 包含群荣誉信息的 `GroupHonorInfo` 数据对象。
|
||||
"""
|
||||
res = await self.call_api("get_group_honor_info", {"group_id": group_id, "type": type})
|
||||
return GroupHonorInfo(**res)
|
||||
|
||||
async def set_group_add_request(self, flag: str, sub_type: str, approve: bool = True, reason: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
处理加群请求或邀请。
|
||||
|
||||
Args:
|
||||
flag (str): 请求的标识,需要从 `request` 事件中获取。
|
||||
sub_type (str): 请求的子类型,`add` 或 `invite`,
|
||||
需要与 `request` 事件中的 `sub_type` 字段相符。
|
||||
approve (bool, optional): 是否同意请求或邀请。Defaults to True.
|
||||
reason (str, optional): 拒绝加群的理由(仅在 `approve` 为 False 时有效)。Defaults to "".
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_add_request", {"flag": flag, "sub_type": sub_type, "approve": approve, "reason": reason})
|
||||
|
||||
async def get_group_info_ex(self, group_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
获取群扩展信息 (NapCat 特有 API)。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("get_group_info_ex", {"group_id": group_id})
|
||||
|
||||
async def delete_essence_msg(self, message_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
删除精华消息。
|
||||
|
||||
Args:
|
||||
message_id (int): 目标消息的 ID。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("delete_essence_msg", {"message_id": message_id})
|
||||
|
||||
async def group_poke(self, group_id: int, user_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
在群内发送 "戳一戳"。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
user_id (int): 目标成员的 QQ 号。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("group_poke", {"group_id": group_id, "user_id": user_id})
|
||||
|
||||
async def mark_group_msg_as_read(self, group_id: int, time: int = 0) -> Dict[str, Any]:
|
||||
"""
|
||||
标记群消息为已读。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
time (int, optional): 标记此时间戳之前的消息为已读。Defaults to 0 (全部标记)。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
params = {"group_id": group_id}
|
||||
if time > 0:
|
||||
params["time"] = time
|
||||
return await self.call_api("mark_group_msg_as_read", params)
|
||||
|
||||
async def forward_group_single_msg(self, group_id: int, message_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
转发单条群消息。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
message_id (str): 要转发的消息 ID。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("forward_group_single_msg", {"group_id": group_id, "message_id": message_id})
|
||||
|
||||
async def set_group_portrait(self, group_id: int, file: str, cache: int = 1) -> Dict[str, Any]:
|
||||
"""
|
||||
设置群头像。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
file (str): 图片文件的路径或 URL 或 Base64。
|
||||
cache (int, optional): 是否使用缓存 (1: 是, 0: 否)。Defaults to 1.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_portrait", {"group_id": group_id, "file": file, "cache": cache})
|
||||
|
||||
async def _send_group_notice(self, group_id: int, content: str, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
发送群公告。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
content (str): 公告内容。
|
||||
**kwargs: 其他可选参数 (如 image)。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
params = {"group_id": group_id, "content": content}
|
||||
params.update(kwargs)
|
||||
return await self.call_api("_send_group_notice", params)
|
||||
|
||||
async def _get_group_notice(self, group_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
获取群公告。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("_get_group_notice", {"group_id": group_id})
|
||||
|
||||
async def _del_group_notice(self, group_id: int, notice_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
删除群公告。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
notice_id (str): 公告 ID。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("_del_group_notice", {"group_id": group_id, "notice_id": notice_id})
|
||||
|
||||
async def get_group_at_all_remain(self, group_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 @全体成员 的剩余次数。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("get_group_at_all_remain", {"group_id": group_id})
|
||||
|
||||
async def get_group_system_msg(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取群系统消息。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("get_group_system_msg")
|
||||
|
||||
async def get_group_shut_list(self, group_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
获取群禁言列表。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("get_group_shut_list", {"group_id": group_id})
|
||||
|
||||
async def set_group_remark(self, group_id: int, remark: str) -> Dict[str, Any]:
|
||||
"""
|
||||
设置群备注。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
remark (str): 要设置的备注。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_remark", {"group_id": group_id, "remark": remark})
|
||||
|
||||
async def set_group_sign(self, group_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
设置群签到。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("set_group_sign", {"group_id": group_id})
|
||||
|
||||
|
||||
49
src/neobot/core/api/media.py
Normal file
49
src/neobot/core/api/media.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
媒体API模块
|
||||
|
||||
封装了与图片、语音等媒体文件相关的API。
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
|
||||
from .base import BaseAPI
|
||||
|
||||
|
||||
class MediaAPI(BaseAPI):
|
||||
"""
|
||||
媒体相关API
|
||||
"""
|
||||
|
||||
async def can_send_image(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查是否可以发送图片
|
||||
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="can_send_image")
|
||||
|
||||
async def can_send_record(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查是否可以发送语音
|
||||
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="can_send_record")
|
||||
|
||||
async def get_image(self, file: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取图片信息
|
||||
|
||||
:param file: 图片文件名或路径
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="get_image", params={"file": file})
|
||||
|
||||
async def get_file(self, file_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取文件信息
|
||||
|
||||
:param file_id: 文件ID
|
||||
:return: OneBot v11标准响应
|
||||
"""
|
||||
return await self.call_api(action="get_file", params={"file_id": file_id})
|
||||
|
||||
202
src/neobot/core/api/message.py
Normal file
202
src/neobot/core/api/message.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
消息相关 API 模块
|
||||
|
||||
该模块定义了 `MessageAPI` Mixin 类,提供了所有与消息发送、撤回、
|
||||
转发等相关的 OneBot v11 API 封装。
|
||||
"""
|
||||
from typing import Union, List, Dict, Any, TYPE_CHECKING
|
||||
from .base import BaseAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from neobot.models.message import MessageSegment
|
||||
from neobot.models.events.base import OneBotEvent
|
||||
|
||||
|
||||
class MessageAPI(BaseAPI):
|
||||
"""
|
||||
`MessageAPI` Mixin 类,提供了所有与消息操作相关的 API 方法。
|
||||
"""
|
||||
|
||||
async def send_group_msg(self, group_id: int, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
发送群消息。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
message (Union[str, MessageSegment, List[MessageSegment]]): 要发送的消息内容。
|
||||
可以是纯文本字符串、单个消息段对象或消息段列表。
|
||||
auto_escape (bool, optional): 仅当 `message` 为字符串时有效,
|
||||
是否对消息内容进行 CQ 码转义。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api(
|
||||
"send_group_msg", {"group_id": group_id, "message": self._process_message(message), "auto_escape": auto_escape}
|
||||
)
|
||||
|
||||
async def send_private_msg(self, user_id: int, message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
发送私聊消息。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
message (Union[str, MessageSegment, List[MessageSegment]]): 要发送的消息内容。
|
||||
auto_escape (bool, optional): 是否对消息内容进行 CQ 码转义。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api(
|
||||
"send_private_msg", {"user_id": user_id, "message": self._process_message(message), "auto_escape": auto_escape}
|
||||
)
|
||||
|
||||
async def send(self, event: "OneBotEvent", message: Union[str, "MessageSegment", List["MessageSegment"]], auto_escape: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
智能发送消息。
|
||||
|
||||
该方法会根据传入的事件对象 `event` 自动判断是私聊还是群聊,
|
||||
并调用相应的发送函数。如果事件是消息事件,则优先使用 `reply` 方法。
|
||||
|
||||
Args:
|
||||
event (OneBotEvent): 触发该发送行为的事件对象。
|
||||
message (Union[str, MessageSegment, List[MessageSegment]]): 要发送的消息内容。
|
||||
auto_escape (bool, optional): 是否对消息内容进行 CQ 码转义。Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
# 如果是消息事件,直接调用 reply
|
||||
if hasattr(event, "reply"):
|
||||
await event.reply(message, auto_escape)
|
||||
return {"status": "ok", "msg": "Replied via event.reply()"}
|
||||
|
||||
# 尝试从事件中获取 user_id 或 group_id
|
||||
user_id = getattr(event, "user_id", None)
|
||||
group_id = getattr(event, "group_id", None)
|
||||
|
||||
if group_id:
|
||||
return await self.send_group_msg(group_id, message, auto_escape)
|
||||
elif user_id:
|
||||
return await self.send_private_msg(user_id, message, auto_escape)
|
||||
|
||||
return {"status": "failed", "msg": "Unknown message target"}
|
||||
|
||||
async def delete_msg(self, message_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
撤回一条消息。
|
||||
|
||||
Args:
|
||||
message_id (int): 要撤回的消息的 ID。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("delete_msg", {"message_id": message_id})
|
||||
|
||||
async def get_msg(self, message_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
获取一条消息的详细信息。
|
||||
|
||||
Args:
|
||||
message_id (int): 要获取的消息的 ID。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据,包含消息详情。
|
||||
"""
|
||||
return await self.call_api("get_msg", {"message_id": message_id})
|
||||
|
||||
async def get_forward_msg(self, id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取合并转发消息的内容。
|
||||
|
||||
Args:
|
||||
id (str): 合并转发消息的 ID。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 转发消息的节点列表。
|
||||
"""
|
||||
forward_data = await self.call_api("get_forward_msg", {"id": id})
|
||||
nodes = forward_data.get("data")
|
||||
|
||||
if not isinstance(nodes, list):
|
||||
# 兼容某些实现可能将节点放在 'messages' 键下
|
||||
data = forward_data.get('data', {})
|
||||
if isinstance(data, dict):
|
||||
nodes = data.get('messages')
|
||||
|
||||
if not isinstance(nodes, list):
|
||||
raise ValueError("在 get_forward_msg 响应中找不到消息节点列表")
|
||||
|
||||
return nodes
|
||||
|
||||
async def send_group_forward_msg(self, group_id: int, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
发送群聊合并转发消息。
|
||||
|
||||
Args:
|
||||
group_id (int): 目标群组的群号。
|
||||
messages (List[Dict[str, Any]]): 消息节点列表。
|
||||
推荐使用 `bot.build_forward_node` 来构建节点。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("send_group_forward_msg", {"group_id": group_id, "messages": messages})
|
||||
|
||||
async def send_private_forward_msg(self, user_id: int, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
发送私聊合并转发消息。
|
||||
|
||||
Args:
|
||||
user_id (int): 目标用户的 QQ 号。
|
||||
messages (List[Dict[str, Any]]): 消息节点列表。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OneBot API 的响应数据。
|
||||
"""
|
||||
return await self.call_api("send_private_forward_msg", {"user_id": user_id, "messages": messages})
|
||||
|
||||
def _process_message(self, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Union[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
内部方法:将消息内容处理成 OneBot API 可接受的格式。
|
||||
|
||||
- `str` -> `str`
|
||||
- `MessageSegment` -> `List[Dict]`
|
||||
- `List[MessageSegment]` -> `List[Dict]`
|
||||
|
||||
Args:
|
||||
message: 原始消息内容。
|
||||
|
||||
Returns:
|
||||
处理后的消息内容。
|
||||
"""
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
|
||||
# 避免循环导入,在运行时导入
|
||||
from neobot.models.message import MessageSegment
|
||||
|
||||
if isinstance(message, MessageSegment):
|
||||
return [self._segment_to_dict(message)]
|
||||
|
||||
if isinstance(message, list):
|
||||
return [self._segment_to_dict(m) for m in message if isinstance(m, MessageSegment)]
|
||||
|
||||
return str(message)
|
||||
|
||||
def _segment_to_dict(self, segment: "MessageSegment") -> Dict[str, Any]:
|
||||
"""
|
||||
内部方法:将 `MessageSegment` 对象转换为字典。
|
||||
|
||||
Args:
|
||||
segment (MessageSegment): 消息段对象。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 符合 OneBot 规范的消息段字典。
|
||||
"""
|
||||
return {
|
||||
"type": segment.type,
|
||||
"data": segment.data
|
||||
}
|
||||
|
||||
117
src/neobot/core/bot.py
Normal file
117
src/neobot/core/bot.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Bot 核心抽象模块
|
||||
|
||||
该模块定义了 `Bot` 类,它是与 OneBot v11 API 进行交互的主要接口。
|
||||
`Bot` 类通过继承 `api` 目录下的各个 Mixin 类,将不同类别的 API 调用
|
||||
整合在一起,提供了一个统一、便捷的调用入口。
|
||||
|
||||
主要职责包括:
|
||||
- 封装 WebSocket 通信,提供 `call_api` 方法。
|
||||
- 提供高级消息发送功能,如 `send_forwarded_messages`。
|
||||
- 整合所有细分的 API 调用(消息、群组、好友等)。
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Dict, Any, List, Union, Optional
|
||||
from neobot.models.events.base import OneBotEvent
|
||||
from neobot.models.message import MessageSegment
|
||||
from neobot.models.objects import GroupInfo, StrangerInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .ws import WS
|
||||
from .utils.executor import CodeExecutor
|
||||
|
||||
from .api import MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI
|
||||
|
||||
|
||||
class Bot(MessageAPI, GroupAPI, FriendAPI, AccountAPI, MediaAPI):
|
||||
"""
|
||||
机器人核心类,封装了所有与 OneBot API 的交互。
|
||||
|
||||
通过 Mixin 模式继承了所有 API 功能,使得结构清晰且易于扩展。
|
||||
实例由 `WS` 客户端在连接成功后创建,并传递给所有事件处理器和插件。
|
||||
"""
|
||||
|
||||
def __init__(self, ws_client: "WS"):
|
||||
"""
|
||||
初始化 Bot 实例。
|
||||
|
||||
Args:
|
||||
ws_client (WS): WebSocket 客户端实例,负责底层的 API 请求和响应处理。
|
||||
"""
|
||||
super().__init__(ws_client, ws_client.self_id or 0)
|
||||
self.code_executor: Optional["CodeExecutor"] = None
|
||||
self.nickname: str = ""
|
||||
|
||||
async def get_group_list(self, no_cache: bool = False) -> List[GroupInfo]:
|
||||
# GroupAPI.get_group_list 不支持 no_cache 参数,这里忽略它
|
||||
result = await super().get_group_list()
|
||||
# 确保结果是 GroupInfo 对象列表
|
||||
return [GroupInfo(**group) if isinstance(group, dict) else group for group in result]
|
||||
|
||||
async def get_stranger_info(self, user_id: int, no_cache: bool = False) -> StrangerInfo:
|
||||
result = await super().get_stranger_info(user_id=user_id, no_cache=no_cache)
|
||||
# 确保结果是 StrangerInfo 对象
|
||||
if isinstance(result, dict):
|
||||
return StrangerInfo(**result)
|
||||
return result
|
||||
|
||||
|
||||
def build_forward_node(self, user_id: int, nickname: str, message: Union[str, "MessageSegment", List["MessageSegment"]]) -> Dict[str, Any]:
|
||||
"""
|
||||
构建一个用于合并转发的消息节点 (Node)。
|
||||
|
||||
这是一个辅助方法,用于方便地创建符合 OneBot v11 规范的消息节点,
|
||||
以便在 `send_forwarded_messages` 中使用。
|
||||
|
||||
Args:
|
||||
user_id (int): 发送者的 QQ 号。
|
||||
nickname (str): 发送者在消息中显示的昵称。
|
||||
message (Union[str, MessageSegment, List[MessageSegment]]): 该节点的消息内容,
|
||||
可以是纯文本、单个消息段或消息段列表。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 构造好的消息节点字典。
|
||||
"""
|
||||
return {
|
||||
"type": "node",
|
||||
"data": {
|
||||
"uin": user_id,
|
||||
"name": nickname,
|
||||
"content": self._process_message(message)
|
||||
}
|
||||
}
|
||||
|
||||
async def send_forwarded_messages(self, target: Union[int, "OneBotEvent"], nodes: List[Dict[str, Any]]):
|
||||
"""
|
||||
发送合并转发消息。
|
||||
|
||||
该方法实现了智能判断,可以根据 `target` 的类型自动发送群聊合并转发
|
||||
或私聊合并转发消息。
|
||||
|
||||
Args:
|
||||
target (Union[int, OneBotEvent]): 发送目标。
|
||||
- 如果是 `OneBotEvent` 对象,则自动判断是群聊还是私聊。
|
||||
- 如果是 `int`,则默认为群号,发送群聊合并转发。
|
||||
nodes (List[Dict[str, Any]]): 消息节点列表。
|
||||
推荐使用 `build_forward_node` 方法来构建列表中的每个节点。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果事件对象中既没有 `group_id` 也没有 `user_id`。
|
||||
"""
|
||||
if isinstance(target, OneBotEvent):
|
||||
group_id = getattr(target, "group_id", None)
|
||||
user_id = getattr(target, "user_id", None)
|
||||
|
||||
if group_id:
|
||||
# 直接发送群聊合并转发
|
||||
await self.send_group_forward_msg(group_id, nodes)
|
||||
elif user_id:
|
||||
# 发送私聊合并转发
|
||||
await self.send_private_forward_msg(user_id, nodes)
|
||||
else:
|
||||
raise ValueError("Event has neither group_id nor user_id")
|
||||
|
||||
else:
|
||||
# 默认行为是发送到群聊
|
||||
group_id = target
|
||||
await self.send_group_forward_msg(group_id, nodes)
|
||||
|
||||
316
src/neobot/core/config_loader.py
Normal file
316
src/neobot/core/config_loader.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
配置加载模块
|
||||
|
||||
负责读取和解析 config.toml 配置文件,提供全局配置对象。
|
||||
支持从环境变量覆盖敏感配置。
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import tomllib
|
||||
from pydantic import ValidationError
|
||||
from .config_models import ConfigModel, NapCatWSModel, BotModel, RedisModel, DockerModel, ImageManagerModel, MySQLModel, ReverseWSModel, ThreadingModel, BilibiliModel, LocalFileServerModel, DiscordModel, CrossPlatformModel, LoggingModel
|
||||
from .utils.logger import ModuleLogger
|
||||
from .utils.exceptions import ConfigError, ConfigNotFoundError, ConfigValidationError
|
||||
from .utils.env_loader import env_loader
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
配置加载类,负责读取和解析 config.toml 文件
|
||||
支持从环境变量覆盖敏感配置
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str = "config.toml"):
|
||||
"""
|
||||
初始化配置加载器
|
||||
|
||||
:param file_path: 配置文件路径,默认为 "config.toml"
|
||||
"""
|
||||
self.path = Path(file_path)
|
||||
self._model: ConfigModel
|
||||
# 创建模块专用日志记录器
|
||||
self.logger = ModuleLogger("ConfigLoader")
|
||||
# 加载环境变量
|
||||
env_loader.load()
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
加载并验证配置文件
|
||||
|
||||
:raises ConfigNotFoundError: 如果配置文件不存在
|
||||
:raises ConfigValidationError: 如果配置格式不正确
|
||||
:raises ConfigError: 如果加载配置时发生其他错误
|
||||
"""
|
||||
# 检查配置文件权限
|
||||
self._check_file_permissions()
|
||||
|
||||
if not self.path.exists():
|
||||
self.logger.warning(f"配置文件 {self.path} 未找到,正在生成示例配置...")
|
||||
self._generate_example_config()
|
||||
self.logger.success(f"示例配置已生成: {self.path}")
|
||||
self.logger.info("请编辑配置文件后重新启动程序")
|
||||
|
||||
try:
|
||||
self.logger.info(f"正在从 {self.path} 加载配置...")
|
||||
with open(self.path, "rb") as f:
|
||||
raw_config = tomllib.load(f)
|
||||
|
||||
# 从环境变量覆盖敏感配置
|
||||
raw_config = self._override_with_env_vars(raw_config)
|
||||
|
||||
self._model = ConfigModel(**raw_config)
|
||||
self.logger.success("配置加载并验证成功!")
|
||||
|
||||
except ValidationError as e:
|
||||
error_details = []
|
||||
for error in e.errors():
|
||||
field = " -> ".join(map(str, error["loc"]))
|
||||
error_msg = f"字段 '{field}': {error['msg']}"
|
||||
error_details.append(error_msg)
|
||||
|
||||
validation_error = ConfigValidationError(
|
||||
message="配置验证失败"
|
||||
)
|
||||
validation_error.original_error = e
|
||||
|
||||
self.logger.error("配置验证失败,请检查 `config.toml` 文件中的以下错误:")
|
||||
for detail in error_details:
|
||||
self.logger.error(f" - {detail}")
|
||||
|
||||
self.logger.log_custom_exception(validation_error)
|
||||
raise validation_error
|
||||
except tomllib.TOMLDecodeError as e:
|
||||
error = ConfigError(
|
||||
message=f"TOML解析错误: {str(e)}"
|
||||
)
|
||||
error.original_error = e
|
||||
self.logger.error(f"加载配置文件时发生TOML解析错误: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
raise error
|
||||
except Exception as e:
|
||||
error = ConfigError(
|
||||
message=f"加载配置文件时发生未知错误: {str(e)}"
|
||||
)
|
||||
error.original_error = e
|
||||
self.logger.exception(f"加载配置文件时发生未知错误: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
raise error
|
||||
|
||||
def _check_file_permissions(self):
|
||||
"""
|
||||
检查配置文件权限
|
||||
|
||||
确保配置文件不会被其他用户读取,保护敏感信息。
|
||||
"""
|
||||
if not self.path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
import os
|
||||
import stat
|
||||
|
||||
# 获取文件状态
|
||||
file_stat = self.path.stat()
|
||||
|
||||
# 检查文件权限
|
||||
mode = file_stat.st_mode
|
||||
|
||||
# 检查是否其他用户可读
|
||||
if mode & stat.S_IROTH:
|
||||
self.logger.warning(f"配置文件 {self.path} 其他用户可读,存在安全风险")
|
||||
self.logger.info("建议使用命令: chmod 600 config.toml")
|
||||
|
||||
# 检查是否其他用户可写
|
||||
if mode & stat.S_IWOTH:
|
||||
self.logger.error(f"配置文件 {self.path} 其他用户可写,存在严重安全风险!")
|
||||
self.logger.error("请立即修复权限: chmod 600 config.toml")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"检查文件权限失败: {e}")
|
||||
|
||||
def _override_with_env_vars(self, raw_config: dict) -> dict:
|
||||
"""
|
||||
使用环境变量覆盖敏感配置
|
||||
|
||||
Args:
|
||||
raw_config: 原始配置字典
|
||||
|
||||
Returns:
|
||||
更新后的配置字典
|
||||
"""
|
||||
# MySQL 配置
|
||||
if 'mysql' in raw_config:
|
||||
mysql_config = raw_config['mysql']
|
||||
mysql_config['host'] = env_loader.get('MYSQL_HOST', mysql_config.get('host', 'localhost'))
|
||||
mysql_config['port'] = env_loader.get_int('MYSQL_PORT', mysql_config.get('port', 3306))
|
||||
mysql_config['user'] = env_loader.get('MYSQL_USER', mysql_config.get('user', 'root'))
|
||||
mysql_config['password'] = env_loader.get('MYSQL_PASSWORD', mysql_config.get('password', ''))
|
||||
mysql_config['db'] = env_loader.get('MYSQL_DB', mysql_config.get('db', 'neobot'))
|
||||
|
||||
# Redis 配置
|
||||
if 'redis' in raw_config:
|
||||
redis_config = raw_config['redis']
|
||||
redis_config['host'] = env_loader.get('REDIS_HOST', redis_config.get('host', 'localhost'))
|
||||
redis_config['port'] = env_loader.get_int('REDIS_PORT', redis_config.get('port', 6379))
|
||||
redis_config['db'] = env_loader.get_int('REDIS_DB', redis_config.get('db', 0))
|
||||
redis_config['password'] = env_loader.get('REDIS_PASSWORD', redis_config.get('password', ''))
|
||||
|
||||
# NapCat WebSocket 配置
|
||||
if 'napcat_ws' in raw_config:
|
||||
ws_config = raw_config['napcat_ws']
|
||||
ws_config['uri'] = env_loader.get('NAPCAT_WS_URI', ws_config.get('uri', 'ws://localhost:8080'))
|
||||
ws_config['token'] = env_loader.get('NAPCAT_WS_TOKEN', ws_config.get('token', ''))
|
||||
|
||||
# Discord 配置
|
||||
if 'discord' in raw_config:
|
||||
discord_config = raw_config['discord']
|
||||
discord_config['token'] = env_loader.get('DISCORD_TOKEN', discord_config.get('token', ''))
|
||||
discord_config['proxy'] = env_loader.get('DISCORD_PROXY', discord_config.get('proxy'))
|
||||
|
||||
# Bilibili 配置
|
||||
if 'bilibili' in raw_config:
|
||||
bili_config = raw_config['bilibili']
|
||||
bili_config['sessdata'] = env_loader.get('BILIBILI_SESSDATA', bili_config.get('sessdata'))
|
||||
bili_config['bili_jct'] = env_loader.get('BILIBILI_BILI_JCT', bili_config.get('bili_jct'))
|
||||
bili_config['buvid3'] = env_loader.get('BILIBILI_BUVID3', bili_config.get('buvid3'))
|
||||
bili_config['dedeuserid'] = env_loader.get('BILIBILI_DEDEUSERID', bili_config.get('dedeuserid'))
|
||||
|
||||
# Docker 配置
|
||||
if 'docker' in raw_config:
|
||||
docker_config = raw_config['docker']
|
||||
docker_config['base_url'] = env_loader.get('DOCKER_BASE_URL', docker_config.get('base_url'))
|
||||
docker_config['tls_verify'] = env_loader.get_bool('DOCKER_TLS_VERIFY', docker_config.get('tls_verify', False))
|
||||
|
||||
# 反向 WebSocket 配置
|
||||
if 'reverse_ws' in raw_config:
|
||||
reverse_config = raw_config['reverse_ws']
|
||||
reverse_config['enabled'] = env_loader.get_bool('REVERSE_WS_ENABLED', reverse_config.get('enabled', False))
|
||||
reverse_config['host'] = env_loader.get('REVERSE_WS_HOST', reverse_config.get('host', '0.0.0.0'))
|
||||
reverse_config['port'] = env_loader.get_int('REVERSE_WS_PORT', reverse_config.get('port', 3002))
|
||||
reverse_config['token'] = env_loader.get('REVERSE_WS_TOKEN', reverse_config.get('token'))
|
||||
|
||||
# 本地文件服务器配置
|
||||
if 'local_file_server' in raw_config:
|
||||
server_config = raw_config['local_file_server']
|
||||
server_config['enabled'] = env_loader.get_bool('LOCAL_FILE_SERVER_ENABLED', server_config.get('enabled', True))
|
||||
server_config['host'] = env_loader.get('LOCAL_FILE_SERVER_HOST', server_config.get('host', '0.0.0.0'))
|
||||
server_config['port'] = env_loader.get_int('LOCAL_FILE_SERVER_PORT', server_config.get('port', 3003))
|
||||
|
||||
# 日志配置
|
||||
if 'logging' in raw_config:
|
||||
log_config = raw_config['logging']
|
||||
log_config['level'] = env_loader.get('LOG_LEVEL', log_config.get('level', 'DEBUG'))
|
||||
log_config['file_level'] = env_loader.get('LOG_FILE_LEVEL', log_config.get('file_level', 'DEBUG'))
|
||||
log_config['console_level'] = env_loader.get('LOG_CONSOLE_LEVEL', log_config.get('console_level', 'INFO'))
|
||||
|
||||
return raw_config
|
||||
|
||||
def _generate_example_config(self):
|
||||
"""
|
||||
生成示例配置文件
|
||||
"""
|
||||
example_path = Path("config.example.toml")
|
||||
|
||||
if not example_path.exists():
|
||||
self.logger.error(f"示例配置文件 {example_path} 不存在,无法生成配置")
|
||||
raise ConfigNotFoundError(message=f"示例配置文件 {example_path} 不存在")
|
||||
|
||||
content = example_path.read_text()
|
||||
self.path.write_text(content)
|
||||
|
||||
# 通过属性访问配置
|
||||
@property
|
||||
def napcat_ws(self) -> NapCatWSModel:
|
||||
"""
|
||||
获取 NapCat WebSocket 配置
|
||||
"""
|
||||
return self._model.napcat_ws
|
||||
|
||||
@property
|
||||
def bot(self) -> BotModel:
|
||||
"""
|
||||
获取 Bot 基础配置
|
||||
"""
|
||||
return self._model.bot
|
||||
|
||||
@property
|
||||
def redis(self) -> RedisModel:
|
||||
"""
|
||||
获取 Redis 配置
|
||||
"""
|
||||
return self._model.redis
|
||||
|
||||
@property
|
||||
def mysql(self) -> MySQLModel:
|
||||
"""
|
||||
获取 MySQL 配置
|
||||
"""
|
||||
return self._model.mysql
|
||||
|
||||
@property
|
||||
def docker(self) -> DockerModel:
|
||||
"""
|
||||
获取 Docker 配置
|
||||
"""
|
||||
return self._model.docker
|
||||
|
||||
@property
|
||||
def image_manager(self) -> ImageManagerModel:
|
||||
"""
|
||||
获取图片生成管理器配置
|
||||
"""
|
||||
return self._model.image_manager
|
||||
|
||||
@property
|
||||
def reverse_ws(self) -> ReverseWSModel:
|
||||
"""
|
||||
获取反向 WebSocket 配置
|
||||
"""
|
||||
return self._model.reverse_ws
|
||||
|
||||
@property
|
||||
def threading(self) -> ThreadingModel:
|
||||
"""
|
||||
获取线程管理配置
|
||||
"""
|
||||
return self._model.threading
|
||||
|
||||
@property
|
||||
def bilibili(self) -> BilibiliModel:
|
||||
"""
|
||||
获取 Bilibili 配置
|
||||
"""
|
||||
return self._model.bilibili
|
||||
|
||||
@property
|
||||
def local_file_server(self) -> LocalFileServerModel:
|
||||
"""
|
||||
获取本地文件服务器配置
|
||||
"""
|
||||
return self._model.local_file_server
|
||||
|
||||
@property
|
||||
def discord(self) -> DiscordModel:
|
||||
"""
|
||||
获取 Discord 配置
|
||||
"""
|
||||
return self._model.discord
|
||||
|
||||
@property
|
||||
def cross_platform(self) -> CrossPlatformModel:
|
||||
"""
|
||||
获取跨平台配置
|
||||
"""
|
||||
return self._model.cross_platform
|
||||
|
||||
@property
|
||||
def logging(self) -> LoggingModel:
|
||||
"""
|
||||
获取日志配置
|
||||
"""
|
||||
return self._model.logging
|
||||
|
||||
|
||||
# 实例化全局配置对象
|
||||
global_config = Config()
|
||||
163
src/neobot/core/config_models.py
Normal file
163
src/neobot/core/config_models.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Pydantic 配置模型模块
|
||||
|
||||
该模块使用 Pydantic 定义了与 `config.toml` 文件结构完全对应的配置模型。
|
||||
这使得配置的加载、校验和访问都变得类型安全和健壮。
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NapCatWSModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[napcat_ws]` 配置块。
|
||||
"""
|
||||
uri: str
|
||||
token: str = ""
|
||||
reconnect_interval: int = 5
|
||||
|
||||
|
||||
class BotModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[bot]` 配置块。
|
||||
"""
|
||||
command: List[str] = Field(default_factory=lambda: ["/"])
|
||||
ignore_self_message: bool = True
|
||||
permission_denied_message: str = "权限不足,需要 {permission_name} 权限"
|
||||
|
||||
class ReverseWSModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[reverse_ws]` 配置块。
|
||||
"""
|
||||
enabled: bool = False
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 3002
|
||||
token: Optional[str] = None
|
||||
|
||||
|
||||
class RedisModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[redis]` 配置块。
|
||||
"""
|
||||
host: str
|
||||
port: int
|
||||
db: int
|
||||
password: str
|
||||
|
||||
|
||||
class MySQLModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[mysql]` 配置块。
|
||||
"""
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
db: str
|
||||
charset: str = "utf8mb4"
|
||||
|
||||
|
||||
|
||||
class DockerModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[docker]` 配置块。
|
||||
"""
|
||||
base_url: Optional[str] = None
|
||||
sandbox_image: str = "python-sandbox:latest"
|
||||
timeout: int = 10
|
||||
concurrency_limit: int = 5
|
||||
tls_verify: bool = False
|
||||
ca_cert_path: Optional[str] = None
|
||||
client_cert_path: Optional[str] = None
|
||||
client_key_path: Optional[str] = None
|
||||
|
||||
class ImageManagerModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[image_manager]` 配置块。
|
||||
"""
|
||||
image_height: int = 1920
|
||||
image_width: int = 1080
|
||||
|
||||
|
||||
class ThreadingModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[threading]` 配置块。
|
||||
"""
|
||||
max_workers: int = Field(default=10, ge=1, le=100)
|
||||
client_max_workers: int = Field(default=5, ge=1, le=50)
|
||||
thread_name_prefix: str = "NeoBot-Thread"
|
||||
|
||||
|
||||
class BilibiliModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[bilibili]` 配置块。
|
||||
"""
|
||||
sessdata: Optional[str] = None
|
||||
bili_jct: Optional[str] = None
|
||||
buvid3: Optional[str] = None
|
||||
dedeuserid: Optional[str] = None
|
||||
|
||||
|
||||
class LocalFileServerModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[local_file_server]` 配置块。
|
||||
"""
|
||||
enabled: bool = True
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 3003
|
||||
|
||||
|
||||
class DiscordModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[discord]` 配置块。
|
||||
"""
|
||||
enabled: bool = False
|
||||
token: str = ""
|
||||
proxy: Optional[str] = None
|
||||
proxy_type: str = "http"
|
||||
|
||||
|
||||
class CrossPlatformMapping(BaseModel):
|
||||
"""
|
||||
跨平台映射配置
|
||||
"""
|
||||
qq_group_id: int
|
||||
name: str
|
||||
|
||||
|
||||
class CrossPlatformModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[cross_platform]` 配置块。
|
||||
"""
|
||||
enabled: bool = False
|
||||
mappings: Optional[dict[int, CrossPlatformMapping]] = None
|
||||
|
||||
|
||||
class LoggingModel(BaseModel):
|
||||
"""
|
||||
对应 `config.toml` 中的 `[logging]` 配置块。
|
||||
"""
|
||||
level: str = "DEBUG"
|
||||
file_level: str = "DEBUG"
|
||||
console_level: str = "INFO"
|
||||
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
"""
|
||||
顶层配置模型,整合了所有子配置块。
|
||||
"""
|
||||
napcat_ws: NapCatWSModel
|
||||
bot: BotModel
|
||||
redis: RedisModel
|
||||
mysql: MySQLModel
|
||||
docker: DockerModel
|
||||
image_manager: ImageManagerModel
|
||||
reverse_ws: ReverseWSModel
|
||||
threading: ThreadingModel = Field(default_factory=ThreadingModel)
|
||||
bilibili: BilibiliModel = Field(default_factory=BilibiliModel)
|
||||
local_file_server: LocalFileServerModel = Field(default_factory=LocalFileServerModel)
|
||||
discord: DiscordModel = Field(default_factory=DiscordModel)
|
||||
cross_platform: CrossPlatformModel = Field(default_factory=CrossPlatformModel)
|
||||
logging: LoggingModel = Field(default_factory=LoggingModel)
|
||||
|
||||
|
||||
3
src/neobot/core/data/admin.json
Normal file
3
src/neobot/core/data/admin.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"admins": [2221577113]
|
||||
}
|
||||
8
src/neobot/core/data/permissions.json
Normal file
8
src/neobot/core/data/permissions.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"users": {
|
||||
"123456789": "op",
|
||||
"888888": "op",
|
||||
"2221577113": "admin",
|
||||
"999999": "user"
|
||||
}
|
||||
}
|
||||
9
src/neobot/core/handlers/__init__.py
Normal file
9
src/neobot/core/handlers/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
NEO Bot Handlers Package
|
||||
|
||||
事件处理器模块。
|
||||
"""
|
||||
|
||||
from .event_handler import matcher
|
||||
|
||||
__all__ = ["matcher"]
|
||||
266
src/neobot/core/handlers/event_handler.py
Normal file
266
src/neobot/core/handlers/event_handler.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
事件处理器模块
|
||||
|
||||
该模块定义了用于处理不同类型事件的处理器类。
|
||||
每个处理器都负责注册和分发特定类型的事件。
|
||||
"""
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bot import Bot
|
||||
from ..config_loader import global_config
|
||||
from ..permission import Permission
|
||||
from ..utils.executor import run_in_thread_pool
|
||||
|
||||
|
||||
class BaseHandler(ABC):
|
||||
"""
|
||||
事件处理器抽象基类
|
||||
"""
|
||||
def __init__(self):
|
||||
self.handlers: List[Dict[str, Any]] = []
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _run_handler(
|
||||
self,
|
||||
func: Callable,
|
||||
bot: "Bot",
|
||||
event: Any,
|
||||
args: Optional[List[str]] = None,
|
||||
permission_granted: Optional[bool] = None
|
||||
):
|
||||
"""
|
||||
智能执行事件处理器,并注入所需参数
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
params = sig.parameters
|
||||
kwargs: Dict[str, Any] = {}
|
||||
|
||||
if "bot" in params:
|
||||
kwargs["bot"] = bot
|
||||
if "event" in params:
|
||||
kwargs["event"] = event
|
||||
if "args" in params and args is not None:
|
||||
kwargs["args"] = args
|
||||
if "permission_granted" in params and permission_granted is not None:
|
||||
kwargs["permission_granted"] = permission_granted
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
result = await func(**kwargs)
|
||||
else:
|
||||
# 如果是同步函数,则放入线程池执行
|
||||
result = await run_in_thread_pool(func, **kwargs)
|
||||
return result is True
|
||||
|
||||
|
||||
class MessageHandler(BaseHandler):
|
||||
"""
|
||||
消息事件处理器
|
||||
"""
|
||||
def __init__(self, prefixes: Tuple[str, ...]):
|
||||
super().__init__()
|
||||
self.prefixes = prefixes
|
||||
self.commands: Dict[str, Dict] = {}
|
||||
self.message_handlers: List[Dict[str, Any]] = []
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
清空所有已注册的消息和命令处理器
|
||||
"""
|
||||
self.commands.clear()
|
||||
self.message_handlers.clear()
|
||||
|
||||
def unregister_by_plugin_name(self, plugin_name: str):
|
||||
"""
|
||||
根据插件名卸载相关的消息和命令处理器
|
||||
"""
|
||||
# 卸载命令
|
||||
commands_to_remove = [name for name, info in self.commands.items() if info["plugin_name"] == plugin_name]
|
||||
for name in commands_to_remove:
|
||||
del self.commands[name]
|
||||
|
||||
# 卸载通用消息处理器
|
||||
self.message_handlers = [h for h in self.message_handlers if h["plugin_name"] != plugin_name]
|
||||
|
||||
def on_message(self) -> Callable:
|
||||
"""
|
||||
注册通用消息处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
self.message_handlers.append({"func": func, "plugin_name": plugin_name})
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def command(
|
||||
self,
|
||||
*names: str,
|
||||
permission: Optional[Permission] = None,
|
||||
override_permission_check: bool = False
|
||||
) -> Callable:
|
||||
"""
|
||||
注册命令处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
for name in names:
|
||||
self.commands[name] = {
|
||||
"func": func,
|
||||
"permission": permission,
|
||||
"override_permission_check": override_permission_check,
|
||||
"plugin_name": plugin_name,
|
||||
}
|
||||
return func
|
||||
return decorator
|
||||
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理消息事件,分发给命令处理器或通用消息处理器
|
||||
"""
|
||||
# 原子化地增加接收消息总数
|
||||
from ..managers.redis_manager import redis_manager
|
||||
from ..utils.logger import logger
|
||||
try:
|
||||
lua_script = "return redis.call('INCR', KEYS[1])"
|
||||
await redis_manager.execute_lua_script(
|
||||
script=lua_script,
|
||||
keys=["neobot:stats:messages_received"],
|
||||
args=[]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"接收消息计数失败: {e}")
|
||||
|
||||
from ..managers import permission_manager
|
||||
for handler_info in self.message_handlers:
|
||||
consumed = await self._run_handler(handler_info["func"], bot, event)
|
||||
if consumed:
|
||||
return
|
||||
|
||||
if not event.raw_message:
|
||||
return
|
||||
|
||||
raw_text = event.raw_message.strip()
|
||||
prefix_found = next((p for p in self.prefixes if raw_text.startswith(p)), None)
|
||||
|
||||
if not prefix_found:
|
||||
return
|
||||
|
||||
command_parts = raw_text[len(prefix_found):].split()
|
||||
if not command_parts:
|
||||
return
|
||||
|
||||
command_name = command_parts[0]
|
||||
args = command_parts[1:]
|
||||
|
||||
if command_name in self.commands:
|
||||
command_info = self.commands[command_name]
|
||||
func = command_info["func"]
|
||||
permission = command_info.get("permission")
|
||||
override_check = command_info.get("override_permission_check", False)
|
||||
|
||||
permission_granted = True
|
||||
if permission:
|
||||
permission_granted = await permission_manager.check_permission(event.user_id, permission)
|
||||
|
||||
if not permission_granted and not override_check:
|
||||
permission_name = permission.name if isinstance(permission, Permission) else permission
|
||||
message_template = global_config.bot.permission_denied_message
|
||||
await bot.send(event, message_template.format(permission_name=permission_name))
|
||||
return
|
||||
|
||||
# 在执行指令前,原子化地增加指令调用次数
|
||||
from ..managers.redis_manager import redis_manager
|
||||
from ..utils.logger import logger
|
||||
try:
|
||||
lua_script = "return redis.call('HINCRBY', KEYS[1], ARGV[1], 1)"
|
||||
await redis_manager.execute_lua_script(
|
||||
script=lua_script,
|
||||
keys=["neobot:command_stats"],
|
||||
args=[command_name]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"指令 /{command_name} 调用次数统计失败: {e}")
|
||||
|
||||
await self._run_handler(
|
||||
func,
|
||||
bot,
|
||||
event,
|
||||
args=args,
|
||||
permission_granted=permission_granted
|
||||
)
|
||||
|
||||
|
||||
class NoticeHandler(BaseHandler):
|
||||
"""
|
||||
通知事件处理器
|
||||
"""
|
||||
def clear(self):
|
||||
self.handlers.clear()
|
||||
|
||||
def unregister_by_plugin_name(self, plugin_name: str):
|
||||
"""
|
||||
根据插件名卸载相关的通知处理器
|
||||
"""
|
||||
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
|
||||
|
||||
def register(self, notice_type: Optional[str] = None) -> Callable:
|
||||
"""
|
||||
注册通知处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
self.handlers.append({"type": notice_type, "func": func, "plugin_name": plugin_name})
|
||||
return func
|
||||
return decorator
|
||||
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理通知事件
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
if handler["type"] is None or handler["type"] == event.notice_type:
|
||||
await self._run_handler(handler["func"], bot, event)
|
||||
|
||||
|
||||
class RequestHandler(BaseHandler):
|
||||
"""
|
||||
请求事件处理器
|
||||
"""
|
||||
def clear(self):
|
||||
self.handlers.clear()
|
||||
|
||||
def unregister_by_plugin_name(self, plugin_name: str):
|
||||
"""
|
||||
根据插件名卸载相关的请求处理器
|
||||
"""
|
||||
self.handlers = [h for h in self.handlers if h["plugin_name"] != plugin_name]
|
||||
|
||||
def register(self, request_type: Optional[str] = None) -> Callable:
|
||||
"""
|
||||
注册请求处理器
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
module = inspect.getmodule(func)
|
||||
plugin_name = module.__name__ if module else "Unknown"
|
||||
self.handlers.append({"type": request_type, "func": func, "plugin_name": plugin_name})
|
||||
return func
|
||||
return decorator
|
||||
|
||||
async def handle(self, bot: "Bot", event: Any):
|
||||
"""
|
||||
处理请求事件
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
if handler["type"] is None or handler["type"] == event.request_type:
|
||||
await self._run_handler(handler["func"], bot, event)
|
||||
31
src/neobot/core/managers/__init__.py
Normal file
31
src/neobot/core/managers/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
NEO Bot Managers Package
|
||||
|
||||
管理器模块,包含各种功能管理器。
|
||||
"""
|
||||
|
||||
from .bot_manager import bot_manager
|
||||
from .browser_manager import browser_manager
|
||||
from .command_manager import command_manager
|
||||
from .image_manager import image_manager
|
||||
from .mysql_manager import mysql_manager
|
||||
from .permission_manager import permission_manager
|
||||
from .plugin_manager import plugin_manager
|
||||
from .redis_manager import redis_manager
|
||||
from .reverse_ws_manager import reverse_ws_manager
|
||||
from .thread_manager import thread_manager
|
||||
from .vectordb_manager import vectordb_manager
|
||||
|
||||
__all__ = [
|
||||
"bot_manager",
|
||||
"browser_manager",
|
||||
"command_manager",
|
||||
"image_manager",
|
||||
"mysql_manager",
|
||||
"permission_manager",
|
||||
"plugin_manager",
|
||||
"redis_manager",
|
||||
"reverse_ws_manager",
|
||||
"thread_manager",
|
||||
"vectordb_manager",
|
||||
]
|
||||
57
src/neobot/core/managers/bot_manager.py
Normal file
57
src/neobot/core/managers/bot_manager.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
import threading
|
||||
from ..utils.logger import ModuleLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bot import Bot
|
||||
|
||||
class BotManager:
|
||||
"""
|
||||
Bot 实例管理器
|
||||
|
||||
负责统一管理所有活跃的 Bot 实例(包括正向 WS 和反向 WS 连接的 Bot)。
|
||||
提供注册、注销和获取 Bot 实例的方法。
|
||||
"""
|
||||
def __init__(self):
|
||||
self._bots: Dict[str, "Bot"] = {} # type: ignore[assignment] # key: bot_id (str), value: Bot instance
|
||||
self._lock = threading.RLock()
|
||||
self.logger = ModuleLogger("BotManager")
|
||||
|
||||
def register_bot(self, bot: "Bot") -> None:
|
||||
"""
|
||||
注册一个 Bot 实例
|
||||
"""
|
||||
if not bot or not bot.self_id:
|
||||
self.logger.warning("尝试注册无效的 Bot 实例")
|
||||
return
|
||||
|
||||
bot_id = str(bot.self_id)
|
||||
with self._lock:
|
||||
self._bots[bot_id] = bot
|
||||
self.logger.info(f"Bot 实例已注册: {bot_id}")
|
||||
|
||||
def unregister_bot(self, bot_id: str) -> None:
|
||||
"""
|
||||
注销一个 Bot 实例
|
||||
"""
|
||||
with self._lock:
|
||||
if bot_id in self._bots:
|
||||
del self._bots[bot_id]
|
||||
self.logger.info(f"Bot 实例已注销: {bot_id}")
|
||||
|
||||
def get_bot(self, bot_id: str) -> Optional["Bot"]:
|
||||
"""
|
||||
根据 ID 获取 Bot 实例
|
||||
"""
|
||||
with self._lock:
|
||||
return self._bots.get(str(bot_id))
|
||||
|
||||
def get_all_bots(self) -> List["Bot"]:
|
||||
"""
|
||||
获取所有活跃的 Bot 实例
|
||||
"""
|
||||
with self._lock:
|
||||
return list(self._bots.values())
|
||||
|
||||
# 全局单例实例
|
||||
bot_manager = BotManager()
|
||||
153
src/neobot/core/managers/browser_manager.py
Normal file
153
src/neobot/core/managers/browser_manager.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
浏览器管理器模块
|
||||
|
||||
负责管理全局唯一的 Playwright 浏览器实例,避免频繁启动/关闭浏览器的开销。
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from playwright.async_api import async_playwright, Browser, Playwright, Page
|
||||
from ..utils.logger import logger
|
||||
from ..utils.singleton import Singleton
|
||||
|
||||
class BrowserManager(Singleton):
|
||||
"""
|
||||
浏览器管理器(异步单例)
|
||||
"""
|
||||
_playwright: Optional[Playwright] = None
|
||||
_browser: Optional[Browser] = None
|
||||
_page_pool: Optional[asyncio.Queue] = None
|
||||
_pool_size: int = 3
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化浏览器管理器
|
||||
"""
|
||||
# 调用父类 __init__ 确保单例初始化
|
||||
super().__init__()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化 Playwright 和 Browser
|
||||
"""
|
||||
if self._browser is None:
|
||||
try:
|
||||
logger.info("正在启动无头浏览器...")
|
||||
self._playwright = await async_playwright().start()
|
||||
# 启动 Chromium,headless=True 表示无头模式
|
||||
self._browser = await self._playwright.chromium.launch(headless=True)
|
||||
logger.success("无头浏览器启动成功!")
|
||||
except Exception as e:
|
||||
logger.exception(f"无头浏览器启动失败: {e}")
|
||||
self._browser = None
|
||||
|
||||
async def init_pool(self, size: int = 3):
|
||||
"""
|
||||
初始化页面池
|
||||
"""
|
||||
if not self._browser:
|
||||
await self.initialize()
|
||||
|
||||
if not self._browser:
|
||||
logger.error("浏览器初始化失败,无法创建页面池")
|
||||
return
|
||||
|
||||
self._pool_size = size
|
||||
self._page_pool = asyncio.Queue(maxsize=size)
|
||||
|
||||
logger.info(f"正在初始化页面池 (大小: {size})...")
|
||||
for i in range(size):
|
||||
try:
|
||||
page = await self._browser.new_page()
|
||||
await self._page_pool.put(page)
|
||||
except Exception as e:
|
||||
logger.error(f"创建页面池页面 {i+1} 失败: {e}")
|
||||
|
||||
logger.success(f"页面池初始化完成,当前可用页面: {self._page_pool.qsize()}")
|
||||
|
||||
async def get_page(self) -> Optional[Page]:
|
||||
"""
|
||||
从池中获取一个页面。如果池未初始化或为空,则尝试创建一个新页面(不入池)。
|
||||
"""
|
||||
if self._page_pool and not self._page_pool.empty():
|
||||
try:
|
||||
page = self._page_pool.get_nowait()
|
||||
# 简单的健康检查
|
||||
if page.is_closed():
|
||||
logger.warning("检测到池中页面已关闭,重新创建一个...")
|
||||
if self._browser:
|
||||
page = await self._browser.new_page()
|
||||
else:
|
||||
return None
|
||||
return page
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
|
||||
# 如果池空了或者没初始化,回退到临时创建
|
||||
logger.debug("页面池为空或未初始化,创建临时页面")
|
||||
return await self.get_new_page()
|
||||
|
||||
async def release_page(self, page: Page):
|
||||
"""
|
||||
归还页面到池中。如果池已满或未初始化,则关闭页面。
|
||||
"""
|
||||
if not page or page.is_closed():
|
||||
return
|
||||
|
||||
if self._page_pool:
|
||||
try:
|
||||
# 重置页面状态 (例如清空内容),防止数据污染
|
||||
# 注意: goto('about:blank') 比 close() 快得多
|
||||
await page.goto("about:blank")
|
||||
|
||||
self._page_pool.put_nowait(page)
|
||||
return
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
# 池满或未启用池,直接关闭
|
||||
await page.close()
|
||||
|
||||
async def get_new_page(self) -> Optional[Page]:
|
||||
"""
|
||||
获取一个新的页面 (Page)
|
||||
|
||||
使用完毕后,调用者应该负责关闭该页面 (await page.close())
|
||||
"""
|
||||
if self._browser is None:
|
||||
logger.warning("浏览器尚未初始化,尝试重新初始化...")
|
||||
await self.initialize()
|
||||
|
||||
if self._browser:
|
||||
try:
|
||||
return await self._browser.new_page()
|
||||
except Exception as e:
|
||||
logger.error(f"创建新页面失败: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
async def shutdown(self):
|
||||
"""
|
||||
关闭浏览器和 Playwright
|
||||
"""
|
||||
# 清空页面池
|
||||
if self._page_pool:
|
||||
while not self._page_pool.empty():
|
||||
try:
|
||||
page = self._page_pool.get_nowait()
|
||||
await page.close()
|
||||
except (asyncio.QueueEmpty, AttributeError):
|
||||
pass
|
||||
self._page_pool = None
|
||||
|
||||
if self._browser:
|
||||
await self._browser.close()
|
||||
self._browser = None
|
||||
logger.info("浏览器已关闭")
|
||||
|
||||
if self._playwright:
|
||||
await self._playwright.stop()
|
||||
self._playwright = None
|
||||
logger.info("Playwright 已停止")
|
||||
|
||||
# 全局浏览器管理器实例
|
||||
browser_manager = BrowserManager()
|
||||
233
src/neobot/core/managers/command_manager.py
Normal file
233
src/neobot/core/managers/command_manager.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
命令与事件管理器模块
|
||||
|
||||
该模块定义了 `CommandManager` 类,它是整个机器人框架事件处理的核心。
|
||||
它通过装饰器模式,为插件提供了注册消息指令、通知事件处理器和
|
||||
请求事件处理器的能力。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
from neobot.models.events.message import MessageSegment
|
||||
|
||||
|
||||
|
||||
from ..config_loader import global_config
|
||||
from ..handlers.event_handler import MessageHandler, NoticeHandler, RequestHandler
|
||||
from .redis_manager import redis_manager
|
||||
from .image_manager import image_manager
|
||||
from ..utils.logger import logger
|
||||
|
||||
# 从配置中获取命令前缀
|
||||
_config_prefixes = global_config.bot.command
|
||||
|
||||
# 确保前缀配置是元组格式
|
||||
_final_prefixes: Tuple[str, ...]
|
||||
if isinstance(_config_prefixes, list):
|
||||
_final_prefixes = tuple(_config_prefixes)
|
||||
elif isinstance(_config_prefixes, str):
|
||||
_final_prefixes = (_config_prefixes,)
|
||||
else:
|
||||
_final_prefixes = tuple(_config_prefixes)
|
||||
|
||||
|
||||
class CommandManager:
|
||||
"""
|
||||
命令管理器,负责注册和分发所有类型的事件。
|
||||
|
||||
这是一个单例对象(`matcher`),在整个应用中共享。
|
||||
它将不同类型的事件处理委托给专门的处理器类。
|
||||
"""
|
||||
|
||||
def __init__(self, prefixes: Tuple[str, ...]):
|
||||
"""
|
||||
初始化命令管理器。
|
||||
|
||||
Args:
|
||||
prefixes (Tuple[str, ...]): 一个包含所有合法命令前缀的元组。
|
||||
"""
|
||||
self.plugins: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 初始化专门的事件处理器
|
||||
self.message_handler = MessageHandler(prefixes)
|
||||
self.notice_handler = NoticeHandler()
|
||||
self.request_handler = RequestHandler()
|
||||
|
||||
# 将处理器映射到事件类型
|
||||
self.handler_map = {
|
||||
"message": self.message_handler,
|
||||
"notice": self.notice_handler,
|
||||
"request": self.request_handler,
|
||||
}
|
||||
|
||||
# 注册内置的 /help 命令
|
||||
self._register_internal_commands()
|
||||
|
||||
async def sync_help_pic(self):
|
||||
"""
|
||||
启动时或插件重载时同步 help 图片到 Redis
|
||||
"""
|
||||
try:
|
||||
logger.info("正在生成帮助图片...")
|
||||
|
||||
# 1. 收集插件数据
|
||||
plugins_data = []
|
||||
for plugin_name, meta in self.plugins.items():
|
||||
plugins_data.append({
|
||||
"name": meta.get("name", plugin_name),
|
||||
"description": meta.get("description", "暂无描述"),
|
||||
"usage": meta.get("usage", "暂无用法")
|
||||
})
|
||||
|
||||
# 2. 渲染图片
|
||||
# 使用 png 格式以获得更好的文字清晰度
|
||||
base64_str = await image_manager.render_template_to_base64(
|
||||
template_name="help.html",
|
||||
data={"plugins": plugins_data},
|
||||
output_name="help_menu.png",
|
||||
image_type="png"
|
||||
)
|
||||
|
||||
if base64_str:
|
||||
await redis_manager.set("neobot:core:help_pic", base64_str)
|
||||
logger.success("帮助图片已更新并缓存到 Redis")
|
||||
else:
|
||||
logger.error("帮助图片生成失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步帮助图片失败: {e}")
|
||||
|
||||
def _register_internal_commands(self):
|
||||
"""
|
||||
注册框架内置的命令
|
||||
"""
|
||||
# Help 命令
|
||||
self.message_handler.command("help")(self._help_command)
|
||||
self.plugins["core.help"] = {
|
||||
"name": "帮助",
|
||||
"description": "显示所有可用指令的帮助信息",
|
||||
"usage": "/help",
|
||||
}
|
||||
|
||||
def clear_all_handlers(self):
|
||||
"""
|
||||
清空所有已注册的事件处理器。
|
||||
注意:这也会移除内置的 /help 命令,因此需要重新注册。
|
||||
"""
|
||||
self.message_handler.clear()
|
||||
self.notice_handler.clear()
|
||||
self.request_handler.clear()
|
||||
self.plugins.clear()
|
||||
|
||||
# 清空后,需要重新注册内置命令
|
||||
self._register_internal_commands()
|
||||
|
||||
def unload_plugin(self, plugin_name: str):
|
||||
"""
|
||||
卸载指定插件的所有处理器和命令。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件的模块名 (例如 'plugins.bili_parser')
|
||||
"""
|
||||
self.message_handler.unregister_by_plugin_name(plugin_name)
|
||||
self.notice_handler.unregister_by_plugin_name(plugin_name)
|
||||
self.request_handler.unregister_by_plugin_name(plugin_name)
|
||||
|
||||
# 移除插件元信息
|
||||
plugins_to_remove = [name for name in self.plugins if name == plugin_name]
|
||||
for name in plugins_to_remove:
|
||||
del self.plugins[name]
|
||||
|
||||
# --- 装饰器代理 ---
|
||||
|
||||
def on_message(self) -> Callable:
|
||||
"""
|
||||
装饰器:注册一个通用的消息处理器。
|
||||
"""
|
||||
return self.message_handler.on_message()
|
||||
|
||||
def command(
|
||||
self,
|
||||
*names: str,
|
||||
permission: Optional[Any] = None,
|
||||
override_permission_check: bool = False,
|
||||
) -> Callable:
|
||||
"""
|
||||
装饰器:注册一个消息指令处理器。
|
||||
"""
|
||||
return self.message_handler.command(
|
||||
*names,
|
||||
permission=permission,
|
||||
override_permission_check=override_permission_check,
|
||||
)
|
||||
|
||||
def on_notice(self, notice_type: Optional[str] = None) -> Callable:
|
||||
"""
|
||||
装饰器:注册一个通知事件处理器。
|
||||
"""
|
||||
return self.notice_handler.register(notice_type=notice_type)
|
||||
|
||||
def on_request(self, request_type: Optional[str] = None) -> Callable:
|
||||
"""
|
||||
装饰器:注册一个请求事件处理器。
|
||||
"""
|
||||
return self.request_handler.register(request_type=request_type)
|
||||
|
||||
# --- 事件处理 ---
|
||||
|
||||
async def handle_event(self, bot, event):
|
||||
"""
|
||||
统一的事件分发入口。
|
||||
|
||||
根据事件的 `post_type` 将其分发给对应的处理器。
|
||||
"""
|
||||
if event.post_type == "message" and global_config.bot.ignore_self_message:
|
||||
if (
|
||||
hasattr(event, "user_id")
|
||||
and hasattr(event, "self_id")
|
||||
and event.user_id == event.self_id
|
||||
):
|
||||
return
|
||||
|
||||
handler = self.handler_map.get(event.post_type)
|
||||
if handler:
|
||||
await handler.handle(bot, event)
|
||||
|
||||
# --- 内置命令实现 ---
|
||||
|
||||
async def _help_command(self, bot, event):
|
||||
"""
|
||||
内置的 `/help` 命令的实现。
|
||||
直接从 Redis 获取缓存的图片。
|
||||
"""
|
||||
try:
|
||||
# 1. 尝试从 Redis 获取
|
||||
help_pic = await redis_manager.get("neobot:core:help_pic")
|
||||
|
||||
if not help_pic:
|
||||
await bot.send(event, "帮助图片缓存缺失,正在重新生成...")
|
||||
await self.sync_help_pic()
|
||||
help_pic = await redis_manager.get("neobot:core:help_pic")
|
||||
|
||||
if help_pic:
|
||||
await bot.send(event, MessageSegment.image(help_pic))
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"获取或生成帮助图片失败: {e}")
|
||||
|
||||
# 2. 最后的兜底:发送纯文本
|
||||
help_text = "--- 可用指令列表 ---\n"
|
||||
for plugin_name, meta in self.plugins.items():
|
||||
name = meta.get("name", "未命名插件")
|
||||
description = meta.get("description", "暂无描述")
|
||||
usage = meta.get("usage", "暂无用法说明")
|
||||
|
||||
help_text += f"\n{name}:\n"
|
||||
help_text += f" 功能: {description}\n"
|
||||
help_text += f" 用法: {usage}\n"
|
||||
|
||||
await bot.send(event, help_text.strip())
|
||||
|
||||
|
||||
# 实例化全局唯一的命令管理器
|
||||
matcher = CommandManager(prefixes=_final_prefixes)
|
||||
140
src/neobot/core/managers/image_manager.py
Normal file
140
src/neobot/core/managers/image_manager.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
图片生成管理器模块
|
||||
|
||||
负责管理图片生成相关的逻辑,支持多种渲染引擎(目前支持 Playwright)。
|
||||
"""
|
||||
import os
|
||||
import base64
|
||||
import tempfile
|
||||
from typing import Dict, Any, Optional
|
||||
from jinja2 import Template
|
||||
|
||||
from .browser_manager import browser_manager
|
||||
from ..utils.logger import logger
|
||||
from ..utils.singleton import Singleton
|
||||
from ..config_loader import global_config
|
||||
|
||||
class ImageManager(Singleton):
|
||||
"""
|
||||
图片生成管理器(单例)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化图片生成管理器
|
||||
"""
|
||||
# 检查是否已经初始化
|
||||
if hasattr(self, 'template_dir'):
|
||||
return
|
||||
|
||||
# 模板目录
|
||||
self.template_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "templates")
|
||||
# 临时文件目录 - 使用系统临时目录
|
||||
self.temp_dir = os.path.join(tempfile.gettempdir(), "neobot_images")
|
||||
os.makedirs(self.temp_dir, exist_ok=True)
|
||||
# 模板缓存
|
||||
self._template_cache: Dict[str, Template] = {}
|
||||
|
||||
async def render_template(self, template_name: str, data: Dict[str, Any], output_name: str = "output.png", quality: int = 80, image_type: str = "png", width: int = 1920, height: int = 1080) -> Optional[str]:
|
||||
"""
|
||||
使用 Playwright 渲染 Jinja2 模板并保存为图片文件
|
||||
|
||||
Args:
|
||||
template_name (str): 模板文件名 (例如 "help.html")
|
||||
data (Dict[str, Any]): 传递给模板的数据字典
|
||||
output_name (str, optional): 输出文件名. Defaults to "output.png".
|
||||
quality (int, optional): JPEG 质量 (0-100). 仅在 image_type 为 jpeg 时有效. Defaults to 80.
|
||||
image_type (str, optional): 图片类型 ('png' or 'jpeg'). Defaults to "png".
|
||||
width (int, optional): 图片宽度. Defaults to 1920.
|
||||
height (int, optional): 图片高度. Defaults to 1080.
|
||||
|
||||
Returns:
|
||||
Optional[str]: 生成图片的绝对路径,如果失败则返回 None
|
||||
"""
|
||||
template_path = os.path.join(self.template_dir, template_name)
|
||||
if not os.path.exists(template_path):
|
||||
logger.error(f"模板文件未找到: {template_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1. 渲染 HTML (使用缓存)
|
||||
if template_name in self._template_cache:
|
||||
template = self._template_cache[template_name]
|
||||
else:
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
template_str = f.read()
|
||||
template = Template(template_str)
|
||||
self._template_cache[template_name] = template
|
||||
|
||||
html_content = template.render(**data)
|
||||
|
||||
# 2. 使用浏览器截图
|
||||
# 改为从池中获取页面
|
||||
page = await browser_manager.get_page()
|
||||
if not page:
|
||||
logger.error("无法获取浏览器页面")
|
||||
return None
|
||||
|
||||
try:
|
||||
width = data.get("width", width)
|
||||
height = data.get("height", height)
|
||||
await page.set_viewport_size({"width": width, "height": height})
|
||||
|
||||
# 加载内容
|
||||
await page.set_content(html_content)
|
||||
await page.wait_for_selector("body")
|
||||
|
||||
|
||||
screenshot_args = {
|
||||
'full_page': True,
|
||||
'type': image_type,
|
||||
'omit_background': False,
|
||||
'scale': 'css'
|
||||
}
|
||||
if image_type == 'jpeg':
|
||||
screenshot_args['quality'] = quality
|
||||
|
||||
screenshot_bytes = await page.screenshot(**screenshot_args) # type: ignore
|
||||
|
||||
finally:
|
||||
# 归还页面到池中,而不是直接关闭
|
||||
await browser_manager.release_page(page)
|
||||
|
||||
# 3. 保存文件
|
||||
output_path = os.path.join(self.temp_dir, output_name)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(screenshot_bytes)
|
||||
|
||||
logger.info(f"图片已生成: {output_path} ({len(screenshot_bytes)/1024:.2f} KB)")
|
||||
return os.path.abspath(output_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"渲染模板 {template_name} 失败: {e}")
|
||||
return None
|
||||
|
||||
async def render_template_to_base64(self, template_name: str, data: Dict[str, Any], output_name: str = "output.png", quality: int = 80, image_type: str = "png", width: int = 1920, height: int = 1080) -> Optional[str]:
|
||||
"""
|
||||
渲染模板并返回 Base64 编码的图片字符串
|
||||
"""
|
||||
file_path = await self.render_template(template_name, data, output_name, quality, image_type, width=width, height=height)
|
||||
if not file_path:
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
mime_type = "image/jpeg" if image_type == "jpeg" else "image/png"
|
||||
base64_str = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
# 记录摘要日志,避免刷屏
|
||||
log_message = f"Base64 图片已生成 (MIME: {mime_type}, Size: {len(base64_str)/1024:.2f} KB, Preview: {base64_str[:30]}...{base64_str[-30:]})"
|
||||
logger.debug(log_message)
|
||||
|
||||
return f"data:{mime_type};base64," + base64_str
|
||||
except Exception as e:
|
||||
logger.error(f"读取图片文件失败: {e}")
|
||||
return None
|
||||
|
||||
# 全局图片管理器实例
|
||||
image_manager = ImageManager()
|
||||
148
src/neobot/core/managers/mysql_manager.py
Normal file
148
src/neobot/core/managers/mysql_manager.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import aiomysql
|
||||
from ..config_loader import global_config as config
|
||||
from ..utils.logger import logger
|
||||
from ..utils.singleton import Singleton
|
||||
|
||||
|
||||
class MySQLManager(Singleton):
|
||||
"""
|
||||
MySQL 数据库连接管理器(异步单例)
|
||||
"""
|
||||
_pool = None
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化 MySQL 管理器
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
异步初始化 MySQL 连接池并进行健康检查
|
||||
"""
|
||||
if self._pool is None:
|
||||
try:
|
||||
mysql_config = config.mysql
|
||||
host = mysql_config.host
|
||||
port = mysql_config.port
|
||||
user = mysql_config.user
|
||||
password = mysql_config.password
|
||||
db = mysql_config.db
|
||||
charset = mysql_config.charset
|
||||
|
||||
logger.info(f"正在尝试连接 MySQL: {host}:{port}, DB: {db}")
|
||||
|
||||
self._pool = await aiomysql.create_pool(
|
||||
host=host,
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
db=db,
|
||||
charset=charset,
|
||||
autocommit=False,
|
||||
maxsize=10,
|
||||
minsize=1
|
||||
)
|
||||
|
||||
async with self._pool.acquire() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute("SELECT 1")
|
||||
result = await cur.fetchone()
|
||||
if result and result[0] == 1:
|
||||
logger.success("MySQL 连接成功!")
|
||||
else:
|
||||
logger.error("MySQL 连接失败: 健康检查失败")
|
||||
except Exception as e:
|
||||
logger.exception(f"MySQL 初始化时发生未知错误: {e}")
|
||||
self._pool = None
|
||||
|
||||
@property
|
||||
def pool(self):
|
||||
"""
|
||||
获取 MySQL 连接池实例
|
||||
"""
|
||||
if self._pool is None:
|
||||
raise ConnectionError("MySQL 未初始化或连接失败,请先调用 initialize()")
|
||||
return self._pool
|
||||
|
||||
async def execute(self, sql: str, args: tuple = None):
|
||||
"""
|
||||
执行 SQL 语句(用于 INSERT、UPDATE、DELETE)
|
||||
|
||||
Args:
|
||||
sql: SQL 语句
|
||||
args: 参数元组
|
||||
|
||||
Returns:
|
||||
影响的行数
|
||||
"""
|
||||
async with self._pool.acquire() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(sql, args)
|
||||
await conn.commit()
|
||||
return cur.rowcount
|
||||
|
||||
async def fetchone(self, sql: str, args: tuple = None):
|
||||
"""
|
||||
查询单条记录
|
||||
|
||||
Args:
|
||||
sql: SQL 语句
|
||||
args: 参数元组
|
||||
|
||||
Returns:
|
||||
单条记录字典
|
||||
"""
|
||||
async with self._pool.acquire() as conn:
|
||||
async with conn.cursor(aiomysql.DictCursor) as cur:
|
||||
await cur.execute(sql, args)
|
||||
return await cur.fetchone()
|
||||
|
||||
async def fetchall(self, sql: str, args: tuple = None):
|
||||
"""
|
||||
查询多条记录
|
||||
|
||||
Args:
|
||||
sql: SQL 语句
|
||||
args: 参数元组
|
||||
|
||||
Returns:
|
||||
记录列表
|
||||
"""
|
||||
async with self._pool.acquire() as conn:
|
||||
async with conn.cursor(aiomysql.DictCursor) as cur:
|
||||
await cur.execute(sql, args)
|
||||
return await cur.fetchall()
|
||||
|
||||
async def begin_transaction(self):
|
||||
"""
|
||||
开始事务
|
||||
|
||||
Returns:
|
||||
事务连接对象
|
||||
"""
|
||||
conn = await self._pool.acquire()
|
||||
return conn
|
||||
|
||||
async def commit_transaction(self, conn):
|
||||
"""
|
||||
提交事务
|
||||
|
||||
Args:
|
||||
conn: 事务连接对象
|
||||
"""
|
||||
await conn.commit()
|
||||
await self._pool.release(conn)
|
||||
|
||||
async def rollback_transaction(self, conn):
|
||||
"""
|
||||
回滚事务
|
||||
|
||||
Args:
|
||||
conn: 事务连接对象
|
||||
"""
|
||||
await conn.rollback()
|
||||
await self._pool.release(conn)
|
||||
|
||||
|
||||
mysql_manager = MySQLManager()
|
||||
435
src/neobot/core/managers/permission_manager.py
Normal file
435
src/neobot/core/managers/permission_manager.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
权限管理器模块
|
||||
|
||||
该模块负责管理用户权限,支持 admin、op、user 三个权限级别。
|
||||
以 permissions.json 文件作为主要数据源,Redis 用于加速访问。
|
||||
"""
|
||||
import orjson
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Set
|
||||
|
||||
from ..utils.logger import logger
|
||||
from ..utils.singleton import Singleton
|
||||
from .redis_manager import redis_manager
|
||||
from ..permission import Permission
|
||||
|
||||
|
||||
# 用于从字符串名称查找权限对象的字典
|
||||
_PERMISSIONS: Dict[str, Permission] = {
|
||||
p.value: p for p in Permission
|
||||
}
|
||||
|
||||
|
||||
class PermissionManager(Singleton):
|
||||
"""
|
||||
权限管理器类
|
||||
|
||||
以 permissions.json 文件作为权限数据的主要来源,Redis 用于高速缓存访问。
|
||||
所有写操作会同时更新文件和Redis缓存,确保数据一致性。
|
||||
"""
|
||||
_REDIS_KEY = "neobot:permissions" # 用于存储用户权限的 Redis Hash 键
|
||||
_REDIS_ADMINS_KEY = "neobot:admins" # 用于存储管理员列表的 Redis 键
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化权限管理器
|
||||
"""
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
|
||||
# 权限数据文件路径,作为主要数据源
|
||||
self.data_file = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"..",
|
||||
"data",
|
||||
"permissions.json"
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(self.data_file), exist_ok=True)
|
||||
|
||||
# 如果文件不存在,创建默认文件
|
||||
if not os.path.exists(self.data_file):
|
||||
default_data = {"users": {}}
|
||||
with open(self.data_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(default_data, indent=2, ensure_ascii=False))
|
||||
logger.info(f"已创建默认权限文件: {self.data_file}")
|
||||
|
||||
logger.info("权限管理器初始化完成")
|
||||
super().__init__()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
异步初始化,以 permissions.json 文件内容为主,同步到 Redis 缓存
|
||||
"""
|
||||
try:
|
||||
# 总是以文件内容为主,强制同步到 Redis
|
||||
logger.info("以 permissions.json 文件内容为准,同步到 Redis 缓存...")
|
||||
await self._sync_file_to_redis()
|
||||
|
||||
# 检查 Redis 中的数据量
|
||||
perm_count = await redis_manager.redis.hlen(self._REDIS_KEY)
|
||||
admin_count = await redis_manager.redis.scard(self._REDIS_ADMINS_KEY)
|
||||
logger.info(f"Redis 缓存已同步,权限数据: {perm_count} 条,管理员: {admin_count} 位。")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化权限数据时发生错误: {e}")
|
||||
|
||||
async def _sync_file_to_redis(self):
|
||||
"""
|
||||
将 permissions.json 文件内容同步到 Redis 缓存
|
||||
"""
|
||||
try:
|
||||
# 清空 Redis 中的现有数据
|
||||
await redis_manager.redis.delete(self._REDIS_KEY)
|
||||
await redis_manager.redis.delete(self._REDIS_ADMINS_KEY)
|
||||
|
||||
# 从文件加载数据
|
||||
if os.path.exists(self.data_file):
|
||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
users = data.get("users", {})
|
||||
|
||||
if users:
|
||||
# 分离普通权限和管理员权限
|
||||
normal_perms = {}
|
||||
admin_ids = set()
|
||||
|
||||
for user_id, level_name in users.items():
|
||||
if level_name == Permission.ADMIN.value:
|
||||
admin_ids.add(user_id)
|
||||
else:
|
||||
normal_perms[user_id] = level_name
|
||||
|
||||
# 使用 pipeline 批量写入普通权限
|
||||
if normal_perms:
|
||||
async with redis_manager.redis.pipeline(transaction=True) as pipe:
|
||||
for user_id, level_name in normal_perms.items():
|
||||
pipe.hset(self._REDIS_KEY, user_id, level_name)
|
||||
await pipe.execute()
|
||||
|
||||
# 使用 pipeline 批量写入管理员
|
||||
if admin_ids:
|
||||
await redis_manager.redis.sadd(self._REDIS_ADMINS_KEY, *admin_ids)
|
||||
|
||||
logger.success(f"成功同步 {len(users)} 条权限数据到 Redis (普通权限: {len(normal_perms)}, 管理员: {len(admin_ids)})")
|
||||
else:
|
||||
logger.info("permissions.json 文件中没有权限数据,已清空 Redis 缓存。")
|
||||
else:
|
||||
logger.warning(f"权限文件 {self.data_file} 不存在,已清空 Redis 缓存。")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"解析 permissions.json 失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"同步文件到 Redis 失败: {e}")
|
||||
|
||||
async def _migrate_from_file_to_redis(self):
|
||||
"""
|
||||
从 permissions.json 加载权限数据并存入 Redis Hash
|
||||
"""
|
||||
perms_to_migrate = {}
|
||||
try:
|
||||
if os.path.exists(self.data_file):
|
||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
perms_to_migrate = data.get("users", {})
|
||||
|
||||
if perms_to_migrate:
|
||||
# 使用 pipeline 批量写入,提高效率
|
||||
async with redis_manager.redis.pipeline(transaction=True) as pipe:
|
||||
for user_id, level_name in perms_to_migrate.items():
|
||||
pipe.hset(self._REDIS_KEY, user_id, level_name)
|
||||
await pipe.execute()
|
||||
logger.success(f"成功从文件迁移 {len(perms_to_migrate)} 条权限数据到 Redis。")
|
||||
else:
|
||||
logger.info("permissions.json 文件为空或不存在,无需迁移。")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"解析 permissions.json 失败,无法迁移: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移权限数据到 Redis 失败: {e}")
|
||||
|
||||
async def _migrate_admins_from_file_to_redis(self):
|
||||
"""
|
||||
从 permissions.json 加载管理员列表并存入 Redis
|
||||
"""
|
||||
admins_to_migrate = set()
|
||||
try:
|
||||
if os.path.exists(self.data_file):
|
||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
# 从 users 字段中查找权限为 admin 的用户
|
||||
users = data.get("users", {})
|
||||
for user_id, level_name in users.items():
|
||||
if level_name == Permission.ADMIN.value:
|
||||
admins_to_migrate.add(user_id)
|
||||
|
||||
# 同时兼容旧版的 admins 字段(如果存在的话)
|
||||
old_admins = data.get("admins", [])
|
||||
for admin_id in old_admins:
|
||||
admins_to_migrate.add(str(admin_id))
|
||||
|
||||
if admins_to_migrate:
|
||||
await redis_manager.redis.sadd(self._REDIS_ADMINS_KEY, *admins_to_migrate)
|
||||
logger.success(f"成功从文件迁移 {len(admins_to_migrate)} 位管理员到 Redis。")
|
||||
else:
|
||||
logger.info("permissions.json 文件中没有管理员数据,无需迁移。")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"解析 permissions.json 失败,无法迁移管理员数据: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移管理员数据到 Redis 失败: {e}")
|
||||
|
||||
async def _save_to_file_backup(self):
|
||||
"""
|
||||
将 Redis 中的权限数据和管理员列表完整备份到 permissions.json
|
||||
"""
|
||||
try:
|
||||
all_perms = await redis_manager.redis.hgetall(self._REDIS_KEY)
|
||||
# 由于Redis连接已设置decode_responses=True,所以直接使用字符串
|
||||
users_data = {k: v for k, v in all_perms.items()}
|
||||
|
||||
# 获取Redis中的管理员列表并合并到数据中
|
||||
all_admins = await redis_manager.redis.smembers(self._REDIS_ADMINS_KEY)
|
||||
for admin_id in all_admins:
|
||||
users_data[admin_id] = Permission.ADMIN.value # 管理员拥有最高权限
|
||||
|
||||
with open(self.data_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps({"users": users_data}, indent=2, ensure_ascii=False))
|
||||
logger.debug(f"权限数据已备份到 {self.data_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"备份权限数据到 permissions.json 失败: {e}")
|
||||
|
||||
async def get_user_permission(self, user_id: int) -> Permission:
|
||||
"""
|
||||
获取指定用户的权限对象
|
||||
|
||||
优先检查是否为机器人管理员,然后从 Redis 查询。
|
||||
"""
|
||||
# 检查用户是否为管理员(Redis Set 中的存在性检查)
|
||||
try:
|
||||
if await redis_manager.redis.sismember(self._REDIS_ADMINS_KEY, str(user_id)):
|
||||
return Permission.ADMIN
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 检查管理员权限失败: {e}")
|
||||
|
||||
try:
|
||||
level_name = await redis_manager.redis.hget(self._REDIS_KEY, str(user_id))
|
||||
if level_name:
|
||||
return _PERMISSIONS.get(level_name, Permission.USER)
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 获取用户 {user_id} 权限失败: {e}")
|
||||
|
||||
return Permission.USER
|
||||
|
||||
async def set_user_permission(self, user_id: int, permission: Permission) -> None:
|
||||
"""
|
||||
设置指定用户的权限级别,首先更新文件,然后同步到 Redis 缓存
|
||||
"""
|
||||
if not isinstance(permission, Permission):
|
||||
raise ValueError(f"无效的权限对象: {permission}")
|
||||
|
||||
try:
|
||||
# 首先从文件加载当前数据
|
||||
if os.path.exists(self.data_file):
|
||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
else:
|
||||
data = {"users": {}}
|
||||
|
||||
# 更新权限数据
|
||||
data["users"][str(user_id)] = permission.value
|
||||
|
||||
# 原子性写入文件
|
||||
temp_file = self.data_file + ".tmp"
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(data, indent=2, ensure_ascii=False))
|
||||
os.replace(temp_file, self.data_file) # 原子操作
|
||||
|
||||
# 同步到 Redis
|
||||
await self._sync_file_to_redis()
|
||||
logger.info(f"已设置用户 {user_id} 的权限为 {permission.value},并同步到 Redis")
|
||||
except Exception as e:
|
||||
logger.error(f"设置用户 {user_id} 权限失败: {e}")
|
||||
|
||||
async def remove_user(self, user_id: int) -> None:
|
||||
"""
|
||||
从权限设置中移除指定用户,首先更新文件,然后同步到 Redis 缓存
|
||||
"""
|
||||
try:
|
||||
# 首先从文件加载当前数据
|
||||
if os.path.exists(self.data_file):
|
||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
else:
|
||||
data = {"users": {}}
|
||||
|
||||
# 从权限数据中移除用户
|
||||
user_id_str = str(user_id)
|
||||
if user_id_str in data["users"]:
|
||||
del data["users"][user_id_str]
|
||||
|
||||
# 原子性写入文件
|
||||
temp_file = self.data_file + ".tmp"
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(data, indent=2, ensure_ascii=False))
|
||||
os.replace(temp_file, self.data_file) # 原子操作
|
||||
|
||||
# 同步到 Redis
|
||||
await self._sync_file_to_redis()
|
||||
logger.info(f"已从权限设置中移除用户 {user_id},并同步到 Redis")
|
||||
except Exception as e:
|
||||
logger.error(f"移除用户 {user_id} 权限失败: {e}")
|
||||
|
||||
async def check_permission(self, user_id: int, required_permission: Permission) -> bool:
|
||||
"""
|
||||
检查用户是否具有指定权限级别
|
||||
"""
|
||||
user_permission = await self.get_user_permission(user_id)
|
||||
|
||||
# 增强类型检查,防止将property对象等错误类型传递进来
|
||||
if not isinstance(required_permission, Permission):
|
||||
logger.error(f"权限检查失败:required_permission 不是 Permission 枚举类型,而是 {type(required_permission).__name__}")
|
||||
return False
|
||||
|
||||
return user_permission >= required_permission
|
||||
|
||||
async def get_all_user_permissions(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取所有已配置的用户权限(合并普通权限和管理员)
|
||||
"""
|
||||
permissions = {}
|
||||
try:
|
||||
# 从 Redis 获取基础权限
|
||||
all_perms = await redis_manager.redis.hgetall(self._REDIS_KEY)
|
||||
# 由于Redis连接已设置decode_responses=True,所以直接使用字符串
|
||||
permissions = {k: v for k, v in all_perms.items()}
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 获取所有权限失败: {e}")
|
||||
|
||||
# 获取 Redis 中的管理员列表并添加到权限字典中
|
||||
try:
|
||||
admins = await redis_manager.redis.smembers(self._REDIS_ADMINS_KEY)
|
||||
for admin_id in admins:
|
||||
permissions[str(admin_id)] = Permission.ADMIN.value
|
||||
except Exception as e:
|
||||
logger.error(f"获取管理员列表以合并权限时失败: {e}")
|
||||
|
||||
return permissions
|
||||
|
||||
async def is_admin(self, user_id: int) -> bool:
|
||||
"""
|
||||
检查用户是否为管理员
|
||||
"""
|
||||
try:
|
||||
return await redis_manager.redis.sismember(self._REDIS_ADMINS_KEY, str(user_id))
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 检查管理员权限失败: {e}")
|
||||
return False
|
||||
|
||||
async def add_admin(self, user_id: int) -> bool:
|
||||
"""
|
||||
添加管理员,首先更新文件,然后同步到 Redis 缓存
|
||||
"""
|
||||
try:
|
||||
# 首先从文件加载当前数据
|
||||
if os.path.exists(self.data_file):
|
||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
else:
|
||||
data = {"users": {}}
|
||||
|
||||
user_id_str = str(user_id)
|
||||
# 检查用户是否已经是管理员
|
||||
if data["users"].get(user_id_str) == Permission.ADMIN.value:
|
||||
return False # 用户已经是管理员
|
||||
|
||||
# 更新权限数据为管理员
|
||||
data["users"][user_id_str] = Permission.ADMIN.value
|
||||
|
||||
# 原子性写入文件
|
||||
temp_file = self.data_file + ".tmp"
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(data, indent=2, ensure_ascii=False))
|
||||
os.replace(temp_file, self.data_file) # 原子操作
|
||||
|
||||
# 同步到 Redis
|
||||
await self._sync_file_to_redis()
|
||||
logger.info(f"已添加新管理员 {user_id},并同步到 Redis")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加管理员 {user_id} 失败: {e}")
|
||||
return False
|
||||
|
||||
async def remove_admin(self, user_id: int) -> bool:
|
||||
"""
|
||||
从管理员列表中移除用户,首先更新文件,然后同步到 Redis 缓存
|
||||
"""
|
||||
try:
|
||||
# 首先从文件加载当前数据
|
||||
if os.path.exists(self.data_file):
|
||||
with open(self.data_file, "r", encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
else:
|
||||
data = {"users": {}}
|
||||
|
||||
user_id_str = str(user_id)
|
||||
# 检查用户是否是管理员
|
||||
if data["users"].get(user_id_str) != Permission.ADMIN.value:
|
||||
return False # 用户不是管理员
|
||||
|
||||
# 将管理员降级为普通用户(或者可以选择完全移除权限)
|
||||
# 这里我们将其设置为USER权限
|
||||
data["users"][user_id_str] = Permission.USER.value
|
||||
|
||||
# 原子性写入文件
|
||||
temp_file = self.data_file + ".tmp"
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(data, indent=2, ensure_ascii=False))
|
||||
os.replace(temp_file, self.data_file) # 原子操作
|
||||
|
||||
# 同步到 Redis
|
||||
await self._sync_file_to_redis()
|
||||
logger.info(f"已从管理员列表中移除用户 {user_id},并同步到 Redis")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"移除管理员 {user_id} 失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_all_admins(self) -> Set[int]:
|
||||
"""
|
||||
从 Redis 获取所有管理员的集合
|
||||
"""
|
||||
try:
|
||||
admins = await redis_manager.redis.smembers(self._REDIS_ADMINS_KEY)
|
||||
return {int(admin_id) for admin_id in admins}
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 获取所有管理员失败: {e}")
|
||||
return set()
|
||||
|
||||
async def clear_all(self) -> None:
|
||||
"""
|
||||
清空所有权限设置,首先更新文件,然后同步到 Redis 缓存
|
||||
"""
|
||||
try:
|
||||
# 创建空的权限数据
|
||||
empty_data: Dict[str, Dict] = {"users": {}}
|
||||
|
||||
# 原子性写入文件
|
||||
temp_file = self.data_file + ".tmp"
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(empty_data, indent=2, ensure_ascii=False))
|
||||
os.replace(temp_file, self.data_file) # 原子操作
|
||||
|
||||
# 同步到 Redis
|
||||
await self._sync_file_to_redis()
|
||||
logger.info("已清空所有权限设置,并同步到 Redis")
|
||||
except Exception as e:
|
||||
logger.error(f"清空权限数据失败: {e}")
|
||||
|
||||
|
||||
def require_admin(func):
|
||||
"""
|
||||
一个装饰器,用于限制命令只能由管理员执行。
|
||||
"""
|
||||
from functools import wraps
|
||||
from neobot.models.events.message import MessageEvent
|
||||
150
src/neobot/core/managers/plugin_manager.py
Normal file
150
src/neobot/core/managers/plugin_manager.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
插件管理器模块
|
||||
|
||||
负责扫描、加载和管理 `plugins` 目录下的所有插件。
|
||||
"""
|
||||
import importlib
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
from typing import Set
|
||||
from .command_manager import CommandManager
|
||||
|
||||
from ..utils.exceptions import SyncHandlerError, PluginLoadError, PluginReloadError, PluginNotFoundError
|
||||
from ..utils.logger import logger, ModuleLogger
|
||||
from ..utils.singleton import Singleton
|
||||
|
||||
# 确保logger在模块级别可见
|
||||
__all__ = ['PluginManager', 'logger']
|
||||
|
||||
# 确保logger在模块级别可见
|
||||
__all__ = ['PluginManager', 'logger']
|
||||
|
||||
|
||||
class PluginManager(Singleton):
|
||||
"""
|
||||
插件管理器类
|
||||
"""
|
||||
def __init__(self, command_manager: "CommandManager" | None = None) -> None:
|
||||
"""
|
||||
初始化插件管理器
|
||||
|
||||
:param command_manager: CommandManager的实例
|
||||
"""
|
||||
# 检查是否已经初始化
|
||||
if hasattr(self, '_command_manager'):
|
||||
return
|
||||
|
||||
# 只有首次初始化时才执行
|
||||
if command_manager:
|
||||
self._command_manager = command_manager
|
||||
self.loaded_plugins: Set[str] = set()
|
||||
# 创建模块专用日志记录器
|
||||
self.logger = ModuleLogger("PluginManager")
|
||||
|
||||
@property
|
||||
def command_manager(self):
|
||||
"""
|
||||
获取命令管理器实例
|
||||
"""
|
||||
return self._command_manager
|
||||
|
||||
def load_all_plugins(self) -> None:
|
||||
"""
|
||||
扫描并加载 `plugins` 目录下的所有插件。
|
||||
"""
|
||||
# 使用 pathlib 获取更可靠的路径
|
||||
# 当前文件: core/managers/plugin_manager.py
|
||||
# 目标: plugins/
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 回退两级到项目根目录 (core/managers -> core -> root)
|
||||
root_dir = os.path.dirname(os.path.dirname(current_dir))
|
||||
plugin_dir = os.path.join(root_dir, "plugins")
|
||||
|
||||
package_name = "plugins"
|
||||
|
||||
if not os.path.exists(plugin_dir):
|
||||
self.logger.error(f"插件目录不存在: {plugin_dir}")
|
||||
return
|
||||
|
||||
self.logger.info(f"正在从 {package_name} 加载插件 (路径: {plugin_dir})...")
|
||||
|
||||
for _, module_name, is_pkg in pkgutil.iter_modules([plugin_dir]):
|
||||
full_module_name = f"{package_name}.{module_name}"
|
||||
|
||||
action = "加载" # 初始化默认值
|
||||
try:
|
||||
if full_module_name in self.loaded_plugins:
|
||||
self.command_manager.unload_plugin(full_module_name)
|
||||
module = importlib.reload(sys.modules[full_module_name])
|
||||
action = "重载"
|
||||
else:
|
||||
module = importlib.import_module(full_module_name)
|
||||
action = "加载"
|
||||
|
||||
if hasattr(module, "__plugin_meta__"):
|
||||
meta = getattr(module, "__plugin_meta__")
|
||||
self.command_manager.plugins[full_module_name] = meta
|
||||
|
||||
self.loaded_plugins.add(full_module_name)
|
||||
|
||||
type_str = "包" if is_pkg else "文件"
|
||||
self.logger.success(f" [{type_str}] 成功{action}: {module_name}")
|
||||
except SyncHandlerError as e:
|
||||
error = PluginLoadError(
|
||||
plugin_name=module_name,
|
||||
message=f"同步处理器错误: {str(e)}",
|
||||
original_error=e
|
||||
)
|
||||
self.logger.error(f" 插件 {module_name} 加载失败: {error.message} (跳过此插件)")
|
||||
self.logger.log_custom_exception(error)
|
||||
except Exception as e:
|
||||
error = PluginLoadError(
|
||||
plugin_name=module_name,
|
||||
message=f"未知错误: {str(e)}",
|
||||
original_error=e
|
||||
)
|
||||
self.logger.exception(f" 加载插件 {module_name} 失败: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
|
||||
def reload_plugin(self, full_module_name: str) -> None:
|
||||
"""
|
||||
精确重载单个插件。
|
||||
"""
|
||||
if full_module_name not in self.loaded_plugins:
|
||||
self.logger.warning(f"尝试重载一个未被加载的插件: {full_module_name},将按首次加载处理。")
|
||||
|
||||
if full_module_name not in sys.modules:
|
||||
reload_error = PluginNotFoundError(
|
||||
plugin_name=full_module_name,
|
||||
message="模块未在sys.modules中找到"
|
||||
)
|
||||
self.logger.error(f"重载失败: {reload_error.message}")
|
||||
self.logger.log_custom_exception(reload_error)
|
||||
return
|
||||
|
||||
try:
|
||||
self.command_manager.unload_plugin(full_module_name)
|
||||
module = importlib.reload(sys.modules[full_module_name])
|
||||
|
||||
if hasattr(module, "__plugin_meta__"):
|
||||
meta = getattr(module, "__plugin_meta__")
|
||||
self.command_manager.plugins[full_module_name] = meta
|
||||
|
||||
self.logger.success(f"插件 {full_module_name} 已成功重载。")
|
||||
except SyncHandlerError as e:
|
||||
error = PluginReloadError(
|
||||
plugin_name=full_module_name,
|
||||
message=f"同步处理器错误: {str(e)}",
|
||||
original_error=e
|
||||
)
|
||||
self.logger.error(f"重载插件 {full_module_name} 失败: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
except Exception as e:
|
||||
error = PluginReloadError(
|
||||
plugin_name=full_module_name,
|
||||
message=f"未知错误: {str(e)}",
|
||||
original_error=e
|
||||
)
|
||||
self.logger.exception(f"重载插件 {full_module_name} 时发生错误: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
93
src/neobot/core/managers/redis_manager.py
Normal file
93
src/neobot/core/managers/redis_manager.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import redis.asyncio as redis
|
||||
from ..config_loader import global_config as config
|
||||
from ..utils.logger import logger
|
||||
from ..utils.singleton import Singleton
|
||||
|
||||
class RedisManager(Singleton):
|
||||
"""
|
||||
Redis 连接管理器(异步单例)
|
||||
"""
|
||||
_redis = None
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化 Redis 管理器
|
||||
"""
|
||||
# 调用父类 __init__ 确保单例初始化
|
||||
super().__init__()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
异步初始化 Redis 连接并进行健康检查
|
||||
"""
|
||||
if self._redis is None:
|
||||
try:
|
||||
redis_config = config.redis
|
||||
host = redis_config.host
|
||||
port = redis_config.port
|
||||
db = redis_config.db
|
||||
password = redis_config.password
|
||||
|
||||
logger.info(f"正在尝试连接 Redis: {host}:{port}, DB: {db}")
|
||||
|
||||
self._redis = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
ssl=False
|
||||
)
|
||||
if await self._redis.ping():
|
||||
logger.success("Redis 连接成功!")
|
||||
else:
|
||||
logger.error("Redis 连接失败: PING 命令无响应")
|
||||
except Exception as e:
|
||||
logger.exception(f"Redis 初始化时发生未知错误: {e}")
|
||||
self._redis = None
|
||||
|
||||
@property
|
||||
def redis(self):
|
||||
"""
|
||||
获取 Redis 连接实例
|
||||
"""
|
||||
if self._redis is None:
|
||||
raise ConnectionError("Redis 未初始化或连接失败,请先调用 initialize()")
|
||||
return self._redis
|
||||
|
||||
async def get(self, name):
|
||||
"""
|
||||
获取指定键的值
|
||||
"""
|
||||
return await self.redis.get(name)
|
||||
|
||||
async def set(self, name, value, ex=None):
|
||||
"""
|
||||
设置指定键的值
|
||||
"""
|
||||
return await self.redis.set(name, value, ex=ex)
|
||||
|
||||
async def execute_lua_script(self, script: str, keys: list, args: list):
|
||||
"""
|
||||
以原子方式执行 Lua 脚本
|
||||
|
||||
Args:
|
||||
script (str): 要执行的 Lua 脚本字符串
|
||||
keys (list): 脚本中使用的 Redis 键 (KEYS[1], KEYS[2], ...)
|
||||
args (list): 传递给脚本的参数 (ARGV[1], ARGV[2], ...)
|
||||
|
||||
Returns:
|
||||
Any: 脚本的返回值
|
||||
"""
|
||||
try:
|
||||
# redis-py 内部会自动处理脚本的缓存 (EVAL/EVALSHA)
|
||||
lua_script = self.redis.register_script(script)
|
||||
return await lua_script(keys=keys, args=args)
|
||||
except Exception as e:
|
||||
logger.error(f"执行 Lua 脚本失败: {e}")
|
||||
logger.debug(f"脚本内容: {script}")
|
||||
raise
|
||||
|
||||
|
||||
# 全局 Redis 管理器实例
|
||||
redis_manager = RedisManager()
|
||||
685
src/neobot/core/managers/reverse_ws_manager.py
Normal file
685
src/neobot/core/managers/reverse_ws_manager.py
Normal file
@@ -0,0 +1,685 @@
|
||||
"""
|
||||
反向 WebSocket 管理器模块
|
||||
|
||||
该模块提供了反向 WebSocket 服务端功能,允许 OneBot 实现(如 NapCat)
|
||||
主动连接到机器人服务器,而不是由机器人主动连接到 OneBot 实现。
|
||||
"""
|
||||
import asyncio
|
||||
import orjson
|
||||
import websockets
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
from typing import Dict, Any, Optional, Set
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import threading
|
||||
|
||||
from ..utils.logger import ModuleLogger
|
||||
from ..utils.error_codes import ErrorCode, create_error_response
|
||||
from .command_manager import matcher
|
||||
from neobot.models.events.factory import EventFactory
|
||||
from ..bot import Bot
|
||||
from ..ws import ReverseWSClient as _ReverseWSClient
|
||||
|
||||
|
||||
class ReverseWSClient(_ReverseWSClient):
|
||||
"""
|
||||
反向 WebSocket 客户端代理,用于 Bot 实例调用 API。
|
||||
"""
|
||||
def __init__(self, manager: "ReverseWSManager", client_id: str):
|
||||
super().__init__(manager, client_id)
|
||||
self.manager = manager
|
||||
self.client_id = client_id
|
||||
|
||||
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
|
||||
"""
|
||||
通过 ReverseWSManager 调用 API。
|
||||
"""
|
||||
return await self.manager.call_api(action, params, self.client_id)
|
||||
|
||||
|
||||
class ReverseWSManager:
|
||||
"""
|
||||
反向 WebSocket 管理器,作为服务端接收 OneBot 实现的连接。
|
||||
支持多前端负载均衡和防重复发送机制。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化反向 WebSocket 管理器。
|
||||
"""
|
||||
self.server = None
|
||||
self.clients: Dict[str, WebSocketServerProtocol] = {}
|
||||
self.client_self_ids: Dict[str, int] = {}
|
||||
self._pending_requests: Dict[str, asyncio.Future] = {}
|
||||
self._running = False
|
||||
self.logger = ModuleLogger("ReverseWSManager")
|
||||
|
||||
# 负载均衡相关
|
||||
self._active_client_id: Optional[str] = None # 当前活跃的客户端(用于消息发送)
|
||||
self._client_load: Dict[str, int] = {} # 客户端负载计数
|
||||
self._client_health: Dict[str, datetime] = {} # 客户端健康检查时间
|
||||
|
||||
# 防重复发送相关
|
||||
self._processed_events: Dict[str, Dict[str, datetime]] = {} # 每个客户端已处理的事件ID和时间
|
||||
self._event_ttl = 60 # 事件ID保留时间(秒)
|
||||
self._message_locks: Dict[str, asyncio.Lock] = {} # 消息处理锁
|
||||
self._message_lock_times: Dict[str, datetime] = {} # 消息锁创建时间
|
||||
self._lock_ttl = 300 # 锁保留时间(秒)
|
||||
|
||||
# 基于消息内容的防重复(仅用于群聊)
|
||||
self._processed_messages: Dict[str, Dict[str, datetime]] = {} # 每个客户端已处理的消息内容和时间
|
||||
self._message_content_ttl = 5 # 消息内容保留时间(秒)
|
||||
|
||||
# 启动清理任务
|
||||
self._cleanup_task = None
|
||||
|
||||
# Bot实例字典(每个前端独立的Bot实例)
|
||||
self.bots: Dict[str, Bot] = {}
|
||||
|
||||
# 正在处理的事件ID集合(用于防止重复处理)
|
||||
self._processing_events: Dict[str, Set[str]] = {} # client_id: set of event_ids
|
||||
|
||||
# 线程安全锁
|
||||
self._clients_lock = threading.RLock()
|
||||
self._bots_lock = threading.RLock()
|
||||
self._pending_requests_lock = threading.RLock()
|
||||
self._load_lock = threading.RLock()
|
||||
self._health_lock = threading.RLock()
|
||||
self._processed_events_lock = threading.RLock()
|
||||
self._processed_messages_lock = threading.RLock()
|
||||
self._processing_events_lock = threading.RLock()
|
||||
self._message_locks_lock = threading.RLock()
|
||||
self._message_lock_times_lock = threading.RLock()
|
||||
|
||||
async def start(self, host: str = "0.0.0.0", port: int = 3002) -> None:
|
||||
"""
|
||||
启动反向 WebSocket 服务端。
|
||||
|
||||
Args:
|
||||
host: 监听地址,默认为 0.0.0.0
|
||||
port: 监听端口,默认为 3002
|
||||
"""
|
||||
self._running = True
|
||||
self.server = await websockets.serve(
|
||||
self._handle_client,
|
||||
host,
|
||||
port,
|
||||
ping_interval=20,
|
||||
ping_timeout=20
|
||||
)
|
||||
self.logger.success(f"反向 WebSocket 服务端已启动: ws://{host}:{port}")
|
||||
|
||||
# 启动清理任务
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_expired_data())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
停止反向 WebSocket 服务端。
|
||||
"""
|
||||
self._running = False
|
||||
|
||||
# 停止清理任务
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self.server:
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
|
||||
for client_id in list(self.clients.keys()):
|
||||
await self._disconnect_client(client_id)
|
||||
|
||||
self.logger.success("反向 WebSocket 服务端已停止")
|
||||
|
||||
async def _handle_client(
|
||||
self,
|
||||
websocket: WebSocketServerProtocol,
|
||||
path: str = None
|
||||
) -> None:
|
||||
"""
|
||||
处理客户端连接。
|
||||
|
||||
Args:
|
||||
websocket: WebSocket 连接对象
|
||||
path: 连接路径
|
||||
"""
|
||||
client_id = str(uuid.uuid4())
|
||||
self.clients[client_id] = websocket
|
||||
self.logger.info(f"新客户端连接: {client_id}")
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
try:
|
||||
data = orjson.loads(message)
|
||||
|
||||
# 处理 API 响应
|
||||
echo_id = data.get("echo")
|
||||
if echo_id and echo_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(echo_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
continue
|
||||
|
||||
# 处理上报事件
|
||||
if "post_type" in data:
|
||||
event_id = data.get('id') or data.get('post_id') or data.get('message_id') or data.get('time')
|
||||
self.logger.debug(f"收到事件: client_id={client_id}, event_id={event_id}, post_type={data.get('post_type')}")
|
||||
asyncio.create_task(self._on_event(client_id, data))
|
||||
|
||||
except orjson.JSONDecodeError as e:
|
||||
self.logger.error(f"JSON 解析失败: {str(e)}")
|
||||
except Exception as e:
|
||||
self.logger.exception(f"处理消息异常: {str(e)}")
|
||||
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
self.logger.info(f"客户端断开连接: {client_id} - {str(e)}")
|
||||
except Exception as e:
|
||||
self.logger.exception(f"客户端异常: {str(e)}")
|
||||
finally:
|
||||
await self._disconnect_client(client_id)
|
||||
|
||||
async def _cleanup_expired_data(self) -> None:
|
||||
"""
|
||||
清理过期的事件ID和消息锁
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(10) # 每10秒清理一次
|
||||
|
||||
current_time = datetime.now()
|
||||
|
||||
# 清理过期的事件ID(按客户端)
|
||||
with self._processed_events_lock:
|
||||
for client_id, events in list(self._processed_events.items()):
|
||||
expired_events = [
|
||||
event_id for event_id, timestamp in events.items()
|
||||
if (current_time - timestamp).total_seconds() > self._event_ttl
|
||||
]
|
||||
for event_id in expired_events:
|
||||
del events[event_id]
|
||||
if not events:
|
||||
del self._processed_events[client_id]
|
||||
|
||||
# 清理过期的消息锁
|
||||
with self._message_lock_times_lock:
|
||||
expired_locks = [
|
||||
lock_key for lock_key, timestamp in self._message_lock_times.items()
|
||||
if (current_time - timestamp).total_seconds() > self._lock_ttl
|
||||
]
|
||||
for lock_key in expired_locks:
|
||||
with self._message_locks_lock:
|
||||
if lock_key in self._message_locks:
|
||||
del self._message_locks[lock_key]
|
||||
if lock_key in self._message_lock_times:
|
||||
del self._message_lock_times[lock_key]
|
||||
|
||||
# 清理过期的消息内容(按客户端)
|
||||
with self._processed_messages_lock:
|
||||
for client_id, messages in list(self._processed_messages.items()):
|
||||
expired_messages = [
|
||||
msg_key for msg_key, timestamp in messages.items()
|
||||
if (current_time - timestamp).total_seconds() > self._message_content_ttl
|
||||
]
|
||||
for msg_key in expired_messages:
|
||||
del messages[msg_key]
|
||||
if not messages:
|
||||
del self._processed_messages[client_id]
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"清理过期数据失败: {str(e)}")
|
||||
|
||||
async def _disconnect_client(self, client_id: str) -> None:
|
||||
"""
|
||||
断开客户端连接。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
"""
|
||||
with self._clients_lock:
|
||||
if client_id in self.clients:
|
||||
del self.clients[client_id]
|
||||
with self._clients_lock:
|
||||
if client_id in self.client_self_ids:
|
||||
del self.client_self_ids[client_id]
|
||||
with self._load_lock:
|
||||
if client_id in self._client_load:
|
||||
del self._client_load[client_id]
|
||||
with self._health_lock:
|
||||
if client_id in self._client_health:
|
||||
del self._client_health[client_id]
|
||||
with self._bots_lock:
|
||||
if client_id in self.bots:
|
||||
# 从 BotManager 注销
|
||||
from .bot_manager import bot_manager
|
||||
if self.bots[client_id].self_id:
|
||||
bot_manager.unregister_bot(str(self.bots[client_id].self_id))
|
||||
del self.bots[client_id]
|
||||
|
||||
# 清理该客户端的防重复数据
|
||||
with self._processed_events_lock:
|
||||
if client_id in self._processed_events:
|
||||
del self._processed_events[client_id]
|
||||
with self._processed_messages_lock:
|
||||
if client_id in self._processed_messages:
|
||||
del self._processed_messages[client_id]
|
||||
with self._processing_events_lock:
|
||||
if client_id in self._processing_events:
|
||||
del self._processing_events[client_id]
|
||||
|
||||
self.logger.info(f"客户端已断开并清理: {client_id}")
|
||||
|
||||
async def _on_event(self, client_id: str, event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
处理事件,包含防重复发送和负载均衡逻辑。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
event_data: 事件数据
|
||||
"""
|
||||
# 获取事件ID
|
||||
event_id = event_data.get('id') or event_data.get('post_id') or event_data.get('message_id') or event_data.get('time')
|
||||
if not event_id:
|
||||
self.logger.debug(f"_on_event: 事件ID为空, client_id={client_id}")
|
||||
return
|
||||
|
||||
event_key = f"{event_data.get('post_type')}:{event_id}"
|
||||
|
||||
# 检查客户端是否已连接
|
||||
with self._clients_lock:
|
||||
if client_id not in self.clients:
|
||||
self.logger.debug(f"_on_event: 客户端已断开, client_id={client_id}")
|
||||
return
|
||||
|
||||
# 检查是否正在处理
|
||||
with self._processing_events_lock:
|
||||
if client_id not in self._processing_events:
|
||||
self._processing_events[client_id] = set()
|
||||
|
||||
if event_key in self._processing_events[client_id]:
|
||||
self.logger.debug(f"_on_event: 事件正在处理中, client_id={client_id}, event_key={event_key}")
|
||||
return
|
||||
|
||||
# 标记为正在处理
|
||||
self._processing_events[client_id].add(event_key)
|
||||
|
||||
try:
|
||||
event = EventFactory.create_event(event_data)
|
||||
|
||||
if hasattr(event, 'self_id'):
|
||||
with self._clients_lock:
|
||||
self.client_self_ids[client_id] = event.self_id
|
||||
|
||||
# 为事件注入Bot实例
|
||||
from ..ws import ReverseWSClient
|
||||
from .bot_manager import bot_manager
|
||||
|
||||
# 为每个前端创建独立的Bot实例
|
||||
with self._bots_lock:
|
||||
if client_id not in self.bots:
|
||||
# 使用 ReverseWSClient 代理
|
||||
temp_ws = ReverseWSClient(self, client_id)
|
||||
temp_ws.self_id = event.self_id if hasattr(event, 'self_id') else 0
|
||||
self.bots[client_id] = Bot(temp_ws)
|
||||
|
||||
# 注册到 BotManager
|
||||
if event.self_id:
|
||||
bot_manager.register_bot(self.bots[client_id])
|
||||
|
||||
event.bot = self.bots[client_id]
|
||||
|
||||
# 记录客户端健康状态
|
||||
with self._health_lock:
|
||||
self._client_health[client_id] = datetime.now()
|
||||
|
||||
# 检查是否为重复事件(按客户端)
|
||||
is_duplicate = self._is_duplicate_event(event_data, client_id)
|
||||
self.logger.debug(f"事件防重复检查: client_id={client_id}, event_id={event_data.get('message_id')}, is_duplicate={is_duplicate}")
|
||||
if is_duplicate:
|
||||
self.logger.debug(f"检测到重复事件,已忽略: {event_data.get('id')}")
|
||||
return
|
||||
|
||||
# 处理消息事件
|
||||
if event.post_type == "message":
|
||||
sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown"
|
||||
message_type = getattr(event, "message_type", "Unknown")
|
||||
user_id = getattr(event, "user_id", "Unknown")
|
||||
raw_message = getattr(event, "raw_message", "")
|
||||
self.logger.info(f"[消息] {message_type} | {user_id}({sender_name}): {raw_message}")
|
||||
|
||||
# 使用锁防止同一消息被多次处理
|
||||
message_key = self._get_message_key(event_data)
|
||||
async with self._get_message_lock(message_key):
|
||||
# 再次检查是否重复(防止并发问题)
|
||||
if self._is_duplicate_event(event_data, client_id):
|
||||
self.logger.debug(f"并发检测到重复消息(事件ID),已忽略: {message_key}")
|
||||
return
|
||||
|
||||
# 检查是否重复(基于消息内容,按客户端,仅群聊)
|
||||
is_duplicate_content = self._is_duplicate_message(event_data, client_id)
|
||||
self.logger.debug(f"锁内内容检查: client_id={client_id}, is_duplicate={is_duplicate_content}")
|
||||
if is_duplicate_content:
|
||||
self.logger.debug(f"并发检测到重复消息(内容),已忽略: {message_key}")
|
||||
return
|
||||
|
||||
# 标记事件已处理(按客户端)
|
||||
with self._processed_events_lock:
|
||||
self._mark_event_processed(event_data, client_id)
|
||||
|
||||
# 更新客户端负载
|
||||
with self._load_lock:
|
||||
self._update_client_load(client_id)
|
||||
|
||||
await matcher.handle_event(event.bot, event)
|
||||
else:
|
||||
# 对于非消息事件,直接标记并处理
|
||||
with self._processed_events_lock:
|
||||
self._mark_event_processed(event_data, client_id)
|
||||
|
||||
if event.post_type == "notice":
|
||||
notice_type = getattr(event, "notice_type", "Unknown")
|
||||
self.logger.info(f"[通知] {notice_type}")
|
||||
await matcher.handle_event(event.bot, event)
|
||||
|
||||
elif event.post_type == "request":
|
||||
request_type = getattr(event, "request_type", "Unknown")
|
||||
self.logger.info(f"[请求] {request_type}")
|
||||
await matcher.handle_event(event.bot, event)
|
||||
|
||||
elif event.post_type == "meta_event":
|
||||
meta_event_type = getattr(event, "meta_event_type", "Unknown")
|
||||
self.logger.debug(f"[元事件] {meta_event_type}")
|
||||
await matcher.handle_event(event.bot, event)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"事件处理异常: {str(e)}")
|
||||
finally:
|
||||
# 清理正在处理的事件
|
||||
with self._processing_events_lock:
|
||||
if client_id in self._processing_events:
|
||||
if event_key in self._processing_events[client_id]:
|
||||
self._processing_events[client_id].discard(event_key)
|
||||
# 如果集合为空,删除该客户端的记录
|
||||
if not self._processing_events[client_id]:
|
||||
del self._processing_events[client_id]
|
||||
|
||||
async def call_api(
|
||||
self,
|
||||
action: str,
|
||||
params: Optional[Dict[Any, Any]] = None,
|
||||
client_id: Optional[str] = None,
|
||||
use_load_balance: bool = True
|
||||
) -> Dict[Any, Any]:
|
||||
"""
|
||||
向客户端发送 API 请求。
|
||||
|
||||
Args:
|
||||
action: API 动作名称
|
||||
params: API 参数
|
||||
client_id: 客户端 ID,如果为 None 则根据负载均衡策略选择
|
||||
use_load_balance: 是否使用负载均衡,默认为 True
|
||||
|
||||
Returns:
|
||||
API 响应数据
|
||||
"""
|
||||
if not self.clients:
|
||||
self.logger.error("调用 API 失败: 没有可用的客户端连接")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_DISCONNECTED,
|
||||
message="没有可用的客户端连接",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
# 如果没有指定客户端,使用负载均衡
|
||||
if client_id is None and use_load_balance:
|
||||
# 优先选择健康的客户端
|
||||
healthy_clients = self.get_healthy_clients()
|
||||
if healthy_clients:
|
||||
# 选择负载最低的客户端
|
||||
client_id = self.get_client_with_least_load()
|
||||
if client_id is None and healthy_clients:
|
||||
with self._clients_lock:
|
||||
client_id = list(healthy_clients.keys())[0]
|
||||
else:
|
||||
# 如果没有健康客户端,使用所有客户端中的一个
|
||||
with self._clients_lock:
|
||||
client_id = list(self.clients.keys())[0]
|
||||
|
||||
echo_id = str(uuid.uuid4())
|
||||
payload = {"action": action, "params": params or {}, "echo": echo_id}
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
with self._pending_requests_lock:
|
||||
self._pending_requests[echo_id] = future
|
||||
|
||||
try:
|
||||
targets = [client_id] if client_id else None
|
||||
clients_to_send = []
|
||||
|
||||
with self._clients_lock:
|
||||
if targets is None:
|
||||
targets = list(self.clients.keys())
|
||||
for cid in targets:
|
||||
if cid in self.clients:
|
||||
clients_to_send.append((cid, self.clients[cid]))
|
||||
|
||||
for cid, websocket in clients_to_send:
|
||||
await websocket.send(orjson.dumps(payload).decode('utf-8'))
|
||||
|
||||
return await asyncio.wait_for(future, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
with self._pending_requests_lock:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.warning(f"API 调用超时: action={action}, params={params}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.TIMEOUT_ERROR,
|
||||
message="API调用超时",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
except Exception as e:
|
||||
with self._pending_requests_lock:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
message=f"API调用异常: {str(e)}",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
def get_connected_clients(self) -> Dict[str, int]:
|
||||
"""
|
||||
获取已连接的客户端列表。
|
||||
|
||||
Returns:
|
||||
客户端 ID 和 self_id 的映射字典
|
||||
"""
|
||||
with self._clients_lock:
|
||||
return self.client_self_ids.copy()
|
||||
|
||||
def _is_duplicate_event(self, event_data: Dict[str, Any], client_id: str) -> bool:
|
||||
"""
|
||||
检查是否为重复事件。
|
||||
|
||||
Args:
|
||||
event_data: 事件数据
|
||||
client_id: 客户端ID
|
||||
|
||||
Returns:
|
||||
是否为重复事件
|
||||
"""
|
||||
# 尝试多种可能的事件ID字段
|
||||
event_id = (event_data.get('id') or
|
||||
event_data.get('post_id') or
|
||||
event_data.get('message_id') or
|
||||
event_data.get('time'))
|
||||
if not event_id:
|
||||
return False
|
||||
|
||||
event_key = f"{event_data.get('post_type')}:{event_id}"
|
||||
|
||||
# 检查该客户端是否已处理过此事件
|
||||
with self._processed_events_lock:
|
||||
if client_id not in self._processed_events:
|
||||
self.logger.debug(f"_is_duplicate_event: client_id={client_id}不在_processed_events中, event_key={event_key}, 返回False")
|
||||
return False
|
||||
|
||||
is_duplicate = event_key in self._processed_events[client_id]
|
||||
self.logger.debug(f"_is_duplicate_event: client_id={client_id}, event_key={event_key}, in_processed={is_duplicate}, processed_events_count={len(self._processed_events[client_id])}")
|
||||
return is_duplicate
|
||||
|
||||
def _is_duplicate_message(self, event_data: Dict[str, Any], client_id: str) -> bool:
|
||||
"""
|
||||
检查是否为重复消息(基于消息内容)。
|
||||
|
||||
Args:
|
||||
event_data: 事件数据
|
||||
client_id: 客户端ID
|
||||
|
||||
Returns:
|
||||
是否为重复消息
|
||||
"""
|
||||
if event_data.get('post_type') != 'message':
|
||||
return False
|
||||
|
||||
# 只对群聊消息进行内容防重复
|
||||
if event_data.get('message_type') != 'group':
|
||||
return False
|
||||
|
||||
# 生成消息内容标识
|
||||
raw_message = event_data.get('raw_message', '')
|
||||
user_id = event_data.get('user_id')
|
||||
group_id = event_data.get('group_id', '0')
|
||||
|
||||
# 使用消息内容、用户ID和群组ID作为标识
|
||||
content_key = f"content:{raw_message}:{user_id}:{group_id}"
|
||||
|
||||
# 检查该客户端是否已处理过此消息内容
|
||||
with self._processed_messages_lock:
|
||||
if client_id not in self._processed_messages:
|
||||
return False
|
||||
|
||||
return content_key in self._processed_messages[client_id]
|
||||
|
||||
def _mark_event_processed(self, event_data: Dict[str, Any], client_id: str) -> None:
|
||||
"""
|
||||
标记事件已处理。
|
||||
|
||||
Args:
|
||||
event_data: 事件数据
|
||||
client_id: 客户端ID
|
||||
"""
|
||||
# 尝试多种可能的事件ID字段
|
||||
event_id = (event_data.get('id') or
|
||||
event_data.get('post_id') or
|
||||
event_data.get('message_id') or
|
||||
event_data.get('time'))
|
||||
if not event_id:
|
||||
self.logger.debug(f"_mark_event_processed: event_id为空, event_data={event_data}")
|
||||
return
|
||||
|
||||
event_key = f"{event_data.get('post_type')}:{event_id}"
|
||||
|
||||
# 为该客户端记录已处理的事件
|
||||
with self._processed_events_lock:
|
||||
if client_id not in self._processed_events:
|
||||
self._processed_events[client_id] = {}
|
||||
self._processed_events[client_id][event_key] = datetime.now()
|
||||
self.logger.debug(f"_mark_event_processed: client_id={client_id}, event_key={event_key}, processed_events_count={len(self._processed_events[client_id])}")
|
||||
|
||||
# 只对群聊消息标记内容已处理
|
||||
if event_data.get('post_type') == 'message' and event_data.get('message_type') == 'group':
|
||||
raw_message = event_data.get('raw_message', '')
|
||||
user_id = event_data.get('user_id')
|
||||
group_id = event_data.get('group_id', '0')
|
||||
content_key = f"content:{raw_message}:{user_id}:{group_id}"
|
||||
|
||||
with self._processed_messages_lock:
|
||||
if client_id not in self._processed_messages:
|
||||
self._processed_messages[client_id] = {}
|
||||
self._processed_messages[client_id][content_key] = datetime.now()
|
||||
|
||||
def _get_message_key(self, event_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
获取消息唯一标识。
|
||||
|
||||
Args:
|
||||
event_data: 事件数据
|
||||
|
||||
Returns:
|
||||
消息唯一标识
|
||||
"""
|
||||
if event_data.get('post_type') == 'message':
|
||||
message_id = event_data.get('message_id') or event_data.get('id')
|
||||
user_id = event_data.get('user_id')
|
||||
return f"msg:{message_id}:{user_id}"
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _get_message_lock(self, key: str) -> asyncio.Lock:
|
||||
"""
|
||||
获取消息处理锁。
|
||||
|
||||
Args:
|
||||
key: 消息唯一标识
|
||||
|
||||
Returns:
|
||||
asyncio.Lock 实例
|
||||
"""
|
||||
with self._message_locks_lock:
|
||||
if key not in self._message_locks:
|
||||
self._message_locks[key] = asyncio.Lock()
|
||||
with self._message_lock_times_lock:
|
||||
self._message_lock_times[key] = datetime.now()
|
||||
return self._message_locks[key]
|
||||
|
||||
def _update_client_load(self, client_id: str) -> None:
|
||||
"""
|
||||
更新客户端负载。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
"""
|
||||
with self._load_lock:
|
||||
if client_id not in self._client_load:
|
||||
self._client_load[client_id] = 0
|
||||
self._client_load[client_id] += 1
|
||||
|
||||
def get_client_with_least_load(self) -> Optional[str]:
|
||||
"""
|
||||
获取负载最低的客户端。
|
||||
|
||||
Returns:
|
||||
客户端 ID,如果没有客户端则返回 None
|
||||
"""
|
||||
with self._load_lock:
|
||||
if not self._client_load:
|
||||
return None
|
||||
|
||||
return min(self._client_load.keys(), key=lambda k: self._client_load[k])
|
||||
|
||||
def get_healthy_clients(self) -> Dict[str, int]:
|
||||
"""
|
||||
获取健康的客户端列表(最近30秒内有活动)。
|
||||
|
||||
Returns:
|
||||
健康的客户端 ID 和 self_id 的映射字典
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
healthy = {}
|
||||
|
||||
with self._health_lock:
|
||||
with self._clients_lock:
|
||||
for client_id, last_health in self._client_health.items():
|
||||
if (current_time - last_health).total_seconds() < 30:
|
||||
if client_id in self.client_self_ids:
|
||||
healthy[client_id] = self.client_self_ids[client_id]
|
||||
|
||||
return healthy
|
||||
|
||||
|
||||
reverse_ws_manager = ReverseWSManager()
|
||||
379
src/neobot/core/managers/thread_manager.py
Normal file
379
src/neobot/core/managers/thread_manager.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
线程管理器模块
|
||||
|
||||
该模块提供了多线程支持,用于处理来自多个实现端的并发事件。
|
||||
每个 WebSocket 连接在独立的线程中运行,避免阻塞主事件循环。
|
||||
"""
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Dict, Optional, Callable, Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from ..utils.logger import ModuleLogger
|
||||
from ..config_loader import global_config
|
||||
|
||||
|
||||
class ThreadManager:
|
||||
"""
|
||||
线程管理器,负责管理多线程环境下的事件处理。
|
||||
|
||||
该管理器为每个 WebSocket 连接提供独立的线程池,
|
||||
确保多前端场景下的事件处理不会相互阻塞。
|
||||
"""
|
||||
|
||||
_instance: Optional['ThreadManager'] = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> 'ThreadManager':
|
||||
"""
|
||||
单例模式:确保全局只有一个线程管理器实例。
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
初始化线程管理器。
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.logger = ModuleLogger("ThreadManager")
|
||||
|
||||
# 线程池配置
|
||||
self._max_workers: int = global_config.threading.max_workers
|
||||
self._thread_name_prefix: str = global_config.threading.thread_name_prefix
|
||||
|
||||
# 线程池
|
||||
self._executor: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
# 每个客户端的线程池(用于反向 WebSocket)
|
||||
self._client_executors: Dict[str, ThreadPoolExecutor] = {}
|
||||
self._client_executor_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
# 线程安全的事件循环(用于跨线程调用)
|
||||
self._event_loops: Dict[str, asyncio.AbstractEventLoop] = {}
|
||||
self._event_loops_lock = threading.Lock()
|
||||
|
||||
# 统计信息
|
||||
self._stats: Dict[str, Any] = {
|
||||
'total_tasks': 0,
|
||||
'completed_tasks': 0,
|
||||
'failed_tasks': 0,
|
||||
'active_threads': 0,
|
||||
'client_tasks': {}
|
||||
}
|
||||
self._stats_lock = threading.Lock()
|
||||
|
||||
self._initialized = True
|
||||
self.logger.success("线程管理器初始化完成")
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
启动线程管理器,创建主线程池。
|
||||
"""
|
||||
if self._executor is None:
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=self._max_workers,
|
||||
thread_name_prefix=self._thread_name_prefix
|
||||
)
|
||||
self.logger.success(f"主 ThreadPool 已启动: max_workers={self._max_workers}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
关闭线程管理器,释放所有资源。
|
||||
"""
|
||||
self.logger.info("正在关闭线程管理器...")
|
||||
|
||||
# 关闭所有客户端线程池
|
||||
for client_id, executor in list(self._client_executors.items()):
|
||||
self._shutdown_client_executor(client_id)
|
||||
|
||||
# 关闭主执行器
|
||||
if self._executor is not None:
|
||||
self._executor.shutdown(wait=True)
|
||||
self._executor = None
|
||||
|
||||
self.logger.success("线程管理器已关闭")
|
||||
|
||||
def _shutdown_client_executor(self, client_id: str) -> None:
|
||||
"""
|
||||
关闭特定客户端的线程池。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
"""
|
||||
if client_id in self._client_executors:
|
||||
try:
|
||||
self._client_executors[client_id].shutdown(wait=True)
|
||||
del self._client_executors[client_id]
|
||||
self.logger.info(f"客户端 {client_id} 的线程池已关闭")
|
||||
except Exception as e:
|
||||
self.logger.error(f"关闭客户端 {client_id} 线程池失败: {e}")
|
||||
|
||||
def get_main_executor(self) -> ThreadPoolExecutor:
|
||||
"""
|
||||
获取主线程池。
|
||||
|
||||
Returns:
|
||||
ThreadPoolExecutor 实例
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果线程管理器未启动
|
||||
"""
|
||||
if self._executor is None:
|
||||
raise RuntimeError("线程管理器未启动,请先调用 start()")
|
||||
return self._executor
|
||||
|
||||
def get_client_executor(self, client_id: str) -> ThreadPoolExecutor:
|
||||
"""
|
||||
获取特定客户端的线程池(为反向 WebSocket 设计)。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
|
||||
Returns:
|
||||
ThreadPoolExecutor 实例
|
||||
"""
|
||||
if client_id not in self._client_executors:
|
||||
with threading.Lock():
|
||||
if client_id not in self._client_executors:
|
||||
executor = ThreadPoolExecutor(
|
||||
max_workers=global_config.threading.client_max_workers,
|
||||
thread_name_prefix=f"{self._thread_name_prefix}_{client_id[:8]}"
|
||||
)
|
||||
self._client_executors[client_id] = executor
|
||||
self._client_executor_locks[client_id] = threading.Lock()
|
||||
self.logger.info(f"为客户端 {client_id} 创建线程池")
|
||||
|
||||
return self._client_executors[client_id]
|
||||
|
||||
def submit_to_main_executor(
|
||||
self,
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
提交任务到主线程池(同步)。
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
函数执行结果
|
||||
"""
|
||||
executor = self.get_main_executor()
|
||||
future = executor.submit(func, *args, **kwargs)
|
||||
self._update_stats('total_tasks')
|
||||
try:
|
||||
result = future.result()
|
||||
self._update_stats('completed_tasks')
|
||||
return result
|
||||
except Exception as e:
|
||||
self._update_stats('failed_tasks')
|
||||
self.logger.error(f"主线程池任务执行失败: {e}")
|
||||
raise
|
||||
|
||||
async def submit_to_main_executor_async(
|
||||
self,
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
提交任务到主线程池(异步)。
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
函数执行结果
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
executor = self.get_main_executor()
|
||||
future = loop.run_in_executor(executor, lambda: func(*args, **kwargs))
|
||||
self._update_stats('total_tasks')
|
||||
try:
|
||||
result = await future
|
||||
self._update_stats('completed_tasks')
|
||||
return result
|
||||
except Exception as e:
|
||||
self._update_stats('failed_tasks')
|
||||
self.logger.error(f"异步主线程池任务执行失败: {e}")
|
||||
raise
|
||||
|
||||
def submit_to_client_executor(
|
||||
self,
|
||||
client_id: str,
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
提交任务到特定客户端的线程池。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
func: 要执行的函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
函数执行结果
|
||||
"""
|
||||
executor = self.get_client_executor(client_id)
|
||||
future = executor.submit(func, *args, **kwargs)
|
||||
self._update_client_stats(client_id, 'total_tasks')
|
||||
try:
|
||||
result = future.result()
|
||||
self._update_client_stats(client_id, 'completed_tasks')
|
||||
return result
|
||||
except Exception as e:
|
||||
self._update_client_stats(client_id, 'failed_tasks')
|
||||
self.logger.error(f"客户端 {client_id} 线程池任务执行失败: {e}")
|
||||
raise
|
||||
|
||||
async def submit_to_client_executor_async(
|
||||
self,
|
||||
client_id: str,
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
提交任务到特定客户端的线程池(异步)。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
func: 要执行的函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
函数执行结果
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
executor = self.get_client_executor(client_id)
|
||||
future = loop.run_in_executor(executor, lambda: func(*args, **kwargs))
|
||||
self._update_client_stats(client_id, 'total_tasks')
|
||||
try:
|
||||
result = await future
|
||||
self._update_client_stats(client_id, 'completed_tasks')
|
||||
return result
|
||||
except Exception as e:
|
||||
self._update_client_stats(client_id, 'failed_tasks')
|
||||
self.logger.error(f"客户端 {client_id} 异步线程池任务执行失败: {e}")
|
||||
raise
|
||||
|
||||
def run_coroutine_threadsafe(
|
||||
self,
|
||||
coro,
|
||||
client_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
在指定客户端的事件循环中运行协程(线程安全)。
|
||||
|
||||
Args:
|
||||
coro: 协程对象
|
||||
client_id: 客户端 ID,如果为 None 则使用主事件循环
|
||||
|
||||
Returns:
|
||||
协程执行结果
|
||||
"""
|
||||
if client_id is None:
|
||||
loop = asyncio.get_running_loop()
|
||||
else:
|
||||
with self._event_loops_lock:
|
||||
if client_id not in self._event_loops:
|
||||
self._event_loops[client_id] = asyncio.new_event_loop()
|
||||
threading.Thread(
|
||||
target=self._event_loop_thread,
|
||||
args=(client_id,),
|
||||
daemon=True
|
||||
).start()
|
||||
loop = self._event_loops[client_id]
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
return future.result()
|
||||
|
||||
def _event_loop_thread(self, client_id: str) -> None:
|
||||
"""
|
||||
事件循环线程(用于反向 WebSocket 客户端)。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
"""
|
||||
asyncio.set_event_loop(self._event_loops[client_id])
|
||||
self.logger.info(f"事件循环线程启动: client_id={client_id}")
|
||||
try:
|
||||
self._event_loops[client_id].run_forever()
|
||||
finally:
|
||||
self._event_loops[client_id].close()
|
||||
self.logger.info(f"事件循环线程停止: client_id={client_id}")
|
||||
|
||||
def _update_stats(self, key: str) -> None:
|
||||
"""
|
||||
更新全局统计信息。
|
||||
|
||||
Args:
|
||||
key: 统计项键名
|
||||
"""
|
||||
with self._stats_lock:
|
||||
self._stats[key] = self._stats.get(key, 0) + 1
|
||||
|
||||
def _update_client_stats(self, client_id: str, key: str) -> None:
|
||||
"""
|
||||
更新客户端统计信息。
|
||||
|
||||
Args:
|
||||
client_id: 客户端 ID
|
||||
key: 统计项键名
|
||||
"""
|
||||
with self._stats_lock:
|
||||
if client_id not in self._stats['client_tasks']:
|
||||
self._stats['client_tasks'][client_id] = {
|
||||
'total_tasks': 0,
|
||||
'completed_tasks': 0,
|
||||
'failed_tasks': 0
|
||||
}
|
||||
self._stats['client_tasks'][client_id][key] += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息。
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
with self._stats_lock:
|
||||
stats = self._stats.copy()
|
||||
stats['client_tasks'] = stats.get('client_tasks', {}).copy()
|
||||
return stats
|
||||
|
||||
def get_active_threads_count(self) -> int:
|
||||
"""
|
||||
获取活动线程数量。
|
||||
|
||||
Returns:
|
||||
活动线程数量
|
||||
"""
|
||||
import threading
|
||||
return sum(
|
||||
1 for t in threading.enumerate()
|
||||
if t.name.startswith(self._thread_name_prefix)
|
||||
)
|
||||
|
||||
|
||||
# 全局线程管理器实例
|
||||
thread_manager = ThreadManager()
|
||||
147
src/neobot/core/managers/vectordb_manager.py
Normal file
147
src/neobot/core/managers/vectordb_manager.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
向量数据库管理器模块
|
||||
|
||||
该模块提供了一个基于 ChromaDB 的向量数据库管理器,
|
||||
用于存储和检索文本向量,为大语言模型提供记忆能力。
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from neobot.core.utils.logger import ModuleLogger
|
||||
from neobot.core.utils.singleton import Singleton
|
||||
|
||||
logger = ModuleLogger("VectorDBManager")
|
||||
|
||||
class VectorDBManager(Singleton):
|
||||
"""
|
||||
向量数据库管理器(单例)
|
||||
"""
|
||||
_client = None
|
||||
_collections = {}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.db_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "vectordb")
|
||||
os.makedirs(self.db_path, exist_ok=True)
|
||||
|
||||
def initialize(self):
|
||||
"""初始化 ChromaDB 客户端"""
|
||||
if self._client is None:
|
||||
try:
|
||||
logger.info(f"正在初始化向量数据库,路径: {self.db_path}")
|
||||
self._client = chromadb.PersistentClient(
|
||||
path=self.db_path,
|
||||
settings=Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True
|
||||
)
|
||||
)
|
||||
logger.success("向量数据库初始化成功!")
|
||||
except Exception as e:
|
||||
logger.error(f"向量数据库初始化失败: {e}")
|
||||
self._client = None
|
||||
|
||||
def get_collection(self, name: str):
|
||||
"""获取或创建集合"""
|
||||
if self._client is None:
|
||||
self.initialize()
|
||||
|
||||
if self._client is None:
|
||||
return None
|
||||
|
||||
if name not in self._collections:
|
||||
try:
|
||||
# 使用默认的 sentence-transformers 嵌入模型
|
||||
self._collections[name] = self._client.get_or_create_collection(name=name)
|
||||
logger.debug(f"已获取/创建向量集合: {name}")
|
||||
except Exception as e:
|
||||
logger.error(f"获取向量集合 {name} 失败: {e}")
|
||||
return None
|
||||
|
||||
return self._collections[name]
|
||||
|
||||
def add_texts(self, collection_name: str, texts: List[str], metadatas: List[Dict[str, Any]], ids: List[str]) -> bool:
|
||||
"""
|
||||
向集合中添加文本
|
||||
|
||||
Args:
|
||||
collection_name: 集合名称
|
||||
texts: 文本列表
|
||||
metadatas: 元数据列表(用于过滤和存储额外信息)
|
||||
ids: 唯一ID列表
|
||||
"""
|
||||
collection = self.get_collection(collection_name)
|
||||
if collection is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info(f"正在将 {len(texts)} 条记忆存入向量集合 {collection_name}...")
|
||||
collection.add(
|
||||
documents=texts,
|
||||
metadatas=metadatas,
|
||||
ids=ids
|
||||
)
|
||||
logger.success(f"成功将记忆存入集合 {collection_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"向集合 {collection_name} 添加记录失败: {e}")
|
||||
return False
|
||||
|
||||
def query_texts(self, collection_name: str, query_texts: List[str], n_results: int = 5, where: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
查询相似文本
|
||||
|
||||
Args:
|
||||
collection_name: 集合名称
|
||||
query_texts: 查询文本列表
|
||||
n_results: 返回结果数量
|
||||
where: 过滤条件
|
||||
"""
|
||||
collection = self.get_collection(collection_name)
|
||||
if collection is None:
|
||||
return {"documents": [], "metadatas": [], "distances": []}
|
||||
|
||||
try:
|
||||
logger.info(f"正在从向量集合 {collection_name} 中检索相关记忆...")
|
||||
results = collection.query(
|
||||
query_texts=query_texts,
|
||||
n_results=n_results,
|
||||
where=where
|
||||
)
|
||||
|
||||
# 统计检索到的结果数量
|
||||
doc_count = 0
|
||||
if results and results.get("documents") and results["documents"][0]:
|
||||
doc_count = len(results["documents"][0])
|
||||
|
||||
if doc_count > 0:
|
||||
logger.success(f"成功从集合 {collection_name} 检索到 {doc_count} 条相关记忆")
|
||||
else:
|
||||
logger.info(f"集合 {collection_name} 中未检索到相关记忆")
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"查询集合 {collection_name} 失败: {e}")
|
||||
return {"documents": [], "metadatas": [], "distances": []}
|
||||
|
||||
def delete_texts(self, collection_name: str, ids: Optional[List[str]] = None, where: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
删除文本
|
||||
"""
|
||||
collection = self.get_collection(collection_name)
|
||||
if collection is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
collection.delete(ids=ids, where=where)
|
||||
logger.debug(f"成功从集合 {collection_name} 删除记录")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"从集合 {collection_name} 删除记录失败: {e}")
|
||||
return False
|
||||
|
||||
# 全局向量数据库管理器实例
|
||||
vectordb_manager = VectorDBManager()
|
||||
42
src/neobot/core/permission.py
Normal file
42
src/neobot/core/permission.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Permission(Enum):
|
||||
"""
|
||||
定义用户权限等级的枚举类。
|
||||
|
||||
使用 @total_ordering 装饰器,只需定义 __lt__ 和 __eq__,
|
||||
即可自动实现所有比较运算符。
|
||||
"""
|
||||
USER = "user"
|
||||
OP = "op"
|
||||
ADMIN = "admin"
|
||||
|
||||
@property
|
||||
def _level_map(self):
|
||||
"""
|
||||
内部属性,用于映射枚举成员到整数等级。
|
||||
"""
|
||||
return {
|
||||
Permission.USER: 1,
|
||||
Permission.OP: 2,
|
||||
Permission.ADMIN: 3
|
||||
}
|
||||
|
||||
def __lt__(self, other):
|
||||
"""
|
||||
比较当前权限是否小于另一个权限。
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self._level_map[self] < self._level_map[other]
|
||||
|
||||
def __ge__(self, other):
|
||||
"""
|
||||
比较当前权限是否大于等于另一个权限。
|
||||
"""
|
||||
if not isinstance(other, Permission):
|
||||
return NotImplemented
|
||||
return self._level_map[self] >= self._level_map[other]
|
||||
217
src/neobot/core/plugin.py
Normal file
217
src/neobot/core/plugin.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import inspect
|
||||
import functools
|
||||
from typing import Optional, Union, Any, Callable
|
||||
from neobot.core.managers.command_manager import matcher as command_manager
|
||||
from neobot.core.permission import Permission
|
||||
from neobot.models.events.message import MessageEvent
|
||||
|
||||
class Plugin:
|
||||
"""
|
||||
插件基类,提供类风格的插件编写方式。
|
||||
通过继承此类,可以使用装饰器在类方法上注册命令和事件处理器。
|
||||
"""
|
||||
def __init__(self):
|
||||
self._register_handlers()
|
||||
|
||||
def _register_handlers(self):
|
||||
"""
|
||||
自动注册带有装饰器的方法。
|
||||
"""
|
||||
# 遍历实例的所有方法
|
||||
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
|
||||
# 检查是否有命令元数据
|
||||
if hasattr(method, "_command_meta"):
|
||||
meta = method._command_meta
|
||||
# 调用 command_manager 的装饰器来注册绑定后的方法
|
||||
command_manager.command(
|
||||
*meta['names'],
|
||||
permission=meta.get('permission'),
|
||||
override_permission_check=meta.get('override_permission_check', False)
|
||||
)(method)
|
||||
|
||||
# 检查是否有消息处理元数据
|
||||
if hasattr(method, "_on_message_meta"):
|
||||
command_manager.on_message()(method)
|
||||
|
||||
# 检查是否有通知处理元数据
|
||||
if hasattr(method, "_on_notice_meta"):
|
||||
meta = method._on_notice_meta
|
||||
command_manager.on_notice(notice_type=meta.get('notice_type'))(method)
|
||||
|
||||
# 检查是否有请求处理元数据
|
||||
if hasattr(method, "_on_request_meta"):
|
||||
meta = method._on_request_meta
|
||||
command_manager.on_request(request_type=meta.get('request_type'))(method)
|
||||
|
||||
async def send(self, event: MessageEvent, message: Union[str, Any]):
|
||||
"""
|
||||
发送消息的基础逻辑。
|
||||
"""
|
||||
if hasattr(event, 'reply'):
|
||||
await event.reply(message)
|
||||
else:
|
||||
pass
|
||||
|
||||
async def reply(self, event: MessageEvent, message: Union[str, Any]):
|
||||
"""
|
||||
回复消息。
|
||||
"""
|
||||
await self.send(event, message)
|
||||
|
||||
class SimplePlugin(Plugin):
|
||||
"""
|
||||
面向新手的简化插件基类。
|
||||
|
||||
特性:
|
||||
1. 自动将公共方法(不以_开头)注册为指令。
|
||||
2. 指令名默认为方法名。
|
||||
3. 自动解析参数类型。
|
||||
4. 支持直接返回字符串来回复消息。
|
||||
"""
|
||||
def _register_handlers(self):
|
||||
# 先处理带装饰器的方法
|
||||
super()._register_handlers()
|
||||
|
||||
# 扫描普通方法并注册为指令
|
||||
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
if hasattr(method, "_command_meta"):
|
||||
continue # 已经处理过
|
||||
if hasattr(method, "_on_message_meta"):
|
||||
continue
|
||||
if hasattr(method, "_on_notice_meta"):
|
||||
continue
|
||||
if hasattr(method, "_on_request_meta"):
|
||||
continue
|
||||
if name in dir(Plugin):
|
||||
continue # 忽略基类方法
|
||||
|
||||
self._register_method_as_command(name, method)
|
||||
|
||||
def _register_method_as_command(self, name: str, method: Callable):
|
||||
# 获取方法的签名
|
||||
sig = inspect.signature(method)
|
||||
|
||||
# 包装函数
|
||||
@functools.wraps(method)
|
||||
async def wrapper(event: MessageEvent, args: list[str]):
|
||||
try:
|
||||
# 准备调用参数
|
||||
call_args: list[Any] = []
|
||||
|
||||
# 跳过 self,第一个参数应该是 event
|
||||
params = list(sig.parameters.values())
|
||||
if not params:
|
||||
# 方法没有参数?这不应该发生,至少要有 event
|
||||
await method()
|
||||
return
|
||||
|
||||
# 绑定 event
|
||||
call_args.append(event)
|
||||
|
||||
# 处理剩余参数
|
||||
method_params = params[1:] # 除去 event
|
||||
|
||||
if not method_params:
|
||||
# 方法不需要额外参数
|
||||
pass
|
||||
elif len(method_params) == 1:
|
||||
# 只有一个参数,把所有 args 拼起来传给它
|
||||
param = method_params[0]
|
||||
if args:
|
||||
str_val = " ".join(args)
|
||||
val: Any = str_val
|
||||
# 类型转换
|
||||
if param.annotation is int:
|
||||
val = int(str_val)
|
||||
elif param.annotation is float:
|
||||
val = float(str_val)
|
||||
call_args.append(val)
|
||||
elif param.default is not inspect.Parameter.empty:
|
||||
call_args.append(param.default)
|
||||
else:
|
||||
await event.reply(f"缺少参数: {param.name}")
|
||||
return
|
||||
else:
|
||||
# 多个参数,尝试一一对应
|
||||
if len(args) < len([p for p in method_params if p.default is inspect.Parameter.empty]):
|
||||
# 必填参数不足
|
||||
usage = " ".join([f"<{p.name}>" for p in method_params])
|
||||
await event.reply(f"参数不足。用法: /{name} {usage}")
|
||||
return
|
||||
|
||||
for i, param in enumerate(method_params):
|
||||
if i < len(args):
|
||||
arg_str = args[i]
|
||||
arg_val: Any = arg_str
|
||||
# 简单的类型转换
|
||||
try:
|
||||
if param.annotation is int:
|
||||
arg_val = int(arg_str)
|
||||
elif param.annotation is float:
|
||||
arg_val = float(arg_str)
|
||||
except ValueError:
|
||||
await event.reply(f"参数 {param.name} 类型错误,应为 {param.annotation.__name__}")
|
||||
return
|
||||
call_args.append(arg_val)
|
||||
else:
|
||||
call_args.append(param.default)
|
||||
|
||||
# 调用方法
|
||||
result = await method(*call_args)
|
||||
|
||||
# 如果有返回值,自动回复
|
||||
if result is not None:
|
||||
await event.reply(str(result))
|
||||
|
||||
except Exception as e:
|
||||
await event.reply(f"执行命令时发生错误: {str(e)}")
|
||||
|
||||
# 注册命令
|
||||
command_manager.command(name)(wrapper)
|
||||
|
||||
|
||||
def command(name: str, *aliases: str, permission: Optional[Permission] = None, override_permission_check: bool = False):
|
||||
"""
|
||||
装饰器:标记方法为命令处理器。
|
||||
"""
|
||||
def decorator(func):
|
||||
func._command_meta = {
|
||||
"names": (name,) + aliases,
|
||||
"permission": permission,
|
||||
"override_permission_check": override_permission_check
|
||||
}
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def on_message():
|
||||
"""
|
||||
装饰器:标记方法为通用消息处理器。
|
||||
"""
|
||||
def decorator(func):
|
||||
func._on_message_meta = {}
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def on_notice(notice_type: Optional[str] = None):
|
||||
"""
|
||||
装饰器:标记方法为通知处理器。
|
||||
"""
|
||||
def decorator(func):
|
||||
func._on_notice_meta = {
|
||||
"notice_type": notice_type
|
||||
}
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def on_request(request_type: Optional[str] = None):
|
||||
"""
|
||||
装饰器:标记方法为请求处理器。
|
||||
"""
|
||||
def decorator(func):
|
||||
func._on_request_meta = {
|
||||
"request_type": request_type
|
||||
}
|
||||
return func
|
||||
return decorator
|
||||
9
src/neobot/core/services/__init__.py
Normal file
9
src/neobot/core/services/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
NEO Bot Services Package
|
||||
|
||||
服务层模块。
|
||||
"""
|
||||
|
||||
from .local_file_server import start_local_file_server, stop_local_file_server
|
||||
|
||||
__all__ = ["start_local_file_server", "stop_local_file_server"]
|
||||
219
src/neobot/core/services/local_file_server.py
Normal file
219
src/neobot/core/services/local_file_server.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
本地文件下载服务
|
||||
|
||||
该模块提供一个本地 HTTP 服务,用于下载远程文件到本地并提供本地访问。
|
||||
主要解决 NapCat 等第三方服务无法直接访问某些远程资源(如 B 站防盗链)的问题。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
import urllib.request
|
||||
|
||||
from neobot.core.utils.logger import logger
|
||||
from neobot.core.config_loader import global_config
|
||||
|
||||
|
||||
class LocalFileServer:
|
||||
"""
|
||||
本地文件下载服务
|
||||
|
||||
提供一个本地 HTTP 服务,用于下载远程文件到本地并提供本地访问。
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 3003):
|
||||
"""
|
||||
初始化本地文件下载服务
|
||||
|
||||
Args:
|
||||
host (str): 服务监听地址
|
||||
port (int): 服务监听端口
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.app = web.Application()
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self.download_dir = Path(tempfile.gettempdir()) / "neobot_downloads"
|
||||
self.download_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 注册路由
|
||||
self.app.router.add_get('/download', self.handle_download)
|
||||
self.app.router.add_get('/health', self.handle_health)
|
||||
|
||||
# 文件映射表:file_id -> file_path
|
||||
self.file_map: Dict[str, Path] = {}
|
||||
|
||||
logger.success(f"[LocalFileServer] 初始化完成: {self.host}:{self.port}")
|
||||
|
||||
async def start(self):
|
||||
"""启动服务"""
|
||||
self.runner = web.AppRunner(self.app)
|
||||
await self.runner.setup()
|
||||
self.site = web.TCPSite(self.runner, self.host, self.port)
|
||||
await self.site.start()
|
||||
logger.success(f"[LocalFileServer] 服务已启动: http://{self.host}:{self.port}")
|
||||
|
||||
async def stop(self):
|
||||
"""停止服务"""
|
||||
if self.runner:
|
||||
await self.runner.cleanup()
|
||||
logger.info("[LocalFileServer] 服务已停止")
|
||||
|
||||
def _generate_file_id(self, url: str) -> str:
|
||||
"""根据 URL 生成唯一的文件 ID"""
|
||||
url_hash = hashlib.md5(url.encode()).hexdigest()[:16]
|
||||
return f"file_{url_hash}"
|
||||
|
||||
async def download_file(self, url: str, timeout: int = 60, headers: Optional[Dict[str, str]] = None) -> Optional[str]:
|
||||
"""
|
||||
下载远程文件到本地
|
||||
|
||||
Args:
|
||||
url (str): 远程文件 URL
|
||||
timeout (int): 下载超时时间(秒)
|
||||
headers (Optional[Dict[str, str]]): 请求头
|
||||
|
||||
Returns:
|
||||
Optional[str]: 本地文件 ID,如果失败则返回 None
|
||||
"""
|
||||
try:
|
||||
file_id = self._generate_file_id(url)
|
||||
file_path = self.download_dir / f"{file_id}"
|
||||
|
||||
# 检查文件是否已存在
|
||||
if file_path.exists():
|
||||
logger.info(f"[LocalFileServer] 文件已存在: {file_id}")
|
||||
return file_id
|
||||
|
||||
logger.info(f"[LocalFileServer] 开始下载: {url}")
|
||||
|
||||
# 使用 aiohttp 下载文件
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, timeout=timeout, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"[LocalFileServer] 下载失败: HTTP {response.status}")
|
||||
return None
|
||||
|
||||
# 读取并保存文件
|
||||
with open(file_path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await response.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
|
||||
self.file_map[file_id] = file_path
|
||||
logger.success(f"[LocalFileServer] 下载完成: {file_id} ({file_path.stat().st_size} bytes)")
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LocalFileServer] 下载失败: {e}")
|
||||
return None
|
||||
|
||||
async def handle_download(self, request: web.Request) -> web.Response:
|
||||
"""处理文件下载请求"""
|
||||
file_id = request.query.get('id')
|
||||
|
||||
if not file_id or file_id not in self.file_map:
|
||||
return web.Response(
|
||||
status=404,
|
||||
text='File not found',
|
||||
content_type='text/plain'
|
||||
)
|
||||
|
||||
file_path = self.file_map[file_id]
|
||||
|
||||
if not file_path.exists():
|
||||
return web.Response(
|
||||
status=404,
|
||||
text='File not found',
|
||||
content_type='text/plain'
|
||||
)
|
||||
|
||||
# 获取文件大小
|
||||
file_size = file_path.stat().st_size
|
||||
|
||||
# 设置响应头
|
||||
headers = {
|
||||
'Content-Disposition': f'attachment; filename="{file_id}"',
|
||||
'Content-Length': str(file_size)
|
||||
}
|
||||
|
||||
return web.FileResponse(file_path, headers=headers)
|
||||
|
||||
async def handle_health(self, request: web.Request) -> web.Response:
|
||||
"""健康检查"""
|
||||
return web.json_response({
|
||||
'status': 'ok',
|
||||
'service': 'LocalFileServer',
|
||||
'download_dir': str(self.download_dir),
|
||||
'files_count': len(self.file_map)
|
||||
})
|
||||
|
||||
|
||||
# 全局实例
|
||||
_local_file_server: Optional[LocalFileServer] = None
|
||||
|
||||
|
||||
def get_local_file_server() -> Optional[LocalFileServer]:
|
||||
"""获取全局本地文件服务器实例"""
|
||||
global _local_file_server
|
||||
|
||||
if _local_file_server is None:
|
||||
try:
|
||||
server_config = global_config.local_file_server
|
||||
_local_file_server = LocalFileServer(
|
||||
host=server_config.host,
|
||||
port=server_config.port
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LocalFileServer] 初始化失败: {e}")
|
||||
return None
|
||||
|
||||
return _local_file_server
|
||||
|
||||
|
||||
async def start_local_file_server():
|
||||
"""启动全局本地文件服务器"""
|
||||
server = get_local_file_server()
|
||||
if server:
|
||||
await server.start()
|
||||
|
||||
|
||||
async def stop_local_file_server():
|
||||
"""停止全局本地文件服务器"""
|
||||
global _local_file_server
|
||||
if _local_file_server:
|
||||
await _local_file_server.stop()
|
||||
_local_file_server = None
|
||||
|
||||
|
||||
async def download_to_local(url: str, timeout: int = 60, headers: Optional[Dict[str, str]] = None) -> Optional[str]:
|
||||
"""
|
||||
下载远程文件到本地并返回本地访问 URL
|
||||
|
||||
Args:
|
||||
url (str): 远程文件 URL
|
||||
timeout (int): 下载超时时间(秒)
|
||||
headers (Optional[Dict[str, str]]): 请求头
|
||||
|
||||
Returns:
|
||||
Optional[str]: 本地访问 URL,如果失败则返回 None
|
||||
"""
|
||||
server = get_local_file_server()
|
||||
if not server:
|
||||
return None
|
||||
|
||||
file_id = await server.download_file(url, timeout, headers)
|
||||
if not file_id:
|
||||
return None
|
||||
|
||||
return f"http://127.0.0.1:{server.port}/download?id={file_id}"
|
||||
19
src/neobot/core/utils/__init__.py
Normal file
19
src/neobot/core/utils/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
NEO Bot Utils Package
|
||||
|
||||
工具函数模块。
|
||||
"""
|
||||
|
||||
from .error_codes import exception_to_error_response, ErrorCodes
|
||||
from .exceptions import BotException
|
||||
from .logger import logger, ModuleLogger
|
||||
from .singleton import Singleton
|
||||
|
||||
__all__ = [
|
||||
"exception_to_error_response",
|
||||
"ErrorCodes",
|
||||
"BotException",
|
||||
"logger",
|
||||
"ModuleLogger",
|
||||
"Singleton",
|
||||
]
|
||||
202
src/neobot/core/utils/env_loader.py
Normal file
202
src/neobot/core/utils/env_loader.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
环境变量加载器
|
||||
|
||||
负责从环境变量加载敏感配置,支持 .env 文件和环境变量。
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .logger import ModuleLogger
|
||||
|
||||
|
||||
class EnvLoader:
|
||||
"""
|
||||
环境变量加载器类
|
||||
"""
|
||||
|
||||
def __init__(self, env_file: str = ".env"):
|
||||
"""
|
||||
初始化环境变量加载器
|
||||
|
||||
Args:
|
||||
env_file: .env 文件路径
|
||||
"""
|
||||
self.env_file = Path(env_file)
|
||||
self.logger = ModuleLogger("EnvLoader")
|
||||
self._loaded = False
|
||||
|
||||
def load(self) -> bool:
|
||||
"""
|
||||
加载环境变量
|
||||
|
||||
Returns:
|
||||
bool: 是否成功加载
|
||||
"""
|
||||
if self._loaded:
|
||||
return True
|
||||
|
||||
try:
|
||||
# 尝试从 .env 文件加载
|
||||
if self.env_file.exists():
|
||||
load_dotenv(self.env_file)
|
||||
self.logger.info(f"已从 {self.env_file} 加载环境变量")
|
||||
else:
|
||||
self.logger.warning(f".env 文件不存在: {self.env_file}")
|
||||
|
||||
self._loaded = True
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"加载环境变量失败: {e}")
|
||||
return False
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
获取环境变量值
|
||||
|
||||
Args:
|
||||
key: 环境变量键名
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
环境变量值,如果不存在则返回默认值
|
||||
"""
|
||||
if not self._loaded:
|
||||
self.load()
|
||||
|
||||
return os.getenv(key, default)
|
||||
|
||||
def get_int(self, key: str, default: int = 0) -> int:
|
||||
"""
|
||||
获取整数类型的环境变量值
|
||||
|
||||
Args:
|
||||
key: 环境变量键名
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
整数类型的环境变量值
|
||||
"""
|
||||
value = self.get(key)
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
self.logger.warning(f"环境变量 {key} 的值 '{value}' 不是有效的整数,使用默认值 {default}")
|
||||
return default
|
||||
|
||||
def get_bool(self, key: str, default: bool = False) -> bool:
|
||||
"""
|
||||
获取布尔类型的环境变量值
|
||||
|
||||
Args:
|
||||
key: 环境变量键名
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
布尔类型的环境变量值
|
||||
"""
|
||||
value = self.get(key)
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
value_lower = value.lower()
|
||||
if value_lower in ('true', 'yes', '1', 'on'):
|
||||
return True
|
||||
elif value_lower in ('false', 'no', '0', 'off'):
|
||||
return False
|
||||
else:
|
||||
self.logger.warning(f"环境变量 {key} 的值 '{value}' 不是有效的布尔值,使用默认值 {default}")
|
||||
return default
|
||||
|
||||
def get_list(self, key: str, default: Optional[list] = None, separator: str = ',') -> list:
|
||||
"""
|
||||
获取列表类型的环境变量值
|
||||
|
||||
Args:
|
||||
key: 环境变量键名
|
||||
default: 默认值
|
||||
separator: 分隔符
|
||||
|
||||
Returns:
|
||||
列表类型的环境变量值
|
||||
"""
|
||||
value = self.get(key)
|
||||
if value is None:
|
||||
return default or []
|
||||
|
||||
return [item.strip() for item in value.split(separator) if item.strip()]
|
||||
|
||||
def validate_required(self, keys: list[str]) -> bool:
|
||||
"""
|
||||
验证必需的环境变量是否存在
|
||||
|
||||
Args:
|
||||
keys: 必需的环境变量键名列表
|
||||
|
||||
Returns:
|
||||
bool: 所有必需的环境变量是否存在
|
||||
"""
|
||||
missing_keys = []
|
||||
|
||||
for key in keys:
|
||||
if self.get(key) is None:
|
||||
missing_keys.append(key)
|
||||
|
||||
if missing_keys:
|
||||
self.logger.error(f"缺少必需的环境变量: {', '.join(missing_keys)}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def mask_sensitive_value(self, value: str) -> str:
|
||||
"""
|
||||
隐藏敏感值(用于日志输出)
|
||||
|
||||
Args:
|
||||
value: 原始值
|
||||
|
||||
Returns:
|
||||
隐藏后的值
|
||||
"""
|
||||
if not value:
|
||||
return ""
|
||||
|
||||
if len(value) <= 4:
|
||||
return "***"
|
||||
else:
|
||||
return value[:2] + "***" + value[-2:]
|
||||
|
||||
def get_safe_log_value(self, key: str) -> str:
|
||||
"""
|
||||
获取安全的日志值(隐藏敏感信息)
|
||||
|
||||
Args:
|
||||
key: 环境变量键名
|
||||
|
||||
Returns:
|
||||
安全的日志值
|
||||
"""
|
||||
value = self.get(key)
|
||||
if value is None:
|
||||
return "<未设置>"
|
||||
|
||||
# 敏感键名列表
|
||||
sensitive_keys = [
|
||||
'password', 'token', 'secret', 'key', 'credential',
|
||||
'sessdata', 'bili_jct', 'buvid3', 'dedeuserid'
|
||||
]
|
||||
|
||||
for sensitive in sensitive_keys:
|
||||
if sensitive in key.lower():
|
||||
return self.mask_sensitive_value(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# 全局环境变量加载器实例
|
||||
env_loader = EnvLoader()
|
||||
235
src/neobot/core/utils/error_codes.py
Normal file
235
src/neobot/core/utils/error_codes.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
错误码和统一响应格式模块
|
||||
|
||||
该模块定义了项目中使用的错误码和统一的错误响应格式,确保所有模块返回一致的错误信息。
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
# 错误码定义
|
||||
class ErrorCode:
|
||||
"""
|
||||
错误码枚举类,包含所有系统错误码的定义。
|
||||
|
||||
错误码规则:
|
||||
- 1xxx: 系统级错误
|
||||
- 2xxx: WebSocket相关错误
|
||||
- 3xxx: 插件相关错误
|
||||
- 4xxx: 配置相关错误
|
||||
- 5xxx: 权限相关错误
|
||||
- 6xxx: 命令相关错误
|
||||
- 7xxx: Redis相关错误
|
||||
- 8xxx: 浏览器管理器相关错误
|
||||
- 9xxx: 代码执行相关错误
|
||||
"""
|
||||
# 系统级错误
|
||||
SUCCESS = 0 # 成功
|
||||
UNKNOWN_ERROR = 1000 # 未知错误
|
||||
INVALID_PARAMETER = 1001 # 参数无效
|
||||
DATABASE_ERROR = 1002 # 数据库错误
|
||||
NETWORK_ERROR = 1003 # 网络错误
|
||||
TIMEOUT_ERROR = 1004 # 超时错误
|
||||
RESOURCE_EXHAUSTED = 1005 # 资源耗尽
|
||||
|
||||
# WebSocket相关错误
|
||||
WS_CONNECTION_FAILED = 2000 # WebSocket连接失败
|
||||
WS_AUTH_FAILED = 2001 # WebSocket认证失败
|
||||
WS_DISCONNECTED = 2002 # WebSocket已断开
|
||||
WS_MESSAGE_ERROR = 2003 # WebSocket消息错误
|
||||
|
||||
# 插件相关错误
|
||||
PLUGIN_LOAD_FAILED = 3000 # 插件加载失败
|
||||
PLUGIN_RELOAD_FAILED = 3001 # 插件重载失败
|
||||
PLUGIN_NOT_FOUND = 3002 # 插件未找到
|
||||
PLUGIN_INVALID = 3003 # 插件无效
|
||||
PLUGIN_DEPENDENCY_ERROR = 3004 # 插件依赖错误
|
||||
|
||||
# 配置相关错误
|
||||
CONFIG_NOT_FOUND = 4000 # 配置文件未找到
|
||||
CONFIG_PARSE_ERROR = 4001 # 配置解析错误
|
||||
CONFIG_VALIDATION_ERROR = 4002 # 配置验证错误
|
||||
CONFIG_KEY_NOT_FOUND = 4003 # 配置项未找到
|
||||
|
||||
# 权限相关错误
|
||||
PERMISSION_DENIED = 5000 # 权限不足
|
||||
NOT_ADMIN = 5001 # 不是管理员
|
||||
USER_BANNED = 5002 # 用户已被禁止
|
||||
|
||||
# 命令相关错误
|
||||
COMMAND_NOT_FOUND = 6000 # 命令未找到
|
||||
COMMAND_PARAM_ERROR = 6001 # 命令参数错误
|
||||
COMMAND_EXECUTE_ERROR = 6002 # 命令执行错误
|
||||
COMMAND_TIMEOUT = 6003 # 命令执行超时
|
||||
|
||||
# Redis相关错误
|
||||
REDIS_CONNECTION_FAILED = 7000 # Redis连接失败
|
||||
REDIS_OPERATION_ERROR = 7001 # Redis操作错误
|
||||
|
||||
# 浏览器管理器相关错误
|
||||
BROWSER_INIT_FAILED = 8000 # 浏览器初始化失败
|
||||
BROWSER_POOL_ERROR = 8001 # 浏览器池错误
|
||||
BROWSER_OPERATION_ERROR = 8002 # 浏览器操作错误
|
||||
|
||||
# 代码执行相关错误
|
||||
CODE_EXECUTE_ERROR = 9000 # 代码执行错误
|
||||
CODE_SECURITY_ERROR = 9001 # 代码安全错误
|
||||
|
||||
|
||||
# 错误码到错误消息的映射
|
||||
ERROR_MESSAGES = {
|
||||
# 系统级错误
|
||||
ErrorCode.SUCCESS: "操作成功",
|
||||
ErrorCode.UNKNOWN_ERROR: "未知错误",
|
||||
ErrorCode.INVALID_PARAMETER: "参数无效",
|
||||
ErrorCode.DATABASE_ERROR: "数据库错误",
|
||||
ErrorCode.NETWORK_ERROR: "网络错误",
|
||||
ErrorCode.TIMEOUT_ERROR: "操作超时",
|
||||
ErrorCode.RESOURCE_EXHAUSTED: "资源耗尽",
|
||||
|
||||
# WebSocket相关错误
|
||||
ErrorCode.WS_CONNECTION_FAILED: "WebSocket连接失败",
|
||||
ErrorCode.WS_AUTH_FAILED: "WebSocket认证失败",
|
||||
ErrorCode.WS_DISCONNECTED: "WebSocket已断开连接",
|
||||
ErrorCode.WS_MESSAGE_ERROR: "WebSocket消息格式错误",
|
||||
|
||||
# 插件相关错误
|
||||
ErrorCode.PLUGIN_LOAD_FAILED: "插件加载失败",
|
||||
ErrorCode.PLUGIN_RELOAD_FAILED: "插件重载失败",
|
||||
ErrorCode.PLUGIN_NOT_FOUND: "插件未找到",
|
||||
ErrorCode.PLUGIN_INVALID: "插件无效",
|
||||
ErrorCode.PLUGIN_DEPENDENCY_ERROR: "插件依赖错误",
|
||||
|
||||
# 配置相关错误
|
||||
ErrorCode.CONFIG_NOT_FOUND: "配置文件未找到",
|
||||
ErrorCode.CONFIG_PARSE_ERROR: "配置文件解析错误",
|
||||
ErrorCode.CONFIG_VALIDATION_ERROR: "配置验证失败",
|
||||
ErrorCode.CONFIG_KEY_NOT_FOUND: "配置项未找到",
|
||||
|
||||
# 权限相关错误
|
||||
ErrorCode.PERMISSION_DENIED: "权限不足",
|
||||
ErrorCode.NOT_ADMIN: "需要管理员权限",
|
||||
ErrorCode.USER_BANNED: "用户已被禁止操作",
|
||||
|
||||
# 命令相关错误
|
||||
ErrorCode.COMMAND_NOT_FOUND: "命令未找到",
|
||||
ErrorCode.COMMAND_PARAM_ERROR: "命令参数错误",
|
||||
ErrorCode.COMMAND_EXECUTE_ERROR: "命令执行错误",
|
||||
ErrorCode.COMMAND_TIMEOUT: "命令执行超时",
|
||||
|
||||
# Redis相关错误
|
||||
ErrorCode.REDIS_CONNECTION_FAILED: "Redis连接失败",
|
||||
ErrorCode.REDIS_OPERATION_ERROR: "Redis操作错误",
|
||||
|
||||
# 浏览器管理器相关错误
|
||||
ErrorCode.BROWSER_INIT_FAILED: "浏览器初始化失败",
|
||||
ErrorCode.BROWSER_POOL_ERROR: "浏览器池错误",
|
||||
ErrorCode.BROWSER_OPERATION_ERROR: "浏览器操作错误",
|
||||
|
||||
# 代码执行相关错误
|
||||
ErrorCode.CODE_EXECUTE_ERROR: "代码执行错误",
|
||||
ErrorCode.CODE_SECURITY_ERROR: "代码存在安全风险",
|
||||
}
|
||||
|
||||
|
||||
def get_error_message(code: int) -> str:
|
||||
"""
|
||||
根据错误码获取错误消息
|
||||
|
||||
Args:
|
||||
code: 错误码
|
||||
|
||||
Returns:
|
||||
str: 错误消息
|
||||
"""
|
||||
return ERROR_MESSAGES.get(code, ERROR_MESSAGES[ErrorCode.UNKNOWN_ERROR])
|
||||
|
||||
|
||||
def create_error_response(code: int, message: Optional[str] = None, data: Optional[dict] = None, request_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
创建统一格式的错误响应
|
||||
|
||||
Args:
|
||||
code: 错误码
|
||||
message: 错误消息(可选,如果未提供则使用默认消息)
|
||||
data: 附加数据(可选)
|
||||
request_id: 请求ID(可选,用于追踪请求)
|
||||
|
||||
Returns:
|
||||
dict: 统一格式的错误响应
|
||||
"""
|
||||
error_message = message if message is not None else get_error_message(code)
|
||||
|
||||
response = {
|
||||
"code": code,
|
||||
"message": error_message,
|
||||
"success": code == ErrorCode.SUCCESS,
|
||||
}
|
||||
|
||||
if data is not None:
|
||||
response["data"] = data
|
||||
|
||||
if request_id is not None:
|
||||
response["request_id"] = request_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def exception_to_error_response(exception: Exception, code: Optional[int] = None, request_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
将异常对象转换为统一格式的错误响应
|
||||
|
||||
Args:
|
||||
exception: 异常对象
|
||||
code: 错误码(可选,如果未提供则根据异常类型自动推断)
|
||||
request_id: 请求ID(可选,用于追踪请求)
|
||||
|
||||
Returns:
|
||||
dict: 统一格式的错误响应
|
||||
"""
|
||||
# 从自定义异常类中提取错误码
|
||||
if hasattr(exception, "code") and exception.code is not None:
|
||||
code = exception.code
|
||||
|
||||
# 如果仍未找到错误码,则根据异常类型推断
|
||||
if code is None:
|
||||
from .exceptions import (
|
||||
WebSocketError, PluginError, ConfigError, PermissionError,
|
||||
CommandError, RedisError, BrowserManagerError, CodeExecutionError
|
||||
)
|
||||
|
||||
if isinstance(exception, WebSocketError):
|
||||
code = ErrorCode.WS_CONNECTION_FAILED
|
||||
elif isinstance(exception, PluginError):
|
||||
code = ErrorCode.PLUGIN_LOAD_FAILED
|
||||
elif isinstance(exception, ConfigError):
|
||||
code = ErrorCode.CONFIG_PARSE_ERROR
|
||||
elif isinstance(exception, PermissionError):
|
||||
code = ErrorCode.PERMISSION_DENIED
|
||||
elif isinstance(exception, CommandError):
|
||||
code = ErrorCode.COMMAND_EXECUTE_ERROR
|
||||
elif isinstance(exception, RedisError):
|
||||
code = ErrorCode.REDIS_OPERATION_ERROR
|
||||
elif isinstance(exception, BrowserManagerError):
|
||||
code = ErrorCode.BROWSER_OPERATION_ERROR
|
||||
elif isinstance(exception, CodeExecutionError):
|
||||
code = ErrorCode.CODE_EXECUTE_ERROR
|
||||
else:
|
||||
code = ErrorCode.UNKNOWN_ERROR
|
||||
|
||||
# 获取错误消息
|
||||
message = str(exception)
|
||||
|
||||
# 如果异常有原始错误,也包含在响应中
|
||||
data = None
|
||||
if hasattr(exception, "original_error") and exception.original_error is not None:
|
||||
data = {"original_error": str(exception.original_error)}
|
||||
|
||||
return create_error_response(code, message, data, request_id)
|
||||
|
||||
|
||||
# 将错误码导出以便其他模块使用
|
||||
__all__ = [
|
||||
"ErrorCode",
|
||||
"get_error_message",
|
||||
"create_error_response",
|
||||
"exception_to_error_response"
|
||||
]
|
||||
222
src/neobot/core/utils/exceptions.py
Normal file
222
src/neobot/core/utils/exceptions.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
自定义异常模块
|
||||
|
||||
该模块定义了项目中使用的各种自定义异常类,用于提供更精确、更友好的错误提示。
|
||||
"""
|
||||
|
||||
class SyncHandlerError(Exception):
|
||||
"""
|
||||
当尝试注册同步函数作为异步事件处理器时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketError(Exception):
|
||||
"""
|
||||
WebSocket相关错误的基类。
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
code: 错误代码(可选)
|
||||
original_error: 原始异常对象(可选)
|
||||
"""
|
||||
def __init__(self, message, code=None, original_error=None):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class WebSocketConnectionError(WebSocketError):
|
||||
"""
|
||||
WebSocket连接失败时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketAuthenticationError(WebSocketError):
|
||||
"""
|
||||
WebSocket认证失败时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PluginError(Exception):
|
||||
"""
|
||||
插件相关错误的基类。
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
message: 错误消息
|
||||
original_error: 原始异常对象(可选)
|
||||
"""
|
||||
def __init__(self, plugin_name, message, original_error=None):
|
||||
self.plugin_name = plugin_name
|
||||
self.message = message
|
||||
self.original_error = original_error
|
||||
super().__init__(f"插件 {plugin_name}: {message}")
|
||||
|
||||
|
||||
class PluginLoadError(PluginError):
|
||||
"""
|
||||
插件加载失败时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PluginReloadError(PluginError):
|
||||
"""
|
||||
插件重载失败时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PluginNotFoundError(PluginError):
|
||||
"""
|
||||
找不到指定插件时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigError(Exception):
|
||||
"""
|
||||
配置相关错误的基类。
|
||||
|
||||
Args:
|
||||
section: 配置部分名称(可选)
|
||||
key: 配置项名称(可选)
|
||||
message: 错误消息(可选)
|
||||
"""
|
||||
def __init__(self, section=None, key=None, message=None):
|
||||
self.section = section
|
||||
self.key = key
|
||||
self.message = message
|
||||
self.original_error = None
|
||||
|
||||
if section and key and message:
|
||||
super().__init__(f"配置错误 [{section}.{key}]: {message}")
|
||||
elif section and message:
|
||||
super().__init__(f"配置错误 [{section}]: {message}")
|
||||
else:
|
||||
super().__init__(message or "配置错误")
|
||||
|
||||
|
||||
class ConfigNotFoundError(ConfigError):
|
||||
"""
|
||||
配置文件不存在时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigValidationError(ConfigError):
|
||||
"""
|
||||
配置验证失败时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PermissionError(Exception):
|
||||
"""
|
||||
权限相关错误的基类。
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
operation: 操作名称
|
||||
message: 错误消息
|
||||
"""
|
||||
def __init__(self, user_id=None, operation=None, message=None):
|
||||
self.user_id = user_id
|
||||
self.operation = operation
|
||||
self.message = message
|
||||
|
||||
if user_id and operation and message:
|
||||
super().__init__(f"权限错误 [用户 {user_id}]: 无权限执行操作 {operation} - {message}")
|
||||
elif user_id and operation:
|
||||
super().__init__(f"权限错误 [用户 {user_id}]: 无权限执行操作 {operation}")
|
||||
else:
|
||||
super().__init__(message or "权限错误")
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
"""
|
||||
命令处理相关错误的基类。
|
||||
|
||||
Args:
|
||||
command: 命令名称
|
||||
message: 错误消息
|
||||
original_error: 原始异常对象(可选)
|
||||
"""
|
||||
def __init__(self, command=None, message=None, original_error=None):
|
||||
self.command = command
|
||||
self.message = message
|
||||
self.original_error = original_error
|
||||
|
||||
if command and message:
|
||||
super().__init__(f"命令错误 [{command}]: {message}")
|
||||
else:
|
||||
super().__init__(message or "命令错误")
|
||||
|
||||
|
||||
class CommandNotFoundError(CommandError):
|
||||
"""
|
||||
找不到指定命令时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class CommandParameterError(CommandError):
|
||||
"""
|
||||
命令参数错误时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RedisError(Exception):
|
||||
"""
|
||||
Redis相关错误的基类。
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
original_error: 原始异常对象(可选)
|
||||
"""
|
||||
def __init__(self, message, original_error=None):
|
||||
self.message = message
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BrowserManagerError(Exception):
|
||||
"""
|
||||
浏览器管理器相关错误的基类。
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
original_error: 原始异常对象(可选)
|
||||
"""
|
||||
def __init__(self, message, original_error=None):
|
||||
self.message = message
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BrowserPoolError(BrowserManagerError):
|
||||
"""
|
||||
浏览器池相关错误时抛出此异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class CodeExecutionError(Exception):
|
||||
"""
|
||||
代码执行相关错误的基类。
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
code: 执行的代码(可选)
|
||||
original_error: 原始异常对象(可选)
|
||||
"""
|
||||
def __init__(self, message, code=None, original_error=None):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
202
src/neobot/core/utils/executor.py
Normal file
202
src/neobot/core/utils/executor.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import docker
|
||||
from docker.tls import TLSConfig
|
||||
from docker.types import LogConfig
|
||||
from typing import Any, Callable
|
||||
|
||||
from neobot.core.utils.logger import logger
|
||||
|
||||
class CodeExecutor:
|
||||
"""
|
||||
代码执行引擎,负责管理一个异步任务队列和并发的 Docker 容器执行。
|
||||
"""
|
||||
def __init__(self, config: Any):
|
||||
"""
|
||||
初始化代码执行引擎。
|
||||
:param config: 从 config_loader.py 加载的全局配置对象。
|
||||
"""
|
||||
self.bot: Any = None # Bot 实例将在 WS 连接成功后动态注入
|
||||
self.task_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
# 从传入的配置中读取 Docker 相关设置
|
||||
docker_config = config.docker
|
||||
self.docker_base_url = docker_config.base_url
|
||||
self.sandbox_image = docker_config.sandbox_image
|
||||
self.timeout = docker_config.timeout
|
||||
concurrency = docker_config.concurrency_limit
|
||||
|
||||
self.concurrency_limit = asyncio.Semaphore(concurrency)
|
||||
self.docker_client = None
|
||||
|
||||
logger.info("[CodeExecutor] 初始化 Docker 客户端...")
|
||||
try:
|
||||
if self.docker_base_url:
|
||||
# 如果配置了远程 Docker 地址,则使用 TLS 选项进行连接
|
||||
tls_config = None
|
||||
if docker_config.tls_verify:
|
||||
tls_config = TLSConfig(
|
||||
ca_cert=docker_config.ca_cert_path,
|
||||
client_cert=(docker_config.client_cert_path, docker_config.client_key_path),
|
||||
verify=True
|
||||
)
|
||||
self.docker_client = docker.DockerClient(base_url=self.docker_base_url, tls=tls_config)
|
||||
else:
|
||||
# 否则,使用默认的本地环境连接
|
||||
self.docker_client = docker.from_env()
|
||||
|
||||
# 检查 Docker 服务是否可用
|
||||
self.docker_client.ping()
|
||||
logger.success("[CodeExecutor] Docker 客户端初始化成功,服务连接正常。")
|
||||
except docker.errors.DockerException as e:
|
||||
self.docker_client = None
|
||||
logger.error(f"无法连接到 Docker 服务,请检查 Docker 是否正在运行: {e}")
|
||||
except Exception as e:
|
||||
self.docker_client = None
|
||||
logger.error(f"初始化 Docker 客户端时发生未知错误: {e}")
|
||||
|
||||
async def add_task(self, code: str, callback: Callable[[str], asyncio.Future]):
|
||||
"""
|
||||
将代码执行任务添加到队列中。
|
||||
:param code: 待执行的 Python 代码字符串。
|
||||
:param callback: 执行完毕后用于回复结果的回调函数。
|
||||
:raises RuntimeError: 如果 Docker 客户端未初始化。
|
||||
"""
|
||||
if not self.docker_client:
|
||||
logger.warning("[CodeExecutor] 尝试添加任务,但 Docker 客户端未初始化。任务被拒绝。")
|
||||
# 这里可以选择抛出异常,或者直接调用回调返回错误信息
|
||||
# 为了用户体验,我们构造一个错误结果并直接调用回调(如果可能)
|
||||
# 但由于 callback 返回 Future,这里简单起见,我们记录日志并抛出异常
|
||||
raise RuntimeError("Docker环境未就绪,无法执行代码。")
|
||||
|
||||
task = {"code": code, "callback": callback}
|
||||
await self.task_queue.put(task)
|
||||
logger.info(f"[CodeExecutor] 新的代码执行任务已入队 (队列当前长度: {self.task_queue.qsize()})。")
|
||||
|
||||
async def worker(self):
|
||||
"""
|
||||
后台工作者,不断从队列中取出任务并执行。
|
||||
"""
|
||||
if not self.docker_client:
|
||||
logger.error("[CodeExecutor] Worker 无法启动,因为 Docker 客户端未初始化。")
|
||||
return
|
||||
|
||||
logger.info("[CodeExecutor] 代码执行 Worker 已启动,等待任务...")
|
||||
while True:
|
||||
task = await self.task_queue.get()
|
||||
|
||||
logger.info("[CodeExecutor] 开始处理代码执行任务。")
|
||||
|
||||
async with self.concurrency_limit:
|
||||
result_message = ""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# 使用 asyncio.wait_for 实现超时控制
|
||||
result_bytes = await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
None, # 使用默认线程池
|
||||
self._run_in_container,
|
||||
task['code']
|
||||
),
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
output = result_bytes.decode('utf-8').strip()
|
||||
result_message = output if output else "代码执行完毕,无输出。"
|
||||
logger.success("[CodeExecutor] 任务成功执行。")
|
||||
|
||||
except docker.errors.ImageNotFound:
|
||||
logger.error(f"[CodeExecutor] 镜像 '{self.sandbox_image}' 不存在!")
|
||||
result_message = f"执行失败:沙箱基础镜像 '{self.sandbox_image}' 不存在,请联系管理员构建。"
|
||||
except docker.errors.ContainerError as e:
|
||||
# 确保 stderr 是字符串
|
||||
error_output = e.stderr.decode('utf-8').strip() if isinstance(e.stderr, bytes) else e.stderr.strip()
|
||||
result_message = f"代码执行出错:\n{error_output}"
|
||||
logger.warning(f"[CodeExecutor] 代码执行时发生错误: {error_output}")
|
||||
except docker.errors.APIError as e:
|
||||
logger.error(f"[CodeExecutor] Docker API 错误: {e}")
|
||||
result_message = "执行失败:与 Docker 服务通信时发生错误,请检查服务状态。"
|
||||
except asyncio.TimeoutError:
|
||||
result_message = f"执行超时 (超过 {self.timeout} 秒)。"
|
||||
logger.warning("[CodeExecutor] 任务执行超时。")
|
||||
except Exception as e:
|
||||
logger.exception(f"[CodeExecutor] 执行 Docker 任务时发生未知严重错误: {e}")
|
||||
result_message = "执行引擎发生内部错误,请联系管理员。"
|
||||
|
||||
# 调用回调函数回复结果
|
||||
try:
|
||||
await task['callback'](result_message)
|
||||
except Exception as callback_error:
|
||||
logger.error(f"[CodeExecutor] 执行回调函数时发生错误: {callback_error}")
|
||||
# 即使回调失败,也要确保任务被标记为完成
|
||||
|
||||
self.task_queue.task_done()
|
||||
|
||||
def _run_in_container(self, code: str) -> bytes:
|
||||
"""
|
||||
同步函数:在 Docker 容器中运行代码。
|
||||
此函数通过手动管理容器生命周期来提高稳定性。
|
||||
"""
|
||||
if self.docker_client is None:
|
||||
raise docker.errors.DockerException("Docker client is not initialized.")
|
||||
|
||||
container = None
|
||||
try:
|
||||
# 1. 创建容器
|
||||
container = self.docker_client.containers.create(
|
||||
image=self.sandbox_image,
|
||||
command=["python", "-c", code],
|
||||
mem_limit='128m',
|
||||
cpu_shares=512,
|
||||
network_disabled=True,
|
||||
log_config=LogConfig(type='json-file', config={'max-size': '1m'}),
|
||||
)
|
||||
# 2. 启动容器
|
||||
container.start()
|
||||
|
||||
# 3. 等待容器执行完成
|
||||
# 主超时由 asyncio.wait_for 控制,这里的 timeout 是一个额外的保险
|
||||
result = container.wait(timeout=self.timeout + 5)
|
||||
|
||||
# 4. 获取日志
|
||||
stdout = container.logs(stdout=True, stderr=False)
|
||||
stderr = container.logs(stdout=False, stderr=True)
|
||||
|
||||
# 5. 检查退出码,如果不为 0,则手动抛出 ContainerError
|
||||
if result.get('StatusCode', 0) != 0:
|
||||
# 确保 stderr 是字符串
|
||||
error_message = stderr.decode('utf-8') if isinstance(stderr, bytes) else stderr
|
||||
raise docker.errors.ContainerError(
|
||||
container, result['StatusCode'], f"python -c '{code}'", self.sandbox_image, error_message
|
||||
)
|
||||
|
||||
return stdout
|
||||
|
||||
finally:
|
||||
# 6. 确保容器总是被移除
|
||||
if container:
|
||||
try:
|
||||
container.remove(force=True)
|
||||
except docker.errors.NotFound:
|
||||
# 如果容器因为某些原因已经消失,也沒关系
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"[CodeExecutor] 强制移除容器 {container.id} 时失败: {e}")
|
||||
|
||||
def initialize_executor(config: Any):
|
||||
"""
|
||||
初始化并返回一个 CodeExecutor 实例。
|
||||
"""
|
||||
return CodeExecutor(config)
|
||||
|
||||
async def run_in_thread_pool(sync_func, *args, **kwargs):
|
||||
"""
|
||||
在线程池中运行同步阻塞函数,以避免阻塞 asyncio 事件循环。
|
||||
:param sync_func: 同步函数
|
||||
:param args: 位置参数
|
||||
:param kwargs: 关键字参数
|
||||
:return: 同步函数的返回值
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, lambda: sync_func(*args, **kwargs))
|
||||
388
src/neobot/core/utils/input_validator.py
Normal file
388
src/neobot/core/utils/input_validator.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
输入验证工具
|
||||
|
||||
提供通用的输入验证功能,防止 SQL 注入、XSS 攻击等安全问题。
|
||||
"""
|
||||
import re
|
||||
import html
|
||||
from typing import Optional, Union, List, Dict, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .logger import ModuleLogger
|
||||
|
||||
|
||||
class InputValidator:
|
||||
"""
|
||||
输入验证器类
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = ModuleLogger("InputValidator")
|
||||
|
||||
# SQL 注入检测模式(预编译正则表达式)
|
||||
self.sql_injection_patterns = [
|
||||
re.compile(r"(?i)(\b(select|insert|update|delete|drop|create|alter|truncate|union|join)\b)"),
|
||||
re.compile(r"(?i)(\b(from|where|group by|order by|having|limit|offset)\b)"),
|
||||
re.compile(r"(?i)(\b(and|or|not|xor|between|in|like|is|null)\b)"),
|
||||
re.compile(r"(?i)(\b(exec|execute|sp_executesql|xp_cmdshell)\b)"),
|
||||
re.compile(r"(?i)(\b(declare|set|cast|convert|case|when|then|else|end)\b)"),
|
||||
re.compile(r"(--|\#|\/\*|\*\/|;)"),
|
||||
re.compile(r"(\b(0x[0-9a-f]+)\b)"),
|
||||
re.compile(r"(\b(admin|administrator|root|sysadmin)\b)"),
|
||||
re.compile(r"(\b(password|passwd|pwd|secret|token|key)\b)"),
|
||||
]
|
||||
|
||||
# XSS 攻击检测模式(预编译正则表达式)
|
||||
self.xss_patterns = [
|
||||
re.compile(r"(<script[^>]*>.*?</script>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<iframe[^>]*>.*?</iframe>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<object[^>]*>.*?</object>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<embed[^>]*>.*?</embed>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<applet[^>]*>.*?</applet>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<meta[^>]*>.*?</meta>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<link[^>]*>.*?</link>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<style[^>]*>.*?</style>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<base[^>]*>.*?</base>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<form[^>]*>.*?</form>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<input[^>]*>.*?</input>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<button[^>]*>.*?</button>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<select[^>]*>.*?</select>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<textarea[^>]*>.*?</textarea>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<img[^>]*>.*?</img>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<svg[^>]*>.*?</svg>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(<math[^>]*>.*?</math>)", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"(javascript:|data:|vbscript:|about:|file:|ftp:|mailto:|telnet:)", re.IGNORECASE),
|
||||
re.compile(r"(on\w+\s*=)", re.IGNORECASE),
|
||||
re.compile(r"(expression\s*\()", re.IGNORECASE),
|
||||
re.compile(r"(url\s*\()", re.IGNORECASE),
|
||||
]
|
||||
|
||||
# 路径遍历检测模式(预编译正则表达式)
|
||||
self.path_traversal_patterns = [
|
||||
re.compile(r"(\.\./|\.\.\\)", re.IGNORECASE),
|
||||
re.compile(r"(/etc/passwd|/etc/shadow|/etc/hosts)", re.IGNORECASE),
|
||||
re.compile(r"(C:\\Windows\\System32|C:\\Windows\\SysWOW64)", re.IGNORECASE),
|
||||
re.compile(r"(/bin/sh|/bin/bash|/usr/bin/python)", re.IGNORECASE),
|
||||
re.compile(r"(\.\.%2f|\.\.%5c)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
# 命令注入检测模式(预编译正则表达式)
|
||||
self.command_injection_patterns = [
|
||||
re.compile(r"(;|\||&|\$\(|\`|\n|\r)"),
|
||||
re.compile(r"(rm\s+-rf|del\s+/f|format\s+)", re.IGNORECASE),
|
||||
re.compile(r"(shutdown|reboot|halt|poweroff)", re.IGNORECASE),
|
||||
re.compile(r"(wget|curl|ftp|scp|ssh)\s+", re.IGNORECASE),
|
||||
re.compile(r"(nc|netcat|telnet|nmap)\s+", re.IGNORECASE),
|
||||
]
|
||||
|
||||
# 预编译常用正则表达式
|
||||
self.email_pattern = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
|
||||
self.phone_pattern = re.compile(r'^1[3-9]\d{9}$')
|
||||
self.nine_digit_pattern = re.compile(r'^\d{9}$') # 用于城市代码验证
|
||||
|
||||
def validate_sql_input(self, input_str: str, allow_safe_keywords: bool = False) -> bool:
|
||||
"""
|
||||
验证 SQL 输入是否安全
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串
|
||||
allow_safe_keywords: 是否允许安全的 SQL 关键字
|
||||
|
||||
Returns:
|
||||
bool: 是否安全
|
||||
"""
|
||||
if not input_str:
|
||||
return True
|
||||
|
||||
input_lower = input_str.lower()
|
||||
|
||||
# 检查 SQL 注入模式(使用预编译的正则表达式)
|
||||
for pattern in self.sql_injection_patterns:
|
||||
if pattern.search(input_lower):
|
||||
self.logger.warning(f"检测到可能的 SQL 注入: {input_str}")
|
||||
return False
|
||||
|
||||
# 如果允许安全关键字,检查是否包含危险操作
|
||||
if allow_safe_keywords:
|
||||
dangerous_operations = ['drop', 'delete', 'truncate', 'alter', 'create', 'exec']
|
||||
for op in dangerous_operations:
|
||||
if op in input_lower:
|
||||
self.logger.warning(f"检测到危险 SQL 操作: {op}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def validate_xss_input(self, input_str: str) -> bool:
|
||||
"""
|
||||
验证 XSS 输入是否安全
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串
|
||||
|
||||
Returns:
|
||||
bool: 是否安全
|
||||
"""
|
||||
if not input_str:
|
||||
return True
|
||||
|
||||
# 检查 XSS 攻击模式(使用预编译的正则表达式)
|
||||
for pattern in self.xss_patterns:
|
||||
if pattern.search(input_str):
|
||||
self.logger.warning(f"检测到可能的 XSS 攻击: {input_str}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def validate_path_input(self, input_str: str) -> bool:
|
||||
"""
|
||||
验证路径输入是否安全
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串
|
||||
|
||||
Returns:
|
||||
bool: 是否安全
|
||||
"""
|
||||
if not input_str:
|
||||
return True
|
||||
|
||||
# 检查路径遍历攻击(使用预编译的正则表达式)
|
||||
for pattern in self.path_traversal_patterns:
|
||||
if pattern.search(input_str):
|
||||
self.logger.warning(f"检测到可能的路径遍历攻击: {input_str}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def validate_command_input(self, input_str: str) -> bool:
|
||||
"""
|
||||
验证命令输入是否安全
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串
|
||||
|
||||
Returns:
|
||||
bool: 是否安全
|
||||
"""
|
||||
if not input_str:
|
||||
return True
|
||||
|
||||
# 检查命令注入攻击(使用预编译的正则表达式)
|
||||
for pattern in self.command_injection_patterns:
|
||||
if pattern.search(input_str):
|
||||
self.logger.warning(f"检测到可能的命令注入攻击: {input_str}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def validate_url(self, url: str, allowed_schemes: List[str] = None) -> bool:
|
||||
"""
|
||||
验证 URL 是否安全
|
||||
|
||||
Args:
|
||||
url: URL 字符串
|
||||
allowed_schemes: 允许的协议列表
|
||||
|
||||
Returns:
|
||||
bool: 是否安全
|
||||
"""
|
||||
if not url:
|
||||
return False
|
||||
|
||||
if allowed_schemes is None:
|
||||
allowed_schemes = ['http', 'https', 'ftp', 'file']
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
# 检查协议
|
||||
if parsed.scheme not in allowed_schemes:
|
||||
self.logger.warning(f"不允许的协议: {parsed.scheme}")
|
||||
return False
|
||||
|
||||
# 检查主机名
|
||||
if not parsed.hostname:
|
||||
self.logger.warning("URL 缺少主机名")
|
||||
return False
|
||||
|
||||
# 检查路径安全性
|
||||
if not self.validate_path_input(parsed.path):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"URL 解析失败: {e}")
|
||||
return False
|
||||
|
||||
def validate_email(self, email: str) -> bool:
|
||||
"""
|
||||
验证邮箱地址格式
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
|
||||
return bool(self.email_pattern.match(email))
|
||||
|
||||
def validate_phone(self, phone: str) -> bool:
|
||||
"""
|
||||
验证手机号码格式
|
||||
|
||||
Args:
|
||||
phone: 手机号码
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
if not phone:
|
||||
return False
|
||||
|
||||
return bool(self.phone_pattern.match(phone))
|
||||
|
||||
def validate_integer(self, value: str, min_value: Optional[int] = None, max_value: Optional[int] = None) -> bool:
|
||||
"""
|
||||
验证整数格式和范围
|
||||
|
||||
Args:
|
||||
value: 整数字符串
|
||||
min_value: 最小值
|
||||
max_value: 最大值
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
if not value:
|
||||
return False
|
||||
|
||||
try:
|
||||
int_value = int(value)
|
||||
|
||||
if min_value is not None and int_value < min_value:
|
||||
return False
|
||||
|
||||
if max_value is not None and int_value > max_value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def validate_float(self, value: str, min_value: Optional[float] = None, max_value: Optional[float] = None) -> bool:
|
||||
"""
|
||||
验证浮点数格式和范围
|
||||
|
||||
Args:
|
||||
value: 浮点数字符串
|
||||
min_value: 最小值
|
||||
max_value: 最大值
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
if not value:
|
||||
return False
|
||||
|
||||
try:
|
||||
float_value = float(value)
|
||||
|
||||
if min_value is not None and float_value < min_value:
|
||||
return False
|
||||
|
||||
if max_value is not None and float_value > max_value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def sanitize_html(self, html_str: str) -> str:
|
||||
"""
|
||||
清理 HTML 字符串,防止 XSS 攻击
|
||||
|
||||
Args:
|
||||
html_str: HTML 字符串
|
||||
|
||||
Returns:
|
||||
str: 清理后的字符串
|
||||
"""
|
||||
if not html_str:
|
||||
return ""
|
||||
|
||||
# 转义 HTML 特殊字符
|
||||
sanitized = html.escape(html_str)
|
||||
|
||||
# 移除危险的属性
|
||||
sanitized = re.sub(r'on\w+\s*=', 'data-', sanitized, flags=re.IGNORECASE)
|
||||
sanitized = re.sub(r'javascript:', 'data:', sanitized, flags=re.IGNORECASE)
|
||||
sanitized = re.sub(r'data:', 'data:', sanitized, flags=re.IGNORECASE)
|
||||
sanitized = re.sub(r'vbscript:', 'data:', sanitized, flags=re.IGNORECASE)
|
||||
|
||||
return sanitized
|
||||
|
||||
def sanitize_sql(self, sql_str: str) -> str:
|
||||
"""
|
||||
清理 SQL 字符串,防止 SQL 注入
|
||||
|
||||
Args:
|
||||
sql_str: SQL 字符串
|
||||
|
||||
Returns:
|
||||
str: 清理后的字符串
|
||||
"""
|
||||
if not sql_str:
|
||||
return ""
|
||||
|
||||
# 移除注释
|
||||
sanitized = re.sub(r'--.*$', '', sql_str, flags=re.MULTILINE)
|
||||
sanitized = re.sub(r'/\*.*?\*/', '', sanitized, flags=re.DOTALL)
|
||||
|
||||
# 移除分号(在参数化查询中不需要)
|
||||
sanitized = sanitized.replace(';', '')
|
||||
|
||||
return sanitized
|
||||
|
||||
def validate_all(self, input_str: str, validation_types: List[str] = None) -> Dict[str, bool]:
|
||||
"""
|
||||
执行所有验证
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串
|
||||
validation_types: 验证类型列表
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: 验证结果字典
|
||||
"""
|
||||
if validation_types is None:
|
||||
validation_types = ['sql', 'xss', 'path', 'command']
|
||||
|
||||
results = {}
|
||||
|
||||
for vtype in validation_types:
|
||||
if vtype == 'sql':
|
||||
results['sql'] = self.validate_sql_input(input_str)
|
||||
elif vtype == 'xss':
|
||||
results['xss'] = self.validate_xss_input(input_str)
|
||||
elif vtype == 'path':
|
||||
results['path'] = self.validate_path_input(input_str)
|
||||
elif vtype == 'command':
|
||||
results['command'] = self.validate_command_input(input_str)
|
||||
elif vtype == 'url':
|
||||
results['url'] = self.validate_url(input_str)
|
||||
elif vtype == 'email':
|
||||
results['email'] = self.validate_email(input_str)
|
||||
elif vtype == 'phone':
|
||||
results['phone'] = self.validate_phone(input_str)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# 全局输入验证器实例
|
||||
input_validator = InputValidator()
|
||||
151
src/neobot/core/utils/logger.py
Normal file
151
src/neobot/core/utils/logger.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
日志模块
|
||||
|
||||
该模块负责初始化和配置 loguru 日志记录器,为整个应用程序提供统一的日志记录接口。
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
# 导入全局配置
|
||||
try:
|
||||
from ..config_loader import global_config
|
||||
USE_CONFIG = True
|
||||
except ImportError:
|
||||
USE_CONFIG = False
|
||||
|
||||
# 定义日志格式,添加进程ID和线程ID作为上下文信息
|
||||
LOG_FORMAT = (
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<magenta>PID {process} TID {thread}</magenta> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||
"<level>{message}</level>"
|
||||
)
|
||||
|
||||
# 开发环境日志格式(更详细)
|
||||
DEBUG_LOG_FORMAT = (
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<magenta>PID {process} TID {thread}</magenta> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
|
||||
"<yellow>Module: {module}</yellow> | "
|
||||
"<level>{message}</level>"
|
||||
)
|
||||
|
||||
# 移除 loguru 默认的处理器
|
||||
logger.remove()
|
||||
|
||||
# 获取日志级别配置
|
||||
if USE_CONFIG:
|
||||
LOG_LEVEL = global_config.logging.level
|
||||
FILE_LEVEL = global_config.logging.file_level
|
||||
CONSOLE_LEVEL = global_config.logging.console_level
|
||||
else:
|
||||
LOG_LEVEL = "DEBUG"
|
||||
FILE_LEVEL = "DEBUG"
|
||||
CONSOLE_LEVEL = "INFO"
|
||||
|
||||
# 添加控制台输出处理器
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
level=CONSOLE_LEVEL,
|
||||
format=LOG_FORMAT,
|
||||
colorize=True,
|
||||
enqueue=True # 异步写入
|
||||
)
|
||||
|
||||
# 定义日志文件路径
|
||||
log_dir = Path("logs")
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
log_file_path = log_dir / "{time:YYYY-MM-DD}.log"
|
||||
|
||||
# 添加文件输出处理器
|
||||
logger.add(
|
||||
log_file_path,
|
||||
level=FILE_LEVEL,
|
||||
format=DEBUG_LOG_FORMAT,
|
||||
colorize=False,
|
||||
rotation="00:00", # 每天午夜创建新文件
|
||||
retention="7 days", # 保留最近 7 天的日志
|
||||
encoding="utf-8",
|
||||
enqueue=True, # 异步写入
|
||||
backtrace=True, # 记录完整的异常堆栈
|
||||
diagnose=True # 添加异常诊断信息
|
||||
)
|
||||
|
||||
# 为自定义异常添加专门的日志记录方法
|
||||
def log_exception(exc, module_name="unknown", level="error"):
|
||||
"""
|
||||
记录自定义异常的详细信息
|
||||
|
||||
Args:
|
||||
exc: 异常对象
|
||||
module_name: 模块名称(可选)
|
||||
level: 日志级别(可选,默认为 "error")
|
||||
"""
|
||||
log_func = getattr(logger, level)
|
||||
log_func(f"模块 {module_name} 发生异常: {exc}")
|
||||
|
||||
# 如果异常对象有原始异常,也记录原始异常信息
|
||||
if hasattr(exc, "original_error") and exc.original_error:
|
||||
log_func(f"原始异常: {exc.original_error}")
|
||||
|
||||
# 如果是配置错误,记录配置相关信息
|
||||
if hasattr(exc, "section") and hasattr(exc, "key"):
|
||||
log_func(f"配置信息: 部分={exc.section}, 键={exc.key}")
|
||||
|
||||
# 如果是插件错误,记录插件名称
|
||||
if hasattr(exc, "plugin_name"):
|
||||
log_func(f"插件名称: {exc.plugin_name}")
|
||||
|
||||
# 如果是命令错误,记录命令名称
|
||||
if hasattr(exc, "command"):
|
||||
log_func(f"命令名称: {exc.command}")
|
||||
|
||||
# 如果是权限错误,记录用户ID和操作
|
||||
if hasattr(exc, "user_id") and hasattr(exc, "operation"):
|
||||
log_func(f"权限信息: 用户ID={exc.user_id}, 操作={exc.operation}")
|
||||
|
||||
# 为不同模块提供日志工具
|
||||
class ModuleLogger:
|
||||
"""
|
||||
模块专用日志记录器
|
||||
|
||||
Args:
|
||||
module_name: 模块名称
|
||||
"""
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
|
||||
def debug(self, message):
|
||||
logger.debug(f"[{self.module_name}] {message}")
|
||||
|
||||
def info(self, message):
|
||||
logger.info(f"[{self.module_name}] {message}")
|
||||
|
||||
def success(self, message):
|
||||
logger.success(f"[{self.module_name}] {message}")
|
||||
|
||||
def warning(self, message):
|
||||
logger.warning(f"[{self.module_name}] {message}")
|
||||
|
||||
def error(self, message):
|
||||
logger.error(f"[{self.module_name}] {message}")
|
||||
|
||||
def exception(self, message, exc_info=True):
|
||||
logger.exception(f"[{self.module_name}] {message}", exc_info=exc_info)
|
||||
|
||||
def log_custom_exception(self, exc, level="error"):
|
||||
"""
|
||||
记录自定义异常
|
||||
|
||||
Args:
|
||||
exc: 异常对象
|
||||
level: 日志级别
|
||||
"""
|
||||
log_exception(exc, self.module_name, level)
|
||||
|
||||
# 导出配置好的 logger 和工具函数
|
||||
__all__ = ["logger", "log_exception", "ModuleLogger"]
|
||||
364
src/neobot/core/utils/performance.py
Normal file
364
src/neobot/core/utils/performance.py
Normal file
@@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
性能分析工具模块
|
||||
|
||||
提供同步和异步函数的性能分析装饰器、上下文管理器和统计工具。
|
||||
|
||||
主要功能:
|
||||
1. 函数执行时间分析(支持同步和异步)
|
||||
2. 内存使用分析
|
||||
3. 性能统计和报告生成
|
||||
4. 低开销的生产环境监控
|
||||
"""
|
||||
|
||||
import time
|
||||
import functools
|
||||
import logging
|
||||
from typing import Dict, Any, Callable, Optional
|
||||
import inspect
|
||||
|
||||
# 尝试导入性能分析库
|
||||
try:
|
||||
from pyinstrument import Profiler
|
||||
from pyinstrument.renderers import HTMLRenderer
|
||||
PYINSTRUMENT_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYINSTRUMENT_AVAILABLE = False
|
||||
|
||||
# 尝试导入内存分析库
|
||||
try:
|
||||
from memory_profiler import memory_usage
|
||||
MEMORY_PROFILER_AVAILABLE = True
|
||||
except ImportError:
|
||||
MEMORY_PROFILER_AVAILABLE = False
|
||||
|
||||
from .logger import logger
|
||||
|
||||
|
||||
class PerformanceStats:
|
||||
"""
|
||||
性能统计工具类
|
||||
用于收集和报告函数执行的性能指标
|
||||
"""
|
||||
def __init__(self):
|
||||
self.stats: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def record(self, func_name: str, duration: float, memory_used: Optional[float] = None):
|
||||
"""
|
||||
记录函数执行的性能数据
|
||||
|
||||
Args:
|
||||
func_name: 函数名称
|
||||
duration: 执行时间(秒)
|
||||
memory_used: 使用的内存(MB),可选
|
||||
"""
|
||||
if func_name not in self.stats:
|
||||
self.stats[func_name] = {
|
||||
"count": 0,
|
||||
"total_time": 0.0,
|
||||
"avg_time": 0.0,
|
||||
"min_time": float('inf'),
|
||||
"max_time": 0.0,
|
||||
"total_memory": 0.0,
|
||||
"avg_memory": 0.0
|
||||
}
|
||||
|
||||
stat = self.stats[func_name]
|
||||
stat["count"] += 1
|
||||
stat["total_time"] += duration
|
||||
stat["avg_time"] = stat["total_time"] / stat["count"]
|
||||
stat["min_time"] = min(stat["min_time"], duration)
|
||||
stat["max_time"] = max(stat["max_time"], duration)
|
||||
|
||||
if memory_used is not None:
|
||||
stat["total_memory"] += memory_used
|
||||
stat["avg_memory"] = stat["total_memory"] / stat["count"]
|
||||
|
||||
def report(self) -> str:
|
||||
"""
|
||||
生成性能统计报告
|
||||
|
||||
Returns:
|
||||
格式化的性能统计报告字符串
|
||||
"""
|
||||
if not self.stats:
|
||||
return "暂无性能统计数据"
|
||||
|
||||
report = ["\n=== 性能统计报告 ===\n"]
|
||||
report.append(f"{'函数名':<40} {'调用次数':<10} {'平均时间(ms)':<15} {'最长时间(ms)':<15} {'内存(MB)':<10}")
|
||||
report.append("-" * 100)
|
||||
|
||||
for func_name, stat in sorted(self.stats.items(), key=lambda x: x[1]["total_time"], reverse=True):
|
||||
memory_str = f"{stat['avg_memory']:.2f}" if stat['avg_memory'] > 0 else "-"
|
||||
report.append(
|
||||
f"{func_name:<40} {stat['count']:<10} {stat['avg_time']*1000:<15.2f} "
|
||||
f"{stat['max_time']*1000:<15.2f} {memory_str:<10}"
|
||||
)
|
||||
|
||||
report.append("=" * 100)
|
||||
return "\n".join(report)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
重置性能统计数据
|
||||
"""
|
||||
self.stats.clear()
|
||||
|
||||
|
||||
# 创建全局性能统计实例
|
||||
performance_stats = PerformanceStats()
|
||||
|
||||
|
||||
def timeit(func: Optional[Callable] = None, *, log_level: int = logging.INFO, collect_stats: bool = True):
|
||||
"""
|
||||
函数执行时间分析装饰器(支持同步和异步)
|
||||
|
||||
Args:
|
||||
func: 要装饰的函数
|
||||
log_level: 日志级别
|
||||
collect_stats: 是否收集到全局统计中
|
||||
|
||||
Returns:
|
||||
装饰后的函数
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
func_name = func.__qualname__
|
||||
is_coroutine = inspect.iscoroutinefunction(func)
|
||||
|
||||
if is_coroutine:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if collect_stats:
|
||||
performance_stats.record(func_name, duration)
|
||||
|
||||
logger.log(log_level, f"[性能] {func_name} 执行时间: {duration*1000:.2f} ms")
|
||||
|
||||
return result
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if collect_stats:
|
||||
performance_stats.record(func_name, duration)
|
||||
|
||||
logger.log(log_level, f"[性能] {func_name} 执行时间: {duration*1000:.2f} ms")
|
||||
|
||||
return result
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
if func is None:
|
||||
return decorator
|
||||
return decorator(func)
|
||||
|
||||
|
||||
class profile:
|
||||
"""
|
||||
性能分析上下文管理器
|
||||
使用 pyinstrument 进行详细的性能分析
|
||||
"""
|
||||
def __init__(self, enabled: bool = True, output_file: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
enabled: 是否启用分析
|
||||
output_file: 分析结果输出文件路径(HTML格式)
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self.output_file = output_file
|
||||
self.profiler = None
|
||||
|
||||
def __enter__(self):
|
||||
if self.enabled and PYINSTRUMENT_AVAILABLE:
|
||||
self.profiler = Profiler()
|
||||
self.profiler.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.enabled and PYINSTRUMENT_AVAILABLE and self.profiler:
|
||||
self.profiler.stop()
|
||||
|
||||
# 输出到日志
|
||||
logger.info(f"[性能分析] {self.profiler.print()}")
|
||||
|
||||
# 如果指定了输出文件,保存为HTML
|
||||
if self.output_file:
|
||||
try:
|
||||
html = self.profiler.render(HTMLRenderer())
|
||||
with open(self.output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(html)
|
||||
logger.info(f"[性能分析] 报告已保存到: {self.output_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"[性能分析] 保存报告失败: {e}")
|
||||
|
||||
|
||||
async def aprofile(func: Callable, *args, **kwargs):
|
||||
"""
|
||||
异步函数性能分析
|
||||
|
||||
Args:
|
||||
func: 要分析的异步函数
|
||||
*args: 函数参数
|
||||
**kwargs: 函数关键字参数
|
||||
|
||||
Returns:
|
||||
函数执行结果
|
||||
"""
|
||||
if not PYINSTRUMENT_AVAILABLE:
|
||||
logger.warning("[性能分析] pyinstrument 未安装,无法进行详细分析")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
profiler = Profiler()
|
||||
profiler.start()
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
finally:
|
||||
profiler.stop()
|
||||
logger.info(f"[性能分析] {profiler.print()}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class memory_profile:
|
||||
"""
|
||||
内存分析上下文管理器
|
||||
"""
|
||||
def __init__(self, interval: float = 0.1, enabled: bool = True):
|
||||
"""
|
||||
Args:
|
||||
interval: 内存采样间隔(秒)
|
||||
enabled: 是否启用内存分析
|
||||
"""
|
||||
self.interval = interval
|
||||
self.enabled = enabled
|
||||
self.memory_start = 0.0
|
||||
self.memory_end = 0.0
|
||||
|
||||
def __enter__(self):
|
||||
if self.enabled and MEMORY_PROFILER_AVAILABLE:
|
||||
self.memory_start = memory_usage()[0]
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.enabled and MEMORY_PROFILER_AVAILABLE:
|
||||
self.memory_end = memory_usage()[0]
|
||||
memory_used = self.memory_end - self.memory_start
|
||||
logger.info(f"[内存分析] 使用内存: {memory_used:.2f} MB")
|
||||
|
||||
|
||||
def memory_profile_decorator(func: Optional[Callable] = None, *, interval: float = 0.1):
|
||||
"""
|
||||
内存分析装饰器(支持同步函数)
|
||||
|
||||
Args:
|
||||
func: 要装饰的函数
|
||||
interval: 内存采样间隔
|
||||
|
||||
Returns:
|
||||
装饰后的函数
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not MEMORY_PROFILER_AVAILABLE:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
mem_usage = memory_usage(
|
||||
(func, args, kwargs),
|
||||
interval=interval,
|
||||
timeout=None,
|
||||
include_children=False
|
||||
)
|
||||
|
||||
max_memory = max(mem_usage)
|
||||
logger.info(f"[内存分析] {func.__qualname__} 最大内存使用: {max_memory:.2f} MB")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if func is None:
|
||||
return decorator
|
||||
return decorator(func)
|
||||
|
||||
|
||||
def performance_monitor(func: Optional[Callable] = None, *, threshold: float = 1.0):
|
||||
"""
|
||||
性能监控装饰器
|
||||
仅当函数执行时间超过阈值时记录日志
|
||||
适合生产环境使用
|
||||
|
||||
Args:
|
||||
func: 要装饰的函数
|
||||
threshold: 时间阈值(秒)
|
||||
|
||||
Returns:
|
||||
装饰后的函数
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
func_name = func.__qualname__
|
||||
is_coroutine = inspect.iscoroutinefunction(func)
|
||||
|
||||
if is_coroutine:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
result = await func(*args, **kwargs)
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if duration > threshold:
|
||||
logger.warning(f"[性能监控] {func_name} 执行时间过长: {duration*1000:.2f} ms (阈值: {threshold*1000:.2f} ms)")
|
||||
|
||||
return result
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if duration > threshold:
|
||||
logger.warning(f"[性能监控] {func_name} 执行时间过长: {duration*1000:.2f} ms (阈值: {threshold*1000:.2f} ms)")
|
||||
|
||||
return result
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
if func is None:
|
||||
return decorator
|
||||
return decorator(func)
|
||||
|
||||
|
||||
# 全局实例
|
||||
global_stats = PerformanceStats()
|
||||
|
||||
|
||||
__all__ = [
|
||||
'timeit',
|
||||
'profile',
|
||||
'aprofile',
|
||||
'memory_profile',
|
||||
'memory_profile_decorator',
|
||||
'performance_monitor',
|
||||
'PerformanceStats',
|
||||
'performance_stats',
|
||||
'global_stats'
|
||||
]
|
||||
78
src/neobot/core/utils/singleton.py
Normal file
78
src/neobot/core/utils/singleton.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
通用单例模式基类
|
||||
"""
|
||||
from typing import Any, Dict, Optional, Type, TypeVar, cast
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# 存储每个类的实例
|
||||
_instance_store: Dict[Type, Any] = {}
|
||||
|
||||
class Singleton:
|
||||
"""
|
||||
一个通用的单例基类
|
||||
|
||||
任何继承自该类的子类都将自动成为单例。
|
||||
它通过重写 __new__ 方法来确保每个类只有一个实例。
|
||||
同时,它处理了重复初始化的问题,确保 __init__ 方法只在第一次实例化时被调用。
|
||||
"""
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls: Type[T], *args: Any, **kwargs: Any) -> T:
|
||||
"""
|
||||
创建或返回现有的实例
|
||||
|
||||
Args:
|
||||
*args: 传递给构造函数的位置参数
|
||||
**kwargs: 传递给构造函数的关键字参数
|
||||
|
||||
Returns:
|
||||
T: 单例实例
|
||||
"""
|
||||
# 使用全局字典存储实例,修复类型检查问题
|
||||
if cls not in _instance_store:
|
||||
_instance_store[cls] = super(Singleton, cls).__new__(cls)
|
||||
return _instance_store[cls]
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
确保初始化逻辑只执行一次
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
|
||||
def singleton(cls: Type[T]) -> Type[T]:
|
||||
"""
|
||||
单例装饰器
|
||||
|
||||
将普通类转换为单例类,确保整个应用程序中只有一个实例。
|
||||
|
||||
Args:
|
||||
cls: 要转换为单例的类
|
||||
|
||||
Returns:
|
||||
Type[T]: 单例类
|
||||
"""
|
||||
# 为每个装饰的类创建一个实例存储
|
||||
class_instance: Optional[T] = None
|
||||
|
||||
# 创建一个新的类,继承自原始类
|
||||
class SingletonClass(cls):
|
||||
"""单例包装类"""
|
||||
|
||||
def __new__(cls: Type[T], *args: Any, **kwargs: Any) -> T:
|
||||
"""创建或返回现有的实例"""
|
||||
nonlocal class_instance
|
||||
if class_instance is None:
|
||||
# 使用super()调用原始类的__new__方法
|
||||
class_instance = super(SingletonClass, cls).__new__(cls)
|
||||
return class_instance
|
||||
|
||||
# 复制类的元数据
|
||||
SingletonClass.__name__ = cls.__name__
|
||||
SingletonClass.__doc__ = cls.__doc__
|
||||
SingletonClass.__module__ = cls.__module__
|
||||
|
||||
return SingletonClass
|
||||
326
src/neobot/core/ws.py
Normal file
326
src/neobot/core/ws.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
WebSocket 核心通信模块
|
||||
|
||||
该模块定义了 `WS` 类,负责与 OneBot v11 实现(如 NapCat)建立和管理
|
||||
WebSocket 连接。它是整个机器人框架的底层通信基础。
|
||||
|
||||
主要职责包括:
|
||||
- 建立 WebSocket 连接并处理认证。
|
||||
- 实现断线自动重连机制。
|
||||
- 监听并接收来自 OneBot 的事件和 API 响应。
|
||||
- 分发事件给 `CommandManager` 进行处理。
|
||||
- 提供 `call_api` 方法,用于异步发送 API 请求并等待响应。
|
||||
"""
|
||||
import asyncio
|
||||
import orjson
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, cast
|
||||
import uuid
|
||||
import threading
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bot import Bot
|
||||
|
||||
import websockets
|
||||
from websockets.legacy.client import WebSocketClientProtocol
|
||||
|
||||
from neobot.models.events.factory import EventFactory
|
||||
|
||||
from .config_loader import global_config
|
||||
from .utils.executor import CodeExecutor
|
||||
from .utils.logger import ModuleLogger
|
||||
from .utils.exceptions import (
|
||||
WebSocketError, WebSocketConnectionError
|
||||
)
|
||||
from .utils.error_codes import ErrorCode, create_error_response
|
||||
|
||||
|
||||
class WS:
|
||||
"""
|
||||
WebSocket 客户端,负责与 OneBot v11 实现进行底层通信。
|
||||
"""
|
||||
|
||||
def __init__(self, code_executor: Optional[CodeExecutor] = None) -> None:
|
||||
"""
|
||||
初始化 WebSocket 客户端。
|
||||
|
||||
从全局配置中读取 WebSocket URI、访问令牌(Token)和重连间隔。
|
||||
|
||||
:param code_executor: 代码执行器实例
|
||||
"""
|
||||
# 读取参数
|
||||
cfg = global_config.napcat_ws
|
||||
self.url = cfg.uri
|
||||
self.token = cfg.token
|
||||
self.reconnect_interval = cfg.reconnect_interval
|
||||
|
||||
# 初始化状态
|
||||
self.ws: Optional[WebSocketClientProtocol] = None
|
||||
self._pending_requests: Dict[str, asyncio.Future] = {} # echo: future
|
||||
self.bot: 'Bot' | None = None
|
||||
self.self_id: int | None = None
|
||||
self.code_executor = code_executor
|
||||
|
||||
# 线程安全锁
|
||||
self._pending_requests_lock = threading.RLock()
|
||||
|
||||
# 创建模块专用日志记录器
|
||||
self.logger = ModuleLogger("WebSocket")
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
启动并管理 WebSocket 连接。
|
||||
|
||||
这是一个无限循环,负责建立连接。如果连接断开,它会根据配置的
|
||||
`reconnect_interval` 时间间隔后自动尝试重新连接。
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
||||
|
||||
while True:
|
||||
try:
|
||||
self.logger.info(f"正在尝试连接至 NapCat: {self.url}")
|
||||
async with websockets.connect(
|
||||
self.url, additional_headers=headers
|
||||
) as websocket_raw:
|
||||
websocket = cast(WebSocketClientProtocol, websocket_raw)
|
||||
self.ws = websocket
|
||||
self.logger.success("连接成功!")
|
||||
await self._listen_loop(websocket)
|
||||
|
||||
except (
|
||||
websockets.exceptions.ConnectionClosed,
|
||||
ConnectionRefusedError,
|
||||
) as e:
|
||||
conn_error = WebSocketConnectionError(
|
||||
message=f"WebSocket连接失败: {str(e)}",
|
||||
code=ErrorCode.WS_CONNECTION_FAILED,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.error(f"连接失败: {conn_error.message}")
|
||||
self.logger.log_custom_exception(conn_error)
|
||||
except Exception as e:
|
||||
error = WebSocketError(
|
||||
message=f"WebSocket运行异常: {str(e)}",
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.exception(f"运行异常: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
|
||||
self.logger.info(f"{self.reconnect_interval}秒后尝试重连...")
|
||||
await asyncio.sleep(self.reconnect_interval)
|
||||
|
||||
async def _listen_loop(self, websocket_connection: WebSocketClientProtocol) -> None:
|
||||
"""
|
||||
核心监听循环,处理所有接收到的 WebSocket 消息。
|
||||
|
||||
此循环会持续从 WebSocket 连接中读取消息,并根据消息内容
|
||||
判断是 API 响应还是上报的事件,然后分发给相应的处理逻辑。
|
||||
|
||||
Args:
|
||||
websocket_connection: 当前活动的 WebSocket 连接对象。
|
||||
"""
|
||||
async for message in websocket_connection:
|
||||
try:
|
||||
data = orjson.loads(message)
|
||||
|
||||
# 1. 处理 API 响应
|
||||
# 如果消息中包含 echo 字段,说明是 API 调用的响应
|
||||
echo_id = data.get("echo")
|
||||
if echo_id and echo_id in self._pending_requests:
|
||||
with self._pending_requests_lock:
|
||||
future = self._pending_requests.pop(echo_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
continue
|
||||
|
||||
# 2. 处理上报事件
|
||||
# 如果消息中包含 post_type 字段,说明是 OneBot 上报的事件
|
||||
if "post_type" in data:
|
||||
# 使用 create_task 异步执行,避免阻塞 WebSocket 接收循环
|
||||
asyncio.create_task(self.on_event(data))
|
||||
|
||||
except orjson.JSONDecodeError as e:
|
||||
error = WebSocketError(
|
||||
message=f"JSON解析失败: {str(e)}",
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.error(f"解析消息异常: {error.message}")
|
||||
# 如果message是bytes类型,需要先解码
|
||||
decoded_message = message.decode('utf-8') if isinstance(message, bytes) else message
|
||||
self.logger.debug(f"原始消息: {decoded_message}")
|
||||
except Exception as e:
|
||||
error = WebSocketError(
|
||||
message=f"处理消息异常: {str(e)}",
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.exception(f"解析消息异常: {error.message}")
|
||||
self.logger.log_custom_exception(error)
|
||||
|
||||
async def on_event(self, event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
事件处理和分发层。
|
||||
|
||||
当接收到一个 OneBot 事件时,此方法负责:
|
||||
1. 使用 `EventFactory` 将原始 JSON 数据解析成对应的事件对象。
|
||||
2. 为事件对象注入 `Bot` 实例,以便在插件中可以调用 API。
|
||||
3. 打印格式化的事件日志。
|
||||
4. 将事件对象传递给 `CommandManager` (`matcher`) 进行后续处理。
|
||||
|
||||
Args:
|
||||
event_data (dict): 从 WebSocket 接收到的原始事件字典。
|
||||
"""
|
||||
try:
|
||||
# 使用工厂创建事件对象
|
||||
event = EventFactory.create_event(event_data)
|
||||
|
||||
# 尝试初始化 Bot 实例 (如果尚未初始化且事件包含 self_id)
|
||||
# 只要事件中包含 self_id,我们就可以初始化 Bot,不必非要等待 meta_event
|
||||
if self.bot is None and hasattr(event, 'self_id'):
|
||||
from .bot import Bot
|
||||
self.self_id = event.self_id
|
||||
self.bot = Bot(self)
|
||||
self.logger.success(f"Bot 实例初始化完成: self_id={self.self_id}")
|
||||
|
||||
# 将代码执行器注入到 Bot 和执行器自身
|
||||
if self.code_executor:
|
||||
self.bot.code_executor = self.code_executor
|
||||
self.code_executor.bot = self.bot
|
||||
self.logger.info("代码执行器已成功注入 Bot 实例。")
|
||||
|
||||
# 如果 bot 尚未初始化,则不处理后续事件
|
||||
if self.bot is None:
|
||||
self.logger.warning("Bot 尚未初始化,跳过事件处理。")
|
||||
return
|
||||
|
||||
event.bot = self.bot # 注入 Bot 实例
|
||||
|
||||
# 打印日志
|
||||
if event.post_type == "message":
|
||||
sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown"
|
||||
message_type = getattr(event, "message_type", "Unknown")
|
||||
user_id = getattr(event, "user_id", "Unknown")
|
||||
raw_message = getattr(event, "raw_message", "")
|
||||
self.logger.info(f"[消息] {message_type} | {user_id}({sender_name}): {raw_message}")
|
||||
elif event.post_type == "notice":
|
||||
notice_type = getattr(event, "notice_type", "Unknown")
|
||||
self.logger.info(f"[通知] {notice_type}")
|
||||
elif event.post_type == "request":
|
||||
request_type = getattr(event, "request_type", "Unknown")
|
||||
self.logger.info(f"[请求] {request_type}")
|
||||
elif event.post_type == "meta_event":
|
||||
meta_event_type = getattr(event, "meta_event_type", "Unknown")
|
||||
self.logger.debug(f"[元事件] {meta_event_type}")
|
||||
|
||||
# 分发事件
|
||||
from .managers.command_manager import matcher
|
||||
await matcher.handle_event(self.bot, event)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"事件处理异常: {str(e)}")
|
||||
error = WebSocketError(
|
||||
message=f"事件处理异常: {str(e)}",
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
self.logger.log_custom_exception(error)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
关闭 WebSocket 客户端,释放资源。
|
||||
"""
|
||||
self.logger.info("正在关闭 WebSocket 客户端...")
|
||||
|
||||
# 从 BotManager 注销
|
||||
if self.bot and self.self_id:
|
||||
from .managers.bot_manager import bot_manager
|
||||
bot_manager.unregister_bot(str(self.self_id))
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
# 取消所有挂起的请求
|
||||
with self._pending_requests_lock:
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
self.logger.success("WebSocket 客户端已关闭")
|
||||
|
||||
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
|
||||
"""
|
||||
向 OneBot v11 实现端发送一个 API 请求。
|
||||
|
||||
该方法通过 WebSocket 发送请求,并使用 `echo` 字段来匹配对应的响应。
|
||||
它创建了一个 `Future` 对象来异步等待响应,并设置了超时机制。
|
||||
|
||||
Args:
|
||||
action (str): API 的动作名称,例如 "send_group_msg"。
|
||||
params (dict, optional): API 请求的参数字典。 Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: OneBot API 的响应数据。如果超时或连接断开,则返回一个
|
||||
表示失败的字典。
|
||||
"""
|
||||
if not self.ws:
|
||||
self.logger.error("调用 API 失败: WebSocket 未初始化")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_DISCONNECTED,
|
||||
message="WebSocket未初始化",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
from websockets.protocol import State
|
||||
|
||||
if getattr(self.ws, "state", None) is not State.OPEN:
|
||||
self.logger.error("调用 API 失败: WebSocket 连接未打开")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_DISCONNECTED,
|
||||
message="WebSocket连接未打开",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
echo_id = str(uuid.uuid4())
|
||||
payload = {"action": action, "params": params or {}, "echo": echo_id}
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
with self._pending_requests_lock:
|
||||
self._pending_requests[echo_id] = future
|
||||
|
||||
try:
|
||||
await self.ws.send(orjson.dumps(payload).decode('utf-8'))
|
||||
return await asyncio.wait_for(future, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
with self._pending_requests_lock:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.warning(f"API 调用超时: action={action}, params={params}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.TIMEOUT_ERROR,
|
||||
message="API调用超时",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
except Exception as e:
|
||||
with self._pending_requests_lock:
|
||||
self._pending_requests.pop(echo_id, None)
|
||||
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
|
||||
return create_error_response(
|
||||
code=ErrorCode.WS_MESSAGE_ERROR,
|
||||
message=f"API调用异常: {str(e)}",
|
||||
data={"action": action, "params": params}
|
||||
)
|
||||
|
||||
|
||||
class ReverseWSClient(WS):
|
||||
"""
|
||||
反向 WebSocket 客户端代理,用于 Bot 实例调用 API。
|
||||
"""
|
||||
def __init__(self, manager: Any, client_id: str):
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
self.client_id = client_id
|
||||
|
||||
async def call_api(self, action: str, params: Optional[Dict[Any, Any]] = None) -> Dict[Any, Any]:
|
||||
return await self.manager.call_api(action, params, self.client_id)
|
||||
Reference in New Issue
Block a user