This commit is contained in:
baby20162016
2026-01-01 00:58:01 +08:00
parent 386534c250
commit 3e91c05688
11 changed files with 114 additions and 80 deletions

View File

@@ -1,7 +1,8 @@
import os
import importlib import importlib
import os
import pkgutil import pkgutil
def load_all_plugins(): def load_all_plugins():
"""扫描并加载当前包下所有的插件(支持文件和文件夹)""" """扫描并加载当前包下所有的插件(支持文件和文件夹)"""
package_name = __package__ package_name = __package__
@@ -11,7 +12,7 @@ def load_all_plugins():
for loader, module_name, is_pkg in pkgutil.iter_modules(package_path): for loader, module_name, is_pkg in pkgutil.iter_modules(package_path):
full_module_name = f"{package_name}.{module_name}" full_module_name = f"{package_name}.{module_name}"
try: try:
importlib.import_module(full_module_name) importlib.import_module(full_module_name)
type_str = "" if is_pkg else "文件" type_str = "" if is_pkg else "文件"
@@ -19,4 +20,5 @@ def load_all_plugins():
except Exception as e: except Exception as e:
print(f" 加载插件 {module_name} 失败: {e}") print(f" 加载插件 {module_name} 失败: {e}")
load_all_plugins()
load_all_plugins()

View File

@@ -1,5 +1,7 @@
from core.command_manager import matcher from core.command_manager import matcher
#TODO 把该死的这些给抽象化
# TODO 把该死的这些给抽象化
@matcher.command("echo") @matcher.command("echo")
async def handle_echo(bot, event, args): async def handle_echo(bot, event, args):
if not args: if not args:
@@ -8,12 +10,10 @@ async def handle_echo(bot, event, args):
reply_msg = " ".join(args) reply_msg = " ".join(args)
if event.message_type == "group": if event.message_type == "group":
await bot.call_api("send_group_msg", { await bot.call_api(
"group_id": event.group_id, "send_group_msg", {"group_id": event.group_id, "message": reply_msg}
"message": reply_msg )
})
else: else:
await bot.call_api("send_private_msg", { await bot.call_api(
"user_id": event.user_id, "send_private_msg", {"user_id": event.user_id, "message": reply_msg}
"message": reply_msg )
})

View File

@@ -1,5 +1,5 @@
from .ws import WS
from .command_manager import matcher from .command_manager import matcher
from .config_loader import global_config from .config_loader import global_config
from .ws import WS
__all__ = ["WS", "matcher", "global_config"] __all__ = ["WS", "matcher", "global_config"]

View File

