Skip to content
This repository was archived by the owner on Feb 5, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions apps/Castor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Castor package initialization
"""
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ model:
config_name: "ViT-B/32"
dtype: "bf16"
text_seqlen: 77

ema:
decay: 0.95
warmup_steps: 2000
update_buffers: false
data:
- stage: stage-1
id: 1
Expand Down
Empty file added apps/Castor/modules/__init__.py
Empty file.
98 changes: 98 additions & 0 deletions apps/Castor/modules/ema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import math
import copy
from dataclasses import dataclass
from collections import OrderedDict
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
from torch.distributed.device_mesh import DeviceMesh
from typing import Optional, List, Tuple

@dataclass
class EMAArgs:
decay: float = 0.95
warmup_steps: int = 2000
update_buffers: bool = False


class EMA:
def __init__(self, model: torch.nn.Module, decay: float, warmup_steps: int = 0, update_buffers: bool = False):
"""
Initializes EMA with warmup support.

Args:
- model (torch.nn.Module): The model to track.
- decay (float): Target decay rate (e.g., 0.95).
- warmup_steps (int): Number of steps for warmup (default is 0 for no warmup).
"""
self.ema_model = copy.deepcopy(model).eval() # Duplicate the model for EMA
self.decay = decay
self.warmup_steps = warmup_steps
self.global_step = 0 # Starts at step 0
self.update_buffers = update_buffers
# Disable gradient computation for EMA model
for param in self.ema_model.parameters():
param.requires_grad = False


def _compute_effective_decay(self) -> float:
"""
Compute the effective decay based on warmup steps.
"""
if self.warmup_steps > 0:
return self.decay * (1 - math.exp(-self.global_step / self.warmup_steps))
return self.decay


@torch.no_grad()
def step(self, model: torch.nn.Module):
"""
Updates the EMA model with the current model parameters.

Args:
- model (torch.nn.Module): Current model to update EMA from.
- update_buffers (bool): Whether to update buffers such as BatchNorm stats.

# https://github.com/pytorch/pytorch/issues/117742 based on this its okay to update ema model without summoning full parameters
"""
self.global_step += 1
effective_decay = self._compute_effective_decay() # Get the current decay rate

# Update parameters
params = OrderedDict(model.named_parameters())
ema_params = OrderedDict(self.ema_model.named_parameters())

assert set(ema_params.keys()) == set(params.keys())



for name in params:
ema_params[name].mul_(effective_decay).add_(params[name].data, alpha=1 - effective_decay)

# Update buffers (if needed)
if self.update_buffers:
buffers = OrderedDict(model.named_buffers())
ema_buffers = OrderedDict(self.ema_model.named_buffers())
assert set(ema_buffers.keys()) == set(buffers.keys())
for name in buffers:
if buffers[name].dtype.is_floating_point:
ema_buffers[name].mul_(effective_decay).add_(
buffers[name].data, alpha=1 - effective_decay
)

def state_dict(self) -> dict:
"""
Returns the state dictionary for the EMA model.
"""
return self.ema_model.state_dict()

def load_state_dict(self, state_dict: dict):
"""
Loads the state dictionary into the EMA model.
"""
self.ema_model.load_state_dict(state_dict)

def to(self, device: torch.device):
"""
Transfers the EMA model to a specified device.
"""
self.ema_model.to(device)
20 changes: 18 additions & 2 deletions apps/Castor/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
get_world_size,
get_local_rank,
parallelize_model,
apply_activation_checkpointing,
apply_compile,
setup_env,
setup_torch_distributed,
requeue_slurm_job,
Expand All @@ -68,12 +70,13 @@
tp_parallelize,
get_no_recompute_ops,
)

from apps.Castor.modules.ema import EMA, EMAArgs
from apps.main.utils.cal_flops import get_num_flop_per_token

logger = logging.getLogger()



@dataclass
class TrainArgs:

Expand Down Expand Up @@ -104,7 +107,7 @@ class TrainArgs:
profiling: ProfilerArgs = field(default_factory=ProfilerArgs)
logging: LoggingArgs = field(default_factory=LoggingArgs)
scheduler: SchedulerArgs = field(default_factory=SchedulerArgs)

ema: EMAArgs = field(default_factory=EMAArgs)
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
async_eval_gpus: Optional[int] = None
eval: Optional[Any] = None
Expand Down Expand Up @@ -277,6 +280,7 @@ def train(args: TrainArgs):

model = Castor(args.model)
logger.info("Model is built !")
ema = EMA(model, decay=args.ema.decay, warmup_steps=args.ema.warmup_steps)

model_param_count = get_num_params(model)

Expand All @@ -291,6 +295,8 @@ def train(args: TrainArgs):
tp_parallelize=tp_parallelize,
no_recompute_ops=get_no_recompute_ops(),
)
model = apply_activation_checkpointing(model, args.distributed)
model = apply_compile(model, args.distributed)
model = model.to(device="cuda")

check_model_value_range(model, range=10.0, std=1.0)
Expand All @@ -299,6 +305,15 @@ def train(args: TrainArgs):

logger.info(f"Model size: {model_param_count:,} total parameters")

ema.ema_model = parallelize_model(
ema.ema_model,
world_mesh,
args.model,
args.distributed,
fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
)
ema.ema_model = ema.ema_model.to(device="cuda")

gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(
f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) "
Expand Down Expand Up @@ -414,6 +429,7 @@ def train(args: TrainArgs):
optimizer.step()
scheduler.step()
optimizer.zero_grad()
ema.step(model)
train_state.step += 1

# updates the scale for next iteration
Expand Down
12 changes: 8 additions & 4 deletions lingua/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

# for no recompute ops
import xformers.ops

from lingua.float8 import convert_linears_to_fp8

logger = logging.getLogger()
Expand Down Expand Up @@ -466,9 +463,13 @@ def parallelize_model(
)

model = fully_shard(model, **fsdp_config, reshard_after_forward=True)
return model
else:
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")



def apply_activation_checkpointing(model, distributed_args: DistributedArgs):
if distributed_args.selective_activation_checkpointing:
non_reentrant_wrapper = partial(
checkpoint_wrapper,
Expand All @@ -479,11 +480,14 @@ def parallelize_model(
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=lambda submodule: submodule in model.get_checkpointing_wrap_module_list(),
)
return model


def apply_compile(model, distributed_args: DistributedArgs):
if distributed_args.compile:
torch._dynamo.config.cache_size_limit = (
distributed_args.compile_cache_size_limit
)
model = torch.compile(model)

return model
return model