diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py index 20b39f2b50..2f77598b6d 100644 --- a/nemo_rl/distributed/batched_data_dict.py +++ b/nemo_rl/distributed/batched_data_dict.py @@ -539,6 +539,7 @@ def _get_padded_seqlen(seqlen: int) -> int: # Group data by shard position across all chunks for shard_idx in range(shards): + shard_ranges: list[tuple[int, int]] = [] for chunk_idx in range(num_chunks): # Calculate indices for this particular sub-shard within the chunk chunk_start = chunk_idx * batch_size @@ -549,41 +550,43 @@ def _get_padded_seqlen(seqlen: int) -> int: # or if shard_end calculation goes beyond total_batch_size shard_start = min(shard_start, total_batch_size) shard_end = min(shard_end, total_batch_size) - indices = torch.arange(shard_start, shard_end) - - for k in data: - if k not in aggregated_shards[shard_idx]: - # First time seeing this key for this shard, initialize it - if torch.is_tensor(data[k]): - aggregated_shards[shard_idx][k] = data[k][indices].clone() - elif isinstance(data[k], PackedTensor): - aggregated_shards[shard_idx][k] = data[k].slice( - indices.tolist() - ) - else: - aggregated_shards[shard_idx][k] = [ - data[k][i] for i in indices - ] - else: - # Append to existing data - concatenate tensors or extend lists - if torch.is_tensor(data[k]): - aggregated_shards[shard_idx][k] = torch.cat( - [ - aggregated_shards[shard_idx][k], - data[k][indices].clone(), - ] - ) - elif isinstance(data[k], PackedTensor): - aggregated_shards[shard_idx][k] = PackedTensor.concat( - [ - aggregated_shards[shard_idx][k], - data[k].slice(indices.tolist()), - ] - ) - else: - aggregated_shards[shard_idx][k].extend( - [data[k][i] for i in indices] - ) + + if shard_start < shard_end: + shard_ranges.append((shard_start, shard_end)) + + for k, v in data.items(): + if torch.is_tensor(v): + # Pre-allocate and copy each chunk once + rows = sum(end - start for start, end in shard_ranges) + shard_tensor = torch.empty( + (rows, *v.shape[1:]), + dtype=v.dtype, + device=v.device, + ) + offset = 0 + for start, end in shard_ranges: + span = end - start + shard_tensor[offset : offset + span].copy_(v[start:end]) + offset += span + + aggregated_shards[shard_idx][k] = shard_tensor + elif isinstance(v, PackedTensor): + # PackedTensor is collected per chunk then concatenated once + packed_slices = [ + v.slice(list(range(start, end))) for start, end in shard_ranges + ] + + aggregated_shards[shard_idx][k] = ( + PackedTensor.concat(packed_slices) + if packed_slices + else PackedTensor.empty_like(v) + ) + else: + shard_values = [] + for start, end in shard_ranges: + shard_values.extend([v[i] for i in range(start, end)]) + + aggregated_shards[shard_idx][k] = shard_values # map inputs to microbatches such that the total number tokens in # a microbatch is as close to (including padding tokens) 'max_tokens_per_microbatch' diff --git a/tests/unit/distributed/test_batched_data_dict.py b/tests/unit/distributed/test_batched_data_dict.py index 9c982c1a11..b34693df1a 100644 --- a/tests/unit/distributed/test_batched_data_dict.py +++ b/tests/unit/distributed/test_batched_data_dict.py @@ -506,6 +506,36 @@ def test_shard_by_batch_size_with_packed_multimodal(): assert tuple(shards[1]["pixel_values"].as_tensor().shape) == (6, 3, 8, 8) +def test_shard_by_batch_size_allow_uneven_empty_shards_preserve_all_keys(): + """Empty trailing shards should preserve all keys with empty values.""" + batch = BatchedDataDict( + { + "input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]), + "pixel_values": PackedTensor( + [torch.randn(2, 3, 8, 8), torch.randn(1, 3, 8, 8)], dim_to_pack=0 + ), + "labels": [0, 1], + } + ) + + # total_batch_size=2, shards=4 -> trailing two shards are empty + shards = batch.shard_by_batch_size(shards=4, allow_uneven_shards=True) + assert len(shards) == 4 + + # Empty trailing shards should preserve all keys and use empty values. + for empty_shard in shards[2:]: + for key, original_value in batch.items(): + assert key in empty_shard + shard_value = empty_shard[key] + if torch.is_tensor(original_value): + assert shard_value.shape[0] == 0 + elif isinstance(original_value, PackedTensor): + assert isinstance(shard_value, PackedTensor) + assert shard_value.as_tensor() is None + else: + assert shard_value == [] + + def test_get_multimodal_dict_mixed_content_and_device_move(): """get_multimodal_dict should include PackedTensor and optional keys, and support device movement.""" images = [torch.randn(2, 3, 8, 8), torch.randn(1, 3, 8, 8)]