Skip to content

fix(rl): flush final partial micro-batch to avoid sample drop#1179

Open
pjo256 wants to merge 2 commits intogoogle:mainfrom
pjo256:rl-microbatch-tail-flush
Open

fix(rl): flush final partial micro-batch to avoid sample drop#1179
pjo256 wants to merge 2 commits intogoogle:mainfrom
pjo256:rl-microbatch-tail-flush

Conversation

@pjo256
Copy link

@pjo256 pjo256 commented Mar 1, 2026

Changes

This PR fixes a data-loss bug in RL micro-batching. _create_micro_batch_iterator only emitted full micro-batches (e.g. rl_learner.py:503) and silently dropped the final < micro_batch_size remainder.

Added unit tests covering final partial micro-batch handling - these fail on main.

Note: _create_micro_batch_iterator seems to be mostly duplicated in RL and agentic learners. I considered extracting a shared helper, but kept this PR small as a first contribution.

Validation:

  • python -m pytest tests/rl/rl_learner_test.py -q
  • python -m pytest tests/rl/experimental/agentic_grpo_learner_test.py -q

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

Signed-off-by: Philip Ottesen <phiott256@gmail.com>
@gemini-code-assist
Copy link

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical data integrity issue within the reinforcement learning micro-batching mechanism. By ensuring that all samples, including those forming incomplete final batches, are properly processed, it prevents potential data loss and improves the robustness of the learning algorithms. The changes enhance the reliability of training by guaranteeing that no valuable data points are inadvertently discarded.

Highlights

  • Bug Fix: Resolved a data-loss bug in the RL micro-batching process where final partial micro-batches were silently dropped, leading to incomplete sample processing.
  • Micro-batch Iterator Enhancement: Modified the _create_micro_batch_iterator function in both rl_learner.py and agentic_rl_learner.py to ensure that any remaining samples in the buffer are yielded as a final partial micro-batch.
  • New Unit Tests: Added comprehensive unit tests in rl_learner_test.py and agentic_grpo_learner_test.py to specifically cover the handling of final partial micro-batches, confirming the fix and preventing regressions.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • tests/rl/experimental/agentic_grpo_learner_test.py
    • Added a new test case test_create_micro_batch_iterator_preserves_tail_samples to verify that the micro-batch iterator correctly handles and yields partial batches at the end of an iteration.
  • tests/rl/rl_learner_test.py
    • Imported the numpy library.
    • Added a new parameterized test test_create_micro_batch_iterator to validate the micro-batch iterator's behavior with both exact and partial batch scenarios, ensuring all samples are processed.
  • tunix/rl/experimental/agentic_rl_learner.py
    • Modified the _create_micro_batch_iterator function to include a check for any remaining samples in the buffer after the main loop, ensuring they are yielded as a final partial micro-batch to prevent data loss.
  • tunix/rl/rl_learner.py
    • Modified the _create_micro_batch_iterator function to include a check for any remaining samples in the buffer after the main loop, ensuring they are yielded as a final partial micro-batch to prevent data loss.
Activity
  • The author, pjo256, has added all necessary unit tests for the change.
  • The author has verified that the change does not break existing code and all unit tests pass.
  • The author has added all appropriate doc-strings/documentation.
  • The author has ensured the PR is based on the latest changes of the main branch.
  • The author has signed the Contributor License Agreement.
  • The author has followed the Contribution Guidelines.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a data loss bug in the micro-batching logic by ensuring the final partial batch is processed. The fix is correct and has been applied to both RLLearner and AgenticRLLearner. I appreciate that you've added comprehensive unit tests for both modules, which clearly demonstrate the issue and verify the fix. Your note about the code duplication in _create_micro_batch_iterator is well-taken; deferring the refactor to a separate change is a reasonable approach. I've added a couple of minor suggestions to improve the readability of the new tests.

Signed-off-by: Philip Ottesen <phiott256@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant