Conversation
197b783 to
c6d4fbb
Compare
📝 WalkthroughWalkthroughThis PR introduces a fused sequence packing optimization that combines sequence packing with logprob computation in a single forward pass. It refactors Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Trainer/Client
participant Fusion as SequencePackingFusionLossWrapper
participant Distributed as from_parallel_logits_to_logprobs_packed_sequences
participant Loss as ClippedPGLossFn
Client->>Fusion: __call__(next_token_logits, data)
Fusion->>Fusion: Pack input sequences (nvtx-wrapped)
Fusion->>Distributed: Compute logprobs from packed logits
Distributed-->>Fusion: curr_logprobs (pre-rolled)
Fusion->>Loss: _compute_loss_from_logprobs(curr_logprobs, data)
Loss-->>Fusion: loss, metrics
Fusion-->>Client: loss, metrics
rect rgba(200, 100, 100, 0.5)
Note over Fusion,Loss: Fused path: single forward pass combines<br/>packing + logprob + loss computation
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/loss_functions.py`:
- Around line 1134-1144: The call to
from_parallel_logits_to_logprobs_packed_sequences uses vocab_parallel_rank
without guarding for None, which will raise a TypeError when None is provided;
add an explicit assertion (like the one used in _compute_curr_logprobs) that
vocab_parallel_rank is not None before computing
vocab_start_index/vocab_end_index, and raise a clear error message if it is None
so callers know the parallel-vocab requirement; update the block around the
curr_logprobs computation (the call to
from_parallel_logits_to_logprobs_packed_sequences and the variables
vocab_parallel_rank/vocab_parallel_group) to perform this check first.
In `@nemo_rl/models/megatron/common.py`:
- Around line 130-148: Add the new `fuse_loss` key to the SequencePackingConfig
TypedDict in nemo_rl/models/policy/__init__.py with a short docstring describing
purpose, valid values (bool), and the recommended default; remove the in-code
default by updating the usage in the megatron wrapper to read fuse_loss =
policy_cfg.get("sequence_packing", {}).get("fuse_loss") (no fallback to False)
so the code no longer embeds a default; place the default value for
sequence_packing.fuse_loss into the exemplar YAMLs under examples/configs/*.yaml
(so YAML is the single source of truth); keep the existing behavior that selects
SequencePackingFusionLossWrapper vs SequencePackingLossWrapper and the
conditional data_dict["packed_input_ids"] assignment, but rely on the
YAML-provided default rather than .get(..., False).
In `@tests/unit/algorithms/test_sequence_packing_fusion.py`:
- Line 445: Remove the unused variable assignment to `world_size = cp_size *
tp_size` in the test (the `world_size` local is never referenced); simply delete
that line (or if the value was intended to be used, replace the assignment by
using `world_size` where needed) so that `world_size`, `cp_size`, and `tp_size`
are not assigned without use in the test `test_sequence_packing_fusion.py`.
🧹 Nitpick comments (1)
nemo_rl/algorithms/loss_functions.py (1)
1147-1149: Duck-typing on_compute_loss_from_logprobs— consider a runtime check or Protocol.
self.loss_fn._compute_loss_from_logprobs(...)will raiseAttributeErrorif the wrappedloss_fndoesn't implement this method (e.g.,NLLLoss,DPOLossFn). While the docstring documents the requirement, ahasattrcheck in__init__would produce a clear error at construction time rather than mid-training.♻️ Proposed guard in __init__
def __init__( self, loss_fn: LossFunction, cu_seqlens_q: torch.Tensor, cu_seqlens_q_padded: Optional[torch.Tensor] = None, ): + if not hasattr(loss_fn, "_compute_loss_from_logprobs"): + raise TypeError( + f"{type(loss_fn).__name__} does not implement _compute_loss_from_logprobs. " + "SequencePackingFusionLossWrapper requires a loss function with this method " + "(e.g., ClippedPGLossFn)." + ) self.loss_fn = loss_fn
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
69e416a to
371c308
Compare
Signed-off-by: mloh <mloh@nvidia.com>
|
@CodeRabbit review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/models/megatron/train.py (1)
313-328: 🛠️ Refactor suggestion | 🟠 MajorAdd
fuse_losstoSequencePackingConfigTypedDict and document in exemplar YAMLs.The
fuse_losskey is used in code (line 314–316) but not declared in theSequencePackingConfigTypedDict definition. Per coding guidelines, new config keys must be documented in the TypedDict with type andNotRequiredmarker where appropriate, and reflected in exemplar YAMLs underexamples/configs/*.yamlwith recommended defaults.Additionally, the
self.cfg is not Noneguard on line 314 is redundant—self.cfgis already dereferenced without a None check on line 311.Minor cleanup for the redundant check
- fuse_loss = self.cfg is not None and self.cfg.get( - "sequence_packing", {} - ).get("fuse_loss", None) + fuse_loss = self.cfg.get("sequence_packing", {}).get("fuse_loss", None)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/models/megatron/train.py` around lines 313 - 328, Add the missing fuse_loss entry to the SequencePackingConfig TypedDict (as NotRequired[bool]) and update the example config YAMLs under examples/configs/*.yaml to document the key and a recommended default (false); then remove the redundant self.cfg is not None guard in the fuse_loss conditional in the megatron train logic where fuse_loss is read (the code that picks SequencePackingFusionLossWrapper vs SequencePackingLossWrapper and constructs the wrapper with cu_seqlens_q/cu_seqlens_q_padded) so the config is consistently typed and documented and the conditional is simplified.
🧹 Nitpick comments (2)
nemo_rl/algorithms/loss_functions.py (2)
1188-1252: Verify thatinput_ids.roll(-1, dims=1)on padded [B, S] is equivalent to per-sequence rolling.Line 1228 rolls the entire
[B, S]tensor row-wise, which wraps the first token of each row into the last position (including padding positions). The non-fusionSequencePackingLossWrapperpath instead rolls each sequence individually within its actual length boundaries viafrom_parallel_logits_to_logprobs_packed_sequences(model_utils.py, lines 599-600).The semantic difference is at position
seq_len - 1of each sequence: the fusion path placesinput_ids[i, seq_len](a padding zero) there, while the non-fusion path wrapsinput_ids[packed_start](the first real token). However, positionseq_len - 1is excluded from the output at line 662 (probs[start_idx : end_idx - 1]), so the difference is harmless.This is a subtle correctness invariant worth a brief inline comment.
💡 Suggested inline comment
# Roll targets on [B, S] (each row shifts independently), then CP-shard and pack. + # NOTE: Full-row roll wraps padding into the last real position, but that + # position is excluded by from_parallel_logits_to_logprobs_packed_sequences + # (which drops the last token per sequence), so the result is equivalent to + # per-sequence rolling done by the non-fused path. rolled_ids = input_ids.roll(-1, dims=1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/algorithms/loss_functions.py` around lines 1188 - 1252, SequencePackingFusionLossWrapper.__call__ currently uses input_ids.roll(-1, dims=1) which wraps padded zeros into the last token position for each row (after _get_tokens_on_this_cp_rank and before packing), differing from the per-sequence roll used by from_parallel_logits_to_logprobs_packed_sequences in the non-fusion path; add a concise inline comment near the input_ids.roll call explaining this semantic difference and why it is safe (the seq_len-1 position is later excluded by the unpacking logic / probs[start_idx:end_idx-1]), referencing input_ids.roll, _get_tokens_on_this_cp_rank, and from_parallel_logits_to_logprobs_packed_sequences so future readers understand the invariant.
128-128:chunk_sizeis never set from configuration.
self.chunk_sizeis initialized toNoneand never populated fromClippedPGLossConfigor any external source. While the attribute exists as a hook forSequencePackingFusionLossWrapper(line 1246:getattr(self.loss_fn, "chunk_size", None)), it will always beNonein practice unless something externally mutates it. If chunked logprob computation during training loss is intended to be supported, consider wiring this through the config or the caller.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/algorithms/loss_functions.py` at line 128, self.chunk_size is initialized to None in the ClippedPGLoss/constructor but never populated from ClippedPGLossConfig or any caller, so SequencePackingFusionLossWrapper's getattr(self.loss_fn, "chunk_size", None) will always return None; update the constructor or config wiring to read a chunk_size value from ClippedPGLossConfig (or an explicit constructor param) and assign it to self.chunk_size (e.g., accept chunk_size in ClippedPGLossConfig or the ClippedPGLoss __init__ and set self.chunk_size = config.chunk_size) so that downstream callers like SequencePackingFusionLossWrapper and loss_fn can detect and use chunked logprob computation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/unit/algorithms/test_sequence_packing_fusion.py`:
- Around line 195-230: The packing currently breaks autograd because you assign
into leaf tensors created with torch.zeros (packed_logits and tmp) using
in-place slices; instead, build packed_logits from tensors derived from
logits_local so gradients flow: in make_logits_and_packed_logits, stop creating
tmp and packed_logits as torch.zeros and avoid in-place slice assignments — for
each sequence produce a slice (e.g., tmp_slice = logits_local[i:i+1, :seq_len,
:] padded via differentiable ops or use indexing + torch.nn.functional.pad), run
it through _get_tokens_on_this_cp_rank, collect those outputs in a list, then
torch.cat the list into packed_logits (or set packed_logits = torch.cat(...)) so
the result is connected to logits_local; alternatively ensure the final
packed_logits has requires_grad_(True) and is produced by non-inplace ops
referencing logits_local rather than assigning into zero tensors.
---
Outside diff comments:
In `@nemo_rl/models/megatron/train.py`:
- Around line 313-328: Add the missing fuse_loss entry to the
SequencePackingConfig TypedDict (as NotRequired[bool]) and update the example
config YAMLs under examples/configs/*.yaml to document the key and a recommended
default (false); then remove the redundant self.cfg is not None guard in the
fuse_loss conditional in the megatron train logic where fuse_loss is read (the
code that picks SequencePackingFusionLossWrapper vs SequencePackingLossWrapper
and constructs the wrapper with cu_seqlens_q/cu_seqlens_q_padded) so the config
is consistently typed and documented and the conditional is simplified.
---
Nitpick comments:
In `@nemo_rl/algorithms/loss_functions.py`:
- Around line 1188-1252: SequencePackingFusionLossWrapper.__call__ currently
uses input_ids.roll(-1, dims=1) which wraps padded zeros into the last token
position for each row (after _get_tokens_on_this_cp_rank and before packing),
differing from the per-sequence roll used by
from_parallel_logits_to_logprobs_packed_sequences in the non-fusion path; add a
concise inline comment near the input_ids.roll call explaining this semantic
difference and why it is safe (the seq_len-1 position is later excluded by the
unpacking logic / probs[start_idx:end_idx-1]), referencing input_ids.roll,
_get_tokens_on_this_cp_rank, and
from_parallel_logits_to_logprobs_packed_sequences so future readers understand
the invariant.
- Line 128: self.chunk_size is initialized to None in the
ClippedPGLoss/constructor but never populated from ClippedPGLossConfig or any
caller, so SequencePackingFusionLossWrapper's getattr(self.loss_fn,
"chunk_size", None) will always return None; update the constructor or config
wiring to read a chunk_size value from ClippedPGLossConfig (or an explicit
constructor param) and assign it to self.chunk_size (e.g., accept chunk_size in
ClippedPGLossConfig or the ClippedPGLoss __init__ and set self.chunk_size =
config.chunk_size) so that downstream callers like
SequencePackingFusionLossWrapper and loss_fn can detect and use chunked logprob
computation.
Signed-off-by: mloh <mloh@nvidia.com>
What does this PR do ?
Apply a single call of loss function to all sequences instead of calling for each individual sequences
Issues
Issue #1247
Usage
Set flag
policy.sequence_packing.fuse_lossto true to turn on this featureResults
Observed up to 15% speedup on policy training flops
Validation results show similar accuracy curve
Additional Information
Check out this report for more detailed analysis
Summary by CodeRabbit
Release Notes
New Features
Tests