Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions src/mcore_bridge/model/modules/gated_delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
except ImportError:
_GatedDeltaNet = object

try:
from megatron.core.ssm.gated_delta_net import _build_thd_cp_a2a_perm
except ImportError:
_build_thd_cp_a2a_perm = None


# Code borrowed from NVIDIA/Megatron-LM
def _unpack_sequence(x, cu_seqlens, dim=1):
Expand Down Expand Up @@ -206,20 +211,34 @@ def forward(
qkvzba, _ = self.in_proj(hidden_states)
nvtx_range_pop(suffix='in_proj')

thd_cp_a2a_inv = None
if cp_size > 1:
from megatron.core.ssm.gated_delta_net import tensor_a2a_cp2hp, tensor_a2a_hp2cp
if cu_seqlens is not None:
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // self.cp_size, dim=0)
outputs = []
for qkvzba_i in unpacked_qkvzba:
qkvzba_i = tensor_a2a_cp2hp(
qkvzba_i,
if _build_thd_cp_a2a_perm is not None:
# Fused THD AlltoAll: single a2a + sequence permutation
qkvzba = tensor_a2a_cp2hp(
qkvzba,
seq_dim=0,
head_dim=-1,
cp_group=self.pg_collection.cp,
undo_attention_load_balancing=False,
)
outputs.append(qkvzba_i)
qkvzba = torch.cat(outputs, dim=0)
thd_cp_a2a_idx, thd_cp_a2a_inv = _build_thd_cp_a2a_perm(cu_seqlens, cp_size, seq_len)
qkvzba = qkvzba.index_select(0, thd_cp_a2a_idx)
else:
# Fallback: per-sequence loop
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // self.cp_size, dim=0)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use the local variable cp_size instead of self.cp_size for consistency with the rest of the method and to avoid potential AttributeError if self.cp_size is not defined on the parent class.

Suggested change
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // self.cp_size, dim=0)
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // cp_size, dim=0)

outputs = []
for qkvzba_i in unpacked_qkvzba:
qkvzba_i = tensor_a2a_cp2hp(
qkvzba_i,
seq_dim=0,
head_dim=-1,
cp_group=self.pg_collection.cp,
)
outputs.append(qkvzba_i)
qkvzba = torch.cat(outputs, dim=0)
else:
# CP All to All: CP to HP
qkvzba = tensor_a2a_cp2hp(
Expand Down Expand Up @@ -365,12 +384,25 @@ def forward(
norm_out = norm_out.transpose(0, 1).contiguous()
if cp_size > 1:
if cu_seqlens is not None:
unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens, dim=0)
outputs = []
for norm_out_i in unpacked_norm_out:
norm_out_i = tensor_a2a_hp2cp(norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp)
outputs.append(norm_out_i)
norm_out = torch.cat(outputs, dim=0)
if thd_cp_a2a_inv is not None:
# Fused THD AlltoAll: inverse permutation + single a2a
norm_out = norm_out.index_select(0, thd_cp_a2a_inv)
norm_out = tensor_a2a_hp2cp(
norm_out,
seq_dim=0,
head_dim=-1,
cp_group=self.pg_collection.cp,
redo_attention_load_balancing=False,
)
else:
# Fallback: per-sequence loop
unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens, dim=0)
outputs = []
for norm_out_i in unpacked_norm_out:
norm_out_i = tensor_a2a_hp2cp(
norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp)
outputs.append(norm_out_i)
norm_out = torch.cat(outputs, dim=0)
else:
norm_out = tensor_a2a_hp2cp(norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp)

Expand Down
Loading