from functools import partial from huggingface_hub import InferenceClient from os import path, unlink, getenv import gradio as gr from PIL.Image import Image import pandas as pd from pandas import DataFrame from utils import save_image_to_temp_file, request_image def image_classification(client: InferenceClient, image: Image) -> DataFrame: """Classify an image using Hugging Face Inference API. This function classifies a recyclable item image into categories: cardboard, glass, metal, paper, plastic, or other. The image is saved to a temporary file since InferenceClient requires a file path rather than a PIL Image object directly. Args: client: Hugging Face InferenceClient instance for API calls. image: PIL Image object to classify. Returns: Pandas DataFrame with two columns: - Label: The classification label (e.g., "cardboard", "glass") - Probability: The confidence score as a percentage string (e.g., "95.23%") Note: - The model ID is determined by the IMAGE_CLASSIFICATION_MODEL environment variable. - Uses Trash-Net model for recyclable item classification. - Automatically cleans up temporary files after classification. - Temporary file is created with format preservation if possible. """ try: temp_file_path = save_image_to_temp_file(image) # Needed because InferenceClient does not accept PIL Images directly. classifications = client.image_classification(temp_file_path, model=getenv("IMAGE_CLASSIFICATION_MODEL")) return pd.DataFrame({ "Label": classification.label, "Probability": f"{classification.score:.2%}" } for classification in classifications) finally: if temp_file_path and path.exists(temp_file_path): # Clean up temporary file. try: unlink(temp_file_path) except Exception: pass # Ignore clean-up errors. def create_image_classification_tab(client: InferenceClient): """Create the image classification tab in the Gradio interface. This function sets up all UI components for image classification, including: - URL input textbox for fetching images from the web - Button to retrieve image from URL - Image preview component - Classify button and output dataframe showing labels and probabilities Args: client: Hugging Face InferenceClient instance to pass to the image_classification function. """ gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).") image_classification_url_input = gr.Textbox(label="Image URL") image_classification_image_request_button = gr.Button("Get Image") image_classification_image_input = gr.Image(label="Image", type="pil") image_classification_image_request_button.click( fn=request_image, inputs=image_classification_url_input, outputs=image_classification_image_input ) image_classification_button = gr.Button("Classify") image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False) image_classification_button.click( fn=partial(image_classification, client), inputs=image_classification_image_input, outputs=image_classification_output )