Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 38 additions & 35 deletions nemo_rl/distributed/batched_data_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def _get_padded_seqlen(seqlen: int) -> int:

# Group data by shard position across all chunks
for shard_idx in range(shards):
shard_ranges: list[tuple[int, int]] = []
for chunk_idx in range(num_chunks):
# Calculate indices for this particular sub-shard within the chunk
chunk_start = chunk_idx * batch_size
Expand All @@ -549,41 +550,43 @@ def _get_padded_seqlen(seqlen: int) -> int:
# or if shard_end calculation goes beyond total_batch_size
shard_start = min(shard_start, total_batch_size)
shard_end = min(shard_end, total_batch_size)
indices = torch.arange(shard_start, shard_end)

for k in data:
if k not in aggregated_shards[shard_idx]:
# First time seeing this key for this shard, initialize it
if torch.is_tensor(data[k]):
aggregated_shards[shard_idx][k] = data[k][indices].clone()
elif isinstance(data[k], PackedTensor):
aggregated_shards[shard_idx][k] = data[k].slice(
indices.tolist()
)
else:
aggregated_shards[shard_idx][k] = [
data[k][i] for i in indices
]
else:
# Append to existing data - concatenate tensors or extend lists
if torch.is_tensor(data[k]):
aggregated_shards[shard_idx][k] = torch.cat(
[
aggregated_shards[shard_idx][k],
data[k][indices].clone(),
]
)
elif isinstance(data[k], PackedTensor):
aggregated_shards[shard_idx][k] = PackedTensor.concat(
[
aggregated_shards[shard_idx][k],
data[k].slice(indices.tolist()),
]
)
else:
aggregated_shards[shard_idx][k].extend(
[data[k][i] for i in indices]
)

if shard_start < shard_end:
shard_ranges.append((shard_start, shard_end))

for k, v in data.items():
if torch.is_tensor(v):
# Pre-allocate and copy each chunk once
rows = sum(end - start for start, end in shard_ranges)
shard_tensor = torch.empty(
(rows, *v.shape[1:]),
dtype=v.dtype,
device=v.device,
)
offset = 0
for start, end in shard_ranges:
span = end - start
shard_tensor[offset : offset + span].copy_(v[start:end])
offset += span

aggregated_shards[shard_idx][k] = shard_tensor
elif isinstance(v, PackedTensor):
# PackedTensor is collected per chunk then concatenated once
packed_slices = [
v.slice(list(range(start, end))) for start, end in shard_ranges
]

aggregated_shards[shard_idx][k] = (
PackedTensor.concat(packed_slices)
if packed_slices
else PackedTensor.empty_like(v)
)
else:
shard_values = []
for start, end in shard_ranges:
shard_values.extend([v[i] for i in range(start, end)])

aggregated_shards[shard_idx][k] = shard_values

# map inputs to microbatches such that the total number tokens in
# a microbatch is as close to (including padding tokens) 'max_tokens_per_microbatch'
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/distributed/test_batched_data_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,36 @@ def test_shard_by_batch_size_with_packed_multimodal():
assert tuple(shards[1]["pixel_values"].as_tensor().shape) == (6, 3, 8, 8)


def test_shard_by_batch_size_allow_uneven_empty_shards_preserve_all_keys():
"""Empty trailing shards should preserve all keys with empty values."""
batch = BatchedDataDict(
{
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]),
"pixel_values": PackedTensor(
[torch.randn(2, 3, 8, 8), torch.randn(1, 3, 8, 8)], dim_to_pack=0
),
"labels": [0, 1],
}
)

# total_batch_size=2, shards=4 -> trailing two shards are empty
shards = batch.shard_by_batch_size(shards=4, allow_uneven_shards=True)
assert len(shards) == 4

# Empty trailing shards should preserve all keys and use empty values.
for empty_shard in shards[2:]:
for key, original_value in batch.items():
assert key in empty_shard
shard_value = empty_shard[key]
if torch.is_tensor(original_value):
assert shard_value.shape[0] == 0
elif isinstance(original_value, PackedTensor):
assert isinstance(shard_value, PackedTensor)
assert shard_value.as_tensor() is None
else:
assert shard_value == []


def test_get_multimodal_dict_mixed_content_and_device_move():
"""get_multimodal_dict should include PackedTensor and optional keys, and support device movement."""
images = [torch.randn(2, 3, 8, 8), torch.randn(1, 3, 8, 8)]
Expand Down