ai-building-blocks / image_classification.py
LiKenun's picture
Switch to use GPU instead of inference client
5c395b2
import gc
from functools import partial
import gradio as gr
import torch
from PIL.Image import Image
import pandas as pd
from pandas import DataFrame
from transformers import pipeline
from utils import get_pytorch_device, spaces_gpu, request_image, get_torch_dtype
@spaces_gpu
def image_classification(model: str, image: Image) -> DataFrame:
"""Classify an image using a vision transformer model.
This function classifies a recyclable item image into categories:
cardboard, glass, metal, paper, plastic, or other. The model is loaded,
inference is performed, and then cleaned up to free GPU memory.
Args:
model: Hugging Face model ID to use for image classification.
image: PIL Image object to classify.
Returns:
Pandas DataFrame with two columns:
- Label: The classification label (e.g., "cardboard", "glass")
- Probability: The confidence score as a percentage string (e.g., "95.23%")
Note:
- Uses safetensors for secure model loading.
- Automatically selects the best available device (CUDA/XPU/MPS/CPU).
- Cleans up model and GPU memory after inference.
"""
pytorch_device = get_pytorch_device()
dtype = get_torch_dtype()
# During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad()
# reduces memory consumption by not storing gradients. This can significantly reduce the
# amount of memory used during the inference phase.
model_kwargs = {"use_safetensors": True}
if dtype is not None:
model_kwargs["dtype"] = dtype
classifier = pipeline(
"image-classification",
model=model,
device=0 if pytorch_device == "cuda" else -1,
model_kwargs=model_kwargs
)
with torch.no_grad():
results = classifier(image)
# Clean up GPU memory
del classifier
if pytorch_device == "cuda":
torch.cuda.empty_cache()
gc.collect()
return pd.DataFrame({
"Label": [result["label"] for result in results],
"Probability": [f"{result['score']:.2%}" for result in results]
})
def create_image_classification_tab(model: str):
"""Create the image classification tab in the Gradio interface.
This function sets up all UI components for image classification, including:
- URL input textbox for fetching images from the web
- Button to retrieve image from URL
- Image preview component
- Classify button and output dataframe showing labels and probabilities
Args:
model: Hugging Face model ID to use for image classification.
"""
gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
image_classification_url_input = gr.Textbox(label="Image URL")
image_classification_image_request_button = gr.Button("Get Image")
image_classification_image_input = gr.Image(label="Image", type="pil")
image_classification_image_request_button.click(
fn=request_image,
inputs=image_classification_url_input,
outputs=image_classification_image_input
)
image_classification_button = gr.Button("Classify")
image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False)
image_classification_button.click(
fn=partial(image_classification, model),
inputs=image_classification_image_input,
outputs=image_classification_output
)