ai-building-blocks / image_to_text.py
LiKenun's picture
Add documentation
5bebd85
raw
history blame
2.8 kB
import gc
from os import getenv
import gradio as gr
from PIL.Image import Image
from transformers import AutoProcessor, BlipForConditionalGeneration
from utils import get_pytorch_device, spaces_gpu, request_image
@spaces_gpu
def image_to_text(image: Image) -> list[str]:
"""Generate text captions for an image using BLIP model.
This function uses a BLIP (Bootstrapping Language-Image Pre-training) model
to generate multiple caption candidates for the input image. The model is
loaded, inference is performed, and then cleaned up to free GPU memory.
Args:
image: PIL Image object to generate captions for.
Returns:
List of string captions describing the image.
Note:
- The model ID is determined by the IMAGE_TO_TEXT_MODEL environment variable.
- Uses safetensors for secure model loading.
- Automatically selects the best available device (CUDA/XPU/MPS/CPU).
- Cleans up model and GPU memory after inference.
- Uses beam search with 3 beams, max length 20, min length 5.
"""
image_to_text_model_id = getenv("IMAGE_TO_TEXT_MODEL")
pytorch_device = get_pytorch_device()
processor = AutoProcessor.from_pretrained(image_to_text_model_id)
model = BlipForConditionalGeneration.from_pretrained(
image_to_text_model_id,
use_safetensors=True # Use safetensors to avoid torch.load restriction.
).to(pytorch_device)
inputs = processor(images=image, return_tensors="pt").to(pytorch_device)
generated_ids = model.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5)
results = processor.batch_decode(generated_ids, skip_special_tokens=True)
del model, inputs
gc.collect()
return results
def create_image_to_text_tab():
"""Create the image-to-text captioning tab in the Gradio interface.
This function sets up all UI components for image captioning, including:
- URL input textbox for fetching images from the web
- Button to retrieve image from URL
- Image preview component
- Caption button and output list
"""
gr.Markdown("Generate a text description of an image.")
image_to_text_url_input = gr.Textbox(label="Image URL")
image_to_text_image_request_button = gr.Button("Get Image")
image_to_text_image_input = gr.Image(label="Image", type="pil")
image_to_text_image_request_button.click(
fn=request_image,
inputs=image_to_text_url_input,
outputs=image_to_text_image_input
)
image_to_text_button = gr.Button("Caption")
image_to_text_output = gr.List(label="Captions", headers=["Caption"])
image_to_text_button.click(
fn=image_to_text,
inputs=image_to_text_image_input,
outputs=image_to_text_output
)