Move flash_attn assert from __init__ into calling func
#32
by
rogerxfeng8
- opened
- modeling_phi3_small.py +2 -1
modeling_phi3_small.py
CHANGED
|
@@ -215,7 +215,6 @@ class Phi3SmallSelfAttention(nn.Module):
|
|
| 215 |
f"Layer {layer_idx + 1} is using dense attention since it is divisible by "
|
| 216 |
f"{self.config.dense_attention_every_n_layers}"
|
| 217 |
)
|
| 218 |
-
assert is_flash_attention_available, "Flash Attention is not available, but is needed for dense attention"
|
| 219 |
else:
|
| 220 |
# BlockSparse related Parameters
|
| 221 |
self.blocksparse_params = BlockSparseParams.from_config(config)
|
|
@@ -419,6 +418,8 @@ class Phi3SmallSelfAttention(nn.Module):
|
|
| 419 |
avoid doing that.
|
| 420 |
|
| 421 |
"""
|
|
|
|
|
|
|
| 422 |
attention_dropout_prob = self.attention_dropout_rate if self.training else 0.0
|
| 423 |
# Get into the correct shape for the Flash Attention API
|
| 424 |
# shape: (bs, seq_len, nqp, hn)
|
|
|
|
| 215 |
f"Layer {layer_idx + 1} is using dense attention since it is divisible by "
|
| 216 |
f"{self.config.dense_attention_every_n_layers}"
|
| 217 |
)
|
|
|
|
| 218 |
else:
|
| 219 |
# BlockSparse related Parameters
|
| 220 |
self.blocksparse_params = BlockSparseParams.from_config(config)
|
|
|
|
| 418 |
avoid doing that.
|
| 419 |
|
| 420 |
"""
|
| 421 |
+
assert is_flash_attention_available, "Flash Attention is not available, but is needed for dense attention"
|
| 422 |
+
|
| 423 |
attention_dropout_prob = self.attention_dropout_rate if self.training else 0.0
|
| 424 |
# Get into the correct shape for the Flash Attention API
|
| 425 |
# shape: (bs, seq_len, nqp, hn)
|