Files
NeoBot/core/services/local_file_server.py
K2Cr2O1 2a6e9b8f89 feat(bili): 支持合并B站分离的音视频流并添加请求头支持
添加对B站分离音视频流的合并功能,使用ffmpeg合并m4s格式的视频和音频流
扩展download_file接口支持自定义请求头,用于B站视频下载的Referer校验
2026-03-15 01:34:00 +08:00

220 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
本地文件下载服务
该模块提供一个本地 HTTP 服务,用于下载远程文件到本地并提供本地访问。
主要解决 NapCat 等第三方服务无法直接访问某些远程资源(如 B 站防盗链)的问题。
"""
import asyncio
import os
import tempfile
import hashlib
from pathlib import Path
from typing import Optional, Dict
from urllib.parse import urlparse
import aiohttp
from aiohttp import web
import urllib.request
from core.utils.logger import logger
from core.config_loader import global_config
class LocalFileServer:
"""
本地文件下载服务
提供一个本地 HTTP 服务,用于下载远程文件到本地并提供本地访问。
"""
def __init__(self, host: str = "0.0.0.0", port: int = 3003):
"""
初始化本地文件下载服务
Args:
host (str): 服务监听地址
port (int): 服务监听端口
"""
self.host = host
self.port = port
self.app = web.Application()
self.runner = None
self.site = None
self.download_dir = Path(tempfile.gettempdir()) / "neobot_downloads"
self.download_dir.mkdir(parents=True, exist_ok=True)
# 注册路由
self.app.router.add_get('/download', self.handle_download)
self.app.router.add_get('/health', self.handle_health)
# 文件映射表file_id -> file_path
self.file_map: Dict[str, Path] = {}
logger.success(f"[LocalFileServer] 初始化完成: {self.host}:{self.port}")
async def start(self):
"""启动服务"""
self.runner = web.AppRunner(self.app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, self.host, self.port)
await self.site.start()
logger.success(f"[LocalFileServer] 服务已启动: http://{self.host}:{self.port}")
async def stop(self):
"""停止服务"""
if self.runner:
await self.runner.cleanup()
logger.info("[LocalFileServer] 服务已停止")
def _generate_file_id(self, url: str) -> str:
"""根据 URL 生成唯一的文件 ID"""
url_hash = hashlib.md5(url.encode()).hexdigest()[:16]
return f"file_{url_hash}"
async def download_file(self, url: str, timeout: int = 60, headers: Optional[Dict[str, str]] = None) -> Optional[str]:
"""
下载远程文件到本地
Args:
url (str): 远程文件 URL
timeout (int): 下载超时时间(秒)
headers (Optional[Dict[str, str]]): 请求头
Returns:
Optional[str]: 本地文件 ID如果失败则返回 None
"""
try:
file_id = self._generate_file_id(url)
file_path = self.download_dir / f"{file_id}"
# 检查文件是否已存在
if file_path.exists():
logger.info(f"[LocalFileServer] 文件已存在: {file_id}")
return file_id
logger.info(f"[LocalFileServer] 开始下载: {url}")
# 使用 aiohttp 下载文件
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=timeout, headers=headers) as response:
if response.status != 200:
logger.error(f"[LocalFileServer] 下载失败: HTTP {response.status}")
return None
# 读取并保存文件
with open(file_path, 'wb') as f:
while True:
chunk = await response.content.read(8192)
if not chunk:
break
f.write(chunk)
self.file_map[file_id] = file_path
logger.success(f"[LocalFileServer] 下载完成: {file_id} ({file_path.stat().st_size} bytes)")
return file_id
except Exception as e:
logger.error(f"[LocalFileServer] 下载失败: {e}")
return None
async def handle_download(self, request: web.Request) -> web.Response:
"""处理文件下载请求"""
file_id = request.query.get('id')
if not file_id or file_id not in self.file_map:
return web.Response(
status=404,
text='File not found',
content_type='text/plain'
)
file_path = self.file_map[file_id]
if not file_path.exists():
return web.Response(
status=404,
text='File not found',
content_type='text/plain'
)
# 获取文件大小
file_size = file_path.stat().st_size
# 设置响应头
headers = {
'Content-Disposition': f'attachment; filename="{file_id}"',
'Content-Length': str(file_size)
}
return web.FileResponse(file_path, headers=headers)
async def handle_health(self, request: web.Request) -> web.Response:
"""健康检查"""
return web.json_response({
'status': 'ok',
'service': 'LocalFileServer',
'download_dir': str(self.download_dir),
'files_count': len(self.file_map)
})
# 全局实例
_local_file_server: Optional[LocalFileServer] = None
def get_local_file_server() -> Optional[LocalFileServer]:
"""获取全局本地文件服务器实例"""
global _local_file_server
if _local_file_server is None:
try:
server_config = global_config.local_file_server
_local_file_server = LocalFileServer(
host=server_config.host,
port=server_config.port
)
except Exception as e:
logger.error(f"[LocalFileServer] 初始化失败: {e}")
return None
return _local_file_server
async def start_local_file_server():
"""启动全局本地文件服务器"""
server = get_local_file_server()
if server:
await server.start()
async def stop_local_file_server():
"""停止全局本地文件服务器"""
global _local_file_server
if _local_file_server:
await _local_file_server.stop()
_local_file_server = None
async def download_to_local(url: str, timeout: int = 60, headers: Optional[Dict[str, str]] = None) -> Optional[str]:
"""
下载远程文件到本地并返回本地访问 URL
Args:
url (str): 远程文件 URL
timeout (int): 下载超时时间(秒)
headers (Optional[Dict[str, str]]): 请求头
Returns:
Optional[str]: 本地访问 URL如果失败则返回 None
"""
server = get_local_file_server()
if not server:
return None
file_id = await server.download_file(url, timeout, headers)
if not file_id:
return None
return f"http://127.0.0.1:{server.port}/download?id={file_id}"