Ali2206 commited on
Commit
9aeb1dd
·
verified ·
1 Parent(s): a09aba3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -41
app.py CHANGED
@@ -1,22 +1,21 @@
1
- import random
2
- import datetime
3
- import sys
4
  import os
 
5
  import torch
6
- import logging
7
  import json
 
 
8
  from importlib.resources import files
9
  from txagent import TxAgent
10
  from tooluniverse import ToolUniverse
11
- import gradio as gr
12
 
13
- # Set up logging
14
  logging.basicConfig(
15
  level=logging.INFO,
16
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
  )
18
  logger = logging.getLogger(__name__)
19
 
 
20
  current_dir = os.path.dirname(os.path.abspath(__file__))
21
  os.environ["MKL_THREADING_LAYER"] = "GNU"
22
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -39,22 +38,12 @@ chat_css = """
39
  .gr-button svg { width: 32px !important; height: 32px !important; }
40
  """
41
 
42
- def validate_message(message: str) -> bool:
43
- """Validate that the message meets minimum requirements."""
44
- if not message or not isinstance(message, str):
45
- return False
46
- # Remove whitespace and check length
47
- clean_msg = message.strip()
48
- return len(clean_msg) >= 10
49
-
50
  def safe_load_embeddings(filepath: str) -> any:
51
  try:
52
- # First try with weights_only=True (secure mode)
53
  return torch.load(filepath, weights_only=True)
54
  except Exception as e:
55
  logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
56
  try:
57
- # If that fails, try with weights_only=False (less secure)
58
  return torch.load(filepath, weights_only=False)
59
  except Exception as e:
60
  logger.error(f"Failed to load embeddings: {str(e)}")
@@ -85,7 +74,6 @@ def patch_embedding_loading():
85
 
86
  if current_count != embedding_count:
87
  logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
88
-
89
  if current_count < embedding_count:
90
  self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
91
  logger.info(f"Truncated embeddings to match {current_count} tools")
@@ -149,18 +137,22 @@ def create_agent():
149
  logger.error(f"Failed to create agent: {str(e)}")
150
  raise
151
 
152
- def respond(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
153
- # Validate the message first
154
- if not validate_message(message):
155
- error_msg = "Please provide a valid message with a string longer than 10 characters."
156
- logger.warning(f"Message validation failed: {message}")
157
- return history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_msg}]
 
 
158
 
159
  updated_history = history + [{"role": "user", "content": message}]
160
- logger.debug(f"\n==== DEBUG ====\nUser Message: {message}\nFull History: {updated_history}\n================\n")
 
 
 
161
 
162
  try:
163
- # Ensure correct format for run_gradio_chat
164
  formatted_history = [(m["role"], m["content"]) for m in updated_history]
165
 
166
  response_generator = agent.run_gradio_chat(
@@ -173,29 +165,22 @@ def respond(message, history, temperature, max_new_tokens, max_tokens, multi_age
173
  max_round
174
  )
175
  except Exception as e:
176
- error_msg = f"Error processing your request: {str(e)}"
177
- logger.error(f"Error in respond function: {str(e)}")
178
- return history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_msg}]
179
-
180
- collected = ""
181
- try:
182
  for chunk in response_generator:
183
  if isinstance(chunk, dict):
184
  collected += chunk.get("content", "")
185
  else:
186
  collected += str(chunk)
187
- except Exception as e:
188
- error_msg = f"Error generating response: {str(e)}"
189
- logger.error(f"Error in response generation: {str(e)}")
190
- return history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_msg}]
191
 
192
- return history + [{"role": "user", "content": message}, {"role": "assistant", "content": collected}]
 
193
 
194
  def create_demo(agent):
195
  with gr.Blocks(css=chat_css) as demo:
196
  chatbot = gr.Chatbot(label="TxAgent", type="messages")
197
- with gr.Row():
198
- msg = gr.Textbox(label="Your question", placeholder="Enter your biomedical question here (minimum 10 characters)...")
199
  with gr.Row():
200
  temp = gr.Slider(0, 1, value=0.3, label="Temperature")
201
  max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
@@ -207,9 +192,10 @@ def create_demo(agent):
207
 
