Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions modelopt/torch/export/diffusers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import torch
import torch.nn as nn
from safetensors.torch import load_file, safe_open

from .layer_utils import is_quantlinear

Expand Down Expand Up @@ -656,3 +657,96 @@ def infer_dtype_from_model(model: nn.Module) -> torch.dtype:
for param in model.parameters():
return param.dtype
return torch.float16


def _merge_ltx2(
diffusion_transformer_state_dict: dict[str, torch.Tensor],
merged_base_safetensor_path: str,
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
"""Merge LTX-2 transformer weights with non-transformer components.

Non-transformer components (VAE, vocoder, text encoders) and embeddings
connectors are taken from the base checkpoint. Transformer keys are
re-prefixed with ``model.diffusion_model.`` for ComfyUI compatibility.

Args:
diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU).
merged_base_safetensor_path: Path to the full base model safetensors file containing
all components (transformer, VAE, vocoder, etc.).

Returns:
Tuple of (merged_state_dict, base_metadata) where base_metadata is the original
safetensors metadata from the base checkpoint.
"""
base_state = load_file(merged_base_safetensor_path)

non_transformer_prefixes = [
"vae.",
"audio_vae.",
"vocoder.",
"text_embedding_projection.",
"text_encoders.",
"first_stage_model.",
"cond_stage_model.",
"conditioner.",
]
correct_prefix = "model.diffusion_model."
strip_prefixes = ["diffusion_model.", "transformer.", "_orig_mod.", "model.", "velocity_model."]

base_non_transformer = {
k: v
for k, v in base_state.items()
if any(k.startswith(p) for p in non_transformer_prefixes)
}
base_connectors = {
k: v
for k, v in base_state.items()
if "embeddings_connector" in k and k.startswith(correct_prefix)
}

prefixed = {}
for k, v in diffusion_transformer_state_dict.items():
clean_k = k
for prefix in strip_prefixes:
if clean_k.startswith(prefix):
clean_k = clean_k[len(prefix) :]
break
prefixed[f"{correct_prefix}{clean_k}"] = v

merged = dict(base_non_transformer)
merged.update(base_connectors)
merged.update(prefixed)
with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f:
base_metadata = f.metadata() or {}

del base_state
return merged, base_metadata


DIFFUSION_MERGE_FUNCTIONS: dict[str, Callable] = {
"ltx2": _merge_ltx2,
}


def get_diffusion_model_type(pipe: Any) -> str:
"""Detect the diffusion model type for merge function dispatch.

To add a new model type, add a detection clause here and a corresponding
merge function in ``DIFFUSION_MERGE_FUNCTIONS``.

Args:
pipe: The pipeline or component being exported.

Returns:
A string key into ``DIFFUSION_MERGE_FUNCTIONS``.

Raises:
ValueError: If the model type is not supported.
"""
if TI2VidTwoStagesPipeline is not None and isinstance(pipe, TI2VidTwoStagesPipeline):
return "ltx2"

raise ValueError(
f"No merge function for model type '{type(pipe).__name__}'. "
"Add an entry to DIFFUSION_MERGE_FUNCTIONS in diffusers_utils.py."
)
100 changes: 83 additions & 17 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
import diffusers

from .diffusers_utils import (
DIFFUSION_MERGE_FUNCTIONS,
generate_diffusion_dummy_forward_fn,
get_diffusion_components,
get_diffusion_model_type,
get_qkv_group_key,
hide_quantizers_from_state_dict,
infer_dtype_from_model,
Expand Down Expand Up @@ -112,19 +114,62 @@ def _is_enabled_quantizer(quantizer):


def _save_component_state_dict_safetensors(
component: nn.Module, component_export_dir: Path
component: nn.Module,
component_export_dir: Path,
merged_base_safetensor_path: str | None = None,
hf_quant_config: dict | None = None,
model_type: str | None = None,
) -> None:
"""Save component state dict as safetensors with optional base checkpoint merge.

Args:
component: The nn.Module to save.
component_export_dir: Directory to save model.safetensors and config.json.
merged_base_safetensor_path: If provided, merge with non-transformer components
from this base safetensors file.
hf_quant_config: If provided, embed quantization config in safetensors metadata
and per-layer _quantization_metadata for ComfyUI.
model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` for the model-specific merge.
Required when ``merged_base_safetensor_path`` is not None.
"""
cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()}
save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"))
metadata: dict[str, str] = {}
metadata_full: dict[str, str] = {}
if merged_base_safetensor_path is not None and model_type is not None:
merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type]
cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path)
if hf_quant_config is not None:
metadata_full["quantization_config"] = json.dumps(hf_quant_config)

