Skip to content

Comments

feat: use ray.put to reduce overhead of repeated arguments#1944

Open
guyueh1 wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
guyueh1:ray_put
Open

feat: use ray.put to reduce overhead of repeated arguments#1944
guyueh1 wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
guyueh1:ray_put

Conversation

@guyueh1
Copy link
Contributor

@guyueh1 guyueh1 commented Feb 14, 2026

What does this PR do ?

In worker group, we sometimes pass the same argument to multiple workers. Use ray.put to avoid the overhead of serialization when the argument size is large.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • Chores
    • Enhanced internal distributed worker processing with conditional optimization for better resource utilization in remote execution scenarios.

Signed-off-by: Guyue Huang <guyueh@nvidia.com>
@guyueh1 guyueh1 requested a review from a team as a code owner February 14, 2026 06:36
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 14, 2026

📝 Walkthrough

Walkthrough

Adds optional Ray remote object reference conversion for worker arguments when the NRL_WG_USE_RAY_REF environment variable is enabled. Both single-data and sharded-data execution paths convert positional and keyword arguments via ray.put before worker dispatch. The sharded-data path additionally implements replication-aware preprocessing based on replicate_degrees.

Changes

Cohort / File(s) Summary
Ray Reference Conversion
nemo_rl/distributed/worker_groups.py
Adds conditional Ray object reference conversion in run_all_workers_single_data and run_all_workers_sharded_data functions. When NRL_WG_USE_RAY_REF environment variable is set, arguments and keyword arguments are converted to Ray references via ray.put before worker dispatch. Sharded-data path includes replication-aware preprocessing that computes replicate_degrees and applies conversions only when replication is greater than 1.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Merge Conflict Detection ⚠️ Warning ❌ Merge conflicts detected (14 files):

⚔️ docker/Dockerfile (content)
⚔️ docs/guides/use-custom-vllm.md (content)
⚔️ examples/configs/grpo_math_1B.yaml (content)
⚔️ examples/configs/vlm_grpo_3B.yaml (content)
⚔️ examples/configs/vlm_grpo_3B_megatron.yaml (content)
⚔️ examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml (content)
⚔️ nemo_rl/algorithms/grpo.py (content)
⚔️ nemo_rl/distributed/worker_groups.py (content)
⚔️ nemo_rl/environments/nemo_gym.py (content)
⚔️ nemo_rl/models/generation/vllm/vllm_worker.py (content)
⚔️ tests/functional/grpo_non_colocated.sh (content)
⚔️ tests/unit/algorithms/test_grpo.py (content)
⚔️ tests/unit/environments/test_nemo_gym.py (content)
⚔️ tools/build-custom-vllm.sh (content)

These conflicts must be resolved before merging into main.
Resolve conflicts locally and push changes to this branch.
Test Results For Major Changes ⚠️ Warning PR introduces a performance optimization using ray.put() but lacks completed tests, documentation, performance benchmarks, and contains critical bugs requiring fixes. Complete tests validating ray.put functionality, provide before-and-after performance benchmarks, fix critical bug in lines 865-869, and finish placeholder documentation items.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: using ray.put to reduce serialization overhead of repeated arguments passed to multiple workers.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
⚔️ Resolve merge conflicts (beta)
  • Auto-commit resolved conflicts to branch ray_put
  • Post resolved changes as copyable diffs in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@nemo_rl/distributed/worker_groups.py`:
- Around line 865-869: The loop is wrong: it reassigns _args each iteration and
then appends the list to itself, creating a circular structure; fix by using a
new outer list (e.g., new_args = []) and for each element `arg` create an inner
list `inner = [ray.put(a) for a in arg]`, then `new_args.append(inner)`, and
after the loop set `args = tuple(new_args)` (preserving the original intent of
converting per-arg lists to a tuple of lists); ensure you do not reuse `_args`
as both the accumulator and the inner list and keep using `ray.put` for each
item.
🧹 Nitpick comments (3)
nemo_rl/distributed/worker_groups.py (3)

779-781: args conversion on line 780 is dead code — the assert on line 773 guarantees args is empty.

The ray.put conversion for args will never execute on any actual data because the assertion at line 773 enforces len(args) == 0. The kwargs conversion is the only useful part here. This isn't harmful but is misleading.

Suggested fix
         if os.getenv("NRL_WG_USE_RAY_REF", "0") == "1":
-            args = [ray.put(arg) for arg in args]
             kwargs = {key: ray.put(value) for key, value in kwargs.items()}

860-874: The ray.put optimization only applies when replicate_degrees > 1, but sharded axes also send the same slice to replicated workers — consider whether the optimization should apply more broadly.

Currently, the ray.put conversion in run_all_workers_sharded_data is gated on replicate_degrees > 1. When replicate_degrees == 1 (no replication axes), no ray.put is performed on kwargs even though the same data slice may still be dispatched to multiple workers (e.g., via in_sharded_axes logic at lines 923-930). Consider whether ray.put should also apply to common_kwargs and to the kwargs/args after slicing, not just before.

Also, unlike run_all_workers_single_data, common_kwargs is never converted via ray.put in either path, which means the same common data is re-serialized per worker call.


779-781: Environment-variable–driven feature toggle is fragile for production use.

Using os.getenv("NRL_WG_USE_RAY_REF", "0") == "1" as a feature gate is acceptable for a POC, but for a production path consider promoting this to a constructor parameter or a configuration option so it is explicit, testable, and discoverable. The env var name is also not documented anywhere.

Also applies to: 864-864

Comment on lines +865 to +869
_args = []
for arg in args:
_args = [ray.put(a) for a in arg]
_args.append(_args)
args = tuple(_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical bug: _args is reassigned each iteration and then appended to itself, creating a circular reference.

There are two defects in this loop:

  1. Line 867 reassigns _args on every iteration instead of using a separate inner variable, so all previous iterations' results are discarded.
  2. Line 868 _args.append(_args) appends the list to itself, creating an infinite recursive data structure instead of appending the inner list.

Additionally (same as the other call site), args is always empty due to the assert on line 840, so this entire block is dead code today — but if args support is ever re-enabled, this will corrupt data.

Proposed fix (assuming args could be non-empty in the future)
         if replicate_degrees > 1:
             if os.getenv("NRL_WG_USE_RAY_REF", "0") == "1":
                 _args = []
                 for arg in args:
-                    _args = [ray.put(a) for a in arg]
-                    _args.append(_args)
+                    _args.append([ray.put(a) for a in arg])
                 args = tuple(_args)
                 _kwargs = dict()
                 for key, value in kwargs.items():
🤖 Prompt for AI Agents
In `@nemo_rl/distributed/worker_groups.py` around lines 865 - 869, The loop is
wrong: it reassigns _args each iteration and then appends the list to itself,
creating a circular structure; fix by using a new outer list (e.g., new_args =
[]) and for each element `arg` create an inner list `inner = [ray.put(a) for a
in arg]`, then `new_args.append(inner)`, and after the loop set `args =
tuple(new_args)` (preserving the original intent of converting per-arg lists to
a tuple of lists); ensure you do not reuse `_args` as both the accumulator and
the inner list and keep using `ray.put` for each item.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants