cytopa99's picture
Upload 47 files
68e5689 verified
"""
并行处理池模块
提供音频片段的并行处理功能,集成完整的处理流程:
ASR → LLM → TTS → Sync
包含性能监控和自适应并发控制。
"""
import os
import asyncio
import logging
import time
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass
from .groq_client import GroqClient, GroqConfig, GroqError
from .tts_generator import TTSGenerator, TTSConfig, TTSError
from .audio_sync import AudioSyncEngine, SyncConfig, AudioSyncError
from .performance_monitor import (
get_performance_monitor,
AdaptiveConcurrencyController,
track_performance
)
# 配置日志
logger = logging.getLogger(__name__)
class ProcessingError(Exception):
"""处理异常基类"""
pass
class SegmentProcessingError(ProcessingError):
"""片段处理异常"""
def __init__(self, segment_index: int, stage: str, reason: str):
self.segment_index = segment_index
self.stage = stage
self.reason = reason
self.message = f"片段 {segment_index}{stage} 阶段失败: {reason}"
super().__init__(self.message)
@dataclass
class ProcessorConfig:
"""
并行处理池配置
属性:
max_workers: 最大并发工作数,默认3
min_workers: 最小并发工作数,默认1
segment_timeout: 单个片段处理超时(秒),默认480秒(8分钟)
retry_count: 失败重试次数,默认2
temp_dir: 临时文件目录
adaptive_concurrency: 是否启用自适应并发控制
"""
max_workers: int = 3
min_workers: int = 1
segment_timeout: float = 480.0 # 8分钟
retry_count: int = 2
temp_dir: str = "temp/processing"
adaptive_concurrency: bool = True
@dataclass
class SegmentResult:
"""
片段处理结果
属性:
index: 片段索引
success: 是否成功
audio_path: 生成的音频文件路径
duration: 片段时长
start_time: 开始时间
error: 错误信息(如果失败)
processing_time: 处理耗时(秒)
"""
index: int
success: bool
audio_path: Optional[str] = None
duration: Optional[float] = None
start_time: Optional[float] = None
error: Optional[str] = None
processing_time: Optional[float] = None
class ParallelProcessingPool:
"""
并行处理池
管理多个音频片段的并行处理,集成ASR、LLM、TTS和音频同步。
使用示例:
pool = ParallelProcessingPool()
await pool.initialize()
segments = [
{"audio_path": "seg1.mp3", "start_time": 0, "duration": 300},
{"audio_path": "seg2.mp3", "start_time": 300, "duration": 300}
]
results = await pool.process_segments(
segments,
progress_callback=lambda msg, pct: print(f"{pct}%: {msg}")
)
"""
def __init__(
self,
config: Optional[ProcessorConfig] = None,
groq_config: Optional[GroqConfig] = None,
tts_config: Optional[TTSConfig] = None,
sync_config: Optional[SyncConfig] = None
):
"""
初始化并行处理池
参数:
config: 处理池配置
groq_config: Groq客户端配置
tts_config: TTS生成器配置
sync_config: 音频同步配置
"""
self.config = config or ProcessorConfig()
# 存储子模块配置,延迟初始化
self._groq_config = groq_config
self._tts_config = tts_config
self._sync_config = sync_config
# 子模块实例
self.groq_client: Optional[GroqClient] = None
self.tts_generator: Optional[TTSGenerator] = None
self.audio_sync: Optional[AudioSyncEngine] = None
self._initialized = False
# 性能监控
self._performance_monitor = get_performance_monitor()
# 自适应并发控制器
self._concurrency_controller = None
if self.config.adaptive_concurrency:
self._concurrency_controller = AdaptiveConcurrencyController(
min_workers=self.config.min_workers,
max_workers=self.config.max_workers
)
# 确保临时目录存在
os.makedirs(self.config.temp_dir, exist_ok=True)
logger.info(
f"并行处理池配置: 最大并发={self.config.max_workers}, "
f"超时={self.config.segment_timeout}s, "
f"自适应并发={'启用' if self.config.adaptive_concurrency else '禁用'}"
)
async def initialize(self) -> None:
"""
初始化所有处理模块
创建并初始化Groq客户端、TTS生成器和音频同步引擎。
"""
if self._initialized:
logger.debug("处理池已初始化,跳过")
return
logger.info("初始化并行处理池...")
# 初始化 Groq 客户端
self.groq_client = GroqClient(self._groq_config)
await self.groq_client.initialize()
# 初始化 TTS 生成器
self.tts_generator = TTSGenerator(self._tts_config)
# 初始化音频同步引擎
self.audio_sync = AudioSyncEngine(self._sync_config)
self._initialized = True
logger.info("并行处理池初始化完成")
def _ensure_initialized(self) -> None:
"""确保处理池已初始化"""
if not self._initialized:
raise ProcessingError("处理池未初始化,请先调用 initialize()")
async def process_segments(
self,
segments: List[Dict[str, Any]],
progress_callback: Optional[Callable[[str, float], None]] = None,
config: Optional[Dict[str, Any]] = None
) -> List[SegmentResult]:
"""
并行处理多个音频片段
参数:
segments: 片段列表,每个片段包含:
- audio_path: str - 音频文件路径
- start_time: float - 开始时间(秒)
- duration: float - 片段时长(秒)
progress_callback: 进度回调函数,接收 (消息, 百分比)
config: 处理配置,包含客户端配置等
返回:
处理结果列表
"""
self._ensure_initialized()
if not segments:
logger.warning("处理输入为空")
return []
# 提取客户端配置
client_config = {}
if config and 'client_config' in config:
client_config = config['client_config']
logger.info(f"处理器使用客户端配置: {list(client_config.keys())}")
total = len(segments)
logger.info(f"开始并行处理 {total} 个片段")
if progress_callback:
progress_callback("开始处理...", 0)
# 获取当前推荐的并发数
if self._concurrency_controller:
current_workers = self._concurrency_controller.get_recommended_workers()
logger.info(f"自适应并发: 当前推荐 {current_workers} 个工作线程")
else:
current_workers = self.config.max_workers
# 创建处理任务(传递客户端配置)
tasks = []
for i, segment in enumerate(segments):
task = self._process_single_segment(
segment,
i,
total,
progress_callback,
client_config # 传递客户端配置
)
tasks.append(task)
# 使用信号量限制并发数
semaphore = asyncio.Semaphore(current_workers)
async def limited_task(task):
async with semaphore:
return await task
# 使用性能监控跟踪整体处理
with self._performance_monitor.track_operation("并行片段处理") as metrics:
metrics.extra["total_segments"] = total
metrics.extra["workers"] = current_workers
# 并行执行
results = await asyncio.gather(
*[limited_task(task) for task in tasks],
return_exceptions=True
)
# 处理结果
processed_results = []
success_count = 0
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"片段 {i} 处理异常: {result}")
processed_results.append(SegmentResult(
index=i,
success=False,
error=str(result)
))
elif isinstance(result, SegmentResult):
processed_results.append(result)
if result.success:
success_count += 1
else:
processed_results.append(SegmentResult(
index=i,
success=False,
error="未知结果类型"
))
logger.info(f"并行处理完成: {success_count}/{total} 成功")
if progress_callback:
progress_callback("处理完成", 100)
return processed_results
async def _process_single_segment(
self,
segment: Dict[str, Any],
index: int,
total: int,
progress_callback: Optional[Callable[[str, float], None]] = None,
client_config: Optional[Dict[str, Any]] = None
) -> SegmentResult:
"""
处理单个音频片段的完整流程
流程: ASR → LLM → TTS → Sync
参数:
segment: 片段信息
index: 片段索引
total: 总片段数
progress_callback: 进度回调
client_config: 客户端配置(优先于默认配置)
返回:
处理结果
"""
start_time = time.time()
audio_path = segment.get('audio_path')
seg_start = segment.get('start_time', 0)
seg_duration = segment.get('duration', 0)
logger.info(f"开始处理片段 {index + 1}/{total}")
# 记录客户端配置使用情况
if client_config:
logger.info(f"片段 {index + 1} 使用客户端配置: {list(client_config.keys())}")
else:
logger.info(f"片段 {index + 1} 使用默认配置")
# 重试机制
last_error = None
for attempt in range(self.config.retry_count + 1):
try:
# 设置超时
result = await asyncio.wait_for(
self._do_process_segment(
audio_path,
seg_start,
seg_duration,
index,
total,
progress_callback,
client_config # 传递客户端配置
),
timeout=self.config.segment_timeout
)
processing_time = time.time() - start_time
return SegmentResult(
index=index,
success=True,
audio_path=result['audio_path'],
duration=seg_duration,
start_time=seg_start,
processing_time=processing_time
)
except asyncio.TimeoutError:
last_error = f"处理超时({self.config.segment_timeout}秒)"
logger.warning(f"片段 {index} 超时(第 {attempt + 1} 次尝试)")
except Exception as e:
last_error = str(e)
logger.warning(
f"片段 {index} 处理失败(第 {attempt + 1} 次尝试): {e}"
)
if attempt < self.config.retry_count:
await asyncio.sleep(1)
# 所有重试都失败
processing_time = time.time() - start_time
logger.error(f"片段 {index} 处理失败: {last_error}")
return SegmentResult(
index=index,
success=False,
start_time=seg_start,
error=last_error,
processing_time=processing_time
)
async def _do_process_segment(
self,
audio_path: str,
start_time: float,
duration: float,
index: int,
total: int,
progress_callback: Optional[Callable[[str, float], None]] = None,
client_config: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
执行单个片段的实际处理
参数:
audio_path: 音频文件路径
start_time: 开始时间
duration: 片段时长
index: 片段索引
total: 总片段数
progress_callback: 进度回调
返回:
处理结果字典
"""
base_progress = (index / total) * 100
step_progress = (1 / total) * 100
def update_progress(stage: str, stage_pct: float):
if progress_callback:
pct = base_progress + (stage_pct / 100) * step_progress
progress_callback(f"片段 {index + 1}: {stage}", pct)
# 1. ASR - 语音识别
update_progress("语音识别中...", 0)
transcription = await self.groq_client.transcribe(audio_path)
if not transcription.get('text'):
raise SegmentProcessingError(index, "ASR", "识别结果为空")
logger.debug(
f"片段 {index} ASR完成: 语言={transcription.get('language')}, "
f"片段数={len(transcription.get('segments', []))}"
)
# 2. LLM - 翻译和角色识别
update_progress("翻译中...", 25)
translation = await self.groq_client.translate(
transcription['text'],
transcription['language'],
transcription.get('segments')
)
if not translation.get('segments'):
raise SegmentProcessingError(index, "LLM", "翻译结果为空")
logger.debug(f"片段 {index} 翻译完成: {len(translation['segments'])} 个片段")
# 3. TTS - 语音合成(使用客户端配置)
update_progress("生成配音...", 50)
# 将客户端配置传递给TTS生成器
tts_paths = await self.tts_generator.generate(
translation['segments'],
client_config
)
# 过滤有效的TTS路径
valid_tts = [(i, p) for i, p in enumerate(tts_paths) if p is not None]
if not valid_tts:
raise SegmentProcessingError(index, "TTS", "所有TTS生成失败")
logger.debug(f"片段 {index} TTS完成: {len(valid_tts)}/{len(tts_paths)} 成功")
# 4. 音频同步
update_progress("音频同步...", 75)
# 准备同步所需的片段信息
sync_segments = []
sync_tts_paths = []
for i, seg in enumerate(translation['segments']):
if i < len(tts_paths) and tts_paths[i] is not None:
sync_segments.append({
'start': seg.get('start', 0),
'end': seg.get('end', 0)
})
sync_tts_paths.append(tts_paths[i])
# 如果没有时间戳信息,使用均匀分布
if not any(s.get('start', 0) or s.get('end', 0) for s in sync_segments):
segment_duration = duration / len(sync_segments) if sync_segments else duration
for i, seg in enumerate(sync_segments):
seg['start'] = i * segment_duration
seg['end'] = (i + 1) * segment_duration
synced_audio = await self.audio_sync.align(
sync_tts_paths,
sync_segments,
duration,
client_config # 传递客户端配置
)
update_progress("完成", 100)
logger.info(f"片段 {index} 处理完成")
return {
'audio_path': synced_audio,
'transcription': transcription,
'translation': translation
}
def cleanup(self) -> int:
"""
清理所有临时文件
返回:
清理的文件数量
"""
cleaned = 0
if self.tts_generator:
cleaned += self.tts_generator.cleanup()
if self.audio_sync:
cleaned += self.audio_sync.cleanup()
logger.info(f"处理池清理完成: {cleaned} 个文件")
return cleaned
@property
def is_initialized(self) -> bool:
"""检查处理池是否已初始化"""
return self._initialized