Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions cycling_utils/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,17 @@ class AtomicDirectory:
- `strategy = "async"` will `force_save` the checkpoint if the saving process passes `force_save = True`
- `strategy = "mono"` will `force_save` the checkpoint if the saving process passes `force_save = True`

Each of the strategies "sync_any", "sync_all", and "async"
Each of the strategies "sync_any", "sync_all", and "async" are intended for use by processes operating within a torchrun
distributed process group, and the AtomicDirectory will enforce that all group processes apply the same strategy.

Further, the `strategy = "mono"` argument should be passed if the AtomicDirectory saver is intended for use outside of a
torchrun distributed process group. In such cases, the user must ensure that all instances of the AtomicDirectory saver are
initialized with a unique 'name'.
The "sync_any" and "sync_all" strategies will cause the AtomicDirectory saver to behave as one distributed object, where the
saver for each process in the group will access the same checkpoint directory, and share a "name".

The "async" strategy is intended for use as a group of independent savers, each creating and accessing independent checkpoint
directories, with "name" automatically generated according to the rank of the saver process in the group.

The "mono" strategy is intended for use by processes operating outside of a torchrun distributed process group. In this case,
the user MUST ensure that all instances of the AtomicDirectory saver are initialized with a unique "name".

Example usage of AtomicDirectory in synchronous mode on the Strong Compute ISC launching with torchrun as follows.

Expand Down Expand Up @@ -121,11 +127,13 @@ def __init__(
name="AtomicDirectory",
keep_last=-1,
strategy="sync_any",
device="cuda",
):
self.output_directory = output_directory
self.is_master = is_master
self.is_master = is_master or strategy in ["async", "mono"]
self.keep_last = keep_last
self.strategy = strategy
self.device = device
self.rank = os.getenv("RANK", "NONE")
self.world_size = os.getenv("WORLD_SIZE", "NONE")

Expand All @@ -137,6 +145,7 @@ def __init__(
raise f"ERROR: AtomicDirectory saver must be initialized with strategy = 'sync_any', 'sync_all', 'async', or 'mono' but rank \
{self.rank} was passed '{strategy}'."

# if strategy == "mono" then do not validate group strategy
if strategy != "mono":

assert (
Expand All @@ -147,10 +156,12 @@ def __init__(
), "ERROR: AtomicDirectory requires WORLD_SIZE environment variable set if strategy is not 'mono'."

local_strategy_tensor = torch.tensor(
strategy_int, dtype=torch.int64, requires_grad=False, device="cuda"
strategy_int, dtype=torch.int64, requires_grad=False, device=self.device
)
global_strategy_list = [
torch.zeros(1, dtype=torch.int64, requires_grad=False, device="cuda")
torch.zeros(
1, dtype=torch.int64, requires_grad=False, device=self.device
)
for _ in range(int(self.world_size))
]
all_gather(global_strategy_list, local_strategy_tensor)
Expand Down Expand Up @@ -194,6 +205,7 @@ def is_checkpoint_directory(self, path_str):

def prepare_checkpoint_directory(self, force_save=False):

# block if sychronous saving
if self.strategy in ["sync_any", "sync_all"]:
barrier()

Expand All @@ -213,7 +225,7 @@ def prepare_checkpoint_directory(self, force_save=False):

if checkpoint_paths and not symlink_found:
print(
"Found one or more checkpoint dirs but no symlink to latest. Will assume all should be deleted."
"Found one or more checkpoint dirs but no symlink to latest. Will assume all should be deleted. Please ensure that you are calling symlink_latest on your completed checkpoint."
)

if symlink_found and not checkpoint_paths:
Expand All @@ -224,6 +236,7 @@ def prepare_checkpoint_directory(self, force_save=False):
latest_sequential_index = -1
deletable = []

# retrieve latest checkpoint information from symlink
if symlink_found:

symlink_path = os.readlink(
Expand All @@ -233,13 +246,22 @@ def prepare_checkpoint_directory(self, force_save=False):
checkpoint_paths[Path(symlink_path).name].split("_")[0]
)

# determine directories that can be deleted
# determine which if any of the found checkpoint_paths should be discarded
if checkpoint_paths:

# delete any checkpoint directories with index
# greater than the latest_sequential_index from the symlink
# because these are assumed to be incomplete
incomplete_deletable = [
os.path.join(self.output_directory, path)
for path, suffix in checkpoint_paths.items()
if int(suffix.split("_")[0]) > latest_sequential_index
]

# delete any checkpoint directories with index
# less than the latest_sequential_index - keep_last + 2
# and with no "force" extension because these are obsolete
# if keep_last < 0 assume keep all.
obsolete_deletable = []
if self.keep_last > 0:
obsolete_deletable = [
Expand All @@ -252,16 +274,18 @@ def prepare_checkpoint_directory(self, force_save=False):

deletable = incomplete_deletable + obsolete_deletable

# Delete deletable
# block if sychronous saving
if self.strategy in ["sync_any", "sync_all"]:
barrier()

# master / responsible rank deletes deletable
if self.is_master:
for path in deletable:
rmtree(path)
for path in deletable:
assert not Path(path).exists()

# block if sychronous saving
if self.strategy in ["sync_any", "sync_all"]:
barrier()

Expand All @@ -274,7 +298,7 @@ def prepare_checkpoint_directory(self, force_save=False):
1 if force_save else 0,
dtype=torch.int64,
requires_grad=False,
device="cuda",
device=self.device,
)
all_reduce(global_force)

Expand Down Expand Up @@ -302,6 +326,7 @@ def prepare_checkpoint_directory(self, force_save=False):
if Path(next_checkpoint_directory).exists():
break

# block if sychronous saving
if self.strategy in ["sync_any", "sync_all"]:
barrier()

Expand All @@ -326,6 +351,7 @@ def prepare_checkpoint_directory(self, force_save=False):

def symlink_latest(self, checkpoint_directory):

# block if sychronous saving
if self.strategy in ["sync_any", "sync_all"]:
barrier()

Expand All @@ -342,5 +368,6 @@ def symlink_latest(self, checkpoint_directory):
os.path.join(parent_dir, self.symlink_name),
)

# block if sychronous saving
if self.strategy in ["sync_any", "sync_all"]:
barrier()