Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions src/mcore_bridge/model/mm_gpts/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,11 @@ def __init__(
config.kv_channels = self.head_dim
config.num_query_groups = self.num_key_value_heads
if text_config.use_bidirectional_attention == 'vision':
kwargs['attn_mask_type'] = AttnMaskType.arbitrary
# In hybrid mode (backend != unfused), only full attention layers need arbitrary mask;
# sliding layers keep causal mask and use native flash sliding window.
# When backend is unfused, all layers use arbitrary mask.
if config.attention_backend.name == 'unfused' or not self.is_sliding:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent potential AttributeError if config.attention_backend is None or does not have a name attribute (e.g., in certain testing or custom configurations), use getattr to safely retrieve the backend name.

Suggested change
if config.attention_backend.name == 'unfused' or not self.is_sliding:
if getattr(config.attention_backend, 'name', None) == 'unfused' or not self.is_sliding:

kwargs['attn_mask_type'] = AttnMaskType.arbitrary
if self.is_kv_shared_layer:
submodules.k_layernorm = IdentityOp
try:
Expand Down Expand Up @@ -217,9 +221,10 @@ def _forward_core_attention(
attention_mask,
attention_bias: Optional[torch.Tensor] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
override_attn_mask_type=None,
):
nvtx_range_push(suffix='core_attention')
attn_mask_type = self.attn_mask_type
attn_mask_type = override_attn_mask_type if override_attn_mask_type is not None else self.attn_mask_type
if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query,
Expand Down Expand Up @@ -298,6 +303,7 @@ def _apply_rotary(self, query, key, rotary_pos_emb, packed_seq_params):
return query, key

def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tuple[Tensor, Tensor]:
override_attn_mask_type = kwargs.pop('override_attn_mask_type', None)
shared_kv_states = kwargs['shared_kv_states']
rotary_pos_emb = kwargs.get('rotary_pos_emb')
packed_seq_params = kwargs.get('packed_seq_params')
Expand Down Expand Up @@ -353,7 +359,7 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tu
if self.store_full_length_kv:
shared_kv_states[self.layer_type] = key, value
core_attn_out = self._forward_core_attention(query, key, value, attention_mask, attention_bias,
packed_seq_params)
packed_seq_params, override_attn_mask_type)

nvtx_range_push(suffix='linear_proj')
output, bias = self.linear_proj(core_attn_out)
Expand Down Expand Up @@ -516,22 +522,22 @@ def __init__(self, config, *args, **kwargs):
# If set to "vision", pass attention_mask manually.
text_config = config.hf_config.text_config
if text_config.use_bidirectional_attention == 'vision':
if config.attention_backend.name != 'unfused':
logger.warning(
f'attention_backend {config.attention_backend.name} does not support use_bidirectional_attention '
'for vision. Setting `use_bidirectional_attention` to None. Note: This may cause computational '
'errors in multimodal scenarios. Please always pass pure text data.')
text_config.use_bidirectional_attention = None
else:
if config.attention_backend.name == 'unfused':

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Safely access the attention_backend name using getattr to avoid potential AttributeError if the backend is not configured or is None.

Suggested change
if config.attention_backend.name == 'unfused':
if getattr(config.attention_backend, 'name', None) == 'unfused':

# All layers use unfused: disable native sliding window, handle manually via mask
config.window_size = None
config.window_attn_skip_freq = None
# else: hybrid mode - full attention layers fallback to unfused via arbitrary mask type,
# sliding layers use the configured backend (flash/fused) with native sliding window.
super().__init__(config, *args, **kwargs)
self.num_query_groups_per_partition = self.decoder.layers[0].self_attention.num_query_groups_per_partition
self.text_config = text_config
self.num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0)
self.unique_layer_types = set(text_config.layer_types)
self.hidden_size_per_layer_input = text_config.hidden_size_per_layer_input
self.final_logit_softcapping = text_config.final_logit_softcapping
# hybrid_attention_mode: full layers use unfused (via arbitrary mask), sliding layers use flash
self.hybrid_attention_mode = (
text_config.use_bidirectional_attention == 'vision' and config.attention_backend.name != 'unfused')
Comment on lines +539 to +540

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Safely access the attention_backend name using getattr to avoid potential AttributeError if the backend is not configured or is None.

        self.hybrid_attention_mode = (\n            text_config.use_bidirectional_attention == 'vision' and getattr(config.attention_backend, 'name', None) != 'unfused')

