Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2)#911
Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2)#911
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #911 +/- ##
==========================================
- Coverage 73.11% 73.10% -0.02%
==========================================
Files 205 205
Lines 22281 22281
==========================================
- Hits 16291 16288 -3
- Misses 5990 5993 +3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
jingyu-ml
left a comment
There was a problem hiding this comment.
Left some comments, overall it looks good to me.
| return False | ||
|
|
||
|
|
||
| def _merge_diffusion_transformer_with_non_transformer_components( |
There was a problem hiding this comment.
For now, this seems to work only for LTX2.
Are these mapping relationships hard-coded? If so, we should move this logic into a model-dependent function, for example:
model_type = LTX2
merge_function[LTX2](...)There was a problem hiding this comment.
and this function needs to be moved to https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/export/diffusers_utils.py
| metadata_full: dict[str, str] = {} | ||
| if merged_base_safetensor_path is not None: | ||
| cpu_state_dict, metadata_full = ( | ||
| _merge_diffusion_transformer_with_non_transformer_components( |
There was a problem hiding this comment.
As noted above, this function should be model-dependent
merge_function[model_type](...)
| metadata["_export_format"] = "safetensors_state_dict" | ||
| metadata["_class_name"] = type(component).__name__ | ||
|
|
||
| if hf_quant_config is not None: |
There was a problem hiding this comment.
should add more checks to make it more safer
if hf_quant_config is not None and merged_base_safetensor_path is not None:
…del(e.g., LTX-2) Signed-off-by: ynankani <ynankani@nvidia.com>
…del(e.g., LTX-2) Signed-off-by: ynankani <ynankani@nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com>
69107c0 to
7e74e91
Compare
📝 WalkthroughWalkthroughIntroduces support for merging a base safetensors checkpoint into exported diffusion transformer models. Adds utility functions to detect model type and merge transformer state dicts with base checkpoint data, then integrates this merge workflow into the export pipeline with quantization metadata support. Changes
Sequence DiagramsequenceDiagram
actor User
participant Exporter as Export Pipeline
participant TypeDetector as Model Type Detector
participant MergeRegistry as Merge Registry
participant Merger as Merge Function
participant BaseCheckpoint as Base Safetensors
participant Transformer as Transformer Dict
participant MetadataHandler as Metadata Handler
participant Output as Safetensors Output
User->>Exporter: export_hf_checkpoint(model, merged_base_safetensor_path)
Exporter->>TypeDetector: get_diffusion_model_type(pipe)
TypeDetector-->>Exporter: model_type (e.g., "ltx2")
Exporter->>MergeRegistry: DIFFUSION_MERGE_FUNCTIONS[model_type]
MergeRegistry-->>Exporter: _merge_ltx2 function
Exporter->>Merger: _merge_ltx2(transformer_state_dict, base_path)
Merger->>BaseCheckpoint: read VAE, vocoder, embeddings
BaseCheckpoint-->>Merger: base components
Merger->>Transformer: merge base components with transformer keys
Transformer-->>Merger: merged_state_dict
Exporter->>MetadataHandler: attach quantization_config & metadata
MetadataHandler-->>Exporter: enriched state_dict
Exporter->>Output: save model.safetensors with metadata
Output-->>User: exported checkpoint
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
modelopt/torch/export/diffusers_utils.py (1)
681-723: Double file read:load_file+safe_openon the same checkpoint.The base safetensors file is read twice: once via
load_file(line 681) to get tensors, and again viasafe_open(line 719) to get metadata. This means parsing a potentially multi-GB file twice.You can read metadata in the same
safe_opencall and also load tensors from it, or usesafe_openfor both purposes:Proposed: single-pass using safe_open
- base_state = load_file(merged_base_safetensor_path) + with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f: + base_metadata = f.metadata() or {} + base_state = {key: f.get_tensor(key) for key in f.keys()} non_transformer_prefixes = [ ... ] ... merged = dict(base_non_transformer) merged.update(base_connectors) merged.update(prefixed) - with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f: - base_metadata = f.metadata() or {} del base_state return merged, base_metadata🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/diffusers_utils.py` around lines 681 - 723, The code currently calls load_file(merged_base_safetensor_path) to populate base_state and later re-opens the same file via safe_open(merged_base_safetensor_path, ...) just to read metadata, causing a double read; replace this by using safe_open once to both access tensors and metadata: open the safetensor with safe_open(merged_base_safetensor_path, framework="pt", device="cpu"), read tensors into base_state from that handle (instead of load_file), extract base_metadata from f.metadata(), then proceed to build base_non_transformer, base_connectors, prefixed, merged and return merged and base_metadata; update references to base_state accordingly and remove the redundant load_file call.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/export/diffusers_utils.pymodelopt/torch/export/unified_export_hf.py
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 885-887: The code unconditionally calls
get_diffusion_model_type(pipe) when merged_base_safetensor_path is set, which
will raise a ValueError for non-LTX-2 diffusers; update export_hf_checkpoint to
first check the pipeline type (or a predicate like is_ltx2_pipeline(pipe))
before calling get_diffusion_model_type, and if merged_base_safetensor_path is
provided for an unsupported pipeline either raise a clearer, descriptive error
mentioning export_hf_checkpoint and merged_base_safetensor_path or document this
constraint in the function docstring so users aren’t met with an opaque
ValueError from get_diffusion_model_type.
- Around line 960-968: The current path calls
_save_component_state_dict_safetensors(component, component_export_dir,
merged_base_safetensor_path, model_type=...) for any component that doesn't
implement save_pretrained, which unintentionally applies the
merged_base_safetensor_path merge to non-quantized components; update the logic
so merged_base_safetensor_path is only passed when the component is quantized
(e.g., detect quantization via a flag or type check before calling
_save_component_state_dict_safetensors) or call
_save_component_state_dict_safetensors without merged_base_safetensor_path for
non-quantized components, ensuring references to
_save_component_state_dict_safetensors, merged_base_safetensor_path,
component.save_pretrained and model_type are used to locate and change the code.
- Around line 136-169: The computed metadata (metadata and metadata_full) is
discarded for non-merge exports because save_file is only given metadata when
merged_base_safetensor_path is not None; change the save_file call in
unified_export_hf.py to always pass the assembled metadata (use metadata_full
which is updated with metadata) instead of conditionally passing None — update
the save_file invocation (function save_file, variables metadata_full,
merged_base_safetensor_path, cpu_state_dict, component_export_dir, component) to
use metadata=metadata_full unconditionally so _export_format and _class_name are
preserved for all exports.
---
Nitpick comments:
In `@modelopt/torch/export/diffusers_utils.py`:
- Around line 681-723: The code currently calls
load_file(merged_base_safetensor_path) to populate base_state and later re-opens
the same file via safe_open(merged_base_safetensor_path, ...) just to read
metadata, causing a double read; replace this by using safe_open once to both
access tensors and metadata: open the safetensor with
safe_open(merged_base_safetensor_path, framework="pt", device="cpu"), read
tensors into base_state from that handle (instead of load_file), extract
base_metadata from f.metadata(), then proceed to build base_non_transformer,
base_connectors, prefixed, merged and return merged and base_metadata; update
references to base_state accordingly and remove the redundant load_file call.
| metadata: dict[str, str] = {} | ||
| metadata_full: dict[str, str] = {} | ||
| if merged_base_safetensor_path is not None and model_type is not None: | ||
| merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type] | ||
| cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path) | ||
| if hf_quant_config is not None: | ||
| metadata_full["quantization_config"] = json.dumps(hf_quant_config) | ||
|
|
||
| # Build per-layer _quantization_metadata for ComfyUI | ||
| quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() | ||
| layer_metadata = {} | ||
| for k in cpu_state_dict: | ||
| if k.endswith((".weight_scale", ".weight_scale_2")): | ||
| layer_name = k.rsplit(".", 1)[0] | ||
| if layer_name.endswith(".weight"): | ||
| layer_name = layer_name.rsplit(".", 1)[0] | ||
| if layer_name not in layer_metadata: | ||
| layer_metadata[layer_name] = {"format": quant_algo} | ||
| metadata_full["_quantization_metadata"] = json.dumps( | ||
| { | ||
| "format_version": "1.0", | ||
| "layers": layer_metadata, | ||
| } | ||
| ) | ||
|
|
||
| metadata["_export_format"] = "safetensors_state_dict" | ||
| metadata["_class_name"] = type(component).__name__ | ||
| metadata_full.update(metadata) | ||
|
|
||
| save_file( | ||
| cpu_state_dict, | ||
| str(component_export_dir / "model.safetensors"), | ||
| metadata=metadata_full if merged_base_safetensor_path is not None else None, | ||
| ) |
There was a problem hiding this comment.
Metadata is discarded for non-merge exports.
On line 168, metadata is only passed to save_file when merged_base_safetensor_path is not None. For non-merge exports through this function, the _export_format and _class_name metadata (lines 161-162) are computed but thrown away — save_file is called with metadata=None.
If metadata should always be attached (even for non-merge exports), pass it unconditionally:
Proposed fix
save_file(
cpu_state_dict,
str(component_export_dir / "model.safetensors"),
- metadata=metadata_full if merged_base_safetensor_path is not None else None,
+ metadata=metadata_full,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| metadata: dict[str, str] = {} | |
| metadata_full: dict[str, str] = {} | |
| if merged_base_safetensor_path is not None and model_type is not None: | |
| merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type] | |
| cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path) | |
| if hf_quant_config is not None: | |
| metadata_full["quantization_config"] = json.dumps(hf_quant_config) | |
| # Build per-layer _quantization_metadata for ComfyUI | |
| quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() | |
| layer_metadata = {} | |
| for k in cpu_state_dict: | |
| if k.endswith((".weight_scale", ".weight_scale_2")): | |
| layer_name = k.rsplit(".", 1)[0] | |
| if layer_name.endswith(".weight"): | |
| layer_name = layer_name.rsplit(".", 1)[0] | |
| if layer_name not in layer_metadata: | |
| layer_metadata[layer_name] = {"format": quant_algo} | |
| metadata_full["_quantization_metadata"] = json.dumps( | |
| { | |
| "format_version": "1.0", | |
| "layers": layer_metadata, | |
| } | |
| ) | |
| metadata["_export_format"] = "safetensors_state_dict" | |
| metadata["_class_name"] = type(component).__name__ | |
| metadata_full.update(metadata) | |
| save_file( | |
| cpu_state_dict, | |
| str(component_export_dir / "model.safetensors"), | |
| metadata=metadata_full if merged_base_safetensor_path is not None else None, | |
| ) | |
| metadata: dict[str, str] = {} | |
| metadata_full: dict[str, str] = {} | |
| if merged_base_safetensor_path is not None and model_type is not None: | |
| merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type] | |
| cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path) | |
| if hf_quant_config is not None: | |
| metadata_full["quantization_config"] = json.dumps(hf_quant_config) | |
| # Build per-layer _quantization_metadata for ComfyUI | |
| quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() | |
| layer_metadata = {} | |
| for k in cpu_state_dict: | |
| if k.endswith((".weight_scale", ".weight_scale_2")): | |
| layer_name = k.rsplit(".", 1)[0] | |
| if layer_name.endswith(".weight"): | |
| layer_name = layer_name.rsplit(".", 1)[0] | |
| if layer_name not in layer_metadata: | |
| layer_metadata[layer_name] = {"format": quant_algo} | |
| metadata_full["_quantization_metadata"] = json.dumps( | |
| { | |
| "format_version": "1.0", | |
| "layers": layer_metadata, | |
| } | |
| ) | |
| metadata["_export_format"] = "safetensors_state_dict" | |
| metadata["_class_name"] = type(component).__name__ | |
| metadata_full.update(metadata) | |
| save_file( | |
| cpu_state_dict, | |
| str(component_export_dir / "model.safetensors"), | |
| metadata=metadata_full, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/unified_export_hf.py` around lines 136 - 169, The
computed metadata (metadata and metadata_full) is discarded for non-merge
exports because save_file is only given metadata when
merged_base_safetensor_path is not None; change the save_file call in
unified_export_hf.py to always pass the assembled metadata (use metadata_full
which is updated with metadata) instead of conditionally passing None — update
the save_file invocation (function save_file, variables metadata_full,
merged_base_safetensor_path, cpu_state_dict, component_export_dir, component) to
use metadata=metadata_full unconditionally so _export_format and _class_name are
preserved for all exports.
| # Resolve model type once (only needed when merging with a base checkpoint) | ||
| model_type = get_diffusion_model_type(pipe) if merged_base_safetensor_path else None | ||
|
|
There was a problem hiding this comment.
get_diffusion_model_type is called unconditionally when merged_base_safetensor_path is truthy — will raise ValueError for non-LTX-2 diffusers pipelines.
If a user passes merged_base_safetensor_path for a standard diffusers pipeline (e.g., StableDiffusion), get_diffusion_model_type(pipe) will raise ValueError with a somewhat opaque message. Consider validating earlier or documenting this limitation more prominently in export_hf_checkpoint's docstring.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/unified_export_hf.py` around lines 885 - 887, The code
unconditionally calls get_diffusion_model_type(pipe) when
merged_base_safetensor_path is set, which will raise a ValueError for non-LTX-2
diffusers; update export_hf_checkpoint to first check the pipeline type (or a
predicate like is_ltx2_pipeline(pipe)) before calling get_diffusion_model_type,
and if merged_base_safetensor_path is provided for an unsupported pipeline
either raise a clearer, descriptive error mentioning export_hf_checkpoint and
merged_base_safetensor_path or document this constraint in the function
docstring so users aren’t met with an opaque ValueError from
get_diffusion_model_type.
| elif hasattr(component, "save_pretrained"): | ||
| component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) | ||
| else: | ||
| _save_component_state_dict_safetensors(component, component_export_dir) | ||
| _save_component_state_dict_safetensors( | ||
| component, | ||
| component_export_dir, | ||
| merged_base_safetensor_path, | ||
| model_type=model_type, | ||
| ) |
There was a problem hiding this comment.
Non-quantized components also receive merged_base_safetensor_path — unintentional merge?
When a non-quantized component falls through to _save_component_state_dict_safetensors (lines 963-968), it receives merged_base_safetensor_path and model_type. This means the merge function will run on the non-quantized component's state dict too, adding all non-transformer base weights (VAE, vocoder, etc.) into it.
In the current LTX-2 flow there's only one component, so this is harmless. But for future model types with multiple components, this would produce incorrect merged checkpoints for non-quantized components.
Consider guarding by only passing merged_base_safetensor_path for quantized components, or adding a comment clarifying the assumption:
Proposed safeguard
else:
_save_component_state_dict_safetensors(
component,
component_export_dir,
- merged_base_safetensor_path,
- model_type=model_type,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| elif hasattr(component, "save_pretrained"): | |
| component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) | |
| else: | |
| _save_component_state_dict_safetensors(component, component_export_dir) | |
| _save_component_state_dict_safetensors( | |
| component, | |
| component_export_dir, | |
| merged_base_safetensor_path, | |
| model_type=model_type, | |
| ) | |
| elif hasattr(component, "save_pretrained"): | |
| component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) | |
| else: | |
| _save_component_state_dict_safetensors( | |
| component, | |
| component_export_dir, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/unified_export_hf.py` around lines 960 - 968, The
current path calls _save_component_state_dict_safetensors(component,
component_export_dir, merged_base_safetensor_path, model_type=...) for any
component that doesn't implement save_pretrained, which unintentionally applies
the merged_base_safetensor_path merge to non-quantized components; update the
logic so merged_base_safetensor_path is only passed when the component is
quantized (e.g., detect quantization via a flag or type check before calling
_save_component_state_dict_safetensors) or call
_save_component_state_dict_safetensors without merged_base_safetensor_path for
non-quantized components, ensuring references to
_save_component_state_dict_safetensors, merged_base_safetensor_path,
component.save_pretrained and model_type are used to locate and change the code.
What does this PR do
Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2)
Type of change:
Overview:
Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2)
Usage
Testing
a) initializing a twoStagePipeline object
b) calling mtq.quantize on transformer with NVFP4_DEFAULT_CFG
c) then exporting with export_hf_checkpoint passing the param merged_base_safetensor_path to generate merged
checkpoint
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Release Notes