|
|
import diffusers |
|
|
import torch |
|
|
import random |
|
|
from tqdm import tqdm |
|
|
from constants import SUBJECTS, MEDIUMS |
|
|
from PIL import Image |
|
|
|
|
|
class CLIPSlider: |
|
|
def __init__( |
|
|
self, |
|
|
sd_pipe, |
|
|
device: torch.device, |
|
|
target_word: str = "", |
|
|
opposite: str = "", |
|
|
target_word_2nd: str = "", |
|
|
opposite_2nd: str = "", |
|
|
iterations: int = 300, |
|
|
): |
|
|
|
|
|
self.device = device |
|
|
self.pipe = sd_pipe.to(self.device) |
|
|
self.iterations = iterations |
|
|
if target_word != "" or opposite != "": |
|
|
self.avg_diff = self.find_latent_direction(target_word, opposite) |
|
|
else: |
|
|
self.avg_diff = None |
|
|
if target_word_2nd != "" or opposite_2nd != "": |
|
|
self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd) |
|
|
else: |
|
|
self.avg_diff_2nd = None |
|
|
|
|
|
|
|
|
def find_latent_direction(self, |
|
|
target_word:str, |
|
|
opposite:str, |
|
|
num_iterations: int = None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if num_iterations is not None: |
|
|
iterations = num_iterations |
|
|
else: |
|
|
iterations = self.iterations |
|
|
with torch.no_grad(): |
|
|
positives = [] |
|
|
negatives = [] |
|
|
for i in tqdm(range(iterations)): |
|
|
medium = random.choice(MEDIUMS) |
|
|
subject = random.choice(SUBJECTS) |
|
|
pos_prompt = f"a {medium} of a {target_word} {subject}" |
|
|
neg_prompt = f"a {medium} of a {opposite} {subject}" |
|
|
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, |
|
|
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda() |
|
|
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, |
|
|
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda() |
|
|
pos = self.pipe.text_encoder(pos_toks).pooler_output |
|
|
neg = self.pipe.text_encoder(neg_toks).pooler_output |
|
|
positives.append(pos) |
|
|
negatives.append(neg) |
|
|
|
|
|
positives = torch.cat(positives, dim=0) |
|
|
negatives = torch.cat(negatives, dim=0) |
|
|
|
|
|
diffs = positives - negatives |
|
|
|
|
|
avg_diff = diffs.mean(0, keepdim=True) |
|
|
return avg_diff |
|
|
|
|
|
|
|
|
def generate(self, |
|
|
prompt = "a photo of a house", |
|
|
scale = 2., |
|
|
scale_2nd = 0., |
|
|
seed = 15, |
|
|
only_pooler = False, |
|
|
normalize_scales = False, |
|
|
correlation_weight_factor = 1.0, |
|
|
avg_diff = None, |
|
|
avg_diff_2nd = None, |
|
|
**pipeline_kwargs |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, |
|
|
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda() |
|
|
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state |
|
|
|
|
|
if avg_diff_2nd and normalize_scales: |
|
|
denominator = abs(scale) + abs(scale_2nd) |
|
|
scale = scale / denominator |
|
|
scale_2nd = scale_2nd / denominator |
|
|
if only_pooler: |
|
|
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale |
|
|
if avg_diff_2nd: |
|
|
prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd |
|
|
else: |
|
|
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) |
|
|
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T |
|
|
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768) |
|
|
|
|
|
standard_weights = torch.ones_like(weights) |
|
|
|
|
|
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor |
|
|
|
|
|
|
|
|
prompt_embeds = prompt_embeds + ( |
|
|
weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale) |
|
|
if avg_diff_2nd: |
|
|
prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd |
|
|
|
|
|
|
|
|
torch.manual_seed(seed) |
|
|
image = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images[0] |
|
|
|
|
|
return image |
|
|
|
|
|
def spectrum(self, |
|
|
prompt="a photo of a house", |
|
|
low_scale=-2, |
|
|
low_scale_2nd=-2, |
|
|
high_scale=2, |
|
|
high_scale_2nd=2, |
|
|
steps=5, |
|
|
seed=15, |
|
|
only_pooler=False, |
|
|
normalize_scales=False, |
|
|
correlation_weight_factor=1.0, |
|
|
**pipeline_kwargs |
|
|
): |
|
|
|
|
|
images = [] |
|
|
for i in range(steps): |
|
|
scale = low_scale + (high_scale - low_scale) * i / (steps - 1) |
|
|
scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1) |
|
|
image = self.generate(prompt, scale, scale_2nd, seed, only_pooler, normalize_scales, correlation_weight_factor, **pipeline_kwargs) |
|
|
images.append(image[0]) |
|
|
|
|
|
canvas = Image.new('RGB', (640 * steps, 640)) |
|
|
for i, im in enumerate(images): |
|
|
canvas.paste(im, (640 * i, 0)) |
|
|
|
|
|
return canvas |
|
|
|
|
|
class CLIPSliderXL(CLIPSlider): |
|
|
|
|
|
def find_latent_direction(self, |
|
|
target_word:str, |
|
|
opposite:str, |
|
|
num_iterations: int = None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if num_iterations is not None: |
|
|
iterations = num_iterations |
|
|
else: |
|
|
iterations = self.iterations |
|
|
|
|
|
with torch.no_grad(): |
|
|
positives = [] |
|
|
negatives = [] |
|
|
positives2 = [] |
|
|
negatives2 = [] |
|
|
for i in tqdm(range(iterations)): |
|
|
medium = random.choice(MEDIUMS) |
|
|
subject = random.choice(SUBJECTS) |
|
|
pos_prompt = f"a {medium} of a {target_word} {subject}" |
|
|
neg_prompt = f"a {medium} of a {opposite} {subject}" |
|
|
|
|
|
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, |
|
|
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda() |
|
|
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, |
|
|
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda() |
|
|
pos = self.pipe.text_encoder(pos_toks).pooler_output |
|
|
neg = self.pipe.text_encoder(neg_toks).pooler_output |
|
|
positives.append(pos) |
|
|
negatives.append(neg) |
|
|
|
|
|
pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, |
|
|
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda() |
|
|
neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, |
|
|
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda() |
|
|
pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds |
|
|
neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds |
|
|
positives2.append(pos2) |
|
|
negatives2.append(neg2) |
|
|
|
|
|
positives = torch.cat(positives, dim=0) |
|
|
negatives = torch.cat(negatives, dim=0) |
|
|
diffs = positives - negatives |
|
|
avg_diff = diffs.mean(0, keepdim=True) |
|
|
|
|
|
positives2 = torch.cat(positives2, dim=0) |
|
|
negatives2 = torch.cat(negatives2, dim=0) |
|
|
diffs2 = positives2 - negatives2 |
|
|
avg_diff2 = diffs2.mean(0, keepdim=True) |
|
|
return (avg_diff, avg_diff2) |
|
|
|
|
|
def generate(self, |
|
|
prompt = "a photo of a house", |
|
|
scale = 2, |
|
|
scale_2nd = 2, |
|
|
seed = 15, |
|
|
only_pooler = False, |
|
|
normalize_scales = False, |
|
|
correlation_weight_factor = 1.0, |
|
|
avg_diff = None, |
|
|
avg_diff_2nd = None, |
|
|
**pipeline_kwargs |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2] |
|
|
tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2] |
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
prompt_embeds_list = [] |
|
|
|
|
|
for i, text_encoder in enumerate(text_encoders): |
|
|
|
|
|
tokenizer = tokenizers[i] |
|
|
text_inputs = tokenizer( |
|
|
prompt, |
|
|
padding="max_length", |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
toks = text_inputs.input_ids |
|
|
|
|
|
prompt_embeds = text_encoder( |
|
|
toks.to(text_encoder.device), |
|
|
output_hidden_states=True, |
|
|
) |
|
|
|
|
|
|
|
|
pooled_prompt_embeds = prompt_embeds[0] |
|
|
prompt_embeds = prompt_embeds.hidden_states[-2] |
|
|
|
|
|
if avg_diff_2nd and normalize_scales: |
|
|
denominator = abs(scale) + abs(scale_2nd) |
|
|
scale = scale / denominator |
|
|
scale_2nd = scale_2nd / denominator |
|
|
if only_pooler: |
|
|
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale |
|
|
if avg_diff_2nd: |
|
|
prompt_embeds[:, toks.argmax()] += avg_diff_2nd[0] * scale_2nd |
|
|
else: |
|
|
|
|
|
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) |
|
|
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T |
|
|
|
|
|
if i == 0: |
|
|
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768) |
|
|
|
|
|
standard_weights = torch.ones_like(weights) |
|
|
|
|
|
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor |
|
|
prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale) |
|
|
if avg_diff_2nd: |
|
|
prompt_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd) |
|
|
else: |
|
|
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280) |
|
|
|
|
|
standard_weights = torch.ones_like(weights) |
|
|
|
|
|
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor |
|
|
prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale) |
|
|
if avg_diff_2nd: |
|
|
prompt_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd) |
|
|
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape |
|
|
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) |
|
|
prompt_embeds_list.append(prompt_embeds) |
|
|
|
|
|
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) |
|
|
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, |
|
|
**pipeline_kwargs).images[0] |
|
|
|
|
|
return image |
|
|
|