From f914e75f08befa2146d56a955afa78278db5ddf9 Mon Sep 17 00:00:00 2001 From: Eduardo Patrocinio Date: Mon, 8 Dec 2025 18:11:33 -0500 Subject: [PATCH] Fix async checkpoint timing in DCP recipe Move checkpoint_future.result() before optimizer.step() to ensure the previous checkpoint completes before weights are modified in-place. This allows better overlap of checkpointing with forward/backward passes. Fixes #3584 --- recipes_source/distributed_async_checkpoint_recipe.rst | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/recipes_source/distributed_async_checkpoint_recipe.rst b/recipes_source/distributed_async_checkpoint_recipe.rst index e959883a25b..0cb22fc71d5 100644 --- a/recipes_source/distributed_async_checkpoint_recipe.rst +++ b/recipes_source/distributed_async_checkpoint_recipe.rst @@ -257,12 +257,14 @@ checkpoint requests users can take advantage of direct memory access to speed up for step in range(10): optimizer.zero_grad() model(torch.rand(8, 16, device="cuda")).sum().backward() - optimizer.step() - state_dict = { "app": AppState(model, optimizer) } + # Wait for the previous checkpoint to finish before optimizer.step() modifies weights in-place if checkpoint_future is not None: - # waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time checkpoint_future.result() + + optimizer.step() + + state_dict = { "app": AppState(model, optimizer) } checkpoint_future = dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}") cleanup()