Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,8 @@ int main(int argc, const char* argv[]) {
gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength,
}, // pm_params
ctx_params.vae_tiling_params,
ctx_params.get_tiling_params(gen_params.get_resolved_width(),
gen_params.get_resolved_height()),
gen_params.cache_params,
};

Expand All @@ -776,7 +777,8 @@ int main(int argc, const char* argv[]) {
gen_params.seed,
gen_params.video_frames,
gen_params.vace_strength,
ctx_params.vae_tiling_params,
ctx_params.get_tiling_params(gen_params.get_resolved_width(),
gen_params.get_resolved_height()),
gen_params.cache_params,
};

Expand Down
50 changes: 46 additions & 4 deletions examples/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ struct SDContextParams {
prediction_t prediction = PREDICTION_COUNT;
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;

int vae_tiling_threshold = 0;
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
bool force_sdxl_vae_conv_scale = false;

Expand Down Expand Up @@ -588,10 +589,6 @@ struct SDContextParams {
};

options.bool_options = {
{"",
"--vae-tiling",
"process vae in tiles to reduce memory usage",
true, &vae_tiling_params.enabled},
{"",
"--force-sdxl-vae-conv-scale",
"force use of conv scale on sdxl vae",
Expand Down Expand Up @@ -728,6 +725,33 @@ struct SDContextParams {
return 1;
};

auto on_tiling_threshold = [&](int argc, const char** argv, int index) {
vae_tiling_threshold = 1;
if (++index >= argc) {
return 0;
}
size_t pos = 0;
std::string threshold_str = argv[index];
int result = -1;
try {
result = std::stoi(threshold_str, &pos);
} catch (const std::invalid_argument&) {
// check if it's likely to be another flag
return (threshold_str.rfind("-", 0) == 0) ? 0 : -1;
} catch (const std::out_of_range&) {
return -1;
}
if (pos != threshold_str.length() || result < 0) {
return -1;
}
if (result > 32768) {
// avoid overflow if the user disabled tiling by using a huge value
result = 0;
}
vae_tiling_threshold = result;
return 1;
};

auto on_tile_size_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
Expand Down Expand Up @@ -800,6 +824,11 @@ struct SDContextParams {
"but it usually offers faster inference speed and, in some cases, lower memory usage. "
"The at_runtime mode, on the other hand, is exactly the opposite.",
on_lora_apply_mode_arg},
{"",
"--vae-tiling",
"process vae in tiles to reduce memory usage. Optionally receives a size threshold T, which will "
"turn on tiling only for images larger than TxT.",
on_tiling_threshold},
{"",
"--vae-tile-size",
"tile size for vae tiling, format [X]x[Y] (default: 32x32)",
Expand Down Expand Up @@ -929,6 +958,7 @@ struct SDContextParams {
<< vae_tiling_params.target_overlap << ", "
<< vae_tiling_params.rel_size_x << ", "
<< vae_tiling_params.rel_size_y << " },\n"
<< " vae_tiling_threshold: " << vae_tiling_threshold << ",\n"
<< " force_sdxl_vae_conv_scale: " << (force_sdxl_vae_conv_scale ? "true" : "false") << "\n"
<< "}";
return oss.str();
Expand Down Expand Up @@ -990,6 +1020,18 @@ struct SDContextParams {
};
return sd_ctx_params;
}

sd_tiling_params_t get_tiling_params(int width, int height) {
sd_tiling_params_t params = vae_tiling_params;
if (vae_tiling_threshold == 0) {
params.enabled = false;
} else {
int area = width * height;
int threshold = vae_tiling_threshold * vae_tiling_threshold;
params.enabled = (area > threshold);
}
return params;
}
};

template <typename T>
Expand Down
6 changes: 3 additions & 3 deletions examples/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ int main(int argc, const char** argv) {
gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength,
}, // pm_params
ctx_params.vae_tiling_params,
ctx_params.get_tiling_params(gen_params.width, gen_params.height),
gen_params.cache_params,
};

Expand Down Expand Up @@ -741,7 +741,7 @@ int main(int argc, const char** argv) {
gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength,
}, // pm_params
ctx_params.vae_tiling_params,
ctx_params.get_tiling_params(get_resolved_width(), get_resolved_height()),
gen_params.cache_params,
};

Expand Down Expand Up @@ -1055,7 +1055,7 @@ int main(int argc, const char** argv) {
gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength,
}, // pm_params
ctx_params.vae_tiling_params,
ctx_params.get_tiling_params(get_resolved_width(), get_resolved_height()),
gen_params.cache_params,
};

Expand Down
Loading