Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from openfold.model.triangular_attention import ( | |
| TriangleAttentionEndingNode, | |
| TriangleAttentionStartingNode, | |
| ) | |
| from openfold.model.triangular_multiplicative_update import ( | |
| TriangleMultiplicationIncoming, | |
| TriangleMultiplicationOutgoing, | |
| ) | |
| from torch import nn | |
| from esm.esmfold.v1.misc import ( | |
| Attention, | |
| Dropout, | |
| PairToSequence, | |
| ResidueMLP, | |
| SequenceToPair, | |
| ) | |
| class TriangularSelfAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| sequence_state_dim, | |
| pairwise_state_dim, | |
| sequence_head_width, | |
| pairwise_head_width, | |
| dropout=0, | |
| **__kwargs, | |
| ): | |
| super().__init__() | |
| assert sequence_state_dim % sequence_head_width == 0 | |
| assert pairwise_state_dim % pairwise_head_width == 0 | |
| sequence_num_heads = sequence_state_dim // sequence_head_width | |
| pairwise_num_heads = pairwise_state_dim // pairwise_head_width | |
| assert sequence_state_dim == sequence_num_heads * sequence_head_width | |
| assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width | |
| assert pairwise_state_dim % 2 == 0 | |
| self.sequence_state_dim = sequence_state_dim | |
| self.pairwise_state_dim = pairwise_state_dim | |
| self.layernorm_1 = nn.LayerNorm(sequence_state_dim) | |
| self.sequence_to_pair = SequenceToPair( | |
| sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim | |
| ) | |
| self.pair_to_sequence = PairToSequence(pairwise_state_dim, sequence_num_heads) | |
| self.seq_attention = Attention( | |
| sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True | |
| ) | |
| self.tri_mul_out = TriangleMultiplicationOutgoing( | |
| pairwise_state_dim, | |
| pairwise_state_dim, | |
| ) | |
| self.tri_mul_in = TriangleMultiplicationIncoming( | |
| pairwise_state_dim, | |
| pairwise_state_dim, | |
| ) | |
| self.tri_att_start = TriangleAttentionStartingNode( | |
| pairwise_state_dim, | |
| pairwise_head_width, | |
| pairwise_num_heads, | |
| inf=1e9, | |
| ) # type: ignore | |
| self.tri_att_end = TriangleAttentionEndingNode( | |
| pairwise_state_dim, | |
| pairwise_head_width, | |
| pairwise_num_heads, | |
| inf=1e9, | |
| ) # type: ignore | |
| self.mlp_seq = ResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=dropout) | |
| self.mlp_pair = ResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout) | |
| assert dropout < 0.4 | |
| self.drop = nn.Dropout(dropout) | |
| self.row_drop = Dropout(dropout * 2, 2) | |
| self.col_drop = Dropout(dropout * 2, 1) | |
| torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight) | |
| torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias) | |
| torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight) | |
| torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias) | |
| torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.weight) | |
| torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.bias) | |
| torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.weight) | |
| torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.bias) | |
| torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight) | |
| torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias) | |
| torch.nn.init.zeros_(self.pair_to_sequence.linear.weight) | |
| torch.nn.init.zeros_(self.seq_attention.o_proj.weight) | |
| torch.nn.init.zeros_(self.seq_attention.o_proj.bias) | |
| torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight) | |
| torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias) | |
| torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight) | |
| torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias) | |
| def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): | |
| """ | |
| Inputs: | |
| sequence_state: B x L x sequence_state_dim | |
| pairwise_state: B x L x L x pairwise_state_dim | |
| mask: B x L boolean tensor of valid positions | |
| Output: | |
| sequence_state: B x L x sequence_state_dim | |
| pairwise_state: B x L x L x pairwise_state_dim | |
| """ | |
| assert len(sequence_state.shape) == 3 | |
| assert len(pairwise_state.shape) == 4 | |
| if mask is not None: | |
| assert len(mask.shape) == 2 | |
| batch_dim, seq_dim, sequence_state_dim = sequence_state.shape | |
| pairwise_state_dim = pairwise_state.shape[3] | |
| assert sequence_state_dim == self.sequence_state_dim | |
| assert pairwise_state_dim == self.pairwise_state_dim | |
| assert batch_dim == pairwise_state.shape[0] | |
| assert seq_dim == pairwise_state.shape[1] | |
| assert seq_dim == pairwise_state.shape[2] | |
| # Update sequence state | |
| bias = self.pair_to_sequence(pairwise_state) | |
| # Self attention with bias + mlp. | |
| y = self.layernorm_1(sequence_state) | |
| y, _ = self.seq_attention(y, mask=mask, bias=bias) | |
| sequence_state = sequence_state + self.drop(y) | |
| sequence_state = self.mlp_seq(sequence_state) | |
| # Update pairwise state | |
| pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) | |
| # Axial attention with triangular bias. | |
| tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None | |
| pairwise_state = pairwise_state + self.row_drop( | |
| self.tri_mul_out(pairwise_state, mask=tri_mask) | |
| ) | |
| pairwise_state = pairwise_state + self.col_drop( | |
| self.tri_mul_in(pairwise_state, mask=tri_mask) | |
| ) | |
| pairwise_state = pairwise_state + self.row_drop( | |
| self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size) | |
| ) | |
| pairwise_state = pairwise_state + self.col_drop( | |
| self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size) | |
| ) | |
| # MLP over pairs. | |
| pairwise_state = self.mlp_pair(pairwise_state) | |
| return sequence_state, pairwise_state | |