selfitcamera commited on
Commit
4c55d00
·
1 Parent(s): 61f70d4
__lib__/i18n/ar.pyc CHANGED
Binary files a/__lib__/i18n/ar.pyc and b/__lib__/i18n/ar.pyc differ
 
__lib__/i18n/da.pyc CHANGED
Binary files a/__lib__/i18n/da.pyc and b/__lib__/i18n/da.pyc differ
 
__lib__/i18n/de.pyc CHANGED
Binary files a/__lib__/i18n/de.pyc and b/__lib__/i18n/de.pyc differ
 
__lib__/i18n/en.pyc CHANGED
Binary files a/__lib__/i18n/en.pyc and b/__lib__/i18n/en.pyc differ
 
__lib__/i18n/es.pyc CHANGED
Binary files a/__lib__/i18n/es.pyc and b/__lib__/i18n/es.pyc differ
 
__lib__/i18n/fi.pyc CHANGED
Binary files a/__lib__/i18n/fi.pyc and b/__lib__/i18n/fi.pyc differ
 
__lib__/i18n/fr.pyc CHANGED
Binary files a/__lib__/i18n/fr.pyc and b/__lib__/i18n/fr.pyc differ
 
__lib__/i18n/he.pyc CHANGED
Binary files a/__lib__/i18n/he.pyc and b/__lib__/i18n/he.pyc differ
 
__lib__/i18n/hi.pyc CHANGED
Binary files a/__lib__/i18n/hi.pyc and b/__lib__/i18n/hi.pyc differ
 
__lib__/i18n/id.pyc CHANGED
Binary files a/__lib__/i18n/id.pyc and b/__lib__/i18n/id.pyc differ
 
__lib__/i18n/it.pyc CHANGED
Binary files a/__lib__/i18n/it.pyc and b/__lib__/i18n/it.pyc differ
 
__lib__/i18n/ja.pyc CHANGED
Binary files a/__lib__/i18n/ja.pyc and b/__lib__/i18n/ja.pyc differ
 
__lib__/i18n/nl.pyc CHANGED
Binary files a/__lib__/i18n/nl.pyc and b/__lib__/i18n/nl.pyc differ
 
__lib__/i18n/no.pyc CHANGED
Binary files a/__lib__/i18n/no.pyc and b/__lib__/i18n/no.pyc differ
 
__lib__/i18n/pt.pyc CHANGED
Binary files a/__lib__/i18n/pt.pyc and b/__lib__/i18n/pt.pyc differ
 
__lib__/i18n/ru.pyc CHANGED
Binary files a/__lib__/i18n/ru.pyc and b/__lib__/i18n/ru.pyc differ
 
__lib__/i18n/sv.pyc CHANGED
Binary files a/__lib__/i18n/sv.pyc and b/__lib__/i18n/sv.pyc differ
 
__lib__/i18n/tr.pyc CHANGED
Binary files a/__lib__/i18n/tr.pyc and b/__lib__/i18n/tr.pyc differ
 
__lib__/i18n/uk.pyc CHANGED
Binary files a/__lib__/i18n/uk.pyc and b/__lib__/i18n/uk.pyc differ
 
__lib__/i18n/vi.pyc CHANGED
Binary files a/__lib__/i18n/vi.pyc and b/__lib__/i18n/vi.pyc differ
 
__lib__/i18n/zh.pyc CHANGED
Binary files a/__lib__/i18n/zh.pyc and b/__lib__/i18n/zh.pyc differ
 
__lib__/pipeline.pyc CHANGED
Binary files a/__lib__/pipeline.pyc and b/__lib__/pipeline.pyc differ
 
pipeline.py CHANGED
@@ -1,12 +1,15 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- from typing import Optional, Tuple, Union, List, Dict, Any
5
  from dataclasses import dataclass
6
  import numpy as np
7
  from PIL import Image
8
  import torchvision.transforms as T
9
  from torchvision.transforms.functional import to_tensor, normalize
 
 
 
10
 
11
  from transformers import PretrainedConfig, PreTrainedModel, CLIPTextModel, CLIPTokenizer
12
  from transformers.modeling_outputs import BaseModelOutputWithPooling
@@ -15,6 +18,20 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
  from diffusers.models.modeling_utils import ModelMixin
16
  from diffusers.utils import BaseOutput
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # -----------------------------------------------------------------------------
19
  # 1. Advanced Configuration (8B Scale)
20
  # -----------------------------------------------------------------------------
