Skip to content

Comments

Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2)#911

Open
ynankani wants to merge 4 commits intomainfrom
ynankani/ltx2_comfyui_checkpoint
Open

Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2)#911
ynankani wants to merge 4 commits intomainfrom
ynankani/ltx2_comfyui_checkpoint

Conversation

@ynankani
Copy link
Contributor

@ynankani ynankani commented Feb 20, 2026

What does this PR do

Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2)

Type of change:

Overview:
Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2)

  1. Added a a parameter for merging the base vae, vocoder, connectors in the quantized checkpoint
  2. storing quantization metadata and export tool as modelopt , required for ComfyUI compatibility.
  3. Internally updating the transformer block prefixes to match the expectation of ComfyUI

Usage

    export_hf_checkpoint(
        pipeline,
        export_dir=EXPORT_DIR,
        merged_base_safetensor_path=BASE_CKPT,  # merge VAE/vocoder from base
    )

Testing

  1. Tested with ltx-2 model
    a) initializing a twoStagePipeline object
    b) calling mtq.quantize on transformer with NVFP4_DEFAULT_CFG
    c) then exporting with export_hf_checkpoint passing the param merged_base_safetensor_path to generate merged
    checkpoint
  2. Ran the generated checkpoint with step1 on ComfyUI to validate
  3. Ran step1 without merged_base_safetensor_path to check backward compatibility.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: NA
  • Did you add or update any necessary documentation?: NA
  • Did you update Changelog?: NA

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features
    • Added support for exporting LTX-2 diffusion models with merged base checkpoint integration
    • Enhanced export functionality to preserve and attach quantization metadata during model export
    • Extended model export capabilities with automatic model type detection for improved export handling

@ynankani ynankani requested a review from a team as a code owner February 20, 2026 14:40
@ynankani ynankani requested a review from Edwardf0t1 February 20, 2026 14:40
@codecov
Copy link

codecov bot commented Feb 20, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.10%. Comparing base (02fa362) to head (7e74e91).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #911      +/-   ##
==========================================
- Coverage   73.11%   73.10%   -0.02%     
==========================================
  Files         205      205              
  Lines       22281    22281              
==========================================
- Hits        16291    16288       -3     
- Misses       5990     5993       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@jingyu-ml jingyu-ml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, overall it looks good to me.

return False


