From 07e315f2c7ca16dc18b4c985d401116a37f8fd64 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Thu, 3 Apr 2025 03:13:08 +0000 Subject: [PATCH 1/3] ema support --- apps/Castor/__init__.py | 3 + ...in_bucket_256_Castor_image_cache_json.yaml | 3 + apps/Castor/modules/__init__.py | 0 apps/Castor/modules/ema.py | 97 +++++++++++++++++++ apps/Castor/train.py | 20 +++- lingua/distributed.py | 12 ++- 6 files changed, 129 insertions(+), 6 deletions(-) create mode 100644 apps/Castor/__init__.py create mode 100644 apps/Castor/modules/__init__.py create mode 100644 apps/Castor/modules/ema.py diff --git a/apps/Castor/__init__.py b/apps/Castor/__init__.py new file mode 100644 index 00000000..8fd387ab --- /dev/null +++ b/apps/Castor/__init__.py @@ -0,0 +1,3 @@ +""" +Castor package initialization +""" \ No newline at end of file diff --git a/apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml b/apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml index fdda0f29..0e634caa 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml @@ -67,6 +67,9 @@ model: config_name: "ViT-B/32" dtype: "bf16" text_seqlen: 77 + ema: + decay: 0.95 + warmup_steps: 2000 data: - stage: stage-1 diff --git a/apps/Castor/modules/__init__.py b/apps/Castor/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/Castor/modules/ema.py b/apps/Castor/modules/ema.py new file mode 100644 index 00000000..08e93b0e --- /dev/null +++ b/apps/Castor/modules/ema.py @@ -0,0 +1,97 @@ +import torch +import math +import copy +from dataclasses import dataclass +from collections import OrderedDict +from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard +from torch.distributed.device_mesh import DeviceMesh +from typing import Optional, List, Tuple + +@dataclass +class EMAArgs: + decay: float = 0.95 + warmup_steps: int = 2000 + + +class EMA: + def __init__(self, model: torch.nn.Module, decay: float, warmup_steps: int = 0): + """ + Initializes EMA with warmup support. + + Args: + - model (torch.nn.Module): The model to track. + - decay (float): Target decay rate (e.g., 0.95). + - warmup_steps (int): Number of steps for warmup (default is 0 for no warmup). + """ + self.ema_model = copy.deepcopy(model).eval() # Duplicate the model for EMA + self.decay = decay + self.warmup_steps = warmup_steps + self.global_step = 0 # Starts at step 0 + + # Disable gradient computation for EMA model + for param in self.ema_model.parameters(): + param.requires_grad = False + + + def _compute_effective_decay(self) -> float: + """ + Compute the effective decay based on warmup steps. + """ + if self.warmup_steps > 0: + return self.decay * (1 - math.exp(-self.global_step / self.warmup_steps)) + return self.decay + + + @torch.no_grad() + def step(self, model: torch.nn.Module, update_buffers: bool = False): + """ + Updates the EMA model with the current model parameters. + + Args: + - model (torch.nn.Module): Current model to update EMA from. + - update_buffers (bool): Whether to update buffers such as BatchNorm stats. + + # https://github.com/pytorch/pytorch/issues/117742 based on this its okay to update ema model without summoning full parameters + """ + self.global_step += 1 + effective_decay = self._compute_effective_decay() # Get the current decay rate + + # Update parameters + params = OrderedDict(model.named_parameters()) + ema_params = OrderedDict(self.ema_model.named_parameters()) + + assert set(ema_params.keys()) == set(params.keys()) + + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + + for name in params: + ema_params[name].mul_(effective_decay).add_(params[name].data, alpha=1 - effective_decay) + + # Update buffers (if needed) + if update_buffers: + buffers = OrderedDict(model.named_buffers()) + ema_buffers = OrderedDict(self.ema_model.named_buffers()) + assert set(ema_buffers.keys()) == set(buffers.keys()) + for name in buffers: + if buffers[name].dtype.is_floating_point: + ema_buffers[name].mul_(effective_decay).add_( + buffers[name].data, alpha=1 - effective_decay + ) + + def state_dict(self) -> dict: + """ + Returns the state dictionary for the EMA model. + """ + return self.ema_model.state_dict() + + def load_state_dict(self, state_dict: dict): + """ + Loads the state dictionary into the EMA model. + """ + self.ema_model.load_state_dict(state_dict) + + def to(self, device: torch.device): + """ + Transfers the EMA model to a specified device. + """ + self.ema_model.to(device) diff --git a/apps/Castor/train.py b/apps/Castor/train.py index 40e865bb..be7932c2 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -42,6 +42,8 @@ get_world_size, get_local_rank, parallelize_model, + apply_activation_checkpointing, + apply_compile, setup_env, setup_torch_distributed, requeue_slurm_job, @@ -68,12 +70,13 @@ tp_parallelize, get_no_recompute_ops, ) - +from apps.Castor.modules.ema import EMA, EMAArgs from apps.main.utils.cal_flops import get_num_flop_per_token logger = logging.getLogger() + @dataclass class TrainArgs: @@ -104,7 +107,7 @@ class TrainArgs: profiling: ProfilerArgs = field(default_factory=ProfilerArgs) logging: LoggingArgs = field(default_factory=LoggingArgs) scheduler: SchedulerArgs = field(default_factory=SchedulerArgs) - + ema: EMAArgs = field(default_factory=EMAArgs) # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus async_eval_gpus: Optional[int] = None eval: Optional[Any] = None @@ -277,6 +280,7 @@ def train(args: TrainArgs): model = Castor(args.model) logger.info("Model is built !") + ema = EMA(model, decay=args.ema.decay, warmup_steps=args.ema.warmup_steps) model_param_count = get_num_params(model) @@ -291,6 +295,8 @@ def train(args: TrainArgs): tp_parallelize=tp_parallelize, no_recompute_ops=get_no_recompute_ops(), ) + model = apply_activation_checkpointing(model, args.distributed) + model = apply_compile(model, args.distributed) model = model.to(device="cuda") check_model_value_range(model, range=10.0, std=1.0) @@ -299,6 +305,15 @@ def train(args: TrainArgs): logger.info(f"Model size: {model_param_count:,} total parameters") + ema.ema_model = parallelize_model( + ema.ema_model, + world_mesh, + args.model, + args.distributed, + fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + ) + ema.ema_model = ema.ema_model.to(device="cuda") + gpu_memory_monitor = GPUMemoryMonitor("cuda") logger.info( f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " @@ -414,6 +429,7 @@ def train(args: TrainArgs): optimizer.step() scheduler.step() optimizer.zero_grad() + ema.step(model) train_state.step += 1 # updates the scale for next iteration diff --git a/lingua/distributed.py b/lingua/distributed.py index 0e0b0407..3abcac16 100644 --- a/lingua/distributed.py +++ b/lingua/distributed.py @@ -34,9 +34,6 @@ ) from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -# for no recompute ops -import xformers.ops - from lingua.float8 import convert_linears_to_fp8 logger = logging.getLogger() @@ -466,9 +463,13 @@ def parallelize_model( ) model = fully_shard(model, **fsdp_config, reshard_after_forward=True) + return model else: raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}") + + +def apply_activation_checkpointing(model, distributed_args: DistributedArgs): if distributed_args.selective_activation_checkpointing: non_reentrant_wrapper = partial( checkpoint_wrapper, @@ -479,11 +480,14 @@ def parallelize_model( checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=lambda submodule: submodule in model.get_checkpointing_wrap_module_list(), ) + return model + +def apply_compile(model, distributed_args: DistributedArgs): if distributed_args.compile: torch._dynamo.config.cache_size_limit = ( distributed_args.compile_cache_size_limit ) model = torch.compile(model) - return model + return model \ No newline at end of file From 50b307170458793740d4b72b3e0e56e809bd292d Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Thu, 3 Apr 2025 03:31:40 +0000 Subject: [PATCH 2/3] minor change --- apps/Castor/modules/ema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/Castor/modules/ema.py b/apps/Castor/modules/ema.py index 08e93b0e..fc7c9f38 100644 --- a/apps/Castor/modules/ema.py +++ b/apps/Castor/modules/ema.py @@ -62,7 +62,7 @@ def step(self, model: torch.nn.Module, update_buffers: bool = False): assert set(ema_params.keys()) == set(params.keys()) - # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + for name in params: ema_params[name].mul_(effective_decay).add_(params[name].data, alpha=1 - effective_decay) From 4e6275e28b18bbeb58cc1b78e22052501387a46c Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Thu, 3 Apr 2025 03:35:02 +0000 Subject: [PATCH 3/3] minor interface change for ema- moved buffers arg to init for simplicity --- .../train_bucket_256_Castor_image_cache_json.yaml | 2 +- apps/Castor/modules/ema.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml b/apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml index 0e634caa..1ec52476 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml @@ -70,7 +70,7 @@ model: ema: decay: 0.95 warmup_steps: 2000 - + update_buffers: false data: - stage: stage-1 id: 1 diff --git a/apps/Castor/modules/ema.py b/apps/Castor/modules/ema.py index fc7c9f38..864babf7 100644 --- a/apps/Castor/modules/ema.py +++ b/apps/Castor/modules/ema.py @@ -11,10 +11,11 @@ class EMAArgs: decay: float = 0.95 warmup_steps: int = 2000 + update_buffers: bool = False class EMA: - def __init__(self, model: torch.nn.Module, decay: float, warmup_steps: int = 0): + def __init__(self, model: torch.nn.Module, decay: float, warmup_steps: int = 0, update_buffers: bool = False): """ Initializes EMA with warmup support. @@ -27,7 +28,7 @@ def __init__(self, model: torch.nn.Module, decay: float, warmup_steps: int = 0): self.decay = decay self.warmup_steps = warmup_steps self.global_step = 0 # Starts at step 0 - + self.update_buffers = update_buffers # Disable gradient computation for EMA model for param in self.ema_model.parameters(): param.requires_grad = False @@ -43,7 +44,7 @@ def _compute_effective_decay(self) -> float: @torch.no_grad() - def step(self, model: torch.nn.Module, update_buffers: bool = False): + def step(self, model: torch.nn.Module): """ Updates the EMA model with the current model parameters. @@ -68,7 +69,7 @@ def step(self, model: torch.nn.Module, update_buffers: bool = False): ema_params[name].mul_(effective_decay).add_(params[name].data, alpha=1 - effective_decay) # Update buffers (if needed) - if update_buffers: + if self.update_buffers: buffers = OrderedDict(model.named_buffers()) ema_buffers = OrderedDict(self.ema_model.named_buffers()) assert set(ema_buffers.keys()) == set(buffers.keys())