Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from io import BytesIO | |
| import librosa | |
| import numpy as np | |
| from os import getenv | |
| from PIL.Image import Image, open as open_image | |
| import soundfile as sf | |
| import requests | |
| from tempfile import NamedTemporaryFile | |
| import torch | |
| from transformers import AutoProcessor | |
| # Try to import spaces decorator (for Hugging Face Spaces), otherwise use no-op decorator. | |
| try: | |
| from spaces import GPU as spaces_gpu | |
| except ImportError: | |
| # For local development, use a no-op decorator because spaces is not available. | |
| def spaces_gpu(func): | |
| """No-op decorator for local development when spaces module is not available.""" | |
| return func | |
| def get_pytorch_device() -> str: | |
| """Determine the best available PyTorch device for computation. | |
| Checks for available hardware accelerators in priority order: | |
| 1. CUDA (Nvidia GPUs and AMD ROCm) | |
| 2. XPU (Intel GPUs) | |
| 3. MPS (Apple Silicon/Metal Performance Shaders) | |
| 4. CPU (fallback) | |
| Returns: | |
| String device name: "cuda", "xpu", "mps", or "cpu" | |
| """ | |
| return ("cuda" if torch.cuda.is_available() # Nvidia CUDA and AMD ROCm | |
| else "xpu" if torch.xpu.is_available() # Intel XPU | |
| else "mps" if torch.mps.is_available() # Apple Silicon | |
| else "cpu") # gl bro 🫠 | |
| def get_torch_dtype(): | |
| """Get the appropriate torch dtype based on reduced memory setting. | |
| Returns: | |
| torch.float16 if reduced memory is enabled, None otherwise (uses default precision). | |
| """ | |
| return torch.float16 if getenv("REDUCED_MEMORY", "False").lower() == "true" else None | |
| def request_image(url: str) -> Image: | |
| """Fetch an image from a URL and return it as a PIL Image. | |
| Downloads an image from the provided URL and converts it to a PIL Image | |
| object for processing. Handles various HTTP errors and timeouts gracefully. | |
| Args: | |
| url: HTTP/HTTPS URL pointing to an image file. | |
| Returns: | |
| PIL Image object loaded from the URL. | |
| Raises: | |
| gr.Error: If the image cannot be fetched due to: | |
| - HTTP errors (4xx, 5xx status codes) | |
| - Network timeouts | |
| - Other request exceptions | |
| Note: | |
| - Timeout is configurable via REQUEST_TIMEOUT environment variable (default: 45 seconds) | |
| - Supports common image formats (JPEG, PNG, GIF, WebP, etc.) | |
| """ | |
| try: | |
| response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT", "45"))) | |
| response.raise_for_status() | |
| return open_image(BytesIO(response.content)) | |
| except requests.HTTPError as e: | |
| raise gr.Error(f"Failed to fetch image from URL because of HTTP error: {e.response.status_code} {e.response.text}") | |
| except requests.Timeout as e: | |
| raise gr.Error(f"Failed to fetch image from URL because the request timed out.") | |
| except requests.RequestException as e: | |
| raise gr.Error(f"Failed to fetch image from URL: {str(e)}") | |
| def request_audio(url: str) -> tuple[int, np.ndarray]: | |
| """Fetch an audio file from a URL and return it as audio data. | |
| Downloads an audio file from the provided URL and loads it using librosa, | |
| which supports many audio formats. Returns the audio data in a format | |
| compatible with Gradio's Audio component. | |
| Args: | |
| url: HTTP/HTTPS URL pointing to an audio file. | |
| Returns: | |
| Tuple containing: | |
| - int: Sample rate of the audio in Hz (e.g., 44100, 22050) | |
| - np.ndarray: Audio waveform data as a numpy array (float32, normalized) | |
| Raises: | |
| gr.Error: If the audio cannot be fetched or loaded due to: | |
| - HTTP errors (4xx, 5xx status codes) | |
| - Network timeouts | |
| - Unsupported audio formats | |
| - Other request or audio loading exceptions | |
| Note: | |
| - Timeout is configurable via REQUEST_TIMEOUT environment variable (default: 45 seconds) | |
| - Supports many audio formats (MP3, WAV, FLAC, OGG, M4A, etc.) | |
| - Audio is loaded at its native sample rate (sr=None) | |
| - Returns normalized float32 audio data suitable for processing | |
| """ | |
| try: | |
| response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT", "45"))) | |
| response.raise_for_status() | |
| audio_array, sample_rate = librosa.load(BytesIO(response.content), sr=None) | |
| return (sample_rate, audio_array) | |
| except requests.HTTPError as e: | |
| raise gr.Error(f"Failed to fetch audio from URL because of HTTP error: {e.response.status_code} {e.response.text}") | |
| except requests.Timeout as e: | |
| raise gr.Error(f"Failed to fetch audio from URL because the request timed out.") | |
| except requests.RequestException as e: | |
| raise gr.Error(f"Failed to fetch audio from URL: {str(e)}") | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load audio file: {str(e)}") | |
| def save_image_to_temp_file(image: Image) -> str: | |
| """Save a PIL Image to a temporary file on disk. | |
| Creates a temporary file with an appropriate extension based on the image's | |
| format and saves the image to it. This is needed because some APIs (like | |
| Hugging Face InferenceClient) require file paths rather than PIL Image objects. | |
| Args: | |
| image: PIL Image object to save. | |
| Returns: | |
| String path to the temporary file where the image was saved. | |
| Note: | |
| - Preserves the original image format if available | |
| - Falls back to PNG format if image.format is None | |
| - Temporary file is not automatically deleted (caller is responsible for cleanup) | |
| - File extension is determined from the image format | |
| - Useful for APIs that require local file paths rather than in-memory objects | |
| """ | |
| image_format = image.format if image.format else 'PNG' | |
| format_extension = image_format.lower() if image_format else 'png' | |
| temp_file = NamedTemporaryFile(delete=False, suffix=f".{format_extension}") | |
| temp_path = temp_file.name | |
| temp_file.close() | |
| image.save(temp_path, format=image_format) | |
| return temp_path | |
| def get_model_sample_rate(model_id: str) -> int: | |
| """Get the expected sample rate for an audio processing model. | |
| Retrieves the sample rate configuration from a Hugging Face model's | |
| feature extractor. This is useful for ensuring audio is resampled to | |
| match the model's expected input format. | |
| Args: | |
| model_id: Hugging Face model identifier (e.g., "openai/whisper-large-v3"). | |
| Returns: | |
| Integer sample rate in Hz that the model expects (e.g., 16000). | |
| Defaults to 16000 Hz if the sample rate cannot be determined. | |
| Note: | |
| - Most ASR models use 16kHz sample rate | |
| - Uses AutoProcessor to access the model's feature extractor configuration | |
| - Returns a sensible default (16kHz) if the model config cannot be loaded | |
| """ | |
| try: | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| return processor.feature_extractor.sampling_rate | |
| except Exception: | |
| return 16000 # Fallback value as most ASR models use 16kHz | |
| def resample_audio(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> np.ndarray: | |
| """Resample audio data to a target sample rate. | |
| Converts audio data to the target sample rate using librosa's resampling. | |
| Handles both bytes and numpy array input formats, converting bytes to | |
| float32 numpy arrays as needed. | |
| Args: | |
| target_sample_rate: Desired output sample rate in Hz (e.g., 16000). | |
| audio: Tuple containing: | |
| - int: Current sample rate of the audio | |
| - bytes | np.ndarray: Audio data (can be raw bytes or numpy array) | |
| Returns: | |
| Numpy array (float32) containing the resampled audio waveform. | |
| If sample rates match, returns the audio data unchanged. | |
| Raises: | |
| ValueError: If audio_data is neither bytes nor np.ndarray. | |
| Note: | |
| - Converts bytes to float32 by assuming int16 PCM format | |
| - Normalizes int16 values to [-1.0, 1.0] range | |
| - Only resamples if source and target sample rates differ | |
| - Uses librosa's high-quality resampling algorithm | |
| """ | |
| sample_rate, audio_data = audio | |
| # Convert audio data to a numpy array if it’s bytes | |
| if isinstance(audio_data, bytes): | |
| audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 | |
| elif isinstance(audio_data, np.ndarray): | |
| audio_array = audio_data.astype(np.float32) | |
| else: | |
| raise ValueError(f"Unsupported audio_data type: {type(audio_data)}") | |
| # Resample if sample rates don’t match. | |
| if sample_rate != target_sample_rate: | |
| audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=target_sample_rate) | |
| return audio_array | |
| def save_audio_to_temp_file(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> str: | |
| """Resample audio to target sample rate and save to a temporary WAV file. | |
| This function resamples audio data to match a target sample rate and saves | |
| it as a WAV file. This is useful for preparing audio for APIs that require | |
| specific sample rates and file formats. | |
| Args: | |
| target_sample_rate: Target sample rate in Hz for the output file (e.g., 16000). | |
| audio: Tuple containing: | |
| - int: Current sample rate of the input audio | |
| - bytes | np.ndarray: Audio data to process | |
| Returns: | |
| String path to the temporary WAV file where the audio was saved. | |
| Note: | |
| - Automatically resamples audio if sample rates don't match | |
| - Saves audio as WAV format (16-bit PCM) | |
| - Temporary file is not automatically deleted (caller is responsible for cleanup) | |
| - Audio is normalized and converted to float32 before saving | |
| - Useful for preparing audio for Hugging Face InferenceClient APIs | |
| """ | |
| audio_array = resample_audio(target_sample_rate, audio) | |
| temp_file = NamedTemporaryFile(delete=False, suffix='.wav') | |
| temp_path = temp_file.name | |
| temp_file.close() | |
| sf.write(temp_path, audio_array, target_sample_rate, format='WAV') | |
| return temp_path | |