From 1363878d7388cf2c70446ada0ca78a6b93edff79 Mon Sep 17 00:00:00 2001 From: ynankani Date: Fri, 20 Feb 2026 05:34:43 -0800 Subject: [PATCH 1/4] Add support for export comfyui compatible checkpoint for diffusion model(e.g., LTX-2) Signed-off-by: ynankani --- modelopt/torch/export/unified_export_hf.py | 136 ++++++++++++++++++--- 1 file changed, 117 insertions(+), 19 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ca80cb450..983e18c53 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn -from safetensors.torch import save_file +from safetensors.torch import save_file, load_file, safe_open try: import diffusers @@ -111,20 +111,108 @@ def _is_enabled_quantizer(quantizer): return False +def _merge_diffusion_transformer_with_non_transformer_components( + diffusion_transformer_state_dict: dict[str, torch.Tensor], + merged_base_safetensor_path: str, +) -> tuple[dict[str, torch.Tensor], dict[str, str]]: + """Merge diffusion transformer weights with non-transformer components from a safetensors file. + + Non-transformer components (VAE, vocoder, text encoders) and embeddings connectors are + taken from the base checkpoint. Transformer keys are prefixed with 'model.diffusion_model.' + for ComfyUI compatibility. + + Args: + diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU). + merged_base_safetensor_path: Path to the full base model safetensors file containing + all components (transformer, VAE, vocoder, etc.). + + Returns: + Tuple of (merged_state_dict, base_metadata) where base_metadata is the original + safetensors metadata from the base checkpoint. + """ + + base_state = load_file(merged_base_safetensor_path) + + non_transformer_prefixes = [ + 'vae.', 'audio_vae.', 'vocoder.', 'text_embedding_projection.', + 'text_encoders.', 'first_stage_model.', 'cond_stage_model.', 'conditioner.', + ] + correct_prefix = 'model.diffusion_model.' + strip_prefixes = ['diffusion_model.', 'transformer.', '_orig_mod.', 'model.', 'velocity_model.'] + + base_non_transformer = {k: v for k, v in base_state.items() + if any(k.startswith(p) for p in non_transformer_prefixes)} + base_connectors = {k: v for k, v in base_state.items() + if 'embeddings_connector' in k and k.startswith(correct_prefix)} + + prefixed = {} + for k, v in diffusion_transformer_state_dict.items(): + clean_k = k + for prefix in strip_prefixes: + if clean_k.startswith(prefix): + clean_k = clean_k[len(prefix):] + break + prefixed[f"{correct_prefix}{clean_k}"] = v + + merged = dict(base_non_transformer) + merged.update(base_connectors) + merged.update(prefixed) + with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f: + base_metadata = f.metadata() or {} + + del base_state + return merged, base_metadata + + def _save_component_state_dict_safetensors( - component: nn.Module, component_export_dir: Path + component: nn.Module, + component_export_dir: Path, + merged_base_safetensor_path: str | None = None, + hf_quant_config: dict | None = None ) -> None: + """Save component state dict as safetensors with optional base checkpoint merge. + + Args: + component: The nn.Module to save. + component_export_dir: Directory to save model.safetensors and config.json. + merged_base_safetensor_path: If provided, merge with non-transformer components + from this base safetensors file. + hf_quant_config: If provided, embed quantization config in safetensors metadata + and per-layer _quantization_metadata for ComfyUI. + """ cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()} - save_file(cpu_state_dict, str(component_export_dir / "model.safetensors")) - with open(component_export_dir / "config.json", "w") as f: - json.dump( - { - "_class_name": type(component).__name__, - "_export_format": "safetensors_state_dict", - }, - f, - indent=4, + metadata: dict[str, str] = {} + metadata_full: dict[str, str] = {} + if merged_base_safetensor_path is not None: + cpu_state_dict, metadata_full = _merge_diffusion_transformer_with_non_transformer_components( + cpu_state_dict, merged_base_safetensor_path ) + metadata["_export_format"] = "safetensors_state_dict" + metadata["_class_name"] = type(component).__name__ + + if hf_quant_config is not None: + metadata_full["quantization_config"] = json.dumps(hf_quant_config) + + # Build per-layer _quantization_metadata for ComfyUI + quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() + layer_metadata = {} + for k in cpu_state_dict: + if k.endswith(".weight_scale") or k.endswith(".weight_scale_2"): + layer_name = k.rsplit(".", 1)[0] + if layer_name.endswith(".weight"): + layer_name = layer_name.rsplit(".", 1)[0] + if layer_name not in layer_metadata: + layer_metadata[layer_name] = {"format": quant_algo} + metadata_full["_quantization_metadata"] = json.dumps({ + "format_version": "1.0", + "layers": layer_metadata, + }) + + metadata_full.update(metadata) + save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"), metadata=metadata_full if merged_base_safetensor_path is not None else None) + + with open(component_export_dir / "config.json", "w") as f: + json.dump(metadata, f, indent=4) def _collect_shared_input_modules( @@ -807,6 +895,7 @@ def _export_diffusers_checkpoint( dtype: torch.dtype | None, export_dir: Path, components: list[str] | None, + merged_base_safetensor_path: str | None = None, max_shard_size: int | str = "10GB", ) -> None: """Internal: Export diffusion(-like) model/pipeline checkpoint. @@ -821,6 +910,8 @@ def _export_diffusers_checkpoint( export_dir: The directory to save the exported checkpoint. components: Optional list of component names to export. Only used for pipelines. If None, all components are exported. + merged_base_safetensor_path: If provided, merge the exported transformer with + non-transformer components from this base safetensors file. max_shard_size: Maximum size of each shard file. If the model exceeds this size, it will be sharded into multiple files and a .safetensors.index.json will be created. Use smaller values like "5GB" or "2GB" to force sharding. @@ -879,7 +970,8 @@ def _export_diffusers_checkpoint( # Step 5: Build quantization config quant_config = get_quant_config(component, is_modelopt_qlora=False) - + hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None + # Step 6: Save the component # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter # - for non-diffusers modules (e.g., LTX-2 transformer), fall back to torch.save @@ -888,12 +980,14 @@ def _export_diffusers_checkpoint( component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: with hide_quantizers_from_state_dict(component): - _save_component_state_dict_safetensors(component, component_export_dir) - + _save_component_state_dict_safetensors( + component, + component_export_dir, + merged_base_safetensor_path, + hf_quant_config, + ) # Step 7: Update config.json with quantization info - if quant_config is not None: - hf_quant_config = convert_hf_quant_config_format(quant_config) - + if hf_quant_config is not None: config_path = component_export_dir / "config.json" if config_path.exists(): with open(config_path) as file: @@ -905,7 +999,7 @@ def _export_diffusers_checkpoint( elif hasattr(component, "save_pretrained"): component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: - _save_component_state_dict_safetensors(component, component_export_dir) + _save_component_state_dict_safetensors(component, component_export_dir, merged_base_safetensor_path) print(f" Saved to: {component_export_dir}") @@ -985,6 +1079,7 @@ def export_hf_checkpoint( save_modelopt_state: bool = False, components: list[str] | None = None, extra_state_dict: dict[str, torch.Tensor] | None = None, + merged_base_safetensor_path: str | None = None, ): """Export quantized HuggingFace model checkpoint (transformers or diffusers). @@ -1002,6 +1097,9 @@ def export_hf_checkpoint( components: Only used for diffusers pipelines. Optional list of component names to export. If None, all quantized components are exported. extra_state_dict: Extra state dictionary to add to the exported model. + merged_base_safetensor_path: If provided, merge the exported diffusion transformer + with non-transformer components (VAE, vocoder, etc.) from this base safetensors + file. Only used for diffusion model exports (e.g., LTX-2). """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) @@ -1010,7 +1108,7 @@ def export_hf_checkpoint( if HAS_DIFFUSERS: is_diffusers_obj = is_diffusers_object(model) if is_diffusers_obj: - _export_diffusers_checkpoint(model, dtype, export_dir, components) + _export_diffusers_checkpoint(model, dtype, export_dir, components, merged_base_safetensor_path) return # Transformers model export From 6e922ae79f801fdba124dd5551164976b606d4b7 Mon Sep 17 00:00:00 2001 From: ynankani Date: Fri, 20 Feb 2026 06:07:40 -0800 Subject: [PATCH 2/4] Add support for export comfyui compatible checkpoint for diffusion model(e.g., LTX-2) Signed-off-by: ynankani --- modelopt/torch/export/unified_export_hf.py | 83 ++++++++++++++-------- 1 file changed, 53 insertions(+), 30 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 983e18c53..bd6df260c 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn -from safetensors.torch import save_file, load_file, safe_open +from safetensors.torch import load_file, safe_open, save_file try: import diffusers @@ -130,27 +130,38 @@ def _merge_diffusion_transformer_with_non_transformer_components( Tuple of (merged_state_dict, base_metadata) where base_metadata is the original safetensors metadata from the base checkpoint. """ - base_state = load_file(merged_base_safetensor_path) non_transformer_prefixes = [ - 'vae.', 'audio_vae.', 'vocoder.', 'text_embedding_projection.', - 'text_encoders.', 'first_stage_model.', 'cond_stage_model.', 'conditioner.', + "vae.", + "audio_vae.", + "vocoder.", + "text_embedding_projection.", + "text_encoders.", + "first_stage_model.", + "cond_stage_model.", + "conditioner.", ] - correct_prefix = 'model.diffusion_model.' - strip_prefixes = ['diffusion_model.', 'transformer.', '_orig_mod.', 'model.', 'velocity_model.'] + correct_prefix = "model.diffusion_model." + strip_prefixes = ["diffusion_model.", "transformer.", "_orig_mod.", "model.", "velocity_model."] - base_non_transformer = {k: v for k, v in base_state.items() - if any(k.startswith(p) for p in non_transformer_prefixes)} - base_connectors = {k: v for k, v in base_state.items() - if 'embeddings_connector' in k and k.startswith(correct_prefix)} + base_non_transformer = { + k: v + for k, v in base_state.items() + if any(k.startswith(p) for p in non_transformer_prefixes) + } + base_connectors = { + k: v + for k, v in base_state.items() + if "embeddings_connector" in k and k.startswith(correct_prefix) + } prefixed = {} for k, v in diffusion_transformer_state_dict.items(): clean_k = k for prefix in strip_prefixes: if clean_k.startswith(prefix): - clean_k = clean_k[len(prefix):] + clean_k = clean_k[len(prefix) :] break prefixed[f"{correct_prefix}{clean_k}"] = v @@ -165,10 +176,10 @@ def _merge_diffusion_transformer_with_non_transformer_components( def _save_component_state_dict_safetensors( - component: nn.Module, - component_export_dir: Path, - merged_base_safetensor_path: str | None = None, - hf_quant_config: dict | None = None + component: nn.Module, + component_export_dir: Path, + merged_base_safetensor_path: str | None = None, + hf_quant_config: dict | None = None, ) -> None: """Save component state dict as safetensors with optional base checkpoint merge. @@ -184,10 +195,12 @@ def _save_component_state_dict_safetensors( metadata: dict[str, str] = {} metadata_full: dict[str, str] = {} if merged_base_safetensor_path is not None: - cpu_state_dict, metadata_full = _merge_diffusion_transformer_with_non_transformer_components( - cpu_state_dict, merged_base_safetensor_path + cpu_state_dict, metadata_full = ( + _merge_diffusion_transformer_with_non_transformer_components( + cpu_state_dict, merged_base_safetensor_path + ) ) - metadata["_export_format"] = "safetensors_state_dict" + metadata["_export_format"] = "safetensors_state_dict" metadata["_class_name"] = type(component).__name__ if hf_quant_config is not None: @@ -197,20 +210,26 @@ def _save_component_state_dict_safetensors( quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() layer_metadata = {} for k in cpu_state_dict: - if k.endswith(".weight_scale") or k.endswith(".weight_scale_2"): + if k.endswith((".weight_scale", ".weight_scale_2")): layer_name = k.rsplit(".", 1)[0] if layer_name.endswith(".weight"): layer_name = layer_name.rsplit(".", 1)[0] if layer_name not in layer_metadata: layer_metadata[layer_name] = {"format": quant_algo} - metadata_full["_quantization_metadata"] = json.dumps({ - "format_version": "1.0", - "layers": layer_metadata, - }) + metadata_full["_quantization_metadata"] = json.dumps( + { + "format_version": "1.0", + "layers": layer_metadata, + } + ) metadata_full.update(metadata) - save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"), metadata=metadata_full if merged_base_safetensor_path is not None else None) - + save_file( + cpu_state_dict, + str(component_export_dir / "model.safetensors"), + metadata=metadata_full if merged_base_safetensor_path is not None else None, + ) + with open(component_export_dir / "config.json", "w") as f: json.dump(metadata, f, indent=4) @@ -971,7 +990,7 @@ def _export_diffusers_checkpoint( # Step 5: Build quantization config quant_config = get_quant_config(component, is_modelopt_qlora=False) hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None - + # Step 6: Save the component # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter # - for non-diffusers modules (e.g., LTX-2 transformer), fall back to torch.save @@ -981,8 +1000,8 @@ def _export_diffusers_checkpoint( else: with hide_quantizers_from_state_dict(component): _save_component_state_dict_safetensors( - component, - component_export_dir, + component, + component_export_dir, merged_base_safetensor_path, hf_quant_config, ) @@ -999,7 +1018,9 @@ def _export_diffusers_checkpoint( elif hasattr(component, "save_pretrained"): component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: - _save_component_state_dict_safetensors(component, component_export_dir, merged_base_safetensor_path) + _save_component_state_dict_safetensors( + component, component_export_dir, merged_base_safetensor_path + ) print(f" Saved to: {component_export_dir}") @@ -1108,7 +1129,9 @@ def export_hf_checkpoint( if HAS_DIFFUSERS: is_diffusers_obj = is_diffusers_object(model) if is_diffusers_obj: - _export_diffusers_checkpoint(model, dtype, export_dir, components, merged_base_safetensor_path) + _export_diffusers_checkpoint( + model, dtype, export_dir, components, merged_base_safetensor_path + ) return # Transformers model export From 09ede351f5c9a9ecfeaf53bd2b3ee629b09d2702 Mon Sep 17 00:00:00 2001 From: ynankani Date: Sun, 22 Feb 2026 23:10:34 -0800 Subject: [PATCH 3/4] Updates based on review comments Signed-off-by: ynankani --- modelopt/torch/export/diffusers_utils.py | 94 +++++++++++++++ modelopt/torch/export/unified_export_hf.py | 127 ++++++--------------- 2 files changed, 129 insertions(+), 92 deletions(-) diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index a9bf13876..c59b80853 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -23,6 +23,7 @@ import torch import torch.nn as nn +from safetensors.torch import load_file, safe_open from .layer_utils import is_quantlinear @@ -656,3 +657,96 @@ def infer_dtype_from_model(model: nn.Module) -> torch.dtype: for param in model.parameters(): return param.dtype return torch.float16 + + +def _merge_ltx2( + diffusion_transformer_state_dict: dict[str, torch.Tensor], + merged_base_safetensor_path: str, +) -> tuple[dict[str, torch.Tensor], dict[str, str]]: + """Merge LTX-2 transformer weights with non-transformer components. + + Non-transformer components (VAE, vocoder, text encoders) and embeddings + connectors are taken from the base checkpoint. Transformer keys are + re-prefixed with ``model.diffusion_model.`` for ComfyUI compatibility. + + Args: + diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU). + merged_base_safetensor_path: Path to the full base model safetensors file containing + all components (transformer, VAE, vocoder, etc.). + + Returns: + Tuple of (merged_state_dict, base_metadata) where base_metadata is the original + safetensors metadata from the base checkpoint. + """ + base_state = load_file(merged_base_safetensor_path) + + non_transformer_prefixes = [ + "vae.", + "audio_vae.", + "vocoder.", + "text_embedding_projection.", + "text_encoders.", + "first_stage_model.", + "cond_stage_model.", + "conditioner.", + ] + correct_prefix = "model.diffusion_model." + strip_prefixes = ["diffusion_model.", "transformer.", "_orig_mod.", "model.", "velocity_model."] + + base_non_transformer = { + k: v + for k, v in base_state.items() + if any(k.startswith(p) for p in non_transformer_prefixes) + } + base_connectors = { + k: v + for k, v in base_state.items() + if "embeddings_connector" in k and k.startswith(correct_prefix) + } + + prefixed = {} + for k, v in diffusion_transformer_state_dict.items(): + clean_k = k + for prefix in strip_prefixes: + if clean_k.startswith(prefix): + clean_k = clean_k[len(prefix) :] + break + prefixed[f"{correct_prefix}{clean_k}"] = v + + merged = dict(base_non_transformer) + merged.update(base_connectors) + merged.update(prefixed) + with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f: + base_metadata = f.metadata() or {} + + del base_state + return merged, base_metadata + + +DIFFUSION_MERGE_FUNCTIONS: dict[str, Callable] = { + "ltx2": _merge_ltx2, +} + + +def get_diffusion_model_type(pipe: Any) -> str: + """Detect the diffusion model type for merge function dispatch. + + To add a new model type, add a detection clause here and a corresponding + merge function in ``DIFFUSION_MERGE_FUNCTIONS``. + + Args: + pipe: The pipeline or component being exported. + + Returns: + A string key into ``DIFFUSION_MERGE_FUNCTIONS``. + + Raises: + ValueError: If the model type is not supported. + """ + if TI2VidTwoStagesPipeline is not None and isinstance(pipe, TI2VidTwoStagesPipeline): + return "ltx2" + + raise ValueError( + f"No merge function for model type '{type(pipe).__name__}'. " + "Add an entry to DIFFUSION_MERGE_FUNCTIONS in diffusers_utils.py." + ) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index bd6df260c..c8d1ed352 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -28,14 +28,16 @@ import torch import torch.nn as nn -from safetensors.torch import load_file, safe_open, save_file +from safetensors.torch import save_file try: import diffusers from .diffusers_utils import ( + DIFFUSION_MERGE_FUNCTIONS, generate_diffusion_dummy_forward_fn, get_diffusion_components, + get_diffusion_model_type, get_qkv_group_key, hide_quantizers_from_state_dict, infer_dtype_from_model, @@ -111,75 +113,12 @@ def _is_enabled_quantizer(quantizer): return False -def _merge_diffusion_transformer_with_non_transformer_components( - diffusion_transformer_state_dict: dict[str, torch.Tensor], - merged_base_safetensor_path: str, -) -> tuple[dict[str, torch.Tensor], dict[str, str]]: - """Merge diffusion transformer weights with non-transformer components from a safetensors file. - - Non-transformer components (VAE, vocoder, text encoders) and embeddings connectors are - taken from the base checkpoint. Transformer keys are prefixed with 'model.diffusion_model.' - for ComfyUI compatibility. - - Args: - diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU). - merged_base_safetensor_path: Path to the full base model safetensors file containing - all components (transformer, VAE, vocoder, etc.). - - Returns: - Tuple of (merged_state_dict, base_metadata) where base_metadata is the original - safetensors metadata from the base checkpoint. - """ - base_state = load_file(merged_base_safetensor_path) - - non_transformer_prefixes = [ - "vae.", - "audio_vae.", - "vocoder.", - "text_embedding_projection.", - "text_encoders.", - "first_stage_model.", - "cond_stage_model.", - "conditioner.", - ] - correct_prefix = "model.diffusion_model." - strip_prefixes = ["diffusion_model.", "transformer.", "_orig_mod.", "model.", "velocity_model."] - - base_non_transformer = { - k: v - for k, v in base_state.items() - if any(k.startswith(p) for p in non_transformer_prefixes) - } - base_connectors = { - k: v - for k, v in base_state.items() - if "embeddings_connector" in k and k.startswith(correct_prefix) - } - - prefixed = {} - for k, v in diffusion_transformer_state_dict.items(): - clean_k = k - for prefix in strip_prefixes: - if clean_k.startswith(prefix): - clean_k = clean_k[len(prefix) :] - break - prefixed[f"{correct_prefix}{clean_k}"] = v - - merged = dict(base_non_transformer) - merged.update(base_connectors) - merged.update(prefixed) - with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f: - base_metadata = f.metadata() or {} - - del base_state - return merged, base_metadata - - def _save_component_state_dict_safetensors( component: nn.Module, component_export_dir: Path, merged_base_safetensor_path: str | None = None, hf_quant_config: dict | None = None, + model_type: str | None = None, ) -> None: """Save component state dict as safetensors with optional base checkpoint merge. @@ -190,40 +129,39 @@ def _save_component_state_dict_safetensors( from this base safetensors file. hf_quant_config: If provided, embed quantization config in safetensors metadata and per-layer _quantization_metadata for ComfyUI. + model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` for the model-specific merge. + Required when ``merged_base_safetensor_path`` is not None. """ cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()} metadata: dict[str, str] = {} metadata_full: dict[str, str] = {} - if merged_base_safetensor_path is not None: - cpu_state_dict, metadata_full = ( - _merge_diffusion_transformer_with_non_transformer_components( - cpu_state_dict, merged_base_safetensor_path + if merged_base_safetensor_path is not None and model_type is not None: + merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type] + cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path) + if hf_quant_config is not None: + metadata_full["quantization_config"] = json.dumps(hf_quant_config) + + # Build per-layer _quantization_metadata for ComfyUI + quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() + layer_metadata = {} + for k in cpu_state_dict: + if k.endswith((".weight_scale", ".weight_scale_2")): + layer_name = k.rsplit(".", 1)[0] + if layer_name.endswith(".weight"): + layer_name = layer_name.rsplit(".", 1)[0] + if layer_name not in layer_metadata: + layer_metadata[layer_name] = {"format": quant_algo} + metadata_full["_quantization_metadata"] = json.dumps( + { + "format_version": "1.0", + "layers": layer_metadata, + } ) - ) + metadata["_export_format"] = "safetensors_state_dict" metadata["_class_name"] = type(component).__name__ - - if hf_quant_config is not None: - metadata_full["quantization_config"] = json.dumps(hf_quant_config) - - # Build per-layer _quantization_metadata for ComfyUI - quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() - layer_metadata = {} - for k in cpu_state_dict: - if k.endswith((".weight_scale", ".weight_scale_2")): - layer_name = k.rsplit(".", 1)[0] - if layer_name.endswith(".weight"): - layer_name = layer_name.rsplit(".", 1)[0] - if layer_name not in layer_metadata: - layer_metadata[layer_name] = {"format": quant_algo} - metadata_full["_quantization_metadata"] = json.dumps( - { - "format_version": "1.0", - "layers": layer_metadata, - } - ) - metadata_full.update(metadata) + save_file( cpu_state_dict, str(component_export_dir / "model.safetensors"), @@ -944,6 +882,9 @@ def _export_diffusers_checkpoint( warnings.warn("No exportable components found in the model.") return + # Resolve model type once (only needed when merging with a base checkpoint) + model_type = get_diffusion_model_type(pipe) if merged_base_safetensor_path else None + # Separate nn.Module components for quantization-aware export module_components = { name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module) @@ -1004,6 +945,7 @@ def _export_diffusers_checkpoint( component_export_dir, merged_base_safetensor_path, hf_quant_config, + model_type, ) # Step 7: Update config.json with quantization info if hf_quant_config is not None: @@ -1019,7 +961,8 @@ def _export_diffusers_checkpoint( component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: _save_component_state_dict_safetensors( - component, component_export_dir, merged_base_safetensor_path + component, component_export_dir, merged_base_safetensor_path, + model_type=model_type, ) print(f" Saved to: {component_export_dir}") From 7e74e911995458edd0fe6f84b2cf5a84c5a7f271 Mon Sep 17 00:00:00 2001 From: ynankani Date: Sun, 22 Feb 2026 23:12:59 -0800 Subject: [PATCH 4/4] Updates based on review comments Signed-off-by: ynankani --- modelopt/torch/export/unified_export_hf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index c8d1ed352..78f779a1d 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -161,7 +161,7 @@ def _save_component_state_dict_safetensors( metadata["_export_format"] = "safetensors_state_dict" metadata["_class_name"] = type(component).__name__ metadata_full.update(metadata) - + save_file( cpu_state_dict, str(component_export_dir / "model.safetensors"), @@ -961,7 +961,9 @@ def _export_diffusers_checkpoint( component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: _save_component_state_dict_safetensors( - component, component_export_dir, merged_base_safetensor_path, + component, + component_export_dir, + merged_base_safetensor_path, model_type=model_type, )