ai-building-blocks / image_to_text.py
LiKenun's picture
Switch to use GPU instead of inference client
5c395b2
import gc
from functools import partial
import gradio as gr
import torch
from PIL.Image import Image
from transformers import AutoProcessor, BlipForConditionalGeneration
from utils import get_pytorch_device, spaces_gpu, request_image, get_torch_dtype
@spaces_gpu
def image_to_text(model: str, image: Image) -> list[str]:
"""Generate text captions for an image using BLIP model.
This function uses a BLIP (Bootstrapping Language-Image Pre-training) model
to generate multiple caption candidates for the input image. The model is
loaded, inference is performed, and then cleaned up to free GPU memory.
Args:
model: Hugging Face model ID to use for image captioning.
image: PIL Image object to generate captions for.
Returns:
List of string captions describing the image.
Note:
- Uses safetensors for secure model loading.
- Automatically selects the best available device (CUDA/XPU/MPS/CPU).
- Cleans up model and GPU memory after inference.
- Uses beam search with 3 beams, max length 20, min length 5.
"""
pytorch_device = get_pytorch_device()
dtype = get_torch_dtype()
# During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad()
# reduces memory consumption by not storing gradients. This can significantly reduce the
# amount of memory used during the inference phase.
processor = AutoProcessor.from_pretrained(model)
model_instance = BlipForConditionalGeneration.from_pretrained(
model,
use_safetensors=True, # Use safetensors to avoid torch.load restriction.
dtype=dtype
).to(pytorch_device)
inputs = processor(images=image, return_tensors="pt").to(pytorch_device)
with torch.no_grad():
generated_ids = model_instance.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5)
results = processor.batch_decode(generated_ids, skip_special_tokens=True)
# Clean up GPU memory
del model_instance, inputs, generated_ids
if pytorch_device == "cuda":
torch.cuda.empty_cache()
gc.collect()
return results
def create_image_to_text_tab(model: str):
"""Create the image-to-text captioning tab in the Gradio interface.
This function sets up all UI components for image captioning, including:
- URL input textbox for fetching images from the web
- Button to retrieve image from URL
- Image preview component
- Caption button and output list
Args:
model: Hugging Face model ID to use for image captioning.
"""
gr.Markdown("Generate a text description of an image.")
image_to_text_url_input = gr.Textbox(label="Image URL")
image_to_text_image_request_button = gr.Button("Get Image")
image_to_text_image_input = gr.Image(label="Image", type="pil")
image_to_text_image_request_button.click(
fn=request_image,
inputs=image_to_text_url_input,
outputs=image_to_text_image_input
)
image_to_text_button = gr.Button("Caption")
image_to_text_output = gr.List(label="Captions", headers=["Caption"])
image_to_text_button.click(
fn=partial(image_to_text, model),
inputs=image_to_text_image_input,
outputs=image_to_text_output
)