selfitcamera commited on
Commit
397c271
·
1 Parent(s): 8f954d0
__lib__/i18n/ar.pyc CHANGED
Binary files a/__lib__/i18n/ar.pyc and b/__lib__/i18n/ar.pyc differ
 
__lib__/i18n/da.pyc CHANGED
Binary files a/__lib__/i18n/da.pyc and b/__lib__/i18n/da.pyc differ
 
__lib__/i18n/de.pyc CHANGED
Binary files a/__lib__/i18n/de.pyc and b/__lib__/i18n/de.pyc differ
 
__lib__/i18n/en.pyc CHANGED
Binary files a/__lib__/i18n/en.pyc and b/__lib__/i18n/en.pyc differ
 
__lib__/i18n/es.pyc CHANGED
Binary files a/__lib__/i18n/es.pyc and b/__lib__/i18n/es.pyc differ
 
__lib__/i18n/fi.pyc CHANGED
Binary files a/__lib__/i18n/fi.pyc and b/__lib__/i18n/fi.pyc differ
 
__lib__/i18n/fr.pyc CHANGED
Binary files a/__lib__/i18n/fr.pyc and b/__lib__/i18n/fr.pyc differ
 
__lib__/i18n/he.pyc CHANGED
Binary files a/__lib__/i18n/he.pyc and b/__lib__/i18n/he.pyc differ
 
__lib__/i18n/hi.pyc CHANGED
Binary files a/__lib__/i18n/hi.pyc and b/__lib__/i18n/hi.pyc differ
 
__lib__/i18n/id.pyc CHANGED
Binary files a/__lib__/i18n/id.pyc and b/__lib__/i18n/id.pyc differ
 
__lib__/i18n/it.pyc CHANGED
Binary files a/__lib__/i18n/it.pyc and b/__lib__/i18n/it.pyc differ
 
__lib__/i18n/ja.pyc CHANGED
Binary files a/__lib__/i18n/ja.pyc and b/__lib__/i18n/ja.pyc differ
 
__lib__/i18n/nl.pyc CHANGED
Binary files a/__lib__/i18n/nl.pyc and b/__lib__/i18n/nl.pyc differ
 
__lib__/i18n/no.pyc CHANGED
Binary files a/__lib__/i18n/no.pyc and b/__lib__/i18n/no.pyc differ
 
__lib__/i18n/pt.pyc CHANGED
Binary files a/__lib__/i18n/pt.pyc and b/__lib__/i18n/pt.pyc differ
 
__lib__/i18n/ru.pyc CHANGED
Binary files a/__lib__/i18n/ru.pyc and b/__lib__/i18n/ru.pyc differ
 
__lib__/i18n/sv.pyc CHANGED
Binary files a/__lib__/i18n/sv.pyc and b/__lib__/i18n/sv.pyc differ
 
__lib__/i18n/tr.pyc CHANGED
Binary files a/__lib__/i18n/tr.pyc and b/__lib__/i18n/tr.pyc differ
 
__lib__/i18n/uk.pyc CHANGED
Binary files a/__lib__/i18n/uk.pyc and b/__lib__/i18n/uk.pyc differ
 
__lib__/i18n/vi.pyc CHANGED
Binary files a/__lib__/i18n/vi.pyc and b/__lib__/i18n/vi.pyc differ
 
__lib__/i18n/zh.pyc CHANGED
Binary files a/__lib__/i18n/zh.pyc and b/__lib__/i18n/zh.pyc differ
 
__lib__/pipeline.pyc CHANGED
Binary files a/__lib__/pipeline.pyc and b/__lib__/pipeline.pyc differ
 
pipeline.py CHANGED
@@ -1058,3 +1058,877 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
1058
  return (output_images,)
1059
 
1060
  return BaseOutput(images=output_images)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1058
  return (output_images,)
1059
 
1060
  return BaseOutput(images=output_images)
