From 19841529bce6fd378d8996b788efffc3d3af3e8a Mon Sep 17 00:00:00 2001 From: mczhuge Date: Sat, 12 Apr 2025 18:21:20 +0000 Subject: [PATCH 1/2] review the model, will continue review and learn. next will help improve --- apps/Castor/model.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 5fc7014..c5a88b8 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -9,9 +9,9 @@ import torch.nn.functional as F from torch import nn -from .modules.schedulers import RectifiedFlow, SchedulerArgs +from .modules.schedulers import SchedulerArgs, RectifiedFlow # TODO (mingchen review comment): could unify as `create_scheduler` from .modules.text_encoder import TextEncoderArgs, create_text_encoder -from .modules.transformer import DiffusionTransformer, TransformerArgs +from .modules.transformer import TransformerArgs, DiffusionTransformer # TODO (mingchen review comment): could unify as `create_dit` from .modules.vae import VideoVAEArgs, create_vae logger = logging.getLogger() @@ -19,12 +19,12 @@ @dataclass class ModelArgs: - diffusion_model: TransformerArgs = field(default_factory=TransformerArgs) - with_vae: bool = False - vae_args: VideoVAEArgs = field(default_factory=VideoVAEArgs) - pre_trained_weight: Optional[str] = None - text_encoder: TextEncoderArgs = field(default_factory=TextEncoderArgs) + diffusion_model: TransformerArgs = field(default_factory=TransformerArgs) # TODO (mingchen review comment): actually, `diffusion_model` with `TransformerArgs` is not aligned. But we can keep it as is. scheduler: SchedulerArgs = field(default_factory=SchedulerArgs) + text_encoder: TextEncoderArgs = field(default_factory=TextEncoderArgs) + vae_args: VideoVAEArgs = field(default_factory=VideoVAEArgs) # TODO (mingchen review comment): need to named as `vae` + with_vae: bool = False + pre_trained_weight: Optional[str] = None # TODO (mingchen review comment): this is confused because all of dit, text_encoder, vae have pre-trained weights. text_cfg_ratio: float = 0.1 @@ -45,13 +45,13 @@ def forward(self, batch: dict[str:any]) -> dict[str:any]: if hasattr(self, "compressor"): if isinstance(batch["image"], list): batch["latent_code"] = [ - self.compressor.encode(img[None])[0] for img in batch["image"] + self.compressor.encode(img[None])[0] for img in batch["image"] # TODO (mingchen review comment): 这里和后面的l51有矛盾。 如果写成这样会不会更加符合一致性并且可能加速: self.compressor.encode(torch.stack([img for img in batch["image"]])) ] else: batch["latent_code"] = self.compressor.encode(batch["image"]) if "text_embedding" not in batch: - batch["text_embedding"], batch["attention_mask"] = self.text_encoder(batch) + batch["text_embedding"], batch["attention_mask"] = self.text_encoder(batch) # TODO: 这样的写法会导致text_encoder.py不太通用了(我们应该把不灵活的判断东西都在更顶层model.py解决,而不是下面的module)。如: 直接送入self.text_encoder(batch['caption'])而不是整个batch,在这里处理输入更直接 conditional_signal, conditional_mask = ( batch["text_embedding"], batch["attention_mask"], @@ -67,7 +67,7 @@ def forward(self, batch: dict[str:any]) -> dict[str:any]: latent_code = batch["latent_code"] noised_x, t, target = self.scheduler.sample_noised_input(latent_code) - output = self.diffusion_transformer( + output = self.diffusion_transformer(· x=noised_x, time_steps=t, condition=conditional_signal, @@ -91,8 +91,10 @@ def set_eval(self): self.diffusion_transformer.eval() def init_weights(self, args: ModelArgs): + + # TODO (mingchen review comment): 怀疑现在的pre_trained_weight已经作废了? if args.pre_trained_weight: - args.diffusion_model.pre_trained_path = None + args.diffusion_model.pre_trained_path = None # TODO (mingchen review comment): 这里写得比较confused。如果args里有pre_trained_weight,把原pre_trained_path = None,这个好像逻辑不正确。 self.diffusion_transformer.init_weights(args=args.diffusion_model) logger.info(f"Loading pre-trained weights from {args.pre_trained_weight}") pre_trained_state_dict = torch.load(args.pre_trained_weight) @@ -115,6 +117,9 @@ def get_no_recompute_ops(): # Optional and only used for fully shard options (fsdp) is choose. Highly recommanded for large models def build_fsdp_grouping_plan(model_args: ModelArgs): + + # TODO (mingchen review comment): 这里需要确认,虽然可能没那么重要。我眼里DiT, TextEncoder最需要fsdp。 目前text_encoder是没考虑的。 + group_plan: Tuple[int, bool] = [] if model_args.with_vae: for i in range(4): # Specific for Hunyuan's VAE From e6395021e80645463f53240d6a4a3344745ae7d7 Mon Sep 17 00:00:00 2001 From: mczhuge Date: Mon, 14 Apr 2025 13:59:59 +0000 Subject: [PATCH 2/2] review train.py --- apps/Castor/train.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/apps/Castor/train.py b/apps/Castor/train.py index cf24514..f27a801 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -26,6 +26,7 @@ import torch.distributed import wandb import xformers.profiler + from apps.Castor.model import (Castor, ModelArgs, build_fsdp_grouping_plan, get_no_recompute_ops, tp_parallelize) from apps.main.data import AutoDataLoader, DataArgs @@ -33,6 +34,7 @@ from apps.main.utils.cal_flops import get_num_flop_per_token from apps.main.utils.dict_tensor_data_load import DictTensorBatchIterator from apps.main.utils.sampler import StatefulDistributedSampler + from lingua.args import dataclass_from_dict, dump_config, flatten_dict from lingua.checkpoint import (CheckpointArgs, CheckpointManager, load_from_checkpoint) @@ -92,8 +94,8 @@ class TrainArgs: @dataclass class TrainState(Stateful): - step: int # Nb of steps taken by the optimizer - acc_step: int # Nb of accumulation steps done since last optimizer step + step: int # Number of steps taken by the optimizer + acc_step: int # Number of accumulation steps done since last optimizer step scheduler: lr_scheduler.LambdaLR sampler: StatefulDistributedSampler @@ -119,11 +121,9 @@ def load_state_dict(self, state_dict): def validate_train_args(args: TrainArgs): - # assert args.dump_dir, "Dump dir not set" # Mingchen: no need any more # Minchen: generate the dump dir according to the config if not args.dump_dir: - # args.dump_dir = f"/mnt/data/dump/{args.name}" args.dump_dir = str(Path(args.output_dir) / f"{args.name}") logger.info(f"Dump dir set to {args.dump_dir}") @@ -139,7 +139,8 @@ def validate_train_args(args: TrainArgs): logger.info(f"Setting checkpoint path to {args.checkpoint.path}") args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") - # TODO: Mingchen: here need to support multiple source later as in the original lingua codebase + # (Deprecated): Mingchen: here need to support multiple source later as in the original lingua codebase + # TODO (New comment): now we store the data in the same cloud storage bucket. But seems that here is not used? for data_args in args.data: if data_args.use: if data_args.source == "local" and not os.path.exists(data_args.root_dir): @@ -255,6 +256,7 @@ def train(args: TrainArgs): torch.manual_seed(args.seed) logger.info("Building model") + #with torch.device("meta"): # TODO (Mingchen): double-check wether we need "meta" because original lingua has it model = Castor(args.model) logger.info("Model is built !") @@ -271,12 +273,11 @@ def train(args: TrainArgs): tp_parallelize=tp_parallelize, no_recompute_ops=get_no_recompute_ops(), ) - model = model.to(device="cuda") - check_model_value_range(model, range=10.0, std=1.0) - # log model size + model = model.to(device="cuda") + check_model_value_range(model, range=10.0, std=1.0) logger.info(f"Model size: {model_param_count:,} total parameters") gpu_memory_monitor = GPUMemoryMonitor("cuda") @@ -309,8 +310,6 @@ def train(args: TrainArgs): checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint) checkpoint.load(model, optimizer, train_state, world_mesh) - # Either load from latest checkpoint or start from scratch - gc.disable() # train loop @@ -350,13 +349,16 @@ def train(args: TrainArgs): batch, active_data[0].dataloader.batch_size ) batch = next(parquet_iterator) + if "_id" in batch: failure_rate = batch["_id"].count("-1") / len(batch["_id"]) + if every_n_steps(train_state, args.gc_collect_freq, acc_step=0): logger.info("garbage collection") # we do garbage collection manually otherwise different processes # run the GC at different times so they slow down the whole pipeline gc.collect() + if "latent_code" in batch: if isinstance(batch["latent_code"], list): batch["latent_code"] = [ @@ -368,6 +370,7 @@ def train(args: TrainArgs): else: batch["latent_code"] = batch["latent_code"].cuda() nwords_since_last_log += batch["latent_code"].numel() + elif "image" in batch: if isinstance(batch["image"], list): batch["image"] = [ @@ -379,8 +382,11 @@ def train(args: TrainArgs): else: batch["image"] = batch["image"].to(device="cuda") nwords_since_last_log += batch["image"].numel() + else: raise ValueError("No image or latent code in batch") + + data_load_time = round(timer() - data_load_start, 4) # forward @@ -397,6 +403,8 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps + # TODO (Mingchen): 这里每次都需要裁剪么,原始lingua代码里是只在step=0的时候裁剪 + # https://github.com/facebookresearch/lingua/blob/main/apps/main/train.py grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=args.optim.clip, foreach=True )