ImageCritic / src /attention_processor.py
ziheng1234's picture
Upload 39 files
3e8fe6c verified
raw
history blame
5.72 kB
import torch
import torch.nn.functional as F
from typing import Optional, Tuple, Dict, Any
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from diffusers.models.attention_processor import FluxAttnProcessor2_0
class VisualFluxAttnProcessor2_0(FluxAttnProcessor2_0):
"""
自定义的Flux注意力处理器,用于保存注意力图进行可视化
"""
def __init__(self, save_attention=True, save_dir="attention_maps"):
super().__init__()
self.save_attention = save_attention
self.save_dir = save_dir
self.step_counter = 0
# 创建保存目录
if self.save_attention:
os.makedirs(self.save_dir, exist_ok=True)
def save_attention_map(self, attn_weights, layer_name="", step=None):
"""保存注意力图"""
if not self.save_attention:
return
if step is None:
step = self.step_counter
# 取第一个batch和第一个head的注意力权重
attn_map = attn_weights[0, 0].detach().cpu().numpy() # [seq_len, seq_len]
# 创建热力图
plt.figure(figsize=(12, 10))
plt.imshow(attn_map, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.title(f'Attention Map - {layer_name} - Step {step}')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
# 保存图片
save_path = os.path.join(self.save_dir, f"attention_{layer_name}_step_{step}.png")
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Attention map saved to: {save_path}")
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cond: bool = False,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 应用旋转位置编码
if image_rotary_emb is not None:
query = attn.rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = attn.rotary_emb(key, image_rotary_emb)
# 计算注意力权重
attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (head_dim ** 0.5)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_probs = F.softmax(attention_scores, dim=-1)
# 保存注意力图
if self.save_attention and self.step_counter % 10 == 0: # 每10步保存一次
layer_name = f"layer_{self.step_counter // 10}"
self.save_attention_map(attention_probs, layer_name, self.step_counter)
# 应用dropout
attention_probs = F.dropout(attention_probs, p=attn.dropout, training=attn.training)
# 计算输出
hidden_states = torch.matmul(attention_probs, value)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if use_cond:
# 处理条件分支的情况
seq_len = hidden_states.shape[1]
if seq_len % 2 == 0:
# 假设前半部分是原始hidden_states,后半部分是条件hidden_states
mid_point = seq_len // 2
original_hidden_states = hidden_states[:, :mid_point, :]
cond_hidden_states = hidden_states[:, mid_point:, :]
# 分别处理
original_output = attn.to_out[0](original_hidden_states)
cond_output = attn.to_out[0](cond_hidden_states)
if len(attn.to_out) > 1:
original_output = attn.to_out[1](original_output)
cond_output = attn.to_out[1](cond_output)
self.step_counter += 1
return original_output, cond_output
# 标准输出处理
hidden_states = attn.to_out[0](hidden_states)
if len(attn.to_out) > 1:
hidden_states = attn.to_out[1](hidden_states)
self.step_counter += 1
return hidden_states