ImageCritic / src /layers.py
ziheng1234's picture
Upload 39 files
3e8fe6c verified
raw
history blame
30.9 kB
import inspect
import math
from typing import Callable, List, Optional, Tuple, Union
from einops import rearrange
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from diffusers.models.attention_processor import Attention
# Global variables for attention visualization
step = 0
global_timestep = 0
global_timestep2 = 0
def scaled_dot_product_average_attention_map(query, key, attn_mask=None, is_causal=False, scale=None) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias.to(attn_weight.device)
attn_weight = attn_weight.mean(dim=(1, 2))
return attn_weight
class LoRALinearLayer(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 4,
network_alpha: Optional[float] = None,
device: Optional[Union[torch.device, str]] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
self.out_features = out_features
self.in_features = in_features
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class MultiSingleStreamBlockLoraProcessor(nn.Module):
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.q_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.lora_weights = lora_weights
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond = False,
) -> torch.FloatTensor:
batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
return hidden_states
class MultiDoubleStreamBlockLoraProcessor(nn.Module):
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.q_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.proj_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.lora_weights = lora_weights
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond=False,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `context` projections.
inner_dim = 3072
head_dim = inner_dim // attn.heads
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# Linear projection (with LoRA weight applied to each proj layer)
hidden_states = attn.to_out[0](hidden_states)
for i in range(self.n_loras):
hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return (hidden_states, encoder_hidden_states)
class MultiSingleStreamBlockLoraProcessorWithLoss(nn.Module):
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.q_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.lora_weights = lora_weights
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond = False,
) -> torch.FloatTensor:
batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_hidden_length = 512
length = (hidden_states.shape[-2] - encoder_hidden_length) // 3
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
# query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
# key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
# attention_probs_query_a_key_noise = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
# attention_probs_query_b_key_noise = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
# attn.attention_probs_query_a_key_noise = attention_probs_query_a_key_noise
# attn.attention_probs_query_b_key_noise = attention_probs_query_b_key_noise
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
return hidden_states
class MultiDoubleStreamBlockLoraProcessorWithLoss(nn.Module):
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.q_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.proj_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.lora_weights = lora_weights
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond=False,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `context` projections.
inner_dim = 3072
head_dim = inner_dim // attn.heads
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
length = hidden_states.shape[-2] // 3
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
encoder_hidden_length = 512
query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
attention_probs_query_a_key_noise = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
attention_probs_query_b_key_noise = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
attn.attention_probs_query_a_key_noise = attention_probs_query_a_key_noise
attn.attention_probs_query_b_key_noise = attention_probs_query_b_key_noise
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# Linear projection (with LoRA weight applied to each proj layer)
hidden_states = attn.to_out[0](hidden_states)
for i in range(self.n_loras):
hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return (hidden_states, encoder_hidden_states)
class MultiDoubleStreamBlockLoraProcessor_visual(nn.Module):
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.q_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.proj_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.lora_weights = lora_weights
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond=False,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `context` projections.
inner_dim = 3072
head_dim = inner_dim // attn.heads
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
length = hidden_states.shape[-2] // 3
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
encoder_hidden_length = 512
query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
attention_probs_query_a_key_noise = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
attention_probs_query_b_key_noise = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
if not hasattr(attn, 'attention_probs_query_a_key_noise'):
attn.attention_probs_query_a_key_noise = []
if not hasattr(attn, 'attention_probs_query_b_key_noise'):
attn.attention_probs_query_b_key_noise = []
global global_timestep
attn.attention_probs_query_a_key_noise.append((global_timestep//19, attention_probs_query_a_key_noise))
attn.attention_probs_query_b_key_noise.append((global_timestep//19, attention_probs_query_b_key_noise))
print(f"Global Timestep: {global_timestep//19}")
global_timestep += 1
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# Linear projection (with LoRA weight applied to each proj layer)
hidden_states = attn.to_out[0](hidden_states)
for i in range(self.n_loras):
hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return (hidden_states, encoder_hidden_states)
class MultiSingleStreamBlockLoraProcessor_visual(nn.Module):
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.q_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.lora_weights = lora_weights
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond = False,
) -> torch.FloatTensor:
batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_hidden_length = 512
length = (hidden_states.shape[-2] - encoder_hidden_length) // 3
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if not hasattr(attn, 'attention_probs_query_a_key_noise2'):
attn.attention_probs_query_a_key_noise2 = []
if not hasattr(attn, 'attention_probs_query_b_key_noise2'):
attn.attention_probs_query_b_key_noise2 = []
query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
attention_probs_query_a_key_noise2 = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
attention_probs_query_b_key_noise2 = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
global global_timestep2
attn.attention_probs_query_a_key_noise2.append((global_timestep//38, attention_probs_query_a_key_noise2))
attn.attention_probs_query_b_key_noise2.append((global_timestep//38, attention_probs_query_b_key_noise2))
print(f"Global Timestep2: {global_timestep2//38}")
global_timestep2 += 1
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
return hidden_states