""" WebSocket 连接池模块 该模块实现了 WebSocket 连接池功能,用于管理多个 WebSocket 连接, 提高并发处理能力和连接复用效率。 """ import asyncio import websockets from websockets.legacy.client import WebSocketClientProtocol from typing import Optional, Dict, Any, cast, Union, AsyncGenerator import uuid from loguru import logger import contextlib from .config_loader import global_config from .utils.exceptions import WebSocketError, WebSocketConnectionError class WSConnection: """ WebSocket 连接包装类 封装单个 WebSocket 连接的状态和操作 """ def __init__(self, conn: WebSocketClientProtocol, conn_id: str): self.conn = conn self.conn_id = conn_id self.last_used = asyncio.get_event_loop().time() self.is_active = True self._pending_requests: Dict[str, asyncio.Future] = {} async def send(self, data: Union[Dict[Any, Any], bytes]): """ 发送数据到 WebSocket 连接 """ if not self.is_active: raise WebSocketError(f"连接 {self.conn_id} 已关闭") try: await self.conn.send(data) self.last_used = asyncio.get_event_loop().time() except Exception as e: self.is_active = False raise WebSocketError(f"发送数据失败: {e}") async def recv(self): """ 从 WebSocket 连接接收数据 """ if not self.is_active: raise WebSocketError(f"连接 {self.conn_id} 已关闭") try: data = await self.conn.recv() self.last_used = asyncio.get_event_loop().time() return data except Exception as e: self.is_active = False raise WebSocketError(f"接收数据失败: {e}") async def ping(self, timeout: int = 5) -> bool: """ 对 WebSocket 连接执行 ping-pong 健康检查 """ if not self.is_active: return False try: # 使用 wait_for 包装 ping pong_waiter = await self.conn.ping() await asyncio.wait_for(pong_waiter, timeout=timeout) return True except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed, Exception): self.is_active = False return False async def close(self): """ 关闭 WebSocket 连接 """ if self.is_active: self.is_active = False try: await self.conn.close() except Exception: pass class WSConnectionPool: """ WebSocket 连接池 管理多个 WebSocket 连接,提供连接的获取、释放和回收功能 """ def __init__(self, pool_size: int = 3, max_idle_time: int = 300): """ 初始化连接池 :param pool_size: 连接池大小 :param max_idle_time: 连接最大空闲时间(秒) """ self.pool_size = pool_size self.max_idle_time = max_idle_time self.pool: asyncio.Queue[WSConnection] = asyncio.Queue(maxsize=pool_size) self._closed = False self._cleanup_task: Optional[asyncio.Task] = None self._current_size = 0 # 当前管理的连接数(包括池中和借出的) self._lock = asyncio.Lock() # 用于保护 _current_size 的修改 # 从全局配置读取参数 self.url = global_config.napcat_ws.uri self.token = global_config.napcat_ws.token self.reconnect_interval = global_config.napcat_ws.reconnect_interval logger.info(f"WebSocket 连接池初始化完成,大小: {pool_size}") async def initialize(self): """ 初始化连接池,创建初始连接 """ if self._closed: raise WebSocketError("连接池已关闭") # 启动连接清理任务 self._cleanup_task = asyncio.create_task(self._cleanup_idle_connections()) # 预热连接池 for _ in range(self.pool_size): try: conn = await self._create_connection() await self.pool.put(conn) async with self._lock: self._current_size += 1 logger.info(f"WebSocket 连接 {conn.conn_id} 已创建并加入连接池") except Exception as e: logger.error(f"创建初始连接失败: {e}") # 初始连接失败不抛出异常,允许后续动态创建 async def _create_connection(self) -> WSConnection: """ 创建新的 WebSocket 连接 """ headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} try: conn_id = str(uuid.uuid4()) websocket_raw = await websockets.connect( self.url, additional_headers=headers ) websocket = cast(WebSocketClientProtocol, websocket_raw) conn = WSConnection(websocket, conn_id) logger.info(f"WebSocket 连接 {conn_id} 已建立") return conn except Exception as e: raise WebSocketConnectionError(f"创建 WebSocket 连接失败: {e}") @contextlib.asynccontextmanager async def connection(self) -> AsyncGenerator[WSConnection, None]: """ 获取连接的上下文管理器 """ conn = await self.get_connection() try: yield conn finally: await self.release_connection(conn) async def get_connection(self) -> WSConnection: """ 从连接池获取一个健康的连接,包含健康检查。 """ if self._closed: raise WebSocketError("连接池已关闭") start_time = asyncio.get_event_loop().time() timeout = 10 # 获取连接的总超时时间 while True: if asyncio.get_event_loop().time() - start_time > timeout: raise WebSocketError("获取连接超时") try: # 1. 尝试从池中获取 conn = self.pool.get_nowait() # 健康检查 if await conn.ping(): logger.debug(f"连接 {conn.conn_id} 健康检查通过") return conn else: logger.warning(f"连接 {conn.conn_id} 健康检查失败,丢弃") await conn.close() async with self._lock: self._current_size -= 1 # 继续循环,尝试获取下一个或创建新的 continue except asyncio.QueueEmpty: # 池为空,检查是否可以创建新连接 async with self._lock: if self._current_size < self.pool_size: # 有配额,创建新连接 self._current_size += 1 # 先占位 create_new = True else: create_new = False if create_new: try: conn = await self._create_connection() return conn except Exception as e: async with self._lock: self._current_size -= 1 # 回滚占位 logger.error(f"创建新连接失败: {e}") await asyncio.sleep(1) # 避免快速失败循环 continue else: # 没有配额,等待池中有可用连接 try: conn = await asyncio.wait_for(self.pool.get(), timeout=1.0) # 获取到了,进行健康检查(在下一次循环中处理,或者这里直接处理) # 为了代码复用,我们把 conn 放回去(或者直接用),这里直接用 if await conn.ping(): return conn else: await conn.close() async with self._lock: self._current_size -= 1 continue except asyncio.TimeoutError: continue async def release_connection(self, conn: WSConnection): """ 释放连接回连接池 """ if self._closed: await conn.close() return if not conn.is_active: logger.warning(f"连接 {conn.conn_id} 已失效,不返回连接池") await conn.close() async with self._lock: self._current_size -= 1 return try: # 尝试放回池中 self.pool.put_nowait(conn) logger.debug(f"连接 {conn.conn_id} 已返回连接池") except asyncio.QueueFull: # 理论上不应该发生,除非 _current_size 逻辑有误 logger.warning(f"连接池已满,关闭多余连接 {conn.conn_id}") await conn.close() async with self._lock: self._current_size -= 1 except Exception as e: logger.error(f"释放连接失败: {e}") await conn.close() async with self._lock: self._current_size -= 1 async def _cleanup_idle_connections(self): """ 清理空闲连接任务 """ while not self._closed: await asyncio.sleep(60) # 每分钟检查一次 try: # 我们不替换队列,而是取出检查再放回 # 这样比较安全,但可能会暂时清空池子 # 更好的做法是只检查队头的连接 # 获取当前队列大小 qsize = self.pool.qsize() for _ in range(qsize): try: conn = self.pool.get_nowait() except asyncio.QueueEmpty: break current_time = asyncio.get_event_loop().time() if current_time - conn.last_used > self.max_idle_time: logger.info(f"清理空闲连接 {conn.conn_id}") await conn.close() async with self._lock: self._current_size -= 1 else: # 还没过期,放回去 try: self.pool.put_nowait(conn) except asyncio.QueueFull: # 竞争条件下可能满了 await conn.close() async with self._lock: self._current_size -= 1 except Exception as e: logger.error(f"清理空闲连接失败: {e}") async def close(self): """ 关闭连接池 """ if self._closed: return self._closed = True # 停止清理任务 if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass # 关闭所有连接 while not self.pool.empty(): try: conn = self.pool.get_nowait() await conn.close() except asyncio.QueueEmpty: break logger.info("WebSocket 连接池已关闭")