Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,718 Bytes
3e8fe6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
import torch.nn.functional as F
from typing import Optional, Tuple, Dict, Any
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from diffusers.models.attention_processor import FluxAttnProcessor2_0
class VisualFluxAttnProcessor2_0(FluxAttnProcessor2_0):
"""
自定义的Flux注意力处理器,用于保存注意力图进行可视化
"""
def __init__(self, save_attention=True, save_dir="attention_maps"):
super().__init__()
self.save_attention = save_attention
self.save_dir = save_dir
self.step_counter = 0
# 创建保存目录
if self.save_attention:
os.makedirs(self.save_dir, exist_ok=True)
def save_attention_map(self, attn_weights, layer_name="", step=None):
"""保存注意力图"""
if not self.save_attention:
return
if step is None:
step = self.step_counter
# 取第一个batch和第一个head的注意力权重
attn_map = attn_weights[0, 0].detach().cpu().numpy() # [seq_len, seq_len]
# 创建热力图
plt.figure(figsize=(12, 10))
plt.imshow(attn_map, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.title(f'Attention Map - {layer_name} - Step {step}')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
# 保存图片
save_path = os.path.join(self.save_dir, f"attention_{layer_name}_step_{step}.png")
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Attention map saved to: {save_path}")
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cond: bool = False,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 应用旋转位置编码
if image_rotary_emb is not None:
query = attn.rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = attn.rotary_emb(key, image_rotary_emb)
# 计算注意力权重
attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (head_dim ** 0.5)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_probs = F.softmax(attention_scores, dim=-1)
# 保存注意力图
if self.save_attention and self.step_counter % 10 == 0: # 每10步保存一次
layer_name = f"layer_{self.step_counter // 10}"
self.save_attention_map(attention_probs, layer_name, self.step_counter)
# 应用dropout
attention_probs = F.dropout(attention_probs, p=attn.dropout, training=attn.training)
# 计算输出
hidden_states = torch.matmul(attention_probs, value)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if use_cond:
# 处理条件分支的情况
seq_len = hidden_states.shape[1]
if seq_len % 2 == 0:
# 假设前半部分是原始hidden_states,后半部分是条件hidden_states
mid_point = seq_len // 2
original_hidden_states = hidden_states[:, :mid_point, :]
cond_hidden_states = hidden_states[:, mid_point:, :]
# 分别处理
original_output = attn.to_out[0](original_hidden_states)
cond_output = attn.to_out[0](cond_hidden_states)
if len(attn.to_out) > 1:
original_output = attn.to_out[1](original_output)
cond_output = attn.to_out[1](cond_output)
self.step_counter += 1
return original_output, cond_output
# 标准输出处理
hidden_states = attn.to_out[0](hidden_states)
if len(attn.to_out) > 1:
hidden_states = attn.to_out[1](hidden_states)
self.step_counter += 1
return hidden_states
|