diff --git a/apps/main/modules/preprocess.py b/apps/main/modules/preprocess.py index 42e5362c..223aedb9 100644 --- a/apps/main/modules/preprocess.py +++ b/apps/main/modules/preprocess.py @@ -166,6 +166,58 @@ def center_crop_arr(pil_image, image_size): ) +def generate_crop_size_list(image_size, patch_size, max_ratio=2.0): + assert max_ratio >= 1.0 + min_wp, min_hp = image_size // patch_size, image_size // patch_size + crop_size_list = [] + wp, hp = min_wp, min_hp + while hp / wp <= max_ratio: + crop_size_list.append((wp * patch_size, hp * patch_size)) + hp += 1 + wp, hp = min_wp + 1, min_hp + while wp / hp <= max_ratio: + crop_size_list.append((wp * patch_size, hp * patch_size)) + wp += 1 + return crop_size_list + + +def is_valid_crop_size(cw, ch, orig_w, orig_h): + down_scale = max(cw / orig_w, ch / orig_h) + return cw <= orig_w * down_scale and ch <= orig_h * down_scale + + +def var_center_crop_size_fn(orig_img_shape, image_size=256, patch_size=16, max_ratio=2.0): + w, h, _ = orig_img_shape + crop_size_list = generate_crop_size_list( + image_size=image_size, + patch_size=patch_size, + max_ratio=max_ratio + ) + rem_percent = [ + min(cw / w, ch / h) / max(cw / w, ch / h) + if is_valid_crop_size(cw, ch, w, h) else 0 + for cw, ch in crop_size_list + ] + crop_size = sorted( + ((x, y) for x, y in zip(rem_percent, crop_size_list) if x > 0 and y[0] <= w and y[1] <= h), + reverse=True + )[0][1] + return np.array(crop_size, dtype=np.float32) + + +def downsample_resize_image(image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + pil_image = Image.fromarray(image) + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + return pil_image + + ######################## FOR TEXT ######################## diff --git a/apps/offline_inf_v2/configs/inference.yaml b/apps/offline_inf_v2/configs/inference.yaml new file mode 100644 index 00000000..f34c93b1 --- /dev/null +++ b/apps/offline_inf_v2/configs/inference.yaml @@ -0,0 +1,25 @@ +name: inference + +model: + gen_vae: + model_name: Hunyuan + pretrained_model_name_or_path: '/jfs/checkpoints/models--tencent--HunyuanVideo/snapshots/2a15b5574ee77888e51ae6f593b2ceed8ce813e5/vae' + enable_tiling: false + enable_slicing: false + autocast: true + +data: + data_path: /mnt/pollux/nemo/data/bucket-256-1-0/ + output_path: /mnt/pollux/nemo/data/bucket-256-1-0-latents + enable_checkpointing: true + id_col: key + batch_size: 64 + mini_batch_size: 16 + num_threads_per_worker: 16 + image_sizes: [256, 512] + patch_size: 16 + dynamic_crop_ratio: 1.0 + image_latent_column: image_latent + image_latent_shape_column: image_latent_shape + caption_column: caption + valid_column: valid diff --git a/apps/offline_inf_v2/data.py b/apps/offline_inf_v2/data.py new file mode 100644 index 00000000..1fbc3030 --- /dev/null +++ b/apps/offline_inf_v2/data.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass, field +from typing import List, Optional + +@dataclass +class DataArgs: + data_path: str = field(default="/mnt/pollux/nemo/sample/") + output_path: Optional[str] = field(default=None) + enable_checkpointing: Optional[str] = field(default=None) + id_col: str = field(default="key") + batch_size: int = field(default=1) + mini_batch_size: int = field(default=1) + num_threads_per_worker: int = field(default=4) + image_sizes: List[int] = field(default_factory=lambda: [256, 512]) + patch_size: int = field(default=16) + dynamic_crop_ratio: float = field(default=1.0) + image_latent_column: str = field(default="image_latent") + image_latent_shape_column: str = field(default="image_latent_shape") + caption_column: str = field(default="caption") + valid_column: str = field(default="valid") diff --git a/apps/offline_inf_v2/inference.py b/apps/offline_inf_v2/inference.py new file mode 100644 index 00000000..99a522a9 --- /dev/null +++ b/apps/offline_inf_v2/inference.py @@ -0,0 +1,86 @@ +""" +conda activate curator +python -m apps.offline_inf_v2.inference.py config=apps/offline_inf_v2/configs/inference.yaml +""" + +from dataclasses import dataclass, field +from omegaconf import OmegaConf + +import dask_cudf +from nemo_curator import get_client +from nemo_curator.datasets import ImageTextPairDataset + +from apps.offline_inf_v2.data import DataArgs +from apps.offline_inf_v2.model import ModelArgs +from apps.offline_inf_v2.vae_latent_extractor import VAELatentExtractor + + +@dataclass +class InferenceArgs: + name: str = field(default="inference") + model: ModelArgs = field(default_factory=ModelArgs) + data: DataArgs = field(default_factory=DataArgs) + + +def init_image_text_dataset(cfg: InferenceArgs): + metadata = dask_cudf.read_parquet(cfg.data.data_path, split_row_groups=False, blocksize=None) + if 'status' in metadata.columns: + metadata = metadata[metadata.status != "failed_to_download"] + metadata = metadata.map_partitions(ImageTextPairDataset._sort_partition, id_col=cfg.data.id_col) + tar_files = ImageTextPairDataset._get_tar_files(cfg.data.data_path) + return ImageTextPairDataset( + path=cfg.data.data_path, + metadata=metadata, + tar_files=tar_files, + id_col=cfg.data.id_col, + ) + + +def main(): + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.structured(InferenceArgs()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_object(cfg) + + assert cfg.data.output_path is not None, f"Output path is required, otherwise the parquets in {cfg.data.data_path} will be overwritten" + + client = get_client( + cluster_type="gpu", + ) + + print("Inititiating dataset ...") + dataset = init_image_text_dataset(cfg) + print("Dataset initialized") + + print("Initializing latent extractor ...") + latent_extractor = VAELatentExtractor( + model_args=cfg.model, + data_args=cfg.data, + use_index_files=True, + ) + print("Latent extractor initialized") + + dataset_with_latents = latent_extractor(dataset) + + latent_columns = [ + f"{cfg.data.image_latent_column}_{cfg.data.image_sizes[0]}", + f"{cfg.data.image_latent_shape_column}_{cfg.data.image_sizes[0]}", + f"{cfg.data.image_latent_column}_{cfg.data.image_sizes[1]}", + f"{cfg.data.image_latent_shape_column}_{cfg.data.image_sizes[1]}", + ] + + print("Extracting latents ...") + # Metadata will have a new column named "image_latent" + dataset_with_latents.save_metadata( + cfg.data.output_path, columns=[ + cfg.data.id_col, "_id", cfg.data.caption_column, cfg.data.valid_column, *latent_columns + ] + ) + print("Latents extracted") + +if __name__ == "__main__": + main() diff --git a/apps/offline_inf_v2/model.py b/apps/offline_inf_v2/model.py new file mode 100644 index 00000000..f4cd6df5 --- /dev/null +++ b/apps/offline_inf_v2/model.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from dataclasses import dataclass, field +import logging +import torch +from torch import nn +from apps.main.modules.vae import build_vae, LatentVideoVAEArgs + +logger = logging.getLogger() + + +@dataclass +class ModelArgs: + gen_vae: LatentVideoVAEArgs = field(default_factory=LatentVideoVAEArgs) + autocast: bool = field(default=False) + use_compile: bool = field(default=False) + +class VAE(nn.Module): + """ + VAE Model + """ + + VERSION: str = "v1.0" + + def __init__(self, args: ModelArgs): + super().__init__() + + self.vae_compressor = build_vae(args.gen_vae) + + @torch.no_grad() + def forward( + self, image: torch.Tensor + ) -> torch.Tensor: + + # Process latent code + image = image.cuda() + latent_code = self.vae_compressor.encode(image) + + return latent_code diff --git a/apps/offline_inf_v2/pipeline.py b/apps/offline_inf_v2/pipeline.py new file mode 100644 index 00000000..31a95184 --- /dev/null +++ b/apps/offline_inf_v2/pipeline.py @@ -0,0 +1,81 @@ +from io import BytesIO +import numpy as np +from pathlib import Path +import nvidia.dali.fn as fn +import nvidia.dali.types as types +from nvidia.dali.pipeline import pipeline_def +from PIL import Image + +from apps.main.modules.preprocess import var_center_crop_size_fn +from apps.offline_inf_v2.data import DataArgs + + +def decode_image_raw(image_raw): + try: + image = Image.open(BytesIO(image_raw)).convert('RGB') + return np.asarray(image, dtype=np.uint8), np.ones(1, dtype=np.uint8) + except Exception as e: + print(f"Error decoding image: {e}") + return np.zeros((1, 1, 3), dtype=np.uint8), np.zeros(1, dtype=np.uint8) + + +@pipeline_def +def image_loading_pipeline(_tar_path: str, use_index_files: bool, data_args: DataArgs): + if use_index_files: + index_paths = [str(Path(_tar_path).with_suffix(".idx"))] + else: + index_paths = [] + + images_raw, text, json = fn.readers.webdataset( + paths=_tar_path, + index_paths=index_paths, + ext=["jpg", "txt", "json"], + missing_component_behavior="error", + ) + + # images_gen = fn.experimental.decoders.image(images_raw, device="cpu", output_type=types.RGB) + images_gen, valid = fn.python_function(images_raw, function=decode_image_raw, device="cpu", num_outputs=2) + + images_gen = fn.resize( + images_gen.gpu(), + device="gpu", + resize_x=data_args.image_sizes[1], + resize_y=data_args.image_sizes[1], + mode="not_smaller", + interp_type=types.DALIInterpType.INTERP_CUBIC + ) + + # get the dynamic crop size + crop_size = fn.python_function( + images_gen.shape(device="cpu"), + data_args.image_sizes[1], + data_args.patch_size, + data_args.dynamic_crop_ratio, + function=var_center_crop_size_fn, + device="cpu" + ) + + images_gen = fn.crop_mirror_normalize( + images_gen, + device="gpu", + crop_h=crop_size[0], + crop_w=crop_size[1], + crop_pos_x=0.5, + crop_pos_y=0.5, + mirror=fn.random.coin_flip(probability=0.5), + dtype=types.DALIDataType.FLOAT, + mean=[0.5 * 255, 0.5 * 255, 0.5 * 255], + std=[0.5 * 255, 0.5 * 255, 0.5 * 255], + scale=1.0, + ) + + images_plan = fn.resize( + images_gen, + device="gpu", + resize_x=data_args.image_sizes[0], + resize_y=data_args.image_sizes[0], + mode="not_smaller", + interp_type=types.DALIInterpType.INTERP_CUBIC + ) + + return images_plan, images_gen, text, json, valid \ No newline at end of file diff --git a/apps/offline_inf_v2/vae_latent_extractor.py b/apps/offline_inf_v2/vae_latent_extractor.py new file mode 100644 index 00000000..8591a0de --- /dev/null +++ b/apps/offline_inf_v2/vae_latent_extractor.py @@ -0,0 +1,501 @@ +from collections import defaultdict +import json +import os +from pathlib import Path +from tqdm import tqdm +import cupy as cp +from crossfit.backend.cudf.series import create_list_series_from_1d_or_2d_ar + +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import torch +from nvidia.dali import pipeline_def +from nvidia.dali.plugin.pytorch import feed_ndarray +from nvidia.dali.tensors import TensorListGPU + +from nemo_curator.datasets import ImageTextPairDataset +from nemo_curator.utils.distributed_utils import load_object_on_worker + +from apps.main.modules.preprocess import generate_crop_size_list, var_center_crop_size_fn +from apps.offline_inf_v2.data import DataArgs +from apps.offline_inf_v2.model import ModelArgs, VAE +from apps.offline_inf_v2.pipeline import image_loading_pipeline + + +class VAELatentExtractor: + def __init__( + self, + model_args: ModelArgs, + data_args: DataArgs, + use_index_files: bool = False, + ) -> None: + """ + Constructs the embedder. + + Args: + model_args (ModelArgs): The arguments for the VAE. + data_args (DataArgs): The arguments for the data. + use_index_files (bool): If True, tries to find and use index files generated + by DALI at the same path as the tar file shards. The index files must be + generated by DALI's wds2idx tool. See https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/dataloading_webdataset.html#Creating-an-index + for more information. Each index file must be of the form "shard_id.idx" + where shard_id is the same integer as the corresponding tar file for the + data. The index files must be in the same folder as the tar files. + """ + self.model_args = model_args + self.data_args = data_args + self.use_index_files = use_index_files + + def load_dataset_shard(self, tar_path: str): + """ + Loads a WebDataset tar shard using DALI. + + Args: + tar_path (str): The path of the tar shard to load. + + Returns: + Iterable: An iterator over the dataset. Each tar file + must have 3 files per record: a .jpg file, a .txt file, + and a .json file. The .jpg file must contain the image, the + .txt file must contain the associated caption, and the + .json must contain the metadata for the record (including + its ID). Images will be loaded using DALI. + """ + + dali_pipeline_args = { + "batch_size": self.data_args.batch_size, + "num_threads": self.data_args.num_threads_per_worker, + "device_id": 0, + "exec_dynamic": True, + "prefetch_queue_depth": 4, + } + + if self.data_args.enable_checkpointing: + dali_pipeline_args["enable_checkpointing"] = True + checkpoint_path = f"{tar_path.rsplit('.', 1)[0]}.pth" + if os.path.exists(checkpoint_path): + print (f"Restoring checkpoint from {checkpoint_path}") + checkpoint = open(checkpoint_path, 'rb').read() + dali_pipeline_args["checkpoint"] = checkpoint + + image_dataset = image_loading_pipeline( + tar_path, + self.use_index_files, + self.data_args, + **dali_pipeline_args, + ) + + image_dataset.build() + + total_samples = image_dataset.epoch_size() + total_samples = total_samples[list(total_samples.keys())[0]] + + crop_size_list = generate_crop_size_list( + image_size=self.data_args.image_sizes[0], + patch_size=self.data_args.patch_size, + max_ratio=2.0 + ) + + bucket_img_plan = defaultdict(list) + bucket_img_gen = defaultdict(list) + bucket_text = defaultdict(list) + bucket_meta = defaultdict(list) + bucket_valid = defaultdict(list) + # pbar = tqdm(total=total_samples, desc="Loading dataset shard") + + samples_completed = 0 + while samples_completed < total_samples: + image_plan, image_gen, text, meta, valid = image_dataset.run() + for i in range(self.data_args.batch_size): + img_plan, img_gen = image_plan[i], image_gen[i] + _, w, h = img_plan.shape() + crop_size = (w, h) + bucket_img_plan[crop_size].append(img_plan) + bucket_img_gen[crop_size].append(img_gen) + bucket_text[crop_size].append(text.at(i).tostring().decode("utf-8")) + bucket_meta[crop_size].append(json.loads(meta.at(i).tostring().decode("utf-8"))) + bucket_valid[crop_size].append(valid.at(i).tolist()) + # if batch size is reached, yield the batch + if len(bucket_img_plan[crop_size]) == self.data_args.batch_size: + image_plan_batch = TensorListGPU(bucket_img_plan[crop_size]).as_tensor() + image_gen_batch = TensorListGPU(bucket_img_gen[crop_size]).as_tensor() + image_plan_torch = torch.empty(image_plan_batch.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_plan_batch, image_plan_torch) # COPY !!! + image_gen_torch = torch.empty(image_gen_batch.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_gen_batch, image_gen_torch) # COPY !!! + yield image_plan_torch, image_gen_torch, bucket_text[crop_size], bucket_meta[crop_size], bucket_valid[crop_size] + bucket_img_plan[crop_size] = [] + bucket_img_gen[crop_size] = [] + bucket_text[crop_size] = [] + bucket_meta[crop_size] = [] + bucket_valid[crop_size] = [] + samples_completed += self.data_args.batch_size + + # if samples_completed % (100 * self.data_args.batch_size) == 0 and self.data_args.enable_checkpointing is not None: + # pipeline.checkpoint(checkpoint_path) + # pbar.update(self.data_args.batch_size) + # pbar.close() + + for crop_size in crop_size_list: + if not bucket_img_plan[crop_size]: + continue + image_plan_batch = TensorListGPU(bucket_img_plan[crop_size]).as_tensor() + image_gen_batch = TensorListGPU(bucket_img_gen[crop_size]).as_tensor() + image_plan_torch = torch.empty(image_plan_batch.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_plan_batch, image_plan_torch) # COPY !!! + image_gen_torch = torch.empty(image_gen_batch.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_gen_batch, image_gen_torch) # COPY !!! + yield image_plan_torch, image_gen_torch, bucket_text[crop_size], bucket_meta[crop_size], bucket_valid[crop_size] + + # checkpoint final state + if self.data_args.enable_checkpointing: + image_dataset.checkpoint(checkpoint_path) + + def load_model(self, model_args, device="cuda"): + """ + Loads the model used to generate image embeddings. + + Args: + device (str): A PyTorch device identifier that specifies what GPU + to load the model on. + + Returns: + Callable: A model loaded on the specified device. + The model's forward call may be augmented with torch.autocast() + """ + model = VAE(model_args).eval().to(device) + model = self._configure_forward(model) + return model + + def _configure_forward(self, model): + original_forward = model.forward + + def custom_forward(*args, **kwargs): + if self.model_args.autocast: + with torch.amp.autocast(device_type="cuda"): + image_features = original_forward(*args, **kwargs) + else: + image_features = original_forward(*args, **kwargs) + + # Inference can be done in lower precision, but cuDF can only handle fp32 + return image_features.to(torch.float32) + + model.forward = custom_forward + return model + + def _process_batch(self, model, batch): + """Helper method to process a batch with appropriate chunking""" + if batch.shape[0] > self.data_args.mini_batch_size and batch.shape[0] % self.data_args.mini_batch_size == 0: + # Process in chunks of mini_batch_size to avoid OOM + sub_batches = batch.chunk(batch.shape[0] // self.data_args.mini_batch_size) + sub_latents = [] + for sub_batch in sub_batches: + sub_latent = model(sub_batch) + sub_latents.append(sub_latent.cpu()) + return torch.cat(sub_latents, dim=0) + else: + latents = model(batch) + return latents.cpu() # Move to CPU immediately + + def _process_batch_pollux(self, model, image_plan_batch, image_gen_batch): + plan_latents = self._process_batch(model, image_plan_batch) + gen_latents = self._process_batch(model, image_gen_batch) + return plan_latents, gen_latents + + def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: + """ + Generates image embeddings for all images in the dataset. + + Args: + dataset (ImageTextPairDataset): The dataset to create image embeddings for. + + Returns: + ImageTextPairDataset: A dataset with image embeddings and potentially + classifier scores. + """ + meta = dataset.metadata.dtypes.to_dict() + latent_columns = [ + f"{self.data_args.image_latent_column}_{self.data_args.image_sizes[0]}", + f"{self.data_args.image_latent_shape_column}_{self.data_args.image_sizes[0]}", + f"{self.data_args.image_latent_column}_{self.data_args.image_sizes[1]}", + f"{self.data_args.image_latent_shape_column}_{self.data_args.image_sizes[1]}", + ] + for col in latent_columns: + meta[col] = "object" + meta[self.data_args.caption_column] = "object" + meta[self.data_args.valid_column] = "object" + embedding_df = dataset.metadata.map_partitions( + self._run_inference, dataset.tar_files, dataset.id_col, meta=meta + ) + + return ImageTextPairDataset( + dataset.path, + metadata=embedding_df, + tar_files=dataset.tar_files, + id_col=dataset.id_col, + ) + + def _run_inference(self, partition, tar_paths, id_col, partition_info=None): + tar_path = tar_paths[partition_info["number"]] + device = "cuda" + + model = load_object_on_worker( + "model", + self.load_model, + {"model_args": self.model_args, "device": device}, + ) + + dataset = self.load_dataset_shard2(tar_path) + final_image_plan_latents = [] + final_image_gen_latents = [] + final_text_batch = [] + image_ids = [] + is_valid = [] + samples_completed = 0 + expected_samples = len(partition) + progress_bar = tqdm( + total=expected_samples, + desc=f"{tar_path} - Latent extraction with {self.model_args.gen_vae.model_name}", + ) + + # Process batches + with torch.no_grad(), torch.amp.autocast(device_type="cuda", enabled=self.model_args.autocast): + for image_plan_batch, image_gen_batch, text_batch, metadata, valid in dataset: + # Only process as many samples as we expect from metadata + if samples_completed >= expected_samples: + break + + # Calculate how many samples we should process from this batch + remaining = expected_samples - samples_completed + actual_batch_size = min(len(metadata), remaining) + + if actual_batch_size < len(metadata): + # Truncate batch and metadata if needed + image_plan_batch = image_plan_batch[:actual_batch_size] + image_gen_batch = image_gen_batch[:actual_batch_size] + text_batch = text_batch[:actual_batch_size] + metadata = metadata[:actual_batch_size] + valid = valid[:actual_batch_size] + + plan_latents, gen_latents = self._process_batch_pollux(model, image_plan_batch, image_gen_batch) + del image_plan_batch + del image_gen_batch + + final_image_plan_latents.append(plan_latents) + final_image_gen_latents.append(gen_latents) + final_text_batch.extend(text_batch) + image_ids.extend(m[id_col] for m in metadata) + is_valid.extend(valid) + samples_completed += actual_batch_size + progress_bar.update(actual_batch_size) + + # Clear CUDA cache frequently + if samples_completed % (self.data_args.batch_size * 5) == 0: + torch.cuda.empty_cache() + progress_bar.close() + + if samples_completed != len(partition): + raise RuntimeError( + f"Mismatch in sample count for partition {partition_info['number']}. " + f"{len(partition)} samples found in the metadata, but {samples_completed} found in {tar_path}." + ) + + # Process embeddings in memory-efficient way + partition = self._process_embeddings(partition, final_image_plan_latents, image_ids, image_size=self.data_args.image_sizes[0]) + partition = self._process_embeddings(partition, final_image_gen_latents, image_ids, image_size=self.data_args.image_sizes[1]) + partition = self._process_captions(partition, final_text_batch, image_ids) + partition = self._process_valid(partition, is_valid, image_ids) + return partition + + def _process_captions(self, partition, final_text_batch, image_ids): + sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + sorted_text_batch = [final_text_batch[i] for i in sorted_indices] + partition[self.data_args.caption_column] = sorted_text_batch + return partition + + def _process_valid(self, partition, is_valid, image_ids): + sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + sorted_valid = [is_valid[i] for i in sorted_indices] + partition[self.data_args.valid_column] = sorted_valid + return partition + + def _process_embeddings(self, partition, final_image_latents, image_ids, image_size): + """Process embeddings in a memory-efficient way""" + + ''' + # Check if we need to handle variable-sized latents + if self.data_args.dynamic_crop_ratio > 1.0: + # Handle variable-sized latents + return self._process_variable_sized_embeddings(partition, final_image_latents, image_ids) + else: + ''' + + # Order the output of the shard + sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + # Process fixed-size latents as before + # Process in chunks to reduce memory usage + all_embeddings = torch.cat(final_image_latents, dim=0) + sorted_embeddings = all_embeddings[sorted_indices] + + embedding_shape_list = [emb.shape for emb in sorted_embeddings] + + # View the embeddings to be [N, 16*32*32] + sorted_embeddings = sorted_embeddings.view(sorted_embeddings.shape[0], -1) + + # Process in chunks to avoid OOM + chunk_size = 1000 # Adjust based on your GPU memory + concat_embedding_output = None + for i in range(0, sorted_embeddings.shape[0], chunk_size): + end_idx = min(i + chunk_size, sorted_embeddings.shape[0]) + chunk = sorted_embeddings[i:end_idx].cuda() # Move chunk to GPU + chunk_cp = cp.asarray(chunk) # Convert to CuPy + + if concat_embedding_output is None: + concat_embedding_output = chunk_cp + else: + concat_embedding_output = cp.concatenate([concat_embedding_output, chunk_cp], axis=0) + + # Free GPU memory + del chunk + torch.cuda.empty_cache() + + partition[f"{self.data_args.image_latent_column}_{image_size}"] = create_list_series_from_1d_or_2d_ar( + concat_embedding_output, index=partition.index + ) + + # Convert embedding_shape_list to a CuPy array before passing it + embedding_shape_array = cp.array(embedding_shape_list) + partition[f"{self.data_args.image_latent_shape_column}_{image_size}"] = create_list_series_from_1d_or_2d_ar( + embedding_shape_array, index=partition.index + ) + + del concat_embedding_output + del final_image_latents + del embedding_shape_array # Also clean up the new array + torch.cuda.empty_cache() + + return partition + + # def _process_variable_sized_embeddings(self, partition, final_image_latents, image_ids): + # """Process embeddings with variable sizes due to dynamic cropping""" + # # Order the output of the shard + # sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + + # # Flatten our list of tensors + # flat_embeddings = [] + # current_idx = 0 + + # for batch_tensors in final_image_latents: + # batch_size = batch_tensors.shape[0] + # for i in range(batch_size): + # flat_embeddings.append(batch_tensors[i]) + # current_idx += 1 + + # # Create sorted embeddings list + # sorted_embeddings = [flat_embeddings[idx] for idx in sorted_indices] + # del flat_embeddings + + # # Process each embedding individually and create a list of numpy arrays + # embedding_list = [] + # embedding_shape_list = [] + # for emb in sorted_embeddings: + # # Flatten the embedding to 1D + # embedding_shape_list.append(emb.shape) + # flat_emb = emb.view(-1).numpy() # Convert to numpy array + # embedding_list.append(flat_emb) + + # # Assign to partition + # partition[self.data_args.image_latent_column] = embedding_list + # partition[self.data_args.image_latent_shape_column] = embedding_shape_list + + # del embedding_list + # del sorted_embeddings + # del final_image_latents + # torch.cuda.empty_cache() + + # return partition + + + def load_dataset_shard2(self, tar_path: str): + """ + Loads a WebDataset tar shard using DALI. + + Args: + tar_path (str): The path of the tar shard to load. + + Returns: + Iterable: An iterator over the dataset. Each tar file + must have 3 files per record: a .jpg file, a .txt file, + and a .json file. The .jpg file must contain the image, the + .txt file must contain the associated caption, and the + .json must contain the metadata for the record (including + its ID). Images will be loaded using DALI. + """ + + dali_pipeline_args = { + "batch_size": self.data_args.batch_size, + "num_threads": self.data_args.num_threads_per_worker, + "device_id": 0, + "exec_dynamic": True, + "prefetch_queue_depth": 8, + } + + if self.data_args.enable_checkpointing: + dali_pipeline_args["enable_checkpointing"] = True + checkpoint_path = f"{tar_path.rsplit('.', 1)[0]}.pth" + if os.path.exists(checkpoint_path): + print (f"Restoring checkpoint from {checkpoint_path}") + checkpoint = open(checkpoint_path, 'rb').read() + dali_pipeline_args["checkpoint"] = checkpoint + + image_dataset = image_loading_pipeline( + tar_path, + self.use_index_files, + self.data_args, + **dali_pipeline_args, + ) + + image_dataset.build() + + total_samples = image_dataset.epoch_size() + total_samples = total_samples[list(total_samples.keys())[0]] + + # pbar = tqdm(total=total_samples, desc="Loading dataset shard") + + samples_completed = 0 + while samples_completed < total_samples: + image_plan, image_gen, text, meta, valid = image_dataset.run() + + image_plan = image_plan.as_tensor() + image_plan_torch = torch.empty(image_plan.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_plan, image_plan_torch) # COPY !!! + image_plan = image_plan_torch + + image_gen = image_gen.as_tensor() + image_gen_torch = torch.empty(image_gen.shape(), dtype=torch.float32, device="cuda") + feed_ndarray(image_gen, image_gen_torch) # COPY !!! + image_gen = image_gen_torch + + captions = [text.at(i).tostring().decode("utf-8") for i in range(len(text))] + metadata = [ + json.loads(meta.at(i).tostring().decode("utf-8")) + for i in range(len(meta)) + ] + valid = [valid.at(i).tolist() for i in range(len(valid))] + remaining_samples = total_samples - samples_completed + if image_plan.shape[0] >= remaining_samples: + image_plan = image_plan[:remaining_samples] + image_gen = image_gen[:remaining_samples] + captions = captions[:remaining_samples] + metadata = metadata[:remaining_samples] + + remaining_samples = total_samples - samples_completed + + samples_completed += min(image_plan.shape[0], remaining_samples) + + yield image_plan, image_gen, captions, metadata, valid + + # checkpoint final state + if self.data_args.enable_checkpointing: + image_dataset.checkpoint(checkpoint_path) + \ No newline at end of file