Skip to content

Comments

perf: Fuse sequence packing for loss function#1904

Open
nujoug wants to merge 10 commits intomainfrom
mloh/seqpack_fusion
Open

perf: Fuse sequence packing for loss function#1904
nujoug wants to merge 10 commits intomainfrom
mloh/seqpack_fusion

Conversation

@nujoug
Copy link

@nujoug nujoug commented Feb 10, 2026

What does this PR do ?

Apply a single call of loss function to all sequences instead of calling for each individual sequences

Issues

Issue #1247

Usage

Set flag policy.sequence_packing.fuse_loss to true to turn on this feature

Results

Observed up to 15% speedup on policy training flops

W B Chart 2_20_2026, 2_26_59 PM

Validation results show similar accuracy curve

W B Chart 2_20_2026, 1_46_27 PM

Additional Information

Check out this report for more detailed analysis

Summary by CodeRabbit

Release Notes

  • New Features

    • Added fused sequence packing optimization that combines packing and logprob computation in a single forward pass for improved performance.
    • New configuration option to enable fused loss computation path when available.
  • Tests

    • Added comprehensive distributed tests validating fused sequence packing optimizations across multiple parallelism configurations.

@nujoug nujoug force-pushed the mloh/seqpack_fusion branch 2 times, most recently from 197b783 to c6d4fbb Compare February 11, 2026 00:37
@nujoug nujoug marked this pull request as ready for review February 11, 2026 18:42
@nujoug nujoug requested review from a team as code owners February 11, 2026 18:42
@nujoug nujoug requested a review from guyueh1 February 11, 2026 18:42
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 11, 2026

📝 Walkthrough

Walkthrough

This PR introduces a fused sequence packing optimization that combines sequence packing with logprob computation in a single forward pass. It refactors ClippedPGLossFn to separate logprob and loss computation, adds SequencePackingFusionLossWrapper for the fused path, extends distributed utilities to support pre-rolled targets, updates the training pipeline to conditionally use the fusion wrapper, and provides comprehensive distributed tests.

Changes

Cohort / File(s) Summary
Loss function refactoring
nemo_rl/algorithms/loss_functions.py
Extracted _compute_curr_logprobs() and _compute_loss_from_logprobs() internal methods from ClippedPGLossFn.__call__() to support external logprob sources. Added chunk_size attribute and new SequencePackingFusionLossWrapper class that fuses sequence packing and logprob computation before delegating loss computation to the wrapped loss function.
Distributed utilities
nemo_rl/distributed/model_utils.py
Added target_is_pre_rolled parameter to from_parallel_logits_to_logprobs_packed_sequences() to skip internal rolling/CP-sharding when targets are pre-processed. Renamed cu_seqlens parameter to cu_seqlens_padded for clarity.
Training pipeline integration
nemo_rl/models/megatron/train.py
Updated LossPostProcessor to conditionally select between SequencePackingFusionLossWrapper and SequencePackingLossWrapper based on cfg["sequence_packing"]["fuse_loss"] configuration flag.
Distributed test suite
tests/unit/algorithms/test_sequence_packing_fusion.py
New comprehensive test module with Ray-based actor infrastructure validating fusion wrapper against baseline wrapper across six CP/TP configurations (1x1, 1x2, 2x1, 2x2, 2x4, 4x2), comparing losses and gradients for correctness.

Sequence Diagram(s)

