# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import math import types from copy import deepcopy from einops import rearrange from typing import List import numpy as np import torch import torch.cuda.amp as amp import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from diffusers.loaders import PeftAdapterMixin from ...distributed.sequence_parallel import ( distributed_attention, gather_forward, get_rank, get_world_size, ) from ..model import ( Head, WanAttentionBlock, WanLayerNorm, WanRMSNorm, WanModel, WanSelfAttention, flash_attention, rope_params, sinusoidal_embedding_1d, rope_apply ) from .face_blocks import FaceEncoder, FaceAdapter from .motion_encoder import Generator class HeadAnimate(Head): def forward(self, x, e): """ Args: x(Tensor): Shape [B, L1, C] e(Tensor): Shape [B, L1, C] """ assert e.dtype == torch.float32 with amp.autocast(dtype=torch.float32): e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) return x class WanAnimateSelfAttention(WanSelfAttention): def forward(self, x, seq_lens, grid_sizes, freqs): """ Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] seq_lens(Tensor): Shape [B] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim # query, key, value function def qkv_fn(x): q = self.norm_q(self.q(x)).view(b, s, n, d) k = self.norm_k(self.k(x)).view(b, s, n, d) v = self.v(x).view(b, s, n, d) return q, k, v q, k, v = qkv_fn(x) x = flash_attention( q=rope_apply(q, grid_sizes, freqs), k=rope_apply(k, grid_sizes, freqs), v=v, k_lens=seq_lens, window_size=self.window_size) # output x = x.flatten(2) x = self.o(x) return x class WanAnimateCrossAttention(WanSelfAttention): def __init__( self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, use_img_emb=True ): super().__init__( dim, num_heads, window_size, qk_norm, eps ) self.use_img_emb = use_img_emb if use_img_emb: self.k_img = nn.Linear(dim, dim) self.v_img = nn.Linear(dim, dim) self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() def forward(self, x, context, context_lens): """ x: [B, L1, C]. context: [B, L2, C]. context_lens: [B]. """ if self.use_img_emb: context_img = context[:, :257] context = context[:, 257:] else: context = context b, n, d = x.size(0), self.num_heads, self.head_dim # compute query, key, value q = self.norm_q(self.q(x)).view(b, -1, n, d) k = self.norm_k(self.k(context)).view(b, -1, n, d) v = self.v(context).view(b, -1, n, d) if self.use_img_emb: k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) v_img = self.v_img(context_img).view(b, -1, n, d) img_x = flash_attention(q, k_img, v_img, k_lens=None) # compute attention x = flash_attention(q, k, v, k_lens=context_lens) # output x = x.flatten(2) if self.use_img_emb: img_x = img_x.flatten(2) x = x + img_x x = self.o(x) return x class WanAnimateAttentionBlock(nn.Module): def __init__(self, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=True, eps=1e-6, use_img_emb=True): super().__init__() self.dim = dim self.ffn_dim = ffn_dim self.num_heads = num_heads self.window_size = window_size self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps # layers self.norm1 = WanLayerNorm(dim, eps) self.self_attn = WanAnimateSelfAttention(dim, num_heads, window_size, qk_norm, eps) self.norm3 = WanLayerNorm( dim, eps, elementwise_affine=True ) if cross_attn_norm else nn.Identity() self.cross_attn = WanAnimateCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps, use_img_emb=use_img_emb) self.norm2 = WanLayerNorm(dim, eps) self.ffn = nn.Sequential( nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(ffn_dim, dim) ) # modulation self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5) def forward( self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, ): """ Args: x(Tensor): Shape [B, L, C] e(Tensor): Shape [B, L1, 6, C] seq_lens(Tensor): Shape [B], length of each sequence in batch grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ assert e.dtype == torch.float32 with amp.autocast(dtype=torch.float32): e = (self.modulation + e).chunk(6, dim=1) assert e[0].dtype == torch.float32 # self-attention y = self.self_attn( self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs ) with amp.autocast(dtype=torch.float32): x = x + y * e[2] # cross-attention & ffn function def cross_attn_ffn(x, context, context_lens, e): x = x + self.cross_attn(self.norm3(x), context, context_lens) y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) with amp.autocast(dtype=torch.float32): x = x + y * e[5] return x x = cross_attn_ffn(x, context, context_lens, e) return x class MLPProj(torch.nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.proj = torch.nn.Sequential( torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), torch.nn.LayerNorm(out_dim), ) def forward(self, image_embeds): clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _no_split_modules = ['WanAttentionBlock'] @register_to_config def __init__(self, patch_size=(1, 2, 2), text_len=512, in_dim=36, dim=5120, ffn_dim=13824, freq_dim=256, text_dim=4096, out_dim=16, num_heads=40, num_layers=40, window_size=(-1, -1), qk_norm=True, cross_attn_norm=True, eps=1e-6, motion_encoder_dim=512, use_context_parallel=False, use_img_emb=True): super().__init__() self.patch_size = patch_size self.text_len = text_len self.in_dim = in_dim self.dim = dim self.ffn_dim = ffn_dim self.freq_dim = freq_dim self.text_dim = text_dim self.out_dim = out_dim self.num_heads = num_heads self.num_layers = num_layers self.window_size = window_size self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps self.motion_encoder_dim = motion_encoder_dim self.use_context_parallel = use_context_parallel self.use_img_emb = use_img_emb # embeddings self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) self.pose_patch_embedding = nn.Conv3d( 16, dim, kernel_size=patch_size, stride=patch_size ) self.text_embedding = nn.Sequential( nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim)) self.time_embedding = nn.Sequential( nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) # blocks self.blocks = nn.ModuleList([ WanAnimateAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, use_img_emb) for _ in range(num_layers) ]) # head self.head = HeadAnimate(dim, out_dim, patch_size, eps) # buffers (don't use register_buffer otherwise dtype will be changed in to()) assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 d = dim // num_heads _freqs = torch.cat([ rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)) ], dim=1) self.register_buffer("freqs", _freqs, persistent=False) self.img_emb = MLPProj(1280, dim) # initialize weights self.init_weights() self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) self.face_adapter = FaceAdapter( heads_num=self.num_heads, hidden_dim=self.dim, num_adapter_layers=self.num_layers // 5, ) self.face_encoder = FaceEncoder( in_dim=motion_encoder_dim, hidden_dim=self.dim, num_heads=4, ) def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents] for x_, pose_latents_ in zip(x, pose_latents): x_[:, :, 1:] += pose_latents_ b,c,T,h,w = face_pixel_values.shape face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") encode_bs = 8 face_pixel_values_tmp = [] for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) motion_vec = torch.cat(face_pixel_values_tmp) motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) motion_vec = self.face_encoder(motion_vec) B, L, H, C = motion_vec.shape pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) motion_vec = torch.cat([pad_face, motion_vec], dim=1) return x, motion_vec def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None): if block_idx % 5 == 0: adapter_args = [x, motion_vec, motion_masks, self.use_context_parallel] residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args) x = residual_out + x return x def forward( self, x, t, clip_fea, context, seq_len, y=None, pose_latents=None, face_pixel_values=None ): # params freqs = self.freqs if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] # embeddings x = [self.patch_embedding(u.unsqueeze(0)) for u in x] x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values) grid_sizes = torch.stack( [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) x = [u.flatten(2).transpose(1, 2) for u in x] seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) assert seq_lens.max() <= seq_len x = torch.cat([ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x ]) # time embeddings with amp.autocast(dtype=torch.float32): e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t).float() ) e0 = self.time_projection(e).unflatten(1, (6, self.dim)) assert e.dtype == torch.float32 and e0.dtype == torch.float32 # context context_lens = None context = self.text_embedding( torch.stack([ torch.cat( [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context ])) if self.use_img_emb: context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) # arguments kwargs = dict( e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=freqs, context=context, context_lens=context_lens) if self.use_context_parallel: x = torch.chunk(x, get_world_size(), dim=1)[get_rank()] for idx, block in enumerate(self.blocks): x = block(x, **kwargs) x = self.after_transformer_block(idx, x, motion_vec) # head x = self.head(x, e) if self.use_context_parallel: x = gather_forward(x, dim=1) # unpatchify x = self.unpatchify(x, grid_sizes) return [u.float() for u in x] def unpatchify(self, x, grid_sizes): r""" Reconstruct video tensors from patch embeddings. Args: x (List[Tensor]): List of patchified features, each with shape [L, C_out * prod(patch_size)] grid_sizes (Tensor): Original spatial-temporal grid dimensions before patching, shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) Returns: List[Tensor]: Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] """ c = self.out_dim out = [] for u, v in zip(x, grid_sizes.tolist()): u = u[:math.prod(v)].view(*v, *self.patch_size, c) u = torch.einsum('fhwpqrc->cfphqwr', u) u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) out.append(u) return out def init_weights(self): r""" Initialize model parameters using Xavier initialization. """ # basic init for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) # init embeddings nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) for m in self.text_embedding.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=.02) for m in self.time_embedding.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=.02) # init output layer nn.init.zeros_(self.head.head.weight)