diff --git a/src/mcore_bridge/model/gpts/deepseek_v4.py b/src/mcore_bridge/model/gpts/deepseek_v4.py index 65c9270..14eb047 100644 --- a/src/mcore_bridge/model/gpts/deepseek_v4.py +++ b/src/mcore_bridge/model/gpts/deepseek_v4.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import copy import torch +import transformer_engine.pytorch as te from contextlib import contextmanager from megatron.core import tensor_parallel from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb @@ -12,6 +13,7 @@ from ..constant import ModelType from ..gpt_model import GPTModel +from ..modules.compressor import Compressor, CSAIndexer from ..register import ModelLoader, ModelMeta, register_model from ..rope import get_rope_inv_freq @@ -68,6 +70,19 @@ def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) self.layer_type = self.config.hf_config.layer_types[self.layer_number - 1] self.rope_layer_type = 'main' if self.layer_type == 'sliding_attention' else 'compress' + if getattr(config, 'fp8_param', False): + group_proj_in_size = self.query_projection_size // config.o_groups + del self.linear_o_group_proj + self.linear_o_group_proj = te.GroupedLinear( + num_gemms=config.o_groups, + in_features=group_proj_in_size, + out_features=config.o_lora_rank, + bias=False, + params_dtype=config.params_dtype, + ) + self._o_group_proj_is_grouped_linear = True + else: + self._o_group_proj_is_grouped_linear = False def get_query_key_value_tensors( self, @@ -312,10 +327,23 @@ def forward( core_attn_out = core_attn_out.view(seq_len, core_attn_out.size(1), -1) # Grouped output - core_attn_out = core_attn_out.view(core_attn_out.size(0), core_attn_out.size(1), self.o_local_groups, -1) - wo_a_weight = self.linear_o_group_proj.view(self.o_local_groups, self.config.o_lora_rank, -1) - core_attn_out = torch.einsum('...gd,grd->...gr', core_attn_out, wo_a_weight) - core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1) + if self._o_group_proj_is_grouped_linear: + s, b = core_attn_out.size(0), core_attn_out.size(1) + # [s, b, G*D] -> [G, s*b, D] -> [G*s*b, D] + core_attn_out = core_attn_out.view(s, b, self.o_local_groups, -1) + core_attn_out = core_attn_out.permute(2, 0, 1, 3).contiguous() + core_attn_out = core_attn_out.reshape(-1, core_attn_out.size(-1)) + m_splits = [s * b] * self.o_local_groups + core_attn_out = self.linear_o_group_proj(core_attn_out, m_splits) + # [G*s*b, R] -> [G, s, b, R] -> [s, b, G*R] + core_attn_out = core_attn_out.view(self.o_local_groups, s, b, -1) + core_attn_out = core_attn_out.permute(1, 2, 0, 3).contiguous() + core_attn_out = core_attn_out.reshape(s, b, -1) + else: + core_attn_out = core_attn_out.view(core_attn_out.size(0), core_attn_out.size(1), self.o_local_groups, -1) + wo_a_weight = self.linear_o_group_proj.view(self.o_local_groups, self.config.o_lora_rank, -1) + core_attn_out = torch.einsum('...gd,grd->...gr', core_attn_out, wo_a_weight) + core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1) # ================= # Output. [sq, b, h] @@ -367,6 +395,12 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): transformer_layer_spec = get_transformer_block_with_experimental_attention_variant_spec(self.config, vp_stage) for layer_spec in transformer_layer_spec.layer_specs: layer_spec.submodules.self_attention.module = DSv4HybridSelfAttention + core_attention_submodules = layer_spec.submodules.self_attention.submodules.core_attention.submodules + if getattr(core_attention_submodules, 'compressor', None) is not None: + core_attention_submodules.compressor.module = Compressor + if getattr(core_attention_submodules, 'indexer', None) is not None: + core_attention_submodules.indexer.module = CSAIndexer + core_attention_submodules.indexer.submodules.compressor.module = Compressor return transformer_layer_spec @@ -381,6 +415,37 @@ class DeepseekV4Bridge(GPTBridge): hf_post_attention_layernorm_key = 'ffn_norm.weight' hf_expert_bias_key = 'gate.bias' + def _set_o_group_proj_grouped(self, mg_attn, hf_state_dict, to_mcore): + """Handle GroupedLinear state dict for linear_o_group_proj in fp8 mode. + + HF stores a single wo_a.weight of shape [G*R, D]. + GroupedLinear stores per-gemm weight{i} each of shape [R, D]. + """ + o_groups = self.config.o_groups + if to_mcore: + hf_weight = hf_state_dict['wo_a.weight'].load() + hf_scale_inv = None + if 'wo_a.weight_scale_inv' in hf_state_dict: + hf_scale_inv = hf_state_dict['wo_a.weight_scale_inv'].load() + weights = hf_weight.chunk(o_groups, dim=0) + scale_invs = hf_scale_inv.chunk(o_groups, dim=0) if hf_scale_inv is not None else [None] * o_groups + for i, (w, s) in enumerate(zip(weights, scale_invs)): + param = getattr(mg_attn.linear_o_group_proj, f'weight{i}') + self._set_param(param, w, s) + else: + weights = [] + scale_invs = [] + for i in range(o_groups): + param = getattr(mg_attn.linear_o_group_proj, f'weight{i}') + if self._is_fp8_param(param): + weights.append(param._rowwise_data) + scale_invs.append(param._rowwise_scale_inv) + else: + weights.append(param.data) + hf_state_dict['wo_a.weight'] = torch.cat(weights, dim=0) + if scale_invs: + hf_state_dict['wo_a.weight_scale_inv'] = torch.cat(scale_invs, dim=0) + def _convert_hf_state_dict(self, hf_state_dict, to_mcore): res = super()._convert_hf_state_dict(hf_state_dict, to_mcore) if to_mcore: @@ -436,7 +501,10 @@ def _set_mla_attn_state( else: hf_state_dict = {} self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'wo_b.weight', to_mcore) - self._set_state_dict(mg_attn, 'linear_o_group_proj', hf_state_dict, 'wo_a.weight', to_mcore) + if self.config.fp8_param: + self._set_o_group_proj_grouped(mg_attn, hf_state_dict, to_mcore) + else: + self._set_state_dict(mg_attn, 'linear_o_group_proj', hf_state_dict, 'wo_a.weight', to_mcore) self._set_state_dict(mg_attn, 'linear_q_down_proj.weight', hf_state_dict, 'wq_a.weight', to_mcore) self._set_state_dict(mg_attn, 'linear_q_up_proj.weight', hf_state_dict, 'wq_b.weight', to_mcore) self._set_state_dict(mg_attn, 'linear_kv_proj.weight', hf_state_dict, 'wkv.weight', to_mcore) diff --git a/src/mcore_bridge/model/modules/__init__.py b/src/mcore_bridge/model/modules/__init__.py index 87b7bc7..edf406d 100644 --- a/src/mcore_bridge/model/modules/__init__.py +++ b/src/mcore_bridge/model/modules/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from .compressor import Compressor, CSAIndexer from .dsa_indexer import DSAIndexer from .gated_delta_net import GatedDeltaNet from .gated_self_attention import GatedSelfAttention diff --git a/src/mcore_bridge/model/modules/compressor.py b/src/mcore_bridge/model/modules/compressor.py new file mode 100644 index 0000000..897a11c --- /dev/null +++ b/src/mcore_bridge/model/modules/compressor.py @@ -0,0 +1,58 @@ +import transformer_engine +from megatron.core.transformer.spec_utils import build_module + +try: + from megatron.core.transformer.experimental_attention_variant.csa import Compressor as McoreCompressor + from megatron.core.transformer.experimental_attention_variant.csa import CSAIndexer as McoreCSAIndexer +except ImportError: + McoreCompressor = object + McoreCSAIndexer = object + + +class Compressor(McoreCompressor): + + def __init__(self, config, submodules, *args, **kwargs): + super().__init__(config, submodules, *args, **kwargs) + if getattr(config, 'fp8_param', False): + with transformer_engine.pytorch.fp8_model_init(enabled=False): + self.linear_wkv = build_module( + submodules.linear_wkv, + config.hidden_size, + self.coff * self.head_dim, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + skip_weight_param_allocation=False, + parallel_mode='duplicated', + ) + self.linear_wgate = build_module( + submodules.linear_wgate, + config.hidden_size, + self.coff * self.head_dim, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + skip_weight_param_allocation=False, + parallel_mode='duplicated', + ) + + +class CSAIndexer(McoreCSAIndexer): + + def __init__(self, config, submodules, *args, **kwargs): + super().__init__(config, submodules, *args, **kwargs) + if getattr(config, 'fp8_param', False): + with transformer_engine.pytorch.fp8_model_init(enabled=False): + self.linear_weights_proj = build_module( + submodules.linear_weights_proj, + self.hidden_size, + self.index_n_heads, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + skip_weight_param_allocation=False, + parallel_mode='duplicated', + ) diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 9806d38..37390ac 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -207,6 +207,7 @@ def _apply_rotary_pos_emb_thd(t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: logger.warning_once('Using non-batched RoPE, which may affect performance.') return _origin_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, *args, **kwargs) + kwargs.pop('max_seqlen', None) # compat megatron-lm dev branch return rope_utils._apply_rotary_pos_emb_bshd(t.unsqueeze(1), freqs, *args, **kwargs).squeeze(1) rope_utils._apply_rotary_pos_emb_thd = _apply_rotary_pos_emb_thd