import os import time import random import gradio as gr import cv2 import numpy as np from PIL import Image import subprocess import importlib def ensure_wan(): try: import wan # noqa print("[setup] wan already installed.") except ImportError: cmd = "pip install --no-build-isolation 'wan@git+https://github.com/Wan-Video/Wan2.1'" env = dict(os.environ) print(f"[setup] Installing wan2.1: {cmd}") subprocess.run(cmd, shell=True, check=True, env=env) def ensure_flash_attn(): try: import flash_attn # noqa from flash_attn.flash_attn_interface import flash_attn_func # type: ignore print("[setup] flash-attn seems OK.") return except Exception as e: print("[setup] flash-attn broken, will rebuild from source:", repr(e)) cmd = ( "pip uninstall -y flash-attn flash_attn || true && " "pip install flash-attn==2.7.2.post1 --no-build-isolation" ) print(f"[setup] Rebuilding flash-attn: {cmd}") subprocess.run(cmd, shell=True, check=True) importlib.invalidate_caches() ensure_flash_attn() ensure_wan() os.makedirs("./sam2/SAM2-Video-Predictor/checkpoints/", exist_ok=True) from huggingface_hub import snapshot_download def download_sam2(): snapshot_download( repo_id="facebook/sam2-hiera-large", local_dir="./sam2/SAM2-Video-Predictor/checkpoints/", ) print("Download sam2 completed") def download_refacade(): snapshot_download( repo_id="fishze/Refacade", local_dir="./models/", ) print("Download refacade completed") download_sam2() download_refacade() import torch import torch.nn.functional as F from decord import VideoReader, cpu from moviepy.editor import ImageSequenceClip from sam2.build_sam import build_sam2, build_sam2_video_predictor from sam2.sam2_image_predictor import SAM2ImagePredictor import spaces from pipeline import RefacadePipeline from vace.models.wan.modules.model_mm import VaceMMModel from vace.models.wan.modules.model_tr import VaceWanModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from wan.text2video import FlowUniPCMultistepScheduler from diffusers.utils import export_to_video, load_image, load_video from vae import WanVAE COLOR_PALETTE = [ (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255), (255, 128, 0), (128, 0, 255), (0, 128, 255), (128, 255, 0), ] video_length = 81 W = 1024 H = W device = "cuda" sam_device = "cpu" def get_pipe_image_and_video_predictor(): vae = WanVAE( vae_pth="./models/vae/Wan2.1_VAE.pth", dtype=torch.float16, ) pipe_device = "cuda" texture_remover = VaceWanModel.from_config( "./models/texture_remover/texture_remover.json" ) ckpt = torch.load( "./models/texture_remover/texture_remover.pth", map_location="cpu", ) texture_remover.load_state_dict(ckpt) texture_remover = texture_remover.to(dtype=torch.float16, device=pipe_device) model = VaceMMModel.from_config( "./models/refacade/refacade.json" ) ckpt = torch.load( "./models/refacade/refacade.pth", map_location="cpu", ) model.load_state_dict(ckpt) model = model.to(dtype=torch.float16, device=pipe_device) sample_scheduler = FlowUniPCMultistepScheduler( num_train_timesteps=1000, shift=1, ) pipe = RefacadePipeline( vae=vae, transformer=model, texture_remover=texture_remover, scheduler=sample_scheduler, ) pipe.to(pipe_device) sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt" config = "sam2_hiera_l.yaml" video_predictor = build_sam2_video_predictor(config, sam2_checkpoint, device=sam_device) model_sam = build_sam2(config, sam2_checkpoint, device=sam_device) model_sam.image_size = 1024 image_predictor = SAM2ImagePredictor(sam_model=model_sam) return pipe, image_predictor, video_predictor def get_video_info(video_path, video_state): video_state["input_points"] = [] video_state["scaled_points"] = [] video_state["input_labels"] = [] video_state["frame_idx"] = 0 vr = VideoReader(video_path, ctx=cpu(0)) first_frame = vr[0].asnumpy() del vr if first_frame.shape[0] > first_frame.shape[1]: W_ = W H_ = int(W_ * first_frame.shape[0] / first_frame.shape[1]) else: H_ = H W_ = int(H_ * first_frame.shape[1] / first_frame.shape[0]) first_frame = cv2.resize(first_frame, (W_, H_)) video_state["origin_images"] = np.expand_dims(first_frame, axis=0) video_state["inference_state"] = None video_state["video_path"] = video_path video_state["masks"] = None video_state["painted_images"] = None image = Image.fromarray(first_frame) return image def segment_frame(evt: gr.SelectData, label, video_state): if video_state["origin_images"] is None: return None x, y = evt.index new_point = [x, y] label_value = 1 if label == "Positive" else 0 video_state["input_points"].append(new_point) video_state["input_labels"].append(label_value) height, width = video_state["origin_images"][0].shape[0:2] scaled_points = [] for pt in video_state["input_points"]: sx = pt[0] / width sy = pt[1] / height scaled_points.append([sx, sy]) video_state["scaled_points"] = scaled_points image_predictor.set_image(video_state["origin_images"][0]) mask, _, _ = image_predictor.predict( point_coords=video_state["scaled_points"], point_labels=video_state["input_labels"], multimask_output=False, normalize_coords=False, ) mask = np.squeeze(mask) mask = cv2.resize(mask, (width, height)) mask = mask[:, :, None] color = ( np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0 ) color = color[None, None, :] org_image = video_state["origin_images"][0].astype(np.float32) / 255.0 painted_image = (1 - mask * 0.5) * org_image + mask * 0.5 * color painted_image = np.uint8(np.clip(painted_image * 255, 0, 255)) video_state["painted_images"] = np.expand_dims(painted_image, axis=0) video_state["masks"] = np.expand_dims(mask[:, :, 0], axis=0) for i in range(len(video_state["input_points"])): point = video_state["input_points"][i] if video_state["input_labels"][i] == 0: cv2.circle(painted_image, point, radius=3, color=(0, 0, 255), thickness=-1) else: cv2.circle(painted_image, point, radius=3, color=(255, 0, 0), thickness=-1) return Image.fromarray(painted_image) def clear_clicks(video_state): video_state["input_points"] = [] video_state["input_labels"] = [] video_state["scaled_points"] = [] video_state["inference_state"] = None video_state["masks"] = None video_state["painted_images"] = None return ( Image.fromarray(video_state["origin_images"][0]) if video_state["origin_images"] is not None else None ) def set_ref_image(ref_img, ref_state): if ref_img is None: return None if isinstance(ref_img, Image.Image): img_np = np.array(ref_img) else: img_np = ref_img ref_state["origin_image"] = img_np ref_state["input_points"] = [] ref_state["input_labels"] = [] ref_state["scaled_points"] = [] ref_state["mask"] = None return Image.fromarray(img_np) def segment_ref_frame(evt: gr.SelectData, label, ref_state): if ref_state["origin_image"] is None: return None x, y = evt.index new_point = [x, y] label_value = 1 if label == "Positive" else 0 ref_state["input_points"].append(new_point) ref_state["input_labels"].append(label_value) img = ref_state["origin_image"] h, w = img.shape[:2] scaled_points = [] for pt in ref_state["input_points"]: sx = pt[0] / w sy = pt[1] / h scaled_points.append([sx, sy]) ref_state["scaled_points"] = scaled_points image_predictor.set_image(img) mask, _, _ = image_predictor.predict( point_coords=scaled_points, point_labels=ref_state["input_labels"], multimask_output=False, normalize_coords=False, ) mask = np.squeeze(mask) mask = cv2.resize(mask, (w, h)) mask = mask[:, :, None] ref_state["mask"] = mask[:, :, 0] color = ( np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0 ) color = color[None, None, :] org_image = img.astype(np.float32) / 255.0 painted = (1 - mask * 0.5) * org_image + mask * 0.5 * color painted = np.uint8(np.clip(painted * 255, 0, 255)) for i in range(len(ref_state["input_points"])): point = ref_state["input_points"][i] if ref_state["input_labels"][i] == 0: cv2.circle(painted, point, radius=3, color=(0, 0, 255), thickness=-1) else: cv2.circle(painted, point, radius=3, color=(255, 0, 0), thickness=-1) return Image.fromarray(painted) def clear_ref_clicks(ref_state): ref_state["input_points"] = [] ref_state["input_labels"] = [] ref_state["scaled_points"] = [] ref_state["mask"] = None if ref_state["origin_image"] is None: return None return Image.fromarray(ref_state["origin_image"]) @spaces.GPU(duration=40) @torch.no_grad() def track_video(n_frames, video_state): input_points = video_state["input_points"] input_labels = video_state["input_labels"] frame_idx = video_state["frame_idx"] obj_id = video_state["obj_id"] scaled_points = video_state["scaled_points"] vr = VideoReader(video_state["video_path"], ctx=cpu(0)) height, width = vr[0].shape[0:2] images = [vr[i].asnumpy() for i in range(min(len(vr), n_frames))] del vr if images[0].shape[0] > images[0].shape[1]: W_ = W H_ = int(W_ * images[0].shape[0] / images[0].shape[1]) else: H_ = H W_ = int(H_ * images[0].shape[1] / images[0].shape[0]) images = [cv2.resize(img, (W_, H_)) for img in images] video_state["origin_images"] = images images_np = np.array(images) sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt" config = "sam2_hiera_l.yaml" with torch.autocast(device_type="cuda", dtype=torch.bfloat16): video_predictor_local = build_sam2_video_predictor( config, sam2_checkpoint, device="cuda" ) inference_state = video_predictor_local.init_state( images=images_np / 255, device="cuda" ) if len(torch.from_numpy(video_state["masks"][0]).shape) == 3: mask0 = torch.from_numpy(video_state["masks"][0])[:, :, 0] else: mask0 = torch.from_numpy(video_state["masks"][0]) video_predictor_local.add_new_mask( inference_state=inference_state, frame_idx=0, obj_id=obj_id, mask=mask0, ) output_frames = [] mask_frames = [] color = ( np.array( COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32, ) / 255.0 ) color = color[None, None, :] for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video( inference_state ): frame = images_np[out_frame_idx].astype(np.float32) / 255.0 mask = np.zeros((H, W, 3), dtype=np.float32) for i, logit in enumerate(out_mask_logits): out_mask = logit.cpu().squeeze().detach().numpy() out_mask = (out_mask[:, :, None] > 0).astype(np.float32) mask += out_mask mask = np.clip(mask, 0, 1) mask = cv2.resize(mask, (W_, H_)) mask_frames.append(mask) painted = (1 - mask * 0.5) * frame + mask * 0.5 * color painted = np.uint8(np.clip(painted * 255, 0, 255)) output_frames.append(painted) video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4" clip = ImageSequenceClip(output_frames, fps=15) clip.write_videofile( video_file, codec="libx264", audio=False, verbose=False, logger=None ) print("Tracking done") print("Tracking done, file:", video_file) try: exists = os.path.exists(video_file) size = os.path.getsize(video_file) if exists else -1 print("File exists?", exists, "size:", size) except Exception as e: print("Error checking video file:", repr(e)) return video_file, images, mask_frames @spaces.GPU(duration=100) def inference_and_return_video( dilate_radius, num_inference_steps, guidance_scale, ref_patch_ratio, fg_threshold, seed, video_frames, mask_frames, ref_state, ): if video_frames is None or mask_frames is None: print("No video frames or video masks.") return None, None, None if ref_state["origin_image"] is None or ref_state["mask"] is None: print("Reference image or reference mask missing.") return None, None, None images = video_frames masks = mask_frames video_frames_pil = [] mask_frames_pil = [] for img, msk in zip(images, masks): if not isinstance(img, np.ndarray): img = np.asarray(img) img_pil = Image.fromarray(img.astype(np.uint8)) if isinstance(msk, np.ndarray): if msk.ndim == 3: m2 = msk[..., 0] else: m2 = msk else: m2 = np.asarray(msk) m2 = (m2 > 0.5).astype(np.uint8) * 255 msk_pil = Image.fromarray(m2, mode="L") video_frames_pil.append(img_pil) mask_frames_pil.append(msk_pil) num_frames = len(video_frames_pil) h0, w0 = images[0].shape[:2] if h0 > w0: height = 832 width = 480 else: height = 480 width = 832 ref_img_np = ref_state["origin_image"] ref_mask_np = ref_state["mask"] ref_img_pil = Image.fromarray(ref_img_np.astype(np.uint8)) ref_mask_bin = (ref_mask_np > 0.5).astype(np.uint8) * 255 ref_mask_pil = Image.fromarray(ref_mask_bin, mode="L") pipe.to("cuda") with torch.no_grad(): retex_frames, mesh_frames, ref_img_out = pipe( video=video_frames_pil, mask=mask_frames_pil, reference_image=ref_img_pil, reference_mask=ref_mask_pil, conditioning_scale=1.0, height=height, width=width, num_frames=num_frames, dilate_radius=int(dilate_radius), num_inference_steps=int(num_inference_steps), guidance_scale=float(guidance_scale), reference_patch_ratio=float(ref_patch_ratio), fg_thresh=float(fg_threshold), generator=torch.Generator(device="cuda").manual_seed(seed), return_dict=False, ) retex_frames_uint8 = (np.clip(retex_frames[0], 0.0, 1.0) * 255).astype(np.uint8) mesh_frames_uint8 = (np.clip(mesh_frames[0], 0.0, 1.0) * 255).astype(np.uint8) retex_output_frames = [frame for frame in retex_frames_uint8] mesh_output_frames = [frame for frame in mesh_frames_uint8] if ref_img_out.dtype != np.uint8: ref_img_out = (np.clip(ref_img_out, 0.0, 1.0) * 255).astype(np.uint8) retex_video_file = f"/tmp/{time.time()}-{random.random()}-refacade_output.mp4" retex_clip = ImageSequenceClip(retex_output_frames, fps=16) retex_clip.write_videofile( retex_video_file, codec="libx264", audio=False, verbose=False, logger=None ) mesh_video_file = f"/tmp/{time.time()}-{random.random()}-mesh_output.mp4" mesh_clip = ImageSequenceClip(mesh_output_frames, fps=16) mesh_clip.write_videofile( mesh_video_file, codec="libx264", audio=False, verbose=False, logger=None ) ref_image_to_show = ref_img_out return retex_video_file, mesh_video_file, ref_image_to_show text = """
Refaçade Video Retexture Demo
Huggingface Model Github arXiv Demo Page
Youze Huang*, Penghui Ruan*, Bojia Zi*, Xianbiao Qi, Jianan Wang, Rong Xiao
* Equal contribution     Corresponding author
""" pipe, image_predictor, video_predictor = get_pipe_image_and_video_predictor() css = """ #my-btn { width: 60% !important; margin: 0 auto; } #my-video1 { width: 60% !important; height: 35% !important; margin: 0 auto; } #my-video { width: 60% !important; height: 35% !important; margin: 0 auto; } #my-md { margin: 0 auto; } #my-btn2 { width: 60% !important; margin: 0 auto; } #my-btn2 button { width: 120px !important; max-width: 120px !important; min-width: 120px !important; height: 70px !important; max-height: 70px !important; min-height: 70px !important; margin: 8px !important; border-radius: 8px !important; overflow: hidden !important; white-space: normal !important; } #my-btn3 { width: 60% !important; margin: 0 auto; } #ref_title { text-align: center; } #ref-image { width: 60% !important; height: 35% !important; margin: 0 auto; } #ref-mask { width: 60% !important; height: 35% !important; margin: 0 auto; } #mesh-row { width: 60% !important; margin: 0 auto; } """ with gr.Blocks() as demo: gr.HTML(f"") video_state = gr.State( { "origin_images": None, "inference_state": None, "masks": None, "painted_images": None, "video_path": None, "input_points": [], "scaled_points": [], "input_labels": [], "frame_idx": 0, "obj_id": 1, } ) ref_state = gr.State( { "origin_image": None, "input_points": [], "input_labels": [], "scaled_points": [], "mask": None, } ) video_frames_state = gr.State(None) mask_frames_state = gr.State(None) gr.Markdown(f"
{text}
") with gr.Column(): gr.Markdown("Step1: Upload a Source Video", elem_id="ref_title") video_input = gr.Video(label="Upload Video", elem_id="my-video1") gr.Examples( examples=[ ["./examples/1.mp4"], ["./examples/2.mp4"], ["./examples/3.mp4"], ["./examples/4.mp4"], ["./examples/5.mp4"], ["./examples/6.mp4"], ], inputs=[video_input], label="You can upload or choose a source video below to retexture.", elem_id="my-btn2" ) gr.Markdown("Step2: Extract the First Frame & Click for Segmentation", elem_id="ref_title") get_info_btn = gr.Button("Extract First Frame", elem_id="my-btn") image_output = gr.Image( label="First Frame Segmentation", interactive=True, elem_id="my-video", ) with gr.Row(elem_id="my-btn"): point_prompt = gr.Radio( ["Positive", "Negative"], label="Click Type", value="Positive" ) clear_btn = gr.Button("Clear All Clicks") gr.Markdown("Step3: Track to Get Video Mask", elem_id="ref_title") with gr.Row(elem_id="my-btn"): n_frames_slider = gr.Slider( minimum=1, maximum=81, value=33, step=1, label="Tracking Frames (4N+1)" ) track_btn = gr.Button("Tracking") video_output = gr.Video(label="Tracking Result", elem_id="my-video") gr.Markdown("Step4: Upload a Reference Image & Click for Reference Segmentation", elem_id="ref_title") ref_image_input = gr.Image( label="Upload Reference Image", elem_id="ref-image", interactive=True ) gr.Examples( examples=[ ["./examples/reference_image/1.png"], ["./examples/reference_image/2.png"], ["./examples/reference_image/3.png"], ["./examples/reference_image/4.png"], ["./examples/reference_image/5.png"], ["./examples/reference_image/6.png"], ["./examples/reference_image/7.png"], ["./examples/reference_image/8.png"], ["./examples/reference_image/9.png"], ], inputs=[ref_image_input], label="You can upload or choose a reference image below to retexture.", elem_id="my-btn3" ) ref_image_display = gr.Image( label="Reference Mask Segmentation", elem_id="ref-mask", interactive=True, ) with gr.Row(elem_id="my-btn"): ref_point_prompt = gr.Radio( ["Positive", "Negative"], label="Ref Click Type", value="Positive" ) ref_clear_btn = gr.Button("Clear Ref Clicks") gr.Markdown("Step5: Retexture", elem_id="ref_title") with gr.Column(elem_id="my-btn"): dilate_radius_slider = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Mask Dilation Radius", ) inference_steps_slider = gr.Slider( minimum=10, maximum=30, value=20, step=1, label="Num Inference Steps", ) guidance_slider = gr.Slider( minimum=1.0, maximum=3.0, value=1.5, step=0.1, label="Guidance Scale", ) ref_patch_slider = gr.Slider( minimum=0.05, maximum=1.0, value=0.1, step=0.05, label="Reference Patch Ratio", ) fg_threshold_slider = gr.Slider( minimum=0.7, maximum=1.0, value=1.0, step=0.01, label="Jigsaw Patches' Foreground Coverage Threshold", ) seed_slider = gr.Slider( minimum=0, maximum=2147483647, value=42, step=1, label="Seed", ) remove_btn = gr.Button("Retexture", elem_id="my-btn") with gr.Row(elem_id="mesh-row"): mesh_video = gr.Video(label="Untextured Object") ref_image_final = gr.Image( label="Jigsawed Reference Image", interactive=False, ) remove_video = gr.Video(label="Retexture Results", elem_id="my-video") remove_btn.click( inference_and_return_video, inputs=[ dilate_radius_slider, inference_steps_slider, guidance_slider, ref_patch_slider, fg_threshold_slider, seed_slider, video_frames_state, mask_frames_state, ref_state, ], outputs=[remove_video, mesh_video, ref_image_final], ) get_info_btn.click( get_video_info, inputs=[video_input, video_state], outputs=image_output, ) image_output.select( fn=segment_frame, inputs=[point_prompt, video_state], outputs=image_output, ) clear_btn.click(clear_clicks, inputs=video_state, outputs=image_output) track_btn.click( track_video, inputs=[n_frames_slider, video_state], outputs=[video_output, video_frames_state, mask_frames_state], ) ref_image_input.change( set_ref_image, inputs=[ref_image_input, ref_state], outputs=ref_image_display, ) ref_image_display.select( fn=segment_ref_frame, inputs=[ref_point_prompt, ref_state], outputs=ref_image_display, ) ref_clear_btn.click( clear_ref_clicks, inputs=ref_state, outputs=ref_image_display ) demo.launch()