-
Notifications
You must be signed in to change notification settings - Fork 20
support gemma4 use_hybrid_sliding #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -179,7 +179,11 @@ def __init__( | |||||
| config.kv_channels = self.head_dim | ||||||
| config.num_query_groups = self.num_key_value_heads | ||||||
| if text_config.use_bidirectional_attention == 'vision': | ||||||
| kwargs['attn_mask_type'] = AttnMaskType.arbitrary | ||||||
| # In hybrid mode (backend != unfused), only full attention layers need arbitrary mask; | ||||||
| # sliding layers keep causal mask and use native flash sliding window. | ||||||
| # When backend is unfused, all layers use arbitrary mask. | ||||||
| if config.attention_backend.name == 'unfused' or not self.is_sliding: | ||||||
| kwargs['attn_mask_type'] = AttnMaskType.arbitrary | ||||||
| if self.is_kv_shared_layer: | ||||||
| submodules.k_layernorm = IdentityOp | ||||||
| try: | ||||||
|
|
@@ -217,9 +221,10 @@ def _forward_core_attention( | |||||
| attention_mask, | ||||||
| attention_bias: Optional[torch.Tensor] = None, | ||||||
| packed_seq_params: Optional[PackedSeqParams] = None, | ||||||
| override_attn_mask_type=None, | ||||||
| ): | ||||||
| nvtx_range_push(suffix='core_attention') | ||||||
| attn_mask_type = self.attn_mask_type | ||||||
| attn_mask_type = override_attn_mask_type if override_attn_mask_type is not None else self.attn_mask_type | ||||||
| if self.checkpoint_core_attention and self.training: | ||||||
| core_attn_out = self._checkpointed_attention_forward( | ||||||
| query, | ||||||
|
|
@@ -298,6 +303,7 @@ def _apply_rotary(self, query, key, rotary_pos_emb, packed_seq_params): | |||||
| return query, key | ||||||
|
|
||||||
| def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tuple[Tensor, Tensor]: | ||||||
| override_attn_mask_type = kwargs.pop('override_attn_mask_type', None) | ||||||
| shared_kv_states = kwargs['shared_kv_states'] | ||||||
| rotary_pos_emb = kwargs.get('rotary_pos_emb') | ||||||
| packed_seq_params = kwargs.get('packed_seq_params') | ||||||
|
|
@@ -353,7 +359,7 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tu | |||||
| if self.store_full_length_kv: | ||||||
| shared_kv_states[self.layer_type] = key, value | ||||||
| core_attn_out = self._forward_core_attention(query, key, value, attention_mask, attention_bias, | ||||||
| packed_seq_params) | ||||||
| packed_seq_params, override_attn_mask_type) | ||||||
|
|
||||||
| nvtx_range_push(suffix='linear_proj') | ||||||
| output, bias = self.linear_proj(core_attn_out) | ||||||
|
|
@@ -516,22 +522,22 @@ def __init__(self, config, *args, **kwargs): | |||||
| # If set to "vision", pass attention_mask manually. | ||||||
| text_config = config.hf_config.text_config | ||||||
| if text_config.use_bidirectional_attention == 'vision': | ||||||
| if config.attention_backend.name != 'unfused': | ||||||
| logger.warning( | ||||||
| f'attention_backend {config.attention_backend.name} does not support use_bidirectional_attention ' | ||||||
| 'for vision. Setting `use_bidirectional_attention` to None. Note: This may cause computational ' | ||||||
| 'errors in multimodal scenarios. Please always pass pure text data.') | ||||||
| text_config.use_bidirectional_attention = None | ||||||
| else: | ||||||
| if config.attention_backend.name == 'unfused': | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| # All layers use unfused: disable native sliding window, handle manually via mask | ||||||
| config.window_size = None | ||||||
| config.window_attn_skip_freq = None | ||||||
| # else: hybrid mode - full attention layers fallback to unfused via arbitrary mask type, | ||||||
| # sliding layers use the configured backend (flash/fused) with native sliding window. | ||||||
| super().__init__(config, *args, **kwargs) | ||||||
| self.num_query_groups_per_partition = self.decoder.layers[0].self_attention.num_query_groups_per_partition | ||||||
| self.text_config = text_config | ||||||
| self.num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) | ||||||
| self.unique_layer_types = set(text_config.layer_types) | ||||||
| self.hidden_size_per_layer_input = text_config.hidden_size_per_layer_input | ||||||
| self.final_logit_softcapping = text_config.final_logit_softcapping | ||||||
| # hybrid_attention_mode: full layers use unfused (via arbitrary mask), sliding layers use flash | ||||||
| self.hybrid_attention_mode = ( | ||||||
| text_config.use_bidirectional_attention == 'vision' and config.attention_backend.name != 'unfused') | ||||||
|
Comment on lines
+539
to
+540
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| if self.hidden_size_per_layer_input and self.pre_process: | ||||||
| total_dim = self.config.num_layers * self.hidden_size_per_layer_input | ||||||
| self.embed_tokens_per_layer = VocabParallelEmbedding( | ||||||
|
|
@@ -633,12 +639,16 @@ def forward(self, *args, **kwargs): | |||||
| def _update_attention_mask(self, attention_mask, mm_token_type_ids): | ||||||
| sliding_attention = attention_mask['sliding_attention'] | ||||||
| full_attention = attention_mask['full_attention'] | ||||||
| # sliding | ||||||
| window_size = self.text_config.sliding_window - 1 | ||||||
| seq_len = sliding_attention.shape[-1] | ||||||
| window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=sliding_attention.device) | ||||||
| window_mask = ~torch.triu(window_mask, diagonal=-window_size) | ||||||
| sliding_attention = sliding_attention | window_mask | ||||||
| # In hybrid mode, only skip sliding mask when no vision tokens (pure text batch). | ||||||
| # When vision tokens are present, fall back to all-unfused behavior for correctness. | ||||||
| use_hybrid_sliding = self.hybrid_attention_mode and mm_token_type_ids is None | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If has_vision = mm_token_type_ids is not None and (mm_token_type_ids > 0).any()\n use_hybrid_sliding = self.hybrid_attention_mode and not has_vision |
||||||
| if not use_hybrid_sliding: | ||||||
| # Construct sliding window mask manually (all-unfused or vision batch) | ||||||
| window_size = self.text_config.sliding_window - 1 | ||||||
| seq_len = sliding_attention.shape[-1] | ||||||
| window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=sliding_attention.device) | ||||||
| window_mask = ~torch.triu(window_mask, diagonal=-window_size) | ||||||
| sliding_attention = sliding_attention | window_mask | ||||||
| if mm_token_type_ids is not None: | ||||||
| is_vision = mm_token_type_ids > 0 | ||||||
| is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1) | ||||||
|
|
@@ -652,6 +662,8 @@ def _update_attention_mask(self, attention_mask, mm_token_type_ids): | |||||
| full_attention = full_attention & ~same_vision_group | ||||||
| attention_mask['sliding_attention'] = sliding_attention | ||||||
| attention_mask['full_attention'] = full_attention | ||||||
| # Signal sliding layers to override attn_mask_type to arbitrary in hybrid mode with vision | ||||||
| attention_mask['_sliding_override_arbitrary'] = self.hybrid_attention_mode and mm_token_type_ids is not None | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the
Suggested change
|
||||||
|
|
||||||
| def _pack_pp_output(self, hidden_states, per_layer_inputs, shared_kv_states): | ||||||
| per_layer_inputs = per_layer_inputs.view(*hidden_states.shape[:2], -1) | ||||||
|
|
@@ -767,8 +779,12 @@ def __init__(self, config, submodules, *args, **kwargs): | |||||
| TENorm, hidden_size=hidden_size, config=self.config, eps=eps) | ||||||
|
|
||||||
| def _forward_attention(self, hidden_states: Tensor, **kwargs): | ||||||
| attn_mask_dict = kwargs['attention_mask'] | ||||||
| kwargs['rotary_pos_emb'] = kwargs['rotary_pos_emb'][self.layer_type] | ||||||
| kwargs['attention_mask'] = kwargs['attention_mask'][self.layer_type] | ||||||
| kwargs['attention_mask'] = attn_mask_dict[self.layer_type] | ||||||
| # In hybrid mode with vision, sliding layers need to override to arbitrary (unfused fallback) | ||||||
| if attn_mask_dict.get('_sliding_override_arbitrary') and self.layer_type == 'sliding_attention': | ||||||
| kwargs['override_attn_mask_type'] = AttnMaskType.arbitrary | ||||||
| context = kwargs.pop('context', None) | ||||||
| residual = hidden_states | ||||||
| input_layernorm_output = self.input_layernorm(hidden_states) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To prevent potential
AttributeErrorifconfig.attention_backendisNoneor does not have anameattribute (e.g., in certain testing or custom configurations), usegetattrto safely retrieve the backend name.