Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions apps/Castor/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def __init__(
self.model = model
self.vae_scale_factor = cfg.vae_scale_factor
self.resolution = int(cfg.resolution // self.vae_scale_factor)
self.cond_resolution = int(cfg.cond_resolution // self.vae_scale_factor)
self.cond_resolution = int(cfg.cond_resolution // self.vae_scale_factor) # TODO (Mingchen review comment): depracted this version?
self.device = cfg.device
self.guidance_scale = cfg.guidance_scale
self.show_progress = cfg.show_progress
self.show_progress = cfg.show_progress # TODO (Mingchen review comment): depracted this version?
self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype]
self.in_channel = model.diffusion_transformer.in_channels
self.sigma = cfg.sigma
Expand All @@ -61,7 +61,7 @@ def __init__(

def prepare_latent(self, context, device):
bsz = len(context["caption"])
latent_size = (bsz, self.in_channel, self.resolution, self.resolution)
latent_size = (bsz, self.in_channel, self.resolution, self.resolution) # TODO (Mingchen review comment): dynamic的情况下,latent也是这样么?
latents = randn_tensor(latent_size, device=device, dtype=self.dtype)
return latents

Expand All @@ -71,7 +71,7 @@ def return_seq_len(self):
@torch.no_grad()
def forward(self, context: Dict[str, Any]) -> torch.Tensor:
cur_device = next(self.model.parameters()).device
image_seq_len = self.return_seq_len()
image_seq_len = self.return_seq_len()
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
Expand Down
24 changes: 14 additions & 10 deletions apps/Castor/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@
import torch.nn.functional as F
from torch import nn

from .modules.schedulers import RectifiedFlow, SchedulerArgs
from .modules.schedulers import SchedulerArgs, RectifiedFlow # TODO (mingchen review comment): could unify as `create_scheduler`
from .modules.text_encoder import TextEncoderArgs, create_text_encoder
from .modules.transformer import DiffusionTransformer, TransformerArgs
from .modules.transformer import TransformerArgs, DiffusionTransformer # TODO (mingchen review comment): could unify as `create_dit`
from .modules.vae import VideoVAEArgs, create_vae

logger = logging.getLogger()


@dataclass
class ModelArgs:
diffusion_model: TransformerArgs = field(default_factory=TransformerArgs)
with_vae: bool = False
vae_args: VideoVAEArgs = field(default_factory=VideoVAEArgs)
pre_trained_weight: Optional[str] = None
text_encoder: TextEncoderArgs = field(default_factory=TextEncoderArgs)
diffusion_model: TransformerArgs = field(default_factory=TransformerArgs) # TODO (mingchen review comment): actually, `diffusion_model` with `TransformerArgs` is not aligned. But we can keep it as is.
scheduler: SchedulerArgs = field(default_factory=SchedulerArgs)
text_encoder: TextEncoderArgs = field(default_factory=TextEncoderArgs)
vae_args: VideoVAEArgs = field(default_factory=VideoVAEArgs) # TODO (mingchen review comment): need to named as `vae`
with_vae: bool = False
pre_trained_weight: Optional[str] = None # TODO (mingchen review comment): this is confused because all of dit, text_encoder, vae have pre-trained weights.
text_cfg_ratio: float = 0.1


Expand All @@ -45,13 +45,13 @@ def forward(self, batch: dict[str:any]) -> dict[str:any]:
if hasattr(self, "compressor"):
if isinstance(batch["image"], list):
batch["latent_code"] = [
self.compressor.encode(img[None])[0] for img in batch["image"]
self.compressor.encode(img[None])[0] for img in batch["image"] # TODO (mingchen review comment): 这里和后面的l51有矛盾。 如果写成这样会不会更加符合一致性并且可能加速: self.compressor.encode(torch.stack([img for img in batch["image"]]))
]
else:
batch["latent_code"] = self.compressor.encode(batch["image"])

if "text_embedding" not in batch:
batch["text_embedding"], batch["attention_mask"] = self.text_encoder(batch)
batch["text_embedding"], batch["attention_mask"] = self.text_encoder(batch) # TODO: 这样的写法会导致text_encoder.py不太通用了(我们应该把不灵活的判断东西都在更顶层model.py解决,而不是下面的module)。如: 直接送入self.text_encoder(batch['caption'])而不是整个batch,在这里处理输入更直接
conditional_signal, conditional_mask = (
batch["text_embedding"],
batch["attention_mask"],
Expand All @@ -67,7 +67,7 @@ def forward(self, batch: dict[str:any]) -> dict[str:any]:

latent_code = batch["latent_code"]
noised_x, t, target = self.scheduler.sample_noised_input(latent_code)
output = self.diffusion_transformer(
output = self.diffusion_transformer(·
x=noised_x,
time_steps=t,
condition=conditional_signal,
Expand All @@ -91,6 +91,7 @@ def set_eval(self):
self.diffusion_transformer.eval()

def init_weights(self, args: ModelArgs):

if args.pre_trained_weight:
self.diffusion_transformer.init_weights()
logger.info(f"Loading pre-trained weights from {args.pre_trained_weight}")
Expand All @@ -114,6 +115,9 @@ def get_no_recompute_ops():

# Optional and only used for fully shard options (fsdp) is choose. Highly recommanded for large models
def build_fsdp_grouping_plan(model_args: ModelArgs):

# TODO (mingchen review comment): 这里需要确认,虽然可能没那么重要。我眼里DiT, TextEncoder最需要fsdp。 目前text_encoder是没考虑的。

group_plan: Tuple[int, bool] = []
if model_args.with_vae:
for i in range(4): # Specific for Hunyuan's VAE
Expand Down
28 changes: 18 additions & 10 deletions apps/Castor/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
import torch.distributed
import wandb
import xformers.profiler

from apps.Castor.model import (Castor, ModelArgs, build_fsdp_grouping_plan,
get_no_recompute_ops, tp_parallelize)
from apps.main.data import AutoDataLoader, DataArgs
from apps.main.modules.schedulers import SchedulerArgs
from apps.main.utils.cal_flops import get_num_flop_per_token
from apps.main.utils.dict_tensor_data_load import DictTensorBatchIterator
from apps.main.utils.sampler import StatefulDistributedSampler

from lingua.args import dataclass_from_dict, dump_config, flatten_dict
from lingua.checkpoint import (CheckpointArgs, CheckpointManager,
load_from_checkpoint)
Expand Down Expand Up @@ -92,8 +94,8 @@ class TrainArgs:

@dataclass
class TrainState(Stateful):
step: int # Nb of steps taken by the optimizer
acc_step: int # Nb of accumulation steps done since last optimizer step
step: int # Number of steps taken by the optimizer
acc_step: int # Number of accumulation steps done since last optimizer step
scheduler: lr_scheduler.LambdaLR
sampler: StatefulDistributedSampler

Expand All @@ -119,11 +121,9 @@ def load_state_dict(self, state_dict):


def validate_train_args(args: TrainArgs):
# assert args.dump_dir, "Dump dir not set" # Mingchen: no need any more

# Minchen: generate the dump dir according to the config
if not args.dump_dir:
# args.dump_dir = f"/mnt/data/dump/{args.name}"
args.dump_dir = str(Path(args.output_dir) / f"{args.name}")

logger.info(f"Dump dir set to {args.dump_dir}")
Expand All @@ -139,7 +139,8 @@ def validate_train_args(args: TrainArgs):
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")

# TODO: Mingchen: here need to support multiple source later as in the original lingua codebase
# (Deprecated): Mingchen: here need to support multiple source later as in the original lingua codebase
# TODO (New comment): now we store the data in the same cloud storage bucket. But seems that here is not used?
for data_args in args.data:
if data_args.use:
if data_args.source == "local" and not os.path.exists(data_args.root_dir):
Expand Down Expand Up @@ -255,6 +256,7 @@ def train(args: TrainArgs):
torch.manual_seed(args.seed)
logger.info("Building model")

#with torch.device("meta"): # TODO (Mingchen): double-check wether we need "meta" because original lingua has it
model = Castor(args.model)
logger.info("Model is built !")

Expand All @@ -271,12 +273,11 @@ def train(args: TrainArgs):
tp_parallelize=tp_parallelize,
no_recompute_ops=get_no_recompute_ops(),
)
model = model.to(device="cuda")

check_model_value_range(model, range=10.0, std=1.0)

# log model size
model = model.to(device="cuda")

check_model_value_range(model, range=10.0, std=1.0)
logger.info(f"Model size: {model_param_count:,} total parameters")

gpu_memory_monitor = GPUMemoryMonitor("cuda")
Expand Down Expand Up @@ -309,8 +310,6 @@ def train(args: TrainArgs):

checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint)
checkpoint.load(model, optimizer, train_state, world_mesh)
# Either load from latest checkpoint or start from scratch

gc.disable()

# train loop
Expand Down Expand Up @@ -350,13 +349,16 @@ def train(args: TrainArgs):
batch, active_data[0].dataloader.batch_size
)
batch = next(parquet_iterator)

if "_id" in batch:
failure_rate = batch["_id"].count("-1") / len(batch["_id"])

if every_n_steps(train_state, args.gc_collect_freq, acc_step=0):
logger.info("garbage collection")
# we do garbage collection manually otherwise different processes
# run the GC at different times so they slow down the whole pipeline
gc.collect()

if "latent_code" in batch:
if isinstance(batch["latent_code"], list):
batch["latent_code"] = [
Expand All @@ -368,6 +370,7 @@ def train(args: TrainArgs):
else:
batch["latent_code"] = batch["latent_code"].cuda()
nwords_since_last_log += batch["latent_code"].numel()

elif "image" in batch:
if isinstance(batch["image"], list):
batch["image"] = [
Expand All @@ -379,8 +382,11 @@ def train(args: TrainArgs):
else:
batch["image"] = batch["image"].to(device="cuda")
nwords_since_last_log += batch["image"].numel()

else:
raise ValueError("No image or latent code in batch")


data_load_time = round(timer() - data_load_start, 4)

# forward
Expand All @@ -397,6 +403,8 @@ def train(args: TrainArgs):
# For logging we undo that scaling
loss = loss.detach() * args.grad_acc_steps

# TODO (Mingchen): 这里每次都需要裁剪么,原始lingua代码里是只在step=0的时候裁剪
# https://github.com/facebookresearch/lingua/blob/main/apps/main/train.py
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=args.optim.clip, foreach=True
)
Expand Down