Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 73 additions & 5 deletions src/mcore_bridge/model/gpts/deepseek_v4.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import copy
import torch
import transformer_engine.pytorch as te
from contextlib import contextmanager
from megatron.core import tensor_parallel
from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb
Expand All @@ -12,6 +13,7 @@

from ..constant import ModelType
from ..gpt_model import GPTModel
from ..modules.compressor import Compressor, CSAIndexer
Comment thread
Jintao-Huang marked this conversation as resolved.
from ..register import ModelLoader, ModelMeta, register_model
from ..rope import get_rope_inv_freq

Expand Down Expand Up @@ -68,6 +70,19 @@ def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.layer_type = self.config.hf_config.layer_types[self.layer_number - 1]
self.rope_layer_type = 'main' if self.layer_type == 'sliding_attention' else 'compress'
if getattr(config, 'fp8_param', False):
group_proj_in_size = self.query_projection_size // config.o_groups
del self.linear_o_group_proj
self.linear_o_group_proj = te.GroupedLinear(
num_gemms=config.o_groups,
in_features=group_proj_in_size,
out_features=config.o_lora_rank,
bias=False,
params_dtype=config.params_dtype,
)
self._o_group_proj_is_grouped_linear = True
else:
self._o_group_proj_is_grouped_linear = False
Comment thread
Jintao-Huang marked this conversation as resolved.

def get_query_key_value_tensors(
self,
Expand Down Expand Up @@ -312,10 +327,23 @@ def forward(
core_attn_out = core_attn_out.view(seq_len, core_attn_out.size(1), -1)

# Grouped output
core_attn_out = core_attn_out.view(core_attn_out.size(0), core_attn_out.size(1), self.o_local_groups, -1)
wo_a_weight = self.linear_o_group_proj.view(self.o_local_groups, self.config.o_lora_rank, -1)
core_attn_out = torch.einsum('...gd,grd->...gr', core_attn_out, wo_a_weight)
core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1)
if self._o_group_proj_is_grouped_linear:
s, b = core_attn_out.size(0), core_attn_out.size(1)
# [s, b, G*D] -> [G, s*b, D] -> [G*s*b, D]
core_attn_out = core_attn_out.view(s, b, self.o_local_groups, -1)
core_attn_out = core_attn_out.permute(2, 0, 1, 3).contiguous()
core_attn_out = core_attn_out.reshape(-1, core_attn_out.size(-1))
m_splits = [s * b] * self.o_local_groups
core_attn_out = self.linear_o_group_proj(core_attn_out, m_splits)
# [G*s*b, R] -> [G, s, b, R] -> [s, b, G*R]
core_attn_out = core_attn_out.view(self.o_local_groups, s, b, -1)
core_attn_out = core_attn_out.permute(1, 2, 0, 3).contiguous()
core_attn_out = core_attn_out.reshape(s, b, -1)
else:
core_attn_out = core_attn_out.view(core_attn_out.size(0), core_attn_out.size(1), self.o_local_groups, -1)
wo_a_weight = self.linear_o_group_proj.view(self.o_local_groups, self.config.o_lora_rank, -1)
core_attn_out = torch.einsum('...gd,grd->...gr', core_attn_out, wo_a_weight)
core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1)

# =================
# Output. [sq, b, h]
Expand Down Expand Up @@ -367,6 +395,12 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
transformer_layer_spec = get_transformer_block_with_experimental_attention_variant_spec(self.config, vp_stage)
for layer_spec in transformer_layer_spec.layer_specs:
layer_spec.submodules.self_attention.module = DSv4HybridSelfAttention
core_attention_submodules = layer_spec.submodules.self_attention.submodules.core_attention.submodules
if getattr(core_attention_submodules, 'compressor', None) is not None:
core_attention_submodules.compressor.module = Compressor
if getattr(core_attention_submodules, 'indexer', None) is not None:
core_attention_submodules.indexer.module = CSAIndexer
core_attention_submodules.indexer.submodules.compressor.module = Compressor
return transformer_layer_spec


Expand All @@ -381,6 +415,37 @@ class DeepseekV4Bridge(GPTBridge):
hf_post_attention_layernorm_key = 'ffn_norm.weight'
hf_expert_bias_key = 'gate.bias'

def _set_o_group_proj_grouped(self, mg_attn, hf_state_dict, to_mcore):
"""Handle GroupedLinear state dict for linear_o_group_proj in fp8 mode.

HF stores a single wo_a.weight of shape [G*R, D].
GroupedLinear stores per-gemm weight{i} each of shape [R, D].
"""
o_groups = self.config.o_groups
if to_mcore:
hf_weight = hf_state_dict['wo_a.weight'].load()
hf_scale_inv = None
if 'wo_a.weight_scale_inv' in hf_state_dict:
hf_scale_inv = hf_state_dict['wo_a.weight_scale_inv'].load()
weights = hf_weight.chunk(o_groups, dim=0)
scale_invs = hf_scale_inv.chunk(o_groups, dim=0) if hf_scale_inv is not None else [None] * o_groups
for i, (w, s) in enumerate(zip(weights, scale_invs)):
param = getattr(mg_attn.linear_o_group_proj, f'weight{i}')
self._set_param(param, w, s)
else:
weights = []
scale_invs = []
for i in range(o_groups):
param = getattr(mg_attn.linear_o_group_proj, f'weight{i}')
if self._is_fp8_param(param):
weights.append(param._rowwise_data)
scale_invs.append(param._rowwise_scale_inv)
else:
weights.append(param.data)
hf_state_dict['wo_a.weight'] = torch.cat(weights, dim=0)
if scale_invs:
hf_state_dict['wo_a.weight_scale_inv'] = torch.cat(scale_invs, dim=0)
Comment thread
Jintao-Huang marked this conversation as resolved.

def _convert_hf_state_dict(self, hf_state_dict, to_mcore):
res = super()._convert_hf_state_dict(hf_state_dict, to_mcore)
if to_mcore:
Expand Down Expand Up @@ -436,7 +501,10 @@ def _set_mla_attn_state(
else:
hf_state_dict = {}
self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'wo_b.weight', to_mcore)
self._set_state_dict(mg_attn, 'linear_o_group_proj', hf_state_dict, 'wo_a.weight', to_mcore)
if self.config.fp8_param:
self._set_o_group_proj_grouped(mg_attn, hf_state_dict, to_mcore)
else:
self._set_state_dict(mg_attn, 'linear_o_group_proj', hf_state_dict, 'wo_a.weight', to_mcore)
self._set_state_dict(mg_attn, 'linear_q_down_proj.weight', hf_state_dict, 'wq_a.weight', to_mcore)
self._set_state_dict(mg_attn, 'linear_q_up_proj.weight', hf_state_dict, 'wq_b.weight', to_mcore)
self._set_state_dict(mg_attn, 'linear_kv_proj.weight', hf_state_dict, 'wkv.weight', to_mcore)
Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/model/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .compressor import Compressor, CSAIndexer
from .dsa_indexer import DSAIndexer
from .gated_delta_net import GatedDeltaNet
from .gated_self_attention import GatedSelfAttention
Expand Down
58 changes: 58 additions & 0 deletions src/mcore_bridge/model/modules/compressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import transformer_engine
from megatron.core.transformer.spec_utils import build_module

try:
from megatron.core.transformer.experimental_attention_variant.csa import Compressor as McoreCompressor
from megatron.core.transformer.experimental_attention_variant.csa import CSAIndexer as McoreCSAIndexer
except ImportError:
McoreCompressor = object
McoreCSAIndexer = object


class Compressor(McoreCompressor):

def __init__(self, config, submodules, *args, **kwargs):
super().__init__(config, submodules, *args, **kwargs)
Comment thread
Jintao-Huang marked this conversation as resolved.
if getattr(config, 'fp8_param', False):
with transformer_engine.pytorch.fp8_model_init(enabled=False):
self.linear_wkv = build_module(
submodules.linear_wkv,
config.hidden_size,
self.coff * self.head_dim,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
skip_weight_param_allocation=False,
parallel_mode='duplicated',
)
self.linear_wgate = build_module(
submodules.linear_wgate,
config.hidden_size,
self.coff * self.head_dim,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
skip_weight_param_allocation=False,
parallel_mode='duplicated',
)


class CSAIndexer(McoreCSAIndexer):

def __init__(self, config, submodules, *args, **kwargs):
super().__init__(config, submodules, *args, **kwargs)
Comment thread
Jintao-Huang marked this conversation as resolved.
if getattr(config, 'fp8_param', False):
with transformer_engine.pytorch.fp8_model_init(enabled=False):
self.linear_weights_proj = build_module(
submodules.linear_weights_proj,
self.hidden_size,
self.index_n_heads,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
skip_weight_param_allocation=False,
parallel_mode='duplicated',
)
1 change: 1 addition & 0 deletions src/mcore_bridge/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def _apply_rotary_pos_emb_thd(t: torch.Tensor, cu_seqlens: torch.Tensor, freqs:
logger.warning_once('Using non-batched RoPE, which may affect performance.')
return _origin_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, *args, **kwargs)

kwargs.pop('max_seqlen', None) # compat megatron-lm dev branch
return rope_utils._apply_rotary_pos_emb_bshd(t.unsqueeze(1), freqs, *args, **kwargs).squeeze(1)

rope_utils._apply_rotary_pos_emb_thd = _apply_rotary_pos_emb_thd
Expand Down
Loading