sequenceDiagram
    participant Client as Trainer/Client
    participant Fusion as SequencePackingFusionLossWrapper
    participant Distributed as from_parallel_logits_to_logprobs_packed_sequences
    participant Loss as ClippedPGLossFn

    Client->>Fusion: __call__(next_token_logits, data)
    Fusion->>Fusion: Pack input sequences (nvtx-wrapped)
    Fusion->>Distributed: Compute logprobs from packed logits
    Distributed-->>Fusion: curr_logprobs (pre-rolled)
    Fusion->>Loss: _compute_loss_from_logprobs(curr_logprobs, data)
    Loss-->>Fusion: loss, metrics
    Fusion-->>Client: loss, metrics

    rect rgba(200, 100, 100, 0.5)
    Note over Fusion,Loss: Fused path: single forward pass combines<br/>packing + logprob + loss computation
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 68.42% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'perf: Fuse sequence packing for loss function' directly and clearly describes the main performance optimization implemented in the PR: fusing sequence packing with loss computation.
Test Results For Major Changes ✅ Passed PR documentation includes performance charts, FLOPs metrics, validation accuracy comparison, and quantified speedup claims for this major 700+ line feature addition.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch mloh/seqpack_fusion

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/loss_functions.py`:
- Around line 1134-1144: The call to
from_parallel_logits_to_logprobs_packed_sequences uses vocab_parallel_rank
without guarding for None, which will raise a TypeError when None is provided;
add an explicit assertion (like the one used in _compute_curr_logprobs) that
vocab_parallel_rank is not None before computing
vocab_start_index/vocab_end_index, and raise a clear error message if it is None
so callers know the parallel-vocab requirement; update the block around the
curr_logprobs computation (the call to
from_parallel_logits_to_logprobs_packed_sequences and the variables
vocab_parallel_rank/vocab_parallel_group) to perform this check first.

In `@nemo_rl/models/megatron/common.py`:
- Around line 130-148: Add the new `fuse_loss` key to the SequencePackingConfig
TypedDict in nemo_rl/models/policy/__init__.py with a short docstring describing
purpose, valid values (bool), and the recommended default; remove the in-code
default by updating the usage in the megatron wrapper to read fuse_loss =
policy_cfg.get("sequence_packing", {}).get("fuse_loss") (no fallback to False)
so the code no longer embeds a default; place the default value for
sequence_packing.fuse_loss into the exemplar YAMLs under examples/configs/*.yaml
(so YAML is the single source of truth); keep the existing behavior that selects
SequencePackingFusionLossWrapper vs SequencePackingLossWrapper and the
conditional data_dict["packed_input_ids"] assignment, but rely on the
YAML-provided default rather than .get(..., False).

In `@tests/unit/algorithms/test_sequence_packing_fusion.py`:
- Line 445: Remove the unused variable assignment to `world_size = cp_size *
tp_size` in the test (the `world_size` local is never referenced); simply delete
that line (or if the value was intended to be used, replace the assignment by
using `world_size` where needed) so that `world_size`, `cp_size`, and `tp_size`
are not assigned without use in the test `test_sequence_packing_fusion.py`.
🧹 Nitpick comments (1)
nemo_rl/algorithms/loss_functions.py (1)

1147-1149: Duck-typing on _compute_loss_from_logprobs — consider a runtime check or Protocol.

self.loss_fn._compute_loss_from_logprobs(...) will raise AttributeError if the wrapped loss_fn doesn't implement this method (e.g., NLLLoss, DPOLossFn). While the docstring documents the requirement, a hasattr check in __init__ would produce a clear error at construction time rather than mid-training.

♻️ Proposed guard in __init__
     def __init__(
         self,
         loss_fn: LossFunction,
         cu_seqlens_q: torch.Tensor,
         cu_seqlens_q_padded: Optional[torch.Tensor] = None,
     ):
+        if not hasattr(loss_fn, "_compute_loss_from_logprobs"):
+            raise TypeError(
+                f"{type(loss_fn).__name__} does not implement _compute_loss_from_logprobs. "
+                "SequencePackingFusionLossWrapper requires a loss function with this method "
+                "(e.g., ClippedPGLossFn)."
+            )
         self.loss_fn = loss_fn

Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
@nujoug nujoug force-pushed the mloh/seqpack_fusion branch from 69e416a to 371c308 Compare February 20, 2026 21:40
Signed-off-by: mloh <mloh@nvidia.com>
@nujoug nujoug added CI:L2 Run doctests, unit tests, functional tests, and convergence tests and removed CI:L2 Run doctests, unit tests, functional tests, and convergence tests labels Feb 20, 2026
@nujoug
Copy link
Author

nujoug commented Feb 21, 2026

@CodeRabbit review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 21, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/models/megatron/train.py (1)

313-328: 🛠️ Refactor suggestion | 🟠 Major

Add fuse_loss to SequencePackingConfig TypedDict and document in exemplar YAMLs.

The fuse_loss key is used in code (line 314–316) but not declared in the SequencePackingConfig TypedDict definition. Per coding guidelines, new config keys must be documented in the TypedDict with type and NotRequired marker where appropriate, and reflected in exemplar YAMLs under examples/configs/*.yaml with recommended defaults.

Additionally, the self.cfg is not None guard on line 314 is redundant—self.cfg is already dereferenced without a None check on line 311.

Minor cleanup for the redundant check
-            fuse_loss = self.cfg is not None and self.cfg.get(
-                "sequence_packing", {}
-            ).get("fuse_loss", None)
+            fuse_loss = self.cfg.get("sequence_packing", {}).get("fuse_loss", None)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/models/megatron/train.py` around lines 313 - 328, Add the missing
fuse_loss entry to the SequencePackingConfig TypedDict (as NotRequired[bool])
and update the example config YAMLs under examples/configs/*.yaml to document
the key and a recommended default (false); then remove the redundant self.cfg is
not None guard in the fuse_loss conditional in the megatron train logic where
fuse_loss is read (the code that picks SequencePackingFusionLossWrapper vs
SequencePackingLossWrapper and constructs the wrapper with
cu_seqlens_q/cu_seqlens_q_padded) so the config is consistently typed and
documented and the conditional is simplified.
🧹 Nitpick comments (2)
nemo_rl/algorithms/loss_functions.py (2)

1188-1252: Verify that input_ids.roll(-1, dims=1) on padded [B, S] is equivalent to per-sequence rolling.

Line 1228 rolls the entire [B, S] tensor row-wise, which wraps the first token of each row into the last position (including padding positions). The non-fusion SequencePackingLossWrapper path instead rolls each sequence individually within its actual length boundaries via from_parallel_logits_to_logprobs_packed_sequences (model_utils.py, lines 599-600).

The semantic difference is at position seq_len - 1 of each sequence: the fusion path places input_ids[i, seq_len] (a padding zero) there, while the non-fusion path wraps input_ids[packed_start] (the first real token). However, position seq_len - 1 is excluded from the output at line 662 (probs[start_idx : end_idx - 1]), so the difference is harmless.

This is a subtle correctness invariant worth a brief inline comment.

💡 Suggested inline comment
         # Roll targets on [B, S] (each row shifts independently), then CP-shard and pack.
+        # NOTE: Full-row roll wraps padding into the last real position, but that
+        # position is excluded by from_parallel_logits_to_logprobs_packed_sequences
+        # (which drops the last token per sequence), so the result is equivalent to
+        # per-sequence rolling done by the non-fused path.
         rolled_ids = input_ids.roll(-1, dims=1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/algorithms/loss_functions.py` around lines 1188 - 1252,
SequencePackingFusionLossWrapper.__call__ currently uses input_ids.roll(-1,
dims=1) which wraps padded zeros into the last token position for each row
(after _get_tokens_on_this_cp_rank and before packing), differing from the
per-sequence roll used by from_parallel_logits_to_logprobs_packed_sequences in
the non-fusion path; add a concise inline comment near the input_ids.roll call
explaining this semantic difference and why it is safe (the seq_len-1 position
is later excluded by the unpacking logic / probs[start_idx:end_idx-1]),
referencing input_ids.roll, _get_tokens_on_this_cp_rank, and
from_parallel_logits_to_logprobs_packed_sequences so future readers understand
the invariant.

128-128: chunk_size is never set from configuration.

self.chunk_size is initialized to None and never populated from ClippedPGLossConfig or any external source. While the attribute exists as a hook for SequencePackingFusionLossWrapper (line 1246: getattr(self.loss_fn, "chunk_size", None)), it will always be None in practice unless something externally mutates it. If chunked logprob computation during training loss is intended to be supported, consider wiring this through the config or the caller.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/algorithms/loss_functions.py` at line 128, self.chunk_size is
initialized to None in the ClippedPGLoss/constructor but never populated from
ClippedPGLossConfig or any caller, so SequencePackingFusionLossWrapper's
getattr(self.loss_fn, "chunk_size", None) will always return None; update the
constructor or config wiring to read a chunk_size value from ClippedPGLossConfig
(or an explicit constructor param) and assign it to self.chunk_size (e.g.,
accept chunk_size in ClippedPGLossConfig or the ClippedPGLoss __init__ and set
self.chunk_size = config.chunk_size) so that downstream callers like
SequencePackingFusionLossWrapper and loss_fn can detect and use chunked logprob
computation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/unit/algorithms/test_sequence_packing_fusion.py`:
- Around line 195-230: The packing currently breaks autograd because you assign
into leaf tensors created with torch.zeros (packed_logits and tmp) using
in-place slices; instead, build packed_logits from tensors derived from
logits_local so gradients flow: in make_logits_and_packed_logits, stop creating
tmp and packed_logits as torch.zeros and avoid in-place slice assignments — for
each sequence produce a slice (e.g., tmp_slice = logits_local[i:i+1, :seq_len,
:] padded via differentiable ops or use indexing + torch.nn.functional.pad), run
it through _get_tokens_on_this_cp_rank, collect those outputs in a list, then
torch.cat the list into packed_logits (or set packed_logits = torch.cat(...)) so
the result is connected to logits_local; alternatively ensure the final
packed_logits has requires_grad_(True) and is produced by non-inplace ops
referencing logits_local rather than assigning into zero tensors.

---

Outside diff comments:
In `@nemo_rl/models/megatron/train.py`:
- Around line 313-328: Add the missing fuse_loss entry to the
SequencePackingConfig TypedDict (as NotRequired[bool]) and update the example
config YAMLs under examples/configs/*.yaml to document the key and a recommended
default (false); then remove the redundant self.cfg is not None guard in the
fuse_loss conditional in the megatron train logic where fuse_loss is read (the
code that picks SequencePackingFusionLossWrapper vs SequencePackingLossWrapper
and constructs the wrapper with cu_seqlens_q/cu_seqlens_q_padded) so the config
is consistently typed and documented and the conditional is simplified.

---

Nitpick comments:
In `@nemo_rl/algorithms/loss_functions.py`:
- Around line 1188-1252: SequencePackingFusionLossWrapper.__call__ currently
uses input_ids.roll(-1, dims=1) which wraps padded zeros into the last token
position for each row (after _get_tokens_on_this_cp_rank and before packing),
differing from the per-sequence roll used by
from_parallel_logits_to_logprobs_packed_sequences in the non-fusion path; add a
concise inline comment near the input_ids.roll call explaining this semantic
difference and why it is safe (the seq_len-1 position is later excluded by the
unpacking logic / probs[start_idx:end_idx-1]), referencing input_ids.roll,
_get_tokens_on_this_cp_rank, and
from_parallel_logits_to_logprobs_packed_sequences so future readers understand
the invariant.
- Line 128: self.chunk_size is initialized to None in the
ClippedPGLoss/constructor but never populated from ClippedPGLossConfig or any
caller, so SequencePackingFusionLossWrapper's getattr(self.loss_fn,
"chunk_size", None) will always return None; update the constructor or config
wiring to read a chunk_size value from ClippedPGLossConfig (or an explicit
constructor param) and assign it to self.chunk_size (e.g., accept chunk_size in
ClippedPGLossConfig or the ClippedPGLoss __init__ and set self.chunk_size =
config.chunk_size) so that downstream callers like
SequencePackingFusionLossWrapper and loss_fn can detect and use chunked logprob
computation.

@nujoug nujoug added the CI:L2 Run doctests, unit tests, functional tests, and convergence tests label Feb 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L2 Run doctests, unit tests, functional tests, and convergence tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant