Extending scanned to unscanned support for GPT-OSS like models.#1189
Extending scanned to unscanned support for GPT-OSS like models.#1189NicoGrande wants to merge 1 commit intomainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly improves the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request extends the transfer_state_directly utility to support models with multiple scanned blocks, such as the GPT-OSS format. The changes introduce a scan_group_ordering parameter to handle both 'sequential' and 'interleaved' layer mappings. The implementation is well-tested with new unit tests covering both ordering scenarios. My main feedback is to refactor the new complex logic block in tunix/generate/utils.py into a helper function to improve readability and maintainability, in line with the repository's style guide.
| # Candidate C: GPT-OSS scanned format — source has layers.layers_Y with | ||
| # a non-axis-0 scan dim. Count total groups, then map layer_idx to | ||
| # (src_group_idx, local_idx) based on scan_group_ordering. | ||
| prefix = key_tuple[:match_index] | ||
| suffix = key_tuple[match_index + 1:] | ||
|
|
||
| # Count available scan groups | ||
| n_groups = 0 | ||
| while prefix + ('layers', f'layers_{n_groups}') + suffix in src_flat: | ||
| n_groups += 1 | ||
|
|
||
| if n_groups > 0: | ||
| first_val = src_flat[prefix + ('layers', 'layers_0') + suffix] | ||
| if hasattr(first_val, 'shape') and hasattr(tgt_val, 'shape'): | ||
| if len(first_val.shape) == len(tgt_val.shape) + 1: | ||
| for axis in range(len(first_val.shape)): | ||
| remaining = first_val.shape[:axis] + first_val.shape[axis + 1:] | ||
| # Check if removing this axis matches the target shape | ||
| if remaining == tgt_val.shape: | ||
| n_scan_steps = first_val.shape[axis] | ||
| # For interleaved ordering, layer_idx 0, n_groups, 2*n_groups... go to group 0; | ||
| # layer_idx 1, n_groups+1, 2*n_groups+1... go to group 1, etc. | ||
| if scan_group_ordering == 'interleaved': | ||
| src_group_idx = layer_idx % n_groups | ||
| local_idx = layer_idx // n_groups | ||
| # For sequential ordering, layers 0-N in group 0, then N+1-2N in group 1, etc. | ||
| else: | ||
| src_group_idx = layer_idx // n_scan_steps | ||
| local_idx = layer_idx % n_scan_steps | ||
| # We only slice if the calculated group index and local index are within bounds | ||
| if src_group_idx < n_groups and local_idx < n_scan_steps: | ||
| candidate_c = prefix + ('layers', f'layers_{src_group_idx}') + suffix | ||
| src_val = src_flat[candidate_c] | ||
| sliced_val = _slice_scanned_param( | ||
| src_val, tgt_val, local_idx, str(key_tuple) | ||
| ) | ||
| sliced_val = _apply_dtype_cast( | ||
| sliced_val, tgt_val.dtype, str(key_tuple) | ||
| ) | ||
| filtered_src_flat[key_tuple] = sliced_val | ||
| filtered_tgt_flat[key_tuple] = tgt_val | ||
| break | ||
|
|
There was a problem hiding this comment.
This new logic block for "Candidate C" is quite long and complex, with multiple levels of nesting. According to the repository style guide, code should be readable and maintainable (lines 11-12).
To improve readability and maintainability, I suggest refactoring this block into a separate helper function. This would encapsulate the logic for handling the GPT-OSS scanned format, making the intersect_trees function easier to follow.
A helper function could look something like this:
def _map_gpt_oss_scanned_layer(
key_tuple: tuple[str, ...],
tgt_val: Any,
src_flat: dict[tuple[str, ...], Any],
layer_idx: int,
match_index: int,
scan_group_ordering: str,
) -> tuple[Any | None, Any | None]:
"""Attempts to map a target layer to a source scanned GPT-OSS-style layer."""
# ... implementation from lines 978-1016 ...
# On success, return (sliced_val, tgt_val)
# On failure, return (None, None)Then, in intersect_trees, you could simplify the main loop:
# ... after checking candidates A and B ...
sliced_val, mapped_tgt_val = _map_gpt_oss_scanned_layer(...)
if sliced_val is not None:
filtered_src_flat[key_tuple] = sliced_val
filtered_tgt_flat[key_tuple] = mapped_tgt_val
continueThis change would make the control flow of intersect_trees much clearer.
Additionally, when refactoring, you could clarify the inline comments (lines 995-996 and 1000-1001) to be more descriptive, similar to the excellent comment added in tunix/rl/rollout/base_rollout.py (lines 161-166).
References
- Code should be easy to understand for all maintainers and users. Code should be easy to modify and extend. (link)
Adding support to
transfer_state_directlyto handle models with multiple scanned blocks.Checklist