import torch from torch import nn from networks.encoder import Encoder from networks.decoder import Decoder import numpy as np from tqdm import tqdm from einops import rearrange, repeat class Generator(nn.Module): def __init__(self, size, style_dim=512, motion_dim=40, scale=1): super(Generator, self).__init__() style_dim = style_dim * scale # encoder self.enc = Encoder(style_dim, motion_dim, scale) self.dec = Decoder(style_dim, motion_dim, scale) @property def device(self): if self._device is None: self._device = next(self.parameters()).device return self._device def get_alpha(self, x): return self.enc.enc_motion(x) def edit_img(self, img_source, d_l, v_l): z_s2r, feat_rgb = self.enc.enc_2r(img_source) alpha_r2s = self.enc.enc_r2t(z_s2r) alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda') img_recon = self.dec(z_s2r, [alpha_r2s], feat_rgb) return img_recon def animate(self, img_source, vid_target, d_l, v_l): alpha_start = self.get_alpha(vid_target[:, 0, :, :, :]) vid_target_recon = [] z_s2r, feat_rgb = self.enc.enc_2r(img_source) alpha_r2s = self.enc.enc_r2t(z_s2r) alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda') for i in tqdm(range(vid_target.size(1))): img_target = vid_target[:, i, :, :, :] alpha = self.enc.enc_transfer_vid(alpha_r2s, img_target, alpha_start) img_recon = self.dec(z_s2r, alpha, feat_rgb) vid_target_recon.append(img_recon.unsqueeze(2)) vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW return vid_target_recon def animate_batch(self, img_source, vid_target, d_l, v_l, chunk_size): b,t,c,h,w = vid_target.size() alpha_start = self.get_alpha(vid_target[:, 0, :, :, :]) # 1x40 vid_target_recon = [] z_s2r, feat_rgb = self.enc.enc_2r(img_source) alpha_r2s = self.enc.enc_r2t(z_s2r) alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda') bs = chunk_size chunks = t//bs alpha_start_r = repeat(alpha_start, 'b c -> (repeat b) c', repeat=bs) alpha_r2s_r = repeat(alpha_r2s, 'b c -> (repeat b) c', repeat=bs) feat_rgb_r = [repeat(feat, 'b c h w -> (repeat b) c h w', repeat=bs) for feat in feat_rgb] z_s2r_r = repeat(z_s2r, 'b c -> (repeat b) c', repeat=bs) for i in range(chunks+1): if i == chunks: img_target = vid_target[:, i*bs:, :, :, :] bs = t-i*bs alpha_start_r = alpha_start_r[:bs] alpha_r2s_r = alpha_r2s_r[:bs] feat_rgb_r = [feat[:bs] for feat in feat_rgb_r] z_s2r_r = z_s2r_r[:bs] else: img_target = vid_target[:, i*bs:(i+1)*bs, :, :, :] alpha = self.enc.enc_transfer_vid(alpha_r2s_r, img_target.squeeze(0), alpha_start_r) img_recon = self.dec(z_s2r_r, alpha, feat_rgb_r) # bs x 3 x h x w vid_target_recon.append(img_recon) vid_target_recon = torch.cat(vid_target_recon, dim=0).unsqueeze(0) # 1xTCHW vid_target_recon = rearrange(vid_target_recon, 'b t c h w -> b c t h w') return vid_target_recon # BCTHW def edit_vid(self, vid_target, d_l, v_l): img_source = vid_target[:, 0, :, :, :] alpha_start = self.get_alpha(vid_target[:, 0, :, :, :]) vid_target_recon = [] z_s2r, feat_rgb = self.enc.enc_2r(img_source) alpha_r2s = self.enc.enc_r2t(z_s2r) alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda') for i in tqdm(range(vid_target.size(1))): img_target = vid_target[:, i, :, :, :] alpha = self.enc.enc_transfer_vid(alpha_r2s, img_target, alpha_start) img_recon = self.dec(z_s2r, alpha, feat_rgb) vid_target_recon.append(img_recon.unsqueeze(2)) vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW return vid_target_recon def edit_vid_batch(self, vid_target, d_l, v_l, chunk_size): b,t,c,h,w = vid_target.size() img_source = vid_target[:, 0, :, :, :] alpha_start = self.get_alpha(img_source) # 1x40 vid_target_recon = [] z_s2r, feat_rgb = self.enc.enc_2r(img_source) alpha_r2s = self.enc.enc_r2t(z_s2r) alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda') bs = chunk_size chunks = t//bs alpha_start_r = repeat(alpha_start, 'b c -> (repeat b) c', repeat=bs) alpha_r2s_r = repeat(alpha_r2s, 'b c -> (repeat b) c', repeat=bs) feat_rgb_r = [repeat(feat, 'b c h w -> (repeat b) c h w', repeat=bs) for feat in feat_rgb] z_s2r_r = repeat(z_s2r, 'b c -> (repeat b) c', repeat=bs) for i in range(chunks+1): if i == chunks: img_target = vid_target[:, i*bs:, :, :, :] bs = t-i*bs alpha_start_r = alpha_start_r[:bs] alpha_r2s_r = alpha_r2s_r[:bs] feat_rgb_r = [feat[:bs] for feat in feat_rgb_r] z_s2r_r = z_s2r_r[:bs] else: img_target = vid_target[:, i*bs:(i+1)*bs, :, :, :] alpha = self.enc.enc_transfer_vid(alpha_r2s_r, img_target.squeeze(0), alpha_start_r) img_recon = self.dec(z_s2r_r, alpha, feat_rgb_r) # bs x 3 x h x w vid_target_recon.append(img_recon) vid_target_recon = torch.cat(vid_target_recon, dim=0).unsqueeze(0) # 1xTCHW vid_target_recon = rearrange(vid_target_recon, 'b t c h w -> b c t h w') return vid_target_recon # BCTHW def interpolate_img(self, img_source, d_l, v_l): vid_target_recon = [] step = 16 v_start = np.array([0.] * len(v_l)) v_end = np.array(v_l) stride = (v_end - v_start) / step z_s2r, feat_rgb = self.enc.enc_2r(img_source) v_tmp = v_start for i in range(step): v_tmp = v_tmp + stride alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp) img_recon = self.dec(z_s2r, alpha, feat_rgb) vid_target_recon.append(img_recon.unsqueeze(2)) for i in range(step): v_tmp = v_tmp - stride alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp) img_recon = self.dec(z_s2r, alpha, feat_rgb) vid_target_recon.append(img_recon.unsqueeze(2)) if (v_l[6]!=0) or (v_l[7]!=0) or (v_l[8]!=0) or (v_l[9]!=0): for i in range(step): v_tmp = v_tmp + stride alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp) img_recon = self.dec(z_s2r, alpha, feat_rgb) vid_target_recon.append(img_recon.unsqueeze(2)) for i in range(step): v_tmp = v_tmp - stride alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp) img_recon = self.dec(z_s2r, alpha, feat_rgb) vid_target_recon.append(img_recon.unsqueeze(2)) else: for i in range(step): v_tmp = v_tmp - stride alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp) img_recon = self.dec(z_s2r, alpha, feat_rgb) vid_target_recon.append(img_recon.unsqueeze(2)) for i in range(step): v_tmp = v_tmp + stride alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp) img_recon = self.dec(z_s2r, alpha, feat_rgb) vid_target_recon.append(img_recon.unsqueeze(2)) vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW return vid_target_recon def enc_img(self, img_source, d_l, v_l): """Core edit_img logic without timing - can be compiled""" z_s2r, feat_rgb = self.enc.enc_2r(img_source) alpha_r2s = self.enc.enc_r2t(z_s2r) # Create tensor directly on the same device as alpha_r2s v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0) alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor return z_s2r, alpha_r2s, feat_rgb def dec_img(self, z_s2r, alpha_r2s, feat_rgb): return self.dec(z_s2r, [alpha_r2s], feat_rgb) def dec_vid(self, z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch): # z_s2r: BC # alpha_r2s: BC # feat: BCHW # alpha_start: BC bs = img_target_batch.size(0) alpha_start = self.get_alpha(img_start) alpha_start_r = repeat(alpha_start, 'b c -> (repeat b) c', repeat=bs) alpha_r2s_r = repeat(alpha_r2s, 'b c -> (repeat b) c', repeat=bs) feat_rgb_r = [repeat(feat, 'b c h w -> (repeat b) c h w', repeat=bs) for feat in feat_rgb] z_s2r_r = repeat(z_s2r, 'b c -> (repeat b) c', repeat=bs) alpha = self.enc.enc_transfer_vid(alpha_r2s_r, img_target_batch, alpha_start_r) img_batch_recon = self.dec(z_s2r_r, alpha, feat_rgb_r) # bs x 3 x h x w return img_batch_recon