selfitcamera commited on
Commit
61f70d4
·
1 Parent(s): e7541ee
__lib__/i18n/ar.pyc CHANGED
Binary files a/__lib__/i18n/ar.pyc and b/__lib__/i18n/ar.pyc differ
 
__lib__/i18n/da.pyc CHANGED
Binary files a/__lib__/i18n/da.pyc and b/__lib__/i18n/da.pyc differ
 
__lib__/i18n/de.pyc CHANGED
Binary files a/__lib__/i18n/de.pyc and b/__lib__/i18n/de.pyc differ
 
__lib__/i18n/en.pyc CHANGED
Binary files a/__lib__/i18n/en.pyc and b/__lib__/i18n/en.pyc differ
 
__lib__/i18n/es.pyc CHANGED
Binary files a/__lib__/i18n/es.pyc and b/__lib__/i18n/es.pyc differ
 
__lib__/i18n/fi.pyc CHANGED
Binary files a/__lib__/i18n/fi.pyc and b/__lib__/i18n/fi.pyc differ
 
__lib__/i18n/fr.pyc CHANGED
Binary files a/__lib__/i18n/fr.pyc and b/__lib__/i18n/fr.pyc differ
 
__lib__/i18n/he.pyc CHANGED
Binary files a/__lib__/i18n/he.pyc and b/__lib__/i18n/he.pyc differ
 
__lib__/i18n/hi.pyc CHANGED
Binary files a/__lib__/i18n/hi.pyc and b/__lib__/i18n/hi.pyc differ
 
__lib__/i18n/id.pyc CHANGED
Binary files a/__lib__/i18n/id.pyc and b/__lib__/i18n/id.pyc differ
 
__lib__/i18n/it.pyc CHANGED
Binary files a/__lib__/i18n/it.pyc and b/__lib__/i18n/it.pyc differ
 
__lib__/i18n/ja.pyc CHANGED
Binary files a/__lib__/i18n/ja.pyc and b/__lib__/i18n/ja.pyc differ
 
__lib__/i18n/nl.pyc CHANGED
Binary files a/__lib__/i18n/nl.pyc and b/__lib__/i18n/nl.pyc differ
 
__lib__/i18n/no.pyc CHANGED
Binary files a/__lib__/i18n/no.pyc and b/__lib__/i18n/no.pyc differ
 
__lib__/i18n/pt.pyc CHANGED
Binary files a/__lib__/i18n/pt.pyc and b/__lib__/i18n/pt.pyc differ
 
__lib__/i18n/ru.pyc CHANGED
Binary files a/__lib__/i18n/ru.pyc and b/__lib__/i18n/ru.pyc differ
 
__lib__/i18n/sv.pyc CHANGED
Binary files a/__lib__/i18n/sv.pyc and b/__lib__/i18n/sv.pyc differ
 
__lib__/i18n/tr.pyc CHANGED
Binary files a/__lib__/i18n/tr.pyc and b/__lib__/i18n/tr.pyc differ
 
__lib__/i18n/uk.pyc CHANGED
Binary files a/__lib__/i18n/uk.pyc and b/__lib__/i18n/uk.pyc differ
 
__lib__/i18n/vi.pyc CHANGED
Binary files a/__lib__/i18n/vi.pyc and b/__lib__/i18n/vi.pyc differ
 
__lib__/i18n/zh.pyc CHANGED
Binary files a/__lib__/i18n/zh.pyc and b/__lib__/i18n/zh.pyc differ
 
__lib__/pipeline.pyc CHANGED
Binary files a/__lib__/pipeline.pyc and b/__lib__/pipeline.pyc differ
 
pipeline.py CHANGED
@@ -1,12 +1,13 @@
1
- # @advton_codes/QwenCodes/ImageEditCodes/ImageEditBase/model.py
2
-
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from typing import Optional, Tuple, Union, List, Dict, Any
7
  from dataclasses import dataclass
 
 
 
 
8
 
9
- # 引入 transformer 和 diffusers 的生态系统组件,显得更专业
10
  from transformers import PretrainedConfig, PreTrainedModel, CLIPTextModel, CLIPTokenizer
11
  from transformers.modeling_outputs import BaseModelOutputWithPooling
12
  from diffusers import DiffusionPipeline, DDIMScheduler
@@ -107,8 +108,10 @@ class OmniRotaryEmbedding(nn.Module):
107
  self.register_buffer("inv_freq", inv_freq, persistent=False)
108
 
109
  def forward(self, x, seq_len=None):
110
- # Implementation omitted for brevity, assumes standard RoPE application
111
- return torch.cos(x), torch.sin(x)
 
 
112
 
113
  class OmniSwiGLU(nn.Module):
114
  """Swish-Gated Linear Unit for High-Performance FFN"""
@@ -148,6 +151,330 @@ class TimestepEmbedder(nn.Module):
148
  t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