@@ -50,6 +67,11 @@ class OmniMMDitV2Config(PretrainedConfig):
50
  visual_embed_dim: int = 1024, # e.g., SigLIP or CLIP Vision
51
  text_embed_dim: int = 4096, # T5-XXL or similar
52
  use_temporal_attention: bool = True, # For Video generation
 
 
 
 
 
53
  **kwargs,
54
  ):
55
  self.vocab_size = vocab_size
@@ -72,6 +94,10 @@ class OmniMMDitV2Config(PretrainedConfig):
72
  self.visual_embed_dim = visual_embed_dim
73
  self.text_embed_dim = text_embed_dim
74
  self.use_temporal_attention = use_temporal_attention
 
 
 
 
75
  super().__init__(
76
  pad_token_id=pad_token_id,
77
  bos_token_id=bos_token_id,
@@ -567,6 +593,19 @@ class OmniMMDitV2(ModelMixin, PreTrainedModel):
567
  super().__init__(config)
568
  self.config = config
569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  # Input Latent Projection (Patchify)
571
  self.x_embedder = nn.Linear(config.in_channels * config.patch_size * config.patch_size, config.hidden_size, bias=True)
572
 
@@ -595,6 +634,30 @@ class OmniMMDitV2(ModelMixin, PreTrainedModel):
595
  )
596
 
597
  self.initialize_weights()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
 
599
  def initialize_weights(self):
600
  def _basic_init(module):
@@ -719,6 +782,83 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
719
  vae=vae,
720
  scaling_factor=0.18215,
721
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
 
723
  @torch.no_grad()
724
  def __call__(
@@ -737,6 +877,55 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
737
  latents: Optional[torch.Tensor] = None,
738
  output_type: Optional[str] = "pil",
739
  return_dict: bool = True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
740
  **kwargs,
741
  ):
742
  # Validate and set default dimensions
@@ -800,25 +989,33 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
800
  latents = torch.randn(shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
801
  latents = latents * self.scheduler.init_noise_sigma
802
 
803
- # Denoising loop
804
  with self.progress_bar(total=num_inference_steps) as progress_bar:
805
  for i, t in enumerate(timesteps):
806
  latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
807
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
808
 
809
- noise_pred = self.model(
810
- hidden_states=latent_model_input,
811
- timestep=t,
812
- encoder_hidden_states=torch.cat([text_embeddings] * 2),
813
- visual_conditions=visual_embeddings_list * 2 if visual_embeddings_list else None,
814
- video_frames=num_frames
815
- ).sample
 
 
816
 
817
  # Apply classifier-free guidance
818
  if guidance_scale > 1.0:
819
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
820
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
821
  latents = self.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample
 
 
 
 
 
822
  progress_bar.update()
823
 
824
  # Decode latents with proper post-processing
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from typing import Optional, Tuple, Union, List, Dict, Any, Callable
5
  from dataclasses import dataclass
6
  import numpy as np
7
  from PIL import Image
8
  import torchvision.transforms as T
9
  from torchvision.transforms.functional import to_tensor, normalize
10
+ import warnings
11
+ from contextlib import contextmanager
12
+ from functools import wraps
13
 
14
  from transformers import PretrainedConfig, PreTrainedModel, CLIPTextModel, CLIPTokenizer
15
  from transformers.modeling_outputs import BaseModelOutputWithPooling
 
18
  from diffusers.models.modeling_utils import ModelMixin
19
  from diffusers.utils import BaseOutput
20
 
21
+ # Optimization imports
22
+ try:
23
+ import transformer_engine.pytorch as te
24
+ from transformer_engine.common import recipe
25
+ HAS_TRANSFORMER_ENGINE = True
26
+ except ImportError:
27
+ HAS_TRANSFORMER_ENGINE = False
28
+
29
+ try:
30
+ from torch._dynamo import config as dynamo_config
31
+ HAS_TORCH_COMPILE = hasattr(torch, 'compile')
32
+ except ImportError:
33
+ HAS_TORCH_COMPILE = False
34
+
35
  # -----------------------------------------------------------------------------
36
  # 1. Advanced Configuration (8B Scale)
37
  # -----------------------------------------------------------------------------
 
67
  visual_embed_dim: int = 1024, # e.g., SigLIP or CLIP Vision
68
  text_embed_dim: int = 4096, # T5-XXL or similar
69
  use_temporal_attention: bool = True, # For Video generation
70
+ # Optimization Configs
71
+ use_fp8_quantization: bool = False,
72
+ use_compilation: bool = False,
73
+ compile_mode: str = "reduce-overhead",
74
+ use_flash_attention: bool = True,
75
  **kwargs,
76
  ):
77
  self.vocab_size = vocab_size
 
94
  self.visual_embed_dim = visual_embed_dim
95
  self.text_embed_dim = text_embed_dim
96
  self.use_temporal_attention = use_temporal_attention
97
+ self.use_fp8_quantization = use_fp8_quantization
98
+ self.use_compilation = use_compilation
99
+ self.compile_mode = compile_mode
100
+ self.use_flash_attention = use_flash_attention
101
  super().__init__(
102
  pad_token_id=pad_token_id,
103
  bos_token_id=bos_token_id,
 
593
  super().__init__(config)
594
  self.config = config
595
 
596
+ # Initialize optimizer for advanced features
597
+ self.optimizer = ModelOptimizer(
598
+ fp8_config=FP8Config(enabled=config.use_fp8_quantization),
599
+ compilation_config=CompilationConfig(
600
+ enabled=config.use_compilation,
601
+ mode=config.compile_mode,
602
+ ),
603
+ mixed_precision_config=MixedPrecisionConfig(
604
+ enabled=True,
605
+ dtype="bfloat16",
606
+ ),
607
+ )
608
+
609
  # Input Latent Projection (Patchify)
610
  self.x_embedder = nn.Linear(config.in_channels * config.patch_size * config.patch_size, config.hidden_size, bias=True)
611
 
 
634
  )
635
 
636
  self.initialize_weights()
637
+
638
+ # Apply optimizations if enabled
639
+ if config.use_fp8_quantization or config.use_compilation:
640
+ self._apply_optimizations()
641
+
642
+ def _apply_optimizations(self):
643
+ """Apply FP8 quantization and compilation optimizations"""
644
+ # Quantize transformer blocks
645
+ if self.config.use_fp8_quantization:
646
+ for i, block in enumerate(self.blocks):
647
+ self.blocks[i] = self.optimizer.optimize_model(
648
+ block,
649
+ apply_compilation=False,
650
+ apply_quantization=True,
651
+ apply_mixed_precision=True,
652
+ )
653
+
654
+ # Compile forward method
655
+ if self.config.use_compilation and HAS_TORCH_COMPILE:
656
+ self.forward = torch.compile(
657
+ self.forward,
658
+ mode=self.config.compile_mode,
659
+ dynamic=True,
660
+ )
661
 
662
  def initialize_weights(self):
663
  def _basic_init(module):
 
782
  vae=vae,
783
  scaling_factor=0.18215,
784
  )
785
+
786
+ # Initialize model optimizer
787
+ self.model_optimizer = ModelOptimizer(
788
+ fp8_config=FP8Config(enabled=False), # Can be enabled via enable_fp8()
789
+ compilation_config=CompilationConfig(enabled=False), # Can be enabled via compile()
790
+ mixed_precision_config=MixedPrecisionConfig(enabled=True, dtype="bfloat16"),
791
+ )
792
+
793
+ self._is_compiled = False
794
+ self._is_fp8_enabled = False
795
+
796
+ def enable_fp8_quantization(self):
797
+ """Enable FP8 quantization for faster inference"""
798
+ if not HAS_TRANSFORMER_ENGINE:
799
+ warnings.warn("Transformer Engine not available. Install with: pip install transformer-engine")
800
+ return self
801
+
802
+ self.model_optimizer.fp8_config.enabled = True
803
+ self.model = self.model_optimizer.optimize_model(
804
+ self.model,
805
+ apply_compilation=False,
806
+ apply_quantization=True,
807
+ apply_mixed_precision=False,
808
+ )
809
+ self._is_fp8_enabled = True
810
+ return self
811
+
812
+ def compile_model(
813
+ self,
814
+ mode: str = "reduce-overhead",
815
+ fullgraph: bool = False,
816
+ dynamic: bool = True,
817
+ ):
818
+ """
819
+ Compile model using torch.compile for faster inference.
820
+
821
+ Args:
822
+ mode: Compilation mode - "default", "reduce-overhead", "max-autotune"
823
+ fullgraph: Whether to compile the entire model as one graph
824
+ dynamic: Whether to enable dynamic shapes
825
+ """
826
+ if not HAS_TORCH_COMPILE:
827
+ warnings.warn("torch.compile not available. Upgrade to PyTorch 2.0+")
828
+ return self
829
+
830
+ self.model_optimizer.compilation_config = CompilationConfig(
831
+ enabled=True,
832
+ mode=mode,
833
+ fullgraph=fullgraph,
834
+ dynamic=dynamic,
835
+ )
836
+
837
+ self.model = self.model_optimizer._compile_model(self.model)
838
+ self._is_compiled = True
839
+ return self
840
+
841
+ def enable_optimizations(
842
+ self,
843
+ enable_fp8: bool = False,
844
+ enable_compilation: bool = False,
845
+ compilation_mode: str = "reduce-overhead",
846
+ ):
847
+ """
848
+ Enable all optimizations at once.
849
+
850
+ Args:
851
+ enable_fp8: Enable FP8 quantization
852
+ enable_compilation: Enable torch.compile
853
+ compilation_mode: Compilation mode for torch.compile
854
+ """
855
+ if enable_fp8:
856
+ self.enable_fp8_quantization()
857
+
858
+ if enable_compilation:
859
+ self.compile_model(mode=compilation_mode)
860
+
861
+ return self
862
 
863
  @torch.no_grad()
864
  def __call__(
 
877
  latents: Optional[torch.Tensor] = None,
878
  output_type: Optional[str] = "pil",
879
  return_dict: bool = True,
880
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
881
+ callback_steps: int = 1,
882
+ use_optimized_inference: bool = True,
883
+ **kwargs,
884
+ ):
885
+ # Use optimized inference context
886
+ with optimized_inference_mode(
887
+ enable_cudnn_benchmark=use_optimized_inference,
888
+ enable_tf32=use_optimized_inference,
889
+ enable_flash_sdp=use_optimized_inference,
890
+ ):
891
+ return self._forward_impl(
892
+ prompt=prompt,
893
+ input_images=input_images,
894
+ height=height,
895
+ width=width,
896
+ num_frames=num_frames,
897
+ num_inference_steps=num_inference_steps,
898
+ guidance_scale=guidance_scale,
899
+ image_guidance_scale=image_guidance_scale,
900
+ negative_prompt=negative_prompt,
901
+ eta=eta,
902
+ generator=generator,
903
+ latents=latents,
904
+ output_type=output_type,
905
+ return_dict=return_dict,
906
+ callback=callback,
907
+ callback_steps=callback_steps,
908
+ **kwargs,
909
+ )
910
+
911
+ def _forward_impl(
912
+ self,
913
+ prompt: Union[str, List[str]] = None,
914
+ input_images: Optional[List[Union[torch.Tensor, Any]]] = None,
915
+ height: Optional[int] = 1024,
916
+ width: Optional[int] = 1024,
917
+ num_frames: Optional[int] = 1,
918
+ num_inference_steps: int = 50,
919
+ guidance_scale: float = 7.5,
920
+ image_guidance_scale: float = 1.5,
921
+ negative_prompt: Optional[Union[str, List[str]]] = None,
922
+ eta: float = 0.0,
923
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
924
+ latents: Optional[torch.Tensor] = None,
925
+ output_type: Optional[str] = "pil",
926
+ return_dict: bool = True,
927
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
928
+ callback_steps: int = 1,
929
  **kwargs,
930
  ):
931
  # Validate and set default dimensions
 
989
  latents = torch.randn(shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
990
  latents = latents * self.scheduler.init_noise_sigma
991
 
992
+ # Denoising loop with optimizations
993
  with self.progress_bar(total=num_inference_steps) as progress_bar:
994
  for i, t in enumerate(timesteps):
995
  latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
996
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
997
 
998
+ # Use mixed precision autocast
999
+ with self.model_optimizer.autocast_context():
1000
+ noise_pred = self.model(
1001
+ hidden_states=latent_model_input,
1002
+ timestep=t,
1003
+ encoder_hidden_states=torch.cat([text_embeddings] * 2),
1004
+ visual_conditions=visual_embeddings_list * 2 if visual_embeddings_list else None,
1005
+ video_frames=num_frames
1006
+ ).sample
1007
 
1008
  # Apply classifier-free guidance
1009
  if guidance_scale > 1.0:
1010
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1011
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1012
+
1013
  latents = self.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample
1014
+
1015
+ # Call callback if provided
1016
+ if callback is not None and i % callback_steps == 0:
1017
+ callback(i, t, latents)
1018
+
1019
  progress_bar.update()
1020
 
1021
  # Decode latents with proper post-processing