| import sys | |
| import os | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from copy import deepcopy | |
| from models.efficientnet_b0 import EfficientNetB0Classifier | |
| from utils.dataset_loader import ImageNpyDataset | |
| def train_epoch(model, dataloader, criterion, optimizer, device): | |
| model.train() | |
| total_loss = 0 | |
| for i, (x, y) in enumerate(dataloader): | |
| x, y = x.to(device), y.to(device) | |
| optimizer.zero_grad() | |
| out = model(x) | |
| loss = criterion(out, y) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() * x.size(0) | |
| print(f"\rπ Batch {i+1}/{len(dataloader)}", end="") | |
| return total_loss / len(dataloader.dataset) | |
| def eval_model(model, dataloader, criterion, device): | |
| model.eval() | |
| total_loss = 0 | |
| correct = 0 | |
| with torch.no_grad(): | |
| for x, y in dataloader: | |
| x, y = x.to(device), y.to(device) | |
| out = model(x) | |
| loss = criterion(out, y) | |
| total_loss += loss.item() * x.size(0) | |
| preds = (out > 0.5).float() | |
| correct += (preds == y).sum().item() | |
| acc = correct / len(dataloader.dataset) | |
| return total_loss / len(dataloader.dataset), acc | |
| def main(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"π» Using device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") | |
| train_ds = ImageNpyDataset("train_paths.npy", "train_labels.npy", augment=True) | |
| val_ds = ImageNpyDataset("val_paths.npy", "val_labels.npy") | |
| train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True) | |
| val_loader = DataLoader(val_ds, batch_size=32, num_workers=4, pin_memory=True) | |
| model = EfficientNetB0Classifier(train_base=False).to(device) | |
| criterion = nn.BCELoss() | |
| optimizer = optim.Adam(model.parameters(), lr=1e-4) | |
| best_model_wts = deepcopy(model.state_dict()) | |
| best_acc = 0 | |
| patience = 10 | |
| lr_patience = 3 | |
| cooldown = 0 | |
| print("πΈ Training EfficientNetB0 head only...") | |
| for epoch in range(20): | |
| print(f"\nπ Starting Epoch {epoch+1}...") | |
| train_loss = train_epoch(model, train_loader, criterion, optimizer, device) | |
| val_loss, val_acc = eval_model(model, val_loader, criterion, device) | |
| print(f" β Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}") | |
| if val_acc > best_acc: | |
| best_acc = val_acc | |
| best_model_wts = deepcopy(model.state_dict()) | |
| torch.save(model.state_dict(), "efficientnet_best.pth") | |
| cooldown = 0 | |
| print("β Saved new best model") | |
| else: | |
| cooldown += 1 | |
| if cooldown >= lr_patience: | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] *= 0.5 | |
| print("π Reduced LR") | |
| cooldown = 0 | |
| if cooldown >= patience: | |
| print("βΉοΈ Early stopping") | |
| break | |
| print("\nπΈ Fine-tuning top EfficientNetB0 layers...") | |
| model.load_state_dict(best_model_wts) | |
| for param in model.base_model.features[-5:].parameters(): | |
| param.requires_grad = True | |
| optimizer = optim.Adam(model.parameters(), lr=1e-4) | |
| best_finetune_acc = 0.0 | |
| best_finetune_wts = deepcopy(model.state_dict()) | |
| for epoch in range(15): | |
| print(f"\n[Fine-tune] Epoch {epoch+1}...") | |
| train_loss = train_epoch(model, train_loader, criterion, optimizer, device) | |
| _, val_acc = eval_model(model, val_loader, criterion, device) | |
| print(f" π Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Acc={val_acc:.4f}") | |
| if val_acc > best_finetune_acc: | |
| best_finetune_acc = val_acc | |
| best_finetune_wts = deepcopy(model.state_dict()) | |
| torch.save(best_finetune_wts, "efficientnet_best_finetuned.pth") | |
| print("β Saved fine-tuned best model") | |
| model.load_state_dict(best_finetune_wts) | |
| torch.save(model.state_dict(), "efficientnet_final.pth") | |
| print("β Training complete. Final model saved as efficientnet_final.pth") | |
| if __name__ == "__main__": | |
| main() | |