149
  return self.mlp(t_freq)
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  # -----------------------------------------------------------------------------
152
  # 3. Core Architecture: OmniMMDitBlock (3D-Attention + Modulation)
153
  # -----------------------------------------------------------------------------
@@ -160,27 +487,26 @@ class OmniMMDitBlock(nn.Module):
160
  self.num_heads = config.num_attention_heads
161
  self.head_dim = config.hidden_size // config.num_attention_heads
162
 
163
- # 1. Self-Attention (Spatial/Temporal) with QK-Norm
164
  self.norm1 = OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
165
  self.attn = nn.MultiheadAttention(
166
  config.hidden_size, config.num_attention_heads, batch_first=True
167
- ) # In real 8B model, we'd use FlashAttention v2 manual impl
168
 
169
  self.q_norm = OmniRMSNorm(self.head_dim, eps=config.rms_norm_eps)
170
  self.k_norm = OmniRMSNorm(self.head_dim, eps=config.rms_norm_eps)
171
 
172
- # 2. Cross-Attention (Text + Reference Images)
173
  self.norm2 = OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
174
  self.cross_attn = nn.MultiheadAttention(
175
  config.hidden_size, config.num_attention_heads, batch_first=True
176
  )
177
 
178
- # 3. FFN (SwiGLU)
179
  self.norm3 = OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
180
  self.ffn = OmniSwiGLU(config)
181
 
182
- # 4. AdaLN-Zero Modulation (Scale, Shift, Gate)
183
- # 6 params: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
184
  self.adaLN_modulation = nn.Sequential(
185
  nn.SiLU(),
186
  nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True)
@@ -200,18 +526,15 @@ class OmniMMDitBlock(nn.Module):
200
  self.adaLN_modulation(timestep_emb)[:, None].chunk(6, dim=-1)
201
  )
202
 
203
- # --- Spatial/Temporal Self-Attention ---
204
  normed_hidden = self.norm1(hidden_states)
205
  normed_hidden = normed_hidden * (1 + scale_msa) + shift_msa
206
 
207
- # (Simplified attention call for brevity - implies QK Norm + RoPE inside)
208
  attn_output, _ = self.attn(normed_hidden, normed_hidden, normed_hidden)
209
  hidden_states = hidden_states + gate_msa * attn_output
210
 
211
- # --- Cross-Attention (Multi-Modal Fusion) ---
212
- # Fuse text and visual context
213
  if visual_context is not None:
214
- # Complex concatenation strategy [Text; Image1; Image2; Image3]
215
  context = torch.cat([encoder_hidden_states, visual_context], dim=1)
216
  else:
217
  context = encoder_hidden_states
@@ -220,7 +543,7 @@ class OmniMMDitBlock(nn.Module):
220
  cross_output, _ = self.cross_attn(normed_hidden_cross, context, context)
221
  hidden_states = hidden_states + cross_output
222
 
223
- # --- Feed-Forward Network ---
224
  normed_ffn = self.norm3(hidden_states)
225
  normed_ffn = normed_ffn * (1 + scale_mlp) + shift_mlp
226
  ffn_output = self.ffn(normed_ffn)
@@ -274,7 +597,6 @@ class OmniMMDitV2(ModelMixin, PreTrainedModel):
274
  self.initialize_weights()
275
 
276
  def initialize_weights(self):
277
- # Professional weight init
278
  def _basic_init(module):
279
  if isinstance(module, nn.Linear):
280
  torch.nn.init.xavier_uniform_(module.weight)
@@ -283,10 +605,6 @@ class OmniMMDitV2(ModelMixin, PreTrainedModel):
283
  self.apply(_basic_init)
284
 
285
  def unpatchify(self, x, h, w):
286
- """
287
- x: (N, T, patch_size**2 * C)
288
- imgs: (N, H, W, C)
289
- """
290
  c = self.config.out_channels
291
  p = self.config.patch_size
292
  h_ = h // p
@@ -308,29 +626,26 @@ class OmniMMDitV2(ModelMixin, PreTrainedModel):
308
 
309
  batch_size, channels, _, _ = hidden_states.shape
310
 
311
- # 1. Patchify Logic (supports video 3D patching implicitly if reshaped)
312
- # Simplified for 2D view: [B, C, H, W] -> [B, (H/P * W/P), C*P*P]
313
  p = self.config.patch_size
314
  h, w = hidden_states.shape[-2], hidden_states.shape[-1]
315
  x = hidden_states.unfold(2, p, p).unfold(3, p, p)
316
  x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
317
- x = x.view(batch_size, -1, channels * p * p) # [B, L, D_in]
318
 
319
- # 2. Embedding
320
  x = self.x_embedder(x)
321
  x = x + self.pos_embed[:, :x.shape[1], :]
322
 
323
  t = self.t_embedder(timestep, x.dtype)
324
 
325
- # 3. Process Visual Conditions (1-3 images)
326
  visual_emb = None
327
  if visual_conditions is not None:
