This commit is contained in:
baby20162016
2026-01-01 00:58:01 +08:00
parent 386534c250
commit 3e91c05688
11 changed files with 114 additions and 80 deletions

View File

@@ -1,5 +1,5 @@
from .ws import WS
from .command_manager import matcher
from .config_loader import global_config
from .ws import WS
__all__ = ["WS", "matcher", "global_config"]
__all__ = ["WS", "matcher", "global_config"]

View File

@@ -1,39 +1,47 @@
import inspect
from typing import Any, Tuple, Dict, List, Callable
from typing import Any, Callable, Dict, List, Tuple
from .config_loader import global_config
# 从配置中获取命令前缀
comm_prefixes = global_config.bot.get("command", ("/",))
class CommandManager:
def __init__(self, prefixes: Tuple[str, ...] = ("/",)):
self.prefixes = prefixes
self.commands: Dict[str, Callable] = {} # 存储消息指令
self.notice_handlers: List[Dict] = [] # 存储通知处理器
self.request_handlers: List[Dict] = [] # 存储请求处理器
self.commands: Dict[str, Callable] = {} # 存储消息指令
self.notice_handlers: List[Dict] = [] # 存储通知处理器
self.request_handlers: List[Dict] = [] # 存储请求处理器
# --- 1. 消息指令装饰器 ---
def command(self, name: str):
"""装饰器:注册消息指令,例如 @matcher.command("echo")"""
def decorator(func):
self.commands[name] = func
return func
return decorator
# --- 2. 通知事件装饰器 ---
def on_notice(self, notice_type: str = None):
"""装饰器:注册通知处理器"""
def decorator(func):
self.notice_handlers.append({"type": notice_type, "func": func})
return func
return decorator
# --- 3. 请求事件装饰器 ---
def on_request(self, request_type: str = None):
"""装饰器:注册请求处理器"""
def decorator(func):
self.request_handlers.append({"type": request_type, "func": func})
return func
return decorator
# --- 消息分发逻辑 ---
@@ -41,24 +49,24 @@ class CommandManager:
"""解析并分发消息指令"""
if not event.raw_message:
return
raw_text = event.raw_message.strip()
# 1. 检查前缀
prefix_found = None
for p in self.prefixes:
if raw_text.startswith(p):
prefix_found = p
break
if not prefix_found:
return
return
# 2. 拆分指令和参数
full_cmd = raw_text[len(prefix_found):].split()
full_cmd = raw_text[len(prefix_found) :].split()
if not full_cmd:
return
cmd_name = full_cmd[0]
args = full_cmd[1:]
@@ -87,14 +95,18 @@ class CommandManager:
sig = inspect.signature(func)
params = sig.parameters
kwargs = {}
if "bot" in params: kwargs["bot"] = bot
if "event" in params: kwargs["event"] = event
if "args" in params and args is not None: kwargs["args"] = args
if "bot" in params:
kwargs["bot"] = bot
if "event" in params:
kwargs["event"] = event
if "args" in params and args is not None:
kwargs["args"] = args
# 执行函数
await func(**kwargs)
# 确保前缀是元组格式
if isinstance(comm_prefixes, list):
comm_prefixes = tuple[Any, ...](comm_prefixes)
@@ -102,4 +114,4 @@ elif isinstance(comm_prefixes, str):
comm_prefixes = (comm_prefixes,)
# 实例化全局管理器
matcher = CommandManager(prefixes=comm_prefixes)
matcher = CommandManager(prefixes=comm_prefixes)

View File

@@ -1,16 +1,19 @@
import tomllib
from pathlib import Path
from typing import Any, Dict
import tomllib
class Config:
def __init__(self, file_path: str = "config.toml"):
self.path = Path(file_path)
self._data: Dict[str, Any] = {}
self.load()
def load(self):
if not self.path.exists():
raise FileNotFoundError(f"配置文件 {self.path} 未找到!")
with open(self.path, "rb") as f:
self._data = tomllib.load(f)
@@ -27,6 +30,7 @@ class Config:
def features(self) -> dict:
return self._data.get("features", {})
# 实例化全局配置对象
global_config = Config()
@@ -35,4 +39,3 @@ if __name__ == "__main__":
print(global_config.bot.get("command"))
print(type(global_config.bot.get("command")) is list)
print(global_config.features)

