feat(reverse_ws): 添加反向WebSocket支持及负载均衡功能

- 新增反向WebSocket管理器模块,支持多客户端连接
- 实现负载均衡机制,自动选择健康且负载最低的客户端
- 添加防重复事件处理机制,防止消息重复处理
- 更新配置模型和加载器以支持反向WebSocket配置
- 添加示例文件和文档说明使用方法
- 修改主程序启动逻辑以支持反向WebSocket服务
This commit is contained in:
2026-02-28 20:57:48 +08:00
parent ed4da64a7a
commit 014c6c9092
10 changed files with 769 additions and 6 deletions

View File

@@ -11,6 +11,7 @@ from .redis_manager import RedisManager
from .mysql_manager import MySQLManager
from .browser_manager import BrowserManager
from .image_manager import ImageManager
from .reverse_ws_manager import ReverseWSManager
# --- 实例化所有单例管理器 ---
@@ -36,6 +37,9 @@ browser_manager = BrowserManager()
# 图片管理器
image_manager = ImageManager()
# 反向 WebSocket 管理器
reverse_ws_manager = ReverseWSManager()
__all__ = [
"permission_manager",
"command_manager",
@@ -45,4 +49,5 @@ __all__ = [
"mysql_manager",
"browser_manager",
"image_manager",
"reverse_ws_manager",
]

View File

@@ -35,7 +35,7 @@ class ImageManager(Singleton):
# 模板缓存
self._template_cache: Dict[str, Template] = {}
async def render_template(self, template_name: str, data: Dict[str, Any], output_name: str = "output.png", quality: int = 80, image_type: str = "png") -> Optional[str]:
async def render_template(self, template_name: str, data: Dict[str, Any], output_name: str = "output.png", quality: int = 80, image_type: str = "png", width: int = 1920, height: int = 1080) -> Optional[str]:
"""
使用 Playwright 渲染 Jinja2 模板并保存为图片文件
@@ -45,6 +45,8 @@ class ImageManager(Singleton):
output_name (str, optional): 输出文件名. Defaults to "output.png".
quality (int, optional): JPEG 质量 (0-100). 仅在 image_type 为 jpeg 时有效. Defaults to 80.
image_type (str, optional): 图片类型 ('png' or 'jpeg'). Defaults to "png".
width (int, optional): 图片宽度. Defaults to 1920.
height (int, optional): 图片高度. Defaults to 1080.
Returns:
Optional[str]: 生成图片的绝对路径,如果失败则返回 None
@@ -74,8 +76,8 @@ class ImageManager(Singleton):
return None
try:
width = global_config.image_manager.image_width
height = global_config.image_manager.image_height
width = data.get("width", width)
height = data.get("height", height)
await page.set_viewport_size({"width": width, "height": height})
# 加载内容

View File

@@ -0,0 +1,438 @@
"""
反向 WebSocket 管理器模块
该模块提供了反向 WebSocket 服务端功能,允许 OneBot 实现(如 NapCat
主动连接到机器人服务器,而不是由机器人主动连接到 OneBot 实现。
"""
import asyncio
import orjson
import websockets
from websockets.server import WebSocketServerProtocol
from typing import Dict, Any, Optional, Set
from datetime import datetime
import uuid
import random
from ..config_loader import global_config
from ..utils.logger import ModuleLogger
from ..utils.exceptions import WebSocketError, WebSocketConnectionError
from ..utils.error_codes import ErrorCode, create_error_response
from .command_manager import matcher
from models.events.factory import EventFactory
from .redis_manager import redis_manager
class ReverseWSManager:
"""
反向 WebSocket 管理器,作为服务端接收 OneBot 实现的连接。
支持多前端负载均衡和防重复发送机制。
"""
def __init__(self):
"""
初始化反向 WebSocket 管理器。
"""
self.server = None
self.clients: Dict[str, WebSocketServerProtocol] = {}
self.client_self_ids: Dict[str, int] = {}
self._pending_requests: Dict[str, asyncio.Future] = {}
self._running = False
self.logger = ModuleLogger("ReverseWSManager")
# 负载均衡相关
self._active_client_id: Optional[str] = None # 当前活跃的客户端(用于消息发送)
self._client_load: Dict[str, int] = {} # 客户端负载计数
self._client_health: Dict[str, datetime] = {} # 客户端健康检查时间
# 防重复发送相关
self._processed_events: Dict[str, datetime] = {} # 已处理的事件ID和时间
self._event_ttl = 60 # 事件ID保留时间
self._message_locks: Dict[str, asyncio.Lock] = {} # 消息处理锁
self._lock_ttl = 300 # 锁保留时间(秒)
# 启动清理任务
self._cleanup_task = None
async def start(self, host: str = "0.0.0.0", port: int = 3002) -> None:
"""
启动反向 WebSocket 服务端。
Args:
host: 监听地址,默认为 0.0.0.0
port: 监听端口,默认为 3002
"""
self._running = True
self.server = await websockets.serve(
self._handle_client,
host,
port,
ping_interval=20,
ping_timeout=20
)
self.logger.success(f"反向 WebSocket 服务端已启动: ws://{host}:{port}")
# 启动清理任务
self._cleanup_task = asyncio.create_task(self._cleanup_expired_data())
async def stop(self) -> None:
"""
停止反向 WebSocket 服务端。
"""
self._running = False
# 停止清理任务
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
if self.server:
self.server.close()
await self.server.wait_closed()
for client_id in list(self.clients.keys()):
await self._disconnect_client(client_id)
self.logger.success("反向 WebSocket 服务端已停止")
async def _handle_client(
self,
websocket: WebSocketServerProtocol,
path: str = None
) -> None:
"""
处理客户端连接。
Args:
websocket: WebSocket 连接对象
path: 连接路径
"""
client_id = str(uuid.uuid4())
self.clients[client_id] = websocket
self.logger.info(f"新客户端连接: {client_id}")
try:
async for message in websocket:
try:
data = orjson.loads(message)
# 处理 API 响应
echo_id = data.get("echo")
if echo_id and echo_id in self._pending_requests:
future = self._pending_requests.pop(echo_id)
if not future.done():
future.set_result(data)
continue
# 处理上报事件
if "post_type" in data:
asyncio.create_task(self._on_event(client_id, data))
except orjson.JSONDecodeError as e:
self.logger.error(f"JSON 解析失败: {str(e)}")
except Exception as e:
self.logger.exception(f"处理消息异常: {str(e)}")
except websockets.exceptions.ConnectionClosed as e:
self.logger.info(f"客户端断开连接: {client_id} - {str(e)}")
except Exception as e:
self.logger.exception(f"客户端异常: {str(e)}")
finally:
await self._disconnect_client(client_id)
async def _cleanup_expired_data(self) -> None:
"""
清理过期的事件ID和消息锁
"""
while self._running:
try:
await asyncio.sleep(10) # 每10秒清理一次
current_time = datetime.now()
# 清理过期的事件ID
expired_events = [
event_id for event_id, timestamp in self._processed_events.items()
if (current_time - timestamp).total_seconds() > self._event_ttl
]
for event_id in expired_events:
del self._processed_events[event_id]
# 清理过期的消息锁
expired_locks = [
lock_key for lock_key, timestamp in self._message_locks.items()
if (current_time - timestamp).total_seconds() > self._lock_ttl
]
for lock_key in expired_locks:
del self._message_locks[lock_key]
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"清理过期数据失败: {str(e)}")
async def _disconnect_client(self, client_id: str) -> None:
"""
断开客户端连接。
Args:
client_id: 客户端 ID
"""
if client_id in self.clients:
del self.clients[client_id]
if client_id in self.client_self_ids:
del self.client_self_ids[client_id]
async def _on_event(self, client_id: str, event_data: Dict[str, Any]) -> None:
"""
处理事件,包含防重复发送和负载均衡逻辑。
Args:
client_id: 客户端 ID
event_data: 事件数据
"""
try:
event = EventFactory.create_event(event_data)
if hasattr(event, 'self_id'):
self.client_self_ids[client_id] = event.self_id
event.bot = None
# 记录客户端健康状态
self._client_health[client_id] = datetime.now()
# 检查是否为重复事件
if self._is_duplicate_event(event_data):
self.logger.debug(f"检测到重复事件,已忽略: {event_data.get('id')}")
return
# 标记事件已处理
self._mark_event_processed(event_data)
# 处理消息事件
if event.post_type == "message":
sender_name = event.sender.nickname if hasattr(event, "sender") and event.sender else "Unknown"
message_type = getattr(event, "message_type", "Unknown")
user_id = getattr(event, "user_id", "Unknown")
raw_message = getattr(event, "raw_message", "")
self.logger.info(f"[消息] {message_type} | {user_id}({sender_name}): {raw_message}")
# 使用锁防止同一消息被多次处理
message_key = self._get_message_key(event_data)
async with self._get_message_lock(message_key):
# 再次检查是否重复(防止并发问题)
if self._is_duplicate_event(event_data):
self.logger.debug(f"并发检测到重复消息,已忽略: {message_key}")
return
self._mark_event_processed(event_data)
# 更新客户端负载
self._update_client_load(client_id)
await matcher.handle_event(None, event)
elif event.post_type == "notice":
notice_type = getattr(event, "notice_type", "Unknown")
self.logger.info(f"[通知] {notice_type}")
await matcher.handle_event(None, event)
elif event.post_type == "request":
request_type = getattr(event, "request_type", "Unknown")
self.logger.info(f"[请求] {request_type}")
await matcher.handle_event(None, event)
elif event.post_type == "meta_event":
meta_event_type = getattr(event, "meta_event_type", "Unknown")
self.logger.debug(f"[元事件] {meta_event_type}")
await matcher.handle_event(None, event)
except Exception as e:
self.logger.exception(f"事件处理异常: {str(e)}")
async def call_api(
self,
action: str,
params: Optional[Dict[Any, Any]] = None,
client_id: Optional[str] = None,
use_load_balance: bool = True
) -> Dict[Any, Any]:
"""
向客户端发送 API 请求。
Args:
action: API 动作名称
params: API 参数
client_id: 客户端 ID如果为 None 则根据负载均衡策略选择
use_load_balance: 是否使用负载均衡,默认为 True
Returns:
API 响应数据
"""
if not self.clients:
self.logger.error("调用 API 失败: 没有可用的客户端连接")
return create_error_response(
code=ErrorCode.WS_DISCONNECTED,
message="没有可用的客户端连接",
data={"action": action, "params": params}
)
# 如果没有指定客户端,使用负载均衡
if client_id is None and use_load_balance:
# 优先选择健康的客户端
healthy_clients = self.get_healthy_clients()
if healthy_clients:
# 选择负载最低的客户端
client_id = self.get_client_with_least_load()
if client_id is None and healthy_clients:
client_id = list(healthy_clients.keys())[0]
else:
# 如果没有健康客户端,使用所有客户端中的一个
client_id = list(self.clients.keys())[0]
echo_id = str(uuid.uuid4())
payload = {"action": action, "params": params or {}, "echo": echo_id}
loop = asyncio.get_running_loop()
future = loop.create_future()
self._pending_requests[echo_id] = future
try:
targets = [client_id] if client_id else list(self.clients.keys())
for cid in targets:
if cid in self.clients:
await self.clients[cid].send(orjson.dumps(payload))
return await asyncio.wait_for(future, timeout=30.0)
except asyncio.TimeoutError:
self._pending_requests.pop(echo_id, None)
self.logger.warning(f"API 调用超时: action={action}, params={params}")
return create_error_response(
code=ErrorCode.TIMEOUT_ERROR,
message="API调用超时",
data={"action": action, "params": params}
)
except Exception as e:
self._pending_requests.pop(echo_id, None)
self.logger.exception(f"API 调用异常: action={action}, error={str(e)}")
return create_error_response(
code=ErrorCode.WS_MESSAGE_ERROR,
message=f"API调用异常: {str(e)}",
data={"action": action, "params": params}
)
def get_connected_clients(self) -> Dict[str, int]:
"""
获取已连接的客户端列表。
Returns:
客户端 ID 和 self_id 的映射字典
"""
return self.client_self_ids.copy()
def _is_duplicate_event(self, event_data: Dict[str, Any]) -> bool:
"""
检查是否为重复事件。
Args:
event_data: 事件数据
Returns:
是否为重复事件
"""
event_id = event_data.get('id') or event_data.get('post_id') or event_data.get('time')
if not event_id:
return False
event_key = f"{event_data.get('post_type')}:{event_id}"
return event_key in self._processed_events
def _mark_event_processed(self, event_data: Dict[str, Any]) -> None:
"""
标记事件已处理。
Args:
event_data: 事件数据
"""
event_id = event_data.get('id') or event_data.get('post_id') or event_data.get('time')
if not event_id:
return
event_key = f"{event_data.get('post_type')}:{event_id}"
self._processed_events[event_key] = datetime.now()
def _get_message_key(self, event_data: Dict[str, Any]) -> str:
"""
获取消息唯一标识。
Args:
event_data: 事件数据
Returns:
消息唯一标识
"""
if event_data.get('post_type') == 'message':
message_id = event_data.get('message_id') or event_data.get('id')
user_id = event_data.get('user_id')
return f"msg:{message_id}:{user_id}"
return str(uuid.uuid4())
def _get_message_lock(self, key: str) -> asyncio.Lock:
"""
获取消息处理锁。
Args:
key: 消息唯一标识
Returns:
asyncio.Lock 实例
"""
if key not in self._message_locks:
self._message_locks[key] = asyncio.Lock()
return self._message_locks[key]
def _update_client_load(self, client_id: str) -> None:
"""
更新客户端负载。
Args:
client_id: 客户端 ID
"""
if client_id not in self._client_load:
self._client_load[client_id] = 0
self._client_load[client_id] += 1
def get_client_with_least_load(self) -> Optional[str]:
"""
获取负载最低的客户端。
Returns:
客户端 ID如果没有客户端则返回 None
"""
if not self._client_load:
return None
return min(self._client_load.keys(), key=lambda k: self._client_load[k])
def get_healthy_clients(self) -> Dict[str, int]:
"""
获取健康的客户端列表最近30秒内有活动
Returns:
健康的客户端 ID 和 self_id 的映射字典
"""
current_time = datetime.now()
healthy = {}
for client_id, last_health in self._client_health.items():
if (current_time - last_health).total_seconds() < 30:
if client_id in self.client_self_ids:
healthy[client_id] = self.client_self_ids[client_id]
return healthy
reverse_ws_manager = ReverseWSManager()