diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 94b6e12..d34a125 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -179,7 +179,11 @@ def __init__( config.kv_channels = self.head_dim config.num_query_groups = self.num_key_value_heads if text_config.use_bidirectional_attention == 'vision': - kwargs['attn_mask_type'] = AttnMaskType.arbitrary + # In hybrid mode (backend != unfused), only full attention layers need arbitrary mask; + # sliding layers keep causal mask and use native flash sliding window. + # When backend is unfused, all layers use arbitrary mask. + if config.attention_backend.name == 'unfused' or not self.is_sliding: + kwargs['attn_mask_type'] = AttnMaskType.arbitrary if self.is_kv_shared_layer: submodules.k_layernorm = IdentityOp try: @@ -217,9 +221,10 @@ def _forward_core_attention( attention_mask, attention_bias: Optional[torch.Tensor] = None, packed_seq_params: Optional[PackedSeqParams] = None, + override_attn_mask_type=None, ): nvtx_range_push(suffix='core_attention') - attn_mask_type = self.attn_mask_type + attn_mask_type = override_attn_mask_type if override_attn_mask_type is not None else self.attn_mask_type if self.checkpoint_core_attention and self.training: core_attn_out = self._checkpointed_attention_forward( query, @@ -298,6 +303,7 @@ def _apply_rotary(self, query, key, rotary_pos_emb, packed_seq_params): return query, key def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tuple[Tensor, Tensor]: + override_attn_mask_type = kwargs.pop('override_attn_mask_type', None) shared_kv_states = kwargs['shared_kv_states'] rotary_pos_emb = kwargs.get('rotary_pos_emb') packed_seq_params = kwargs.get('packed_seq_params') @@ -353,7 +359,7 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tu if self.store_full_length_kv: shared_kv_states[self.layer_type] = key, value core_attn_out = self._forward_core_attention(query, key, value, attention_mask, attention_bias, - packed_seq_params) + packed_seq_params, override_attn_mask_type) nvtx_range_push(suffix='linear_proj') output, bias = self.linear_proj(core_attn_out) @@ -516,15 +522,12 @@ def __init__(self, config, *args, **kwargs): # If set to "vision", pass attention_mask manually. text_config = config.hf_config.text_config if text_config.use_bidirectional_attention == 'vision': - if config.attention_backend.name != 'unfused': - logger.warning( - f'attention_backend {config.attention_backend.name} does not support use_bidirectional_attention ' - 'for vision. Setting `use_bidirectional_attention` to None. Note: This may cause computational ' - 'errors in multimodal scenarios. Please always pass pure text data.') - text_config.use_bidirectional_attention = None - else: + if config.attention_backend.name == 'unfused': + # All layers use unfused: disable native sliding window, handle manually via mask config.window_size = None config.window_attn_skip_freq = None + # else: hybrid mode - full attention layers fallback to unfused via arbitrary mask type, + # sliding layers use the configured backend (flash/fused) with native sliding window. super().__init__(config, *args, **kwargs) self.num_query_groups_per_partition = self.decoder.layers[0].self_attention.num_query_groups_per_partition self.text_config = text_config @@ -532,6 +535,9 @@ def __init__(self, config, *args, **kwargs): self.unique_layer_types = set(text_config.layer_types) self.hidden_size_per_layer_input = text_config.hidden_size_per_layer_input self.final_logit_softcapping = text_config.final_logit_softcapping + # hybrid_attention_mode: full layers use unfused (via arbitrary mask), sliding layers use flash + self.hybrid_attention_mode = ( + text_config.use_bidirectional_attention == 'vision' and config.attention_backend.name != 'unfused') if self.hidden_size_per_layer_input and self.pre_process: total_dim = self.config.num_layers * self.hidden_size_per_layer_input self.embed_tokens_per_layer = VocabParallelEmbedding( @@ -633,12 +639,16 @@ def forward(self, *args, **kwargs): def _update_attention_mask(self, attention_mask, mm_token_type_ids): sliding_attention = attention_mask['sliding_attention'] full_attention = attention_mask['full_attention'] - # sliding - window_size = self.text_config.sliding_window - 1 - seq_len = sliding_attention.shape[-1] - window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=sliding_attention.device) - window_mask = ~torch.triu(window_mask, diagonal=-window_size) - sliding_attention = sliding_attention | window_mask + # In hybrid mode, only skip sliding mask when no vision tokens (pure text batch). + # When vision tokens are present, fall back to all-unfused behavior for correctness. + use_hybrid_sliding = self.hybrid_attention_mode and mm_token_type_ids is None + if not use_hybrid_sliding: + # Construct sliding window mask manually (all-unfused or vision batch) + window_size = self.text_config.sliding_window - 1 + seq_len = sliding_attention.shape[-1] + window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=sliding_attention.device) + window_mask = ~torch.triu(window_mask, diagonal=-window_size) + sliding_attention = sliding_attention | window_mask if mm_token_type_ids is not None: is_vision = mm_token_type_ids > 0 is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1) @@ -652,6 +662,8 @@ def _update_attention_mask(self, attention_mask, mm_token_type_ids): full_attention = full_attention & ~same_vision_group attention_mask['sliding_attention'] = sliding_attention attention_mask['full_attention'] = full_attention + # Signal sliding layers to override attn_mask_type to arbitrary in hybrid mode with vision + attention_mask['_sliding_override_arbitrary'] = self.hybrid_attention_mode and mm_token_type_ids is not None def _pack_pp_output(self, hidden_states, per_layer_inputs, shared_kv_states): per_layer_inputs = per_layer_inputs.view(*hidden_states.shape[:2], -1) @@ -767,8 +779,12 @@ def __init__(self, config, submodules, *args, **kwargs): TENorm, hidden_size=hidden_size, config=self.config, eps=eps) def _forward_attention(self, hidden_states: Tensor, **kwargs): + attn_mask_dict = kwargs['attention_mask'] kwargs['rotary_pos_emb'] = kwargs['rotary_pos_emb'][self.layer_type] - kwargs['attention_mask'] = kwargs['attention_mask'][self.layer_type] + kwargs['attention_mask'] = attn_mask_dict[self.layer_type] + # In hybrid mode with vision, sliding layers need to override to arbitrary (unfused fallback) + if attn_mask_dict.get('_sliding_override_arbitrary') and self.layer_type == 'sliding_attention': + kwargs['override_attn_mask_type'] = AttnMaskType.arbitrary context = kwargs.pop('context', None) residual = hidden_states input_layernorm_output = self.input_layernorm(hidden_states)