import os import torch from torch.utils.data import DataLoader from tqdm import tqdm import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import roc_curve, auc, confusion_matrix, roc_auc_score from sklearn.preprocessing import label_binarize import json import sys sys.path.append('.') import pmc_clip from training.params import parse_args from training.data import PmcDataset from training.fusion_method import convert_model_to_cls # 标签映射 LABEL_MAP = { "Irritant dermatitis": 0, "Allergic contact dermatitis": 1, "Mechanical injury": 2, "Folliculitis": 3, "Fungal infection": 4, "Skin hyperplasia": 5, "Parastomal varices": 6, "Urate crystals": 7, "Cancerous metastasis": 8, "Pyoderma gangrenosum": 9, "Normal": 10 } REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()} def main(): # 创建输出目录 output_dir = './evaluation_results_pmc_clip_cat' if not os.path.exists(output_dir): os.makedirs(output_dir) # 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 加载模型配置 model_path = "logs/0321-Stoma-clip-train-cls/2025_03_21-23_45_18-model_RN50_fusion4-lr_1e-05-b_256-j_8-p_amp/checkpoints/epoch_150.pt" model_name = "RN50_fusion4" args = parse_args() args.model = model_name args.pretrained = model_path args.device = device args.mlm = True args.train_data = "data/single_symptoms_test.jsonl" args.image_dir = "./data/cleaned_data" args.csv_img_key = "image" args.csv_caption_key = "caption" args.context_length = 77 args.num_classes = len(LABEL_MAP) args.output_dir = output_dir # 创建模型和预处理函数 model, _, preprocess = pmc_clip.create_model_and_transforms(args) model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='concat') # 加载模型权重 state_dict = torch.load(model_path, map_location='cpu', weights_only=False) state_dict_real = {} for k, v in state_dict['state_dict'].items(): state_dict_real[k.replace("module.", "", 1)] = v print(model.load_state_dict(state_dict_real)) model.to(device=device) # 准备数据集 dataset = PmcDataset(args, input_filename=args.train_data, transforms=preprocess, is_train=False) test_loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4) print(f"测试集样本数: {len(dataset)}") # 收集预测结果 all_preds = [] all_probs = [] all_labels = [] print("开始评估...") model.eval() with torch.no_grad(): for batch in tqdm(test_loader): labels = batch["cls_label"].to(device) # 前向传播 outputs = model(batch) # 获取预测结果 probs = torch.softmax(outputs, dim=1) _, preds = torch.max(outputs, dim=1) all_preds.extend(preds.cpu().numpy()) all_probs.extend(probs.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) # 转换为numpy数组 all_preds = np.array(all_preds) all_probs = np.array(all_probs) all_labels = np.array(all_labels) # 计算整体AUC(使用one-vs-rest策略的平均) try: y_true_bin = label_binarize(all_labels, classes=range(args.num_classes)) if args.num_classes == 2: overall_fpr, overall_tpr, _ = roc_curve(y_true_bin[:, 1], all_probs[:, 1]) overall_auc = roc_auc_score(y_true_bin, all_probs[:, 1]) else: overall_fpr, overall_tpr, _ = roc_curve(y_true_bin.ravel(), all_probs.ravel()) overall_auc = roc_auc_score(y_true_bin, all_probs, multi_class='ovr', average='micro') except Exception as e: print(f"计算整体AUC时出错: {e}") return # 保存整体ROC曲线数据 roc_data = { "fpr": overall_fpr.tolist(), "tpr": overall_tpr.tolist(), "auc": overall_auc } roc_file = os.path.join(output_dir, "overall_roc_data.json") with open(roc_file, "w") as f: json.dump(roc_data, f) print(f"整体ROC曲线数据已保存至: {roc_file}") # 绘制ROC曲线 plt.figure(figsize=(8, 6)) plt.plot(overall_fpr, overall_tpr, label=f"Overall (AUC = {overall_auc:.4f})") plt.plot([0, 1], [0, 1], 'k--', label="Random Guess") plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate (1 - Specificity)', fontsize=12) plt.ylabel('True Positive Rate (Sensitivity)', fontsize=12) plt.title('Overall ROC Curve', fontsize=14) plt.legend(loc="lower right", fontsize=10) plt.grid(alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(output_dir, 'overall_roc_curve.png'), dpi=300, bbox_inches='tight') print(f"整体ROC曲线图已保存至: {os.path.join(output_dir, 'overall_roc_curve.png')}") if __name__ == '__main__': main()