r3gm commited on
Commit
3f29218
·
verified ·
1 Parent(s): 5a1470c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -100
app.py CHANGED
@@ -1,25 +1,22 @@
 
1
  import spaces
2
- import torch
3
- from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
4
- from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
5
- from diffusers.utils.export_utils import export_to_video
6
- import gradio as gr
7
- import tempfile
8
- import numpy as np
9
- from PIL import Image
10
  import random
 
 
 
11
  import gc
12
- import copy
13
- import os
14
- import shutil
15
-
16
- from gradio_client import Client, handle_file
17
- from torchao.quantization import quantize_
18
- from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
19
- from torchao.quantization import Int8WeightOnlyConfig
20
 
21
- import aoti
 
 
 
 
22
 
 
23
  from diffusers import (
24
  FlowMatchEulerDiscreteScheduler,
25
  SASolverScheduler,
@@ -29,15 +26,145 @@ from diffusers import (
29
  DPMSolverMultistepScheduler,
30
  DPMSolverSinglestepScheduler,
31
  )
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
 
35
 
36
  MAX_DIM = 832
37
  MIN_DIM = 480
38
  SQUARE_DIM = 640
39
  MULTIPLE_OF = 16
40
-
41
  MAX_SEED = np.iinfo(np.int32).max
42
 
43
  FIXED_FPS = 16
@@ -47,8 +174,6 @@ MAX_FRAMES_MODEL = 160
47
  MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
48
  MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
49
 
50
- CACHE_DIR = os.path.expanduser("~/.cache/huggingface/")
51
-
52
  SCHEDULER_MAP = {
53
  "FlowMatchEulerDiscrete": FlowMatchEulerDiscreteScheduler,
54
  "SASolver": SASolverScheduler,
@@ -64,7 +189,6 @@ pipe = WanImageToVideoPipeline.from_pretrained(
64
  torch_dtype=torch.bfloat16,
65
  ).to('cuda')
66
  original_scheduler = copy.deepcopy(pipe.scheduler)
67
- print(original_scheduler)
68
 
69
  if os.path.exists(CACHE_DIR):
70
  shutil.rmtree(CACHE_DIR)
@@ -79,59 +203,46 @@ quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
79
  aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
80
  aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
81
 
82
-
83
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
84
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
85
 
86
 
87
  def resize_image(image: Image.Image) -> Image.Image:
88
- """
89
- Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible.
90
- """
91
  width, height = image.size
92
-
93
- # Handle square case
94
  if width == height:
95
  return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
96
-
97
  aspect_ratio = width / height
98
-
99
  MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM
100
  MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM
101
 
102
  image_to_resize = image
103
-
104
  if aspect_ratio > MAX_ASPECT_RATIO:
105
- # Very wide image -> crop width to fit 832x480 aspect ratio
106
  target_w, target_h = MAX_DIM, MIN_DIM
107
  crop_width = int(round(height * MAX_ASPECT_RATIO))
108
  left = (width - crop_width) // 2
109
  image_to_resize = image.crop((left, 0, left + crop_width, height))
110
  elif aspect_ratio < MIN_ASPECT_RATIO:
111
- # Very tall image -> crop height to fit 480x832 aspect ratio
112
  target_w, target_h = MIN_DIM, MAX_DIM
113
  crop_height = int(round(width / MIN_ASPECT_RATIO))
114
  top = (height - crop_height) // 2
115
  image_to_resize = image.crop((0, top, width, top + crop_height))
116
  else:
117
- if width > height: # Landscape
118
  target_w = MAX_DIM
119
  target_h = int(round(target_w / aspect_ratio))
120
- else: # Portrait
121
  target_h = MAX_DIM
122
  target_w = int(round(target_h * aspect_ratio))
123
 
124
  final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
125
  final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
126
-
127
  final_w = max(MIN_DIM, min(MAX_DIM, final_w))
128
  final_h = max(MIN_DIM, min(MAX_DIM, final_h))
129
-
130
  return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
131
 
132
 
133
  def resize_and_crop_to_match(target_image, reference_image):
134
- """Resizes and center-crops the target image to match the reference image's dimensions."""
135
  ref_width, ref_height = reference_image.size
136
  target_width, target_height = target_image.size
137
  scale = max(ref_width / target_width, ref_height / target_height)
@@ -161,6 +272,9 @@ def get_inference_duration(
161
  current_seed,
162
  scheduler_name,
163
  flow_shift,
 
 
 
164
  progress
165
  ):
166
  BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
@@ -168,7 +282,22 @@ def get_inference_duration(
168
  width, height = resized_image.size
169
  factor = num_frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
170
  step_duration = BASE_STEP_DURATION * factor ** 1.5
171
- return 10 + int(steps) * step_duration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
 
174
  @spaces.GPU(duration=get_inference_duration)
@@ -184,9 +313,11 @@ def run_inference(
184
  current_seed,
185
  scheduler_name,
186
  flow_shift,
 
 
 
187
  progress=gr.Progress(track_tqdm=True),
188
  ):
189
-
190
  scheduler_class = SCHEDULER_MAP.get(scheduler_name)
191
  if scheduler_class.__name__ != pipe.scheduler.config._class_name or flow_shift != pipe.scheduler.config.get("flow_shift", "shift"):
192
  config = copy.deepcopy(original_scheduler.config)
@@ -196,6 +327,8 @@ def run_inference(
196
  config['flow_shift'] = flow_shift
197
  pipe.scheduler = scheduler_class.from_config(config)
198
 
 
 
199
  result = pipe(
200
  image=resized_image,
201
  last_image=processed_last_image,
@@ -208,11 +341,32 @@ def run_inference(
208
  guidance_scale_2=float(guidance_scale_2),
209
  num_inference_steps=int(steps),
210
  generator=torch.Generator(device="cuda").manual_seed(current_seed),
211
- ).frames[0]
212
-
 
 
 
213
  pipe.scheduler = original_scheduler
214
- return result
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  def generate_video(
218
  input_image,
@@ -228,15 +382,14 @@ def generate_video(
228
  quality=5,
229
  scheduler="UniPCMultistep",
230
  flow_shift=6.0,
 
231
  progress=gr.Progress(track_tqdm=True),
232
  ):
233
  """
234
  Generate a video from an input image using the Wan 2.2 14B I2V model with Lightning LoRA.
235
-
236
  This function takes an input image and generates a video animation based on the provided
237
  prompt and parameters. It uses an FP8 qunatized Wan 2.2 14B Image-to-Video model in with Lightning LoRA
238
  for fast generation in 4-8 steps.
239
-
240
  Args:
241
  input_image (PIL.Image): The input image to animate. Will be resized to target dimensions.
242
  last_image (PIL.Image, optional): The optional last image for the video.
@@ -259,23 +412,22 @@ def generate_video(
259
  Highest quality is 10, lowest is 1.
260
  scheduler (str, optional): The name of the scheduler to use for inference. Defaults to "UniPCMultistep".
261
  flow_shift (float, optional): The flow shift value for compatible schedulers. Defaults to 6.0.
 
262
  progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
263
-
264
  Returns:
265
  tuple: A tuple containing:
266
  - video_path (str): Path for the video component.
267
  - video_path (str): Path for the file download component. Attempt to avoid reconversion in video component.
268
  - current_seed (int): The seed used for generation.
269
-
270
  Raises:
271
  gr.Error: If input_image is None (no image uploaded).
272
-
273
  Note:
274
  - Frame count is calculated as duration_seconds * FIXED_FPS (24)
275
  - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
276
  - The function uses GPU acceleration via the @spaces.GPU decorator
277
  - Generation time varies based on steps and duration (see get_duration function)
278
  """
 
279
  if input_image is None:
280
  raise gr.Error("Please upload an input image.")
281
 
@@ -287,7 +439,7 @@ def generate_video(
287
  if last_image:
288
  processed_last_image = resize_and_crop_to_match(last_image, resized_image)
289
 
290
- output_frames_list = run_inference(
291
  resized_image,
292
  processed_last_image,
293
  prompt,
@@ -299,93 +451,66 @@ def generate_video(
299
  current_seed,
300
  scheduler,
301
  flow_shift,
 
 
 
302
  progress,
303
  )
304
-
305
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
306
- video_path = tmpfile.name
307
-
308
- export_to_video(output_frames_list, video_path, fps=FIXED_FPS, quality=quality)
309
 
310
  return video_path, video_path, current_seed
311
 
312
 
313
- def visibility_interpolation():
314
- return (
315
- gr.update(visible=True),
316
- gr.update(visible=True)
317
- )
318
-
319
-
320
- def interpolate_video(generated_video, multiplier, request: gr.Request):
321
- x_ip_token = request.headers['x-ip-token']
322
- client = Client("r3gm/FPS_Enhancer", headers={"x-ip-token": x_ip_token})
323
- result = client.predict(
324
- input_video=handle_file(generated_video),
325
- frame_multiplier=multiplier,
326
- time_exponent=1,
327
- fixed_fps=0,
328
- video_scale=1.0,
329
- remove_duplicate_frames=False,
330
- create_montage=False,
331
- api_name="/run_rife"
332
- )
333
- return result
334
-
335
-
336
- def generate_interpolate(generated_video, multiplier, request: gr.Request):
337
- return interpolate_video(generated_video, multiplier, request)
338
-
339
-
340
  with gr.Blocks(delete_cache=(3600, 10800)) as demo:
341
  gr.Markdown("# WAMU - Wan 2.2 I2V (14B)")
342
  gr.Markdown("## ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
343
  gr.Markdown("run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
 
344
  with gr.Row():
345
  with gr.Column():
346
  input_image_component = gr.Image(type="pil", label="Input Image")
347
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
348
- duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
349
  steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
350
-
 
 
 
 
 
351
  with gr.Accordion("Advanced Settings", open=False):
352
  last_image_component = gr.Image(type="pil", label="Last Image (Optional)")
353
- negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, info="Used if any Guidance Scale > 1.", lines=3)
354
  quality_slider = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Video Quality")
355
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
356
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
357
- guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage")
358
- guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
359
  scheduler_dropdown = gr.Dropdown(
360
  label="Scheduler",
361
  choices=list(SCHEDULER_MAP.keys()),
362
- value="UniPCMultistep",
363
- info="Select a custom scheduler."
364
  )
365
  flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
366
 
367
- frame_multi = gr.Number(
368
- minimum=2, maximum=6, step=1, value=4, label="Frame Multiplier for 'Generate Intermediate Frames'",
369
- info="2X = Double FPS (e.g. 16 -> 32). Higher multipliers create more intermediate frames."
370
- )
371
-
372
  generate_button = gr.Button("Generate Video", variant="primary")
 
373
  with gr.Column():
374
- video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
375
  file_output = gr.File(label="Download Video")
376
 
377
- frame_button = gr.Button("Generate Intermediate Frames", variant="secondary", visible=False)
378
- frame_output = gr.File(label="Interpolated result", visible=False)
379
-
380
  ui_inputs = [
381
  input_image_component, last_image_component, prompt_input, steps_slider,
382
  negative_prompt_input, duration_seconds_input,
383
  guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox,
384
- quality_slider, scheduler_dropdown, flow_shift_slider,
385
  ]
386
- generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, file_output, seed_input]).success(visibility_interpolation, [], [frame_button, frame_output])
387
- frame_button.click(fn=generate_interpolate, inputs=[file_output, frame_multi], outputs=[frame_output], api_visibility="private")
388
-
 
 
 
389
 
390
  if __name__ == "__main__":
391
  demo.queue().launch(mcp_server=True)
 
1
+ import os
2
  import spaces
3
+ import shutil
4
+ import subprocess
5
+ import sys
6
+ import copy
 
 
 
 
7
  import random
8
+ import tempfile
9
+ import warnings
10
+ import time
11
  import gc
 
 
 
 
 
 
 
 
12
 
13
+ import cv2
14
+ import numpy as np
15
+ import torch
16
+ from torch.nn import functional as F
17
+ from PIL import Image
18
 
19
+ import gradio as gr
20
  from diffusers import (
21
  FlowMatchEulerDiscreteScheduler,
22
  SASolverScheduler,
 
26
  DPMSolverMultistepScheduler,
27
  DPMSolverSinglestepScheduler,
28
  )
29
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
30
+ from diffusers.utils.export_utils import export_to_video
31
 
32
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
33
+ import aoti
34
+
35
+ # os.environ["TOKENIZERS_PARALLELISM"] = "false"
36
+ warnings.filterwarnings("ignore")
37
+
38
+ # RIFE
39
+ if not os.path.exists("RIFEv4.26_0921.zip"):
40
+ print("Downloading RIFE Model...")
41
+ subprocess.run([
42
+ "wget", "-q",
43
+ "https://huggingface.co/r3gm/RIFE/resolve/main/RIFEv4.26_0921.zip",
44
+ "-O", "RIFEv4.26_0921.zip"
45
+ ], check=True)
46
+ subprocess.run(["unzip", "-o", "RIFEv4.26_0921.zip"], check=True)
47
+
48
+ # sys.path.append(os.getcwd())
49
+
50
+ from train_log.RIFE_HDv3 import Model
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ rife_model = Model()
53
+ rife_model.load_model("train_log", -1)
54
+ rife_model.eval()
55
+ rife_model.device()
56
+
57
+ if torch.cuda.is_available():
58
+ rife_model.flownet = rife_model.flownet.half()
59
+
60
+ @torch.no_grad()
61
+ def interpolate_bits(frames_np, multiplier=2, scale=1.0):
62
+ """
63
+ Interpolation maintaining Numpy Float 0-1 format.
64
+ Args:
65
+ frames_np: Numpy Array (Time, Height, Width, Channels) - Float32 [0.0, 1.0]
66
+ multiplier: int (2, 4, 8)
67
+ Returns:
68
+ List of Numpy Arrays (Height, Width, Channels) - Float32 [0.0, 1.0]
69
+ """
70
+
71
+ # Handle input shape
72
+ if isinstance(frames_np, list):
73
+ # Convert list of arrays to one big array for easier shape handling if needed,
74
+ # but here we just grab dims from first frame
75
+ T = len(frames_np)
76
+ H, W, C = frames_np[0].shape
77
+ else:
78
+ T, H, W, C = frames_np.shape
79
+
80
+ # 1. No Interpolation Case
81
+ if multiplier < 2:
82
+ # Just convert 4D array to list of 3D arrays
83
+ if isinstance(frames_np, np.ndarray):
84
+ return list(frames_np)
85
+ return frames_np
86
+
87
+ n_interp = multiplier - 1
88
+
89
+ # Pre-calc padding for RIFE (requires dimensions divisible by 32/scale)
90
+ tmp = max(128, int(128 / scale))
91
+ ph = ((H - 1) // tmp + 1) * tmp
92
+ pw = ((W - 1) // tmp + 1) * tmp
93
+ padding = (0, pw - W, 0, ph - H)
94
+
95
+ # Helper: Numpy (H, W, C) Float -> Tensor (1, C, H, W) Half
96
+ def to_tensor(frame_np):
97
+ # frame_np is float32 0-1
98
+ t = torch.from_numpy(frame_np).to(device)
99
+ # HWC -> CHW
100
+ t = t.permute(2, 0, 1).unsqueeze(0)
101
+ return F.pad(t, padding).half()
102
+
103
+ # Helper: Tensor (1, C, H, W) Half -> Numpy (H, W, C) Float
104
+ def from_tensor(tensor):
105
+ # Crop padding
106
+ t = tensor[0, :, :H, :W]
107
+ # CHW -> HWC
108
+ t = t.permute(1, 2, 0)
109
+ # Keep as float32, range 0-1
110
+ return t.float().cpu().numpy()
111
+
112
+ def make_inference(I0, I1, n):
113
+ if rife_model.version >= 3.9:
114
+ res = []
115
+ for i in range(n):
116
+ res.append(rife_model.inference(I0, I1, (i+1) * 1. / (n+1), scale))
117
+ return res
118
+ else:
119
+ middle = rife_model.inference(I0, I1, scale)
120
+ if n == 1:
121
+ return [middle]
122
+ first_half = make_inference(I0, middle, n=n//2)
123
+ second_half = make_inference(middle, I1, n=n//2)
124
+ if n % 2:
125
+ return [*first_half, middle, *second_half]
126
+ else:
127
+ return [*first_half, *second_half]
128
+
129
+ output_frames = []
130
+
131
+ # Process Frames
132
+ # Load first frame into GPU
133
+ I1 = to_tensor(frames_np[0])
134
+
135
+ for i in range(T - 1):
136
+ I0 = I1
137
+ # Add original frame to output
138
+ output_frames.append(from_tensor(I0))
139
+
140
+ # Load next frame
141
+ I1 = to_tensor(frames_np[i+1])
142
+
143
+ # Generate intermediate frames
144
+ mid_tensors = make_inference(I0, I1, n_interp)
145
+
146
+ # Append intermediate frames
147
+ for mid in mid_tensors:
148
+ output_frames.append(from_tensor(mid))
149
+
150
+ # Add the very last frame
151
+ output_frames.append(from_tensor(I1))
152
+
153
+ # Cleanup
154
+ del I0, I1, mid_tensors
155
+ torch.cuda.empty_cache()
156
+
157
+ return output_frames
158
+
159
+ # WAN
160
 
161
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
162
+ CACHE_DIR = os.path.expanduser("~/.cache/huggingface/")
163
 
164
  MAX_DIM = 832
165
  MIN_DIM = 480
166
  SQUARE_DIM = 640
167
  MULTIPLE_OF = 16
 
168
  MAX_SEED = np.iinfo(np.int32).max
169
 
170
  FIXED_FPS = 16
 
174
  MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
175
  MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
176
 
 
 
177
  SCHEDULER_MAP = {
178
  "FlowMatchEulerDiscrete": FlowMatchEulerDiscreteScheduler,
179
  "SASolver": SASolverScheduler,
 
189
  torch_dtype=torch.bfloat16,
190
  ).to('cuda')
191
  original_scheduler = copy.deepcopy(pipe.scheduler)
 
192
 
193
  if os.path.exists(CACHE_DIR):
194
  shutil.rmtree(CACHE_DIR)
 
203
  aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
204
  aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
205
 
 
206
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
207
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
208
 
209
 
210
  def resize_image(image: Image.Image) -> Image.Image:
 
 
 
211
  width, height = image.size
 
 
212
  if width == height:
213
  return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
214
+
215
  aspect_ratio = width / height
 
216
  MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM
217
  MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM
218
 
219
  image_to_resize = image
 
220
  if aspect_ratio > MAX_ASPECT_RATIO:
 
221
  target_w, target_h = MAX_DIM, MIN_DIM
222
  crop_width = int(round(height * MAX_ASPECT_RATIO))
223
  left = (width - crop_width) // 2
224
  image_to_resize = image.crop((left, 0, left + crop_width, height))
225
  elif aspect_ratio < MIN_ASPECT_RATIO:
 
226
  target_w, target_h = MIN_DIM, MAX_DIM
227
  crop_height = int(round(width / MIN_ASPECT_RATIO))
228
  top = (height - crop_height) // 2
229
  image_to_resize = image.crop((0, top, width, top + crop_height))
230
  else:
231
+ if width > height:
232
  target_w = MAX_DIM
233
  target_h = int(round(target_w / aspect_ratio))
234
+ else:
235
  target_h = MAX_DIM
236
  target_w = int(round(target_h * aspect_ratio))
237
 
238
  final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
239
  final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
 
240
  final_w = max(MIN_DIM, min(MAX_DIM, final_w))
241
  final_h = max(MIN_DIM, min(MAX_DIM, final_h))
 
242
  return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
243
 
244
 
245
  def resize_and_crop_to_match(target_image, reference_image):
 
246
  ref_width, ref_height = reference_image.size
247
  target_width, target_height = target_image.size
248
  scale = max(ref_width / target_width, ref_height / target_height)
 
272
  current_seed,
273
  scheduler_name,
274
  flow_shift,
275
+ frame_multiplier,
276
+ quality,
277
+ duration_seconds,
278
  progress
279
  ):
280
  BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
 
282
  width, height = resized_image.size
283
  factor = num_frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
284
  step_duration = BASE_STEP_DURATION * factor ** 1.5
285
+ gen_time = int(steps) * step_duration
286
+ print(gen_time)
287
+ if guidance_scale > 1:
288
+ gen_time = gen_time * 1.5
289
+ if guidance_scale_2 > 1:
290
+ gen_time = gen_time * 1.5
291
+
292
+ if frame_multiplier > 1:
293
+ # total_out_frames = (num_frames * frame_multiplier)
294
+ # inter_time = (total_out_frames * 0.02)
295
+ inter_time = duration_seconds
296
+ print(inter_time)
297
+ gen_time += inter_time
298
+
299
+ print("Time GPU", gen_time + 10)
300
+ return 10 + gen_time
301
 
302
 
303
  @spaces.GPU(duration=get_inference_duration)
 
313
  current_seed,
314
  scheduler_name,
315
  flow_shift,
316
+ frame_multiplier,
317
+ quality,
318
+ duration_seconds,
319
  progress=gr.Progress(track_tqdm=True),
320
  ):
 
321
  scheduler_class = SCHEDULER_MAP.get(scheduler_name)
322
  if scheduler_class.__name__ != pipe.scheduler.config._class_name or flow_shift != pipe.scheduler.config.get("flow_shift", "shift"):
323
  config = copy.deepcopy(original_scheduler.config)
 
327
  config['flow_shift'] = flow_shift
328
  pipe.scheduler = scheduler_class.from_config(config)
329
 
330
+ print(f"Generating {num_frames} frames with Wan...")
331
+ start = time.time()
332
  result = pipe(
333
  image=resized_image,
334
  last_image=processed_last_image,
 
341
  guidance_scale_2=float(guidance_scale_2),
342
  num_inference_steps=int(steps),
343
  generator=torch.Generator(device="cuda").manual_seed(current_seed),
344
+ output_type="np"
345
+ )
346
+ print("gen time passed:", time.time() - start)
347
+
348
+ raw_frames_np = result.frames[0] # Returns (T, H, W, C) float32
349
  pipe.scheduler = original_scheduler
 
350
 
351
+ start = time.time()
352
+ if frame_multiplier > 1:
353
+ print(f"Processing frames (RIFE Multiplier: {frame_multiplier}x)...")
354
+ final_frames = interpolate_bits(raw_frames_np, multiplier=int(frame_multiplier))
355
+ else:
356
+ final_frames = list(raw_frames_np)
357
+ print("Interpolation time passed:", time.time() - start)
358
+
359
+ final_fps = FIXED_FPS * int(frame_multiplier)
360
+
361
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
362
+ video_path = tmpfile.name
363
+
364
+ print(f"Exporting video at {final_fps} FPS...")
365
+ start = time.time()
366
+ export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
367
+ print("Export time passed:", time.time() - start)
368
+
369
+ return video_path
370
 
371
  def generate_video(
372
  input_image,
 
382
  quality=5,
383
  scheduler="UniPCMultistep",
384
  flow_shift=6.0,
385
+ frame_multiplier=1,
386
  progress=gr.Progress(track_tqdm=True),
387
  ):
388
  """
389
  Generate a video from an input image using the Wan 2.2 14B I2V model with Lightning LoRA.
 
390
  This function takes an input image and generates a video animation based on the provided
391
  prompt and parameters. It uses an FP8 qunatized Wan 2.2 14B Image-to-Video model in with Lightning LoRA
392
  for fast generation in 4-8 steps.
 
393
  Args:
394
  input_image (PIL.Image): The input image to animate. Will be resized to target dimensions.
395
  last_image (PIL.Image, optional): The optional last image for the video.
 
412
  Highest quality is 10, lowest is 1.
413
  scheduler (str, optional): The name of the scheduler to use for inference. Defaults to "UniPCMultistep".
414
  flow_shift (float, optional): The flow shift value for compatible schedulers. Defaults to 6.0.
415
+ frame_multiplier (int, optional): The int value for fps enhancer
416
  progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
 
417
  Returns:
418
  tuple: A tuple containing:
419
  - video_path (str): Path for the video component.
420
  - video_path (str): Path for the file download component. Attempt to avoid reconversion in video component.
421
  - current_seed (int): The seed used for generation.
 
422
  Raises:
423
  gr.Error: If input_image is None (no image uploaded).
 
424
  Note:
425
  - Frame count is calculated as duration_seconds * FIXED_FPS (24)
426
  - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
427
  - The function uses GPU acceleration via the @spaces.GPU decorator
428
  - Generation time varies based on steps and duration (see get_duration function)
429
  """
430
+
431
  if input_image is None:
432
  raise gr.Error("Please upload an input image.")
433
 
 
439
  if last_image:
440
  processed_last_image = resize_and_crop_to_match(last_image, resized_image)
441
 
442
+ video_path = run_inference(
443
  resized_image,
444
  processed_last_image,
445
  prompt,
 
451
  current_seed,
452
  scheduler,
453
  flow_shift,
454
+ frame_multiplier,
455
+ quality,
456
+ duration_seconds,
457
  progress,
458
  )
459
+ print("GPU complete")
 
 
 
 
460
 
461
  return video_path, video_path, current_seed
462
 
463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  with gr.Blocks(delete_cache=(3600, 10800)) as demo:
465
  gr.Markdown("# WAMU - Wan 2.2 I2V (14B)")
466
  gr.Markdown("## ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
467
  gr.Markdown("run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
468
+
469
  with gr.Row():
470
  with gr.Column():
471
  input_image_component = gr.Image(type="pil", label="Input Image")
472
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
473
+ duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Base Duration (seconds)")
474
  steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
475
+ frame_multi = gr.Dropdown(
476
+ choices=[1, 2, 4, 8],
477
+ value=1,
478
+ label="Frame Rate Enhancer (Interpolation)",
479
+ info="2 = Double FPS (e.g. 16 -> 32). Higher multipliers create more intermediate frames."
480
+ )
481
  with gr.Accordion("Advanced Settings", open=False):
482
  last_image_component = gr.Image(type="pil", label="Last Image (Optional)")
483
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
484
  quality_slider = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Video Quality")
485
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
486
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
487
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale")
488
+ guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2")
489
  scheduler_dropdown = gr.Dropdown(
490
  label="Scheduler",
491
  choices=list(SCHEDULER_MAP.keys()),
492
+ value="UniPCMultistep"
 
493
  )
494
  flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
495
 
 
 
 
 
 
496
  generate_button = gr.Button("Generate Video", variant="primary")
497
+
498
  with gr.Column():
499
+ video_output = gr.Video(label="Generated Video", autoplay=True)
500
  file_output = gr.File(label="Download Video")
501
 
 
 
 
502
  ui_inputs = [
503
  input_image_component, last_image_component, prompt_input, steps_slider,
504
  negative_prompt_input, duration_seconds_input,
505
  guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox,
506
+ quality_slider, scheduler_dropdown, flow_shift_slider, frame_multi
507
  ]
508
+
509
+ generate_button.click(
510
+ fn=generate_video,
511
+ inputs=ui_inputs,
512
+ outputs=[video_output, file_output, seed_input]
513
+ )
514
 
515
  if __name__ == "__main__":
516
  demo.queue().launch(mcp_server=True)