ai-building-blocks / utils.py
LiKenun's picture
Switch to use GPU instead of inference client
5c395b2
import gradio as gr
from io import BytesIO
import librosa
import numpy as np
from os import getenv
from PIL.Image import Image, open as open_image
import soundfile as sf
import requests
from tempfile import NamedTemporaryFile
import torch
from transformers import AutoProcessor
# Try to import spaces decorator (for Hugging Face Spaces), otherwise use no-op decorator.
try:
from spaces import GPU as spaces_gpu
except ImportError:
# For local development, use a no-op decorator because spaces is not available.
def spaces_gpu(func):
"""No-op decorator for local development when spaces module is not available."""
return func
def get_pytorch_device() -> str:
"""Determine the best available PyTorch device for computation.
Checks for available hardware accelerators in priority order:
1. CUDA (Nvidia GPUs and AMD ROCm)
2. XPU (Intel GPUs)
3. MPS (Apple Silicon/Metal Performance Shaders)
4. CPU (fallback)
Returns:
String device name: "cuda", "xpu", "mps", or "cpu"
"""
return ("cuda" if torch.cuda.is_available() # Nvidia CUDA and AMD ROCm
else "xpu" if torch.xpu.is_available() # Intel XPU
else "mps" if torch.mps.is_available() # Apple Silicon
else "cpu") # gl bro 🫠
def get_torch_dtype():
"""Get the appropriate torch dtype based on reduced memory setting.
Returns:
torch.float16 if reduced memory is enabled, None otherwise (uses default precision).
"""
return torch.float16 if getenv("REDUCED_MEMORY", "False").lower() == "true" else None
def request_image(url: str) -> Image:
"""Fetch an image from a URL and return it as a PIL Image.
Downloads an image from the provided URL and converts it to a PIL Image
object for processing. Handles various HTTP errors and timeouts gracefully.
Args:
url: HTTP/HTTPS URL pointing to an image file.
Returns:
PIL Image object loaded from the URL.
Raises:
gr.Error: If the image cannot be fetched due to:
- HTTP errors (4xx, 5xx status codes)
- Network timeouts
- Other request exceptions
Note:
- Timeout is configurable via REQUEST_TIMEOUT environment variable (default: 45 seconds)
- Supports common image formats (JPEG, PNG, GIF, WebP, etc.)
"""
try:
response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT", "45")))
response.raise_for_status()
return open_image(BytesIO(response.content))
except requests.HTTPError as e:
raise gr.Error(f"Failed to fetch image from URL because of HTTP error: {e.response.status_code} {e.response.text}")
except requests.Timeout as e:
raise gr.Error(f"Failed to fetch image from URL because the request timed out.")
except requests.RequestException as e:
raise gr.Error(f"Failed to fetch image from URL: {str(e)}")
def request_audio(url: str) -> tuple[int, np.ndarray]:
"""Fetch an audio file from a URL and return it as audio data.
Downloads an audio file from the provided URL and loads it using librosa,
which supports many audio formats. Returns the audio data in a format
compatible with Gradio's Audio component.
Args:
url: HTTP/HTTPS URL pointing to an audio file.
Returns:
Tuple containing:
- int: Sample rate of the audio in Hz (e.g., 44100, 22050)
- np.ndarray: Audio waveform data as a numpy array (float32, normalized)
Raises:
gr.Error: If the audio cannot be fetched or loaded due to:
- HTTP errors (4xx, 5xx status codes)
- Network timeouts
- Unsupported audio formats
- Other request or audio loading exceptions
Note:
- Timeout is configurable via REQUEST_TIMEOUT environment variable (default: 45 seconds)
- Supports many audio formats (MP3, WAV, FLAC, OGG, M4A, etc.)
- Audio is loaded at its native sample rate (sr=None)
- Returns normalized float32 audio data suitable for processing
"""
try:
response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT", "45")))
response.raise_for_status()
audio_array, sample_rate = librosa.load(BytesIO(response.content), sr=None)
return (sample_rate, audio_array)
except requests.HTTPError as e:
raise gr.Error(f"Failed to fetch audio from URL because of HTTP error: {e.response.status_code} {e.response.text}")
except requests.Timeout as e:
raise gr.Error(f"Failed to fetch audio from URL because the request timed out.")
except requests.RequestException as e:
raise gr.Error(f"Failed to fetch audio from URL: {str(e)}")
except Exception as e:
raise gr.Error(f"Failed to load audio file: {str(e)}")
def save_image_to_temp_file(image: Image) -> str:
"""Save a PIL Image to a temporary file on disk.
Creates a temporary file with an appropriate extension based on the image's
format and saves the image to it. This is needed because some APIs (like
Hugging Face InferenceClient) require file paths rather than PIL Image objects.
Args:
image: PIL Image object to save.
Returns:
String path to the temporary file where the image was saved.
Note:
- Preserves the original image format if available
- Falls back to PNG format if image.format is None
- Temporary file is not automatically deleted (caller is responsible for cleanup)
- File extension is determined from the image format
- Useful for APIs that require local file paths rather than in-memory objects
"""
image_format = image.format if image.format else 'PNG'
format_extension = image_format.lower() if image_format else 'png'
temp_file = NamedTemporaryFile(delete=False, suffix=f".{format_extension}")
temp_path = temp_file.name
temp_file.close()
image.save(temp_path, format=image_format)
return temp_path
def get_model_sample_rate(model_id: str) -> int:
"""Get the expected sample rate for an audio processing model.
Retrieves the sample rate configuration from a Hugging Face model's
feature extractor. This is useful for ensuring audio is resampled to
match the model's expected input format.
Args:
model_id: Hugging Face model identifier (e.g., "openai/whisper-large-v3").
Returns:
Integer sample rate in Hz that the model expects (e.g., 16000).
Defaults to 16000 Hz if the sample rate cannot be determined.
Note:
- Most ASR models use 16kHz sample rate
- Uses AutoProcessor to access the model's feature extractor configuration
- Returns a sensible default (16kHz) if the model config cannot be loaded
"""
try:
processor = AutoProcessor.from_pretrained(model_id)
return processor.feature_extractor.sampling_rate
except Exception:
return 16000 # Fallback value as most ASR models use 16kHz
def resample_audio(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> np.ndarray:
"""Resample audio data to a target sample rate.
Converts audio data to the target sample rate using librosa's resampling.
Handles both bytes and numpy array input formats, converting bytes to
float32 numpy arrays as needed.
Args:
target_sample_rate: Desired output sample rate in Hz (e.g., 16000).
audio: Tuple containing:
- int: Current sample rate of the audio
- bytes | np.ndarray: Audio data (can be raw bytes or numpy array)
Returns:
Numpy array (float32) containing the resampled audio waveform.
If sample rates match, returns the audio data unchanged.
Raises:
ValueError: If audio_data is neither bytes nor np.ndarray.
Note:
- Converts bytes to float32 by assuming int16 PCM format
- Normalizes int16 values to [-1.0, 1.0] range
- Only resamples if source and target sample rates differ
- Uses librosa's high-quality resampling algorithm
"""
sample_rate, audio_data = audio
# Convert audio data to a numpy array if it’s bytes
if isinstance(audio_data, bytes):
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
elif isinstance(audio_data, np.ndarray):
audio_array = audio_data.astype(np.float32)
else:
raise ValueError(f"Unsupported audio_data type: {type(audio_data)}")
# Resample if sample rates don’t match.
if sample_rate != target_sample_rate:
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=target_sample_rate)
return audio_array
def save_audio_to_temp_file(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> str:
"""Resample audio to target sample rate and save to a temporary WAV file.
This function resamples audio data to match a target sample rate and saves
it as a WAV file. This is useful for preparing audio for APIs that require
specific sample rates and file formats.
Args:
target_sample_rate: Target sample rate in Hz for the output file (e.g., 16000).
audio: Tuple containing:
- int: Current sample rate of the input audio
- bytes | np.ndarray: Audio data to process
Returns:
String path to the temporary WAV file where the audio was saved.
Note:
- Automatically resamples audio if sample rates don't match
- Saves audio as WAV format (16-bit PCM)
- Temporary file is not automatically deleted (caller is responsible for cleanup)
- Audio is normalized and converted to float32 before saving
- Useful for preparing audio for Hugging Face InferenceClient APIs
"""
audio_array = resample_audio(target_sample_rate, audio)
temp_file = NamedTemporaryFile(delete=False, suffix='.wav')
temp_path = temp_file.name
temp_file.close()
sf.write(temp_path, audio_array, target_sample_rate, format='WAV')
return temp_path