Skip to content

Compatibility with modeling code relying on flash-attn package #327

@pmalic

Description

@pmalic

There are widely used models on HF that ship with custom modeling code that relies on the original flash-attn package, either by calling transformers.utils.is_flash_attn_2_available() and/or relying that the AutoModel.from_pretrained() parameter attn_implementation is equal to flash_attention_2 (for example, Jina Embeddings v4: https://huggingface.co/jinaai/jina-embeddings-v4/blob/main/qwen2_5_vl.py).

I wanted to switch to kernels-community/flash-attn2 to avoid building flash-attn or using pre-built wheels from https://github.com/mjun0812/flash-attention-prebuild-wheels (has no wheels index, so you have to hardcode the path to a particular .whl file) and managed to do it with a few lines of code.

Have you thought about adding some documentation/example for this case. In my experience, model authors rarely update the included custom modeling code to support new stuff (e.g., Transformers v5 or kernels), so unless you want to create your own local modeling code, you're stuck with whatever the authors used when they published the model.

Here's the code that "exposes" kernels-community/flash-attn2 as the original flash-attn:

from pathlib import Path
import sys

from kernels import get_kernel


def setup_flash_attn_2() -> None:

    flash_attn = get_kernel("kernels-community/flash-attn2", version=1)

    sys.modules["flash_attn"] = flash_attn

    dist_info = Path(flash_attn.__file__).parent / "flash_attn.dist-info"
    dist_info.mkdir(parents=True, exist_ok=True)

    (dist_info / "METADATA").write_text("Metadata-Version: 2.5\nName: flash-attn\nVersion: 2.8.3")

    sys.path.append(str(dist_info.parent))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions