Skip to content

Comments

feat: Support top-p and top-k#1938

Open
zhandaz wants to merge 2 commits intomainfrom
zhanda/top-p-k-reimpl
Open

feat: Support top-p and top-k#1938
zhandaz wants to merge 2 commits intomainfrom
zhanda/top-p-k-reimpl

Conversation

@zhandaz
Copy link
Contributor

@zhandaz zhandaz commented Feb 12, 2026

What does this PR do ?

This PR is a reimplementation of #1578 on the new main. It implements top-k and top-p sampling in training to ensure consistency between training and inference when sampling parameters are used.

Status: WIP

  1. The Megatron worker refactoring is not merged yet. So we will update later.
  2. Will need to end-to-end runs for tests.

Why is this tricky? In Tensor Parallelism (TP), the vocabulary is sharded across GPUs (vocab-parallel). Top-k and top-p filtering require probabilities across the full vocabulary, not individual shards. A naive all-gather would incur large memory overhead.

Solution: We convert from vocab-parallel to batch-sequence-parallel layout via all-to-all communication, apply filtering on the full vocabulary, then convert back. This avoids materializing the full vocabulary on any single rank.

Detailed Changes

  • Core sampling utilities (nemo_rl/models/policy/utils.py):

    • TrainingSamplingParams dataclass for standardizing top_k, top_p, temperature.
    • ApplyTopKTopP autograd function with correct gradient masking.
  • Distributed logprob computation (nemo_rl/distributed/model_utils.py):

    • DistributedLogprobWithSampling and ChunkedDistributedLogprobWithSampling: perform vocab-parallel to batch-sequence-parallel conversion via all_to_all_vp2sq/all_to_all_sq2vp, apply filtering, then convert back.
    • compute_logprobs_from_logits to centralize logprob computation across all parallel setups (TP, CP, DTensor, packed sequences).
  • Loss functions (nemo_rl/algorithms/loss_functions.py):

    • ClippedPGLossFn, NLLLoss, DPOLossFn, DistillationLossFn, PreferenceLoss now accept sampling_params.
    • ClippedPGLossFn masks out -inf prev_logprobs and computes separate unfiltered logprobs for KL divergence.
  • Policy workers (dtensor_policy_worker.py, dtensor_policy_worker_v2.py, megatron_policy_worker.py):

    • Initialize self.sampling_params from generation config and pass through to loss/logprob functions.
    • use_reference_model() temporarily disables top-k/top-p (retains temperature) for reference model.
  • vLLM validation (vllm_generation.py):

    • Simplified: top_k must be None, -1, or >= 1; top_p must be > 0. Removed old threshold-based validation.
  • Configs & nightly:

    • Added megatron-sampling.yaml (temp=0.8, top_p=0.9, top_k=50) and megatron-temp0.6.yaml (temp=0.6) with test scripts, registered in nightly.txt.

Review Starting Points

  1. Start with the dataclass and autograd core in nemo_rl/models/policy/utils.py: this defines TrainingSamplingParams and ApplyTopKTopP, which everything else builds on.
  2. Then the distributed integration in nemo_rl/distributed/model_utils.py: focus on DistributedLogprobWithSampling and the all_to_all_vp2sq/all_to_all_sq2vp helpers to understand the TP layout conversion.
  3. Then the loss function changes in nemo_rl/algorithms/loss_functions.py: particularly ClippedPGLossFn which has the most involved changes (handling -inf masking and separate KL logprobs).
  4. Policy workers are mostly mechanical plumbing: each one follows the same pattern of init, pass-through, and reference-model save/restore.

Tests

  • test_sampling_params_top_k_top_p (tests/unit/distributed/test_model_utils.py): Tests apply_top_k_top_p forward correctness and gradient masking (gradients zero for filtered positions, preserved for non-filtered).
  • test_sampling_params_distributed_logprob (tests/unit/distributed/test_model_utils.py): Tests DistributedLogprobWithSampling and ChunkedDistributedLogprobWithSampling in a 2-GPU TP setup, validating forward logprobs and backward gradients against a full-vocabulary baseline.
  • test_vllm_logprobs_mode (tests/unit/models/generation/test_vllm_logprobs_mode.py): Adapted to the new apply_top_k_top_p return signature (logits, mask tuple).
  • Nightly convergence tests: Two new 1N8G Megatron training runs: one with sampling (top_p=0.9, top_k=50, temp=0.8) and one with temperature-only (temp=0.6): validating reward convergence and token_mult_prob_error metrics.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

CC: @yuki-97 @terrykong

Summary by CodeRabbit

  • New Features

    • Added experimental GRPO training configurations for Llama-3.2-1B Instruct with Megatron sampling and temperature control variants.
    • Enabled top-k/top-p token filtering during distributed training with improved gradient handling for sampling-based approaches.
  • Bug Fixes

    • Simplified parameter validation for generation settings with clearer error messaging.
  • Tests

    • Added comprehensive test coverage for distributed sampling operations and token filtering behavior.

