Spaces:
Runtime error
Runtime error
Update pipelines/flux_pipeline/pipeline.py
Browse files
pipelines/flux_pipeline/pipeline.py
CHANGED
|
@@ -17,6 +17,11 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
from transformers import (
|
| 21 |
CLIPImageProcessor,
|
| 22 |
CLIPTextModel,
|
|
@@ -26,14 +31,6 @@ from transformers import (
|
|
| 26 |
T5TokenizerFast,
|
| 27 |
)
|
| 28 |
|
| 29 |
-
from diffusers import FluxPipeline
|
| 30 |
-
from diffusers.image_processor import VaeImageProcessor
|
| 31 |
-
from diffusers.loaders import FluxLoraLoaderMixin
|
| 32 |
-
from diffusers.models.autoencoders import AutoencoderKL
|
| 33 |
-
from diffusers.models.transformers import FluxTransformer2DModel
|
| 34 |
-
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 35 |
-
from diffusers.utils import USE_PEFT_BACKEND, is_torch_xla_available
|
| 36 |
-
|
| 37 |
if is_torch_xla_available():
|
| 38 |
import torch_xla.core.xla_model as xm
|
| 39 |
|
|
@@ -42,7 +39,6 @@ else:
|
|
| 42 |
XLA_AVAILABLE = False
|
| 43 |
|
| 44 |
|
| 45 |
-
|
| 46 |
def calculate_shift(
|
| 47 |
image_seq_len,
|
| 48 |
base_seq_len: int = 256,
|
|
@@ -102,16 +98,17 @@ def normalized_guidance_image(neg_noise_pred, noise_pred, image_noise_pred, true
|
|
| 102 |
diff_img = image_noise_pred - neg_noise_pred
|
| 103 |
diff_txt = noise_pred - image_noise_pred
|
| 104 |
|
| 105 |
-
diff_norm_txt = diff_txt.norm(p=2, dim=[-1, -2], keepdim=True)
|
| 106 |
-
diff_norm_img = diff_img.norm(p=2, dim=[-1, -2], keepdim=True)
|
| 107 |
min_norm = torch.minimum(diff_norm_img, diff_norm_txt)
|
| 108 |
diff_txt = diff_txt * torch.minimum(torch.ones_like(diff_txt), min_norm / diff_norm_txt)
|
| 109 |
diff_img = diff_img * torch.minimum(torch.ones_like(diff_txt), min_norm / diff_norm_img)
|
| 110 |
-
pred_guided = image_noise_pred + image_cfg_scale * diff_img + true_cfg_scale *
|
| 111 |
return pred_guided
|
| 112 |
|
|
|
|
| 113 |
class SynCDFluxPipeline(FluxPipeline):
|
| 114 |
-
|
| 115 |
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
| 116 |
_optional_components = []
|
| 117 |
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
|
@@ -127,7 +124,7 @@ class SynCDFluxPipeline(FluxPipeline):
|
|
| 127 |
transformer: FluxTransformer2DModel,
|
| 128 |
image_encoder: CLIPVisionModelWithProjection = None,
|
| 129 |
feature_extractor: CLIPImageProcessor = None,
|
| 130 |
-
###
|
| 131 |
num=2,
|
| 132 |
):
|
| 133 |
super().__init__(
|
|
@@ -173,8 +170,8 @@ class SynCDFluxPipeline(FluxPipeline):
|
|
| 173 |
#####
|
| 174 |
latents_ref: Optional[torch.Tensor] = None,
|
| 175 |
latents_mask: Optional[torch.Tensor] = None,
|
| 176 |
-
return_latents: bool=False,
|
| 177 |
-
image_cfg_scale: float=0.0,
|
| 178 |
):
|
| 179 |
r"""
|
| 180 |
Function invoked when calling the pipeline for generation.
|
|
@@ -386,7 +383,7 @@ class SynCDFluxPipeline(FluxPipeline):
|
|
| 386 |
self._current_timestep = t
|
| 387 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 388 |
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 389 |
-
self.joint_attention_kwargs.update({'timestep': t/1000
|
| 390 |
if self.joint_attention_kwargs is not None and self.joint_attention_kwargs['shared_attn'] and latents_ref is not None and latents_mask is not None:
|
| 391 |
latents = (1 - latents_mask) * latents_ref + latents_mask * latents
|
| 392 |
|
|
@@ -427,13 +424,12 @@ class SynCDFluxPipeline(FluxPipeline):
|
|
| 427 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 428 |
return_dict=False,
|
| 429 |
)[0]
|
| 430 |
-
|
| 431 |
if image_cfg_scale == 0:
|
| 432 |
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 433 |
else:
|
| 434 |
noise_pred = normalized_guidance_image(neg_noise_pred, noise_pred, image_noise_pred, true_cfg_scale, image_cfg_scale)
|
| 435 |
|
| 436 |
-
|
| 437 |
# compute the previous noisy sample x_t -> x_t-1
|
| 438 |
latents_dtype = latents.dtype
|
| 439 |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
|
@@ -471,4 +467,4 @@ class SynCDFluxPipeline(FluxPipeline):
|
|
| 471 |
# Offload all models
|
| 472 |
self.maybe_free_model_hooks()
|
| 473 |
|
| 474 |
-
return (image,)
|
|
|
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
| 20 |
+
from diffusers import FluxPipeline
|
| 21 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
| 22 |
+
from diffusers.models.transformers import FluxTransformer2DModel
|
| 23 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 24 |
+
from diffusers.utils import is_torch_xla_available
|
| 25 |
from transformers import (
|
| 26 |
CLIPImageProcessor,
|
| 27 |
CLIPTextModel,
|
|
|
|
| 31 |
T5TokenizerFast,
|
| 32 |
)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
if is_torch_xla_available():
|
| 35 |
import torch_xla.core.xla_model as xm
|
| 36 |
|
|
|
|
| 39 |
XLA_AVAILABLE = False
|
| 40 |
|
| 41 |
|
|
|
|
| 42 |
def calculate_shift(
|
| 43 |
image_seq_len,
|
| 44 |
base_seq_len: int = 256,
|
|
|
|
| 98 |
diff_img = image_noise_pred - neg_noise_pred
|
| 99 |
diff_txt = noise_pred - image_noise_pred
|
| 100 |
|
| 101 |
+
diff_norm_txt = diff_txt.norm(p=2, dim=[-1, -2], keepdim=True)
|
| 102 |
+
diff_norm_img = diff_img.norm(p=2, dim=[-1, -2], keepdim=True)
|
| 103 |
min_norm = torch.minimum(diff_norm_img, diff_norm_txt)
|
| 104 |
diff_txt = diff_txt * torch.minimum(torch.ones_like(diff_txt), min_norm / diff_norm_txt)
|
| 105 |
diff_img = diff_img * torch.minimum(torch.ones_like(diff_txt), min_norm / diff_norm_img)
|
| 106 |
+
pred_guided = image_noise_pred + image_cfg_scale * diff_img + true_cfg_scale * diff_txt
|
| 107 |
return pred_guided
|
| 108 |
|
| 109 |
+
|
| 110 |
class SynCDFluxPipeline(FluxPipeline):
|
| 111 |
+
|
| 112 |
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
| 113 |
_optional_components = []
|
| 114 |
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
|
|
|
| 124 |
transformer: FluxTransformer2DModel,
|
| 125 |
image_encoder: CLIPVisionModelWithProjection = None,
|
| 126 |
feature_extractor: CLIPImageProcessor = None,
|
| 127 |
+
###
|
| 128 |
num=2,
|
| 129 |
):
|
| 130 |
super().__init__(
|
|
|
|
| 170 |
#####
|
| 171 |
latents_ref: Optional[torch.Tensor] = None,
|
| 172 |
latents_mask: Optional[torch.Tensor] = None,
|
| 173 |
+
return_latents: bool = False,
|
| 174 |
+
image_cfg_scale: float = 0.0,
|
| 175 |
):
|
| 176 |
r"""
|
| 177 |
Function invoked when calling the pipeline for generation.
|
|
|
|
| 383 |
self._current_timestep = t
|
| 384 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 385 |
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 386 |
+
self.joint_attention_kwargs.update({'timestep': t/1000})
|
| 387 |
if self.joint_attention_kwargs is not None and self.joint_attention_kwargs['shared_attn'] and latents_ref is not None and latents_mask is not None:
|
| 388 |
latents = (1 - latents_mask) * latents_ref + latents_mask * latents
|
| 389 |
|
|
|
|
| 424 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 425 |
return_dict=False,
|
| 426 |
)[0]
|
| 427 |
+
|
| 428 |
if image_cfg_scale == 0:
|
| 429 |
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 430 |
else:
|
| 431 |
noise_pred = normalized_guidance_image(neg_noise_pred, noise_pred, image_noise_pred, true_cfg_scale, image_cfg_scale)
|
| 432 |
|
|
|
|
| 433 |
# compute the previous noisy sample x_t -> x_t-1
|
| 434 |
latents_dtype = latents.dtype
|
| 435 |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
|
|
|
| 467 |
# Offload all models
|
| 468 |
self.maybe_free_model_hooks()
|
| 469 |
|
| 470 |
+
return (image,)
|