import gc from os import getenv import gradio as gr from PIL.Image import Image from transformers import AutoProcessor, BlipForConditionalGeneration from utils import get_pytorch_device, spaces_gpu, request_image @spaces_gpu def image_to_text(image: Image) -> list[str]: """Generate text captions for an image using BLIP model. This function uses a BLIP (Bootstrapping Language-Image Pre-training) model to generate multiple caption candidates for the input image. The model is loaded, inference is performed, and then cleaned up to free GPU memory. Args: image: PIL Image object to generate captions for. Returns: List of string captions describing the image. Note: - The model ID is determined by the IMAGE_TO_TEXT_MODEL environment variable. - Uses safetensors for secure model loading. - Automatically selects the best available device (CUDA/XPU/MPS/CPU). - Cleans up model and GPU memory after inference. - Uses beam search with 3 beams, max length 20, min length 5. """ image_to_text_model_id = getenv("IMAGE_TO_TEXT_MODEL") pytorch_device = get_pytorch_device() processor = AutoProcessor.from_pretrained(image_to_text_model_id) model = BlipForConditionalGeneration.from_pretrained( image_to_text_model_id, use_safetensors=True # Use safetensors to avoid torch.load restriction. ).to(pytorch_device) inputs = processor(images=image, return_tensors="pt").to(pytorch_device) generated_ids = model.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5) results = processor.batch_decode(generated_ids, skip_special_tokens=True) del model, inputs gc.collect() return results def create_image_to_text_tab(): """Create the image-to-text captioning tab in the Gradio interface. This function sets up all UI components for image captioning, including: - URL input textbox for fetching images from the web - Button to retrieve image from URL - Image preview component - Caption button and output list """ gr.Markdown("Generate a text description of an image.") image_to_text_url_input = gr.Textbox(label="Image URL") image_to_text_image_request_button = gr.Button("Get Image") image_to_text_image_input = gr.Image(label="Image", type="pil") image_to_text_image_request_button.click( fn=request_image, inputs=image_to_text_url_input, outputs=image_to_text_image_input ) image_to_text_button = gr.Button("Caption") image_to_text_output = gr.List(label="Captions", headers=["Caption"]) image_to_text_button.click( fn=image_to_text, inputs=image_to_text_image_input, outputs=image_to_text_output )