Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| from functools import partial | |
| import gradio as gr | |
| import torch | |
| from PIL.Image import Image | |
| from transformers import AutoProcessor, BlipForConditionalGeneration | |
| from utils import get_pytorch_device, spaces_gpu, request_image, get_torch_dtype | |
| def image_to_text(model: str, image: Image) -> list[str]: | |
| """Generate text captions for an image using BLIP model. | |
| This function uses a BLIP (Bootstrapping Language-Image Pre-training) model | |
| to generate multiple caption candidates for the input image. The model is | |
| loaded, inference is performed, and then cleaned up to free GPU memory. | |
| Args: | |
| model: Hugging Face model ID to use for image captioning. | |
| image: PIL Image object to generate captions for. | |
| Returns: | |
| List of string captions describing the image. | |
| Note: | |
| - Uses safetensors for secure model loading. | |
| - Automatically selects the best available device (CUDA/XPU/MPS/CPU). | |
| - Cleans up model and GPU memory after inference. | |
| - Uses beam search with 3 beams, max length 20, min length 5. | |
| """ | |
| pytorch_device = get_pytorch_device() | |
| dtype = get_torch_dtype() | |
| # During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad() | |
| # reduces memory consumption by not storing gradients. This can significantly reduce the | |
| # amount of memory used during the inference phase. | |
| processor = AutoProcessor.from_pretrained(model) | |
| model_instance = BlipForConditionalGeneration.from_pretrained( | |
| model, | |
| use_safetensors=True, # Use safetensors to avoid torch.load restriction. | |
| dtype=dtype | |
| ).to(pytorch_device) | |
| inputs = processor(images=image, return_tensors="pt").to(pytorch_device) | |
| with torch.no_grad(): | |
| generated_ids = model_instance.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5) | |
| results = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| # Clean up GPU memory | |
| del model_instance, inputs, generated_ids | |
| if pytorch_device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return results | |
| def create_image_to_text_tab(model: str): | |
| """Create the image-to-text captioning tab in the Gradio interface. | |
| This function sets up all UI components for image captioning, including: | |
| - URL input textbox for fetching images from the web | |
| - Button to retrieve image from URL | |
| - Image preview component | |
| - Caption button and output list | |
| Args: | |
| model: Hugging Face model ID to use for image captioning. | |
| """ | |
| gr.Markdown("Generate a text description of an image.") | |
| image_to_text_url_input = gr.Textbox(label="Image URL") | |
| image_to_text_image_request_button = gr.Button("Get Image") | |
| image_to_text_image_input = gr.Image(label="Image", type="pil") | |
| image_to_text_image_request_button.click( | |
| fn=request_image, | |
| inputs=image_to_text_url_input, | |
| outputs=image_to_text_image_input | |
| ) | |
| image_to_text_button = gr.Button("Caption") | |
| image_to_text_output = gr.List(label="Captions", headers=["Caption"]) | |
| image_to_text_button.click( | |
| fn=partial(image_to_text, model), | |
| inputs=image_to_text_image_input, | |
| outputs=image_to_text_output | |
| ) | |