import gc from functools import partial import gradio as gr import torch from PIL.Image import Image from transformers import AutoProcessor, BlipForConditionalGeneration from utils import get_pytorch_device, spaces_gpu, request_image, get_torch_dtype @spaces_gpu def image_to_text(model: str, image: Image) -> list[str]: """Generate text captions for an image using BLIP model. This function uses a BLIP (Bootstrapping Language-Image Pre-training) model to generate multiple caption candidates for the input image. The model is loaded, inference is performed, and then cleaned up to free GPU memory. Args: model: Hugging Face model ID to use for image captioning. image: PIL Image object to generate captions for. Returns: List of string captions describing the image. Note: - Uses safetensors for secure model loading. - Automatically selects the best available device (CUDA/XPU/MPS/CPU). - Cleans up model and GPU memory after inference. - Uses beam search with 3 beams, max length 20, min length 5. """ pytorch_device = get_pytorch_device() dtype = get_torch_dtype() # During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad() # reduces memory consumption by not storing gradients. This can significantly reduce the # amount of memory used during the inference phase. processor = AutoProcessor.from_pretrained(model) model_instance = BlipForConditionalGeneration.from_pretrained( model, use_safetensors=True, # Use safetensors to avoid torch.load restriction. dtype=dtype ).to(pytorch_device) inputs = processor(images=image, return_tensors="pt").to(pytorch_device) with torch.no_grad(): generated_ids = model_instance.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5) results = processor.batch_decode(generated_ids, skip_special_tokens=True) # Clean up GPU memory del model_instance, inputs, generated_ids if pytorch_device == "cuda": torch.cuda.empty_cache() gc.collect() return results def create_image_to_text_tab(model: str): """Create the image-to-text captioning tab in the Gradio interface. This function sets up all UI components for image captioning, including: - URL input textbox for fetching images from the web - Button to retrieve image from URL - Image preview component - Caption button and output list Args: model: Hugging Face model ID to use for image captioning. """ gr.Markdown("Generate a text description of an image.") image_to_text_url_input = gr.Textbox(label="Image URL") image_to_text_image_request_button = gr.Button("Get Image") image_to_text_image_input = gr.Image(label="Image", type="pil") image_to_text_image_request_button.click( fn=request_image, inputs=image_to_text_url_input, outputs=image_to_text_image_input ) image_to_text_button = gr.Button("Caption") image_to_text_output = gr.List(label="Captions", headers=["Caption"]) image_to_text_button.click( fn=partial(image_to_text, model), inputs=image_to_text_image_input, outputs=image_to_text_output )