1061
+
1062
+ # -----------------------------------------------------------------------------
1063
+ # 6. Advanced Multi-Modal Window Attention Block (Audio + Video + Image)
1064
+ # -----------------------------------------------------------------------------
1065
+
1066
+ @dataclass
1067
+ class MultiModalInput:
1068
+ """Container for multi-modal inputs"""
1069
+ image_embeds: Optional[torch.Tensor] = None # [B, L_img, D]
1070
+ video_embeds: Optional[torch.Tensor] = None # [B, T_video, L_vid, D]
1071
+ audio_embeds: Optional[torch.Tensor] = None # [B, T_audio, L_aud, D]
1072
+ attention_mask: Optional[torch.Tensor] = None # [B, total_length]
1073
+
1074
+
1075
+ class TemporalWindowPartition(nn.Module):
1076
+ """
1077
+ Partition temporal sequences into windows for efficient attention.
1078
+ Supports both uniform and adaptive windowing strategies.
1079
+ """
1080
+ def __init__(
1081
+ self,
1082
+ window_size: int = 8,
1083
+ shift_size: int = 0,
1084
+ use_adaptive_window: bool = False,
1085
+ ):
1086
+ super().__init__()
1087
+ self.window_size = window_size
1088
+ self.shift_size = shift_size
1089
+ self.use_adaptive_window = use_adaptive_window
1090
+
1091
+ def partition(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, Any]]:
1092
+ """
1093
+ Partition sequence into windows.
1094
+
1095
+ Args:
1096
+ x: Input tensor [B, T, L, D] or [B, L, D]
1097
+
1098
+ Returns:
1099
+ windowed: [B * num_windows, window_size, L, D]
1100
+ info: Dictionary with partition information
1101
+ """
1102
+ if x.ndim == 3: # Static input (image)
1103
+ return x, {"is_temporal": False, "original_shape": x.shape}
1104
+
1105
+ B, T, L, D = x.shape
1106
+
1107
+ # Apply temporal shift for shifted window attention (Swin-Transformer style)
1108
+ if self.shift_size > 0:
1109
+ x = torch.roll(x, shifts=-self.shift_size, dims=1)
1110
+
1111
+ # Pad if necessary
1112
+ pad_t = (self.window_size - T % self.window_size) % self.window_size
1113
+ if pad_t > 0:
1114
+ x = F.pad(x, (0, 0, 0, 0, 0, pad_t))
1115
+
1116
+ T_padded = T + pad_t
1117
+ num_windows = T_padded // self.window_size
1118
+
1119
+ # Reshape into windows: [B, num_windows, window_size, L, D]
1120
+ x_windowed = x.view(B, num_windows, self.window_size, L, D)
1121
+
1122
+ # Merge batch and window dims: [B * num_windows, window_size, L, D]
1123
+ x_windowed = x_windowed.view(B * num_windows, self.window_size, L, D)
1124
+
1125
+ info = {
1126
+ "is_temporal": True,
1127
+ "original_shape": (B, T, L, D),
1128
+ "num_windows": num_windows,
1129
+ "pad_t": pad_t,
1130
+ }
1131
+
1132
+ return x_windowed, info
1133
+
1134
+ def merge(self, x_windowed: torch.Tensor, info: Dict[str, Any]) -> torch.Tensor:
1135
+ """
1136
+ Merge windows back to original sequence.
1137
+
1138
+ Args:
1139
+ x_windowed: Windowed tensor [B * num_windows, window_size, L, D]
1140
+ info: Partition information from partition()
1141
+
1142
+ Returns:
1143
+ x: Merged tensor [B, T, L, D] or [B, L, D]
1144
+ """
1145
+ if not info["is_temporal"]:
1146
+ return x_windowed
1147
+
1148
+ B, T, L, D = info["original_shape"]
1149
+ num_windows = info["num_windows"]
1150
+ pad_t = info["pad_t"]
1151
+
1152
+ # Reshape: [B * num_windows, window_size, L, D] -> [B, num_windows, window_size, L, D]
1153
+ x = x_windowed.view(B, num_windows, self.window_size, L, D)
1154
+
1155
+ # Merge windows: [B, T_padded, L, D]
1156
+ x = x.view(B, num_windows * self.window_size, L, D)
1157
+
1158
+ # Remove padding
1159
+ if pad_t > 0:
1160
+ x = x[:, :-pad_t, :, :]
1161
+
1162
+ # Reverse temporal shift
1163
+ if self.shift_size > 0:
1164
+ x = torch.roll(x, shifts=self.shift_size, dims=1)
1165
+
1166
+ return x
1167
+
1168
+
1169
+ class WindowCrossAttention(nn.Module):
1170
+ """
1171
+ Window-based Cross Attention with support for temporal sequences.
1172
+ Performs attention within local windows for computational efficiency.
1173
+ """
1174
+ def __init__(
1175
+ self,
1176
+ dim: int,
1177
+ num_heads: int = 8,
1178
+ window_size: int = 8,
1179
+ qkv_bias: bool = True,
1180
+ attn_drop: float = 0.0,
1181
+ proj_drop: float = 0.0,
1182
+ use_relative_position_bias: bool = True,
1183
+ ):
1184
+ super().__init__()
1185
+ self.dim = dim
1186
+ self.num_heads = num_heads
1187
+ self.window_size = window_size
1188
+ self.head_dim = dim // num_heads
1189
+ self.scale = self.head_dim ** -0.5
1190
+
1191
+ # Query, Key, Value projections
1192
+ self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
1193
+ self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
1194
+ self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
1195
+
1196
+ # QK Normalization for stability
1197
+ self.q_norm = OmniRMSNorm(self.head_dim)
1198
+ self.k_norm = OmniRMSNorm(self.head_dim)
1199
+
1200
+ # Attention dropout
1201
+ self.attn_drop = nn.Dropout(attn_drop)
1202
+
1203
+ # Output projection
1204
+ self.proj = nn.Linear(dim, dim)
1205
+ self.proj_drop = nn.Dropout(proj_drop)
1206
+
1207
+ # Relative position bias (for temporal coherence)
1208
+ self.use_relative_position_bias = use_relative_position_bias
1209
+ if use_relative_position_bias:
1210
+ # Temporal relative position bias
1211
+ self.relative_position_bias_table = nn.Parameter(
1212
+ torch.zeros((2 * window_size - 1), num_heads)
1213
+ )
1214
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
1215
+
1216
+ # Get relative position index
1217
+ coords = torch.arange(window_size)
1218
+ relative_coords = coords[:, None] - coords[None, :] # [window_size, window_size]
1219
+ relative_coords += window_size - 1 # Shift to start from 0
1220
+ self.register_buffer("relative_position_index", relative_coords)
1221
+
1222
+ def get_relative_position_bias(self, window_size: int) -> torch.Tensor:
1223
+ """Generate relative position bias for attention"""
1224
+ if not self.use_relative_position_bias:
1225
+ return None
1226
+
1227
+ relative_position_bias = self.relative_position_bias_table[
1228
+ self.relative_position_index[:window_size, :window_size].reshape(-1)
1229
+ ].reshape(window_size, window_size, -1)
1230
+
1231
+ # Permute to [num_heads, window_size, window_size]
1232
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
1233
+ return relative_position_bias
1234
+
1235
+ def forward(
1236
+ self,
1237
+ query: torch.Tensor, # [B, T_q, L_q, D] or [B, L_q, D]
1238
+ key: torch.Tensor, # [B, T_k, L_k, D] or [B, L_k, D]
1239
+ value: torch.Tensor, # [B, T_v, L_v, D] or [B, L_v, D]
1240
+ attention_mask: Optional[torch.Tensor] = None,
1241
+ ) -> torch.Tensor:
1242
+ """
1243
+ Perform windowed cross attention.
1244
+
1245
+ Args:
1246
+ query: Query tensor
1247
+ key: Key tensor
1248
+ value: Value tensor
1249
+ attention_mask: Optional attention mask
1250
+
1251
+ Returns:
1252
+ Output tensor with same shape as query
1253
+ """
1254
+ # Handle both temporal and non-temporal inputs
1255
+ is_temporal = query.ndim == 4
1256
+
1257
+ if is_temporal:
1258
+ B, T_q, L_q, D = query.shape
1259
+ _, T_k, L_k, _ = key.shape
1260
+
1261
+ # Flatten temporal and spatial dims for cross attention
1262
+ query_flat = query.reshape(B, T_q * L_q, D)
1263
+ key_flat = key.reshape(B, T_k * L_k, D)
1264
+ value_flat = value.reshape(B, T_k * L_k, D)
1265
+ else:
1266
+ B, L_q, D = query.shape
1267
+ _, L_k, _ = key.shape
1268
+ query_flat = query
1269
+ key_flat = key
1270
+ value_flat = value
1271
+
1272
+ # Project to Q, K, V
1273
+ q = self.q_proj(query_flat) # [B, N_q, D]
1274
+ k = self.k_proj(key_flat) # [B, N_k, D]
1275
+ v = self.v_proj(value_flat) # [B, N_v, D]
1276
+
1277
+ # Reshape for multi-head attention
1278
+ q = q.reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, N_q, head_dim]
1279
+ k = k.reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, N_k, head_dim]
1280
+ v = v.reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, N_v, head_dim]
1281
+
1282
+ # Apply QK normalization
1283
+ q = self.q_norm(q)
1284
+ k = self.k_norm(k)
1285
+
1286
+ # Scaled dot-product attention
1287
+ attn = (q @ k.transpose(-2, -1)) * self.scale # [B, H, N_q, N_k]
1288
+
1289
+ # Add relative position bias if temporal
1290
+ if is_temporal and self.use_relative_position_bias:
1291
+ # Apply per-window bias
1292
+ rel_bias = self.get_relative_position_bias(min(T_q, self.window_size))
1293
+ if rel_bias is not None:
1294
+ # Broadcast bias across spatial dimensions
1295
+ attn = attn + rel_bias.unsqueeze(0).unsqueeze(2)
1296
+
1297
+ # Apply attention mask
1298
+ if attention_mask is not None:
1299
+ attn = attn.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2) == 0, float('-inf'))
1300
+
1301
+ # Softmax and dropout
1302
+ attn = F.softmax(attn, dim=-1)
1303
+ attn = self.attn_drop(attn)
1304
+
1305
+ # Apply attention to values
1306
+ out = (attn @ v).transpose(1, 2).reshape(B, -1, D) # [B, N_q, D]
1307
+
1308
+ # Output projection
1309
+ out = self.proj(out)
1310
+ out = self.proj_drop(out)
1311
+
1312
+ # Reshape back to original shape
1313
+ if is_temporal:
1314
+ out = out.reshape(B, T_q, L_q, D)
1315
+ else:
1316
+ out = out.reshape(B, L_q, D)
1317
+
1318
+ return out
1319
+
1320
+
1321
+ class MultiModalFusionLayer(nn.Module):
1322
+ """
1323
+ Fuses multiple modalities (audio, video, image) with learnable fusion weights.
1324
+ """
1325
+ def __init__(
1326
+ self,
1327
+ dim: int,
1328
+ num_modalities: int = 3,
1329
+ fusion_type: str = "weighted", # "weighted", "gated", "adaptive"
1330
+ ):
1331
+ super().__init__()
1332
+ self.dim = dim
1333
+ self.num_modalities = num_modalities
1334
+ self.fusion_type = fusion_type
1335
+
1336
+ if fusion_type == "weighted":
1337
+ # Learnable fusion weights
1338
+ self.fusion_weights = nn.Parameter(torch.ones(num_modalities) / num_modalities)
1339
+
1340
+ elif fusion_type == "gated":
1341
+ # Gated fusion with cross-modal interactions
1342
+ self.gate_proj = nn.Sequential(
1343
+ nn.Linear(dim * num_modalities, dim * 2),
1344
+ nn.GELU(),
1345
+ nn.Linear(dim * 2, num_modalities),
1346
+ nn.Softmax(dim=-1)
1347
+ )
1348
+
1349
+ elif fusion_type == "adaptive":
1350
+ # Adaptive fusion with per-token gating
1351
+ self.adaptive_gate = nn.Sequential(
1352
+ nn.Linear(dim, dim // 2),
1353
+ nn.GELU(),
1354
+ nn.Linear(dim // 2, num_modalities),
1355
+ nn.Sigmoid()
1356
+ )
1357
+
1358
+ def forward(self, modality_features: List[torch.Tensor]) -> torch.Tensor:
1359
+ """
1360
+ Fuse multiple modality features.
1361
+
1362
+ Args:
1363
+ modality_features: List of [B, L, D] tensors for each modality
1364
+
1365
+ Returns:
1366
+ fused: Fused features [B, L, D]
1367
+ """
1368
+ if self.fusion_type == "weighted":
1369
+ # Simple weighted sum
1370
+ weights = F.softmax(self.fusion_weights, dim=0)
1371
+ fused = sum(w * feat for w, feat in zip(weights, modality_features))
1372
+
1373
+ elif self.fusion_type == "gated":
1374
+ # Concatenate and compute gates
1375
+ concat_features = torch.cat(modality_features, dim=-1) # [B, L, D * num_modalities]
1376
+ gates = self.gate_proj(concat_features) # [B, L, num_modalities]
1377
+
1378
+ # Apply gates
1379
+ stacked = torch.stack(modality_features, dim=-1) # [B, L, D, num_modalities]
1380
+ fused = (stacked * gates.unsqueeze(2)).sum(dim=-1) # [B, L, D]
1381
+
1382
+ elif self.fusion_type == "adaptive":
1383
+ # Adaptive per-token fusion
1384
+ fused_list = []
1385
+ for feat in modality_features:
1386
+ gate = self.adaptive_gate(feat) # [B, L, num_modalities]
1387
+ fused_list.append(feat.unsqueeze(-1) * gate.unsqueeze(2))
1388
+
1389
+ fused = torch.cat(fused_list, dim=-1).sum(dim=-1) # [B, L, D]
1390
+
1391
+ return fused
1392
+
1393
+
1394
+ class FancyMultiModalWindowAttentionBlock(nn.Module):
1395
+ """
1396
+ 🎯 Fancy Multi-Modal Window Attention Block
1397
+
1398
+ A state-of-the-art block that processes audio, video, and image embeddings
1399
+ with temporal window-based cross-attention for efficient multi-modal fusion.
1400
+
1401
+ Features:
1402
+ - ✨ Temporal windowing for audio and video (frame-by-frame processing)
1403
+ - 🪟 Shifted window attention for better temporal coherence (Swin-style)
1404
+ - 🔄 Cross-modal attention between all modality pairs
1405
+ - 🎭 Adaptive multi-modal fusion with learnable gates
1406
+ - 🚀 Efficient computation with window partitioning
1407
+ - 💎 QK normalization for training stability
1408
+
1409
+ Architecture:
1410
+ 1. Temporal Partitioning (audio/video frames → windows)
1411
+ 2. Intra-Modal Self-Attention (within each modality)
1412
+ 3. Inter-Modal Cross-Attention (audio ↔ video ↔ image)
1413
+ 4. Multi-Modal Fusion (adaptive weighted combination)
1414
+ 5. Feed-Forward Network (SwiGLU activation)
1415
+ 6. Window Merging (reconstruct temporal sequences)
1416
+ """
1417
+
1418
+ def __init__(
1419
+ self,
1420
+ dim: int = 1024,
1421
+ num_heads: int = 16,
1422
+ window_size: int = 8,
1423
+ shift_size: int = 4,
1424
+ mlp_ratio: float = 4.0,
1425
+ qkv_bias: bool = True,
1426
+ drop: float = 0.0,
1427
+ attn_drop: float = 0.0,
1428
+ drop_path: float = 0.1,
1429
+ use_relative_position_bias: bool = True,
1430
+ fusion_type: str = "adaptive", # "weighted", "gated", "adaptive"
1431
+ use_shifted_window: bool = True,
1432
+ ):
1433
+ super().__init__()
1434
+ self.dim = dim
1435
+ self.num_heads = num_heads
1436
+ self.window_size = window_size
1437
+ self.shift_size = shift_size if use_shifted_window else 0
1438
+ self.mlp_ratio = mlp_ratio
1439
+
1440
+ # =============== Temporal Window Partitioning ===============
1441
+ self.window_partition = TemporalWindowPartition(
1442
+ window_size=window_size,
1443
+ shift_size=self.shift_size,
1444
+ )
1445
+
1446
+ # =============== Intra-Modal Self-Attention ===============
1447
+ self.norm_audio_self = OmniRMSNorm(dim)
1448
+ self.norm_video_self = OmniRMSNorm(dim)
1449
+ self.norm_image_self = OmniRMSNorm(dim)
1450
+
1451
+ self.audio_self_attn = WindowCrossAttention(
1452
+ dim=dim,
1453
+ num_heads=num_heads,
1454
+ window_size=window_size,
1455
+ qkv_bias=qkv_bias,
1456
+ attn_drop=attn_drop,
1457
+ proj_drop=drop,
1458
+ use_relative_position_bias=use_relative_position_bias,
1459
+ )
1460
+
1461
+ self.video_self_attn = WindowCrossAttention(
1462
+ dim=dim,
1463
+ num_heads=num_heads,
1464
+ window_size=window_size,
1465
+ qkv_bias=qkv_bias,
1466
+ attn_drop=attn_drop,
1467
+ proj_drop=drop,
1468
+ use_relative_position_bias=use_relative_position_bias,
1469
+ )
1470
+
1471
+ self.image_self_attn = WindowCrossAttention(
1472
+ dim=dim,
1473
+ num_heads=num_heads,
1474
+ window_size=window_size,
1475
+ qkv_bias=qkv_bias,
1476
+ attn_drop=attn_drop,
1477
+ proj_drop=drop,
1478
+ use_relative_position_bias=False, # No temporal bias for static images
1479
+ )
1480
+
1481
+ # =============== Inter-Modal Cross-Attention ===============
1482
+ # Audio → Video/Image
1483
+ self.norm_audio_cross = OmniRMSNorm(dim)
1484
+ self.audio_to_visual = WindowCrossAttention(
1485
+ dim=dim, num_heads=num_heads, window_size=window_size,
1486
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
1487
+ )
1488
+
1489
+ # Video → Audio/Image
1490
+ self.norm_video_cross = OmniRMSNorm(dim)
1491
+ self.video_to_others = WindowCrossAttention(
1492
+ dim=dim, num_heads=num_heads, window_size=window_size,
1493
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
1494
+ )
1495
+
1496
+ # Image → Audio/Video
1497
+ self.norm_image_cross = OmniRMSNorm(dim)
1498
+ self.image_to_temporal = WindowCrossAttention(
1499
+ dim=dim, num_heads=num_heads, window_size=window_size,
1500
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
1501
+ )
1502
+
1503
+ # =============== Multi-Modal Fusion ===============
1504
+ self.multimodal_fusion = MultiModalFusionLayer(
1505
+ dim=dim,
1506
+ num_modalities=3,
1507
+ fusion_type=fusion_type,
1508
+ )
1509
+
1510
+ # =============== Feed-Forward Network ===============
1511
+ self.norm_ffn = OmniRMSNorm(dim)
1512
+ mlp_hidden_dim = int(dim * mlp_ratio)
1513
+ self.ffn = nn.Sequential(
1514
+ nn.Linear(dim, mlp_hidden_dim, bias=False),
1515
+ nn.GELU(),
1516
+ nn.Dropout(drop),
1517
+ nn.Linear(mlp_hidden_dim, dim, bias=False),
1518
+ nn.Dropout(drop),
1519
+ )
1520
+
1521
+ # =============== Stochastic Depth (Drop Path) ===============
1522
+ self.drop_path = nn.Identity() if drop_path <= 0. else nn.Dropout(drop_path)
1523
+
1524
+ # =============== Output Projections ===============
1525
+ self.output_projection = nn.ModuleDict({
1526
+ 'audio': nn.Linear(dim, dim),
1527
+ 'video': nn.Linear(dim, dim),
1528
+ 'image': nn.Linear(dim, dim),
1529
+ })
1530
+
1531
+ def forward(
1532
+ self,
1533
+ audio_embeds: Optional[torch.Tensor] = None, # [B, T_audio, L_audio, D]
1534
+ video_embeds: Optional[torch.Tensor] = None, # [B, T_video, L_video, D]
1535
+ image_embeds: Optional[torch.Tensor] = None, # [B, L_image, D]
1536
+ attention_mask: Optional[torch.Tensor] = None,
1537
+ return_intermediates: bool = False,
1538
+ ) -> Dict[str, torch.Tensor]:
1539
+ """
1540
+ Forward pass of the Fancy Multi-Modal Window Attention Block.
1541
+
1542
+ Args:
1543
+ audio_embeds: Audio embeddings [B, T_audio, L_audio, D]
1544
+ T_audio: number of audio frames
1545
+ L_audio: sequence length per frame
1546
+ video_embeds: Video embeddings [B, T_video, L_video, D]
1547
+ T_video: number of video frames
1548
+ L_video: sequence length per frame (e.g., patches)
1549
+ image_embeds: Image embeddings [B, L_image, D]
1550
+ L_image: sequence length (e.g., image patches)
1551
+ attention_mask: Optional attention mask
1552
+ return_intermediates: Whether to return intermediate features
1553
+
1554
+ Returns:
1555
+ outputs: Dictionary containing processed embeddings for each modality
1556
+ - 'audio': [B, T_audio, L_audio, D]
1557
+ - 'video': [B, T_video, L_video, D]
1558
+ - 'image': [B, L_image, D]
1559
+ - 'fused': [B, L_total, D] (optional)
1560
+ """
1561
+ intermediates = {} if return_intermediates else None
1562
+
1563
+ # ========== Stage 1: Temporal Window Partitioning ==========
1564
+ partitioned_audio, audio_info = None, None
1565
+ partitioned_video, video_info = None, None
1566
+
1567
+ if audio_embeds is not None:
1568
+ partitioned_audio, audio_info = self.window_partition.partition(audio_embeds)
1569
+ if return_intermediates:
1570
+ intermediates['audio_windows'] = partitioned_audio
1571
+
1572
+ if video_embeds is not None:
1573
+ partitioned_video, video_info = self.window_partition.partition(video_embeds)
1574
+ if return_intermediates:
1575
+ intermediates['video_windows'] = partitioned_video
1576
+
1577
+ # ========== Stage 2: Intra-Modal Self-Attention ==========
1578
+ audio_self_out, video_self_out, image_self_out = None, None, None
1579
+
1580
+ if audio_embeds is not None:
1581
+ audio_normed = self.norm_audio_self(partitioned_audio)
1582
+ audio_self_out = self.audio_self_attn(audio_normed, audio_normed, audio_normed)
1583
+ audio_self_out = partitioned_audio + self.drop_path(audio_self_out)
1584
+
1585
+ if video_embeds is not None:
1586
+ video_normed = self.norm_video_self(partitioned_video)
1587
+ video_self_out = self.video_self_attn(video_normed, video_normed, video_normed)
1588
+ video_self_out = partitioned_video + self.drop_path(video_self_out)
1589
+
1590
+ if image_embeds is not None:
1591
+ image_normed = self.norm_image_self(image_embeds)
1592
+ image_self_out = self.image_self_attn(image_normed, image_normed, image_normed)
1593
+ image_self_out = image_embeds + self.drop_path(image_self_out)
1594
+
1595
+ # ========== Stage 3: Inter-Modal Cross-Attention ==========
1596
+ audio_cross_out, video_cross_out, image_cross_out = None, None, None
1597
+
1598
+ # Prepare context (merge windows temporarily for cross-attention)
1599
+ if audio_self_out is not None:
1600
+ audio_merged = self.window_partition.merge(audio_self_out, audio_info)
1601
+ if video_self_out is not None:
1602
+ video_merged = self.window_partition.merge(video_self_out, video_info)
1603
+
1604
+ # Audio attends to Video and Image
1605
+ if audio_embeds is not None:
1606
+ audio_q = self.norm_audio_cross(audio_merged)
1607
+
1608
+ # Create key-value context from other modalities
1609
+ kv_list = []
1610
+ if video_embeds is not None:
1611
+ kv_list.append(video_merged)
1612
+ if image_embeds is not None:
1613
+ # Expand image to match temporal dimension
1614
+ B, L_img, D = image_self_out.shape
1615
+ T_audio = audio_merged.shape[1]
1616
+ image_expanded = image_self_out.unsqueeze(1).expand(B, T_audio, L_img, D)
1617
+ kv_list.append(image_expanded)
1618
+
1619
+ if kv_list:
1620
+ # Concatenate along sequence dimension
1621
+ kv_context = torch.cat([kv.flatten(1, 2) for kv in kv_list], dim=1)
1622
+ kv_context = kv_context.reshape(B, -1, D)
1623
+
1624
+ audio_cross_out = self.audio_to_visual(
1625
+ audio_q.flatten(1, 2),
1626
+ kv_context,
1627
+ kv_context,
1628
+ attention_mask
1629
+ )
1630
+ audio_cross_out = audio_cross_out.reshape_as(audio_merged)
1631
+ audio_cross_out = audio_merged + self.drop_path(audio_cross_out)
1632
+ else:
1633
+ audio_cross_out = audio_merged
1634
+
1635
+ # Video attends to Audio and Image
1636
+ if video_embeds is not None:
1637
+ video_q = self.norm_video_cross(video_merged)
1638
+
1639
+ kv_list = []
1640
+ if audio_embeds is not None:
1641
+ kv_list.append(audio_merged if audio_cross_out is None else audio_cross_out)
1642
+ if image_embeds is not None:
1643
+ B, L_img, D = image_self_out.shape
1644
+ T_video = video_merged.shape[1]
1645
+ image_expanded = image_self_out.unsqueeze(1).expand(B, T_video, L_img, D)
1646
+ kv_list.append(image_expanded)
1647
+
1648
+ if kv_list:
1649
+ kv_context = torch.cat([kv.flatten(1, 2) for kv in kv_list], dim=1)
1650
+ kv_context = kv_context.reshape(B, -1, D)
1651
+
1652
+ video_cross_out = self.video_to_others(
1653
+ video_q.flatten(1, 2),
1654
+ kv_context,
1655
+ kv_context,
1656
+ attention_mask
1657
+ )
1658
+ video_cross_out = video_cross_out.reshape_as(video_merged)
1659
+ video_cross_out = video_merged + self.drop_path(video_cross_out)
1660
+ else:
1661
+ video_cross_out = video_merged
1662
+
1663
+ # Image attends to Audio and Video
1664
+ if image_embeds is not None:
1665
+ image_q = self.norm_image_cross(image_self_out)
1666
+
1667
+ kv_list = []
1668
+ if audio_embeds is not None:
1669
+ # Average pool audio over time for image
1670
+ audio_pooled = (audio_merged if audio_cross_out is None else audio_cross_out).mean(dim=1)
1671
+ kv_list.append(audio_pooled)
1672
+ if video_embeds is not None:
1673
+ # Average pool video over time for image
1674
+ video_pooled = (video_merged if video_cross_out is None else video_cross_out).mean(dim=1)
1675
+ kv_list.append(video_pooled)
1676
+
1677
+ if kv_list:
1678
+ kv_context = torch.cat(kv_list, dim=1)
1679
+
1680
+ image_cross_out = self.image_to_temporal(
1681
+ image_q,
1682
+ kv_context,
1683
+ kv_context,
1684
+ attention_mask
1685
+ )
1686
+ image_cross_out = image_self_out + self.drop_path(image_cross_out)
1687
+ else:
1688
+ image_cross_out = image_self_out
1689
+
1690
+ # ========== Stage 4: Multi-Modal Fusion ==========
1691
+ # Collect features from all modalities for fusion
1692
+ fusion_features = []
1693
+ if audio_cross_out is not None:
1694
+ audio_flat = audio_cross_out.flatten(1, 2) # [B, T*L, D]
1695
+ fusion_features.append(audio_flat)
1696
+ if video_cross_out is not None:
1697
+ video_flat = video_cross_out.flatten(1, 2) # [B, T*L, D]
1698
+ fusion_features.append(video_flat)
1699
+ if image_cross_out is not None:
1700
+ fusion_features.append(image_cross_out) # [B, L, D]
1701
+
1702
+ # Pad/align sequence lengths for fusion
1703
+ if len(fusion_features) > 1:
1704
+ max_len = max(f.shape[1] for f in fusion_features)
1705
+ aligned_features = []
1706
+ for feat in fusion_features:
1707
+ if feat.shape[1] < max_len:
1708
+ pad_len = max_len - feat.shape[1]
1709
+ feat = F.pad(feat, (0, 0, 0, pad_len))
1710
+ aligned_features.append(feat)
1711
+
1712
+ # Fuse modalities
1713
+ fused_features = self.multimodal_fusion(aligned_features)
1714
+ else:
1715
+ fused_features = fusion_features[0] if fusion_features else None
1716
+
1717
+ # ========== Stage 5: Feed-Forward Network ==========
1718
+ if fused_features is not None:
1719
+ fused_normed = self.norm_ffn(fused_features)
1720
+ fused_ffn = self.ffn(fused_normed)
1721
+ fused_features = fused_features + self.drop_path(fused_ffn)
1722
+
1723
+ # ========== Stage 6: Prepare Outputs ==========
1724
+ outputs = {}
1725
+
1726
+ # Project back to original shapes
1727
+ if audio_embeds is not None and audio_cross_out is not None:
1728
+ # Partition again for consistency
1729
+ audio_final, _ = self.window_partition.partition(audio_cross_out)
1730
+ audio_final = self.output_projection['audio'](audio_final)
1731
+ audio_final = self.window_partition.merge(audio_final, audio_info)
1732
+ outputs['audio'] = audio_final
1733
+
1734
+ if video_embeds is not None and video_cross_out is not None:
1735
+ video_final, _ = self.window_partition.partition(video_cross_out)
1736
+ video_final = self.output_projection['video'](video_final)
1737
+ video_final = self.window_partition.merge(video_final, video_info)
1738
+ outputs['video'] = video_final
1739
+
1740
+ if image_embeds is not None and image_cross_out is not None:
1741
+ image_final = self.output_projection['image'](image_cross_out)
1742
+ outputs['image'] = image_final
1743
+
1744
+ if fused_features is not None:
1745
+ outputs['fused'] = fused_features
1746
+
1747
+ if return_intermediates:
1748
+ outputs['intermediates'] = intermediates
1749
+
1750
+ return outputs
1751
+
1752
+
1753
+ # -----------------------------------------------------------------------------
1754
+ # 7. Optimization Utilities (FP8, Compilation, Mixed Precision)
1755
+ # -----------------------------------------------------------------------------
1756
+
1757
+ @dataclass
1758
+ class FP8Config:
1759
+ """Configuration for FP8 quantization"""
1760
+ enabled: bool = False
1761
+ margin: int = 0
1762
+ fp8_format: str = "hybrid" # "e4m3", "e5m2", "hybrid"
1763
+ amax_history_len: int = 1024
1764
+ amax_compute_algo: str = "max"
1765
+
1766
+
1767
+ @dataclass
1768
+ class CompilationConfig:
1769
+ """Configuration for torch.compile"""
1770
+ enabled: bool = False
1771
+ mode: str = "reduce-overhead" # "default", "reduce-overhead", "max-autotune"
1772
+ fullgraph: bool = False
1773
+ dynamic: bool = True
1774
+ backend: str = "inductor"
1775
+
1776
+
1777
+ @dataclass
1778
+ class MixedPrecisionConfig:
1779
+ """Configuration for mixed precision training/inference"""
1780
+ enabled: bool = True
1781
+ dtype: str = "bfloat16" # "float16", "bfloat16"
1782
+ use_amp: bool = True
1783
+
1784
+
1785
+ class ModelOptimizer:
1786
+ """
1787
+ Unified model optimizer supporting FP8 quantization, torch.compile,
1788
+ and mixed precision inference.
1789
+ """
1790
+ def __init__(
1791
+ self,
1792
+ fp8_config: Optional[FP8Config] = None,
1793
+ compilation_config: Optional[CompilationConfig] = None,
1794
+ mixed_precision_config: Optional[MixedPrecisionConfig] = None,
1795
+ ):
1796
+ self.fp8_config = fp8_config or FP8Config()
1797
+ self.compilation_config = compilation_config or CompilationConfig()
1798
+ self.mixed_precision_config = mixed_precision_config or MixedPrecisionConfig()
1799
+
1800
+ # Setup mixed precision
1801
+ self._setup_mixed_precision()
1802
+
1803
+ def _setup_mixed_precision(self):
1804
+ """Setup mixed precision context"""
1805
+ if self.mixed_precision_config.enabled:
1806
+ dtype_map = {
1807
+ "float16": torch.float16,
1808
+ "bfloat16": torch.bfloat16,
1809
+ }
1810
+ self.dtype = dtype_map.get(self.mixed_precision_config.dtype, torch.bfloat16)
1811
+ else:
1812
+ self.dtype = torch.float32
1813
+
1814
+ @contextmanager
1815
+ def autocast_context(self):
1816
+ """Context manager for automatic mixed precision"""
1817
+ if self.mixed_precision_config.enabled and self.mixed_precision_config.use_amp:
1818
+ with torch.autocast(device_type='cuda', dtype=self.dtype):
1819
+ yield
1820
+ else:
1821
+ yield
1822
+
1823
+ def _compile_model(self, model: nn.Module) -> nn.Module:
1824
+ """Compile model using torch.compile"""
1825
+ if not self.compilation_config.enabled or not HAS_TORCH_COMPILE:
1826
+ return model
1827
+
1828
+ return torch.compile(
1829
+ model,
1830
+ mode=self.compilation_config.mode,
1831
+ fullgraph=self.compilation_config.fullgraph,
1832
+ dynamic=self.compilation_config.dynamic,
1833
+ backend=self.compilation_config.backend,
1834
+ )
1835
+
1836
+ def _quantize_model_fp8(self, model: nn.Module) -> nn.Module:
1837
+ """Apply FP8 quantization using Transformer Engine"""
1838
+ if not self.fp8_config.enabled or not HAS_TRANSFORMER_ENGINE:
1839
+ return model
1840
+
1841
+ # Convert compatible layers to FP8
1842
+ for name, module in model.named_modules():
1843
+ if isinstance(module, nn.Linear):
1844
+ # Replace with TE FP8 Linear
1845
+ fp8_linear = te.Linear(
1846
+ module.in_features,
1847
+ module.out_features,
1848
+ bias=module.bias is not None,
1849
+ )
1850
+ # Copy weights
1851
+ fp8_linear.weight.data.copy_(module.weight.data)
1852
+ if module.bias is not None:
1853
+ fp8_linear.bias.data.copy_(module.bias.data)
1854
+
1855
+ # Replace module
1856
+ parent_name = '.'.join(name.split('.')[:-1])
1857
+ child_name = name.split('.')[-1]
1858
+ if parent_name:
1859
+ parent = dict(model.named_modules())[parent_name]
1860
+ setattr(parent, child_name, fp8_linear)
1861
+
1862
+ return model
1863
+
1864
+ def optimize_model(
1865
+ self,
1866
+ model: nn.Module,
1867
+ apply_compilation: bool = True,
1868
+ apply_quantization: bool = True,
1869
+ apply_mixed_precision: bool = True,
1870
+ ) -> nn.Module:
1871
+ """
1872
+ Apply all optimizations to model.
1873
+
1874
+ Args:
1875
+ model: Model to optimize
1876
+ apply_compilation: Whether to compile with torch.compile
1877
+ apply_quantization: Whether to apply FP8 quantization
1878
+ apply_mixed_precision: Whether to convert to mixed precision dtype
1879
+
1880
+ Returns:
1881
+ Optimized model
1882
+ """
1883
+ # Apply FP8 quantization first
1884
+ if apply_quantization and self.fp8_config.enabled:
1885
+ model = self._quantize_model_fp8(model)
1886
+
1887
+ # Convert to mixed precision dtype
1888
+ if apply_mixed_precision and self.mixed_precision_config.enabled:
1889
+ model = model.to(dtype=self.dtype)
1890
+
1891
+ # Compile model last
1892
+ if apply_compilation and self.compilation_config.enabled:
1893
+ model = self._compile_model(model)
1894
+
1895
+ return model
1896
+
1897
+
1898
+ @contextmanager
1899
+ def optimized_inference_mode(
1900
+ enable_cudnn_benchmark: bool = True,
1901
+ enable_tf32: bool = True,
1902
+ enable_flash_sdp: bool = True,
1903
+ ):
1904
+ """
1905
+ Context manager for optimized inference with various PyTorch optimizations.
1906
+
1907
+ Args:
1908
+ enable_cudnn_benchmark: Enable cuDNN autotuner
1909
+ enable_tf32: Enable TF32 for faster matmul on Ampere+ GPUs
1910
+ enable_flash_sdp: Enable Flash Attention in scaled_dot_product_attention
1911
+ """
1912
+ # Save original states
1913
+ orig_benchmark = torch.backends.cudnn.benchmark
1914
+ orig_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
1915
+ orig_tf32_cudnn = torch.backends.cudnn.allow_tf32
1916
+ orig_sdp_flash = torch.backends.cuda.flash_sdp_enabled()
1917
+
1918
+ try:
1919
+ # Enable optimizations
1920
+ torch.backends.cudnn.benchmark = enable_cudnn_benchmark
1921
+ torch.backends.cuda.matmul.allow_tf32 = enable_tf32
1922
+ torch.backends.cudnn.allow_tf32 = enable_tf32
1923
+
1924
+ if enable_flash_sdp:
1925
+ torch.backends.cuda.enable_flash_sdp(True)
1926
+
1927
+ yield
1928
+
1929
+ finally:
1930
+ # Restore original states
1931
+ torch.backends.cudnn.benchmark = orig_benchmark
1932
+ torch.backends.cuda.matmul.allow_tf32 = orig_tf32_matmul
1933
+ torch.backends.cudnn.allow_tf32 = orig_tf32_cudnn
1934
+ torch.backends.cuda.enable_flash_sdp(orig_sdp_flash)