Ryan-PR commited on
Commit
67a974c
·
verified ·
1 Parent(s): 491fe1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -72
app.py CHANGED
@@ -1,19 +1,13 @@
1
  import os
2
  import time
3
  import random
4
- import subprocess
5
- import importlib
6
 
7
  import gradio as gr
8
  import cv2
9
  import numpy as np
10
  from PIL import Image
11
-
12
- os.makedirs("./sam2/SAM2-Video-Predictor/checkpoints/", exist_ok=True)
13
- os.makedirs("./models/", exist_ok=True)
14
-
15
- from huggingface_hub import snapshot_download
16
-
17
 
18
  def ensure_wan():
19
  try:
@@ -24,9 +18,6 @@ def ensure_wan():
24
  env = dict(os.environ)
25
  print(f"[setup] Installing wan2.1: {cmd}")
26
  subprocess.run(cmd, shell=True, check=True, env=env)
27
- importlib.invalidate_caches()
28
- import wan # noqa
29
- print("[setup] wan installed.")
30
 
31
  def ensure_flash_attn():
32
  try:
@@ -48,6 +39,9 @@ def ensure_flash_attn():
48
  ensure_flash_attn()
49
  ensure_wan()
50
 
 
 
 
51
 
52
  def download_sam2():
53
  snapshot_download(
@@ -55,7 +49,7 @@ def download_sam2():
55
  local_dir="./sam2/SAM2-Video-Predictor/checkpoints/",
56
  )
57
  print("Download sam2 completed")
58
-
59
  def download_refacade():
60
  snapshot_download(
61
  repo_id="fishze/Refacade",
@@ -63,28 +57,25 @@ def download_refacade():
63
  )
64
  print("Download refacade completed")
65
 
 
66
  download_sam2()
67
  download_refacade()
68
 
69
-
70
  import torch
71
  import torch.nn.functional as F
72
  from decord import VideoReader, cpu
73
  from moviepy.editor import ImageSequenceClip
74
-
75
  from sam2.build_sam import build_sam2, build_sam2_video_predictor
76
  from sam2.sam2_image_predictor import SAM2ImagePredictor
77
-
78
  import spaces
79
-
80
  from vace.models.wan.modules.model_mm import VaceMMModel
81
  from vace.models.wan.modules.model_tr import VaceWanModel
82
- from wan.text2video import FlowUniPCMultistepScheduler
83
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
 
84
  from diffusers.utils import export_to_video, load_image, load_video
85
  from vae import WanVAE
86
 