@zhandaz zhandaz requested review from a team as code owners February 12, 2026 20:46
Signed-off-by: Zhanda <zhandazhu@gmail.com>
@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: d2eac7a (PR #1938 from zhanda/top-p-k-reimpl)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@zhandaz zhandaz force-pushed the zhanda/top-p-k-reimpl branch from d2eac7a to ac0897f Compare February 12, 2026 20:47
@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: ac0897f (PR #1938 from zhanda/top-p-k-reimpl)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 12, 2026

📝 Walkthrough

Walkthrough

This PR introduces sampling-aware log-probability computation across the training pipeline, enabling top-k/top-p filtering with gradient support through new autograd functions. It threads a TrainingSamplingParams parameter through loss functions, distributed utilities, and policy workers, with new chunked rematerialization paths for memory efficiency.

Changes

Cohort / File(s) Summary
Configuration Files
examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-*.yaml
New YAML configuration files for GRPO Llama-3.2 1B experiments with Megatron sampling and temperature 0.6 setups, including hyperparameters, generation settings, logging, and resource allocation.
Core Sampling Infrastructure
nemo_rl/models/policy/utils.py
Introduces TrainingSamplingParams dataclass, helper functions for filtering detection, chunked top-k/top-p filtering logic with custom autograd function ApplyTopKTopP, and modular sampling utilities.
Distributed Log-Probability Computation
nemo_rl/distributed/model_utils.py
Adds DistributedLogprobWithSampling and ChunkedDistributedLogprobWithSampling autograd functions for gradient-enabled sampling paths; centralizes logprob computation via compute_logprobs_from_logits; adds helpers for vocab-parallel layout conversion.
Loss Function Integration
nemo_rl/algorithms/loss_functions.py
Propagates sampling_params across loss functions (ClippedPGLoss, NLLLoss, DPO, Distillation, Preference, SequencePacking); adds masking for invalid positions when top-k/top-p filtering is active; uses unfiltered logprobs for KL calculations.
Policy Worker Implementation (DTensor)
nemo_rl/models/policy/workers/dtensor_policy_worker*.py
Integrates TrainingSamplingParams from config; adds temperature scaling and top-k/top-p filtering helpers; propagates sampling params to loss and logprob functions; temporarily disables top-k/top-p during reference model swaps.
Policy Worker Implementation (Megatron)
nemo_rl/models/megatron/common.py, nemo_rl/models/policy/workers/megatron_policy_worker.py
Replaces policy_cfg with sampling_params in forward path; propagates sampling params through loss and logprob computations; manages sampling params during reference model context switches.
Generation Configuration Validation
nemo_rl/models/generation/vllm/vllm_generation.py
Simplifies top_k/top_p validation logic to accept None/-1 (no filtering) and allow flexible filtering ranges; removes hardcoded threshold constants.
Distributed Testing Infrastructure
tests/unit/distributed/test_model_utils.py
Adds SamplingParamsTestActor with comprehensive tests for top-k/top-p filtering forward/backward passes and distributed logprob computation with sampling across GPU clusters.
Test Suite Updates
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-*.sh, tests/test_suites/nightly.txt
New shell scripts for running GRPO experiments with logging, checkpointing, and metrics validation; registers test scripts in nightly test suite.
Test Utility Updates
tests/unit/models/generation/test_vllm_logprobs_mode.py, tests/unit/test_utils.py
Updates to handle apply_top_k_top_p returning tuple instead of single tensor; adds sampling_params parameter to test loss function signatures.
Removed Tests
tests/unit/models/generation/test_vllm_generation.py
Removes threshold-based validation test for top_p/top_k since validation logic changed to be more flexible.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer
    participant PolicyWorker
    participant LossFunction
    participant DistributedUtils
    participant Sampling

    Trainer->>PolicyWorker: train(data, sampling_params)
    PolicyWorker->>Sampling: Initialize sampling_params from config
    PolicyWorker->>DistributedUtils: forward_step with sampling_params
    DistributedUtils->>DistributedUtils: Compute logits from model
    DistributedUtils->>Sampling: apply_top_k_top_p if filtering enabled
    Sampling->>Sampling: Filter logits, create keep_mask
    DistributedUtils->>DistributedUtils: compute_logprobs_from_logits
    DistributedUtils->>LossFunction: Pass logits, logprobs, sampling_params
    LossFunction->>LossFunction: Apply masking for invalid positions
    LossFunction->>LossFunction: Use unfiltered logprobs for KL calculations
    LossFunction->>LossFunction: Use filtered logprobs for actor loss
    LossFunction-->>Trainer: loss, metrics
    Trainer->>Trainer: Backward pass through sampling autograd functions
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

CI:L1

Suggested reviewers

  • terrykong
  • yuki-97
  • adil-a
🚥 Pre-merge checks | ✅ 3 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 73.44% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR introduces major changes affecting training numerics and distributed gradient computation but explicitly states end-to-end testing is pending with no regression validation documented. Complete end-to-end testing and document results including nightly convergence validation runs, unit test execution results, convergence comparisons, and loss/gradient correctness confirmation.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'feat: Support top-p and top-k' directly and concisely summarizes the main feature addition, which is implementing top-p and top-k sampling support throughout the codebase.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch zhanda/top-p-k-reimpl

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: 1d5f7c0 (PR #1938 from zhanda/top-p-k-reimpl)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (7)
tests/unit/test_utils.py (2)

26-38: ⚠️ Potential issue | 🟡 Minor

Silence unused sampling_params to satisfy Ruff.

Ruff flags this as unused (ARG002). Add an explicit del to keep the signature but avoid lint noise.

💡 Suggested fix
     ) -> tuple[torch.Tensor, dict[str, Any]]:
+        del sampling_params  # unused test hook; keeps signature aligned
         # Just return mean of logprobs as the loss for testing
         loss = next_token_logits.mean()

51-62: ⚠️ Potential issue | 🟡 Minor

Silence unused sampling_params to satisfy Ruff.

Same ARG002 warning here; explicitly delete the argument.

💡 Suggested fix
     ) -> tuple[torch.Tensor, dict[str, Any]]:
+        del sampling_params  # unused test hook; keeps signature aligned
         # logits shape: [batch_size, seq_len, vocab_size]
nemo_rl/distributed/model_utils.py (1)

1-1: ⚠️ Potential issue | 🟡 Minor

Update the NVIDIA copyright year to 2026.

As per coding guidelines, "{src/,examples/,nemo_rl/**}/*.{py,sh}: Add the NVIDIA copyright header (with current year) to all Python files and shell scripts, excluding tests (files under tests/ or test-only scripts)".

nemo_rl/algorithms/loss_functions.py (2)

1-1: ⚠️ Potential issue | 🟡 Minor

Update the NVIDIA copyright year to 2026.

As per coding guidelines, "{src/,examples/,nemo_rl/**}/*.{py,sh}: Add the NVIDIA copyright header (with current year) to all Python files and shell scripts, excluding tests (files under tests/ or test-only scripts)".


391-399: ⚠️ Potential issue | 🟠 Major

Use the updated mask for sequence-level ratio averages.

After masking out -inf positions, token_mask still includes those tokens. This can skew sequence-level ratios under top‑k/top‑p mismatch scenarios. Use the already-updated mask instead.

🛠️ Suggested fix
-                seq_log_ratio_mean = masked_mean(
-                    log_ratios,
-                    token_mask,
-                    dim=-1,
-                ).unsqueeze(-1)
+                seq_log_ratio_mean = masked_mean(
+                    log_ratios,
+                    mask,
+                    dim=-1,
+                ).unsqueeze(-1)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)

1-1: ⚠️ Potential issue | 🟡 Minor

Update the NVIDIA copyright year to 2026.

As per coding guidelines, "{src/,examples/,nemo_rl/**}/*.{py,sh}: Add the NVIDIA copyright header (with current year) to all Python files and shell scripts, excluding tests (files under tests/ or test-only scripts)".

nemo_rl/models/policy/workers/megatron_policy_worker.py (1)

722-742: ⚠️ Potential issue | 🟠 Major

saved_sampling_params may be unbound in the finally block if an earlier line throws.

If self.model.load_state_dict(...) on Line 717 raises before saved_sampling_params is assigned on Line 722, the finally block on Line 742 will hit UnboundLocalError, masking the original exception.

Move the assignment before any code that can fail inside the try:

Proposed fix
         with torch.no_grad():
             try:
+                saved_sampling_params = self.sampling_params
+
                 # Save original references
                 model_state_dict = {}
                 for name, item in self.model.state_dict().items():
@@ -719,8 +721,7 @@
                 # self.model.state_dict()[name] = item.detach().to(device="cuda", non_blocking=True, copy=True)

-                saved_sampling_params = self.sampling_params
-                if saved_sampling_params is not None:
+                if self.sampling_params is not None:
                     self.sampling_params = TrainingSamplingParams(
                         top_k=None,
                         top_p=1.0,
🤖 Fix all issues with AI agents
In `@nemo_rl/distributed/model_utils.py`:
- Around line 1544-1559: The code moves input_ids to GPU with a hard-coded
.cuda() which breaks CPU or non-default-GPU runs; change the conversion of
input_ids[:, 1:] (used to produce next_tokens) to use the same device as the
logits tensor (e.g., next_token_logits_wo_last.device or
next_token_logprobs.device) instead of .cuda() so next_tokens and subsequent
gather operate on the correct device; update the expression that defines
next_tokens to call .to(<logits_device>) so device alignment is ensured for
next_token_logits_wo_last, next_token_logprobs, and token_logprobs.

In `@nemo_rl/models/generation/vllm/vllm_generation.py`:
- Around line 96-99: The validation for the top_p parameter in
vllm_generation.py currently only rejects top_p <= 0 but allows values > 1.0,
contradicting the documented valid range (0, 1]; update the check that inspects
the top_p variable so it raises a ValueError when top_p <= 0 or top_p > 1.0, and
keep the existing error message (or adjust it to mention the closed upper bound)
to reflect the contract; locate the validation near the top_p check in the VLLM
generation logic and modify that conditional to enforce both bounds.

In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py`:
- Around line 339-355: Add Google-style docstrings to the two new helper methods
_apply_temperature_scaling and _apply_top_k_top_p_filtering: for each function
provide a one-line summary, a short description if needed, an Args section
describing parameters (e.g., logits: torch.Tensor and mention
self.sampling_params where relevant), a Returns section describing the returned
torch.Tensor, and any Notes or Raises if applicable (e.g., side effects from
in-place div_ in _apply_temperature_scaling). Place the docstrings immediately
under each def to match project Sphinx parsing and coding guidelines.
- Around line 245-251: The code is supplying non-None defaults for generation
config values; remove those defaults and read required keys directly from the
config so YAML owns defaults: when "generation" exists in self.cfg, assign
generation_cfg = self.cfg["generation"] and construct TrainingSamplingParams
using the raw values (e.g. generation_cfg["top_k"], generation_cfg["top_p"],
generation_cfg["temperature"]) instead of .get(..., default), and set
self.sampling_params accordingly; ensure any callers expect these keys to be
present (YAML should provide defaults) and do not introduce new in-code fallback
values.
- Around line 831-847: The finally block can raise UnboundLocalError if an
exception occurs before saved_sampling_params is set; initialize
saved_sampling_params = None before the try that modifies self.sampling_params
so the finally always has a defined variable. Update the context around the
sampling_params manipulation (reference symbols: saved_sampling_params,
self.sampling_params, TrainingSamplingParams) to set saved_sampling_params =
None immediately before entering the try, then proceed with the existing logic
and restoration in the finally.

In `@nemo_rl/models/policy/workers/dtensor_policy_worker.py`:
- Around line 179-187: The generation config access uses self.cfg before it's
assigned, causing AttributeError in DTensorPolicyWorker __init__; update the
block that constructs self.sampling_params to read from the local config
parameter (use config or config["generation"]) or move this entire
generation-handling block to after the assignment to self.cfg so
TrainingSamplingParams (top_k/top_p/temperature) is created from the
already-initialized config; target the DTensorPolicyWorker constructor and the
sampling_params/TrainingSamplingParams logic to fix the reference.
- Around line 1663-1684: The try/finally can raise UnboundLocalError because
saved_sampling_params is assigned after operations; move the
saved_sampling_params = self.sampling_params assignment to the very start of the
try block (before any state-dict or model manipulation) so it is always defined
for the finally, then use that saved value when setting self.sampling_params to
the temporary TrainingSamplingParams (or None) and restore it in the finally;
reference the variables/methods: self.sampling_params, saved_sampling_params,
TrainingSamplingParams, and the surrounding try/finally in
dtensor_policy_worker.py.

In `@tests/unit/distributed/test_model_utils.py`:
- Around line 1157-1177: Rename the global constant
SAMPLING_PARAMS_TEST_ACTOR_FQN to follow the G_ prefix convention (e.g.,
G_SAMPLING_PARAMS_TEST_ACTOR_FQN) and update all references in this test block:
the fixture register_sampling_params_test_actor, the ACTOR_ENVIRONMENT_REGISTRY
key accesses, and any use with PY_EXECUTABLES.SYSTEM; ensure the
original_registry_value capture and cleanup logic remain unchanged but use the
new constant name everywhere.
🧹 Nitpick comments (3)
nemo_rl/models/policy/utils.py (1)

170-269: Verify logits_sort.scatter with self as source behaves correctly.

Line 254: chunk_filtered = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) uses logits_sort as both the input tensor and the src. This works because scatter creates a new output tensor (non-in-place variant), but it looks unusual. A brief inline comment clarifying that this "unscrambles" sorted values back to original vocab order via the argsort indices would improve readability.

nemo_rl/models/policy/workers/megatron_policy_worker.py (1)

233-242: Non-None fallback defaults for config values.

Lines 237–239 supply top_p=1.0 and temperature=1.0 via .get() when the keys are absent from the generation config. Per coding guidelines, YAML should be the single source of truth for configuration defaults. These happen to be the "disabled" sentinel values so the risk is low, but consider ensuring the exemplar YAML configs always define top_k, top_p, and temperature explicitly so the .get() fallbacks are never actually needed.

As per coding guidelines: "YAML is the single source of truth for configuration defaults; do not set non-None defaults in code for configuration values."

nemo_rl/models/policy/workers/dtensor_policy_worker.py (1)

493-509: New helper methods are correct and consistent with dtensor_policy_worker_v2.py.

_apply_temperature_scaling and _apply_top_k_top_p_filtering match the pattern in the v2 worker. The filtering helper correctly guards with need_top_k_filtering / need_top_p_filtering checks before calling apply_top_k_top_p. Note that _apply_top_k_top_p_filtering uses separate need_top_k_filtering and need_top_p_filtering checks while v2 uses need_top_k_or_top_p_filtering — functionally equivalent but slightly inconsistent.

Align with v2 worker's combined check
+from nemo_rl.models.policy.utils import (
+    TrainingSamplingParams,
+    apply_top_k_top_p,
+    configure_dynamo_cache,
+    get_runtime_env_for_policy_worker,
+    need_top_k_or_top_p_filtering,
+    resolve_model_class,
+)
...
     def _apply_top_k_top_p_filtering(self, logits: torch.Tensor) -> torch.Tensor:
         """Apply top-k and top-p filtering to the logits locally when TP is disabled."""
-        if self.sampling_params is not None and (
-            need_top_k_filtering(self.sampling_params.top_k)
-            or need_top_p_filtering(self.sampling_params.top_p)
-        ):
+        if self.sampling_params is not None and need_top_k_or_top_p_filtering(
+            self.sampling_params.top_k, self.sampling_params.top_p
+        ):
             logits, _ = apply_top_k_top_p(

Comment on lines +1544 to +1559
next_token_logits_wo_last = next_token_logits[
:, :-1
] # Remove last position's logits
# Apply top-k and top-p filtering
next_token_logits_wo_last, _ = apply_top_k_top_p(
next_token_logits_wo_last,
top_k=sampling_params.top_k if sampling_params is not None else None,
top_p=sampling_params.top_p if sampling_params is not None else 1.0,
)
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits_wo_last, dim=-1
)
next_tokens = input_ids[:, 1:].cuda() # Skip first token
token_logprobs = next_token_logprobs.gather(
dim=-1, index=next_tokens.unsqueeze(-1)
).squeeze(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

wc -l nemo_rl/distributed/model_utils.py

Repository: NVIDIA-NeMo/RL

Length of output: 97


🏁 Script executed:

sed -n '1530,1570p' nemo_rl/distributed/model_utils.py

Repository: NVIDIA-NeMo/RL

Length of output: 1710


🏁 Script executed:

# Get more context around the function to understand device handling
sed -n '1500,1560p' nemo_rl/distributed/model_utils.py

Repository: NVIDIA-NeMo/RL

Length of output: 2819


🏁 Script executed:

# Find the function definition that contains these lines
sed -n '1400,1560p' nemo_rl/distributed/model_utils.py | head -100

Repository: NVIDIA-NeMo/RL

Length of output: 3908


🏁 Script executed:

# Search for function definition around line 1544
ast-grep --pattern 'def $FUNC($_) { $$$ }'

Repository: NVIDIA-NeMo/RL

Length of output: 40


🏁 Script executed:

# Check the function signature and initial device handling
sed -n '1478,1520p' nemo_rl/distributed/model_utils.py

Repository: NVIDIA-NeMo/RL

Length of output: 1969


🏁 Script executed:

# Search for other .cuda() usages in this file to see if there's a pattern
rg '\.cuda\(\)' nemo_rl/distributed/model_utils.py -n

Repository: NVIDIA-NeMo/RL

Length of output: 128


🏁 Script executed:

# Check if input_ids device is explicitly managed elsewhere in the file
rg 'input_ids.*\.to\(' nemo_rl/distributed/model_utils.py -n -A 1

Repository: NVIDIA-NeMo/RL

Length of output: 40


🏁 Script executed:

# Check where compute_logprobs_from_logits is called to understand device assumptions
rg 'compute_logprobs_from_logits' nemo_rl/ -A 2 -B 2

Repository: NVIDIA-NeMo/RL

Length of output: 2106


🏁 Script executed:

# Check if input_ids is explicitly moved elsewhere before being passed to this function
rg 'compute_logprobs_from_logits' nemo_rl/ --type py -l | head -5 | xargs grep -l 'input_ids'

Repository: NVIDIA-NeMo/RL

Length of output: 129


🏁 Script executed:

# Check the calls to compute_logprobs_from_logits in loss_functions.py
sed -n '1,100p' nemo_rl/algorithms/loss_functions.py | head -40

Repository: NVIDIA-NeMo/RL

Length of output: 1466


🏁 Script executed:

# Find where compute_logprobs_from_logits is called and check device context
rg -B 5 'compute_logprobs_from_logits' nemo_rl/algorithms/loss_functions.py | head -60

Repository: NVIDIA-NeMo/RL

Length of output: 1406


🏁 Script executed:

# Check what device next_token_logits is on when passed to this function
rg -B 10 'compute_logprobs_from_logits' nemo_rl/algorithms/loss_functions.py | grep -E '(device|cuda|\.to\(|next_token_logits)' | head -20

Repository: NVIDIA-NeMo/RL

Length of output: 40


🏁 Script executed:

# Get context around the first call to compute_logprobs_from_logits in loss_functions.py
rg -B 20 'curr_logprobs = compute_logprobs_from_logits' nemo_rl/algorithms/loss_functions.py | head -40

Repository: NVIDIA-NeMo/RL

Length of output: 750


🏁 Script executed:

# Check if input_ids is explicitly on CPU or if there's any device handling
rg 'data\["input_ids"\]' nemo_rl/algorithms/loss_functions.py -B 3 -A 3 | head -30

Repository: NVIDIA-NeMo/RL

Length of output: 1264


🏁 Script executed:

# Check the parallel paths to see if they handle device differently
sed -n '1515,1540p' nemo_rl/distributed/model_utils.py

Repository: NVIDIA-NeMo/RL

Length of output: 1221


🏁 Script executed:

# Verify how input_ids is used in parallel paths to confirm device handling
sed -n '1520,1560p' nemo_rl/distributed/model_utils.py

Repository: NVIDIA-NeMo/RL

Length of output: 1879


🏁 Script executed:

# Check from_parallel_logits_to_logprobs to see if it handles device explicitly
rg -A 30 'def from_parallel_logits_to_logprobs' nemo_rl/distributed/model_utils.py | head -40

Repository: NVIDIA-NeMo/RL

Length of output: 2227


Move input_ids to match the logits device instead of hard-coding .cuda().

Line 1556 forces execution to GPU 0, breaking CPU setups and non-default GPU allocation. Align the device with the logits tensor to ensure consistency across all code paths.

Suggested fix
-        next_tokens = input_ids[:, 1:].cuda()  # Skip first token
+        next_tokens = input_ids[:, 1:].to(next_token_logits_wo_last.device)  # Skip first token
🤖 Prompt for AI Agents
In `@nemo_rl/distributed/model_utils.py` around lines 1544 - 1559, The code moves
input_ids to GPU with a hard-coded .cuda() which breaks CPU or non-default-GPU
runs; change the conversion of input_ids[:, 1:] (used to produce next_tokens) to
use the same device as the logits tensor (e.g., next_token_logits_wo_last.device
or next_token_logprobs.device) instead of .cuda() so next_tokens and subsequent
gather operate on the correct device; update the expression that defines
next_tokens to call .to(<logits_device>) so device alignment is ensured for
next_token_logits_wo_last, next_token_logprobs, and token_logprobs.

Comment on lines +96 to 99
if top_p <= 0:
raise ValueError(
(
f"top_p sampling with values < {TOP_P_THRESHOLD} is not supported because the vLLM V1 engine "
"does not return logprobs after top_p filtering. Values >= {TOP_P_THRESHOLD} are allowed "
"for token filtering purposes. If you understand the implications and still want to use "
f"a lower top_p value, please manually comment out this check. Got top_p={top_p}. "
"See https://github.com/NVIDIA-NeMo/RL/issues/69 for more details."
)
f"top_p valid values: i) 1.0: no filtering. ii) (0, 1]: top-p filtering. Got top_p={top_p}."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

top_p > 1.0 passes validation but contradicts the documented valid range.

The error message states valid values are (0, 1], but the condition only rejects top_p <= 0. A value like 1.5 would be silently accepted. Consider also rejecting top_p > 1.0 for consistency with the documented contract.

Proposed fix
-        if top_p <= 0:
+        if top_p <= 0 or top_p > 1.0:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if top_p <= 0:
raise ValueError(
(
f"top_p sampling with values < {TOP_P_THRESHOLD} is not supported because the vLLM V1 engine "
"does not return logprobs after top_p filtering. Values >= {TOP_P_THRESHOLD} are allowed "
"for token filtering purposes. If you understand the implications and still want to use "
f"a lower top_p value, please manually comment out this check. Got top_p={top_p}. "
"See https://github.com/NVIDIA-NeMo/RL/issues/69 for more details."
)
f"top_p valid values: i) 1.0: no filtering. ii) (0, 1]: top-p filtering. Got top_p={top_p}."
)
if top_p <= 0 or top_p > 1.0:
raise ValueError(
f"top_p valid values: i) 1.0: no filtering. ii) (0, 1]: top-p filtering. Got top_p={top_p}."
)
🧰 Tools
🪛 Ruff (0.15.0)

[warning] 97-99: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@nemo_rl/models/generation/vllm/vllm_generation.py` around lines 96 - 99, The
validation for the top_p parameter in vllm_generation.py currently only rejects
top_p <= 0 but allows values > 1.0, contradicting the documented valid range (0,
1]; update the check that inspects the top_p variable so it raises a ValueError
when top_p <= 0 or top_p > 1.0, and keep the existing error message (or adjust
it to mention the closed upper bound) to reflect the contract; locate the
validation near the top_p check in the VLLM generation logic and modify that
conditional to enforce both bounds.

Comment on lines +245 to +251
if "generation" in self.cfg and self.cfg["generation"] is not None:
generation_cfg = self.cfg["generation"]
self.sampling_params = TrainingSamplingParams(
top_k=generation_cfg.get("top_k", None),
top_p=generation_cfg.get("top_p", 1.0),
temperature=generation_cfg.get("temperature", 1.0),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid setting non‑None defaults in code for generation config.

Defaults should live in YAML; use required keys here and let the config own defaults.

🛠️ Suggested fix
             self.sampling_params = TrainingSamplingParams(
-                top_k=generation_cfg.get("top_k", None),
-                top_p=generation_cfg.get("top_p", 1.0),
-                temperature=generation_cfg.get("temperature", 1.0),
+                top_k=generation_cfg.get("top_k"),
+                top_p=generation_cfg["top_p"],
+                temperature=generation_cfg["temperature"],
             )
As per coding guidelines, "YAML is the single source of truth for configuration defaults; do not set non-None defaults in code for configuration values".
🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py` around lines 245 -
251, The code is supplying non-None defaults for generation config values;
remove those defaults and read required keys directly from the config so YAML
owns defaults: when "generation" exists in self.cfg, assign generation_cfg =
self.cfg["generation"] and construct TrainingSamplingParams using the raw values
(e.g. generation_cfg["top_k"], generation_cfg["top_p"],
generation_cfg["temperature"]) instead of .get(..., default), and set
self.sampling_params accordingly; ensure any callers expect these keys to be
present (YAML should provide defaults) and do not introduce new in-code fallback
values.

Comment on lines +339 to +355
def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor:
if self.sampling_params is not None and self.sampling_params.temperature != 1.0:
logits.div_(self.sampling_params.temperature)
return logits

def _apply_top_k_top_p_filtering(self, logits: torch.Tensor) -> torch.Tensor:
"""Apply top-k and top-p filtering to the logits locally when TP is disabled."""
sampling_params = self.sampling_params
if sampling_params is not None and need_top_k_or_top_p_filtering(
sampling_params.top_k, sampling_params.top_p
):
logits, _ = apply_top_k_top_p(
logits,
top_k=sampling_params.top_k,
top_p=sampling_params.top_p,
)
return logits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add Google‑style docstrings for the new helper methods.

These are new functions; please document them in Google style for consistency.

✍️ Suggested docstrings
     def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor:
+        """Apply temperature scaling to logits.
+
+        Args:
+            logits: Logits tensor to scale in-place.
+
+        Returns:
+            The scaled logits tensor.
+        """
         if self.sampling_params is not None and self.sampling_params.temperature != 1.0:
             logits.div_(self.sampling_params.temperature)
         return logits
 
     def _apply_top_k_top_p_filtering(self, logits: torch.Tensor) -> torch.Tensor:
-        """Apply top-k and top-p filtering to the logits locally when TP is disabled."""
+        """Apply top-k/top-p filtering to logits.
+
+        Args:
+            logits: Logits tensor to filter.
+
+        Returns:
+            Filtered logits tensor.
+        """
As per coding guidelines, "Use Google style docstrings for classes and functions, which can be parsed by Sphinx".
🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py` around lines 339 -
355, Add Google-style docstrings to the two new helper methods
_apply_temperature_scaling and _apply_top_k_top_p_filtering: for each function
provide a one-line summary, a short description if needed, an Args section
describing parameters (e.g., logits: torch.Tensor and mention
self.sampling_params where relevant), a Returns section describing the returned
torch.Tensor, and any Notes or Raises if applicable (e.g., side effects from
in-place div_ in _apply_temperature_scaling). Place the docstrings immediately
under each def to match project Sphinx parsing and coding guidelines.

Comment on lines +831 to +847
# Save and adjust sampling_params for reference model
saved_sampling_params = self.sampling_params
if saved_sampling_params is not None:
self.sampling_params = TrainingSamplingParams(
top_k=None,
top_p=1.0,
temperature=saved_sampling_params.temperature,
)
else:
self.sampling_params = None

yield

finally:
# Restore sampling_params
self.sampling_params = saved_sampling_params

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Prevent saved_sampling_params from being undefined on exceptions.

If an exception occurs before saved_sampling_params is assigned, the finally block will raise UnboundLocalError. Initialize it before the try.

🛠️ Suggested fix
         with torch.no_grad():
-            try:
+            saved_sampling_params = self.sampling_params
+            try:
                 # Save train model state_dict
                 curr_state_dict = get_cpu_state_dict(
                     self.model.state_dict().items(), pin_memory=True
                 )
@@
-                saved_sampling_params = self.sampling_params
                 if saved_sampling_params is not None:
                     self.sampling_params = TrainingSamplingParams(
                         top_k=None,
                         top_p=1.0,
                         temperature=saved_sampling_params.temperature,
                     )
🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py` around lines 831 -
847, The finally block can raise UnboundLocalError if an exception occurs before
saved_sampling_params is set; initialize saved_sampling_params = None before the
try that modifies self.sampling_params so the finally always has a defined
variable. Update the context around the sampling_params manipulation (reference
symbols: saved_sampling_params, self.sampling_params, TrainingSamplingParams) to
set saved_sampling_params = None immediately before entering the try, then
proceed with the existing logic and restoration in the finally.

Comment on lines +179 to +187
if "generation" in self.cfg and self.cfg["generation"] is not None:
generation_cfg = self.cfg["generation"]
self.sampling_params = TrainingSamplingParams(
top_k=generation_cfg.get("top_k", None),
top_p=generation_cfg.get("top_p", 1.0),
temperature=generation_cfg.get("temperature", 1.0),
)
else:
self.sampling_params = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

self.cfg is not yet assigned — this will raise AttributeError on construction.

self.cfg is set at Line 198, but Line 179 references self.cfg. Every instantiation of DTensorPolicyWorker will crash here. The existing block at Line 176 correctly uses the config parameter; this new block should do the same.

Proposed fix
-        if "generation" in self.cfg and self.cfg["generation"] is not None:
-            generation_cfg = self.cfg["generation"]
+        if "generation" in config and config["generation"] is not None:
+            generation_cfg = config["generation"]
             self.sampling_params = TrainingSamplingParams(
                 top_k=generation_cfg.get("top_k", None),
                 top_p=generation_cfg.get("top_p", 1.0),
🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker.py` around lines 179 -
187, The generation config access uses self.cfg before it's assigned, causing
AttributeError in DTensorPolicyWorker __init__; update the block that constructs
self.sampling_params to read from the local config parameter (use config or
config["generation"]) or move this entire generation-handling block to after the
assignment to self.cfg so TrainingSamplingParams (top_k/top_p/temperature) is
created from the already-initialized config; target the DTensorPolicyWorker
constructor and the sampling_params/TrainingSamplingParams logic to fix the
reference.

Comment on lines +1663 to +1684
# Temporarily disable top-k/top-p filtering for reference policy logprobs.
# The reference policy has different weights, so its top-k/top-p set is
# inherently different from the current policy. Using filtered logprobs
# would cause -inf mismatches that cannot be resolved by masking.
# Note: We keep temperature scaling since it was applied to prev_logprobs.
saved_sampling_params = self.sampling_params
if saved_sampling_params is not None:
self.sampling_params = TrainingSamplingParams(
top_k=None, # Disable top-k
top_p=1.0, # Disable top-p
temperature=saved_sampling_params.temperature, # Keep temperature
)
else:
self.sampling_params = None

# - self.model is the original reference_model, now on CUDA
# - curr_state_dict is the train model, now on CPU
yield

finally:
# Restore sampling_params
self.sampling_params = saved_sampling_params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Same UnboundLocalError risk as in the Megatron worker.

If any of the state-dict operations on Lines 1654–1661 throw before saved_sampling_params is assigned on Line 1668, the finally block on Line 1684 will fail with UnboundLocalError.

Move the assignment to the top of the try block:

Proposed fix
         with torch.no_grad():
             try:
+                # Save sampling_params early so finally can always restore them
+                saved_sampling_params = self.sampling_params
+
                 # Save train model state_dict
                 curr_state_dict = get_cpu_state_dict(
                     self.model.state_dict().items(), pin_memory=True
🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker.py` around lines 1663 -
1684, The try/finally can raise UnboundLocalError because saved_sampling_params
is assigned after operations; move the saved_sampling_params =
self.sampling_params assignment to the very start of the try block (before any
state-dict or model manipulation) so it is always defined for the finally, then
use that saved value when setting self.sampling_params to the temporary
TrainingSamplingParams (or None) and restore it in the finally; reference the
variables/methods: self.sampling_params, saved_sampling_params,
TrainingSamplingParams, and the surrounding try/finally in
dtensor_policy_worker.py.

Comment on lines +1157 to +1177
SAMPLING_PARAMS_TEST_ACTOR_FQN = (
f"{SamplingParamsTestActor.__module__}.SamplingParamsTestActor"
)


@pytest.fixture
def register_sampling_params_test_actor():
"""Register the SamplingParamsTestActor for use in tests."""
original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get(
SAMPLING_PARAMS_TEST_ACTOR_FQN
)
ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] = PY_EXECUTABLES.SYSTEM
yield SAMPLING_PARAMS_TEST_ACTOR_FQN
if SAMPLING_PARAMS_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY:
if original_registry_value is None:
del ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN]
else:
ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] = (
original_registry_value
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Prefix new global with G_ to match the naming rule.

Globals should use the G_ prefix; rename the constant and its uses.

🔧 Suggested fix
-SAMPLING_PARAMS_TEST_ACTOR_FQN = (
+G_SAMPLING_PARAMS_TEST_ACTOR_FQN = (
     f"{SamplingParamsTestActor.__module__}.SamplingParamsTestActor"
 )
 
 `@pytest.fixture`
 def register_sampling_params_test_actor():
     """Register the SamplingParamsTestActor for use in tests."""
     original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get(
-        SAMPLING_PARAMS_TEST_ACTOR_FQN
+        G_SAMPLING_PARAMS_TEST_ACTOR_FQN
     )
-    ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] = PY_EXECUTABLES.SYSTEM
-    yield SAMPLING_PARAMS_TEST_ACTOR_FQN
-    if SAMPLING_PARAMS_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY:
+    ACTOR_ENVIRONMENT_REGISTRY[G_SAMPLING_PARAMS_TEST_ACTOR_FQN] = PY_EXECUTABLES.SYSTEM
+    yield G_SAMPLING_PARAMS_TEST_ACTOR_FQN
+    if G_SAMPLING_PARAMS_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY:
         if original_registry_value is None:
-            del ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN]
+            del ACTOR_ENVIRONMENT_REGISTRY[G_SAMPLING_PARAMS_TEST_ACTOR_FQN]
         else:
-            ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] = (
+            ACTOR_ENVIRONMENT_REGISTRY[G_SAMPLING_PARAMS_TEST_ACTOR_FQN] = (
                 original_registry_value
             )
As per coding guidelines, "Use upper snake_case with `G` prefix for global variables, e.g., `G_MY_GLOBAL`".
🤖 Prompt for AI Agents
In `@tests/unit/distributed/test_model_utils.py` around lines 1157 - 1177, Rename
the global constant SAMPLING_PARAMS_TEST_ACTOR_FQN to follow the G_ prefix
convention (e.g., G_SAMPLING_PARAMS_TEST_ACTOR_FQN) and update all references in
this test block: the fixture register_sampling_params_test_actor, the
ACTOR_ENVIRONMENT_REGISTRY key accesses, and any use with PY_EXECUTABLES.SYSTEM;
ensure the original_registry_value capture and cleanup logic remain unchanged
but use the new constant name everywhere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant