from dotenv import load_dotenv from functools import partial import gradio as gr from huggingface_hub import InferenceClient from image_classification import image_classification from text_to_image import text_to_image class App: def __init__(self, client: InferenceClient): self.client = client def run(self): with gr.Blocks(title="AI Building Blocks") as demo: gr.Markdown("# AI Building Blocks") gr.Markdown("A gallery of building blocks for building AI applications") with gr.Tabs(): with gr.Tab("Text-to-image Generation"): gr.Markdown("Generate an image from a text prompt.") text_to_image_prompt = gr.Textbox(label="Prompt", value="A panda under a giant mushroom next to a pumpkin") text_to_image_generate_button = gr.Button("Generate") text_to_image_output = gr.Image(label="Image", type="pil") text_to_image_generate_button.click( fn=partial(text_to_image, self.client), inputs=text_to_image_prompt, outputs=text_to_image_output ) with gr.Tab("Image Classification"): gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).") with gr.Row(): with gr.Column(): image_classification_url_input = gr.Textbox( label="Image URL", value="https://campuslifeservices.ucsf.edu/upload/facilities/galleries/cardboard_0.jpg", placeholder="Enter the URL of the image to classify", scale=2 ) image_classification_image_preview = gr.Image(label="Image Preview", type="pil") image_classification_upload_input = gr.Image( label="Or Upload Image", type="pil", scale=2 ) image_classification_button = gr.Button("Classify") image_classification_output = gr.Dataframe( label="Classification Results", headers=["Label", "Probability"], interactive=False ) image_classification_button.click( fn=partial(image_classification, self.client), inputs=[image_classification_url_input, image_classification_upload_input], outputs=[image_classification_image_preview, image_classification_output] ) demo.launch() if __name__ == "__main__": load_dotenv() app = App(InferenceClient()) app.run()