|
|
|
|
|
import soundfile as sf |
|
|
import os |
|
|
from librosa.filters import mel as librosa_mel_fn |
|
|
import sys |
|
|
import tools.torch_tools as torch_tools |
|
|
import torch.nn as nn |
|
|
import torch |
|
|
import numpy as np |
|
|
from einops import rearrange |
|
|
from scipy.signal import get_window |
|
|
from librosa.util import pad_center, tiny |
|
|
import librosa.util as librosa_util |
|
|
|
|
|
class AttrDict(dict): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super(AttrDict, self).__init__(*args, **kwargs) |
|
|
self.__dict__ = self |
|
|
|
|
|
def init_weights(m, mean=0.0, std=0.01): |
|
|
classname = m.__class__.__name__ |
|
|
if classname.find("Conv") != -1: |
|
|
m.weight.data.normal_(mean, std) |
|
|
|
|
|
|
|
|
def get_padding(kernel_size, dilation=1): |
|
|
return int((kernel_size * dilation - dilation) / 2) |
|
|
|
|
|
LRELU_SLOPE = 0.1 |
|
|
|
|
|
class ResBlock(torch.nn.Module): |
|
|
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): |
|
|
super(ResBlock, self).__init__() |
|
|
self.h = h |
|
|
self.convs1 = nn.ModuleList( |
|
|
[ |
|
|
torch.nn.utils.weight_norm( |
|
|
nn.Conv1d( |
|
|
channels, |
|
|
channels, |
|
|
kernel_size, |
|
|
1, |
|
|
dilation=dilation[0], |
|
|
padding=get_padding(kernel_size, dilation[0]), |
|
|
) |
|
|
), |
|
|
torch.nn.utils.weight_norm( |
|
|
nn.Conv1d( |
|
|
channels, |
|
|
channels, |
|
|
kernel_size, |
|
|
1, |
|
|
dilation=dilation[1], |
|
|
padding=get_padding(kernel_size, dilation[1]), |
|
|
) |
|
|
), |
|
|
torch.nn.utils.weight_norm( |
|
|
nn.Conv1d( |
|
|
channels, |
|
|
channels, |
|
|
kernel_size, |
|
|
1, |
|
|
dilation=dilation[2], |
|
|
padding=get_padding(kernel_size, dilation[2]), |
|
|
) |
|
|
), |
|
|
] |
|
|
) |
|
|
self.convs1.apply(init_weights) |
|
|
|
|
|
self.convs2 = nn.ModuleList( |
|
|
[ |
|
|
torch.nn.utils.weight_norm( |
|
|
nn.Conv1d( |
|
|
channels, |
|
|
channels, |
|
|
kernel_size, |
|
|
1, |
|
|
dilation=1, |
|
|
padding=get_padding(kernel_size, 1), |
|
|
) |
|
|
), |
|
|
torch.nn.utils.weight_norm( |
|
|
nn.Conv1d( |
|
|
channels, |
|
|
channels, |
|
|
kernel_size, |
|
|
1, |
|
|
dilation=1, |
|
|
padding=get_padding(kernel_size, 1), |
|
|
) |
|
|
), |
|
|
torch.nn.utils.weight_norm( |
|
|
nn.Conv1d( |
|
|
channels, |
|
|
channels, |
|
|
kernel_size, |
|
|
1, |
|
|
dilation=1, |
|
|
padding=get_padding(kernel_size, 1), |
|
|
) |
|
|
), |
|
|
] |
|
|
) |
|
|
self.convs2.apply(init_weights) |
|
|
|
|
|
def forward(self, x): |
|
|
for c1, c2 in zip(self.convs1, self.convs2): |
|
|
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) |
|
|
xt = c1(xt) |
|
|
xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) |
|
|
xt = c2(xt) |
|
|
x = xt + x |
|
|
return x |
|
|
|
|
|
def remove_weight_norm(self): |
|
|
for l in self.convs1: |
|
|
torch.nn.utils.remove_weight_norm(l) |
|
|
for l in self.convs2: |
|
|
torch.nn.utils.remove_weight_norm(l) |
|
|
|
|
|
|
|
|
class Generator_old(torch.nn.Module): |
|
|
def __init__(self, h): |
|
|
super(Generator_old, self).__init__() |
|
|
self.h = h |
|
|
self.num_kernels = len(h.resblock_kernel_sizes) |
|
|
self.num_upsamples = len(h.upsample_rates) |
|
|
self.conv_pre = torch.nn.utils.weight_norm( |
|
|
nn.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) |
|
|
) |
|
|
resblock = ResBlock |
|
|
|
|
|
self.ups = nn.ModuleList() |
|
|
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): |
|
|
self.ups.append( |
|
|
torch.nn.utils.weight_norm( |
|
|
nn.ConvTranspose1d( |
|
|
h.upsample_initial_channel // (2**i), |
|
|
h.upsample_initial_channel // (2 ** (i + 1)), |
|
|
k, |
|
|
u, |
|
|
padding=(k - u) // 2, |
|
|
) |
|
|
) |
|
|
) |
|
|
|
|
|
self.resblocks = nn.ModuleList() |
|
|
for i in range(len(self.ups)): |
|
|
ch = h.upsample_initial_channel // (2 ** (i + 1)) |
|
|
for j, (k, d) in enumerate( |
|
|
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) |
|
|
): |
|
|
self.resblocks.append(resblock(h, ch, k, d)) |
|
|
|
|
|
self.conv_post = torch.nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3)) |
|
|
self.ups.apply(init_weights) |
|
|
self.conv_post.apply(init_weights) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv_pre(x) |
|
|
for i in range(self.num_upsamples): |
|
|
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) |
|
|
x = self.ups[i](x) |
|
|
xs = None |
|
|
for j in range(self.num_kernels): |
|
|
if xs is None: |
|
|
xs = self.resblocks[i * self.num_kernels + j](x) |
|
|
else: |
|
|
xs += self.resblocks[i * self.num_kernels + j](x) |
|
|
x = xs / self.num_kernels |
|
|
x = torch.nn.functional.leaky_relu(x) |
|
|
x = self.conv_post(x) |
|
|
x = torch.tanh(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def remove_weight_norm(self): |
|
|
|
|
|
for l in self.ups: |
|
|
torch.nn.utils.remove_weight_norm(l) |
|
|
for l in self.resblocks: |
|
|
l.remove_weight_norm() |
|
|
torch.nn.utils.remove_weight_norm(self.conv_pre) |
|
|
torch.nn.utils.remove_weight_norm(self.conv_post) |
|
|
|
|
|
|
|
|
|
|
|
def nonlinearity(x): |
|
|
|
|
|
return x * torch.sigmoid(x) |
|
|
|
|
|
|
|
|
def Normalize(in_channels, num_groups=32): |
|
|
return torch.nn.GroupNorm( |
|
|
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True |
|
|
) |
|
|
|
|
|
class Downsample(nn.Module): |
|
|
def __init__(self, in_channels, with_conv): |
|
|
super().__init__() |
|
|
self.with_conv = with_conv |
|
|
if self.with_conv: |
|
|
|
|
|
|
|
|
self.conv = torch.nn.Conv2d( |
|
|
in_channels, in_channels, kernel_size=3, stride=2, padding=0 |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
if self.with_conv: |
|
|
pad = (0, 1, 0, 1) |
|
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
|
|
x = self.conv(x) |
|
|
else: |
|
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) |
|
|
return x |
|
|
|
|
|
|
|
|
class DownsampleTimeStride4(nn.Module): |
|
|
def __init__(self, in_channels, with_conv): |
|
|
super().__init__() |
|
|
self.with_conv = with_conv |
|
|
if self.with_conv: |
|
|
|
|
|
|
|
|
self.conv = torch.nn.Conv2d( |
|
|
in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
if self.with_conv: |
|
|
pad = (0, 1, 0, 1) |
|
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
|
|
x = self.conv(x) |
|
|
else: |
|
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) |
|
|
return x |
|
|
|
|
|
class Upsample(nn.Module): |
|
|
def __init__(self, in_channels, with_conv): |
|
|
super().__init__() |
|
|
self.with_conv = with_conv |
|
|
if self.with_conv: |
|
|
self.conv = torch.nn.Conv2d( |
|
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1 |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
|
|
if self.with_conv: |
|
|
x = self.conv(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class UpsampleTimeStride4(nn.Module): |
|
|
def __init__(self, in_channels, with_conv): |
|
|
super().__init__() |
|
|
self.with_conv = with_conv |
|
|
if self.with_conv: |
|
|
self.conv = torch.nn.Conv2d( |
|
|
in_channels, in_channels, kernel_size=5, stride=1, padding=2 |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") |
|
|
if self.with_conv: |
|
|
x = self.conv(x) |
|
|
return x |
|
|
|
|
|
class AttnBlock(nn.Module): |
|
|
def __init__(self, in_channels): |
|
|
super().__init__() |
|
|
self.in_channels = in_channels |
|
|
|
|
|
self.norm = Normalize(in_channels) |
|
|
self.q = torch.nn.Conv2d( |
|
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
|
) |
|
|
self.k = torch.nn.Conv2d( |
|
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
|
) |
|
|
self.v = torch.nn.Conv2d( |
|
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
|
) |
|
|
self.proj_out = torch.nn.Conv2d( |
|
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
h_ = x |
|
|
h_ = self.norm(h_) |
|
|
q = self.q(h_) |
|
|
k = self.k(h_) |
|
|
v = self.v(h_) |
|
|
|
|
|
|
|
|
b, c, h, w = q.shape |
|
|
q = q.reshape(b, c, h * w).contiguous() |
|
|
q = q.permute(0, 2, 1).contiguous() |
|
|
k = k.reshape(b, c, h * w).contiguous() |
|
|
w_ = torch.bmm(q, k).contiguous() |
|
|
w_ = w_ * (int(c) ** (-0.5)) |
|
|
w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
|
|
|
|
|
|
v = v.reshape(b, c, h * w).contiguous() |
|
|
w_ = w_.permute(0, 2, 1).contiguous() |
|
|
h_ = torch.bmm( |
|
|
v, w_ |
|
|
).contiguous() |
|
|
h_ = h_.reshape(b, c, h, w).contiguous() |
|
|
|
|
|
h_ = self.proj_out(h_) |
|
|
|
|
|
return x + h_ |
|
|
|
|
|
|
|
|
def make_attn(in_channels, attn_type="vanilla"): |
|
|
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" |
|
|
|
|
|
if attn_type == "vanilla": |
|
|
return AttnBlock(in_channels) |
|
|
elif attn_type == "none": |
|
|
return nn.Identity(in_channels) |
|
|
else: |
|
|
raise ValueError(attn_type) |
|
|
|
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
in_channels, |
|
|
out_channels=None, |
|
|
conv_shortcut=False, |
|
|
dropout, |
|
|
temb_channels=512, |
|
|
): |
|
|
super().__init__() |
|
|
self.in_channels = in_channels |
|
|
out_channels = in_channels if out_channels is None else out_channels |
|
|
self.out_channels = out_channels |
|
|
self.use_conv_shortcut = conv_shortcut |
|
|
|
|
|
self.norm1 = Normalize(in_channels) |
|
|
self.conv1 = torch.nn.Conv2d( |
|
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
|
|
) |
|
|
if temb_channels > 0: |
|
|
self.temb_proj = torch.nn.Linear(temb_channels, out_channels) |
|
|
self.norm2 = Normalize(out_channels) |
|
|
self.dropout = torch.nn.Dropout(dropout) |
|
|
self.conv2 = torch.nn.Conv2d( |
|
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1 |
|
|
) |
|
|
if self.in_channels != self.out_channels: |
|
|
if self.use_conv_shortcut: |
|
|
self.conv_shortcut = torch.nn.Conv2d( |
|
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
|
|
) |
|
|
else: |
|
|
self.nin_shortcut = torch.nn.Conv2d( |
|
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0 |
|
|
) |
|
|
|
|
|
def forward(self, x, temb): |
|
|
h = x |
|
|
h = self.norm1(h) |
|
|
h = nonlinearity(h) |
|
|
h = self.conv1(h) |
|
|
|
|
|
if temb is not None: |
|
|
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] |
|
|
|
|
|
h = self.norm2(h) |
|
|
h = nonlinearity(h) |
|
|
h = self.dropout(h) |
|
|
h = self.conv2(h) |
|
|
|
|
|
if self.in_channels != self.out_channels: |
|
|
if self.use_conv_shortcut: |
|
|
x = self.conv_shortcut(x) |
|
|
else: |
|
|
x = self.nin_shortcut(x) |
|
|
|
|
|
return x + h |
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
ch, |
|
|
out_ch, |
|
|
ch_mult=(1, 2, 4, 8), |
|
|
num_res_blocks, |
|
|
attn_resolutions, |
|
|
dropout=0.0, |
|
|
resamp_with_conv=True, |
|
|
in_channels, |
|
|
resolution, |
|
|
z_channels, |
|
|
double_z=True, |
|
|
use_linear_attn=False, |
|
|
attn_type="vanilla", |
|
|
downsample_time_stride4_levels=[], |
|
|
**ignore_kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
if use_linear_attn: |
|
|
attn_type = "linear" |
|
|
self.ch = ch |
|
|
self.temb_ch = 0 |
|
|
self.num_resolutions = len(ch_mult) |
|
|
self.num_res_blocks = num_res_blocks |
|
|
self.resolution = resolution |
|
|
self.in_channels = in_channels |
|
|
self.downsample_time_stride4_levels = downsample_time_stride4_levels |
|
|
|
|
|
if len(self.downsample_time_stride4_levels) > 0: |
|
|
assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( |
|
|
"The level to perform downsample 4 operation need to be smaller than the total resolution number %s" |
|
|
% str(self.num_resolutions) |
|
|
) |
|
|
|
|
|
|
|
|
self.conv_in = torch.nn.Conv2d( |
|
|
in_channels, self.ch, kernel_size=3, stride=1, padding=1 |
|
|
) |
|
|
|
|
|
curr_res = resolution |
|
|
in_ch_mult = (1,) + tuple(ch_mult) |
|
|
self.in_ch_mult = in_ch_mult |
|
|
self.down = nn.ModuleList() |
|
|
for i_level in range(self.num_resolutions): |
|
|
block = nn.ModuleList() |
|
|
attn = nn.ModuleList() |
|
|
block_in = ch * in_ch_mult[i_level] |
|
|
block_out = ch * ch_mult[i_level] |
|
|
for i_block in range(self.num_res_blocks): |
|
|
block.append( |
|
|
ResnetBlock( |
|
|
in_channels=block_in, |
|
|
out_channels=block_out, |
|
|
temb_channels=self.temb_ch, |
|
|
dropout=dropout, |
|
|
) |
|
|
) |
|
|
block_in = block_out |
|
|
if curr_res in attn_resolutions: |
|
|
attn.append(make_attn(block_in, attn_type=attn_type)) |
|
|
down = nn.Module() |
|
|
down.block = block |
|
|
down.attn = attn |
|
|
if i_level != self.num_resolutions - 1: |
|
|
if i_level in self.downsample_time_stride4_levels: |
|
|
down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) |
|
|
else: |
|
|
down.downsample = Downsample(block_in, resamp_with_conv) |
|
|
curr_res = curr_res // 2 |
|
|
self.down.append(down) |
|
|
|
|
|
|
|
|
self.mid = nn.Module() |
|
|
self.mid.block_1 = ResnetBlock( |
|
|
in_channels=block_in, |
|
|
out_channels=block_in, |
|
|
temb_channels=self.temb_ch, |
|
|
dropout=dropout, |
|
|
) |
|
|
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) |
|
|
self.mid.block_2 = ResnetBlock( |
|
|
in_channels=block_in, |
|
|
out_channels=block_in, |
|
|
temb_channels=self.temb_ch, |
|
|
dropout=dropout, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm_out = Normalize(block_in) |
|
|
self.conv_out = torch.nn.Conv2d( |
|
|
block_in, |
|
|
2 * z_channels if double_z else z_channels, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
temb = None |
|
|
|
|
|
hs = [self.conv_in(x)] |
|
|
for i_level in range(self.num_resolutions): |
|
|
for i_block in range(self.num_res_blocks): |
|
|
h = self.down[i_level].block[i_block](hs[-1], temb) |
|
|
if len(self.down[i_level].attn) > 0: |
|
|
h = self.down[i_level].attn[i_block](h) |
|
|
hs.append(h) |
|
|
if i_level != self.num_resolutions - 1: |
|
|
hs.append(self.down[i_level].downsample(hs[-1])) |
|
|
|
|
|
|
|
|
h = hs[-1] |
|
|
h = self.mid.block_1(h, temb) |
|
|
h = self.mid.attn_1(h) |
|
|
h = self.mid.block_2(h, temb) |
|
|
|
|
|
|
|
|
h = self.norm_out(h) |
|
|
h = nonlinearity(h) |
|
|
h = self.conv_out(h) |
|
|
return h |
|
|
|
|
|
|
|
|
class Decoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
ch, |
|
|
out_ch, |
|
|
ch_mult=(1, 2, 4, 8), |
|
|
num_res_blocks, |
|
|
attn_resolutions, |
|
|
dropout=0.0, |
|
|
resamp_with_conv=True, |
|
|
in_channels, |
|
|
resolution, |
|
|
z_channels, |
|
|
give_pre_end=False, |
|
|
tanh_out=False, |
|
|
use_linear_attn=False, |
|
|
downsample_time_stride4_levels=[], |
|
|
attn_type="vanilla", |
|
|
**ignorekwargs, |
|
|
): |
|
|
super().__init__() |
|
|
if use_linear_attn: |
|
|
attn_type = "linear" |
|
|
self.ch = ch |
|
|
self.temb_ch = 0 |
|
|
self.num_resolutions = len(ch_mult) |
|
|
self.num_res_blocks = num_res_blocks |
|
|
self.resolution = resolution |
|
|
self.in_channels = in_channels |
|
|
self.give_pre_end = give_pre_end |
|
|
self.tanh_out = tanh_out |
|
|
self.downsample_time_stride4_levels = downsample_time_stride4_levels |
|
|
|
|
|
if len(self.downsample_time_stride4_levels) > 0: |
|
|
assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( |
|
|
"The level to perform downsample 4 operation need to be smaller than the total resolution number %s" |
|
|
% str(self.num_resolutions) |
|
|
) |
|
|
|
|
|
|
|
|
(1,) + tuple(ch_mult) |
|
|
block_in = ch * ch_mult[self.num_resolutions - 1] |
|
|
curr_res = resolution // 2 ** (self.num_resolutions - 1) |
|
|
self.z_shape = (1, z_channels, curr_res, curr_res) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.conv_in = torch.nn.Conv2d( |
|
|
z_channels, block_in, kernel_size=3, stride=1, padding=1 |
|
|
) |
|
|
|
|
|
|
|
|
self.mid = nn.Module() |
|
|
self.mid.block_1 = ResnetBlock( |
|
|
in_channels=block_in, |
|
|
out_channels=block_in, |
|
|
temb_channels=self.temb_ch, |
|
|
dropout=dropout, |
|
|
) |
|
|
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) |
|
|
self.mid.block_2 = ResnetBlock( |
|
|
in_channels=block_in, |
|
|
out_channels=block_in, |
|
|
temb_channels=self.temb_ch, |
|
|
dropout=dropout, |
|
|
) |
|
|
|
|
|
|
|
|
self.up = nn.ModuleList() |
|
|
for i_level in reversed(range(self.num_resolutions)): |
|
|
block = nn.ModuleList() |
|
|
attn = nn.ModuleList() |
|
|
block_out = ch * ch_mult[i_level] |
|
|
for i_block in range(self.num_res_blocks + 1): |
|
|
block.append( |
|
|
ResnetBlock( |
|
|
in_channels=block_in, |
|
|
out_channels=block_out, |
|
|
temb_channels=self.temb_ch, |
|
|
dropout=dropout, |
|
|
) |
|
|
) |
|
|
block_in = block_out |
|
|
if curr_res in attn_resolutions: |
|
|
attn.append(make_attn(block_in, attn_type=attn_type)) |
|
|
up = nn.Module() |
|
|
up.block = block |
|
|
up.attn = attn |
|
|
if i_level != 0: |
|
|
if i_level - 1 in self.downsample_time_stride4_levels: |
|
|
up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) |
|
|
else: |
|
|
up.upsample = Upsample(block_in, resamp_with_conv) |
|
|
curr_res = curr_res * 2 |
|
|
self.up.insert(0, up) |
|
|
|
|
|
|
|
|
self.norm_out = Normalize(block_in) |
|
|
self.conv_out = torch.nn.Conv2d( |
|
|
block_in, out_ch, kernel_size=3, stride=1, padding=1 |
|
|
) |
|
|
|
|
|
def forward(self, z): |
|
|
|
|
|
self.last_z_shape = z.shape |
|
|
|
|
|
|
|
|
temb = None |
|
|
|
|
|
|
|
|
h = self.conv_in(z) |
|
|
|
|
|
|
|
|
h = self.mid.block_1(h, temb) |
|
|
h = self.mid.attn_1(h) |
|
|
h = self.mid.block_2(h, temb) |
|
|
|
|
|
|
|
|
for i_level in reversed(range(self.num_resolutions)): |
|
|
for i_block in range(self.num_res_blocks + 1): |
|
|
h = self.up[i_level].block[i_block](h, temb) |
|
|
if len(self.up[i_level].attn) > 0: |
|
|
h = self.up[i_level].attn[i_block](h) |
|
|
if i_level != 0: |
|
|
h = self.up[i_level].upsample(h) |
|
|
|
|
|
|
|
|
if self.give_pre_end: |
|
|
return h |
|
|
|
|
|
h = self.norm_out(h) |
|
|
h = nonlinearity(h) |
|
|
h = self.conv_out(h) |
|
|
if self.tanh_out: |
|
|
h = torch.tanh(h) |
|
|
return h |
|
|
|
|
|
|
|
|
class DiagonalGaussianDistribution(object): |
|
|
def __init__(self, parameters, deterministic=False): |
|
|
self.parameters = parameters |
|
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) |
|
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) |
|
|
self.deterministic = deterministic |
|
|
self.std = torch.exp(0.5 * self.logvar) |
|
|
self.var = torch.exp(self.logvar) |
|
|
if self.deterministic: |
|
|
self.var = self.std = torch.zeros_like(self.mean).to( |
|
|
device=self.parameters.device |
|
|
) |
|
|
|
|
|
def sample(self): |
|
|
x = self.mean + self.std * torch.randn(self.mean.shape).to( |
|
|
device=self.parameters.device |
|
|
) |
|
|
return x |
|
|
|
|
|
def kl(self, other=None): |
|
|
if self.deterministic: |
|
|
return torch.Tensor([0.0]) |
|
|
else: |
|
|
if other is None: |
|
|
return 0.5 * torch.mean( |
|
|
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, |
|
|
dim=[1, 2, 3], |
|
|
) |
|
|
else: |
|
|
return 0.5 * torch.mean( |
|
|
torch.pow(self.mean - other.mean, 2) / other.var |
|
|
+ self.var / other.var |
|
|
- 1.0 |
|
|
- self.logvar |
|
|
+ other.logvar, |
|
|
dim=[1, 2, 3], |
|
|
) |
|
|
|
|
|
def nll(self, sample, dims=[1, 2, 3]): |
|
|
if self.deterministic: |
|
|
return torch.Tensor([0.0]) |
|
|
logtwopi = np.log(2.0 * np.pi) |
|
|
return 0.5 * torch.sum( |
|
|
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, |
|
|
dim=dims, |
|
|
) |
|
|
|
|
|
def mode(self): |
|
|
return self.mean |
|
|
|
|
|
def get_vocoder_config_48k(): |
|
|
return { |
|
|
"resblock": "1", |
|
|
"num_gpus": 8, |
|
|
"batch_size": 128, |
|
|
"learning_rate": 0.0001, |
|
|
"adam_b1": 0.8, |
|
|
"adam_b2": 0.99, |
|
|
"lr_decay": 0.999, |
|
|
"seed": 1234, |
|
|
|
|
|
"upsample_rates": [6,5,4,2,2], |
|
|
"upsample_kernel_sizes": [12,10,8,4,4], |
|
|
"upsample_initial_channel": 1536, |
|
|
"resblock_kernel_sizes": [3,7,11,15], |
|
|
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]], |
|
|
|
|
|
"segment_size": 15360, |
|
|
"num_mels": 256, |
|
|
"n_fft": 2048, |
|
|
"hop_size": 480, |
|
|
"win_size": 2048, |
|
|
|
|
|
"sampling_rate": 48000, |
|
|
|
|
|
"fmin": 20, |
|
|
"fmax": 24000, |
|
|
"fmax_for_loss": None, |
|
|
|
|
|
"num_workers": 8, |
|
|
|
|
|
"dist_config": { |
|
|
"dist_backend": "nccl", |
|
|
"dist_url": "tcp://localhost:18273", |
|
|
"world_size": 1 |
|
|
} |
|
|
} |
|
|
|
|
|
def get_vocoder(config, device, mel_bins): |
|
|
name = "HiFi-GAN" |
|
|
speaker = "" |
|
|
if name == "MelGAN": |
|
|
if speaker == "LJSpeech": |
|
|
vocoder = torch.hub.load( |
|
|
"descriptinc/melgan-neurips", "load_melgan", "linda_johnson" |
|
|
) |
|
|
elif speaker == "universal": |
|
|
vocoder = torch.hub.load( |
|
|
"descriptinc/melgan-neurips", "load_melgan", "multi_speaker" |
|
|
) |
|
|
vocoder.mel2wav.eval() |
|
|
vocoder.mel2wav.to(device) |
|
|
elif name == "HiFi-GAN": |
|
|
if(mel_bins == 256): |
|
|
config = get_vocoder_config_48k() |
|
|
config = AttrDict(config) |
|
|
vocoder = Generator_old(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocoder.eval() |
|
|
vocoder.remove_weight_norm() |
|
|
vocoder = vocoder.to(device) |
|
|
|
|
|
else: |
|
|
raise ValueError(mel_bins) |
|
|
return vocoder |
|
|
|
|
|
def vocoder_infer(mels, vocoder, lengths=None): |
|
|
with torch.no_grad(): |
|
|
wavs = vocoder(mels).squeeze(1) |
|
|
|
|
|
|
|
|
wavs = (wavs.cpu().numpy()) |
|
|
|
|
|
if lengths is not None: |
|
|
wavs = wavs[:, :lengths] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return wavs |
|
|
|
|
|
@torch.no_grad() |
|
|
def vocoder_chunk_infer(mels, vocoder, lengths=None): |
|
|
chunk_size = 256*4 |
|
|
shift_size = 256*1 |
|
|
ov_size = chunk_size-shift_size |
|
|
|
|
|
|
|
|
for cinx in range(0, mels.shape[2], shift_size): |
|
|
if(cinx==0): |
|
|
wavs = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).float() |
|
|
num_samples = int(wavs.shape[-1]/chunk_size)*chunk_size |
|
|
wavs = wavs[:,0:num_samples] |
|
|
ov_sample = int(float(wavs.shape[-1]) * ov_size / chunk_size) |
|
|
ov_win = torch.linspace(0, 1, ov_sample, device="cuda").unsqueeze(0) |
|
|
ov_win = torch.cat([ov_win,1-ov_win],-1) |
|
|
if(cinx+chunk_size>=mels.shape[2]): |
|
|
break |
|
|
else: |
|
|
cur_wav = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1)[:,0:num_samples].float() |
|
|
wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * ov_win[:,-ov_sample:] + cur_wav[:,0:ov_sample] * ov_win[:,0:ov_sample] |
|
|
|
|
|
wavs = torch.cat([wavs, cur_wav[:,ov_sample:]],-1) |
|
|
if(cinx+chunk_size>=mels.shape[2]): |
|
|
break |
|
|
|
|
|
|
|
|
wavs = (wavs.cpu().numpy()) |
|
|
|
|
|
if lengths is not None: |
|
|
wavs = wavs[:, :lengths] |
|
|
|
|
|
return wavs |
|
|
|
|
|
def synth_one_sample(mel_input, mel_prediction, labels, vocoder): |
|
|
if vocoder is not None: |
|
|
|
|
|
wav_reconstruction = vocoder_infer( |
|
|
mel_input.permute(0, 2, 1), |
|
|
vocoder, |
|
|
) |
|
|
wav_prediction = vocoder_infer( |
|
|
mel_prediction.permute(0, 2, 1), |
|
|
vocoder, |
|
|
) |
|
|
else: |
|
|
wav_reconstruction = wav_prediction = None |
|
|
|
|
|
return wav_reconstruction, wav_prediction |
|
|
|
|
|
|
|
|
class AutoencoderKL(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
ddconfig=None, |
|
|
lossconfig=None, |
|
|
batchsize=None, |
|
|
embed_dim=None, |
|
|
time_shuffle=1, |
|
|
subband=1, |
|
|
sampling_rate=16000, |
|
|
ckpt_path=None, |
|
|
reload_from_ckpt=None, |
|
|
ignore_keys=[], |
|
|
image_key="fbank", |
|
|
colorize_nlabels=None, |
|
|
monitor=None, |
|
|
base_learning_rate=1e-5, |
|
|
scale_factor=1 |
|
|
): |
|
|
super().__init__() |
|
|
self.automatic_optimization = False |
|
|
assert ( |
|
|
"mel_bins" in ddconfig.keys() |
|
|
), "mel_bins is not specified in the Autoencoder config" |
|
|
num_mel = ddconfig["mel_bins"] |
|
|
self.image_key = image_key |
|
|
self.sampling_rate = sampling_rate |
|
|
self.encoder = Encoder(**ddconfig) |
|
|
self.decoder = Decoder(**ddconfig) |
|
|
|
|
|
self.loss = None |
|
|
self.subband = int(subband) |
|
|
|
|
|
if self.subband > 1: |
|
|
print("Use subband decomposition %s" % self.subband) |
|
|
|
|
|
assert ddconfig["double_z"] |
|
|
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) |
|
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) |
|
|
|
|
|
if self.image_key == "fbank": |
|
|
self.vocoder = get_vocoder(None, torch.device("cuda"), num_mel) |
|
|
self.embed_dim = embed_dim |
|
|
if colorize_nlabels is not None: |
|
|
assert type(colorize_nlabels) == int |
|
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) |
|
|
if monitor is not None: |
|
|
self.monitor = monitor |
|
|
if ckpt_path is not None: |
|
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
|
|
self.learning_rate = float(base_learning_rate) |
|
|
|
|
|
|
|
|
self.time_shuffle = time_shuffle |
|
|
self.reload_from_ckpt = reload_from_ckpt |
|
|
self.reloaded = False |
|
|
self.mean, self.std = None, None |
|
|
|
|
|
self.feature_cache = None |
|
|
self.flag_first_run = True |
|
|
self.train_step = 0 |
|
|
|
|
|
self.logger_save_dir = None |
|
|
self.logger_exp_name = None |
|
|
self.scale_factor = scale_factor |
|
|
|
|
|
print("Num parameters:") |
|
|
print("Encoder : ", sum(p.numel() for p in self.encoder.parameters())) |
|
|
print("Decoder : ", sum(p.numel() for p in self.decoder.parameters())) |
|
|
print("Vocoder : ", sum(p.numel() for p in self.vocoder.parameters())) |
|
|
|
|
|
def get_log_dir(self): |
|
|
if self.logger_save_dir is None and self.logger_exp_name is None: |
|
|
return os.path.join(self.logger.save_dir, self.logger._project) |
|
|
else: |
|
|
return os.path.join(self.logger_save_dir, self.logger_exp_name) |
|
|
|
|
|
def set_log_dir(self, save_dir, exp_name): |
|
|
self.logger_save_dir = save_dir |
|
|
self.logger_exp_name = exp_name |
|
|
|
|
|
def init_from_ckpt(self, path, ignore_keys=list()): |
|
|
sd = torch.load(path, map_location="cpu")["state_dict"] |
|
|
keys = list(sd.keys()) |
|
|
for k in keys: |
|
|
for ik in ignore_keys: |
|
|
if k.startswith(ik): |
|
|
print("Deleting key {} from state_dict.".format(k)) |
|
|
del sd[k] |
|
|
self.load_state_dict(sd, strict=False) |
|
|
print(f"Restored from {path}") |
|
|
|
|
|
def encode(self, x): |
|
|
|
|
|
|
|
|
h = self.encoder(x) |
|
|
moments = self.quant_conv(h) |
|
|
posterior = DiagonalGaussianDistribution(moments) |
|
|
return posterior |
|
|
|
|
|
def decode(self, z): |
|
|
z = self.post_quant_conv(z) |
|
|
dec = self.decoder(z) |
|
|
|
|
|
|
|
|
|
|
|
return dec |
|
|
|
|
|
def decode_to_waveform(self, dec): |
|
|
|
|
|
if self.image_key == "fbank": |
|
|
dec = dec.squeeze(1).permute(0, 2, 1) |
|
|
wav_reconstruction = vocoder_chunk_infer(dec, self.vocoder) |
|
|
elif self.image_key == "stft": |
|
|
dec = dec.squeeze(1).permute(0, 2, 1) |
|
|
wav_reconstruction = self.wave_decoder(dec) |
|
|
return wav_reconstruction |
|
|
|
|
|
def mel_spectrogram_to_waveform( |
|
|
self, mel, savepath=".", bs=None, name="outwav", save=True |
|
|
): |
|
|
|
|
|
if len(mel.size()) == 4: |
|
|
mel = mel.squeeze(1) |
|
|
mel = mel.permute(0, 2, 1) |
|
|
waveform = self.vocoder(mel) |
|
|
waveform = waveform.cpu().detach().numpy() |
|
|
|
|
|
|
|
|
return waveform |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_first_stage(self, x): |
|
|
return self.encode(x) |
|
|
|
|
|
@torch.no_grad() |
|
|
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): |
|
|
if predict_cids: |
|
|
if z.dim() == 4: |
|
|
z = torch.argmax(z.exp(), dim=1).long() |
|
|
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) |
|
|
z = rearrange(z, "b h w c -> b c h w").contiguous() |
|
|
|
|
|
z = 1.0 / self.scale_factor * z |
|
|
return self.decode(z) |
|
|
|
|
|
def decode_first_stage_withgrad(self, z): |
|
|
z = 1.0 / self.scale_factor * z |
|
|
return self.decode(z) |
|
|
|
|
|
def get_first_stage_encoding(self, encoder_posterior, use_mode=False): |
|
|
if isinstance(encoder_posterior, DiagonalGaussianDistribution) and not use_mode: |
|
|
z = encoder_posterior.sample() |
|
|
elif isinstance(encoder_posterior, DiagonalGaussianDistribution) and use_mode: |
|
|
z = encoder_posterior.mode() |
|
|
elif isinstance(encoder_posterior, torch.Tensor): |
|
|
z = encoder_posterior |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" |
|
|
) |
|
|
return self.scale_factor * z |
|
|
|
|
|
def visualize_latent(self, input): |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
np.save("input.npy", input.cpu().detach().numpy()) |
|
|
|
|
|
time_input = input.clone() |
|
|
time_input[:, :, :, :32] *= 0 |
|
|
time_input[:, :, :, :32] -= 11.59 |
|
|
|
|
|
np.save("time_input.npy", time_input.cpu().detach().numpy()) |
|
|
|
|
|
posterior = self.encode(time_input) |
|
|
latent = posterior.sample() |
|
|
np.save("time_latent.npy", latent.cpu().detach().numpy()) |
|
|
avg_latent = torch.mean(latent, dim=1) |
|
|
for i in range(avg_latent.size(0)): |
|
|
plt.imshow(avg_latent[i].cpu().detach().numpy().T) |
|
|
plt.savefig("freq_%s.png" % i) |
|
|
plt.close() |
|
|
|
|
|
freq_input = input.clone() |
|
|
freq_input[:, :, :512, :] *= 0 |
|
|
freq_input[:, :, :512, :] -= 11.59 |
|
|
|
|
|
np.save("freq_input.npy", freq_input.cpu().detach().numpy()) |
|
|
|
|
|
posterior = self.encode(freq_input) |
|
|
latent = posterior.sample() |
|
|
np.save("freq_latent.npy", latent.cpu().detach().numpy()) |
|
|
avg_latent = torch.mean(latent, dim=1) |
|
|
for i in range(avg_latent.size(0)): |
|
|
plt.imshow(avg_latent[i].cpu().detach().numpy().T) |
|
|
plt.savefig("time_%s.png" % i) |
|
|
plt.close() |
|
|
|
|
|
def get_input(self, batch): |
|
|
fname, text, label_indices, waveform, stft, fbank = ( |
|
|
batch["fname"], |
|
|
batch["text"], |
|
|
batch["label_vector"], |
|
|
batch["waveform"], |
|
|
batch["stft"], |
|
|
batch["log_mel_spec"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ret = {} |
|
|
|
|
|
ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( |
|
|
fbank.unsqueeze(1), |
|
|
stft.unsqueeze(1), |
|
|
fname, |
|
|
waveform.unsqueeze(1), |
|
|
) |
|
|
|
|
|
return ret |
|
|
|
|
|
def save_wave(self, batch_wav, fname, save_dir): |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
for wav, name in zip(batch_wav, fname): |
|
|
name = os.path.basename(name) |
|
|
|
|
|
sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) |
|
|
|
|
|
def get_last_layer(self): |
|
|
return self.decoder.conv_out.weight |
|
|
|
|
|
@torch.no_grad() |
|
|
def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): |
|
|
log = dict() |
|
|
x = batch.to(self.device) |
|
|
if not only_inputs: |
|
|
xrec, posterior = self(x) |
|
|
log["samples"] = self.decode(posterior.sample()) |
|
|
log["reconstructions"] = xrec |
|
|
|
|
|
log["inputs"] = x |
|
|
wavs = self._log_img(log, train=train, index=0, waveform=waveform) |
|
|
return wavs |
|
|
|
|
|
def _log_img(self, log, train=True, index=0, waveform=None): |
|
|
images_input = self.tensor2numpy(log["inputs"][index, 0]).T |
|
|
images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T |
|
|
images_samples = self.tensor2numpy(log["samples"][index, 0]).T |
|
|
|
|
|
if train: |
|
|
name = "train" |
|
|
else: |
|
|
name = "val" |
|
|
|
|
|
if self.logger is not None: |
|
|
self.logger.log_image( |
|
|
"img_%s" % name, |
|
|
[images_input, images_reconstruct, images_samples], |
|
|
caption=["input", "reconstruct", "samples"], |
|
|
) |
|
|
|
|
|
inputs, reconstructions, samples = ( |
|
|
log["inputs"], |
|
|
log["reconstructions"], |
|
|
log["samples"], |
|
|
) |
|
|
|
|
|
if self.image_key == "fbank": |
|
|
wav_original, wav_prediction = synth_one_sample( |
|
|
inputs[index], |
|
|
reconstructions[index], |
|
|
labels="validation", |
|
|
vocoder=self.vocoder, |
|
|
) |
|
|
wav_original, wav_samples = synth_one_sample( |
|
|
inputs[index], samples[index], labels="validation", vocoder=self.vocoder |
|
|
) |
|
|
wav_original, wav_samples, wav_prediction = ( |
|
|
wav_original[0], |
|
|
wav_samples[0], |
|
|
wav_prediction[0], |
|
|
) |
|
|
elif self.image_key == "stft": |
|
|
wav_prediction = ( |
|
|
self.decode_to_waveform(reconstructions)[index, 0] |
|
|
.cpu() |
|
|
.detach() |
|
|
.numpy() |
|
|
) |
|
|
wav_samples = ( |
|
|
self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() |
|
|
) |
|
|
wav_original = waveform[index, 0].cpu().detach().numpy() |
|
|
|
|
|
if self.logger is not None: |
|
|
self.logger.experiment.log( |
|
|
{ |
|
|
"original_%s" |
|
|
% name: wandb.Audio( |
|
|
wav_original, caption="original", sample_rate=self.sampling_rate |
|
|
), |
|
|
"reconstruct_%s" |
|
|
% name: wandb.Audio( |
|
|
wav_prediction, |
|
|
caption="reconstruct", |
|
|
sample_rate=self.sampling_rate, |
|
|
), |
|
|
"samples_%s" |
|
|
% name: wandb.Audio( |
|
|
wav_samples, caption="samples", sample_rate=self.sampling_rate |
|
|
), |
|
|
} |
|
|
) |
|
|
|
|
|
return wav_original, wav_prediction, wav_samples |
|
|
|
|
|
def tensor2numpy(self, tensor): |
|
|
return tensor.cpu().detach().numpy() |
|
|
|
|
|
def to_rgb(self, x): |
|
|
assert self.image_key == "segmentation" |
|
|
if not hasattr(self, "colorize"): |
|
|
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) |
|
|
x = torch.nn.functional.conv2d(x, weight=self.colorize) |
|
|
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 |
|
|
return x |
|
|
|
|
|
|
|
|
class IdentityFirstStage(torch.nn.Module): |
|
|
def __init__(self, *args, vq_interface=False, **kwargs): |
|
|
self.vq_interface = vq_interface |
|
|
super().__init__() |
|
|
|
|
|
def encode(self, x, *args, **kwargs): |
|
|
return x |
|
|
|
|
|
def decode(self, x, *args, **kwargs): |
|
|
return x |
|
|
|
|
|
def quantize(self, x, *args, **kwargs): |
|
|
if self.vq_interface: |
|
|
return x, None, [None, None, None] |
|
|
return x |
|
|
|
|
|
def forward(self, x, *args, **kwargs): |
|
|
return x |
|
|
|
|
|
|
|
|
def window_sumsquare( |
|
|
window, |
|
|
n_frames, |
|
|
hop_length, |
|
|
win_length, |
|
|
n_fft, |
|
|
dtype=np.float32, |
|
|
norm=None, |
|
|
): |
|
|
""" |
|
|
# from librosa 0.6 |
|
|
Compute the sum-square envelope of a window function at a given hop length. |
|
|
|
|
|
This is used to estimate modulation effects induced by windowing |
|
|
observations in short-time fourier transforms. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
window : string, tuple, number, callable, or list-like |
|
|
Window specification, as in `get_window` |
|
|
|
|
|
n_frames : int > 0 |
|
|
The number of analysis frames |
|
|
|
|
|
hop_length : int > 0 |
|
|
The number of samples to advance between frames |
|
|
|
|
|
win_length : [optional] |
|
|
The length of the window function. By default, this matches `n_fft`. |
|
|
|
|
|
n_fft : int > 0 |
|
|
The length of each analysis frame. |
|
|
|
|
|
dtype : np.dtype |
|
|
The data type of the output |
|
|
|
|
|
Returns |
|
|
------- |
|
|
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` |
|
|
The sum-squared envelope of the window function |
|
|
""" |
|
|
if win_length is None: |
|
|
win_length = n_fft |
|
|
|
|
|
n = n_fft + hop_length * (n_frames - 1) |
|
|
x = np.zeros(n, dtype=dtype) |
|
|
|
|
|
|
|
|
win_sq = get_window(window, win_length, fftbins=True) |
|
|
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 |
|
|
win_sq = librosa_util.pad_center(win_sq, n_fft) |
|
|
|
|
|
|
|
|
for i in range(n_frames): |
|
|
sample = i * hop_length |
|
|
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] |
|
|
return x |
|
|
|
|
|
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): |
|
|
""" |
|
|
PARAMS |
|
|
------ |
|
|
C: compression factor |
|
|
""" |
|
|
return normalize_fun(torch.clamp(x, min=clip_val) * C) |
|
|
|
|
|
|
|
|
def dynamic_range_decompression(x, C=1): |
|
|
""" |
|
|
PARAMS |
|
|
------ |
|
|
C: compression factor used to compress |
|
|
""" |
|
|
return torch.exp(x) / C |
|
|
|
|
|
|
|
|
class STFT(torch.nn.Module): |
|
|
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" |
|
|
|
|
|
def __init__(self, filter_length, hop_length, win_length, window="hann"): |
|
|
super(STFT, self).__init__() |
|
|
self.filter_length = filter_length |
|
|
self.hop_length = hop_length |
|
|
self.win_length = win_length |
|
|
self.window = window |
|
|
self.forward_transform = None |
|
|
scale = self.filter_length / self.hop_length |
|
|
fourier_basis = np.fft.fft(np.eye(self.filter_length)) |
|
|
|
|
|
cutoff = int((self.filter_length / 2 + 1)) |
|
|
fourier_basis = np.vstack( |
|
|
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] |
|
|
) |
|
|
|
|
|
forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) |
|
|
inverse_basis = torch.FloatTensor( |
|
|
np.linalg.pinv(scale * fourier_basis).T[:, None, :] |
|
|
) |
|
|
|
|
|
if window is not None: |
|
|
assert filter_length >= win_length |
|
|
|
|
|
fft_window = get_window(window, win_length, fftbins=True) |
|
|
fft_window = pad_center(fft_window, size=filter_length) |
|
|
fft_window = torch.from_numpy(fft_window).float() |
|
|
|
|
|
|
|
|
forward_basis *= fft_window |
|
|
inverse_basis *= fft_window |
|
|
|
|
|
self.register_buffer("forward_basis", forward_basis.float()) |
|
|
self.register_buffer("inverse_basis", inverse_basis.float()) |
|
|
|
|
|
def transform(self, input_data): |
|
|
|
|
|
device = self.forward_basis.device |
|
|
input_data = input_data.to(device) |
|
|
|
|
|
num_batches = input_data.size(0) |
|
|
num_samples = input_data.size(1) |
|
|
|
|
|
self.num_samples = num_samples |
|
|
|
|
|
|
|
|
input_data = input_data.view(num_batches, 1, num_samples) |
|
|
input_data = torch.nn.functional.pad( |
|
|
input_data.unsqueeze(1), |
|
|
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), |
|
|
mode="reflect", |
|
|
) |
|
|
input_data = input_data.squeeze(1) |
|
|
|
|
|
forward_transform = torch.nn.functional.conv1d( |
|
|
input_data, |
|
|
torch.autograd.Variable(self.forward_basis, requires_grad=False), |
|
|
stride=self.hop_length, |
|
|
padding=0, |
|
|
) |
|
|
|
|
|
cutoff = int((self.filter_length / 2) + 1) |
|
|
real_part = forward_transform[:, :cutoff, :] |
|
|
imag_part = forward_transform[:, cutoff:, :] |
|
|
|
|
|
magnitude = torch.sqrt(real_part**2 + imag_part**2) |
|
|
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) |
|
|
|
|
|
return magnitude, phase |
|
|
|
|
|
def inverse(self, magnitude, phase): |
|
|
|
|
|
device = self.forward_basis.device |
|
|
magnitude, phase = magnitude.to(device), phase.to(device) |
|
|
|
|
|
recombine_magnitude_phase = torch.cat( |
|
|
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 |
|
|
) |
|
|
|
|
|
inverse_transform = torch.nn.functional.conv_transpose1d( |
|
|
recombine_magnitude_phase, |
|
|
torch.autograd.Variable(self.inverse_basis, requires_grad=False), |
|
|
stride=self.hop_length, |
|
|
padding=0, |
|
|
) |
|
|
|
|
|
if self.window is not None: |
|
|
window_sum = window_sumsquare( |
|
|
self.window, |
|
|
magnitude.size(-1), |
|
|
hop_length=self.hop_length, |
|
|
win_length=self.win_length, |
|
|
n_fft=self.filter_length, |
|
|
dtype=np.float32, |
|
|
) |
|
|
|
|
|
approx_nonzero_indices = torch.from_numpy( |
|
|
np.where(window_sum > tiny(window_sum))[0] |
|
|
) |
|
|
window_sum = torch.autograd.Variable( |
|
|
torch.from_numpy(window_sum), requires_grad=False |
|
|
) |
|
|
window_sum = window_sum |
|
|
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ |
|
|
approx_nonzero_indices |
|
|
] |
|
|
|
|
|
|
|
|
inverse_transform *= float(self.filter_length) / self.hop_length |
|
|
|
|
|
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] |
|
|
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] |
|
|
|
|
|
return inverse_transform |
|
|
|
|
|
def forward(self, input_data): |
|
|
self.magnitude, self.phase = self.transform(input_data) |
|
|
reconstruction = self.inverse(self.magnitude, self.phase) |
|
|
return reconstruction |
|
|
|
|
|
|
|
|
class TacotronSTFT(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
filter_length, |
|
|
hop_length, |
|
|
win_length, |
|
|
n_mel_channels, |
|
|
sampling_rate, |
|
|
mel_fmin, |
|
|
mel_fmax, |
|
|
): |
|
|
super(TacotronSTFT, self).__init__() |
|
|
self.n_mel_channels = n_mel_channels |
|
|
self.sampling_rate = sampling_rate |
|
|
self.stft_fn = STFT(filter_length, hop_length, win_length) |
|
|
mel_basis = librosa_mel_fn( |
|
|
sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax |
|
|
) |
|
|
mel_basis = torch.from_numpy(mel_basis).float() |
|
|
self.register_buffer("mel_basis", mel_basis) |
|
|
|
|
|
def spectral_normalize(self, magnitudes, normalize_fun): |
|
|
output = dynamic_range_compression(magnitudes, normalize_fun) |
|
|
return output |
|
|
|
|
|
def spectral_de_normalize(self, magnitudes): |
|
|
output = dynamic_range_decompression(magnitudes) |
|
|
return output |
|
|
|
|
|
def mel_spectrogram(self, y, normalize_fun=torch.log): |
|
|
"""Computes mel-spectrograms from a batch of waves |
|
|
PARAMS |
|
|
------ |
|
|
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] |
|
|
|
|
|
RETURNS |
|
|
------- |
|
|
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) |
|
|
""" |
|
|
assert torch.min(y.data) >= -1, torch.min(y.data) |
|
|
assert torch.max(y.data) <= 1, torch.max(y.data) |
|
|
|
|
|
magnitudes, phases = self.stft_fn.transform(y) |
|
|
magnitudes = magnitudes.data |
|
|
mel_output = torch.matmul(self.mel_basis, magnitudes) |
|
|
mel_output = self.spectral_normalize(mel_output, normalize_fun) |
|
|
energy = torch.norm(magnitudes, dim=1) |
|
|
|
|
|
log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) |
|
|
|
|
|
return mel_output, log_magnitudes, energy |
|
|
|
|
|
|
|
|
def build_pretrained_models(ckpt): |
|
|
checkpoint = torch.load(ckpt, map_location="cpu") |
|
|
scale_factor = checkpoint["state_dict"]["scale_factor"].item() |
|
|
print("scale_factor: ", scale_factor) |
|
|
|
|
|
vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k} |
|
|
|
|
|
config = { |
|
|
"preprocessing": { |
|
|
"audio": { |
|
|
"sampling_rate": 48000, |
|
|
"max_wav_value": 32768, |
|
|
"duration": 10.24 |
|
|
}, |
|
|
"stft": { |
|
|
"filter_length": 2048, |
|
|
"hop_length": 480, |
|
|
"win_length": 2048 |
|
|
}, |
|
|
"mel": { |
|
|
"n_mel_channels": 256, |
|
|
"mel_fmin": 20, |
|
|
"mel_fmax": 24000 |
|
|
} |
|
|
}, |
|
|
"model": { |
|
|
"params": { |
|
|
"first_stage_config": { |
|
|
"params": { |
|
|
"sampling_rate": 48000, |
|
|
"batchsize": 4, |
|
|
"monitor": "val/rec_loss", |
|
|
"image_key": "fbank", |
|
|
"subband": 1, |
|
|
"embed_dim": 16, |
|
|
"time_shuffle": 1, |
|
|
"lossconfig": { |
|
|
"target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator", |
|
|
"params": { |
|
|
"disc_start": 50001, |
|
|
"kl_weight": 1000, |
|
|
"disc_weight": 0.5, |
|
|
"disc_in_channels": 1 |
|
|
} |
|
|
}, |
|
|
"ddconfig": { |
|
|
"double_z": True, |
|
|
"mel_bins": 256, |
|
|
"z_channels": 16, |
|
|
"resolution": 256, |
|
|
"downsample_time": False, |
|
|
"in_channels": 1, |
|
|
"out_ch": 1, |
|
|
"ch": 128, |
|
|
"ch_mult": [ |
|
|
1, |
|
|
2, |
|
|
4, |
|
|
8 |
|
|
], |
|
|
"num_res_blocks": 2, |
|
|
"attn_resolutions": [], |
|
|
"dropout": 0 |
|
|
} |
|
|
} |
|
|
}, |
|
|
} |
|
|
} |
|
|
} |
|
|
vae_config = config["model"]["params"]["first_stage_config"]["params"] |
|
|
vae_config["scale_factor"] = scale_factor |
|
|
|
|
|
vae = AutoencoderKL(**vae_config) |
|
|
vae.load_state_dict(vae_state_dict) |
|
|
|
|
|
fn_STFT = TacotronSTFT( |
|
|
config["preprocessing"]["stft"]["filter_length"], |
|
|
config["preprocessing"]["stft"]["hop_length"], |
|
|
config["preprocessing"]["stft"]["win_length"], |
|
|
config["preprocessing"]["mel"]["n_mel_channels"], |
|
|
config["preprocessing"]["audio"]["sampling_rate"], |
|
|
config["preprocessing"]["mel"]["mel_fmin"], |
|
|
config["preprocessing"]["mel"]["mel_fmax"], |
|
|
) |
|
|
|
|
|
vae.eval() |
|
|
fn_STFT.eval() |
|
|
return vae, fn_STFT |
|
|
|
|
|
|
|
|
if __name__=="__main__": |
|
|
vae, stft = build_pretrained_models() |
|
|
vae, stft = vae.cuda(), stft.cuda() |
|
|
|
|
|
json_file="outputs/wav.scp" |
|
|
out_path="outputs/Music_inverse" |
|
|
|
|
|
wavform = torch.randn(2,int(48000*10.24)) |
|
|
mel, _, waveform = torch_tools.wav_to_fbank2(wavform, target_length=-1, fn_STFT=stft) |
|
|
mel = mel.unsqueeze(1).cuda() |
|
|
print(mel.shape) |
|
|
|
|
|
|
|
|
true_latent = vae.get_first_stage_encoding(vae.encode_first_stage(mel)) |
|
|
print(true_latent.shape) |
|
|
true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach() |
|
|
|
|
|
true_latent = true_latent.reshape(true_latent.shape[0]*2,-1,true_latent.shape[2],true_latent.shape[3]) |
|
|
print("111", true_latent.size()) |
|
|
|
|
|
mel = vae.decode_first_stage(true_latent) |
|
|
print("222", mel.size()) |
|
|
audio = vae.decode_to_waveform(mel) |
|
|
print("333", audio.shape) |
|
|
|
|
|
|
|
|
|
|
|
|