# Build per-layer _quantization_metadata for ComfyUI
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
layer_metadata = {}
for k in cpu_state_dict:
if k.endswith((".weight_scale", ".weight_scale_2")):
layer_name = k.rsplit(".", 1)[0]
if layer_name.endswith(".weight"):
layer_name = layer_name.rsplit(".", 1)[0]
if layer_name not in layer_metadata:
layer_metadata[layer_name] = {"format": quant_algo}
metadata_full["_quantization_metadata"] = json.dumps(
{
"format_version": "1.0",
"layers": layer_metadata,
}
)

metadata["_export_format"] = "safetensors_state_dict"
metadata["_class_name"] = type(component).__name__
metadata_full.update(metadata)

save_file(
cpu_state_dict,
str(component_export_dir / "model.safetensors"),
metadata=metadata_full if merged_base_safetensor_path is not None else None,
)
Comment on lines +136 to +169
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Metadata is discarded for non-merge exports.

On line 168, metadata is only passed to save_file when merged_base_safetensor_path is not None. For non-merge exports through this function, the _export_format and _class_name metadata (lines 161-162) are computed but thrown away — save_file is called with metadata=None.

If metadata should always be attached (even for non-merge exports), pass it unconditionally:

Proposed fix
     save_file(
         cpu_state_dict,
         str(component_export_dir / "model.safetensors"),
-        metadata=metadata_full if merged_base_safetensor_path is not None else None,
+        metadata=metadata_full,
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
metadata: dict[str, str] = {}
metadata_full: dict[str, str] = {}
if merged_base_safetensor_path is not None and model_type is not None:
merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type]
cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path)
if hf_quant_config is not None:
metadata_full["quantization_config"] = json.dumps(hf_quant_config)
# Build per-layer _quantization_metadata for ComfyUI
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
layer_metadata = {}
for k in cpu_state_dict:
if k.endswith((".weight_scale", ".weight_scale_2")):
layer_name = k.rsplit(".", 1)[0]
if layer_name.endswith(".weight"):
layer_name = layer_name.rsplit(".", 1)[0]
if layer_name not in layer_metadata:
layer_metadata[layer_name] = {"format": quant_algo}
metadata_full["_quantization_metadata"] = json.dumps(
{
"format_version": "1.0",
"layers": layer_metadata,
}
)
metadata["_export_format"] = "safetensors_state_dict"
metadata["_class_name"] = type(component).__name__
metadata_full.update(metadata)
save_file(
cpu_state_dict,
str(component_export_dir / "model.safetensors"),
metadata=metadata_full if merged_base_safetensor_path is not None else None,
)
metadata: dict[str, str] = {}
metadata_full: dict[str, str] = {}
if merged_base_safetensor_path is not None and model_type is not None:
merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type]
cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path)
if hf_quant_config is not None:
metadata_full["quantization_config"] = json.dumps(hf_quant_config)
# Build per-layer _quantization_metadata for ComfyUI
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
layer_metadata = {}
for k in cpu_state_dict:
if k.endswith((".weight_scale", ".weight_scale_2")):
layer_name = k.rsplit(".", 1)[0]
if layer_name.endswith(".weight"):
layer_name = layer_name.rsplit(".", 1)[0]
if layer_name not in layer_metadata:
layer_metadata[layer_name] = {"format": quant_algo}
metadata_full["_quantization_metadata"] = json.dumps(
{
"format_version": "1.0",
"layers": layer_metadata,
}
)
metadata["_export_format"] = "safetensors_state_dict"
metadata["_class_name"] = type(component).__name__
metadata_full.update(metadata)
save_file(
cpu_state_dict,
str(component_export_dir / "model.safetensors"),
metadata=metadata_full,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 136 - 169, The
computed metadata (metadata and metadata_full) is discarded for non-merge
exports because save_file is only given metadata when
merged_base_safetensor_path is not None; change the save_file call in
unified_export_hf.py to always pass the assembled metadata (use metadata_full
which is updated with metadata) instead of conditionally passing None — update
the save_file invocation (function save_file, variables metadata_full,
merged_base_safetensor_path, cpu_state_dict, component_export_dir, component) to
use metadata=metadata_full unconditionally so _export_format and _class_name are
preserved for all exports.


with open(component_export_dir / "config.json", "w") as f:
json.dump(
{
"_class_name": type(component).__name__,
"_export_format": "safetensors_state_dict",
},
f,
indent=4,
)
json.dump(metadata, f, indent=4)


def _collect_shared_input_modules(
Expand Down Expand Up @@ -807,6 +852,7 @@ def _export_diffusers_checkpoint(
dtype: torch.dtype | None,
export_dir: Path,
components: list[str] | None,
merged_base_safetensor_path: str | None = None,
max_shard_size: int | str = "10GB",
) -> None:
"""Internal: Export diffusion(-like) model/pipeline checkpoint.
Expand All @@ -821,6 +867,8 @@ def _export_diffusers_checkpoint(
export_dir: The directory to save the exported checkpoint.
components: Optional list of component names to export. Only used for pipelines.
If None, all components are exported.
merged_base_safetensor_path: If provided, merge the exported transformer with
non-transformer components from this base safetensors file.
max_shard_size: Maximum size of each shard file. If the model exceeds this size,
it will be sharded into multiple files and a .safetensors.index.json will be
created. Use smaller values like "5GB" or "2GB" to force sharding.
Expand All @@ -834,6 +882,9 @@ def _export_diffusers_checkpoint(
warnings.warn("No exportable components found in the model.")
return

# Resolve model type once (only needed when merging with a base checkpoint)
model_type = get_diffusion_model_type(pipe) if merged_base_safetensor_path else None

Comment on lines +885 to +887
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

get_diffusion_model_type is called unconditionally when merged_base_safetensor_path is truthy — will raise ValueError for non-LTX-2 diffusers pipelines.

If a user passes merged_base_safetensor_path for a standard diffusers pipeline (e.g., StableDiffusion), get_diffusion_model_type(pipe) will raise ValueError with a somewhat opaque message. Consider validating earlier or documenting this limitation more prominently in export_hf_checkpoint's docstring.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 885 - 887, The code
unconditionally calls get_diffusion_model_type(pipe) when
merged_base_safetensor_path is set, which will raise a ValueError for non-LTX-2
diffusers; update export_hf_checkpoint to first check the pipeline type (or a
predicate like is_ltx2_pipeline(pipe)) before calling get_diffusion_model_type,
and if merged_base_safetensor_path is provided for an unsupported pipeline
either raise a clearer, descriptive error mentioning export_hf_checkpoint and
merged_base_safetensor_path or document this constraint in the function
docstring so users aren’t met with an opaque ValueError from
get_diffusion_model_type.

# Separate nn.Module components for quantization-aware export
module_components = {
name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module)
Expand Down Expand Up @@ -879,6 +930,7 @@ def _export_diffusers_checkpoint(

# Step 5: Build quantization config
quant_config = get_quant_config(component, is_modelopt_qlora=False)
hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None

# Step 6: Save the component
# - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter
Expand All @@ -888,12 +940,15 @@ def _export_diffusers_checkpoint(
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
with hide_quantizers_from_state_dict(component):
_save_component_state_dict_safetensors(component, component_export_dir)

_save_component_state_dict_safetensors(
component,
component_export_dir,
merged_base_safetensor_path,
hf_quant_config,
model_type,
)
# Step 7: Update config.json with quantization info
if quant_config is not None:
hf_quant_config = convert_hf_quant_config_format(quant_config)

if hf_quant_config is not None:
config_path = component_export_dir / "config.json"
if config_path.exists():
with open(config_path) as file:
Expand All @@ -905,7 +960,12 @@ def _export_diffusers_checkpoint(
elif hasattr(component, "save_pretrained"):
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
_save_component_state_dict_safetensors(component, component_export_dir)
_save_component_state_dict_safetensors(
component,
component_export_dir,
merged_base_safetensor_path,
model_type=model_type,
)
Comment on lines 960 to +968
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Non-quantized components also receive merged_base_safetensor_path — unintentional merge?

When a non-quantized component falls through to _save_component_state_dict_safetensors (lines 963-968), it receives merged_base_safetensor_path and model_type. This means the merge function will run on the non-quantized component's state dict too, adding all non-transformer base weights (VAE, vocoder, etc.) into it.

In the current LTX-2 flow there's only one component, so this is harmless. But for future model types with multiple components, this would produce incorrect merged checkpoints for non-quantized components.

Consider guarding by only passing merged_base_safetensor_path for quantized components, or adding a comment clarifying the assumption:

Proposed safeguard
         else:
             _save_component_state_dict_safetensors(
                 component,
                 component_export_dir,
-                merged_base_safetensor_path,
-                model_type=model_type,
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif hasattr(component, "save_pretrained"):
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
_save_component_state_dict_safetensors(component, component_export_dir)
_save_component_state_dict_safetensors(
component,
component_export_dir,
merged_base_safetensor_path,
model_type=model_type,
)
elif hasattr(component, "save_pretrained"):
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
_save_component_state_dict_safetensors(
component,
component_export_dir,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 960 - 968, The
current path calls _save_component_state_dict_safetensors(component,
component_export_dir, merged_base_safetensor_path, model_type=...) for any
component that doesn't implement save_pretrained, which unintentionally applies
the merged_base_safetensor_path merge to non-quantized components; update the
logic so merged_base_safetensor_path is only passed when the component is
quantized (e.g., detect quantization via a flag or type check before calling
_save_component_state_dict_safetensors) or call
_save_component_state_dict_safetensors without merged_base_safetensor_path for
non-quantized components, ensuring references to
_save_component_state_dict_safetensors, merged_base_safetensor_path,
component.save_pretrained and model_type are used to locate and change the code.


print(f" Saved to: {component_export_dir}")

Expand Down Expand Up @@ -985,6 +1045,7 @@ def export_hf_checkpoint(
save_modelopt_state: bool = False,
components: list[str] | None = None,
extra_state_dict: dict[str, torch.Tensor] | None = None,
merged_base_safetensor_path: str | None = None,
):
"""Export quantized HuggingFace model checkpoint (transformers or diffusers).

Expand All @@ -1002,6 +1063,9 @@ def export_hf_checkpoint(
components: Only used for diffusers pipelines. Optional list of component names
to export. If None, all quantized components are exported.
extra_state_dict: Extra state dictionary to add to the exported model.
merged_base_safetensor_path: If provided, merge the exported diffusion transformer
with non-transformer components (VAE, vocoder, etc.) from this base safetensors
file. Only used for diffusion model exports (e.g., LTX-2).
"""
export_dir = Path(export_dir)
export_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -1010,7 +1074,9 @@ def export_hf_checkpoint(
if HAS_DIFFUSERS:
is_diffusers_obj = is_diffusers_object(model)
if is_diffusers_obj:
_export_diffusers_checkpoint(model, dtype, export_dir, components)
_export_diffusers_checkpoint(
model, dtype, export_dir, components, merged_base_safetensor_path
)
return

# Transformers model export
Expand Down