Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fastgen/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ class BaseModelConfig:
# - however, it is required if the model has a discriminator or the net initializes unused modules (e.g., for logvar predictions)
ddp_find_unused_parameters: bool = True

# enable torch.compile for training networks
torch_compile: bool = False

# precision variables (choose from "float64", "float32", "bfloat16", or "float16")
# (precision of the time steps is handled in the noise scheduler, defaulting to float64 for numerical stability)

Expand Down
6 changes: 6 additions & 0 deletions fastgen/methods/distribution_matching/dmd2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def build_model(self):
synchronize()
torch.cuda.empty_cache()

def _apply_torch_compile(self):
super()._apply_torch_compile()
logger.info("Applying torch.compile to teacher and fake_score")
self.teacher = torch.compile(self.teacher)
self.fake_score = torch.compile(self.fake_score)

def _setup_grad_requirements(self, iteration: int) -> None:
if iteration % self.config.student_update_freq == 0:
# update the student
Expand Down
8 changes: 8 additions & 0 deletions fastgen/methods/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def __init__(self, config: BaseModelConfig):
# instantiate all necessary nets and submodules
self.build_model()

# optionally compile networks with torch.compile
if self.config.torch_compile:
self._apply_torch_compile()

def _setup_ema(self):
"""Initialize EMA networks. Only call during build_model(), before checkpoint loading."""
for name in self.use_ema:
Expand Down Expand Up @@ -264,6 +268,10 @@ def build_model(self):
if hasattr(self.net, "init_preprocessors") and self.config.enable_preprocessors:
self.net.init_preprocessors()

def _apply_torch_compile(self):
logger.info("Applying torch.compile to net")
self.net = torch.compile(self.net)

def on_train_begin(self, is_fsdp=False):
self._is_fsdp = is_fsdp # Store for later use (e.g., to skip EMA during inference)
ctx = dict(dtype=self.precision, device=self.device)
Expand Down
154 changes: 154 additions & 0 deletions tests/test_torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import gc
import torch
import pytest
from fastgen.methods import FastGenModel, DMD2Model
from fastgen.configs.methods.config_sft import ModelConfig as SFTModelConfig
from fastgen.configs.methods.config_dmd2 import ModelConfig as DMD2ModelConfig
from fastgen.configs.config_utils import override_config_with_opts
from fastgen.methods.fine_tuning.sft import SFTModel


def _is_compiled(module):
return isinstance(module, torch._dynamo.OptimizedModule)


@pytest.fixture
def sft_model_compiled():
gc.collect()
instance = SFTModelConfig()
opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1", "r_timestep=False"]
instance.net = override_config_with_opts(instance.net, opts)
instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16"
instance.pretrained_model_path = ""
instance.input_shape = [3, 8, 8]
instance.torch_compile = True
instance.cond_dropout_prob = 0.1
instance.cond_keys_no_dropout = []
instance.guidance_scale = None
return SFTModel(instance)


@pytest.fixture
def sft_model_not_compiled():
gc.collect()
instance = SFTModelConfig()
opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1", "r_timestep=False"]
instance.net = override_config_with_opts(instance.net, opts)
instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16"
instance.pretrained_model_path = ""
instance.input_shape = [3, 8, 8]
instance.torch_compile = False
instance.cond_dropout_prob = 0.1
instance.cond_keys_no_dropout = []
instance.guidance_scale = None
return SFTModel(instance)


@pytest.fixture
def dmd2_model_compiled():
gc.collect()
instance = DMD2ModelConfig()
opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1"]
instance.net = override_config_with_opts(instance.net, opts)
opts_discriminator = ["-", "feature_indices=[0]", "all_res=[8]", "in_channels=128"]
instance.discriminator = override_config_with_opts(instance.discriminator, opts_discriminator)
instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16"
instance.pretrained_model_path = ""
instance.student_update_freq = 2
instance.input_shape = [3, 8, 8]
instance.torch_compile = True
return DMD2Model(instance)


@pytest.fixture
def dmd2_model_not_compiled():
gc.collect()
instance = DMD2ModelConfig()
opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1"]
instance.net = override_config_with_opts(instance.net, opts)
opts_discriminator = ["-", "feature_indices=[0]", "all_res=[8]", "in_channels=128"]
instance.discriminator = override_config_with_opts(instance.discriminator, opts_discriminator)
instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16"
instance.pretrained_model_path = ""
instance.student_update_freq = 2
instance.input_shape = [3, 8, 8]
instance.torch_compile = False
return DMD2Model(instance)


def test_default_torch_compile_is_false():
from fastgen.configs.config import BaseModelConfig

config = BaseModelConfig()
assert config.torch_compile is False


def test_sft_compile_enabled(sft_model_compiled):
assert _is_compiled(sft_model_compiled.net)


def test_sft_compile_disabled(sft_model_not_compiled):
assert not _is_compiled(sft_model_not_compiled.net)


def test_dmd2_compile_enabled(dmd2_model_compiled):
assert _is_compiled(dmd2_model_compiled.net)
assert _is_compiled(dmd2_model_compiled.teacher)
assert _is_compiled(dmd2_model_compiled.fake_score)


def test_dmd2_compile_disabled(dmd2_model_not_compiled):
assert not _is_compiled(dmd2_model_not_compiled.net)
assert not _is_compiled(dmd2_model_not_compiled.teacher)
assert not _is_compiled(dmd2_model_not_compiled.fake_score)


def test_sft_compiled_train_step(sft_model_compiled):
model = sft_model_compiled
model.on_train_begin()
model.init_optimizers()

batch_size = 1
labels = torch.nn.functional.one_hot(torch.randint(0, 10, (batch_size,)), num_classes=10).float()
data = {
"real": torch.randn(batch_size, 3, 8, 8).to(model.device, model.precision),
"condition": labels.to(model.device, model.precision),
"neg_condition": torch.zeros(batch_size, 10).to(model.device, model.precision),
}

loss_map, outputs = model.single_train_step(data, 0)
assert "total_loss" in loss_map
assert not torch.isnan(loss_map["total_loss"])
loss_map["total_loss"].backward()


def test_dmd2_compiled_train_step(dmd2_model_compiled):
model = dmd2_model_compiled
model.on_train_begin()
model.init_optimizers()

batch_size = 1
labels = torch.nn.functional.one_hot(torch.randint(0, 10, (batch_size,)), num_classes=10)
data = {
"real": torch.randn(batch_size, 3, 8, 8).to(model.device, model.precision),
"condition": labels.to(model.device, model.precision),
"neg_condition": torch.zeros(batch_size, 10).to(model.device, model.precision),
}

# Student update step
loss_map, outputs = model.single_train_step(data, 0)
assert "total_loss" in loss_map
assert not torch.isnan(loss_map["total_loss"])

# Fake score update step
model.optimizers_zero_grad(1)
loss_map, outputs = model.single_train_step(data, 1)
assert "total_loss" in loss_map
assert not torch.isnan(loss_map["total_loss"])