Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,314 Bytes
d56b9d9 1c1b97a 39d9406 5c395b2 d56b9d9 5c395b2 d56b9d9 1c1b97a 5bebd85 1c1b97a 5bebd85 d56b9d9 5c395b2 1c1b97a 5c395b2 4c71b8b d56b9d9 5c395b2 d56b9d9 5c395b2 d56b9d9 39d9406 55d79e2 5bebd85 55d79e2 5bebd85 39d9406 55d79e2 39d9406 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gc
from functools import partial
import gradio as gr
import torch
from PIL.Image import Image
from transformers import AutoProcessor, BlipForConditionalGeneration
from utils import get_pytorch_device, spaces_gpu, request_image, get_torch_dtype
@spaces_gpu
def image_to_text(model: str, image: Image) -> list[str]:
"""Generate text captions for an image using BLIP model.
This function uses a BLIP (Bootstrapping Language-Image Pre-training) model
to generate multiple caption candidates for the input image. The model is
loaded, inference is performed, and then cleaned up to free GPU memory.
Args:
model: Hugging Face model ID to use for image captioning.
image: PIL Image object to generate captions for.
Returns:
List of string captions describing the image.
Note:
- Uses safetensors for secure model loading.
- Automatically selects the best available device (CUDA/XPU/MPS/CPU).
- Cleans up model and GPU memory after inference.
- Uses beam search with 3 beams, max length 20, min length 5.
"""
pytorch_device = get_pytorch_device()
dtype = get_torch_dtype()
# During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad()
# reduces memory consumption by not storing gradients. This can significantly reduce the
# amount of memory used during the inference phase.
processor = AutoProcessor.from_pretrained(model)
model_instance = BlipForConditionalGeneration.from_pretrained(
model,
use_safetensors=True, # Use safetensors to avoid torch.load restriction.
dtype=dtype
).to(pytorch_device)
inputs = processor(images=image, return_tensors="pt").to(pytorch_device)
with torch.no_grad():
generated_ids = model_instance.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5)
results = processor.batch_decode(generated_ids, skip_special_tokens=True)
# Clean up GPU memory
del model_instance, inputs, generated_ids
if pytorch_device == "cuda":
torch.cuda.empty_cache()
gc.collect()
return results
def create_image_to_text_tab(model: str):
"""Create the image-to-text captioning tab in the Gradio interface.
This function sets up all UI components for image captioning, including:
- URL input textbox for fetching images from the web
- Button to retrieve image from URL
- Image preview component
- Caption button and output list
Args:
model: Hugging Face model ID to use for image captioning.
"""
gr.Markdown("Generate a text description of an image.")
image_to_text_url_input = gr.Textbox(label="Image URL")
image_to_text_image_request_button = gr.Button("Get Image")
image_to_text_image_input = gr.Image(label="Image", type="pil")
image_to_text_image_request_button.click(
fn=request_image,
inputs=image_to_text_url_input,
outputs=image_to_text_image_input
)
image_to_text_button = gr.Button("Caption")
image_to_text_output = gr.List(label="Captions", headers=["Caption"])
image_to_text_button.click(
fn=partial(image_to_text, model),
inputs=image_to_text_image_input,
outputs=image_to_text_output
)
|