Skip to content

Extending scanned to unscanned support for GPT-OSS like models.#1189

Draft
NicoGrande wants to merge 1 commit intomainfrom
nicogrande/improve-scanned-unrolling
Draft

Extending scanned to unscanned support for GPT-OSS like models.#1189
NicoGrande wants to merge 1 commit intomainfrom
nicogrande/improve-scanned-unrolling

Conversation

@NicoGrande
Copy link
Collaborator

Adding support to transfer_state_directly to handle models with multiple scanned blocks.

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@gemini-code-assist
Copy link

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly improves the transfer_state_directly utility by enabling it to correctly handle weight transfers for large language models, particularly those with complex scanned layer structures like GPT-OSS models. It introduces a flexible mechanism to define how layers are mapped from a scanned source to an unscanned destination, which is crucial for maintaining model integrity during state transfer. This enhancement ensures broader compatibility and more robust weight management within the system.

Highlights

  • Enhanced transfer_state_directly function: The transfer_state_directly utility now supports models with multiple scanned blocks, specifically addressing GPT-OSS like architectures where source layers might be structured with a non-axis-0 scan dimension.
  • New scan_group_ordering parameter: A new parameter, scan_group_ordering, has been introduced to control how layers are mapped from scanned source models to unscanned target models. It supports 'sequential' (default) and 'interleaved' ordering, allowing flexible weight transfer strategies.
  • VLLM Integration: The new scan_group_ordering functionality has been integrated into VllmConfig and RolloutConfig, ensuring that the VLLM sampler and rollout engine can correctly handle and transfer weights from models with various scanned layer configurations.
  • Comprehensive Unit Tests: New unit tests have been added to validate the correct behavior of transfer_state_directly for both sequential and interleaved scan group orderings in GPT-OSS like scanned formats.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • tests/generate/utils_test.py
    • Added test_transfer_state_directly_gpt_oss_scanned_format to verify sequential scan group ordering for GPT-OSS models.
    • Added test_transfer_state_directly_gpt_oss_scanned_format_interleaved to verify interleaved scan group ordering for GPT-OSS models.
  • tunix/generate/utils.py
    • Updated transfer_state_directly function signature to include scan_group_ordering parameter.
    • Expanded docstring for transfer_state_directly to explain scan_group_ordering options ('sequential' and 'interleaved').
    • Implemented logic within intersect_trees to detect and correctly slice parameters from GPT-OSS scanned format sources based on the specified scan_group_ordering.
  • tunix/generate/vllm_sampler.py
    • Added scan_group_ordering field to VllmConfig with a default value of 'sequential'.
    • Modified update_params method to pass the scan_group_ordering from VllmConfig to utils.transfer_state_directly.
  • tunix/rl/rollout/base_rollout.py
    • Introduced rollout_vllm_scan_group_ordering field to RolloutConfig with a default of 'sequential' and a descriptive docstring.
  • tunix/rl/rollout/vllm_rollout.py
    • Updated VllmRollout constructor to pass rollout_vllm_scan_group_ordering from RolloutConfig to the VllmSampler's VllmConfig.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends the transfer_state_directly utility to support models with multiple scanned blocks, such as the GPT-OSS format. The changes introduce a scan_group_ordering parameter to handle both 'sequential' and 'interleaved' layer mappings. The implementation is well-tested with new unit tests covering both ordering scenarios. My main feedback is to refactor the new complex logic block in tunix/generate/utils.py into a helper function to improve readability and maintainability, in line with the repository's style guide.

Comment on lines +975 to +1017
# Candidate C: GPT-OSS scanned format — source has layers.layers_Y with
# a non-axis-0 scan dim. Count total groups, then map layer_idx to
# (src_group_idx, local_idx) based on scan_group_ordering.
prefix = key_tuple[:match_index]
suffix = key_tuple[match_index + 1:]

# Count available scan groups
n_groups = 0
while prefix + ('layers', f'layers_{n_groups}') + suffix in src_flat:
n_groups += 1

if n_groups > 0:
first_val = src_flat[prefix + ('layers', 'layers_0') + suffix]
if hasattr(first_val, 'shape') and hasattr(tgt_val, 'shape'):
if len(first_val.shape) == len(tgt_val.shape) + 1:
for axis in range(len(first_val.shape)):
remaining = first_val.shape[:axis] + first_val.shape[axis + 1:]
# Check if removing this axis matches the target shape
if remaining == tgt_val.shape:
n_scan_steps = first_val.shape[axis]
# For interleaved ordering, layer_idx 0, n_groups, 2*n_groups... go to group 0;
# layer_idx 1, n_groups+1, 2*n_groups+1... go to group 1, etc.
if scan_group_ordering == 'interleaved':
src_group_idx = layer_idx % n_groups
local_idx = layer_idx // n_groups
# For sequential ordering, layers 0-N in group 0, then N+1-2N in group 1, etc.
else:
src_group_idx = layer_idx // n_scan_steps
local_idx = layer_idx % n_scan_steps
# We only slice if the calculated group index and local index are within bounds
if src_group_idx < n_groups and local_idx < n_scan_steps:
candidate_c = prefix + ('layers', f'layers_{src_group_idx}') + suffix
src_val = src_flat[candidate_c]
sliced_val = _slice_scanned_param(
src_val, tgt_val, local_idx, str(key_tuple)
)
sliced_val = _apply_dtype_cast(
sliced_val, tgt_val.dtype, str(key_tuple)
)
filtered_src_flat[key_tuple] = sliced_val
filtered_tgt_flat[key_tuple] = tgt_val
break

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new logic block for "Candidate C" is quite long and complex, with multiple levels of nesting. According to the repository style guide, code should be readable and maintainable (lines 11-12).

To improve readability and maintainability, I suggest refactoring this block into a separate helper function. This would encapsulate the logic for handling the GPT-OSS scanned format, making the intersect_trees function easier to follow.

A helper function could look something like this:

def _map_gpt_oss_scanned_layer(
    key_tuple: tuple[str, ...],
    tgt_val: Any,
    src_flat: dict[tuple[str, ...], Any],
    layer_idx: int,
    match_index: int,
    scan_group_ordering: str,
) -> tuple[Any | None, Any | None]:
    """Attempts to map a target layer to a source scanned GPT-OSS-style layer."""
    # ... implementation from lines 978-1016 ...
    # On success, return (sliced_val, tgt_val)
    # On failure, return (None, None)

Then, in intersect_trees, you could simplify the main loop:

        # ... after checking candidates A and B ...
        sliced_val, mapped_tgt_val = _map_gpt_oss_scanned_layer(...)
        if sliced_val is not None:
            filtered_src_flat[key_tuple] = sliced_val
            filtered_tgt_flat[key_tuple] = mapped_tgt_val
            continue

This change would make the control flow of intersect_trees much clearer.

Additionally, when refactoring, you could clarify the inline comments (lines 995-996 and 1000-1001) to be more descriptive, similar to the excellent comment added in tunix/rl/rollout/base_rollout.py (lines 161-166).

References
  1. Code should be easy to understand for all maintainers and users. Code should be easy to modify and extend. (link)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants