diff --git a/core/command_manager.py b/core/command_manager.py index b750214..4b67756 100644 --- a/core/command_manager.py +++ b/core/command_manager.py @@ -1,25 +1,47 @@ -from typing import Any - - import inspect +from typing import Any, Tuple, Dict, List, Callable from .config_loader import global_config -comm = global_config.bot.get("command") +# 从配置中获取命令前缀 +comm_prefixes = global_config.bot.get("command", ("/",)) class CommandManager: - def __init__(self, prefixes=(tuple[Any, ...] (comm))): + def __init__(self, prefixes: Tuple[str, ...] = ("/",)): self.prefixes = prefixes - self.commands = {} # 存储指令函数 + self.commands: Dict[str, Callable] = {} # 存储消息指令 + self.notice_handlers: List[Dict] = [] # 存储通知处理器 + self.request_handlers: List[Dict] = [] # 存储请求处理器 + # --- 1. 消息指令装饰器 --- def command(self, name: str): - """装饰器:注册指令""" + """装饰器:注册消息指令,例如 @matcher.command("echo")""" def decorator(func): self.commands[name] = func return func return decorator + # --- 2. 通知事件装饰器 --- + def on_notice(self, notice_type: str = None): + """装饰器:注册通知处理器""" + def decorator(func): + self.notice_handlers.append({"type": notice_type, "func": func}) + return func + return decorator + + # --- 3. 请求事件装饰器 --- + def on_request(self, request_type: str = None): + """装饰器:注册请求处理器""" + def decorator(func): + self.request_handlers.append({"type": request_type, "func": func}) + return func + return decorator + + # --- 消息分发逻辑 --- async def handle_message(self, bot, event): - """解析并分发指令""" + """解析并分发消息指令""" + if not event.raw_message: + return + raw_text = event.raw_message.strip() # 1. 检查前缀 @@ -30,7 +52,7 @@ class CommandManager: break if not prefix_found: - return # 不是指令,跳过 + return # 2. 拆分指令和参数 full_cmd = raw_text[len(prefix_found):].split() @@ -43,13 +65,41 @@ class CommandManager: # 3. 查找并执行 if cmd_name in self.commands: func = self.commands[cmd_name] - # 自动注入参数 (判断函数是否需要 args) - sig = inspect.signature(func) - if "args" in sig.parameters: - await func(bot, event, args) - else: - await func(bot, event) + await self._run_handler(func, bot, event, args) + + # --- 通知分发逻辑 --- + async def handle_notice(self, bot, event): + """分发通知事件""" + for handler in self.notice_handlers: + if handler["type"] is None or handler["type"] == event.notice_type: + await self._run_handler(handler["func"], bot, event) + + # --- 请求分发逻辑 --- + async def handle_request(self, bot, event): + """分发请求事件""" + for handler in self.request_handlers: + if handler["type"] is None or handler["type"] == event.request_type: + await self._run_handler(handler["func"], bot, event) + + # --- 通用执行器:自动注入参数 --- + async def _run_handler(self, func, bot, event, args=None): + """根据函数签名自动注入 bot, event 或 args""" + sig = inspect.signature(func) + params = sig.parameters + kwargs = {} + + if "bot" in params: kwargs["bot"] = bot + if "event" in params: kwargs["event"] = event + if "args" in params and args is not None: kwargs["args"] = args + + # 执行函数 + await func(**kwargs) + +# 确保前缀是元组格式 +if isinstance(comm_prefixes, list): + comm_prefixes = tuple[Any, ...](comm_prefixes) +elif isinstance(comm_prefixes, str): + comm_prefixes = (comm_prefixes,) # 实例化全局管理器 -qianzhui = global_config.bot.get("command") -matcher = CommandManager(prefixes=(tuple[Any, ...] (comm))) \ No newline at end of file +matcher = CommandManager(prefixes=comm_prefixes) \ No newline at end of file diff --git a/core/ws.py b/core/ws.py index aa8d7c7..88a3b0d 100644 --- a/core/ws.py +++ b/core/ws.py @@ -6,6 +6,7 @@ import traceback from .command_manager import matcher from .config_loader import global_config from models import Event +from datetime import datetime class WS: def __init__(self): @@ -15,11 +16,11 @@ class WS: self.token = cfg.get("token") self.reconnect_interval = cfg.get("reconnect_interval", 5) - self.ws = None # 存储当前的活跃连接 - self._pending_requests = {} # 存储等待 API 返回的 Future 对象 + self.ws = None + self._pending_requests = {} async def connect(self): - """主连接循环:负责建立连接并处理断线重连""" + """主连接循环""" headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} while True: @@ -28,8 +29,6 @@ class WS: async with websockets.connect(self.url, additional_headers=headers) as websocket: self.ws = websocket print(" 连接成功!") - - # 进入阻塞式的监听循环 await self._listen_loop(websocket) except (websockets.exceptions.ConnectionClosed, ConnectionRefusedError) as e: @@ -38,78 +37,86 @@ class WS: print(f" 运行异常: {e}") traceback.print_exc() - print(f" {self.reconnect_interval}秒后尝试重连...") await asyncio.sleep(self.reconnect_interval) async def _listen_loop(self, websocket): - """核心监听循环:负责从 WebSocket 读取原始数据并分类分发""" + """核心监听循环""" async for message in websocket: try: data = json.loads(message) - # 1. 优先处理 API 响应 (带有 echo 字段) + # 1. 处理 API 响应 echo_id = data.get("echo") if echo_id and echo_id in self._pending_requests: future = self._pending_requests.pop(echo_id) if not future.done(): - future.set_result(data) # 唤醒对应的 call_api 函数 - continue # 处理完 API 响应后跳过本次循环 + future.set_result(data) + continue - # 2. 处理上报的事件 (含有 post_type 字段) + # 2. 处理上报事件 if "post_type" in data: - # 使用 create_task 异步执行,确保复杂的业务逻辑不阻塞消息接收 + # 使用 create_task 异步执行,避免阻塞 asyncio.create_task(self.on_event(data)) except Exception as e: print(f" 解析消息异常: {e}") async def on_event(self, raw_data: dict): - """事件分发层:将原始字典转换为 Event 对象并交给 matcher""" - # 仅处理消息事件 (message),忽略元事件 (meta_event) 或请求事件 (request) - if raw_data.get("post_type") != "message": - return - + """事件分发层:根据 post_type 调用 matcher 对应的处理器""" try: - # 将字典解析为强类型的 Event 对象 + # 解析为 Event 对象 event = Event.from_dict(raw_data) - # 调试日志:可以看到收到的每条指令内容 - print(f" 收到消息: [{event.user_id}] -> {event.raw_message}") + # 格式化时间用于打印 + t = datetime.fromtimestamp(event.time).strftime("%H:%M:%S") + + # --- 分流处理 --- + + # A. 消息事件 (Message) + if event.post_type == "message": + print(f" [{t}] [消息] {event.message_type} | {event.user_id}: {event.raw_message}") + await matcher.handle_message(self, event) + + # B. 通知事件 (Notice) + elif event.post_type == "notice": + print(f" [{t}] [通知] {event.notice_type} | 来自: {event.group_id or '私聊'}") + await matcher.handle_notice(self, event) + + # C. 请求事件 (Request) + elif event.post_type == "request": + print(f" [{t}] [请求] {event.request_type} | 内容: {event.comment}") + await matcher.handle_request(self, event) + + # D. 元事件 (Meta Event) - 通常用来心跳检测,可不处理 + elif event.post_type == "meta_event": + pass - # 调用插件系统的入口函数 - await matcher.handle_message(self, event) - except Exception as e: - print(f" 事件分发失败: {e}") + print(f"事件分发失败: {e}") + traceback.print_exc() async def call_api(self, action: str, params: dict = None): + """调用 OneBot API""" if not self.ws: return {"status": "failed", "msg": "websocket not initialized"} - # 检查 websockets 13.x+ 的状态属性 from websockets.protocol import State if getattr(self.ws, "state", None) is not State.OPEN: return {"status": "failed", "msg": "websocket is not open"} echo_id = str(uuid.uuid4()) - payload = { - "action": action, - "params": params or {}, - "echo": echo_id - } + payload = {"action": action, "params": params or {}, "echo": echo_id} - # 创建一个 Future 对象用于等待返回结果 loop = asyncio.get_running_loop() future = loop.create_future() self._pending_requests[echo_id] = future - # 通过 WebSocket 发送请求 await self.ws.send(json.dumps(payload)) try: - # 设置 100 秒超时,防止 API 请求永久挂起 - return await asyncio.wait_for(future, timeout=100.0) + + return await asyncio.wait_for(future, timeout=30.0) except asyncio.TimeoutError: self._pending_requests.pop(echo_id, None) return {"status": "failed", "retcode": -1, "msg": "api timeout"} \ No newline at end of file diff --git a/models/event.py b/models/event.py index 6a80ff6..06b58cd 100644 --- a/models/event.py +++ b/models/event.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List, Optional, Dict, Any -from .sender import Sender # 导入上面的 Sender +from .sender import Sender @dataclass class MessageSegment: @@ -9,16 +9,13 @@ class MessageSegment: @property def text(self) -> str: - """如果是文本段,返回文本内容,否则返回空字符串""" return self.data.get("text", "") if self.type == "text" else "" @property def image_url(self) -> str: - """如果是图片段,返回图片 URL""" return self.data.get("url", "") if self.type == "image" else "" def is_at(self, user_id: int = None) -> bool: - """判断是否是 @某人""" if self.type != "at": return False if user_id is None: @@ -28,36 +25,72 @@ class MessageSegment: def __repr__(self): return f"[MS:{self.type}:{self.data}]" + + @dataclass class Event: post_type: str - message_type: str # group 或 private - user_id: int self_id: int - raw_message: str - message: List[MessageSegment] - sender: Sender time: int + + message_type: Optional[str] = None + sub_type: Optional[str] = None + message_id: Optional[int] = None + user_id: Optional[int] = None + raw_message: Optional[str] = None + message: List[MessageSegment] = field(default_factory=list) + sender: Optional[Sender] = None group_id: Optional[int] = None target_id: Optional[int] = None + notice_type: Optional[str] = None + operator_id: Optional[int] = None + duration: Optional[int] = None + honor_type: Optional[str] = None + + request_type: Optional[str] = None + flag: Optional[str] = None + comment: Optional[str] = None + @classmethod def from_dict(cls, data: dict): - raw_msg_array = data.get("message", []) - segments = [ - MessageSegment(type=seg["type"], data=seg["data"]) - for seg in raw_msg_array - ] + raw_msg_array = data.get("message") + segments = [] + if isinstance(raw_msg_array, list): + segments = [ + MessageSegment(type=seg["type"], data=seg["data"]) + for seg in raw_msg_array + ] - data_copy = data.copy() - data_copy["message"] = segments - - sender_data = data.get("sender", {}) - sender_obj = Sender(**{k: v for k, v in sender_data.items() if k in Sender.__annotations__}) + sender_data = data.get("sender") + sender_obj = None + if isinstance(sender_data, dict): + sender_obj = Sender(**{ + k: v for k, v in sender_data.items() + if k in Sender.__annotations__ + }) - data_copy = data.copy() - data_copy["message"] = segments - data_copy["sender"] = sender_obj # 关键点:把对象塞进去 + # 数据整合 + processed_data = data.copy() + processed_data["message"] = segments + processed_data["sender"] = sender_obj - valid_data = {k: v for k, v in data_copy.items() if k in cls.__annotations__} - return cls(**valid_data) \ No newline at end of file + # 字段过滤:只提取 dataclass 中定义的字段 + valid_data = { + k: v for k, v in processed_data.items() + if k in cls.__annotations__ + } + return cls(**valid_data) + + # --- 快捷判断工具 --- + @property + def is_message(self) -> bool: + return self.post_type == "message" + + @property + def is_notice(self) -> bool: + return self.post_type == "notice" + + @property + def is_request(self) -> bool: + return self.post_type == "request" \ No newline at end of file diff --git a/models/sender.py b/models/sender.py deleted file mode 100644 index 13f71cb..0000000 --- a/models/sender.py +++ /dev/null @@ -1,8 +0,0 @@ -from dataclasses import dataclass -#TODO 数据类型 -@dataclass -class Sender: - user_id: int - nickname: str - card: str = "" - role: str = "" # admin, owner, member \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1e0504e..7278891 100644 Binary files a/requirements.txt and b/requirements.txt differ