diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 50f35aed8..310a44b72 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -1,4 +1,5 @@ +#include #include #include #include @@ -18,6 +19,7 @@ namespace fs = std::filesystem; #endif // _WIN32 #include "stable-diffusion.h" +#include "model.h" // For SDVersion enum #define STB_IMAGE_IMPLEMENTATION #define STB_IMAGE_STATIC @@ -443,6 +445,7 @@ struct SDContextParams { std::string control_net_path; std::string embedding_dir; std::string photo_maker_path; + std::string model_type; // Manual model version override (sd1, sd2, sdxl, flux, etc.) sd_type_t wtype = SD_TYPE_COUNT; std::string tensor_type_rules; std::string lora_model_dir = "."; @@ -487,6 +490,10 @@ struct SDContextParams { "--model", "path to full model", &model_path}, + {"", + "--model-type", + "force model type (sd1, sd2, sdxl, flux, sdxl_inpaint, etc). Auto-detect if not specified.", + &model_type}, {"", "--clip_l", "path to the clip-l text encoder", &clip_l_path}, @@ -944,6 +951,38 @@ struct SDContextParams { embedding_vec.emplace_back(item); } + // Parse model_type string to SDVersion enum + int version_override = VERSION_COUNT; // Auto-detect by default + if (!model_type.empty()) { + std::string mt = model_type; + // Convert to lowercase for case-insensitive matching + std::transform(mt.begin(), mt.end(), mt.begin(), ::tolower); + + if (mt == "sd1" || mt == "sd1.5" || mt == "sd1.x") { + version_override = VERSION_SD1; + } else if (mt == "sd1_inpaint") { + version_override = VERSION_SD1_INPAINT; + } else if (mt == "sd2" || mt == "sd2.0" || mt == "sd2.1" || mt == "sd2.x") { + version_override = VERSION_SD2; + } else if (mt == "sd2_inpaint") { + version_override = VERSION_SD2_INPAINT; + } else if (mt == "sdxl" || mt == "sdxl1.0") { + version_override = VERSION_SDXL; + } else if (mt == "sdxl_inpaint") { + version_override = VERSION_SDXL_INPAINT; + } else if (mt == "sdxl_pix2pix") { + version_override = VERSION_SDXL_PIX2PIX; + } else if (mt == "flux" || mt == "flux1") { + version_override = VERSION_FLUX; + } else if (mt == "sd3" || mt == "sd3.5") { + version_override = VERSION_SD3; + } else if (mt == "svd") { + version_override = VERSION_SVD; + } else { + fprintf(stderr, "Warning: Unknown model type '%s', using auto-detect\n", model_type.c_str()); + } + } + sd_ctx_params_t sd_ctx_params = { model_path.c_str(), clip_l_path.c_str(), @@ -969,6 +1008,7 @@ struct SDContextParams { sampler_rng_type, prediction, lora_apply_mode, + version_override, // Add version_override parameter offload_params_to_cpu, enable_mmap, clip_on_cpu, diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index cb966d7e8..5b57e4939 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -184,6 +184,7 @@ typedef struct { enum rng_type_t sampler_rng_type; enum prediction_t prediction; enum lora_apply_mode_t lora_apply_mode; + int version_override; // SDVersion enum value, VERSION_COUNT = auto-detect bool offload_params_to_cpu; bool enable_mmap; bool keep_clip_on_cpu; diff --git a/src/model.cpp b/src/model.cpp index 58d71d9e4..81c39151b 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -655,11 +655,11 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s LOG_WARN("Couldn't find working VAE in %s", file_path.c_str()); // return false; } - if (!init_from_safetensors_file(clip_path, "te.")) { + if (!init_from_safetensors_file(clip_path, "cond_stage_model.transformer.")) { LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str()); // return false; } - if (!init_from_safetensors_file(clip_g_path, "te.1.")) { + if (!init_from_safetensors_file(clip_g_path, "cond_stage_model.1.transformer.")) { LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str()); } return true; @@ -1028,6 +1028,11 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight, input_block_weight; + // Return cached version if already detected as SDXL in earlier component + if (version_ == VERSION_SDXL || version_ == VERSION_SDXL_INPAINT || version_ == VERSION_SDXL_PIX2PIX) { + LOG_DEBUG("Returning cached SDXL version"); + return version_; + } bool has_multiple_encoders = false; bool is_unet = false; @@ -1089,8 +1094,10 @@ SDVersion ModelLoader::get_sd_version() { tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) { has_multiple_encoders = true; + // Return SDXL immediately to prevent later components from overriding if (is_unet) { - is_xl = true; + LOG_DEBUG("Detected SDXL (multiple text encoders in UNET model)"); + return VERSION_SDXL; } } if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { @@ -1122,6 +1129,11 @@ SDVersion ModelLoader::get_sd_version() { input_block_weight = tensor_storage; } } + + // Ensure SDXL is detected even if early return was not reached + if (has_multiple_encoders && is_unet) { + is_xl = true; + } if (is_wan) { LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels); if (patch_embedding_channels == 184320 && !has_img_emb) { diff --git a/src/name_conversion.cpp b/src/name_conversion.cpp index d3e863b8a..9a1eda514 100644 --- a/src/name_conversion.cpp +++ b/src/name_conversion.cpp @@ -920,6 +920,8 @@ std::vector cond_stage_model_prefix_vec = { "cond_stage_model.", "conditioner.embedders.", "text_encoders.", + "te.1.", // diffusers SDXL text_encoder_2 (clip_g) + "te.", // diffusers text_encoder (clip_l) }; std::vector diffuison_model_prefix_vec = { diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index c0ee1182d..0c2e2cd01 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -326,15 +326,28 @@ class StableDiffusionGGML { model_loader.convert_tensors_name(); - version = model_loader.get_sd_version(); - if (version == VERSION_COUNT) { - LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path)); - return false; + // Check for manual version override first + if (sd_ctx_params->version_override != VERSION_COUNT) { + version = (SDVersion)sd_ctx_params->version_override; + LOG_INFO("Version overridden to: %s", model_version_to_str[version]); + } else { + // Auto-detect version - don't overwrite if already detected as SDXL in earlier component + SDVersion detected_version = model_loader.get_sd_version(); + if (version != VERSION_SDXL && version != VERSION_SDXL_INPAINT && version != VERSION_SDXL_PIX2PIX) { + version = detected_version; + } else { + LOG_INFO("Keeping previous SDXL version, detected version: %s", model_version_to_str[detected_version]); + } + if (version == VERSION_COUNT) { + LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path)); + return false; + } } auto& tensor_storage_map = model_loader.get_tensor_storage_map(); LOG_INFO("Version: %s ", model_version_to_str[version]); + ggml_type wtype = (int)sd_ctx_params->wtype < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT) ? (ggml_type)sd_ctx_params->wtype : GGML_TYPE_COUNT; @@ -2918,6 +2931,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT; sd_ctx_params->prediction = PREDICTION_COUNT; sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO; + sd_ctx_params->version_override = VERSION_COUNT; // Auto-detect sd_ctx_params->offload_params_to_cpu = false; sd_ctx_params->enable_mmap = false; sd_ctx_params->keep_clip_on_cpu = false;