From be6b1991f1bbed762102dab063ced595acb7bab9 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 22 Jun 2026 17:15:23 +0800 Subject: [PATCH 1/3] patch dsa --- src/mcore_bridge/model/modules/__init__.py | 1 + src/mcore_bridge/model/modules/dsa_indexer.py | 154 ++++++++++++++++++ src/mcore_bridge/model/register.py | 5 +- src/mcore_bridge/patcher.py | 154 ------------------ 4 files changed, 159 insertions(+), 155 deletions(-) create mode 100644 src/mcore_bridge/model/modules/dsa_indexer.py diff --git a/src/mcore_bridge/model/modules/__init__.py b/src/mcore_bridge/model/modules/__init__.py index 1a9a5e2..87b7bc7 100644 --- a/src/mcore_bridge/model/modules/__init__.py +++ b/src/mcore_bridge/model/modules/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from .dsa_indexer import DSAIndexer from .gated_delta_net import GatedDeltaNet from .gated_self_attention import GatedSelfAttention from .mtp_layer import MultiTokenPredictionLayer diff --git a/src/mcore_bridge/model/modules/dsa_indexer.py b/src/mcore_bridge/model/modules/dsa_indexer.py new file mode 100644 index 0000000..2fa917e --- /dev/null +++ b/src/mcore_bridge/model/modules/dsa_indexer.py @@ -0,0 +1,154 @@ + +import torch +from typing import Optional +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region +from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb + +from typing import Tuple + +try: + from megatron.core.models.gpt.experimental_attention_variant_module_specs import DSAIndexer as McoreDSAIndexer +except ImportError: + McoreDSAIndexer = None + +class DSAIndexer(McoreDSAIndexer): + + def forward_before_topk( + self, + x: torch.Tensor, + qr: torch.Tensor, + packed_seq_params: Optional[PackedSeqParams] = None, + ): + """All computations before topk.""" + from megatron.core.transformer.experimental_attention_variant.dsa import rotate_activation + + # ========================================= + # Gather inputs if sp is enabled + # ========================================= + packed_seq_params, rotary_pos_emb = packed_seq_params # patch + assert packed_seq_params is None, 'Packed sequence is not supported for DSAttention' + + if self.config.sequence_parallel and self.pg_collection.tp.size() > 1: + x = gather_from_sequence_parallel_region(x, group=self.pg_collection.tp) + qr = gather_from_sequence_parallel_region(qr, group=self.pg_collection.tp) + + # ========================================= + # Get sequence length and batch size + # ========================================= + seqlen, bsz, _ = x.size() + + # ========================================= + # q linear and apply rope to q + # ========================================= + # [seqlen, batch, q_lora_rank] -> [seqlen, batch, index_n_heads * index_head_dim] + q, _ = self.linear_wq_b(qr) + # [seqlen, batch, index_n_heads * index_head_dim] + # -> [seqlen, batch, index_n_heads, index_head_dim] + q = q.reshape(seqlen, bsz, self.index_n_heads, self.index_head_dim) + q = self._apply_rope(q, rotary_pos_emb) # mscale will be passed in by patch + + # ========================================= + # k linear and apply rope to k + # ========================================= + # [seqlen, batch, hidden_size] -> [seqlen, batch, index_head_dim] + k, _ = self.linear_wk(x) + k = self.k_norm(k) + # [seqlen, batch, index_head_dim] -> [seqlen, batch, 1, index_head_dim] + k = k.reshape(seqlen, bsz, 1, self.index_head_dim) + k = self._apply_rope(k, rotary_pos_emb) + # [seqlen, batch, 1, index_head_dim] -> [seqlen, batch, index_head_dim] + k = k.reshape(seqlen, bsz, self.index_head_dim) + + # ========================================= + # Rotate activation + # ========================================= + q = rotate_activation(q) + k = rotate_activation(k) + + # ========================================= + # Prepare weights for index scores + # ========================================= + # [seqlen, batch, hidden_size] -> [seqlen, batch, index_n_heads] + weights, _ = self.linear_weights_proj(x) + weights = weights * (self.index_n_heads**-0.5) * self.softmax_scale + + return q, k, weights + + def _apply_rope(self, x: torch.Tensor, rotary_pos_emb: torch.Tensor): + """Apply RoPE to the input tensor.""" + # x_nope [seqlen, batch, *, index_head_dim - qk_pos_emb_head_dim] + # x_pe [seqlen, batch, *, qk_pos_emb_head_dim] + x_pe, x_nope = torch.split( + x, [self.index_head_dim - self.qk_pos_emb_head_dim, self.qk_pos_emb_head_dim], dim=-1) + origin_multi_latent_attention = self.config.multi_latent_attention + try: + self.config.multi_latent_attention = self.config.dsa_indexer_rotary_interleaved + x_pe = apply_rotary_pos_emb( + x_pe, + rotary_pos_emb, + config=self.config, + cu_seqlens=None, + cp_group=self.pg_collection.cp, + ) + finally: + self.config.multi_latent_attention = origin_multi_latent_attention + # [seqlen, batch, *, index_head_dim] + x = torch.cat([x_pe, x_nope], dim=-1) + return x + + def forward_with_scores( + self, + x: torch.Tensor, + qr: torch.Tensor, + mask: Optional[torch.Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for DSA Indexer that returns both index scores and top-k indices. + + This is used when KL loss is enabled to compare indexer scores with true attention scores. + + Args: + x: hidden states [seqlen, batch, hidden_size]. + qr: Low-rank query tensor [seqlen, batch, q_lora_rank]. + mask: Attention mask [batch, seqlen, seqlen]. + packed_seq_params: Packed sequence parameters for variable length sequences. + + Returns: + index_scores: Index scores [batch, seqlen, seqlen]. + topk_indices: Top-k indices [batch, seqlen, index_topk]. + """ + try: + from megatron.core.transformer.experimental_attention_variant.dsa import fused_qk_topk_naive + except ImportError: + raise ImportError('fused_qk_topk_naive is not available. Please install "megatron-core>=0.17.0"') + # [seqlen, batch, index_n_heads * index_head_dim] + # [seqlen, batch, index_head_dim] + # [seqlen, batch, index_n_heads] + q, k, weights = self.forward_before_topk(x, qr, packed_seq_params) + + # [batch, seqlen, seqlen], [batch, seqlen, index_topk] + index_scores, topk_indices = fused_qk_topk_naive(q, k, weights, self.index_topk, mask) + + return index_scores, topk_indices + + def forward(self, + x: torch.Tensor, + qr: torch.Tensor, + mask: Optional[torch.Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None): + """ + Forward pass for DSA Indexer. + + Args: + x: hidden states [seqlen, batch, hidden_size]. + qr: Low-rank query tensor [seqlen, batch, q_lora_rank]. + mask: Attention mask [batch, seqlen, seqlen]. + packed_seq_params: Packed sequence parameters for variable length sequences. + + Returns: + topk_indices: Top-k indices for sparse attention [batch, seqlen, index_topk]. + """ + _, topk_indices = self.forward_with_scores(x, qr, mask, packed_seq_params) + return topk_indices diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index be81845..c6a0aaa 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -20,7 +20,8 @@ from mcore_bridge.config import ModelConfig from mcore_bridge.utils import get_logger -from .modules import MLASelfAttention, MultiTokenPredictionLayer, TopKRouter, TransformerBlock, TransformerLayer +from .modules import (DSAIndexer, MLASelfAttention, MultiTokenPredictionLayer, TopKRouter, TransformerBlock, + TransformerLayer) if TYPE_CHECKING: from .gpt_model import GPTModel @@ -96,6 +97,8 @@ def _replace_spec_dsa(self, layer_spec): _get_backend_spec_provider, get_dsa_module_spec_for_backend) backend = _get_backend_spec_provider(config=self.config) dsa_spec = get_dsa_module_spec_for_backend(self.config, backend) + if getattr(dsa_spec.submodules.core_attention.submodules, 'indexer', None) is not None: + dsa_spec.submodules.core_attention.submodules.indexer = DSAIndexer if self.config.qk_layernorm: linear_q_up_proj = backend.column_parallel_linear() # fix megatron-core diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 4779cd0..9806d38 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -6,10 +6,7 @@ from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELinear from megatron.core.models.common.embeddings import rope_utils -from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb from megatron.core.models.common.embeddings.rotary_pos_embedding import MultimodalRotaryEmbedding -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import TransformerConfig from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionBlock, get_mtp_layer_offset from packaging import version @@ -235,153 +232,6 @@ def apply_rotary_pos_emb( rope_utils.apply_rotary_pos_emb = apply_rotary_pos_emb -def _patch_dsa(): - from megatron.core.models.gpt import experimental_attention_variant_module_specs - from megatron.core.transformer.experimental_attention_variant.dsa import rotate_activation - _DSAIndexer = experimental_attention_variant_module_specs.DSAIndexer - - class DSAIndexer(_DSAIndexer): - - def forward_before_topk( - self, - x: torch.Tensor, - qr: torch.Tensor, - packed_seq_params: Optional[PackedSeqParams] = None, - ): - """All computations before topk.""" - # ========================================= - # Gather inputs if sp is enabled - # ========================================= - packed_seq_params, rotary_pos_emb = packed_seq_params # patch - assert packed_seq_params is None, 'Packed sequence is not supported for DSAttention' - - if self.config.sequence_parallel and self.pg_collection.tp.size() > 1: - x = gather_from_sequence_parallel_region(x, group=self.pg_collection.tp) - qr = gather_from_sequence_parallel_region(qr, group=self.pg_collection.tp) - - # ========================================= - # Get sequence length and batch size - # ========================================= - seqlen, bsz, _ = x.size() - - # ========================================= - # q linear and apply rope to q - # ========================================= - # [seqlen, batch, q_lora_rank] -> [seqlen, batch, index_n_heads * index_head_dim] - q, _ = self.linear_wq_b(qr) - # [seqlen, batch, index_n_heads * index_head_dim] - # -> [seqlen, batch, index_n_heads, index_head_dim] - q = q.reshape(seqlen, bsz, self.index_n_heads, self.index_head_dim) - q = self._apply_rope(q, rotary_pos_emb) # mscale will be passed in by patch - - # ========================================= - # k linear and apply rope to k - # ========================================= - # [seqlen, batch, hidden_size] -> [seqlen, batch, index_head_dim] - k, _ = self.linear_wk(x) - k = self.k_norm(k) - # [seqlen, batch, index_head_dim] -> [seqlen, batch, 1, index_head_dim] - k = k.reshape(seqlen, bsz, 1, self.index_head_dim) - k = self._apply_rope(k, rotary_pos_emb) - # [seqlen, batch, 1, index_head_dim] -> [seqlen, batch, index_head_dim] - k = k.reshape(seqlen, bsz, self.index_head_dim) - - # ========================================= - # Rotate activation - # ========================================= - q = rotate_activation(q) - k = rotate_activation(k) - - # ========================================= - # Prepare weights for index scores - # ========================================= - # [seqlen, batch, hidden_size] -> [seqlen, batch, index_n_heads] - weights, _ = self.linear_weights_proj(x) - weights = weights * (self.index_n_heads**-0.5) * self.softmax_scale - - return q, k, weights - - def _apply_rope(self, x: torch.Tensor, rotary_pos_emb: torch.Tensor): - """Apply RoPE to the input tensor.""" - # x_nope [seqlen, batch, *, index_head_dim - qk_pos_emb_head_dim] - # x_pe [seqlen, batch, *, qk_pos_emb_head_dim] - x_pe, x_nope = torch.split( - x, [self.index_head_dim - self.qk_pos_emb_head_dim, self.qk_pos_emb_head_dim], dim=-1) - origin_multi_latent_attention = self.config.multi_latent_attention - try: - self.config.multi_latent_attention = self.config.dsa_indexer_rotary_interleaved - x_pe = apply_rotary_pos_emb( - x_pe, - rotary_pos_emb, - config=self.config, - cu_seqlens=None, - cp_group=self.pg_collection.cp, - ) - finally: - self.config.multi_latent_attention = origin_multi_latent_attention - # [seqlen, batch, *, index_head_dim] - x = torch.cat([x_pe, x_nope], dim=-1) - return x - - def forward_with_scores( - self, - x: torch.Tensor, - qr: torch.Tensor, - mask: Optional[torch.Tensor] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Forward pass for DSA Indexer that returns both index scores and top-k indices. - - This is used when KL loss is enabled to compare indexer scores with true attention scores. - - Args: - x: hidden states [seqlen, batch, hidden_size]. - qr: Low-rank query tensor [seqlen, batch, q_lora_rank]. - mask: Attention mask [batch, seqlen, seqlen]. - packed_seq_params: Packed sequence parameters for variable length sequences. - - Returns: - index_scores: Index scores [batch, seqlen, seqlen]. - topk_indices: Top-k indices [batch, seqlen, index_topk]. - """ - try: - from megatron.core.transformer.experimental_attention_variant.dsa import fused_qk_topk_naive - except ImportError: - raise ImportError('fused_qk_topk_naive is not available. Please install "megatron-core>=0.17.0"') - # [seqlen, batch, index_n_heads * index_head_dim] - # [seqlen, batch, index_head_dim] - # [seqlen, batch, index_n_heads] - q, k, weights = self.forward_before_topk(x, qr, packed_seq_params) - - # [batch, seqlen, seqlen], [batch, seqlen, index_topk] - index_scores, topk_indices = fused_qk_topk_naive(q, k, weights, self.index_topk, mask) - - return index_scores, topk_indices - - def forward(self, - x: torch.Tensor, - qr: torch.Tensor, - mask: Optional[torch.Tensor] = None, - packed_seq_params: Optional[PackedSeqParams] = None): - """ - Forward pass for DSA Indexer. - - Args: - x: hidden states [seqlen, batch, hidden_size]. - qr: Low-rank query tensor [seqlen, batch, q_lora_rank]. - mask: Attention mask [batch, seqlen, seqlen]. - packed_seq_params: Packed sequence parameters for variable length sequences. - - Returns: - topk_indices: Top-k indices for sparse attention [batch, seqlen, index_topk]. - """ - _, topk_indices = self.forward_with_scores(x, qr, mask, packed_seq_params) - return topk_indices - - experimental_attention_variant_module_specs.DSAIndexer = DSAIndexer - - def _patch_mtp(): def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, hidden_states: torch.Tensor, @@ -444,7 +294,3 @@ def apply_patch(): _patch_mrope() _patch_mtp() from mcore_bridge import tuners # apply patch - try: - _patch_dsa() - except ImportError: - pass From c4ca01040d2da8c8dcc546ad709c85ddccc2311e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 22 Jun 2026 17:19:22 +0800 Subject: [PATCH 2/3] fix --- src/mcore_bridge/model/modules/dsa_indexer.py | 8 +++----- src/mcore_bridge/model/register.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/mcore_bridge/model/modules/dsa_indexer.py b/src/mcore_bridge/model/modules/dsa_indexer.py index 2fa917e..b1ecb81 100644 --- a/src/mcore_bridge/model/modules/dsa_indexer.py +++ b/src/mcore_bridge/model/modules/dsa_indexer.py @@ -1,17 +1,15 @@ - import torch -from typing import Optional +from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region -from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb - -from typing import Tuple +from typing import Optional, Tuple try: from megatron.core.models.gpt.experimental_attention_variant_module_specs import DSAIndexer as McoreDSAIndexer except ImportError: McoreDSAIndexer = None + class DSAIndexer(McoreDSAIndexer): def forward_before_topk( diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index c6a0aaa..5f39664 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -98,7 +98,7 @@ def _replace_spec_dsa(self, layer_spec): backend = _get_backend_spec_provider(config=self.config) dsa_spec = get_dsa_module_spec_for_backend(self.config, backend) if getattr(dsa_spec.submodules.core_attention.submodules, 'indexer', None) is not None: - dsa_spec.submodules.core_attention.submodules.indexer = DSAIndexer + dsa_spec.submodules.core_attention.submodules.indexer.module = DSAIndexer if self.config.qk_layernorm: linear_q_up_proj = backend.column_parallel_linear() # fix megatron-core From 5f583984285045547d7b5104e98a435ff1172072 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 22 Jun 2026 19:42:35 +0800 Subject: [PATCH 3/3] update --- src/mcore_bridge/model/modules/dsa_indexer.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/mcore_bridge/model/modules/dsa_indexer.py b/src/mcore_bridge/model/modules/dsa_indexer.py index b1ecb81..e148a53 100644 --- a/src/mcore_bridge/model/modules/dsa_indexer.py +++ b/src/mcore_bridge/model/modules/dsa_indexer.py @@ -1,7 +1,9 @@ import torch +import transformer_engine from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region +from megatron.core.transformer.spec_utils import build_module from typing import Optional, Tuple try: @@ -12,6 +14,22 @@ class DSAIndexer(McoreDSAIndexer): + def __init__(self, config, submodules, *args, **kwargs): + super().__init__(config, submodules, *args, **kwargs) + if getattr(config, 'fp8_param', False): + with transformer_engine.pytorch.fp8_model_init(enabled=False): + self.linear_weights_proj = build_module( + submodules.linear_weights_proj, + self.hidden_size, + self.index_n_heads, + config=self.config, + init_method=self.config.init_method, + bias=False, + skip_bias_add=False, + skip_weight_param_allocation=False, + parallel_mode='duplicated', + ) + def forward_before_topk( self, x: torch.Tensor,