Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions apps/main/modules/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,58 @@ def center_crop_arr(pil_image, image_size):
)


def generate_crop_size_list(image_size, patch_size, max_ratio=2.0):
assert max_ratio >= 1.0
min_wp, min_hp = image_size // patch_size, image_size // patch_size
crop_size_list = []
wp, hp = min_wp, min_hp
while hp / wp <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
hp += 1
wp, hp = min_wp + 1, min_hp
while wp / hp <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
wp += 1
return crop_size_list


def is_valid_crop_size(cw, ch, orig_w, orig_h):
down_scale = max(cw / orig_w, ch / orig_h)
return cw <= orig_w * down_scale and ch <= orig_h * down_scale


def var_center_crop_size_fn(orig_img_shape, image_size=256, patch_size=16, max_ratio=2.0):
w, h, _ = orig_img_shape
crop_size_list = generate_crop_size_list(
image_size=image_size,
patch_size=patch_size,
max_ratio=max_ratio
)
rem_percent = [
min(cw / w, ch / h) / max(cw / w, ch / h)
if is_valid_crop_size(cw, ch, w, h) else 0
for cw, ch in crop_size_list
]
crop_size = sorted(
((x, y) for x, y in zip(rem_percent, crop_size_list) if x > 0 and y[0] <= w and y[1] <= h),
reverse=True
)[0][1]
return np.array(crop_size, dtype=np.float32)


def downsample_resize_image(image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
pil_image = Image.fromarray(image)
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
return pil_image


######################## FOR TEXT ########################


Expand Down
25 changes: 25 additions & 0 deletions apps/offline_inf_v2/configs/inference.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: inference

model:
gen_vae:
model_name: Hunyuan
pretrained_model_name_or_path: '/jfs/checkpoints/models--tencent--HunyuanVideo/snapshots/2a15b5574ee77888e51ae6f593b2ceed8ce813e5/vae'
enable_tiling: false
enable_slicing: false
autocast: true

data:
data_path: /mnt/pollux/nemo/data/bucket-256-1-0/
output_path: /mnt/pollux/nemo/data/bucket-256-1-0-latents
enable_checkpointing: true
id_col: key
batch_size: 64
mini_batch_size: 16
num_threads_per_worker: 16
image_sizes: [256, 512]
patch_size: 16
dynamic_crop_ratio: 1.0
image_latent_column: image_latent
image_latent_shape_column: image_latent_shape
caption_column: caption
valid_column: valid
19 changes: 19 additions & 0 deletions apps/offline_inf_v2/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass, field
from typing import List, Optional

@dataclass
class DataArgs:
data_path: str = field(default="/mnt/pollux/nemo/sample/")
output_path: Optional[str] = field(default=None)
enable_checkpointing: Optional[str] = field(default=None)
id_col: str = field(default="key")
batch_size: int = field(default=1)
mini_batch_size: int = field(default=1)
num_threads_per_worker: int = field(default=4)
image_sizes: List[int] = field(default_factory=lambda: [256, 512])
patch_size: int = field(default=16)
dynamic_crop_ratio: float = field(default=1.0)
image_latent_column: str = field(default="image_latent")
image_latent_shape_column: str = field(default="image_latent_shape")
caption_column: str = field(default="caption")
valid_column: str = field(default="valid")
86 changes: 86 additions & 0 deletions apps/offline_inf_v2/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
conda activate curator
python -m apps.offline_inf_v2.inference.py config=apps/offline_inf_v2/configs/inference.yaml
"""

from dataclasses import dataclass, field
from omegaconf import OmegaConf

import dask_cudf
from nemo_curator import get_client
from nemo_curator.datasets import ImageTextPairDataset

from apps.offline_inf_v2.data import DataArgs
from apps.offline_inf_v2.model import ModelArgs
from apps.offline_inf_v2.vae_latent_extractor import VAELatentExtractor


@dataclass
class InferenceArgs:
name: str = field(default="inference")
model: ModelArgs = field(default_factory=ModelArgs)
data: DataArgs = field(default_factory=DataArgs)


def init_image_text_dataset(cfg: InferenceArgs):
metadata = dask_cudf.read_parquet(cfg.data.data_path, split_row_groups=False, blocksize=None)
if 'status' in metadata.columns:
metadata = metadata[metadata.status != "failed_to_download"]
metadata = metadata.map_partitions(ImageTextPairDataset._sort_partition, id_col=cfg.data.id_col)
tar_files = ImageTextPairDataset._get_tar_files(cfg.data.data_path)
return ImageTextPairDataset(
path=cfg.data.data_path,
metadata=metadata,
tar_files=tar_files,
id_col=cfg.data.id_col,
)


def main():
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config

default_cfg = OmegaConf.structured(InferenceArgs())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_object(cfg)

assert cfg.data.output_path is not None, f"Output path is required, otherwise the parquets in {cfg.data.data_path} will be overwritten"

client = get_client(
cluster_type="gpu",
)

print("Inititiating dataset ...")
dataset = init_image_text_dataset(cfg)
print("Dataset initialized")

print("Initializing latent extractor ...")
latent_extractor = VAELatentExtractor(
model_args=cfg.model,
data_args=cfg.data,
use_index_files=True,
)
print("Latent extractor initialized")

dataset_with_latents = latent_extractor(dataset)

latent_columns = [
f"{cfg.data.image_latent_column}_{cfg.data.image_sizes[0]}",
f"{cfg.data.image_latent_shape_column}_{cfg.data.image_sizes[0]}",
f"{cfg.data.image_latent_column}_{cfg.data.image_sizes[1]}",
f"{cfg.data.image_latent_shape_column}_{cfg.data.image_sizes[1]}",
]

print("Extracting latents ...")
# Metadata will have a new column named "image_latent"
dataset_with_latents.save_metadata(
cfg.data.output_path, columns=[
cfg.data.id_col, "_id", cfg.data.caption_column, cfg.data.valid_column, *latent_columns
]
)
print("Latents extracted")

if __name__ == "__main__":
main()
39 changes: 39 additions & 0 deletions apps/offline_inf_v2/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

from dataclasses import dataclass, field
import logging
import torch
from torch import nn
from apps.main.modules.vae import build_vae, LatentVideoVAEArgs

logger = logging.getLogger()


@dataclass
class ModelArgs:
gen_vae: LatentVideoVAEArgs = field(default_factory=LatentVideoVAEArgs)
autocast: bool = field(default=False)
use_compile: bool = field(default=False)

class VAE(nn.Module):
"""
VAE Model
"""

VERSION: str = "v1.0"

def __init__(self, args: ModelArgs):
super().__init__()

self.vae_compressor = build_vae(args.gen_vae)

@torch.no_grad()
def forward(
self, image: torch.Tensor
) -> torch.Tensor:

# Process latent code
image = image.cuda()
latent_code = self.vae_compressor.encode(image)

return latent_code
81 changes: 81 additions & 0 deletions apps/offline_inf_v2/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from io import BytesIO
import numpy as np
from pathlib import Path
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.pipeline import pipeline_def
from PIL import Image

from apps.main.modules.preprocess import var_center_crop_size_fn
from apps.offline_inf_v2.data import DataArgs


def decode_image_raw(image_raw):
try:
image = Image.open(BytesIO(image_raw)).convert('RGB')
return np.asarray(image, dtype=np.uint8), np.ones(1, dtype=np.uint8)
except Exception as e:
print(f"Error decoding image: {e}")
return np.zeros((1, 1, 3), dtype=np.uint8), np.zeros(1, dtype=np.uint8)


@pipeline_def
def image_loading_pipeline(_tar_path: str, use_index_files: bool, data_args: DataArgs):
if use_index_files:
index_paths = [str(Path(_tar_path).with_suffix(".idx"))]
else:
index_paths = []

images_raw, text, json = fn.readers.webdataset(
paths=_tar_path,
index_paths=index_paths,
ext=["jpg", "txt", "json"],
missing_component_behavior="error",
)

# images_gen = fn.experimental.decoders.image(images_raw, device="cpu", output_type=types.RGB)
images_gen, valid = fn.python_function(images_raw, function=decode_image_raw, device="cpu", num_outputs=2)

images_gen = fn.resize(
images_gen.gpu(),
device="gpu",
resize_x=data_args.image_sizes[1],
resize_y=data_args.image_sizes[1],
mode="not_smaller",
interp_type=types.DALIInterpType.INTERP_CUBIC
)

# get the dynamic crop size
crop_size = fn.python_function(
images_gen.shape(device="cpu"),
data_args.image_sizes[1],
data_args.patch_size,
data_args.dynamic_crop_ratio,
function=var_center_crop_size_fn,
device="cpu"
)

images_gen = fn.crop_mirror_normalize(
images_gen,
device="gpu",
crop_h=crop_size[0],
crop_w=crop_size[1],
crop_pos_x=0.5,
crop_pos_y=0.5,
mirror=fn.random.coin_flip(probability=0.5),
dtype=types.DALIDataType.FLOAT,
mean=[0.5 * 255, 0.5 * 255, 0.5 * 255],
std=[0.5 * 255, 0.5 * 255, 0.5 * 255],
scale=1.0,
)

images_plan = fn.resize(
images_gen,
device="gpu",
resize_x=data_args.image_sizes[0],
resize_y=data_args.image_sizes[0],
mode="not_smaller",
interp_type=types.DALIInterpType.INTERP_CUBIC
)

return images_plan, images_gen, text, json, valid
Loading