diff --git a/cycling_utils/saving.py b/cycling_utils/saving.py index d852749..80d48c4 100644 --- a/cycling_utils/saving.py +++ b/cycling_utils/saving.py @@ -63,11 +63,17 @@ class AtomicDirectory: - `strategy = "async"` will `force_save` the checkpoint if the saving process passes `force_save = True` - `strategy = "mono"` will `force_save` the checkpoint if the saving process passes `force_save = True` - Each of the strategies "sync_any", "sync_all", and "async" + Each of the strategies "sync_any", "sync_all", and "async" are intended for use by processes operating within a torchrun + distributed process group, and the AtomicDirectory will enforce that all group processes apply the same strategy. - Further, the `strategy = "mono"` argument should be passed if the AtomicDirectory saver is intended for use outside of a - torchrun distributed process group. In such cases, the user must ensure that all instances of the AtomicDirectory saver are - initialized with a unique 'name'. + The "sync_any" and "sync_all" strategies will cause the AtomicDirectory saver to behave as one distributed object, where the + saver for each process in the group will access the same checkpoint directory, and share a "name". + + The "async" strategy is intended for use as a group of independent savers, each creating and accessing independent checkpoint + directories, with "name" automatically generated according to the rank of the saver process in the group. + + The "mono" strategy is intended for use by processes operating outside of a torchrun distributed process group. In this case, + the user MUST ensure that all instances of the AtomicDirectory saver are initialized with a unique "name". Example usage of AtomicDirectory in synchronous mode on the Strong Compute ISC launching with torchrun as follows. @@ -121,11 +127,13 @@ def __init__( name="AtomicDirectory", keep_last=-1, strategy="sync_any", + device="cuda", ): self.output_directory = output_directory - self.is_master = is_master + self.is_master = is_master or strategy in ["async", "mono"] self.keep_last = keep_last self.strategy = strategy + self.device = device self.rank = os.getenv("RANK", "NONE") self.world_size = os.getenv("WORLD_SIZE", "NONE") @@ -137,6 +145,7 @@ def __init__( raise f"ERROR: AtomicDirectory saver must be initialized with strategy = 'sync_any', 'sync_all', 'async', or 'mono' but rank \ {self.rank} was passed '{strategy}'." + # if strategy == "mono" then do not validate group strategy if strategy != "mono": assert ( @@ -147,10 +156,12 @@ def __init__( ), "ERROR: AtomicDirectory requires WORLD_SIZE environment variable set if strategy is not 'mono'." local_strategy_tensor = torch.tensor( - strategy_int, dtype=torch.int64, requires_grad=False, device="cuda" + strategy_int, dtype=torch.int64, requires_grad=False, device=self.device ) global_strategy_list = [ - torch.zeros(1, dtype=torch.int64, requires_grad=False, device="cuda") + torch.zeros( + 1, dtype=torch.int64, requires_grad=False, device=self.device + ) for _ in range(int(self.world_size)) ] all_gather(global_strategy_list, local_strategy_tensor) @@ -194,6 +205,7 @@ def is_checkpoint_directory(self, path_str): def prepare_checkpoint_directory(self, force_save=False): + # block if sychronous saving if self.strategy in ["sync_any", "sync_all"]: barrier() @@ -213,7 +225,7 @@ def prepare_checkpoint_directory(self, force_save=False): if checkpoint_paths and not symlink_found: print( - "Found one or more checkpoint dirs but no symlink to latest. Will assume all should be deleted." + "Found one or more checkpoint dirs but no symlink to latest. Will assume all should be deleted. Please ensure that you are calling symlink_latest on your completed checkpoint." ) if symlink_found and not checkpoint_paths: @@ -224,6 +236,7 @@ def prepare_checkpoint_directory(self, force_save=False): latest_sequential_index = -1 deletable = [] + # retrieve latest checkpoint information from symlink if symlink_found: symlink_path = os.readlink( @@ -233,13 +246,22 @@ def prepare_checkpoint_directory(self, force_save=False): checkpoint_paths[Path(symlink_path).name].split("_")[0] ) - # determine directories that can be deleted + # determine which if any of the found checkpoint_paths should be discarded + if checkpoint_paths: + + # delete any checkpoint directories with index + # greater than the latest_sequential_index from the symlink + # because these are assumed to be incomplete incomplete_deletable = [ os.path.join(self.output_directory, path) for path, suffix in checkpoint_paths.items() if int(suffix.split("_")[0]) > latest_sequential_index ] + # delete any checkpoint directories with index + # less than the latest_sequential_index - keep_last + 2 + # and with no "force" extension because these are obsolete + # if keep_last < 0 assume keep all. obsolete_deletable = [] if self.keep_last > 0: obsolete_deletable = [ @@ -252,16 +274,18 @@ def prepare_checkpoint_directory(self, force_save=False): deletable = incomplete_deletable + obsolete_deletable - # Delete deletable + # block if sychronous saving if self.strategy in ["sync_any", "sync_all"]: barrier() + # master / responsible rank deletes deletable if self.is_master: for path in deletable: rmtree(path) for path in deletable: assert not Path(path).exists() + # block if sychronous saving if self.strategy in ["sync_any", "sync_all"]: barrier() @@ -274,7 +298,7 @@ def prepare_checkpoint_directory(self, force_save=False): 1 if force_save else 0, dtype=torch.int64, requires_grad=False, - device="cuda", + device=self.device, ) all_reduce(global_force) @@ -302,6 +326,7 @@ def prepare_checkpoint_directory(self, force_save=False): if Path(next_checkpoint_directory).exists(): break + # block if sychronous saving if self.strategy in ["sync_any", "sync_all"]: barrier() @@ -326,6 +351,7 @@ def prepare_checkpoint_directory(self, force_save=False): def symlink_latest(self, checkpoint_directory): + # block if sychronous saving if self.strategy in ["sync_any", "sync_all"]: barrier() @@ -342,5 +368,6 @@ def symlink_latest(self, checkpoint_directory): os.path.join(parent_dir, self.symlink_name), ) + # block if sychronous saving if self.strategy in ["sync_any", "sync_all"]: barrier()