diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index d88fb02..c85c2ca 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -29,6 +29,11 @@ except ImportError: _GatedDeltaNet = object +try: + from megatron.core.ssm.gated_delta_net import _build_thd_cp_a2a_perm +except ImportError: + _build_thd_cp_a2a_perm = None + # Code borrowed from NVIDIA/Megatron-LM def _unpack_sequence(x, cu_seqlens, dim=1): @@ -206,20 +211,34 @@ def forward( qkvzba, _ = self.in_proj(hidden_states) nvtx_range_pop(suffix='in_proj') + thd_cp_a2a_inv = None if cp_size > 1: from megatron.core.ssm.gated_delta_net import tensor_a2a_cp2hp, tensor_a2a_hp2cp if cu_seqlens is not None: - unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // self.cp_size, dim=0) - outputs = [] - for qkvzba_i in unpacked_qkvzba: - qkvzba_i = tensor_a2a_cp2hp( - qkvzba_i, + if _build_thd_cp_a2a_perm is not None: + # Fused THD AlltoAll: single a2a + sequence permutation + qkvzba = tensor_a2a_cp2hp( + qkvzba, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp, + undo_attention_load_balancing=False, ) - outputs.append(qkvzba_i) - qkvzba = torch.cat(outputs, dim=0) + thd_cp_a2a_idx, thd_cp_a2a_inv = _build_thd_cp_a2a_perm(cu_seqlens, cp_size, seq_len) + qkvzba = qkvzba.index_select(0, thd_cp_a2a_idx) + else: + # Fallback: per-sequence loop + unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // self.cp_size, dim=0) + outputs = [] + for qkvzba_i in unpacked_qkvzba: + qkvzba_i = tensor_a2a_cp2hp( + qkvzba_i, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + ) + outputs.append(qkvzba_i) + qkvzba = torch.cat(outputs, dim=0) else: # CP All to All: CP to HP qkvzba = tensor_a2a_cp2hp( @@ -365,12 +384,25 @@ def forward( norm_out = norm_out.transpose(0, 1).contiguous() if cp_size > 1: if cu_seqlens is not None: - unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens, dim=0) - outputs = [] - for norm_out_i in unpacked_norm_out: - norm_out_i = tensor_a2a_hp2cp(norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp) - outputs.append(norm_out_i) - norm_out = torch.cat(outputs, dim=0) + if thd_cp_a2a_inv is not None: + # Fused THD AlltoAll: inverse permutation + single a2a + norm_out = norm_out.index_select(0, thd_cp_a2a_inv) + norm_out = tensor_a2a_hp2cp( + norm_out, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + redo_attention_load_balancing=False, + ) + else: + # Fallback: per-sequence loop + unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens, dim=0) + outputs = [] + for norm_out_i in unpacked_norm_out: + norm_out_i = tensor_a2a_hp2cp( + norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp) + outputs.append(norm_out_i) + norm_out = torch.cat(outputs, dim=0) else: norm_out = tensor_a2a_hp2cp(norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp)