-
Notifications
You must be signed in to change notification settings - Fork 58
Description
There are widely used models on HF that ship with custom modeling code that relies on the original flash-attn package, either by calling transformers.utils.is_flash_attn_2_available() and/or relying that the AutoModel.from_pretrained() parameter attn_implementation is equal to flash_attention_2 (for example, Jina Embeddings v4: https://huggingface.co/jinaai/jina-embeddings-v4/blob/main/qwen2_5_vl.py).
I wanted to switch to kernels-community/flash-attn2 to avoid building flash-attn or using pre-built wheels from https://github.com/mjun0812/flash-attention-prebuild-wheels (has no wheels index, so you have to hardcode the path to a particular .whl file) and managed to do it with a few lines of code.
Have you thought about adding some documentation/example for this case. In my experience, model authors rarely update the included custom modeling code to support new stuff (e.g., Transformers v5 or kernels), so unless you want to create your own local modeling code, you're stuck with whatever the authors used when they published the model.
Here's the code that "exposes" kernels-community/flash-attn2 as the original flash-attn:
from pathlib import Path
import sys
from kernels import get_kernel
def setup_flash_attn_2() -> None:
flash_attn = get_kernel("kernels-community/flash-attn2", version=1)
sys.modules["flash_attn"] = flash_attn
dist_info = Path(flash_attn.__file__).parent / "flash_attn.dist-info"
dist_info.mkdir(parents=True, exist_ok=True)
(dist_info / "METADATA").write_text("Metadata-Version: 2.5\nName: flash-attn\nVersion: 2.8.3")
sys.path.append(str(dist_info.parent))