From 12f9da8bb7f2d839fd146a3d3ea6d5e826ed581e Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Tue, 11 Mar 2025 17:13:15 +0000 Subject: [PATCH 1/8] fast latent extraction with Nemo Curator --- apps/offline_inf_v2/configs/inference.yaml | 18 ++ apps/offline_inf_v2/data.py | 12 + apps/offline_inf_v2/inference.py | 53 ++++ apps/offline_inf_v2/model.py | 39 +++ apps/offline_inf_v2/vae_latent_extractor.py | 317 ++++++++++++++++++++ 5 files changed, 439 insertions(+) create mode 100644 apps/offline_inf_v2/configs/inference.yaml create mode 100644 apps/offline_inf_v2/data.py create mode 100644 apps/offline_inf_v2/inference.py create mode 100644 apps/offline_inf_v2/model.py create mode 100644 apps/offline_inf_v2/vae_latent_extractor.py diff --git a/apps/offline_inf_v2/configs/inference.yaml b/apps/offline_inf_v2/configs/inference.yaml new file mode 100644 index 00000000..a14eb358 --- /dev/null +++ b/apps/offline_inf_v2/configs/inference.yaml @@ -0,0 +1,18 @@ +name: inference + +model: + gen_vae: + model_name: Hunyuan + pretrained_model_name_or_path: '/jfs/checkpoints/models--tencent--HunyuanVideo/snapshots/2a15b5574ee77888e51ae6f593b2ceed8ce813e5/vae' + enable_tiling: false + enable_slicing: false + autocast: true + +data: + data_path: /mnt/pollux/nemo/sample + output_path: /mnt/pollux/nemo/sample_latents_256/ + id_col: key + batch_size: 256 + num_threads_per_worker: 16 + image_size: 256 + image_latent_column: image_latent_256 diff --git a/apps/offline_inf_v2/data.py b/apps/offline_inf_v2/data.py new file mode 100644 index 00000000..628b5d73 --- /dev/null +++ b/apps/offline_inf_v2/data.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass, field +from typing import Optional + +@dataclass +class DataArgs: + data_path: str = field(default="/mnt/pollux/nemo/sample/") + output_path: Optional[str] = field(default=None) + id_col: str = field(default="key") + batch_size: int = field(default=1) + num_threads_per_worker: int = field(default=4) + image_size: int = field(default=256) + image_latent_column: str = field(default="image_latent") diff --git a/apps/offline_inf_v2/inference.py b/apps/offline_inf_v2/inference.py new file mode 100644 index 00000000..4132d48b --- /dev/null +++ b/apps/offline_inf_v2/inference.py @@ -0,0 +1,53 @@ +""" +python -m apps.offline_inf_v2.inference.py config=apps/offline_inf_v2/configs/inference.yaml +""" + +from dataclasses import dataclass, field +from omegaconf import OmegaConf + +from nemo_curator import get_client +from nemo_curator.datasets import ImageTextPairDataset + +from apps.offline_inf_v2.data import DataArgs +from apps.offline_inf_v2.model import ModelArgs +from apps.offline_inf_v2.vae_latent_extractor import VAELatentExtractor + + +@dataclass +class InferenceArgs: + name: str = field(default="inference") + model: ModelArgs = field(default_factory=ModelArgs) + data: DataArgs = field(default_factory=DataArgs) + + +def main(): + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.structured(InferenceArgs()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_object(cfg) + + client = get_client(cluster_type="gpu") + + dataset = ImageTextPairDataset.from_webdataset( + path=cfg.data.data_path, id_col=cfg.data.id_col + ) + + latent_extractor = VAELatentExtractor( + model_args=cfg.model, + data_args=cfg.data, + ) + + dataset_with_latents = latent_extractor(dataset) + + # Metadata will have a new column named "image_latent" + dataset_with_latents.save_metadata( + cfg.data.output_path, columns=[cfg.data.id_col, "doc_id", cfg.data.image_latent_column] + ) + + +if __name__ == "__main__": + main() diff --git a/apps/offline_inf_v2/model.py b/apps/offline_inf_v2/model.py new file mode 100644 index 00000000..f4cd6df5 --- /dev/null +++ b/apps/offline_inf_v2/model.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from dataclasses import dataclass, field +import logging +import torch +from torch import nn +from apps.main.modules.vae import build_vae, LatentVideoVAEArgs + +logger = logging.getLogger() + + +@dataclass +class ModelArgs: + gen_vae: LatentVideoVAEArgs = field(default_factory=LatentVideoVAEArgs) + autocast: bool = field(default=False) + use_compile: bool = field(default=False) + +class VAE(nn.Module): + """ + VAE Model + """ + + VERSION: str = "v1.0" + + def __init__(self, args: ModelArgs): + super().__init__() + + self.vae_compressor = build_vae(args.gen_vae) + + @torch.no_grad() + def forward( + self, image: torch.Tensor + ) -> torch.Tensor: + + # Process latent code + image = image.cuda() + latent_code = self.vae_compressor.encode(image) + + return latent_code diff --git a/apps/offline_inf_v2/vae_latent_extractor.py b/apps/offline_inf_v2/vae_latent_extractor.py new file mode 100644 index 00000000..7a4ce879 --- /dev/null +++ b/apps/offline_inf_v2/vae_latent_extractor.py @@ -0,0 +1,317 @@ +import json +from tqdm import tqdm +from PIL import Image +import cupy as cp +from crossfit.backend.cudf.series import create_list_series_from_1d_or_2d_ar + +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import torch +from nvidia.dali import pipeline_def +from nvidia.dali.plugin.pytorch import feed_ndarray + +from nemo_curator.datasets import ImageTextPairDataset +from nemo_curator.utils.distributed_utils import load_object_on_worker + +from apps.offline_inf_v2.data import DataArgs +from apps.offline_inf_v2.model import ModelArgs, VAE + + +class VAELatentExtractor: + def __init__( + self, + model_args: ModelArgs, + data_args: DataArgs, + use_index_files: bool = False, + ) -> None: + """ + Constructs the embedder. + + Args: + model_args (ModelArgs): The arguments for the VAE. + data_args (DataArgs): The arguments for the data. + use_index_files (bool): If True, tries to find and use index files generated + by DALI at the same path as the tar file shards. The index files must be + generated by DALI's wds2idx tool. See https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/dataloading_webdataset.html#Creating-an-index + for more information. Each index file must be of the form "shard_id.idx" + where shard_id is the same integer as the corresponding tar file for the + data. The index files must be in the same folder as the tar files. + """ + self.model_args = model_args + self.data_args = data_args + self.use_index_files = use_index_files + + # torch_transforms = transforms.Compose( + # [ + # transforms.RandomHorizontalFlip(), + # transforms.ToTensor(), + # transforms.Normalize( + # mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True + # ), + # ] + # ) + # self.dali_transforms = convert_transforms_to_dali(torch_transforms) + + def load_dataset_shard(self, tar_path: str): + """ + Loads a WebDataset tar shard using DALI. + + Args: + tar_path (str): The path of the tar shard to load. + + Returns: + Iterable: An iterator over the dataset. Each tar file + must have 3 files per record: a .jpg file, a .txt file, + and a .json file. The .jpg file must contain the image, the + .txt file must contain the associated caption, and the + .json must contain the metadata for the record (including + its ID). Images will be loaded using DALI. + """ + + + def downsample_resize_image(image): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + image = cp.asnumpy(image) + pil_image = Image.fromarray(image) + while min(*pil_image.size) >= 2 * self.data_args.image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + return cp.asarray(pil_image) + + + # Create the DALI pipeline + @pipeline_def( + batch_size=self.data_args.batch_size, + num_threads=self.data_args.num_threads_per_worker, + device_id=0, + ) + def webdataset_pipeline(_tar_path: str): + if self.use_index_files: + index_paths = [f"{_tar_path.rsplit('.', 1)[0]}.idx"] + else: + index_paths = [] + + images_raw, text, json = fn.readers.webdataset( + paths=_tar_path, + index_paths=index_paths, + ext=["jpg", "txt", "json"], + missing_component_behavior="error", + ) + images = fn.decoders.image(images_raw, device="mixed", output_type=types.RGB) + + # images = fn.python_function( + # images, + # function=downsample_resize_image, + # ) + images = fn.resize( + images, + device="gpu", + resize_x=self.data_args.image_size, + resize_y=self.data_args.image_size, + mode="not_smaller", + interp_type=types.DALIInterpType.INTERP_CUBIC + ) + images = fn.crop_mirror_normalize( + images, + device="gpu", + crop_h=self.data_args.image_size, + crop_w=self.data_args.image_size, + crop_pos_x=0.5, + crop_pos_y=0.5, + mirror=fn.random.coin_flip(probability=0.5), + dtype=types.DALIDataType.FLOAT, + mean=[0.5 * 255, 0.5 * 255, 0.5 * 255], + std=[0.5 * 255, 0.5 * 255, 0.5 * 255], + scale=1.0, + ) + + return images, text, json + + pipeline = webdataset_pipeline(tar_path) + pipeline.build() + + total_samples = pipeline.epoch_size() + total_samples = total_samples[list(total_samples.keys())[0]] + + samples_completed = 0 + while samples_completed < total_samples: + image, text, meta = pipeline.run() + image = image.as_tensor() + + image_torch = torch.empty(image.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image, image_torch) # COPY !!! + image = image_torch + + captions = [text.at(i).tostring().decode("utf-8") for i in range(len(text))] + metadata = [ + json.loads(meta.at(i).tostring().decode("utf-8")) + for i in range(len(meta)) + ] + + remaining_samples = total_samples - samples_completed + if image.shape[0] >= remaining_samples: + image = image[:remaining_samples] + captions = captions[:remaining_samples] + metadata = metadata[:remaining_samples] + + samples_completed += min(image.shape[0], remaining_samples) + + yield image, metadata + + def load_model(self, model_args, device="cuda"): + """ + Loads the model used to generate image embeddings. + + Args: + device (str): A PyTorch device identifier that specifies what GPU + to load the model on. + + Returns: + Callable: A timm model loaded on the specified device. + The model's forward call may be augmented with torch.autocast() + or embedding normalization if specified in the constructor. + """ + model = VAE(model_args).eval().to(device) + model = self._configure_forward(model) + + return model + + def _configure_forward(self, model): + original_forward = model.forward + + def custom_forward(*args, **kwargs): + if self.model_args.autocast: + with torch.amp.autocast(device_type="cuda"): + image_features = original_forward(*args, **kwargs) + else: + image_features = original_forward(*args, **kwargs) + + # Inference can be done in lower precision, but cuDF can only handle fp32 + return image_features.to(torch.float32) + + model.forward = custom_forward + return model + + def _process_batch(self, model, batch): + """Helper method to process a batch with appropriate chunking""" + if batch.shape[0] > 16 and batch.shape[0] % 16 == 0: + # Process in chunks of 16 to avoid OOM + sub_batches = batch.chunk(batch.shape[0] // 16) + sub_latents = [] + for sub_batch in sub_batches: + sub_latent = model(sub_batch) + sub_latents.append(sub_latent.cpu()) + return torch.cat(sub_latents, dim=0) + else: + latents = model(batch) + return latents.cpu() # Move to CPU immediately + + def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: + """ + Generates image embeddings for all images in the dataset. + + Args: + dataset (ImageTextPairDataset): The dataset to create image embeddings for. + + Returns: + ImageTextPairDataset: A dataset with image embeddings and potentially + classifier scores. + """ + meta = dataset.metadata.dtypes.to_dict() + meta[self.data_args.image_latent_column] = "object" + + embedding_df = dataset.metadata.map_partitions( + self._run_inference, dataset.tar_files, dataset.id_col, meta=meta + ) + + return ImageTextPairDataset( + dataset.path, + metadata=embedding_df, + tar_files=dataset.tar_files, + id_col=dataset.id_col, + ) + + def _run_inference(self, partition, tar_paths, id_col, partition_info=None): + tar_path = tar_paths[partition_info["number"]] + device = "cuda" + + model = load_object_on_worker( + "model", + self.load_model, + {"model_args": self.model_args, "device": device}, + ) + + dataset = self.load_dataset_shard(tar_path) + final_image_latents = [] + image_ids = [] + samples_completed = 0 + progress_bar = tqdm( + total=len(partition), + desc=f"{tar_path} - Latent extraction with {self.model_args.gen_vae.model_name}", + ) + + # Process batches + with torch.no_grad(), torch.amp.autocast(device_type="cuda", enabled=self.model_args.autocast): + for batch, metadata in dataset: + image_latents = self._process_batch(model, batch) + + final_image_latents.append(image_latents) + image_ids.extend(m[id_col] for m in metadata) + + batch_size = len(image_latents) + samples_completed += batch_size + progress_bar.update(batch_size) + + # Clear CUDA cache less frequently + if samples_completed % (self.data_args.batch_size * 50) == 0: + torch.cuda.empty_cache() + progress_bar.close() + + if samples_completed != len(partition): + raise RuntimeError( + f"Mismatch in sample count for partition {partition_info['number']}. " + f"{len(partition)} samples found in the metadata, but {samples_completed} found in {tar_path}." + ) + + # Process embeddings in memory-efficient way + return self._process_embeddings(partition, final_image_latents, image_ids) + + def _process_embeddings(self, partition, final_image_latents, image_ids): + """Process embeddings in a memory-efficient way""" + # Order the output of the shard + sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + + # Process in chunks to reduce memory usage + all_embeddings = torch.cat(final_image_latents, dim=0) + sorted_embeddings = all_embeddings[sorted_indices] + + # View the embeddings to be [N, 16*32*32] + sorted_embeddings = sorted_embeddings.view(sorted_embeddings.shape[0], -1) + + # Process in chunks to avoid OOM + chunk_size = 1000 # Adjust based on your GPU memory + concat_embedding_output = None + + for i in range(0, sorted_embeddings.shape[0], chunk_size): + end_idx = min(i + chunk_size, sorted_embeddings.shape[0]) + chunk = sorted_embeddings[i:end_idx].cuda() # Move chunk to GPU + chunk_cp = cp.asarray(chunk) # Convert to CuPy + + if concat_embedding_output is None: + concat_embedding_output = chunk_cp + else: + concat_embedding_output = cp.concatenate([concat_embedding_output, chunk_cp], axis=0) + + # Free GPU memory + del chunk + torch.cuda.empty_cache() + + partition[self.data_args.image_latent_column] = create_list_series_from_1d_or_2d_ar( + concat_embedding_output, index=partition.index + ) + + return partition From c9a0f3e112b6a6724f6b337f711a4a1f268dc708 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Tue, 11 Mar 2025 17:56:23 +0000 Subject: [PATCH 2/8] config --- .../{inference.yaml => inference_256.yaml} | 0 apps/offline_inf_v2/configs/inference_512.yaml | 18 ++++++++++++++++++ 2 files changed, 18 insertions(+) rename apps/offline_inf_v2/configs/{inference.yaml => inference_256.yaml} (100%) create mode 100644 apps/offline_inf_v2/configs/inference_512.yaml diff --git a/apps/offline_inf_v2/configs/inference.yaml b/apps/offline_inf_v2/configs/inference_256.yaml similarity index 100% rename from apps/offline_inf_v2/configs/inference.yaml rename to apps/offline_inf_v2/configs/inference_256.yaml diff --git a/apps/offline_inf_v2/configs/inference_512.yaml b/apps/offline_inf_v2/configs/inference_512.yaml new file mode 100644 index 00000000..248dba1b --- /dev/null +++ b/apps/offline_inf_v2/configs/inference_512.yaml @@ -0,0 +1,18 @@ +name: inference + +model: + gen_vae: + model_name: Hunyuan + pretrained_model_name_or_path: '/jfs/checkpoints/models--tencent--HunyuanVideo/snapshots/2a15b5574ee77888e51ae6f593b2ceed8ce813e5/vae' + enable_tiling: false + enable_slicing: false + autocast: true + +data: + data_path: /mnt/pollux/nemo/sample + output_path: /mnt/pollux/nemo/sample_latents_512/ + id_col: key + batch_size: 256 + num_threads_per_worker: 16 + image_size: 512 + image_latent_column: image_latent_512 From e546e3210fe5eb72116f967c31d9a8333ca220af Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Tue, 11 Mar 2025 17:56:45 +0000 Subject: [PATCH 3/8] assert cfg.data.output_path --- apps/offline_inf_v2/inference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/apps/offline_inf_v2/inference.py b/apps/offline_inf_v2/inference.py index 4132d48b..aa837e38 100644 --- a/apps/offline_inf_v2/inference.py +++ b/apps/offline_inf_v2/inference.py @@ -1,4 +1,5 @@ """ +conda activate curator python -m apps.offline_inf_v2.inference.py config=apps/offline_inf_v2/configs/inference.yaml """ @@ -30,7 +31,9 @@ def main(): cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) cfg = OmegaConf.to_object(cfg) - client = get_client(cluster_type="gpu") + assert cfg.data.output_path is not None, f"Output path is required, otherwise the parquets in {cfg.data.data_path} will be overwritten" + + client = get_client(cluster_type="gpu", nvlink_only=True) dataset = ImageTextPairDataset.from_webdataset( path=cfg.data.data_path, id_col=cfg.data.id_col From 307eb9a2c1f79e16855e796f181b9047aac34908 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Wed, 12 Mar 2025 07:28:07 +0000 Subject: [PATCH 4/8] clean docstring --- apps/offline_inf_v2/vae_latent_extractor.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/apps/offline_inf_v2/vae_latent_extractor.py b/apps/offline_inf_v2/vae_latent_extractor.py index 7a4ce879..7565a975 100644 --- a/apps/offline_inf_v2/vae_latent_extractor.py +++ b/apps/offline_inf_v2/vae_latent_extractor.py @@ -171,13 +171,11 @@ def load_model(self, model_args, device="cuda"): to load the model on. Returns: - Callable: A timm model loaded on the specified device. + Callable: A model loaded on the specified device. The model's forward call may be augmented with torch.autocast() - or embedding normalization if specified in the constructor. """ model = VAE(model_args).eval().to(device) model = self._configure_forward(model) - return model def _configure_forward(self, model): @@ -245,6 +243,8 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): {"model_args": self.model_args, "device": device}, ) + print(f"Model loaded on {device}") + dataset = self.load_dataset_shard(tar_path) final_image_latents = [] image_ids = [] @@ -258,7 +258,8 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): with torch.no_grad(), torch.amp.autocast(device_type="cuda", enabled=self.model_args.autocast): for batch, metadata in dataset: image_latents = self._process_batch(model, batch) - + del batch + final_image_latents.append(image_latents) image_ids.extend(m[id_col] for m in metadata) @@ -266,8 +267,8 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): samples_completed += batch_size progress_bar.update(batch_size) - # Clear CUDA cache less frequently - if samples_completed % (self.data_args.batch_size * 50) == 0: + # Clear CUDA cache frequently + if samples_completed % (self.data_args.batch_size * 5) == 0: torch.cuda.empty_cache() progress_bar.close() @@ -314,4 +315,8 @@ def _process_embeddings(self, partition, final_image_latents, image_ids): concat_embedding_output, index=partition.index ) - return partition + del concat_embedding_output + del final_image_latents + torch.cuda.empty_cache() + + return partition \ No newline at end of file From de6a05879c9f974de3f624c03a9ac22d93e6d974 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Sat, 15 Mar 2025 06:54:37 +0000 Subject: [PATCH 5/8] dynamic resolution updates --- apps/main/modules/preprocess.py | 52 +++ .../offline_inf_v2/configs/inference_256.yaml | 6 +- .../offline_inf_v2/configs/inference_512.yaml | 6 +- apps/offline_inf_v2/data.py | 4 + apps/offline_inf_v2/inference.py | 4 +- apps/offline_inf_v2/vae_latent_extractor.py | 374 ++++++++++++------ 6 files changed, 316 insertions(+), 130 deletions(-) diff --git a/apps/main/modules/preprocess.py b/apps/main/modules/preprocess.py index 42e5362c..223aedb9 100644 --- a/apps/main/modules/preprocess.py +++ b/apps/main/modules/preprocess.py @@ -166,6 +166,58 @@ def center_crop_arr(pil_image, image_size): ) +def generate_crop_size_list(image_size, patch_size, max_ratio=2.0): + assert max_ratio >= 1.0 + min_wp, min_hp = image_size // patch_size, image_size // patch_size + crop_size_list = [] + wp, hp = min_wp, min_hp + while hp / wp <= max_ratio: + crop_size_list.append((wp * patch_size, hp * patch_size)) + hp += 1 + wp, hp = min_wp + 1, min_hp + while wp / hp <= max_ratio: + crop_size_list.append((wp * patch_size, hp * patch_size)) + wp += 1 + return crop_size_list + + +def is_valid_crop_size(cw, ch, orig_w, orig_h): + down_scale = max(cw / orig_w, ch / orig_h) + return cw <= orig_w * down_scale and ch <= orig_h * down_scale + + +def var_center_crop_size_fn(orig_img_shape, image_size=256, patch_size=16, max_ratio=2.0): + w, h, _ = orig_img_shape + crop_size_list = generate_crop_size_list( + image_size=image_size, + patch_size=patch_size, + max_ratio=max_ratio + ) + rem_percent = [ + min(cw / w, ch / h) / max(cw / w, ch / h) + if is_valid_crop_size(cw, ch, w, h) else 0 + for cw, ch in crop_size_list + ] + crop_size = sorted( + ((x, y) for x, y in zip(rem_percent, crop_size_list) if x > 0 and y[0] <= w and y[1] <= h), + reverse=True + )[0][1] + return np.array(crop_size, dtype=np.float32) + + +def downsample_resize_image(image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + pil_image = Image.fromarray(image) + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + return pil_image + + ######################## FOR TEXT ######################## diff --git a/apps/offline_inf_v2/configs/inference_256.yaml b/apps/offline_inf_v2/configs/inference_256.yaml index a14eb358..34c0ff67 100644 --- a/apps/offline_inf_v2/configs/inference_256.yaml +++ b/apps/offline_inf_v2/configs/inference_256.yaml @@ -11,8 +11,12 @@ model: data: data_path: /mnt/pollux/nemo/sample output_path: /mnt/pollux/nemo/sample_latents_256/ + enable_checkpointing: true id_col: key - batch_size: 256 + batch_size: 16 num_threads_per_worker: 16 image_size: 256 + patch_size: 16 + dynamic_crop_ratio: 1.0 image_latent_column: image_latent_256 + image_latent_shape_column: image_latent_shape_256 diff --git a/apps/offline_inf_v2/configs/inference_512.yaml b/apps/offline_inf_v2/configs/inference_512.yaml index 248dba1b..f8f5e93e 100644 --- a/apps/offline_inf_v2/configs/inference_512.yaml +++ b/apps/offline_inf_v2/configs/inference_512.yaml @@ -11,8 +11,12 @@ model: data: data_path: /mnt/pollux/nemo/sample output_path: /mnt/pollux/nemo/sample_latents_512/ + enable_checkpointing: true id_col: key - batch_size: 256 + batch_size: 16 num_threads_per_worker: 16 image_size: 512 + patch_size: 16 + dynamic_crop_ratio: 1.0 image_latent_column: image_latent_512 + image_latent_shape_column: image_latent_shape_512 diff --git a/apps/offline_inf_v2/data.py b/apps/offline_inf_v2/data.py index 628b5d73..fb426517 100644 --- a/apps/offline_inf_v2/data.py +++ b/apps/offline_inf_v2/data.py @@ -5,8 +5,12 @@ class DataArgs: data_path: str = field(default="/mnt/pollux/nemo/sample/") output_path: Optional[str] = field(default=None) + enable_checkpointing: Optional[str] = field(default=None) id_col: str = field(default="key") batch_size: int = field(default=1) num_threads_per_worker: int = field(default=4) image_size: int = field(default=256) + patch_size: int = field(default=16) + dynamic_crop_ratio: float = field(default=1.0) image_latent_column: str = field(default="image_latent") + image_latent_shape_column: str = field(default="image_latent_shape") \ No newline at end of file diff --git a/apps/offline_inf_v2/inference.py b/apps/offline_inf_v2/inference.py index aa837e38..5b5c4edb 100644 --- a/apps/offline_inf_v2/inference.py +++ b/apps/offline_inf_v2/inference.py @@ -48,7 +48,9 @@ def main(): # Metadata will have a new column named "image_latent" dataset_with_latents.save_metadata( - cfg.data.output_path, columns=[cfg.data.id_col, "doc_id", cfg.data.image_latent_column] + cfg.data.output_path, columns=[ + cfg.data.id_col, "doc_id", cfg.data.image_latent_column, cfg.data.image_latent_shape_column + ] ) diff --git a/apps/offline_inf_v2/vae_latent_extractor.py b/apps/offline_inf_v2/vae_latent_extractor.py index 7565a975..1c8c1aba 100644 --- a/apps/offline_inf_v2/vae_latent_extractor.py +++ b/apps/offline_inf_v2/vae_latent_extractor.py @@ -1,4 +1,6 @@ +from collections import defaultdict import json +import os from tqdm import tqdm from PIL import Image import cupy as cp @@ -9,14 +11,84 @@ import torch from nvidia.dali import pipeline_def from nvidia.dali.plugin.pytorch import feed_ndarray +from nvidia.dali.tensors import TensorListGPU from nemo_curator.datasets import ImageTextPairDataset from nemo_curator.utils.distributed_utils import load_object_on_worker +from apps.main.modules.preprocess import generate_crop_size_list, var_center_crop_size_fn from apps.offline_inf_v2.data import DataArgs from apps.offline_inf_v2.model import ModelArgs, VAE +@pipeline_def +def image_loading_pipeline(_tar_path: str, use_index_files: bool, data_args: DataArgs): + if use_index_files: + index_paths = [f"{_tar_path.rsplit('.', 1)[0]}.idx"] + else: + index_paths = [] + + images_raw, text, json = fn.readers.webdataset( + paths=_tar_path, + index_paths=index_paths, + ext=["jpg", "txt", "json"], + missing_component_behavior="error", + ) + images = fn.decoders.image(images_raw, device="mixed", output_type=types.RGB) + return images, text, json + + +@pipeline_def +def webdataset_pipeline(_tar_path: str, use_index_files: bool, data_args: DataArgs): + if use_index_files: + index_paths = [f"{_tar_path.rsplit('.', 1)[0]}.idx"] + else: + index_paths = [] + + images_raw, text, json = fn.readers.webdataset( + paths=_tar_path, + index_paths=index_paths, + ext=["jpg", "txt", "json"], + missing_component_behavior="error", + ) + images = fn.decoders.image(images_raw, device="mixed", output_type=types.RGB) + + images = fn.resize( + images, + device="gpu", + resize_x=data_args.image_size, + resize_y=data_args.image_size, + mode="not_smaller", + interp_type=types.DALIInterpType.INTERP_CUBIC + ) + + # get the dynamic crop size + crop_size = fn.python_function( + images.shape(device="cpu"), + data_args.image_size, + data_args.patch_size, + data_args.dynamic_crop_ratio, + function=var_center_crop_size_fn, + device="cpu" + ) + + images = fn.crop_mirror_normalize( + images, + device="gpu", + crop_h=crop_size[0], + crop_w=crop_size[1], + crop_pos_x=0.5, + crop_pos_y=0.5, + mirror=fn.random.coin_flip(probability=0.5), + dtype=types.DALIDataType.FLOAT, + mean=[0.5 * 255, 0.5 * 255, 0.5 * 255], + std=[0.5 * 255, 0.5 * 255, 0.5 * 255], + scale=1.0, + ) + + return images, text, json + + class VAELatentExtractor: def __init__( self, @@ -41,17 +113,6 @@ def __init__( self.data_args = data_args self.use_index_files = use_index_files - # torch_transforms = transforms.Compose( - # [ - # transforms.RandomHorizontalFlip(), - # transforms.ToTensor(), - # transforms.Normalize( - # mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True - # ), - # ] - # ) - # self.dali_transforms = convert_transforms_to_dali(torch_transforms) - def load_dataset_shard(self, tar_path: str): """ Loads a WebDataset tar shard using DALI. @@ -67,100 +128,92 @@ def load_dataset_shard(self, tar_path: str): .json must contain the metadata for the record (including its ID). Images will be loaded using DALI. """ - - - def downsample_resize_image(image): - """ - Center cropping implementation from ADM. - https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 - """ - image = cp.asnumpy(image) - pil_image = Image.fromarray(image) - while min(*pil_image.size) >= 2 * self.data_args.image_size: - pil_image = pil_image.resize( - tuple(x // 2 for x in pil_image.size), resample=Image.BOX - ) - return cp.asarray(pil_image) - - - # Create the DALI pipeline - @pipeline_def( - batch_size=self.data_args.batch_size, - num_threads=self.data_args.num_threads_per_worker, - device_id=0, + # # Create the DALI pipeline + # @pipeline_def( + # batch_size=self.data_args.batch_size, + # num_threads=self.data_args.num_threads_per_worker, + # device_id=0, + # exec_dynamic=True, + # ) + # def webdataset_pipeline_wrapper(_tar_path: str, use_index_files: bool, data_args: DataArgs): + # return webdataset_pipeline(_tar_path, use_index_files, data_args) + + dali_pipeline_args = { + "batch_size": self.data_args.batch_size, + "num_threads": self.data_args.num_threads_per_worker, + "device_id": 0, + "exec_dynamic": True, + } + + if self.data_args.enable_checkpointing: + dali_pipeline_args["enable_checkpointing"] = True + checkpoint_path = f"{tar_path.rsplit('.', 1)[0]}_{self.data_args.image_size}.pth" + if os.path.exists(checkpoint_path): + print (f"Restoring checkpoint from {checkpoint_path}") + checkpoint = open(checkpoint_path, 'rb').read() + dali_pipeline_args["checkpoint"] = checkpoint + + pipeline = webdataset_pipeline( + tar_path, + self.use_index_files, + self.data_args, + **dali_pipeline_args, ) - def webdataset_pipeline(_tar_path: str): - if self.use_index_files: - index_paths = [f"{_tar_path.rsplit('.', 1)[0]}.idx"] - else: - index_paths = [] - - images_raw, text, json = fn.readers.webdataset( - paths=_tar_path, - index_paths=index_paths, - ext=["jpg", "txt", "json"], - missing_component_behavior="error", - ) - images = fn.decoders.image(images_raw, device="mixed", output_type=types.RGB) - - # images = fn.python_function( - # images, - # function=downsample_resize_image, - # ) - images = fn.resize( - images, - device="gpu", - resize_x=self.data_args.image_size, - resize_y=self.data_args.image_size, - mode="not_smaller", - interp_type=types.DALIInterpType.INTERP_CUBIC - ) - images = fn.crop_mirror_normalize( - images, - device="gpu", - crop_h=self.data_args.image_size, - crop_w=self.data_args.image_size, - crop_pos_x=0.5, - crop_pos_y=0.5, - mirror=fn.random.coin_flip(probability=0.5), - dtype=types.DALIDataType.FLOAT, - mean=[0.5 * 255, 0.5 * 255, 0.5 * 255], - std=[0.5 * 255, 0.5 * 255, 0.5 * 255], - scale=1.0, - ) - - return images, text, json - - pipeline = webdataset_pipeline(tar_path) pipeline.build() total_samples = pipeline.epoch_size() total_samples = total_samples[list(total_samples.keys())[0]] - samples_completed = 0 - while samples_completed < total_samples: - image, text, meta = pipeline.run() - image = image.as_tensor() - - image_torch = torch.empty(image.shape(), dtype=torch.float32, device="cuda") - feed_ndarray(image, image_torch) # COPY !!! - image = image_torch - - captions = [text.at(i).tostring().decode("utf-8") for i in range(len(text))] - metadata = [ - json.loads(meta.at(i).tostring().decode("utf-8")) - for i in range(len(meta)) - ] + crop_size_list = generate_crop_size_list( + image_size=self.data_args.image_size, + patch_size=self.data_args.patch_size, + max_ratio=2.0 + ) - remaining_samples = total_samples - samples_completed - if image.shape[0] >= remaining_samples: - image = image[:remaining_samples] - captions = captions[:remaining_samples] - metadata = metadata[:remaining_samples] + bucket_img = defaultdict(list) + # bucket_text_cpu = defaultdict(list) + bucket_meta = defaultdict(list) - samples_completed += min(image.shape[0], remaining_samples) + # pbar = tqdm(total=total_samples, desc="Loading dataset shard") - yield image, metadata + samples_completed = 0 + while samples_completed < total_samples: + image, text, meta = pipeline.run() + for i in range(self.data_args.batch_size): + img = image[i] + _, w, h = img.shape() + crop_size = (w, h) + bucket_img[crop_size].append(img) + # bucket_text_cpu[crop_size].append(text.at(i).tostring().decode("utf-8")) + bucket_meta[crop_size].append(json.loads(meta.at(i).tostring().decode("utf-8"))) + + # if batch size is reached, yield the batch + if len(bucket_img[crop_size]) == self.data_args.batch_size: + image_batch = TensorListGPU(bucket_img[crop_size]).as_tensor() + image_torch = torch.empty(image_batch.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_batch, image_torch) # COPY !!! + yield image_torch, bucket_meta[crop_size] + bucket_img[crop_size] = [] + bucket_meta[crop_size] = [] + + samples_completed += self.data_args.batch_size + + # if samples_completed % (100 * self.data_args.batch_size) == 0 and self.data_args.enable_checkpointing is not None: + # pipeline.checkpoint(checkpoint_path) + # pbar.update(self.data_args.batch_size) + # pbar.close() + + for crop_size in crop_size_list: + if not bucket_img[crop_size]: + continue + image_batch = TensorListGPU(bucket_img[crop_size]).as_tensor() + image_torch = torch.empty(image_batch.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_batch, image_torch) # COPY !!! + yield image_torch, bucket_meta[crop_size] + + # checkpoint final state + if self.data_args.enable_checkpointing: + pipeline.checkpoint(checkpoint_path) def load_model(self, model_args, device="cuda"): """ @@ -221,6 +274,7 @@ def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: """ meta = dataset.metadata.dtypes.to_dict() meta[self.data_args.image_latent_column] = "object" + meta[self.data_args.image_latent_shape_column] = "object" embedding_df = dataset.metadata.map_partitions( self._run_inference, dataset.tar_files, dataset.id_col, meta=meta @@ -249,23 +303,36 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): final_image_latents = [] image_ids = [] samples_completed = 0 + expected_samples = len(partition) progress_bar = tqdm( - total=len(partition), + total=expected_samples, desc=f"{tar_path} - Latent extraction with {self.model_args.gen_vae.model_name}", ) # Process batches with torch.no_grad(), torch.amp.autocast(device_type="cuda", enabled=self.model_args.autocast): for batch, metadata in dataset: + # Only process as many samples as we expect from metadata + if samples_completed >= expected_samples: + break + + # Calculate how many samples we should process from this batch + remaining = expected_samples - samples_completed + actual_batch_size = min(len(metadata), remaining) + + if actual_batch_size < len(metadata): + # Truncate batch and metadata if needed + batch = batch[:actual_batch_size] + metadata = metadata[:actual_batch_size] + image_latents = self._process_batch(model, batch) del batch final_image_latents.append(image_latents) image_ids.extend(m[id_col] for m in metadata) - batch_size = len(image_latents) - samples_completed += batch_size - progress_bar.update(batch_size) + samples_completed += actual_batch_size + progress_bar.update(actual_batch_size) # Clear CUDA cache frequently if samples_completed % (self.data_args.batch_size * 5) == 0: @@ -283,40 +350,93 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): def _process_embeddings(self, partition, final_image_latents, image_ids): """Process embeddings in a memory-efficient way""" + + # Check if we need to handle variable-sized latents + if self.data_args.dynamic_crop_ratio > 1.0: + # Handle variable-sized latents + return self._process_variable_sized_embeddings(partition, final_image_latents, image_ids) + else: + # Order the output of the shard + sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + # Process fixed-size latents as before + # Process in chunks to reduce memory usage + all_embeddings = torch.cat(final_image_latents, dim=0) + sorted_embeddings = all_embeddings[sorted_indices] + + embedding_shape_list = [emb.shape for emb in sorted_embeddings] + + # View the embeddings to be [N, 16*32*32] + sorted_embeddings = sorted_embeddings.view(sorted_embeddings.shape[0], -1) + + # Process in chunks to avoid OOM + chunk_size = 1000 # Adjust based on your GPU memory + concat_embedding_output = None + for i in range(0, sorted_embeddings.shape[0], chunk_size): + end_idx = min(i + chunk_size, sorted_embeddings.shape[0]) + chunk = sorted_embeddings[i:end_idx].cuda() # Move chunk to GPU + chunk_cp = cp.asarray(chunk) # Convert to CuPy + + if concat_embedding_output is None: + concat_embedding_output = chunk_cp + else: + concat_embedding_output = cp.concatenate([concat_embedding_output, chunk_cp], axis=0) + + # Free GPU memory + del chunk + torch.cuda.empty_cache() + + partition[self.data_args.image_latent_column] = create_list_series_from_1d_or_2d_ar( + concat_embedding_output, index=partition.index + ) + + # Convert embedding_shape_list to a CuPy array before passing it + embedding_shape_array = cp.array(embedding_shape_list) + partition[self.data_args.image_latent_shape_column] = create_list_series_from_1d_or_2d_ar( + embedding_shape_array, index=partition.index + ) + + del concat_embedding_output + del final_image_latents + del embedding_shape_array # Also clean up the new array + torch.cuda.empty_cache() + + return partition + + def _process_variable_sized_embeddings(self, partition, final_image_latents, image_ids): + """Process embeddings with variable sizes due to dynamic cropping""" # Order the output of the shard sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + + # Flatten our list of tensors + flat_embeddings = [] + current_idx = 0 - # Process in chunks to reduce memory usage - all_embeddings = torch.cat(final_image_latents, dim=0) - sorted_embeddings = all_embeddings[sorted_indices] - - # View the embeddings to be [N, 16*32*32] - sorted_embeddings = sorted_embeddings.view(sorted_embeddings.shape[0], -1) + for batch_tensors in final_image_latents: + batch_size = batch_tensors.shape[0] + for i in range(batch_size): + flat_embeddings.append(batch_tensors[i]) + current_idx += 1 - # Process in chunks to avoid OOM - chunk_size = 1000 # Adjust based on your GPU memory - concat_embedding_output = None + # Create sorted embeddings list + sorted_embeddings = [flat_embeddings[idx] for idx in sorted_indices] + del flat_embeddings - for i in range(0, sorted_embeddings.shape[0], chunk_size): - end_idx = min(i + chunk_size, sorted_embeddings.shape[0]) - chunk = sorted_embeddings[i:end_idx].cuda() # Move chunk to GPU - chunk_cp = cp.asarray(chunk) # Convert to CuPy + # Process each embedding individually and create a list of numpy arrays + embedding_list = [] + embedding_shape_list = [] + for emb in sorted_embeddings: + # Flatten the embedding to 1D + embedding_shape_list.append(emb.shape) + flat_emb = emb.view(-1).numpy() # Convert to numpy array + embedding_list.append(flat_emb) - if concat_embedding_output is None: - concat_embedding_output = chunk_cp - else: - concat_embedding_output = cp.concatenate([concat_embedding_output, chunk_cp], axis=0) - - # Free GPU memory - del chunk - torch.cuda.empty_cache() + # Assign to partition + partition[self.data_args.image_latent_column] = embedding_list + partition[self.data_args.image_latent_shape_column] = embedding_shape_list - partition[self.data_args.image_latent_column] = create_list_series_from_1d_or_2d_ar( - concat_embedding_output, index=partition.index - ) - - del concat_embedding_output + del embedding_list + del sorted_embeddings del final_image_latents torch.cuda.empty_cache() - + return partition \ No newline at end of file From 4d5f2599aba7b915da59d9f4a874462881f1e6a4 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Sun, 16 Mar 2025 06:16:14 +0000 Subject: [PATCH 6/8] latent extraction improvements --- .../{inference_256.yaml => inference.yaml} | 12 +- .../offline_inf_v2/configs/inference_512.yaml | 22 - apps/offline_inf_v2/data.py | 7 +- apps/offline_inf_v2/inference.py | 10 +- apps/offline_inf_v2/vae_latent_extractor.py | 408 +++++++++++------- 5 files changed, 278 insertions(+), 181 deletions(-) rename apps/offline_inf_v2/configs/{inference_256.yaml => inference.yaml} (60%) delete mode 100644 apps/offline_inf_v2/configs/inference_512.yaml diff --git a/apps/offline_inf_v2/configs/inference_256.yaml b/apps/offline_inf_v2/configs/inference.yaml similarity index 60% rename from apps/offline_inf_v2/configs/inference_256.yaml rename to apps/offline_inf_v2/configs/inference.yaml index 34c0ff67..35947475 100644 --- a/apps/offline_inf_v2/configs/inference_256.yaml +++ b/apps/offline_inf_v2/configs/inference.yaml @@ -9,14 +9,14 @@ model: autocast: true data: - data_path: /mnt/pollux/nemo/sample - output_path: /mnt/pollux/nemo/sample_latents_256/ + data_path: /mnt/pollux/nemo/data/bucket-256-1-1/ + output_path: /mnt/pollux/nemo/data/bucket-256-1-1-latents enable_checkpointing: true id_col: key batch_size: 16 - num_threads_per_worker: 16 - image_size: 256 + num_threads_per_worker: 32 + image_sizes: [256, 512] patch_size: 16 dynamic_crop_ratio: 1.0 - image_latent_column: image_latent_256 - image_latent_shape_column: image_latent_shape_256 + image_latent_column: image_latent + image_latent_shape_column: image_latent_shape diff --git a/apps/offline_inf_v2/configs/inference_512.yaml b/apps/offline_inf_v2/configs/inference_512.yaml deleted file mode 100644 index f8f5e93e..00000000 --- a/apps/offline_inf_v2/configs/inference_512.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: inference - -model: - gen_vae: - model_name: Hunyuan - pretrained_model_name_or_path: '/jfs/checkpoints/models--tencent--HunyuanVideo/snapshots/2a15b5574ee77888e51ae6f593b2ceed8ce813e5/vae' - enable_tiling: false - enable_slicing: false - autocast: true - -data: - data_path: /mnt/pollux/nemo/sample - output_path: /mnt/pollux/nemo/sample_latents_512/ - enable_checkpointing: true - id_col: key - batch_size: 16 - num_threads_per_worker: 16 - image_size: 512 - patch_size: 16 - dynamic_crop_ratio: 1.0 - image_latent_column: image_latent_512 - image_latent_shape_column: image_latent_shape_512 diff --git a/apps/offline_inf_v2/data.py b/apps/offline_inf_v2/data.py index fb426517..8809a4ce 100644 --- a/apps/offline_inf_v2/data.py +++ b/apps/offline_inf_v2/data.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional +from typing import List, Optional @dataclass class DataArgs: @@ -9,8 +9,9 @@ class DataArgs: id_col: str = field(default="key") batch_size: int = field(default=1) num_threads_per_worker: int = field(default=4) - image_size: int = field(default=256) + image_sizes: List[int] = field(default_factory=lambda: [256, 512]) patch_size: int = field(default=16) dynamic_crop_ratio: float = field(default=1.0) image_latent_column: str = field(default="image_latent") - image_latent_shape_column: str = field(default="image_latent_shape") \ No newline at end of file + image_latent_shape_column: str = field(default="image_latent_shape") + caption_column: str = field(default="caption") diff --git a/apps/offline_inf_v2/inference.py b/apps/offline_inf_v2/inference.py index 5b5c4edb..15fd2b0d 100644 --- a/apps/offline_inf_v2/inference.py +++ b/apps/offline_inf_v2/inference.py @@ -42,14 +42,22 @@ def main(): latent_extractor = VAELatentExtractor( model_args=cfg.model, data_args=cfg.data, + use_index_files=True, ) dataset_with_latents = latent_extractor(dataset) + latent_columns = [ + f"{cfg.data.image_latent_column}_{cfg.data.image_sizes[0]}", + f"{cfg.data.image_latent_shape_column}_{cfg.data.image_sizes[0]}", + f"{cfg.data.image_latent_column}_{cfg.data.image_sizes[1]}", + f"{cfg.data.image_latent_shape_column}_{cfg.data.image_sizes[1]}", + ] + # Metadata will have a new column named "image_latent" dataset_with_latents.save_metadata( cfg.data.output_path, columns=[ - cfg.data.id_col, "doc_id", cfg.data.image_latent_column, cfg.data.image_latent_shape_column + cfg.data.id_col, "_id", cfg.data.caption_column, *latent_columns ] ) diff --git a/apps/offline_inf_v2/vae_latent_extractor.py b/apps/offline_inf_v2/vae_latent_extractor.py index 1c8c1aba..9a41ec61 100644 --- a/apps/offline_inf_v2/vae_latent_extractor.py +++ b/apps/offline_inf_v2/vae_latent_extractor.py @@ -1,8 +1,8 @@ from collections import defaultdict import json import os +from pathlib import Path from tqdm import tqdm -from PIL import Image import cupy as cp from crossfit.backend.cudf.series import create_list_series_from_1d_or_2d_ar @@ -24,7 +24,7 @@ @pipeline_def def image_loading_pipeline(_tar_path: str, use_index_files: bool, data_args: DataArgs): if use_index_files: - index_paths = [f"{_tar_path.rsplit('.', 1)[0]}.idx"] + index_paths = [str(Path(_tar_path).with_suffix(".idx"))] else: index_paths = [] @@ -32,48 +32,32 @@ def image_loading_pipeline(_tar_path: str, use_index_files: bool, data_args: Dat paths=_tar_path, index_paths=index_paths, ext=["jpg", "txt", "json"], - missing_component_behavior="error", + missing_component_behavior="skip", ) - images = fn.decoders.image(images_raw, device="mixed", output_type=types.RGB) - return images, text, json - - -@pipeline_def -def webdataset_pipeline(_tar_path: str, use_index_files: bool, data_args: DataArgs): - if use_index_files: - index_paths = [f"{_tar_path.rsplit('.', 1)[0]}.idx"] - else: - index_paths = [] - - images_raw, text, json = fn.readers.webdataset( - paths=_tar_path, - index_paths=index_paths, - ext=["jpg", "txt", "json"], - missing_component_behavior="error", - ) - images = fn.decoders.image(images_raw, device="mixed", output_type=types.RGB) - images = fn.resize( - images, + images_gen = fn.experimental.decoders.image(images_raw, device="mixed", output_type=types.RGB) + + images_gen = fn.resize( + images_gen, device="gpu", - resize_x=data_args.image_size, - resize_y=data_args.image_size, + resize_x=data_args.image_sizes[1], + resize_y=data_args.image_sizes[1], mode="not_smaller", interp_type=types.DALIInterpType.INTERP_CUBIC ) # get the dynamic crop size crop_size = fn.python_function( - images.shape(device="cpu"), - data_args.image_size, + images_gen.shape(device="cpu"), + data_args.image_sizes[1], data_args.patch_size, data_args.dynamic_crop_ratio, function=var_center_crop_size_fn, device="cpu" ) - images = fn.crop_mirror_normalize( - images, + images_gen = fn.crop_mirror_normalize( + images_gen, device="gpu", crop_h=crop_size[0], crop_w=crop_size[1], @@ -85,8 +69,17 @@ def webdataset_pipeline(_tar_path: str, use_index_files: bool, data_args: DataAr std=[0.5 * 255, 0.5 * 255, 0.5 * 255], scale=1.0, ) + + images_plan = fn.resize( + images_gen, + device="gpu", + resize_x=data_args.image_sizes[0], + resize_y=data_args.image_sizes[0], + mode="not_smaller", + interp_type=types.DALIInterpType.INTERP_CUBIC + ) - return images, text, json + return images_plan, images_gen, text, json class VAELatentExtractor: @@ -128,72 +121,72 @@ def load_dataset_shard(self, tar_path: str): .json must contain the metadata for the record (including its ID). Images will be loaded using DALI. """ - # # Create the DALI pipeline - # @pipeline_def( - # batch_size=self.data_args.batch_size, - # num_threads=self.data_args.num_threads_per_worker, - # device_id=0, - # exec_dynamic=True, - # ) - # def webdataset_pipeline_wrapper(_tar_path: str, use_index_files: bool, data_args: DataArgs): - # return webdataset_pipeline(_tar_path, use_index_files, data_args) dali_pipeline_args = { "batch_size": self.data_args.batch_size, "num_threads": self.data_args.num_threads_per_worker, "device_id": 0, "exec_dynamic": True, + "prefetch_queue_depth": 8, } if self.data_args.enable_checkpointing: dali_pipeline_args["enable_checkpointing"] = True - checkpoint_path = f"{tar_path.rsplit('.', 1)[0]}_{self.data_args.image_size}.pth" + checkpoint_path = f"{tar_path.rsplit('.', 1)[0]}.pth" if os.path.exists(checkpoint_path): print (f"Restoring checkpoint from {checkpoint_path}") checkpoint = open(checkpoint_path, 'rb').read() dali_pipeline_args["checkpoint"] = checkpoint - pipeline = webdataset_pipeline( + image_dataset = image_loading_pipeline( tar_path, - self.use_index_files, + self.use_index_files, self.data_args, **dali_pipeline_args, ) - pipeline.build() - total_samples = pipeline.epoch_size() + image_dataset.build() + + total_samples = image_dataset.epoch_size() total_samples = total_samples[list(total_samples.keys())[0]] crop_size_list = generate_crop_size_list( - image_size=self.data_args.image_size, + image_size=self.data_args.image_sizes[0], patch_size=self.data_args.patch_size, max_ratio=2.0 ) - bucket_img = defaultdict(list) - # bucket_text_cpu = defaultdict(list) + bucket_img_plan = defaultdict(list) + bucket_img_gen = defaultdict(list) + bucket_text = defaultdict(list) bucket_meta = defaultdict(list) # pbar = tqdm(total=total_samples, desc="Loading dataset shard") samples_completed = 0 while samples_completed < total_samples: - image, text, meta = pipeline.run() + image_plan, image_gen, text, meta = image_dataset.run() for i in range(self.data_args.batch_size): - img = image[i] - _, w, h = img.shape() + img_plan, img_gen = image_plan[i], image_gen[i] + _, w, h = img_plan.shape() crop_size = (w, h) - bucket_img[crop_size].append(img) - # bucket_text_cpu[crop_size].append(text.at(i).tostring().decode("utf-8")) + bucket_img_plan[crop_size].append(img_plan) + bucket_img_gen[crop_size].append(img_gen) + bucket_text[crop_size].append(text.at(i).tostring().decode("utf-8")) bucket_meta[crop_size].append(json.loads(meta.at(i).tostring().decode("utf-8"))) # if batch size is reached, yield the batch - if len(bucket_img[crop_size]) == self.data_args.batch_size: - image_batch = TensorListGPU(bucket_img[crop_size]).as_tensor() - image_torch = torch.empty(image_batch.shape(), dtype=torch.float32, device="cuda") - feed_ndarray(image_batch, image_torch) # COPY !!! - yield image_torch, bucket_meta[crop_size] - bucket_img[crop_size] = [] + if len(bucket_img_plan[crop_size]) == self.data_args.batch_size: + image_plan_batch = TensorListGPU(bucket_img_plan[crop_size]).as_tensor() + image_gen_batch = TensorListGPU(bucket_img_gen[crop_size]).as_tensor() + image_plan_torch = torch.empty(image_plan_batch.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_plan_batch, image_plan_torch) # COPY !!! + image_gen_torch = torch.empty(image_gen_batch.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_gen_batch, image_gen_torch) # COPY !!! + yield image_plan_torch, image_gen_torch, bucket_text[crop_size], bucket_meta[crop_size] + bucket_img_plan[crop_size] = [] + bucket_img_gen[crop_size] = [] + bucket_text[crop_size] = [] bucket_meta[crop_size] = [] samples_completed += self.data_args.batch_size @@ -204,16 +197,19 @@ def load_dataset_shard(self, tar_path: str): # pbar.close() for crop_size in crop_size_list: - if not bucket_img[crop_size]: + if not bucket_img_plan[crop_size]: continue - image_batch = TensorListGPU(bucket_img[crop_size]).as_tensor() - image_torch = torch.empty(image_batch.shape(), dtype=torch.float32, device="cuda") - feed_ndarray(image_batch, image_torch) # COPY !!! - yield image_torch, bucket_meta[crop_size] + image_plan_batch = TensorListGPU(bucket_img_plan[crop_size]).as_tensor() + image_gen_batch = TensorListGPU(bucket_img_gen[crop_size]).as_tensor() + image_plan_torch = torch.empty(image_plan_batch.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_plan_batch, image_plan_torch) # COPY !!! + image_gen_torch = torch.empty(image_gen_batch.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_gen_batch, image_gen_torch) # COPY !!! + yield image_plan_torch, image_gen_torch, bucket_text[crop_size], bucket_meta[crop_size] # checkpoint final state if self.data_args.enable_checkpointing: - pipeline.checkpoint(checkpoint_path) + image_dataset.checkpoint(checkpoint_path) def load_model(self, model_args, device="cuda"): """ @@ -260,6 +256,11 @@ def _process_batch(self, model, batch): else: latents = model(batch) return latents.cpu() # Move to CPU immediately + + def _process_batch_pollux(self, model, image_plan_batch, image_gen_batch): + plan_latents = self._process_batch(model, image_plan_batch) + gen_latents = self._process_batch(model, image_gen_batch) + return plan_latents, gen_latents def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: """ @@ -273,8 +274,15 @@ def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: classifier scores. """ meta = dataset.metadata.dtypes.to_dict() - meta[self.data_args.image_latent_column] = "object" - meta[self.data_args.image_latent_shape_column] = "object" + latent_columns = [ + f"{self.data_args.image_latent_column}_{self.data_args.image_sizes[0]}", + f"{self.data_args.image_latent_shape_column}_{self.data_args.image_sizes[0]}", + f"{self.data_args.image_latent_column}_{self.data_args.image_sizes[1]}", + f"{self.data_args.image_latent_shape_column}_{self.data_args.image_sizes[1]}", + ] + for col in latent_columns: + meta[col] = "object" + meta[self.data_args.caption_column] = "object" embedding_df = dataset.metadata.map_partitions( self._run_inference, dataset.tar_files, dataset.id_col, meta=meta @@ -297,10 +305,10 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): {"model_args": self.model_args, "device": device}, ) - print(f"Model loaded on {device}") - dataset = self.load_dataset_shard(tar_path) - final_image_latents = [] + final_image_plan_latents = [] + final_image_gen_latents = [] + final_text_batch = [] image_ids = [] samples_completed = 0 expected_samples = len(partition) @@ -311,7 +319,7 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): # Process batches with torch.no_grad(), torch.amp.autocast(device_type="cuda", enabled=self.model_args.autocast): - for batch, metadata in dataset: + for image_plan_batch, image_gen_batch, text_batch, metadata in dataset: # Only process as many samples as we expect from metadata if samples_completed >= expected_samples: break @@ -322,13 +330,18 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): if actual_batch_size < len(metadata): # Truncate batch and metadata if needed - batch = batch[:actual_batch_size] + image_plan_batch = image_plan_batch[:actual_batch_size] + image_gen_batch = image_gen_batch[:actual_batch_size] + text_batch = text_batch[:actual_batch_size] metadata = metadata[:actual_batch_size] - image_latents = self._process_batch(model, batch) - del batch + plan_latents, gen_latents = self._process_batch_pollux(model, image_plan_batch, image_gen_batch) + del image_plan_batch + del image_gen_batch - final_image_latents.append(image_latents) + final_image_plan_latents.append(plan_latents) + final_image_gen_latents.append(gen_latents) + final_text_batch.extend(text_batch) image_ids.extend(m[id_col] for m in metadata) samples_completed += actual_batch_size @@ -346,97 +359,194 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): ) # Process embeddings in memory-efficient way - return self._process_embeddings(partition, final_image_latents, image_ids) - - def _process_embeddings(self, partition, final_image_latents, image_ids): + partition = self._process_embeddings(partition, final_image_plan_latents, image_ids, image_size=self.data_args.image_sizes[0]) + partition = self._process_embeddings(partition, final_image_gen_latents, image_ids, image_size=self.data_args.image_sizes[1]) + partition = self._process_captions(partition, final_text_batch, image_ids) + return partition + + def _process_captions(self, partition, final_text_batch, image_ids): + sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + sorted_text_batch = [final_text_batch[i] for i in sorted_indices] + partition[self.data_args.caption_column] = sorted_text_batch + return partition + + def _process_embeddings(self, partition, final_image_latents, image_ids, image_size): """Process embeddings in a memory-efficient way""" + ''' # Check if we need to handle variable-sized latents if self.data_args.dynamic_crop_ratio > 1.0: # Handle variable-sized latents return self._process_variable_sized_embeddings(partition, final_image_latents, image_ids) else: - # Order the output of the shard - sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) - # Process fixed-size latents as before - # Process in chunks to reduce memory usage - all_embeddings = torch.cat(final_image_latents, dim=0) - sorted_embeddings = all_embeddings[sorted_indices] - - embedding_shape_list = [emb.shape for emb in sorted_embeddings] + ''' - # View the embeddings to be [N, 16*32*32] - sorted_embeddings = sorted_embeddings.view(sorted_embeddings.shape[0], -1) - - # Process in chunks to avoid OOM - chunk_size = 1000 # Adjust based on your GPU memory - concat_embedding_output = None - for i in range(0, sorted_embeddings.shape[0], chunk_size): - end_idx = min(i + chunk_size, sorted_embeddings.shape[0]) - chunk = sorted_embeddings[i:end_idx].cuda() # Move chunk to GPU - chunk_cp = cp.asarray(chunk) # Convert to CuPy - - if concat_embedding_output is None: - concat_embedding_output = chunk_cp - else: - concat_embedding_output = cp.concatenate([concat_embedding_output, chunk_cp], axis=0) - - # Free GPU memory - del chunk - torch.cuda.empty_cache() + # Order the output of the shard + sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + # Process fixed-size latents as before + # Process in chunks to reduce memory usage + all_embeddings = torch.cat(final_image_latents, dim=0) + sorted_embeddings = all_embeddings[sorted_indices] + + embedding_shape_list = [emb.shape for emb in sorted_embeddings] - partition[self.data_args.image_latent_column] = create_list_series_from_1d_or_2d_ar( - concat_embedding_output, index=partition.index - ) + # View the embeddings to be [N, 16*32*32] + sorted_embeddings = sorted_embeddings.view(sorted_embeddings.shape[0], -1) + + # Process in chunks to avoid OOM + chunk_size = 1000 # Adjust based on your GPU memory + concat_embedding_output = None + for i in range(0, sorted_embeddings.shape[0], chunk_size): + end_idx = min(i + chunk_size, sorted_embeddings.shape[0]) + chunk = sorted_embeddings[i:end_idx].cuda() # Move chunk to GPU + chunk_cp = cp.asarray(chunk) # Convert to CuPy - # Convert embedding_shape_list to a CuPy array before passing it - embedding_shape_array = cp.array(embedding_shape_list) - partition[self.data_args.image_latent_shape_column] = create_list_series_from_1d_or_2d_ar( - embedding_shape_array, index=partition.index - ) - - del concat_embedding_output - del final_image_latents - del embedding_shape_array # Also clean up the new array + if concat_embedding_output is None: + concat_embedding_output = chunk_cp + else: + concat_embedding_output = cp.concatenate([concat_embedding_output, chunk_cp], axis=0) + + # Free GPU memory + del chunk torch.cuda.empty_cache() - return partition - - def _process_variable_sized_embeddings(self, partition, final_image_latents, image_ids): - """Process embeddings with variable sizes due to dynamic cropping""" - # Order the output of the shard - sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + partition[f"{self.data_args.image_latent_column}_{image_size}"] = create_list_series_from_1d_or_2d_ar( + concat_embedding_output, index=partition.index + ) + + # Convert embedding_shape_list to a CuPy array before passing it + embedding_shape_array = cp.array(embedding_shape_list) + partition[f"{self.data_args.image_latent_shape_column}_{image_size}"] = create_list_series_from_1d_or_2d_ar( + embedding_shape_array, index=partition.index + ) + + del concat_embedding_output + del final_image_latents + del embedding_shape_array # Also clean up the new array + torch.cuda.empty_cache() - # Flatten our list of tensors - flat_embeddings = [] - current_idx = 0 + return partition + + # def _process_variable_sized_embeddings(self, partition, final_image_latents, image_ids): + # """Process embeddings with variable sizes due to dynamic cropping""" + # # Order the output of the shard + # sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + + # # Flatten our list of tensors + # flat_embeddings = [] + # current_idx = 0 - for batch_tensors in final_image_latents: - batch_size = batch_tensors.shape[0] - for i in range(batch_size): - flat_embeddings.append(batch_tensors[i]) - current_idx += 1 + # for batch_tensors in final_image_latents: + # batch_size = batch_tensors.shape[0] + # for i in range(batch_size): + # flat_embeddings.append(batch_tensors[i]) + # current_idx += 1 - # Create sorted embeddings list - sorted_embeddings = [flat_embeddings[idx] for idx in sorted_indices] - del flat_embeddings + # # Create sorted embeddings list + # sorted_embeddings = [flat_embeddings[idx] for idx in sorted_indices] + # del flat_embeddings - # Process each embedding individually and create a list of numpy arrays - embedding_list = [] - embedding_shape_list = [] - for emb in sorted_embeddings: - # Flatten the embedding to 1D - embedding_shape_list.append(emb.shape) - flat_emb = emb.view(-1).numpy() # Convert to numpy array - embedding_list.append(flat_emb) + # # Process each embedding individually and create a list of numpy arrays + # embedding_list = [] + # embedding_shape_list = [] + # for emb in sorted_embeddings: + # # Flatten the embedding to 1D + # embedding_shape_list.append(emb.shape) + # flat_emb = emb.view(-1).numpy() # Convert to numpy array + # embedding_list.append(flat_emb) - # Assign to partition - partition[self.data_args.image_latent_column] = embedding_list - partition[self.data_args.image_latent_shape_column] = embedding_shape_list + # # Assign to partition + # partition[self.data_args.image_latent_column] = embedding_list + # partition[self.data_args.image_latent_shape_column] = embedding_shape_list - del embedding_list - del sorted_embeddings - del final_image_latents - torch.cuda.empty_cache() + # del embedding_list + # del sorted_embeddings + # del final_image_latents + # torch.cuda.empty_cache() + + # return partition + + + def load_dataset_shard2(self, tar_path: str): + """ + Loads a WebDataset tar shard using DALI. + + Args: + tar_path (str): The path of the tar shard to load. + + Returns: + Iterable: An iterator over the dataset. Each tar file + must have 3 files per record: a .jpg file, a .txt file, + and a .json file. The .jpg file must contain the image, the + .txt file must contain the associated caption, and the + .json must contain the metadata for the record (including + its ID). Images will be loaded using DALI. + """ + + dali_pipeline_args = { + "batch_size": self.data_args.batch_size, + "num_threads": self.data_args.num_threads_per_worker, + "device_id": 0, + "exec_dynamic": True, + "prefetch_queue_depth": 8, + } + + if self.data_args.enable_checkpointing: + dali_pipeline_args["enable_checkpointing"] = True + checkpoint_path = f"{tar_path.rsplit('.', 1)[0]}.pth" + if os.path.exists(checkpoint_path): + print (f"Restoring checkpoint from {checkpoint_path}") + checkpoint = open(checkpoint_path, 'rb').read() + dali_pipeline_args["checkpoint"] = checkpoint + + image_dataset = image_loading_pipeline( + tar_path, + self.use_index_files, + self.data_args, + **dali_pipeline_args, + ) + + image_dataset.build() + + total_samples = image_dataset.epoch_size() + total_samples = total_samples[list(total_samples.keys())[0]] + + # pbar = tqdm(total=total_samples, desc="Loading dataset shard") + + samples_completed = 0 + while samples_completed < total_samples: + image_plan, image_gen, text, meta = image_dataset.run() + + image_plan = image_plan.as_tensor() + image_plan_torch = torch.empty(image_plan.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_plan, image_plan_torch) # COPY !!! + image_plan = image_plan_torch + + image_gen = image_gen.as_tensor() + image_gen_torch = torch.empty(image_gen.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_gen, image_gen_torch) # COPY !!! + image_gen = image_gen_torch + + captions = [text.at(i).tostring().decode("utf-8") for i in range(len(text))] + metadata = [ + json.loads(meta.at(i).tostring().decode("utf-8")) + for i in range(len(meta)) + ] + + remaining_samples = total_samples - samples_completed + if image_plan.shape[0] >= remaining_samples: + image_plan = image_plan[:remaining_samples] + image_gen = image_gen[:remaining_samples] + captions = captions[:remaining_samples] + metadata = metadata[:remaining_samples] + + remaining_samples = total_samples - samples_completed + + samples_completed += min(image_plan.shape[0], remaining_samples) + + yield image_plan, image_gen, captions, metadata - return partition \ No newline at end of file + # checkpoint final state + if self.data_args.enable_checkpointing: + image_dataset.checkpoint(checkpoint_path) + \ No newline at end of file From 856e4f932a001d9ca4903828bcf956fc71b2fe6d Mon Sep 17 00:00:00 2001 From: sippycoder Date: Mon, 17 Mar 2025 05:20:18 +0000 Subject: [PATCH 7/8] bug fixes --- apps/offline_inf_v2/configs/inference.yaml | 11 +- apps/offline_inf_v2/data.py | 2 + apps/offline_inf_v2/inference.py | 34 +++++-- apps/offline_inf_v2/pipeline.py | 81 +++++++++++++++ apps/offline_inf_v2/vae_latent_extractor.py | 105 +++++--------------- 5 files changed, 144 insertions(+), 89 deletions(-) create mode 100644 apps/offline_inf_v2/pipeline.py diff --git a/apps/offline_inf_v2/configs/inference.yaml b/apps/offline_inf_v2/configs/inference.yaml index 35947475..f34c93b1 100644 --- a/apps/offline_inf_v2/configs/inference.yaml +++ b/apps/offline_inf_v2/configs/inference.yaml @@ -9,14 +9,17 @@ model: autocast: true data: - data_path: /mnt/pollux/nemo/data/bucket-256-1-1/ - output_path: /mnt/pollux/nemo/data/bucket-256-1-1-latents + data_path: /mnt/pollux/nemo/data/bucket-256-1-0/ + output_path: /mnt/pollux/nemo/data/bucket-256-1-0-latents enable_checkpointing: true id_col: key - batch_size: 16 - num_threads_per_worker: 32 + batch_size: 64 + mini_batch_size: 16 + num_threads_per_worker: 16 image_sizes: [256, 512] patch_size: 16 dynamic_crop_ratio: 1.0 image_latent_column: image_latent image_latent_shape_column: image_latent_shape + caption_column: caption + valid_column: valid diff --git a/apps/offline_inf_v2/data.py b/apps/offline_inf_v2/data.py index 8809a4ce..1fbc3030 100644 --- a/apps/offline_inf_v2/data.py +++ b/apps/offline_inf_v2/data.py @@ -8,6 +8,7 @@ class DataArgs: enable_checkpointing: Optional[str] = field(default=None) id_col: str = field(default="key") batch_size: int = field(default=1) + mini_batch_size: int = field(default=1) num_threads_per_worker: int = field(default=4) image_sizes: List[int] = field(default_factory=lambda: [256, 512]) patch_size: int = field(default=16) @@ -15,3 +16,4 @@ class DataArgs: image_latent_column: str = field(default="image_latent") image_latent_shape_column: str = field(default="image_latent_shape") caption_column: str = field(default="caption") + valid_column: str = field(default="valid") diff --git a/apps/offline_inf_v2/inference.py b/apps/offline_inf_v2/inference.py index 15fd2b0d..99a522a9 100644 --- a/apps/offline_inf_v2/inference.py +++ b/apps/offline_inf_v2/inference.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from omegaconf import OmegaConf +import dask_cudf from nemo_curator import get_client from nemo_curator.datasets import ImageTextPairDataset @@ -21,6 +22,20 @@ class InferenceArgs: data: DataArgs = field(default_factory=DataArgs) +def init_image_text_dataset(cfg: InferenceArgs): + metadata = dask_cudf.read_parquet(cfg.data.data_path, split_row_groups=False, blocksize=None) + if 'status' in metadata.columns: + metadata = metadata[metadata.status != "failed_to_download"] + metadata = metadata.map_partitions(ImageTextPairDataset._sort_partition, id_col=cfg.data.id_col) + tar_files = ImageTextPairDataset._get_tar_files(cfg.data.data_path) + return ImageTextPairDataset( + path=cfg.data.data_path, + metadata=metadata, + tar_files=tar_files, + id_col=cfg.data.id_col, + ) + + def main(): cli_args = OmegaConf.from_cli() file_cfg = OmegaConf.load(cli_args.config) @@ -33,17 +48,21 @@ def main(): assert cfg.data.output_path is not None, f"Output path is required, otherwise the parquets in {cfg.data.data_path} will be overwritten" - client = get_client(cluster_type="gpu", nvlink_only=True) - - dataset = ImageTextPairDataset.from_webdataset( - path=cfg.data.data_path, id_col=cfg.data.id_col + client = get_client( + cluster_type="gpu", ) + print("Inititiating dataset ...") + dataset = init_image_text_dataset(cfg) + print("Dataset initialized") + + print("Initializing latent extractor ...") latent_extractor = VAELatentExtractor( model_args=cfg.model, data_args=cfg.data, use_index_files=True, ) + print("Latent extractor initialized") dataset_with_latents = latent_extractor(dataset) @@ -53,14 +72,15 @@ def main(): f"{cfg.data.image_latent_column}_{cfg.data.image_sizes[1]}", f"{cfg.data.image_latent_shape_column}_{cfg.data.image_sizes[1]}", ] - + + print("Extracting latents ...") # Metadata will have a new column named "image_latent" dataset_with_latents.save_metadata( cfg.data.output_path, columns=[ - cfg.data.id_col, "_id", cfg.data.caption_column, *latent_columns + cfg.data.id_col, "_id", cfg.data.caption_column, cfg.data.valid_column, *latent_columns ] ) - + print("Latents extracted") if __name__ == "__main__": main() diff --git a/apps/offline_inf_v2/pipeline.py b/apps/offline_inf_v2/pipeline.py new file mode 100644 index 00000000..d5c73d21 --- /dev/null +++ b/apps/offline_inf_v2/pipeline.py @@ -0,0 +1,81 @@ +from io import BytesIO +import numpy as np +from pathlib import Path +import nvidia.dali.fn as fn +import nvidia.dali.types as types +from nvidia.dali.pipeline import pipeline_def +from PIL import Image + +from apps.main.modules.preprocess import var_center_crop_size_fn +from apps.offline_inf_v2.data import DataArgs + + +def decode_image_raw(image_raw): + try: + image = Image.open(BytesIO(image_raw)) + return np.asarray(image, dtype=np.uint8), np.ones(1, dtype=np.uint8) + except Exception as e: + print(f"Error decoding image: {e}") + return np.zeros((1, 1, 3), dtype=np.uint8), np.zeros(1, dtype=np.uint8) + + +@pipeline_def +def image_loading_pipeline(_tar_path: str, use_index_files: bool, data_args: DataArgs): + if use_index_files: + index_paths = [str(Path(_tar_path).with_suffix(".idx"))] + else: + index_paths = [] + + images_raw, text, json = fn.readers.webdataset( + paths=_tar_path, + index_paths=index_paths, + ext=["jpg", "txt", "json"], + missing_component_behavior="error", + ) + + # images_gen = fn.experimental.decoders.image(images_raw, device="cpu", output_type=types.RGB) + images_gen, valid = fn.python_function(images_raw, function=decode_image_raw, device="cpu", num_outputs=2) + + images_gen = fn.resize( + images_gen.gpu(), + device="gpu", + resize_x=data_args.image_sizes[1], + resize_y=data_args.image_sizes[1], + mode="not_smaller", + interp_type=types.DALIInterpType.INTERP_CUBIC + ) + + # get the dynamic crop size + crop_size = fn.python_function( + images_gen.shape(device="cpu"), + data_args.image_sizes[1], + data_args.patch_size, + data_args.dynamic_crop_ratio, + function=var_center_crop_size_fn, + device="cpu" + ) + + images_gen = fn.crop_mirror_normalize( + images_gen, + device="gpu", + crop_h=crop_size[0], + crop_w=crop_size[1], + crop_pos_x=0.5, + crop_pos_y=0.5, + mirror=fn.random.coin_flip(probability=0.5), + dtype=types.DALIDataType.FLOAT, + mean=[0.5 * 255, 0.5 * 255, 0.5 * 255], + std=[0.5 * 255, 0.5 * 255, 0.5 * 255], + scale=1.0, + ) + + images_plan = fn.resize( + images_gen, + device="gpu", + resize_x=data_args.image_sizes[0], + resize_y=data_args.image_sizes[0], + mode="not_smaller", + interp_type=types.DALIInterpType.INTERP_CUBIC + ) + + return images_plan, images_gen, text, json, valid \ No newline at end of file diff --git a/apps/offline_inf_v2/vae_latent_extractor.py b/apps/offline_inf_v2/vae_latent_extractor.py index 9a41ec61..8591a0de 100644 --- a/apps/offline_inf_v2/vae_latent_extractor.py +++ b/apps/offline_inf_v2/vae_latent_extractor.py @@ -19,67 +19,7 @@ from apps.main.modules.preprocess import generate_crop_size_list, var_center_crop_size_fn from apps.offline_inf_v2.data import DataArgs from apps.offline_inf_v2.model import ModelArgs, VAE - - -@pipeline_def -def image_loading_pipeline(_tar_path: str, use_index_files: bool, data_args: DataArgs): - if use_index_files: - index_paths = [str(Path(_tar_path).with_suffix(".idx"))] - else: - index_paths = [] - - images_raw, text, json = fn.readers.webdataset( - paths=_tar_path, - index_paths=index_paths, - ext=["jpg", "txt", "json"], - missing_component_behavior="skip", - ) - - images_gen = fn.experimental.decoders.image(images_raw, device="mixed", output_type=types.RGB) - - images_gen = fn.resize( - images_gen, - device="gpu", - resize_x=data_args.image_sizes[1], - resize_y=data_args.image_sizes[1], - mode="not_smaller", - interp_type=types.DALIInterpType.INTERP_CUBIC - ) - - # get the dynamic crop size - crop_size = fn.python_function( - images_gen.shape(device="cpu"), - data_args.image_sizes[1], - data_args.patch_size, - data_args.dynamic_crop_ratio, - function=var_center_crop_size_fn, - device="cpu" - ) - - images_gen = fn.crop_mirror_normalize( - images_gen, - device="gpu", - crop_h=crop_size[0], - crop_w=crop_size[1], - crop_pos_x=0.5, - crop_pos_y=0.5, - mirror=fn.random.coin_flip(probability=0.5), - dtype=types.DALIDataType.FLOAT, - mean=[0.5 * 255, 0.5 * 255, 0.5 * 255], - std=[0.5 * 255, 0.5 * 255, 0.5 * 255], - scale=1.0, - ) - - images_plan = fn.resize( - images_gen, - device="gpu", - resize_x=data_args.image_sizes[0], - resize_y=data_args.image_sizes[0], - mode="not_smaller", - interp_type=types.DALIInterpType.INTERP_CUBIC - ) - - return images_plan, images_gen, text, json +from apps.offline_inf_v2.pipeline import image_loading_pipeline class VAELatentExtractor: @@ -127,7 +67,7 @@ def load_dataset_shard(self, tar_path: str): "num_threads": self.data_args.num_threads_per_worker, "device_id": 0, "exec_dynamic": True, - "prefetch_queue_depth": 8, + "prefetch_queue_depth": 4, } if self.data_args.enable_checkpointing: @@ -160,12 +100,12 @@ def load_dataset_shard(self, tar_path: str): bucket_img_gen = defaultdict(list) bucket_text = defaultdict(list) bucket_meta = defaultdict(list) - + bucket_valid = defaultdict(list) # pbar = tqdm(total=total_samples, desc="Loading dataset shard") samples_completed = 0 while samples_completed < total_samples: - image_plan, image_gen, text, meta = image_dataset.run() + image_plan, image_gen, text, meta, valid = image_dataset.run() for i in range(self.data_args.batch_size): img_plan, img_gen = image_plan[i], image_gen[i] _, w, h = img_plan.shape() @@ -174,7 +114,7 @@ def load_dataset_shard(self, tar_path: str): bucket_img_gen[crop_size].append(img_gen) bucket_text[crop_size].append(text.at(i).tostring().decode("utf-8")) bucket_meta[crop_size].append(json.loads(meta.at(i).tostring().decode("utf-8"))) - + bucket_valid[crop_size].append(valid.at(i).tolist()) # if batch size is reached, yield the batch if len(bucket_img_plan[crop_size]) == self.data_args.batch_size: image_plan_batch = TensorListGPU(bucket_img_plan[crop_size]).as_tensor() @@ -183,12 +123,12 @@ def load_dataset_shard(self, tar_path: str): feed_ndarray(image_plan_batch, image_plan_torch) # COPY !!! image_gen_torch = torch.empty(image_gen_batch.shape(), dtype=torch.float32, device="cuda") feed_ndarray(image_gen_batch, image_gen_torch) # COPY !!! - yield image_plan_torch, image_gen_torch, bucket_text[crop_size], bucket_meta[crop_size] + yield image_plan_torch, image_gen_torch, bucket_text[crop_size], bucket_meta[crop_size], bucket_valid[crop_size] bucket_img_plan[crop_size] = [] bucket_img_gen[crop_size] = [] bucket_text[crop_size] = [] bucket_meta[crop_size] = [] - + bucket_valid[crop_size] = [] samples_completed += self.data_args.batch_size # if samples_completed % (100 * self.data_args.batch_size) == 0 and self.data_args.enable_checkpointing is not None: @@ -205,7 +145,7 @@ def load_dataset_shard(self, tar_path: str): feed_ndarray(image_plan_batch, image_plan_torch) # COPY !!! image_gen_torch = torch.empty(image_gen_batch.shape(), dtype=torch.float32, device="cuda") feed_ndarray(image_gen_batch, image_gen_torch) # COPY !!! - yield image_plan_torch, image_gen_torch, bucket_text[crop_size], bucket_meta[crop_size] + yield image_plan_torch, image_gen_torch, bucket_text[crop_size], bucket_meta[crop_size], bucket_valid[crop_size] # checkpoint final state if self.data_args.enable_checkpointing: @@ -245,9 +185,9 @@ def custom_forward(*args, **kwargs): def _process_batch(self, model, batch): """Helper method to process a batch with appropriate chunking""" - if batch.shape[0] > 16 and batch.shape[0] % 16 == 0: - # Process in chunks of 16 to avoid OOM - sub_batches = batch.chunk(batch.shape[0] // 16) + if batch.shape[0] > self.data_args.mini_batch_size and batch.shape[0] % self.data_args.mini_batch_size == 0: + # Process in chunks of mini_batch_size to avoid OOM + sub_batches = batch.chunk(batch.shape[0] // self.data_args.mini_batch_size) sub_latents = [] for sub_batch in sub_batches: sub_latent = model(sub_batch) @@ -283,7 +223,7 @@ def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: for col in latent_columns: meta[col] = "object" meta[self.data_args.caption_column] = "object" - + meta[self.data_args.valid_column] = "object" embedding_df = dataset.metadata.map_partitions( self._run_inference, dataset.tar_files, dataset.id_col, meta=meta ) @@ -305,11 +245,12 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): {"model_args": self.model_args, "device": device}, ) - dataset = self.load_dataset_shard(tar_path) + dataset = self.load_dataset_shard2(tar_path) final_image_plan_latents = [] final_image_gen_latents = [] final_text_batch = [] image_ids = [] + is_valid = [] samples_completed = 0 expected_samples = len(partition) progress_bar = tqdm( @@ -319,7 +260,7 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): # Process batches with torch.no_grad(), torch.amp.autocast(device_type="cuda", enabled=self.model_args.autocast): - for image_plan_batch, image_gen_batch, text_batch, metadata in dataset: + for image_plan_batch, image_gen_batch, text_batch, metadata, valid in dataset: # Only process as many samples as we expect from metadata if samples_completed >= expected_samples: break @@ -334,6 +275,7 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): image_gen_batch = image_gen_batch[:actual_batch_size] text_batch = text_batch[:actual_batch_size] metadata = metadata[:actual_batch_size] + valid = valid[:actual_batch_size] plan_latents, gen_latents = self._process_batch_pollux(model, image_plan_batch, image_gen_batch) del image_plan_batch @@ -343,7 +285,7 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): final_image_gen_latents.append(gen_latents) final_text_batch.extend(text_batch) image_ids.extend(m[id_col] for m in metadata) - + is_valid.extend(valid) samples_completed += actual_batch_size progress_bar.update(actual_batch_size) @@ -362,6 +304,7 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): partition = self._process_embeddings(partition, final_image_plan_latents, image_ids, image_size=self.data_args.image_sizes[0]) partition = self._process_embeddings(partition, final_image_gen_latents, image_ids, image_size=self.data_args.image_sizes[1]) partition = self._process_captions(partition, final_text_batch, image_ids) + partition = self._process_valid(partition, is_valid, image_ids) return partition def _process_captions(self, partition, final_text_batch, image_ids): @@ -369,6 +312,12 @@ def _process_captions(self, partition, final_text_batch, image_ids): sorted_text_batch = [final_text_batch[i] for i in sorted_indices] partition[self.data_args.caption_column] = sorted_text_batch return partition + + def _process_valid(self, partition, is_valid, image_ids): + sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + sorted_valid = [is_valid[i] for i in sorted_indices] + partition[self.data_args.valid_column] = sorted_valid + return partition def _process_embeddings(self, partition, final_image_latents, image_ids, image_size): """Process embeddings in a memory-efficient way""" @@ -515,7 +464,7 @@ def load_dataset_shard2(self, tar_path: str): samples_completed = 0 while samples_completed < total_samples: - image_plan, image_gen, text, meta = image_dataset.run() + image_plan, image_gen, text, meta, valid = image_dataset.run() image_plan = image_plan.as_tensor() image_plan_torch = torch.empty(image_plan.shape(), dtype=torch.float32, device="cuda") @@ -532,7 +481,7 @@ def load_dataset_shard2(self, tar_path: str): json.loads(meta.at(i).tostring().decode("utf-8")) for i in range(len(meta)) ] - + valid = [valid.at(i).tolist() for i in range(len(valid))] remaining_samples = total_samples - samples_completed if image_plan.shape[0] >= remaining_samples: image_plan = image_plan[:remaining_samples] @@ -544,7 +493,7 @@ def load_dataset_shard2(self, tar_path: str): samples_completed += min(image_plan.shape[0], remaining_samples) - yield image_plan, image_gen, captions, metadata + yield image_plan, image_gen, captions, metadata, valid # checkpoint final state if self.data_args.enable_checkpointing: From f91e7965ac96191054ccae31fef7d5d358f6e537 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Mon, 17 Mar 2025 06:44:19 +0000 Subject: [PATCH 8/8] RGB conversion --- apps/offline_inf_v2/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/offline_inf_v2/pipeline.py b/apps/offline_inf_v2/pipeline.py index d5c73d21..31a95184 100644 --- a/apps/offline_inf_v2/pipeline.py +++ b/apps/offline_inf_v2/pipeline.py @@ -12,7 +12,7 @@ def decode_image_raw(image_raw): try: - image = Image.open(BytesIO(image_raw)) + image = Image.open(BytesIO(image_raw)).convert('RGB') return np.asarray(image, dtype=np.uint8), np.ones(1, dtype=np.uint8) except Exception as e: print(f"Error decoding image: {e}")