diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index f9e4928ea..440c5b044 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -750,7 +750,8 @@ int main(int argc, const char* argv[]) { gen_params.pm_id_embed_path.c_str(), gen_params.pm_style_strength, }, // pm_params - ctx_params.vae_tiling_params, + ctx_params.get_tiling_params(gen_params.get_resolved_width(), + gen_params.get_resolved_height()), gen_params.cache_params, }; @@ -776,7 +777,8 @@ int main(int argc, const char* argv[]) { gen_params.seed, gen_params.video_frames, gen_params.vace_strength, - ctx_params.vae_tiling_params, + ctx_params.get_tiling_params(gen_params.get_resolved_width(), + gen_params.get_resolved_height()), gen_params.cache_params, }; diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 50f35aed8..ecffb0479 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -475,6 +475,7 @@ struct SDContextParams { prediction_t prediction = PREDICTION_COUNT; lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO; + int vae_tiling_threshold = 0; sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; bool force_sdxl_vae_conv_scale = false; @@ -588,10 +589,6 @@ struct SDContextParams { }; options.bool_options = { - {"", - "--vae-tiling", - "process vae in tiles to reduce memory usage", - true, &vae_tiling_params.enabled}, {"", "--force-sdxl-vae-conv-scale", "force use of conv scale on sdxl vae", @@ -728,6 +725,33 @@ struct SDContextParams { return 1; }; + auto on_tiling_threshold = [&](int argc, const char** argv, int index) { + vae_tiling_threshold = 1; + if (++index >= argc) { + return 0; + } + size_t pos = 0; + std::string threshold_str = argv[index]; + int result = -1; + try { + result = std::stoi(threshold_str, &pos); + } catch (const std::invalid_argument&) { + // check if it's likely to be another flag + return (threshold_str.rfind("-", 0) == 0) ? 0 : -1; + } catch (const std::out_of_range&) { + return -1; + } + if (pos != threshold_str.length() || result < 0) { + return -1; + } + if (result > 32768) { + // avoid overflow if the user disabled tiling by using a huge value + result = 0; + } + vae_tiling_threshold = result; + return 1; + }; + auto on_tile_size_arg = [&](int argc, const char** argv, int index) { if (++index >= argc) { return -1; @@ -800,6 +824,11 @@ struct SDContextParams { "but it usually offers faster inference speed and, in some cases, lower memory usage. " "The at_runtime mode, on the other hand, is exactly the opposite.", on_lora_apply_mode_arg}, + {"", + "--vae-tiling", + "process vae in tiles to reduce memory usage. Optionally receives a size threshold T, which will " + "turn on tiling only for images larger than TxT.", + on_tiling_threshold}, {"", "--vae-tile-size", "tile size for vae tiling, format [X]x[Y] (default: 32x32)", @@ -929,6 +958,7 @@ struct SDContextParams { << vae_tiling_params.target_overlap << ", " << vae_tiling_params.rel_size_x << ", " << vae_tiling_params.rel_size_y << " },\n" + << " vae_tiling_threshold: " << vae_tiling_threshold << ",\n" << " force_sdxl_vae_conv_scale: " << (force_sdxl_vae_conv_scale ? "true" : "false") << "\n" << "}"; return oss.str(); @@ -990,6 +1020,18 @@ struct SDContextParams { }; return sd_ctx_params; } + + sd_tiling_params_t get_tiling_params(int width, int height) { + sd_tiling_params_t params = vae_tiling_params; + if (vae_tiling_threshold == 0) { + params.enabled = false; + } else { + int area = width * height; + int threshold = vae_tiling_threshold * vae_tiling_threshold; + params.enabled = (area > threshold); + } + return params; + } }; template diff --git a/examples/server/main.cpp b/examples/server/main.cpp index 0fb10c7a3..1eceebca0 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -500,7 +500,7 @@ int main(int argc, const char** argv) { gen_params.pm_id_embed_path.c_str(), gen_params.pm_style_strength, }, // pm_params - ctx_params.vae_tiling_params, + ctx_params.get_tiling_params(gen_params.width, gen_params.height), gen_params.cache_params, }; @@ -741,7 +741,7 @@ int main(int argc, const char** argv) { gen_params.pm_id_embed_path.c_str(), gen_params.pm_style_strength, }, // pm_params - ctx_params.vae_tiling_params, + ctx_params.get_tiling_params(get_resolved_width(), get_resolved_height()), gen_params.cache_params, }; @@ -1055,7 +1055,7 @@ int main(int argc, const char** argv) { gen_params.pm_id_embed_path.c_str(), gen_params.pm_style_strength, }, // pm_params - ctx_params.vae_tiling_params, + ctx_params.get_tiling_params(get_resolved_width(), get_resolved_height()), gen_params.cache_params, };