LiKenun commited on
Commit
4c71b8b
·
1 Parent(s): bb6107f

Add AI-generated chat sample

Browse files
Files changed (4) hide show
  1. app.py +28 -0
  2. chatbot.py +66 -0
  3. image_to_text.py +4 -1
  4. text_to_speech.py +6 -2
app.py CHANGED
@@ -3,6 +3,7 @@ from functools import partial
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
  from automatic_speech_recognition import automatic_speech_recognition
 
6
  from image_classification import image_classification
7
  from image_to_text import image_to_text
8
  from text_to_image import text_to_image
@@ -91,6 +92,33 @@ class App:
91
  inputs=audio_transcription_audio_input,
92
  outputs=audio_transcription_output
93
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  demo.launch()
96
 
 
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
  from automatic_speech_recognition import automatic_speech_recognition
6
+ from chatbot import chat
7
  from image_classification import image_classification
8
  from image_to_text import image_to_text
9
  from text_to_image import text_to_image
 
92
  inputs=audio_transcription_audio_input,
93
  outputs=audio_transcription_output
94
  )
95
+ with gr.Tab("Chat"):
96
+ gr.Markdown("Have a conversation with an AI chatbot.")
97
+ chatbot_history = gr.State(value=None) # Store the conversation history.
98
+ chatbot_output = gr.Chatbot(label="Conversation")
99
+ chatbot_input = gr.Textbox(label="Your message")
100
+ chatbot_send_button = gr.Button("Send")
101
+
102
+ def chat_interface(message: str, history: list | None, conversation_state: list[dict] | None):
103
+ """Handle chatbot interaction with Gradio chat format."""
104
+ if not message.strip():
105
+ return history, conversation_state, ""
106
+ response, updated_conversation = chat(message, conversation_state) # Get response from chatbot.
107
+ if history is None: # Update Gradio chat history format: list of [user_message, bot_message] pairs.
108
+ history = []
109
+ history.append([message, response])
110
+ return history, updated_conversation, "" # Clear input field for the next message from the user.
111
+
112
+ chatbot_send_button.click(
113
+ fn=chat_interface,
114
+ inputs=[chatbot_input, chatbot_output, chatbot_history],
115
+ outputs=[chatbot_output, chatbot_history, chatbot_input]
116
+ )
117
+ chatbot_input.submit(
118
+ fn=chat_interface,
119
+ inputs=[chatbot_input, chatbot_output, chatbot_history],
120
+ outputs=[chatbot_output, chatbot_history, chatbot_input]
121
+ )
122
 
123
  demo.launch()
124
 
chatbot.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import getenv
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ from utils import get_pytorch_device, spaces_gpu
4
+
5
+ # Global chatbot instance (initialized once)
6
+ _chatbot = None
7
+ _tokenizer = None
8
+
9
+ def get_chatbot():
10
+ global _chatbot, _tokenizer
11
+ if _chatbot is None:
12
+ model_id = getenv("CHAT_MODEL")
13
+ device = get_pytorch_device()
14
+ _tokenizer = AutoTokenizer.from_pretrained(model_id)
15
+ _chatbot = AutoModelForSeq2SeqLM.from_pretrained(
16
+ model_id,
17
+ use_safetensors=True # Use safetensors to avoid torch.load restriction
18
+ ).to(device)
19
+ return _chatbot, _tokenizer
20
+
21
+ @spaces_gpu
22
+ def chat(message: str, conversation_history: list[dict] | None) -> tuple[str, list[dict]]:
23
+ model, tokenizer = get_chatbot()
24
+
25
+ # Initialize conversation history if this is the first message
26
+ if conversation_history is None:
27
+ conversation_history = []
28
+
29
+ # Add the user's message
30
+ conversation_history.append({"role": "user", "content": message})
31
+
32
+ # For BlenderBot models, format conversation as dialogue history
33
+ # Build the full conversation context as a string
34
+ dialogue_text = ""
35
+ for msg in conversation_history:
36
+ if msg["role"] == "user":
37
+ dialogue_text += f"User: {msg['content']}\n"
38
+ elif msg["role"] == "assistant":
39
+ dialogue_text += f"Assistant: {msg['content']}\n"
40
+
41
+ # Tokenize the input
42
+ inputs = tokenizer([dialogue_text], return_tensors="pt", truncation=True, max_length=512)
43
+ device = get_pytorch_device()
44
+ inputs = {k: v.to(device) for k, v in inputs.items()}
45
+
46
+ # Generate response
47
+ outputs = model.generate(
48
+ **inputs,
49
+ max_new_tokens=128,
50
+ do_sample=True,
51
+ temperature=0.7,
52
+ pad_token_id=tokenizer.eos_token_id
53
+ )
54
+
55
+ # Decode the generated tokens - for seq2seq models, this should be just the assistant's response
56
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
+
58
+ # Clean up the response - remove any "Assistant:" prefix if present
59
+ response = response.strip()
60
+ if response.startswith("Assistant:"):
61
+ response = response[len("Assistant:"):].strip()
62
+
63
+ # Add the assistant's response to history
64
+ conversation_history.append({"role": "assistant", "content": response})
65
+
66
+ return response, conversation_history
image_to_text.py CHANGED
@@ -10,7 +10,10 @@ def image_to_text(image: Image) -> list[str]:
10
  image_to_text_model_id = getenv("IMAGE_TO_TEXT_MODEL")
11
  pytorch_device = get_pytorch_device()
12
  processor = AutoProcessor.from_pretrained(image_to_text_model_id)
13
- model = BlipForConditionalGeneration.from_pretrained(image_to_text_model_id).to(pytorch_device)
 
 
 
14
  inputs = processor(images=image, return_tensors="pt").to(pytorch_device)
15
  generated_ids = model.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5)
16
  results = processor.batch_decode(generated_ids, skip_special_tokens=True)
 
10
  image_to_text_model_id = getenv("IMAGE_TO_TEXT_MODEL")
11
  pytorch_device = get_pytorch_device()
12
  processor = AutoProcessor.from_pretrained(image_to_text_model_id)
13
+ model = BlipForConditionalGeneration.from_pretrained(
14
+ image_to_text_model_id,
15
+ use_safetensors=True # Use safetensors to avoid torch.load restriction.
16
+ ).to(pytorch_device)
17
  inputs = processor(images=image, return_tensors="pt").to(pytorch_device)
18
  generated_ids = model.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5)
19
  results = processor.batch_decode(generated_ids, skip_special_tokens=True)
text_to_speech.py CHANGED
@@ -6,8 +6,12 @@ from utils import spaces_gpu
6
 
7
  @spaces_gpu
8
  def text_to_speech(text: str) -> tuple[int, bytes]:
9
- narrator = pipeline("text-to-speech", getenv("TEXT_TO_SPEECH_MODEL"))
 
 
 
 
 
10
  del narrator
11
  gc.collect()
12
- result = narrator(text)
13
  return (result["sampling_rate"], result["audio"][0])
 
6
 
7
  @spaces_gpu
8
  def text_to_speech(text: str) -> tuple[int, bytes]:
9
+ narrator = pipeline(
10
+ "text-to-speech",
11
+ getenv("TEXT_TO_SPEECH_MODEL"),
12
+ model_kwargs={"use_safetensors": True} # Use safetensors to avoid torch.load restriction.
13
+ )
14
+ result = narrator(text)
15
  del narrator
16
  gc.collect()
 
17
  return (result["sampling_rate"], result["audio"][0])