@@ -1,39 +1,47 @@
import inspect import inspect
from typing import Any, Tuple, Dict, List, Callable from typing import Any, Callable, Dict, List, Tuple
from .config_loader import global_config from .config_loader import global_config
# 从配置中获取命令前缀 # 从配置中获取命令前缀
comm_prefixes = global_config.bot.get("command", ("/",)) comm_prefixes = global_config.bot.get("command", ("/",))
class CommandManager: class CommandManager:
def __init__(self, prefixes: Tuple[str, ...] = ("/",)): def __init__(self, prefixes: Tuple[str, ...] = ("/",)):
self.prefixes = prefixes self.prefixes = prefixes
self.commands: Dict[str, Callable] = {} # 存储消息指令 self.commands: Dict[str, Callable] = {} # 存储消息指令
self.notice_handlers: List[Dict] = [] # 存储通知处理器 self.notice_handlers: List[Dict] = [] # 存储通知处理器
self.request_handlers: List[Dict] = [] # 存储请求处理器 self.request_handlers: List[Dict] = [] # 存储请求处理器
# --- 1. 消息指令装饰器 --- # --- 1. 消息指令装饰器 ---
def command(self, name: str): def command(self, name: str):
"""装饰器:注册消息指令,例如 @matcher.command("echo")""" """装饰器:注册消息指令,例如 @matcher.command("echo")"""
def decorator(func): def decorator(func):
self.commands[name] = func self.commands[name] = func
return func return func
return decorator return decorator
# --- 2. 通知事件装饰器 --- # --- 2. 通知事件装饰器 ---
def on_notice(self, notice_type: str = None): def on_notice(self, notice_type: str = None):
"""装饰器:注册通知处理器""" """装饰器:注册通知处理器"""
def decorator(func): def decorator(func):
self.notice_handlers.append({"type": notice_type, "func": func}) self.notice_handlers.append({"type": notice_type, "func": func})
return func return func
return decorator return decorator
# --- 3. 请求事件装饰器 --- # --- 3. 请求事件装饰器 ---
def on_request(self, request_type: str = None): def on_request(self, request_type: str = None):
"""装饰器:注册请求处理器""" """装饰器:注册请求处理器"""
def decorator(func): def decorator(func):
self.request_handlers.append({"type": request_type, "func": func}) self.request_handlers.append({"type": request_type, "func": func})
return func return func
return decorator return decorator
# --- 消息分发逻辑 --- # --- 消息分发逻辑 ---
@@ -41,24 +49,24 @@ class CommandManager:
"""解析并分发消息指令""" """解析并分发消息指令"""
if not event.raw_message: if not event.raw_message:
return return
raw_text = event.raw_message.strip() raw_text = event.raw_message.strip()
# 1. 检查前缀 # 1. 检查前缀
prefix_found = None prefix_found = None
for p in self.prefixes: for p in self.prefixes:
if raw_text.startswith(p): if raw_text.startswith(p):
prefix_found = p prefix_found = p
break break
if not prefix_found: if not prefix_found:
return return
# 2. 拆分指令和参数 # 2. 拆分指令和参数
full_cmd = raw_text[len(prefix_found):].split() full_cmd = raw_text[len(prefix_found) :].split()
if not full_cmd: if not full_cmd:
return return
cmd_name = full_cmd[0] cmd_name = full_cmd[0]
args = full_cmd[1:] args = full_cmd[1:]
@@ -87,14 +95,18 @@ class CommandManager:
sig = inspect.signature(func) sig = inspect.signature(func)
params = sig.parameters params = sig.parameters
kwargs = {} kwargs = {}
if "bot" in params: kwargs["bot"] = bot if "bot" in params:
if "event" in params: kwargs["event"] = event kwargs["bot"] = bot
if "args" in params and args is not None: kwargs["args"] = args if "event" in params:
kwargs["event"] = event
if "args" in params and args is not None:
kwargs["args"] = args
# 执行函数 # 执行函数
await func(**kwargs) await func(**kwargs)
# 确保前缀是元组格式 # 确保前缀是元组格式
if isinstance(comm_prefixes, list): if isinstance(comm_prefixes, list):
comm_prefixes = tuple[Any, ...](comm_prefixes) comm_prefixes = tuple[Any, ...](comm_prefixes)
@@ -102,4 +114,4 @@ elif isinstance(comm_prefixes, str):
comm_prefixes = (comm_prefixes,) comm_prefixes = (comm_prefixes,)
# 实例化全局管理器 # 实例化全局管理器
matcher = CommandManager(prefixes=comm_prefixes) matcher = CommandManager(prefixes=comm_prefixes)

View File

@@ -1,16 +1,19 @@
import tomllib
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
import tomllib
class Config: class Config:
def __init__(self, file_path: str = "config.toml"): def __init__(self, file_path: str = "config.toml"):
self.path = Path(file_path) self.path = Path(file_path)
self._data: Dict[str, Any] = {} self._data: Dict[str, Any] = {}
self.load() self.load()
def load(self): def load(self):
if not self.path.exists(): if not self.path.exists():
raise FileNotFoundError(f"配置文件 {self.path} 未找到!") raise FileNotFoundError(f"配置文件 {self.path} 未找到!")
with open(self.path, "rb") as f: with open(self.path, "rb") as f:
self._data = tomllib.load(f) self._data = tomllib.load(f)
@@ -27,6 +30,7 @@ class Config:
def features(self) -> dict: def features(self) -> dict:
return self._data.get("features", {}) return self._data.get("features", {})
# 实例化全局配置对象 # 实例化全局配置对象
global_config = Config() global_config = Config()
@@ -35,4 +39,3 @@ if __name__ == "__main__":
print(global_config.bot.get("command")) print(global_config.bot.get("command"))
print(type(global_config.bot.get("command")) is list) print(type(global_config.bot.get("command")) is list)
print(global_config.features) print(global_config.features)

View File

