update support for flash attn
Browse files- modeling_qwen.py +6 -3
modeling_qwen.py
CHANGED
|
@@ -87,10 +87,13 @@ def _import_flash_attn():
|
|
| 87 |
|
| 88 |
try:
|
| 89 |
import flash_attn
|
| 90 |
-
if
|
| 91 |
-
from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
|
| 92 |
-
else:
|
| 93 |
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
flash_attn_unpadded_func = __flash_attn_unpadded_func
|
| 95 |
except ImportError:
|
| 96 |
logger.warn(
|
|
|
|
| 87 |
|
| 88 |
try:
|
| 89 |
import flash_attn
|
| 90 |
+
if not hasattr(flash_attn, '__version__'):
|
|
|
|
|
|
|
| 91 |
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
|
| 92 |
+
else:
|
| 93 |
+
if int(flash_attn.__version__.split(".")[0]) >= 2:
|
| 94 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
|
| 95 |
+
else:
|
| 96 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
|
| 97 |
flash_attn_unpadded_func = __flash_attn_unpadded_func
|
| 98 |
except ImportError:
|
| 99 |
logger.warn(
|