View File

@@ -1,13 +1,17 @@
import asyncio
import json
import uuid
import websockets
import traceback
from .command_manager import matcher
from .config_loader import global_config
from models import Event
import uuid
from datetime import datetime
import websockets
from models import Event
from .command_manager import matcher
from .config_loader import global_config
class WS:
def __init__(self):
# 读取参数
@@ -15,28 +19,33 @@ class WS:
self.url = cfg.get("uri")
self.token = cfg.get("token")
self.reconnect_interval = cfg.get("reconnect_interval", 5)
self.ws = None
self._pending_requests = {}
self.ws = None
self._pending_requests = {}
async def connect(self):
"""主连接循环"""
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
while True:
try:
print(f" 正在尝试连接至 NapCat: {self.url}")
async with websockets.connect(self.url, additional_headers=headers) as websocket:
async with websockets.connect(
self.url, additional_headers=headers
) as websocket:
self.ws = websocket
print(" 连接成功!")
await self._listen_loop(websocket)
except (websockets.exceptions.ConnectionClosed, ConnectionRefusedError) as e:
except (
websockets.exceptions.ConnectionClosed,
ConnectionRefusedError,
) as e:
print(f" 连接断开或服务器拒绝访问: {e}")
except Exception as e:
print(f" 运行异常: {e}")
traceback.print_exc()
print(f" {self.reconnect_interval}秒后尝试重连...")
await asyncio.sleep(self.reconnect_interval)
@@ -45,20 +54,20 @@ class WS:
async for message in websocket:
try:
data = json.loads(message)
# 1. 处理 API 响应
echo_id = data.get("echo")
if echo_id and echo_id in self._pending_requests:
future = self._pending_requests.pop(echo_id)
if not future.done():
future.set_result(data)
continue
continue
# 2. 处理上报事件
if "post_type" in data:
# 使用 create_task 异步执行,避免阻塞
asyncio.create_task(self.on_event(data))
except Exception as e:
print(f" 解析消息异常: {e}")
@@ -67,7 +76,7 @@ class WS:
try:
# 解析为 Event 对象
event = Event.from_dict(raw_data)
# 格式化时间用于打印
t = datetime.fromtimestamp(event.time).strftime("%H:%M:%S")
@@ -75,12 +84,16 @@ class WS:
# A. 消息事件 (Message)
if event.post_type == "message":
print(f" [{t}] [消息] {event.message_type} | {event.user_id}: {event.raw_message}")
print(
f" [{t}] [消息] {event.message_type} | {event.user_id}: {event.raw_message}"
)
await matcher.handle_message(self, event)
# B. 通知事件 (Notice)
elif event.post_type == "notice":
print(f" [{t}] [通知] {event.notice_type} | 来自: {event.group_id or '私聊'}")
print(
f" [{t}] [通知] {event.notice_type} | 来自: {event.group_id or '私聊'}"
)
await matcher.handle_notice(self, event)
# C. 请求事件 (Request)
@@ -100,23 +113,23 @@ class WS:
"""调用 OneBot API"""
if not self.ws:
return {"status": "failed", "msg": "websocket not initialized"}
from websockets.protocol import State
if getattr(self.ws, "state", None) is not State.OPEN:
return {"status": "failed", "msg": "websocket is not open"}
return {"status": "failed", "msg": "websocket is not open"}
echo_id = str(uuid.uuid4())
payload = {"action": action, "params": params or {}, "echo": echo_id}
loop = asyncio.get_running_loop()
future = loop.create_future()
self._pending_requests[echo_id] = future
await self.ws.send(json.dumps(payload))
try:
return await asyncio.wait_for(future, timeout=30.0)
except asyncio.TimeoutError:
self._pending_requests.pop(echo_id, None)
return {"status": "failed", "retcode": -1, "msg": "api timeout"}
return {"status": "failed", "retcode": -1, "msg": "api timeout"}