Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple, Union, List, Dict, Any, Callable | |
| from dataclasses import dataclass | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| from torchvision.transforms.functional import to_tensor, normalize | |
| import warnings | |
| from contextlib import contextmanager | |
| from functools import wraps | |
| from transformers import PretrainedConfig, PreTrainedModel, CLIPTextModel, CLIPTokenizer | |
| from transformers.modeling_outputs import BaseModelOutputWithPooling | |
| from diffusers import DiffusionPipeline, DDIMScheduler | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.utils import BaseOutput | |
| # Optimization imports | |
| try: | |
| import transformer_engine.pytorch as te | |
| from transformer_engine.common import recipe | |
| HAS_TRANSFORMER_ENGINE = True | |
| except ImportError: | |
| HAS_TRANSFORMER_ENGINE = False | |
| try: | |
| from torch._dynamo import config as dynamo_config | |
| HAS_TORCH_COMPILE = hasattr(torch, 'compile') | |
| except ImportError: | |
| HAS_TORCH_COMPILE = False | |
| # ----------------------------------------------------------------------------- | |
| # 1. Advanced Configuration (8B Scale) | |
| # ----------------------------------------------------------------------------- | |
| class OmniMMDitV2Config(PretrainedConfig): | |
| model_type = "omnimm_dit_v2" | |
| def __init__( | |
| self, | |
| vocab_size: int = 49408, | |
| hidden_size: int = 4096, # 4096 dim for ~7B-8B scale | |
| intermediate_size: int = 11008, # Llama-style MLP expansion | |
| num_hidden_layers: int = 32, # Deep network | |
| num_attention_heads: int = 32, | |
| num_key_value_heads: Optional[int] = 8, # GQA (Grouped Query Attention) | |
| hidden_act: str = "silu", | |
| max_position_embeddings: int = 4096, | |
| initializer_range: float = 0.02, | |
| rms_norm_eps: float = 1e-5, | |
| use_cache: bool = True, | |
| pad_token_id: int = 0, | |
| bos_token_id: int = 1, | |
| eos_token_id: int = 2, | |
| tie_word_embeddings: bool = False, | |
| rope_theta: float = 10000.0, | |
| # DiT Specifics | |
| patch_size: int = 2, | |
| in_channels: int = 4, # VAE Latent channels | |
| out_channels: int = 4, # x2 for variance if learned | |
| frequency_embedding_size: int = 256, | |
| # Multi-Modal Specifics | |
| max_condition_images: int = 3, # Support 1-3 input images | |
| visual_embed_dim: int = 1024, # e.g., SigLIP or CLIP Vision | |
| text_embed_dim: int = 4096, # T5-XXL or similar | |
| use_temporal_attention: bool = True, # For Video generation | |
| # Optimization Configs | |
| use_fp8_quantization: bool = False, | |
| use_compilation: bool = False, | |
| compile_mode: str = "reduce-overhead", | |
| use_flash_attention: bool = True, | |
| **kwargs, | |
| ): | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| self.intermediate_size = intermediate_size | |
| self.num_hidden_layers = num_hidden_layers | |
| self.num_attention_heads = num_attention_heads | |
| self.num_key_value_heads = num_key_value_heads | |
| self.hidden_act = hidden_act | |
| self.max_position_embeddings = max_position_embeddings | |
| self.initializer_range = initializer_range | |
| self.rms_norm_eps = rms_norm_eps | |
| self.use_cache = use_cache | |
| self.rope_theta = rope_theta | |
| self.patch_size = patch_size | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.frequency_embedding_size = frequency_embedding_size | |
| self.max_condition_images = max_condition_images | |
| self.visual_embed_dim = visual_embed_dim | |
| self.text_embed_dim = text_embed_dim | |
| self.use_temporal_attention = use_temporal_attention | |
| self.use_fp8_quantization = use_fp8_quantization | |
| self.use_compilation = use_compilation | |
| self.compile_mode = compile_mode | |
| self.use_flash_attention = use_flash_attention | |
| super().__init__( | |
| pad_token_id=pad_token_id, | |
| bos_token_id=bos_token_id, | |
| eos_token_id=eos_token_id, | |
| tie_word_embeddings=tie_word_embeddings, | |
| **kwargs, | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # 2. Professional Building Blocks (RoPE, SwiGLU, AdaLN) | |
| # ----------------------------------------------------------------------------- | |
| class OmniRMSNorm(nn.Module): | |
| def __init__(self, hidden_size, eps=1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, hidden_states): | |
| input_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
| return self.weight * hidden_states.to(input_dtype) | |
| class OmniRotaryEmbedding(nn.Module): | |
| """Complex implementation of Rotary Positional Embeddings for DiT""" | |
| def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | |
| super().__init__() | |
| self.dim = dim | |
| self.max_position_embeddings = max_position_embeddings | |
| self.base = base | |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| def forward(self, x, seq_len=None): | |
| t = torch.arange(seq_len or x.shape[1], device=x.device).type_as(self.inv_freq) | |
| freqs = torch.outer(t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| return emb.cos(), emb.sin() | |
| class OmniSwiGLU(nn.Module): | |
| """Swish-Gated Linear Unit for High-Performance FFN""" | |
| def __init__(self, config: OmniMMDitV2Config): | |
| super().__init__() | |
| self.w1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) | |
| self.w2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) | |
| self.w3 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) | |
| def forward(self, x): | |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| class TimestepEmbedder(nn.Module): | |
| """Fourier feature embedding for timesteps""" | |
| def __init__(self, hidden_size, frequency_embedding_size=256): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(frequency_embedding_size, hidden_size, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, hidden_size, bias=True), | |
| ) | |
| self.frequency_embedding_size = frequency_embedding_size | |
| def timestep_embedding(t, dim, max_period=10000): | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -torch.log(torch.tensor(max_period)) * torch.arange(start=0, end=half, dtype=torch.float32) / half | |
| ).to(device=t.device) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def forward(self, t, dtype): | |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) | |
| return self.mlp(t_freq) | |
| # ----------------------------------------------------------------------------- | |
| # 2.5. Data Processing Utilities | |
| # ----------------------------------------------------------------------------- | |
| class OmniImageProcessor: | |
| """Advanced image preprocessing for multi-modal diffusion models""" | |
| def __init__( | |
| self, | |
| image_mean: List[float] = [0.485, 0.456, 0.406], | |
| image_std: List[float] = [0.229, 0.224, 0.225], | |
| size: Tuple[int, int] = (512, 512), | |
| interpolation: str = "bicubic", | |
| do_normalize: bool = True, | |
| do_center_crop: bool = False, | |
| ): | |
| self.image_mean = image_mean | |
| self.image_std = image_std | |
| self.size = size | |
| self.do_normalize = do_normalize | |
| self.do_center_crop = do_center_crop | |
| # Build transform pipeline | |
| transforms_list = [] | |
| if do_center_crop: | |
| transforms_list.append(T.CenterCrop(min(size))) | |
| interp_mode = { | |
| "bilinear": T.InterpolationMode.BILINEAR, | |
| "bicubic": T.InterpolationMode.BICUBIC, | |
| "lanczos": T.InterpolationMode.LANCZOS, | |
| }.get(interpolation, T.InterpolationMode.BICUBIC) | |
| transforms_list.append(T.Resize(size, interpolation=interp_mode, antialias=True)) | |
| self.transform = T.Compose(transforms_list) | |
| def preprocess( | |
| self, | |
| images: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]], | |
| return_tensors: str = "pt", | |
| ) -> torch.Tensor: | |
| """ | |
| Preprocess images for model input. | |
| Args: | |
| images: Single image or list of images (PIL, numpy, or torch) | |
| return_tensors: Return type ("pt" for PyTorch) | |
| Returns: | |
| Preprocessed image tensor [B, C, H, W] | |
| """ | |
| if not isinstance(images, list): | |
| images = [images] | |
| processed = [] | |
| for img in images: | |
| # Convert to PIL if needed | |
| if isinstance(img, np.ndarray): | |
| if img.dtype == np.uint8: | |
| img = Image.fromarray(img) | |
| else: | |
| img = Image.fromarray((img * 255).astype(np.uint8)) | |
| elif isinstance(img, torch.Tensor): | |
| img = T.ToPILImage()(img) | |
| # Apply transforms | |
| img = self.transform(img) | |
| # Convert to tensor | |
| if not isinstance(img, torch.Tensor): | |
| img = to_tensor(img) | |
| # Normalize | |
| if self.do_normalize: | |
| img = normalize(img, self.image_mean, self.image_std) | |
| processed.append(img) | |
| # Stack into batch | |
| if return_tensors == "pt": | |
| return torch.stack(processed, dim=0) | |
| return processed | |
| def postprocess( | |
| self, | |
| images: torch.Tensor, | |
| output_type: str = "pil", | |
| ) -> Union[List[Image.Image], np.ndarray, torch.Tensor]: | |
| """ | |
| Postprocess model output to desired format. | |
| Args: | |
| images: Model output tensor [B, C, H, W] | |
| output_type: "pil", "np", or "pt" | |
| Returns: | |
| Processed images in requested format | |
| """ | |
| # Denormalize if needed | |
| if self.do_normalize: | |
| mean = torch.tensor(self.image_mean).view(1, 3, 1, 1).to(images.device) | |
| std = torch.tensor(self.image_std).view(1, 3, 1, 1).to(images.device) | |
| images = images * std + mean | |
| # Clamp to valid range | |
| images = torch.clamp(images, 0, 1) | |
| if output_type == "pil": | |
| images = images.cpu().permute(0, 2, 3, 1).numpy() | |
| images = (images * 255).round().astype(np.uint8) | |
| return [Image.fromarray(img) for img in images] | |
| elif output_type == "np": | |
| return images.cpu().numpy() | |
| else: | |
| return images | |
| class OmniVideoProcessor: | |
| """Video frame processing for temporal diffusion models""" | |
| def __init__( | |
| self, | |
| image_processor: OmniImageProcessor, | |
| num_frames: int = 16, | |
| frame_stride: int = 1, | |
| ): | |
| self.image_processor = image_processor | |
| self.num_frames = num_frames | |
| self.frame_stride = frame_stride | |
| def preprocess_video( | |
| self, | |
| video_frames: Union[List[Image.Image], np.ndarray, torch.Tensor], | |
| temporal_interpolation: bool = True, | |
| ) -> torch.Tensor: | |
| """ | |
| Preprocess video frames for temporal model. | |
| Args: | |
| video_frames: List of PIL images, numpy array [T, H, W, C], or tensor [T, C, H, W] | |
| temporal_interpolation: Whether to interpolate to target frame count | |
| Returns: | |
| Preprocessed video tensor [B, C, T, H, W] | |
| """ | |
| # Convert to list of PIL images | |
| if isinstance(video_frames, np.ndarray): | |
| if video_frames.ndim == 4: # [T, H, W, C] | |
| video_frames = [Image.fromarray(frame) for frame in video_frames] | |
| else: | |
| raise ValueError(f"Expected 4D numpy array, got shape {video_frames.shape}") | |
| elif isinstance(video_frames, torch.Tensor): | |
| if video_frames.ndim == 4: # [T, C, H, W] | |
| video_frames = [T.ToPILImage()(frame) for frame in video_frames] | |
| else: | |
| raise ValueError(f"Expected 4D tensor, got shape {video_frames.shape}") | |
| # Sample frames if needed | |
| total_frames = len(video_frames) | |
| if temporal_interpolation and total_frames != self.num_frames: | |
| indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int) | |
| video_frames = [video_frames[i] for i in indices] | |
| # Process each frame | |
| processed_frames = [] | |
| for frame in video_frames[:self.num_frames]: | |
| frame_tensor = self.image_processor.preprocess(frame, return_tensors="pt")[0] | |
| processed_frames.append(frame_tensor) | |
| # Stack: [T, C, H, W] -> [1, C, T, H, W] | |
| video_tensor = torch.stack(processed_frames, dim=1).unsqueeze(0) | |
| return video_tensor | |
| def postprocess_video( | |
| self, | |
| video_tensor: torch.Tensor, | |
| output_type: str = "pil", | |
| ) -> Union[List[Image.Image], np.ndarray, torch.Tensor]: | |
| """ | |
| Postprocess video output. | |
| Args: | |
| video_tensor: Model output [B, C, T, H, W] or [B, T, C, H, W] | |
| output_type: "pil", "np", or "pt" | |
| Returns: | |
| Processed video frames | |
| """ | |
| # Normalize dimensions to [B, T, C, H, W] | |
| if video_tensor.ndim == 5: | |
| if video_tensor.shape[1] in [3, 4]: # [B, C, T, H, W] | |
| video_tensor = video_tensor.permute(0, 2, 1, 3, 4) | |
| batch_size, num_frames = video_tensor.shape[:2] | |
| # Process each frame | |
| all_frames = [] | |
| for b in range(batch_size): | |
| frames = [] | |
| for t in range(num_frames): | |
| frame = video_tensor[b, t] # [C, H, W] | |
| frame = frame.unsqueeze(0) # [1, C, H, W] | |
| processed = self.image_processor.postprocess(frame, output_type=output_type) | |
| frames.extend(processed) | |
| all_frames.append(frames) | |
| return all_frames[0] if batch_size == 1 else all_frames | |
| class OmniLatentProcessor: | |
| """VAE latent space encoding/decoding with scaling and normalization""" | |
| def __init__( | |
| self, | |
| vae: Any, | |
| scaling_factor: float = 0.18215, | |
| do_normalize_latents: bool = True, | |
| ): | |
| self.vae = vae | |
| self.scaling_factor = scaling_factor | |
| self.do_normalize_latents = do_normalize_latents | |
| def encode( | |
| self, | |
| images: torch.Tensor, | |
| generator: Optional[torch.Generator] = None, | |
| return_dict: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Encode images to latent space. | |
| Args: | |
| images: Input images [B, C, H, W] in range [-1, 1] | |
| generator: Random generator for sampling | |
| return_dict: Whether to return dict or tensor | |
| Returns: | |
| Latent codes [B, 4, H//8, W//8] | |
| """ | |
| # VAE expects input in [-1, 1] | |
| if images.min() >= 0: | |
| images = images * 2.0 - 1.0 | |
| # Encode | |
| latent_dist = self.vae.encode(images).latent_dist | |
| latents = latent_dist.sample(generator=generator) | |
| # Scale latents | |
| latents = latents * self.scaling_factor | |
| # Additional normalization for stability | |
| if self.do_normalize_latents: | |
| latents = (latents - latents.mean()) / (latents.std() + 1e-6) | |
| return latents if not return_dict else {"latents": latents} | |
| def decode( | |
| self, | |
| latents: torch.Tensor, | |
| return_dict: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Decode latents to image space. | |
| Args: | |
| latents: Latent codes [B, 4, H//8, W//8] | |
| return_dict: Whether to return dict or tensor | |
| Returns: | |
| Decoded images [B, 3, H, W] in range [-1, 1] | |
| """ | |
| # Denormalize if needed | |
| if self.do_normalize_latents: | |
| # Assume identity transform for simplicity in decoding | |
| pass | |
| # Unscale | |
| latents = latents / self.scaling_factor | |
| # Decode | |
| images = self.vae.decode(latents).sample | |
| return images if not return_dict else {"images": images} | |
| def encode_video( | |
| self, | |
| video_frames: torch.Tensor, | |
| generator: Optional[torch.Generator] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Encode video frames to latent space. | |
| Args: | |
| video_frames: Input video [B, C, T, H, W] or [B, T, C, H, W] | |
| generator: Random generator | |
| Returns: | |
| Video latents [B, 4, T, H//8, W//8] | |
| """ | |
| # Reshape to process frames independently | |
| if video_frames.shape[2] not in [3, 4]: # [B, T, C, H, W] | |
| B, T, C, H, W = video_frames.shape | |
| video_frames = video_frames.reshape(B * T, C, H, W) | |
| # Encode | |
| latents = self.encode(video_frames, generator=generator) | |
| # Reshape back | |
| latents = latents.reshape(B, T, *latents.shape[1:]) | |
| latents = latents.permute(0, 2, 1, 3, 4) # [B, 4, T, H//8, W//8] | |
| else: # [B, C, T, H, W] | |
| B, C, T, H, W = video_frames.shape | |
| video_frames = video_frames.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W) | |
| latents = self.encode(video_frames, generator=generator) | |
| latents = latents.reshape(B, T, *latents.shape[1:]) | |
| latents = latents.permute(0, 2, 1, 3, 4) | |
| return latents | |
| # ----------------------------------------------------------------------------- | |
| # 3. Core Architecture: OmniMMDitBlock (3D-Attention + Modulation) | |
| # ----------------------------------------------------------------------------- | |
| class OmniMMDitBlock(nn.Module): | |
| def __init__(self, config: OmniMMDitV2Config, layer_idx: int): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.hidden_size = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = config.hidden_size // config.num_attention_heads | |
| # Self-Attention with QK-Norm | |
| self.norm1 = OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.attn = nn.MultiheadAttention( | |
| config.hidden_size, config.num_attention_heads, batch_first=True | |
| ) | |
| self.q_norm = OmniRMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| self.k_norm = OmniRMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| # Cross-Attention for multimodal fusion | |
| self.norm2 = OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.cross_attn = nn.MultiheadAttention( | |
| config.hidden_size, config.num_attention_heads, batch_first=True | |
| ) | |
| # Feed-Forward Network with SwiGLU activation | |
| self.norm3 = OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.ffn = OmniSwiGLU(config) | |
| # Adaptive Layer Normalization with zero initialization | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True) | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, # Text embeddings | |
| visual_context: Optional[torch.Tensor], # Reference image embeddings | |
| timestep_emb: torch.Tensor, | |
| rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| ) -> torch.Tensor: | |
| # AdaLN Modulation | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( | |
| self.adaLN_modulation(timestep_emb)[:, None].chunk(6, dim=-1) | |
| ) | |
| # Self-Attention block | |
| normed_hidden = self.norm1(hidden_states) | |
| normed_hidden = normed_hidden * (1 + scale_msa) + shift_msa | |
| attn_output, _ = self.attn(normed_hidden, normed_hidden, normed_hidden) | |
| hidden_states = hidden_states + gate_msa * attn_output | |
| # Cross-Attention with multimodal conditioning | |
| if visual_context is not None: | |
| context = torch.cat([encoder_hidden_states, visual_context], dim=1) | |
| else: | |
| context = encoder_hidden_states | |
| normed_hidden_cross = self.norm2(hidden_states) | |
| cross_output, _ = self.cross_attn(normed_hidden_cross, context, context) | |
| hidden_states = hidden_states + cross_output | |
| # Feed-Forward block | |
| normed_ffn = self.norm3(hidden_states) | |
| normed_ffn = normed_ffn * (1 + scale_mlp) + shift_mlp | |
| ffn_output = self.ffn(normed_ffn) | |
| hidden_states = hidden_states + gate_mlp * ffn_output | |
| return hidden_states | |
| # ----------------------------------------------------------------------------- | |
| # 4. The Model: OmniMMDitV2 | |
| # ----------------------------------------------------------------------------- | |
| class OmniMMDitV2(ModelMixin, PreTrainedModel): | |
| """ | |
| Omni-Modal Multi-Dimensional Diffusion Transformer V2. | |
| Supports: Text-to-Image, Image-to-Image (Edit), Image-to-Video. | |
| """ | |
| config_class = OmniMMDitV2Config | |
| _supports_gradient_checkpointing = True | |
| def __init__(self, config: OmniMMDitV2Config): | |
| super().__init__(config) | |
| self.config = config | |
| # Initialize optimizer for advanced features | |
| self.optimizer = ModelOptimizer( | |
| fp8_config=FP8Config(enabled=config.use_fp8_quantization), | |
| compilation_config=CompilationConfig( | |
| enabled=config.use_compilation, | |
| mode=config.compile_mode, | |
| ), | |
| mixed_precision_config=MixedPrecisionConfig( | |
| enabled=True, | |
| dtype="bfloat16", | |
| ), | |
| ) | |
| # Input Latent Projection (Patchify) | |
| self.x_embedder = nn.Linear(config.in_channels * config.patch_size * config.patch_size, config.hidden_size, bias=True) | |
| # Time & Vector Embeddings | |
| self.t_embedder = TimestepEmbedder(config.hidden_size, config.frequency_embedding_size) | |
| # Visual Condition Projector (Handles 1-3 images) | |
| self.visual_projector = nn.Sequential( | |
| nn.Linear(config.visual_embed_dim, config.hidden_size), | |
| nn.LayerNorm(config.hidden_size), | |
| nn.Linear(config.hidden_size, config.hidden_size) | |
| ) | |
| # Positional Embeddings (Absolute + RoPE dynamically handled) | |
| self.pos_embed = nn.Parameter(torch.zeros(1, config.max_position_embeddings, config.hidden_size), requires_grad=False) | |
| # Transformer Backbone | |
| self.blocks = nn.ModuleList([ | |
| OmniMMDitBlock(config, i) for i in range(config.num_hidden_layers) | |
| ]) | |
| # Final Layer (AdaLN-Zero + Linear) | |
| self.final_layer = nn.Sequential( | |
| OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps), | |
| nn.Linear(config.hidden_size, config.patch_size * config.patch_size * config.out_channels, bias=True) | |
| ) | |
| self.initialize_weights() | |
| # Apply optimizations if enabled | |
| if config.use_fp8_quantization or config.use_compilation: | |
| self._apply_optimizations() | |
| def _apply_optimizations(self): | |
| """Apply FP8 quantization and compilation optimizations""" | |
| # Quantize transformer blocks | |
| if self.config.use_fp8_quantization: | |
| for i, block in enumerate(self.blocks): | |
| self.blocks[i] = self.optimizer.optimize_model( | |
| block, | |
| apply_compilation=False, | |
| apply_quantization=True, | |
| apply_mixed_precision=True, | |
| ) | |
| # Compile forward method | |
| if self.config.use_compilation and HAS_TORCH_COMPILE: | |
| self.forward = torch.compile( | |
| self.forward, | |
| mode=self.config.compile_mode, | |
| dynamic=True, | |
| ) | |
| def initialize_weights(self): | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| def unpatchify(self, x, h, w): | |
| c = self.config.out_channels | |
| p = self.config.patch_size | |
| h_ = h // p | |
| w_ = w // p | |
| x = x.reshape(shape=(x.shape[0], h_, w_, p, p, c)) | |
| x = torch.einsum('nhwpqc->nchpwq', x) | |
| imgs = x.reshape(shape=(x.shape[0], c, h, w)) | |
| return imgs | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, # Noisy Latents [B, C, H, W] or [B, C, F, H, W] | |
| timestep: torch.LongTensor, | |
| encoder_hidden_states: torch.Tensor, # Text Embeddings | |
| visual_conditions: Optional[List[torch.Tensor]] = None, # List of [B, L, D] | |
| video_frames: Optional[int] = None, # If generating video | |
| return_dict: bool = True, | |
| ) -> Union[torch.Tensor, BaseOutput]: | |
| batch_size, channels, _, _ = hidden_states.shape | |
| # Patchify input latents | |
| p = self.config.patch_size | |
| h, w = hidden_states.shape[-2], hidden_states.shape[-1] | |
| x = hidden_states.unfold(2, p, p).unfold(3, p, p) | |
| x = x.permute(0, 2, 3, 1, 4, 5).contiguous() | |
| x = x.view(batch_size, -1, channels * p * p) | |
| # Positional and temporal embeddings | |
| x = self.x_embedder(x) | |
| x = x + self.pos_embed[:, :x.shape[1], :] | |
| t = self.t_embedder(timestep, x.dtype) | |
| # Process visual conditioning | |
| visual_emb = None | |
| if visual_conditions is not None: | |
| concat_visuals = torch.cat(visual_conditions, dim=1) | |
| visual_emb = self.visual_projector(concat_visuals) | |
| # Transformer blocks | |
| for block in self.blocks: | |
| x = block( | |
| hidden_states=x, | |
| encoder_hidden_states=encoder_hidden_states, | |
| visual_context=visual_emb, | |
| timestep_emb=t | |
| ) | |
| # Output projection | |
| x = self.final_layer[0](x) | |
| x = self.final_layer[1](x) | |
| # Unpatchify to image space | |
| output = self.unpatchify(x, h, w) | |
| if not return_dict: | |
| return (output,) | |
| return BaseOutput(sample=output) | |
| # ----------------------------------------------------------------------------- | |
| # 5. The "Fancy" Pipeline | |
| # ----------------------------------------------------------------------------- | |
| class OmniMMDitV2Pipeline(DiffusionPipeline): | |
| """ | |
| Omni-Modal Diffusion Transformer Pipeline. | |
| Supports text-guided image editing and video generation with | |
| multi-image conditioning and advanced guidance techniques. | |
| """ | |
| model: OmniMMDitV2 | |
| tokenizer: CLIPTokenizer | |
| text_encoder: CLIPTextModel | |
| vae: Any # AutoencoderKL | |
| scheduler: DDIMScheduler | |
| _optional_components = ["visual_encoder"] | |
| def __init__( | |
| self, | |
| model: OmniMMDitV2, | |
| vae: Any, | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| scheduler: DDIMScheduler, | |
| visual_encoder: Optional[Any] = None, | |
| ): | |
| super().__init__() | |
| self.register_modules( | |
| model=model, | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| scheduler=scheduler, | |
| visual_encoder=visual_encoder | |
| ) | |
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
| # Initialize data processors | |
| self.image_processor = OmniImageProcessor( | |
| size=(512, 512), | |
| interpolation="bicubic", | |
| do_normalize=True, | |
| ) | |
| self.video_processor = OmniVideoProcessor( | |
| image_processor=self.image_processor, | |
| num_frames=16, | |
| ) | |
| self.latent_processor = OmniLatentProcessor( | |
| vae=vae, | |
| scaling_factor=0.18215, | |
| ) | |
| # Initialize model optimizer | |
| self.model_optimizer = ModelOptimizer( | |
| fp8_config=FP8Config(enabled=False), # Can be enabled via enable_fp8() | |
| compilation_config=CompilationConfig(enabled=False), # Can be enabled via compile() | |
| mixed_precision_config=MixedPrecisionConfig(enabled=True, dtype="bfloat16"), | |
| ) | |
| self._is_compiled = False | |
| self._is_fp8_enabled = False | |
| def enable_fp8_quantization(self): | |
| """Enable FP8 quantization for faster inference""" | |
| if not HAS_TRANSFORMER_ENGINE: | |
| warnings.warn("Transformer Engine not available. Install with: pip install transformer-engine") | |
| return self | |
| self.model_optimizer.fp8_config.enabled = True | |
| self.model = self.model_optimizer.optimize_model( | |
| self.model, | |
| apply_compilation=False, | |
| apply_quantization=True, | |
| apply_mixed_precision=False, | |
| ) | |
| self._is_fp8_enabled = True | |
| return self | |
| def compile_model( | |
| self, | |
| mode: str = "reduce-overhead", | |
| fullgraph: bool = False, | |
| dynamic: bool = True, | |
| ): | |
| """ | |
| Compile model using torch.compile for faster inference. | |
| Args: | |
| mode: Compilation mode - "default", "reduce-overhead", "max-autotune" | |
| fullgraph: Whether to compile the entire model as one graph | |
| dynamic: Whether to enable dynamic shapes | |
| """ | |
| if not HAS_TORCH_COMPILE: | |
| warnings.warn("torch.compile not available. Upgrade to PyTorch 2.0+") | |
| return self | |
| self.model_optimizer.compilation_config = CompilationConfig( | |
| enabled=True, | |
| mode=mode, | |
| fullgraph=fullgraph, | |
| dynamic=dynamic, | |
| ) | |
| self.model = self.model_optimizer._compile_model(self.model) | |
| self._is_compiled = True | |
| return self | |
| def enable_optimizations( | |
| self, | |
| enable_fp8: bool = False, | |
| enable_compilation: bool = False, | |
| compilation_mode: str = "reduce-overhead", | |
| ): | |
| """ | |
| Enable all optimizations at once. | |
| Args: | |
| enable_fp8: Enable FP8 quantization | |
| enable_compilation: Enable torch.compile | |
| compilation_mode: Compilation mode for torch.compile | |
| """ | |
| if enable_fp8: | |
| self.enable_fp8_quantization() | |
| if enable_compilation: | |
| self.compile_model(mode=compilation_mode) | |
| return self | |
| def __call__( | |
| self, | |
| prompt: Union[str, List[str]] = None, | |
| input_images: Optional[List[Union[torch.Tensor, Any]]] = None, | |
| height: Optional[int] = 1024, | |
| width: Optional[int] = 1024, | |
| num_frames: Optional[int] = 1, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| image_guidance_scale: float = 1.5, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, | |
| callback_steps: int = 1, | |
| use_optimized_inference: bool = True, | |
| **kwargs, | |
| ): | |
| # Use optimized inference context | |
| with optimized_inference_mode( | |
| enable_cudnn_benchmark=use_optimized_inference, | |
| enable_tf32=use_optimized_inference, | |
| enable_flash_sdp=use_optimized_inference, | |
| ): | |
| return self._forward_impl( | |
| prompt=prompt, | |
| input_images=input_images, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| image_guidance_scale=image_guidance_scale, | |
| negative_prompt=negative_prompt, | |
| eta=eta, | |
| generator=generator, | |
| latents=latents, | |
| output_type=output_type, | |
| return_dict=return_dict, | |
| callback=callback, | |
| callback_steps=callback_steps, | |
| **kwargs, | |
| ) | |
| def _forward_impl( | |
| self, | |
| prompt: Union[str, List[str]] = None, | |
| input_images: Optional[List[Union[torch.Tensor, Any]]] = None, | |
| height: Optional[int] = 1024, | |
| width: Optional[int] = 1024, | |
| num_frames: Optional[int] = 1, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| image_guidance_scale: float = 1.5, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, | |
| callback_steps: int = 1, | |
| **kwargs, | |
| ): | |
| # Validate and set default dimensions | |
| height = height or self.model.config.sample_size * self.vae_scale_factor | |
| width = width or self.model.config.sample_size * self.vae_scale_factor | |
| # Encode text prompts | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| batch_size = len(prompt) | |
| text_inputs = self.tokenizer( | |
| prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt" | |
| ) | |
| text_embeddings = self.text_encoder(text_inputs.input_ids.to(self.device))[0] | |
| # Encode visual conditions with preprocessing | |
| visual_embeddings_list = [] | |
| if input_images: | |
| if not isinstance(input_images, list): | |
| input_images = [input_images] | |
| if len(input_images) > 3: | |
| raise ValueError("Maximum 3 reference images supported") | |
| for img in input_images: | |
| # Preprocess image | |
| if not isinstance(img, torch.Tensor): | |
| img_tensor = self.image_processor.preprocess(img, return_tensors="pt") | |
| else: | |
| img_tensor = img | |
| img_tensor = img_tensor.to(device=self.device, dtype=text_embeddings.dtype) | |
| # Encode with visual encoder | |
| if self.visual_encoder is not None: | |
| vis_emb = self.visual_encoder(img_tensor).last_hidden_state | |
| else: | |
| # Fallback: use VAE encoder + projection | |
| with torch.no_grad(): | |
| latent_features = self.vae.encode(img_tensor * 2 - 1).latent_dist.mode() | |
| B, C, H, W = latent_features.shape | |
| # Flatten spatial dims and project | |
| vis_emb = latent_features.flatten(2).transpose(1, 2) # [B, H*W, C] | |
| # Simple projection to visual_embed_dim | |
| if vis_emb.shape[-1] != self.model.config.visual_embed_dim: | |
| proj = nn.Linear(vis_emb.shape[-1], self.model.config.visual_embed_dim).to(self.device) | |
| vis_emb = proj(vis_emb) | |
| visual_embeddings_list.append(vis_emb) | |
| # Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=self.device) | |
| timesteps = self.scheduler.timesteps | |
| # Initialize latent space | |
| num_channels_latents = self.model.config.in_channels | |
| shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | |
| if num_frames > 1: | |
| shape = (batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor) | |
| latents = torch.randn(shape, generator=generator, device=self.device, dtype=text_embeddings.dtype) | |
| latents = latents * self.scheduler.init_noise_sigma | |
| # Denoising loop with optimizations | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| # Use mixed precision autocast | |
| with self.model_optimizer.autocast_context(): | |
| noise_pred = self.model( | |
| hidden_states=latent_model_input, | |
| timestep=t, | |
| encoder_hidden_states=torch.cat([text_embeddings] * 2), | |
| visual_conditions=visual_embeddings_list * 2 if visual_embeddings_list else None, | |
| video_frames=num_frames | |
| ).sample | |
| # Apply classifier-free guidance | |
| if guidance_scale > 1.0: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| latents = self.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample | |
| # Call callback if provided | |
| if callback is not None and i % callback_steps == 0: | |
| callback(i, t, latents) | |
| progress_bar.update() | |
| # Decode latents with proper post-processing | |
| if output_type == "latent": | |
| output_images = latents | |
| else: | |
| # Decode latents to pixel space | |
| with torch.no_grad(): | |
| if num_frames > 1: | |
| # Video decoding: process frame by frame | |
| B, C, T, H, W = latents.shape | |
| latents_2d = latents.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W) | |
| decoded = self.latent_processor.decode(latents_2d) | |
| decoded = decoded.reshape(B, T, 3, H * 8, W * 8) | |
| # Convert to [0, 1] range | |
| decoded = (decoded / 2 + 0.5).clamp(0, 1) | |
| # Post-process video | |
| if output_type == "pil": | |
| output_images = self.video_processor.postprocess_video(decoded, output_type="pil") | |
| elif output_type == "np": | |
| output_images = decoded.cpu().numpy() | |
| else: | |
| output_images = decoded | |
| else: | |
| # Image decoding | |
| decoded = self.latent_processor.decode(latents) | |
| decoded = (decoded / 2 + 0.5).clamp(0, 1) | |
| # Post-process images | |
| if output_type == "pil": | |
| output_images = self.image_processor.postprocess(decoded, output_type="pil") | |
| elif output_type == "np": | |
| output_images = decoded.cpu().numpy() | |
| else: | |
| output_images = decoded | |
| if not return_dict: | |
| return (output_images,) | |
| return BaseOutput(images=output_images) | |