Refacade / pipeline.py
Ryan-PR's picture
Update pipeline.py
004b6a0 verified
from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image
import torch
import math
import random
import numpy as np
import torch.nn.functional as F
from typing import Tuple
from PIL import Image
from vae import WanVAE
from vace.models.wan.modules.model_mm import VaceMMModel
from vace.models.wan.modules.model_tr import VaceWanModel
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.image_processor import PipelineImageInput
from diffusers.loaders import WanLoraLoaderMixin
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import BaseOutput
from dataclasses import dataclass
@dataclass
class RefacadePipelineOutput(BaseOutput):
frames: torch.Tensor
meshes: torch.Tensor
ref_img: torch.Tensor
logger = logging.get_logger(__name__)
@torch.no_grad()
def _pad_to_multiple(x: torch.Tensor, multiple: int, mode: str = "reflect"):
H, W = x.shape[-2], x.shape[-1]
pad_h = (multiple - H % multiple) % multiple
pad_w = (multiple - W % multiple) % multiple
pad = (0, pad_w, 0, pad_h)
if pad_h or pad_w:
x = F.pad(x, pad, mode=mode)
return x, pad
@torch.no_grad()
def _unpad(x: torch.Tensor, pad):
l, r, t, b = pad
H, W = x.shape[-2], x.shape[-1]
return x[..., t:H - b if b > 0 else H, l:W - r if r > 0 else W]
@torch.no_grad()
def _resize(x: torch.Tensor, size: tuple, is_mask: bool):
mode = "nearest" if is_mask else "bilinear"
if is_mask:
return F.interpolate(x, size=size, mode=mode)
else:
return F.interpolate(x, size=size, mode=mode, align_corners=False)
@torch.no_grad()
def _center_scale_foreground_to_canvas(
x_f: torch.Tensor,
m_f: torch.Tensor,
target_hw: tuple,
bg_value: float = 1.0,
):
C, H, W = x_f.shape
H2, W2 = target_hw
device = x_f.device
ys, xs = (m_f > 0.5).nonzero(as_tuple=True)
canvas = torch.full((C, H2, W2), bg_value, dtype=x_f.dtype, device=device)
mask_canvas = torch.zeros((1, H2, W2), dtype=x_f.dtype, device=device)
if ys.numel() == 0:
return canvas, mask_canvas
y0, y1 = ys.min().item(), ys.max().item()
x0, x1 = xs.min().item(), xs.max().item()
crop_img = x_f[:, y0:y1 + 1, x0:x1 + 1]
crop_msk = m_f[y0:y1 + 1, x0:x1 + 1].unsqueeze(0)
hc, wc = crop_msk.shape[-2], crop_msk.shape[-1]
s = min(H2 / max(1, hc), W2 / max(1, wc))
Ht = max(1, min(H2, int(math.floor(hc * s))))
Wt = max(1, min(W2, int(math.floor(wc * s))))
crop_img_up = _resize(crop_img.unsqueeze(0), (Ht, Wt), is_mask=False).squeeze(0)
crop_msk_up = _resize(crop_msk.unsqueeze(0), (Ht, Wt), is_mask=True).squeeze(0)
crop_msk_up = (crop_msk_up > 0.5).to(crop_msk_up.dtype)
top = (H2 - Ht) // 2
left = (W2 - Wt) // 2
canvas[:, top:top + Ht, left:left + Wt] = crop_img_up
mask_canvas[:, top:top + Ht, left:left + Wt] = crop_msk_up
return canvas, mask_canvas
@torch.no_grad()
def _sample_patch_size_from_hw(
H: int,
W: int,
ratio: float = 0.2,
min_px: int = 16,
max_px: Optional[int] = None,
) -> int:
r = ratio
raw = r * min(H, W)
if max_px is None:
max_px = min(192, min(H, W))
P = int(round(raw))
P = max(min_px, min(P, max_px))
P = int(P)
return P
@torch.no_grad()
def _masked_patch_pack_to_center_rectangle(
x_f: torch.Tensor,
m_f: torch.Tensor,
patch: int,
fg_thresh: float = 0.8,
bg_value: float = 1.0,
min_patches: int = 4,
flip_prob: float = 0.5,
use_morph_erode: bool = False,
):
C, H, W = x_f.shape
device = x_f.device
P = int(patch)
x_pad, pad = _pad_to_multiple(x_f, P, mode="reflect")
l, r, t, b = pad
H2, W2 = x_pad.shape[-2], x_pad.shape[-1]
m_pad = F.pad(m_f.unsqueeze(0).unsqueeze(0), (l, r, t, b), mode="constant", value=0.0).squeeze(0)
cs_img, cs_msk = _center_scale_foreground_to_canvas(x_pad, m_pad.squeeze(0), (H2, W2), bg_value)
if (cs_msk > 0.5).sum() == 0:
out_img = _unpad(cs_img, pad).clamp_(-1, 1)
out_msk = _unpad(cs_msk, pad).clamp_(0, 1)
return out_img, out_msk, True
m_eff = cs_msk
if use_morph_erode:
erode_px = int(max(1, min(6, round(P * 0.03))))
m_eff = 1.0 - F.max_pool2d(1.0 - cs_msk, kernel_size=2 * erode_px + 1, stride=1, padding=erode_px)
x_pad2, pad2 = _pad_to_multiple(cs_img, P, mode="reflect")
m_pad2 = F.pad(m_eff, pad2, mode="constant", value=0.0)
H3, W3 = x_pad2.shape[-2], x_pad2.shape[-1]
m_pool = F.avg_pool2d(m_pad2, kernel_size=P, stride=P).view(-1)
base_thr = float(fg_thresh)
thr_candidates = [base_thr, max(base_thr - 0.05, 0.75), max(base_thr - 0.10, 0.60)]
x_unf = F.unfold(x_pad2.unsqueeze(0), kernel_size=P, stride=P)
N = x_unf.shape[-1]
sel = None
for thr in thr_candidates:
idx = (m_pool >= (thr - 1e-6)).nonzero(as_tuple=False).squeeze(1)
if idx.numel() >= min_patches:
sel = idx
break
if sel is None:
img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1)
msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1)
return img_fallback, msk_fallback, True
sel = sel.to(device=device, dtype=torch.long)
sel = sel[(sel >= 0) & (sel < N)]
if sel.numel() == 0:
img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1)
msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1)
return img_fallback, msk_fallback, True
perm = torch.randperm(sel.numel(), device=device, dtype=torch.long)
sel = sel[perm]
chosen_x = x_unf[:, :, sel]
K = chosen_x.shape[-1]
if K == 0:
img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1)
msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1)
return img_fallback, msk_fallback, True
if flip_prob > 0:
cx4 = chosen_x.view(1, C, P, P, K)
do_flip = (torch.rand(K, device=device) < flip_prob)
coin = (torch.rand(K, device=device) < 0.5)
flip_h = do_flip & coin
flip_v = do_flip & (~coin)
if flip_h.any():
cx4[..., flip_h] = cx4[..., flip_h].flip(dims=[3])
if flip_v.any():
cx4[..., flip_v] = cx4[..., flip_v].flip(dims=[2])
chosen_x = cx4.view(1, C * P * P, K)
max_cols = max(1, W3 // P)
max_rows = max(1, H3 // P)
capacity = max_rows * max_cols
K_cap = min(K, capacity)
cols = int(max(1, min(int(math.floor(math.sqrt(K_cap))), max_cols)))
rows_full = min(max_rows, K_cap // cols)
K_used = rows_full * cols
if K_used == 0:
img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1)
msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1)
return img_fallback, msk_fallback, True
chosen_x = chosen_x[:, :, :K_used]
rect_unf = torch.full((1, C * P * P, rows_full * cols), bg_value, device=device, dtype=x_f.dtype)
rect_unf[:, :, :K_used] = chosen_x
rect = F.fold(rect_unf, output_size=(rows_full * P, cols * P), kernel_size=P, stride=P).squeeze(0)
ones_patch = torch.ones((1, 1 * P * P, K_used), device=device, dtype=x_f.dtype)
mask_rect_unf = torch.zeros((1, 1 * P * P, rows_full * cols), device=device, dtype=x_f.dtype)
mask_rect_unf[:, :, :K_used] = ones_patch
rect_mask = F.fold(mask_rect_unf, output_size=(rows_full * P, cols * P), kernel_size=P, stride=P).squeeze(0)
Hr, Wr = rect.shape[-2], rect.shape[-1]
s = min(H3 / max(1, Hr), W3 / max(1, Wr))
Ht = min(max(1, int(math.floor(Hr * s))), H3)
Wt = min(max(1, int(math.floor(Wr * s))), W3)
rect_up = _resize(rect.unsqueeze(0), (Ht, Wt), is_mask=False).squeeze(0)
rect_mask_up = _resize(rect_mask.unsqueeze(0), (Ht, Wt), is_mask=True).squeeze(0)
canvas_x = torch.full((C, H3, W3), bg_value, device=device, dtype=x_f.dtype)
canvas_m = torch.zeros((1, H3, W3), device=device, dtype=x_f.dtype)
top, left = (H3 - Ht) // 2, (W3 - Wt) // 2
canvas_x[:, top:top + Ht, left:left + Wt] = rect_up
canvas_m[:, top:top + Ht, left:left + Wt] = rect_mask_up
out_img = _unpad(_unpad(canvas_x, pad2), pad).clamp_(-1, 1)
out_msk = _unpad(_unpad(canvas_m, pad2), pad).clamp_(0, 1)
return out_img, out_msk, False
@torch.no_grad()
def _compose_centered_foreground(x_f: torch.Tensor, m_f3: torch.Tensor, target_hw: Tuple[int, int], bg_value: float = 1.0):
m_bin = (m_f3 > 0.5).float().mean(dim=0)
m_bin = (m_bin > 0.5).float()
return _center_scale_foreground_to_canvas(x_f, m_bin, target_hw, bg_value)
class RefacadePipeline(DiffusionPipeline, WanLoraLoaderMixin):
model_cpu_offload_seq = "texture_remover->transformer->vae"
def __init__(
self,
vae,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer: VaceMMModel = None,
texture_remover: VaceWanModel = None,
):
super().__init__()
self.register_modules(
vae=vae,
texture_remover=texture_remover,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.empty_embedding = torch.load(
"./text_embedding/empty.pt",
map_location="cpu"
)
self.negative_embedding = torch.load(
"./text_embedding/negative.pt",
map_location="cpu"
)
def vace_encode_masks(self, masks: torch.Tensor):
masks = masks[:, :1, :, :, :]
B, C, D, H, W = masks.shape
patch_h, patch_w = self.vae_scale_factor_spatial, self.vae_scale_factor_spatial
stride_t = self.vae_scale_factor_temporal
patch_count = patch_h * patch_w
new_D = (D + stride_t - 1) // stride_t
new_H = 2 * (H // (patch_h * 2))
new_W = 2 * (W // (patch_w * 2))
masks = masks[:, 0]
masks = masks.view(B, D, new_H, patch_h, new_W, patch_w)
masks = masks.permute(0, 3, 5, 1, 2, 4)
masks = masks.reshape(B, patch_count, D, new_H, new_W)
masks = F.interpolate(
masks,
size=(new_D, new_H, new_W),
mode="nearest-exact"
)
return masks
def preprocess_conditions(
self,
video: Optional[List[PipelineImageInput]] = None,
mask: Optional[List[PipelineImageInput]] = None,
reference_image: Optional[PIL.Image.Image] = None,
reference_mask: Optional[PIL.Image.Image] = None,
batch_size: int = 1,
height: int = 480,
width: int = 832,
num_frames: int = 81,
reference_patch_ratio: float = 0.2,
fg_thresh: float = 0.9,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
base = self.vae_scale_factor_spatial * 2
video_height, video_width = self.video_processor.get_default_height_width(video[0])
if video_height * video_width > height * width:
scale_w = width / video_width
scale_h = height / video_height
video_height, video_width = int(video_height * scale_h), int(video_width * scale_w)
if video_height % base != 0 or video_width % base != 0:
logger.warning(
f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}. "
)
video_height = (video_height // base) * base
video_width = (video_width // base) * base
assert video_height * video_width <= height * width
video = self.video_processor.preprocess_video(video, video_height, video_width)
image_size = (video_height, video_width)
mask = self.video_processor.preprocess_video(mask, video_height, video_width)
mask = torch.clamp((mask + 1) / 2, min=0, max=1)
video = video.to(dtype=dtype, device=device)
mask = mask.to(dtype=dtype, device=device)
if reference_image is None:
raise ValueError("reference_image must be provided when using IMAGE_CONTROL mode.")
if isinstance(reference_image, (list, tuple)):
ref_img_pil = reference_image[0]
else:
ref_img_pil = reference_image
if reference_mask is not None and isinstance(reference_mask, (list, tuple)):
ref_mask_pil = reference_mask[0]
else:
ref_mask_pil = reference_mask
ref_img_t = self.video_processor.preprocess(ref_img_pil, image_size[0], image_size[1])
if ref_img_t.dim() == 4 and ref_img_t.shape[0] == 1:
ref_img_t = ref_img_t[0]
if ref_img_t.shape[0] == 1:
ref_img_t = ref_img_t.repeat(3, 1, 1)
ref_img_t = ref_img_t.to(dtype=dtype, device=device)
H, W = image_size
if ref_mask_pil is not None:
if not isinstance(ref_mask_pil, Image.Image):
ref_mask_pil = Image.fromarray(np.array(ref_mask_pil))
ref_mask_pil = ref_mask_pil.convert("L")
ref_mask_pil = ref_mask_pil.resize((W, H), Image.NEAREST)
mask_arr = np.array(ref_mask_pil)
m = torch.from_numpy(mask_arr).float() / 255.0
m = (m > 0.5).float()
ref_msk3 = m.unsqueeze(0).repeat(3, 1, 1)
else:
ref_msk3 = torch.ones(3, H, W, dtype=dtype)
ref_msk3 = ref_msk3.to(dtype=dtype, device=device)
if math.isclose(reference_patch_ratio, 1.0, rel_tol=1e-6, abs_tol=1e-6):
cs_img, cs_m = _compose_centered_foreground(
x_f=ref_img_t,
m_f3=ref_msk3,
target_hw=image_size,
bg_value=1.0,
)
ref_img_out = cs_img
ref_mask_out = cs_m
else:
patch = _sample_patch_size_from_hw(
H=image_size[0],
W=image_size[1],
ratio=reference_patch_ratio,
)
m_bin = (ref_msk3 > 0.5).float().mean(dim=0)
m_bin = (m_bin > 0.5).float()
reshuffled, reshuf_mask, used_fb = _masked_patch_pack_to_center_rectangle(
x_f=ref_img_t,
m_f=m_bin,
patch=patch,
fg_thresh=fg_thresh,
bg_value=1.0,
min_patches=4,
)
ref_img_out = reshuffled
ref_mask_out = reshuf_mask
B = video.shape[0]
if batch_size is not None:
B = batch_size
ref_image = ref_img_out.unsqueeze(0).unsqueeze(2).expand(B, -1, -1, -1, -1).contiguous()
ref_mask = ref_mask_out.unsqueeze(0).unsqueeze(2).expand(B, 3, -1, -1, -1).contiguous()
ref_image = ref_image.to(dtype=dtype, device=device)
ref_mask = ref_mask.to(dtype=dtype, device=device)
return video[:, :, :num_frames], mask[:, :, :num_frames], ref_image, ref_mask
@torch.no_grad()
def texture_remove(self, foreground_latent):
sample_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=1)
text_embedding = torch.zeros(
[256, 4096],
device=foreground_latent.device,
dtype=foreground_latent.dtype
)
context = text_embedding.unsqueeze(0).expand(
foreground_latent.shape[0], -1, -1
).to(foreground_latent.device)
sample_scheduler.set_timesteps(3, device=foreground_latent.device)
timesteps = sample_scheduler.timesteps
noise = torch.randn_like(
foreground_latent,
dtype=foreground_latent.dtype,
device=foreground_latent.device
)
seq_len = math.ceil(
noise.shape[2] * noise.shape[3] * noise.shape[4] / 4
)
latents = noise
arg_c = {"context": context, "seq_len": seq_len}
with torch.autocast(device_type="cuda", dtype=torch.float16):
for _, t in enumerate(timesteps):
timestep = torch.stack([t]).to(foreground_latent.device)
noise_pred_cond = self.texture_remover(
latents,
t=timestep,
vace_context=foreground_latent,
vace_context_scale=1,
**arg_c
)[0]
temp_x0 = sample_scheduler.step(
noise_pred_cond, t, latents, return_dict=False
)[0]
latents = temp_x0
return latents
def dilate_mask_hw(self, mask: torch.Tensor, radius: int = 3) -> torch.Tensor:
B, C, F_, H, W = mask.shape
k = 2 * radius + 1
mask_2d = mask.permute(0, 2, 1, 3, 4).reshape(B * F_, C, H, W)
kernel = torch.ones(
(C, 1, k, k),
device=mask.device,
dtype=mask.dtype
)
dilated_2d = F.conv2d(
mask_2d,
weight=kernel,
bias=None,
stride=1,
padding=radius,
groups=C
)
dilated_2d = (dilated_2d > 0).to(mask.dtype)
dilated = dilated_2d.view(B, F_, C, H, W).permute(0, 2, 1, 3, 4)
return dilated
def prepare_vace_latents(
self,
dilate_radius: int,
video: torch.Tensor,
mask: torch.Tensor,
reference_image: Optional[torch.Tensor] = None,
reference_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
device = device or self._execution_device
vae_dtype = self.vae.dtype
video = video.to(dtype=vae_dtype)
mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype)
mask_clone = mask.clone()
mask = self.dilate_mask_hw(mask, dilate_radius)
inactive = video * (1 - mask)
reactive = video * mask_clone
reactive_latent = self.vae.encode(reactive)
mesh_latent = self.texture_remove(reactive_latent)
inactive_latent = self.vae.encode(inactive)
ref_latent = self.vae.encode(reference_image)
neg_ref_latent = self.vae.encode(torch.ones_like(reference_image))
reference_mask = torch.where(reference_mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype)
mask = self.vace_encode_masks(mask)
ref_mask = self.vace_encode_masks(reference_mask)
return inactive_latent, mesh_latent, ref_latent, neg_ref_latent, mask, ref_mask
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int = 16,
height: int = 480,
width: int = 832,
num_frames: int = 81,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@torch.no_grad()
def __call__(
self,
video: Optional[PipelineImageInput] = None,
mask: Optional[PipelineImageInput] = None,
reference_image: Optional[PipelineImageInput] = None,
reference_mask: Optional[PipelineImageInput] = None,
conditioning_scale: float = 1.0,
dilate_radius: int = 3,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_inference_steps: int = 20,
guidance_scale: float = 1.5,
num_videos_per_prompt: Optional[int] = 1,
reference_patch_ratio: float = 0.2,
fg_thresh: float = 0.9,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
):
if num_frames % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
)
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
self._guidance_scale = guidance_scale
device = self._execution_device
batch_size = 1
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
video, mask, reference_image, reference_mask = self.preprocess_conditions(
video,
mask,
reference_image,
reference_mask,
batch_size,
height,
width,
num_frames,
reference_patch_ratio,
fg_thresh,
torch.float16,
device,
)
inactive_latent, mesh_latent, ref_latent, neg_ref_latent, mask, ref_mask = self.prepare_vace_latents(dilate_radius, video, mask, reference_image, reference_mask, device)
c = torch.cat([inactive_latent, mesh_latent, mask], dim=1)
c1 = torch.cat([ref_latent, ref_mask], dim=1)
c1_negative = torch.cat(
[neg_ref_latent, torch.zeros_like(ref_mask)],
dim=1
)
num_channels_latents = 16
noise = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float16,
device,
generator,
latents,
)
latents_cond = torch.cat([ref_latent, noise], dim=2)
latents_uncond = torch.cat([neg_ref_latent, noise], dim=2)
seq_len = math.ceil(
latents_cond.shape[2] *
latents_cond.shape[3] *
latents_cond.shape[4] / 4
)
seq_len_ref = math.ceil(
ref_latent.shape[2] *
ref_latent.shape[3] *
ref_latent.shape[4] / 4
)
context = self.empty_embedding.unsqueeze(0).expand(batch_size, -1, -1).to(device)
context_neg = self.negative_embedding.unsqueeze(0).expand(batch_size, -1, -1).to(device)
arg_c = {
"context": context,
"seq_len": seq_len,
"seq_len_ref": seq_len_ref
}
arg_c_null = {
"context": context_neg,
"seq_len": seq_len,
"seq_len_ref": seq_len_ref
}
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
self._current_timestep = t
timestep = t.expand(batch_size)
with torch.autocast(device_type="cuda", dtype=torch.float16):
noise_pred = self.transformer(
latents_cond,
t=timestep,
vace_context=c,
ref_context=c1,
vace_context_scale=conditioning_scale,
**arg_c,
)[0]
if self.do_classifier_free_guidance:
noise_pred_uncond = self.transformer(
latents_uncond,
t=timestep,
vace_context=c,
ref_context=c1_negative,
vace_context_scale=0,
**arg_c_null,
)[0]
noise_pred = (noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)).unsqueeze(0)
temp_x0 = self.scheduler.step(noise_pred[:, :, 1:],
t,
latents_cond[:, :, 1:],
return_dict=False)[0]
latents_cond = torch.cat([ref_latent, temp_x0], dim=2)
latents_uncond = torch.cat([neg_ref_latent, temp_x0], dim=2)
progress_bar.update()
self._current_timestep = None
if not output_type == "latent":
latents = temp_x0
latents = latents.to(vae_dtype)
video = self.vae.decode(latents)
video = self.video_processor.postprocess_video(video, output_type=output_type)
mesh = self.vae.decode(mesh_latent.to(vae_dtype))
mesh = self.video_processor.postprocess_video(mesh, output_type=output_type)
ref_img = reference_image.cpu().squeeze(0).squeeze(1).permute(1, 2, 0).numpy()
ref_img = ((ref_img+1)*255/2).astype(np.uint8)
else:
video = temp_x0
mesh = mesh_latent
ref_img = ref_latent
self.maybe_free_model_hooks()
if not return_dict:
return (video, mesh, ref_img)
return RefacadePipelineOutput(frames=video, meshes=mesh, ref_img=ref_img)