def _merge_diffusion_transformer_with_non_transformer_components(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, this seems to work only for LTX2.

Are these mapping relationships hard-coded? If so, we should move this logic into a model-dependent function, for example:

model_type = LTX2
merge_function[LTX2](...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metadata_full: dict[str, str] = {}
if merged_base_safetensor_path is not None:
cpu_state_dict, metadata_full = (
_merge_diffusion_transformer_with_non_transformer_components(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted above, this function should be model-dependent

merge_function[model_type](...)

metadata["_export_format"] = "safetensors_state_dict"
metadata["_class_name"] = type(component).__name__

if hf_quant_config is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add more checks to make it more safer
if hf_quant_config is not None and merged_base_safetensor_path is not None:

…del(e.g., LTX-2)

Signed-off-by: ynankani <ynankani@nvidia.com>
…del(e.g., LTX-2)

Signed-off-by: ynankani <ynankani@nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com>
@ynankani ynankani force-pushed the ynankani/ltx2_comfyui_checkpoint branch from 69107c0 to 7e74e91 Compare February 23, 2026 07:24
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 23, 2026

📝 Walkthrough

Walkthrough

Introduces support for merging a base safetensors checkpoint into exported diffusion transformer models. Adds utility functions to detect model type and merge transformer state dicts with base checkpoint data, then integrates this merge workflow into the export pipeline with quantization metadata support.

Changes

Cohort / File(s) Summary
Diffusion Merge Infrastructure
modelopt/torch/export/diffusers_utils.py
New _merge_ltx2 function merges diffusion transformer state_dict with base safetensors checkpoint, reading VAE, vocoder, embeddings connectors, and other components. Adds DIFFUSION_MERGE_FUNCTIONS registry mapping "ltx2" to the merge function. Introduces get_diffusion_model_type(pipe) to detect model types and dispatch to appropriate merger. Extends imports for safetensors file operations.
Export Pipeline Integration
modelopt/torch/export/unified_export_hf.py
Integrates diffusion merging into export workflow. Updates _save_component_state_dict_safetensors to accept merged_base_safetensor_path, hf_quant_config, and model_type parameters; conditionally merges base checkpoint and attaches quantization metadata. Extends _export_diffusers_checkpoint and export_hf_checkpoint signatures to accept merged_base_safetensor_path and thread it through component export routines. Imports merge utilities and model type detection from diffusers infrastructure.

Sequence Diagram

sequenceDiagram
    actor User
    participant Exporter as Export Pipeline
    participant TypeDetector as Model Type Detector
    participant MergeRegistry as Merge Registry
    participant Merger as Merge Function
    participant BaseCheckpoint as Base Safetensors
    participant Transformer as Transformer Dict
    participant MetadataHandler as Metadata Handler
    participant Output as Safetensors Output

    User->>Exporter: export_hf_checkpoint(model, merged_base_safetensor_path)
    Exporter->>TypeDetector: get_diffusion_model_type(pipe)
    TypeDetector-->>Exporter: model_type (e.g., "ltx2")
    
    Exporter->>MergeRegistry: DIFFUSION_MERGE_FUNCTIONS[model_type]
    MergeRegistry-->>Exporter: _merge_ltx2 function
    
    Exporter->>Merger: _merge_ltx2(transformer_state_dict, base_path)
    
    Merger->>BaseCheckpoint: read VAE, vocoder, embeddings
    BaseCheckpoint-->>Merger: base components
    
    Merger->>Transformer: merge base components with transformer keys
    Transformer-->>Merger: merged_state_dict
    
    Exporter->>MetadataHandler: attach quantization_config & metadata
    MetadataHandler-->>Exporter: enriched state_dict
    
    Exporter->>Output: save model.safetensors with metadata
    Output-->>User: exported checkpoint
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: adding support for exporting ComfyUI-compatible checkpoints for diffusion models with LTX-2 as an example.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ynankani/ltx2_comfyui_checkpoint

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
modelopt/torch/export/diffusers_utils.py (1)

681-723: Double file read: load_file + safe_open on the same checkpoint.

The base safetensors file is read twice: once via load_file (line 681) to get tensors, and again via safe_open (line 719) to get metadata. This means parsing a potentially multi-GB file twice.

You can read metadata in the same safe_open call and also load tensors from it, or use safe_open for both purposes:

Proposed: single-pass using safe_open
-    base_state = load_file(merged_base_safetensor_path)
+    with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f:
+        base_metadata = f.metadata() or {}
+        base_state = {key: f.get_tensor(key) for key in f.keys()}
 
     non_transformer_prefixes = [
         ...
     ]
     ...
 
     merged = dict(base_non_transformer)
     merged.update(base_connectors)
     merged.update(prefixed)
-    with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f:
-        base_metadata = f.metadata() or {}
 
     del base_state
     return merged, base_metadata
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/diffusers_utils.py` around lines 681 - 723, The code
currently calls load_file(merged_base_safetensor_path) to populate base_state
and later re-opens the same file via safe_open(merged_base_safetensor_path, ...)
just to read metadata, causing a double read; replace this by using safe_open
once to both access tensors and metadata: open the safetensor with
safe_open(merged_base_safetensor_path, framework="pt", device="cpu"), read
tensors into base_state from that handle (instead of load_file), extract
base_metadata from f.metadata(), then proceed to build base_non_transformer,
base_connectors, prefixed, merged and return merged and base_metadata; update
references to base_state accordingly and remove the redundant load_file call.
ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 02fa362 and 7e74e91.

📒 Files selected for processing (2)
  • modelopt/torch/export/diffusers_utils.py
  • modelopt/torch/export/unified_export_hf.py
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 885-887: The code unconditionally calls
get_diffusion_model_type(pipe) when merged_base_safetensor_path is set, which
will raise a ValueError for non-LTX-2 diffusers; update export_hf_checkpoint to
first check the pipeline type (or a predicate like is_ltx2_pipeline(pipe))
before calling get_diffusion_model_type, and if merged_base_safetensor_path is
provided for an unsupported pipeline either raise a clearer, descriptive error
mentioning export_hf_checkpoint and merged_base_safetensor_path or document this
constraint in the function docstring so users aren’t met with an opaque
ValueError from get_diffusion_model_type.
- Around line 960-968: The current path calls
_save_component_state_dict_safetensors(component, component_export_dir,
merged_base_safetensor_path, model_type=...) for any component that doesn't
implement save_pretrained, which unintentionally applies the
merged_base_safetensor_path merge to non-quantized components; update the logic
so merged_base_safetensor_path is only passed when the component is quantized
(e.g., detect quantization via a flag or type check before calling
_save_component_state_dict_safetensors) or call
_save_component_state_dict_safetensors without merged_base_safetensor_path for
non-quantized components, ensuring references to
_save_component_state_dict_safetensors, merged_base_safetensor_path,
component.save_pretrained and model_type are used to locate and change the code.
- Around line 136-169: The computed metadata (metadata and metadata_full) is
discarded for non-merge exports because save_file is only given metadata when
merged_base_safetensor_path is not None; change the save_file call in
unified_export_hf.py to always pass the assembled metadata (use metadata_full
which is updated with metadata) instead of conditionally passing None — update
the save_file invocation (function save_file, variables metadata_full,
merged_base_safetensor_path, cpu_state_dict, component_export_dir, component) to
use metadata=metadata_full unconditionally so _export_format and _class_name are
preserved for all exports.

---

Nitpick comments:
In `@modelopt/torch/export/diffusers_utils.py`:
- Around line 681-723: The code currently calls
load_file(merged_base_safetensor_path) to populate base_state and later re-opens
the same file via safe_open(merged_base_safetensor_path, ...) just to read
metadata, causing a double read; replace this by using safe_open once to both
access tensors and metadata: open the safetensor with
safe_open(merged_base_safetensor_path, framework="pt", device="cpu"), read
tensors into base_state from that handle (instead of load_file), extract
base_metadata from f.metadata(), then proceed to build base_non_transformer,
base_connectors, prefixed, merged and return merged and base_metadata; update
references to base_state accordingly and remove the redundant load_file call.

Comment on lines +136 to +169
metadata: dict[str, str] = {}
metadata_full: dict[str, str] = {}
if merged_base_safetensor_path is not None and model_type is not None:
merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type]
cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path)
if hf_quant_config is not None:
metadata_full["quantization_config"] = json.dumps(hf_quant_config)

# Build per-layer _quantization_metadata for ComfyUI
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
layer_metadata = {}
for k in cpu_state_dict:
if k.endswith((".weight_scale", ".weight_scale_2")):
layer_name = k.rsplit(".", 1)[0]
if layer_name.endswith(".weight"):
layer_name = layer_name.rsplit(".", 1)[0]
if layer_name not in layer_metadata:
layer_metadata[layer_name] = {"format": quant_algo}
metadata_full["_quantization_metadata"] = json.dumps(
{
"format_version": "1.0",
"layers": layer_metadata,
}
)

metadata["_export_format"] = "safetensors_state_dict"
metadata["_class_name"] = type(component).__name__
metadata_full.update(metadata)

save_file(
cpu_state_dict,
str(component_export_dir / "model.safetensors"),
metadata=metadata_full if merged_base_safetensor_path is not None else None,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Metadata is discarded for non-merge exports.

On line 168, metadata is only passed to save_file when merged_base_safetensor_path is not None. For non-merge exports through this function, the _export_format and _class_name metadata (lines 161-162) are computed but thrown away — save_file is called with metadata=None.

If metadata should always be attached (even for non-merge exports), pass it unconditionally:

Proposed fix
     save_file(
         cpu_state_dict,
         str(component_export_dir / "model.safetensors"),
-        metadata=metadata_full if merged_base_safetensor_path is not None else None,
+        metadata=metadata_full,
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
metadata: dict[str, str] = {}
metadata_full: dict[str, str] = {}
if merged_base_safetensor_path is not None and model_type is not None:
merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type]
cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path)
if hf_quant_config is not None:
metadata_full["quantization_config"] = json.dumps(hf_quant_config)
# Build per-layer _quantization_metadata for ComfyUI
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
layer_metadata = {}
for k in cpu_state_dict:
if k.endswith((".weight_scale", ".weight_scale_2")):
layer_name = k.rsplit(".", 1)[0]
if layer_name.endswith(".weight"):
layer_name = layer_name.rsplit(".", 1)[0]
if layer_name not in layer_metadata:
layer_metadata[layer_name] = {"format": quant_algo}
metadata_full["_quantization_metadata"] = json.dumps(
{
"format_version": "1.0",
"layers": layer_metadata,
}
)
metadata["_export_format"] = "safetensors_state_dict"
metadata["_class_name"] = type(component).__name__
metadata_full.update(metadata)
save_file(
cpu_state_dict,
str(component_export_dir / "model.safetensors"),
metadata=metadata_full if merged_base_safetensor_path is not None else None,
)
metadata: dict[str, str] = {}
metadata_full: dict[str, str] = {}
if merged_base_safetensor_path is not None and model_type is not None:
merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type]
cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path)
if hf_quant_config is not None:
metadata_full["quantization_config"] = json.dumps(hf_quant_config)
# Build per-layer _quantization_metadata for ComfyUI
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
layer_metadata = {}
for k in cpu_state_dict:
if k.endswith((".weight_scale", ".weight_scale_2")):
layer_name = k.rsplit(".", 1)[0]
if layer_name.endswith(".weight"):
layer_name = layer_name.rsplit(".", 1)[0]
if layer_name not in layer_metadata:
layer_metadata[layer_name] = {"format": quant_algo}
metadata_full["_quantization_metadata"] = json.dumps(
{
"format_version": "1.0",
"layers": layer_metadata,
}
)
metadata["_export_format"] = "safetensors_state_dict"
metadata["_class_name"] = type(component).__name__
metadata_full.update(metadata)
save_file(
cpu_state_dict,
str(component_export_dir / "model.safetensors"),
metadata=metadata_full,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 136 - 169, The
computed metadata (metadata and metadata_full) is discarded for non-merge
exports because save_file is only given metadata when
merged_base_safetensor_path is not None; change the save_file call in
unified_export_hf.py to always pass the assembled metadata (use metadata_full
which is updated with metadata) instead of conditionally passing None — update
the save_file invocation (function save_file, variables metadata_full,
merged_base_safetensor_path, cpu_state_dict, component_export_dir, component) to
use metadata=metadata_full unconditionally so _export_format and _class_name are
preserved for all exports.

Comment on lines +885 to +887
# Resolve model type once (only needed when merging with a base checkpoint)
model_type = get_diffusion_model_type(pipe) if merged_base_safetensor_path else None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

get_diffusion_model_type is called unconditionally when merged_base_safetensor_path is truthy — will raise ValueError for non-LTX-2 diffusers pipelines.

If a user passes merged_base_safetensor_path for a standard diffusers pipeline (e.g., StableDiffusion), get_diffusion_model_type(pipe) will raise ValueError with a somewhat opaque message. Consider validating earlier or documenting this limitation more prominently in export_hf_checkpoint's docstring.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 885 - 887, The code
unconditionally calls get_diffusion_model_type(pipe) when
merged_base_safetensor_path is set, which will raise a ValueError for non-LTX-2
diffusers; update export_hf_checkpoint to first check the pipeline type (or a
predicate like is_ltx2_pipeline(pipe)) before calling get_diffusion_model_type,
and if merged_base_safetensor_path is provided for an unsupported pipeline
either raise a clearer, descriptive error mentioning export_hf_checkpoint and
merged_base_safetensor_path or document this constraint in the function
docstring so users aren’t met with an opaque ValueError from
get_diffusion_model_type.

Comment on lines 960 to +968
elif hasattr(component, "save_pretrained"):
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
_save_component_state_dict_safetensors(component, component_export_dir)
_save_component_state_dict_safetensors(
component,
component_export_dir,
merged_base_safetensor_path,
model_type=model_type,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Non-quantized components also receive merged_base_safetensor_path — unintentional merge?

When a non-quantized component falls through to _save_component_state_dict_safetensors (lines 963-968), it receives merged_base_safetensor_path and model_type. This means the merge function will run on the non-quantized component's state dict too, adding all non-transformer base weights (VAE, vocoder, etc.) into it.

In the current LTX-2 flow there's only one component, so this is harmless. But for future model types with multiple components, this would produce incorrect merged checkpoints for non-quantized components.

Consider guarding by only passing merged_base_safetensor_path for quantized components, or adding a comment clarifying the assumption:

Proposed safeguard
         else:
             _save_component_state_dict_safetensors(
                 component,
                 component_export_dir,
-                merged_base_safetensor_path,
-                model_type=model_type,
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif hasattr(component, "save_pretrained"):
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
_save_component_state_dict_safetensors(component, component_export_dir)
_save_component_state_dict_safetensors(
component,
component_export_dir,
merged_base_safetensor_path,
model_type=model_type,
)
elif hasattr(component, "save_pretrained"):
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
_save_component_state_dict_safetensors(
component,
component_export_dir,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 960 - 968, The
current path calls _save_component_state_dict_safetensors(component,
component_export_dir, merged_base_safetensor_path, model_type=...) for any
component that doesn't implement save_pretrained, which unintentionally applies
the merged_base_safetensor_path merge to non-quantized components; update the
logic so merged_base_safetensor_path is only passed when the component is
quantized (e.g., detect quantization via a flag or type check before calling
_save_component_state_dict_safetensors) or call
_save_component_state_dict_safetensors without merged_base_safetensor_path for
non-quantized components, ensuring references to
_save_component_state_dict_safetensors, merged_base_safetensor_path,
component.save_pretrained and model_type are used to locate and change the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants