LiKenun commited on
Commit
dc382c8
·
1 Parent(s): 979bbdf

Reorganize structure for less code clutter

Browse files
Files changed (4) hide show
  1. app.py +57 -103
  2. image_classification.py +45 -0
  3. text_to_image.py +7 -0
  4. utils.py +12 -0
app.py CHANGED
@@ -1,109 +1,63 @@
1
  from dotenv import load_dotenv
 
2
  import gradio as gr
3
  from huggingface_hub import InferenceClient
4
- from io import BytesIO
5
- from os import path, unlink
6
- import pandas as pd
7
- from pandas import DataFrame
8
- from PIL.Image import Image, open as open_image
9
- import requests
10
- import tempfile
11
-
12
- REQUEST_TIMEOUT = 45
13
-
14
- TEXT_TO_IMAGE_MODEL = "black-forest-labs/FLUX.1-dev"
15
- IMAGE_CLASSIFICATION_MODEL = "prithivMLmods/Trash-Net"
16
-
17
- # Load environment variables from .env file
18
- load_dotenv()
19
-
20
- client = InferenceClient()
21
-
22
- def save_image_to_temp_file(image: Image) -> str:
23
- image_format = image.format if image.format else 'PNG'
24
- format_extension = image_format.lower() if image_format else 'png'
25
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f".{format_extension}")
26
- temp_path = temp_file.name
27
- temp_file.close()
28
- image.save(temp_path, format=image_format)
29
- return temp_path
30
-
31
- def text_to_image(prompt: str) -> Image:
32
- return client.text_to_image(prompt, model=TEXT_TO_IMAGE_MODEL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- def image_classification(image_url: str | None, uploaded_image: Image | None) -> tuple[Image | None, DataFrame]:
35
- temp_file_path = None
36
- try:
37
- if uploaded_image is not None and image_url and image_url.strip():
38
- raise gr.Error("Both an image URL and an uploaded image were provided. Please provide only one or the other.")
39
- elif uploaded_image is not None:
40
- temp_file_path = save_image_to_temp_file(uploaded_image)
41
- classifications = client.image_classification(temp_file_path, model=IMAGE_CLASSIFICATION_MODEL)
42
- image = None
43
- elif image_url and image_url.strip():
44
- try:
45
- response = requests.get(image_url, timeout=REQUEST_TIMEOUT)
46
- response.raise_for_status()
47
- image = open_image(BytesIO(response.content))
48
- temp_file_path = save_image_to_temp_file(image)
49
- classifications = client.image_classification(temp_file_path, model=IMAGE_CLASSIFICATION_MODEL)
50
- except Exception as e:
51
- raise gr.Error(f"Failed to fetch image from URL: {str(e)}")
52
- else:
53
- raise gr.Error("Please either provide an image URL or upload an image.")
54
- df = pd.DataFrame([
55
- {"Label": classification.label, "Probability": f"{classification.score:.2%}"}
56
- for classification in classifications
57
- ])
58
- return image, df
59
- finally:
60
- # Clean up temporary file.
61
- if temp_file_path and path.exists(temp_file_path):
62
- try:
63
- unlink(temp_file_path)
64
- except Exception:
65
- pass # Ignore clean-up errors.
66
 
67
- with gr.Blocks(title="AI Building Blocks") as demo:
68
- gr.Markdown("# AI Building Blocks")
69
- gr.Markdown("A gallery of building blocks for building AI applications")
70
- with gr.Tabs():
71
- with gr.Tab("Text-to-image Generation"):
72
- gr.Markdown("Generate an image from a text prompt.")
73
- text_to_image_prompt = gr.Textbox(label="Prompt", value="A panda under a giant mushroom next to a pumpkin")
74
- text_to_image_generate_button = gr.Button("Generate")
75
- text_to_image_output = gr.Image(label="Image", type="pil")
76
- text_to_image_generate_button.click(
77
- fn=text_to_image,
78
- inputs=text_to_image_prompt,
79
- outputs=text_to_image_output
80
- )
81
- with gr.Tab("Image Classification"):
82
- gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
83
- with gr.Row():
84
- with gr.Column():
85
- image_classification_url_input = gr.Textbox(
86
- label="Image URL",
87
- value="https://campuslifeservices.ucsf.edu/upload/facilities/galleries/cardboard_0.jpg",
88
- placeholder="Enter the URL of the image to classify",
89
- scale=2
90
- )
91
- image_classification_image_preview = gr.Image(label="Image Preview", type="pil")
92
- image_classification_upload_input = gr.Image(
93
- label="Or Upload Image",
94
- type="pil",
95
- scale=2
96
- )
97
- image_classification_button = gr.Button("Classify")
98
- image_classification_output = gr.Dataframe(
99
- label="Classification Results",
100
- headers=["Label", "Probability"],
101
- interactive=False
102
- )
103
- image_classification_button.click(
104
- fn=image_classification,
105
- inputs=[image_classification_url_input, image_classification_upload_input],
106
- outputs=[image_classification_image_preview, image_classification_output]
107
- )
108
 
109
- demo.launch()
 
 
 
 
1
  from dotenv import load_dotenv
2
+ from functools import partial
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
+ from image_classification import image_classification
6
+ from text_to_image import text_to_image
7
+
8
+
9
+ class App:
10
+
11
+ def __init__(self, client: InferenceClient):
12
+ self.client = client
13
+
14
+ def run(self):
15
+ with gr.Blocks(title="AI Building Blocks") as demo:
16
+ gr.Markdown("# AI Building Blocks")
17
+ gr.Markdown("A gallery of building blocks for building AI applications")
18
+ with gr.Tabs():
19
+ with gr.Tab("Text-to-image Generation"):
20
+ gr.Markdown("Generate an image from a text prompt.")
21
+ text_to_image_prompt = gr.Textbox(label="Prompt", value="A panda under a giant mushroom next to a pumpkin")
22
+ text_to_image_generate_button = gr.Button("Generate")
23
+ text_to_image_output = gr.Image(label="Image", type="pil")
24
+ text_to_image_generate_button.click(
25
+ fn=partial(text_to_image, self.client),
26
+ inputs=text_to_image_prompt,
27
+ outputs=text_to_image_output
28
+ )
29
+ with gr.Tab("Image Classification"):
30
+ gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
31
+ with gr.Row():
32
+ with gr.Column():
33
+ image_classification_url_input = gr.Textbox(
34
+ label="Image URL",
35
+ value="https://campuslifeservices.ucsf.edu/upload/facilities/galleries/cardboard_0.jpg",
36
+ placeholder="Enter the URL of the image to classify",
37
+ scale=2
38
+ )
39
+ image_classification_image_preview = gr.Image(label="Image Preview", type="pil")
40
+ image_classification_upload_input = gr.Image(
41
+ label="Or Upload Image",
42
+ type="pil",
43
+ scale=2
44
+ )
45
+ image_classification_button = gr.Button("Classify")
46
+ image_classification_output = gr.Dataframe(
47
+ label="Classification Results",
48
+ headers=["Label", "Probability"],
49
+ interactive=False
50
+ )
51
+ image_classification_button.click(
52
+ fn=partial(image_classification, self.client),
53
+ inputs=[image_classification_url_input, image_classification_upload_input],
54
+ outputs=[image_classification_image_preview, image_classification_output]
55
+ )
56
 
57
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ if __name__ == "__main__":
61
+ load_dotenv()
62
+ app = App(InferenceClient())
63
+ app.run()
image_classification.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ from io import BytesIO
4
+ from os import path, unlink, getenv
5
+ from PIL.Image import Image, open as open_image
6
+ import pandas as pd
7
+ from pandas import DataFrame
8
+ import requests
9
+ from utils import save_image_to_temp_file
10
+
11
+
12
+ def image_classification(client: InferenceClient, image_url: str | None, image: Image | None) -> tuple[Image | None, DataFrame]:
13
+ temp_file_path = None
14
+ try:
15
+ if image is not None and image_url and image_url.strip():
16
+ raise gr.Error("Both an image URL and an uploaded image were provided. Please provide only one or the other.")
17
+ elif image is not None:
18
+ temp_file_path = save_image_to_temp_file(image)
19
+ classifications = client.image_classification(temp_file_path, model=getenv("IMAGE_CLASSIFICATION_MODEL"))
20
+ image = None
21
+ elif image_url and image_url.strip():
22
+ try:
23
+ response = requests.get(image_url, timeout=int(getenv("REQUEST_TIMEOUT")))
24
+ response.raise_for_status()
25
+ image = open_image(BytesIO(response.content))
26
+ temp_file_path = save_image_to_temp_file(image)
27
+ classifications = client.image_classification(temp_file_path, model=getenv("IMAGE_CLASSIFICATION_MODEL"))
28
+ except Exception as e:
29
+ raise gr.Error(f"Failed to fetch image from URL: {str(e)}")
30
+ else:
31
+ raise gr.Error("Please either provide an image URL or upload an image.")
32
+ df = pd.DataFrame({
33
+ "Label": classification.label,
34
+ "Probability": f"{classification.score:.2%}"
35
+ }
36
+ for classification
37
+ in classifications)
38
+ return image, df
39
+ finally:
40
+ # Clean up temporary file.
41
+ if temp_file_path and path.exists(temp_file_path):
42
+ try:
43
+ unlink(temp_file_path)
44
+ except Exception:
45
+ pass # Ignore clean-up errors.
text_to_image.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from os import getenv
2
+ from PIL.Image import Image
3
+ from huggingface_hub import InferenceClient
4
+
5
+
6
+ def text_to_image(client: InferenceClient, prompt: str) -> Image:
7
+ return client.text_to_image(prompt, model=getenv("TEXT_TO_IMAGE_MODEL"))
utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL.Image import Image
2
+ from tempfile import NamedTemporaryFile
3
+
4
+
5
+ def save_image_to_temp_file(image: Image) -> str:
6
+ image_format = image.format if image.format else 'PNG'
7
+ format_extension = image_format.lower() if image_format else 'png'
8
+ temp_file = NamedTemporaryFile(delete=False, suffix=f".{format_extension}")
9
+ temp_path = temp_file.name
10
+ temp_file.close()
11
+ image.save(temp_path, format=image_format)
12
+ return temp_path