diff --git a/recipes_source/distributed_async_checkpoint_recipe.rst b/recipes_source/distributed_async_checkpoint_recipe.rst index e959883a25b..0cb22fc71d5 100644 --- a/recipes_source/distributed_async_checkpoint_recipe.rst +++ b/recipes_source/distributed_async_checkpoint_recipe.rst @@ -257,12 +257,14 @@ checkpoint requests users can take advantage of direct memory access to speed up for step in range(10): optimizer.zero_grad() model(torch.rand(8, 16, device="cuda")).sum().backward() - optimizer.step() - state_dict = { "app": AppState(model, optimizer) } + # Wait for the previous checkpoint to finish before optimizer.step() modifies weights in-place if checkpoint_future is not None: - # waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time checkpoint_future.result() + + optimizer.step() + + state_dict = { "app": AppState(model, optimizer) } checkpoint_future = dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}") cleanup()