File size: 3,587 Bytes
5c395b2
39d9406
 
5c395b2
d56b9d9
dc382c8
 
5c395b2
 
dc382c8
 
5c395b2
 
 
5bebd85
 
5c395b2
 
5bebd85
 
1c1b97a
5bebd85
 
 
 
 
 
 
 
5c395b2
 
 
5bebd85
5c395b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39d9406
 
5c395b2
5bebd85
 
 
 
 
 
 
 
 
55d79e2
5bebd85
39d9406
 
 
 
 
 
 
 
 
 
 
 
5c395b2
39d9406
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gc
from functools import partial
import gradio as gr
import torch
from PIL.Image import Image
import pandas as pd
from pandas import DataFrame
from transformers import pipeline
from utils import get_pytorch_device, spaces_gpu, request_image, get_torch_dtype


@spaces_gpu
def image_classification(model: str, image: Image) -> DataFrame:
    """Classify an image using a vision transformer model.
    
    This function classifies a recyclable item image into categories:
    cardboard, glass, metal, paper, plastic, or other. The model is loaded,
    inference is performed, and then cleaned up to free GPU memory.
    
    Args:
        model: Hugging Face model ID to use for image classification.
        image: PIL Image object to classify.
    
    Returns:
        Pandas DataFrame with two columns:
            - Label: The classification label (e.g., "cardboard", "glass")
            - Probability: The confidence score as a percentage string (e.g., "95.23%")
    
    Note:
        - Uses safetensors for secure model loading.
        - Automatically selects the best available device (CUDA/XPU/MPS/CPU).
        - Cleans up model and GPU memory after inference.
    """
    pytorch_device = get_pytorch_device()
    dtype = get_torch_dtype()
    
    # During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad()
    # reduces memory consumption by not storing gradients. This can significantly reduce the
    # amount of memory used during the inference phase.
    model_kwargs = {"use_safetensors": True}
    if dtype is not None:
        model_kwargs["dtype"] = dtype
    
    classifier = pipeline(
        "image-classification",
        model=model,
        device=0 if pytorch_device == "cuda" else -1,
        model_kwargs=model_kwargs
    )
    with torch.no_grad():
        results = classifier(image)
    
    # Clean up GPU memory
    del classifier
    if pytorch_device == "cuda":
        torch.cuda.empty_cache()
    gc.collect()
    return pd.DataFrame({
        "Label": [result["label"] for result in results],
        "Probability": [f"{result['score']:.2%}" for result in results]
    })


def create_image_classification_tab(model: str):
    """Create the image classification tab in the Gradio interface.

    This function sets up all UI components for image classification, including:
    - URL input textbox for fetching images from the web
    - Button to retrieve image from URL
    - Image preview component
    - Classify button and output dataframe showing labels and probabilities

    Args:
        model: Hugging Face model ID to use for image classification.
    """
    gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
    image_classification_url_input = gr.Textbox(label="Image URL")
    image_classification_image_request_button = gr.Button("Get Image")
    image_classification_image_input = gr.Image(label="Image", type="pil")
    image_classification_image_request_button.click(
        fn=request_image,
        inputs=image_classification_url_input,
        outputs=image_classification_image_input
    )
    image_classification_button = gr.Button("Classify")
    image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False)
    image_classification_button.click(
        fn=partial(image_classification, model),
        inputs=image_classification_image_input,
        outputs=image_classification_output
    )