Spaces:
Sleeping
Sleeping
| """ | |
| HuatuoGPT-based Medical Agent with strict question limits | |
| Uses HuatuoGPT-o1-8B for medical consultation | |
| """ | |
| import logging | |
| import torch | |
| from typing import Dict, Optional | |
| from collections import defaultdict | |
| logger = logging.getLogger(__name__) | |
| class HuatuoMedicalAgent: | |
| """Medical agent using HuatuoGPT with strict question limits""" | |
| def __init__(self, max_questions: int = 3, max_words_per_question: int = 5): | |
| """ | |
| Initialize HuatuoGPT Medical Agent | |
| Args: | |
| max_questions: Maximum number of follow-up questions (default: 3) | |
| max_words_per_question: Maximum words per question (default: 5) | |
| """ | |
| self.max_questions = max_questions | |
| self.max_words_per_question = max_words_per_question | |
| self.sessions = defaultdict(lambda: { | |
| 'question_count': 0, | |
| 'symptoms': [], | |
| 'conversation_history': [] | |
| }) | |
| # Try to load HuatuoGPT model | |
| self.model = None | |
| self.tokenizer = None | |
| self._load_model() | |
| def _load_model(self): | |
| """Load HuatuoGPT model with error handling""" | |
| try: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| model_name = "FreedomIntelligence/HuatuoGPT-o1-8B" | |
| logger.info(f"π Loading HuatuoGPT from {model_name}...") | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True | |
| ) | |
| # Load model with optimizations | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| logger.info("β HuatuoGPT model loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"β οΈ Failed to load HuatuoGPT: {e}") | |
| logger.info("π Falling back to rule-based medical questions") | |
| self.model = None | |
| self.tokenizer = None | |
| def _generate_question_with_huatuo(self, symptoms: list, history: list) -> str: | |
| """Generate medical question using HuatuoGPT""" | |
| if not self.model or not self.tokenizer: | |
| return self._generate_rule_based_question(symptoms) | |
| try: | |
| # Build prompt for HuatuoGPT | |
| symptoms_text = ", ".join(symptoms) if symptoms else "general consultation" | |
| prompt = f"""You are a medical doctor. A patient reported: {symptoms_text}. | |
| Ask ONE brief follow-up question (maximum {self.max_words_per_question} words) to better understand their condition. | |
| Question:""" | |
| # Generate response | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=20, # Limit tokens for short questions | |
| temperature=0.7, | |
| do_sample=True, | |
| top_p=0.9 | |
| ) | |
| question = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract just the question part | |
| if "Question:" in question: | |
| question = question.split("Question:")[-1].strip() | |
| # Enforce word limit | |
| words = question.split() | |
| if len(words) > self.max_words_per_question: | |
| question = " ".join(words[:self.max_words_per_question]) + "?" | |
| return question | |
| except Exception as e: | |
| logger.error(f"β HuatuoGPT generation failed: {e}") | |
| return self._generate_rule_based_question(symptoms) | |
| def _generate_rule_based_question(self, symptoms: list) -> str: | |
| """Fallback: Generate rule-based medical questions""" | |
| # Medical question templates (max 5 words each) | |
| question_templates = [ | |
| "How long feeling this?", | |
| "Pain level one-to-ten?", | |
| "Any other symptoms present?", | |
| "Taking any medications now?", | |
| "When did this start?" | |
| ] | |
| if not symptoms: | |
| return question_templates[0] | |
| # Select based on symptom keywords | |
| symptom_text = " ".join(symptoms).lower() | |
| if any(word in symptom_text for word in ["pain", "hurt", "ache"]): | |
| return "Pain level one-to-ten?" | |
| elif any(word in symptom_text for word in ["fever", "temperature", "hot"]): | |
| return "How high is fever?" | |
| elif any(word in symptom_text for word in ["cough", "cold", "flu"]): | |
| return "Cough with mucus present?" | |
| else: | |
| return question_templates[min(len(symptoms), len(question_templates)-1)] | |
| def process_input(self, patient_input: str, session_id: str = "default") -> Dict: | |
| """ | |
| Process patient input and generate medical response | |
| Args: | |
| patient_input: Patient's symptom/response in English | |
| session_id: Session identifier | |
| Returns: | |
| Dict with 'response', 'question_count', 'state' | |
| """ | |
| session = self.sessions[session_id] | |
| # Add to symptoms | |
| if patient_input.strip(): | |
| session['symptoms'].append(patient_input) | |
| session['conversation_history'].append(f"Patient: {patient_input}") | |
| # Check if reached question limit | |
| if session['question_count'] >= self.max_questions: | |
| response = f"Thank you. I have {self.max_questions} questions answered. Please see a doctor for detailed examination." | |
| return { | |
| 'response': response, | |
| 'question_count': session['question_count'], | |
| 'state': 'complete' | |
| } | |
| # Generate follow-up question | |
| if self.model: | |
| question = self._generate_question_with_huatuo( | |
| session['symptoms'], | |
| session['conversation_history'] | |
| ) | |
| else: | |
| question = self._generate_rule_based_question(session['symptoms']) | |
| # Increment question count | |
| session['question_count'] += 1 | |
| session['conversation_history'].append(f"Doctor: {question}") | |
| return { | |
| 'response': question, | |
| 'question_count': session['question_count'], | |
| 'state': f'question_{session["question_count"]}' | |
| } | |
| def process_doctor_input(self, doctor_text: str) -> str: | |
| """Process doctor's voice input""" | |
| # Enforce word limit | |
| words = doctor_text.split() | |
| if len(words) > self.max_words_per_question: | |
| return " ".join(words[:self.max_words_per_question]) + "?" | |
| return doctor_text | |
| def reset_session(self, session_id: str = "default"): | |
| """Reset a session""" | |
| if session_id in self.sessions: | |
| del self.sessions[session_id] | |