File size: 3,314 Bytes
d56b9d9
1c1b97a
39d9406
5c395b2
d56b9d9
 
5c395b2
d56b9d9
 
 
1c1b97a
5bebd85
 
 
 
 
 
 
1c1b97a
5bebd85
 
 
 
 
 
 
 
 
 
 
d56b9d9
5c395b2
 
 
 
 
1c1b97a
 
 
5c395b2
 
4c71b8b
d56b9d9
5c395b2
 
d56b9d9
5c395b2
 
 
 
 
d56b9d9
 
39d9406
 
55d79e2
5bebd85
 
 
 
 
 
 
55d79e2
 
 
5bebd85
39d9406
 
 
 
 
 
 
 
 
 
 
 
55d79e2
39d9406
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gc
from functools import partial
import gradio as gr
import torch
from PIL.Image import Image
from transformers import AutoProcessor, BlipForConditionalGeneration
from utils import get_pytorch_device, spaces_gpu, request_image, get_torch_dtype


@spaces_gpu
def image_to_text(model: str, image: Image) -> list[str]:
    """Generate text captions for an image using BLIP model.

    This function uses a BLIP (Bootstrapping Language-Image Pre-training) model
    to generate multiple caption candidates for the input image. The model is
    loaded, inference is performed, and then cleaned up to free GPU memory.

    Args:
        model: Hugging Face model ID to use for image captioning.
        image: PIL Image object to generate captions for.

    Returns:
        List of string captions describing the image.

    Note:
        - Uses safetensors for secure model loading.
        - Automatically selects the best available device (CUDA/XPU/MPS/CPU).
        - Cleans up model and GPU memory after inference.
        - Uses beam search with 3 beams, max length 20, min length 5.
    """
    pytorch_device = get_pytorch_device()
    dtype = get_torch_dtype()
    
    # During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad()
    # reduces memory consumption by not storing gradients. This can significantly reduce the
    # amount of memory used during the inference phase.
    processor = AutoProcessor.from_pretrained(model)
    model_instance = BlipForConditionalGeneration.from_pretrained(
        model,
        use_safetensors=True, # Use safetensors to avoid torch.load restriction.
        dtype=dtype
    ).to(pytorch_device)
    inputs = processor(images=image, return_tensors="pt").to(pytorch_device)
    with torch.no_grad():
        generated_ids = model_instance.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5)
    results = processor.batch_decode(generated_ids, skip_special_tokens=True)
    
    # Clean up GPU memory
    del model_instance, inputs, generated_ids
    if pytorch_device == "cuda":
        torch.cuda.empty_cache()
    gc.collect()
    return results


def create_image_to_text_tab(model: str):
    """Create the image-to-text captioning tab in the Gradio interface.
    
    This function sets up all UI components for image captioning, including:
    - URL input textbox for fetching images from the web
    - Button to retrieve image from URL
    - Image preview component
    - Caption button and output list
    
    Args:
        model: Hugging Face model ID to use for image captioning.
    """
    gr.Markdown("Generate a text description of an image.")
    image_to_text_url_input = gr.Textbox(label="Image URL")
    image_to_text_image_request_button = gr.Button("Get Image")
    image_to_text_image_input = gr.Image(label="Image", type="pil")
    image_to_text_image_request_button.click(
        fn=request_image,
        inputs=image_to_text_url_input,
        outputs=image_to_text_image_input
    )
    image_to_text_button = gr.Button("Caption")
    image_to_text_output = gr.List(label="Captions", headers=["Caption"])
    image_to_text_button.click(
        fn=partial(image_to_text, model),
        inputs=image_to_text_image_input,
        outputs=image_to_text_output
    )