@@ -1,13 +1,17 @@
import asyncio import asyncio
import json import json
import uuid
import websockets
import traceback import traceback
from .command_manager import matcher import uuid
from .config_loader import global_config
from models import Event
from datetime import datetime from datetime import datetime
import websockets
from models import Event
from .command_manager import matcher
from .config_loader import global_config
class WS: class WS:
def __init__(self): def __init__(self):
# 读取参数 # 读取参数
@@ -15,28 +19,33 @@ class WS:
self.url = cfg.get("uri") self.url = cfg.get("uri")
self.token = cfg.get("token") self.token = cfg.get("token")
self.reconnect_interval = cfg.get("reconnect_interval", 5) self.reconnect_interval = cfg.get("reconnect_interval", 5)
self.ws = None self.ws = None
self._pending_requests = {} self._pending_requests = {}
async def connect(self): async def connect(self):
"""主连接循环""" """主连接循环"""
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
while True: while True:
try: try:
print(f" 正在尝试连接至 NapCat: {self.url}") print(f" 正在尝试连接至 NapCat: {self.url}")
async with websockets.connect(self.url, additional_headers=headers) as websocket: async with websockets.connect(
self.url, additional_headers=headers
) as websocket:
self.ws = websocket self.ws = websocket
print(" 连接成功!") print(" 连接成功!")
await self._listen_loop(websocket) await self._listen_loop(websocket)
except (websockets.exceptions.ConnectionClosed, ConnectionRefusedError) as e: except (
websockets.exceptions.ConnectionClosed,
ConnectionRefusedError,
) as e:
print(f" 连接断开或服务器拒绝访问: {e}") print(f" 连接断开或服务器拒绝访问: {e}")
except Exception as e: except Exception as e:
print(f" 运行异常: {e}") print(f" 运行异常: {e}")
traceback.print_exc() traceback.print_exc()
print(f" {self.reconnect_interval}秒后尝试重连...") print(f" {self.reconnect_interval}秒后尝试重连...")
await asyncio.sleep(self.reconnect_interval) await asyncio.sleep(self.reconnect_interval)
@@ -45,20 +54,20 @@ class WS:
async for message in websocket: async for message in websocket:
try: try:
data = json.loads(message) data = json.loads(message)
# 1. 处理 API 响应 # 1. 处理 API 响应
echo_id = data.get("echo") echo_id = data.get("echo")
if echo_id and echo_id in self._pending_requests: if echo_id and echo_id in self._pending_requests:
future = self._pending_requests.pop(echo_id) future = self._pending_requests.pop(echo_id)
if not future.done(): if not future.done():
future.set_result(data) future.set_result(data)
continue continue
# 2. 处理上报事件 # 2. 处理上报事件
if "post_type" in data: if "post_type" in data:
# 使用 create_task 异步执行,避免阻塞 # 使用 create_task 异步执行,避免阻塞
asyncio.create_task(self.on_event(data)) asyncio.create_task(self.on_event(data))
except Exception as e: except Exception as e:
print(f" 解析消息异常: {e}") print(f" 解析消息异常: {e}")
@@ -67,7 +76,7 @@ class WS:
try: try:
# 解析为 Event 对象 # 解析为 Event 对象
event = Event.from_dict(raw_data) event = Event.from_dict(raw_data)
# 格式化时间用于打印 # 格式化时间用于打印
t = datetime.fromtimestamp(event.time).strftime("%H:%M:%S") t = datetime.fromtimestamp(event.time).strftime("%H:%M:%S")
@@ -75,12 +84,16 @@ class WS:
# A. 消息事件 (Message) # A. 消息事件 (Message)
if event.post_type == "message": if event.post_type == "message":
print(f" [{t}] [消息] {event.message_type} | {event.user_id}: {event.raw_message}") print(
f" [{t}] [消息] {event.message_type} | {event.user_id}: {event.raw_message}"
)
await matcher.handle_message(self, event) await matcher.handle_message(self, event)
# B. 通知事件 (Notice) # B. 通知事件 (Notice)
elif event.post_type == "notice": elif event.post_type == "notice":
print(f" [{t}] [通知] {event.notice_type} | 来自: {event.group_id or '私聊'}") print(
f" [{t}] [通知] {event.notice_type} | 来自: {event.group_id or '私聊'}"
)
await matcher.handle_notice(self, event) await matcher.handle_notice(self, event)
# C. 请求事件 (Request) # C. 请求事件 (Request)
@@ -100,23 +113,23 @@ class WS:
"""调用 OneBot API""" """调用 OneBot API"""
if not self.ws: if not self.ws:
return {"status": "failed", "msg": "websocket not initialized"} return {"status": "failed", "msg": "websocket not initialized"}
from websockets.protocol import State from websockets.protocol import State
if getattr(self.ws, "state", None) is not State.OPEN: if getattr(self.ws, "state", None) is not State.OPEN:
return {"status": "failed", "msg": "websocket is not open"} return {"status": "failed", "msg": "websocket is not open"}
echo_id = str(uuid.uuid4()) echo_id = str(uuid.uuid4())
payload = {"action": action, "params": params or {}, "echo": echo_id} payload = {"action": action, "params": params or {}, "echo": echo_id}
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
future = loop.create_future() future = loop.create_future()
self._pending_requests[echo_id] = future self._pending_requests[echo_id] = future
await self.ws.send(json.dumps(payload)) await self.ws.send(json.dumps(payload))
try: try:
return await asyncio.wait_for(future, timeout=30.0) return await asyncio.wait_for(future, timeout=30.0)
except asyncio.TimeoutError: except asyncio.TimeoutError:
self._pending_requests.pop(echo_id, None) self._pending_requests.pop(echo_id, None)
return {"status": "failed", "retcode": -1, "msg": "api timeout"} return {"status": "failed", "retcode": -1, "msg": "api timeout"}

View File

@@ -1,10 +1,13 @@
# main.py # main.py
import asyncio import asyncio
from core import WS from core import WS
import base_plugins
async def main(): async def main():
bot = WS() bot = WS()
await bot.connect() await bot.connect()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,2 +1,3 @@
from .event import Event, MessageSegment from .event import Event, MessageSegment
__all__ = ["Event", "MessageSegment", "Sender"]
__all__ = ["Event", "MessageSegment", "Sender"]

View File

@@ -1 +1 @@
#TODO 数据类型 # TODO 数据类型

View File

@@ -1,7 +1,9 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any from typing import Any, Dict, List, Optional
from .sender import Sender from .sender import Sender
@dataclass @dataclass
class MessageSegment: class MessageSegment:
type: str type: str
@@ -26,7 +28,6 @@ class MessageSegment:
return f"[MS:{self.type}:{self.data}]" return f"[MS:{self.type}:{self.data}]"
@dataclass @dataclass
class Event: class Event:
post_type: str post_type: str
@@ -58,27 +59,25 @@ class Event:
segments = [] segments = []
if isinstance(raw_msg_array, list): if isinstance(raw_msg_array, list):
segments = [ segments = [
MessageSegment(type=seg["type"], data=seg["data"]) MessageSegment(type=seg["type"], data=seg["data"])
for seg in raw_msg_array for seg in raw_msg_array
] ]
sender_data = data.get("sender") sender_data = data.get("sender")
sender_obj = None sender_obj = None
if isinstance(sender_data, dict): if isinstance(sender_data, dict):
sender_obj = Sender(**{ sender_obj = Sender(
k: v for k, v in sender_data.items() **{k: v for k, v in sender_data.items() if k in Sender.__annotations__}
if k in Sender.__annotations__ )
})
# 数据整合 # 数据整合
processed_data = data.copy() processed_data = data.copy()
processed_data["message"] = segments processed_data["message"] = segments
processed_data["sender"] = sender_obj processed_data["sender"] = sender_obj
# 字段过滤:只提取 dataclass 中定义的字段 # 字段过滤:只提取 dataclass 中定义的字段
valid_data = { valid_data = {
k: v for k, v in processed_data.items() k: v for k, v in processed_data.items() if k in cls.__annotations__
if k in cls.__annotations__
} }
return cls(**valid_data) return cls(**valid_data)
@@ -93,4 +92,4 @@ class Event:
@property @property
def is_request(self) -> bool: def is_request(self) -> bool:
return self.post_type == "request" return self.post_type == "request"

View File

@@ -1,16 +1,17 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
@dataclass @dataclass
class Sender: class Sender:
user_id: int user_id: int
nickname: str nickname: str
sex: str = "unknown" sex: str = "unknown"
age: int = 0 age: int = 0
# 群聊特有字段 # 群聊特有字段
card: Optional[str] = None # 群名片 card: Optional[str] = None # 群名片
area: Optional[str] = None # 地区 area: Optional[str] = None # 地区
level: Optional[str] = None # 等级 level: Optional[str] = None # 等级
role: Optional[str] = None # 角色: owner/admin/member role: Optional[str] = None # 角色: owner/admin/member
title: Optional[str] = None # 专属头衔 title: Optional[str] = None # 专属头衔