import gradio as gr from io import BytesIO import librosa import numpy as np from os import getenv from PIL.Image import Image, open as open_image import soundfile as sf import requests from tempfile import NamedTemporaryFile import torch from transformers import AutoProcessor # Try to import spaces decorator (for Hugging Face Spaces), otherwise use no-op decorator. try: from spaces import GPU as spaces_gpu except ImportError: # For local development, use a no-op decorator because spaces is not available. def spaces_gpu(func): """No-op decorator for local development when spaces module is not available.""" return func def get_pytorch_device() -> str: """Determine the best available PyTorch device for computation. Checks for available hardware accelerators in priority order: 1. CUDA (Nvidia GPUs and AMD ROCm) 2. XPU (Intel GPUs) 3. MPS (Apple Silicon/Metal Performance Shaders) 4. CPU (fallback) Returns: String device name: "cuda", "xpu", "mps", or "cpu" """ return ("cuda" if torch.cuda.is_available() # Nvidia CUDA and AMD ROCm else "xpu" if torch.xpu.is_available() # Intel XPU else "mps" if torch.mps.is_available() # Apple Silicon else "cpu") # gl bro 🫠 def get_torch_dtype(): """Get the appropriate torch dtype based on reduced memory setting. Returns: torch.float16 if reduced memory is enabled, None otherwise (uses default precision). """ return torch.float16 if getenv("REDUCED_MEMORY", "False").lower() == "true" else None def request_image(url: str) -> Image: """Fetch an image from a URL and return it as a PIL Image. Downloads an image from the provided URL and converts it to a PIL Image object for processing. Handles various HTTP errors and timeouts gracefully. Args: url: HTTP/HTTPS URL pointing to an image file. Returns: PIL Image object loaded from the URL. Raises: gr.Error: If the image cannot be fetched due to: - HTTP errors (4xx, 5xx status codes) - Network timeouts - Other request exceptions Note: - Timeout is configurable via REQUEST_TIMEOUT environment variable (default: 45 seconds) - Supports common image formats (JPEG, PNG, GIF, WebP, etc.) """ try: response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT", "45"))) response.raise_for_status() return open_image(BytesIO(response.content)) except requests.HTTPError as e: raise gr.Error(f"Failed to fetch image from URL because of HTTP error: {e.response.status_code} {e.response.text}") except requests.Timeout as e: raise gr.Error(f"Failed to fetch image from URL because the request timed out.") except requests.RequestException as e: raise gr.Error(f"Failed to fetch image from URL: {str(e)}") def request_audio(url: str) -> tuple[int, np.ndarray]: """Fetch an audio file from a URL and return it as audio data. Downloads an audio file from the provided URL and loads it using librosa, which supports many audio formats. Returns the audio data in a format compatible with Gradio's Audio component. Args: url: HTTP/HTTPS URL pointing to an audio file. Returns: Tuple containing: - int: Sample rate of the audio in Hz (e.g., 44100, 22050) - np.ndarray: Audio waveform data as a numpy array (float32, normalized) Raises: gr.Error: If the audio cannot be fetched or loaded due to: - HTTP errors (4xx, 5xx status codes) - Network timeouts - Unsupported audio formats - Other request or audio loading exceptions Note: - Timeout is configurable via REQUEST_TIMEOUT environment variable (default: 45 seconds) - Supports many audio formats (MP3, WAV, FLAC, OGG, M4A, etc.) - Audio is loaded at its native sample rate (sr=None) - Returns normalized float32 audio data suitable for processing """ try: response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT", "45"))) response.raise_for_status() audio_array, sample_rate = librosa.load(BytesIO(response.content), sr=None) return (sample_rate, audio_array) except requests.HTTPError as e: raise gr.Error(f"Failed to fetch audio from URL because of HTTP error: {e.response.status_code} {e.response.text}") except requests.Timeout as e: raise gr.Error(f"Failed to fetch audio from URL because the request timed out.") except requests.RequestException as e: raise gr.Error(f"Failed to fetch audio from URL: {str(e)}") except Exception as e: raise gr.Error(f"Failed to load audio file: {str(e)}") def save_image_to_temp_file(image: Image) -> str: """Save a PIL Image to a temporary file on disk. Creates a temporary file with an appropriate extension based on the image's format and saves the image to it. This is needed because some APIs (like Hugging Face InferenceClient) require file paths rather than PIL Image objects. Args: image: PIL Image object to save. Returns: String path to the temporary file where the image was saved. Note: - Preserves the original image format if available - Falls back to PNG format if image.format is None - Temporary file is not automatically deleted (caller is responsible for cleanup) - File extension is determined from the image format - Useful for APIs that require local file paths rather than in-memory objects """ image_format = image.format if image.format else 'PNG' format_extension = image_format.lower() if image_format else 'png' temp_file = NamedTemporaryFile(delete=False, suffix=f".{format_extension}") temp_path = temp_file.name temp_file.close() image.save(temp_path, format=image_format) return temp_path def get_model_sample_rate(model_id: str) -> int: """Get the expected sample rate for an audio processing model. Retrieves the sample rate configuration from a Hugging Face model's feature extractor. This is useful for ensuring audio is resampled to match the model's expected input format. Args: model_id: Hugging Face model identifier (e.g., "openai/whisper-large-v3"). Returns: Integer sample rate in Hz that the model expects (e.g., 16000). Defaults to 16000 Hz if the sample rate cannot be determined. Note: - Most ASR models use 16kHz sample rate - Uses AutoProcessor to access the model's feature extractor configuration - Returns a sensible default (16kHz) if the model config cannot be loaded """ try: processor = AutoProcessor.from_pretrained(model_id) return processor.feature_extractor.sampling_rate except Exception: return 16000 # Fallback value as most ASR models use 16kHz def resample_audio(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> np.ndarray: """Resample audio data to a target sample rate. Converts audio data to the target sample rate using librosa's resampling. Handles both bytes and numpy array input formats, converting bytes to float32 numpy arrays as needed. Args: target_sample_rate: Desired output sample rate in Hz (e.g., 16000). audio: Tuple containing: - int: Current sample rate of the audio - bytes | np.ndarray: Audio data (can be raw bytes or numpy array) Returns: Numpy array (float32) containing the resampled audio waveform. If sample rates match, returns the audio data unchanged. Raises: ValueError: If audio_data is neither bytes nor np.ndarray. Note: - Converts bytes to float32 by assuming int16 PCM format - Normalizes int16 values to [-1.0, 1.0] range - Only resamples if source and target sample rates differ - Uses librosa's high-quality resampling algorithm """ sample_rate, audio_data = audio # Convert audio data to a numpy array if it’s bytes if isinstance(audio_data, bytes): audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 elif isinstance(audio_data, np.ndarray): audio_array = audio_data.astype(np.float32) else: raise ValueError(f"Unsupported audio_data type: {type(audio_data)}") # Resample if sample rates don’t match. if sample_rate != target_sample_rate: audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=target_sample_rate) return audio_array def save_audio_to_temp_file(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> str: """Resample audio to target sample rate and save to a temporary WAV file. This function resamples audio data to match a target sample rate and saves it as a WAV file. This is useful for preparing audio for APIs that require specific sample rates and file formats. Args: target_sample_rate: Target sample rate in Hz for the output file (e.g., 16000). audio: Tuple containing: - int: Current sample rate of the input audio - bytes | np.ndarray: Audio data to process Returns: String path to the temporary WAV file where the audio was saved. Note: - Automatically resamples audio if sample rates don't match - Saves audio as WAV format (16-bit PCM) - Temporary file is not automatically deleted (caller is responsible for cleanup) - Audio is normalized and converted to float32 before saving - Useful for preparing audio for Hugging Face InferenceClient APIs """ audio_array = resample_audio(target_sample_rate, audio) temp_file = NamedTemporaryFile(delete=False, suffix='.wav') temp_path = temp_file.name temp_file.close() sf.write(temp_path, audio_array, target_sample_rate, format='WAV') return temp_path