Xiaomeng1130 commited on
Commit
75059a8
·
verified ·
1 Parent(s): 7046c4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -27
app.py CHANGED
@@ -3,18 +3,21 @@ import torch
3
  import gradio as gr
4
  from PIL import Image
5
  import numpy as np
 
 
6
 
7
- # ========== 1. Import project modules ==========
8
  try:
9
- # 尝试导入 stoma_clip 模块(通过 requirements.txt 中的 -e . 安装)
10
  from stoma_clip import pmc_clip
11
  from stoma_clip.pmc_clip.factory import _rescan_model_configs
12
  from stoma_clip.training.fusion_method import convert_model_to_cls
13
  from stoma_clip.training.dataset.utils import encode_mlm
14
  print("Stoma-CLIP modules imported successfully.")
 
15
  except ImportError as e:
16
- # 导入失败的日志,以供调试
17
- print(f"Error importing Stoma-CLIP modules: {e}")
 
18
 
19
  # ========== 2. Model Configuration and Loading ==========
20
  LABEL_MAP = {
@@ -28,16 +31,14 @@ NUM_CLASSES = len(LABEL_MAP)
28
  class Args:
29
  def __init__(self):
30
  self.model = "RN50_fusion4"
31
- # 假设 stoma_clip.pt 文件位于应用的根目录(/app),或被您的内部库识别。
32
- # 确保这个文件是正确的文件名。
33
  self.pretrained = "stoma_clip.pt"
34
  self.num_classes = NUM_CLASSES
35
  self.mlm = True
36
  self.crop_scale = 0.9
37
  self.context_length = 77
38
- # 自动检测并使用 CUDA/GPU
39
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  print(f"Using device: {self.device}")
 
41
  args = Args()
42
 
43
  MODEL = None
@@ -45,34 +46,46 @@ PREPROCESS = None
45
  TOKENIZER = None
46
 
47
  def load_model():
48
- """Load model once when Gradio starts, implementing the singleton pattern."""
49
  global MODEL, PREPROCESS, TOKENIZER
 
 
50
  if MODEL is not None:
51
- print("Model already loaded. Returning cached objects.")
52
  return MODEL, PREPROCESS, TOKENIZER
53
 
54
- print("--- Starting Model Load Process ---")
 
 
55
  try:
56
  # Step 1: Create model and transforms
57
- print("1. Rescanning model configs...")
 
 
58
  _rescan_model_configs()
59
  model, _, preprocess = pmc_clip.create_model_and_transforms(args)
60
  model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='cross_attention')
61
  print("2. Model architecture created. Moving to device...")
 
62
 
63
  # Move model architecture to GPU/CPU
64
  model.to(args.device).eval()
65
 
66
- # Step 2: Load weights - 使用 map_location 确保加载到正确的设备
67
  print(f"3. Loading weights from {args.pretrained} to {args.device}...")
68
- # 这里的 torch.load 必须依赖于 Dockerfile 预下载或 COPY 进来的文件
 
 
69
  state_dict = torch.load(args.pretrained, map_location=args.device)
70
 
71
  print("4. Weights file loaded. Cleaning state dict...")
 
 
72
  state_dict_clean = {k.replace("module.", "", 1): v for k, v in state_dict['state_dict'].items()}
73
 
74
  # Step 3: Apply weights
75
  print("5. Loading state dict into model architecture...")
 
 
76
  model.load_state_dict(state_dict_clean)
77
 
78
  # Step 4: Final setup
@@ -81,32 +94,40 @@ def load_model():
81
  PREPROCESS = preprocess
82
  TOKENIZER = tokenizer
83
 
84
- print("✨ Stoma-CLIP Model loaded successfully!")
 
 
 
85
  return MODEL, PREPROCESS, TOKENIZER
86
 
87
  except Exception as e:
88
  print(f"🔥 Error during model loading: {e}")
89
- MODEL = None
90
- # 抛出异常,让 Gradio 知道启动失败
91
  raise RuntimeError(f"Failed to load Stoma-CLIP model: {e}")
92
 
93
  # ========== 3. Inference Function ==========
94
  def predict_stoma_clip(image: Image.Image, caption: str):
95
- # 确保在推理时调用加载模型
96
  try:
97
- model, preprocess, tokenizer = load_model()
 
 
 
 
 
98
  except RuntimeError:
99
  return "Model Loading Failed (See Logs)", {}
100
-
 
101
  image = image.convert("RGB")
102
  device = args.device
103
-
104
  # 将输入数据移动到 GPU
105
  image_tensor = preprocess(image).unsqueeze(0).to(device)
106
-
107
  mask_token, pad_token = '[MASK]', '[PAD]'
108
  vocab = [v for v in tokenizer.get_vocab().keys() if v not in tokenizer.all_special_tokens]
109
-
110
  bert_input, bert_label = encode_mlm(
111
  caption=caption,
112
  vocab=vocab,
@@ -116,14 +137,14 @@ def predict_stoma_clip(image: Image.Image, caption: str):
116
  tokenizer=tokenizer,
117
  args=args,
118
  )
119
-
120
  with torch.no_grad():
121
  inputs = {"images": image_tensor, "bert_input": bert_input, "bert_label": bert_label}
122
  outputs = model(inputs)
123
  # 将结果移回 CPU 进行 numpy 转换
124
  probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
125
  predicted_class_idx = torch.argmax(outputs, dim=1).item()
126
-
127
  predicted_class_name = REVERSE_LABEL_MAP.get(predicted_class_idx, "Unknown")
128
  probability_distribution = {REVERSE_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)}
129
  return predicted_class_name, probability_distribution
@@ -159,8 +180,14 @@ iface = gr.Interface(
159
  )
160
 
161
  if __name__ == "__main__":
162
- # 在应用启动时尝试加载模型,如果失败,launch 会抛出异常
163
- # load_model() # iface.launch() 内部通常会自动触发模型加载,但显式调用可以捕获启动错误
 
164
 
165
- # T4 / Docker 环境下使用 0.0.0.0 和默认端口
 
 
 
 
 
166
  iface.launch(server_name="0.0.0.0", server_port=7860)
 
3
  import gradio as gr
4
  from PIL import Image
5
  import numpy as np
6
+ import sys
7
+ import time
8
 
9
+ # ========== 1. Import project modules and Model Configuration ==========
10
  try:
 
11
  from stoma_clip import pmc_clip
12
  from stoma_clip.pmc_clip.factory import _rescan_model_configs
13
  from stoma_clip.training.fusion_method import convert_model_to_cls
14
  from stoma_clip.training.dataset.utils import encode_mlm
15
  print("Stoma-CLIP modules imported successfully.")
16
+ sys.stdout.flush() # 强制刷新输出
17
  except ImportError as e:
18
+ print(f"FATAL: Error importing Stoma-CLIP modules: {e}")
19
+ sys.stdout.flush()
20
+ sys.exit(1)
21
 
22
  # ========== 2. Model Configuration and Loading ==========
23
  LABEL_MAP = {
 
31
  class Args:
32
  def __init__(self):
33
  self.model = "RN50_fusion4"
 
 
34
  self.pretrained = "stoma_clip.pt"
35
  self.num_classes = NUM_CLASSES
36
  self.mlm = True
37
  self.crop_scale = 0.9
38
  self.context_length = 77
 
39
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  print(f"Using device: {self.device}")
41
+ sys.stdout.flush()
42
  args = Args()
43
 
44
  MODEL = None
 
46
  TOKENIZER = None
47
 
48
  def load_model():
49
+ """Load model once in the main thread during application initialization."""
50
  global MODEL, PREPROCESS, TOKENIZER
51
+
52
+ start_time = time.time()
53
  if MODEL is not None:
 
54
  return MODEL, PREPROCESS, TOKENIZER
55
 
56
+ print(f"--- Starting Model Load Process at {time.strftime('%H:%M:%S')} ---")
57
+ sys.stdout.flush() # 诊断点 1
58
+
59
  try:
60
  # Step 1: Create model and transforms
61
+ print("1. Rescanning model configs and creating architecture...")
62
+ sys.stdout.flush() # 诊断点 2
63
+
64
  _rescan_model_configs()
65
  model, _, preprocess = pmc_clip.create_model_and_transforms(args)
66
  model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='cross_attention')
67
  print("2. Model architecture created. Moving to device...")
68
+ sys.stdout.flush() # 诊断点 3
69
 
70
  # Move model architecture to GPU/CPU
71
  model.to(args.device).eval()
72
 
73
+ # Step 2: Load weights - 必须确保 stoma_clip.pt 文件大小合理或复制完整
74
  print(f"3. Loading weights from {args.pretrained} to {args.device}...")
75
+ sys.stdout.flush() # 诊断点 4 - 关键点:在执行耗时 I/O 前确保日志已输出
76
+
77
+ # 强制使用 Float32 加载,然后转换为半精度,如果模型支持的话,有助于加速传输
78
  state_dict = torch.load(args.pretrained, map_location=args.device)
79
 
80
  print("4. Weights file loaded. Cleaning state dict...")
81
+ sys.stdout.flush() # 诊断点 5
82
+
83
  state_dict_clean = {k.replace("module.", "", 1): v for k, v in state_dict['state_dict'].items()}
84
 
85
  # Step 3: Apply weights
86
  print("5. Loading state dict into model architecture...")
87
+ sys.stdout.flush() # 诊断点 6
88
+
89
  model.load_state_dict(state_dict_clean)
90
 
91
  # Step 4: Final setup
 
94
  PREPROCESS = preprocess
95
  TOKENIZER = tokenizer
96
 
97
+ end_time = time.time()
98
+ print(f"✨ Stoma-CLIP Model loaded successfully! Total time: {end_time - start_time:.2f} seconds.")
99
+ sys.stdout.flush() # 诊断点 7
100
+
101
  return MODEL, PREPROCESS, TOKENIZER
102
 
103
  except Exception as e:
104
  print(f"🔥 Error during model loading: {e}")
105
+ sys.stdout.flush()
 
106
  raise RuntimeError(f"Failed to load Stoma-CLIP model: {e}")
107
 
108
  # ========== 3. Inference Function ==========
109
  def predict_stoma_clip(image: Image.Image, caption: str):
110
+ # 确保在推理时调用加载模型(仅作为后备/懒加载)
111
  try:
112
+ # 如果启动时加载失败,这里会再次尝试,但依赖于全局 MODEL 变量
113
+ if MODEL is None:
114
+ model, preprocess, tokenizer = load_model()
115
+ else:
116
+ model, preprocess, tokenizer = MODEL, PREPROCESS, TOKENIZER
117
+
118
  except RuntimeError:
119
  return "Model Loading Failed (See Logs)", {}
120
+
121
+ # ... 原来的推理逻辑保持不变 ...
122
  image = image.convert("RGB")
123
  device = args.device
124
+
125
  # 将输入数据移动到 GPU
126
  image_tensor = preprocess(image).unsqueeze(0).to(device)
127
+
128
  mask_token, pad_token = '[MASK]', '[PAD]'
129
  vocab = [v for v in tokenizer.get_vocab().keys() if v not in tokenizer.all_special_tokens]
130
+
131
  bert_input, bert_label = encode_mlm(
132
  caption=caption,
133
  vocab=vocab,
 
137
  tokenizer=tokenizer,
138
  args=args,
139
  )
140
+
141
  with torch.no_grad():
142
  inputs = {"images": image_tensor, "bert_input": bert_input, "bert_label": bert_label}
143
  outputs = model(inputs)
144
  # 将结果移回 CPU 进行 numpy 转换
145
  probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
146
  predicted_class_idx = torch.argmax(outputs, dim=1).item()
147
+
148
  predicted_class_name = REVERSE_LABEL_MAP.get(predicted_class_idx, "Unknown")
149
  probability_distribution = {REVERSE_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)}
150
  return predicted_class_name, probability_distribution
 
180
  )
181
 
182
  if __name__ == "__main__":
183
+ # --- 关键修复:强制在 Gradio launch 之前加载模型,将 I/O 阻塞移到启动阶段 ---
184
+ print("Pre-loading model before Gradio launch to prevent runtime timeout...")
185
+ sys.stdout.flush()
186
 
187
+ load_model()
188
+
189
+ print("Model loaded. Launching Gradio interface...")
190
+ sys.stdout.flush()
191
+
192
+ # 启动 Gradio
193
  iface.launch(server_name="0.0.0.0", server_port=7860)