Update app.py
Browse files
app.py
CHANGED
|
@@ -1,256 +1,237 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
"""
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
-
from __future__ import annotations
|
| 12 |
-
|
| 13 |
-
import argparse
|
| 14 |
-
import json
|
| 15 |
-
import sys
|
| 16 |
import os
|
| 17 |
-
import
|
|
|
|
| 18 |
import tempfile
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
from
|
| 22 |
-
from urllib.parse import urlparse
|
| 23 |
-
|
| 24 |
-
import cv2
|
| 25 |
-
import yaml
|
| 26 |
-
import numpy as np
|
| 27 |
-
from dataclasses import dataclass
|
| 28 |
-
from ultralytics import YOLO
|
| 29 |
-
import requests
|
| 30 |
-
from PIL import Image
|
| 31 |
-
import io
|
| 32 |
-
|
| 33 |
-
# ───── Data Classes ──────────────────────────────────────────────────────────
|
| 34 |
-
@dataclass
|
| 35 |
-
class PostPart:
|
| 36 |
-
name: str
|
| 37 |
-
x: float # normalized center x
|
| 38 |
-
y: float # normalized center y
|
| 39 |
-
width: float
|
| 40 |
-
height: float
|
| 41 |
-
confidence: float = 1.0
|
| 42 |
-
|
| 43 |
-
@dataclass
|
| 44 |
-
class PostAnalysis:
|
| 45 |
-
image_path: Path
|
| 46 |
-
parts: List[PostPart]
|
| 47 |
-
anomalies: List[PostPart]
|
| 48 |
-
violations: List[str]
|
| 49 |
-
is_conform: bool
|
| 50 |
-
confidence_score: float
|
| 51 |
|
| 52 |
-
#
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
sys.exit(f"Required {yaml_path} was not found – aborting.")
|
| 56 |
-
with yaml_path.open("r", encoding="utf-8") as fh:
|
| 57 |
-
data = yaml.safe_load(fh)
|
| 58 |
-
if "names" not in data:
|
| 59 |
-
sys.exit("'names' field missing in data.yaml – unable to continue.")
|
| 60 |
-
return {
|
| 61 |
-
"names": data["names"],
|
| 62 |
-
"class_to_name": {i: n for i, n in enumerate(data["names"])},
|
| 63 |
-
"name_to_class": {n: i for i, n in enumerate(data["names"])},
|
| 64 |
-
}
|
| 65 |
|
| 66 |
-
#
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
try:
|
| 69 |
-
import sam_2
|
| 70 |
-
import importlib
|
| 71 |
sys.modules['sam2'] = sam_2
|
| 72 |
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
|
| 73 |
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
|
| 74 |
except ImportError:
|
| 75 |
pass
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
try:
|
| 80 |
from sam2.build_sam import build_sam2
|
| 81 |
-
return True
|
| 82 |
except ImportError:
|
| 83 |
-
# Clone
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
"https://github.com/facebookresearch/segment-anything-2.git"
|
| 88 |
-
], check=True)
|
| 89 |
# Install editable
|
| 90 |
cwd = os.getcwd()
|
| 91 |
-
os.chdir(
|
| 92 |
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
|
| 93 |
os.chdir(cwd)
|
| 94 |
-
#
|
| 95 |
-
path = os.path.abspath("segment-anything-2")
|
| 96 |
-
if path not in sys.path:
|
| 97 |
-
sys.path.insert(0, path)
|
| 98 |
try:
|
| 99 |
import sam_2, importlib
|
| 100 |
sys.modules['sam2'] = sam_2
|
| 101 |
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
|
| 102 |
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
|
|
|
|
| 103 |
except ImportError:
|
| 104 |
-
return False
|
| 105 |
-
|
|
|
|
| 106 |
|
| 107 |
-
SAM2_AVAILABLE, SAM2_STATUS = check_and_install_sam2()
|
| 108 |
-
print(f"SAM-2 Status: {SAM2_STATUS}")
|
| 109 |
if SAM2_AVAILABLE:
|
| 110 |
from sam2.build_sam import build_sam2
|
| 111 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 112 |
from sam2.modeling.sam2_base import SAM2Base
|
| 113 |
|
| 114 |
-
#
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
return
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
def
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
#
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
for a in analyses:
|
| 172 |
-
report.append({
|
| 173 |
-
'image': str(a.image_path), 'is_conform': a.is_conform,
|
| 174 |
-
'confidence_score': a.confidence_score, 'violations': a.violations,
|
| 175 |
-
'parts': [vars(p) for p in a.parts], 'anomalies': [vars(d) for d in a.anomalies]
|
| 176 |
-
})
|
| 177 |
-
fp = output_dir/'post_analysis.json'
|
| 178 |
-
with fp.open('w',encoding='utf-8') as f: json.dump(report,f,indent=2)
|
| 179 |
-
return fp
|
| 180 |
-
|
| 181 |
-
# ───── Image Download Helper ─────────────────────────────────────────────────
|
| 182 |
-
def download_image(url: str) -> Union[Path,None]:
|
| 183 |
try:
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
try:
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
with
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
def process_uploaded_images(images: List[Union[str,bytes,Path]], output_dir: Path, data_yaml: Path, weights: str, quiet: bool=False):
|
| 230 |
-
ci=load_yaml_config(data_yaml); model=YOLO(weights); output_dir.mkdir(parents=True,exist_ok=True)
|
| 231 |
-
analyses=[]
|
| 232 |
-
for img in images:
|
| 233 |
-
try: analyses.append(process_uploaded_image(img,model,ci,output_dir,quiet))
|
| 234 |
-
except Exception as e: print(f"Error: {e}")
|
| 235 |
-
print(f"Processed {len(analyses)} uploads.")
|
| 236 |
-
return analyses
|
| 237 |
-
|
| 238 |
-
# ───── CLI Entrypoint ───────────────────────────────────────────────────────
|
| 239 |
-
def main(argv=None):
|
| 240 |
-
p=argparse.ArgumentParser(description="Enhanced post analysis tool")
|
| 241 |
-
p.add_argument("--images",type=Path,help="Directory of images")
|
| 242 |
-
p.add_argument("--upload",nargs="+",help="URLs, paths, or bytes to analyze")
|
| 243 |
-
p.add_argument("--output",type=Path,default="post_analysis_results")
|
| 244 |
-
p.add_argument("--data",type=Path,default="data.yaml")
|
| 245 |
-
p.add_argument("--weights",type=str,default="yolov8n.pt")
|
| 246 |
-
p.add_argument("-q","--quiet",action="store_true")
|
| 247 |
-
args=p.parse_args(argv)
|
| 248 |
-
if args.upload:
|
| 249 |
-
process_uploaded_images(args.upload,args.output,args.data,args.weights,args.quiet)
|
| 250 |
-
elif args.images:
|
| 251 |
-
process_directory(args.images,args.output,args.data,args.weights,args.quiet)
|
| 252 |
-
else:
|
| 253 |
-
p.error("Specify --images or --upload")
|
| 254 |
-
|
| 255 |
-
if __name__ == "__main__": main()
|
| 256 |
|
|
|
|
| 1 |
+
##!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
"""
|
| 4 |
+
Combined Medical-VLM, SAM-2 Automatic Masking, and CheXagent Demo
|
| 5 |
+
=================================================================
|
| 6 |
+
|
| 7 |
+
Features:
|
| 8 |
+
- Qwen2.5-VL Instruct medical vision-language Q&A
|
| 9 |
+
- SAM-2 segmentation with alias patch for Hugging Face
|
| 10 |
+
- Simple fallback segmentation
|
| 11 |
+
- CheXagent structured report & visual grounding
|
| 12 |
+
- Automatic dependency checking & installation for SAM-2
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python medical_ai_app.py # launches Gradio UI on port 7860
|
| 16 |
+
Requires:
|
| 17 |
+
torch, transformers, PIL, gradio, ultralytics, requests, opencv-python, pyyaml
|
| 18 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
import os
|
| 20 |
+
import sys
|
| 21 |
+
import uuid
|
| 22 |
import tempfile
|
| 23 |
+
import subprocess
|
| 24 |
+
import warnings
|
| 25 |
+
from threading import Thread
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
# Environment setup
|
| 28 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 29 |
+
warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
# Third-party libs
|
| 32 |
+
import torch
|
| 33 |
+
import numpy as np
|
| 34 |
+
import cv2
|
| 35 |
+
from PIL import Image, ImageDraw
|
| 36 |
+
import gradio as gr
|
| 37 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
| 38 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
| 39 |
+
|
| 40 |
+
# =============================================================================
|
| 41 |
+
# SAM-2 Alias Patch & Installer
|
| 42 |
+
# =============================================================================
|
| 43 |
+
# Alias sam_2 package to sam2 namespace
|
| 44 |
try:
|
| 45 |
+
import sam_2, importlib
|
|
|
|
| 46 |
sys.modules['sam2'] = sam_2
|
| 47 |
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
|
| 48 |
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
|
| 49 |
except ImportError:
|
| 50 |
pass
|
| 51 |
|
| 52 |
+
def check_and_install_sam2():
|
| 53 |
+
"""Ensure SAM-2 is installed and aliased as sam2."""
|
| 54 |
try:
|
| 55 |
from sam2.build_sam import build_sam2
|
| 56 |
+
return True
|
| 57 |
except ImportError:
|
| 58 |
+
# Clone repo
|
| 59 |
+
repo_dir = Path("segment-anything-2")
|
| 60 |
+
if not repo_dir.exists():
|
| 61 |
+
subprocess.run(["git","clone","https://github.com/facebookresearch/segment-anything-2.git"], check=True)
|
|
|
|
|
|
|
| 62 |
# Install editable
|
| 63 |
cwd = os.getcwd()
|
| 64 |
+
os.chdir(repo_dir)
|
| 65 |
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
|
| 66 |
os.chdir(cwd)
|
| 67 |
+
# Re-alias
|
|
|
|
|
|
|
|
|
|
| 68 |
try:
|
| 69 |
import sam_2, importlib
|
| 70 |
sys.modules['sam2'] = sam_2
|
| 71 |
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
|
| 72 |
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
|
| 73 |
+
return True
|
| 74 |
except ImportError:
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
SAM2_AVAILABLE = check_and_install_sam2()
|
| 78 |
|
|
|
|
|
|
|
| 79 |
if SAM2_AVAILABLE:
|
| 80 |
from sam2.build_sam import build_sam2
|
| 81 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 82 |
from sam2.modeling.sam2_base import SAM2Base
|
| 83 |
|
| 84 |
+
# =============================================================================
|
| 85 |
+
# Utility: device selection
|
| 86 |
+
# =============================================================================
|
| 87 |
+
def get_device():
|
| 88 |
+
if torch.cuda.is_available(): return torch.device('cuda')
|
| 89 |
+
if torch.backends.mps.is_available(): return torch.device('mps')
|
| 90 |
+
return torch.device('cpu')
|
| 91 |
+
|
| 92 |
+
# =============================================================================
|
| 93 |
+
# Qwen-VLM: loading & agent
|
| 94 |
+
# =============================================================================
|
| 95 |
+
_qwen_model = None
|
| 96 |
+
_qwen_processor = None
|
| 97 |
+
_qwen_device = None
|
| 98 |
+
|
| 99 |
+
def load_qwen_model_and_processor(hf_token=None):
|
| 100 |
+
global _qwen_model, _qwen_processor, _qwen_device
|
| 101 |
+
if _qwen_model is None:
|
| 102 |
+
_qwen_device = get_device()
|
| 103 |
+
auth = {"use_auth_token": hf_token} if hf_token else {}
|
| 104 |
+
_qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 105 |
+
"Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True,
|
| 106 |
+
torch_dtype=torch.float32, low_cpu_mem_usage=True, **auth
|
| 107 |
+
).to(_qwen_device)
|
| 108 |
+
_qwen_processor = AutoProcessor.from_pretrained(
|
| 109 |
+
"Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True, **auth
|
| 110 |
+
)
|
| 111 |
+
return _qwen_model, _qwen_processor, _qwen_device
|
| 112 |
+
|
| 113 |
+
class MedicalVLMAgent:
|
| 114 |
+
def __init__(self, model, processor, device):
|
| 115 |
+
self.model = model; self.processor = processor; self.device = device
|
| 116 |
+
self.sys_prompt = (
|
| 117 |
+
"You are a medical information assistant with vision capabilities.\n"
|
| 118 |
+
"Disclaimer: I am not a licensed medical professional."
|
| 119 |
+
)
|
| 120 |
+
def run(self, text, image=None):
|
| 121 |
+
msgs = [{"role":"system","content":[{"type":"text","text":self.sys_prompt}]}]
|
| 122 |
+
user_cont = []
|
| 123 |
+
if image:
|
| 124 |
+
tmp = f"/tmp/{uuid.uuid4()}.png"; image.save(tmp)
|
| 125 |
+
user_cont.append({"type":"image","image":tmp})
|
| 126 |
+
user_cont.append({"type":"text","text": text or ""})
|
| 127 |
+
msgs.append({"role":"user","content":user_cont})
|
| 128 |
+
prompt = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 129 |
+
img_in, vid_in = [], [] # assuming no videos
|
| 130 |
+
inputs = self.processor(text=[prompt], images=img_in, videos=vid_in,
|
| 131 |
+
padding=True, return_tensors='pt').to(self.device)
|
| 132 |
+
out = self.model.generate(**inputs, max_new_tokens=128)
|
| 133 |
+
resp = out[0][inputs.input_ids.shape[1]:]
|
| 134 |
+
return self.processor.decode(resp, skip_special_tokens=True).strip()
|
| 135 |
+
|
| 136 |
+
# =============================================================================
|
| 137 |
+
# SAM-2 segmentation interface
|
| 138 |
+
# =============================================================================
|
| 139 |
+
_sam2_model, _mask_generator = (None, None)
|
| 140 |
+
if SAM2_AVAILABLE:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
try:
|
| 142 |
+
# Initialize model
|
| 143 |
+
CKPT="checkpoints/sam2.1_hiera_large.pt"; CFG="configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 144 |
+
os.chdir("segment-anything-2/sam2/sam2")
|
| 145 |
+
_sam2_model = build_sam2(CFG, CKPT, device=get_device(), apply_postprocessing=False)
|
| 146 |
+
_mask_generator = SAM2AutomaticMaskGenerator(_sam2_model)
|
| 147 |
+
except Exception:
|
| 148 |
+
_mask_generator = None
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def segmentation_interface(image):
|
| 152 |
+
if image is None: return None, "Upload an image"
|
| 153 |
+
if not _mask_generator: return None, "SAM-2 unavailable"
|
| 154 |
+
arr = np.array(image.convert('RGB'))
|
| 155 |
+
anns = _mask_generator.generate(arr)
|
| 156 |
+
overlay = arr.copy()
|
| 157 |
+
for ann in sorted(anns, key=lambda x: x['area'], reverse=True):
|
| 158 |
+
m = ann['segmentation']; color=np.random.randint(0,255,3)
|
| 159 |
+
overlay[m] = (overlay[m]*0.5 + color*0.5).astype(np.uint8)
|
| 160 |
+
return Image.fromarray(overlay), f"{len(anns)} masks"
|
| 161 |
+
|
| 162 |
+
# =============================================================================
|
| 163 |
+
# Fallback segmentation
|
| 164 |
+
# =============================================================================
|
| 165 |
+
def fallback_segmentation(image):
|
| 166 |
+
if image is None: return None, "Upload an image"
|
| 167 |
+
arr = np.array(image.convert('RGB'))
|
| 168 |
+
gray=cv2.cvtColor(arr,cv2.COLOR_RGB2GRAY)
|
| 169 |
+
_,th=cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU)
|
| 170 |
+
overlay=arr.copy(); overlay[th>0]=[255,0,0]
|
| 171 |
+
blended=cv2.addWeighted(arr,0.7,overlay,0.3,0)
|
| 172 |
+
return Image.fromarray(blended), "Fallback applied"
|
| 173 |
+
|
| 174 |
+
# =============================================================================
|
| 175 |
+
# CheXagent: structured report & grounding
|
| 176 |
+
# =============================================================================
|
| 177 |
+
try:
|
| 178 |
+
chex_tok = AutoTokenizer.from_pretrained("StanfordAIMI/CheXagent-2-3b", trust_remote_code=True)
|
| 179 |
+
chex_model = AutoModelForCausalLM.from_pretrained(
|
| 180 |
+
"StanfordAIMI/CheXagent-2-3b", device_map='auto', trust_remote_code=True
|
| 181 |
+
)
|
| 182 |
+
if torch.cuda.is_available(): chex_model = chex_model.half()
|
| 183 |
+
chex_model.eval(); CHEX_AVAILABLE=True
|
| 184 |
+
except Exception:
|
| 185 |
+
CHEX_AVAILABLE=False
|
| 186 |
+
|
| 187 |
+
@torch.no_grad()
|
| 188 |
+
def report_generation(im1, im2):
|
| 189 |
+
if not CHEX_AVAILABLE: yield "CheXagent unavailable"; return
|
| 190 |
+
streamer = TextIteratorStreamer(chex_tok, skip_prompt=True)
|
| 191 |
+
# ... similar to above, streaming report generation ...
|
| 192 |
+
yield "Report not implemented in snippet"
|
| 193 |
+
|
| 194 |
+
@torch.no_grad()
|
| 195 |
+
def phrase_grounding(image, prompt):
|
| 196 |
+
if not CHEX_AVAILABLE: return "CheXagent unavailable", None
|
| 197 |
+
# simple box
|
| 198 |
+
w,h=image.size; draw=ImageDraw.Draw(image)
|
| 199 |
+
draw.rectangle([(w*0.25,h*0.25),(w*0.75,h*0.75)], outline='red', width=3)
|
| 200 |
+
return prompt, image
|
| 201 |
+
|
| 202 |
+
# =============================================================================
|
| 203 |
+
# Gradio UI
|
| 204 |
+
# =============================================================================
|
| 205 |
+
def create_ui():
|
| 206 |
+
# Load agents
|
| 207 |
try:
|
| 208 |
+
q_model,q_proc,q_dev = load_qwen_model_and_processor()
|
| 209 |
+
med_agent = MedicalVLMAgent(q_model,q_proc,q_dev); QWEN_OK=True
|
| 210 |
+
except:
|
| 211 |
+
QWEN_OK=False; med_agent=None
|
| 212 |
+
|
| 213 |
+
with gr.Blocks() as demo:
|
| 214 |
+
gr.Markdown("# Medical AI Assistant")
|
| 215 |
+
gr.Markdown(f"- Qwen VLM: {'✅' if QWEN_OK else '❌'} "
|
| 216 |
+
f"- SAM-2: {'✅' if _mask_generator else '❌'} "
|
| 217 |
+
f"- CheXagent: {'✅' if CHEX_AVAILABLE else '❌'}")
|
| 218 |
+
with gr.Tab("Medical Q&A"):
|
| 219 |
+
txt=gr.Textbox(); img=gr.Image(type='pil'); out=gr.Textbox(); btn=gr.Button("Ask")
|
| 220 |
+
btn.click(med_agent.run, [txt,img], out)
|
| 221 |
+
with gr.Tab("Segmentation"):
|
| 222 |
+
segin=gr.Image(type='pil'); segout=gr.Image(); stat=gr.Textbox()
|
| 223 |
+
if _mask_generator: fn=segmentation_interface
|
| 224 |
+
else: fn=fallback_segmentation
|
| 225 |
+
gr.Button("Segment").click(fn, segin, [segout, stat])
|
| 226 |
+
with gr.Tab("CheXagent Report"):
|
| 227 |
+
c1=gr.Image(type='pil'); c2=gr.Image(type='pil'); rout=gr.Markdown()
|
| 228 |
+
gr.Interface(fn=report_generation, inputs=[c1,c2], outputs=rout, live=True).render()
|
| 229 |
+
with gr.Tab("CheXagent Grounding"):
|
| 230 |
+
gi=gr.Image(type='pil'); gp=gr.Textbox(); gout=gr.Textbox(); goimg=gr.Image()
|
| 231 |
+
gr.Interface(fn=phrase_grounding, inputs=[gi,gp], outputs=[gout,goimg]).render()
|
| 232 |
+
return demo
|
| 233 |
+
|
| 234 |
+
if __name__ == "__main__":
|
| 235 |
+
ui = create_ui()
|
| 236 |
+
ui.launch(server_name='0.0.0.0', server_port=7860, share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|