Syncnet_FCN / train_continue_epoch2.py
Shubham
Deploy clean version
579f772
#!/usr/bin/env python
"""
Continue training from epoch 2 checkpoint.
This script resumes training from checkpoints/syncnet_fcn_epoch2.pth
which uses SyncNet_TransferLearning with 31-class classification (±15 frames).
Usage:
python train_continue_epoch2.py --data_dir "E:\voxc2\vox2_dev_mp4_partaa~\dev\mp4" --hours 5
"""
import os
import sys
import argparse
import time
import numpy as np
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
import subprocess
from scipy.io import wavfile
import python_speech_features
from SyncNet_TransferLearning import SyncNet_TransferLearning
class AVSyncDataset(Dataset):
"""Dataset for audio-video sync classification."""
def __init__(self, video_dir, max_offset=15, num_samples_per_video=2,
frame_size=(112, 112), num_frames=25, max_videos=None):
self.video_dir = video_dir
self.max_offset = max_offset
self.num_samples_per_video = num_samples_per_video
self.frame_size = frame_size
self.num_frames = num_frames
# Find all video files
self.video_files = []
for ext in ['*.mp4', '*.avi', '*.mov', '*.mkv']:
self.video_files.extend(Path(video_dir).glob(f'**/{ext}'))
# Limit number of videos if specified
if max_videos and len(self.video_files) > max_videos:
np.random.shuffle(self.video_files)
self.video_files = self.video_files[:max_videos]
if not self.video_files:
raise ValueError(f"No video files found in {video_dir}")
print(f"Using {len(self.video_files)} video files")
# Generate sample list
self.samples = []
for vid_idx in range(len(self.video_files)):
for _ in range(num_samples_per_video):
offset = np.random.randint(-max_offset, max_offset + 1)
self.samples.append((vid_idx, offset))
print(f"Generated {len(self.samples)} training samples")
def __len__(self):
return len(self.samples)
def extract_features(self, video_path):
"""Extract audio MFCC and video frames."""
video_path = str(video_path)
# Extract audio
temp_audio = f'temp_audio_{os.getpid()}_{np.random.randint(10000)}.wav'
try:
cmd = ['ffmpeg', '-y', '-i', video_path, '-ac', '1', '-ar', '16000',
'-vn', '-acodec', 'pcm_s16le', temp_audio]
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
sample_rate, audio = wavfile.read(temp_audio)
# Validate audio length
min_audio_samples = (self.num_frames * 4 + self.max_offset * 4) * 160
if len(audio) < min_audio_samples:
raise ValueError(f"Audio too short: {len(audio)} samples")
mfcc = python_speech_features.mfcc(audio, sample_rate, numcep=13)
min_mfcc_frames = self.num_frames * 4 + abs(self.max_offset) * 4
if len(mfcc) < min_mfcc_frames:
raise ValueError(f"MFCC too short: {len(mfcc)} frames")
finally:
if os.path.exists(temp_audio):
os.remove(temp_audio)
# Extract video frames
cap = cv2.VideoCapture(video_path)
frames = []
while len(frames) < self.num_frames + abs(self.max_offset) + 10:
ret, frame = cap.read()
if not ret:
break
frame = cv2.resize(frame, self.frame_size)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame.astype(np.float32) / 255.0)
cap.release()
if len(frames) < self.num_frames + abs(self.max_offset):
raise ValueError(f"Video too short: {len(frames)} frames")
return mfcc, np.stack(frames)
def apply_offset(self, mfcc, frames, offset):
"""Apply temporal offset between audio and video."""
mfcc_offset = offset * 4
num_video_frames = min(self.num_frames, len(frames) - abs(offset))
num_mfcc_frames = num_video_frames * 4
if offset >= 0:
video_start = 0
mfcc_start = mfcc_offset
else:
video_start = abs(offset)
mfcc_start = 0
video_segment = frames[video_start:video_start + num_video_frames]
mfcc_segment = mfcc[mfcc_start:mfcc_start + num_mfcc_frames]
# Pad if needed
if len(video_segment) < self.num_frames:
pad_frames = self.num_frames - len(video_segment)
video_segment = np.concatenate([
video_segment,
np.repeat(video_segment[-1:], pad_frames, axis=0)
], axis=0)
target_mfcc_len = self.num_frames * 4
if len(mfcc_segment) < target_mfcc_len:
pad_mfcc = target_mfcc_len - len(mfcc_segment)
mfcc_segment = np.concatenate([
mfcc_segment,
np.repeat(mfcc_segment[-1:], pad_mfcc, axis=0)
], axis=0)
return mfcc_segment[:target_mfcc_len], video_segment[:self.num_frames]
def __getitem__(self, idx):
vid_idx, offset = self.samples[idx]
video_path = self.video_files[vid_idx]
try:
mfcc, frames = self.extract_features(video_path)
mfcc, frames = self.apply_offset(mfcc, frames, offset)
audio_tensor = torch.FloatTensor(mfcc.T).unsqueeze(0) # [1, 13, T]
video_tensor = torch.FloatTensor(frames).permute(3, 0, 1, 2) # [3, T, H, W]
offset_tensor = torch.tensor(offset, dtype=torch.long)
return audio_tensor, video_tensor, offset_tensor
except Exception as e:
return None
def collate_fn_skip_none(batch):
"""Skip None samples."""
batch = [b for b in batch if b is not None]
if len(batch) == 0:
return None
audio = torch.stack([b[0] for b in batch])
video = torch.stack([b[1] for b in batch])
offset = torch.stack([b[2] for b in batch])
return audio, video, offset
def train_epoch(model, dataloader, criterion, optimizer, device, max_offset):
"""Train for one epoch."""
model.train()
total_loss = 0
total_correct = 0
total_samples = 0
for batch_idx, batch in enumerate(dataloader):
if batch is None:
continue
audio, video, target_offset = batch
audio = audio.to(device)
video = video.to(device)
target_class = (target_offset + max_offset).long().to(device)
optimizer.zero_grad()
# Forward pass
sync_probs, _, _ = model(audio, video)
# Global average pooling over time
sync_logits = torch.log(sync_probs + 1e-8).mean(dim=2) # [B, 31]
# Compute loss
loss = criterion(sync_logits, target_class)
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Track metrics
total_loss += loss.item() * audio.size(0)
predicted_class = sync_logits.argmax(dim=1)
total_correct += (predicted_class == target_class).sum().item()
total_samples += audio.size(0)
if batch_idx % 10 == 0:
acc = 100.0 * total_correct / total_samples if total_samples > 0 else 0
print(f" Batch {batch_idx}/{len(dataloader)}: Loss={loss.item():.4f}, Acc={acc:.2f}%")
return total_loss / total_samples, total_correct / total_samples
def main():
parser = argparse.ArgumentParser(description='Continue training from epoch 2')
parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--checkpoint', type=str, default='checkpoints/syncnet_fcn_epoch2.pth')
parser.add_argument('--output_dir', type=str, default='checkpoints')
parser.add_argument('--hours', type=float, default=5.0, help='Training time in hours')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--max_videos', type=int, default=None,
help='Limit number of videos (for faster training)')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
max_offset = 15 # ±15 frames, 31 classes
# Create model
print("Creating model...")
model = SyncNet_TransferLearning(
video_backbone='fcn',
audio_backbone='fcn',
embedding_dim=512,
max_offset=max_offset,
freeze_backbone=False
)
# Load checkpoint
print(f"Loading checkpoint: {args.checkpoint}")
checkpoint = torch.load(args.checkpoint, map_location=device)
# Load model state
model_state = checkpoint['model_state_dict']
# Remove 'fcn_model.' prefix if present
new_state = {}
for k, v in model_state.items():
if k.startswith('fcn_model.'):
new_state[k[10:]] = v # Remove 'fcn_model.' prefix
else:
new_state[k] = v
model.load_state_dict(new_state, strict=False)
start_epoch = checkpoint.get('epoch', 2)
print(f"Resuming from epoch {start_epoch}")
model = model.to(device)
# Dataset
print("Loading dataset...")
dataset = AVSyncDataset(
video_dir=args.data_dir,
max_offset=max_offset,
num_samples_per_video=2,
max_videos=args.max_videos
)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=0,
collate_fn=collate_fn_skip_none,
pin_memory=True
)
# Loss and optimizer
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
# Training loop with time limit
os.makedirs(args.output_dir, exist_ok=True)
max_seconds = args.hours * 3600
start_time = time.time()
epoch = start_epoch
best_acc = 0
print(f"\n{'='*60}")
print(f"Starting training for {args.hours} hours...")
print(f"{'='*60}")
while True:
elapsed = time.time() - start_time
remaining = max_seconds - elapsed
if remaining <= 0:
print(f"\nTime limit reached ({args.hours} hours)")
break
epoch += 1
print(f"\nEpoch {epoch} (Time remaining: {remaining/3600:.2f} hours)")
print("-" * 40)
train_loss, train_acc = train_epoch(
model, dataloader, criterion, optimizer, device, max_offset
)
print(f"Epoch {epoch}: Loss={train_loss:.4f}, Acc={100*train_acc:.2f}%")
# Save checkpoint
checkpoint_path = os.path.join(args.output_dir, f'syncnet_fcn_epoch{epoch}.pth')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss,
'accuracy': train_acc * 100,
}, checkpoint_path)
print(f"Saved: {checkpoint_path}")
# Save best
if train_acc > best_acc:
best_acc = train_acc
best_path = os.path.join(args.output_dir, 'syncnet_fcn_best.pth')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss,
'accuracy': train_acc * 100,
}, best_path)
print(f"New best model saved: {best_path}")
print(f"\n{'='*60}")
print(f"Training complete!")
print(f"Final epoch: {epoch}")
print(f"Best accuracy: {100*best_acc:.2f}%")
print(f"{'='*60}")
if __name__ == '__main__':
main()