arabic-sign-language-yolo / utils /medical_agent.py
Mr-HASSAN
Add word building, Arabic translation, HuatuoGPT agent with 3Q/5W limits
b68720c
"""
HuatuoGPT-based Medical Agent with strict question limits
Uses HuatuoGPT-o1-8B for medical consultation
"""
import logging
import torch
from typing import Dict, Optional
from collections import defaultdict
logger = logging.getLogger(__name__)
class HuatuoMedicalAgent:
"""Medical agent using HuatuoGPT with strict question limits"""
def __init__(self, max_questions: int = 3, max_words_per_question: int = 5):
"""
Initialize HuatuoGPT Medical Agent
Args:
max_questions: Maximum number of follow-up questions (default: 3)
max_words_per_question: Maximum words per question (default: 5)
"""
self.max_questions = max_questions
self.max_words_per_question = max_words_per_question
self.sessions = defaultdict(lambda: {
'question_count': 0,
'symptoms': [],
'conversation_history': []
})
# Try to load HuatuoGPT model
self.model = None
self.tokenizer = None
self._load_model()
def _load_model(self):
"""Load HuatuoGPT model with error handling"""
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "FreedomIntelligence/HuatuoGPT-o1-8B"
logger.info(f"πŸ”„ Loading HuatuoGPT from {model_name}...")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
# Load model with optimizations
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True
)
logger.info("βœ… HuatuoGPT model loaded successfully")
except Exception as e:
logger.warning(f"⚠️ Failed to load HuatuoGPT: {e}")
logger.info("πŸ“‹ Falling back to rule-based medical questions")
self.model = None
self.tokenizer = None
def _generate_question_with_huatuo(self, symptoms: list, history: list) -> str:
"""Generate medical question using HuatuoGPT"""
if not self.model or not self.tokenizer:
return self._generate_rule_based_question(symptoms)
try:
# Build prompt for HuatuoGPT
symptoms_text = ", ".join(symptoms) if symptoms else "general consultation"
prompt = f"""You are a medical doctor. A patient reported: {symptoms_text}.
Ask ONE brief follow-up question (maximum {self.max_words_per_question} words) to better understand their condition.
Question:"""
# Generate response
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=20, # Limit tokens for short questions
temperature=0.7,
do_sample=True,
top_p=0.9
)
question = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the question part
if "Question:" in question:
question = question.split("Question:")[-1].strip()
# Enforce word limit
words = question.split()
if len(words) > self.max_words_per_question:
question = " ".join(words[:self.max_words_per_question]) + "?"
return question
except Exception as e:
logger.error(f"❌ HuatuoGPT generation failed: {e}")
return self._generate_rule_based_question(symptoms)
def _generate_rule_based_question(self, symptoms: list) -> str:
"""Fallback: Generate rule-based medical questions"""
# Medical question templates (max 5 words each)
question_templates = [
"How long feeling this?",
"Pain level one-to-ten?",
"Any other symptoms present?",
"Taking any medications now?",
"When did this start?"
]
if not symptoms:
return question_templates[0]
# Select based on symptom keywords
symptom_text = " ".join(symptoms).lower()
if any(word in symptom_text for word in ["pain", "hurt", "ache"]):
return "Pain level one-to-ten?"
elif any(word in symptom_text for word in ["fever", "temperature", "hot"]):
return "How high is fever?"
elif any(word in symptom_text for word in ["cough", "cold", "flu"]):
return "Cough with mucus present?"
else:
return question_templates[min(len(symptoms), len(question_templates)-1)]
def process_input(self, patient_input: str, session_id: str = "default") -> Dict:
"""
Process patient input and generate medical response
Args:
patient_input: Patient's symptom/response in English
session_id: Session identifier
Returns:
Dict with 'response', 'question_count', 'state'
"""
session = self.sessions[session_id]
# Add to symptoms
if patient_input.strip():
session['symptoms'].append(patient_input)
session['conversation_history'].append(f"Patient: {patient_input}")
# Check if reached question limit
if session['question_count'] >= self.max_questions:
response = f"Thank you. I have {self.max_questions} questions answered. Please see a doctor for detailed examination."
return {
'response': response,
'question_count': session['question_count'],
'state': 'complete'
}
# Generate follow-up question
if self.model:
question = self._generate_question_with_huatuo(
session['symptoms'],
session['conversation_history']
)
else:
question = self._generate_rule_based_question(session['symptoms'])
# Increment question count
session['question_count'] += 1
session['conversation_history'].append(f"Doctor: {question}")
return {
'response': question,
'question_count': session['question_count'],
'state': f'question_{session["question_count"]}'
}
def process_doctor_input(self, doctor_text: str) -> str:
"""Process doctor's voice input"""
# Enforce word limit
words = doctor_text.split()
if len(words) > self.max_words_per_question:
return " ".join(words[:self.max_words_per_question]) + "?"
return doctor_text
def reset_session(self, session_id: str = "default"):
"""Reset a session"""
if session_id in self.sessions:
del self.sessions[session_id]