feat: use ray.put to reduce overhead of repeated arguments#1944
feat: use ray.put to reduce overhead of repeated arguments#1944guyueh1 wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Conversation
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
📝 WalkthroughWalkthroughAdds optional Ray remote object reference conversion for worker arguments when the Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@nemo_rl/distributed/worker_groups.py`:
- Around line 865-869: The loop is wrong: it reassigns _args each iteration and
then appends the list to itself, creating a circular structure; fix by using a
new outer list (e.g., new_args = []) and for each element `arg` create an inner
list `inner = [ray.put(a) for a in arg]`, then `new_args.append(inner)`, and
after the loop set `args = tuple(new_args)` (preserving the original intent of
converting per-arg lists to a tuple of lists); ensure you do not reuse `_args`
as both the accumulator and the inner list and keep using `ray.put` for each
item.
🧹 Nitpick comments (3)
nemo_rl/distributed/worker_groups.py (3)
779-781:argsconversion on line 780 is dead code — the assert on line 773 guaranteesargsis empty.The
ray.putconversion forargswill never execute on any actual data because the assertion at line 773 enforceslen(args) == 0. Thekwargsconversion is the only useful part here. This isn't harmful but is misleading.Suggested fix
if os.getenv("NRL_WG_USE_RAY_REF", "0") == "1": - args = [ray.put(arg) for arg in args] kwargs = {key: ray.put(value) for key, value in kwargs.items()}
860-874: Theray.putoptimization only applies whenreplicate_degrees > 1, but sharded axes also send the same slice to replicated workers — consider whether the optimization should apply more broadly.Currently, the
ray.putconversion inrun_all_workers_sharded_datais gated onreplicate_degrees > 1. Whenreplicate_degrees == 1(no replication axes), noray.putis performed onkwargseven though the same data slice may still be dispatched to multiple workers (e.g., viain_sharded_axeslogic at lines 923-930). Consider whetherray.putshould also apply tocommon_kwargsand to the kwargs/args after slicing, not just before.Also, unlike
run_all_workers_single_data,common_kwargsis never converted viaray.putin either path, which means the same common data is re-serialized per worker call.
779-781: Environment-variable–driven feature toggle is fragile for production use.Using
os.getenv("NRL_WG_USE_RAY_REF", "0") == "1"as a feature gate is acceptable for a POC, but for a production path consider promoting this to a constructor parameter or a configuration option so it is explicit, testable, and discoverable. The env var name is also not documented anywhere.Also applies to: 864-864
| _args = [] | ||
| for arg in args: | ||
| _args = [ray.put(a) for a in arg] | ||
| _args.append(_args) | ||
| args = tuple(_args) |
There was a problem hiding this comment.
Critical bug: _args is reassigned each iteration and then appended to itself, creating a circular reference.
There are two defects in this loop:
- Line 867 reassigns
_argson every iteration instead of using a separate inner variable, so all previous iterations' results are discarded. - Line 868
_args.append(_args)appends the list to itself, creating an infinite recursive data structure instead of appending the inner list.
Additionally (same as the other call site), args is always empty due to the assert on line 840, so this entire block is dead code today — but if args support is ever re-enabled, this will corrupt data.
Proposed fix (assuming args could be non-empty in the future)
if replicate_degrees > 1:
if os.getenv("NRL_WG_USE_RAY_REF", "0") == "1":
_args = []
for arg in args:
- _args = [ray.put(a) for a in arg]
- _args.append(_args)
+ _args.append([ray.put(a) for a in arg])
args = tuple(_args)
_kwargs = dict()
for key, value in kwargs.items():🤖 Prompt for AI Agents
In `@nemo_rl/distributed/worker_groups.py` around lines 865 - 869, The loop is
wrong: it reassigns _args each iteration and then appends the list to itself,
creating a circular structure; fix by using a new outer list (e.g., new_args =
[]) and for each element `arg` create an inner list `inner = [ray.put(a) for a
in arg]`, then `new_args.append(inner)`, and after the loop set `args =
tuple(new_args)` (preserving the original intent of converting per-arg lists to
a tuple of lists); ensure you do not reuse `_args` as both the accumulator and
the inner list and keep using `ray.put` for each item.
What does this PR do ?
In worker group, we sometimes pass the same argument to multiple workers. Use ray.put to avoid the overhead of serialization when the argument size is large.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit