This commit is contained in:
baby20162016
2026-01-01 00:58:01 +08:00
parent 386534c250
commit 3e91c05688
11 changed files with 114 additions and 80 deletions

View File

@@ -1,7 +1,8 @@
import os
import importlib import importlib
import os
import pkgutil import pkgutil
def load_all_plugins(): def load_all_plugins():
"""扫描并加载当前包下所有的插件(支持文件和文件夹)""" """扫描并加载当前包下所有的插件(支持文件和文件夹)"""
package_name = __package__ package_name = __package__
@@ -19,4 +20,5 @@ def load_all_plugins():
except Exception as e: except Exception as e:
print(f" 加载插件 {module_name} 失败: {e}") print(f" 加载插件 {module_name} 失败: {e}")
load_all_plugins() load_all_plugins()

View File

@@ -1,4 +1,6 @@
from core.command_manager import matcher from core.command_manager import matcher
# TODO 把该死的这些给抽象化 # TODO 把该死的这些给抽象化
@matcher.command("echo") @matcher.command("echo")
async def handle_echo(bot, event, args): async def handle_echo(bot, event, args):
@@ -8,12 +10,10 @@ async def handle_echo(bot, event, args):
reply_msg = " ".join(args) reply_msg = " ".join(args)
if event.message_type == "group": if event.message_type == "group":
await bot.call_api("send_group_msg", { await bot.call_api(
"group_id": event.group_id, "send_group_msg", {"group_id": event.group_id, "message": reply_msg}
"message": reply_msg )
})
else: else:
await bot.call_api("send_private_msg", { await bot.call_api(
"user_id": event.user_id, "send_private_msg", {"user_id": event.user_id, "message": reply_msg}
"message": reply_msg )
})

View File

@@ -1,5 +1,5 @@
from .ws import WS
from .command_manager import matcher from .command_manager import matcher
from .config_loader import global_config from .config_loader import global_config
from .ws import WS
__all__ = ["WS", "matcher", "global_config"] __all__ = ["WS", "matcher", "global_config"]

View File

@@ -1,10 +1,12 @@
import inspect import inspect
from typing import Any, Tuple, Dict, List, Callable from typing import Any, Callable, Dict, List, Tuple
from .config_loader import global_config from .config_loader import global_config
# 从配置中获取命令前缀 # 从配置中获取命令前缀
comm_prefixes = global_config.bot.get("command", ("/",)) comm_prefixes = global_config.bot.get("command", ("/",))
class CommandManager: class CommandManager:
def __init__(self, prefixes: Tuple[str, ...] = ("/",)): def __init__(self, prefixes: Tuple[str, ...] = ("/",)):
self.prefixes = prefixes self.prefixes = prefixes
@@ -15,25 +17,31 @@ class CommandManager:
# --- 1. 消息指令装饰器 --- # --- 1. 消息指令装饰器 ---
def command(self, name: str): def command(self, name: str):
"""装饰器:注册消息指令,例如 @matcher.command("echo")""" """装饰器:注册消息指令,例如 @matcher.command("echo")"""
def decorator(func): def decorator(func):
self.commands[name] = func self.commands[name] = func
return func return func
return decorator return decorator
# --- 2. 通知事件装饰器 --- # --- 2. 通知事件装饰器 ---
def on_notice(self, notice_type: str = None): def on_notice(self, notice_type: str = None):
"""装饰器:注册通知处理器""" """装饰器:注册通知处理器"""
def decorator(func): def decorator(func):
self.notice_handlers.append({"type": notice_type, "func": func}) self.notice_handlers.append({"type": notice_type, "func": func})
return func return func
return decorator return decorator
# --- 3. 请求事件装饰器 --- # --- 3. 请求事件装饰器 ---
def on_request(self, request_type: str = None): def on_request(self, request_type: str = None):
"""装饰器:注册请求处理器""" """装饰器:注册请求处理器"""
def decorator(func): def decorator(func):
self.request_handlers.append({"type": request_type, "func": func}) self.request_handlers.append({"type": request_type, "func": func})
return func return func
return decorator return decorator
# --- 消息分发逻辑 --- # --- 消息分发逻辑 ---
@@ -88,13 +96,17 @@ class CommandManager:
params = sig.parameters params = sig.parameters
kwargs = {} kwargs = {}
if "bot" in params: kwargs["bot"] = bot if "bot" in params:
if "event" in params: kwargs["event"] = event kwargs["bot"] = bot
if "args" in params and args is not None: kwargs["args"] = args if "event" in params:
kwargs["event"] = event
if "args" in params and args is not None:
kwargs["args"] = args
# 执行函数 # 执行函数
await func(**kwargs) await func(**kwargs)
# 确保前缀是元组格式 # 确保前缀是元组格式
if isinstance(comm_prefixes, list): if isinstance(comm_prefixes, list):
comm_prefixes = tuple[Any, ...](comm_prefixes) comm_prefixes = tuple[Any, ...](comm_prefixes)

View File

@@ -1,12 +1,15 @@
import tomllib
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
import tomllib
class Config: class Config:
def __init__(self, file_path: str = "config.toml"): def __init__(self, file_path: str = "config.toml"):
self.path = Path(file_path) self.path = Path(file_path)
self._data: Dict[str, Any] = {} self._data: Dict[str, Any] = {}
self.load() self.load()
def load(self): def load(self):
if not self.path.exists(): if not self.path.exists():
raise FileNotFoundError(f"配置文件 {self.path} 未找到!") raise FileNotFoundError(f"配置文件 {self.path} 未找到!")
@@ -27,6 +30,7 @@ class Config:
def features(self) -> dict: def features(self) -> dict:
return self._data.get("features", {}) return self._data.get("features", {})
# 实例化全局配置对象 # 实例化全局配置对象
global_config = Config() global_config = Config()
@@ -35,4 +39,3 @@ if __name__ == "__main__":
print(global_config.bot.get("command")) print(global_config.bot.get("command"))
print(type(global_config.bot.get("command")) is list) print(type(global_config.bot.get("command")) is list)
print(global_config.features) print(global_config.features)

View File

@@ -1,12 +1,16 @@
import asyncio import asyncio
import json import json
import uuid
import websockets
import traceback import traceback
import uuid
from datetime import datetime
import websockets
from models import Event
from .command_manager import matcher from .command_manager import matcher
from .config_loader import global_config from .config_loader import global_config
from models import Event
from datetime import datetime
class WS: class WS:
def __init__(self): def __init__(self):
@@ -26,12 +30,17 @@ class WS:
while True: while True:
try: try:
print(f" 正在尝试连接至 NapCat: {self.url}") print(f" 正在尝试连接至 NapCat: {self.url}")
async with websockets.connect(self.url, additional_headers=headers) as websocket: async with websockets.connect(
self.url, additional_headers=headers
) as websocket:
self.ws = websocket self.ws = websocket
print(" 连接成功!") print(" 连接成功!")
await self._listen_loop(websocket) await self._listen_loop(websocket)
except (websockets.exceptions.ConnectionClosed, ConnectionRefusedError) as e: except (
websockets.exceptions.ConnectionClosed,
ConnectionRefusedError,
) as e:
print(f" 连接断开或服务器拒绝访问: {e}") print(f" 连接断开或服务器拒绝访问: {e}")
except Exception as e: except Exception as e:
print(f" 运行异常: {e}") print(f" 运行异常: {e}")
@@ -75,12 +84,16 @@ class WS:
# A. 消息事件 (Message) # A. 消息事件 (Message)
if event.post_type == "message": if event.post_type == "message":
print(f" [{t}] [消息] {event.message_type} | {event.user_id}: {event.raw_message}") print(
f" [{t}] [消息] {event.message_type} | {event.user_id}: {event.raw_message}"
)
await matcher.handle_message(self, event) await matcher.handle_message(self, event)
# B. 通知事件 (Notice) # B. 通知事件 (Notice)
elif event.post_type == "notice": elif event.post_type == "notice":
print(f" [{t}] [通知] {event.notice_type} | 来自: {event.group_id or '私聊'}") print(
f" [{t}] [通知] {event.notice_type} | 来自: {event.group_id or '私聊'}"
)
await matcher.handle_notice(self, event) await matcher.handle_notice(self, event)
# C. 请求事件 (Request) # C. 请求事件 (Request)
@@ -102,6 +115,7 @@ class WS:
return {"status": "failed", "msg": "websocket not initialized"} return {"status": "failed", "msg": "websocket not initialized"}
from websockets.protocol import State from websockets.protocol import State
if getattr(self.ws, "state", None) is not State.OPEN: if getattr(self.ws, "state", None) is not State.OPEN:
return {"status": "failed", "msg": "websocket is not open"} return {"status": "failed", "msg": "websocket is not open"}
@@ -115,7 +129,6 @@ class WS:
await self.ws.send(json.dumps(payload)) await self.ws.send(json.dumps(payload))
try: try:
return await asyncio.wait_for(future, timeout=30.0) return await asyncio.wait_for(future, timeout=30.0)
except asyncio.TimeoutError: except asyncio.TimeoutError:
self._pending_requests.pop(echo_id, None) self._pending_requests.pop(echo_id, None)

View File

@@ -1,10 +1,13 @@
# main.py # main.py
import asyncio import asyncio
from core import WS from core import WS
import base_plugins
async def main(): async def main():
bot = WS() bot = WS()
await bot.connect() await bot.connect()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,2 +1,3 @@
from .event import Event, MessageSegment from .event import Event, MessageSegment
__all__ = ["Event", "MessageSegment", "Sender"] __all__ = ["Event", "MessageSegment", "Sender"]

View File

@@ -1,7 +1,9 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any from typing import Any, Dict, List, Optional
from .sender import Sender from .sender import Sender
@dataclass @dataclass
class MessageSegment: class MessageSegment:
type: str type: str
@@ -26,7 +28,6 @@ class MessageSegment:
return f"[MS:{self.type}:{self.data}]" return f"[MS:{self.type}:{self.data}]"
@dataclass @dataclass
class Event: class Event:
post_type: str post_type: str
@@ -65,10 +66,9 @@ class Event:
sender_data = data.get("sender") sender_data = data.get("sender")
sender_obj = None sender_obj = None
if isinstance(sender_data, dict): if isinstance(sender_data, dict):
sender_obj = Sender(**{ sender_obj = Sender(
k: v for k, v in sender_data.items() **{k: v for k, v in sender_data.items() if k in Sender.__annotations__}
if k in Sender.__annotations__ )
})
# 数据整合 # 数据整合
processed_data = data.copy() processed_data = data.copy()
@@ -77,8 +77,7 @@ class Event:
# 字段过滤:只提取 dataclass 中定义的字段 # 字段过滤:只提取 dataclass 中定义的字段
valid_data = { valid_data = {
k: v for k, v in processed_data.items() k: v for k, v in processed_data.items() if k in cls.__annotations__
if k in cls.__annotations__
} }
return cls(**valid_data) return cls(**valid_data)

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
@dataclass @dataclass
class Sender: class Sender:
user_id: int user_id: int