Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions examples/common/common.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

#include <algorithm>
#include <filesystem>
#include <iostream>
#include <map>
Expand All @@ -18,6 +19,7 @@ namespace fs = std::filesystem;
#endif // _WIN32

#include "stable-diffusion.h"
#include "model.h" // For SDVersion enum

#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
Expand Down Expand Up @@ -443,6 +445,7 @@ struct SDContextParams {
std::string control_net_path;
std::string embedding_dir;
std::string photo_maker_path;
std::string model_type; // Manual model version override (sd1, sd2, sdxl, flux, etc.)
sd_type_t wtype = SD_TYPE_COUNT;
std::string tensor_type_rules;
std::string lora_model_dir = ".";
Expand Down Expand Up @@ -487,6 +490,10 @@ struct SDContextParams {
"--model",
"path to full model",
&model_path},
{"",
"--model-type",
"force model type (sd1, sd2, sdxl, flux, sdxl_inpaint, etc). Auto-detect if not specified.",
&model_type},
{"",
"--clip_l",
"path to the clip-l text encoder", &clip_l_path},
Expand Down Expand Up @@ -944,6 +951,38 @@ struct SDContextParams {
embedding_vec.emplace_back(item);
}

// Parse model_type string to SDVersion enum
int version_override = VERSION_COUNT; // Auto-detect by default
if (!model_type.empty()) {
std::string mt = model_type;
// Convert to lowercase for case-insensitive matching
std::transform(mt.begin(), mt.end(), mt.begin(), ::tolower);

if (mt == "sd1" || mt == "sd1.5" || mt == "sd1.x") {
version_override = VERSION_SD1;
} else if (mt == "sd1_inpaint") {
version_override = VERSION_SD1_INPAINT;
} else if (mt == "sd2" || mt == "sd2.0" || mt == "sd2.1" || mt == "sd2.x") {
version_override = VERSION_SD2;
} else if (mt == "sd2_inpaint") {
version_override = VERSION_SD2_INPAINT;
} else if (mt == "sdxl" || mt == "sdxl1.0") {
version_override = VERSION_SDXL;
} else if (mt == "sdxl_inpaint") {
version_override = VERSION_SDXL_INPAINT;
} else if (mt == "sdxl_pix2pix") {
version_override = VERSION_SDXL_PIX2PIX;
} else if (mt == "flux" || mt == "flux1") {
version_override = VERSION_FLUX;
} else if (mt == "sd3" || mt == "sd3.5") {
version_override = VERSION_SD3;
} else if (mt == "svd") {
version_override = VERSION_SVD;
} else {
fprintf(stderr, "Warning: Unknown model type '%s', using auto-detect\n", model_type.c_str());
}
}

sd_ctx_params_t sd_ctx_params = {
model_path.c_str(),
clip_l_path.c_str(),
Expand All @@ -969,6 +1008,7 @@ struct SDContextParams {
sampler_rng_type,
prediction,
lora_apply_mode,
version_override, // Add version_override parameter
offload_params_to_cpu,
enable_mmap,
clip_on_cpu,
Expand Down
1 change: 1 addition & 0 deletions include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ typedef struct {
enum rng_type_t sampler_rng_type;
enum prediction_t prediction;
enum lora_apply_mode_t lora_apply_mode;
int version_override; // SDVersion enum value, VERSION_COUNT = auto-detect
bool offload_params_to_cpu;
bool enable_mmap;
bool keep_clip_on_cpu;
Expand Down
18 changes: 15 additions & 3 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,11 +655,11 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s
LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
// return false;
}
if (!init_from_safetensors_file(clip_path, "te.")) {
if (!init_from_safetensors_file(clip_path, "cond_stage_model.transformer.")) {
LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
// return false;
}
if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
if (!init_from_safetensors_file(clip_g_path, "cond_stage_model.1.transformer.")) {
LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
}
return true;
Expand Down Expand Up @@ -1028,6 +1028,11 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s

SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight, input_block_weight;
// Return cached version if already detected as SDXL in earlier component
if (version_ == VERSION_SDXL || version_ == VERSION_SDXL_INPAINT || version_ == VERSION_SDXL_PIX2PIX) {
LOG_DEBUG("Returning cached SDXL version");
return version_;
}

bool has_multiple_encoders = false;
bool is_unet = false;
Expand Down Expand Up @@ -1089,8 +1094,10 @@ SDVersion ModelLoader::get_sd_version() {
tensor_storage.name.find("cond_stage_model.1") != std::string::npos ||
tensor_storage.name.find("te.1") != std::string::npos) {
has_multiple_encoders = true;
// Return SDXL immediately to prevent later components from overriding
if (is_unet) {
is_xl = true;
LOG_DEBUG("Detected SDXL (multiple text encoders in UNET model)");
return VERSION_SDXL;
}
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
Expand Down Expand Up @@ -1122,6 +1129,11 @@ SDVersion ModelLoader::get_sd_version() {
input_block_weight = tensor_storage;
}
}

// Ensure SDXL is detected even if early return was not reached
if (has_multiple_encoders && is_unet) {
is_xl = true;
}
if (is_wan) {
LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels);
if (patch_embedding_channels == 184320 && !has_img_emb) {
Expand Down
2 changes: 2 additions & 0 deletions src/name_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,8 @@ std::vector<std::string> cond_stage_model_prefix_vec = {
"cond_stage_model.",
"conditioner.embedders.",
"text_encoders.",
"te.1.", // diffusers SDXL text_encoder_2 (clip_g)
"te.", // diffusers text_encoder (clip_l)
};

std::vector<std::string> diffuison_model_prefix_vec = {
Expand Down
22 changes: 18 additions & 4 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,15 +326,28 @@ class StableDiffusionGGML {

model_loader.convert_tensors_name();

version = model_loader.get_sd_version();
if (version == VERSION_COUNT) {
LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path));
return false;
// Check for manual version override first
if (sd_ctx_params->version_override != VERSION_COUNT) {
version = (SDVersion)sd_ctx_params->version_override;
LOG_INFO("Version overridden to: %s", model_version_to_str[version]);
} else {
// Auto-detect version - don't overwrite if already detected as SDXL in earlier component
SDVersion detected_version = model_loader.get_sd_version();
if (version != VERSION_SDXL && version != VERSION_SDXL_INPAINT && version != VERSION_SDXL_PIX2PIX) {
version = detected_version;
} else {
LOG_INFO("Keeping previous SDXL version, detected version: %s", model_version_to_str[detected_version]);
}
if (version == VERSION_COUNT) {
LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path));
return false;
}
}

auto& tensor_storage_map = model_loader.get_tensor_storage_map();

LOG_INFO("Version: %s ", model_version_to_str[version]);

ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
? (ggml_type)sd_ctx_params->wtype
: GGML_TYPE_COUNT;
Expand Down Expand Up @@ -2918,6 +2931,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT;
sd_ctx_params->prediction = PREDICTION_COUNT;
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
sd_ctx_params->version_override = VERSION_COUNT; // Auto-detect
sd_ctx_params->offload_params_to_cpu = false;
sd_ctx_params->enable_mmap = false;
sd_ctx_params->keep_clip_on_cpu = false;
Expand Down