Syncnet_FCN / SyncNet_TransferLearning.py
Shubham
Deploy clean version
579f772
#!/usr/bin/python
#-*- coding: utf-8 -*-
"""
Transfer Learning Implementation for SyncNet
This module provides pre-trained backbone integration for improved performance.
Supported backbones:
- Video: 3D ResNet (Kinetics), I3D, SlowFast, X3D
- Audio: VGGish (AudioSet), wav2vec 2.0, HuBERT
Author: Enhanced version
Date: 2025-11-22
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# ==================== VIDEO BACKBONES ====================
class ResNet3D_Backbone(nn.Module):
"""
3D ResNet backbone pre-trained on Kinetics-400.
Uses torchvision's video models.
"""
def __init__(self, embedding_dim=512, pretrained=True, model_type='r3d_18'):
super(ResNet3D_Backbone, self).__init__()
try:
import torchvision.models.video as video_models
# Load pre-trained model
if model_type == 'r3d_18':
backbone = video_models.r3d_18(pretrained=pretrained)
elif model_type == 'mc3_18':
backbone = video_models.mc3_18(pretrained=pretrained)
elif model_type == 'r2plus1d_18':
backbone = video_models.r2plus1d_18(pretrained=pretrained)
else:
raise ValueError(f"Unknown model type: {model_type}")
# Remove final FC and pooling layers
self.features = nn.Sequential(*list(backbone.children())[:-2])
# Add custom head
self.conv_head = nn.Sequential(
nn.Conv3d(512, embedding_dim, kernel_size=1),
nn.BatchNorm3d(embedding_dim),
nn.ReLU(inplace=True),
)
print(f"Loaded {model_type} with pretrained={pretrained}")
except ImportError:
print("Warning: torchvision not found. Using random initialization.")
self.features = self._build_simple_3dcnn()
self.conv_head = nn.Conv3d(512, embedding_dim, 1)
def _build_simple_3dcnn(self):
"""Fallback if torchvision not available."""
return nn.Sequential(
nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3)),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.Conv3d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm3d(128),
nn.ReLU(inplace=True),
nn.Conv3d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm3d(256),
nn.ReLU(inplace=True),
nn.Conv3d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm3d(512),
nn.ReLU(inplace=True),
)
def forward(self, x):
"""
Args:
x: [B, 3, T, H, W]
Returns:
features: [B, C, T', H', W']
"""
x = self.features(x)
x = self.conv_head(x)
return x
class I3D_Backbone(nn.Module):
"""
Inflated 3D ConvNet (I3D) backbone.
Requires external I3D implementation.
"""
def __init__(self, embedding_dim=512, pretrained=True):
super(I3D_Backbone, self).__init__()
try:
# Try to import I3D (needs to be installed separately)
from i3d import InceptionI3d
self.i3d = InceptionI3d(400, in_channels=3)
if pretrained:
# Load pre-trained weights
state_dict = torch.load('models/rgb_imagenet.pt', map_location='cpu')
self.i3d.load_state_dict(state_dict)
print("Loaded I3D with ImageNet+Kinetics pre-training")
# Adaptation layer
self.adapt = nn.Conv3d(1024, embedding_dim, kernel_size=1)
except:
print("Warning: I3D not available. Install from: https://github.com/piergiaj/pytorch-i3d")
# Fallback to simple 3D CNN
self.i3d = self._build_fallback()
self.adapt = nn.Conv3d(512, embedding_dim, 1)
def _build_fallback(self):
return nn.Sequential(
nn.Conv3d(3, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3)),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True),
nn.Conv3d(64, 512, kernel_size=3, padding=1),
nn.BatchNorm3d(512),
nn.ReLU(inplace=True),
)
def forward(self, x):
features = self.i3d.extract_features(x) if hasattr(self.i3d, 'extract_features') else self.i3d(x)
features = self.adapt(features)
return features
# ==================== AUDIO BACKBONES ====================
class VGGish_Backbone(nn.Module):
"""
VGGish audio encoder pre-trained on AudioSet.
Processes log-mel spectrograms.
"""
def __init__(self, embedding_dim=512, pretrained=True):
super(VGGish_Backbone, self).__init__()
try:
import torchvggish
# Load VGGish
self.vggish = torchvggish.vggish()
if pretrained:
# Download and load pre-trained weights
self.vggish.load_state_dict(
torch.hub.load_state_dict_from_url(
'https://github.com/harritaylor/torchvggish/releases/download/v0.1/vggish-10086976.pth',
map_location='cpu'
)
)
print("Loaded VGGish pre-trained on AudioSet")
# Use convolutional part only
self.features = self.vggish.features
# Adaptation layer
self.adapt = nn.Sequential(
nn.Conv2d(512, embedding_dim, kernel_size=1),
nn.BatchNorm2d(embedding_dim),
nn.ReLU(inplace=True),
)
except ImportError:
print("Warning: torchvggish not found. Install: pip install torchvggish")
self.features = self._build_fallback()
self.adapt = nn.Conv2d(512, embedding_dim, 1)
def _build_fallback(self):
"""Simple audio CNN if VGGish unavailable."""
return nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
def forward(self, x):
"""
Args:
x: [B, 1, F, T] or [B, 1, 96, T] (log-mel spectrogram)
Returns:
features: [B, C, F', T']
"""
x = self.features(x)
x = self.adapt(x)
return x
class Wav2Vec_Backbone(nn.Module):
"""
wav2vec 2.0 backbone for speech representation.
Processes raw waveforms.
"""
def __init__(self, embedding_dim=512, pretrained=True, model_name='facebook/wav2vec2-base'):
super(Wav2Vec_Backbone, self).__init__()
try:
from transformers import Wav2Vec2Model
if pretrained:
self.wav2vec = Wav2Vec2Model.from_pretrained(model_name)
print(f"Loaded {model_name} from HuggingFace")
else:
from transformers import Wav2Vec2Config
config = Wav2Vec2Config()
self.wav2vec = Wav2Vec2Model(config)
# Freeze early layers for fine-tuning
self._freeze_layers(num_layers_to_freeze=6)
# Adaptation layer
wav2vec_dim = self.wav2vec.config.hidden_size
self.adapt = nn.Sequential(
nn.Linear(wav2vec_dim, embedding_dim),
nn.LayerNorm(embedding_dim),
nn.ReLU(),
)
except ImportError:
print("Warning: transformers not found. Install: pip install transformers")
raise
def _freeze_layers(self, num_layers_to_freeze):
"""Freeze early transformer layers."""
for param in self.wav2vec.feature_extractor.parameters():
param.requires_grad = False
for i, layer in enumerate(self.wav2vec.encoder.layers):
if i < num_layers_to_freeze:
for param in layer.parameters():
param.requires_grad = False
def forward(self, waveform):
"""
Args:
waveform: [B, T] - raw audio waveform (16kHz)
Returns:
features: [B, C, T'] - temporal features
"""
# Extract features from wav2vec
outputs = self.wav2vec(waveform, output_hidden_states=True)
features = outputs.last_hidden_state # [B, T', D]
# Adapt to target dimension
features = self.adapt(features) # [B, T', embedding_dim]
# Reshape to [B, C, T']
features = features.transpose(1, 2)
return features
# ==================== INTEGRATED SYNCNET WITH TRANSFER LEARNING ====================
class SyncNet_TransferLearning(nn.Module):
"""
SyncNet with transfer learning from pre-trained backbones.
Args:
video_backbone: 'resnet3d', 'i3d', 'simple'
audio_backbone: 'vggish', 'wav2vec', 'simple'
embedding_dim: Dimension of shared embedding space
max_offset: Maximum temporal offset to consider
freeze_backbone: Whether to freeze backbone weights
"""
def __init__(self,
video_backbone='resnet3d',
audio_backbone='vggish',
embedding_dim=512,
max_offset=15,
freeze_backbone=False):
super(SyncNet_TransferLearning, self).__init__()
self.embedding_dim = embedding_dim
self.max_offset = max_offset
# Initialize video encoder
if video_backbone == 'resnet3d':
self.video_encoder = ResNet3D_Backbone(embedding_dim, pretrained=True)
elif video_backbone == 'i3d':
self.video_encoder = I3D_Backbone(embedding_dim, pretrained=True)
else:
from SyncNetModel_FCN import FCN_VideoEncoder
self.video_encoder = FCN_VideoEncoder(embedding_dim)
# Initialize audio encoder
if audio_backbone == 'vggish':
self.audio_encoder = VGGish_Backbone(embedding_dim, pretrained=True)
elif audio_backbone == 'wav2vec':
self.audio_encoder = Wav2Vec_Backbone(embedding_dim, pretrained=True)
else:
from SyncNetModel_FCN import FCN_AudioEncoder
self.audio_encoder = FCN_AudioEncoder(embedding_dim)
# Freeze backbones if requested
if freeze_backbone:
self._freeze_backbones()
# Temporal pooling to handle variable spatial/frequency dimensions
self.video_temporal_pool = nn.AdaptiveAvgPool3d((None, 1, 1))
self.audio_temporal_pool = nn.AdaptiveAvgPool2d((1, None))
# Correlation and sync prediction (from FCN model)
from SyncNetModel_FCN import TemporalCorrelation
self.correlation = TemporalCorrelation(max_displacement=max_offset)
self.sync_predictor = nn.Sequential(
nn.Conv1d(2*max_offset+1, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.Conv1d(64, 2*max_offset+1, kernel_size=1),
)
def _freeze_backbones(self):
"""Freeze backbone parameters for fine-tuning only the head."""
for param in self.video_encoder.parameters():
param.requires_grad = False
for param in self.audio_encoder.parameters():
param.requires_grad = False
print("Backbones frozen. Only training sync predictor.")
def forward_video(self, video):
"""
Extract video features.
Args:
video: [B, 3, T, H, W]
Returns:
features: [B, C, T']
"""
features = self.video_encoder(video) # [B, C, T', H', W']
features = self.video_temporal_pool(features) # [B, C, T', 1, 1]
B, C, T, _, _ = features.shape
features = features.view(B, C, T) # [B, C, T']
return features
def forward_audio(self, audio):
"""
Extract audio features.
Args:
audio: [B, 1, F, T] or [B, T] (raw waveform for wav2vec)
Returns:
features: [B, C, T']
"""
if isinstance(self.audio_encoder, Wav2Vec_Backbone):
# wav2vec expects [B, T]
if audio.dim() == 4:
# Convert from spectrogram to waveform (placeholder - need actual audio)
raise NotImplementedError("Need raw waveform for wav2vec")
features = self.audio_encoder(audio)
else:
features = self.audio_encoder(audio) # [B, C, F', T']
features = self.audio_temporal_pool(features) # [B, C, 1, T']
B, C, _, T = features.shape
features = features.view(B, C, T) # [B, C, T']
return features
def forward(self, audio, video):
"""
Full forward pass with sync prediction.
Args:
audio: [B, 1, F, T] - audio features
video: [B, 3, T', H, W] - video frames
Returns:
sync_probs: [B, 2K+1, T''] - sync probabilities
audio_features: [B, C, T_a]
video_features: [B, C, T_v]
"""
# Extract features
audio_features = self.forward_audio(audio)
video_features = self.forward_video(video)
# Align temporal dimensions
min_time = min(audio_features.size(2), video_features.size(2))
audio_features = audio_features[:, :, :min_time]
video_features = video_features[:, :, :min_time]
# Compute correlation
correlation = self.correlation(video_features, audio_features)
# Predict sync probabilities
sync_logits = self.sync_predictor(correlation)
sync_probs = F.softmax(sync_logits, dim=1)
return sync_probs, audio_features, video_features
def compute_offset(self, sync_probs):
"""
Compute offset from sync probability map.
Args:
sync_probs: [B, 2K+1, T] - sync probabilities
Returns:
offsets: [B, T] - predicted offset for each frame
confidences: [B, T] - confidence scores
"""
max_probs, max_indices = torch.max(sync_probs, dim=1)
offsets = self.max_offset - max_indices
median_probs = torch.median(sync_probs, dim=1)[0]
confidences = max_probs - median_probs
return offsets, confidences
# ==================== TRAINING UTILITIES ====================
def fine_tune_with_transfer_learning(model,
train_loader,
val_loader,
num_epochs=10,
lr=1e-4,
device='cuda'):
"""
Fine-tune pre-trained model on SyncNet task.
Strategy:
1. Freeze backbones, train head (2-3 epochs)
2. Unfreeze last layers, train with small lr (5 epochs)
3. Unfreeze all, train with very small lr (2-3 epochs)
"""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
for epoch in range(num_epochs):
# Phase 1: Freeze backbones
if epoch < 3:
model._freeze_backbones()
current_lr = lr
# Phase 2: Unfreeze
elif epoch == 3:
for param in model.parameters():
param.requires_grad = True
current_lr = lr / 10
optimizer = torch.optim.Adam(model.parameters(), lr=current_lr)
model.train()
total_loss = 0
for batch_idx, (audio, video, labels) in enumerate(train_loader):
audio, video = audio.to(device), video.to(device)
labels = labels.to(device)
# Forward pass
sync_probs, _, _ = model(audio, video)
# Loss (cross-entropy on offset prediction)
loss = F.cross_entropy(
sync_probs.view(-1, sync_probs.size(1)),
labels.view(-1)
)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
# Validation
model.eval()
val_loss = 0
correct = 0
total = 0
with torch.no_grad():
for audio, video, labels in val_loader:
audio, video = audio.to(device), video.to(device)
labels = labels.to(device)
sync_probs, _, _ = model(audio, video)
val_loss += F.cross_entropy(
sync_probs.view(-1, sync_probs.size(1)),
labels.view(-1)
).item()
offsets, _ = model.compute_offset(sync_probs)
correct += (offsets.round() == labels).sum().item()
total += labels.numel()
scheduler.step()
print(f"Epoch {epoch+1}/{num_epochs}")
print(f" Train Loss: {total_loss/len(train_loader):.4f}")
print(f" Val Loss: {val_loss/len(val_loader):.4f}")
print(f" Val Accuracy: {100*correct/total:.2f}%")
# ==================== EXAMPLE USAGE ====================
if __name__ == "__main__":
print("Testing Transfer Learning SyncNet...")
# Create model with pre-trained backbones
model = SyncNet_TransferLearning(
video_backbone='resnet3d', # or 'i3d'
audio_backbone='vggish', # or 'wav2vec'
embedding_dim=512,
max_offset=15,
freeze_backbone=False
)
print(f"\nModel architecture:")
print(f" Video encoder: {type(model.video_encoder).__name__}")
print(f" Audio encoder: {type(model.audio_encoder).__name__}")
# Test forward pass
dummy_audio = torch.randn(2, 1, 13, 100)
dummy_video = torch.randn(2, 3, 25, 112, 112)
try:
sync_probs, audio_feat, video_feat = model(dummy_audio, dummy_video)
print(f"\nForward pass successful!")
print(f" Sync probs: {sync_probs.shape}")
print(f" Audio features: {audio_feat.shape}")
print(f" Video features: {video_feat.shape}")
offsets, confidences = model.compute_offset(sync_probs)
print(f" Offsets: {offsets.shape}")
print(f" Confidences: {confidences.shape}")
except Exception as e:
print(f"Error: {e}")
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nParameters:")
print(f" Total: {total_params:,}")
print(f" Trainable: {trainable_params:,}")