LIA-X / networks /generator.py
YaohuiW's picture
update
6b0ef0f
import torch
from torch import nn
from networks.encoder import Encoder
from networks.decoder import Decoder
import numpy as np
from tqdm import tqdm
from einops import rearrange, repeat
class Generator(nn.Module):
def __init__(self, size, style_dim=512, motion_dim=40, scale=1):
super(Generator, self).__init__()
style_dim = style_dim * scale
# encoder
self.enc = Encoder(style_dim, motion_dim, scale)
self.dec = Decoder(style_dim, motion_dim, scale)
@property
def device(self):
if self._device is None:
self._device = next(self.parameters()).device
return self._device
def get_alpha(self, x):
return self.enc.enc_motion(x)
def edit_img(self, img_source, d_l, v_l):
z_s2r, feat_rgb = self.enc.enc_2r(img_source)
alpha_r2s = self.enc.enc_r2t(z_s2r)
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
img_recon = self.dec(z_s2r, [alpha_r2s], feat_rgb)
return img_recon
def animate(self, img_source, vid_target, d_l, v_l):
alpha_start = self.get_alpha(vid_target[:, 0, :, :, :])
vid_target_recon = []
z_s2r, feat_rgb = self.enc.enc_2r(img_source)
alpha_r2s = self.enc.enc_r2t(z_s2r)
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
for i in tqdm(range(vid_target.size(1))):
img_target = vid_target[:, i, :, :, :]
alpha = self.enc.enc_transfer_vid(alpha_r2s, img_target, alpha_start)
img_recon = self.dec(z_s2r, alpha, feat_rgb)
vid_target_recon.append(img_recon.unsqueeze(2))
vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW
return vid_target_recon
def animate_batch(self, img_source, vid_target, d_l, v_l, chunk_size):
b,t,c,h,w = vid_target.size()
alpha_start = self.get_alpha(vid_target[:, 0, :, :, :]) # 1x40
vid_target_recon = []
z_s2r, feat_rgb = self.enc.enc_2r(img_source)
alpha_r2s = self.enc.enc_r2t(z_s2r)
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
bs = chunk_size
chunks = t//bs
alpha_start_r = repeat(alpha_start, 'b c -> (repeat b) c', repeat=bs)
alpha_r2s_r = repeat(alpha_r2s, 'b c -> (repeat b) c', repeat=bs)
feat_rgb_r = [repeat(feat, 'b c h w -> (repeat b) c h w', repeat=bs) for feat in feat_rgb]
z_s2r_r = repeat(z_s2r, 'b c -> (repeat b) c', repeat=bs)
for i in range(chunks+1):
if i == chunks:
img_target = vid_target[:, i*bs:, :, :, :]
bs = t-i*bs
alpha_start_r = alpha_start_r[:bs]
alpha_r2s_r = alpha_r2s_r[:bs]
feat_rgb_r = [feat[:bs] for feat in feat_rgb_r]
z_s2r_r = z_s2r_r[:bs]
else:
img_target = vid_target[:, i*bs:(i+1)*bs, :, :, :]
alpha = self.enc.enc_transfer_vid(alpha_r2s_r, img_target.squeeze(0), alpha_start_r)
img_recon = self.dec(z_s2r_r, alpha, feat_rgb_r) # bs x 3 x h x w
vid_target_recon.append(img_recon)
vid_target_recon = torch.cat(vid_target_recon, dim=0).unsqueeze(0) # 1xTCHW
vid_target_recon = rearrange(vid_target_recon, 'b t c h w -> b c t h w')
return vid_target_recon # BCTHW
def edit_vid(self, vid_target, d_l, v_l):
img_source = vid_target[:, 0, :, :, :]
alpha_start = self.get_alpha(vid_target[:, 0, :, :, :])
vid_target_recon = []
z_s2r, feat_rgb = self.enc.enc_2r(img_source)
alpha_r2s = self.enc.enc_r2t(z_s2r)
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
for i in tqdm(range(vid_target.size(1))):
img_target = vid_target[:, i, :, :, :]
alpha = self.enc.enc_transfer_vid(alpha_r2s, img_target, alpha_start)
img_recon = self.dec(z_s2r, alpha, feat_rgb)
vid_target_recon.append(img_recon.unsqueeze(2))
vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW
return vid_target_recon
def edit_vid_batch(self, vid_target, d_l, v_l, chunk_size):
b,t,c,h,w = vid_target.size()
img_source = vid_target[:, 0, :, :, :]
alpha_start = self.get_alpha(img_source) # 1x40
vid_target_recon = []
z_s2r, feat_rgb = self.enc.enc_2r(img_source)
alpha_r2s = self.enc.enc_r2t(z_s2r)
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
bs = chunk_size
chunks = t//bs
alpha_start_r = repeat(alpha_start, 'b c -> (repeat b) c', repeat=bs)
alpha_r2s_r = repeat(alpha_r2s, 'b c -> (repeat b) c', repeat=bs)
feat_rgb_r = [repeat(feat, 'b c h w -> (repeat b) c h w', repeat=bs) for feat in feat_rgb]
z_s2r_r = repeat(z_s2r, 'b c -> (repeat b) c', repeat=bs)
for i in range(chunks+1):
if i == chunks:
img_target = vid_target[:, i*bs:, :, :, :]
bs = t-i*bs
alpha_start_r = alpha_start_r[:bs]
alpha_r2s_r = alpha_r2s_r[:bs]
feat_rgb_r = [feat[:bs] for feat in feat_rgb_r]
z_s2r_r = z_s2r_r[:bs]
else:
img_target = vid_target[:, i*bs:(i+1)*bs, :, :, :]
alpha = self.enc.enc_transfer_vid(alpha_r2s_r, img_target.squeeze(0), alpha_start_r)
img_recon = self.dec(z_s2r_r, alpha, feat_rgb_r) # bs x 3 x h x w
vid_target_recon.append(img_recon)
vid_target_recon = torch.cat(vid_target_recon, dim=0).unsqueeze(0) # 1xTCHW
vid_target_recon = rearrange(vid_target_recon, 'b t c h w -> b c t h w')
return vid_target_recon # BCTHW
def interpolate_img(self, img_source, d_l, v_l):
vid_target_recon = []
step = 16
v_start = np.array([0.] * len(v_l))
v_end = np.array(v_l)
stride = (v_end - v_start) / step
z_s2r, feat_rgb = self.enc.enc_2r(img_source)
v_tmp = v_start
for i in range(step):
v_tmp = v_tmp + stride
alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp)
img_recon = self.dec(z_s2r, alpha, feat_rgb)
vid_target_recon.append(img_recon.unsqueeze(2))
for i in range(step):
v_tmp = v_tmp - stride
alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp)
img_recon = self.dec(z_s2r, alpha, feat_rgb)
vid_target_recon.append(img_recon.unsqueeze(2))
if (v_l[6]!=0) or (v_l[7]!=0) or (v_l[8]!=0) or (v_l[9]!=0):
for i in range(step):
v_tmp = v_tmp + stride
alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp)
img_recon = self.dec(z_s2r, alpha, feat_rgb)
vid_target_recon.append(img_recon.unsqueeze(2))
for i in range(step):
v_tmp = v_tmp - stride
alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp)
img_recon = self.dec(z_s2r, alpha, feat_rgb)
vid_target_recon.append(img_recon.unsqueeze(2))
else:
for i in range(step):
v_tmp = v_tmp - stride
alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp)
img_recon = self.dec(z_s2r, alpha, feat_rgb)
vid_target_recon.append(img_recon.unsqueeze(2))
for i in range(step):
v_tmp = v_tmp + stride
alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp)
img_recon = self.dec(z_s2r, alpha, feat_rgb)
vid_target_recon.append(img_recon.unsqueeze(2))
vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW
return vid_target_recon
def enc_img(self, img_source, d_l, v_l):
"""Core edit_img logic without timing - can be compiled"""
z_s2r, feat_rgb = self.enc.enc_2r(img_source)
alpha_r2s = self.enc.enc_r2t(z_s2r)
# Create tensor directly on the same device as alpha_r2s
v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0)
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor
return z_s2r, alpha_r2s, feat_rgb
def dec_img(self, z_s2r, alpha_r2s, feat_rgb):
return self.dec(z_s2r, [alpha_r2s], feat_rgb)
def dec_vid(self, z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch):
# z_s2r: BC
# alpha_r2s: BC
# feat: BCHW
# alpha_start: BC
bs = img_target_batch.size(0)
alpha_start = self.get_alpha(img_start)
alpha_start_r = repeat(alpha_start, 'b c -> (repeat b) c', repeat=bs)
alpha_r2s_r = repeat(alpha_r2s, 'b c -> (repeat b) c', repeat=bs)
feat_rgb_r = [repeat(feat, 'b c h w -> (repeat b) c h w', repeat=bs) for feat in feat_rgb]
z_s2r_r = repeat(z_s2r, 'b c -> (repeat b) c', repeat=bs)
alpha = self.enc.enc_transfer_vid(alpha_r2s_r, img_target_batch, alpha_start_r)
img_batch_recon = self.dec(z_s2r_r, alpha, feat_rgb_r) # bs x 3 x h x w
return img_batch_recon