Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,587 Bytes
5c395b2 39d9406 5c395b2 d56b9d9 dc382c8 5c395b2 dc382c8 5c395b2 5bebd85 5c395b2 5bebd85 1c1b97a 5bebd85 5c395b2 5bebd85 5c395b2 39d9406 5c395b2 5bebd85 55d79e2 5bebd85 39d9406 5c395b2 39d9406 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gc
from functools import partial
import gradio as gr
import torch
from PIL.Image import Image
import pandas as pd
from pandas import DataFrame
from transformers import pipeline
from utils import get_pytorch_device, spaces_gpu, request_image, get_torch_dtype
@spaces_gpu
def image_classification(model: str, image: Image) -> DataFrame:
"""Classify an image using a vision transformer model.
This function classifies a recyclable item image into categories:
cardboard, glass, metal, paper, plastic, or other. The model is loaded,
inference is performed, and then cleaned up to free GPU memory.
Args:
model: Hugging Face model ID to use for image classification.
image: PIL Image object to classify.
Returns:
Pandas DataFrame with two columns:
- Label: The classification label (e.g., "cardboard", "glass")
- Probability: The confidence score as a percentage string (e.g., "95.23%")
Note:
- Uses safetensors for secure model loading.
- Automatically selects the best available device (CUDA/XPU/MPS/CPU).
- Cleans up model and GPU memory after inference.
"""
pytorch_device = get_pytorch_device()
dtype = get_torch_dtype()
# During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad()
# reduces memory consumption by not storing gradients. This can significantly reduce the
# amount of memory used during the inference phase.
model_kwargs = {"use_safetensors": True}
if dtype is not None:
model_kwargs["dtype"] = dtype
classifier = pipeline(
"image-classification",
model=model,
device=0 if pytorch_device == "cuda" else -1,
model_kwargs=model_kwargs
)
with torch.no_grad():
results = classifier(image)
# Clean up GPU memory
del classifier
if pytorch_device == "cuda":
torch.cuda.empty_cache()
gc.collect()
return pd.DataFrame({
"Label": [result["label"] for result in results],
"Probability": [f"{result['score']:.2%}" for result in results]
})
def create_image_classification_tab(model: str):
"""Create the image classification tab in the Gradio interface.
This function sets up all UI components for image classification, including:
- URL input textbox for fetching images from the web
- Button to retrieve image from URL
- Image preview component
- Classify button and output dataframe showing labels and probabilities
Args:
model: Hugging Face model ID to use for image classification.
"""
gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
image_classification_url_input = gr.Textbox(label="Image URL")
image_classification_image_request_button = gr.Button("Get Image")
image_classification_image_input = gr.Image(label="Image", type="pil")
image_classification_image_request_button.click(
fn=request_image,
inputs=image_classification_url_input,
outputs=image_classification_image_input
)
image_classification_button = gr.Button("Classify")
image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False)
image_classification_button.click(
fn=partial(image_classification, model),
inputs=image_classification_image_input,
outputs=image_classification_output
)
|