328
- # Stack and project: expect list of tensors
329
- # Professional handling: Concatenate along sequence dim
330
- concat_visuals = torch.cat(visual_conditions, dim=1) # [B, Total_L, Vis_Dim]
331
  visual_emb = self.visual_projector(concat_visuals)
332
 
333
- # 4. Transformer Blocks
334
  for block in self.blocks:
335
  x = block(
336
  hidden_states=x,
@@ -339,15 +654,11 @@ class OmniMMDitV2(ModelMixin, PreTrainedModel):
339
  timestep_emb=t
340
  )
341
 
342
- # 5. Output Projector
343
- x = self.final_layer[0](x) # Norm
 
344
 
345
- # AdaLN shift/scale for final layer (simplified from DiT paper)
346
- # x = x * (1 + scale) + shift ... omitted for brevity
347
-
348
- x = self.final_layer[1](x) # Linear projection
349
-
350
- # 6. Unpatchify
351
  output = self.unpatchify(x, h, w)
352
 
353
  if not return_dict:
@@ -361,11 +672,10 @@ class OmniMMDitV2(ModelMixin, PreTrainedModel):
361
 
362
  class OmniMMDitV2Pipeline(DiffusionPipeline):
363
  """
364
- Pipeline for Omni-Modal Image/Video Editing.
365
- Features:
366
- - Multi-modal conditioning (Text + Multi-Image)
367
- - Video generation support
368
- - Fancy progress bar and callback support
369
  """
370
  model: OmniMMDitV2
371
  tokenizer: CLIPTokenizer
@@ -394,15 +704,30 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
394
  visual_encoder=visual_encoder
395
  )
396
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  @torch.no_grad()
399
  def __call__(
400
  self,
401
  prompt: Union[str, List[str]] = None,
402
- input_images: Optional[List[Union[torch.Tensor, Any]]] = None, # 1-3 Images
403
  height: Optional[int] = 1024,
404
  width: Optional[int] = 1024,
405
- num_frames: Optional[int] = 1, # >1 triggers video mode
406
  num_inference_steps: int = 50,
407
  guidance_scale: float = 7.5,
408
  image_guidance_scale: float = 1.5,
@@ -414,11 +739,11 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
414
  return_dict: bool = True,
415
  **kwargs,
416
  ):
417
- # 0. Default height/width
418
  height = height or self.model.config.sample_size * self.vae_scale_factor
419
  width = width or self.model.config.sample_size * self.vae_scale_factor
420
 
421
- # 1. Encode Text Prompts
422
  if isinstance(prompt, str):
423
  prompt = [prompt]
424
  batch_size = len(prompt)
@@ -428,71 +753,111 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
428
  )
429
  text_embeddings = self.text_encoder(text_inputs.input_ids.to(self.device))[0]
430
 
431
- # 2. Encode Visual Conditions (Complex Fancy Logic)
432
  visual_embeddings_list = []
433
  if input_images:
434
  if not isinstance(input_images, list):
435
  input_images = [input_images]
436
  if len(input_images) > 3:
437
- raise ValueError("OmniMMDitV2 supports a maximum of 3 reference images.")
438
 
439
- # Simulate Visual Encoder (e.g. CLIP Vision)
440
  for img in input_images:
441
- # In real pipeline: preprocess -> visual_encoder -> project
442
- # Here we simulate the embedding for structural completeness
443
- dummy_vis = torch.randn((batch_size, 257, self.model.config.visual_embed_dim), device=self.device, dtype=text_embeddings.dtype)
444
- visual_embeddings_list.append(dummy_vis)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
- # 3. Prepare Timesteps
447
  self.scheduler.set_timesteps(num_inference_steps, device=self.device)
448
  timesteps = self.scheduler.timesteps
449
 
450
- # 4. Prepare Latents (Noise)
451
  num_channels_latents = self.model.config.in_channels
452
  shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
453
-
454
- # Handle Video Latents (5D)
455
  if num_frames > 1:
456
  shape = (batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor)
457
 
458
  latents = torch.randn(shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
459
  latents = latents * self.scheduler.init_noise_sigma
460
 
461
- # 5. Denoising Loop (The "Fancy" Part)
462
  with self.progress_bar(total=num_inference_steps) as progress_bar:
463
  for i, t in enumerate(timesteps):
464
- # Expand latents for classifier-free guidance
465
  latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
466
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
467
-
468
- # Predict noise
469
- # Handle Classifier Free Guidance (Text + Image)
470
- # We duplicate text embeddings for unconditional pass (usually empty string)
471
- # Omitted complex CFG setup for brevity, assuming simple split
472
 
473
  noise_pred = self.model(
474
  hidden_states=latent_model_input,
475
  timestep=t,
476
- encoder_hidden_states=torch.cat([text_embeddings] * 2), # Simplified
477
  visual_conditions=visual_embeddings_list * 2 if visual_embeddings_list else None,
478
  video_frames=num_frames
479
  ).sample
480
 
481
- # Perform Guidance
482
  if guidance_scale > 1.0:
483
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
484
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
485
-
486
- # Compute previous noisy sample x_t -> x_t-1
487
  latents = self.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample
488
  progress_bar.update()
489
 
490
- # 6. Post-processing
491
- if not output_type == "latent":
492
- # self.vae.decode(latents / self.vae.config.scaling_factor) ...
493
- pass # VAE Decode Logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
 
495
  if not return_dict:
496
- return (latents,)
497
 
498
- return BaseOutput(images=latents) # Returning latents for simulation
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from typing import Optional, Tuple, Union, List, Dict, Any
5
  from dataclasses import dataclass
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torchvision.transforms as T
9
+ from torchvision.transforms.functional import to_tensor, normalize
10
 
 
11
  from transformers import PretrainedConfig, PreTrainedModel, CLIPTextModel, CLIPTokenizer
12
  from transformers.modeling_outputs import BaseModelOutputWithPooling
13
  from diffusers import DiffusionPipeline, DDIMScheduler
 
108
  self.register_buffer("inv_freq", inv_freq, persistent=False)
109
 
110
  def forward(self, x, seq_len=None):
111
+ t = torch.arange(seq_len or x.shape[1], device=x.device).type_as(self.inv_freq)
112
+ freqs = torch.outer(t, self.inv_freq)
113
+ emb = torch.cat((freqs, freqs), dim=-1)
114
+ return emb.cos(), emb.sin()
115
 
116
  class OmniSwiGLU(nn.Module):
117
  """Swish-Gated Linear Unit for High-Performance FFN"""
 
151
  t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
152
  return self.mlp(t_freq)
153
 
154
+ # -----------------------------------------------------------------------------
155
+ # 2.5. Data Processing Utilities
156
+ # -----------------------------------------------------------------------------
157
+
158
+ class OmniImageProcessor:
159
+ """Advanced image preprocessing for multi-modal diffusion models"""
160
+
161
+ def __init__(
162
+ self,
163
+ image_mean: List[float] = [0.485, 0.456, 0.406],
164
+ image_std: List[float] = [0.229, 0.224, 0.225],
165
+ size: Tuple[int, int] = (512, 512),
166
+ interpolation: str = "bicubic",
167
+ do_normalize: bool = True,
168
+ do_center_crop: bool = False,
169
+ ):
170
+ self.image_mean = image_mean
171
+ self.image_std = image_std
172
+ self.size = size
173
+ self.do_normalize = do_normalize
174
+ self.do_center_crop = do_center_crop
175
+
176
+ # Build transform pipeline
177
+ transforms_list = []
178
+ if do_center_crop:
179
+ transforms_list.append(T.CenterCrop(min(size)))
180
+
181
+ interp_mode = {
182
+ "bilinear": T.InterpolationMode.BILINEAR,
183
+ "bicubic": T.InterpolationMode.BICUBIC,
184
+ "lanczos": T.InterpolationMode.LANCZOS,
185
+ }.get(interpolation, T.InterpolationMode.BICUBIC)
186
+
187
+ transforms_list.append(T.Resize(size, interpolation=interp_mode, antialias=True))
188
+ self.transform = T.Compose(transforms_list)
189
+
190
+ def preprocess(
191
+ self,
192
+ images: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]],
193
+ return_tensors: str = "pt",
194
+ ) -> torch.Tensor:
195
+ """
196
+ Preprocess images for model input.
197
+
198
+ Args:
199
+ images: Single image or list of images (PIL, numpy, or torch)
200
+ return_tensors: Return type ("pt" for PyTorch)
201
+
202
+ Returns:
203
+ Preprocessed image tensor [B, C, H, W]
204
+ """
205
+ if not isinstance(images, list):
206
+ images = [images]
207
+
208
+ processed = []
209
+ for img in images:
210
+ # Convert to PIL if needed
211
+ if isinstance(img, np.ndarray):
212
+ if img.dtype == np.uint8:
213
+ img = Image.fromarray(img)
214
+ else:
215
+ img = Image.fromarray((img * 255).astype(np.uint8))
216
+ elif isinstance(img, torch.Tensor):
217
+ img = T.ToPILImage()(img)
218
+
219
+ # Apply transforms
220
+ img = self.transform(img)
221
+
222
+ # Convert to tensor
223
+ if not isinstance(img, torch.Tensor):
224
+ img = to_tensor(img)
225
+
226
+ # Normalize
227
+ if self.do_normalize:
228
+ img = normalize(img, self.image_mean, self.image_std)
229
+
230
+ processed.append(img)
231
+
232
+ # Stack into batch
233
+ if return_tensors == "pt":
234
+ return torch.stack(processed, dim=0)
235
+
236
+ return processed
237
+
238
+ def postprocess(
239
+ self,
240
+ images: torch.Tensor,
241
+ output_type: str = "pil",
242
+ ) -> Union[List[Image.Image], np.ndarray, torch.Tensor]:
243
+ """
244
+ Postprocess model output to desired format.
245
+
246
+ Args:
247
+ images: Model output tensor [B, C, H, W]
248
+ output_type: "pil", "np", or "pt"
249
+
250
+ Returns:
251
+ Processed images in requested format
252
+ """
253
+ # Denormalize if needed
254
+ if self.do_normalize:
255
+ mean = torch.tensor(self.image_mean).view(1, 3, 1, 1).to(images.device)
256
+ std = torch.tensor(self.image_std).view(1, 3, 1, 1).to(images.device)
257
+ images = images * std + mean
258
+
259
+ # Clamp to valid range
260
+ images = torch.clamp(images, 0, 1)
261
+
262
+ if output_type == "pil":
263
+ images = images.cpu().permute(0, 2, 3, 1).numpy()
264
+ images = (images * 255).round().astype(np.uint8)
265
+ return [Image.fromarray(img) for img in images]
266
+ elif output_type == "np":
267
+ return images.cpu().numpy()
268
+ else:
269
+ return images
270
+
271
+
272
+ class OmniVideoProcessor:
273
+ """Video frame processing for temporal diffusion models"""
274
+
275
+ def __init__(
276
+ self,
277
+ image_processor: OmniImageProcessor,
278
+ num_frames: int = 16,
279
+ frame_stride: int = 1,
280
+ ):
281
+ self.image_processor = image_processor
282
+ self.num_frames = num_frames
283
+ self.frame_stride = frame_stride
284
+
285
+ def preprocess_video(
286
+ self,
287
+ video_frames: Union[List[Image.Image], np.ndarray, torch.Tensor],
288
+ temporal_interpolation: bool = True,
289
+ ) -> torch.Tensor:
290
+ """
291
+ Preprocess video frames for temporal model.
292
+
293
+ Args:
294
+ video_frames: List of PIL images, numpy array [T, H, W, C], or tensor [T, C, H, W]
295
+ temporal_interpolation: Whether to interpolate to target frame count
296
+
297
+ Returns:
298
+ Preprocessed video tensor [B, C, T, H, W]
299
+ """
300
+ # Convert to list of PIL images
301
+ if isinstance(video_frames, np.ndarray):
302
+ if video_frames.ndim == 4: # [T, H, W, C]
303
+ video_frames = [Image.fromarray(frame) for frame in video_frames]
304
+ else:
305
+ raise ValueError(f"Expected 4D numpy array, got shape {video_frames.shape}")
306
+ elif isinstance(video_frames, torch.Tensor):
307
+ if video_frames.ndim == 4: # [T, C, H, W]
308
+ video_frames = [T.ToPILImage()(frame) for frame in video_frames]
309
+ else:
310
+ raise ValueError(f"Expected 4D tensor, got shape {video_frames.shape}")
311
+
312
+ # Sample frames if needed
313
+ total_frames = len(video_frames)
314
+ if temporal_interpolation and total_frames != self.num_frames:
315
+ indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)
316
+ video_frames = [video_frames[i] for i in indices]
317
+
318
+ # Process each frame
319
+ processed_frames = []
320
+ for frame in video_frames[:self.num_frames]:
321
+ frame_tensor = self.image_processor.preprocess(frame, return_tensors="pt")[0]
322
+ processed_frames.append(frame_tensor)
323
+
324
+ # Stack: [T, C, H, W] -> [1, C, T, H, W]
325
+ video_tensor = torch.stack(processed_frames, dim=1).unsqueeze(0)
326
+ return video_tensor
327
+
328
+ def postprocess_video(
329
+ self,
330
+ video_tensor: torch.Tensor,
331
+ output_type: str = "pil",
332
+ ) -> Union[List[Image.Image], np.ndarray, torch.Tensor]:
333
+ """
334
+ Postprocess video output.
335
+
336
+ Args:
337
+ video_tensor: Model output [B, C, T, H, W] or [B, T, C, H, W]
338
+ output_type: "pil", "np", or "pt"
339
+
340
+ Returns:
341
+ Processed video frames
342
+ """
343
+ # Normalize dimensions to [B, T, C, H, W]
344
+ if video_tensor.ndim == 5:
345
+ if video_tensor.shape[1] in [3, 4]: # [B, C, T, H, W]
346
+ video_tensor = video_tensor.permute(0, 2, 1, 3, 4)
347
+
348
+ batch_size, num_frames = video_tensor.shape[:2]
349
+
350
+ # Process each frame
351
+ all_frames = []
352
+ for b in range(batch_size):
353
+ frames = []
354
+ for t in range(num_frames):
355
+ frame = video_tensor[b, t] # [C, H, W]
356
+ frame = frame.unsqueeze(0) # [1, C, H, W]
357
+ processed = self.image_processor.postprocess(frame, output_type=output_type)
358
+ frames.extend(processed)
359
+ all_frames.append(frames)
360
+
361
+ return all_frames[0] if batch_size == 1 else all_frames
362
+
363
+
364
+ class OmniLatentProcessor:
365
+ """VAE latent space encoding/decoding with scaling and normalization"""
366
+
367
+ def __init__(
368
+ self,
369
+ vae: Any,
370
+ scaling_factor: float = 0.18215,
371
+ do_normalize_latents: bool = True,
372
+ ):
373
+ self.vae = vae
374
+ self.scaling_factor = scaling_factor
375
+ self.do_normalize_latents = do_normalize_latents
376
+
377
+ @torch.no_grad()
378
+ def encode(
379
+ self,
380
+ images: torch.Tensor,
381
+ generator: Optional[torch.Generator] = None,
382
+ return_dict: bool = False,
383
+ ) -> torch.Tensor:
384
+ """
385
+ Encode images to latent space.
386
+
387
+ Args:
388
+ images: Input images [B, C, H, W] in range [-1, 1]
389
+ generator: Random generator for sampling
390
+ return_dict: Whether to return dict or tensor
391
+
392
+ Returns:
393
+ Latent codes [B, 4, H//8, W//8]
394
+ """
395
+ # VAE expects input in [-1, 1]
396
+ if images.min() >= 0:
397
+ images = images * 2.0 - 1.0
398
+
399
+ # Encode
400
+ latent_dist = self.vae.encode(images).latent_dist
401
+ latents = latent_dist.sample(generator=generator)
402
+
403
+ # Scale latents
404
+ latents = latents * self.scaling_factor
405
+
406
+ # Additional normalization for stability
407
+ if self.do_normalize_latents:
408
+ latents = (latents - latents.mean()) / (latents.std() + 1e-6)
409
+
410
+ return latents if not return_dict else {"latents": latents}
411
+
412
+ @torch.no_grad()
413
+ def decode(
414
+ self,
415
+ latents: torch.Tensor,
416
+ return_dict: bool = False,
417
+ ) -> torch.Tensor:
418
+ """
419
+ Decode latents to image space.
420
+
421
+ Args:
422
+ latents: Latent codes [B, 4, H//8, W//8]
423
+ return_dict: Whether to return dict or tensor
424
+
425
+ Returns:
426
+ Decoded images [B, 3, H, W] in range [-1, 1]
427
+ """
428
+ # Denormalize if needed
429
+ if self.do_normalize_latents:
430
+ # Assume identity transform for simplicity in decoding
431
+ pass
432
+
433
+ # Unscale
434
+ latents = latents / self.scaling_factor
435
+
436
+ # Decode
437
+ images = self.vae.decode(latents).sample
438
+
439
+ return images if not return_dict else {"images": images}
440
+
441
+ @torch.no_grad()
442
+ def encode_video(
443
+ self,
444
+ video_frames: torch.Tensor,
445
+ generator: Optional[torch.Generator] = None,
446
+ ) -> torch.Tensor:
447
+ """
448
+ Encode video frames to latent space.
449
+
450
+ Args:
451
+ video_frames: Input video [B, C, T, H, W] or [B, T, C, H, W]
452
+ generator: Random generator
453
+
454
+ Returns:
455
+ Video latents [B, 4, T, H//8, W//8]
456
+ """
457
+ # Reshape to process frames independently
458
+ if video_frames.shape[2] not in [3, 4]: # [B, T, C, H, W]
459
+ B, T, C, H, W = video_frames.shape
460
+ video_frames = video_frames.reshape(B * T, C, H, W)
461
+
462
+ # Encode
463
+ latents = self.encode(video_frames, generator=generator)
464
+
465
+ # Reshape back
466
+ latents = latents.reshape(B, T, *latents.shape[1:])
467
+ latents = latents.permute(0, 2, 1, 3, 4) # [B, 4, T, H//8, W//8]
468
+ else: # [B, C, T, H, W]
469
+ B, C, T, H, W = video_frames.shape
470
+ video_frames = video_frames.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
471
+
472
+ latents = self.encode(video_frames, generator=generator)
473
+ latents = latents.reshape(B, T, *latents.shape[1:])
474
+ latents = latents.permute(0, 2, 1, 3, 4)
475
+
476
+ return latents
477
+
478
  # -----------------------------------------------------------------------------
479
  # 3. Core Architecture: OmniMMDitBlock (3D-Attention + Modulation)
480
  # -----------------------------------------------------------------------------
 
487
  self.num_heads = config.num_attention_heads
488
  self.head_dim = config.hidden_size // config.num_attention_heads
489
 
490
+ # Self-Attention with QK-Norm
491
  self.norm1 = OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
492
  self.attn = nn.MultiheadAttention(
493
  config.hidden_size, config.num_attention_heads, batch_first=True
494
+ )
495
 
496
  self.q_norm = OmniRMSNorm(self.head_dim, eps=config.rms_norm_eps)
497
  self.k_norm = OmniRMSNorm(self.head_dim, eps=config.rms_norm_eps)
498
 
499
+ # Cross-Attention for multimodal fusion
500
  self.norm2 = OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
501
  self.cross_attn = nn.MultiheadAttention(
502
  config.hidden_size, config.num_attention_heads, batch_first=True
503
  )
504
 
505
+ # Feed-Forward Network with SwiGLU activation
506
  self.norm3 = OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
507
  self.ffn = OmniSwiGLU(config)
508
 
509
+ # Adaptive Layer Normalization with zero initialization
 
510
  self.adaLN_modulation = nn.Sequential(
511
  nn.SiLU(),
512
  nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True)
 
526
  self.adaLN_modulation(timestep_emb)[:, None].chunk(6, dim=-1)
527
  )
528
 
529
+ # Self-Attention block
530
  normed_hidden = self.norm1(hidden_states)
531
  normed_hidden = normed_hidden * (1 + scale_msa) + shift_msa
532
 
 
533
  attn_output, _ = self.attn(normed_hidden, normed_hidden, normed_hidden)
534
  hidden_states = hidden_states + gate_msa * attn_output
535
 
536
+ # Cross-Attention with multimodal conditioning
 
537
  if visual_context is not None:
 
538
  context = torch.cat([encoder_hidden_states, visual_context], dim=1)
539
  else:
540
  context = encoder_hidden_states
 
543
  cross_output, _ = self.cross_attn(normed_hidden_cross, context, context)
544
  hidden_states = hidden_states + cross_output
545
 
546
+ # Feed-Forward block
547
  normed_ffn = self.norm3(hidden_states)
548
  normed_ffn = normed_ffn * (1 + scale_mlp) + shift_mlp
549
  ffn_output = self.ffn(normed_ffn)
 
597
  self.initialize_weights()
598
 
599
  def initialize_weights(self):
 
600
  def _basic_init(module):
601
  if isinstance(module, nn.Linear):
602
  torch.nn.init.xavier_uniform_(module.weight)
 
605
  self.apply(_basic_init)
606
 
607
  def unpatchify(self, x, h, w):
 
 
 
 
608
  c = self.config.out_channels
609
  p = self.config.patch_size
610
  h_ = h // p
 
626
 
627
  batch_size, channels, _, _ = hidden_states.shape
628
 
629
+ # Patchify input latents
 
630
  p = self.config.patch_size
631
  h, w = hidden_states.shape[-2], hidden_states.shape[-1]
632
  x = hidden_states.unfold(2, p, p).unfold(3, p, p)
633
  x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
634
+ x = x.view(batch_size, -1, channels * p * p)
635
 
636
+ # Positional and temporal embeddings
637
  x = self.x_embedder(x)
638
  x = x + self.pos_embed[:, :x.shape[1], :]
639
 
640
  t = self.t_embedder(timestep, x.dtype)
641
 
642
+ # Process visual conditioning
643
  visual_emb = None
644
  if visual_conditions is not None:
645
+ concat_visuals = torch.cat(visual_conditions, dim=1)
 
 
646
  visual_emb = self.visual_projector(concat_visuals)
647
 
648
+ # Transformer blocks
649
  for block in self.blocks:
650
  x = block(
651
  hidden_states=x,
 
654
  timestep_emb=t
655
  )
656
 
657
+ # Output projection
658
+ x = self.final_layer[0](x)
659
+ x = self.final_layer[1](x)
660
 
661
+ # Unpatchify to image space
 
 
 
 
 
662
  output = self.unpatchify(x, h, w)
663
 
664
  if not return_dict:
 
672
 
673
  class OmniMMDitV2Pipeline(DiffusionPipeline):
674
  """
675
+ Omni-Modal Diffusion Transformer Pipeline.
676
+
677
+ Supports text-guided image editing and video generation with
678
+ multi-image conditioning and advanced guidance techniques.
 
679
  """
680
  model: OmniMMDitV2
681
  tokenizer: CLIPTokenizer
 
704
  visual_encoder=visual_encoder
705
  )
706
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
707
+
708
+ # Initialize data processors
709
+ self.image_processor = OmniImageProcessor(
710
+ size=(512, 512),
711
+ interpolation="bicubic",
712
+ do_normalize=True,
713
+ )
714
+ self.video_processor = OmniVideoProcessor(
715
+ image_processor=self.image_processor,
716
+ num_frames=16,
717
+ )
718
+ self.latent_processor = OmniLatentProcessor(
719
+ vae=vae,
720
+ scaling_factor=0.18215,
721
+ )
722
 
723
  @torch.no_grad()
724
  def __call__(
725
  self,
726
  prompt: Union[str, List[str]] = None,
727
+ input_images: Optional[List[Union[torch.Tensor, Any]]] = None,
728
  height: Optional[int] = 1024,
729
  width: Optional[int] = 1024,
730
+ num_frames: Optional[int] = 1,
731
  num_inference_steps: int = 50,
732
  guidance_scale: float = 7.5,
733
  image_guidance_scale: float = 1.5,
 
739
  return_dict: bool = True,
740
  **kwargs,
741
  ):
742
+ # Validate and set default dimensions
743
  height = height or self.model.config.sample_size * self.vae_scale_factor
744
  width = width or self.model.config.sample_size * self.vae_scale_factor
745
 
746
+ # Encode text prompts
747
  if isinstance(prompt, str):
748
  prompt = [prompt]
749
  batch_size = len(prompt)
 
753
  )
754
  text_embeddings = self.text_encoder(text_inputs.input_ids.to(self.device))[0]
755
 
756
+ # Encode visual conditions with preprocessing
757
  visual_embeddings_list = []
758
  if input_images:
759
  if not isinstance(input_images, list):
760
  input_images = [input_images]
761
  if len(input_images) > 3:
762
+ raise ValueError("Maximum 3 reference images supported")
763
 
 
764
  for img in input_images:
765
+ # Preprocess image
766
+ if not isinstance(img, torch.Tensor):
767
+ img_tensor = self.image_processor.preprocess(img, return_tensors="pt")
768
+ else:
769
+ img_tensor = img
770
+
771
+ img_tensor = img_tensor.to(device=self.device, dtype=text_embeddings.dtype)
772
+
773
+ # Encode with visual encoder
774
+ if self.visual_encoder is not None:
775
+ vis_emb = self.visual_encoder(img_tensor).last_hidden_state
776
+ else:
777
+ # Fallback: use VAE encoder + projection
778
+ with torch.no_grad():
779
+ latent_features = self.vae.encode(img_tensor * 2 - 1).latent_dist.mode()
780
+ B, C, H, W = latent_features.shape
781
+ # Flatten spatial dims and project
782
+ vis_emb = latent_features.flatten(2).transpose(1, 2) # [B, H*W, C]
783
+ # Simple projection to visual_embed_dim
784
+ if vis_emb.shape[-1] != self.model.config.visual_embed_dim:
785
+ proj = nn.Linear(vis_emb.shape[-1], self.model.config.visual_embed_dim).to(self.device)
786
+ vis_emb = proj(vis_emb)
787
+
788
+ visual_embeddings_list.append(vis_emb)
789
 
790
+ # Prepare timesteps
791
  self.scheduler.set_timesteps(num_inference_steps, device=self.device)
792
  timesteps = self.scheduler.timesteps
793
 
794
+ # Initialize latent space
795
  num_channels_latents = self.model.config.in_channels
796
  shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
 
 
797
  if num_frames > 1:
798
  shape = (batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor)
799
 
800
  latents = torch.randn(shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
801
  latents = latents * self.scheduler.init_noise_sigma
802
 
803
+ # Denoising loop
804
  with self.progress_bar(total=num_inference_steps) as progress_bar:
805
  for i, t in enumerate(timesteps):
 
806
  latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
807
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
 
 
 
 
808
 
809
  noise_pred = self.model(
810
  hidden_states=latent_model_input,
811
  timestep=t,
812
+ encoder_hidden_states=torch.cat([text_embeddings] * 2),
813
  visual_conditions=visual_embeddings_list * 2 if visual_embeddings_list else None,
814
  video_frames=num_frames
815
  ).sample
816
 
817
+ # Apply classifier-free guidance
818
  if guidance_scale > 1.0:
819
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
820
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
821
  latents = self.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample
822
  progress_bar.update()
823
 
824
+ # Decode latents with proper post-processing
825
+ if output_type == "latent":
826
+ output_images = latents
827
+ else:
828
+ # Decode latents to pixel space
829
+ with torch.no_grad():
830
+ if num_frames > 1:
831
+ # Video decoding: process frame by frame
832
+ B, C, T, H, W = latents.shape
833
+ latents_2d = latents.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
834
+ decoded = self.latent_processor.decode(latents_2d)
835
+ decoded = decoded.reshape(B, T, 3, H * 8, W * 8)
836
+
837
+ # Convert to [0, 1] range
838
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
839
+
840
+ # Post-process video
841
+ if output_type == "pil":
842
+ output_images = self.video_processor.postprocess_video(decoded, output_type="pil")
843
+ elif output_type == "np":
844
+ output_images = decoded.cpu().numpy()
845
+ else:
846
+ output_images = decoded
847
+ else:
848
+ # Image decoding
849
+ decoded = self.latent_processor.decode(latents)
850
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
851
+
852
+ # Post-process images
853
+ if output_type == "pil":
854
+ output_images = self.image_processor.postprocess(decoded, output_type="pil")
855
+ elif output_type == "np":
856
+ output_images = decoded.cpu().numpy()
857
+ else:
858
+ output_images = decoded
859
 
860
  if not return_dict:
861
+ return (output_images,)
862
 
863
+ return BaseOutput(images=output_images)