From e9863fe313ca43e93e23995984ecb25b898c3115 Mon Sep 17 00:00:00 2001 From: wenxin0319 Date: Sun, 31 May 2026 21:19:55 -0700 Subject: [PATCH] Add torch_compile flag for training networks --- fastgen/configs/config.py | 3 + fastgen/methods/distribution_matching/dmd2.py | 6 + fastgen/methods/model.py | 8 + tests/test_torch_compile.py | 154 ++++++++++++++++++ 4 files changed, 171 insertions(+) create mode 100644 tests/test_torch_compile.py diff --git a/fastgen/configs/config.py b/fastgen/configs/config.py index c7cbf3a..58f6a44 100644 --- a/fastgen/configs/config.py +++ b/fastgen/configs/config.py @@ -160,6 +160,9 @@ class BaseModelConfig: # - however, it is required if the model has a discriminator or the net initializes unused modules (e.g., for logvar predictions) ddp_find_unused_parameters: bool = True + # enable torch.compile for training networks + torch_compile: bool = False + # precision variables (choose from "float64", "float32", "bfloat16", or "float16") # (precision of the time steps is handled in the noise scheduler, defaulting to float64 for numerical stability) diff --git a/fastgen/methods/distribution_matching/dmd2.py b/fastgen/methods/distribution_matching/dmd2.py index b106e61..26dfa7d 100644 --- a/fastgen/methods/distribution_matching/dmd2.py +++ b/fastgen/methods/distribution_matching/dmd2.py @@ -64,6 +64,12 @@ def build_model(self): synchronize() torch.cuda.empty_cache() + def _apply_torch_compile(self): + super()._apply_torch_compile() + logger.info("Applying torch.compile to teacher and fake_score") + self.teacher = torch.compile(self.teacher) + self.fake_score = torch.compile(self.fake_score) + def _setup_grad_requirements(self, iteration: int) -> None: if iteration % self.config.student_update_freq == 0: # update the student diff --git a/fastgen/methods/model.py b/fastgen/methods/model.py index 72d6f05..9562795 100644 --- a/fastgen/methods/model.py +++ b/fastgen/methods/model.py @@ -62,6 +62,10 @@ def __init__(self, config: BaseModelConfig): # instantiate all necessary nets and submodules self.build_model() + # optionally compile networks with torch.compile + if self.config.torch_compile: + self._apply_torch_compile() + def _setup_ema(self): """Initialize EMA networks. Only call during build_model(), before checkpoint loading.""" for name in self.use_ema: @@ -264,6 +268,10 @@ def build_model(self): if hasattr(self.net, "init_preprocessors") and self.config.enable_preprocessors: self.net.init_preprocessors() + def _apply_torch_compile(self): + logger.info("Applying torch.compile to net") + self.net = torch.compile(self.net) + def on_train_begin(self, is_fsdp=False): self._is_fsdp = is_fsdp # Store for later use (e.g., to skip EMA during inference) ctx = dict(dtype=self.precision, device=self.device) diff --git a/tests/test_torch_compile.py b/tests/test_torch_compile.py new file mode 100644 index 0000000..51f05ca --- /dev/null +++ b/tests/test_torch_compile.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import gc +import torch +import pytest +from fastgen.methods import FastGenModel, DMD2Model +from fastgen.configs.methods.config_sft import ModelConfig as SFTModelConfig +from fastgen.configs.methods.config_dmd2 import ModelConfig as DMD2ModelConfig +from fastgen.configs.config_utils import override_config_with_opts +from fastgen.methods.fine_tuning.sft import SFTModel + + +def _is_compiled(module): + return isinstance(module, torch._dynamo.OptimizedModule) + + +@pytest.fixture +def sft_model_compiled(): + gc.collect() + instance = SFTModelConfig() + opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1", "r_timestep=False"] + instance.net = override_config_with_opts(instance.net, opts) + instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16" + instance.pretrained_model_path = "" + instance.input_shape = [3, 8, 8] + instance.torch_compile = True + instance.cond_dropout_prob = 0.1 + instance.cond_keys_no_dropout = [] + instance.guidance_scale = None + return SFTModel(instance) + + +@pytest.fixture +def sft_model_not_compiled(): + gc.collect() + instance = SFTModelConfig() + opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1", "r_timestep=False"] + instance.net = override_config_with_opts(instance.net, opts) + instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16" + instance.pretrained_model_path = "" + instance.input_shape = [3, 8, 8] + instance.torch_compile = False + instance.cond_dropout_prob = 0.1 + instance.cond_keys_no_dropout = [] + instance.guidance_scale = None + return SFTModel(instance) + + +@pytest.fixture +def dmd2_model_compiled(): + gc.collect() + instance = DMD2ModelConfig() + opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1"] + instance.net = override_config_with_opts(instance.net, opts) + opts_discriminator = ["-", "feature_indices=[0]", "all_res=[8]", "in_channels=128"] + instance.discriminator = override_config_with_opts(instance.discriminator, opts_discriminator) + instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16" + instance.pretrained_model_path = "" + instance.student_update_freq = 2 + instance.input_shape = [3, 8, 8] + instance.torch_compile = True + return DMD2Model(instance) + + +@pytest.fixture +def dmd2_model_not_compiled(): + gc.collect() + instance = DMD2ModelConfig() + opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1"] + instance.net = override_config_with_opts(instance.net, opts) + opts_discriminator = ["-", "feature_indices=[0]", "all_res=[8]", "in_channels=128"] + instance.discriminator = override_config_with_opts(instance.discriminator, opts_discriminator) + instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16" + instance.pretrained_model_path = "" + instance.student_update_freq = 2 + instance.input_shape = [3, 8, 8] + instance.torch_compile = False + return DMD2Model(instance) + + +def test_default_torch_compile_is_false(): + from fastgen.configs.config import BaseModelConfig + + config = BaseModelConfig() + assert config.torch_compile is False + + +def test_sft_compile_enabled(sft_model_compiled): + assert _is_compiled(sft_model_compiled.net) + + +def test_sft_compile_disabled(sft_model_not_compiled): + assert not _is_compiled(sft_model_not_compiled.net) + + +def test_dmd2_compile_enabled(dmd2_model_compiled): + assert _is_compiled(dmd2_model_compiled.net) + assert _is_compiled(dmd2_model_compiled.teacher) + assert _is_compiled(dmd2_model_compiled.fake_score) + + +def test_dmd2_compile_disabled(dmd2_model_not_compiled): + assert not _is_compiled(dmd2_model_not_compiled.net) + assert not _is_compiled(dmd2_model_not_compiled.teacher) + assert not _is_compiled(dmd2_model_not_compiled.fake_score) + + +def test_sft_compiled_train_step(sft_model_compiled): + model = sft_model_compiled + model.on_train_begin() + model.init_optimizers() + + batch_size = 1 + labels = torch.nn.functional.one_hot(torch.randint(0, 10, (batch_size,)), num_classes=10).float() + data = { + "real": torch.randn(batch_size, 3, 8, 8).to(model.device, model.precision), + "condition": labels.to(model.device, model.precision), + "neg_condition": torch.zeros(batch_size, 10).to(model.device, model.precision), + } + + loss_map, outputs = model.single_train_step(data, 0) + assert "total_loss" in loss_map + assert not torch.isnan(loss_map["total_loss"]) + loss_map["total_loss"].backward() + + +def test_dmd2_compiled_train_step(dmd2_model_compiled): + model = dmd2_model_compiled + model.on_train_begin() + model.init_optimizers() + + batch_size = 1 + labels = torch.nn.functional.one_hot(torch.randint(0, 10, (batch_size,)), num_classes=10) + data = { + "real": torch.randn(batch_size, 3, 8, 8).to(model.device, model.precision), + "condition": labels.to(model.device, model.precision), + "neg_condition": torch.zeros(batch_size, 10).to(model.device, model.precision), + } + + # Student update step + loss_map, outputs = model.single_train_step(data, 0) + assert "total_loss" in loss_map + assert not torch.isnan(loss_map["total_loss"]) + + # Fake score update step + model.optimizers_zero_grad(1) + loss_map, outputs = model.single_train_step(data, 1) + assert "total_loss" in loss_map + assert not torch.isnan(loss_map["total_loss"])