20260101_01

This commit is contained in:
2026-01-01 00:32:36 +08:00
parent 878dbee576
commit dfd4ec638b
5 changed files with 166 additions and 84 deletions

View File

@@ -1,25 +1,47 @@
from typing import Any
import inspect import inspect
from typing import Any, Tuple, Dict, List, Callable
from .config_loader import global_config from .config_loader import global_config
comm = global_config.bot.get("command") # 从配置中获取命令前缀
comm_prefixes = global_config.bot.get("command", ("/",))
class CommandManager: class CommandManager:
def __init__(self, prefixes=(tuple[Any, ...] (comm))): def __init__(self, prefixes: Tuple[str, ...] = ("/",)):
self.prefixes = prefixes self.prefixes = prefixes
self.commands = {} # 存储指令函数 self.commands: Dict[str, Callable] = {} # 存储消息指令
self.notice_handlers: List[Dict] = [] # 存储通知处理器
self.request_handlers: List[Dict] = [] # 存储请求处理器
# --- 1. 消息指令装饰器 ---
def command(self, name: str): def command(self, name: str):
"""装饰器:注册指令""" """装饰器:注册消息指令,例如 @matcher.command("echo")"""
def decorator(func): def decorator(func):
self.commands[name] = func self.commands[name] = func
return func return func
return decorator return decorator
# --- 2. 通知事件装饰器 ---
def on_notice(self, notice_type: str = None):
"""装饰器:注册通知处理器"""
def decorator(func):
self.notice_handlers.append({"type": notice_type, "func": func})
return func
return decorator
# --- 3. 请求事件装饰器 ---
def on_request(self, request_type: str = None):
"""装饰器:注册请求处理器"""
def decorator(func):
self.request_handlers.append({"type": request_type, "func": func})
return func
return decorator
# --- 消息分发逻辑 ---
async def handle_message(self, bot, event): async def handle_message(self, bot, event):
"""解析并分发指令""" """解析并分发消息指令"""
if not event.raw_message:
return
raw_text = event.raw_message.strip() raw_text = event.raw_message.strip()
# 1. 检查前缀 # 1. 检查前缀
@@ -30,7 +52,7 @@ class CommandManager:
break break
if not prefix_found: if not prefix_found:
return # 不是指令,跳过 return
# 2. 拆分指令和参数 # 2. 拆分指令和参数
full_cmd = raw_text[len(prefix_found):].split() full_cmd = raw_text[len(prefix_found):].split()
@@ -43,13 +65,41 @@ class CommandManager:
# 3. 查找并执行 # 3. 查找并执行
if cmd_name in self.commands: if cmd_name in self.commands:
func = self.commands[cmd_name] func = self.commands[cmd_name]
# 自动注入参数 (判断函数是否需要 args) await self._run_handler(func, bot, event, args)
# --- 通知分发逻辑 ---
async def handle_notice(self, bot, event):
"""分发通知事件"""
for handler in self.notice_handlers:
if handler["type"] is None or handler["type"] == event.notice_type:
await self._run_handler(handler["func"], bot, event)
# --- 请求分发逻辑 ---
async def handle_request(self, bot, event):
"""分发请求事件"""
for handler in self.request_handlers:
if handler["type"] is None or handler["type"] == event.request_type:
await self._run_handler(handler["func"], bot, event)
# --- 通用执行器:自动注入参数 ---
async def _run_handler(self, func, bot, event, args=None):
"""根据函数签名自动注入 bot, event 或 args"""
sig = inspect.signature(func) sig = inspect.signature(func)
if "args" in sig.parameters: params = sig.parameters
await func(bot, event, args) kwargs = {}
else:
await func(bot, event) if "bot" in params: kwargs["bot"] = bot
if "event" in params: kwargs["event"] = event
if "args" in params and args is not None: kwargs["args"] = args
# 执行函数
await func(**kwargs)
# 确保前缀是元组格式
if isinstance(comm_prefixes, list):
comm_prefixes = tuple[Any, ...](comm_prefixes)
elif isinstance(comm_prefixes, str):
comm_prefixes = (comm_prefixes,)
# 实例化全局管理器 # 实例化全局管理器
qianzhui = global_config.bot.get("command") matcher = CommandManager(prefixes=comm_prefixes)
matcher = CommandManager(prefixes=(tuple[Any, ...] (comm)))

View File

@@ -6,6 +6,7 @@ import traceback
from .command_manager import matcher from .command_manager import matcher
from .config_loader import global_config from .config_loader import global_config
from models import Event from models import Event
from datetime import datetime
class WS: class WS:
def __init__(self): def __init__(self):
@@ -15,11 +16,11 @@ class WS:
self.token = cfg.get("token") self.token = cfg.get("token")
self.reconnect_interval = cfg.get("reconnect_interval", 5) self.reconnect_interval = cfg.get("reconnect_interval", 5)
self.ws = None # 存储当前的活跃连接 self.ws = None
self._pending_requests = {} # 存储等待 API 返回的 Future 对象 self._pending_requests = {}
async def connect(self): async def connect(self):
"""主连接循环:负责建立连接并处理断线重连""" """主连接循环"""
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
while True: while True:
@@ -28,8 +29,6 @@ class WS:
async with websockets.connect(self.url, additional_headers=headers) as websocket: async with websockets.connect(self.url, additional_headers=headers) as websocket:
self.ws = websocket self.ws = websocket
print(" 连接成功!") print(" 连接成功!")
# 进入阻塞式的监听循环
await self._listen_loop(websocket) await self._listen_loop(websocket)
except (websockets.exceptions.ConnectionClosed, ConnectionRefusedError) as e: except (websockets.exceptions.ConnectionClosed, ConnectionRefusedError) as e:
@@ -38,78 +37,86 @@ class WS:
print(f" 运行异常: {e}") print(f" 运行异常: {e}")
traceback.print_exc() traceback.print_exc()
print(f" {self.reconnect_interval}秒后尝试重连...") print(f" {self.reconnect_interval}秒后尝试重连...")
await asyncio.sleep(self.reconnect_interval) await asyncio.sleep(self.reconnect_interval)
async def _listen_loop(self, websocket): async def _listen_loop(self, websocket):
"""核心监听循环:负责从 WebSocket 读取原始数据并分类分发""" """核心监听循环"""
async for message in websocket: async for message in websocket:
try: try:
data = json.loads(message) data = json.loads(message)
# 1. 优先处理 API 响应 (带有 echo 字段) # 1. 处理 API 响应
echo_id = data.get("echo") echo_id = data.get("echo")
if echo_id and echo_id in self._pending_requests: if echo_id and echo_id in self._pending_requests:
future = self._pending_requests.pop(echo_id) future = self._pending_requests.pop(echo_id)
if not future.done(): if not future.done():
future.set_result(data) # 唤醒对应的 call_api 函数 future.set_result(data)
continue # 处理完 API 响应后跳过本次循环 continue
# 2. 处理上报事件 (含有 post_type 字段) # 2. 处理上报事件
if "post_type" in data: if "post_type" in data:
# 使用 create_task 异步执行,确保复杂的业务逻辑不阻塞消息接收 # 使用 create_task 异步执行,避免阻塞
asyncio.create_task(self.on_event(data)) asyncio.create_task(self.on_event(data))
except Exception as e: except Exception as e:
print(f" 解析消息异常: {e}") print(f" 解析消息异常: {e}")
async def on_event(self, raw_data: dict): async def on_event(self, raw_data: dict):
"""事件分发层:将原始字典转换为 Event 对象并交给 matcher""" """事件分发层:根据 post_type 调用 matcher 对应的处理器"""
# 仅处理消息事件 (message),忽略元事件 (meta_event) 或请求事件 (request)
if raw_data.get("post_type") != "message":
return
try: try:
# 将字典解析为强类型的 Event 对象 # 解析为 Event 对象
event = Event.from_dict(raw_data) event = Event.from_dict(raw_data)
# 调试日志:可以看到收到的每条指令内容 # 格式化时间用于打印
print(f" 收到消息: [{event.user_id}] -> {event.raw_message}") t = datetime.fromtimestamp(event.time).strftime("%H:%M:%S")
# 调用插件系统的入口函数 # --- 分流处理 ---
# A. 消息事件 (Message)
if event.post_type == "message":
print(f" [{t}] [消息] {event.message_type} | {event.user_id}: {event.raw_message}")
await matcher.handle_message(self, event) await matcher.handle_message(self, event)
# B. 通知事件 (Notice)
elif event.post_type == "notice":
print(f" [{t}] [通知] {event.notice_type} | 来自: {event.group_id or '私聊'}")
await matcher.handle_notice(self, event)
# C. 请求事件 (Request)
elif event.post_type == "request":
print(f" [{t}] [请求] {event.request_type} | 内容: {event.comment}")
await matcher.handle_request(self, event)
# D. 元事件 (Meta Event) - 通常用来心跳检测,可不处理
elif event.post_type == "meta_event":
pass
except Exception as e: except Exception as e:
print(f"事件分发失败: {e}") print(f"事件分发失败: {e}")
traceback.print_exc()
async def call_api(self, action: str, params: dict = None): async def call_api(self, action: str, params: dict = None):
"""调用 OneBot API"""
if not self.ws: if not self.ws:
return {"status": "failed", "msg": "websocket not initialized"} return {"status": "failed", "msg": "websocket not initialized"}
# 检查 websockets 13.x+ 的状态属性
from websockets.protocol import State from websockets.protocol import State
if getattr(self.ws, "state", None) is not State.OPEN: if getattr(self.ws, "state", None) is not State.OPEN:
return {"status": "failed", "msg": "websocket is not open"} return {"status": "failed", "msg": "websocket is not open"}
echo_id = str(uuid.uuid4()) echo_id = str(uuid.uuid4())
payload = { payload = {"action": action, "params": params or {}, "echo": echo_id}
"action": action,
"params": params or {},
"echo": echo_id
}
# 创建一个 Future 对象用于等待返回结果
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
future = loop.create_future() future = loop.create_future()
self._pending_requests[echo_id] = future self._pending_requests[echo_id] = future
# 通过 WebSocket 发送请求
await self.ws.send(json.dumps(payload)) await self.ws.send(json.dumps(payload))
try: try:
# 设置 100 秒超时,防止 API 请求永久挂起
return await asyncio.wait_for(future, timeout=100.0) return await asyncio.wait_for(future, timeout=30.0)
except asyncio.TimeoutError: except asyncio.TimeoutError:
self._pending_requests.pop(echo_id, None) self._pending_requests.pop(echo_id, None)
return {"status": "failed", "retcode": -1, "msg": "api timeout"} return {"status": "failed", "retcode": -1, "msg": "api timeout"}

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from .sender import Sender # 导入上面的 Sender from .sender import Sender
@dataclass @dataclass
class MessageSegment: class MessageSegment:
@@ -9,16 +9,13 @@ class MessageSegment:
@property @property
def text(self) -> str: def text(self) -> str:
"""如果是文本段,返回文本内容,否则返回空字符串"""
return self.data.get("text", "") if self.type == "text" else "" return self.data.get("text", "") if self.type == "text" else ""
@property @property
def image_url(self) -> str: def image_url(self) -> str:
"""如果是图片段,返回图片 URL"""
return self.data.get("url", "") if self.type == "image" else "" return self.data.get("url", "") if self.type == "image" else ""
def is_at(self, user_id: int = None) -> bool: def is_at(self, user_id: int = None) -> bool:
"""判断是否是 @某人"""
if self.type != "at": if self.type != "at":
return False return False
if user_id is None: if user_id is None:
@@ -28,36 +25,72 @@ class MessageSegment:
def __repr__(self): def __repr__(self):
return f"[MS:{self.type}:{self.data}]" return f"[MS:{self.type}:{self.data}]"
@dataclass @dataclass
class Event: class Event:
post_type: str post_type: str
message_type: str # group 或 private
user_id: int
self_id: int self_id: int
raw_message: str
message: List[MessageSegment]
sender: Sender
time: int time: int
message_type: Optional[str] = None
sub_type: Optional[str] = None
message_id: Optional[int] = None
user_id: Optional[int] = None
raw_message: Optional[str] = None
message: List[MessageSegment] = field(default_factory=list)
sender: Optional[Sender] = None
group_id: Optional[int] = None group_id: Optional[int] = None
target_id: Optional[int] = None target_id: Optional[int] = None
notice_type: Optional[str] = None
operator_id: Optional[int] = None
duration: Optional[int] = None
honor_type: Optional[str] = None
request_type: Optional[str] = None
flag: Optional[str] = None
comment: Optional[str] = None
@classmethod @classmethod
def from_dict(cls, data: dict): def from_dict(cls, data: dict):
raw_msg_array = data.get("message", []) raw_msg_array = data.get("message")
segments = []
if isinstance(raw_msg_array, list):
segments = [ segments = [
MessageSegment(type=seg["type"], data=seg["data"]) MessageSegment(type=seg["type"], data=seg["data"])
for seg in raw_msg_array for seg in raw_msg_array
] ]
data_copy = data.copy() sender_data = data.get("sender")
data_copy["message"] = segments sender_obj = None
if isinstance(sender_data, dict):
sender_obj = Sender(**{
k: v for k, v in sender_data.items()
if k in Sender.__annotations__
})
sender_data = data.get("sender", {}) # 数据整合
sender_obj = Sender(**{k: v for k, v in sender_data.items() if k in Sender.__annotations__}) processed_data = data.copy()
processed_data["message"] = segments
processed_data["sender"] = sender_obj
data_copy = data.copy() # 字段过滤:只提取 dataclass 中定义的字段
data_copy["message"] = segments valid_data = {
data_copy["sender"] = sender_obj # 关键点:把对象塞进去 k: v for k, v in processed_data.items()
if k in cls.__annotations__
valid_data = {k: v for k, v in data_copy.items() if k in cls.__annotations__} }
return cls(**valid_data) return cls(**valid_data)
# --- 快捷判断工具 ---
@property
def is_message(self) -> bool:
return self.post_type == "message"
@property
def is_notice(self) -> bool:
return self.post_type == "notice"
@property
def is_request(self) -> bool:
return self.post_type == "request"

View File

@@ -1,8 +0,0 @@
from dataclasses import dataclass
#TODO 数据类型
@dataclass
class Sender:
user_id: int
nickname: str
card: str = ""
role: str = "" # admin, owner, member

Binary file not shown.