|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from PIL import Image
|
|
|
from torchvision.transforms import ToTensor
|
|
|
from torchvision.utils import save_image
|
|
|
import matplotlib.pyplot as plt
|
|
|
import math
|
|
|
|
|
|
|
|
|
def register_attn_control(unet, controller, cache=None):
|
|
|
def attn_forward(self):
|
|
|
def forward(
|
|
|
hidden_states,
|
|
|
encoder_hidden_states=None,
|
|
|
attention_mask=None,
|
|
|
temb=None,
|
|
|
*args,
|
|
|
**kwargs,
|
|
|
):
|
|
|
residual = hidden_states
|
|
|
if self.spatial_norm is not None:
|
|
|
hidden_states = self.spatial_norm(hidden_states, temb)
|
|
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
|
|
if input_ndim == 4:
|
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
|
hidden_states = hidden_states.view(
|
|
|
batch_size, channel, height * width
|
|
|
).transpose(1, 2)
|
|
|
|
|
|
batch_size, sequence_length, _ = (
|
|
|
hidden_states.shape
|
|
|
if encoder_hidden_states is None
|
|
|
else encoder_hidden_states.shape
|
|
|
)
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
attention_mask = self.prepare_attention_mask(
|
|
|
attention_mask, sequence_length, batch_size
|
|
|
)
|
|
|
|
|
|
|
|
|
attention_mask = attention_mask.view(
|
|
|
batch_size, self.heads, -1, attention_mask.shape[-1]
|
|
|
)
|
|
|
|
|
|
if self.group_norm is not None:
|
|
|
hidden_states = self.group_norm(
|
|
|
hidden_states.transpose(1, 2)
|
|
|
).transpose(1, 2)
|
|
|
|
|
|
q = self.to_q(hidden_states)
|
|
|
is_self = encoder_hidden_states is None
|
|
|
|
|
|
if encoder_hidden_states is None:
|
|
|
encoder_hidden_states = hidden_states
|
|
|
elif self.norm_cross:
|
|
|
encoder_hidden_states = self.norm_encoder_hidden_states(
|
|
|
encoder_hidden_states
|
|
|
)
|
|
|
|
|
|
k = self.to_k(encoder_hidden_states)
|
|
|
v = self.to_v(encoder_hidden_states)
|
|
|
|
|
|
inner_dim = k.shape[-1]
|
|
|
head_dim = inner_dim // self.heads
|
|
|
|
|
|
q = q.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
|
k = k.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
|
v = v.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention(
|
|
|
q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
|
)
|
|
|
if is_self and controller.cur_self_layer in controller.self_layers:
|
|
|
cache.add(q, k, v, hidden_states)
|
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(
|
|
|
batch_size, -1, self.heads * head_dim
|
|
|
)
|
|
|
hidden_states = hidden_states.to(q.dtype)
|
|
|
|
|
|
|
|
|
hidden_states = self.to_out[0](hidden_states)
|
|
|
|
|
|
hidden_states = self.to_out[1](hidden_states)
|
|
|
|
|
|
if input_ndim == 4:
|
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
|
|
batch_size, channel, height, width
|
|
|
)
|
|
|
if self.residual_connection:
|
|
|
hidden_states = hidden_states + residual
|
|
|
|
|
|
hidden_states = hidden_states / self.rescale_output_factor
|
|
|
|
|
|
if is_self:
|
|
|
controller.cur_self_layer += 1
|
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
return forward
|
|
|
|
|
|
def modify_forward(net, count):
|
|
|
for name, subnet in net.named_children():
|
|
|
if net.__class__.__name__ == "Attention":
|
|
|
net.forward = attn_forward(net)
|
|
|
return count + 1
|
|
|
elif hasattr(net, "children"):
|
|
|
count = modify_forward(subnet, count)
|
|
|
return count
|
|
|
|
|
|
cross_att_count = 0
|
|
|
for net_name, net in unet.named_children():
|
|
|
cross_att_count += modify_forward(net, 0)
|
|
|
controller.num_self_layers = cross_att_count // 2
|
|
|
|
|
|
|
|
|
def load_image(image_path, size=None, mode="RGB"):
|
|
|
img = Image.open(image_path).convert(mode)
|
|
|
if size is None:
|
|
|
width, height = img.size
|
|
|
new_width = (width // 64) * 64
|
|
|
new_height = (height // 64) * 64
|
|
|
size = (new_width, new_height)
|
|
|
img = img.resize(size, Image.BICUBIC)
|
|
|
return ToTensor()(img).unsqueeze(0)
|
|
|
|
|
|
|
|
|
def adain(source, target, eps=1e-6):
|
|
|
source_mean, source_std = torch.mean(source, dim=(2, 3), keepdim=True), torch.std(
|
|
|
source, dim=(2, 3), keepdim=True
|
|
|
)
|
|
|
target_mean, target_std = torch.mean(
|
|
|
target, dim=(0, 2, 3), keepdim=True
|
|
|
), torch.std(target, dim=(0, 2, 3), keepdim=True)
|
|
|
normalized_source = (source - source_mean) / (source_std + eps)
|
|
|
transferred_source = normalized_source * target_std + target_mean
|
|
|
|
|
|
return transferred_source
|
|
|
|
|
|
|
|
|
class Controller:
|
|
|
def step(self):
|
|
|
self.cur_self_layer = 0
|
|
|
|
|
|
def __init__(self, self_layers=(0, 16)):
|
|
|
self.num_self_layers = -1
|
|
|
self.cur_self_layer = 0
|
|
|
self.self_layers = list(range(*self_layers))
|
|
|
|
|
|
|
|
|
class DataCache:
|
|
|
def __init__(self):
|
|
|
self.q = []
|
|
|
self.k = []
|
|
|
self.v = []
|
|
|
self.out = []
|
|
|
|
|
|
def clear(self):
|
|
|
self.q.clear()
|
|
|
self.k.clear()
|
|
|
self.v.clear()
|
|
|
self.out.clear()
|
|
|
|
|
|
def add(self, q, k, v, out):
|
|
|
self.q.append(q)
|
|
|
self.k.append(k)
|
|
|
self.v.append(v)
|
|
|
self.out.append(out)
|
|
|
|
|
|
def get(self):
|
|
|
return self.q.copy(), self.k.copy(), self.v.copy(), self.out.copy()
|
|
|
|
|
|
|
|
|
|
|
|
def show_image(path, title, display_height=3, title_fontsize=12):
|
|
|
img = Image.open(path)
|
|
|
img_width, img_height = img.size
|
|
|
|
|
|
aspect_ratio = img_width / img_height
|
|
|
display_width = display_height * aspect_ratio
|
|
|
|
|
|
plt.figure(figsize=(display_width, display_height))
|
|
|
plt.imshow(img)
|
|
|
plt.title(title,
|
|
|
fontsize=title_fontsize,
|
|
|
fontweight='bold',
|
|
|
pad=20)
|
|
|
plt.axis('off')
|
|
|
plt.tight_layout()
|
|
|
plt.show()
|
|
|
|