r3gm commited on
Commit
6ab11e0
·
verified ·
1 Parent(s): 330958f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -16
app.py CHANGED
@@ -9,6 +9,7 @@ import tempfile
9
  import warnings
10
  import time
11
  import gc
 
12
 
13
  import cv2
14
  import numpy as np
@@ -53,6 +54,7 @@ function() {
53
  }
54
  """
55
 
 
56
  def extract_frame(video_path, timestamp):
57
  # Safety check: if no video is present
58
  if not video_path:
@@ -89,6 +91,11 @@ def extract_frame(video_path, timestamp):
89
  # --- END FRAME EXTRACTION LOGIC ---
90
 
91
 
 
 
 
 
 
92
  # RIFE
93
  if not os.path.exists("RIFEv4.26_0921.zip"):
94
  print("Downloading RIFE Model...")
@@ -111,6 +118,7 @@ rife_model.device()
111
  if torch.cuda.is_available():
112
  rife_model.flownet = rife_model.flownet.half()
113
 
 
114
  @torch.no_grad()
115
  def interpolate_bits(frames_np, multiplier=2, scale=1.0):
116
  """
@@ -181,7 +189,7 @@ def interpolate_bits(frames_np, multiplier=2, scale=1.0):
181
  return [*first_half, *second_half]
182
 
183
  output_frames = []
184
-
185
  # Process Frames
186
  # Load first frame into GPU
187
  I1 = to_tensor(frames_np[0])
@@ -210,6 +218,7 @@ def interpolate_bits(frames_np, multiplier=2, scale=1.0):
210
 
211
  return output_frames
212
 
 
213
  # WAN
214
 
215
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
@@ -257,6 +266,9 @@ quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
257
  aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
258
  aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
259
 
 
 
 
260
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
261
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
262
 
@@ -339,9 +351,7 @@ def get_inference_duration(
339
  gen_time = int(steps) * step_duration
340
  print(gen_time)
341
  if guidance_scale > 1:
342
- gen_time = gen_time * 1.55
343
- if guidance_scale_2 > 1:
344
- gen_time = gen_time * 1.55
345
 
346
  if frame_multiplier > 1:
347
  total_out_frames = (num_frames * frame_multiplier) - num_frames
@@ -380,7 +390,10 @@ def run_inference(
380
  config['flow_shift'] = flow_shift
381
  pipe.scheduler = scheduler_class.from_config(config)
382
 
383
- print(f"Generating {num_frames} frames with Wan...")
 
 
 
384
  start = time.time()
385
  result = pipe(
386
  image=resized_image,
@@ -398,7 +411,7 @@ def run_inference(
398
  )
399
  print("gen time passed:", time.time() - start)
400
 
401
- raw_frames_np = result.frames[0] # Returns (T, H, W, C) float32
402
  pipe.scheduler = original_scheduler
403
 
404
  if frame_multiplier > 1:
@@ -417,8 +430,9 @@ def run_inference(
417
  start = time.time()
418
  export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
419
  print(f"Export time passed, {final_fps} FPS:", time.time() - start)
420
-
421
- return video_path
 
422
 
423
  def generate_video(
424
  input_image,
@@ -494,7 +508,7 @@ def generate_video(
494
  if last_image:
495
  processed_last_image = resize_and_crop_to_match(last_image, resized_image)
496
 
497
- video_path = run_inference(
498
  resized_image,
499
  processed_last_image,
500
  prompt,
@@ -511,7 +525,7 @@ def generate_video(
511
  duration_seconds,
512
  progress,
513
  )
514
- print("GPU complete")
515
 
516
  return (video_path if video_component else None), video_path, current_seed
517
 
@@ -531,13 +545,13 @@ CSS = """
531
 
532
 
533
  with gr.Blocks(delete_cache=(3600, 10800)) as demo:
534
- gr.Markdown("## WAMU - Wan 2.2 I2V (14B)")
535
  gr.Markdown("#### ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
536
- gr.Markdown("run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
537
 
538
  with gr.Row():
539
  with gr.Column():
540
- input_image_component = gr.Image(type="pil", label="Input Image")
541
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
542
  duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
543
  steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
@@ -554,7 +568,7 @@ with gr.Blocks(delete_cache=(3600, 10800)) as demo:
554
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
555
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
556
  guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage", info="Values above 1 increase GPU usage and may take longer to process.")
557
- guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage", info="Values above 1 increase GPU usage and may take longer to process.")
558
  scheduler_dropdown = gr.Dropdown(
559
  label="Scheduler",
560
  choices=list(SCHEDULER_MAP.keys()),
@@ -563,13 +577,14 @@ with gr.Blocks(delete_cache=(3600, 10800)) as demo:
563
  )
564
  flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
565
  play_result_video = gr.Checkbox(label="Display result", value=True, interactive=True)
566
- gr.Markdown("[ZeroGPU Help, Tips, and Troubleshooting](https://huggingface.co/datasets/TestOrganizationPleaseIgnore/help/blob/main/gpu_help.md)")
 
567
 
568
  generate_button = gr.Button("Generate Video", variant="primary")
569
 
570
  with gr.Column():
571
  # ASSIGNED elem_id="generated-video" so JS can find it
572
- video_output = gr.Video(label="Generated Video", autoplay=True, elem_id="generated-video")
573
 
574
  # --- Frame Grabbing UI ---
575
  with gr.Row():
 
9
  import warnings
10
  import time
11
  import gc
12
+ import uuid
13
 
14
  import cv2
15
  import numpy as np
 
54
  }
55
  """
56
 
57
+
58
  def extract_frame(video_path, timestamp):
59
  # Safety check: if no video is present
60
  if not video_path:
 
91
  # --- END FRAME EXTRACTION LOGIC ---
92
 
93
 
94
+ def clear_vram():
95
+ gc.collect()
96
+ torch.cuda.empty_cache()
97
+
98
+
99
  # RIFE
100
  if not os.path.exists("RIFEv4.26_0921.zip"):
101
  print("Downloading RIFE Model...")
 
118
  if torch.cuda.is_available():
119
  rife_model.flownet = rife_model.flownet.half()
120
 
121
+
122
  @torch.no_grad()
123
  def interpolate_bits(frames_np, multiplier=2, scale=1.0):
124
  """
 
189
  return [*first_half, *second_half]
190
 
191
  output_frames = []
192
+
193
  # Process Frames
194
  # Load first frame into GPU
195
  I1 = to_tensor(frames_np[0])
 
218
 
219
  return output_frames
220
 
221
+
222
  # WAN
223
 
224
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
 
266
  aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
267
  aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
268
 
269
+ # pipe.vae.enable_slicing()
270
+ # pipe.vae.enable_tiling()
271
+
272
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
273
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
274
 
 
351
  gen_time = int(steps) * step_duration
352
  print(gen_time)
353
  if guidance_scale > 1:
354
+ gen_time = gen_time * 1.8
 
 
355
 
356
  if frame_multiplier > 1:
357
  total_out_frames = (num_frames * frame_multiplier) - num_frames
 
390
  config['flow_shift'] = flow_shift
391
  pipe.scheduler = scheduler_class.from_config(config)
392
 
393
+ clear_vram()
394
+
395
+ task_name = str(uuid.uuid4())[:8]
396
+ print(f"Generating {num_frames} frames, task: {task_name}, {duration_seconds}, {resized_image.size}")
397
  start = time.time()
398
  result = pipe(
399
  image=resized_image,
 
411
  )
412
  print("gen time passed:", time.time() - start)
413
 
414
+ raw_frames_np = result.frames[0] # Returns (T, H, W, C) float32
415
  pipe.scheduler = original_scheduler
416
 
417
  if frame_multiplier > 1:
 
430
  start = time.time()
431
  export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
432
  print(f"Export time passed, {final_fps} FPS:", time.time() - start)
433
+
434
+ return video_path, task_name
435
+
436
 
437
  def generate_video(
438
  input_image,
 
508
  if last_image:
509
  processed_last_image = resize_and_crop_to_match(last_image, resized_image)
510
 
511
+ video_path, task_n = run_inference(
512
  resized_image,
513
  processed_last_image,
514
  prompt,
 
525
  duration_seconds,
526
  progress,
527
  )
528
+ print(f"GPU complete: {task_n}")
529
 
530
  return (video_path if video_component else None), video_path, current_seed
531
 
 
545
 
546
 
547
  with gr.Blocks(delete_cache=(3600, 10800)) as demo:
548
+ gr.Markdown("## WAMU - Wan 2.2 I2V (14B) 🐌")
549
  gr.Markdown("#### ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
550
+ gr.Markdown("Run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU")
551
 
552
  with gr.Row():
553
  with gr.Column():
554
+ input_image_component = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"])
555
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
556
  duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
557
  steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
 
568
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
569
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
570
  guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage", info="Values above 1 increase GPU usage and may take longer to process.")
571
+ guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
572
  scheduler_dropdown = gr.Dropdown(
573
  label="Scheduler",
574
  choices=list(SCHEDULER_MAP.keys()),
 
577
  )
578
  flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
579
  play_result_video = gr.Checkbox(label="Display result", value=True, interactive=True)
580
+ org_name = "TestOrganizationPleaseIgnore"
581
+ gr.Markdown(f"[ZeroGPU Help, Tips, and Troubleshooting](https://huggingface.co/datasets/{org_name}/help/blob/main/gpu_help.md)")
582
 
583
  generate_button = gr.Button("Generate Video", variant="primary")
584
 
585
  with gr.Column():
586
  # ASSIGNED elem_id="generated-video" so JS can find it
587
+ video_output = gr.Video(label="Generated Video", autoplay=True, sources=["upload"], buttons=["download", "share"], interactive=True, elem_id="generated-video")
588
 
589
  # --- Frame Grabbing UI ---
590
  with gr.Row():