diff --git a/apps/Castor/__init__.py b/apps/Castor/__init__.py new file mode 100644 index 00000000..8fd387ab --- /dev/null +++ b/apps/Castor/__init__.py @@ -0,0 +1,3 @@ +""" +Castor package initialization +""" \ No newline at end of file diff --git a/apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml b/apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml index fdda0f29..1ec52476 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml @@ -67,7 +67,10 @@ model: config_name: "ViT-B/32" dtype: "bf16" text_seqlen: 77 - + ema: + decay: 0.95 + warmup_steps: 2000 + update_buffers: false data: - stage: stage-1 id: 1 diff --git a/apps/Castor/modules/__init__.py b/apps/Castor/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/Castor/modules/ema.py b/apps/Castor/modules/ema.py new file mode 100644 index 00000000..864babf7 --- /dev/null +++ b/apps/Castor/modules/ema.py @@ -0,0 +1,98 @@ +import torch +import math +import copy +from dataclasses import dataclass +from collections import OrderedDict +from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard +from torch.distributed.device_mesh import DeviceMesh +from typing import Optional, List, Tuple + +@dataclass +class EMAArgs: + decay: float = 0.95 + warmup_steps: int = 2000 + update_buffers: bool = False + + +class EMA: + def __init__(self, model: torch.nn.Module, decay: float, warmup_steps: int = 0, update_buffers: bool = False): + """ + Initializes EMA with warmup support. + + Args: + - model (torch.nn.Module): The model to track. + - decay (float): Target decay rate (e.g., 0.95). + - warmup_steps (int): Number of steps for warmup (default is 0 for no warmup). + """ + self.ema_model = copy.deepcopy(model).eval() # Duplicate the model for EMA + self.decay = decay + self.warmup_steps = warmup_steps + self.global_step = 0 # Starts at step 0 + self.update_buffers = update_buffers + # Disable gradient computation for EMA model + for param in self.ema_model.parameters(): + param.requires_grad = False + + + def _compute_effective_decay(self) -> float: + """ + Compute the effective decay based on warmup steps. + """ + if self.warmup_steps > 0: + return self.decay * (1 - math.exp(-self.global_step / self.warmup_steps)) + return self.decay + + + @torch.no_grad() + def step(self, model: torch.nn.Module): + """ + Updates the EMA model with the current model parameters. + + Args: + - model (torch.nn.Module): Current model to update EMA from. + - update_buffers (bool): Whether to update buffers such as BatchNorm stats. + + # https://github.com/pytorch/pytorch/issues/117742 based on this its okay to update ema model without summoning full parameters + """ + self.global_step += 1 + effective_decay = self._compute_effective_decay() # Get the current decay rate + + # Update parameters + params = OrderedDict(model.named_parameters()) + ema_params = OrderedDict(self.ema_model.named_parameters()) + + assert set(ema_params.keys()) == set(params.keys()) + + + + for name in params: + ema_params[name].mul_(effective_decay).add_(params[name].data, alpha=1 - effective_decay) + + # Update buffers (if needed) + if self.update_buffers: + buffers = OrderedDict(model.named_buffers()) + ema_buffers = OrderedDict(self.ema_model.named_buffers()) + assert set(ema_buffers.keys()) == set(buffers.keys()) + for name in buffers: + if buffers[name].dtype.is_floating_point: + ema_buffers[name].mul_(effective_decay).add_( + buffers[name].data, alpha=1 - effective_decay + ) + + def state_dict(self) -> dict: + """ + Returns the state dictionary for the EMA model. + """ + return self.ema_model.state_dict() + + def load_state_dict(self, state_dict: dict): + """ + Loads the state dictionary into the EMA model. + """ + self.ema_model.load_state_dict(state_dict) + + def to(self, device: torch.device): + """ + Transfers the EMA model to a specified device. + """ + self.ema_model.to(device) diff --git a/apps/Castor/train.py b/apps/Castor/train.py index 40e865bb..be7932c2 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -42,6 +42,8 @@ get_world_size, get_local_rank, parallelize_model, + apply_activation_checkpointing, + apply_compile, setup_env, setup_torch_distributed, requeue_slurm_job, @@ -68,12 +70,13 @@ tp_parallelize, get_no_recompute_ops, ) - +from apps.Castor.modules.ema import EMA, EMAArgs from apps.main.utils.cal_flops import get_num_flop_per_token logger = logging.getLogger() + @dataclass class TrainArgs: @@ -104,7 +107,7 @@ class TrainArgs: profiling: ProfilerArgs = field(default_factory=ProfilerArgs) logging: LoggingArgs = field(default_factory=LoggingArgs) scheduler: SchedulerArgs = field(default_factory=SchedulerArgs) - + ema: EMAArgs = field(default_factory=EMAArgs) # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus async_eval_gpus: Optional[int] = None eval: Optional[Any] = None @@ -277,6 +280,7 @@ def train(args: TrainArgs): model = Castor(args.model) logger.info("Model is built !") + ema = EMA(model, decay=args.ema.decay, warmup_steps=args.ema.warmup_steps) model_param_count = get_num_params(model) @@ -291,6 +295,8 @@ def train(args: TrainArgs): tp_parallelize=tp_parallelize, no_recompute_ops=get_no_recompute_ops(), ) + model = apply_activation_checkpointing(model, args.distributed) + model = apply_compile(model, args.distributed) model = model.to(device="cuda") check_model_value_range(model, range=10.0, std=1.0) @@ -299,6 +305,15 @@ def train(args: TrainArgs): logger.info(f"Model size: {model_param_count:,} total parameters") + ema.ema_model = parallelize_model( + ema.ema_model, + world_mesh, + args.model, + args.distributed, + fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + ) + ema.ema_model = ema.ema_model.to(device="cuda") + gpu_memory_monitor = GPUMemoryMonitor("cuda") logger.info( f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " @@ -414,6 +429,7 @@ def train(args: TrainArgs): optimizer.step() scheduler.step() optimizer.zero_grad() + ema.step(model) train_state.step += 1 # updates the scale for next iteration diff --git a/lingua/distributed.py b/lingua/distributed.py index 0e0b0407..3abcac16 100644 --- a/lingua/distributed.py +++ b/lingua/distributed.py @@ -34,9 +34,6 @@ ) from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -# for no recompute ops -import xformers.ops - from lingua.float8 import convert_linears_to_fp8 logger = logging.getLogger() @@ -466,9 +463,13 @@ def parallelize_model( ) model = fully_shard(model, **fsdp_config, reshard_after_forward=True) + return model else: raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}") + + +def apply_activation_checkpointing(model, distributed_args: DistributedArgs): if distributed_args.selective_activation_checkpointing: non_reentrant_wrapper = partial( checkpoint_wrapper, @@ -479,11 +480,14 @@ def parallelize_model( checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=lambda submodule: submodule in model.get_checkpointing_wrap_module_list(), ) + return model + +def apply_compile(model, distributed_args: DistributedArgs): if distributed_args.compile: torch._dynamo.config.cache_size_limit = ( distributed_args.compile_cache_size_limit ) model = torch.compile(model) - return model + return model \ No newline at end of file