From 46d2ffe21bec3dddd5ba55a6a573b5f3c015f507 Mon Sep 17 00:00:00 2001 From: Philip Ottesen Date: Sat, 21 Feb 2026 16:09:36 +0100 Subject: [PATCH 1/6] fix: avoid repeated concat copies in shard_by_batch_size Signed-off-by: Philip Ottesen --- nemo_rl/distributed/batched_data_dict.py | 60 ++++++++++++------------ 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py index 20b39f2b50..b4fc970a66 100644 --- a/nemo_rl/distributed/batched_data_dict.py +++ b/nemo_rl/distributed/batched_data_dict.py @@ -536,6 +536,12 @@ def _get_padded_seqlen(seqlen: int) -> int: data = self.data aggregated_shards = [SlicedDataDict() for _ in range(shards)] + shard_tensor_chunks: list[dict[str, list[torch.Tensor]]] = [ + {} for _ in range(shards) + ] + shard_packed_chunks: list[dict[str, list[PackedTensor]]] = [ + {} for _ in range(shards) + ] # Group data by shard position across all chunks for shard_idx in range(shards): @@ -552,38 +558,30 @@ def _get_padded_seqlen(seqlen: int) -> int: indices = torch.arange(shard_start, shard_end) for k in data: - if k not in aggregated_shards[shard_idx]: - # First time seeing this key for this shard, initialize it - if torch.is_tensor(data[k]): - aggregated_shards[shard_idx][k] = data[k][indices].clone() - elif isinstance(data[k], PackedTensor): - aggregated_shards[shard_idx][k] = data[k].slice( - indices.tolist() - ) - else: - aggregated_shards[shard_idx][k] = [ - data[k][i] for i in indices - ] + if torch.is_tensor(data[k]): + shard_tensor_chunks[shard_idx].setdefault(k, []).append( + data[k][indices].clone() + ) + elif isinstance(data[k], PackedTensor): + shard_packed_chunks[shard_idx].setdefault(k, []).append( + data[k].slice(indices.tolist()) + ) + elif k not in aggregated_shards[shard_idx]: + aggregated_shards[shard_idx][k] = [data[k][i] for i in indices] else: - # Append to existing data - concatenate tensors or extend lists - if torch.is_tensor(data[k]): - aggregated_shards[shard_idx][k] = torch.cat( - [ - aggregated_shards[shard_idx][k], - data[k][indices].clone(), - ] - ) - elif isinstance(data[k], PackedTensor): - aggregated_shards[shard_idx][k] = PackedTensor.concat( - [ - aggregated_shards[shard_idx][k], - data[k].slice(indices.tolist()), - ] - ) - else: - aggregated_shards[shard_idx][k].extend( - [data[k][i] for i in indices] - ) + aggregated_shards[shard_idx][k].extend( + [data[k][i] for i in indices] + ) + + for shard_idx in range(shards): + for k, chunks in shard_tensor_chunks[shard_idx].items(): + aggregated_shards[shard_idx][k] = ( + torch.cat(chunks, dim=0) if len(chunks) > 1 else chunks[0] + ) + for k, chunks in shard_packed_chunks[shard_idx].items(): + aggregated_shards[shard_idx][k] = ( + PackedTensor.concat(chunks) if len(chunks) > 1 else chunks[0] + ) # map inputs to microbatches such that the total number tokens in # a microbatch is as close to (including padding tokens) 'max_tokens_per_microbatch' From fbef9a3ff79a92c1f85f362725bb25799dcba01f Mon Sep 17 00:00:00 2001 From: Philip Ottesen Date: Sat, 21 Feb 2026 17:01:23 +0100 Subject: [PATCH 2/6] fix: preallocate shard tensors to remove chunk concat copy overhead Signed-off-by: Philip Ottesen --- nemo_rl/distributed/batched_data_dict.py | 95 +++++++++++++++--------- 1 file changed, 59 insertions(+), 36 deletions(-) diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py index b4fc970a66..e23ce55121 100644 --- a/nemo_rl/distributed/batched_data_dict.py +++ b/nemo_rl/distributed/batched_data_dict.py @@ -536,52 +536,75 @@ def _get_padded_seqlen(seqlen: int) -> int: data = self.data aggregated_shards = [SlicedDataDict() for _ in range(shards)] - shard_tensor_chunks: list[dict[str, list[torch.Tensor]]] = [ - {} for _ in range(shards) - ] - shard_packed_chunks: list[dict[str, list[PackedTensor]]] = [ - {} for _ in range(shards) - ] - - # Group data by shard position across all chunks + tensor_keys = [k for k, v in data.items() if torch.is_tensor(v)] + packed_keys = [k for k, v in data.items() if isinstance(v, PackedTensor)] + other_keys = [k for k in data if k not in tensor_keys and k not in packed_keys] + + # Group data by shard position across all chunks. + # Tensor fields are pre-allocated and filled exactly once per chunk. for shard_idx in range(shards): + shard_tensor_offsets: dict[str, int] = {} + for k in tensor_keys: + if allow_uneven_shards: + rows = 0 + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * batch_size + shard_start = min( + chunk_start + shard_idx * shard_size, total_batch_size + ) + shard_end = min( + chunk_start + (shard_idx + 1) * shard_size, total_batch_size + ) + rows += max(0, shard_end - shard_start) + else: + rows = num_chunks * shard_size + + aggregated_shards[shard_idx][k] = torch.empty( + (rows, *data[k].shape[1:]), + dtype=data[k].dtype, + device=data[k].device, + ) + shard_tensor_offsets[k] = 0 + + shard_packed_chunks: dict[str, list[PackedTensor]] = { + k: [] for k in packed_keys + } + for k in other_keys: + aggregated_shards[shard_idx][k] = [] + for chunk_idx in range(num_chunks): - # Calculate indices for this particular sub-shard within the chunk chunk_start = chunk_idx * batch_size shard_start = chunk_start + shard_idx * shard_size shard_end = chunk_start + (shard_idx + 1) * shard_size if allow_uneven_shards: - # Cap the end index at the total batch size for the last shard - # or if shard_end calculation goes beyond total_batch_size shard_start = min(shard_start, total_batch_size) shard_end = min(shard_end, total_batch_size) - indices = torch.arange(shard_start, shard_end) + if shard_start >= shard_end: + continue + + slice_len = shard_end - shard_start + for k in tensor_keys: + shard_tensor = aggregated_shards[shard_idx][k] + offset = shard_tensor_offsets[k] + shard_tensor[offset : offset + slice_len].copy_( + data[k][shard_start:shard_end] + ) + shard_tensor_offsets[k] = offset + slice_len - for k in data: - if torch.is_tensor(data[k]): - shard_tensor_chunks[shard_idx].setdefault(k, []).append( - data[k][indices].clone() - ) - elif isinstance(data[k], PackedTensor): - shard_packed_chunks[shard_idx].setdefault(k, []).append( - data[k].slice(indices.tolist()) - ) - elif k not in aggregated_shards[shard_idx]: - aggregated_shards[shard_idx][k] = [data[k][i] for i in indices] - else: - aggregated_shards[shard_idx][k].extend( - [data[k][i] for i in indices] - ) + packed_slice_indices = list(range(shard_start, shard_end)) + for k in packed_keys: + shard_packed_chunks[k].append(data[k].slice(packed_slice_indices)) - for shard_idx in range(shards): - for k, chunks in shard_tensor_chunks[shard_idx].items(): - aggregated_shards[shard_idx][k] = ( - torch.cat(chunks, dim=0) if len(chunks) > 1 else chunks[0] - ) - for k, chunks in shard_packed_chunks[shard_idx].items(): - aggregated_shards[shard_idx][k] = ( - PackedTensor.concat(chunks) if len(chunks) > 1 else chunks[0] - ) + for k in other_keys: + aggregated_shards[shard_idx][k].extend( + [data[k][i] for i in packed_slice_indices] + ) + + for k, chunks in shard_packed_chunks.items(): + if chunks: + aggregated_shards[shard_idx][k] = ( + PackedTensor.concat(chunks) if len(chunks) > 1 else chunks[0] + ) # map inputs to microbatches such that the total number tokens in # a microbatch is as close to (including padding tokens) 'max_tokens_per_microbatch' From 6b7d11f7bd941741087744624d494997f5e189b4 Mon Sep 17 00:00:00 2001 From: Philip Ottesen Date: Sat, 21 Feb 2026 17:12:34 +0100 Subject: [PATCH 3/6] fix: simplify with per-shard range aggregation Signed-off-by: Philip Ottesen --- nemo_rl/distributed/batched_data_dict.py | 92 +++++++++--------------- 1 file changed, 35 insertions(+), 57 deletions(-) diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py index e23ce55121..d9d51589fb 100644 --- a/nemo_rl/distributed/batched_data_dict.py +++ b/nemo_rl/distributed/batched_data_dict.py @@ -536,42 +536,10 @@ def _get_padded_seqlen(seqlen: int) -> int: data = self.data aggregated_shards = [SlicedDataDict() for _ in range(shards)] - tensor_keys = [k for k, v in data.items() if torch.is_tensor(v)] - packed_keys = [k for k, v in data.items() if isinstance(v, PackedTensor)] - other_keys = [k for k in data if k not in tensor_keys and k not in packed_keys] # Group data by shard position across all chunks. - # Tensor fields are pre-allocated and filled exactly once per chunk. for shard_idx in range(shards): - shard_tensor_offsets: dict[str, int] = {} - for k in tensor_keys: - if allow_uneven_shards: - rows = 0 - for chunk_idx in range(num_chunks): - chunk_start = chunk_idx * batch_size - shard_start = min( - chunk_start + shard_idx * shard_size, total_batch_size - ) - shard_end = min( - chunk_start + (shard_idx + 1) * shard_size, total_batch_size - ) - rows += max(0, shard_end - shard_start) - else: - rows = num_chunks * shard_size - - aggregated_shards[shard_idx][k] = torch.empty( - (rows, *data[k].shape[1:]), - dtype=data[k].dtype, - device=data[k].device, - ) - shard_tensor_offsets[k] = 0 - - shard_packed_chunks: dict[str, list[PackedTensor]] = { - k: [] for k in packed_keys - } - for k in other_keys: - aggregated_shards[shard_idx][k] = [] - + shard_ranges: list[tuple[int, int]] = [] for chunk_idx in range(num_chunks): chunk_start = chunk_idx * batch_size shard_start = chunk_start + shard_idx * shard_size @@ -579,32 +547,42 @@ def _get_padded_seqlen(seqlen: int) -> int: if allow_uneven_shards: shard_start = min(shard_start, total_batch_size) shard_end = min(shard_end, total_batch_size) - if shard_start >= shard_end: - continue - - slice_len = shard_end - shard_start - for k in tensor_keys: - shard_tensor = aggregated_shards[shard_idx][k] - offset = shard_tensor_offsets[k] - shard_tensor[offset : offset + slice_len].copy_( - data[k][shard_start:shard_end] - ) - shard_tensor_offsets[k] = offset + slice_len + if shard_start < shard_end: + shard_ranges.append((shard_start, shard_end)) - packed_slice_indices = list(range(shard_start, shard_end)) - for k in packed_keys: - shard_packed_chunks[k].append(data[k].slice(packed_slice_indices)) - - for k in other_keys: - aggregated_shards[shard_idx][k].extend( - [data[k][i] for i in packed_slice_indices] - ) - - for k, chunks in shard_packed_chunks.items(): - if chunks: - aggregated_shards[shard_idx][k] = ( - PackedTensor.concat(chunks) if len(chunks) > 1 else chunks[0] + # Process each key by data type. + for k, v in data.items(): + if torch.is_tensor(v): + # Pre-allocate and copy each chunk exactly once. + rows = sum(end - start for start, end in shard_ranges) + shard_tensor = torch.empty( + (rows, *v.shape[1:]), + dtype=v.dtype, + device=v.device, ) + offset = 0 + for start, end in shard_ranges: + span = end - start + shard_tensor[offset : offset + span].copy_(v[start:end]) + offset += span + aggregated_shards[shard_idx][k] = shard_tensor + elif isinstance(v, PackedTensor): + # PackedTensor is collected per chunk then concatenated once. + packed_slices = [ + v.slice(list(range(start, end))) for start, end in shard_ranges + ] + if packed_slices: + aggregated_shards[shard_idx][k] = ( + PackedTensor.concat(packed_slices) + if len(packed_slices) > 1 + else packed_slices[0] + ) + else: + # Append list-like data in shard order across chunks. + shard_values = [] + for start, end in shard_ranges: + shard_values.extend([v[i] for i in range(start, end)]) + aggregated_shards[shard_idx][k] = shard_values # map inputs to microbatches such that the total number tokens in # a microbatch is as close to (including padding tokens) 'max_tokens_per_microbatch' From 950c1c936980bb5fedfa5ce520e793fd3f330301 Mon Sep 17 00:00:00 2001 From: Philip Ottesen Date: Sat, 21 Feb 2026 17:36:31 +0100 Subject: [PATCH 4/6] docs: cleanup comments Signed-off-by: Philip Ottesen --- nemo_rl/distributed/batched_data_dict.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py index d9d51589fb..3cd4dda8dc 100644 --- a/nemo_rl/distributed/batched_data_dict.py +++ b/nemo_rl/distributed/batched_data_dict.py @@ -541,16 +541,19 @@ def _get_padded_seqlen(seqlen: int) -> int: for shard_idx in range(shards): shard_ranges: list[tuple[int, int]] = [] for chunk_idx in range(num_chunks): + # Calculate indices for this particular sub-shard within the chunk chunk_start = chunk_idx * batch_size shard_start = chunk_start + shard_idx * shard_size shard_end = chunk_start + (shard_idx + 1) * shard_size if allow_uneven_shards: + # Cap the end index at the total batch size for the last shard + # or if shard_end calculation goes beyond total_batch_size shard_start = min(shard_start, total_batch_size) shard_end = min(shard_end, total_batch_size) + if shard_start < shard_end: shard_ranges.append((shard_start, shard_end)) - # Process each key by data type. for k, v in data.items(): if torch.is_tensor(v): # Pre-allocate and copy each chunk exactly once. @@ -578,10 +581,10 @@ def _get_padded_seqlen(seqlen: int) -> int: else packed_slices[0] ) else: - # Append list-like data in shard order across chunks. shard_values = [] for start, end in shard_ranges: shard_values.extend([v[i] for i in range(start, end)]) + aggregated_shards[shard_idx][k] = shard_values # map inputs to microbatches such that the total number tokens in From aca061dab4bef0e300b81eb42be76f1323e32b81 Mon Sep 17 00:00:00 2001 From: Philip Ottesen Date: Sat, 21 Feb 2026 18:28:10 +0100 Subject: [PATCH 5/6] tests: fix missing empty PackedTensor, add a unit test for it Signed-off-by: Philip Ottesen --- nemo_rl/distributed/batched_data_dict.py | 17 +++++------ .../distributed/test_batched_data_dict.py | 30 +++++++++++++++++++ 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py index 3cd4dda8dc..5296b599ee 100644 --- a/nemo_rl/distributed/batched_data_dict.py +++ b/nemo_rl/distributed/batched_data_dict.py @@ -537,7 +537,7 @@ def _get_padded_seqlen(seqlen: int) -> int: aggregated_shards = [SlicedDataDict() for _ in range(shards)] - # Group data by shard position across all chunks. + # Group data by shard position across all chunks for shard_idx in range(shards): shard_ranges: list[tuple[int, int]] = [] for chunk_idx in range(num_chunks): @@ -556,7 +556,7 @@ def _get_padded_seqlen(seqlen: int) -> int: for k, v in data.items(): if torch.is_tensor(v): - # Pre-allocate and copy each chunk exactly once. + # Pre-allocate and copy each chunk once rows = sum(end - start for start, end in shard_ranges) shard_tensor = torch.empty( (rows, *v.shape[1:]), @@ -568,23 +568,20 @@ def _get_padded_seqlen(seqlen: int) -> int: span = end - start shard_tensor[offset : offset + span].copy_(v[start:end]) offset += span + aggregated_shards[shard_idx][k] = shard_tensor elif isinstance(v, PackedTensor): - # PackedTensor is collected per chunk then concatenated once. + # PackedTensor is collected per chunk then concatenated once packed_slices = [ v.slice(list(range(start, end))) for start, end in shard_ranges ] - if packed_slices: - aggregated_shards[shard_idx][k] = ( - PackedTensor.concat(packed_slices) - if len(packed_slices) > 1 - else packed_slices[0] - ) + + aggregated_shards[shard_idx][k] = PackedTensor.concat(packed_slices) if packed_slices else PackedTensor.empty_like(v) else: shard_values = [] for start, end in shard_ranges: shard_values.extend([v[i] for i in range(start, end)]) - + aggregated_shards[shard_idx][k] = shard_values # map inputs to microbatches such that the total number tokens in diff --git a/tests/unit/distributed/test_batched_data_dict.py b/tests/unit/distributed/test_batched_data_dict.py index 9c982c1a11..b34693df1a 100644 --- a/tests/unit/distributed/test_batched_data_dict.py +++ b/tests/unit/distributed/test_batched_data_dict.py @@ -506,6 +506,36 @@ def test_shard_by_batch_size_with_packed_multimodal(): assert tuple(shards[1]["pixel_values"].as_tensor().shape) == (6, 3, 8, 8) +def test_shard_by_batch_size_allow_uneven_empty_shards_preserve_all_keys(): + """Empty trailing shards should preserve all keys with empty values.""" + batch = BatchedDataDict( + { + "input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]), + "pixel_values": PackedTensor( + [torch.randn(2, 3, 8, 8), torch.randn(1, 3, 8, 8)], dim_to_pack=0 + ), + "labels": [0, 1], + } + ) + + # total_batch_size=2, shards=4 -> trailing two shards are empty + shards = batch.shard_by_batch_size(shards=4, allow_uneven_shards=True) + assert len(shards) == 4 + + # Empty trailing shards should preserve all keys and use empty values. + for empty_shard in shards[2:]: + for key, original_value in batch.items(): + assert key in empty_shard + shard_value = empty_shard[key] + if torch.is_tensor(original_value): + assert shard_value.shape[0] == 0 + elif isinstance(original_value, PackedTensor): + assert isinstance(shard_value, PackedTensor) + assert shard_value.as_tensor() is None + else: + assert shard_value == [] + + def test_get_multimodal_dict_mixed_content_and_device_move(): """get_multimodal_dict should include PackedTensor and optional keys, and support device movement.""" images = [torch.randn(2, 3, 8, 8), torch.randn(1, 3, 8, 8)] From 534d1c5a2b38adb489ef126d20179e2e151382f9 Mon Sep 17 00:00:00 2001 From: Philip Ottesen Date: Sat, 21 Feb 2026 20:45:56 +0100 Subject: [PATCH 6/6] fix: lint Signed-off-by: Philip Ottesen --- nemo_rl/distributed/batched_data_dict.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py index 5296b599ee..2f77598b6d 100644 --- a/nemo_rl/distributed/batched_data_dict.py +++ b/nemo_rl/distributed/batched_data_dict.py @@ -550,7 +550,7 @@ def _get_padded_seqlen(seqlen: int) -> int: # or if shard_end calculation goes beyond total_batch_size shard_start = min(shard_start, total_batch_size) shard_end = min(shard_end, total_batch_size) - + if shard_start < shard_end: shard_ranges.append((shard_start, shard_end)) @@ -576,7 +576,11 @@ def _get_padded_seqlen(seqlen: int) -> int: v.slice(list(range(start, end))) for start, end in shard_ranges ] - aggregated_shards[shard_idx][k] = PackedTensor.concat(packed_slices) if packed_slices else PackedTensor.empty_like(v) + aggregated_shards[shard_idx][k] = ( + PackedTensor.concat(packed_slices) + if packed_slices + else PackedTensor.empty_like(v) + ) else: shard_values = [] for start, end in shard_ranges: