Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
selfitcamera
commited on
Commit
·
4c55d00
1
Parent(s):
61f70d4
init
Browse files- __lib__/i18n/ar.pyc +0 -0
- __lib__/i18n/da.pyc +0 -0
- __lib__/i18n/de.pyc +0 -0
- __lib__/i18n/en.pyc +0 -0
- __lib__/i18n/es.pyc +0 -0
- __lib__/i18n/fi.pyc +0 -0
- __lib__/i18n/fr.pyc +0 -0
- __lib__/i18n/he.pyc +0 -0
- __lib__/i18n/hi.pyc +0 -0
- __lib__/i18n/id.pyc +0 -0
- __lib__/i18n/it.pyc +0 -0
- __lib__/i18n/ja.pyc +0 -0
- __lib__/i18n/nl.pyc +0 -0
- __lib__/i18n/no.pyc +0 -0
- __lib__/i18n/pt.pyc +0 -0
- __lib__/i18n/ru.pyc +0 -0
- __lib__/i18n/sv.pyc +0 -0
- __lib__/i18n/tr.pyc +0 -0
- __lib__/i18n/uk.pyc +0 -0
- __lib__/i18n/vi.pyc +0 -0
- __lib__/i18n/zh.pyc +0 -0
- __lib__/pipeline.pyc +0 -0
- pipeline.py +206 -9
__lib__/i18n/ar.pyc
CHANGED
|
Binary files a/__lib__/i18n/ar.pyc and b/__lib__/i18n/ar.pyc differ
|
|
|
__lib__/i18n/da.pyc
CHANGED
|
Binary files a/__lib__/i18n/da.pyc and b/__lib__/i18n/da.pyc differ
|
|
|
__lib__/i18n/de.pyc
CHANGED
|
Binary files a/__lib__/i18n/de.pyc and b/__lib__/i18n/de.pyc differ
|
|
|
__lib__/i18n/en.pyc
CHANGED
|
Binary files a/__lib__/i18n/en.pyc and b/__lib__/i18n/en.pyc differ
|
|
|
__lib__/i18n/es.pyc
CHANGED
|
Binary files a/__lib__/i18n/es.pyc and b/__lib__/i18n/es.pyc differ
|
|
|
__lib__/i18n/fi.pyc
CHANGED
|
Binary files a/__lib__/i18n/fi.pyc and b/__lib__/i18n/fi.pyc differ
|
|
|
__lib__/i18n/fr.pyc
CHANGED
|
Binary files a/__lib__/i18n/fr.pyc and b/__lib__/i18n/fr.pyc differ
|
|
|
__lib__/i18n/he.pyc
CHANGED
|
Binary files a/__lib__/i18n/he.pyc and b/__lib__/i18n/he.pyc differ
|
|
|
__lib__/i18n/hi.pyc
CHANGED
|
Binary files a/__lib__/i18n/hi.pyc and b/__lib__/i18n/hi.pyc differ
|
|
|
__lib__/i18n/id.pyc
CHANGED
|
Binary files a/__lib__/i18n/id.pyc and b/__lib__/i18n/id.pyc differ
|
|
|
__lib__/i18n/it.pyc
CHANGED
|
Binary files a/__lib__/i18n/it.pyc and b/__lib__/i18n/it.pyc differ
|
|
|
__lib__/i18n/ja.pyc
CHANGED
|
Binary files a/__lib__/i18n/ja.pyc and b/__lib__/i18n/ja.pyc differ
|
|
|
__lib__/i18n/nl.pyc
CHANGED
|
Binary files a/__lib__/i18n/nl.pyc and b/__lib__/i18n/nl.pyc differ
|
|
|
__lib__/i18n/no.pyc
CHANGED
|
Binary files a/__lib__/i18n/no.pyc and b/__lib__/i18n/no.pyc differ
|
|
|
__lib__/i18n/pt.pyc
CHANGED
|
Binary files a/__lib__/i18n/pt.pyc and b/__lib__/i18n/pt.pyc differ
|
|
|
__lib__/i18n/ru.pyc
CHANGED
|
Binary files a/__lib__/i18n/ru.pyc and b/__lib__/i18n/ru.pyc differ
|
|
|
__lib__/i18n/sv.pyc
CHANGED
|
Binary files a/__lib__/i18n/sv.pyc and b/__lib__/i18n/sv.pyc differ
|
|
|
__lib__/i18n/tr.pyc
CHANGED
|
Binary files a/__lib__/i18n/tr.pyc and b/__lib__/i18n/tr.pyc differ
|
|
|
__lib__/i18n/uk.pyc
CHANGED
|
Binary files a/__lib__/i18n/uk.pyc and b/__lib__/i18n/uk.pyc differ
|
|
|
__lib__/i18n/vi.pyc
CHANGED
|
Binary files a/__lib__/i18n/vi.pyc and b/__lib__/i18n/vi.pyc differ
|
|
|
__lib__/i18n/zh.pyc
CHANGED
|
Binary files a/__lib__/i18n/zh.pyc and b/__lib__/i18n/zh.pyc differ
|
|
|
__lib__/pipeline.pyc
CHANGED
|
Binary files a/__lib__/pipeline.pyc and b/__lib__/pipeline.pyc differ
|
|
|
pipeline.py
CHANGED
|
@@ -1,12 +1,15 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
-
from typing import Optional, Tuple, Union, List, Dict, Any
|
| 5 |
from dataclasses import dataclass
|
| 6 |
import numpy as np
|
| 7 |
from PIL import Image
|
| 8 |
import torchvision.transforms as T
|
| 9 |
from torchvision.transforms.functional import to_tensor, normalize
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
from transformers import PretrainedConfig, PreTrainedModel, CLIPTextModel, CLIPTokenizer
|
| 12 |
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
@@ -15,6 +18,20 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
| 15 |
from diffusers.models.modeling_utils import ModelMixin
|
| 16 |
from diffusers.utils import BaseOutput
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# -----------------------------------------------------------------------------
|
| 19 |
# 1. Advanced Configuration (8B Scale)
|
| 20 |
# -----------------------------------------------------------------------------
|
|
@@ -50,6 +67,11 @@ class OmniMMDitV2Config(PretrainedConfig):
|
|
| 50 |
visual_embed_dim: int = 1024, # e.g., SigLIP or CLIP Vision
|
| 51 |
text_embed_dim: int = 4096, # T5-XXL or similar
|
| 52 |
use_temporal_attention: bool = True, # For Video generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
**kwargs,
|
| 54 |
):
|
| 55 |
self.vocab_size = vocab_size
|
|
@@ -72,6 +94,10 @@ class OmniMMDitV2Config(PretrainedConfig):
|
|
| 72 |
self.visual_embed_dim = visual_embed_dim
|
| 73 |
self.text_embed_dim = text_embed_dim
|
| 74 |
self.use_temporal_attention = use_temporal_attention
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
super().__init__(
|
| 76 |
pad_token_id=pad_token_id,
|
| 77 |
bos_token_id=bos_token_id,
|
|
@@ -567,6 +593,19 @@ class OmniMMDitV2(ModelMixin, PreTrainedModel):
|
|
| 567 |
super().__init__(config)
|
| 568 |
self.config = config
|
| 569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
# Input Latent Projection (Patchify)
|
| 571 |
self.x_embedder = nn.Linear(config.in_channels * config.patch_size * config.patch_size, config.hidden_size, bias=True)
|
| 572 |
|
|
@@ -595,6 +634,30 @@ class OmniMMDitV2(ModelMixin, PreTrainedModel):
|
|
| 595 |
)
|
| 596 |
|
| 597 |
self.initialize_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
|
| 599 |
def initialize_weights(self):
|
| 600 |
def _basic_init(module):
|
|
@@ -719,6 +782,83 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
|
|
| 719 |
vae=vae,
|
| 720 |
scaling_factor=0.18215,
|
| 721 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
|
| 723 |
@torch.no_grad()
|
| 724 |
def __call__(
|
|
@@ -737,6 +877,55 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
|
|
| 737 |
latents: Optional[torch.Tensor] = None,
|
| 738 |
output_type: Optional[str] = "pil",
|
| 739 |
return_dict: bool = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
**kwargs,
|
| 741 |
):
|
| 742 |
# Validate and set default dimensions
|
|
@@ -800,25 +989,33 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
|
|
| 800 |
latents = torch.randn(shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
|
| 801 |
latents = latents * self.scheduler.init_noise_sigma
|
| 802 |
|
| 803 |
-
# Denoising loop
|
| 804 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 805 |
for i, t in enumerate(timesteps):
|
| 806 |
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
|
| 807 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 808 |
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
|
|
|
|
|
|
| 816 |
|
| 817 |
# Apply classifier-free guidance
|
| 818 |
if guidance_scale > 1.0:
|
| 819 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 820 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
| 821 |
latents = self.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 822 |
progress_bar.update()
|
| 823 |
|
| 824 |
# Decode latents with proper post-processing
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
|
| 5 |
from dataclasses import dataclass
|
| 6 |
import numpy as np
|
| 7 |
from PIL import Image
|
| 8 |
import torchvision.transforms as T
|
| 9 |
from torchvision.transforms.functional import to_tensor, normalize
|
| 10 |
+
import warnings
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
+
from functools import wraps
|
| 13 |
|
| 14 |
from transformers import PretrainedConfig, PreTrainedModel, CLIPTextModel, CLIPTokenizer
|
| 15 |
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
|
|
| 18 |
from diffusers.models.modeling_utils import ModelMixin
|
| 19 |
from diffusers.utils import BaseOutput
|
| 20 |
|
| 21 |
+
# Optimization imports
|
| 22 |
+
try:
|
| 23 |
+
import transformer_engine.pytorch as te
|
| 24 |
+
from transformer_engine.common import recipe
|
| 25 |
+
HAS_TRANSFORMER_ENGINE = True
|
| 26 |
+
except ImportError:
|
| 27 |
+
HAS_TRANSFORMER_ENGINE = False
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from torch._dynamo import config as dynamo_config
|
| 31 |
+
HAS_TORCH_COMPILE = hasattr(torch, 'compile')
|
| 32 |
+
except ImportError:
|
| 33 |
+
HAS_TORCH_COMPILE = False
|
| 34 |
+
|
| 35 |
# -----------------------------------------------------------------------------
|
| 36 |
# 1. Advanced Configuration (8B Scale)
|
| 37 |
# -----------------------------------------------------------------------------
|
|
|
|
| 67 |
visual_embed_dim: int = 1024, # e.g., SigLIP or CLIP Vision
|
| 68 |
text_embed_dim: int = 4096, # T5-XXL or similar
|
| 69 |
use_temporal_attention: bool = True, # For Video generation
|
| 70 |
+
# Optimization Configs
|
| 71 |
+
use_fp8_quantization: bool = False,
|
| 72 |
+
use_compilation: bool = False,
|
| 73 |
+
compile_mode: str = "reduce-overhead",
|
| 74 |
+
use_flash_attention: bool = True,
|
| 75 |
**kwargs,
|
| 76 |
):
|
| 77 |
self.vocab_size = vocab_size
|
|
|
|
| 94 |
self.visual_embed_dim = visual_embed_dim
|
| 95 |
self.text_embed_dim = text_embed_dim
|
| 96 |
self.use_temporal_attention = use_temporal_attention
|
| 97 |
+
self.use_fp8_quantization = use_fp8_quantization
|
| 98 |
+
self.use_compilation = use_compilation
|
| 99 |
+
self.compile_mode = compile_mode
|
| 100 |
+
self.use_flash_attention = use_flash_attention
|
| 101 |
super().__init__(
|
| 102 |
pad_token_id=pad_token_id,
|
| 103 |
bos_token_id=bos_token_id,
|
|
|
|
| 593 |
super().__init__(config)
|
| 594 |
self.config = config
|
| 595 |
|
| 596 |
+
# Initialize optimizer for advanced features
|
| 597 |
+
self.optimizer = ModelOptimizer(
|
| 598 |
+
fp8_config=FP8Config(enabled=config.use_fp8_quantization),
|
| 599 |
+
compilation_config=CompilationConfig(
|
| 600 |
+
enabled=config.use_compilation,
|
| 601 |
+
mode=config.compile_mode,
|
| 602 |
+
),
|
| 603 |
+
mixed_precision_config=MixedPrecisionConfig(
|
| 604 |
+
enabled=True,
|
| 605 |
+
dtype="bfloat16",
|
| 606 |
+
),
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
# Input Latent Projection (Patchify)
|
| 610 |
self.x_embedder = nn.Linear(config.in_channels * config.patch_size * config.patch_size, config.hidden_size, bias=True)
|
| 611 |
|
|
|
|
| 634 |
)
|
| 635 |
|
| 636 |
self.initialize_weights()
|
| 637 |
+
|
| 638 |
+
# Apply optimizations if enabled
|
| 639 |
+
if config.use_fp8_quantization or config.use_compilation:
|
| 640 |
+
self._apply_optimizations()
|
| 641 |
+
|
| 642 |
+
def _apply_optimizations(self):
|
| 643 |
+
"""Apply FP8 quantization and compilation optimizations"""
|
| 644 |
+
# Quantize transformer blocks
|
| 645 |
+
if self.config.use_fp8_quantization:
|
| 646 |
+
for i, block in enumerate(self.blocks):
|
| 647 |
+
self.blocks[i] = self.optimizer.optimize_model(
|
| 648 |
+
block,
|
| 649 |
+
apply_compilation=False,
|
| 650 |
+
apply_quantization=True,
|
| 651 |
+
apply_mixed_precision=True,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
# Compile forward method
|
| 655 |
+
if self.config.use_compilation and HAS_TORCH_COMPILE:
|
| 656 |
+
self.forward = torch.compile(
|
| 657 |
+
self.forward,
|
| 658 |
+
mode=self.config.compile_mode,
|
| 659 |
+
dynamic=True,
|
| 660 |
+
)
|
| 661 |
|
| 662 |
def initialize_weights(self):
|
| 663 |
def _basic_init(module):
|
|
|
|
| 782 |
vae=vae,
|
| 783 |
scaling_factor=0.18215,
|
| 784 |
)
|
| 785 |
+
|
| 786 |
+
# Initialize model optimizer
|
| 787 |
+
self.model_optimizer = ModelOptimizer(
|
| 788 |
+
fp8_config=FP8Config(enabled=False), # Can be enabled via enable_fp8()
|
| 789 |
+
compilation_config=CompilationConfig(enabled=False), # Can be enabled via compile()
|
| 790 |
+
mixed_precision_config=MixedPrecisionConfig(enabled=True, dtype="bfloat16"),
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
self._is_compiled = False
|
| 794 |
+
self._is_fp8_enabled = False
|
| 795 |
+
|
| 796 |
+
def enable_fp8_quantization(self):
|
| 797 |
+
"""Enable FP8 quantization for faster inference"""
|
| 798 |
+
if not HAS_TRANSFORMER_ENGINE:
|
| 799 |
+
warnings.warn("Transformer Engine not available. Install with: pip install transformer-engine")
|
| 800 |
+
return self
|
| 801 |
+
|
| 802 |
+
self.model_optimizer.fp8_config.enabled = True
|
| 803 |
+
self.model = self.model_optimizer.optimize_model(
|
| 804 |
+
self.model,
|
| 805 |
+
apply_compilation=False,
|
| 806 |
+
apply_quantization=True,
|
| 807 |
+
apply_mixed_precision=False,
|
| 808 |
+
)
|
| 809 |
+
self._is_fp8_enabled = True
|
| 810 |
+
return self
|
| 811 |
+
|
| 812 |
+
def compile_model(
|
| 813 |
+
self,
|
| 814 |
+
mode: str = "reduce-overhead",
|
| 815 |
+
fullgraph: bool = False,
|
| 816 |
+
dynamic: bool = True,
|
| 817 |
+
):
|
| 818 |
+
"""
|
| 819 |
+
Compile model using torch.compile for faster inference.
|
| 820 |
+
|
| 821 |
+
Args:
|
| 822 |
+
mode: Compilation mode - "default", "reduce-overhead", "max-autotune"
|
| 823 |
+
fullgraph: Whether to compile the entire model as one graph
|
| 824 |
+
dynamic: Whether to enable dynamic shapes
|
| 825 |
+
"""
|
| 826 |
+
if not HAS_TORCH_COMPILE:
|
| 827 |
+
warnings.warn("torch.compile not available. Upgrade to PyTorch 2.0+")
|
| 828 |
+
return self
|
| 829 |
+
|
| 830 |
+
self.model_optimizer.compilation_config = CompilationConfig(
|
| 831 |
+
enabled=True,
|
| 832 |
+
mode=mode,
|
| 833 |
+
fullgraph=fullgraph,
|
| 834 |
+
dynamic=dynamic,
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
self.model = self.model_optimizer._compile_model(self.model)
|
| 838 |
+
self._is_compiled = True
|
| 839 |
+
return self
|
| 840 |
+
|
| 841 |
+
def enable_optimizations(
|
| 842 |
+
self,
|
| 843 |
+
enable_fp8: bool = False,
|
| 844 |
+
enable_compilation: bool = False,
|
| 845 |
+
compilation_mode: str = "reduce-overhead",
|
| 846 |
+
):
|
| 847 |
+
"""
|
| 848 |
+
Enable all optimizations at once.
|
| 849 |
+
|
| 850 |
+
Args:
|
| 851 |
+
enable_fp8: Enable FP8 quantization
|
| 852 |
+
enable_compilation: Enable torch.compile
|
| 853 |
+
compilation_mode: Compilation mode for torch.compile
|
| 854 |
+
"""
|
| 855 |
+
if enable_fp8:
|
| 856 |
+
self.enable_fp8_quantization()
|
| 857 |
+
|
| 858 |
+
if enable_compilation:
|
| 859 |
+
self.compile_model(mode=compilation_mode)
|
| 860 |
+
|
| 861 |
+
return self
|
| 862 |
|
| 863 |
@torch.no_grad()
|
| 864 |
def __call__(
|
|
|
|
| 877 |
latents: Optional[torch.Tensor] = None,
|
| 878 |
output_type: Optional[str] = "pil",
|
| 879 |
return_dict: bool = True,
|
| 880 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 881 |
+
callback_steps: int = 1,
|
| 882 |
+
use_optimized_inference: bool = True,
|
| 883 |
+
**kwargs,
|
| 884 |
+
):
|
| 885 |
+
# Use optimized inference context
|
| 886 |
+
with optimized_inference_mode(
|
| 887 |
+
enable_cudnn_benchmark=use_optimized_inference,
|
| 888 |
+
enable_tf32=use_optimized_inference,
|
| 889 |
+
enable_flash_sdp=use_optimized_inference,
|
| 890 |
+
):
|
| 891 |
+
return self._forward_impl(
|
| 892 |
+
prompt=prompt,
|
| 893 |
+
input_images=input_images,
|
| 894 |
+
height=height,
|
| 895 |
+
width=width,
|
| 896 |
+
num_frames=num_frames,
|
| 897 |
+
num_inference_steps=num_inference_steps,
|
| 898 |
+
guidance_scale=guidance_scale,
|
| 899 |
+
image_guidance_scale=image_guidance_scale,
|
| 900 |
+
negative_prompt=negative_prompt,
|
| 901 |
+
eta=eta,
|
| 902 |
+
generator=generator,
|
| 903 |
+
latents=latents,
|
| 904 |
+
output_type=output_type,
|
| 905 |
+
return_dict=return_dict,
|
| 906 |
+
callback=callback,
|
| 907 |
+
callback_steps=callback_steps,
|
| 908 |
+
**kwargs,
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
def _forward_impl(
|
| 912 |
+
self,
|
| 913 |
+
prompt: Union[str, List[str]] = None,
|
| 914 |
+
input_images: Optional[List[Union[torch.Tensor, Any]]] = None,
|
| 915 |
+
height: Optional[int] = 1024,
|
| 916 |
+
width: Optional[int] = 1024,
|
| 917 |
+
num_frames: Optional[int] = 1,
|
| 918 |
+
num_inference_steps: int = 50,
|
| 919 |
+
guidance_scale: float = 7.5,
|
| 920 |
+
image_guidance_scale: float = 1.5,
|
| 921 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 922 |
+
eta: float = 0.0,
|
| 923 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 924 |
+
latents: Optional[torch.Tensor] = None,
|
| 925 |
+
output_type: Optional[str] = "pil",
|
| 926 |
+
return_dict: bool = True,
|
| 927 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 928 |
+
callback_steps: int = 1,
|
| 929 |
**kwargs,
|
| 930 |
):
|
| 931 |
# Validate and set default dimensions
|
|
|
|
| 989 |
latents = torch.randn(shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
|
| 990 |
latents = latents * self.scheduler.init_noise_sigma
|
| 991 |
|
| 992 |
+
# Denoising loop with optimizations
|
| 993 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 994 |
for i, t in enumerate(timesteps):
|
| 995 |
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
|
| 996 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 997 |
|
| 998 |
+
# Use mixed precision autocast
|
| 999 |
+
with self.model_optimizer.autocast_context():
|
| 1000 |
+
noise_pred = self.model(
|
| 1001 |
+
hidden_states=latent_model_input,
|
| 1002 |
+
timestep=t,
|
| 1003 |
+
encoder_hidden_states=torch.cat([text_embeddings] * 2),
|
| 1004 |
+
visual_conditions=visual_embeddings_list * 2 if visual_embeddings_list else None,
|
| 1005 |
+
video_frames=num_frames
|
| 1006 |
+
).sample
|
| 1007 |
|
| 1008 |
# Apply classifier-free guidance
|
| 1009 |
if guidance_scale > 1.0:
|
| 1010 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1011 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1012 |
+
|
| 1013 |
latents = self.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample
|
| 1014 |
+
|
| 1015 |
+
# Call callback if provided
|
| 1016 |
+
if callback is not None and i % callback_steps == 0:
|
| 1017 |
+
callback(i, t, latents)
|
| 1018 |
+
|
| 1019 |
progress_bar.update()
|
| 1020 |
|
| 1021 |
# Decode latents with proper post-processing
|