|
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
import PIL.Image |
|
|
import torch |
|
|
import math |
|
|
import random |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
from typing import Tuple |
|
|
from PIL import Image |
|
|
|
|
|
from vae import WanVAE |
|
|
from vace.models.wan.modules.model_mm import VaceMMModel |
|
|
from vace.models.wan.modules.model_tr import VaceWanModel |
|
|
|
|
|
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback |
|
|
from diffusers.image_processor import PipelineImageInput |
|
|
from diffusers.loaders import WanLoraLoaderMixin |
|
|
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
|
|
from diffusers.utils import logging |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
from diffusers.video_processor import VideoProcessor |
|
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
|
|
from diffusers.utils import BaseOutput |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RefacadePipelineOutput(BaseOutput): |
|
|
frames: torch.Tensor |
|
|
meshes: torch.Tensor |
|
|
ref_img: torch.Tensor |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _pad_to_multiple(x: torch.Tensor, multiple: int, mode: str = "reflect"): |
|
|
H, W = x.shape[-2], x.shape[-1] |
|
|
pad_h = (multiple - H % multiple) % multiple |
|
|
pad_w = (multiple - W % multiple) % multiple |
|
|
pad = (0, pad_w, 0, pad_h) |
|
|
if pad_h or pad_w: |
|
|
x = F.pad(x, pad, mode=mode) |
|
|
return x, pad |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _unpad(x: torch.Tensor, pad): |
|
|
l, r, t, b = pad |
|
|
H, W = x.shape[-2], x.shape[-1] |
|
|
return x[..., t:H - b if b > 0 else H, l:W - r if r > 0 else W] |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _resize(x: torch.Tensor, size: tuple, is_mask: bool): |
|
|
mode = "nearest" if is_mask else "bilinear" |
|
|
if is_mask: |
|
|
return F.interpolate(x, size=size, mode=mode) |
|
|
else: |
|
|
return F.interpolate(x, size=size, mode=mode, align_corners=False) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _center_scale_foreground_to_canvas( |
|
|
x_f: torch.Tensor, |
|
|
m_f: torch.Tensor, |
|
|
target_hw: tuple, |
|
|
bg_value: float = 1.0, |
|
|
): |
|
|
C, H, W = x_f.shape |
|
|
H2, W2 = target_hw |
|
|
device = x_f.device |
|
|
ys, xs = (m_f > 0.5).nonzero(as_tuple=True) |
|
|
canvas = torch.full((C, H2, W2), bg_value, dtype=x_f.dtype, device=device) |
|
|
mask_canvas = torch.zeros((1, H2, W2), dtype=x_f.dtype, device=device) |
|
|
if ys.numel() == 0: |
|
|
return canvas, mask_canvas |
|
|
|
|
|
y0, y1 = ys.min().item(), ys.max().item() |
|
|
x0, x1 = xs.min().item(), xs.max().item() |
|
|
crop_img = x_f[:, y0:y1 + 1, x0:x1 + 1] |
|
|
crop_msk = m_f[y0:y1 + 1, x0:x1 + 1].unsqueeze(0) |
|
|
hc, wc = crop_msk.shape[-2], crop_msk.shape[-1] |
|
|
s = min(H2 / max(1, hc), W2 / max(1, wc)) |
|
|
Ht = max(1, min(H2, int(math.floor(hc * s)))) |
|
|
Wt = max(1, min(W2, int(math.floor(wc * s)))) |
|
|
crop_img_up = _resize(crop_img.unsqueeze(0), (Ht, Wt), is_mask=False).squeeze(0) |
|
|
crop_msk_up = _resize(crop_msk.unsqueeze(0), (Ht, Wt), is_mask=True).squeeze(0) |
|
|
crop_msk_up = (crop_msk_up > 0.5).to(crop_msk_up.dtype) |
|
|
|
|
|
top = (H2 - Ht) // 2 |
|
|
left = (W2 - Wt) // 2 |
|
|
canvas[:, top:top + Ht, left:left + Wt] = crop_img_up |
|
|
mask_canvas[:, top:top + Ht, left:left + Wt] = crop_msk_up |
|
|
return canvas, mask_canvas |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _sample_patch_size_from_hw( |
|
|
H: int, |
|
|
W: int, |
|
|
ratio: float = 0.2, |
|
|
min_px: int = 16, |
|
|
max_px: Optional[int] = None, |
|
|
) -> int: |
|
|
r = ratio |
|
|
raw = r * min(H, W) |
|
|
if max_px is None: |
|
|
max_px = min(192, min(H, W)) |
|
|
P = int(round(raw)) |
|
|
P = max(min_px, min(P, max_px)) |
|
|
P = int(P) |
|
|
return P |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _masked_patch_pack_to_center_rectangle( |
|
|
x_f: torch.Tensor, |
|
|
m_f: torch.Tensor, |
|
|
patch: int, |
|
|
fg_thresh: float = 0.8, |
|
|
bg_value: float = 1.0, |
|
|
min_patches: int = 4, |
|
|
flip_prob: float = 0.5, |
|
|
use_morph_erode: bool = False, |
|
|
): |
|
|
|
|
|
C, H, W = x_f.shape |
|
|
device = x_f.device |
|
|
P = int(patch) |
|
|
|
|
|
x_pad, pad = _pad_to_multiple(x_f, P, mode="reflect") |
|
|
l, r, t, b = pad |
|
|
H2, W2 = x_pad.shape[-2], x_pad.shape[-1] |
|
|
m_pad = F.pad(m_f.unsqueeze(0).unsqueeze(0), (l, r, t, b), mode="constant", value=0.0).squeeze(0) |
|
|
|
|
|
cs_img, cs_msk = _center_scale_foreground_to_canvas(x_pad, m_pad.squeeze(0), (H2, W2), bg_value) |
|
|
if (cs_msk > 0.5).sum() == 0: |
|
|
out_img = _unpad(cs_img, pad).clamp_(-1, 1) |
|
|
out_msk = _unpad(cs_msk, pad).clamp_(0, 1) |
|
|
return out_img, out_msk, True |
|
|
|
|
|
m_eff = cs_msk |
|
|
if use_morph_erode: |
|
|
erode_px = int(max(1, min(6, round(P * 0.03)))) |
|
|
m_eff = 1.0 - F.max_pool2d(1.0 - cs_msk, kernel_size=2 * erode_px + 1, stride=1, padding=erode_px) |
|
|
|
|
|
x_pad2, pad2 = _pad_to_multiple(cs_img, P, mode="reflect") |
|
|
m_pad2 = F.pad(m_eff, pad2, mode="constant", value=0.0) |
|
|
H3, W3 = x_pad2.shape[-2], x_pad2.shape[-1] |
|
|
|
|
|
m_pool = F.avg_pool2d(m_pad2, kernel_size=P, stride=P).view(-1) |
|
|
|
|
|
base_thr = float(fg_thresh) |
|
|
thr_candidates = [base_thr, max(base_thr - 0.05, 0.75), max(base_thr - 0.10, 0.60)] |
|
|
|
|
|
x_unf = F.unfold(x_pad2.unsqueeze(0), kernel_size=P, stride=P) |
|
|
N = x_unf.shape[-1] |
|
|
|
|
|
sel = None |
|
|
for thr in thr_candidates: |
|
|
idx = (m_pool >= (thr - 1e-6)).nonzero(as_tuple=False).squeeze(1) |
|
|
if idx.numel() >= min_patches: |
|
|
sel = idx |
|
|
break |
|
|
if sel is None: |
|
|
img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1) |
|
|
msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1) |
|
|
return img_fallback, msk_fallback, True |
|
|
|
|
|
sel = sel.to(device=device, dtype=torch.long) |
|
|
sel = sel[(sel >= 0) & (sel < N)] |
|
|
if sel.numel() == 0: |
|
|
img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1) |
|
|
msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1) |
|
|
return img_fallback, msk_fallback, True |
|
|
|
|
|
perm = torch.randperm(sel.numel(), device=device, dtype=torch.long) |
|
|
sel = sel[perm] |
|
|
chosen_x = x_unf[:, :, sel] |
|
|
K = chosen_x.shape[-1] |
|
|
if K == 0: |
|
|
img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1) |
|
|
msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1) |
|
|
return img_fallback, msk_fallback, True |
|
|
|
|
|
if flip_prob > 0: |
|
|
cx4 = chosen_x.view(1, C, P, P, K) |
|
|
do_flip = (torch.rand(K, device=device) < flip_prob) |
|
|
coin = (torch.rand(K, device=device) < 0.5) |
|
|
flip_h = do_flip & coin |
|
|
flip_v = do_flip & (~coin) |
|
|
if flip_h.any(): |
|
|
cx4[..., flip_h] = cx4[..., flip_h].flip(dims=[3]) |
|
|
if flip_v.any(): |
|
|
cx4[..., flip_v] = cx4[..., flip_v].flip(dims=[2]) |
|
|
chosen_x = cx4.view(1, C * P * P, K) |
|
|
|
|
|
max_cols = max(1, W3 // P) |
|
|
max_rows = max(1, H3 // P) |
|
|
capacity = max_rows * max_cols |
|
|
K_cap = min(K, capacity) |
|
|
cols = int(max(1, min(int(math.floor(math.sqrt(K_cap))), max_cols))) |
|
|
rows_full = min(max_rows, K_cap // cols) |
|
|
K_used = rows_full * cols |
|
|
if K_used == 0: |
|
|
img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1) |
|
|
msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1) |
|
|
return img_fallback, msk_fallback, True |
|
|
|
|
|
chosen_x = chosen_x[:, :, :K_used] |
|
|
rect_unf = torch.full((1, C * P * P, rows_full * cols), bg_value, device=device, dtype=x_f.dtype) |
|
|
rect_unf[:, :, :K_used] = chosen_x |
|
|
rect = F.fold(rect_unf, output_size=(rows_full * P, cols * P), kernel_size=P, stride=P).squeeze(0) |
|
|
|
|
|
ones_patch = torch.ones((1, 1 * P * P, K_used), device=device, dtype=x_f.dtype) |
|
|
mask_rect_unf = torch.zeros((1, 1 * P * P, rows_full * cols), device=device, dtype=x_f.dtype) |
|
|
mask_rect_unf[:, :, :K_used] = ones_patch |
|
|
rect_mask = F.fold(mask_rect_unf, output_size=(rows_full * P, cols * P), kernel_size=P, stride=P).squeeze(0) |
|
|
|
|
|
Hr, Wr = rect.shape[-2], rect.shape[-1] |
|
|
s = min(H3 / max(1, Hr), W3 / max(1, Wr)) |
|
|
Ht = min(max(1, int(math.floor(Hr * s))), H3) |
|
|
Wt = min(max(1, int(math.floor(Wr * s))), W3) |
|
|
|
|
|
rect_up = _resize(rect.unsqueeze(0), (Ht, Wt), is_mask=False).squeeze(0) |
|
|
rect_mask_up = _resize(rect_mask.unsqueeze(0), (Ht, Wt), is_mask=True).squeeze(0) |
|
|
|
|
|
canvas_x = torch.full((C, H3, W3), bg_value, device=device, dtype=x_f.dtype) |
|
|
canvas_m = torch.zeros((1, H3, W3), device=device, dtype=x_f.dtype) |
|
|
top, left = (H3 - Ht) // 2, (W3 - Wt) // 2 |
|
|
canvas_x[:, top:top + Ht, left:left + Wt] = rect_up |
|
|
canvas_m[:, top:top + Ht, left:left + Wt] = rect_mask_up |
|
|
|
|
|
out_img = _unpad(_unpad(canvas_x, pad2), pad).clamp_(-1, 1) |
|
|
out_msk = _unpad(_unpad(canvas_m, pad2), pad).clamp_(0, 1) |
|
|
return out_img, out_msk, False |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _compose_centered_foreground(x_f: torch.Tensor, m_f3: torch.Tensor, target_hw: Tuple[int, int], bg_value: float = 1.0): |
|
|
m_bin = (m_f3 > 0.5).float().mean(dim=0) |
|
|
m_bin = (m_bin > 0.5).float() |
|
|
return _center_scale_foreground_to_canvas(x_f, m_bin, target_hw, bg_value) |
|
|
|
|
|
class RefacadePipeline(DiffusionPipeline, WanLoraLoaderMixin): |
|
|
|
|
|
model_cpu_offload_seq = "texture_remover->transformer->vae" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vae, |
|
|
scheduler: FlowMatchEulerDiscreteScheduler, |
|
|
transformer: VaceMMModel = None, |
|
|
texture_remover: VaceWanModel = None, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.register_modules( |
|
|
vae=vae, |
|
|
texture_remover=texture_remover, |
|
|
transformer=transformer, |
|
|
scheduler=scheduler, |
|
|
) |
|
|
self.vae_scale_factor_temporal = 4 |
|
|
self.vae_scale_factor_spatial = 8 |
|
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) |
|
|
self.empty_embedding = torch.load( |
|
|
"./text_embedding/empty.pt", |
|
|
map_location="cpu" |
|
|
) |
|
|
self.negative_embedding = torch.load( |
|
|
"./text_embedding/negative.pt", |
|
|
map_location="cpu" |
|
|
) |
|
|
|
|
|
def vace_encode_masks(self, masks: torch.Tensor): |
|
|
masks = masks[:, :1, :, :, :] |
|
|
B, C, D, H, W = masks.shape |
|
|
patch_h, patch_w = self.vae_scale_factor_spatial, self.vae_scale_factor_spatial |
|
|
stride_t = self.vae_scale_factor_temporal |
|
|
patch_count = patch_h * patch_w |
|
|
new_D = (D + stride_t - 1) // stride_t |
|
|
new_H = 2 * (H // (patch_h * 2)) |
|
|
new_W = 2 * (W // (patch_w * 2)) |
|
|
masks = masks[:, 0] |
|
|
masks = masks.view(B, D, new_H, patch_h, new_W, patch_w) |
|
|
masks = masks.permute(0, 3, 5, 1, 2, 4) |
|
|
masks = masks.reshape(B, patch_count, D, new_H, new_W) |
|
|
masks = F.interpolate( |
|
|
masks, |
|
|
size=(new_D, new_H, new_W), |
|
|
mode="nearest-exact" |
|
|
) |
|
|
return masks |
|
|
|
|
|
def preprocess_conditions( |
|
|
self, |
|
|
video: Optional[List[PipelineImageInput]] = None, |
|
|
mask: Optional[List[PipelineImageInput]] = None, |
|
|
reference_image: Optional[PIL.Image.Image] = None, |
|
|
reference_mask: Optional[PIL.Image.Image] = None, |
|
|
batch_size: int = 1, |
|
|
height: int = 480, |
|
|
width: int = 832, |
|
|
num_frames: int = 81, |
|
|
reference_patch_ratio: float = 0.2, |
|
|
fg_thresh: float = 0.9, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
): |
|
|
|
|
|
base = self.vae_scale_factor_spatial * 2 |
|
|
video_height, video_width = self.video_processor.get_default_height_width(video[0]) |
|
|
|
|
|
if video_height * video_width > height * width: |
|
|
scale_w = width / video_width |
|
|
scale_h = height / video_height |
|
|
video_height, video_width = int(video_height * scale_h), int(video_width * scale_w) |
|
|
|
|
|
if video_height % base != 0 or video_width % base != 0: |
|
|
logger.warning( |
|
|
f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}. " |
|
|
) |
|
|
video_height = (video_height // base) * base |
|
|
video_width = (video_width // base) * base |
|
|
|
|
|
assert video_height * video_width <= height * width |
|
|
|
|
|
video = self.video_processor.preprocess_video(video, video_height, video_width) |
|
|
image_size = (video_height, video_width) |
|
|
|
|
|
mask = self.video_processor.preprocess_video(mask, video_height, video_width) |
|
|
mask = torch.clamp((mask + 1) / 2, min=0, max=1) |
|
|
|
|
|
video = video.to(dtype=dtype, device=device) |
|
|
mask = mask.to(dtype=dtype, device=device) |
|
|
|
|
|
if reference_image is None: |
|
|
raise ValueError("reference_image must be provided when using IMAGE_CONTROL mode.") |
|
|
|
|
|
if isinstance(reference_image, (list, tuple)): |
|
|
ref_img_pil = reference_image[0] |
|
|
else: |
|
|
ref_img_pil = reference_image |
|
|
|
|
|
if reference_mask is not None and isinstance(reference_mask, (list, tuple)): |
|
|
ref_mask_pil = reference_mask[0] |
|
|
else: |
|
|
ref_mask_pil = reference_mask |
|
|
|
|
|
ref_img_t = self.video_processor.preprocess(ref_img_pil, image_size[0], image_size[1]) |
|
|
if ref_img_t.dim() == 4 and ref_img_t.shape[0] == 1: |
|
|
ref_img_t = ref_img_t[0] |
|
|
if ref_img_t.shape[0] == 1: |
|
|
ref_img_t = ref_img_t.repeat(3, 1, 1) |
|
|
ref_img_t = ref_img_t.to(dtype=dtype, device=device) |
|
|
|
|
|
H, W = image_size |
|
|
if ref_mask_pil is not None: |
|
|
if not isinstance(ref_mask_pil, Image.Image): |
|
|
ref_mask_pil = Image.fromarray(np.array(ref_mask_pil)) |
|
|
ref_mask_pil = ref_mask_pil.convert("L") |
|
|
ref_mask_pil = ref_mask_pil.resize((W, H), Image.NEAREST) |
|
|
mask_arr = np.array(ref_mask_pil) |
|
|
m = torch.from_numpy(mask_arr).float() / 255.0 |
|
|
m = (m > 0.5).float() |
|
|
ref_msk3 = m.unsqueeze(0).repeat(3, 1, 1) |
|
|
else: |
|
|
ref_msk3 = torch.ones(3, H, W, dtype=dtype) |
|
|
|
|
|
ref_msk3 = ref_msk3.to(dtype=dtype, device=device) |
|
|
|
|
|
if math.isclose(reference_patch_ratio, 1.0, rel_tol=1e-6, abs_tol=1e-6): |
|
|
cs_img, cs_m = _compose_centered_foreground( |
|
|
x_f=ref_img_t, |
|
|
m_f3=ref_msk3, |
|
|
target_hw=image_size, |
|
|
bg_value=1.0, |
|
|
) |
|
|
ref_img_out = cs_img |
|
|
ref_mask_out = cs_m |
|
|
else: |
|
|
patch = _sample_patch_size_from_hw( |
|
|
H=image_size[0], |
|
|
W=image_size[1], |
|
|
ratio=reference_patch_ratio, |
|
|
) |
|
|
|
|
|
m_bin = (ref_msk3 > 0.5).float().mean(dim=0) |
|
|
m_bin = (m_bin > 0.5).float() |
|
|
reshuffled, reshuf_mask, used_fb = _masked_patch_pack_to_center_rectangle( |
|
|
x_f=ref_img_t, |
|
|
m_f=m_bin, |
|
|
patch=patch, |
|
|
fg_thresh=fg_thresh, |
|
|
bg_value=1.0, |
|
|
min_patches=4, |
|
|
) |
|
|
|
|
|
ref_img_out = reshuffled |
|
|
ref_mask_out = reshuf_mask |
|
|
|
|
|
B = video.shape[0] |
|
|
if batch_size is not None: |
|
|
B = batch_size |
|
|
|
|
|
ref_image = ref_img_out.unsqueeze(0).unsqueeze(2).expand(B, -1, -1, -1, -1).contiguous() |
|
|
ref_mask = ref_mask_out.unsqueeze(0).unsqueeze(2).expand(B, 3, -1, -1, -1).contiguous() |
|
|
|
|
|
ref_image = ref_image.to(dtype=dtype, device=device) |
|
|
ref_mask = ref_mask.to(dtype=dtype, device=device) |
|
|
|
|
|
return video[:, :, :num_frames], mask[:, :, :num_frames], ref_image, ref_mask |
|
|
|
|
|
@torch.no_grad() |
|
|
def texture_remove(self, foreground_latent): |
|
|
sample_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=1) |
|
|
text_embedding = torch.zeros( |
|
|
[256, 4096], |
|
|
device=foreground_latent.device, |
|
|
dtype=foreground_latent.dtype |
|
|
) |
|
|
context = text_embedding.unsqueeze(0).expand( |
|
|
foreground_latent.shape[0], -1, -1 |
|
|
).to(foreground_latent.device) |
|
|
sample_scheduler.set_timesteps(3, device=foreground_latent.device) |
|
|
timesteps = sample_scheduler.timesteps |
|
|
noise = torch.randn_like( |
|
|
foreground_latent, |
|
|
dtype=foreground_latent.dtype, |
|
|
device=foreground_latent.device |
|
|
) |
|
|
seq_len = math.ceil( |
|
|
noise.shape[2] * noise.shape[3] * noise.shape[4] / 4 |
|
|
) |
|
|
latents = noise |
|
|
arg_c = {"context": context, "seq_len": seq_len} |
|
|
with torch.autocast(device_type="cuda", dtype=torch.float16): |
|
|
for _, t in enumerate(timesteps): |
|
|
timestep = torch.stack([t]).to(foreground_latent.device) |
|
|
noise_pred_cond = self.texture_remover( |
|
|
latents, |
|
|
t=timestep, |
|
|
vace_context=foreground_latent, |
|
|
vace_context_scale=1, |
|
|
**arg_c |
|
|
)[0] |
|
|
temp_x0 = sample_scheduler.step( |
|
|
noise_pred_cond, t, latents, return_dict=False |
|
|
)[0] |
|
|
latents = temp_x0 |
|
|
return latents |
|
|
|
|
|
def dilate_mask_hw(self, mask: torch.Tensor, radius: int = 3) -> torch.Tensor: |
|
|
B, C, F_, H, W = mask.shape |
|
|
k = 2 * radius + 1 |
|
|
mask_2d = mask.permute(0, 2, 1, 3, 4).reshape(B * F_, C, H, W) |
|
|
kernel = torch.ones( |
|
|
(C, 1, k, k), |
|
|
device=mask.device, |
|
|
dtype=mask.dtype |
|
|
) |
|
|
dilated_2d = F.conv2d( |
|
|
mask_2d, |
|
|
weight=kernel, |
|
|
bias=None, |
|
|
stride=1, |
|
|
padding=radius, |
|
|
groups=C |
|
|
) |
|
|
dilated_2d = (dilated_2d > 0).to(mask.dtype) |
|
|
dilated = dilated_2d.view(B, F_, C, H, W).permute(0, 2, 1, 3, 4) |
|
|
return dilated |
|
|
|
|
|
def prepare_vace_latents( |
|
|
self, |
|
|
dilate_radius: int, |
|
|
video: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
reference_image: Optional[torch.Tensor] = None, |
|
|
reference_mask: Optional[torch.Tensor] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
) -> torch.Tensor: |
|
|
device = device or self._execution_device |
|
|
|
|
|
vae_dtype = self.vae.dtype |
|
|
video = video.to(dtype=vae_dtype) |
|
|
mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype) |
|
|
mask_clone = mask.clone() |
|
|
mask = self.dilate_mask_hw(mask, dilate_radius) |
|
|
inactive = video * (1 - mask) |
|
|
reactive = video * mask_clone |
|
|
reactive_latent = self.vae.encode(reactive) |
|
|
mesh_latent = self.texture_remove(reactive_latent) |
|
|
|
|
|
inactive_latent = self.vae.encode(inactive) |
|
|
ref_latent = self.vae.encode(reference_image) |
|
|
neg_ref_latent = self.vae.encode(torch.ones_like(reference_image)) |
|
|
|
|
|
reference_mask = torch.where(reference_mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype) |
|
|
mask = self.vace_encode_masks(mask) |
|
|
ref_mask = self.vace_encode_masks(reference_mask) |
|
|
|
|
|
return inactive_latent, mesh_latent, ref_latent, neg_ref_latent, mask, ref_mask |
|
|
|
|
|
|
|
|
def prepare_latents( |
|
|
self, |
|
|
batch_size: int, |
|
|
num_channels_latents: int = 16, |
|
|
height: int = 480, |
|
|
width: int = 832, |
|
|
num_frames: int = 81, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
|
latents: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
if latents is not None: |
|
|
return latents.to(device=device, dtype=dtype) |
|
|
|
|
|
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 |
|
|
shape = ( |
|
|
batch_size, |
|
|
num_channels_latents, |
|
|
num_latent_frames, |
|
|
int(height) // self.vae_scale_factor_spatial, |
|
|
int(width) // self.vae_scale_factor_spatial, |
|
|
) |
|
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
|
raise ValueError( |
|
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
|
) |
|
|
|
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
|
return latents |
|
|
|
|
|
@property |
|
|
def guidance_scale(self): |
|
|
return self._guidance_scale |
|
|
|
|
|
@property |
|
|
def do_classifier_free_guidance(self): |
|
|
return self._guidance_scale > 1.0 |
|
|
|
|
|
@property |
|
|
def num_timesteps(self): |
|
|
return self._num_timesteps |
|
|
|
|
|
@property |
|
|
def current_timestep(self): |
|
|
return self._current_timestep |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, |
|
|
video: Optional[PipelineImageInput] = None, |
|
|
mask: Optional[PipelineImageInput] = None, |
|
|
reference_image: Optional[PipelineImageInput] = None, |
|
|
reference_mask: Optional[PipelineImageInput] = None, |
|
|
conditioning_scale: float = 1.0, |
|
|
dilate_radius: int = 3, |
|
|
height: int = 480, |
|
|
width: int = 832, |
|
|
num_frames: int = 81, |
|
|
num_inference_steps: int = 20, |
|
|
guidance_scale: float = 1.5, |
|
|
num_videos_per_prompt: Optional[int] = 1, |
|
|
reference_patch_ratio: float = 0.2, |
|
|
fg_thresh: float = 0.9, |
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
|
latents: Optional[torch.Tensor] = None, |
|
|
output_type: Optional[str] = "np", |
|
|
return_dict: bool = True, |
|
|
): |
|
|
|
|
|
if num_frames % self.vae_scale_factor_temporal != 1: |
|
|
logger.warning( |
|
|
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." |
|
|
) |
|
|
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 |
|
|
num_frames = max(num_frames, 1) |
|
|
|
|
|
|
|
|
self._guidance_scale = guidance_scale |
|
|
|
|
|
device = self._execution_device |
|
|
batch_size = 1 |
|
|
|
|
|
vae_dtype = self.vae.dtype |
|
|
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
video, mask, reference_image, reference_mask = self.preprocess_conditions( |
|
|
video, |
|
|
mask, |
|
|
reference_image, |
|
|
reference_mask, |
|
|
batch_size, |
|
|
height, |
|
|
width, |
|
|
num_frames, |
|
|
reference_patch_ratio, |
|
|
fg_thresh, |
|
|
torch.float16, |
|
|
device, |
|
|
) |
|
|
|
|
|
inactive_latent, mesh_latent, ref_latent, neg_ref_latent, mask, ref_mask = self.prepare_vace_latents(dilate_radius, video, mask, reference_image, reference_mask, device) |
|
|
c = torch.cat([inactive_latent, mesh_latent, mask], dim=1) |
|
|
c1 = torch.cat([ref_latent, ref_mask], dim=1) |
|
|
c1_negative = torch.cat( |
|
|
[neg_ref_latent, torch.zeros_like(ref_mask)], |
|
|
dim=1 |
|
|
) |
|
|
|
|
|
num_channels_latents = 16 |
|
|
noise = self.prepare_latents( |
|
|
batch_size * num_videos_per_prompt, |
|
|
num_channels_latents, |
|
|
height, |
|
|
width, |
|
|
num_frames, |
|
|
torch.float16, |
|
|
device, |
|
|
generator, |
|
|
latents, |
|
|
) |
|
|
|
|
|
latents_cond = torch.cat([ref_latent, noise], dim=2) |
|
|
latents_uncond = torch.cat([neg_ref_latent, noise], dim=2) |
|
|
|
|
|
seq_len = math.ceil( |
|
|
latents_cond.shape[2] * |
|
|
latents_cond.shape[3] * |
|
|
latents_cond.shape[4] / 4 |
|
|
) |
|
|
seq_len_ref = math.ceil( |
|
|
ref_latent.shape[2] * |
|
|
ref_latent.shape[3] * |
|
|
ref_latent.shape[4] / 4 |
|
|
) |
|
|
context = self.empty_embedding.unsqueeze(0).expand(batch_size, -1, -1).to(device) |
|
|
context_neg = self.negative_embedding.unsqueeze(0).expand(batch_size, -1, -1).to(device) |
|
|
arg_c = { |
|
|
"context": context, |
|
|
"seq_len": seq_len, |
|
|
"seq_len_ref": seq_len_ref |
|
|
} |
|
|
arg_c_null = { |
|
|
"context": context_neg, |
|
|
"seq_len": seq_len, |
|
|
"seq_len_ref": seq_len_ref |
|
|
} |
|
|
|
|
|
self._num_timesteps = len(timesteps) |
|
|
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
|
for i, t in enumerate(timesteps): |
|
|
self._current_timestep = t |
|
|
timestep = t.expand(batch_size) |
|
|
|
|
|
with torch.autocast(device_type="cuda", dtype=torch.float16): |
|
|
noise_pred = self.transformer( |
|
|
latents_cond, |
|
|
t=timestep, |
|
|
vace_context=c, |
|
|
ref_context=c1, |
|
|
vace_context_scale=conditioning_scale, |
|
|
**arg_c, |
|
|
)[0] |
|
|
|
|
|
if self.do_classifier_free_guidance: |
|
|
noise_pred_uncond = self.transformer( |
|
|
latents_uncond, |
|
|
t=timestep, |
|
|
vace_context=c, |
|
|
ref_context=c1_negative, |
|
|
vace_context_scale=0, |
|
|
**arg_c_null, |
|
|
)[0] |
|
|
noise_pred = (noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)).unsqueeze(0) |
|
|
temp_x0 = self.scheduler.step(noise_pred[:, :, 1:], |
|
|
t, |
|
|
latents_cond[:, :, 1:], |
|
|
return_dict=False)[0] |
|
|
latents_cond = torch.cat([ref_latent, temp_x0], dim=2) |
|
|
latents_uncond = torch.cat([neg_ref_latent, temp_x0], dim=2) |
|
|
progress_bar.update() |
|
|
|
|
|
|
|
|
self._current_timestep = None |
|
|
|
|
|
if not output_type == "latent": |
|
|
latents = temp_x0 |
|
|
latents = latents.to(vae_dtype) |
|
|
video = self.vae.decode(latents) |
|
|
video = self.video_processor.postprocess_video(video, output_type=output_type) |
|
|
mesh = self.vae.decode(mesh_latent.to(vae_dtype)) |
|
|
mesh = self.video_processor.postprocess_video(mesh, output_type=output_type) |
|
|
ref_img = reference_image.cpu().squeeze(0).squeeze(1).permute(1, 2, 0).numpy() |
|
|
ref_img = ((ref_img+1)*255/2).astype(np.uint8) |
|
|
else: |
|
|
video = temp_x0 |
|
|
mesh = mesh_latent |
|
|
ref_img = ref_latent |
|
|
|
|
|
self.maybe_free_model_hooks() |
|
|
|
|
|
if not return_dict: |
|
|
return (video, mesh, ref_img) |
|
|
|
|
|
return RefacadePipelineOutput(frames=video, meshes=mesh, ref_img=ref_img) |
|
|
|