from functools import partial from huggingface_hub import InferenceClient from os import path, unlink import gradio as gr from PIL.Image import Image import pandas as pd from pandas import DataFrame from utils import save_image_to_temp_file, request_image def image_classification(client: InferenceClient, model: str, image: Image) -> DataFrame: """Classify an image using Hugging Face Inference API. This function classifies a recyclable item image into categories: cardboard, glass, metal, paper, plastic, or other. The image is saved to a temporary file since InferenceClient requires a file path rather than a PIL Image object directly. Args: client: Hugging Face InferenceClient instance for API calls. model: Hugging Face model ID to use for image classification. image: PIL Image object to classify. Returns: Pandas DataFrame with two columns: - Label: The classification label (e.g., "cardboard", "glass") - Probability: The confidence score as a percentage string (e.g., "95.23%") Note: - Automatically cleans up temporary files after classification. - Temporary file is created with format preservation if possible. """ try: temp_file_path = save_image_to_temp_file(image) # Needed because InferenceClient does not accept PIL Images directly. classifications = client.image_classification(temp_file_path, model=model) return pd.DataFrame({ "Label": classification.label, "Probability": f"{classification.score:.2%}" } for classification in classifications) finally: if temp_file_path and path.exists(temp_file_path): # Clean up temporary file. try: unlink(temp_file_path) except Exception: pass # Ignore clean-up errors. def create_image_classification_tab(client: InferenceClient, model: str): """Create the image classification tab in the Gradio interface. This function sets up all UI components for image classification, including: - URL input textbox for fetching images from the web - Button to retrieve image from URL - Image preview component - Classify button and output dataframe showing labels and probabilities Args: client: Hugging Face InferenceClient instance to pass to the image_classification function. model: Hugging Face model ID to use for image classification. """ gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).") image_classification_url_input = gr.Textbox(label="Image URL") image_classification_image_request_button = gr.Button("Get Image") image_classification_image_input = gr.Image(label="Image", type="pil") image_classification_image_request_button.click( fn=request_image, inputs=image_classification_url_input, outputs=image_classification_image_input ) image_classification_button = gr.Button("Classify") image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False) image_classification_button.click( fn=partial(image_classification, client, model), inputs=image_classification_image_input, outputs=image_classification_output )