if self.hidden_size_per_layer_input and self.pre_process:
total_dim = self.config.num_layers * self.hidden_size_per_layer_input
self.embed_tokens_per_layer = VocabParallelEmbedding(
Expand Down Expand Up @@ -633,12 +639,16 @@ def forward(self, *args, **kwargs):
def _update_attention_mask(self, attention_mask, mm_token_type_ids):
sliding_attention = attention_mask['sliding_attention']
full_attention = attention_mask['full_attention']
# sliding
window_size = self.text_config.sliding_window - 1
seq_len = sliding_attention.shape[-1]
window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=sliding_attention.device)
window_mask = ~torch.triu(window_mask, diagonal=-window_size)
sliding_attention = sliding_attention | window_mask
# In hybrid mode, only skip sliding mask when no vision tokens (pure text batch).
# When vision tokens are present, fall back to all-unfused behavior for correctness.
use_hybrid_sliding = self.hybrid_attention_mode and mm_token_type_ids is None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If mm_token_type_ids is not None but contains no vision tokens (e.g., a tensor of all zeros, which is common for text-only batches in multimodal pipelines), checking mm_token_type_ids is None will incorrectly evaluate to False. This disables hybrid sliding and falls back to the slower manual sliding window mask. We should check if there are actually any vision tokens present in the batch.

        has_vision = mm_token_type_ids is not None and (mm_token_type_ids > 0).any()\n        use_hybrid_sliding = self.hybrid_attention_mode and not has_vision

if not use_hybrid_sliding:
# Construct sliding window mask manually (all-unfused or vision batch)
window_size = self.text_config.sliding_window - 1
seq_len = sliding_attention.shape[-1]
window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=sliding_attention.device)
window_mask = ~torch.triu(window_mask, diagonal=-window_size)
sliding_attention = sliding_attention | window_mask
if mm_token_type_ids is not None:
is_vision = mm_token_type_ids > 0
is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1)
Expand All @@ -652,6 +662,8 @@ def _update_attention_mask(self, attention_mask, mm_token_type_ids):
full_attention = full_attention & ~same_vision_group
attention_mask['sliding_attention'] = sliding_attention
attention_mask['full_attention'] = full_attention
# Signal sliding layers to override attn_mask_type to arbitrary in hybrid mode with vision
attention_mask['_sliding_override_arbitrary'] = self.hybrid_attention_mode and mm_token_type_ids is not None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use the has_vision boolean flag defined earlier to avoid overriding the attention mask to arbitrary when there are no actual vision tokens in the batch.

Suggested change
attention_mask['_sliding_override_arbitrary'] = self.hybrid_attention_mode and mm_token_type_ids is not None
attention_mask['_sliding_override_arbitrary'] = self.hybrid_attention_mode and has_vision


def _pack_pp_output(self, hidden_states, per_layer_inputs, shared_kv_states):
per_layer_inputs = per_layer_inputs.view(*hidden_states.shape[:2], -1)
Expand Down Expand Up @@ -767,8 +779,12 @@ def __init__(self, config, submodules, *args, **kwargs):
TENorm, hidden_size=hidden_size, config=self.config, eps=eps)

def _forward_attention(self, hidden_states: Tensor, **kwargs):
attn_mask_dict = kwargs['attention_mask']
kwargs['rotary_pos_emb'] = kwargs['rotary_pos_emb'][self.layer_type]
kwargs['attention_mask'] = kwargs['attention_mask'][self.layer_type]
kwargs['attention_mask'] = attn_mask_dict[self.layer_type]
# In hybrid mode with vision, sliding layers need to override to arbitrary (unfused fallback)
if attn_mask_dict.get('_sliding_override_arbitrary') and self.layer_type == 'sliding_attention':
kwargs['override_attn_mask_type'] = AttnMaskType.arbitrary
context = kwargs.pop('context', None)
residual = hidden_states
input_layernorm_output = self.input_layernorm(hidden_states)
Expand Down
Loading