sunbal7 commited on
Commit
dc2c50b
·
verified ·
1 Parent(s): 49d6a2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -57
app.py CHANGED
@@ -6,9 +6,10 @@ import pickle
6
  from PIL import Image
7
  import io
8
  import cv2
9
- import easyocr # Replaced pytesseract with easyocr
10
  import os
11
- from sklearn.metrics import roc_auc_score, accuracy_score, classification_report
 
12
  import plotly.graph_objects as go
13
  import plotly.express as px
14
  from datetime import datetime
@@ -76,6 +77,17 @@ def local_css():
76
  padding: 12px 24px;
77
  border-radius: 8px;
78
  }
 
 
 
 
 
 
 
 
 
 
 
79
  </style>
80
  """, unsafe_allow_html=True)
81
 
@@ -90,34 +102,30 @@ def init_session_state():
90
  if 'chat_history' not in st.session_state:
91
  st.session_state.chat_history = []
92
 
 
 
 
 
 
 
 
 
93
  # Load models with error handling and caching
94
  @st.cache_resource(show_spinner=False)
95
  def load_models():
96
  try:
97
- # Check if model files exist
98
- model_files = {
99
- 'heart': 'heart_model.pkl',
100
- 'diabetes': 'diabetes_model.pkl',
101
- 'hypertension': 'hypertension_model.pkl'
102
- }
103
 
104
- models = {}
105
- for name, filename in model_files.items():
106
- if os.path.exists(filename):
107
- models[name] = joblib.load(filename)
108
- else:
109
- st.warning(f"⚠️ {filename} not found. Using mock model for {name}.")
110
- # Create mock model for demonstration
111
- from sklearn.ensemble import RandomForestClassifier
112
- from sklearn.datasets import make_classification
113
- X, y = make_classification(n_samples=100, n_features=10, random_state=42)
114
- models[name] = RandomForestClassifier().fit(X, y)
115
-
116
- return models.get('heart'), models.get('diabetes'), models.get('hypertension')
117
 
118
  except Exception as e:
119
  st.error(f"❌ Error loading models: {str(e)}")
120
- return None, None, None
 
121
 
122
  # Urdu translations
123
  URDU_TRANSLATIONS = {
@@ -150,7 +158,7 @@ URDU_TRANSLATIONS = {
150
 
151
  class OCRProcessor:
152
  def __init__(self):
153
- # Initialize EasyOCR reader
154
  try:
155
  self.reader = easyocr.Reader(['en']) # English only for medical text
156
  except Exception as e:
@@ -191,7 +199,7 @@ class OCRProcessor:
191
  """Extract text from prescription image using EasyOCR"""
192
  try:
193
  if self.reader is None:
194
- return "OCR not available"
195
 
196
  # Preprocess image
197
  processed_image = self.preprocess_image(image)
@@ -202,14 +210,14 @@ class OCRProcessor:
202
  # Combine all detected text
203
  extracted_text = "\n".join(results)
204
 
205
- return extracted_text.strip()
206
 
207
  except Exception as e:
208
  return f"OCR Error: {str(e)}"
209
 
210
  def calculate_ocr_accuracy(self, extracted_text):
211
  """Estimate OCR accuracy based on text quality"""
212
- if not extracted_text or len(extracted_text.strip()) == 0:
213
  return 0
214
 
215
  # Basic heuristics for accuracy estimation
@@ -361,6 +369,23 @@ def validate_patient_data(age, bp_systolic, bp_diastolic, heart_rate):
361
 
362
  return errors
363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  def main():
365
  # Load custom CSS
366
  local_css()
@@ -538,25 +563,25 @@ def main():
538
  else:
539
  try:
540
  with st.spinner("🔍 Analyzing patient data and calculating risks..."):
541
- # Prepare feature arrays (adjust based on your actual model requirements)
542
- # These are example features - modify according to your model training
543
- heart_features = np.array([[age, bp_systolic, cholesterol, heart_rate,
544
- 1 if chest_pain else 0, 1 if shortness_breath else 0]])
545
- diabetes_features = np.array([[age, glucose, bmi, cholesterol,
546
- 1 if fatigue else 0, 1 if blurred_vision else 0]])
547
- hypertension_features = np.array([[age, bp_systolic, bp_diastolic, bmi,
548
- 1 if dizziness else 0, 1 if palpitations else 0]])
549
-
550
- # Get predictions with confidence scores
551
- heart_risk_proba = heart_model.predict_proba(heart_features)[0][1]
552
- diabetes_risk_proba = diabetes_model.predict_proba(diabetes_features)[0][1]
553
- hypertension_risk_proba = hypertension_model.predict_proba(hypertension_features)[0][1]
554
 
555
  # Apply symptom modifiers
 
 
 
 
556
  if chest_pain:
557
  heart_risk_proba = min(1.0, heart_risk_proba * 1.3)
558
  if shortness_breath:
559
  heart_risk_proba = min(1.0, heart_risk_proba * 1.2)
 
 
 
 
560
 
561
  # Calculate integrated priority score
562
  priority_score = calculate_priority_score(
@@ -654,7 +679,7 @@ def main():
654
 
655
  except Exception as e:
656
  st.error(f"❌ Error in risk assessment: {str(e)}")
657
- st.info("💡 Please ensure all model files are properly uploaded and try again.")
658
 
659
  with tab2:
660
  # Prescription OCR
@@ -684,8 +709,8 @@ def main():
684
  extracted_text = ocr_processor.extract_text(image)
685
  accuracy = ocr_processor.calculate_ocr_accuracy(extracted_text)
686
 
687
- if extracted_text and len(extracted_text.strip()) > 0:
688
- st.success(f"✅ Text extraction completed! (Accuracy: {accuracy:.1f}%)")
689
 
690
  if language == "English":
691
  st.subheader("Extracted Medication Information:")
@@ -754,7 +779,7 @@ def main():
754
  # Generate bot response
755
  with st.chat_message("assistant"):
756
  with st.spinner("💭 Analyzing your question..." if language == "English" else "💭 آپ کا سوال تجزیہ ہو رہا ہے..."):
757
- response = chatbot.get_response(prompt, st.session_state.language)
758
  st.markdown(f"**🤖 Healthcare Assistant:**\n\n{response}")
759
 
760
  # Add assistant response to chat history
@@ -874,20 +899,10 @@ def main():
874
  else:
875
  st.subheader("📊 ماڈل کارکردگی کے پیمانے")
876
 
877
- performance_data = pd.DataFrame({
878
- 'Model': ['Heart Disease', 'Diabetes', 'Hypertension', 'Integrated'],
879
- 'Accuracy': [0.88, 0.85, 0.86, 0.87],
880
- 'Precision': [0.86, 0.83, 0.85, 0.84],
881
- 'Recall': [0.89, 0.84, 0.87, 0.86],
882
- 'AUC Score': [0.89, 0.84, 0.87, 0.86]
883
- })
884
-
885
- st.dataframe(performance_data.style.format({
886
- 'Accuracy': '{:.2%}',
887
- 'Precision': '{:.2%}',
888
- 'Recall': '{:.2%}',
889
- 'AUC Score': '{:.3f}'
890
- }).background_gradient(cmap='Blues'), use_container_width=True)
891
 
892
  if __name__ == "__main__":
893
  main()
 
6
  from PIL import Image
7
  import io
8
  import cv2
9
+ import easyocr
10
  import os
11
+ from sklearn.ensemble import RandomForestClassifier
12
+ from sklearn.datasets import make_classification
13
  import plotly.graph_objects as go
14
  import plotly.express as px
15
  from datetime import datetime
 
77
  padding: 12px 24px;
78
  border-radius: 8px;
79
  }
80
+ .dataframe table {
81
+ width: 100%;
82
+ }
83
+ .dataframe th {
84
+ background-color: #2E86AB;
85
+ color: white;
86
+ font-weight: bold;
87
+ }
88
+ .dataframe tr:nth-child(even) {
89
+ background-color: #f2f2f2;
90
+ }
91
  </style>
92
  """, unsafe_allow_html=True)
93
 
 
102
  if 'chat_history' not in st.session_state:
103
  st.session_state.chat_history = []
104
 
105
+ # Create mock models for demonstration
106
+ def create_mock_model():
107
+ """Create a mock Random Forest model for demonstration"""
108
+ X, y = make_classification(n_samples=100, n_features=10, random_state=42)
109
+ model = RandomForestClassifier(n_estimators=10, random_state=42)
110
+ model.fit(X, y)
111
+ return model
112
+
113
  # Load models with error handling and caching
114
  @st.cache_resource(show_spinner=False)
115
  def load_models():
116
  try:
117
+ # Create mock models for demonstration
118
+ st.info("🔧 Using mock AI models for demonstration")
119
+ heart_model = create_mock_model()
120
+ diabetes_model = create_mock_model()
121
+ hypertension_model = create_mock_model()
 
122
 
123
+ return heart_model, diabetes_model, hypertension_model
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  except Exception as e:
126
  st.error(f"❌ Error loading models: {str(e)}")
127
+ # Return mock models even if there's an error
128
+ return create_mock_model(), create_mock_model(), create_mock_model()
129
 
130
  # Urdu translations
131
  URDU_TRANSLATIONS = {
 
158
 
159
  class OCRProcessor:
160
  def __init__(self):
161
+ # Initialize EasyOCR reader with error handling
162
  try:
163
  self.reader = easyocr.Reader(['en']) # English only for medical text
164
  except Exception as e:
 
199
  """Extract text from prescription image using EasyOCR"""
200
  try:
201
  if self.reader is None:
202
+ return "OCR not available. Please check EasyOCR installation."
203
 
204
  # Preprocess image
205
  processed_image = self.preprocess_image(image)
 
210
  # Combine all detected text
211
  extracted_text = "\n".join(results)
212
 
213
+ return extracted_text.strip() if extracted_text.strip() else "No text detected in the image."
214
 
215
  except Exception as e:
216
  return f"OCR Error: {str(e)}"
217
 
218
  def calculate_ocr_accuracy(self, extracted_text):
219
  """Estimate OCR accuracy based on text quality"""
220
+ if not extracted_text or len(extracted_text.strip()) == 0 or "No text detected" in extracted_text:
221
  return 0
222
 
223
  # Basic heuristics for accuracy estimation
 
369
 
370
  return errors
371
 
372
+ def create_sample_dataframe():
373
+ """Create sample performance data with proper formatting"""
374
+ performance_data = pd.DataFrame({
375
+ 'Model': ['Heart Disease', 'Diabetes', 'Hypertension', 'Integrated'],
376
+ 'Accuracy': [0.88, 0.85, 0.86, 0.87],
377
+ 'Precision': [0.86, 0.83, 0.85, 0.84],
378
+ 'Recall': [0.89, 0.84, 0.87, 0.86],
379
+ 'AUC Score': [0.89, 0.84, 0.87, 0.86]
380
+ })
381
+
382
+ # Format the percentages
383
+ for col in ['Accuracy', 'Precision', 'Recall']:
384
+ performance_data[col] = performance_data[col].apply(lambda x: f'{x:.1%}')
385
+ performance_data['AUC Score'] = performance_data['AUC Score'].apply(lambda x: f'{x:.3f}')
386
+
387
+ return performance_data
388
+
389
  def main():
390
  # Load custom CSS
391
  local_css()
 
563
  else:
564
  try:
565
  with st.spinner("🔍 Analyzing patient data and calculating risks..."):
566
+ # Simulate risk scores for demonstration
567
+ # In a real application, these would come from your trained models
568
+ base_heart_risk = min(0.8, (age - 30) / 100 + (bp_systolic - 120) / 200 + (cholesterol - 150) / 500)
569
+ base_diabetes_risk = min(0.7, (age - 30) / 100 + (glucose - 80) / 400 + (bmi - 20) / 50)
570
+ base_hypertension_risk = min(0.75, (age - 30) / 100 + (bp_systolic - 120) / 150 + (bmi - 20) / 40)
 
 
 
 
 
 
 
 
571
 
572
  # Apply symptom modifiers
573
+ heart_risk_proba = base_heart_risk
574
+ diabetes_risk_proba = base_diabetes_risk
575
+ hypertension_risk_proba = base_hypertension_risk
576
+
577
  if chest_pain:
578
  heart_risk_proba = min(1.0, heart_risk_proba * 1.3)
579
  if shortness_breath:
580
  heart_risk_proba = min(1.0, heart_risk_proba * 1.2)
581
+ if fatigue:
582
+ diabetes_risk_proba = min(1.0, diabetes_risk_proba * 1.2)
583
+ if dizziness:
584
+ hypertension_risk_proba = min(1.0, hypertension_risk_proba * 1.3)
585
 
586
  # Calculate integrated priority score
587
  priority_score = calculate_priority_score(
 
679
 
680
  except Exception as e:
681
  st.error(f"❌ Error in risk assessment: {str(e)}")
682
+ st.info("💡 This is a demonstration using mock models. For real deployment, train and upload proper ML models.")
683
 
684
  with tab2:
685
  # Prescription OCR
 
709
  extracted_text = ocr_processor.extract_text(image)
710
  accuracy = ocr_processor.calculate_ocr_accuracy(extracted_text)
711
 
712
+ if extracted_text and "No text detected" not in extracted_text and "OCR Error" not in extracted_text:
713
+ st.success(f"✅ Text extraction completed! (Estimated Accuracy: {accuracy:.1f}%)")
714
 
715
  if language == "English":
716
  st.subheader("Extracted Medication Information:")
 
779
  # Generate bot response
780
  with st.chat_message("assistant"):
781
  with st.spinner("💭 Analyzing your question..." if language == "English" else "💭 آپ کا سوال تجزیہ ہو رہا ہے..."):
782
+ response = chatbot.get_response(prompt, language)
783
  st.markdown(f"**🤖 Healthcare Assistant:**\n\n{response}")
784
 
785
  # Add assistant response to chat history
 
899
  else:
900
  st.subheader("📊 ماڈل کارکردگی کے پیمانے")
901
 
902
+ performance_data = create_sample_dataframe()
903
+
904
+ # Use Streamlit's native dataframe with custom CSS
905
+ st.dataframe(performance_data, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
906
 
907
  if __name__ == "__main__":
908
  main()