From fda8f39d5366ce32a10bf7838c48f83d3f9b471e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 22 Jun 2026 20:30:21 +0800 Subject: [PATCH 1/7] fix deepseek fp8 --- src/mcore_bridge/model/gpts/deepseek_v4.py | 7 +++++++ src/mcore_bridge/model/modules/__init__.py | 1 + src/mcore_bridge/model/modules/compressor.py | 5 +++++ 3 files changed, 13 insertions(+) create mode 100644 src/mcore_bridge/model/modules/compressor.py diff --git a/src/mcore_bridge/model/gpts/deepseek_v4.py b/src/mcore_bridge/model/gpts/deepseek_v4.py index 65c9270..73279ec 100644 --- a/src/mcore_bridge/model/gpts/deepseek_v4.py +++ b/src/mcore_bridge/model/gpts/deepseek_v4.py @@ -367,6 +367,13 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): transformer_layer_spec = get_transformer_block_with_experimental_attention_variant_spec(self.config, vp_stage) for layer_spec in transformer_layer_spec.layer_specs: layer_spec.submodules.self_attention.module = DSv4HybridSelfAttention + for layer_spec in transformer_layer_spec.layer_specs: + self._set_mlp_spec(layer_spec.submodules, DSv4HybridMLP) + core_attention_submodules = layer_spec.submodules.self_attention.submodules.core_attention.submodules + if getattr(core_attention_submodules, 'compressor') is not None: + core_attention_submodules.compressor.module = Compressor + if getattr(core_attention_submodules, 'indexer') is not None: + core_attention_submodules.compressor.indexer.module = Compressor return transformer_layer_spec diff --git a/src/mcore_bridge/model/modules/__init__.py b/src/mcore_bridge/model/modules/__init__.py index 87b7bc7..6cb056d 100644 --- a/src/mcore_bridge/model/modules/__init__.py +++ b/src/mcore_bridge/model/modules/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .dsa_indexer import DSAIndexer +# from .compressor import Compressor from .gated_delta_net import GatedDeltaNet from .gated_self_attention import GatedSelfAttention from .mtp_layer import MultiTokenPredictionLayer diff --git a/src/mcore_bridge/model/modules/compressor.py b/src/mcore_bridge/model/modules/compressor.py new file mode 100644 index 0000000..3f2ff2d --- /dev/null +++ b/src/mcore_bridge/model/modules/compressor.py @@ -0,0 +1,5 @@ + + + + + From 4be1679198b445d1725863ba5b742694aecf77a1 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 22 Jun 2026 20:37:41 +0800 Subject: [PATCH 2/7] update --- src/mcore_bridge/model/gpts/deepseek_v4.py | 4 +- src/mcore_bridge/model/modules/__init__.py | 2 +- src/mcore_bridge/model/modules/compressor.py | 53 ++++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/src/mcore_bridge/model/gpts/deepseek_v4.py b/src/mcore_bridge/model/gpts/deepseek_v4.py index 73279ec..019fa9d 100644 --- a/src/mcore_bridge/model/gpts/deepseek_v4.py +++ b/src/mcore_bridge/model/gpts/deepseek_v4.py @@ -12,6 +12,7 @@ from ..constant import ModelType from ..gpt_model import GPTModel +from ..modules.compressor import Compressor, CSAIndexer from ..register import ModelLoader, ModelMeta, register_model from ..rope import get_rope_inv_freq @@ -373,7 +374,8 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): if getattr(core_attention_submodules, 'compressor') is not None: core_attention_submodules.compressor.module = Compressor if getattr(core_attention_submodules, 'indexer') is not None: - core_attention_submodules.compressor.indexer.module = Compressor + core_attention_submodules.indexer.module = CSAIndexer + core_attention_submodules.indexer.compressor.module = Compressor return transformer_layer_spec diff --git a/src/mcore_bridge/model/modules/__init__.py b/src/mcore_bridge/model/modules/__init__.py index 6cb056d..edf406d 100644 --- a/src/mcore_bridge/model/modules/__init__.py +++ b/src/mcore_bridge/model/modules/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from .compressor import Compressor, CSAIndexer from .dsa_indexer import DSAIndexer -# from .compressor import Compressor from .gated_delta_net import GatedDeltaNet from .gated_self_attention import GatedSelfAttention from .mtp_layer import MultiTokenPredictionLayer diff --git a/src/mcore_bridge/model/modules/compressor.py b/src/mcore_bridge/model/modules/compressor.py index 3f2ff2d..897a11c 100644 --- a/src/mcore_bridge/model/modules/compressor.py +++ b/src/mcore_bridge/model/modules/compressor.py @@ -1,5 +1,58 @@ +import transformer_engine +from megatron.core.transformer.spec_utils import build_module +try: + from megatron.core.transformer.experimental_attention_variant.csa import Compressor as McoreCompressor + from megatron.core.transformer.experimental_attention_variant.csa import CSAIndexer as McoreCSAIndexer +except ImportError: + McoreCompressor = object + McoreCSAIndexer = object +class Compressor(McoreCompressor): + def __init__(self, config, submodules, *args, **kwargs): + super().__init__(config, submodules, *args, **kwargs) + if getattr(config, 'fp8_param', False): + with transformer_engine.pytorch.fp8_model_init(enabled=False): + self.linear_wkv = build_module( + submodules.linear_wkv, + config.hidden_size, + self.coff * self.head_dim, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + skip_weight_param_allocation=False, + parallel_mode='duplicated', + ) + self.linear_wgate = build_module( + submodules.linear_wgate, + config.hidden_size, + self.coff * self.head_dim, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + skip_weight_param_allocation=False, + parallel_mode='duplicated', + ) + +class CSAIndexer(McoreCSAIndexer): + + def __init__(self, config, submodules, *args, **kwargs): + super().__init__(config, submodules, *args, **kwargs) + if getattr(config, 'fp8_param', False): + with transformer_engine.pytorch.fp8_model_init(enabled=False): + self.linear_weights_proj = build_module( + submodules.linear_weights_proj, + self.hidden_size, + self.index_n_heads, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + skip_weight_param_allocation=False, + parallel_mode='duplicated', + ) From 023f141377d8b5a416d21a06f34aa772f0cb3e0c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 22 Jun 2026 20:46:02 +0800 Subject: [PATCH 3/7] fix --- src/mcore_bridge/model/gpts/deepseek_v4.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/mcore_bridge/model/gpts/deepseek_v4.py b/src/mcore_bridge/model/gpts/deepseek_v4.py index 019fa9d..03837f1 100644 --- a/src/mcore_bridge/model/gpts/deepseek_v4.py +++ b/src/mcore_bridge/model/gpts/deepseek_v4.py @@ -369,13 +369,12 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): for layer_spec in transformer_layer_spec.layer_specs: layer_spec.submodules.self_attention.module = DSv4HybridSelfAttention for layer_spec in transformer_layer_spec.layer_specs: - self._set_mlp_spec(layer_spec.submodules, DSv4HybridMLP) core_attention_submodules = layer_spec.submodules.self_attention.submodules.core_attention.submodules if getattr(core_attention_submodules, 'compressor') is not None: core_attention_submodules.compressor.module = Compressor if getattr(core_attention_submodules, 'indexer') is not None: core_attention_submodules.indexer.module = CSAIndexer - core_attention_submodules.indexer.compressor.module = Compressor + core_attention_submodules.indexer.submodules.compressor.module = Compressor return transformer_layer_spec From 1b145d56acfd6a2f358c34d8bee163fdcff1c096 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 23 Jun 2026 11:26:02 +0800 Subject: [PATCH 4/7] update --- src/mcore_bridge/model/gpts/deepseek_v4.py | 70 ++++++++++++++++++++-- 1 file changed, 65 insertions(+), 5 deletions(-) diff --git a/src/mcore_bridge/model/gpts/deepseek_v4.py b/src/mcore_bridge/model/gpts/deepseek_v4.py index 03837f1..d7e32b9 100644 --- a/src/mcore_bridge/model/gpts/deepseek_v4.py +++ b/src/mcore_bridge/model/gpts/deepseek_v4.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import copy import torch +import transformer_engine.pytorch as te from contextlib import contextmanager from megatron.core import tensor_parallel from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb @@ -69,6 +70,18 @@ def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) self.layer_type = self.config.hf_config.layer_types[self.layer_number - 1] self.rope_layer_type = 'main' if self.layer_type == 'sliding_attention' else 'compress' + if getattr(config, 'fp8_param', False): + group_proj_in_size = self.query_projection_size // config.o_groups + del self.linear_o_group_proj + self.linear_o_group_proj = te.GroupedLinear( + num_gemms=config.o_groups, + in_features=group_proj_in_size, + out_features=config.o_lora_rank, + bias=False, + ) + self._o_group_proj_is_grouped_linear = True + else: + self._o_group_proj_is_grouped_linear = False def get_query_key_value_tensors( self, @@ -313,10 +326,23 @@ def forward( core_attn_out = core_attn_out.view(seq_len, core_attn_out.size(1), -1) # Grouped output - core_attn_out = core_attn_out.view(core_attn_out.size(0), core_attn_out.size(1), self.o_local_groups, -1) - wo_a_weight = self.linear_o_group_proj.view(self.o_local_groups, self.config.o_lora_rank, -1) - core_attn_out = torch.einsum('...gd,grd->...gr', core_attn_out, wo_a_weight) - core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1) + if self._o_group_proj_is_grouped_linear: + s, b = core_attn_out.size(0), core_attn_out.size(1) + # [s, b, G*D] -> [G, s*b, D] -> [G*s*b, D] + core_attn_out = core_attn_out.view(s, b, self.o_local_groups, -1) + core_attn_out = core_attn_out.permute(2, 0, 1, 3).contiguous() + core_attn_out = core_attn_out.reshape(-1, core_attn_out.size(-1)) + m_splits = [s * b] * self.o_local_groups + core_attn_out, _ = self.linear_o_group_proj(core_attn_out, m_splits) + # [G*s*b, R] -> [G, s, b, R] -> [s, b, G*R] + core_attn_out = core_attn_out.view(self.o_local_groups, s, b, -1) + core_attn_out = core_attn_out.permute(1, 2, 0, 3).contiguous() + core_attn_out = core_attn_out.reshape(s, b, -1) + else: + core_attn_out = core_attn_out.view(core_attn_out.size(0), core_attn_out.size(1), self.o_local_groups, -1) + wo_a_weight = self.linear_o_group_proj.view(self.o_local_groups, self.config.o_lora_rank, -1) + core_attn_out = torch.einsum('...gd,grd->...gr', core_attn_out, wo_a_weight) + core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1) # ================= # Output. [sq, b, h] @@ -389,6 +415,37 @@ class DeepseekV4Bridge(GPTBridge): hf_post_attention_layernorm_key = 'ffn_norm.weight' hf_expert_bias_key = 'gate.bias' + def _set_o_group_proj_grouped(self, mg_attn, hf_state_dict, to_mcore): + """Handle GroupedLinear state dict for linear_o_group_proj in fp8 mode. + + HF stores a single wo_a.weight of shape [G*R, D]. + GroupedLinear stores per-gemm weight{i} each of shape [R, D]. + """ + o_groups = self.config.o_groups + if to_mcore: + hf_weight = hf_state_dict['wo_a.weight'].load() + hf_scale_inv = None + if 'wo_a.weight_scale_inv' in hf_state_dict: + hf_scale_inv = hf_state_dict['wo_a.weight_scale_inv'].load() + weights = hf_weight.chunk(o_groups, dim=0) + scale_invs = hf_scale_inv.chunk(o_groups, dim=0) if hf_scale_inv is not None else [None] * o_groups + for i, (w, s) in enumerate(zip(weights, scale_invs)): + param = getattr(mg_attn.linear_o_group_proj, f'weight{i}') + self._set_param(param, w, s) + else: + weights = [] + scale_invs = [] + for i in range(o_groups): + param = getattr(mg_attn.linear_o_group_proj, f'weight{i}') + if self._is_fp8_param(param): + weights.append(param._rowwise_data) + scale_invs.append(param._rowwise_scale_inv) + else: + weights.append(param.data) + hf_state_dict['wo_a.weight'] = torch.cat(weights, dim=0) + if scale_invs: + hf_state_dict['wo_a.weight_scale_inv'] = torch.cat(scale_invs, dim=0) + def _convert_hf_state_dict(self, hf_state_dict, to_mcore): res = super()._convert_hf_state_dict(hf_state_dict, to_mcore) if to_mcore: @@ -444,7 +501,10 @@ def _set_mla_attn_state( else: hf_state_dict = {} self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'wo_b.weight', to_mcore) - self._set_state_dict(mg_attn, 'linear_o_group_proj', hf_state_dict, 'wo_a.weight', to_mcore) + if self.config.fp8_param: + self._set_o_group_proj_grouped(mg_attn, hf_state_dict, to_mcore) + else: + self._set_state_dict(mg_attn, 'linear_o_group_proj', hf_state_dict, 'wo_a.weight', to_mcore) self._set_state_dict(mg_attn, 'linear_q_down_proj.weight', hf_state_dict, 'wq_a.weight', to_mcore) self._set_state_dict(mg_attn, 'linear_q_up_proj.weight', hf_state_dict, 'wq_b.weight', to_mcore) self._set_state_dict(mg_attn, 'linear_kv_proj.weight', hf_state_dict, 'wkv.weight', to_mcore) From e1882b9e2c1a397baf59872ad436939dc1b99a79 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 23 Jun 2026 11:39:07 +0800 Subject: [PATCH 5/7] fix --- src/mcore_bridge/model/gpts/deepseek_v4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcore_bridge/model/gpts/deepseek_v4.py b/src/mcore_bridge/model/gpts/deepseek_v4.py index d7e32b9..d6dae2e 100644 --- a/src/mcore_bridge/model/gpts/deepseek_v4.py +++ b/src/mcore_bridge/model/gpts/deepseek_v4.py @@ -394,7 +394,6 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): transformer_layer_spec = get_transformer_block_with_experimental_attention_variant_spec(self.config, vp_stage) for layer_spec in transformer_layer_spec.layer_specs: layer_spec.submodules.self_attention.module = DSv4HybridSelfAttention - for layer_spec in transformer_layer_spec.layer_specs: core_attention_submodules = layer_spec.submodules.self_attention.submodules.core_attention.submodules if getattr(core_attention_submodules, 'compressor') is not None: core_attention_submodules.compressor.module = Compressor From 5d7a62cb1d2c3514b9239816fdd696b64ad8f6e4 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 23 Jun 2026 11:43:18 +0800 Subject: [PATCH 6/7] update --- src/mcore_bridge/model/gpts/deepseek_v4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcore_bridge/model/gpts/deepseek_v4.py b/src/mcore_bridge/model/gpts/deepseek_v4.py index d6dae2e..aa7edf3 100644 --- a/src/mcore_bridge/model/gpts/deepseek_v4.py +++ b/src/mcore_bridge/model/gpts/deepseek_v4.py @@ -395,9 +395,9 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): for layer_spec in transformer_layer_spec.layer_specs: layer_spec.submodules.self_attention.module = DSv4HybridSelfAttention core_attention_submodules = layer_spec.submodules.self_attention.submodules.core_attention.submodules - if getattr(core_attention_submodules, 'compressor') is not None: + if getattr(core_attention_submodules, 'compressor', None) is not None: core_attention_submodules.compressor.module = Compressor - if getattr(core_attention_submodules, 'indexer') is not None: + if getattr(core_attention_submodules, 'indexer', None) is not None: core_attention_submodules.indexer.module = CSAIndexer core_attention_submodules.indexer.submodules.compressor.module = Compressor return transformer_layer_spec From 39dad9723eaca1ffda470c3f235ef8038b65d98e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 23 Jun 2026 14:46:14 +0800 Subject: [PATCH 7/7] fix --- src/mcore_bridge/model/gpts/deepseek_v4.py | 3 ++- src/mcore_bridge/patcher.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/gpts/deepseek_v4.py b/src/mcore_bridge/model/gpts/deepseek_v4.py index aa7edf3..14eb047 100644 --- a/src/mcore_bridge/model/gpts/deepseek_v4.py +++ b/src/mcore_bridge/model/gpts/deepseek_v4.py @@ -78,6 +78,7 @@ def __init__(self, config, *args, **kwargs): in_features=group_proj_in_size, out_features=config.o_lora_rank, bias=False, + params_dtype=config.params_dtype, ) self._o_group_proj_is_grouped_linear = True else: @@ -333,7 +334,7 @@ def forward( core_attn_out = core_attn_out.permute(2, 0, 1, 3).contiguous() core_attn_out = core_attn_out.reshape(-1, core_attn_out.size(-1)) m_splits = [s * b] * self.o_local_groups - core_attn_out, _ = self.linear_o_group_proj(core_attn_out, m_splits) + core_attn_out = self.linear_o_group_proj(core_attn_out, m_splits) # [G*s*b, R] -> [G, s, b, R] -> [s, b, G*R] core_attn_out = core_attn_out.view(self.o_local_groups, s, b, -1) core_attn_out = core_attn_out.permute(1, 2, 0, 3).contiguous() diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 9806d38..37390ac 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -207,6 +207,7 @@ def _apply_rotary_pos_emb_thd(t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: logger.warning_once('Using non-batched RoPE, which may affect performance.') return _origin_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, *args, **kwargs) + kwargs.pop('max_seqlen', None) # compat megatron-lm dev branch return rope_utils._apply_rotary_pos_emb_bshd(t.unsqueeze(1), freqs, *args, **kwargs).squeeze(1) rope_utils._apply_rotary_pos_emb_thd = _apply_rotary_pos_emb_thd