#!/usr/bin/python #-*- coding: utf-8 -*- """ Fully Convolutional SyncNet (FCN-SyncNet) - CLASSIFICATION VERSION Key difference from regression version: - Output: Probability distribution over discrete offset classes - Loss: CrossEntropyLoss instead of MSE - Avoids regression-to-mean problem Offset classes: -15 to +15 frames (31 classes total) Class 0 = -15 frames, Class 15 = 0 frames, Class 30 = +15 frames Author: Enhanced version based on original SyncNet Date: 2025-12-04 """ import torch import torch.nn as nn import torch.nn.functional as F import math import numpy as np import cv2 import os import subprocess from scipy.io import wavfile import python_speech_features class TemporalCorrelation(nn.Module): """ Compute correlation between audio and video features across time. """ def __init__(self, max_displacement=15): super(TemporalCorrelation, self).__init__() self.max_displacement = max_displacement def forward(self, feat1, feat2): """ Args: feat1: [B, C, T] - visual features feat2: [B, C, T] - audio features Returns: correlation: [B, 2*max_displacement+1, T] - correlation map """ B, C, T = feat1.shape max_disp = self.max_displacement # Normalize features feat1 = F.normalize(feat1, dim=1) feat2 = F.normalize(feat2, dim=1) # Pad feat2 for shifting feat2_padded = F.pad(feat2, (max_disp, max_disp), mode='replicate') corr_list = [] for offset in range(-max_disp, max_disp + 1): shifted_feat2 = feat2_padded[:, :, offset+max_disp:offset+max_disp+T] corr = (feat1 * shifted_feat2).sum(dim=1, keepdim=True) corr_list.append(corr) correlation = torch.cat(corr_list, dim=1) return correlation class ChannelAttention(nn.Module): """Squeeze-and-Excitation style channel attention.""" def __init__(self, channels, reduction=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool1d(1) self.fc = nn.Sequential( nn.Linear(channels, channels // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channels // reduction, channels, bias=False), nn.Sigmoid() ) def forward(self, x): b, c, t = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1) return x * y.expand_as(x) class TemporalAttention(nn.Module): """Self-attention over temporal dimension.""" def __init__(self, channels): super(TemporalAttention, self).__init__() self.query_conv = nn.Conv1d(channels, channels // 8, 1) self.key_conv = nn.Conv1d(channels, channels // 8, 1) self.value_conv = nn.Conv1d(channels, channels, 1) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): B, C, T = x.size() query = self.query_conv(x).permute(0, 2, 1) key = self.key_conv(x) value = self.value_conv(x) attention = torch.bmm(query, key) attention = F.softmax(attention, dim=-1) out = torch.bmm(value, attention.permute(0, 2, 1)) out = self.gamma * out + x return out class FCN_AudioEncoder(nn.Module): """Fully convolutional audio encoder.""" def __init__(self, output_channels=512): super(FCN_AudioEncoder, self).__init__() self.conv_layers = nn.Sequential( nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 192, kernel_size=(3,3), stride=(1,1), padding=(1,1)), nn.BatchNorm2d(192), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(3,3), stride=(1,2)), nn.Conv2d(192, 384, kernel_size=(3,3), padding=(1,1)), nn.BatchNorm2d(384), nn.ReLU(inplace=True), nn.Conv2d(384, 256, kernel_size=(3,3), padding=(1,1)), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=(3,3), padding=(1,1)), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)), nn.Conv2d(256, 512, kernel_size=(5,1), stride=(5,1), padding=(0,0)), nn.BatchNorm2d(512), nn.ReLU(inplace=True), ) self.channel_conv = nn.Sequential( nn.Conv1d(512, 512, kernel_size=1), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Conv1d(512, output_channels, kernel_size=1), nn.BatchNorm1d(output_channels), ) self.channel_attn = ChannelAttention(output_channels) def forward(self, x): x = self.conv_layers(x) B, C, F, T = x.size() x = x.view(B, C * F, T) x = self.channel_conv(x) x = self.channel_attn(x) return x class FCN_VideoEncoder(nn.Module): """Fully convolutional video encoder.""" def __init__(self, output_channels=512): super(FCN_VideoEncoder, self).__init__() self.conv_layers = nn.Sequential( nn.Conv3d(3, 96, kernel_size=(5,7,7), stride=(1,2,2), padding=(2,3,3)), nn.BatchNorm3d(96), nn.ReLU(inplace=True), nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)), nn.Conv3d(96, 256, kernel_size=(3,5,5), stride=(1,2,2), padding=(1,2,2)), nn.BatchNorm3d(256), nn.ReLU(inplace=True), nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)), nn.Conv3d(256, 256, kernel_size=(3,3,3), padding=(1,1,1)), nn.BatchNorm3d(256), nn.ReLU(inplace=True), nn.Conv3d(256, 256, kernel_size=(3,3,3), padding=(1,1,1)), nn.BatchNorm3d(256), nn.ReLU(inplace=True), nn.Conv3d(256, 256, kernel_size=(3,3,3), padding=(1,1,1)), nn.BatchNorm3d(256), nn.ReLU(inplace=True), nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)), nn.Conv3d(256, 512, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)), nn.BatchNorm3d(512), nn.ReLU(inplace=True), nn.AdaptiveAvgPool3d((None, 1, 1)) ) self.channel_conv = nn.Sequential( nn.Conv1d(512, 512, kernel_size=1), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Conv1d(512, output_channels, kernel_size=1), nn.BatchNorm1d(output_channels), ) self.channel_attn = ChannelAttention(output_channels) def forward(self, x): x = self.conv_layers(x) B, C, T, H, W = x.size() x = x.view(B, C, T) x = self.channel_conv(x) x = self.channel_attn(x) return x class SyncNetFCN_Classification(nn.Module): """ Fully Convolutional SyncNet with CLASSIFICATION output. Treats offset detection as a multi-class classification problem: - num_classes = 2 * max_offset + 1 (e.g., 251 classes for max_offset=125) - Class index = offset + max_offset (e.g., offset -5 → class 120) - Uses CrossEntropyLoss for training - Default: ±125 frames = ±5 seconds at 25fps This avoids the regression-to-mean problem encountered with MSE loss. Architecture: 1. Audio encoder: MFCC → temporal features 2. Video encoder: frames → temporal features 3. Correlation layer: compute audio-video similarity over time 4. Classifier: predict offset class probabilities """ def __init__(self, embedding_dim=512, max_offset=125, dropout=0.3): super(SyncNetFCN_Classification, self).__init__() self.embedding_dim = embedding_dim self.max_offset = max_offset self.num_classes = 2 * max_offset + 1 # -15 to +15 = 31 classes # Encoders self.audio_encoder = FCN_AudioEncoder(output_channels=embedding_dim) self.video_encoder = FCN_VideoEncoder(output_channels=embedding_dim) # Temporal correlation self.correlation = TemporalCorrelation(max_displacement=max_offset) # Classifier head (replaces regressor) self.classifier = nn.Sequential( nn.Conv1d(self.num_classes, 128, kernel_size=3, padding=1), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Conv1d(128, 64, kernel_size=3, padding=1), nn.BatchNorm1d(64), nn.ReLU(inplace=True), nn.Dropout(dropout), # Output: class logits for each timestep nn.Conv1d(64, self.num_classes, kernel_size=1), ) # Global classifier (for single prediction from sequence) self.global_classifier = nn.Sequential( nn.AdaptiveAvgPool1d(1), nn.Flatten(), nn.Linear(self.num_classes, 128), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(128, self.num_classes), ) def forward_audio(self, audio_mfcc): """Extract audio features.""" return self.audio_encoder(audio_mfcc) def forward_video(self, video_frames): """Extract video features.""" return self.video_encoder(video_frames) def forward(self, audio_mfcc, video_frames, return_temporal=False): """ Forward pass with audio-video offset classification. Args: audio_mfcc: [B, 1, F, T] - MFCC features video_frames: [B, 3, T', H, W] - video frames return_temporal: If True, also return per-timestep predictions Returns: class_logits: [B, num_classes] - global offset class logits temporal_logits: [B, num_classes, T] - per-timestep logits (if return_temporal) audio_features: [B, C, T_a] - audio embeddings video_features: [B, C, T_v] - video embeddings """ # Extract features if audio_mfcc.dim() == 3: audio_mfcc = audio_mfcc.unsqueeze(1) audio_features = self.audio_encoder(audio_mfcc) video_features = self.video_encoder(video_frames) # Align temporal dimensions min_time = min(audio_features.size(2), video_features.size(2)) audio_features = audio_features[:, :, :min_time] video_features = video_features[:, :, :min_time] # Compute correlation correlation = self.correlation(video_features, audio_features) # Per-timestep classification temporal_logits = self.classifier(correlation) # Global classification (aggregate over time) class_logits = self.global_classifier(temporal_logits) if return_temporal: return class_logits, temporal_logits, audio_features, video_features return class_logits, audio_features, video_features def predict_offset(self, class_logits): """ Convert class logits to offset prediction. Args: class_logits: [B, num_classes] - classification logits Returns: offsets: [B] - predicted offset in frames confidences: [B] - prediction confidence (softmax probability) """ probs = F.softmax(class_logits, dim=1) predicted_class = probs.argmax(dim=1) offsets = predicted_class - self.max_offset # Convert class to offset confidences = probs.max(dim=1).values return offsets, confidences def offset_to_class(self, offset): """Convert offset value to class index.""" return offset + self.max_offset def class_to_offset(self, class_idx): """Convert class index to offset value.""" return class_idx - self.max_offset class StreamSyncFCN_Classification(nn.Module): """ Streaming-capable FCN SyncNet with classification output. Includes preprocessing, transfer learning, and inference utilities. """ def __init__(self, embedding_dim=512, max_offset=125, window_size=25, stride=5, buffer_size=100, pretrained_syncnet_path=None, auto_load_pretrained=True, dropout=0.3): super(StreamSyncFCN_Classification, self).__init__() self.window_size = window_size self.stride = stride self.buffer_size = buffer_size self.max_offset = max_offset self.num_classes = 2 * max_offset + 1 # Initialize classification model self.fcn_model = SyncNetFCN_Classification( embedding_dim=embedding_dim, max_offset=max_offset, dropout=dropout ) # Auto-load pretrained weights if auto_load_pretrained and pretrained_syncnet_path: self.load_pretrained_syncnet(pretrained_syncnet_path) self.reset_buffers() def reset_buffers(self): """Reset temporal buffers.""" self.logits_buffer = [] self.frame_count = 0 def load_pretrained_syncnet(self, syncnet_model_path, freeze_conv=True, verbose=True): """Load conv layers from original SyncNet.""" if verbose: print(f"Loading pretrained SyncNet from: {syncnet_model_path}") try: pretrained = torch.load(syncnet_model_path, map_location='cpu') if isinstance(pretrained, dict): pretrained_dict = pretrained.get('model_state_dict', pretrained.get('state_dict', pretrained)) else: pretrained_dict = pretrained.state_dict() fcn_dict = self.fcn_model.state_dict() loaded_count = 0 for key in list(pretrained_dict.keys()): if key.startswith('netcnnaud.'): idx = key.split('.')[1] param = '.'.join(key.split('.')[2:]) new_key = f'audio_encoder.conv_layers.{idx}.{param}' if new_key in fcn_dict and pretrained_dict[key].shape == fcn_dict[new_key].shape: fcn_dict[new_key] = pretrained_dict[key] loaded_count += 1 elif key.startswith('netcnnlip.'): idx = key.split('.')[1] param = '.'.join(key.split('.')[2:]) new_key = f'video_encoder.conv_layers.{idx}.{param}' if new_key in fcn_dict and pretrained_dict[key].shape == fcn_dict[new_key].shape: fcn_dict[new_key] = pretrained_dict[key] loaded_count += 1 self.fcn_model.load_state_dict(fcn_dict, strict=False) if verbose: print(f"✓ Loaded {loaded_count} pretrained conv parameters") if freeze_conv: for name, param in self.fcn_model.named_parameters(): if 'conv_layers' in name: param.requires_grad = False if verbose: print("✓ Froze pretrained conv layers") except Exception as e: if verbose: print(f"⚠ Could not load pretrained weights: {e}") def load_fcn_checkpoint(self, checkpoint_path, verbose=True): """Load FCN classification checkpoint.""" checkpoint = torch.load(checkpoint_path, map_location='cpu') if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] else: state_dict = checkpoint # Try to load directly first try: self.fcn_model.load_state_dict(state_dict, strict=True) if verbose: print(f"✓ Loaded full checkpoint from {checkpoint_path}") except: # Load only matching keys model_dict = self.fcn_model.state_dict() pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict and v.shape == model_dict[k].shape} model_dict.update(pretrained_dict) self.fcn_model.load_state_dict(model_dict, strict=False) if verbose: print(f"✓ Loaded {len(pretrained_dict)}/{len(state_dict)} parameters from {checkpoint_path}") return checkpoint.get('epoch', None) def unfreeze_all_layers(self, verbose=True): """Unfreeze all layers for fine-tuning.""" for param in self.fcn_model.parameters(): param.requires_grad = True if verbose: print("✓ Unfrozen all layers for fine-tuning") def forward(self, audio_mfcc, video_frames, return_temporal=False): """Forward pass through FCN model.""" return self.fcn_model(audio_mfcc, video_frames, return_temporal) def extract_audio_mfcc(self, video_path, temp_dir='temp'): """Extract audio and compute MFCC.""" os.makedirs(temp_dir, exist_ok=True) audio_path = os.path.join(temp_dir, 'temp_audio.wav') cmd = ['ffmpeg', '-y', '-i', video_path, '-ac', '1', '-ar', '16000', '-vn', '-acodec', 'pcm_s16le', audio_path] subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) sample_rate, audio = wavfile.read(audio_path) mfcc = python_speech_features.mfcc(audio, sample_rate, numcep=13).T mfcc_tensor = torch.FloatTensor(mfcc).unsqueeze(0).unsqueeze(0) if os.path.exists(audio_path): os.remove(audio_path) return mfcc_tensor def extract_video_frames(self, video_path, target_size=(112, 112)): """Extract video frames as tensor.""" cap = cv2.VideoCapture(video_path) frames = [] while True: ret, frame = cap.read() if not ret: break frame = cv2.resize(frame, target_size) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame.astype(np.float32) / 255.0) cap.release() if not frames: raise ValueError(f"No frames extracted from {video_path}") frames_array = np.stack(frames, axis=0) video_tensor = torch.FloatTensor(frames_array).permute(3, 0, 1, 2).unsqueeze(0) return video_tensor def detect_offset(self, video_path, temp_dir='temp', verbose=True): """ Detect AV offset using classification approach. Args: video_path: Path to video file temp_dir: Temporary directory for audio extraction verbose: Print progress information Returns: offset: Predicted offset in frames (positive = audio ahead) confidence: Classification confidence (0-1) class_probs: Full probability distribution over offset classes """ if verbose: print(f"Processing: {video_path}") # Extract features mfcc = self.extract_audio_mfcc(video_path, temp_dir) video = self.extract_video_frames(video_path) if verbose: print(f" Audio MFCC: {mfcc.shape}, Video: {video.shape}") # Run inference self.fcn_model.eval() with torch.no_grad(): class_logits, _, _ = self.fcn_model(mfcc, video) offset, confidence = self.fcn_model.predict_offset(class_logits) class_probs = F.softmax(class_logits, dim=1) offset = offset.item() confidence = confidence.item() if verbose: print(f" Detected offset: {offset:+d} frames") print(f" Confidence: {confidence:.4f}") return offset, confidence, class_probs.squeeze(0).numpy() def process_video_file(self, video_path, temp_dir='temp', verbose=True): """Alias for detect_offset for compatibility.""" offset, confidence, _ = self.detect_offset(video_path, temp_dir, verbose) return offset, confidence def create_classification_criterion(max_offset=125, label_smoothing=0.1): """ Create loss function for classification training. Args: max_offset: Maximum offset value label_smoothing: Label smoothing factor (0 = no smoothing) Returns: criterion: CrossEntropyLoss with optional label smoothing """ return nn.CrossEntropyLoss(label_smoothing=label_smoothing) def train_step_classification(model, audio, video, target_offset, criterion, optimizer, device): """ Single training step for classification model. Args: model: SyncNetFCN_Classification or StreamSyncFCN_Classification audio: [B, 1, F, T] audio MFCC video: [B, 3, T, H, W] video frames target_offset: [B] target offset in frames (-max_offset to +max_offset) criterion: CrossEntropyLoss optimizer: Optimizer device: torch device Returns: loss: Training loss value accuracy: Classification accuracy """ model.train() optimizer.zero_grad() audio = audio.to(device) video = video.to(device) # Convert offset to class index if hasattr(model, 'fcn_model'): target_class = target_offset + model.fcn_model.max_offset else: target_class = target_offset + model.max_offset target_class = target_class.long().to(device) # Forward pass if hasattr(model, 'fcn_model'): class_logits, _, _ = model(audio, video) else: class_logits, _, _ = model(audio, video) # Compute loss loss = criterion(class_logits, target_class) # Backward pass loss.backward() optimizer.step() # Compute accuracy predicted_class = class_logits.argmax(dim=1) accuracy = (predicted_class == target_class).float().mean().item() return loss.item(), accuracy def validate_classification(model, dataloader, criterion, device, max_offset=125): """ Validate classification model. Returns: avg_loss: Average validation loss accuracy: Classification accuracy mean_error: Mean absolute error in frames """ model.eval() total_loss = 0 correct = 0 total = 0 total_error = 0 with torch.no_grad(): for audio, video, target_offset in dataloader: audio = audio.to(device) video = video.to(device) target_class = (target_offset + max_offset).long().to(device) if hasattr(model, 'fcn_model'): class_logits, _, _ = model(audio, video) else: class_logits, _, _ = model(audio, video) loss = criterion(class_logits, target_class) total_loss += loss.item() * audio.size(0) predicted_class = class_logits.argmax(dim=1) correct += (predicted_class == target_class).sum().item() total += audio.size(0) # Mean absolute error predicted_offset = predicted_class - max_offset target_offset_dev = target_class - max_offset total_error += (predicted_offset - target_offset_dev).abs().sum().item() return total_loss / total, correct / total, total_error / total if __name__ == "__main__": print("Testing SyncNetFCN_Classification...") # Test model creation (use smaller offset for quick testing) model = SyncNetFCN_Classification(embedding_dim=512, max_offset=125) print(f"Number of classes: {model.num_classes}") # Test forward pass audio_input = torch.randn(2, 1, 13, 100) video_input = torch.randn(2, 3, 25, 112, 112) class_logits, audio_feat, video_feat = model(audio_input, video_input) print(f"Class logits: {class_logits.shape}") print(f"Audio features: {audio_feat.shape}") print(f"Video features: {video_feat.shape}") # Test prediction offsets, confidences = model.predict_offset(class_logits) print(f"Predicted offsets: {offsets}") print(f"Confidences: {confidences}") # Test with temporal output class_logits, temporal_logits, _, _ = model(audio_input, video_input, return_temporal=True) print(f"Temporal logits: {temporal_logits.shape}") # Test training step print("\nTesting training step...") criterion = create_classification_criterion(max_offset=125, label_smoothing=0.1) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) target_offset = torch.tensor([3, -5]) # Example target offsets loss, acc = train_step_classification( model, audio_input, video_input, target_offset, criterion, optimizer, 'cpu' ) print(f"Training loss: {loss:.4f}, Accuracy: {acc:.2%}") # Count parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\nTotal parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") print("\nTesting StreamSyncFCN_Classification...") stream_model = StreamSyncFCN_Classification( embedding_dim=512, max_offset=125, pretrained_syncnet_path=None, auto_load_pretrained=False ) class_logits, _, _ = stream_model(audio_input, video_input) print(f"Stream model class logits: {class_logits.shape}") print("\n✓ All tests passed!")