87
-
88
  COLOR_PALETTE = [
89
  (255, 0, 0),
90
  (0, 255, 0),
@@ -101,10 +92,8 @@ COLOR_PALETTE = [
101
  video_length = 201
102
  W = 1024
103
  H = W
104
-
105
- DEVICE_SAM = "cpu"
106
- DEVICE_PIPE = "cuda"
107
-
108
 
109
  def get_pipe_image_and_video_predictor():
110
  vae = WanVAE(
@@ -112,61 +101,50 @@ def get_pipe_image_and_video_predictor():
112
  dtype=torch.float16,
113
  )
114
 
 
 
115
  texture_remover = VaceWanModel.from_config(
116
  "./models/texture_remover/texture_remover.json"
117
  )
118
- ckpt_tr = torch.load(
119
  "./models/texture_remover/texture_remover.pth",
120
  map_location="cpu",
121
  )
122
- texture_remover.load_state_dict(ckpt_tr)
123
- texture_remover = texture_remover.to(dtype=torch.float16, device=DEVICE_PIPE)
124
 
125
  model = VaceMMModel.from_config(
126
  "./models/refacade/refacade.json"
127
  )
128
- ckpt_ref = torch.load(
129
  "./models/refacade/refacade.pth",
130
  map_location="cpu",
131
  )
132
- model.load_state_dict(ckpt_ref)
133
- model = model.to(dtype=torch.float16, device=DEVICE_PIPE)
134
 
135
  sample_scheduler = FlowUniPCMultistepScheduler(
136
  num_train_timesteps=1000,
137
  shift=1,
138
  )
139
-
140
- from pipeline import RefacadePipeline
141
  pipe = RefacadePipeline(
142
  vae=vae,
143
  transformer=model,
144
  texture_remover=texture_remover,
145
  scheduler=sample_scheduler,
146
  )
147
- pipe.to(DEVICE_PIPE)
148
 
149
  sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
150
  config = "sam2_hiera_l.yaml"
151
 
152
- video_predictor = build_sam2_video_predictor(
153
- config,
154
- sam2_checkpoint,
155
- device="cuda",
156
- )
157
-
158
- model_sam = build_sam2(
159
- config,
160
- sam2_checkpoint,
161
- device=DEVICE_SAM,
162
- )
163
  model_sam.image_size = 1024
164
  image_predictor = SAM2ImagePredictor(sam_model=model_sam)
165
 
166
  return pipe, image_predictor, video_predictor
167
 
168
- pipe, image_predictor, video_predictor = get_pipe_image_and_video_predictor()
169
-
170
 
171
  def get_video_info(video_path, video_state):
172
  video_state["input_points"] = []
@@ -194,27 +172,26 @@ def get_video_info(video_path, video_state):
194
  image = Image.fromarray(first_frame)
195
  return image
196
 
 
197
  def segment_frame(evt: gr.SelectData, label, video_state):
198
  if video_state["origin_images"] is None:
199
  return None
200
-
201
  x, y = evt.index
202
  new_point = [x, y]
203
  label_value = 1 if label == "Positive" else 0
204
 
205
  video_state["input_points"].append(new_point)
206
  video_state["input_labels"].append(label_value)
207
-
208
  height, width = video_state["origin_images"][0].shape[0:2]
209
  scaled_points = []
210
  for pt in video_state["input_points"]:
211
  sx = pt[0] / width
212
  sy = pt[1] / height
213
  scaled_points.append([sx, sy])
 
214
  video_state["scaled_points"] = scaled_points
215
 
216
- img0 = video_state["origin_images"][0]
217
- image_predictor.set_image(img0)
218
  mask, _, _ = image_predictor.predict(
219
  point_coords=video_state["scaled_points"],
220
  point_labels=video_state["input_labels"],
@@ -231,10 +208,9 @@ def segment_frame(evt: gr.SelectData, label, video_state):
231
  / 255.0
232
  )
233
  color = color[None, None, :]
234
- org_image = img0.astype(np.float32) / 255.0
235
  painted_image = (1 - mask * 0.5) * org_image + mask * 0.5 * color
236
  painted_image = np.uint8(np.clip(painted_image * 255, 0, 255))
237
-
238
  video_state["painted_images"] = np.expand_dims(painted_image, axis=0)
239
  video_state["masks"] = np.expand_dims(mask[:, :, 0], axis=0)
240
 
@@ -247,6 +223,7 @@ def segment_frame(evt: gr.SelectData, label, video_state):
247
 
248
  return Image.fromarray(painted_image)
249
 
 
250
  def clear_clicks(video_state):
251
  video_state["input_points"] = []
252
  video_state["input_labels"] = []
@@ -260,6 +237,7 @@ def clear_clicks(video_state):
260
  else None
261
  )
262
 
 
263
  def set_ref_image(ref_img, ref_state):
264
  if ref_img is None:
265
  return None
@@ -277,6 +255,7 @@ def set_ref_image(ref_img, ref_state):
277
 
278
  return Image.fromarray(img_np)
279
 
 
280
  def segment_ref_frame(evt: gr.SelectData, label, ref_state):
281
  if ref_state["origin_image"] is None:
282
  return None
@@ -320,7 +299,7 @@ def segment_ref_frame(evt: gr.SelectData, label, ref_state):
320
  painted = (1 - mask * 0.5) * org_image + mask * 0.5 * color
321
  painted = np.uint8(np.clip(painted * 255, 0, 255))
322
 
323
- for i in range(len(ref_state["input_points"])):
324
  point = ref_state["input_points"][i]
325
  if ref_state["input_labels"][i] == 0:
326
  cv2.circle(painted, point, radius=3, color=(0, 0, 255), thickness=-1)
@@ -329,6 +308,7 @@ def segment_ref_frame(evt: gr.SelectData, label, ref_state):
329
 
330
  return Image.fromarray(painted)
331
 
 
332
  def clear_ref_clicks(ref_state):
333
  ref_state["input_points"] = []
334
  ref_state["input_labels"] = []
@@ -366,11 +346,11 @@ def track_video(n_frames, video_state):
366
  sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
367
  config = "sam2_hiera_l.yaml"
368
  video_predictor_local = build_sam2_video_predictor(
369
- config, sam2_checkpoint, device="cuda"
370
  )
371
 
372
  inference_state = video_predictor_local.init_state(
373
- images=images / 255, device="cuda"
374
  )
375
 
376
  if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
@@ -417,6 +397,7 @@ def track_video(n_frames, video_state):
417
  print("Tracking done")
418
  return video_file, video_state
419
 
 
420
  @spaces.GPU(duration=50)
421
  def inference_and_return_video(
422
  dilate_radius,
@@ -477,7 +458,7 @@ def inference_and_return_video(
477
  ref_mask_bin = (ref_mask_np > 0.5).astype(np.uint8) * 255
478
  ref_mask_pil = Image.fromarray(ref_mask_bin, mode="L")
479
 
480
- pipe.to(DEVICE_PIPE)
481
  with torch.no_grad():
482
  retex_frames, mesh_frames, ref_img_out = pipe(
483
  video=video_frames,
@@ -493,7 +474,7 @@ def inference_and_return_video(
493
  guidance_scale=float(guidance_scale),
494
  reference_patch_ratio=float(ref_patch_ratio),
495
  fg_thresh=float(fg_threshold),
496
- generator=torch.Generator(device=DEVICE_PIPE).manual_seed(seed),
497
  return_dict=False,
498
  )
499
 
@@ -522,7 +503,6 @@ def inference_and_return_video(
522
 
523
  return retex_video_file, mesh_video_file, ref_image_to_show
524
 
525
- # ================== Gradio UI ==================
526
 
527
  text = """
528
  <div style='text-align:center; font-size:32px; font-family: Arial, Helvetica, sans-serif;'>
@@ -533,6 +513,8 @@ text = """
533
  </div>
534
  """
535
 
 
 
536
  with gr.Blocks() as demo:
537
  video_state = gr.State(
538
  {
@@ -564,7 +546,7 @@ with gr.Blocks() as demo:
564
  with gr.Column():
565
  video_input = gr.Video(label="Upload Video", elem_id="my-video1")
566
  get_info_btn = gr.Button("Extract First Frame", elem_id="my-btn")
567
-
568
  gr.Examples(
569
  examples=[
570
  ["./examples/1.mp4"],
@@ -576,7 +558,7 @@ with gr.Blocks() as demo:
576
  ],
577
  inputs=[video_input],
578
  label="You can upload or choose a source video below to retexture.",
579
- elem_id="my-btn2",
580
  )
581
 
582
  image_output = gr.Image(
@@ -623,18 +605,6 @@ with gr.Blocks() as demo:
623
  width: 60% !important;
624
  margin: 0 auto;
625
  }
626
- #my-btn3 button {
627
- width: 120px !important;
628
- max-width: 120px !important;
629
- min-width: 120px !important;
630
- height: 70px !important;
631
- max-height: 70px !important;
632
- min-height: 70px !important;
633
- margin: 8px !important;
634
- border-radius: 8px !important;
635
- overflow: hidden !important;
636
- white-space: normal !important;
637
- }
638
  #ref_title {
639
  text-align: center;
640
  }
@@ -686,7 +656,7 @@ with gr.Blocks() as demo:
686
  ],
687
  inputs=[ref_image_input],
688
  label="You can upload or choose a reference image below to retexture.",
689
- elem_id="my-btn3",
690
  )
691
  ref_image_display = gr.Image(
692
  label="Reference Mask Segmentation",
@@ -742,7 +712,7 @@ with gr.Blocks() as demo:
742
  maximum=2147483647,
743
  value=42,
744
  step=1,
745
- label="Seed",
746
  )
747
 
748
  remove_btn = gr.Button("Retexture", elem_id="my-btn")
 
1
  import os
2
  import time
3
  import random
 
 
4
 
5
  import gradio as gr
6
  import cv2
7
  import numpy as np
8
  from PIL import Image
9
+ import subprocess
10
+ import importlib
 
 
 
 
11
 
12
  def ensure_wan():
13
  try:
 
18
  env = dict(os.environ)
19
  print(f"[setup] Installing wan2.1: {cmd}")
20
  subprocess.run(cmd, shell=True, check=True, env=env)
 
 
 
21
 
22
  def ensure_flash_attn():
23
  try:
 
39
  ensure_flash_attn()
40
  ensure_wan()
41
 
42
+ os.makedirs("./sam2/SAM2-Video-Predictor/checkpoints/", exist_ok=True)
43
+
44
+ from huggingface_hub import snapshot_download
45
 
46
  def download_sam2():
47
  snapshot_download(
 
49
  local_dir="./sam2/SAM2-Video-Predictor/checkpoints/",
50
  )
51
  print("Download sam2 completed")
52
+
53
  def download_refacade():
54
  snapshot_download(
55
  repo_id="fishze/Refacade",
 
57
  )
58
  print("Download refacade completed")
59
 
60
+
61
  download_sam2()
62
  download_refacade()
63
 
 
64
  import torch
65
  import torch.nn.functional as F
66
  from decord import VideoReader, cpu
67
  from moviepy.editor import ImageSequenceClip
 
68
  from sam2.build_sam import build_sam2, build_sam2_video_predictor
69
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
70
  import spaces
71
+ from pipeline import RefacadePipeline
72
  from vace.models.wan.modules.model_mm import VaceMMModel
73
  from vace.models.wan.modules.model_tr import VaceWanModel
 
74
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
75
+ from wan.text2video import FlowUniPCMultistepScheduler
76
  from diffusers.utils import export_to_video, load_image, load_video
77
  from vae import WanVAE
78
 
 
79
  COLOR_PALETTE = [
80
  (255, 0, 0),
81
  (0, 255, 0),
 
92
  video_length = 201
93
  W = 1024
94
  H = W
95
+ device = "cuda"
96
+ sam_device = "cpu"
 
 
97
 
98
  def get_pipe_image_and_video_predictor():
99
  vae = WanVAE(
 
101
  dtype=torch.float16,
102
  )
103
 
104
+ pipe_device = "cuda"
105
+
106
  texture_remover = VaceWanModel.from_config(
107
  "./models/texture_remover/texture_remover.json"
108
  )
109
+ ckpt = torch.load(
110
  "./models/texture_remover/texture_remover.pth",
111
  map_location="cpu",
112
  )
113
+ texture_remover.load_state_dict(ckpt)
114
+ texture_remover = texture_remover.to(dtype=torch.float16, device=pipe_device)
115
 
116
  model = VaceMMModel.from_config(
117
  "./models/refacade/refacade.json"
118
  )
119
+ ckpt = torch.load(
120
  "./models/refacade/refacade.pth",
121
  map_location="cpu",
122
  )
123
+ model.load_state_dict(ckpt)
124
+ model = model.to(dtype=torch.float16, device=pipe_device)
125
 
126
  sample_scheduler = FlowUniPCMultistepScheduler(
127
  num_train_timesteps=1000,
128
  shift=1,
129
  )
 
 
130
  pipe = RefacadePipeline(
131
  vae=vae,
132
  transformer=model,
133
  texture_remover=texture_remover,
134
  scheduler=sample_scheduler,
135
  )
136
+ pipe.to(pipe_device)
137
 
138
  sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
139
  config = "sam2_hiera_l.yaml"
140
 
141
+ video_predictor = build_sam2_video_predictor(config, sam2_checkpoint, device=sam_device)
142
+ model_sam = build_sam2(config, sam2_checkpoint, device=sam_device)
 
 
 
 
 
 
 
 
 
143
  model_sam.image_size = 1024
144
  image_predictor = SAM2ImagePredictor(sam_model=model_sam)
145
 
146
  return pipe, image_predictor, video_predictor
147
 
 
 
148
 
149
  def get_video_info(video_path, video_state):
150
  video_state["input_points"] = []
 
172
  image = Image.fromarray(first_frame)
173
  return image
174
 
175
+
176
  def segment_frame(evt: gr.SelectData, label, video_state):
177
  if video_state["origin_images"] is None:
178
  return None
 
179
  x, y = evt.index
180
  new_point = [x, y]
181
  label_value = 1 if label == "Positive" else 0
182
 
183
  video_state["input_points"].append(new_point)
184
  video_state["input_labels"].append(label_value)
 
185
  height, width = video_state["origin_images"][0].shape[0:2]
186
  scaled_points = []
187
  for pt in video_state["input_points"]:
188
  sx = pt[0] / width
189
  sy = pt[1] / height
190
  scaled_points.append([sx, sy])
191
+
192
  video_state["scaled_points"] = scaled_points
193
 
194
+ image_predictor.set_image(video_state["origin_images"][0])
 
195
  mask, _, _ = image_predictor.predict(
196
  point_coords=video_state["scaled_points"],
197
  point_labels=video_state["input_labels"],
 
208
  / 255.0
209
  )
210
  color = color[None, None, :]
211
+ org_image = video_state["origin_images"][0].astype(np.float32) / 255.0
212
  painted_image = (1 - mask * 0.5) * org_image + mask * 0.5 * color
213
  painted_image = np.uint8(np.clip(painted_image * 255, 0, 255))
 
214
  video_state["painted_images"] = np.expand_dims(painted_image, axis=0)
215
  video_state["masks"] = np.expand_dims(mask[:, :, 0], axis=0)
216
 
 
223
 
224
  return Image.fromarray(painted_image)
225
 
226
+
227
  def clear_clicks(video_state):
228
  video_state["input_points"] = []
229
  video_state["input_labels"] = []
 
237
  else None
238
  )
239
 
240
+
241
  def set_ref_image(ref_img, ref_state):
242
  if ref_img is None:
243
  return None
 
255
 
256
  return Image.fromarray(img_np)
257
 
258
+
259
  def segment_ref_frame(evt: gr.SelectData, label, ref_state):
260
  if ref_state["origin_image"] is None:
261
  return None
 
299
  painted = (1 - mask * 0.5) * org_image + mask * 0.5 * color
300
  painted = np.uint8(np.clip(painted * 255, 0, 255))
301
 
302
+ for i in range(len(ref_state["input_points"]))):
303
  point = ref_state["input_points"][i]
304
  if ref_state["input_labels"][i] == 0:
305
  cv2.circle(painted, point, radius=3, color=(0, 0, 255), thickness=-1)
 
308
 
309
  return Image.fromarray(painted)
310
 
311
+
312
  def clear_ref_clicks(ref_state):
313
  ref_state["input_points"] = []
314
  ref_state["input_labels"] = []
 
346
  sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
347
  config = "sam2_hiera_l.yaml"
348
  video_predictor_local = build_sam2_video_predictor(
349
+ config, sam2_checkpoint, device=sam_device
350
  )
351
 
352
  inference_state = video_predictor_local.init_state(
353
+ images=images / 255, device=sam_device
354
  )
355
 
356
  if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
 
397
  print("Tracking done")
398
  return video_file, video_state
399
 
400
+
401
  @spaces.GPU(duration=50)
402
  def inference_and_return_video(
403
  dilate_radius,
 
458
  ref_mask_bin = (ref_mask_np > 0.5).astype(np.uint8) * 255
459
  ref_mask_pil = Image.fromarray(ref_mask_bin, mode="L")
460
 
461
+ pipe.to("cuda")
462
  with torch.no_grad():
463
  retex_frames, mesh_frames, ref_img_out = pipe(
464
  video=video_frames,
 
474
  guidance_scale=float(guidance_scale),
475
  reference_patch_ratio=float(ref_patch_ratio),
476
  fg_thresh=float(fg_threshold),
477
+ generator=torch.Generator(device="cuda").manual_seed(seed),
478
  return_dict=False,
479
  )
480
 
 
503
 
504
  return retex_video_file, mesh_video_file, ref_image_to_show
505
 
 
506
 
507
  text = """
508
  <div style='text-align:center; font-size:32px; font-family: Arial, Helvetica, sans-serif;'>
 
513
  </div>
514
  """
515
 
516
+ pipe, image_predictor, video_predictor = get_pipe_image_and_video_predictor()
517
+
518
  with gr.Blocks() as demo:
519
  video_state = gr.State(
520
  {
 
546
  with gr.Column():
547
  video_input = gr.Video(label="Upload Video", elem_id="my-video1")
548
  get_info_btn = gr.Button("Extract First Frame", elem_id="my-btn")
549
+
550
  gr.Examples(
551
  examples=[
552
  ["./examples/1.mp4"],
 
558
  ],
559
  inputs=[video_input],
560
  label="You can upload or choose a source video below to retexture.",
561
+ elem_id="my-btn2"
562
  )
563
 
564
  image_output = gr.Image(
 
605
  width: 60% !important;
606
  margin: 0 auto;
607
  }
 
 
 
 
 
 
 
 
 
 
 
 
608
  #ref_title {
609
  text-align: center;
610
  }
 
656
  ],
657
  inputs=[ref_image_input],
658
  label="You can upload or choose a reference image below to retexture.",
659
+ elem_id="my-btn3"
660
  )
661
  ref_image_display = gr.Image(
662
  label="Reference Mask Segmentation",
 
712
  maximum=2147483647,
713
  value=42,
714
  step=1,
715
+ label="Seed",
716
  )
717
 
718
  remove_btn = gr.Button("Retexture", elem_id="my-btn")