Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| from functools import partial | |
| import gradio as gr | |
| import torch | |
| from PIL.Image import Image | |
| import pandas as pd | |
| from pandas import DataFrame | |
| from transformers import pipeline | |
| from utils import get_pytorch_device, spaces_gpu, request_image, get_torch_dtype | |
| def image_classification(model: str, image: Image) -> DataFrame: | |
| """Classify an image using a vision transformer model. | |
| This function classifies a recyclable item image into categories: | |
| cardboard, glass, metal, paper, plastic, or other. The model is loaded, | |
| inference is performed, and then cleaned up to free GPU memory. | |
| Args: | |
| model: Hugging Face model ID to use for image classification. | |
| image: PIL Image object to classify. | |
| Returns: | |
| Pandas DataFrame with two columns: | |
| - Label: The classification label (e.g., "cardboard", "glass") | |
| - Probability: The confidence score as a percentage string (e.g., "95.23%") | |
| Note: | |
| - Uses safetensors for secure model loading. | |
| - Automatically selects the best available device (CUDA/XPU/MPS/CPU). | |
| - Cleans up model and GPU memory after inference. | |
| """ | |
| pytorch_device = get_pytorch_device() | |
| dtype = get_torch_dtype() | |
| # During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad() | |
| # reduces memory consumption by not storing gradients. This can significantly reduce the | |
| # amount of memory used during the inference phase. | |
| model_kwargs = {"use_safetensors": True} | |
| if dtype is not None: | |
| model_kwargs["dtype"] = dtype | |
| classifier = pipeline( | |
| "image-classification", | |
| model=model, | |
| device=0 if pytorch_device == "cuda" else -1, | |
| model_kwargs=model_kwargs | |
| ) | |
| with torch.no_grad(): | |
| results = classifier(image) | |
| # Clean up GPU memory | |
| del classifier | |
| if pytorch_device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return pd.DataFrame({ | |
| "Label": [result["label"] for result in results], | |
| "Probability": [f"{result['score']:.2%}" for result in results] | |
| }) | |
| def create_image_classification_tab(model: str): | |
| """Create the image classification tab in the Gradio interface. | |
| This function sets up all UI components for image classification, including: | |
| - URL input textbox for fetching images from the web | |
| - Button to retrieve image from URL | |
| - Image preview component | |
| - Classify button and output dataframe showing labels and probabilities | |
| Args: | |
| model: Hugging Face model ID to use for image classification. | |
| """ | |
| gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).") | |
| image_classification_url_input = gr.Textbox(label="Image URL") | |
| image_classification_image_request_button = gr.Button("Get Image") | |
| image_classification_image_input = gr.Image(label="Image", type="pil") | |
| image_classification_image_request_button.click( | |
| fn=request_image, | |
| inputs=image_classification_url_input, | |
| outputs=image_classification_image_input | |
| ) | |
| image_classification_button = gr.Button("Classify") | |
| image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False) | |
| image_classification_button.click( | |
| fn=partial(image_classification, model), | |
| inputs=image_classification_image_input, | |
| outputs=image_classification_output | |
| ) | |