Spaces:
Running
Running
| """ | |
| 并行处理池模块 | |
| 提供音频片段的并行处理功能,集成完整的处理流程: | |
| ASR → LLM → TTS → Sync | |
| 包含性能监控和自适应并发控制。 | |
| """ | |
| import os | |
| import asyncio | |
| import logging | |
| import time | |
| from typing import List, Dict, Any, Optional, Callable | |
| from dataclasses import dataclass | |
| from .groq_client import GroqClient, GroqConfig, GroqError | |
| from .tts_generator import TTSGenerator, TTSConfig, TTSError | |
| from .audio_sync import AudioSyncEngine, SyncConfig, AudioSyncError | |
| from .performance_monitor import ( | |
| get_performance_monitor, | |
| AdaptiveConcurrencyController, | |
| track_performance | |
| ) | |
| # 配置日志 | |
| logger = logging.getLogger(__name__) | |
| class ProcessingError(Exception): | |
| """处理异常基类""" | |
| pass | |
| class SegmentProcessingError(ProcessingError): | |
| """片段处理异常""" | |
| def __init__(self, segment_index: int, stage: str, reason: str): | |
| self.segment_index = segment_index | |
| self.stage = stage | |
| self.reason = reason | |
| self.message = f"片段 {segment_index} 在 {stage} 阶段失败: {reason}" | |
| super().__init__(self.message) | |
| class ProcessorConfig: | |
| """ | |
| 并行处理池配置 | |
| 属性: | |
| max_workers: 最大并发工作数,默认3 | |
| min_workers: 最小并发工作数,默认1 | |
| segment_timeout: 单个片段处理超时(秒),默认480秒(8分钟) | |
| retry_count: 失败重试次数,默认2 | |
| temp_dir: 临时文件目录 | |
| adaptive_concurrency: 是否启用自适应并发控制 | |
| """ | |
| max_workers: int = 3 | |
| min_workers: int = 1 | |
| segment_timeout: float = 480.0 # 8分钟 | |
| retry_count: int = 2 | |
| temp_dir: str = "temp/processing" | |
| adaptive_concurrency: bool = True | |
| class SegmentResult: | |
| """ | |
| 片段处理结果 | |
| 属性: | |
| index: 片段索引 | |
| success: 是否成功 | |
| audio_path: 生成的音频文件路径 | |
| duration: 片段时长 | |
| start_time: 开始时间 | |
| error: 错误信息(如果失败) | |
| processing_time: 处理耗时(秒) | |
| """ | |
| index: int | |
| success: bool | |
| audio_path: Optional[str] = None | |
| duration: Optional[float] = None | |
| start_time: Optional[float] = None | |
| error: Optional[str] = None | |
| processing_time: Optional[float] = None | |
| class ParallelProcessingPool: | |
| """ | |
| 并行处理池 | |
| 管理多个音频片段的并行处理,集成ASR、LLM、TTS和音频同步。 | |
| 使用示例: | |
| pool = ParallelProcessingPool() | |
| await pool.initialize() | |
| segments = [ | |
| {"audio_path": "seg1.mp3", "start_time": 0, "duration": 300}, | |
| {"audio_path": "seg2.mp3", "start_time": 300, "duration": 300} | |
| ] | |
| results = await pool.process_segments( | |
| segments, | |
| progress_callback=lambda msg, pct: print(f"{pct}%: {msg}") | |
| ) | |
| """ | |
| def __init__( | |
| self, | |
| config: Optional[ProcessorConfig] = None, | |
| groq_config: Optional[GroqConfig] = None, | |
| tts_config: Optional[TTSConfig] = None, | |
| sync_config: Optional[SyncConfig] = None | |
| ): | |
| """ | |
| 初始化并行处理池 | |
| 参数: | |
| config: 处理池配置 | |
| groq_config: Groq客户端配置 | |
| tts_config: TTS生成器配置 | |
| sync_config: 音频同步配置 | |
| """ | |
| self.config = config or ProcessorConfig() | |
| # 存储子模块配置,延迟初始化 | |
| self._groq_config = groq_config | |
| self._tts_config = tts_config | |
| self._sync_config = sync_config | |
| # 子模块实例 | |
| self.groq_client: Optional[GroqClient] = None | |
| self.tts_generator: Optional[TTSGenerator] = None | |
| self.audio_sync: Optional[AudioSyncEngine] = None | |
| self._initialized = False | |
| # 性能监控 | |
| self._performance_monitor = get_performance_monitor() | |
| # 自适应并发控制器 | |
| self._concurrency_controller = None | |
| if self.config.adaptive_concurrency: | |
| self._concurrency_controller = AdaptiveConcurrencyController( | |
| min_workers=self.config.min_workers, | |
| max_workers=self.config.max_workers | |
| ) | |
| # 确保临时目录存在 | |
| os.makedirs(self.config.temp_dir, exist_ok=True) | |
| logger.info( | |
| f"并行处理池配置: 最大并发={self.config.max_workers}, " | |
| f"超时={self.config.segment_timeout}s, " | |
| f"自适应并发={'启用' if self.config.adaptive_concurrency else '禁用'}" | |
| ) | |
| async def initialize(self) -> None: | |
| """ | |
| 初始化所有处理模块 | |
| 创建并初始化Groq客户端、TTS生成器和音频同步引擎。 | |
| """ | |
| if self._initialized: | |
| logger.debug("处理池已初始化,跳过") | |
| return | |
| logger.info("初始化并行处理池...") | |
| # 初始化 Groq 客户端 | |
| self.groq_client = GroqClient(self._groq_config) | |
| await self.groq_client.initialize() | |
| # 初始化 TTS 生成器 | |
| self.tts_generator = TTSGenerator(self._tts_config) | |
| # 初始化音频同步引擎 | |
| self.audio_sync = AudioSyncEngine(self._sync_config) | |
| self._initialized = True | |
| logger.info("并行处理池初始化完成") | |
| def _ensure_initialized(self) -> None: | |
| """确保处理池已初始化""" | |
| if not self._initialized: | |
| raise ProcessingError("处理池未初始化,请先调用 initialize()") | |
| async def process_segments( | |
| self, | |
| segments: List[Dict[str, Any]], | |
| progress_callback: Optional[Callable[[str, float], None]] = None, | |
| config: Optional[Dict[str, Any]] = None | |
| ) -> List[SegmentResult]: | |
| """ | |
| 并行处理多个音频片段 | |
| 参数: | |
| segments: 片段列表,每个片段包含: | |
| - audio_path: str - 音频文件路径 | |
| - start_time: float - 开始时间(秒) | |
| - duration: float - 片段时长(秒) | |
| progress_callback: 进度回调函数,接收 (消息, 百分比) | |
| config: 处理配置,包含客户端配置等 | |
| 返回: | |
| 处理结果列表 | |
| """ | |
| self._ensure_initialized() | |
| if not segments: | |
| logger.warning("处理输入为空") | |
| return [] | |
| # 提取客户端配置 | |
| client_config = {} | |
| if config and 'client_config' in config: | |
| client_config = config['client_config'] | |
| logger.info(f"处理器使用客户端配置: {list(client_config.keys())}") | |
| total = len(segments) | |
| logger.info(f"开始并行处理 {total} 个片段") | |
| if progress_callback: | |
| progress_callback("开始处理...", 0) | |
| # 获取当前推荐的并发数 | |
| if self._concurrency_controller: | |
| current_workers = self._concurrency_controller.get_recommended_workers() | |
| logger.info(f"自适应并发: 当前推荐 {current_workers} 个工作线程") | |
| else: | |
| current_workers = self.config.max_workers | |
| # 创建处理任务(传递客户端配置) | |
| tasks = [] | |
| for i, segment in enumerate(segments): | |
| task = self._process_single_segment( | |
| segment, | |
| i, | |
| total, | |
| progress_callback, | |
| client_config # 传递客户端配置 | |
| ) | |
| tasks.append(task) | |
| # 使用信号量限制并发数 | |
| semaphore = asyncio.Semaphore(current_workers) | |
| async def limited_task(task): | |
| async with semaphore: | |
| return await task | |
| # 使用性能监控跟踪整体处理 | |
| with self._performance_monitor.track_operation("并行片段处理") as metrics: | |
| metrics.extra["total_segments"] = total | |
| metrics.extra["workers"] = current_workers | |
| # 并行执行 | |
| results = await asyncio.gather( | |
| *[limited_task(task) for task in tasks], | |
| return_exceptions=True | |
| ) | |
| # 处理结果 | |
| processed_results = [] | |
| success_count = 0 | |
| for i, result in enumerate(results): | |
| if isinstance(result, Exception): | |
| logger.error(f"片段 {i} 处理异常: {result}") | |
| processed_results.append(SegmentResult( | |
| index=i, | |
| success=False, | |
| error=str(result) | |
| )) | |
| elif isinstance(result, SegmentResult): | |
| processed_results.append(result) | |
| if result.success: | |
| success_count += 1 | |
| else: | |
| processed_results.append(SegmentResult( | |
| index=i, | |
| success=False, | |
| error="未知结果类型" | |
| )) | |
| logger.info(f"并行处理完成: {success_count}/{total} 成功") | |
| if progress_callback: | |
| progress_callback("处理完成", 100) | |
| return processed_results | |
| async def _process_single_segment( | |
| self, | |
| segment: Dict[str, Any], | |
| index: int, | |
| total: int, | |
| progress_callback: Optional[Callable[[str, float], None]] = None, | |
| client_config: Optional[Dict[str, Any]] = None | |
| ) -> SegmentResult: | |
| """ | |
| 处理单个音频片段的完整流程 | |
| 流程: ASR → LLM → TTS → Sync | |
| 参数: | |
| segment: 片段信息 | |
| index: 片段索引 | |
| total: 总片段数 | |
| progress_callback: 进度回调 | |
| client_config: 客户端配置(优先于默认配置) | |
| 返回: | |
| 处理结果 | |
| """ | |
| start_time = time.time() | |
| audio_path = segment.get('audio_path') | |
| seg_start = segment.get('start_time', 0) | |
| seg_duration = segment.get('duration', 0) | |
| logger.info(f"开始处理片段 {index + 1}/{total}") | |
| # 记录客户端配置使用情况 | |
| if client_config: | |
| logger.info(f"片段 {index + 1} 使用客户端配置: {list(client_config.keys())}") | |
| else: | |
| logger.info(f"片段 {index + 1} 使用默认配置") | |
| # 重试机制 | |
| last_error = None | |
| for attempt in range(self.config.retry_count + 1): | |
| try: | |
| # 设置超时 | |
| result = await asyncio.wait_for( | |
| self._do_process_segment( | |
| audio_path, | |
| seg_start, | |
| seg_duration, | |
| index, | |
| total, | |
| progress_callback, | |
| client_config # 传递客户端配置 | |
| ), | |
| timeout=self.config.segment_timeout | |
| ) | |
| processing_time = time.time() - start_time | |
| return SegmentResult( | |
| index=index, | |
| success=True, | |
| audio_path=result['audio_path'], | |
| duration=seg_duration, | |
| start_time=seg_start, | |
| processing_time=processing_time | |
| ) | |
| except asyncio.TimeoutError: | |
| last_error = f"处理超时({self.config.segment_timeout}秒)" | |
| logger.warning(f"片段 {index} 超时(第 {attempt + 1} 次尝试)") | |
| except Exception as e: | |
| last_error = str(e) | |
| logger.warning( | |
| f"片段 {index} 处理失败(第 {attempt + 1} 次尝试): {e}" | |
| ) | |
| if attempt < self.config.retry_count: | |
| await asyncio.sleep(1) | |
| # 所有重试都失败 | |
| processing_time = time.time() - start_time | |
| logger.error(f"片段 {index} 处理失败: {last_error}") | |
| return SegmentResult( | |
| index=index, | |
| success=False, | |
| start_time=seg_start, | |
| error=last_error, | |
| processing_time=processing_time | |
| ) | |
| async def _do_process_segment( | |
| self, | |
| audio_path: str, | |
| start_time: float, | |
| duration: float, | |
| index: int, | |
| total: int, | |
| progress_callback: Optional[Callable[[str, float], None]] = None, | |
| client_config: Optional[Dict[str, Any]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| 执行单个片段的实际处理 | |
| 参数: | |
| audio_path: 音频文件路径 | |
| start_time: 开始时间 | |
| duration: 片段时长 | |
| index: 片段索引 | |
| total: 总片段数 | |
| progress_callback: 进度回调 | |
| 返回: | |
| 处理结果字典 | |
| """ | |
| base_progress = (index / total) * 100 | |
| step_progress = (1 / total) * 100 | |
| def update_progress(stage: str, stage_pct: float): | |
| if progress_callback: | |
| pct = base_progress + (stage_pct / 100) * step_progress | |
| progress_callback(f"片段 {index + 1}: {stage}", pct) | |
| # 1. ASR - 语音识别 | |
| update_progress("语音识别中...", 0) | |
| transcription = await self.groq_client.transcribe(audio_path) | |
| if not transcription.get('text'): | |
| raise SegmentProcessingError(index, "ASR", "识别结果为空") | |
| logger.debug( | |
| f"片段 {index} ASR完成: 语言={transcription.get('language')}, " | |
| f"片段数={len(transcription.get('segments', []))}" | |
| ) | |
| # 2. LLM - 翻译和角色识别 | |
| update_progress("翻译中...", 25) | |
| translation = await self.groq_client.translate( | |
| transcription['text'], | |
| transcription['language'], | |
| transcription.get('segments') | |
| ) | |
| if not translation.get('segments'): | |
| raise SegmentProcessingError(index, "LLM", "翻译结果为空") | |
| logger.debug(f"片段 {index} 翻译完成: {len(translation['segments'])} 个片段") | |
| # 3. TTS - 语音合成(使用客户端配置) | |
| update_progress("生成配音...", 50) | |
| # 将客户端配置传递给TTS生成器 | |
| tts_paths = await self.tts_generator.generate( | |
| translation['segments'], | |
| client_config | |
| ) | |
| # 过滤有效的TTS路径 | |
| valid_tts = [(i, p) for i, p in enumerate(tts_paths) if p is not None] | |
| if not valid_tts: | |
| raise SegmentProcessingError(index, "TTS", "所有TTS生成失败") | |
| logger.debug(f"片段 {index} TTS完成: {len(valid_tts)}/{len(tts_paths)} 成功") | |
| # 4. 音频同步 | |
| update_progress("音频同步...", 75) | |
| # 准备同步所需的片段信息 | |
| sync_segments = [] | |
| sync_tts_paths = [] | |
| for i, seg in enumerate(translation['segments']): | |
| if i < len(tts_paths) and tts_paths[i] is not None: | |
| sync_segments.append({ | |
| 'start': seg.get('start', 0), | |
| 'end': seg.get('end', 0) | |
| }) | |
| sync_tts_paths.append(tts_paths[i]) | |
| # 如果没有时间戳信息,使用均匀分布 | |
| if not any(s.get('start', 0) or s.get('end', 0) for s in sync_segments): | |
| segment_duration = duration / len(sync_segments) if sync_segments else duration | |
| for i, seg in enumerate(sync_segments): | |
| seg['start'] = i * segment_duration | |
| seg['end'] = (i + 1) * segment_duration | |
| synced_audio = await self.audio_sync.align( | |
| sync_tts_paths, | |
| sync_segments, | |
| duration, | |
| client_config # 传递客户端配置 | |
| ) | |
| update_progress("完成", 100) | |
| logger.info(f"片段 {index} 处理完成") | |
| return { | |
| 'audio_path': synced_audio, | |
| 'transcription': transcription, | |
| 'translation': translation | |
| } | |
| def cleanup(self) -> int: | |
| """ | |
| 清理所有临时文件 | |
| 返回: | |
| 清理的文件数量 | |
| """ | |
| cleaned = 0 | |
| if self.tts_generator: | |
| cleaned += self.tts_generator.cleanup() | |
| if self.audio_sync: | |
| cleaned += self.audio_sync.cleanup() | |
| logger.info(f"处理池清理完成: {cleaned} 个文件") | |
| return cleaned | |
| def is_initialized(self) -> bool: | |
| """检查处理池是否已初始化""" | |
| return self._initialized | |