Syncnet_FCN / train_syncnet_fcn_complete.py
Shubham
Deploy clean version
579f772
#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
Training Script for SyncNetFCN on VoxCeleb2
Usage:
python train_syncnet_fcn_complete.py --data_dir E:/voxceleb2_dataset/VoxCeleb2/dev --pretrained_model data/syncnet_v2.model
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import argparse
import numpy as np
from SyncNetModel_FCN import StreamSyncFCN
import glob
import random
import cv2
import subprocess
from scipy.io import wavfile
import python_speech_features
class VoxCeleb2Dataset(Dataset):
"""VoxCeleb2 dataset loader for sync training with real preprocessing."""
def __init__(self, data_dir, max_offset=15, video_length=25, temp_dir='temp_dataset'):
"""
Args:
data_dir: Path to VoxCeleb2 root directory
max_offset: Maximum frame offset for negative samples
video_length: Number of frames per clip
temp_dir: Temporary directory for audio extraction
"""
self.data_dir = data_dir
self.max_offset = max_offset
self.video_length = video_length
self.temp_dir = temp_dir
os.makedirs(temp_dir, exist_ok=True)
# Find all video files
self.video_files = glob.glob(os.path.join(data_dir, '**', '*.mp4'), recursive=True)
print(f"Found {len(self.video_files)} videos in dataset")
def __len__(self):
return len(self.video_files)
def _extract_audio_mfcc(self, video_path):
"""Extract audio and compute MFCC features."""
# Create unique temp audio file
video_id = os.path.splitext(os.path.basename(video_path))[0]
audio_path = os.path.join(self.temp_dir, f'{video_id}_audio.wav')
try:
# Extract audio using FFmpeg
cmd = ['ffmpeg', '-y', '-i', video_path, '-ac', '1', '-ar', '16000',
'-vn', '-acodec', 'pcm_s16le', audio_path]
result = subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=30)
if result.returncode != 0:
raise RuntimeError(f"FFmpeg failed for {video_path}: {result.stderr.decode(errors='ignore')}")
# Read audio and compute MFCC
try:
sample_rate, audio = wavfile.read(audio_path)
except Exception as e:
raise RuntimeError(f"wavfile.read failed for {audio_path}: {e}")
# Ensure audio is 1D
if isinstance(audio, np.ndarray) and len(audio.shape) > 1:
audio = audio.mean(axis=1)
# Check for empty or invalid audio
if not isinstance(audio, np.ndarray) or audio.size == 0:
raise ValueError(f"Audio data is empty or invalid for {audio_path}")
# Compute MFCC
try:
mfcc = python_speech_features.mfcc(audio, sample_rate, numcep=13)
except Exception as e:
raise RuntimeError(f"MFCC extraction failed for {audio_path}: {e}")
# Shape: [T, 13] -> [13, T] -> [1, 1, 13, T]
mfcc_tensor = torch.FloatTensor(mfcc.T).unsqueeze(0).unsqueeze(0) # [1, 1, 13, T]
# Clean up temp file
if os.path.exists(audio_path):
try:
os.remove(audio_path)
except Exception:
pass
return mfcc_tensor
except Exception as e:
# Clean up temp file on error
if os.path.exists(audio_path):
try:
os.remove(audio_path)
except Exception:
pass
raise RuntimeError(f"Failed to extract audio from {video_path}: {e}")
def _extract_video_frames(self, video_path, target_size=(112, 112)):
"""Extract video frames as tensor."""
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
# Resize and normalize
frame = cv2.resize(frame, target_size)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame.astype(np.float32) / 255.0)
cap.release()
if not frames:
raise ValueError(f"No frames extracted from {video_path}")
# Stack and convert to tensor [T, H, W, 3] -> [3, T, H, W]
frames_array = np.stack(frames, axis=0)
video_tensor = torch.FloatTensor(frames_array).permute(3, 0, 1, 2).unsqueeze(0)
return video_tensor
def _crop_or_pad_video(self, video_tensor, target_length):
"""Crop or pad video to target length."""
B, C, T, H, W = video_tensor.shape
if T > target_length:
# Random crop
start = random.randint(0, T - target_length)
return video_tensor[:, :, start:start+target_length, :, :]
elif T < target_length:
# Pad with last frame
pad_length = target_length - T
last_frame = video_tensor[:, :, -1:, :, :].repeat(1, 1, pad_length, 1, 1)
return torch.cat([video_tensor, last_frame], dim=2)
else:
return video_tensor
def _crop_or_pad_audio(self, audio_tensor, target_length):
"""Crop or pad audio to target length."""
B, C, T = audio_tensor.shape
if T > target_length:
# Random crop
start = random.randint(0, T - target_length)
return audio_tensor[:, :, start:start+target_length]
elif T < target_length:
# Pad with zeros
pad_length = target_length - T
padding = torch.zeros(B, C, pad_length)
return torch.cat([audio_tensor, padding], dim=2)
else:
return audio_tensor
def __getitem__(self, idx):
"""
Returns:
audio: [1, 13, T] MFCC features
video: [3, T_frames, H, W] video frames
offset: Ground truth offset (0 for positive, non-zero for negative)
label: 1 if in sync, 0 if out of sync
"""
import time
video_path = self.video_files[idx]
t0 = time.time()
# Randomly decide if this should be positive (sync) or negative (out-of-sync)
is_positive = random.random() > 0.5
if is_positive:
offset = 0
label = 1
else:
# Random offset between 1 and max_offset
offset = random.randint(1, self.max_offset) * random.choice([-1, 1])
label = 0
# Log offset/label distribution occasionally
if random.random() < 0.01:
print(f"[INFO][VoxCeleb2Dataset] idx={idx}, path={video_path}, offset={offset}, label={label}")
try:
# Extract audio MFCC features
t_audio0 = time.time()
audio = self._extract_audio_mfcc(video_path)
t_audio1 = time.time()
# Log audio tensor shape/dtype
if random.random() < 0.01:
print(f"[INFO][Audio] idx={idx}, path={video_path}, shape={audio.shape}, dtype={audio.dtype}, time={t_audio1-t_audio0:.2f}s")
# Extract video frames
t_vid0 = time.time()
video = self._extract_video_frames(video_path)
t_vid1 = time.time()
# Log number of frames
if random.random() < 0.01:
print(f"[INFO][Video] idx={idx}, path={video_path}, frames={video.shape[2] if video.dim()==5 else 'ERR'}, shape={video.shape}, dtype={video.dtype}, time={t_vid1-t_vid0:.2f}s")
# Apply temporal offset for negative samples
if not is_positive and offset != 0:
if offset > 0:
# Shift video forward (cut from beginning)
video = video[:, :, offset:, :, :]
else:
# Shift video backward (cut from end)
video = video[:, :, :offset, :, :]
# Crop/pad to fixed length
video = self._crop_or_pad_video(video, self.video_length)
audio = self._crop_or_pad_audio(audio, self.video_length * 4)
# Remove batch dimension (DataLoader will add it)
# audio is [1, 1, 13, T], squeeze to [1, 13, T]
audio = audio.squeeze(0) # [1, 13, T]
video = video.squeeze(0) # [3, T, H, W]
# Check for shape mismatches
if audio.shape[0] != 13:
raise ValueError(f"Audio MFCC shape mismatch: {audio.shape} for {video_path}")
if video.shape[0] != 3 or video.shape[2] != 112 or video.shape[3] != 112:
raise ValueError(f"Video frame shape mismatch: {video.shape} for {video_path}")
t1 = time.time()
if random.random() < 0.01:
print(f"[INFO][Sample] idx={idx}, path={video_path}, total_time={t1-t0:.2f}s")
dummy = False
except Exception as e:
# Fallback to dummy data if preprocessing fails
# Only print occasionally to avoid spam
import traceback
print(f"[WARN][VoxCeleb2Dataset] idx={idx}, path={video_path}, ERROR_STAGE=__getitem__, error={str(e)[:100]}")
traceback.print_exc(limit=1)
audio = torch.randn(1, 13, self.video_length * 4)
video = torch.randn(3, self.video_length, 112, 112)
offset = 0
label = 1
dummy = True
# Resource cleanup: ensure no temp files left behind (audio)
temp_audio = os.path.join(self.temp_dir, f'{os.path.splitext(os.path.basename(video_path))[0]}_audio.wav')
if os.path.exists(temp_audio):
try:
os.remove(temp_audio)
except Exception:
pass
# Log dummy sample usage
if dummy and random.random() < 0.5:
print(f"[WARN][VoxCeleb2Dataset] idx={idx}, path={video_path}, DUMMY_SAMPLE_USED")
return {
'audio': audio,
'video': video,
'offset': torch.tensor(offset, dtype=torch.float32),
'label': torch.tensor(label, dtype=torch.float32),
'dummy': dummy
}
class SyncLoss(nn.Module):
"""Binary cross-entropy loss for sync/no-sync classification."""
def __init__(self):
super(SyncLoss, self).__init__()
self.bce = nn.BCEWithLogitsLoss()
def forward(self, sync_probs, labels):
"""
Args:
sync_probs: [B, 2*K+1, T] sync probability distribution
labels: [B] binary labels (1=sync, 0=out-of-sync)
"""
# Take max probability across offsets and time
max_probs = sync_probs.max(dim=1)[0].max(dim=1)[0] # [B]
# BCE loss
loss = self.bce(max_probs, labels)
return loss
def train_epoch(model, dataloader, optimizer, criterion, device):
"""Train for one epoch."""
model.train()
total_loss = 0
correct = 0
total = 0
import torch
import gc
for batch_idx, batch in enumerate(dataloader):
audio = batch['audio'].to(device)
video = batch['video'].to(device)
labels = batch['label'].to(device)
# Log dummy data in batch
if 'dummy' in batch:
num_dummy = batch['dummy'].sum().item() if hasattr(batch['dummy'], 'sum') else int(sum(batch['dummy']))
if num_dummy > 0:
print(f"[WARN][train_epoch] Batch {batch_idx}: {num_dummy}/{len(labels)} dummy samples in batch!")
# Forward pass
optimizer.zero_grad()
sync_probs, _, _ = model(audio, video)
# Log tensor shapes
if batch_idx % 50 == 0:
print(f"[INFO][train_epoch] Batch {batch_idx}: audio {audio.shape}, video {video.shape}, sync_probs {sync_probs.shape}")
# Compute loss
loss = criterion(sync_probs, labels)
# Backward pass
loss.backward()
optimizer.step()
# Statistics
total_loss += loss.item()
pred = (sync_probs.max(dim=1)[0].max(dim=1)[0] > 0.5).float()
correct += (pred == labels).sum().item()
total += labels.size(0)
# Log memory usage occasionally
if batch_idx % 100 == 0 and torch.cuda.is_available():
mem = torch.cuda.memory_allocated() / 1024**2
print(f"[INFO][train_epoch] Batch {batch_idx}: GPU memory used: {mem:.2f} MB")
if batch_idx % 10 == 0:
print(f' Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}, Acc: {100*correct/total:.2f}%')
# Clean up
del audio, video, labels
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
avg_loss = total_loss / len(dataloader)
accuracy = 100 * correct / total
return avg_loss, accuracy
def main():
parser = argparse.ArgumentParser(description='Train SyncNetFCN')
parser.add_argument('--data_dir', type=str, required=True, help='VoxCeleb2 root directory')
parser.add_argument('--pretrained_model', type=str, default='data/syncnet_v2.model',
help='Pretrained SyncNet model')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)')
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument('--output_dir', type=str, default='checkpoints', help='Output directory')
parser.add_argument('--use_attention', action='store_true', help='Use attention model')
parser.add_argument('--num_workers', type=int, default=2, help='DataLoader workers')
args = parser.parse_args()
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Create model with transfer learning
print('Creating model...')
model = StreamSyncFCN(
pretrained_syncnet_path=args.pretrained_model,
auto_load_pretrained=True,
use_attention=args.use_attention
)
model = model.to(device)
print(f'Model created. Pretrained conv layers loaded and frozen.')
# Dataset and dataloader
print('Loading dataset...')
dataset = VoxCeleb2Dataset(args.data_dir)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True)
# Loss and optimizer
criterion = SyncLoss()
# Only optimize non-frozen parameters
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(trainable_params, lr=args.lr)
print(f'Trainable parameters: {sum(p.numel() for p in trainable_params):,}')
print(f'Frozen parameters: {sum(p.numel() for p in model.parameters() if not p.requires_grad):,}')
# Training loop
print('\nStarting training...')
print('='*80)
for epoch in range(args.epochs):
print(f'\nEpoch {epoch+1}/{args.epochs}')
print('-'*80)
avg_loss, accuracy = train_epoch(model, dataloader, optimizer, criterion, device)
print(f'\nEpoch {epoch+1} Summary:')
print(f' Average Loss: {avg_loss:.4f}')
print(f' Accuracy: {accuracy:.2f}%')
# Save checkpoint
checkpoint_path = os.path.join(args.output_dir, f'syncnet_fcn_epoch{epoch+1}.pth')
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_loss,
'accuracy': accuracy,
}, checkpoint_path)
print(f' Checkpoint saved: {checkpoint_path}')
print('\n' + '='*80)
print('Training complete!')
print(f'Final model saved to: {args.output_dir}')
if __name__ == '__main__':
main()