|
|
""" |
|
|
Main model for using CodecLM. This will combine all the required components |
|
|
and provide easy access to the generation API. |
|
|
""" |
|
|
|
|
|
import typing as tp |
|
|
import warnings |
|
|
import sys |
|
|
import time |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
import torchaudio |
|
|
import numpy as np |
|
|
import lightning as pl |
|
|
from torchmetrics.classification import MulticlassAccuracy |
|
|
import pdb |
|
|
from codeclm.models import builders |
|
|
import math |
|
|
from torch.optim import Optimizer |
|
|
from torch.optim.lr_scheduler import _LRScheduler |
|
|
from peft import LoraConfig, get_peft_model |
|
|
from datetime import datetime |
|
|
import os |
|
|
os.environ['TOKENIZERS_PARALLELISM'] = "false" |
|
|
|
|
|
|
|
|
class CodecLM_PL(pl.LightningModule): |
|
|
def __init__(self, cfg): |
|
|
super().__init__() |
|
|
|
|
|
self.cfg = cfg |
|
|
|
|
|
|
|
|
self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg) |
|
|
if self.audio_tokenizer is not None: |
|
|
for param in self.audio_tokenizer.parameters(): |
|
|
param.requires_grad = False |
|
|
if "audio_tokenizer_checkpoint_sep" in self.cfg.keys(): |
|
|
self.seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg) |
|
|
for param in self.seperate_tokenizer.parameters(): |
|
|
param.requires_grad = False |
|
|
else: |
|
|
self.seperate_tokenizer = None |
|
|
|
|
|
|
|
|
self.audiolm = builders.get_lm_model(self.cfg) |
|
|
print(self.audiolm) |
|
|
|
|
|
print('Number of parameters: ', sum(p.numel() for p in self.audiolm.parameters())) |
|
|
|
|
|
if self.cfg.use_pretrained == 'deepspeed': |
|
|
checkpoint = torch.load(self.cfg.pretrained.deepspeed_checkpoint, map_location='cpu') |
|
|
missing, unexpected = self.load_state_dict(checkpoint, strict=False) |
|
|
print(f'-------------Missing--------------\n{missing}') |
|
|
print(f'-------------Unexpected--------------\n{unexpected}') |
|
|
print("successfully load deepspeed pretrained model {}".format(self.cfg.pretrained.deepspeed_checkpoint)) |
|
|
self.missing = missing |
|
|
else: |
|
|
self.missing = [] |
|
|
|
|
|
if hasattr(self.cfg, 'lora'): |
|
|
perf_config = LoraConfig( |
|
|
r = self.cfg.lora.r, |
|
|
lora_alpha = self.cfg.lora.lora_alpha, |
|
|
target_modules = self.cfg.lora.target_modules, |
|
|
lora_dropout = self.cfg.lora.lora_dropout, |
|
|
bias = self.cfg.lora.bias, |
|
|
task_type = self.cfg.lora.task_type, |
|
|
) |
|
|
self.audiolm = get_peft_model(self.audiolm, perf_config) |
|
|
|
|
|
|
|
|
self.val_steps = [] |
|
|
self.train_slide_acc = [] |
|
|
self.train_steps = [] |
|
|
self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy( |
|
|
self.audiolm.code_size, |
|
|
top_k=1, |
|
|
average="micro", multidim_average="global", |
|
|
ignore_index=self.cfg.lm.code_size, |
|
|
) for _ in range(self.audiolm.code_depth)]) |
|
|
self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy( |
|
|
self.audiolm.code_size, |
|
|
top_k=10, |
|
|
average="micro", multidim_average="global", |
|
|
ignore_index=self.cfg.lm.code_size, |
|
|
) for _ in range(self.audiolm.code_depth)]) |
|
|
|
|
|
self.epoch = 0 |
|
|
print("++++++++++++++++ training <song> +++++++++++++++++") |
|
|
|
|
|
|
|
|
def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384): |
|
|
batch_size = sequence_lengths.size(0) |
|
|
max_length = x.size(2) |
|
|
|
|
|
|
|
|
if max_length == sequence_lengths.max(): |
|
|
x = F.pad(x, (0, 1), value=end_id) |
|
|
max_length = x.size(2) |
|
|
|
|
|
if max_length <= sequence_lengths.max() + 1: |
|
|
sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length) |
|
|
|
|
|
|
|
|
x[torch.arange(batch_size), :, sequence_lengths] = end_id |
|
|
sequence_lengths += 1 |
|
|
|
|
|
mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1) |
|
|
mask = mask.to(x.device) |
|
|
mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length) |
|
|
x = torch.where(mask_3d, x, end_id+1) |
|
|
return x, mask_3d |
|
|
|
|
|
@torch.no_grad() |
|
|
def preprocess_batch(self, batch): |
|
|
|
|
|
audio, text_lyric, time_stamp, structure_dur, prompt_audio, structure_labels = batch |
|
|
|
|
|
dur, valid_st, valid_et = zip(*time_stamp) |
|
|
|
|
|
if self.audio_tokenizer is not None: |
|
|
|
|
|
self.audio_tokenizer.eval() |
|
|
with torch.no_grad(): |
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
audio_tokens, scale = self.audio_tokenizer.encode(audio) |
|
|
audio_tokens = audio_tokens[:,:self.cfg.lm.code_depth,:] |
|
|
audio_tokens = audio_tokens.long() |
|
|
else: |
|
|
audio_tokens = audio.long() |
|
|
|
|
|
token_dur = (torch.Tensor(dur) * self.cfg.audio_tokenizer_frame_rate).int() |
|
|
audio_tokens, audio_padding_mask = self.generate_mask_and_end_token(audio_tokens, token_dur, |
|
|
end_id=self.audiolm.eos_token_id) |
|
|
condition_tensors = self.audiolm.prepare_condition_tensors(batch_size=len(text_lyric), |
|
|
text=text_lyric, audio_qt_emb=prompt_audio) |
|
|
|
|
|
return condition_tensors, audio_tokens, audio_padding_mask |
|
|
|
|
|
def get_time(self): |
|
|
|
|
|
now = datetime.now() |
|
|
|
|
|
|
|
|
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f") |
|
|
return formatted_now |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
|
|
condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch) |
|
|
|
|
|
|
|
|
model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors, |
|
|
training_steps=self.global_step) |
|
|
logits = model_output.logits.float() |
|
|
mask = padding_mask & model_output.mask |
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) |
|
|
|
|
|
total_loss = ce |
|
|
if torch.isnan(total_loss): |
|
|
print(self.trainer.global_rank, ce, padding_mask, batch[1]) |
|
|
print('--------------------------------------------------------------') |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
metrics = {} |
|
|
self.log('ce', ce, prog_bar=True) |
|
|
metrics['ppl'] = torch.exp(ce) |
|
|
for k, ce_q in enumerate(ce_per_codebook): |
|
|
metrics[f'ce_q{k + 1}'] = ce_q |
|
|
metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q) |
|
|
|
|
|
masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size) |
|
|
metrics['acc'] = [] |
|
|
for k in range(self.audiolm.code_depth): |
|
|
metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), |
|
|
masked_labels[:, k]).item()) |
|
|
metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])).item() |
|
|
|
|
|
self.train_steps.append({'ce': ce.detach().cpu().item(), 'acc': metrics['acc']}) |
|
|
self.log('train_acc', metrics['acc']+1e-8, prog_bar=True) |
|
|
self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], prog_bar=True) |
|
|
self.log_dict(metrics) |
|
|
|
|
|
return total_loss |
|
|
|
|
|
@torch.no_grad() |
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
|
|
condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch) |
|
|
|
|
|
|
|
|
model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors) |
|
|
logits = model_output.logits |
|
|
mask = padding_mask & model_output.mask |
|
|
|
|
|
|
|
|
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) |
|
|
metrics = {} |
|
|
metrics['val_ce'] = ce |
|
|
metrics['val_ppl'] = torch.exp(ce) |
|
|
for k, ce_q in enumerate(ce_per_codebook): |
|
|
metrics[f'val_ce_q{k + 1}'] = ce_q |
|
|
metrics[f'val_ppl_q{k + 1}'] = torch.exp(ce_q) |
|
|
masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size) |
|
|
|
|
|
for k in range(self.audiolm.code_depth): |
|
|
self.top1_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) |
|
|
self.top10_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) |
|
|
self.val_steps.append(metrics) |
|
|
|
|
|
metrics['acc'] = [] |
|
|
metrics['acc_top10'] = [] |
|
|
for k in range(self.audiolm.code_depth): |
|
|
metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item()) |
|
|
metrics['acc_top10'].append(self.top10_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item()) |
|
|
metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])) |
|
|
metrics['acc_top10'] = torch.mean(torch.Tensor(metrics['acc_top10'])) |
|
|
|
|
|
return metrics['acc'] |
|
|
|
|
|
|
|
|
def on_validation_epoch_end(self) -> None: |
|
|
final_metrics = {} |
|
|
for i in self.val_steps: |
|
|
for k in i: |
|
|
final_metrics[k] = final_metrics.get(k, []) + [i[k]] |
|
|
final_metrics = {k: sum(v) / len(v) for k,v in list(final_metrics.items())} |
|
|
self.log_dict(final_metrics) |
|
|
|
|
|
q_acc = [] |
|
|
q_acc10 = [] |
|
|
for i in range(self.audiolm.code_depth): |
|
|
q_acc.append(self.top1_acc_metric[i].compute()) |
|
|
q_acc10.append(self.top10_acc_metric[i].compute()) |
|
|
self.log(f"val_Top1Acc_{i}", q_acc[-1]) |
|
|
self.log(f"val_Top10Acc_{i}", q_acc10[-1]) |
|
|
self.top1_acc_metric[i].reset() |
|
|
self.top10_acc_metric[i].reset() |
|
|
|
|
|
self.log('val_Top1Acc', sum(q_acc) / self.audiolm.code_depth) |
|
|
self.log('val_Top10Acc', sum(q_acc10) / self.audiolm.code_depth) |
|
|
|
|
|
return super().on_validation_epoch_end() |
|
|
|
|
|
|
|
|
def on_validation_epoch_start(self) -> None: |
|
|
self.val_steps = [] |
|
|
for i in range(self.audiolm.code_depth): |
|
|
self.top1_acc_metric[i].reset() |
|
|
self.top10_acc_metric[i].reset() |
|
|
|
|
|
if len(self.train_steps) > 0: |
|
|
train_metrics = {} |
|
|
for i in self.train_steps: |
|
|
for k in i: |
|
|
train_metrics[k] = train_metrics.get(k, []) + [i[k]] |
|
|
train_metrics = {k: sum(v) / len(v) for k,v in list(train_metrics.items())} |
|
|
self.log('train_summary_Top1Acc', train_metrics['acc']) |
|
|
self.log('train_summary_ce', train_metrics['ce']) |
|
|
self.train_steps = [] |
|
|
|
|
|
return super().on_validation_epoch_start() |
|
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
total_updates = self.cfg.optim.epochs * self.cfg.optim.updates_per_epoch |
|
|
optim_dict = {} |
|
|
|
|
|
param_groups = [] |
|
|
missing_params = [] |
|
|
other_params = [] |
|
|
cnt = 0 |
|
|
|
|
|
print('before missing len', len(self.missing)) |
|
|
self.missing = [name.replace('audiolm.', '') for name in self.missing] |
|
|
print('after missing len', len(self.missing)) |
|
|
for name, param in self.audiolm.named_parameters(): |
|
|
if name in self.missing: |
|
|
cnt += 1 |
|
|
print(name) |
|
|
missing_params.append(param) |
|
|
else: |
|
|
other_params.append(param) |
|
|
print(cnt) |
|
|
assert cnt == len(self.missing) |
|
|
param_groups.append({'params': other_params, 'lr': self.cfg.optim.old_lr}) |
|
|
param_groups.append({ |
|
|
'params': missing_params, |
|
|
'lr': self.cfg.optim.new_lr |
|
|
}) |
|
|
|
|
|
if self.cfg.optim.optimizer == "adamw": |
|
|
optim_dict['optimizer'] = torch.optim.AdamW( |
|
|
param_groups, |
|
|
betas=tuple(self.cfg.optim.adam.betas), |
|
|
weight_decay=self.cfg.optim.adam.weight_decay, |
|
|
eps=self.cfg.optim.adam.eps, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
if self.cfg.schedule is None: |
|
|
pass |
|
|
elif self.cfg.schedule.lr_scheduler == "cosine": |
|
|
scheduler = CosineLRScheduler(optim_dict['optimizer'], |
|
|
total_steps=total_updates, |
|
|
warmup_steps=self.cfg.schedule.cosine.warmup, |
|
|
lr_min_ratio=self.cfg.schedule.cosine.lr_min_ratio, |
|
|
cycle_length=self.cfg.schedule.cosine.cycle_length, |
|
|
) |
|
|
optim_dict['lr_scheduler'] = {"scheduler": scheduler, "interval": "step"} |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
return optim_dict |
|
|
|
|
|
|
|
|
def _compute_cross_entropy( |
|
|
self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor |
|
|
) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: |
|
|
"""Compute cross entropy between multi-codebook targets and model's logits. |
|
|
The cross entropy is computed per codebook to provide codebook-level cross entropy. |
|
|
Valid timesteps for each of the codebook are pulled from the mask, where invalid |
|
|
timesteps are set to 0. |
|
|
|
|
|
Args: |
|
|
logits (torch.Tensor): Model's logits of shape [B, K, T, card]. |
|
|
targets (torch.Tensor): Target codes, of shape [B, K, T]. |
|
|
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. |
|
|
Returns: |
|
|
ce (torch.Tensor): Cross entropy averaged over the codebooks |
|
|
ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). |
|
|
""" |
|
|
|
|
|
B, K, T = targets.shape |
|
|
assert logits.shape[:-1] == targets.shape |
|
|
assert mask.shape == targets.shape |
|
|
ce = torch.zeros([], device=targets.device) |
|
|
ce_per_codebook: tp.List[torch.Tensor] = [] |
|
|
for k in range(K): |
|
|
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) |
|
|
targets_k = targets[:, k, ...].contiguous().view(-1) |
|
|
mask_k = mask[:, k, ...].contiguous().view(-1) |
|
|
ce_targets = targets_k[mask_k] |
|
|
ce_logits = logits_k[mask_k] |
|
|
q_ce = F.cross_entropy(ce_logits, ce_targets) |
|
|
ce += q_ce |
|
|
ce_per_codebook.append(q_ce.detach()) |
|
|
|
|
|
ce = ce / K |
|
|
return ce, ce_per_codebook |
|
|
|
|
|
|
|
|
class CodecLM_PL_FT(pl.LightningModule): |
|
|
def __init__(self, cfg): |
|
|
super().__init__() |
|
|
|
|
|
self.cfg = cfg |
|
|
|
|
|
|
|
|
self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg) |
|
|
if self.audio_tokenizer is not None: |
|
|
for param in self.audio_tokenizer.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
self.audiolm = builders.get_lm_model(self.cfg) |
|
|
|
|
|
|
|
|
if self.cfg.use_pretrained == 'deepspeed': |
|
|
checkpoint = torch.load(self.cfg.pretrained.deepspeed_checkpoint, map_location='cpu') |
|
|
missing, unexpected = self.load_state_dict(checkpoint, strict=False) |
|
|
print(f'-------------Missing--------------\n{missing}') |
|
|
print(f'-------------Unexpected--------------\n{unexpected}') |
|
|
print("successfully load deepspeed pretrained model {}".format(self.cfg.pretrained.deepspeed_checkpoint)) |
|
|
|
|
|
|
|
|
self.val_steps = [] |
|
|
self.train_slide_acc = [] |
|
|
self.train_steps = [] |
|
|
self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy( |
|
|
self.audiolm.code_size, |
|
|
top_k=1, |
|
|
average="micro", multidim_average="global", |
|
|
ignore_index=self.cfg.lm.code_size, |
|
|
) for _ in range(self.audiolm.code_depth)]) |
|
|
self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy( |
|
|
self.audiolm.code_size, |
|
|
top_k=10, |
|
|
average="micro", multidim_average="global", |
|
|
ignore_index=self.cfg.lm.code_size, |
|
|
) for _ in range(self.audiolm.code_depth)]) |
|
|
|
|
|
self.epoch = 0 |
|
|
print("++++++++++++++++ training <song> +++++++++++++++++") |
|
|
|
|
|
|
|
|
def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384): |
|
|
batch_size = sequence_lengths.size(0) |
|
|
max_length = x.size(2) |
|
|
|
|
|
|
|
|
if max_length == sequence_lengths.max(): |
|
|
x = F.pad(x, (0, 1), value=end_id) |
|
|
max_length = x.size(2) |
|
|
|
|
|
if max_length <= sequence_lengths.max() + 1: |
|
|
sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length) |
|
|
|
|
|
|
|
|
x[torch.arange(batch_size), :, sequence_lengths] = end_id |
|
|
sequence_lengths += 1 |
|
|
|
|
|
mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1) |
|
|
mask = mask.to(x.device) |
|
|
mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length) |
|
|
x = torch.where(mask_3d, x, end_id+1) |
|
|
return x, mask_3d |
|
|
|
|
|
@torch.no_grad() |
|
|
def preprocess_batch(self, batch): |
|
|
|
|
|
audio, text_lyric, time_stamp, lang_type, prompt_audio = batch |
|
|
dur, valid_st, valid_et = zip(*time_stamp) |
|
|
|
|
|
if self.audio_tokenizer is not None: |
|
|
|
|
|
self.audio_tokenizer.eval() |
|
|
with torch.no_grad(): |
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
audio_tokens, scale = self.audio_tokenizer.encode(audio) |
|
|
audio_tokens = audio_tokens[:,:self.cfg.lm.code_depth,:] |
|
|
audio_tokens = audio_tokens.long() |
|
|
else: |
|
|
audio_tokens = audio.long() |
|
|
|
|
|
token_dur = (torch.Tensor(dur) * self.cfg.audio_tokenizer_frame_rate).int() |
|
|
|
|
|
audio_tokens, audio_padding_mask = self.generate_mask_and_end_token(audio_tokens, token_dur, |
|
|
end_id=self.audiolm.eos_token_id) |
|
|
condition_tensors = self.audiolm.prepare_condition_tensors(batch_size=len(text_lyric), |
|
|
text=text_lyric, audio_qt_emb=prompt_audio) |
|
|
|
|
|
return condition_tensors, audio_tokens, audio_padding_mask |
|
|
|
|
|
def get_time(self): |
|
|
|
|
|
now = datetime.now() |
|
|
|
|
|
|
|
|
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f") |
|
|
return formatted_now |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
|
|
condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch) |
|
|
|
|
|
|
|
|
model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors, |
|
|
training_steps=self.global_step) |
|
|
logits = model_output.logits.float() |
|
|
mask = padding_mask & model_output.mask |
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) |
|
|
|
|
|
total_loss = ce |
|
|
if torch.isnan(total_loss): |
|
|
print(self.trainer.global_rank, ce, padding_mask, batch[1]) |
|
|
|
|
|
torchaudio.save("error_rank{}.wav".format(self.trainer.global_rank), batch[0][:,0].cpu(), 24000) |
|
|
import pdb; pdb.set_trace() |
|
|
return None |
|
|
|
|
|
|
|
|
metrics = {} |
|
|
self.log('ce', ce, prog_bar=True) |
|
|
metrics['ppl'] = torch.exp(ce) |
|
|
for k, ce_q in enumerate(ce_per_codebook): |
|
|
metrics[f'ce_q{k + 1}'] = ce_q |
|
|
metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q) |
|
|
|
|
|
masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size) |
|
|
metrics['acc'] = [] |
|
|
for k in range(self.audiolm.code_depth): |
|
|
metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), |
|
|
masked_labels[:, k]).item()) |
|
|
metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])).item() |
|
|
|
|
|
self.train_steps.append({'ce': ce.detach().cpu().item(), 'acc': metrics['acc']}) |
|
|
self.log('train_acc', metrics['acc']+1e-8, prog_bar=True) |
|
|
self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], prog_bar=True) |
|
|
self.log_dict(metrics) |
|
|
|
|
|
return total_loss |
|
|
|
|
|
@torch.no_grad() |
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
|
|
condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch) |
|
|
|
|
|
|
|
|
model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors) |
|
|
logits = model_output.logits |
|
|
mask = padding_mask & model_output.mask |
|
|
|
|
|
|
|
|
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) |
|
|
metrics = {} |
|
|
metrics['val_ce'] = ce |
|
|
metrics['val_ppl'] = torch.exp(ce) |
|
|
for k, ce_q in enumerate(ce_per_codebook): |
|
|
metrics[f'val_ce_q{k + 1}'] = ce_q |
|
|
metrics[f'val_ppl_q{k + 1}'] = torch.exp(ce_q) |
|
|
masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size) |
|
|
|
|
|
for k in range(self.audiolm.code_depth): |
|
|
self.top1_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) |
|
|
self.top10_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) |
|
|
self.val_steps.append(metrics) |
|
|
metrics['acc'] = [] |
|
|
metrics['acc_top10'] = [] |
|
|
for k in range(self.audiolm.code_depth): |
|
|
metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item()) |
|
|
metrics['acc_top10'].append(self.top10_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item()) |
|
|
metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])) |
|
|
metrics['acc_top10'] = torch.mean(torch.Tensor(metrics['acc_top10'])) |
|
|
|
|
|
return metrics['acc'] |
|
|
|
|
|
def on_validation_epoch_end(self) -> None: |
|
|
final_metrics = {} |
|
|
for i in self.val_steps: |
|
|
for k in i: |
|
|
final_metrics[k] = final_metrics.get(k, []) + [i[k]] |
|
|
final_metrics = {k: sum(v) / len(v) for k,v in list(final_metrics.items())} |
|
|
self.log_dict(final_metrics) |
|
|
|
|
|
q_acc = [] |
|
|
q_acc10 = [] |
|
|
for i in range(self.audiolm.code_depth): |
|
|
q_acc.append(self.top1_acc_metric[i].compute()) |
|
|
q_acc10.append(self.top10_acc_metric[i].compute()) |
|
|
self.log(f"val_Top1Acc_{i}", q_acc[-1]) |
|
|
self.log(f"val_Top10Acc_{i}", q_acc10[-1]) |
|
|
self.top1_acc_metric[i].reset() |
|
|
self.top10_acc_metric[i].reset() |
|
|
|
|
|
self.log('val_Top1Acc', sum(q_acc) / self.audiolm.code_depth) |
|
|
self.log('val_Top10Acc', sum(q_acc10) / self.audiolm.code_depth) |
|
|
|
|
|
return super().on_validation_epoch_end() |
|
|
|
|
|
|
|
|
def on_validation_epoch_start(self) -> None: |
|
|
self.val_steps = [] |
|
|
for i in range(self.audiolm.code_depth): |
|
|
self.top1_acc_metric[i].reset() |
|
|
self.top10_acc_metric[i].reset() |
|
|
|
|
|
if len(self.train_steps) > 0: |
|
|
train_metrics = {} |
|
|
for i in self.train_steps: |
|
|
for k in i: |
|
|
train_metrics[k] = train_metrics.get(k, []) + [i[k]] |
|
|
train_metrics = {k: sum(v) / len(v) for k,v in list(train_metrics.items())} |
|
|
self.log('train_summary_Top1Acc', train_metrics['acc']) |
|
|
self.log('train_summary_ce', train_metrics['ce']) |
|
|
self.train_steps = [] |
|
|
|
|
|
return super().on_validation_epoch_start() |
|
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
total_updates = self.cfg.optim.epochs * self.cfg.optim.updates_per_epoch |
|
|
optim_dict = {} |
|
|
|
|
|
if self.cfg.optim.optimizer == "adamw": |
|
|
optim_dict['optimizer'] = torch.optim.AdamW( |
|
|
self.audiolm.parameters(), |
|
|
lr=self.cfg.optim.lr, |
|
|
betas=tuple(self.cfg.optim.adam.betas), |
|
|
weight_decay=self.cfg.optim.adam.weight_decay, |
|
|
eps=self.cfg.optim.adam.eps, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
if self.cfg.schedule is None: |
|
|
pass |
|
|
elif self.cfg.schedule.lr_scheduler == "cosine": |
|
|
scheduler = CosineLRScheduler(optim_dict['optimizer'], |
|
|
total_steps=total_updates, |
|
|
warmup_steps=self.cfg.schedule.cosine.warmup, |
|
|
lr_min_ratio=self.cfg.schedule.cosine.lr_min_ratio, |
|
|
cycle_length=self.cfg.schedule.cosine.cycle_length, |
|
|
) |
|
|
optim_dict['lr_scheduler'] = {"scheduler": scheduler, "interval": "step"} |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
return optim_dict |
|
|
|
|
|
|
|
|
def _compute_cross_entropy( |
|
|
self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor |
|
|
) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: |
|
|
"""Compute cross entropy between multi-codebook targets and model's logits. |
|
|
The cross entropy is computed per codebook to provide codebook-level cross entropy. |
|
|
Valid timesteps for each of the codebook are pulled from the mask, where invalid |
|
|
timesteps are set to 0. |
|
|
|
|
|
Args: |
|
|
logits (torch.Tensor): Model's logits of shape [B, K, T, card]. |
|
|
targets (torch.Tensor): Target codes, of shape [B, K, T]. |
|
|
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. |
|
|
Returns: |
|
|
ce (torch.Tensor): Cross entropy averaged over the codebooks |
|
|
ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). |
|
|
""" |
|
|
|
|
|
B, K, T = targets.shape |
|
|
assert logits.shape[:-1] == targets.shape |
|
|
assert mask.shape == targets.shape |
|
|
ce = torch.zeros([], device=targets.device) |
|
|
ce_per_codebook: tp.List[torch.Tensor] = [] |
|
|
for k in range(K): |
|
|
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) |
|
|
targets_k = targets[:, k, ...].contiguous().view(-1) |
|
|
mask_k = mask[:, k, ...].contiguous().view(-1) |
|
|
ce_targets = targets_k[mask_k] |
|
|
ce_logits = logits_k[mask_k] |
|
|
q_ce = F.cross_entropy(ce_logits, ce_targets) |
|
|
ce += q_ce |
|
|
ce_per_codebook.append(q_ce.detach()) |
|
|
|
|
|
ce = ce / K |
|
|
return ce, ce_per_codebook |
|
|
|
|
|
class CosineLRScheduler(_LRScheduler): |
|
|
"""Cosine LR scheduler. |
|
|
|
|
|
Args: |
|
|
optimizer (Optimizer): Torch optimizer. |
|
|
warmup_steps (int): Number of warmup steps. |
|
|
total_steps (int): Total number of steps. |
|
|
lr_min_ratio (float): Minimum learning rate. |
|
|
cycle_length (float): Cycle length. |
|
|
""" |
|
|
def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int, |
|
|
lr_min_ratio: float = 0.0, cycle_length: float = 1.0): |
|
|
self.warmup_steps = warmup_steps |
|
|
assert self.warmup_steps >= 0 |
|
|
self.total_steps = total_steps |
|
|
assert self.total_steps >= 0 |
|
|
self.lr_min_ratio = lr_min_ratio |
|
|
self.cycle_length = cycle_length |
|
|
super().__init__(optimizer) |
|
|
|
|
|
def _get_sched_lr(self, lr: float, step: int): |
|
|
if step < self.warmup_steps: |
|
|
lr_ratio = step / self.warmup_steps |
|
|
lr = lr_ratio * lr |
|
|
elif step <= self.total_steps: |
|
|
s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) |
|
|
lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \ |
|
|
(1. + math.cos(math.pi * s / self.cycle_length)) |
|
|
lr = lr_ratio * lr |
|
|
else: |
|
|
lr_ratio = self.lr_min_ratio |
|
|
lr = lr_ratio * lr |
|
|
return lr |
|
|
|
|
|
def get_lr(self): |
|
|
return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs] |
|
|
|