# -*- coding: utf-8 -*- """ 本地文件下载服务 该模块提供一个本地 HTTP 服务,用于下载远程文件到本地并提供本地访问。 主要解决 NapCat 等第三方服务无法直接访问某些远程资源(如 B 站防盗链)的问题。 """ import asyncio import os import tempfile import hashlib from pathlib import Path from typing import Optional, Dict from urllib.parse import urlparse import aiohttp from aiohttp import web import urllib.request from core.utils.logger import logger from core.config_loader import global_config class LocalFileServer: """ 本地文件下载服务 提供一个本地 HTTP 服务,用于下载远程文件到本地并提供本地访问。 """ def __init__(self, host: str = "0.0.0.0", port: int = 3003): """ 初始化本地文件下载服务 Args: host (str): 服务监听地址 port (int): 服务监听端口 """ self.host = host self.port = port self.app = web.Application() self.runner = None self.site = None self.download_dir = Path(tempfile.gettempdir()) / "neobot_downloads" self.download_dir.mkdir(parents=True, exist_ok=True) # 注册路由 self.app.router.add_get('/download', self.handle_download) self.app.router.add_get('/health', self.handle_health) # 文件映射表:file_id -> file_path self.file_map: Dict[str, Path] = {} logger.success(f"[LocalFileServer] 初始化完成: {self.host}:{self.port}") async def start(self): """启动服务""" self.runner = web.AppRunner(self.app) await self.runner.setup() self.site = web.TCPSite(self.runner, self.host, self.port) await self.site.start() logger.success(f"[LocalFileServer] 服务已启动: http://{self.host}:{self.port}") async def stop(self): """停止服务""" if self.runner: await self.runner.cleanup() logger.info("[LocalFileServer] 服务已停止") def _generate_file_id(self, url: str) -> str: """根据 URL 生成唯一的文件 ID""" url_hash = hashlib.md5(url.encode()).hexdigest()[:16] return f"file_{url_hash}" async def download_file(self, url: str, timeout: int = 60, headers: Optional[Dict[str, str]] = None) -> Optional[str]: """ 下载远程文件到本地 Args: url (str): 远程文件 URL timeout (int): 下载超时时间(秒) headers (Optional[Dict[str, str]]): 请求头 Returns: Optional[str]: 本地文件 ID,如果失败则返回 None """ try: file_id = self._generate_file_id(url) file_path = self.download_dir / f"{file_id}" # 检查文件是否已存在 if file_path.exists(): logger.info(f"[LocalFileServer] 文件已存在: {file_id}") return file_id logger.info(f"[LocalFileServer] 开始下载: {url}") # 使用 aiohttp 下载文件 async with aiohttp.ClientSession() as session: async with session.get(url, timeout=timeout, headers=headers) as response: if response.status != 200: logger.error(f"[LocalFileServer] 下载失败: HTTP {response.status}") return None # 读取并保存文件 with open(file_path, 'wb') as f: while True: chunk = await response.content.read(8192) if not chunk: break f.write(chunk) self.file_map[file_id] = file_path logger.success(f"[LocalFileServer] 下载完成: {file_id} ({file_path.stat().st_size} bytes)") return file_id except Exception as e: logger.error(f"[LocalFileServer] 下载失败: {e}") return None async def handle_download(self, request: web.Request) -> web.Response: """处理文件下载请求""" file_id = request.query.get('id') if not file_id or file_id not in self.file_map: return web.Response( status=404, text='File not found', content_type='text/plain' ) file_path = self.file_map[file_id] if not file_path.exists(): return web.Response( status=404, text='File not found', content_type='text/plain' ) # 获取文件大小 file_size = file_path.stat().st_size # 设置响应头 headers = { 'Content-Disposition': f'attachment; filename="{file_id}"', 'Content-Length': str(file_size) } return web.FileResponse(file_path, headers=headers) async def handle_health(self, request: web.Request) -> web.Response: """健康检查""" return web.json_response({ 'status': 'ok', 'service': 'LocalFileServer', 'download_dir': str(self.download_dir), 'files_count': len(self.file_map) }) # 全局实例 _local_file_server: Optional[LocalFileServer] = None def get_local_file_server() -> Optional[LocalFileServer]: """获取全局本地文件服务器实例""" global _local_file_server if _local_file_server is None: try: server_config = global_config.local_file_server _local_file_server = LocalFileServer( host=server_config.host, port=server_config.port ) except Exception as e: logger.error(f"[LocalFileServer] 初始化失败: {e}") return None return _local_file_server async def start_local_file_server(): """启动全局本地文件服务器""" server = get_local_file_server() if server: await server.start() async def stop_local_file_server(): """停止全局本地文件服务器""" global _local_file_server if _local_file_server: await _local_file_server.stop() _local_file_server = None async def download_to_local(url: str, timeout: int = 60, headers: Optional[Dict[str, str]] = None) -> Optional[str]: """ 下载远程文件到本地并返回本地访问 URL Args: url (str): 远程文件 URL timeout (int): 下载超时时间(秒) headers (Optional[Dict[str, str]]): 请求头 Returns: Optional[str]: 本地访问 URL,如果失败则返回 None """ server = get_local_file_server() if not server: return None file_id = await server.download_file(url, timeout, headers) if not file_id: return None return f"http://127.0.0.1:{server.port}/download?id={file_id}"