Skip to content

Comments

perf: shard concat overhead#2002

Open
pjo256 wants to merge 6 commits intoNVIDIA-NeMo:mainfrom
pjo256:shard-concat-overhead
Open

perf: shard concat overhead#2002
pjo256 wants to merge 6 commits intoNVIDIA-NeMo:mainfrom
pjo256:shard-concat-overhead

Conversation

@pjo256
Copy link

@pjo256 pjo256 commented Feb 21, 2026

What does this PR do ?

Optimize BatchedDataDict.shard_by_batch_size() by replacing expensive per-chunk concats with single-allocation shard aggregation (pre-allocated tensor + a single PackedTensor concat), reducing allocation overhead.

Issues

Closes #2001.

Additional Information

I tried a single concat for both tensors and PackedTensor, but the overhead of storing N chunk torch.tensors before the concat was slower than the baseline at smaller chunk sizes. Instead, we pre-allocate a per-shard output tensor to make use of copy.

Performance benchmark

Click to expand benchmark script
import time
import torch
from nemo_rl.distributed.batched_data_dict import BatchedDataDict

CHUNK_BATCH_SIZE = 512
SHARDS = 8
ITERS = 20
SAMPLE_SIZES = [512, 1024, 2048, 4096, 8192, 16384]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def cuda_sync():
    if DEVICE == "cuda":
        torch.cuda.synchronize()

def bench(samples):
    d = {
        "input_ids": torch.randint(0, 1000, (samples, 1024), dtype=torch.int64, device=DEVICE),
        "input_lengths": torch.randint(64, 1024, (samples,), dtype=torch.int64, device=DEVICE),
        "advantages": torch.randn(samples, 1024, device=DEVICE),
    }
    b = BatchedDataDict(d)

    for _ in range(5):
        _ = b.shard_by_batch_size(shards=SHARDS, batch_size=CHUNK_BATCH_SIZE)
    cuda_sync()
    t0 = time.perf_counter()
    for _ in range(ITERS):
        _ = b.shard_by_batch_size(shards=SHARDS, batch_size=CHUNK_BATCH_SIZE)
    cuda_sync()

    return (time.perf_counter() - t0) / ITERS * 1000.0

print(f"config: device={DEVICE}, chunk_batch_size={CHUNK_BATCH_SIZE}, shards={SHARDS}, iters={ITERS}")
print("samples\tchunks\tavg_ms")
for samples in SAMPLE_SIZES:
    avg_ms = bench(samples)
    chunks = samples // CHUNK_BATCH_SIZE
    print(f"{samples}\t{chunks}\t{avg_ms:.3f}")

cuda

samples chunks before after speedup
512 1 2.330 ms 1.021 ms 2.28x
1024 2 5.531 ms 1.769 ms 3.13x
2048 4 11.724 ms 2.189 ms 5.36x
4096 8 21.993 ms 4.031 ms 5.46x
8192 16 43.707 ms 9.813 ms 4.45x
16384 32 108.990 ms 12.968 ms 8.40x

cpu

samples chunks before after speedup
512 1 3.758 ms 0.473 ms 7.94x
1024 2 7.455 ms 6.203 ms 1.20x
2048 4 12.575 ms 12.624 ms 1.00x
4096 8 36.089 ms 25.394 ms 1.42x
8192 16 77.890 ms 47.339 ms 1.65x
16384 32 358.891 ms 117.545 ms 3.05x

Testing

pytest -q tests/unit/distributed/test_batched_data_dict.py

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Signed-off-by: Philip Ottesen <phiott256@gmail.com>
Signed-off-by: Philip Ottesen <phiott256@gmail.com>
Signed-off-by: Philip Ottesen <phiott256@gmail.com>
Signed-off-by: Philip Ottesen <phiott256@gmail.com>
@pjo256 pjo256 requested a review from a team as a code owner February 21, 2026 16:37
@pjo256 pjo256 changed the title Shard concat overhead perf: shard concat overhead Feb 21, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 21, 2026

📝 Walkthrough

Walkthrough

Optimized shard aggregation in _get_padded_seqlen by replacing incremental concatenation with a two-phase approach: first collect valid shard ranges, then preallocate and copy tensor slices and process PackedTensors in single passes rather than per-chunk concatenations.

Changes

Cohort / File(s) Summary
Shard Aggregation Optimization
nemo_rl/distributed/batched_data_dict.py
Reworked _get_padded_seqlen aggregation logic: replaced incremental per-key concatenation with two-phase approach using shard_ranges collection, tensor preallocation with slice copying, PackedTensor single-pass concatenation, and streamlined value accumulation to eliminate repeated concatenation overhead.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning Bug analysis: PackedTensor with empty shard_ranges does not assign to aggregated_shards due to 'if packed_slices' guard, leaving key missing in result. The PackedTensor case requires fixing to handle empty shards consistently with tensor and list cases by assigning empty PackedTensor when shard_ranges is empty.
✅ Passed checks (5 passed)
Check name Status Explanation
Linked Issues check ✅ Passed The code changes directly address issue #2001 by replacing incremental chunk concatenation with a two-phase shard aggregation approach that preallocates tensors and performs single-pass copies, eliminating repeated concat overhead.
Out of Scope Changes check ✅ Passed All changes are focused on optimizing the shard aggregation logic in BatchedDataDict._get_padded_seqlen to address the repeated concatenation issue described in issue #2001.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Title check ✅ Passed The title 'perf: shard concat overhead' is directly related to the main change, which optimizes shard aggregation by eliminating incremental concatenation overhead.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@nemo_rl/distributed/batched_data_dict.py`:
- Around line 572-582: The PackedTensor branch currently skips assignment when
shard_ranges is empty, leaving aggregated_shards[shard_idx][k] missing and
causing downstream KeyError; fix the elif isinstance(v, PackedTensor) block in
batched_data_dict.py to detect empty shard_ranges (or empty packed_slices) and
assign PackedTensor.empty_like(v) instead of skipping or attempting concat/slice
on empty lists so that aggregated_shards always contains the key with an
empty-like PackedTensor representation.

Signed-off-by: Philip Ottesen <phiott256@gmail.com>
@pjo256 pjo256 requested a review from a team as a code owner February 21, 2026 17:28
Signed-off-by: Philip Ottesen <phiott256@gmail.com>
@pjo256 pjo256 force-pushed the shard-concat-overhead branch from 17e358b to 534d1c5 Compare February 21, 2026 19:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BatchDataDict chunking overhead from repeated concats

1 participant