diff --git a/adapters/discord_adapter.py b/adapters/discord_adapter.py index 155702b..c4a72de 100644 --- a/adapters/discord_adapter.py +++ b/adapters/discord_adapter.py @@ -112,7 +112,8 @@ class DiscordAdapter(discord.Client if DISCORD_AVAILABLE else object): try: data = json.loads(message["data"]) if data.get("type") == "send_message": - await self.handle_send_message(data) + # 使用 asyncio.create_task 异步处理消息,避免阻塞订阅循环 + asyncio.create_task(self.handle_send_message(data)) except json.JSONDecodeError as e: self.logger.error(f"[DiscordAdapter] 解析 Redis 消息失败: {e}") except Exception as e: diff --git a/adapters/router.py b/adapters/router.py index e1c97ef..372540e 100644 --- a/adapters/router.py +++ b/adapters/router.py @@ -356,7 +356,8 @@ class DiscordToOneBotConverter: # 注入 Discord 特定信息(用于跨平台插件识别) discord_channel_id = discord_message.channel.id if not isinstance(discord_message.channel, discord.DMChannel) else None - discord_username = discord_message.author.name + # 使用 global_name (显示名称/昵称) 如果存在,否则使用 name (用户名) + discord_username = getattr(discord_message.author, 'global_name', None) or discord_message.author.name discord_discriminator = f"#{discord_message.author.discriminator}" if discord_message.author.discriminator != "0" else "" if is_private: diff --git a/core/managers/__init__.py b/core/managers/__init__.py index cdda6aa..4e88f1a 100644 --- a/core/managers/__init__.py +++ b/core/managers/__init__.py @@ -13,6 +13,7 @@ from .browser_manager import BrowserManager from .image_manager import ImageManager from .reverse_ws_manager import ReverseWSManager from .thread_manager import thread_manager +from .vectordb_manager import vectordb_manager # --- 实例化所有单例管理器 --- @@ -55,4 +56,5 @@ __all__ = [ "image_manager", "reverse_ws_manager", "thread_manager", + "vectordb_manager", ] diff --git a/core/managers/vectordb_manager.py b/core/managers/vectordb_manager.py new file mode 100644 index 0000000..3be5eb7 --- /dev/null +++ b/core/managers/vectordb_manager.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +""" +向量数据库管理器模块 + +该模块提供了一个基于 ChromaDB 的向量数据库管理器, +用于存储和检索文本向量,为大语言模型提供记忆能力。 +""" +import os +import json +from typing import List, Dict, Any, Optional +import chromadb +from chromadb.config import Settings +from core.utils.logger import ModuleLogger +from core.utils.singleton import Singleton + +logger = ModuleLogger("VectorDBManager") + +class VectorDBManager(Singleton): + """ + 向量数据库管理器(单例) + """ + _client = None + _collections = {} + + def __init__(self): + super().__init__() + self.db_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "vectordb") + os.makedirs(self.db_path, exist_ok=True) + + def initialize(self): + """初始化 ChromaDB 客户端""" + if self._client is None: + try: + logger.info(f"正在初始化向量数据库,路径: {self.db_path}") + self._client = chromadb.PersistentClient( + path=self.db_path, + settings=Settings( + anonymized_telemetry=False, + allow_reset=True + ) + ) + logger.success("向量数据库初始化成功!") + except Exception as e: + logger.error(f"向量数据库初始化失败: {e}") + self._client = None + + def get_collection(self, name: str): + """获取或创建集合""" + if self._client is None: + self.initialize() + + if self._client is None: + return None + + if name not in self._collections: + try: + # 使用默认的 sentence-transformers 嵌入模型 + self._collections[name] = self._client.get_or_create_collection(name=name) + logger.debug(f"已获取/创建向量集合: {name}") + except Exception as e: + logger.error(f"获取向量集合 {name} 失败: {e}") + return None + + return self._collections[name] + + def add_texts(self, collection_name: str, texts: List[str], metadatas: List[Dict[str, Any]], ids: List[str]) -> bool: + """ + 向集合中添加文本 + + Args: + collection_name: 集合名称 + texts: 文本列表 + metadatas: 元数据列表(用于过滤和存储额外信息) + ids: 唯一ID列表 + """ + collection = self.get_collection(collection_name) + if collection is None: + return False + + try: + collection.add( + documents=texts, + metadatas=metadatas, + ids=ids + ) + logger.debug(f"成功向集合 {collection_name} 添加 {len(texts)} 条记录") + return True + except Exception as e: + logger.error(f"向集合 {collection_name} 添加记录失败: {e}") + return False + + def query_texts(self, collection_name: str, query_texts: List[str], n_results: int = 5, where: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + 查询相似文本 + + Args: + collection_name: 集合名称 + query_texts: 查询文本列表 + n_results: 返回结果数量 + where: 过滤条件 + """ + collection = self.get_collection(collection_name) + if collection is None: + return {"documents": [], "metadatas": [], "distances": []} + + try: + results = collection.query( + query_texts=query_texts, + n_results=n_results, + where=where + ) + return results + except Exception as e: + logger.error(f"查询集合 {collection_name} 失败: {e}") + return {"documents": [], "metadatas": [], "distances": []} + + def delete_texts(self, collection_name: str, ids: Optional[List[str]] = None, where: Optional[Dict[str, Any]] = None) -> bool: + """ + 删除文本 + """ + collection = self.get_collection(collection_name) + if collection is None: + return False + + try: + collection.delete(ids=ids, where=where) + logger.debug(f"成功从集合 {collection_name} 删除记录") + return True + except Exception as e: + logger.error(f"从集合 {collection_name} 删除记录失败: {e}") + return False + +# 全局向量数据库管理器实例 +vectordb_manager = VectorDBManager() diff --git a/data/vectordb/chroma.sqlite3 b/data/vectordb/chroma.sqlite3 new file mode 100644 index 0000000..c0ab1dd Binary files /dev/null and b/data/vectordb/chroma.sqlite3 differ diff --git a/main.py b/main.py index a6793eb..e2a8433 100644 --- a/main.py +++ b/main.py @@ -111,6 +111,10 @@ async def main(): 2. 初始化 WebSocket 客户端 3. 建立连接并保持运行 """ + # 初始化向量数据库 + from core.managers.vectordb_manager import vectordb_manager + vectordb_manager.initialize() + # 首先加载所有插件 plugin_manager.load_all_plugins() diff --git a/plugins/ai_chat.py b/plugins/ai_chat.py new file mode 100644 index 0000000..1e94bcf --- /dev/null +++ b/plugins/ai_chat.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +""" +AI 聊天插件,支持向量数据库记忆功能 +""" +import time +import uuid +from core.managers.command_manager import matcher +from models.events.message import GroupMessageEvent, PrivateMessageEvent +from core.managers.vectordb_manager import vectordb_manager +from core.utils.logger import ModuleLogger +from core.config_loader import global_config + +logger = ModuleLogger("AIChat") + +# 尝试导入 OpenAI 客户端 +try: + from openai import AsyncOpenAI + OPENAI_AVAILABLE = True +except ImportError: + OPENAI_AVAILABLE = False + +async def get_ai_response(user_id: int, group_id: int, user_message: str) -> str: + """获取 AI 回复,包含向量数据库记忆""" + if not OPENAI_AVAILABLE: + return "请先安装 openai 库: pip install openai" + + # 从配置中获取 DeepSeek API 配置(复用跨平台插件的配置或全局配置) + api_key = getattr(global_config.cross_platform, 'deepseek_api_key', None) or "your-api-key" + api_url = getattr(global_config.cross_platform, 'deepseek_api_url', "https://api.deepseek.com/v1") + model = getattr(global_config.cross_platform, 'deepseek_model', "deepseek-chat") + + if api_key == "your-api-key": + return "请先在配置中设置 DeepSeek API Key" + + # 1. 从向量数据库检索相关记忆 + collection_name = f"chat_memory_{user_id}" + memory_context = "" + + try: + results = vectordb_manager.query_texts( + collection_name=collection_name, + query_texts=[user_message], + n_results=3 + ) + + if results and results.get("documents") and results["documents"][0]: + memory_context = "\n\n相关历史记忆:\n" + for i, doc in enumerate(results["documents"][0], 1): + memory_context += f"{i}. {doc}\n" + except Exception as e: + logger.error(f"检索聊天记忆失败: {e}") + + # 2. 构建 Prompt + system_prompt = f"""你是一个友好的 AI 助手。请根据用户的输入进行回复。 +如果提供了相关历史记忆,请参考这些记忆来保持对话的连贯性。{memory_context}""" + + try: + client = AsyncOpenAI( + api_key=api_key, + base_url=api_url.replace("/chat/completions", "") + ) + + response = await client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message} + ], + temperature=0.7, + max_tokens=1000 + ) + + ai_reply = response.choices[0].message.content + + # 3. 将本次对话存入向量数据库 + if ai_reply: + try: + doc_id = str(uuid.uuid4()) + text_to_embed = f"用户: {user_message}\nAI: {ai_reply}" + metadata = { + "user_id": user_id, + "group_id": group_id, + "timestamp": int(time.time()) + } + + vectordb_manager.add_texts( + collection_name=collection_name, + texts=[text_to_embed], + metadatas=[metadata], + ids=[doc_id] + ) + except Exception as e: + logger.error(f"保存聊天记忆失败: {e}") + + return ai_reply + except Exception as e: + logger.error(f"AI 聊天请求失败: {e}") + return f"请求失败: {str(e)}" + +@matcher.command("chat", "聊天") +async def chat_command(event: GroupMessageEvent | PrivateMessageEvent, args: list[str]): + """AI 聊天命令""" + if not args: + await event.reply("请提供要聊天的内容,例如:/chat 你好") + return + + user_message = " ".join(args) + user_id = event.user_id + group_id = getattr(event, 'group_id', 0) + + await event.reply("正在思考中...") + reply = await get_ai_response(user_id, group_id, user_message) + await event.reply(reply) diff --git a/plugins/discord-cross/handlers.py b/plugins/discord-cross/handlers.py index 1e13b51..bc92f9a 100644 --- a/plugins/discord-cross/handlers.py +++ b/plugins/discord-cross/handlers.py @@ -148,7 +148,7 @@ async def handle_qq_group_message(event: GroupMessageEvent): group_name = f"群{group_id}" await handle_qq_message( - nickname=event.sender.nickname or event.sender.card or str(event.user_id), + nickname=event.sender.card or event.sender.nickname or str(event.user_id), user_id=event.user_id, group_name=group_name, group_id=group_id, diff --git a/plugins/discord-cross/translator.py b/plugins/discord-cross/translator.py index 472c7bf..8b9cf55 100644 --- a/plugins/discord-cross/translator.py +++ b/plugins/discord-cross/translator.py @@ -2,8 +2,11 @@ """ 跨平台消息互通插件翻译模块 """ +import time +import uuid from typing import Dict, List from core.utils.logger import ModuleLogger +from core.managers.vectordb_manager import vectordb_manager from .config import config # 创建模块专用日志记录器 @@ -19,7 +22,7 @@ def get_translation_context(channel_id: int, direction: str) -> List[Dict[str, s return TRANSLATION_CONTEXT_CACHE.get(cache_key, []) def add_translation_context(channel_id: int, direction: str, original: str, translated: str): - """添加翻译到上下文缓存""" + """添加翻译到上下文缓存和向量数据库""" cache_key = f"{channel_id}_{direction}" if cache_key not in TRANSLATION_CONTEXT_CACHE: TRANSLATION_CONTEXT_CACHE[cache_key] = [] @@ -31,6 +34,59 @@ def add_translation_context(channel_id: int, direction: str, original: str, tran if len(TRANSLATION_CONTEXT_CACHE[cache_key]) > MAX_CONTEXT_MESSAGES: TRANSLATION_CONTEXT_CACHE[cache_key] = TRANSLATION_CONTEXT_CACHE[cache_key][-MAX_CONTEXT_MESSAGES:] + + # 将翻译记录保存到向量数据库 + try: + collection_name = f"translation_memory_{channel_id}" + doc_id = str(uuid.uuid4()) + + # 将原文和译文组合作为向量化文本 + text_to_embed = f"原文: {original}\n译文: {translated}" + + metadata = { + "channel_id": channel_id, + "direction": direction, + "original": original, + "translated": translated, + "timestamp": int(time.time()) + } + + vectordb_manager.add_texts( + collection_name=collection_name, + texts=[text_to_embed], + metadatas=[metadata], + ids=[doc_id] + ) + logger.debug(f"[CrossPlatform] 翻译记录已保存到向量数据库: {collection_name}") + except Exception as e: + logger.error(f"[CrossPlatform] 保存翻译记录到向量数据库失败: {e}") + +def get_similar_translations(channel_id: int, text: str, direction: str, limit: int = 3) -> str: + """从向量数据库检索相似的翻译记录""" + try: + collection_name = f"translation_memory_{channel_id}" + + # 检索相似文本 + results = vectordb_manager.query_texts( + collection_name=collection_name, + query_texts=[text], + n_results=limit, + where={"direction": direction} + ) + + if not results or not results.get("documents") or not results["documents"][0]: + return "" + + context_ref = "\n\n参考历史相似翻译(向量检索):\n" + for i, metadata in enumerate(results["metadatas"][0], 1): + original = metadata.get("original", "") + translated = metadata.get("translated", "") + context_ref += f"{i}. 原文: {original[:100]}\n 译文: {translated[:100]}\n" + + return context_ref + except Exception as e: + logger.error(f"[CrossPlatform] 从向量数据库检索翻译记录失败: {e}") + return "" async def translate_with_deepseek( text: str, @@ -51,11 +107,17 @@ async def translate_with_deepseek( messages = [] context_ref = "" if channel_id > 0: + # 1. 获取最近的上下文缓存 context = get_translation_context(channel_id, direction) if context: - context_ref = "\n\n参考之前的翻译:\n" + context_ref = "\n\n参考最近的翻译:\n" for i, ctx in enumerate(context[-5:], 1): context_ref += f"{i}. 原文: {ctx['original'][:100]}\n 译文: {ctx['translated'][:100]}\n" + + # 2. 从向量数据库检索相似的历史翻译 + similar_context = get_similar_translations(channel_id, text, direction) + if similar_context: + context_ref += similar_context system_prompt = f"""你是一个专业的翻译助手。请将以下文本翻译成{lang_name}。 只返回翻译后的文本,不要添加任何解释、注释或其他内容。避免翻译出仇视言论以及违反中国大陆相关法律法规的内容。如果有,请在翻译后有敏感的词语中把文本替换成井号(#) @@ -115,11 +177,17 @@ async def translate_with_deepseek_sync( context_ref = "" if channel_id > 0: + # 1. 获取最近的上下文缓存 context = get_translation_context(channel_id, direction) if context: - context_ref = "\n\n参考之前的翻译:\n" + context_ref = "\n\n参考最近的翻译:\n" for i, ctx in enumerate(context[-5:], 1): context_ref += f"{i}. 原文: {ctx['original'][:100]}\n 译文: {ctx['translated'][:100]}\n" + + # 2. 从向量数据库检索相似的历史翻译 + similar_context = get_similar_translations(channel_id, text, direction) + if similar_context: + context_ref += similar_context system_prompt = f"""你是一个专业的翻译助手。请将以下文本翻译成{lang_name}。 只返回翻译后的文本,不要添加任何解释、注释或其他内容。避免翻译出仇视言论以及违反中国大陆相关法律法规的内容。如果有,请在翻译后有敏感的词语中把文本替换成井号(#) diff --git a/plugins/knowledge_base.py b/plugins/knowledge_base.py new file mode 100644 index 0000000..88dd5e2 --- /dev/null +++ b/plugins/knowledge_base.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +""" +群聊知识库插件,支持向量数据库检索 +""" +import time +import uuid +from core.managers.command_manager import matcher +from models.events.message import GroupMessageEvent +from core.managers.vectordb_manager import vectordb_manager +from core.utils.logger import ModuleLogger +from core.permission import Permission + +logger = ModuleLogger("GroupKnowledgeBase") + +@matcher.command("kb_add", "添加知识库", permission=Permission.ADMIN) +async def kb_add_command(event: GroupMessageEvent, args: list[str]): + """添加知识库条目""" + if len(args) < 2: + await event.reply("用法: /kb_add <问题> <答案>") + return + + question = args[0] + answer = " ".join(args[1:]) + group_id = event.group_id + + try: + collection_name = f"knowledge_base_{group_id}" + doc_id = str(uuid.uuid4()) + + text_to_embed = f"问题: {question}\n答案: {answer}" + metadata = { + "group_id": group_id, + "question": question, + "answer": answer, + "added_by": event.user_id, + "timestamp": int(time.time()) + } + + success = vectordb_manager.add_texts( + collection_name=collection_name, + texts=[text_to_embed], + metadatas=[metadata], + ids=[doc_id] + ) + + if success: + await event.reply(f"知识库条目添加成功!\n问题: {question}") + else: + await event.reply("知识库条目添加失败,请查看日志。") + except Exception as e: + logger.error(f"添加知识库失败: {e}") + await event.reply(f"添加失败: {str(e)}") + +@matcher.command("kb_search", "搜索知识库") +async def kb_search_command(event: GroupMessageEvent, args: list[str]): + """搜索知识库条目""" + if not args: + await event.reply("用法: /kb_search <关键词>") + return + + query = " ".join(args) + group_id = event.group_id + + try: + collection_name = f"knowledge_base_{group_id}" + + results = vectordb_manager.query_texts( + collection_name=collection_name, + query_texts=[query], + n_results=3 + ) + + if not results or not results.get("documents") or not results["documents"][0]: + await event.reply("未找到相关的知识库条目。") + return + + reply_msg = f"为您找到以下相关知识:\n" + for i, metadata in enumerate(results["metadatas"][0], 1): + question = metadata.get("question", "") + answer = metadata.get("answer", "") + reply_msg += f"\n{i}. Q: {question}\n A: {answer}" + + await event.reply(reply_msg) + except Exception as e: + logger.error(f"搜索知识库失败: {e}") + await event.reply(f"搜索失败: {str(e)}")