-
Notifications
You must be signed in to change notification settings - Fork 282
Add 2:4 Sparse Attention #916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| from modelopt.torch.sparsity.attention_sparsity.config import ( | ||
| SKIP_SOFTMAX_CALIB, | ||
| SKIP_SOFTMAX_DEFAULT, | ||
| SPARSE24_TRITON, | ||
| ) | ||
| from modelopt.torch.utils.memory_monitor import launch_memory_monitor | ||
|
|
||
|
|
@@ -43,6 +44,7 @@ | |
| SPARSE_ATTN_CFG_CHOICES = { | ||
| "skip_softmax": SKIP_SOFTMAX_DEFAULT, | ||
| "skip_softmax_calib": SKIP_SOFTMAX_CALIB, | ||
| "sparse24_triton": SPARSE24_TRITON, | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -144,12 +146,14 @@ def main(args): | |
|
|
||
| print(f"Loading model: {args.pyt_ckpt_path}") | ||
|
|
||
| # Load model and tokenizer | ||
| # Note: attn_implementation="eager" is required for calibration to work properly | ||
| # (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection) | ||
| # Select attn_implementation based on sparse method: | ||
| # - skip_softmax methods require "eager" (softmax patching bypassed by flash/sdpa) | ||
| # - sparse24_triton requires "modelopt_triton" (fused Triton kernel) | ||
| # No need to specify attn_implementation here — mtsa.sparsify() handles it | ||
| # automatically based on the sparse config (sets "modelopt_triton" for triton | ||
| # backend, keeps "eager" for pytorch backend). | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| args.pyt_ckpt_path, | ||
| attn_implementation="eager", | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
|
Comment on lines
+149
to
158
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before/after comparison uses different attention backends for Before Consider documenting this limitation in the comment block at lines 149-154, or conditionally set 🤖 Prompt for AI Agents |
||
| tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) | ||
|
|
@@ -246,8 +250,8 @@ def main(args): | |
| "--backend", | ||
| type=str, | ||
| default="pytorch", | ||
| choices=["pytorch"], | ||
| help="Backend for sparse attention (default: pytorch). More backends coming soon.", | ||
| choices=["pytorch", "triton"], | ||
| help="Backend for sparse attention (default: pytorch). Use 'triton' with sparse24_triton.", | ||
| ) | ||
|
|
||
| # Sequence length arguments | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -72,8 +72,8 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): | |||||
| title="Backend implementation.", | ||||||
| description=( | ||||||
| "Backend to use for sparse attention computation. " | ||||||
| "Only 'pytorch' is supported, which uses softmax patching with F.softmax. " | ||||||
| "Requires model to be loaded with attn_implementation='eager'." | ||||||
| "'pytorch' uses softmax patching with F.softmax (requires attn_implementation='eager'). " | ||||||
| "'triton' uses the fused Triton kernel (requires attn_implementation='modelopt_triton')." | ||||||
| ), | ||||||
| ) | ||||||
|
|
||||||
|
|
@@ -89,10 +89,20 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): | |||||
| description=( | ||||||
| "Whether the model uses causal (autoregressive) attention. " | ||||||
| "If True, sparsity statistics are calculated over the lower triangle only. " | ||||||
| "Set to False for cross-attention models. " | ||||||
| "Defaults to True for decoder-only models like GPT, LLaMA, etc." | ||||||
| ), | ||||||
| ) | ||||||
|
|
||||||
| skip_diagonal_blocks: bool = ModeloptField( | ||||||
| default=True, | ||||||
| title="Skip diagonal blocks.", | ||||||
| description=( | ||||||
| "When True, keep diagonal tiles dense for 2:4 sparse attention. " | ||||||
| "Only used by sparse24_triton method. Defaults to True." | ||||||
| ), | ||||||
| ) | ||||||
|
|
||||||
| @field_validator("method") | ||||||
| @classmethod | ||||||
| def validate_method(cls, v): | ||||||
|
|
@@ -104,11 +114,12 @@ def validate_method(cls, v): | |||||
| @field_validator("backend") | ||||||
| @classmethod | ||||||
| def validate_backend(cls, v): | ||||||
| """Validate backend is pytorch.""" | ||||||
| if v != "pytorch": | ||||||
| """Validate backend is pytorch or triton.""" | ||||||
| if v not in ("pytorch", "triton"): | ||||||
| raise ValueError( | ||||||
| f"Invalid backend: {v}. Only 'pytorch' backend is supported. " | ||||||
| f"Model must be loaded with attn_implementation='eager'." | ||||||
| f"Invalid backend: {v}. Supported backends: 'pytorch' (requires " | ||||||
| f"attn_implementation='eager'), 'triton' (requires " | ||||||
| f"attn_implementation='modelopt_triton')." | ||||||
| ) | ||||||
| return v | ||||||
|
|
||||||
|
|
@@ -416,10 +427,24 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): | |||||
| }, | ||||||
| } | ||||||
|
|
||||||
| # 2:4 structured sparsity via Triton prefill kernel (prefill-only) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment says "prefill-only" but the kernel supports both prefill and decode. The PR description explicitly states the unified Triton kernel supports both prefill (2D kernel) and decode (3D kernel) paths with paged KV cache. The comment at line 429 is inaccurate and should be corrected to avoid misleading users. 📝 Proposed fix-# 2:4 structured sparsity via Triton prefill kernel (prefill-only)
+# 2:4 structured sparsity via Triton unified attention kernel (prefill + decode)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||
| SPARSE24_TRITON = { | ||||||
| "sparse_cfg": { | ||||||
| "*attn*": { | ||||||
| "method": "sparse24_triton", | ||||||
| "backend": "triton", | ||||||
| "skip_diagonal_blocks": True, | ||||||
| "enable": True, | ||||||
| }, | ||||||
| "default": {"enable": False}, | ||||||
| }, | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
| __all__ = [ | ||||||
| "SKIP_SOFTMAX_CALIB", | ||||||
| "SKIP_SOFTMAX_DEFAULT", | ||||||
| "SPARSE24_TRITON", | ||||||
| "CalibrationConfig", | ||||||
| "FlashSkipSoftmaxConfig", | ||||||
| "SparseAttentionAttributeConfig", | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Triton attention kernels for sparse attention optimization.""" | ||
|
|
||
| import torch | ||
|
|
||
| from modelopt.torch.utils import import_plugin | ||
|
|
||
| IS_AVAILABLE = False | ||
| context_attention_fwd = None | ||
| register_triton_attention = None | ||
| set_sparse24 = None | ||
| unified_attention = None | ||
|
|
||
| if torch.cuda.is_available(): | ||
| with import_plugin( | ||
| "triton", | ||
| msg_if_missing=( | ||
| "Your device is potentially capable of using the triton attention " | ||
| "kernel. Try to install triton with `pip install triton`." | ||
| ), | ||
| ): | ||
| from .triton_unified_attention import context_attention_fwd as _context_attention_fwd | ||
| from .triton_unified_attention import unified_attention as _unified_attention | ||
|
|
||
| context_attention_fwd = _context_attention_fwd | ||
| unified_attention = _unified_attention | ||
| IS_AVAILABLE = True | ||
| with import_plugin("transformers"): | ||
| from .hf_triton_attention import register_triton_attention as _register_triton_attention | ||
| from .hf_triton_attention import set_sparse24 as _set_sparse24 | ||
|
|
||
| register_triton_attention = _register_triton_attention | ||
| set_sparse24 = _set_sparse24 | ||
| _register_triton_attention() | ||
|
|
||
| __all__ = [ | ||
| "IS_AVAILABLE", | ||
| "context_attention_fwd", | ||
| "register_triton_attention", | ||
| "set_sparse24", | ||
| "unified_attention", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to update anything in example readme or changelog?