112 lines
4.4 KiB
Python
112 lines
4.4 KiB
Python
import asyncio
|
||
import json
|
||
import uuid
|
||
import websockets
|
||
import traceback
|
||
from .command_manager import matcher
|
||
from .config_loader import global_config
|
||
from models import Event
|
||
|
||
class WS:
|
||
def __init__(self):
|
||
# 读取参数
|
||
cfg = global_config.napcat_ws
|
||
self.url = cfg.get("uri")
|
||
self.token = cfg.get("token")
|
||
self.reconnect_interval = cfg.get("reconnect_interval", 5)
|
||
|
||
self.ws = None # 存储当前的活跃连接
|
||
self._pending_requests = {} # 存储等待 API 返回的 Future 对象
|
||
|
||
async def connect(self):
|
||
"""主连接循环:负责建立连接并处理断线重连"""
|
||
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
||
|
||
while True:
|
||
try:
|
||
print(f" 正在尝试连接至 NapCat: {self.url}")
|
||
async with websockets.connect(self.url, additional_headers=headers) as websocket:
|
||
self.ws = websocket
|
||
print(" 连接成功!")
|
||
|
||
# 进入阻塞式的监听循环
|
||
await self._listen_loop(websocket)
|
||
|
||
except (websockets.exceptions.ConnectionClosed, ConnectionRefusedError) as e:
|
||
print(f" 连接断开或服务器拒绝访问: {e}")
|
||
except Exception as e:
|
||
print(f" 运行异常: {e}")
|
||
traceback.print_exc()
|
||
|
||
|
||
print(f" {self.reconnect_interval}秒后尝试重连...")
|
||
await asyncio.sleep(self.reconnect_interval)
|
||
|
||
async def _listen_loop(self, websocket):
|
||
"""核心监听循环:负责从 WebSocket 读取原始数据并分类分发"""
|
||
async for message in websocket:
|
||
try:
|
||
data = json.loads(message)
|
||
|
||
# 1. 优先处理 API 响应 (带有 echo 字段)
|
||
echo_id = data.get("echo")
|
||
if echo_id and echo_id in self._pending_requests:
|
||
future = self._pending_requests.pop(echo_id)
|
||
if not future.done():
|
||
future.set_result(data) # 唤醒对应的 call_api 函数
|
||
continue # 处理完 API 响应后跳过本次循环
|
||
|
||
# 2. 处理上报的事件 (含有 post_type 字段)
|
||
if "post_type" in data:
|
||
# 使用 create_task 异步执行,确保复杂的业务逻辑不阻塞消息接收
|
||
asyncio.create_task(self.on_event(data))
|
||
|
||
except Exception as e:
|
||
print(f" 解析消息异常: {e}")
|
||
|
||
async def on_event(self, raw_data: dict):
|
||
"""事件分发层:将原始字典转换为 Event 对象并交给 matcher"""
|
||
# 仅处理消息事件 (message),忽略元事件 (meta_event) 或请求事件 (request)
|
||
if raw_data.get("post_type") != "message":
|
||
return
|
||
|
||
try:
|
||
# 将字典解析为强类型的 Event 对象
|
||
event = Event.from_dict(raw_data)
|
||
|
||
# 调试日志:可以看到收到的每条指令内容
|
||
print(f" 收到消息: [{event.user_id}] -> {event.raw_message}")
|
||
|
||
# 调用插件系统的入口函数
|
||
await matcher.handle_message(self, event)
|
||
|
||
except Exception as e:
|
||
print(f" 事件分发失败: {e}")
|
||
|
||
async def call_api(self, action: str, params: dict = None):
|
||
"""公有 API:供插件调用,发送指令并异步等待结果"""
|
||
if not self.ws or self.ws.closed:
|
||
return {"status": "failed", "msg": "websocket not connected"}
|
||
|
||
# 创建唯一的 echo ID
|
||
echo_id = str(uuid.uuid4())
|
||
payload = {
|
||
"action": action,
|
||
"params": params or {},
|
||
"echo": echo_id
|
||
}
|
||
|
||
# 创建一个 Future 对象用于等待返回结果
|
||
loop = asyncio.get_running_loop()
|
||
future = loop.create_future()
|
||
self._pending_requests[echo_id] = future
|
||
|
||
# 通过 WebSocket 发送请求
|
||
await self.ws.send(json.dumps(payload))
|
||
|
||
try:
|
||
# 设置 100 秒超时,防止 API 请求永久挂起
|
||
return await asyncio.wait_for(future, timeout=100.0)
|
||
except asyncio.TimeoutError:
|
||
self._pending_requests.pop(echo_id, None)
|
||
return {"status": "failed", "retcode": -1, "msg": "api timeout"} |