Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 276 additions & 6 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,12 +630,230 @@ def test_make_graphed_callables_with_kwargs(
assert_all_equal(outputs, graph_outputs)


def test_make_graphed_callables_returns_owned_parameter_grads() -> None:
"""Parameter grads returned from graph replay must not alias static graph buffers."""
reset_rng_states()
model_config = model_configs["small"]
dtype = torch.float32
model = torch.nn.Linear(
model_config.hidden_size,
model_config.hidden_size,
bias=False,
device="cuda",
dtype=dtype,
)
model = make_graphed_callables(
model,
(generate_data(model_config, dtype, warmup=True, requires_grad=False),),
)

seen_grads = []

def save_grad(grad):
seen_grads.append(grad)
return grad

hook = model.weight.register_hook(save_grad)
try:
output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 1
first_grad = seen_grads[0]
first_grad_ptr = first_grad.data_ptr()
first_grad_snapshot = first_grad.clone()

model.zero_grad(set_to_none=True)

output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 2
assert first_grad.data_ptr() == first_grad_ptr
assert seen_grads[1].data_ptr() != first_grad_ptr
torch.testing.assert_close(first_grad, first_grad_snapshot, rtol=0, atol=0)
finally:
hook.remove()
reset_graphs(model)


def test_make_graphed_callables_accumulates_owned_parameter_grads() -> None:
"""Parameter grad accumulation must not reuse overwritten static graph buffers."""
reset_rng_states()
model_config = model_configs["small"]
dtype = torch.float32
model = torch.nn.Linear(
model_config.hidden_size,
model_config.hidden_size,
bias=False,
device="cuda",
dtype=dtype,
)
model = make_graphed_callables(
model,
(generate_data(model_config, dtype, warmup=True, requires_grad=False),),
)

input_1 = generate_data(model_config, dtype, requires_grad=False)
grad_1 = generate_data(model_config, dtype, requires_grad=False)
input_2 = generate_data(model_config, dtype, requires_grad=False)
grad_2 = generate_data(model_config, dtype, requires_grad=False)
expected_grad = torch.einsum("...o,...i->oi", grad_1, input_1) + torch.einsum(
"...o,...i->oi", grad_2, input_2
)

try:
model.zero_grad(set_to_none=True)
model(input_1).backward(grad_1)
model(input_2).backward(grad_2)
torch.testing.assert_close(model.weight.grad, expected_grad, rtol=0, atol=0)
finally:
reset_graphs(model)


def test_make_graphed_callables_preserves_skipped_parameter_grad_alias() -> None:
"""Delayed-wgrad parameters are excluded from returned-grad clone handling."""
reset_rng_states()
model_config = model_configs["small"]
dtype = torch.float32
model = torch.nn.Linear(
model_config.hidden_size,
model_config.hidden_size,
bias=False,
device="cuda",
dtype=dtype,
)
model.weight.skip_backward_post_hook = True
model = make_graphed_callables(
model,
(generate_data(model_config, dtype, warmup=True, requires_grad=False),),
)

seen_grads = []

def save_grad(grad):
seen_grads.append(grad)
return grad

hook = model.weight.register_hook(save_grad)
try:
output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 1
first_grad_ptr = seen_grads[0].data_ptr()

model.zero_grad(set_to_none=True)

output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 2
assert seen_grads[1].data_ptr() == first_grad_ptr
finally:
hook.remove()
reset_graphs(model)


def test_make_graphed_callables_can_skip_returned_parameter_grad_clone() -> None:
"""Parameter grad clone handling can be disabled for callers that manage lifetimes."""
reset_rng_states()
model_config = model_configs["small"]
dtype = torch.float32
model = torch.nn.Linear(
model_config.hidden_size,
model_config.hidden_size,
bias=False,
device="cuda",
dtype=dtype,
)
model = make_graphed_callables(
model,
(generate_data(model_config, dtype, warmup=True, requires_grad=False),),
_clone_param_grads_on_return=False,
)

seen_grads = []

def save_grad(grad):
seen_grads.append(grad)
return grad

hook = model.weight.register_hook(save_grad)
try:
output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 1
first_grad_ptr = seen_grads[0].data_ptr()

model.zero_grad(set_to_none=True)

output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 2
assert seen_grads[1].data_ptr() == first_grad_ptr
finally:
hook.remove()
reset_graphs(model)


def test_make_graphed_callables_snapshots_parameter_grad_clone_policy() -> None:
"""Parameter grad clone policy is fixed at capture time."""
reset_rng_states()
model_config = model_configs["small"]
dtype = torch.float32
model = torch.nn.Linear(
model_config.hidden_size,
model_config.hidden_size,
bias=False,
device="cuda",
dtype=dtype,
)
model = make_graphed_callables(
model,
(generate_data(model_config, dtype, warmup=True, requires_grad=False),),
)
model.weight.skip_backward_post_hook = True

seen_grads = []

def save_grad(grad):
seen_grads.append(grad)
return grad

hook = model.weight.register_hook(save_grad)
try:
output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 1
first_grad = seen_grads[0]
first_grad_ptr = first_grad.data_ptr()
first_grad_snapshot = first_grad.clone()

model.zero_grad(set_to_none=True)

output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 2
assert seen_grads[1].data_ptr() != first_grad_ptr
torch.testing.assert_close(first_grad, first_grad_snapshot, rtol=0, atol=0)
finally:
hook.remove()
reset_graphs(model)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
*,
with_graph: bool,
model_config: ModelConfig,
dtype: torch.dtype,
) -> List[torch.Tensor]:
reuse_graph_input_output_buffers: bool = False,
clone_param_grads_on_return: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Simulate Megatron-LM interleaved pipeline parallelism."""
reset_rng_states()

Expand Down Expand Up @@ -675,6 +893,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
sample_args,
allow_unused_input=True,
_order=layer_order,
_reuse_graph_input_output_buffers=reuse_graph_input_output_buffers,
_clone_param_grads_on_return=clone_param_grads_on_return,
)
layer_forwards = {
(i // num_microbatches, i % num_microbatches): forward
Expand All @@ -701,11 +921,15 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(

# Cache for layer outputs.
outputs = {}
output_snapshots = {} if reuse_graph_input_output_buffers else None

def forward(layer_idx: int, microbatch_idx: int):
"""Helper function for forward steps"""
idxs = (layer_idx, microbatch_idx)
outputs[idxs] = layer_forwards[idxs](inputs[idxs])
if output_snapshots is not None:
# Reused graph output buffers are only valid until their corresponding backward.
output_snapshots[idxs] = outputs[idxs].detach().clone()

def backward(layer_idx: int, microbatch_idx: int):
"""Helper function for backward steps"""
Expand All @@ -728,11 +952,13 @@ def backward(layer_idx: int, microbatch_idx: int):
# Optimizer step.
optimizer.step()

outputs = [y for _, y in sorted(outputs.items())]
outputs = get_outputs(model, outputs)
output_values = output_snapshots if output_snapshots is not None else outputs
output_values = [y for _, y in sorted(output_values.items())]
outputs = get_outputs(model, output_values)
final_weights = [param.detach().clone() for param in model.parameters()]
if with_graph:
reset_graphs(layer_forwards)
return outputs
return outputs, final_weights


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
Expand All @@ -743,12 +969,56 @@ def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
"""Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
outputs, weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False,
**kwargs,
)
graph_outputs, graph_weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=True,
**kwargs,
)
assert_all_equal(outputs, graph_outputs)
assert_all_equal(weights, graph_weights)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers(
*,
model_config: str = "small",
dtype: torch.dtype = torch.float16,
) -> None:
"""Test CUDA graphs with reused input/output buffers."""
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs, weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False,
**kwargs,
)
graph_outputs, graph_weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=True,
reuse_graph_input_output_buffers=True,
**kwargs,
)
assert_all_equal(outputs, graph_outputs)
Comment on lines +984 to +1001
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Reused-buffer test only validates forward outputs, not gradient correctness

test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers compares output_snapshots (forward tensors cloned before the corresponding backward) against the eager baseline. If the clone-on-return logic in Graphed.backward had a bug specifically in the _reuse_graph_input_output_buffers + pipeline path (e.g., gradient accumulation or an incorrect static buffer being read), weights would diverge but the test would still pass. A weight-equality check after one full schedule would strengthen confidence in the gradient path for this mode.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 4077b85. The interleaved pipeline helper now returns final weights in addition to outputs, and the reused-buffer test compares graph/eager final weights to cover gradient correctness. Full tests/pytorch/test_cuda_graphs.py passed on H100: 415 passed, 423 skipped.

assert_all_equal(weights, graph_weights)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers_no_param_grad_clone(
*,
model_config: str = "small",
dtype: torch.dtype = torch.float16,
) -> None:
"""Test reused input/output buffers when returned parameter grad clones are disabled."""
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs, weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False,
**kwargs,
)
graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
graph_outputs, graph_weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=True,
reuse_graph_input_output_buffers=True,
clone_param_grads_on_return=False,
**kwargs,
)
assert_all_equal(outputs, graph_outputs)
assert_all_equal(weights, graph_weights)
Loading
Loading