import gc from functools import partial import gradio as gr import torch from PIL.Image import Image import pandas as pd from pandas import DataFrame from transformers import pipeline from utils import get_pytorch_device, spaces_gpu, request_image, get_torch_dtype @spaces_gpu def image_classification(model: str, image: Image) -> DataFrame: """Classify an image using a vision transformer model. This function classifies a recyclable item image into categories: cardboard, glass, metal, paper, plastic, or other. The model is loaded, inference is performed, and then cleaned up to free GPU memory. Args: model: Hugging Face model ID to use for image classification. image: PIL Image object to classify. Returns: Pandas DataFrame with two columns: - Label: The classification label (e.g., "cardboard", "glass") - Probability: The confidence score as a percentage string (e.g., "95.23%") Note: - Uses safetensors for secure model loading. - Automatically selects the best available device (CUDA/XPU/MPS/CPU). - Cleans up model and GPU memory after inference. """ pytorch_device = get_pytorch_device() dtype = get_torch_dtype() # During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad() # reduces memory consumption by not storing gradients. This can significantly reduce the # amount of memory used during the inference phase. model_kwargs = {"use_safetensors": True} if dtype is not None: model_kwargs["dtype"] = dtype classifier = pipeline( "image-classification", model=model, device=0 if pytorch_device == "cuda" else -1, model_kwargs=model_kwargs ) with torch.no_grad(): results = classifier(image) # Clean up GPU memory del classifier if pytorch_device == "cuda": torch.cuda.empty_cache() gc.collect() return pd.DataFrame({ "Label": [result["label"] for result in results], "Probability": [f"{result['score']:.2%}" for result in results] }) def create_image_classification_tab(model: str): """Create the image classification tab in the Gradio interface. This function sets up all UI components for image classification, including: - URL input textbox for fetching images from the web - Button to retrieve image from URL - Image preview component - Classify button and output dataframe showing labels and probabilities Args: model: Hugging Face model ID to use for image classification. """ gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).") image_classification_url_input = gr.Textbox(label="Image URL") image_classification_image_request_button = gr.Button("Get Image") image_classification_image_input = gr.Image(label="Image", type="pil") image_classification_image_request_button.click( fn=request_image, inputs=image_classification_url_input, outputs=image_classification_image_input ) image_classification_button = gr.Button("Classify") image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False) image_classification_button.click( fn=partial(image_classification, model), inputs=image_classification_image_input, outputs=image_classification_output )