Conversation
Signed-off-by: Zhanda <zhandazhu@gmail.com>
ℹ️ File Consistency CheckCheck based on commit: d2eac7a (PR #1938 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
d2eac7a to
ac0897f
Compare
ℹ️ File Consistency CheckCheck based on commit: ac0897f (PR #1938 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
📝 WalkthroughWalkthroughThis PR introduces sampling-aware log-probability computation across the training pipeline, enabling top-k/top-p filtering with gradient support through new autograd functions. It threads a Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant PolicyWorker
participant LossFunction
participant DistributedUtils
participant Sampling
Trainer->>PolicyWorker: train(data, sampling_params)
PolicyWorker->>Sampling: Initialize sampling_params from config
PolicyWorker->>DistributedUtils: forward_step with sampling_params
DistributedUtils->>DistributedUtils: Compute logits from model
DistributedUtils->>Sampling: apply_top_k_top_p if filtering enabled
Sampling->>Sampling: Filter logits, create keep_mask
DistributedUtils->>DistributedUtils: compute_logprobs_from_logits
DistributedUtils->>LossFunction: Pass logits, logprobs, sampling_params
LossFunction->>LossFunction: Apply masking for invalid positions
LossFunction->>LossFunction: Use unfiltered logprobs for KL calculations
LossFunction->>LossFunction: Use filtered logprobs for actor loss
LossFunction-->>Trainer: loss, metrics
Trainer->>Trainer: Backward pass through sampling autograd functions
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
ℹ️ File Consistency CheckCheck based on commit: 1d5f7c0 (PR #1938 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (7)
tests/unit/test_utils.py (2)
26-38:⚠️ Potential issue | 🟡 MinorSilence unused
sampling_paramsto satisfy Ruff.Ruff flags this as unused (ARG002). Add an explicit
delto keep the signature but avoid lint noise.💡 Suggested fix
) -> tuple[torch.Tensor, dict[str, Any]]: + del sampling_params # unused test hook; keeps signature aligned # Just return mean of logprobs as the loss for testing loss = next_token_logits.mean()
51-62:⚠️ Potential issue | 🟡 MinorSilence unused
sampling_paramsto satisfy Ruff.Same ARG002 warning here; explicitly delete the argument.
💡 Suggested fix
) -> tuple[torch.Tensor, dict[str, Any]]: + del sampling_params # unused test hook; keeps signature aligned # logits shape: [batch_size, seq_len, vocab_size]nemo_rl/distributed/model_utils.py (1)
1-1:⚠️ Potential issue | 🟡 MinorUpdate the NVIDIA copyright year to 2026.
As per coding guidelines, "{src/,examples/,nemo_rl/**}/*.{py,sh}: Add the NVIDIA copyright header (with current year) to all Python files and shell scripts, excluding tests (files under
tests/or test-only scripts)".nemo_rl/algorithms/loss_functions.py (2)
1-1:⚠️ Potential issue | 🟡 MinorUpdate the NVIDIA copyright year to 2026.
As per coding guidelines, "{src/,examples/,nemo_rl/**}/*.{py,sh}: Add the NVIDIA copyright header (with current year) to all Python files and shell scripts, excluding tests (files under
tests/or test-only scripts)".
391-399:⚠️ Potential issue | 🟠 MajorUse the updated
maskfor sequence-level ratio averages.After masking out
-infpositions,token_maskstill includes those tokens. This can skew sequence-level ratios under top‑k/top‑p mismatch scenarios. Use the already-updatedmaskinstead.🛠️ Suggested fix
- seq_log_ratio_mean = masked_mean( - log_ratios, - token_mask, - dim=-1, - ).unsqueeze(-1) + seq_log_ratio_mean = masked_mean( + log_ratios, + mask, + dim=-1, + ).unsqueeze(-1)nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
1-1:⚠️ Potential issue | 🟡 MinorUpdate the NVIDIA copyright year to 2026.
As per coding guidelines, "{src/,examples/,nemo_rl/**}/*.{py,sh}: Add the NVIDIA copyright header (with current year) to all Python files and shell scripts, excluding tests (files under
tests/or test-only scripts)".nemo_rl/models/policy/workers/megatron_policy_worker.py (1)
722-742:⚠️ Potential issue | 🟠 Major
saved_sampling_paramsmay be unbound in thefinallyblock if an earlier line throws.If
self.model.load_state_dict(...)on Line 717 raises beforesaved_sampling_paramsis assigned on Line 722, thefinallyblock on Line 742 will hitUnboundLocalError, masking the original exception.Move the assignment before any code that can fail inside the
try:Proposed fix
with torch.no_grad(): try: + saved_sampling_params = self.sampling_params + # Save original references model_state_dict = {} for name, item in self.model.state_dict().items(): @@ -719,8 +721,7 @@ # self.model.state_dict()[name] = item.detach().to(device="cuda", non_blocking=True, copy=True) - saved_sampling_params = self.sampling_params - if saved_sampling_params is not None: + if self.sampling_params is not None: self.sampling_params = TrainingSamplingParams( top_k=None, top_p=1.0,
🤖 Fix all issues with AI agents
In `@nemo_rl/distributed/model_utils.py`:
- Around line 1544-1559: The code moves input_ids to GPU with a hard-coded
.cuda() which breaks CPU or non-default-GPU runs; change the conversion of
input_ids[:, 1:] (used to produce next_tokens) to use the same device as the
logits tensor (e.g., next_token_logits_wo_last.device or
next_token_logprobs.device) instead of .cuda() so next_tokens and subsequent
gather operate on the correct device; update the expression that defines
next_tokens to call .to(<logits_device>) so device alignment is ensured for
next_token_logits_wo_last, next_token_logprobs, and token_logprobs.
In `@nemo_rl/models/generation/vllm/vllm_generation.py`:
- Around line 96-99: The validation for the top_p parameter in
vllm_generation.py currently only rejects top_p <= 0 but allows values > 1.0,
contradicting the documented valid range (0, 1]; update the check that inspects
the top_p variable so it raises a ValueError when top_p <= 0 or top_p > 1.0, and
keep the existing error message (or adjust it to mention the closed upper bound)
to reflect the contract; locate the validation near the top_p check in the VLLM
generation logic and modify that conditional to enforce both bounds.
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py`:
- Around line 339-355: Add Google-style docstrings to the two new helper methods
_apply_temperature_scaling and _apply_top_k_top_p_filtering: for each function
provide a one-line summary, a short description if needed, an Args section
describing parameters (e.g., logits: torch.Tensor and mention
self.sampling_params where relevant), a Returns section describing the returned
torch.Tensor, and any Notes or Raises if applicable (e.g., side effects from
in-place div_ in _apply_temperature_scaling). Place the docstrings immediately
under each def to match project Sphinx parsing and coding guidelines.
- Around line 245-251: The code is supplying non-None defaults for generation
config values; remove those defaults and read required keys directly from the
config so YAML owns defaults: when "generation" exists in self.cfg, assign
generation_cfg = self.cfg["generation"] and construct TrainingSamplingParams
using the raw values (e.g. generation_cfg["top_k"], generation_cfg["top_p"],
generation_cfg["temperature"]) instead of .get(..., default), and set
self.sampling_params accordingly; ensure any callers expect these keys to be
present (YAML should provide defaults) and do not introduce new in-code fallback
values.
- Around line 831-847: The finally block can raise UnboundLocalError if an
exception occurs before saved_sampling_params is set; initialize
saved_sampling_params = None before the try that modifies self.sampling_params
so the finally always has a defined variable. Update the context around the
sampling_params manipulation (reference symbols: saved_sampling_params,
self.sampling_params, TrainingSamplingParams) to set saved_sampling_params =
None immediately before entering the try, then proceed with the existing logic
and restoration in the finally.
In `@nemo_rl/models/policy/workers/dtensor_policy_worker.py`:
- Around line 179-187: The generation config access uses self.cfg before it's
assigned, causing AttributeError in DTensorPolicyWorker __init__; update the
block that constructs self.sampling_params to read from the local config
parameter (use config or config["generation"]) or move this entire
generation-handling block to after the assignment to self.cfg so
TrainingSamplingParams (top_k/top_p/temperature) is created from the
already-initialized config; target the DTensorPolicyWorker constructor and the
sampling_params/TrainingSamplingParams logic to fix the reference.
- Around line 1663-1684: The try/finally can raise UnboundLocalError because
saved_sampling_params is assigned after operations; move the
saved_sampling_params = self.sampling_params assignment to the very start of the
try block (before any state-dict or model manipulation) so it is always defined
for the finally, then use that saved value when setting self.sampling_params to
the temporary TrainingSamplingParams (or None) and restore it in the finally;
reference the variables/methods: self.sampling_params, saved_sampling_params,
TrainingSamplingParams, and the surrounding try/finally in
dtensor_policy_worker.py.
In `@tests/unit/distributed/test_model_utils.py`:
- Around line 1157-1177: Rename the global constant
SAMPLING_PARAMS_TEST_ACTOR_FQN to follow the G_ prefix convention (e.g.,
G_SAMPLING_PARAMS_TEST_ACTOR_FQN) and update all references in this test block:
the fixture register_sampling_params_test_actor, the ACTOR_ENVIRONMENT_REGISTRY
key accesses, and any use with PY_EXECUTABLES.SYSTEM; ensure the
original_registry_value capture and cleanup logic remain unchanged but use the
new constant name everywhere.
🧹 Nitpick comments (3)
nemo_rl/models/policy/utils.py (1)
170-269: Verifylogits_sort.scatterwith self as source behaves correctly.Line 254:
chunk_filtered = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)useslogits_sortas both the input tensor and thesrc. This works becausescattercreates a new output tensor (non-in-place variant), but it looks unusual. A brief inline comment clarifying that this "unscrambles" sorted values back to original vocab order via the argsort indices would improve readability.nemo_rl/models/policy/workers/megatron_policy_worker.py (1)
233-242: Non-None fallback defaults for config values.Lines 237–239 supply
top_p=1.0andtemperature=1.0via.get()when the keys are absent from the generation config. Per coding guidelines, YAML should be the single source of truth for configuration defaults. These happen to be the "disabled" sentinel values so the risk is low, but consider ensuring the exemplar YAML configs always definetop_k,top_p, andtemperatureexplicitly so the.get()fallbacks are never actually needed.As per coding guidelines: "YAML is the single source of truth for configuration defaults; do not set non-None defaults in code for configuration values."
nemo_rl/models/policy/workers/dtensor_policy_worker.py (1)
493-509: New helper methods are correct and consistent withdtensor_policy_worker_v2.py.
_apply_temperature_scalingand_apply_top_k_top_p_filteringmatch the pattern in the v2 worker. The filtering helper correctly guards withneed_top_k_filtering/need_top_p_filteringchecks before callingapply_top_k_top_p. Note that_apply_top_k_top_p_filteringuses separateneed_top_k_filteringandneed_top_p_filteringchecks while v2 usesneed_top_k_or_top_p_filtering— functionally equivalent but slightly inconsistent.Align with v2 worker's combined check
+from nemo_rl.models.policy.utils import ( + TrainingSamplingParams, + apply_top_k_top_p, + configure_dynamo_cache, + get_runtime_env_for_policy_worker, + need_top_k_or_top_p_filtering, + resolve_model_class, +) ... def _apply_top_k_top_p_filtering(self, logits: torch.Tensor) -> torch.Tensor: """Apply top-k and top-p filtering to the logits locally when TP is disabled.""" - if self.sampling_params is not None and ( - need_top_k_filtering(self.sampling_params.top_k) - or need_top_p_filtering(self.sampling_params.top_p) - ): + if self.sampling_params is not None and need_top_k_or_top_p_filtering( + self.sampling_params.top_k, self.sampling_params.top_p + ): logits, _ = apply_top_k_top_p(
| next_token_logits_wo_last = next_token_logits[ | ||
| :, :-1 | ||
| ] # Remove last position's logits | ||
| # Apply top-k and top-p filtering | ||
| next_token_logits_wo_last, _ = apply_top_k_top_p( | ||
| next_token_logits_wo_last, | ||
| top_k=sampling_params.top_k if sampling_params is not None else None, | ||
| top_p=sampling_params.top_p if sampling_params is not None else 1.0, | ||
| ) | ||
| next_token_logprobs = torch.nn.functional.log_softmax( | ||
| next_token_logits_wo_last, dim=-1 | ||
| ) | ||
| next_tokens = input_ids[:, 1:].cuda() # Skip first token | ||
| token_logprobs = next_token_logprobs.gather( | ||
| dim=-1, index=next_tokens.unsqueeze(-1) | ||
| ).squeeze(-1) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
wc -l nemo_rl/distributed/model_utils.pyRepository: NVIDIA-NeMo/RL
Length of output: 97
🏁 Script executed:
sed -n '1530,1570p' nemo_rl/distributed/model_utils.pyRepository: NVIDIA-NeMo/RL
Length of output: 1710
🏁 Script executed:
# Get more context around the function to understand device handling
sed -n '1500,1560p' nemo_rl/distributed/model_utils.pyRepository: NVIDIA-NeMo/RL
Length of output: 2819
🏁 Script executed:
# Find the function definition that contains these lines
sed -n '1400,1560p' nemo_rl/distributed/model_utils.py | head -100Repository: NVIDIA-NeMo/RL
Length of output: 3908
🏁 Script executed:
# Search for function definition around line 1544
ast-grep --pattern 'def $FUNC($_) { $$$ }'Repository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Check the function signature and initial device handling
sed -n '1478,1520p' nemo_rl/distributed/model_utils.pyRepository: NVIDIA-NeMo/RL
Length of output: 1969
🏁 Script executed:
# Search for other .cuda() usages in this file to see if there's a pattern
rg '\.cuda\(\)' nemo_rl/distributed/model_utils.py -nRepository: NVIDIA-NeMo/RL
Length of output: 128
🏁 Script executed:
# Check if input_ids device is explicitly managed elsewhere in the file
rg 'input_ids.*\.to\(' nemo_rl/distributed/model_utils.py -n -A 1Repository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Check where compute_logprobs_from_logits is called to understand device assumptions
rg 'compute_logprobs_from_logits' nemo_rl/ -A 2 -B 2Repository: NVIDIA-NeMo/RL
Length of output: 2106
🏁 Script executed:
# Check if input_ids is explicitly moved elsewhere before being passed to this function
rg 'compute_logprobs_from_logits' nemo_rl/ --type py -l | head -5 | xargs grep -l 'input_ids'Repository: NVIDIA-NeMo/RL
Length of output: 129
🏁 Script executed:
# Check the calls to compute_logprobs_from_logits in loss_functions.py
sed -n '1,100p' nemo_rl/algorithms/loss_functions.py | head -40Repository: NVIDIA-NeMo/RL
Length of output: 1466
🏁 Script executed:
# Find where compute_logprobs_from_logits is called and check device context
rg -B 5 'compute_logprobs_from_logits' nemo_rl/algorithms/loss_functions.py | head -60Repository: NVIDIA-NeMo/RL
Length of output: 1406
🏁 Script executed:
# Check what device next_token_logits is on when passed to this function
rg -B 10 'compute_logprobs_from_logits' nemo_rl/algorithms/loss_functions.py | grep -E '(device|cuda|\.to\(|next_token_logits)' | head -20Repository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Get context around the first call to compute_logprobs_from_logits in loss_functions.py
rg -B 20 'curr_logprobs = compute_logprobs_from_logits' nemo_rl/algorithms/loss_functions.py | head -40Repository: NVIDIA-NeMo/RL
Length of output: 750
🏁 Script executed:
# Check if input_ids is explicitly on CPU or if there's any device handling
rg 'data\["input_ids"\]' nemo_rl/algorithms/loss_functions.py -B 3 -A 3 | head -30Repository: NVIDIA-NeMo/RL
Length of output: 1264
🏁 Script executed:
# Check the parallel paths to see if they handle device differently
sed -n '1515,1540p' nemo_rl/distributed/model_utils.pyRepository: NVIDIA-NeMo/RL
Length of output: 1221
🏁 Script executed:
# Verify how input_ids is used in parallel paths to confirm device handling
sed -n '1520,1560p' nemo_rl/distributed/model_utils.pyRepository: NVIDIA-NeMo/RL
Length of output: 1879
🏁 Script executed:
# Check from_parallel_logits_to_logprobs to see if it handles device explicitly
rg -A 30 'def from_parallel_logits_to_logprobs' nemo_rl/distributed/model_utils.py | head -40Repository: NVIDIA-NeMo/RL
Length of output: 2227
Move input_ids to match the logits device instead of hard-coding .cuda().
Line 1556 forces execution to GPU 0, breaking CPU setups and non-default GPU allocation. Align the device with the logits tensor to ensure consistency across all code paths.
Suggested fix
- next_tokens = input_ids[:, 1:].cuda() # Skip first token
+ next_tokens = input_ids[:, 1:].to(next_token_logits_wo_last.device) # Skip first token🤖 Prompt for AI Agents
In `@nemo_rl/distributed/model_utils.py` around lines 1544 - 1559, The code moves
input_ids to GPU with a hard-coded .cuda() which breaks CPU or non-default-GPU
runs; change the conversion of input_ids[:, 1:] (used to produce next_tokens) to
use the same device as the logits tensor (e.g., next_token_logits_wo_last.device
or next_token_logprobs.device) instead of .cuda() so next_tokens and subsequent
gather operate on the correct device; update the expression that defines
next_tokens to call .to(<logits_device>) so device alignment is ensured for
next_token_logits_wo_last, next_token_logprobs, and token_logprobs.
| if top_p <= 0: | ||
| raise ValueError( | ||
| ( | ||
| f"top_p sampling with values < {TOP_P_THRESHOLD} is not supported because the vLLM V1 engine " | ||
| "does not return logprobs after top_p filtering. Values >= {TOP_P_THRESHOLD} are allowed " | ||
| "for token filtering purposes. If you understand the implications and still want to use " | ||
| f"a lower top_p value, please manually comment out this check. Got top_p={top_p}. " | ||
| "See https://github.com/NVIDIA-NeMo/RL/issues/69 for more details." | ||
| ) | ||
| f"top_p valid values: i) 1.0: no filtering. ii) (0, 1]: top-p filtering. Got top_p={top_p}." | ||
| ) |
There was a problem hiding this comment.
top_p > 1.0 passes validation but contradicts the documented valid range.
The error message states valid values are (0, 1], but the condition only rejects top_p <= 0. A value like 1.5 would be silently accepted. Consider also rejecting top_p > 1.0 for consistency with the documented contract.
Proposed fix
- if top_p <= 0:
+ if top_p <= 0 or top_p > 1.0:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if top_p <= 0: | |
| raise ValueError( | |
| ( | |
| f"top_p sampling with values < {TOP_P_THRESHOLD} is not supported because the vLLM V1 engine " | |
| "does not return logprobs after top_p filtering. Values >= {TOP_P_THRESHOLD} are allowed " | |
| "for token filtering purposes. If you understand the implications and still want to use " | |
| f"a lower top_p value, please manually comment out this check. Got top_p={top_p}. " | |
| "See https://github.com/NVIDIA-NeMo/RL/issues/69 for more details." | |
| ) | |
| f"top_p valid values: i) 1.0: no filtering. ii) (0, 1]: top-p filtering. Got top_p={top_p}." | |
| ) | |
| if top_p <= 0 or top_p > 1.0: | |
| raise ValueError( | |
| f"top_p valid values: i) 1.0: no filtering. ii) (0, 1]: top-p filtering. Got top_p={top_p}." | |
| ) |
🧰 Tools
🪛 Ruff (0.15.0)
[warning] 97-99: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@nemo_rl/models/generation/vllm/vllm_generation.py` around lines 96 - 99, The
validation for the top_p parameter in vllm_generation.py currently only rejects
top_p <= 0 but allows values > 1.0, contradicting the documented valid range (0,
1]; update the check that inspects the top_p variable so it raises a ValueError
when top_p <= 0 or top_p > 1.0, and keep the existing error message (or adjust
it to mention the closed upper bound) to reflect the contract; locate the
validation near the top_p check in the VLLM generation logic and modify that
conditional to enforce both bounds.
| if "generation" in self.cfg and self.cfg["generation"] is not None: | ||
| generation_cfg = self.cfg["generation"] | ||
| self.sampling_params = TrainingSamplingParams( | ||
| top_k=generation_cfg.get("top_k", None), | ||
| top_p=generation_cfg.get("top_p", 1.0), | ||
| temperature=generation_cfg.get("temperature", 1.0), | ||
| ) |
There was a problem hiding this comment.
Avoid setting non‑None defaults in code for generation config.
Defaults should live in YAML; use required keys here and let the config own defaults.
🛠️ Suggested fix
self.sampling_params = TrainingSamplingParams(
- top_k=generation_cfg.get("top_k", None),
- top_p=generation_cfg.get("top_p", 1.0),
- temperature=generation_cfg.get("temperature", 1.0),
+ top_k=generation_cfg.get("top_k"),
+ top_p=generation_cfg["top_p"],
+ temperature=generation_cfg["temperature"],
)🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py` around lines 245 -
251, The code is supplying non-None defaults for generation config values;
remove those defaults and read required keys directly from the config so YAML
owns defaults: when "generation" exists in self.cfg, assign generation_cfg =
self.cfg["generation"] and construct TrainingSamplingParams using the raw values
(e.g. generation_cfg["top_k"], generation_cfg["top_p"],
generation_cfg["temperature"]) instead of .get(..., default), and set
self.sampling_params accordingly; ensure any callers expect these keys to be
present (YAML should provide defaults) and do not introduce new in-code fallback
values.
| def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: | ||
| if self.sampling_params is not None and self.sampling_params.temperature != 1.0: | ||
| logits.div_(self.sampling_params.temperature) | ||
| return logits | ||
|
|
||
| def _apply_top_k_top_p_filtering(self, logits: torch.Tensor) -> torch.Tensor: | ||
| """Apply top-k and top-p filtering to the logits locally when TP is disabled.""" | ||
| sampling_params = self.sampling_params | ||
| if sampling_params is not None and need_top_k_or_top_p_filtering( | ||
| sampling_params.top_k, sampling_params.top_p | ||
| ): | ||
| logits, _ = apply_top_k_top_p( | ||
| logits, | ||
| top_k=sampling_params.top_k, | ||
| top_p=sampling_params.top_p, | ||
| ) | ||
| return logits |
There was a problem hiding this comment.
Add Google‑style docstrings for the new helper methods.
These are new functions; please document them in Google style for consistency.
✍️ Suggested docstrings
def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor:
+ """Apply temperature scaling to logits.
+
+ Args:
+ logits: Logits tensor to scale in-place.
+
+ Returns:
+ The scaled logits tensor.
+ """
if self.sampling_params is not None and self.sampling_params.temperature != 1.0:
logits.div_(self.sampling_params.temperature)
return logits
def _apply_top_k_top_p_filtering(self, logits: torch.Tensor) -> torch.Tensor:
- """Apply top-k and top-p filtering to the logits locally when TP is disabled."""
+ """Apply top-k/top-p filtering to logits.
+
+ Args:
+ logits: Logits tensor to filter.
+
+ Returns:
+ Filtered logits tensor.
+ """🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py` around lines 339 -
355, Add Google-style docstrings to the two new helper methods
_apply_temperature_scaling and _apply_top_k_top_p_filtering: for each function
provide a one-line summary, a short description if needed, an Args section
describing parameters (e.g., logits: torch.Tensor and mention
self.sampling_params where relevant), a Returns section describing the returned
torch.Tensor, and any Notes or Raises if applicable (e.g., side effects from
in-place div_ in _apply_temperature_scaling). Place the docstrings immediately
under each def to match project Sphinx parsing and coding guidelines.
| # Save and adjust sampling_params for reference model | ||
| saved_sampling_params = self.sampling_params | ||
| if saved_sampling_params is not None: | ||
| self.sampling_params = TrainingSamplingParams( | ||
| top_k=None, | ||
| top_p=1.0, | ||
| temperature=saved_sampling_params.temperature, | ||
| ) | ||
| else: | ||
| self.sampling_params = None | ||
|
|
||
| yield | ||
|
|
||
| finally: | ||
| # Restore sampling_params | ||
| self.sampling_params = saved_sampling_params | ||
|
|
There was a problem hiding this comment.
Prevent saved_sampling_params from being undefined on exceptions.
If an exception occurs before saved_sampling_params is assigned, the finally block will raise UnboundLocalError. Initialize it before the try.
🛠️ Suggested fix
with torch.no_grad():
- try:
+ saved_sampling_params = self.sampling_params
+ try:
# Save train model state_dict
curr_state_dict = get_cpu_state_dict(
self.model.state_dict().items(), pin_memory=True
)
@@
- saved_sampling_params = self.sampling_params
if saved_sampling_params is not None:
self.sampling_params = TrainingSamplingParams(
top_k=None,
top_p=1.0,
temperature=saved_sampling_params.temperature,
)🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py` around lines 831 -
847, The finally block can raise UnboundLocalError if an exception occurs before
saved_sampling_params is set; initialize saved_sampling_params = None before the
try that modifies self.sampling_params so the finally always has a defined
variable. Update the context around the sampling_params manipulation (reference
symbols: saved_sampling_params, self.sampling_params, TrainingSamplingParams) to
set saved_sampling_params = None immediately before entering the try, then
proceed with the existing logic and restoration in the finally.
| if "generation" in self.cfg and self.cfg["generation"] is not None: | ||
| generation_cfg = self.cfg["generation"] | ||
| self.sampling_params = TrainingSamplingParams( | ||
| top_k=generation_cfg.get("top_k", None), | ||
| top_p=generation_cfg.get("top_p", 1.0), | ||
| temperature=generation_cfg.get("temperature", 1.0), | ||
| ) | ||
| else: | ||
| self.sampling_params = None |
There was a problem hiding this comment.
self.cfg is not yet assigned — this will raise AttributeError on construction.
self.cfg is set at Line 198, but Line 179 references self.cfg. Every instantiation of DTensorPolicyWorker will crash here. The existing block at Line 176 correctly uses the config parameter; this new block should do the same.
Proposed fix
- if "generation" in self.cfg and self.cfg["generation"] is not None:
- generation_cfg = self.cfg["generation"]
+ if "generation" in config and config["generation"] is not None:
+ generation_cfg = config["generation"]
self.sampling_params = TrainingSamplingParams(
top_k=generation_cfg.get("top_k", None),
top_p=generation_cfg.get("top_p", 1.0),🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker.py` around lines 179 -
187, The generation config access uses self.cfg before it's assigned, causing
AttributeError in DTensorPolicyWorker __init__; update the block that constructs
self.sampling_params to read from the local config parameter (use config or
config["generation"]) or move this entire generation-handling block to after the
assignment to self.cfg so TrainingSamplingParams (top_k/top_p/temperature) is
created from the already-initialized config; target the DTensorPolicyWorker
constructor and the sampling_params/TrainingSamplingParams logic to fix the
reference.
| # Temporarily disable top-k/top-p filtering for reference policy logprobs. | ||
| # The reference policy has different weights, so its top-k/top-p set is | ||
| # inherently different from the current policy. Using filtered logprobs | ||
| # would cause -inf mismatches that cannot be resolved by masking. | ||
| # Note: We keep temperature scaling since it was applied to prev_logprobs. | ||
| saved_sampling_params = self.sampling_params | ||
| if saved_sampling_params is not None: | ||
| self.sampling_params = TrainingSamplingParams( | ||
| top_k=None, # Disable top-k | ||
| top_p=1.0, # Disable top-p | ||
| temperature=saved_sampling_params.temperature, # Keep temperature | ||
| ) | ||
| else: | ||
| self.sampling_params = None | ||
|
|
||
| # - self.model is the original reference_model, now on CUDA | ||
| # - curr_state_dict is the train model, now on CPU | ||
| yield | ||
|
|
||
| finally: | ||
| # Restore sampling_params | ||
| self.sampling_params = saved_sampling_params |
There was a problem hiding this comment.
Same UnboundLocalError risk as in the Megatron worker.
If any of the state-dict operations on Lines 1654–1661 throw before saved_sampling_params is assigned on Line 1668, the finally block on Line 1684 will fail with UnboundLocalError.
Move the assignment to the top of the try block:
Proposed fix
with torch.no_grad():
try:
+ # Save sampling_params early so finally can always restore them
+ saved_sampling_params = self.sampling_params
+
# Save train model state_dict
curr_state_dict = get_cpu_state_dict(
self.model.state_dict().items(), pin_memory=True🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker.py` around lines 1663 -
1684, The try/finally can raise UnboundLocalError because saved_sampling_params
is assigned after operations; move the saved_sampling_params =
self.sampling_params assignment to the very start of the try block (before any
state-dict or model manipulation) so it is always defined for the finally, then
use that saved value when setting self.sampling_params to the temporary
TrainingSamplingParams (or None) and restore it in the finally; reference the
variables/methods: self.sampling_params, saved_sampling_params,
TrainingSamplingParams, and the surrounding try/finally in
dtensor_policy_worker.py.
| SAMPLING_PARAMS_TEST_ACTOR_FQN = ( | ||
| f"{SamplingParamsTestActor.__module__}.SamplingParamsTestActor" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def register_sampling_params_test_actor(): | ||
| """Register the SamplingParamsTestActor for use in tests.""" | ||
| original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get( | ||
| SAMPLING_PARAMS_TEST_ACTOR_FQN | ||
| ) | ||
| ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] = PY_EXECUTABLES.SYSTEM | ||
| yield SAMPLING_PARAMS_TEST_ACTOR_FQN | ||
| if SAMPLING_PARAMS_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: | ||
| if original_registry_value is None: | ||
| del ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] | ||
| else: | ||
| ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] = ( | ||
| original_registry_value | ||
| ) | ||
|
|
There was a problem hiding this comment.
Prefix new global with G_ to match the naming rule.
Globals should use the G_ prefix; rename the constant and its uses.
🔧 Suggested fix
-SAMPLING_PARAMS_TEST_ACTOR_FQN = (
+G_SAMPLING_PARAMS_TEST_ACTOR_FQN = (
f"{SamplingParamsTestActor.__module__}.SamplingParamsTestActor"
)
`@pytest.fixture`
def register_sampling_params_test_actor():
"""Register the SamplingParamsTestActor for use in tests."""
original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get(
- SAMPLING_PARAMS_TEST_ACTOR_FQN
+ G_SAMPLING_PARAMS_TEST_ACTOR_FQN
)
- ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] = PY_EXECUTABLES.SYSTEM
- yield SAMPLING_PARAMS_TEST_ACTOR_FQN
- if SAMPLING_PARAMS_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY:
+ ACTOR_ENVIRONMENT_REGISTRY[G_SAMPLING_PARAMS_TEST_ACTOR_FQN] = PY_EXECUTABLES.SYSTEM
+ yield G_SAMPLING_PARAMS_TEST_ACTOR_FQN
+ if G_SAMPLING_PARAMS_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY:
if original_registry_value is None:
- del ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN]
+ del ACTOR_ENVIRONMENT_REGISTRY[G_SAMPLING_PARAMS_TEST_ACTOR_FQN]
else:
- ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] = (
+ ACTOR_ENVIRONMENT_REGISTRY[G_SAMPLING_PARAMS_TEST_ACTOR_FQN] = (
original_registry_value
)🤖 Prompt for AI Agents
In `@tests/unit/distributed/test_model_utils.py` around lines 1157 - 1177, Rename
the global constant SAMPLING_PARAMS_TEST_ACTOR_FQN to follow the G_ prefix
convention (e.g., G_SAMPLING_PARAMS_TEST_ACTOR_FQN) and update all references in
this test block: the fixture register_sampling_params_test_actor, the
ACTOR_ENVIRONMENT_REGISTRY key accesses, and any use with PY_EXECUTABLES.SYSTEM;
ensure the original_registry_value capture and cleanup logic remain unchanged
but use the new constant name everywhere.
What does this PR do ?
This PR is a reimplementation of #1578 on the new main. It implements top-k and top-p sampling in training to ensure consistency between training and inference when sampling parameters are used.
Status: WIP
Why is this tricky? In Tensor Parallelism (TP), the vocabulary is sharded across GPUs (vocab-parallel). Top-k and top-p filtering require probabilities across the full vocabulary, not individual shards. A naive all-gather would incur large memory overhead.
Solution: We convert from vocab-parallel to batch-sequence-parallel layout via all-to-all communication, apply filtering on the full vocabulary, then convert back. This avoids materializing the full vocabulary on any single rank.
Detailed Changes
Core sampling utilities (
nemo_rl/models/policy/utils.py):TrainingSamplingParamsdataclass for standardizingtop_k,top_p,temperature.ApplyTopKTopPautograd function with correct gradient masking.Distributed logprob computation (
nemo_rl/distributed/model_utils.py):DistributedLogprobWithSamplingandChunkedDistributedLogprobWithSampling: perform vocab-parallel to batch-sequence-parallel conversion viaall_to_all_vp2sq/all_to_all_sq2vp, apply filtering, then convert back.compute_logprobs_from_logitsto centralize logprob computation across all parallel setups (TP, CP, DTensor, packed sequences).Loss functions (
nemo_rl/algorithms/loss_functions.py):ClippedPGLossFn,NLLLoss,DPOLossFn,DistillationLossFn,PreferenceLossnow acceptsampling_params.ClippedPGLossFnmasks out-infprev_logprobs and computes separate unfiltered logprobs for KL divergence.Policy workers (
dtensor_policy_worker.py,dtensor_policy_worker_v2.py,megatron_policy_worker.py):self.sampling_paramsfrom generation config and pass through to loss/logprob functions.use_reference_model()temporarily disables top-k/top-p (retains temperature) for reference model.vLLM validation (
vllm_generation.py):top_kmust beNone,-1, or>= 1;top_pmust be> 0. Removed old threshold-based validation.Configs & nightly:
megatron-sampling.yaml(temp=0.8, top_p=0.9, top_k=50) andmegatron-temp0.6.yaml(temp=0.6) with test scripts, registered innightly.txt.Review Starting Points
nemo_rl/models/policy/utils.py: this definesTrainingSamplingParamsandApplyTopKTopP, which everything else builds on.nemo_rl/distributed/model_utils.py: focus onDistributedLogprobWithSamplingand theall_to_all_vp2sq/all_to_all_sq2vphelpers to understand the TP layout conversion.nemo_rl/algorithms/loss_functions.py: particularlyClippedPGLossFnwhich has the most involved changes (handling-infmasking and separate KL logprobs).Tests
test_sampling_params_top_k_top_p(tests/unit/distributed/test_model_utils.py): Testsapply_top_k_top_pforward correctness and gradient masking (gradients zero for filtered positions, preserved for non-filtered).test_sampling_params_distributed_logprob(tests/unit/distributed/test_model_utils.py): TestsDistributedLogprobWithSamplingandChunkedDistributedLogprobWithSamplingin a 2-GPU TP setup, validating forward logprobs and backward gradients against a full-vocabulary baseline.test_vllm_logprobs_mode(tests/unit/models/generation/test_vllm_logprobs_mode.py): Adapted to the newapply_top_k_top_preturn signature (logits, mask tuple).token_mult_prob_errormetrics.Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
CC: @yuki-97 @terrykong
Summary by CodeRabbit
New Features
Bug Fixes
Tests