Spaces:
Running
on
Zero
Running
on
Zero
Update ledits/pipeline_leditspp_stable_diffusion_xl.py
Browse files
ledits/pipeline_leditspp_stable_diffusion_xl.py
CHANGED
|
@@ -415,10 +415,11 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
| 415 |
editing_prompt: Optional[str] = None,
|
| 416 |
editing_prompt_embeds: Optional[torch.Tensor] = None,
|
| 417 |
editing_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 418 |
-
avg_diff
|
| 419 |
-
|
| 420 |
-
correlation_weight_factor
|
| 421 |
scale=2,
|
|
|
|
| 422 |
) -> object:
|
| 423 |
r"""
|
| 424 |
Encodes the prompt into text encoder hidden states.
|
|
@@ -538,9 +539,8 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
| 538 |
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
| 539 |
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
| 540 |
|
| 541 |
-
if avg_diff is not None
|
| 542 |
-
#scale=3
|
| 543 |
-
print("SHALOM neg")
|
| 544 |
normed_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
|
| 545 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
| 546 |
if j == 0:
|
|
@@ -549,15 +549,26 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
| 549 |
standard_weights = torch.ones_like(weights)
|
| 550 |
|
| 551 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 552 |
-
edit_concepts_embeds = negative_prompt_embeds + (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
else:
|
| 554 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
| 555 |
|
| 556 |
standard_weights = torch.ones_like(weights)
|
| 557 |
|
| 558 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 559 |
-
edit_concepts_embeds = negative_prompt_embeds + (
|
|
|
|
| 560 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
|
| 562 |
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
| 563 |
j+=1
|
|
@@ -878,10 +889,12 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
| 878 |
clip_skip: Optional[int] = None,
|
| 879 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 880 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 881 |
-
avg_diff
|
| 882 |
-
|
| 883 |
-
correlation_weight_factor
|
| 884 |
scale=2,
|
|
|
|
|
|
|
| 885 |
init_latents: [torch.Tensor] = None,
|
| 886 |
zs: [torch.Tensor] = None,
|
| 887 |
**kwargs,
|
|
@@ -1088,9 +1101,10 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
| 1088 |
editing_prompt_embeds=editing_prompt_embeddings,
|
| 1089 |
editing_pooled_prompt_embeds=editing_pooled_prompt_embeds,
|
| 1090 |
avg_diff = avg_diff,
|
| 1091 |
-
|
| 1092 |
correlation_weight_factor = correlation_weight_factor,
|
| 1093 |
scale=scale,
|
|
|
|
| 1094 |
)
|
| 1095 |
|
| 1096 |
# 4. Prepare timesteps
|
|
|
|
| 415 |
editing_prompt: Optional[str] = None,
|
| 416 |
editing_prompt_embeds: Optional[torch.Tensor] = None,
|
| 417 |
editing_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 418 |
+
avg_diff=None, # [0] -> text encoder 1,[1] ->text encoder 2
|
| 419 |
+
avg_diff_2nd=None, # text encoder 1,2
|
| 420 |
+
correlation_weight_factor=0.7,
|
| 421 |
scale=2,
|
| 422 |
+
scale_2nd=2,
|
| 423 |
) -> object:
|
| 424 |
r"""
|
| 425 |
Encodes the prompt into text encoder hidden states.
|
|
|
|
| 539 |
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
| 540 |
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
| 541 |
|
| 542 |
+
if avg_diff is not None:
|
| 543 |
+
# scale=3
|
|
|
|
| 544 |
normed_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
|
| 545 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
| 546 |
if j == 0:
|
|
|
|
| 549 |
standard_weights = torch.ones_like(weights)
|
| 550 |
|
| 551 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 552 |
+
edit_concepts_embeds = negative_prompt_embeds + (
|
| 553 |
+
weights * avg_diff[0][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
|
| 554 |
+
|
| 555 |
+
if avg_diff_2nd is not None:
|
| 556 |
+
edit_concepts_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1,
|
| 557 |
+
self.pipe.tokenizer.model_max_length,
|
| 558 |
+
1) * scale_2nd)
|
| 559 |
else:
|
| 560 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
| 561 |
|
| 562 |
standard_weights = torch.ones_like(weights)
|
| 563 |
|
| 564 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 565 |
+
edit_concepts_embeds = negative_prompt_embeds + (
|
| 566 |
+
weights * avg_diff[1][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
|
| 567 |
|
| 568 |
+
if avg_diff_2nd is not None:
|
| 569 |
+
edit_concepts_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1,
|
| 570 |
+
self.pipe.tokenizer_2.model_max_length,
|
| 571 |
+
1) * scale_2nd)
|
| 572 |
|
| 573 |
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
| 574 |
j+=1
|
|
|
|
| 889 |
clip_skip: Optional[int] = None,
|
| 890 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 891 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 892 |
+
avg_diff=None, # [0] -> text encoder 1,[1] ->text encoder 2
|
| 893 |
+
avg_diff_2nd=None, # text encoder 1,2
|
| 894 |
+
correlation_weight_factor=0.7,
|
| 895 |
scale=2,
|
| 896 |
+
scale_2nd=2,
|
| 897 |
+
correlation_weight_factor = 0.7,
|
| 898 |
init_latents: [torch.Tensor] = None,
|
| 899 |
zs: [torch.Tensor] = None,
|
| 900 |
**kwargs,
|
|
|
|
| 1101 |
editing_prompt_embeds=editing_prompt_embeddings,
|
| 1102 |
editing_pooled_prompt_embeds=editing_pooled_prompt_embeds,
|
| 1103 |
avg_diff = avg_diff,
|
| 1104 |
+
avg_diff_2nd = avg_diff_2nd,
|
| 1105 |
correlation_weight_factor = correlation_weight_factor,
|
| 1106 |
scale=scale,
|
| 1107 |
+
scale_2nd=scale_2nd
|
| 1108 |
)
|
| 1109 |
|
| 1110 |
# 4. Prepare timesteps
|