nupurkmr9 commited on
Commit
da28f5f
Β·
verified Β·
1 Parent(s): a30afc5

Update pipelines/flux_pipeline/pipeline.py

Browse files
Files changed (1) hide show
  1. pipelines/flux_pipeline/pipeline.py +16 -20
pipelines/flux_pipeline/pipeline.py CHANGED
@@ -17,6 +17,11 @@ from typing import Any, Callable, Dict, List, Optional, Union
17
 
18
  import numpy as np
19
  import torch
 
 
 
 
 
20
  from transformers import (
21
  CLIPImageProcessor,
22
  CLIPTextModel,
@@ -26,14 +31,6 @@ from transformers import (
26
  T5TokenizerFast,
27
  )
28
 
29
- from diffusers import FluxPipeline
30
- from diffusers.image_processor import VaeImageProcessor
31
- from diffusers.loaders import FluxLoraLoaderMixin
32
- from diffusers.models.autoencoders import AutoencoderKL
33
- from diffusers.models.transformers import FluxTransformer2DModel
34
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
35
- from diffusers.utils import USE_PEFT_BACKEND, is_torch_xla_available
36
-
37
  if is_torch_xla_available():
38
  import torch_xla.core.xla_model as xm
39
 
@@ -42,7 +39,6 @@ else:
42
  XLA_AVAILABLE = False
43
 
44
 
45
-
46
  def calculate_shift(
47
  image_seq_len,
48
  base_seq_len: int = 256,
@@ -102,16 +98,17 @@ def normalized_guidance_image(neg_noise_pred, noise_pred, image_noise_pred, true
102
  diff_img = image_noise_pred - neg_noise_pred
103
  diff_txt = noise_pred - image_noise_pred
104
 
105
- diff_norm_txt = diff_txt.norm(p=2, dim=[-1, -2], keepdim=True)
106
- diff_norm_img = diff_img.norm(p=2, dim=[-1, -2], keepdim=True)
107
  min_norm = torch.minimum(diff_norm_img, diff_norm_txt)
108
  diff_txt = diff_txt * torch.minimum(torch.ones_like(diff_txt), min_norm / diff_norm_txt)
109
  diff_img = diff_img * torch.minimum(torch.ones_like(diff_txt), min_norm / diff_norm_img)
110
- pred_guided = image_noise_pred + image_cfg_scale * diff_img + true_cfg_scale * diff_txt
111
  return pred_guided
112
 
 
113
  class SynCDFluxPipeline(FluxPipeline):
114
-
115
  model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
116
  _optional_components = []
117
  _callback_tensor_inputs = ["latents", "prompt_embeds"]
@@ -127,7 +124,7 @@ class SynCDFluxPipeline(FluxPipeline):
127
  transformer: FluxTransformer2DModel,
128
  image_encoder: CLIPVisionModelWithProjection = None,
129
  feature_extractor: CLIPImageProcessor = None,
130
- ###
131
  num=2,
132
  ):
133
  super().__init__(
@@ -173,8 +170,8 @@ class SynCDFluxPipeline(FluxPipeline):
173
  #####
174
  latents_ref: Optional[torch.Tensor] = None,
175
  latents_mask: Optional[torch.Tensor] = None,
176
- return_latents: bool=False,
177
- image_cfg_scale: float=0.0,
178
  ):
179
  r"""
180
  Function invoked when calling the pipeline for generation.
@@ -386,7 +383,7 @@ class SynCDFluxPipeline(FluxPipeline):
386
  self._current_timestep = t
387
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
388
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
389
- self.joint_attention_kwargs.update({'timestep': t/1000, 'val': True})
390
  if self.joint_attention_kwargs is not None and self.joint_attention_kwargs['shared_attn'] and latents_ref is not None and latents_mask is not None:
391
  latents = (1 - latents_mask) * latents_ref + latents_mask * latents
392
 
@@ -427,13 +424,12 @@ class SynCDFluxPipeline(FluxPipeline):
427
  joint_attention_kwargs=self.joint_attention_kwargs,
428
  return_dict=False,
429
  )[0]
430
-
431
  if image_cfg_scale == 0:
432
  noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
433
  else:
434
  noise_pred = normalized_guidance_image(neg_noise_pred, noise_pred, image_noise_pred, true_cfg_scale, image_cfg_scale)
435
 
436
-
437
  # compute the previous noisy sample x_t -> x_t-1
438
  latents_dtype = latents.dtype
439
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -471,4 +467,4 @@ class SynCDFluxPipeline(FluxPipeline):
471
  # Offload all models
472
  self.maybe_free_model_hooks()
473
 
474
- return (image,)
 
17
 
18
  import numpy as np
19
  import torch
20
+ from diffusers import FluxPipeline
21
+ from diffusers.models.autoencoders import AutoencoderKL
22
+ from diffusers.models.transformers import FluxTransformer2DModel
23
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
24
+ from diffusers.utils import is_torch_xla_available
25
  from transformers import (
26
  CLIPImageProcessor,
27
  CLIPTextModel,
 
31
  T5TokenizerFast,
32
  )
33
 
 
 
 
 
 
 
 
 
34
  if is_torch_xla_available():
35
  import torch_xla.core.xla_model as xm
36
 
 
39
  XLA_AVAILABLE = False
40
 
41
 
 
42
  def calculate_shift(
43
  image_seq_len,
44
  base_seq_len: int = 256,
 
98
  diff_img = image_noise_pred - neg_noise_pred
99
  diff_txt = noise_pred - image_noise_pred
100
 
101
+ diff_norm_txt = diff_txt.norm(p=2, dim=[-1, -2], keepdim=True)
102
+ diff_norm_img = diff_img.norm(p=2, dim=[-1, -2], keepdim=True)
103
  min_norm = torch.minimum(diff_norm_img, diff_norm_txt)
104
  diff_txt = diff_txt * torch.minimum(torch.ones_like(diff_txt), min_norm / diff_norm_txt)
105
  diff_img = diff_img * torch.minimum(torch.ones_like(diff_txt), min_norm / diff_norm_img)
106
+ pred_guided = image_noise_pred + image_cfg_scale * diff_img + true_cfg_scale * diff_txt
107
  return pred_guided
108
 
109
+
110
  class SynCDFluxPipeline(FluxPipeline):
111
+
112
  model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
113
  _optional_components = []
114
  _callback_tensor_inputs = ["latents", "prompt_embeds"]
 
124
  transformer: FluxTransformer2DModel,
125
  image_encoder: CLIPVisionModelWithProjection = None,
126
  feature_extractor: CLIPImageProcessor = None,
127
+ ###
128
  num=2,
129
  ):
130
  super().__init__(
 
170
  #####
171
  latents_ref: Optional[torch.Tensor] = None,
172
  latents_mask: Optional[torch.Tensor] = None,
173
+ return_latents: bool = False,
174
+ image_cfg_scale: float = 0.0,
175
  ):
176
  r"""
177
  Function invoked when calling the pipeline for generation.
 
383
  self._current_timestep = t
384
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
385
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
386
+ self.joint_attention_kwargs.update({'timestep': t/1000})
387
  if self.joint_attention_kwargs is not None and self.joint_attention_kwargs['shared_attn'] and latents_ref is not None and latents_mask is not None:
388
  latents = (1 - latents_mask) * latents_ref + latents_mask * latents
389
 
 
424
  joint_attention_kwargs=self.joint_attention_kwargs,
425
  return_dict=False,
426
  )[0]
427
+
428
  if image_cfg_scale == 0:
429
  noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
430
  else:
431
  noise_pred = normalized_guidance_image(neg_noise_pred, noise_pred, image_noise_pred, true_cfg_scale, image_cfg_scale)
432
 
 
433
  # compute the previous noisy sample x_t -> x_t-1
434
  latents_dtype = latents.dtype
435
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 
467
  # Offload all models
468
  self.maybe_free_model_hooks()
469
 
470
+ return (image,)