Mendota commited on
Commit
5faac9a
·
verified ·
1 Parent(s): 1ebc342

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +369 -0
main.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import datasets, models, transforms
5
+ from torch.utils.data import DataLoader, random_split, Dataset
6
+ from torch.optim import lr_scheduler
7
+ import pandas as pd
8
+ import numpy as np
9
+ import time
10
+ import copy
11
+ import os
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
15
+ from sklearn.preprocessing import label_binarize
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ root_dir = r"path"
20
+ data_dir = os.path.join(root_dir, 'Training')
21
+ save_dir = "./improved_results"
22
+ os.makedirs(save_dir, exist_ok=True)
23
+
24
+ CONFIG = {
25
+ 'model_name': 'ResNet50_Improved',
26
+ 'batch_size': 32,
27
+ 'lr': 0.001,
28
+ 'epochs': 25,
29
+ 'scheduler_step': 7,
30
+ 'gamma': 0.1,
31
+ 'weight_decay': 5e-4,
32
+ 'dropout_rate': 0.6,
33
+ 'early_stopping_patience': 5,
34
+ 'early_stopping_min_delta': 0.001
35
+ }
36
+
37
+ train_transforms = transforms.Compose([
38
+ transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
39
+ transforms.RandomHorizontalFlip(),
40
+ transforms.RandomRotation(20),
41
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
42
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
45
+ ])
46
+
47
+ test_transforms = transforms.Compose([
48
+ transforms.Resize((224, 224)),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
51
+ ])
52
+
53
+
54
+ class EarlyStopping:
55
+ def __init__(self, patience=5, min_delta=0.001, mode='max'):
56
+ self.patience = patience
57
+ self.min_delta = min_delta
58
+ self.mode = mode
59
+ self.counter = 0
60
+ self.best_score = None
61
+ self.early_stop = False
62
+
63
+ def __call__(self, score):
64
+ if self.best_score is None:
65
+ self.best_score = score
66
+ return False
67
+
68
+ if self.mode == 'max':
69
+ if score > self.best_score + self.min_delta:
70
+ self.best_score = score
71
+ self.counter = 0
72
+ else:
73
+ self.counter += 1
74
+ else:
75
+ if score < self.best_score - self.min_delta:
76
+ self.best_score = score
77
+ self.counter = 0
78
+ else:
79
+ self.counter += 1
80
+
81
+ if self.counter >= self.patience:
82
+ self.early_stop = True
83
+ print(f"\nEarly stopping triggered! No improvement for {self.patience} epochs.")
84
+ return True
85
+ return False
86
+
87
+
88
+ class TransformedSubset(Dataset):
89
+ def __init__(self, subset, transform=None):
90
+ self.subset = subset
91
+ self.transform = transform
92
+
93
+ def __getitem__(self, index):
94
+ x, y = self.subset[index]
95
+ if self.transform:
96
+ x = self.transform(x)
97
+ return x, y
98
+
99
+ def __len__(self):
100
+ return len(self.subset)
101
+
102
+
103
+ base_dataset = datasets.ImageFolder(root=data_dir)
104
+ class_names = base_dataset.classes
105
+ num_classes = len(class_names)
106
+
107
+ train_size = int(0.8 * len(base_dataset))
108
+ test_size = len(base_dataset) - train_size
109
+ train_indices, test_indices = random_split(base_dataset, [train_size, test_size])
110
+
111
+ train_dataset = TransformedSubset(train_indices, transform=train_transforms)
112
+ test_dataset = TransformedSubset(test_indices, transform=test_transforms)
113
+
114
+ dataloaders = {
115
+ 'train': DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=0),
116
+ 'test': DataLoader(test_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0)
117
+ }
118
+ dataset_sizes = {'train': train_size, 'test': test_size}
119
+
120
+
121
+ def get_model():
122
+ model = models.resnet50(pretrained=True)
123
+ num_ftrs = model.fc.in_features
124
+ model.fc = nn.Sequential(
125
+ nn.Dropout(CONFIG['dropout_rate']),
126
+ nn.Linear(num_ftrs, num_classes)
127
+ )
128
+ return model.to(device)
129
+
130
+
131
+ model = get_model()
132
+ criterion = nn.CrossEntropyLoss()
133
+ optimizer = optim.Adam(
134
+ model.parameters(),
135
+ lr=CONFIG['lr'],
136
+ weight_decay=CONFIG['weight_decay']
137
+ )
138
+ exp_lr_scheduler = lr_scheduler.StepLR(
139
+ optimizer,
140
+ step_size=CONFIG['scheduler_step'],
141
+ gamma=CONFIG['gamma']
142
+ )
143
+
144
+
145
+ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
146
+ since = time.time()
147
+ best_model_wts = copy.deepcopy(model.state_dict())
148
+ best_acc = 0.0
149
+ history = []
150
+
151
+ early_stopping = EarlyStopping(
152
+ patience=CONFIG['early_stopping_patience'],
153
+ min_delta=CONFIG['early_stopping_min_delta'],
154
+ mode='max'
155
+ )
156
+
157
+ for epoch in range(num_epochs):
158
+ print(f'\n{"="*50}')
159
+ print(f'Epoch {epoch+1}/{num_epochs}')
160
+ print("="*50)
161
+
162
+ epoch_stats = {'Epoch': epoch+1}
163
+
164
+ for phase in ['train', 'test']:
165
+ if phase == 'train':
166
+ model.train()
167
+ else:
168
+ model.eval()
169
+
170
+ running_loss = 0.0
171
+ running_corrects = 0
172
+
173
+ for inputs, labels in dataloaders[phase]:
174
+ inputs = inputs.to(device)
175
+ labels = labels.to(device)
176
+ optimizer.zero_grad()
177
+
178
+ with torch.set_grad_enabled(phase == 'train'):
179
+ outputs = model(inputs)
180
+ _, preds = torch.max(outputs, 1)
181
+ loss = criterion(outputs, labels)
182
+
183
+ if phase == 'train':
184
+ loss.backward()
185
+ optimizer.step()
186
+
187
+ running_loss += loss.item() * inputs.size(0)
188
+ running_corrects += torch.sum(preds == labels.data)
189
+
190
+ if phase == 'train':
191
+ scheduler.step()
192
+
193
+ epoch_loss = running_loss / dataset_sizes[phase]
194
+ epoch_acc = running_corrects.double() / dataset_sizes[phase]
195
+
196
+ print(f'{phase.upper():5s} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f} ({epoch_acc*100:.2f}%)')
197
+
198
+ epoch_stats[f'{phase}_loss'] = epoch_loss
199
+ epoch_stats[f'{phase}_acc'] = epoch_acc.item()
200
+
201
+ if phase == 'test':
202
+ if epoch_acc > best_acc:
203
+ best_acc = epoch_acc
204
+ best_model_wts = copy.deepcopy(model.state_dict())
205
+ torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth'))
206
+ print(f"✅ New Record! Test Acc: {best_acc:.4f}")
207
+
208
+ if early_stopping(epoch_acc.item()):
209
+ print(f"\nTraining stopped (Epoch {epoch+1})")
210
+ model.load_state_dict(best_model_wts)
211
+ df = pd.DataFrame(history)
212
+ df.to_csv(os.path.join(save_dir, 'training_logs.csv'), index=False)
213
+ return model, df
214
+
215
+ history.append(epoch_stats)
216
+
217
+ time_elapsed = time.time() - since
218
+ print(f'\n{"="*50}')
219
+ print(f'Training completed: {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
220
+ print(f'Best Test Accuracy: {best_acc:.4f} ({best_acc*100:.2f}%)')
221
+ print("="*50)
222
+
223
+ model.load_state_dict(best_model_wts)
224
+ df = pd.DataFrame(history)
225
+ df.to_csv(os.path.join(save_dir, 'training_logs.csv'), index=False)
226
+
227
+ return model, df
228
+
229
+
230
+ def evaluate_model(model, dataloader, class_names):
231
+ model.eval()
232
+ all_preds = []
233
+ all_labels = []
234
+ all_probs = []
235
+
236
+ with torch.no_grad():
237
+ for inputs, labels in dataloader:
238
+ inputs = inputs.to(device)
239
+ outputs = model(inputs)
240
+ probs = torch.softmax(outputs, dim=1)
241
+ _, preds = torch.max(outputs, 1)
242
+
243
+ all_preds.extend(preds.cpu().numpy())
244
+ all_labels.extend(labels.numpy())
245
+ all_probs.extend(probs.cpu().numpy())
246
+
247
+ all_preds = np.array(all_preds)
248
+ all_labels = np.array(all_labels)
249
+ all_probs = np.array(all_probs)
250
+
251
+ cm = confusion_matrix(all_labels, all_preds)
252
+ plt.figure(figsize=(10, 8))
253
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
254
+ xticklabels=class_names, yticklabels=class_names,
255
+ cbar_kws={'label': 'Count'})
256
+ plt.title('Confusion Matrix', fontsize=16, fontweight='bold')
257
+ plt.ylabel('True Class', fontsize=12)
258
+ plt.xlabel('Predicted Class', fontsize=12)
259
+ plt.tight_layout()
260
+ plt.savefig(os.path.join(save_dir, 'confusion_matrix.png'), dpi=300)
261
+ plt.show()
262
+
263
+ print("\n" + "="*60)
264
+ print("DETAILED PERFORMANCE REPORT")
265
+ print("="*60)
266
+ report = classification_report(all_labels, all_preds,
267
+ target_names=class_names,
268
+ digits=4)
269
+ print(report)
270
+
271
+ report_dict = classification_report(all_labels, all_preds,
272
+ target_names=class_names,
273
+ output_dict=True)
274
+
275
+ metrics = ['precision', 'recall', 'f1-score']
276
+ class_metrics = {metric: [] for metric in metrics}
277
+
278
+ for class_name in class_names:
279
+ for metric in metrics:
280
+ class_metrics[metric].append(report_dict[class_name][metric])
281
+
282
+ fig, ax = plt.subplots(figsize=(12, 6))
283
+ x = np.arange(len(class_names))
284
+ width = 0.25
285
+
286
+ for i, metric in enumerate(metrics):
287
+ ax.bar(x + i*width, class_metrics[metric], width,
288
+ label=metric.capitalize(), alpha=0.8)
289
+
290
+ ax.set_xlabel('Classes', fontsize=12)
291
+ ax.set_ylabel('Score', fontsize=12)
292
+ ax.set_title('Per-Class Performance Metrics', fontsize=14, fontweight='bold')
293
+ ax.set_xticks(x + width)
294
+ ax.set_xticklabels(class_names, rotation=45, ha='right')
295
+ ax.legend()
296
+ ax.set_ylim([0, 1.05])
297
+ ax.grid(axis='y', alpha=0.3)
298
+ plt.tight_layout()
299
+ plt.savefig(os.path.join(save_dir, 'class_metrics.png'), dpi=300)
300
+ plt.show()
301
+
302
+ try:
303
+ y_bin = label_binarize(all_labels, classes=range(num_classes))
304
+ auc_scores = []
305
+ for i in range(num_classes):
306
+ auc = roc_auc_score(y_bin[:, i], all_probs[:, i])
307
+ auc_scores.append(auc)
308
+ print(f"ROC-AUC ({class_names[i]}): {auc:.4f}")
309
+ print(f"Mean ROC-AUC: {np.mean(auc_scores):.4f}")
310
+ except:
311
+ print("ROC-AUC could not be calculated")
312
+
313
+ return cm, report
314
+
315
+
316
+ def plot_training_results(df):
317
+ sns.set_style("whitegrid")
318
+ fig, axes = plt.subplots(1, 2, figsize=(15, 5))
319
+
320
+ axes[0].plot(df['Epoch'], df['train_loss'], 'o-', label='Train Loss', linewidth=2, markersize=6)
321
+ axes[0].plot(df['Epoch'], df['test_loss'], 's-', label='Test Loss', linewidth=2, markersize=6)
322
+ axes[0].set_title('Loss Evolution', fontsize=14, fontweight='bold')
323
+ axes[0].set_xlabel('Epoch', fontsize=12)
324
+ axes[0].set_ylabel('Loss', fontsize=12)
325
+ axes[0].legend(fontsize=11)
326
+ axes[0].grid(True, alpha=0.3)
327
+
328
+ axes[1].plot(df['Epoch'], df['train_acc'], 'o-', label='Train Acc', linewidth=2, markersize=6, color='green')
329
+ axes[1].plot(df['Epoch'], df['test_acc'], 's-', label='Test Acc', linewidth=2, markersize=6, color='orange')
330
+ axes[1].set_title('Accuracy Evolution', fontsize=14, fontweight='bold')
331
+ axes[1].set_xlabel('Epoch', fontsize=12)
332
+ axes[1].set_ylabel('Accuracy', fontsize=12)
333
+ axes[1].legend(fontsize=11)
334
+ axes[1].grid(True, alpha=0.3)
335
+ axes[1].set_ylim([0, 1.05])
336
+
337
+ plt.tight_layout()
338
+ plt.savefig(os.path.join(save_dir, 'training_curves.png'), dpi=300)
339
+ plt.show()
340
+
341
+ df['overfit_gap'] = df['train_acc'] - df['test_acc']
342
+ print(f"\nOverfitting Analysis:")
343
+ print(f"Mean Train-Test Gap: {df['overfit_gap'].mean():.4f}")
344
+ print(f"Max Gap: {df['overfit_gap'].max():.4f} (Epoch {df.loc[df['overfit_gap'].idxmax(), 'Epoch']:.0f})")
345
+
346
+
347
+ print("\nStarting training...\n")
348
+ model_ft, logs = train_model(
349
+ model, criterion, optimizer, exp_lr_scheduler,
350
+ num_epochs=CONFIG['epochs']
351
+ )
352
+
353
+ print("\nVisualizing results...")
354
+ plot_training_results(logs)
355
+
356
+ print("\nPerforming detailed evaluation...")
357
+ cm, report = evaluate_model(model_ft, dataloaders['test'], class_names)
358
+
359
+ print("\n" + "="*60)
360
+ print("SUMMARY REPORT")
361
+ print("="*60)
362
+ print(f"Model: {CONFIG['model_name']}")
363
+ print(f"Total Epochs: {len(logs)}")
364
+ print(f"Best Test Accuracy: {logs['test_acc'].max():.4f} ({logs['test_acc'].max()*100:.2f}%)")
365
+ print(f"Final Test Accuracy: {logs['test_acc'].iloc[-1]:.4f}")
366
+ print(f"Final Train Accuracy: {logs['train_acc'].iloc[-1]:.4f}")
367
+ print(f"Overfitting Gap: {logs['train_acc'].iloc[-1] - logs['test_acc'].iloc[-1]:.4f}")
368
+ print(f"\nAll results saved to '{save_dir}'")
369
+ print("="*60)