208
  submit.click(
209
  respond,
210
- inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
211
  outputs=[chatbot]
212
  )
 
213
  return demo
214
 
215
  def main():
@@ -217,10 +203,10 @@ def main():
217
  global agent
218
  agent = create_agent()
219
  demo = create_demo(agent)
220
- demo.launch(server_name="0.0.0.0", server_port=7860)
221
  except Exception as e:
222
  logger.error(f"Application failed to start: {str(e)}")
223
  raise
224
 
225
  if __name__ == "__main__":
226
- main()
 
 
 
 
1
  import os
2
+ import sys
3
  import torch
 
4
  import json
5
+ import logging
6
+ import gradio as gr
7
  from importlib.resources import files
8
  from txagent import TxAgent
9
  from tooluniverse import ToolUniverse
 
10
 
11
+ # Logging setup
12
  logging.basicConfig(
13
  level=logging.INFO,
14
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
15
  )
16
  logger = logging.getLogger(__name__)
17
 
18
+ # Paths and environment
19
  current_dir = os.path.dirname(os.path.abspath(__file__))
20
  os.environ["MKL_THREADING_LAYER"] = "GNU"
21
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
38
  .gr-button svg { width: 32px !important; height: 32px !important; }
39
  """
40
 
 
 
 
 
 
 
 
 
41
  def safe_load_embeddings(filepath: str) -> any:
42
  try:
 
43
  return torch.load(filepath, weights_only=True)
44
  except Exception as e:
45
  logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
46
  try:
 
47
  return torch.load(filepath, weights_only=False)
48
  except Exception as e:
49
  logger.error(f"Failed to load embeddings: {str(e)}")
 
74
 
75
  if current_count != embedding_count:
76
  logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
 
77
  if current_count < embedding_count:
78
  self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
79
  logger.info(f"Truncated embeddings to match {current_count} tools")
 
137
  logger.error(f"Failed to create agent: {str(e)}")
138
  raise
139
 
140
+ def respond(chat_history, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
141
+ if not chat_history:
142
+ return [["assistant", "Please provide a message."]]
143
+
144
+ message = chat_history[-1][1] if isinstance(chat_history[-1], (list, tuple)) else chat_history[-1]
145
+
146
+ if not isinstance(message, str) or len(message.strip()) <= 10:
147
+ return chat_history + [["assistant", "Please provide a valid message with a string longer than 10 characters."]]
148
 
149
  updated_history = history + [{"role": "user", "content": message}]
150
+ print("\n==== DEBUG ====")
151
+ print("User Message:", message)
152
+ print("Full History:", updated_history)
153
+ print("================\n")
154
 
155
  try:
 
156
  formatted_history = [(m["role"], m["content"]) for m in updated_history]
157
 
158
  response_generator = agent.run_gradio_chat(
 
165
  max_round
166
  )
167
  except Exception as e:
168
+ updated_history.append({"role": "assistant", "content": f"Error: {str(e)}"})
169
+ else:
170
+ collected = ""
 
 
 
171
  for chunk in response_generator:
172
  if isinstance(chunk, dict):
173
  collected += chunk.get("content", "")
174
  else:
175
  collected += str(chunk)
176
+ updated_history.append({"role": "assistant", "content": collected})
 
 
 
177
 
178
+ # Return formatted history to Gradio
179
+ return [(m["role"], m["content"]) for m in updated_history]
180
 
181
  def create_demo(agent):
182
  with gr.Blocks(css=chat_css) as demo:
183
  chatbot = gr.Chatbot(label="TxAgent", type="messages")
 
 
184
  with gr.Row():
185
  temp = gr.Slider(0, 1, value=0.3, label="Temperature")
186
  max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
 
192
 
193
  submit.click(
194
  respond,
195
+ inputs=[chatbot, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
196
  outputs=[chatbot]
197
  )
198
+
199
  return demo
200
 
201
  def main():
 
203
  global agent
204
  agent = create_agent()
205
  demo = create_demo(agent)
206
+ demo.launch()
207
  except Exception as e:
208
  logger.error(f"Application failed to start: {str(e)}")
209
  raise
210
 
211
  if __name__ == "__main__":
212
+ main()