diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 05c4dd81b7..ef29ea8a62 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -65,7 +65,10 @@ When analyzing a Pull Request, follow this protocol: - **Keep headers self-contained but minimal**: each header must compile on its own, but should not pull in transitive dependencies that callers don't need. - **Prefer opaque types / Pimpl**: for complex implementation details, consider the Pimpl idiom to keep implementation-only types out of the public header entirely. - **Never include a header solely for a typedef or enum**: forward-declare the enum (`enum class Foo;` in C++17) or relocate the typedef to a lightweight `fwd.hpp`-style header. -13. Be mindful when accepting `const T&` in constructors or functions that store the reference: verify that the referenced object's lifetime outlives the usage to avoid dangling references. +13. **No dangling references or temporaries bound to `const T&`**: + - Never use `const T&` parameters with default arguments that construct temporaries (e.g. `const std::string& param = ""`). This binds a reference to a temporary — use a function overload instead, or pass by value. + - When accepting `const T&` in constructors or functions that store the reference, verify that the referenced object's lifetime outlives the usage to avoid dangling references. + - Prefer overloads over default arguments for non-trivial types passed by reference. ## Build System diff --git a/demos/common/export_models/export_model.py b/demos/common/export_models/export_model.py index 5aa81b0c81..aa3d05516f 100644 --- a/demos/common/export_models/export_model.py +++ b/demos/common/export_models/export_model.py @@ -86,6 +86,11 @@ def add_common_arguments(parser): parser_image_generation.add_argument('--max_num_images_per_prompt', type=int, default=0, help='Max allowed number of images client is allowed to request for a given prompt', dest='max_num_images_per_prompt') parser_image_generation.add_argument('--default_num_inference_steps', type=int, default=0, help='Default number of inference steps when not specified by client', dest='default_num_inference_steps') parser_image_generation.add_argument('--max_num_inference_steps', type=int, default=0, help='Max allowed number of inference steps client is allowed to request for a given prompt', dest='max_num_inference_steps') +parser_image_generation.add_argument('--source_loras', default=None, + help='LoRA adapters to apply. Format: alias1=org1/repo1,alias2=org2/repo2@lora_file.safetensors ' + 'where @filename is optional and specifies which .safetensors file to use from the downloaded repo ' + '(auto-detected when repo contains exactly one). Only for image_generation task.', + dest='source_loras') parser_text2speech = subparsers.add_parser('text2speech', help='export model for text2speech endpoint') add_common_arguments(parser_text2speech) @@ -323,6 +328,9 @@ def add_common_arguments(parser): default_num_inference_steps: {{default_num_inference_steps}},{% endif %} {%- if max_num_inference_steps > 0 %} max_num_inference_steps: {{max_num_inference_steps}},{% endif %} + {%- for lora in lora_adapters %} + lora_adapters { alias: "{{lora.alias}}" path: "{{lora.path}}" } + {%- endfor %} } } }""" @@ -600,7 +608,7 @@ def export_rerank_model(model_repository_path, source_model, model_name, precisi add_servable_to_config(config_file_path, model_name, os.path.relpath(os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path))) -def export_image_generation_model(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path, num_streams): +def export_image_generation_model(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path, num_streams, source_loras): model_path = "./" target_path = os.path.join(model_repository_path, model_name) model_index_path = os.path.join(target_path, 'model_index.json') @@ -613,6 +621,41 @@ def export_image_generation_model(model_repository_path, source_model, model_nam if os.system(optimum_command): raise ValueError("Failed to export image generation model", source_model) + # Download and resolve LoRA adapters + lora_adapters = [] + if source_loras: + from huggingface_hub import snapshot_download + entries = source_loras.split(',') + for entry in entries: + entry = entry.strip() + if '=' in entry: + alias, repo_and_file = entry.split('=', 1) + else: + repo_and_file = entry + alias = entry.split('/')[-1] if '/' in entry else entry + safetensors_file = '' + if '@' in repo_and_file: + repo, safetensors_file = repo_and_file.rsplit('@', 1) + else: + repo = repo_and_file + lora_dir = os.path.join(target_path, 'loras', repo) + if not os.path.isdir(lora_dir): + print(f"Downloading LoRA adapter: {repo} to {lora_dir}") + snapshot_download(repo_id=repo, local_dir=lora_dir) + else: + print(f"LoRA adapter directory already exists: {lora_dir}") + if not safetensors_file: + st_files = [f for f in os.listdir(lora_dir) if f.endswith('.safetensors')] + if len(st_files) == 0: + raise ValueError(f"No .safetensors files found in LoRA adapter: {repo}") + if len(st_files) > 1: + raise ValueError(f"Multiple .safetensors files in LoRA adapter: {repo}. Use @filename to specify.") + safetensors_file = st_files[0] + lora_path = 'loras/' + repo + '/' + safetensors_file + lora_adapters.append({'alias': alias, 'path': lora_path}) + print(f"LoRA adapter: {alias} -> {lora_path}") + task_parameters['lora_adapters'] = lora_adapters + plugin_config = {} assert num_streams >= 0, "num_streams should be a non-negative integer" if num_streams > 0: @@ -695,4 +738,4 @@ def export_image_generation_model(model_repository_path, source_model, model_nam 'max_num_inference_steps', 'extra_quantization_params' ]} - export_image_generation_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path'], args['num_streams']) + export_image_generation_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path'], args['num_streams'], args['source_loras']) diff --git a/demos/image_generation/README.md b/demos/image_generation/README.md index 4943524c84..f37257e2de 100644 --- a/demos/image_generation/README.md +++ b/demos/image_generation/README.md @@ -528,6 +528,12 @@ Output file (`edit_output.png`): Inpainting replaces a masked region in an image based on the prompt. The `mask` is a black-and-white image where white pixels mark the area to repaint. +Download sample images: +```console +curl -O https://raw.githubusercontent.com/openvinotoolkit/model_server/main/demos/image_generation/cat.png +curl -O https://raw.githubusercontent.com/openvinotoolkit/model_server/main/demos/image_generation/cat_mask.png +``` + ![cat](./cat.png) ![cat_mask](./cat_mask.png) ::::{tab-set} @@ -599,6 +605,12 @@ Outpainting extends an image beyond its original borders. Prepare two images: - **outpaint_input.png** — the original image centered on a larger canvas (e.g. 768×768) with black borders - **outpaint_mask.png** — white where the new content should be generated (the borders), black where the original image is +Download sample images: +```console +curl -O https://raw.githubusercontent.com/openvinotoolkit/model_server/main/demos/image_generation/outpaint_input.png +curl -O https://raw.githubusercontent.com/openvinotoolkit/model_server/main/demos/image_generation/outpaint_mask.png +``` + ![outpaint_input](./outpaint_input.png) ![outpaint_mask](./outpaint_mask.png) ::::{tab-set} diff --git a/src/BUILD b/src/BUILD index ea624f5e59..743c23303e 100644 --- a/src/BUILD +++ b/src/BUILD @@ -3067,6 +3067,7 @@ cc_library( ":test_light_test_utils", ":test_test_with_temp_dir", "//src/graph_export:graph_export", + "//src/graph_export:image_generation_graph_cli_parser", "//src:libovms_server_settings", "@com_google_googletest//:gtest", ], diff --git a/src/capi_frontend/server_settings.hpp b/src/capi_frontend/server_settings.hpp index 5b8a3dce54..df3ae007b2 100644 --- a/src/capi_frontend/server_settings.hpp +++ b/src/capi_frontend/server_settings.hpp @@ -143,6 +143,29 @@ struct RerankGraphSettingsImpl { uint64_t maxAllowedChunks = 10000; }; +enum class LoraSourceType { + HF_REPO, + DIRECT_URL, + LOCAL_FILE +}; + +struct LoraAdapterSettings { + std::string alias; + std::string sourceLora; // HF repo, direct URL, or local file path + std::string safetensorsFile; // resolved filename, empty = auto-detect (HF only) + LoraSourceType sourceType = LoraSourceType::HF_REPO; +}; + +struct CompositeLoraComponent { + std::string adapterAlias; // references a LoraAdapterSettings alias + float weight = 1.0f; +}; + +struct CompositeLoraSettings { + std::string alias; + std::vector components; +}; + struct ImageGenerationGraphSettingsImpl { std::string resolution = ""; std::string maxResolution = ""; @@ -152,6 +175,8 @@ struct ImageGenerationGraphSettingsImpl { std::optional maxNumberImagesPerPrompt; std::optional defaultNumInferenceSteps; std::optional maxNumInferenceSteps; + std::vector loraAdapters; + std::vector compositeLoraAdapters; }; struct ExportSettings { @@ -169,6 +194,7 @@ struct HFSettingsImpl { std::string sourceModel = ""; std::optional ggufFilename; std::string downloadPath = ""; + std::string sourceLoras = ""; // raw --source_loras value, parsed by image gen CLI parser bool overwriteModels = false; ModelDownlaodType downloadType = GIT_CLONE_DOWNLOAD; GraphExportType task = TEXT_GENERATION_GRAPH; diff --git a/src/cli_parser.cpp b/src/cli_parser.cpp index 4f968c5294..e99770f123 100644 --- a/src/cli_parser.cpp +++ b/src/cli_parser.cpp @@ -115,11 +115,11 @@ std::variant> CLIParser::parse(int argc, char* cxxopts::value(), "GRPC_CHANNEL_ARGUMENTS") ("file_system_poll_wait_seconds", "Time interval between config and model versions changes detection. Default is 1. Zero or negative value disables changes monitoring.", - cxxopts::value()->default_value("1"), + cxxopts::value()->default_value("0"), "FILE_SYSTEM_POLL_WAIT_SECONDS") ("sequence_cleaner_poll_wait_minutes", "Time interval between two consecutive sequence cleanup scans. Default is 5. Zero value disables sequence cleaner. It also sets the schedule for releasing free memory from the heap.", - cxxopts::value()->default_value("5"), + cxxopts::value()->default_value("0"), "SEQUENCE_CLEANER_POLL_WAIT_MINUTES") ("custom_node_resources_cleaner_interval_seconds", "Time interval between two consecutive resources cleanup scans. Default is 300. Zero value disables resources cleaner.", @@ -213,6 +213,10 @@ std::variant> CLIParser::parse(int argc, char* "HF source model path", cxxopts::value(), "HF_SOURCE") + ("source_loras", + "LoRA adapters for image generation. Format: alias1=org1/repo1,alias2=org2/repo2@file.safetensors,alias3=https://url/file.safetensors,alias4=/local/path/file.safetensors", + cxxopts::value(), + "SOURCE_LORAS") ("gguf_filename", "Name of the GGUF file", cxxopts::value(), @@ -715,6 +719,9 @@ void CLIParser::prepareGraph(ServerSettingsImpl& serverSettings, HFSettingsImpl& } else if (result->count("model_name")) { hfSettings.sourceModel = result->operator[]("model_name").as(); } + if (result->count("source_loras")) { + hfSettings.sourceLoras = result->operator[]("source_loras").as(); + } if ((result->count("weight-format") || result->count("extra_quantization_params")) && isOptimumCliDownload(hfSettings.sourceModel, hfSettings.ggufFilename)) { hfSettings.downloadType = OPTIMUM_CLI_DOWNLOAD; } diff --git a/src/graph_export/graph_export.cpp b/src/graph_export/graph_export.cpp index dadbd57777..48b6c7b1c3 100644 --- a/src/graph_export/graph_export.cpp +++ b/src/graph_export/graph_export.cpp @@ -467,6 +467,37 @@ node: { max_num_inference_steps: )" << graphSettings.maxNumInferenceSteps.value(); } + for (const auto& adapter : graphSettings.loraAdapters) { + std::string loraPath; + if (adapter.sourceType == LoraSourceType::LOCAL_FILE) { + loraPath = adapter.sourceLora; + } else if (adapter.sourceType == LoraSourceType::HF_REPO) { + loraPath = "loras/" + adapter.sourceLora + "/" + adapter.safetensorsFile; + } else { // cURL direct link + loraPath = "loras/" + adapter.alias + "/" + adapter.safetensorsFile; + } + oss << R"( + lora_adapters { alias: ")" << adapter.alias << R"(" path: ")" << loraPath << R"(")"; + // Only omit alpha when default (1.0) - let proto handle it + oss << R"( })"; + } + + for (const auto& composite : graphSettings.compositeLoraAdapters) { + oss << R"( + composite_lora_adapters { + alias: ")" << composite.alias << R"(" +)"; + for (const auto& component : composite.components) { + oss << R"( components { adapter_alias: ")" << component.adapterAlias << R"(")"; + if (component.weight != 1.0f) { + oss << R"( weight: )" << component.weight; + } + oss << R"( } +)"; + } + oss << R"( })"; + } + oss << R"( } } diff --git a/src/graph_export/image_generation_graph_cli_parser.cpp b/src/graph_export/image_generation_graph_cli_parser.cpp index ed0d1b91ef..399ae62c3c 100644 --- a/src/graph_export/image_generation_graph_cli_parser.cpp +++ b/src/graph_export/image_generation_graph_cli_parser.cpp @@ -16,9 +16,11 @@ #include "image_generation_graph_cli_parser.hpp" #include +#include #include #include #include +#include #include #include #include @@ -27,6 +29,7 @@ #include "../capi_frontend/server_settings.hpp" #include "../ovms_exit_codes.hpp" #include "../status.hpp" +#include "src/stringutils.hpp" namespace ovms { @@ -164,6 +167,126 @@ void ImageGenerationGraphCLIParser::prepare(ServerSettingsImpl& serverSettings, } } + // Parse --source_loras + // Supports three source types plus composite aliases: + // alias=org/repo (HF_REPO) + // alias=org/repo@file.safetensors (HF_REPO with explicit file) + // alias=https://url/file.safetensors (DIRECT_URL) + // alias=/path/to/file.safetensors (LOCAL_FILE) + // alias=@ref1:0.7+@ref2:0.5 (COMPOSITE - references other aliases) + if (!hfSettings.sourceLoras.empty()) { + auto entries = ovms::tokenize(hfSettings.sourceLoras, ','); + // First pass: collect all real adapters + for (const auto& entry : entries) { + auto eqPos = entry.find('='); + if (eqPos == std::string::npos) { + throw std::invalid_argument("Missing alias in --source_loras entry: '" + entry + "'. Expected format: alias=source"); + } + std::string alias = entry.substr(0, eqPos); + std::string source = entry.substr(eqPos + 1); + if (alias.empty() || source.empty()) { + throw std::invalid_argument("Invalid --source_loras entry: '" + entry + "'. Alias and source must not be empty."); + } + // Skip composite entries in first pass + if (source[0] == '@') { + continue; + } + + LoraAdapterSettings adapter; + adapter.alias = alias; + // Detect source type + if (source.substr(0, 8) == "https://" || source.substr(0, 7) == "http://") { + adapter.sourceType = LoraSourceType::DIRECT_URL; + adapter.sourceLora = source; + auto lastSlash = source.rfind('/'); + if (lastSlash == std::string::npos || lastSlash == source.size() - 1) { + throw std::invalid_argument("Cannot extract filename from URL in --source_loras entry: '" + entry + "'"); + } + adapter.safetensorsFile = source.substr(lastSlash + 1); + if (!endsWith(adapter.safetensorsFile, ".safetensors")) { + throw std::invalid_argument("URL must point to a .safetensors file in --source_loras entry: '" + entry + "'"); + } + } else if (ovms::isLocalFilePath(source)) { + adapter.sourceType = LoraSourceType::LOCAL_FILE; + adapter.sourceLora = source; + if (!endsWith(source, ".safetensors")) { + throw std::invalid_argument("Local path must point to a .safetensors file in --source_loras entry: '" + entry + "'"); + } + if (!std::filesystem::exists(source)) { + throw std::invalid_argument("Local LoRA file does not exist: '" + source + "' in --source_loras entry: '" + entry + "'"); + } + auto lastSlash = source.find_last_of("/\\"); + adapter.safetensorsFile = (lastSlash != std::string::npos) ? source.substr(lastSlash + 1) : source; + } else { + adapter.sourceType = LoraSourceType::HF_REPO; + auto atPos = source.find('@'); + if (atPos != std::string::npos) { + adapter.sourceLora = source.substr(0, atPos); + adapter.safetensorsFile = source.substr(atPos + 1); + if (adapter.safetensorsFile.empty()) { + throw std::invalid_argument("Empty filename after @ in --source_loras entry: '" + entry + "'"); + } + } else { + adapter.sourceLora = source; + } + if (adapter.sourceLora.empty()) { + throw std::invalid_argument("Invalid --source_loras entry: '" + entry + "'. HF repo source must not be empty."); + } + } + imageGenerationGraphSettings.loraAdapters.push_back(std::move(adapter)); + } + + // Collect known adapter aliases for validation + std::set knownAliases; + for (const auto& adapter : imageGenerationGraphSettings.loraAdapters) { + knownAliases.insert(adapter.alias); + } + + // Second pass: parse composite entries (source starts with @) + for (const auto& entry : entries) { + auto eqPos = entry.find('='); + std::string alias = entry.substr(0, eqPos); + std::string source = entry.substr(eqPos + 1); + if (source[0] != '@') { + continue; + } + CompositeLoraSettings composite; + composite.alias = alias; + // Parse @ref1:0.7+@ref2:0.5 + auto componentTokens = ovms::tokenize(source, '+'); + for (const auto& compToken : componentTokens) { + if (compToken.empty() || compToken[0] != '@') { + throw std::invalid_argument("Invalid composite LoRA component '" + compToken + "' in entry: '" + entry + "'. Each component must start with @"); + } + CompositeLoraComponent component; + std::string ref = compToken.substr(1); // strip @ + auto colonPos = ref.find(':'); + if (colonPos != std::string::npos) { + component.adapterAlias = ref.substr(0, colonPos); + std::string weightStr = ref.substr(colonPos + 1); + try { + component.weight = std::stof(weightStr); + } catch (...) { + throw std::invalid_argument("Invalid weight '" + weightStr + "' in composite LoRA component: '" + compToken + "'"); + } + } else { + component.adapterAlias = ref; + } + if (component.adapterAlias.empty()) { + throw std::invalid_argument("Empty adapter reference in composite LoRA component: '" + compToken + "'"); + } + if (knownAliases.find(component.adapterAlias) == knownAliases.end()) { + throw std::invalid_argument("Composite LoRA references unknown adapter '" + component.adapterAlias + "' in entry: '" + entry + "'"); + } + composite.components.push_back(std::move(component)); + } + if (composite.components.empty()) { + throw std::invalid_argument("Composite LoRA entry has no components: '" + entry + "'"); + } + imageGenerationGraphSettings.compositeLoraAdapters.push_back(std::move(composite)); + } + } + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); } diff --git a/src/http_payload.hpp b/src/http_payload.hpp index b4415a1616..1a8465d8f7 100644 --- a/src/http_payload.hpp +++ b/src/http_payload.hpp @@ -32,6 +32,7 @@ namespace ovms { struct HttpPayload { std::string uri; + std::string modelName; // resolved model name from request (JSON model field, multipart, or URI) std::unordered_map headers; std::string body; // always std::shared_ptr parsedJson; // pre-parsed body = null diff --git a/src/http_rest_api_handler.cpp b/src/http_rest_api_handler.cpp index 33a81cb429..ffcd448462 100644 --- a/src/http_rest_api_handler.cpp +++ b/src/http_rest_api_handler.cpp @@ -566,6 +566,7 @@ static Status createV3HttpPayload( request.body = request_body; request.parsedJson = std::move(parsedJson); request.uri = std::string(uri); + request.modelName = modelName; request.client = std::make_shared(serverReaderWriter); request.multipartParser = std::move(multiPartParser); diff --git a/src/image_gen/http_image_gen_calculator.cc b/src/image_gen/http_image_gen_calculator.cc index 4e4381f2a6..5136524558 100644 --- a/src/image_gen/http_image_gen_calculator.cc +++ b/src/image_gen/http_image_gen_calculator.cc @@ -30,6 +30,7 @@ #include "pipelines.hpp" #include "imagegenutils.hpp" +#include #pragma warning(push) #pragma warning(disable : 6001 4324 6385 6386) @@ -45,6 +46,68 @@ using ImageGenerationPipelinesMap = std::unordered_map& loraAdapters, + const std::unordered_map>>& compositeLoraAdapters, + const ImageGenPipelineArgs& args, + ov::AnyMap& requestOptions, + const std::unordered_map& loraWeightsOverride = {}) { + if (loraAdapters.empty()) { + return; + } + // All adapters were registered at compile time (alpha=1.0 each). + // At generate time we must explicitly set the adapter config: + // - If modelName matches a composite alias: activate all component adapters with their weights. + // - If modelName matches a single adapter alias: activate that adapter. + // - Otherwise: disable all adapters (alpha=0) so the base model runs clean. + // lora_weights from request body can override default weights. + ov::genai::AdapterConfig adapterConfig; + + auto compositeIt = compositeLoraAdapters.find(modelName); + if (compositeIt != compositeLoraAdapters.end()) { + // Composite adapter — activate multiple adapters + for (const auto& [compAlias, defaultWeight] : compositeIt->second) { + auto adapterIt = loraAdapters.find(compAlias); + if (adapterIt == loraAdapters.end()) { + SPDLOG_LOGGER_WARN(llm_calculator_logger, "Composite LoRA '{}' references unknown adapter '{}', skipping", modelName, compAlias); + continue; + } + float weight = defaultWeight; + auto overrideIt = loraWeightsOverride.find(compAlias); + if (overrideIt != loraWeightsOverride.end()) { + weight = overrideIt->second; + } + adapterConfig.add(adapterIt->second, weight); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Composite LoRA '{}': applied adapter '{}' with weight: {}", modelName, compAlias, weight); + } + } else { + auto adapterIt = loraAdapters.find(modelName); + if (adapterIt != loraAdapters.end()) { + float alpha = 1.0f; + auto overrideIt = loraWeightsOverride.find(modelName); + if (overrideIt != loraWeightsOverride.end()) { + alpha = overrideIt->second; + } else { + for (const auto& info : args.loraAdapters) { + if (info.alias == modelName) { + alpha = info.alpha; + break; + } + } + } + adapterConfig.add(adapterIt->second, alpha); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Applied LoRA adapter: {} with alpha: {}", modelName, alpha); + } else { + // Disable all adapters that were registered at compile time + for (const auto& [alias, adapter] : loraAdapters) { + adapterConfig.add(adapter, 0.0f); + } + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "No LoRA adapter matched for model: {}, disabling all adapters", modelName); + } + } + requestOptions[ov::genai::adapters.name()] = adapterConfig; +} + static bool progress_bar(size_t step, size_t num_steps, ov::Tensor&) { SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Image Generation Step: {}/{}", step + 1, num_steps); return false; @@ -179,10 +242,26 @@ class ImageGenCalculator : public CalculatorBase { SET_OR_RETURN(std::string, prompt, getPromptField(*payload.parsedJson)); SET_OR_RETURN(ov::AnyMap, requestOptions, getImageGenerationRequestOptions(*payload.parsedJson, pipe->args)); + // Parse optional lora_weights from request body + std::unordered_map loraWeightsOverride; + auto loraWeightsIt = payload.parsedJson->FindMember("lora_weights"); + if (loraWeightsIt != payload.parsedJson->MemberEnd() && loraWeightsIt->value.IsObject()) { + for (auto member = loraWeightsIt->value.MemberBegin(); member != loraWeightsIt->value.MemberEnd(); ++member) { + if (member->value.IsNumber()) { + loraWeightsOverride[member->name.GetString()] = member->value.GetFloat(); + } + } + } + + // Apply LoRA adapter if the requested model name matches an alias + applyLoraAdapterIfNeeded(payload.modelName, pipe->loraAdapters, pipe->compositeLoraAdapters, pipe->args, requestOptions, loraWeightsOverride); if (!pipe->text2ImagePipeline) return absl::FailedPreconditionError("Text-to-image pipeline is not available for this model"); - auto t2i = pipe->text2ImagePipeline->clone(); - auto status = generateTensor(t2i, prompt, requestOptions, images); + absl::Status status; + { + auto t2i = pipe->text2ImagePipeline->clone(); + status = generateTensor(t2i, prompt, requestOptions, images); + } if (!status.ok()) { return status; } @@ -203,6 +282,9 @@ class ImageGenCalculator : public CalculatorBase { SET_OR_RETURN(ov::AnyMap, requestOptions, getImageEditRequestOptions(*payload.multipartParser, pipe->args)); + // Apply LoRA adapter if the requested model name matches an alias + applyLoraAdapterIfNeeded(payload.modelName, pipe->loraAdapters, pipe->compositeLoraAdapters, pipe->args, requestOptions); + SET_OR_RETURN(std::optional, mask, getFileFromPayload(*payload.multipartParser, "mask")); SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "ImageGenCalculator [Node: {}] Mask present: {}", cc->NodeName(), mask.has_value() && !mask.value().empty()); @@ -218,14 +300,16 @@ class ImageGenCalculator : public CalculatorBase { return status; } SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "ImageGenCalculator [Node: {}] Inpainting: mask tensor decoded, acquiring inpainting queue slot", cc->NodeName()); - InpaintingQueueGuard inpaintingGuard(*pipe->inpaintingQueue); + PipelineSlotGuard inpaintingGuard(*pipe->inpaintingQueue); SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "ImageGenCalculator [Node: {}] Inpainting: queue slot acquired, invoking generate()", cc->NodeName()); status = generateTensorInpainting(*pipe->inpaintingPipeline, prompt, imageTensor, maskTensor, requestOptions, images); } else { if (!pipe->image2ImagePipeline) return absl::FailedPreconditionError("Image-to-image pipeline is not available for this model"); - auto i2i = pipe->image2ImagePipeline->clone(); - status = generateTensorImg2Img(i2i, prompt, imageTensor, requestOptions, images); + { + auto i2i = pipe->image2ImagePipeline->clone(); + status = generateTensorImg2Img(i2i, prompt, imageTensor, requestOptions, images); + } } if (!status.ok()) { return status; diff --git a/src/image_gen/image_gen_calculator.proto b/src/image_gen/image_gen_calculator.proto index c69f3ce97e..f73f01e0d9 100644 --- a/src/image_gen/image_gen_calculator.proto +++ b/src/image_gen/image_gen_calculator.proto @@ -43,4 +43,26 @@ message ImageGenCalculatorOptions { optional string resolution = 9; optional int64 num_images_per_prompt = 10; optional float guidance_scale = 11; + + // LoRA adapters + repeated LoraAdapterEntry lora_adapters = 12; + + // Composite LoRA adapters (multi-LoRA presets) + repeated CompositeLoraAdapterEntry composite_lora_adapters = 13; +} + +message LoraAdapterEntry { + required string alias = 1; + required string path = 2; + optional float alpha = 3 [default = 1.0]; +} + +message CompositeLoraComponent { + required string adapter_alias = 1; + optional float weight = 2 [default = 1.0]; +} + +message CompositeLoraAdapterEntry { + required string alias = 1; + repeated CompositeLoraComponent components = 2; } diff --git a/src/image_gen/imagegen_init.cpp b/src/image_gen/imagegen_init.cpp index a96cbd764c..ae9d8c4b45 100644 --- a/src/image_gen/imagegen_init.cpp +++ b/src/image_gen/imagegen_init.cpp @@ -258,6 +258,31 @@ std::variant prepareImageGenPipelineArgs(const goo args.maxNumImagesPerPrompt = nodeOptions.max_num_images_per_prompt(); args.defaultNumInferenceSteps = nodeOptions.default_num_inference_steps(); args.maxNumInferenceSteps = nodeOptions.max_num_inference_steps(); + + for (int i = 0; i < nodeOptions.lora_adapters_size(); ++i) { + const auto& loraEntry = nodeOptions.lora_adapters(i); + LoraAdapterInfo info; + info.alias = loraEntry.alias(); + auto fsLoraPath = std::filesystem::path(loraEntry.path()); + if (fsLoraPath.is_relative()) { + info.path = (std::filesystem::path(graphPath) / fsLoraPath).string(); + } else { + info.path = fsLoraPath.string(); + } + info.alpha = loraEntry.alpha(); + args.loraAdapters.push_back(std::move(info)); + } + + for (int i = 0; i < nodeOptions.composite_lora_adapters_size(); ++i) { + const auto& compositeEntry = nodeOptions.composite_lora_adapters(i); + std::vector> components; + for (int j = 0; j < compositeEntry.components_size(); ++j) { + const auto& comp = compositeEntry.components(j); + components.emplace_back(comp.adapter_alias(), comp.weight()); + } + args.compositeLoraAdapters.emplace(compositeEntry.alias(), std::move(components)); + } + return std::move(args); } } // namespace ovms diff --git a/src/image_gen/imagegenpipelineargs.hpp b/src/image_gen/imagegenpipelineargs.hpp index 25d46860ac..4c446ce070 100644 --- a/src/image_gen/imagegenpipelineargs.hpp +++ b/src/image_gen/imagegenpipelineargs.hpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -39,6 +40,12 @@ struct StaticReshapeSettingsArgs { guidanceScale(guidance) {} }; +struct LoraAdapterInfo { + std::string alias; + std::string path; // absolute path to .safetensors file + float alpha = 1.0f; +}; + struct ImageGenPipelineArgs { std::string modelsPath; std::vector device; @@ -51,5 +58,9 @@ struct ImageGenPipelineArgs { uint64_t maxNumInferenceSteps; std::optional staticReshapeSettings; + std::vector loraAdapters; + // Maps a composite alias to its component (adapter alias, weight) pairs. + using CompositeLoraMap = std::unordered_map>>; + CompositeLoraMap compositeLoraAdapters; }; } // namespace ovms diff --git a/src/image_gen/imagegenutils.cpp b/src/image_gen/imagegenutils.cpp index 2235e9fa44..df0b8ba221 100644 --- a/src/image_gen/imagegenutils.cpp +++ b/src/image_gen/imagegenutils.cpp @@ -414,7 +414,8 @@ std::variant getImageGenerationRequestOptions(const ra "size", "height", "width", "n", "num_images_per_prompt", "response_format", // allowed, however only b64_json is supported - "num_inference_steps", "rng_seed", "strength", "guidance_scale", "max_sequence_length", "model"}; + "num_inference_steps", "rng_seed", "strength", "guidance_scale", "max_sequence_length", "model", + "lora_weights"}; // per-request LoRA weight overrides for (auto it = parser.MemberBegin(); it != parser.MemberEnd(); ++it) { if (acceptedFields.find(it->name.GetString()) == acceptedFields.end()) { return absl::InvalidArgumentError(absl::StrCat("Unhandled parameter: ", it->name.GetString())); @@ -532,7 +533,8 @@ std::variant getImageEditRequestOptions(const ovms::Mu "size", "height", "width", "n", "num_images_per_prompt", "response_format", // allowed, however only b64_json is supported - "num_inference_steps", "rng_seed", "strength", "guidance_scale", "max_sequence_length", "model"}; + "num_inference_steps", "rng_seed", "strength", "guidance_scale", "max_sequence_length", "model", + "lora_weights"}; // per-request LoRA weight overrides auto fieldNames = parser.getAllFieldNames(); for (const auto& fieldName : fieldNames) { if (acceptedFields.find(fieldName) == acceptedFields.end()) { diff --git a/src/image_gen/pipelines.cpp b/src/image_gen/pipelines.cpp index 65071fef60..6560e5aa35 100644 --- a/src/image_gen/pipelines.cpp +++ b/src/image_gen/pipelines.cpp @@ -22,6 +22,7 @@ #include #include "src/logging.hpp" +#include "src/stringutils.hpp" namespace ovms { @@ -30,7 +31,8 @@ namespace ovms { template static void reshapeAndCompile(PipelineT& pipeline, const ImageGenPipelineArgs& args, - const std::vector& device) { + const std::vector& device, + const ov::AnyMap& properties) { if (args.staticReshapeSettings.has_value() && args.staticReshapeSettings.value().resolution.size() == 1) { auto numImagesPerPrompt = args.staticReshapeSettings.value().numImagesPerPrompt.value_or(ov::genai::ImageGenerationConfig().num_images_per_prompt); auto guidanceScale = args.staticReshapeSettings.value().guidanceScale.value_or(ov::genai::ImageGenerationConfig().guidance_scale); @@ -47,10 +49,10 @@ static void reshapeAndCompile(PipelineT& pipeline, if (device.size() == 1) { SPDLOG_DEBUG("Image Generation Pipeline compiling to device: {}", device[0]); - pipeline.compile(device[0], args.pluginConfig); + pipeline.compile(device[0], properties); } else { SPDLOG_DEBUG("Image Generation Pipeline compiling to devices: text_encode={} denoise={} vae={}", device[0], device[1], device[2]); - pipeline.compile(device[0], device[1], device[2], args.pluginConfig); + pipeline.compile(device[0], device[1], device[2], properties); } } @@ -65,6 +67,36 @@ ImageGenerationPipelines::ImageGenerationPipelines(const ImageGenPipelineArgs& a SPDLOG_DEBUG("Image Generation Pipelines weights loading from: {}", args.modelsPath); + // --- Load LoRA adapters before pipeline compilation --- + // Adapters must be registered at compile time so that the AdapterController + // is initialized and can apply/disable them at inference time. + for (const auto& loraInfo : args.loraAdapters) { + SPDLOG_INFO("Loading LoRA adapter: {} from: {}", loraInfo.alias, loraInfo.path); + try { + loraAdapters.emplace(loraInfo.alias, ov::genai::Adapter(loraInfo.path)); + SPDLOG_INFO("LoRA adapter loaded: {}", loraInfo.alias); + } catch (const std::exception& e) { + throw std::runtime_error("Failed to load LoRA adapter '" + loraInfo.alias + "' from " + loraInfo.path + ": " + e.what()); + } + } + + // Build compile-time adapter properties so the pipeline's AdapterController + // knows about all adapters. At generate time we select which to activate. + ov::AnyMap compileProperties = args.pluginConfig; + if (!loraAdapters.empty()) { + ov::genai::AdapterConfig adapterConfig; + for (const auto& [alias, adapter] : loraAdapters) { + adapterConfig.add(adapter, 1.0f); + } + compileProperties.insert(ov::genai::adapters(adapterConfig)); + } + + // Populate composite LoRA map from args + compositeLoraAdapters = args.compositeLoraAdapters; + for (const auto& [alias, components] : compositeLoraAdapters) { + SPDLOG_INFO("Registered composite LoRA adapter: {} with {} components", alias, components.size()); + } + // Pipeline construction strategy: // Preferred chain (weight-sharing, single model load): // INP(disk) → reshape+compile → I2I(INP) → T2I(I2I) @@ -78,7 +110,7 @@ ImageGenerationPipelines::ImageGenerationPipelines(const ImageGenPipelineArgs& a // --- Step 1: InpaintingPipeline from disk --- try { inpaintingPipeline = std::make_unique(args.modelsPath); - reshapeAndCompile(*inpaintingPipeline, args, device); + reshapeAndCompile(*inpaintingPipeline, args, device, compileProperties); SPDLOG_DEBUG("InpaintingPipeline created from disk"); } catch (const std::exception& e) { SPDLOG_WARN("Failed to create InpaintingPipeline from disk: {}", e.what()); @@ -97,7 +129,7 @@ ImageGenerationPipelines::ImageGenerationPipelines(const ImageGenPipelineArgs& a if (!image2ImagePipeline) { try { image2ImagePipeline = std::make_unique(args.modelsPath); - reshapeAndCompile(*image2ImagePipeline, args, device); + reshapeAndCompile(*image2ImagePipeline, args, device, compileProperties); SPDLOG_DEBUG("Image2ImagePipeline created from disk (fallback)"); } catch (const std::exception& e) { SPDLOG_WARN("Failed to create Image2ImagePipeline from disk: {}", e.what()); @@ -125,7 +157,7 @@ ImageGenerationPipelines::ImageGenerationPipelines(const ImageGenPipelineArgs& a if (!text2ImagePipeline) { try { text2ImagePipeline = std::make_unique(args.modelsPath); - reshapeAndCompile(*text2ImagePipeline, args, device); + reshapeAndCompile(*text2ImagePipeline, args, device, compileProperties); SPDLOG_DEBUG("Text2ImagePipeline created from disk (fallback)"); } catch (const std::exception& e) { SPDLOG_WARN("Failed to create Text2ImagePipeline from disk: {}", e.what()); @@ -144,9 +176,10 @@ ImageGenerationPipelines::ImageGenerationPipelines(const ImageGenPipelineArgs& a inpaintingQueue = std::make_unique>(1); } - SPDLOG_INFO("Image Generation Pipelines ready — T2I: {} | I2I: {} | INP: {}", + SPDLOG_INFO("Image Generation Pipelines ready — T2I: {} | I2I: {} | INP: {} | LoRAs: {}", text2ImagePipeline ? "OK" : "N/A", image2ImagePipeline ? "OK" : "N/A", - inpaintingPipeline ? "OK" : "N/A"); + inpaintingPipeline ? "OK" : "N/A", + loraAdapters.size()); } } // namespace ovms diff --git a/src/image_gen/pipelines.hpp b/src/image_gen/pipelines.hpp index cda14396a7..7f98563b39 100644 --- a/src/image_gen/pipelines.hpp +++ b/src/image_gen/pipelines.hpp @@ -17,10 +17,14 @@ #include #include +#include +#include +#include #include #include #include +#include #include "imagegenpipelineargs.hpp" #include "src/queue.hpp" @@ -28,19 +32,19 @@ namespace ovms { // RAII guard that acquires a slot from a Queue(1) on construction -// and returns it on destruction, serializing concurrent inpainting requests. -class InpaintingQueueGuard { +// and returns it on destruction, serializing concurrent pipeline access. +class PipelineSlotGuard { public: - // Blocks until an inpainting slot becomes available. - explicit InpaintingQueueGuard(Queue& queue) : + // Blocks until a pipeline slot becomes available. + explicit PipelineSlotGuard(Queue& queue) : queue_(queue), streamId_(queue_.getIdleStream().get()) {} - ~InpaintingQueueGuard() { + ~PipelineSlotGuard() { queue_.returnStream(streamId_); } - InpaintingQueueGuard(const InpaintingQueueGuard&) = delete; - InpaintingQueueGuard& operator=(const InpaintingQueueGuard&) = delete; + PipelineSlotGuard(const PipelineSlotGuard&) = delete; + PipelineSlotGuard& operator=(const PipelineSlotGuard&) = delete; private: Queue& queue_; @@ -51,6 +55,9 @@ struct ImageGenerationPipelines { std::unique_ptr image2ImagePipeline; std::unique_ptr text2ImagePipeline; std::unique_ptr inpaintingPipeline; + std::unordered_map loraAdapters; // alias -> loaded adapter + // composite alias -> [(component adapter alias, weight)] + std::unordered_map>> compositeLoraAdapters; ImageGenPipelineArgs args; // Serializes concurrent inpainting requests (InpaintingPipeline lacks clone()). diff --git a/src/mediapipe_internal/mediapipefactory.cpp b/src/mediapipe_internal/mediapipefactory.cpp index aa3689ae31..86bc05a3ae 100644 --- a/src/mediapipe_internal/mediapipefactory.cpp +++ b/src/mediapipe_internal/mediapipefactory.cpp @@ -74,6 +74,12 @@ Status MediapipeFactory::createDefinition(const std::string& pipelineName, } std::unique_lock lock(definitionsMtx); definitions.insert({pipelineName, std::move(graphDefinition)}); + // Register LoRA aliases discovered during validation (image gen graphs) + auto* def = definitions[pipelineName].get(); + for (const auto& alias : def->getLoraAliases()) { + loraAliases[alias] = pipelineName; + SPDLOG_LOGGER_INFO(modelmanager_logger, "Registered LoRA alias: {} -> {}", alias, pipelineName); + } return stat; } @@ -86,6 +92,14 @@ MediapipeGraphDefinition* MediapipeFactory::findDefinitionByName(const std::stri std::shared_lock lock(definitionsMtx); auto it = definitions.find(name); if (it == std::end(definitions)) { + // Check LoRA aliases + auto aliasIt = loraAliases.find(name); + if (aliasIt != loraAliases.end()) { + it = definitions.find(aliasIt->second); + if (it != std::end(definitions)) { + return it->second.get(); + } + } return nullptr; } else { return it->second.get(); @@ -109,6 +123,13 @@ Status MediapipeFactory::create(std::unique_ptr& pipelin ModelManager& manager) const { std::shared_lock lock(definitionsMtx); auto it = definitions.find(name); + if (it == definitions.end()) { + // Check LoRA aliases + auto aliasIt = loraAliases.find(name); + if (aliasIt != loraAliases.end()) { + it = definitions.find(aliasIt->second); + } + } if (it == definitions.end()) { SPDLOG_LOGGER_DEBUG(dag_executor_logger, "Mediapipe with requested name: {} does not exist", name); return StatusCode::MEDIAPIPE_DEFINITION_NAME_MISSING; @@ -149,8 +170,33 @@ const std::vector MediapipeFactory::getNamesOfAvailableMediapipePip names.push_back(definition->getName()); } } + // Add LoRA aliases that point to available definitions + for (const auto& [alias, graphName] : loraAliases) { + auto it = definitions.find(graphName); + if (it != definitions.end() && it->second->getStatus().isAvailable()) { + names.push_back(alias); + } + } return names; } +void MediapipeFactory::registerLoraAlias(const std::string& alias, const std::string& graphName) { + std::unique_lock lock(definitionsMtx); + loraAliases[alias] = graphName; + SPDLOG_LOGGER_INFO(modelmanager_logger, "Registered LoRA alias: {} -> {}", alias, graphName); +} + +void MediapipeFactory::clearLoraAliases(const std::string& graphName) { + std::unique_lock lock(definitionsMtx); + for (auto it = loraAliases.begin(); it != loraAliases.end();) { + if (it->second == graphName) { + SPDLOG_LOGGER_DEBUG(modelmanager_logger, "Removing LoRA alias: {} -> {}", it->first, graphName); + it = loraAliases.erase(it); + } else { + ++it; + } + } +} + MediapipeFactory::~MediapipeFactory() = default; } // namespace ovms diff --git a/src/mediapipe_internal/mediapipefactory.hpp b/src/mediapipe_internal/mediapipefactory.hpp index e48146b0f0..cf5c6f6b77 100644 --- a/src/mediapipe_internal/mediapipefactory.hpp +++ b/src/mediapipe_internal/mediapipefactory.hpp @@ -45,6 +45,7 @@ class PythonBackend; class MediapipeFactory { std::map> definitions; + std::map loraAliases; // alias -> real graph definition name mutable std::shared_mutex definitionsMtx; PythonBackend* pythonBackend{nullptr}; @@ -71,6 +72,8 @@ class MediapipeFactory { ModelManager& manager) const; MediapipeGraphDefinition* findDefinitionByName(const std::string& name) const; + void registerLoraAlias(const std::string& alias, const std::string& graphName); + void clearLoraAliases(const std::string& graphName); Status reloadDefinition(const std::string& pipelineName, const MediapipeGraphConfig& config, ModelManager& manager); diff --git a/src/mediapipe_internal/mediapipegraphdefinition.cpp b/src/mediapipe_internal/mediapipegraphdefinition.cpp index 9047765e75..ec8c9715e7 100644 --- a/src/mediapipe_internal/mediapipegraphdefinition.cpp +++ b/src/mediapipe_internal/mediapipegraphdefinition.cpp @@ -511,6 +511,11 @@ Status MediapipeGraphDefinition::initializeNodes() { } imageGenPipelinesMap.insert(std::pair>(nodeName, std::move(servable))); guard.disableCleaning(); + // Register LoRA aliases for routing + const auto& loraAdapters = std::get(statusOrArgs).loraAdapters; + for (const auto& adapter : loraAdapters) { + this->loraAliases_.push_back(adapter.alias); + } } if (endsWith(config.node(i).calculator(), EMBEDDINGS_NODE_CALCULATOR_NAME)) { auto& embeddingsServableMap = this->sidePacketMaps.embeddingsServableMap; diff --git a/src/mediapipe_internal/mediapipegraphdefinition.hpp b/src/mediapipe_internal/mediapipegraphdefinition.hpp index 14c9e0679f..755ee56869 100644 --- a/src/mediapipe_internal/mediapipegraphdefinition.hpp +++ b/src/mediapipe_internal/mediapipegraphdefinition.hpp @@ -112,6 +112,7 @@ class MediapipeGraphDefinition { const PipelineDefinitionStatus& getStatus() const { return this->status; } + const std::vector& getLoraAliases() const { return loraAliases_; } const PipelineDefinitionStateCode getStateCode() const { return status.getStateCode(); } const model_version_t getVersion() const { return VERSION; } @@ -204,6 +205,8 @@ class MediapipeGraphDefinition { std::vector outputNames; std::vector inputSidePacketNames; + std::vector loraAliases_; + std::atomic requestsHandlesCounter = 0; PythonBackend* pythonBackend; diff --git a/src/pull_module/BUILD b/src/pull_module/BUILD index 7b6b0b0588..e22e2d5a3f 100644 --- a/src/pull_module/BUILD +++ b/src/pull_module/BUILD @@ -54,21 +54,32 @@ ovms_cc_library( ], visibility = ["//visibility:public"], ) +ovms_cc_library( + name = "curl_downloader", + srcs = ["curl_downloader.cpp"], + hdrs = ["curl_downloader.hpp"], + deps = [ + "//third_party:curl", + "@ovms//src:libovmslogging", + "@ovms//src:libovmsstatus", + "@ovms//src:libovms_version", + ], + visibility = ["//visibility:public"], +) + ovms_cc_library( name = "gguf_downloader", srcs = ["gguf_downloader.cpp"], hdrs = ["gguf_downloader.hpp"], deps = [ + ":curl_downloader", ":model_downloader", - "//third_party:curl", - "@nlohmann_json//:json", "@ovms//src:libovmslogging", "@ovms//src:libovmsstatus", "@ovms//src:libovmsstring_utils", "@ovms//src:libovms_server_settings", "@ovms//src:libovmsfilesystem", "@ovms//src:libovmslocalfilesystem", - "@ovms//src:libovms_version", ], visibility = ["//visibility:public"], ) @@ -96,6 +107,7 @@ ovms_cc_library( srcs = ["hf_pull_model_module.cpp"], hdrs = ["hf_pull_model_module.hpp"], deps = [ + ":curl_downloader", ":libgit2", "gguf_downloader", ":optimum_export", @@ -105,6 +117,9 @@ ovms_cc_library( "@ovms//src:libovmslogging", "@ovms//src:libovms_server_settings", "@ovms//src:libovms_module", + "@ovms//src:libovms_version", + "//third_party:curl", + "@nlohmann_json//:json", ], visibility = ["//visibility:public"], ) diff --git a/src/pull_module/curl_downloader.cpp b/src/pull_module/curl_downloader.cpp new file mode 100644 index 0000000000..e7ff643c2b --- /dev/null +++ b/src/pull_module/curl_downloader.cpp @@ -0,0 +1,205 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include "curl_downloader.hpp" + +#include +#include +#include +#include +#include + +#include +#include + +#include "src/logging.hpp" +#include "src/status.hpp" +#include "src/version.hpp" + +namespace ovms { + +static const char* sizeUnits[] = {"B", "KB", "MB", "GB", "TB", NULL}; + +static void print_download_speed_info(size_t received_size, size_t elapsed_time) { + double recv_len = (double)received_size; + uint64_t elapsed = (uint64_t)elapsed_time; + double rate; + rate = elapsed ? recv_len / elapsed : received_size; + + size_t rate_unit_idx = 0; + while (rate > 1000 && sizeUnits[rate_unit_idx + 1]) { + rate /= 1000.0; + rate_unit_idx++; + } + printf(" [%.2f %s/s] ", rate, sizeUnits[rate_unit_idx]); +} + +static void print_progress(size_t count, size_t max, bool first_run, size_t elapsed_time) { + float progress = (float)count / max; + if (!first_run && progress < 0.01 && count > 0) + return; + + const int bar_width = 50; + int bar_length = progress * bar_width; + + printf("\rProgress: ["); + int i; + for (i = 0; i < bar_length; ++i) { + printf("#"); + } + for (i = bar_length; i < bar_width; ++i) { + printf(" "); + } + size_t totalSizeUnitId = 0; + double totalSize = max; + while (totalSize > 1000 && sizeUnits[totalSizeUnitId + 1]) { + totalSize /= 1000.0; + totalSizeUnitId++; + } + printf("] %.2f%% of %.2f %s", progress * 100, totalSize, sizeUnits[totalSizeUnitId]); + print_download_speed_info(count, elapsed_time); + if (progress == 1.0) + printf("\n"); + fflush(stdout); +} + +struct CurlDownloadFile { + const char* filename; + FILE* stream; + CurlDownloadFile() = delete; + CurlDownloadFile(const CurlDownloadFile&) = delete; + CurlDownloadFile& operator=(const CurlDownloadFile&) = delete; + CurlDownloadFile(const char* fname, FILE* str) : + filename(fname), + stream(str) {} + ~CurlDownloadFile() { + if (stream) { + fclose(stream); + } + if (!success) { + std::filesystem::remove(filename); + } + } + bool success = false; +}; + +static size_t file_write_callback(void* buffer, size_t size, size_t nmemb, void* stream) { + CurlDownloadFile* out = static_cast(stream); + if (!out->stream) { + out->stream = fopen(out->filename, "wb"); + if (!out->stream) { + fprintf(stderr, "failure, cannot open file to write: %s\n", + out->filename); + return 0; + } + } + return fwrite(buffer, size, nmemb, out->stream); +} + +#define CHECK_CURL_CALL(call) \ + do { \ + CURLcode curlCode = call; \ + if (curlCode != CURLE_OK) { \ + SPDLOG_ERROR("curl error: {}. Error code: {}", curl_easy_strerror(curlCode), (int)curlCode); \ + return StatusCode::INTERNAL_ERROR; \ + } \ + } while (0) + +struct ProgressData { + time_t started_download; + time_t last_print_time; + bool fullDownloadPrinted = false; +}; + +static int progress_callback(void* clientp, + curl_off_t dltotal, + curl_off_t dlnow, + curl_off_t ultotal, + curl_off_t ulnow) { + ProgressData* pcs = reinterpret_cast(clientp); + if (dlnow == 0) { + pcs->started_download = time(NULL); + pcs->last_print_time = time(NULL); + } + time_t currentTime = time(NULL); + bool shouldPrintDueToTime = (currentTime - pcs->last_print_time >= 1); + if ((dltotal == dlnow) && dltotal < 10000) { + return 0; + } + if (pcs->fullDownloadPrinted) { + return 0; + } + if (!shouldPrintDueToTime && (dltotal != dlnow)) { + return 0; + } + pcs->fullDownloadPrinted = (dltotal == dlnow); + pcs->last_print_time = currentTime; + print_progress(dlnow, dltotal, (dlnow == 0), currentTime - pcs->started_download); + std::cout.flush(); + return 0; +} + +Status downloadFileWithCurl(const std::string& url, const std::string& filePath) { + return downloadFileWithCurl(url, filePath, ""); +} + +Status downloadFileWithCurl(const std::string& url, const std::string& filePath, const std::string& authTokenHF) { + std::string agentString = std::string(PROJECT_NAME) + "/" + std::string(PROJECT_VERSION); + + CURL* curl = nullptr; + CHECK_CURL_CALL(curl_global_init(CURL_GLOBAL_DEFAULT)); + auto globalCurlGuard = std::unique_ptr( + nullptr, [](void*) { curl_global_cleanup(); }); + curl = curl_easy_init(); + if (!curl) { + SPDLOG_ERROR("Failed to initialize cURL."); + return StatusCode::INTERNAL_ERROR; + } + auto handleGuard = std::unique_ptr(curl, curl_easy_cleanup); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_URL, url.c_str())); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, file_write_callback)); + CurlDownloadFile downloadFile{filePath.c_str(), NULL}; + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_WRITEDATA, &downloadFile)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_USERAGENT, agentString.c_str())); + struct curl_slist* headers = nullptr; + std::string authHeader; + if (!authTokenHF.empty()) { + authHeader = "Authorization: Bearer " + authTokenHF; + headers = curl_slist_append(headers, authHeader.c_str()); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers)); + } + auto headersGuard = std::unique_ptr(headers, curl_slist_free_all); + ProgressData progressData; + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, progress_callback)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &progressData)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_USE_SSL, CURLUSESSL_ALL)); + CHECK_CURL_CALL(curl_easy_perform(curl)); + int32_t http_code = 0; + CHECK_CURL_CALL(curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code)); + SPDLOG_TRACE("HTTP response code: {}", http_code); + if (http_code != 200) { + SPDLOG_ERROR("Failed to download file from URL: {} HTTP response code: {}", url, http_code); + return StatusCode::PATH_INVALID; + } + downloadFile.success = true; + return StatusCode::OK; +} + +#undef CHECK_CURL_CALL + +} // namespace ovms diff --git a/src/pull_module/curl_downloader.hpp b/src/pull_module/curl_downloader.hpp new file mode 100644 index 0000000000..9232dbeb6b --- /dev/null +++ b/src/pull_module/curl_downloader.hpp @@ -0,0 +1,25 @@ +#pragma once +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include + +namespace ovms { +class Status; + +Status downloadFileWithCurl(const std::string& url, const std::string& filePath); +Status downloadFileWithCurl(const std::string& url, const std::string& filePath, const std::string& authTokenHF); + +} // namespace ovms diff --git a/src/pull_module/gguf_downloader.cpp b/src/pull_module/gguf_downloader.cpp index 1a6c5355fd..0952eb05a4 100644 --- a/src/pull_module/gguf_downloader.cpp +++ b/src/pull_module/gguf_downloader.cpp @@ -15,23 +15,18 @@ //***************************************************************************** #include "gguf_downloader.hpp" -#include -#include #include +#include +#include #include -#include -#include - #include "../capi_frontend/server_settings.hpp" #include "../filesystem.hpp" #include "../localfilesystem.hpp" #include "../logging.hpp" -#include "../stringutils.hpp" #include "../status.hpp" -#include "../version.hpp" - -#include +#include "../stringutils.hpp" +#include "curl_downloader.hpp" namespace ovms { @@ -128,177 +123,6 @@ Status GGUFDownloader::downloadModel() { return StatusCode::OK; } -static const char* sizeUnits[] = {"B", "KB", "MB", "GB", "TB", NULL}; -static void print_download_speed_info(size_t received_size, size_t elapsed_time) { - double recv_len = (double)received_size; - uint64_t elapsed = (uint64_t)elapsed_time; - double rate; - rate = elapsed ? recv_len / elapsed : received_size; - - size_t rate_unit_idx = 0; - while (rate > 1000 && sizeUnits[rate_unit_idx + 1]) { - rate /= 1000.0; - rate_unit_idx++; - } - printf(" [%.2f %s/s] ", rate, sizeUnits[rate_unit_idx]); -} - -void print_progress(size_t count, size_t max, bool first_run, size_t elapsed_time) { - float progress = (float)count / max; - if (!first_run && progress < 0.01 && count > 0) - return; - - const int bar_width = 50; - int bar_length = progress * bar_width; - - printf("\rProgress: ["); - int i; - for (i = 0; i < bar_length; ++i) { - printf("#"); - } - for (i = bar_length; i < bar_width; ++i) { - printf(" "); - } - size_t totalSizeUnitId = 0; - double totalSize = max; - while (totalSize > 1000 && sizeUnits[totalSizeUnitId + 1]) { - totalSize /= 1000.0; - totalSizeUnitId++; - } - printf("] %.2f%% of %.2f %s", progress * 100, totalSize, sizeUnits[totalSizeUnitId]); - print_download_speed_info(count, elapsed_time); - if (progress == 1.0) - printf("\n"); - fflush(stdout); -} - -struct FtpFile { - const char* filename; - FILE* stream; - FtpFile() = delete; - FtpFile(const FtpFile&) = delete; - FtpFile& operator=(const FtpFile&) = delete; - FtpFile(const char* fname, FILE* str) : - filename(fname), - stream(str) {} - ~FtpFile() { - if (stream) { - fclose(stream); - } - if (!success) { - std::filesystem::remove(filename); - } - } - bool success = false; -}; - -void fileClose(FILE* file) { - if (file) { - fclose(file); - } -} - -static size_t file_write_callback(void* buffer, size_t size, size_t nmemb, void* stream) { - struct FtpFile* out = (struct FtpFile*)stream; - if (!out->stream) { - out->stream = fopen(out->filename, "wb"); - if (!out->stream) { - fprintf(stderr, "failure, cannot open file to write: %s\n", - out->filename); - return 0; - } - } - return fwrite(buffer, size, nmemb, out->stream); -} - -#define CHECK_CURL_CALL(call) \ - do { \ - CURLcode curlCode = call; \ - if (curlCode != CURLE_OK) { \ - SPDLOG_ERROR("curl error: {}. Error code: {}", curl_easy_strerror(curlCode), (int)curlCode); \ - return StatusCode::INTERNAL_ERROR; \ - } \ - } while (0) - -struct ProgressData { - time_t started_download; - time_t last_print_time; - bool fullDownloadPrinted = false; -}; -int progress_callback(void* clientp, - curl_off_t dltotal, - curl_off_t dlnow, - curl_off_t ultotal, - curl_off_t ulnow) { - ProgressData* pcs = reinterpret_cast(clientp); - if (dlnow == 0) { - pcs->started_download = time(NULL); - pcs->last_print_time = time(NULL); - } - time_t currentTime = time(NULL); - bool shouldPrintDueToTime = (currentTime - pcs->last_print_time >= 1); - if ((dltotal == dlnow) && dltotal < 10000) { - // Usually with first messages we don't get the full size and we don't want to print progress bar - // so we assume that until dltotal is less than 1000 we don't have full size - // otherwise we would print 100% progress bar - return 0; - } - // called multiple times, so we want to print progress bar only once reached 100% - if (pcs->fullDownloadPrinted) { - return 0; - } - if (!shouldPrintDueToTime && (dltotal != dlnow)) { - // we dont want to skip printing progress bar for the 100% but we don't want to spam stdout either - return 0; - } - pcs->fullDownloadPrinted = (dltotal == dlnow); - pcs->last_print_time = currentTime; - print_progress(dlnow, dltotal, (dlnow == 0), currentTime - pcs->started_download); - std::cout.flush(); - return 0; -} - -static Status downloadSingleFileWithCurl(const std::string& filePath, const std::string& url) { - // agent string required to avoid 403 Forbidden error on modelscope - std::string agentString = std::string(PROJECT_NAME) + "/" + std::string(PROJECT_VERSION); - - CURL* curl = nullptr; - CHECK_CURL_CALL(curl_global_init(CURL_GLOBAL_DEFAULT)); - auto globalCurlGuard = std::unique_ptr( - nullptr, [](void*) { curl_global_cleanup(); }); - curl = curl_easy_init(); - if (!curl) { - SPDLOG_ERROR("Failed to initialize cURL."); - return StatusCode::INTERNAL_ERROR; - } - auto handleGuard = std::unique_ptr(curl, curl_easy_cleanup); - // set impl options - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_URL, url.c_str())); - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, file_write_callback)); - struct FtpFile ftpFile = {filePath.c_str(), NULL}; - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_WRITEDATA, &ftpFile)); - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_USERAGENT, agentString.c_str())); - // progress bar options - ProgressData progressData; - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L)); - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, progress_callback)); - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &progressData)); - // other options - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA)); - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L)); - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_USE_SSL, CURLUSESSL_ALL)); - CHECK_CURL_CALL(curl_easy_perform(curl)); - int32_t http_code = 0; - CHECK_CURL_CALL(curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code)); - SPDLOG_TRACE("HTTP response code: {}", http_code); - if (http_code != 200) { - SPDLOG_ERROR("Failed to download file from URL: {} HTTP response code: {}", url, http_code); - return StatusCode::PATH_INVALID; - } - ftpFile.success = true; - return StatusCode::OK; -} - std::variant> GGUFDownloader::createGGUFFilenamesToDownload(const std::string& ggufFilename) { std::vector filesToDownload; // we need to check if ggufFilename is of multipart type (contains 00001-of-N string) @@ -364,7 +188,7 @@ Status GGUFDownloader::downloadWithCurl(const std::string& hfEndpoint, const std // construct filepath auto filePath = FileSystem::joinPath({downloadPath, file}); SPDLOG_DEBUG("Downloading part {}/{} filename: {} url:{}", partNo, filesToDownload.size(), file, url); - auto status = downloadSingleFileWithCurl(filePath, url); + auto status = downloadFileWithCurl(url, filePath); if (!status.ok()) { return status; } diff --git a/src/pull_module/hf_pull_model_module.cpp b/src/pull_module/hf_pull_model_module.cpp index b73cad6638..a3891966a0 100644 --- a/src/pull_module/hf_pull_model_module.cpp +++ b/src/pull_module/hf_pull_model_module.cpp @@ -16,19 +16,27 @@ #include "hf_pull_model_module.hpp" #include +#include #include #include #include +#include + +#include +#include #include "../config.hpp" +#include "../filesystem.hpp" #include "libgit2.hpp" #include "optimum_export.hpp" +#include "curl_downloader.hpp" #include "gguf_downloader.hpp" #include "../graph_export/graph_export.hpp" #include "../logging.hpp" #include "../module_names.hpp" #include "../status.hpp" #include "../stringutils.hpp" +#include "../version.hpp" namespace ovms { const std::string DEFAULT_EMPTY_ENV_VALUE{""}; @@ -110,7 +118,144 @@ Status HfPullModelModule::start(const ovms::Config& config) { return StatusCode::OK; } -Status HfPullModelModule::clone() const { +Status HfPullModelModule::resolveHfLoraFilenames() { + if (!std::holds_alternative(this->hfSettings.graphSettings)) { + return StatusCode::OK; + } + auto& graphSettings = std::get(this->hfSettings.graphSettings); + for (auto& adapter : graphSettings.loraAdapters) { + if (adapter.sourceType != LoraSourceType::HF_REPO) { + continue; + } + if (!adapter.safetensorsFile.empty()) { + continue; + } + // Query HF API to find the .safetensors file in the LoRA repo + std::string apiUrl = this->GetHfEndpoint() + "api/models/" + adapter.sourceLora; + SPDLOG_DEBUG("Querying HF API for LoRA adapter files: {}", apiUrl); + std::string agentString = std::string(PROJECT_NAME) + "/" + std::string(PROJECT_VERSION); + std::string responseBody; + CURL* curl = curl_easy_init(); + if (!curl) { + SPDLOG_ERROR("Failed to initialize cURL for HF API query"); + return StatusCode::INTERNAL_ERROR; + } + auto handleGuard = std::unique_ptr(curl, curl_easy_cleanup); + auto writeCallback = +[](void* buffer, size_t size, size_t nmemb, void* userData) -> size_t { + auto* body = static_cast(userData); + body->append(static_cast(buffer), size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl, CURLOPT_URL, apiUrl.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &responseBody); + curl_easy_setopt(curl, CURLOPT_USERAGENT, agentString.c_str()); + curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + std::string hfToken = this->GetHfToken(); + struct curl_slist* headers = nullptr; + if (!hfToken.empty()) { + std::string authHeader = "Authorization: Bearer " + hfToken; + headers = curl_slist_append(headers, authHeader.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + } + CURLcode res = curl_easy_perform(curl); + if (headers) { + curl_slist_free_all(headers); + } + if (res != CURLE_OK) { + SPDLOG_ERROR("cURL error querying HF API for LoRA {}: {}", adapter.sourceLora, curl_easy_strerror(res)); + return StatusCode::INTERNAL_ERROR; + } + int32_t httpCode = 0; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &httpCode); + if (httpCode != 200) { + SPDLOG_ERROR("HF API returned HTTP {} for LoRA adapter: {}", httpCode, adapter.sourceLora); + return StatusCode::PATH_INVALID; + } + // Parse JSON response to find .safetensors files in siblings array + // Example: { "siblings": [{"rfilename": "file1.safetensors"}, ...] } + try { + auto json = nlohmann::json::parse(responseBody); + std::vector safetensorsFiles; + if (json.contains("siblings") && json["siblings"].is_array()) { + for (const auto& sibling : json["siblings"]) { + if (sibling.contains("rfilename") && sibling["rfilename"].is_string()) { + const std::string& filename = sibling["rfilename"].get_ref(); + if (endsWith(filename, ".safetensors")) { + safetensorsFiles.push_back(filename); + } + } + } + } + if (safetensorsFiles.empty()) { + SPDLOG_ERROR("No .safetensors files found via HF API for LoRA adapter: {}", adapter.sourceLora); + return StatusCode::PATH_INVALID; + } + if (safetensorsFiles.size() > 1) { + SPDLOG_ERROR("Multiple .safetensors files found for LoRA adapter: {}. Use @filename to specify.", adapter.sourceLora); + return StatusCode::PATH_INVALID; + } + adapter.safetensorsFile = safetensorsFiles[0]; + SPDLOG_DEBUG("Resolved LoRA safetensors file for {}: {}", adapter.sourceLora, adapter.safetensorsFile); + } catch (const nlohmann::json::exception& e) { + SPDLOG_ERROR("Failed to parse HF API JSON response for LoRA adapter {}: {}", adapter.sourceLora, e.what()); + return StatusCode::INTERNAL_ERROR; + } + } + return StatusCode::OK; +} + +Status HfPullModelModule::pullLoraAdapters(const std::string& graphDirectory) { + if (!std::holds_alternative(this->hfSettings.graphSettings)) { + return StatusCode::OK; + } + auto status = this->resolveHfLoraFilenames(); + if (!status.ok()) { + return status; + } + const auto& graphSettings = std::get(this->hfSettings.graphSettings); + for (const auto& adapter : graphSettings.loraAdapters) { + if (adapter.sourceType == LoraSourceType::LOCAL_FILE) { + std::cout << "LoRA adapter: " << adapter.alias << " using local file: " << adapter.sourceLora << std::endl; + continue; + } + std::string loraDownloadPath; + std::string loraUrl; + std::string authTokenHF; + if (adapter.sourceType == LoraSourceType::HF_REPO) { + loraDownloadPath = FileSystem::joinPath({graphDirectory, "loras", adapter.sourceLora}); + loraUrl = this->GetHfEndpoint() + adapter.sourceLora + "/resolve/main/" + adapter.safetensorsFile; + authTokenHF = this->GetHfToken(); + } else if (adapter.sourceType == LoraSourceType::DIRECT_URL) { + loraDownloadPath = FileSystem::joinPath({graphDirectory, "loras", adapter.alias}); + loraUrl = adapter.sourceLora; + } else { + SPDLOG_ERROR("Unknown LoRA source type for adapter: {}", adapter.alias); + return StatusCode::INTERNAL_ERROR; + } + auto loraFilePath = FileSystem::joinPath({loraDownloadPath, adapter.safetensorsFile}); + if (!this->hfSettings.overwriteModels && std::filesystem::exists(loraFilePath)) { + std::cout << "LoRA adapter: " << adapter.alias << " already exists, skipping download." << std::endl; + continue; + } + if (!std::filesystem::exists(loraDownloadPath)) { + if (!std::filesystem::create_directories(loraDownloadPath)) { + SPDLOG_ERROR("Failed to create LoRA directory: {}", loraDownloadPath); + return StatusCode::DIRECTORY_NOT_CREATED; + } + } + status = downloadFileWithCurl(loraUrl, loraFilePath, authTokenHF); + if (!status.ok()) { + SPDLOG_ERROR("Failed to download LoRA adapter: {} from: {}", adapter.alias, loraUrl); + return status; + } + std::cout << "LoRA adapter: " << adapter.alias << " downloaded to: " << loraDownloadPath << std::endl; + } + return StatusCode::OK; +} + +Status HfPullModelModule::clone() { std::string graphDirectory = ""; std::unique_ptr downloader; std::variant> guardOrError; @@ -150,6 +295,12 @@ Status HfPullModelModule::clone() const { std::cout << "Draft model: " << GraphExport::getDraftModelDirectoryName(graphSettings.draftModelDirName.value()) << " downloaded to: " << GraphExport::getDraftModelDirectoryPath(graphDirectory, graphSettings.draftModelDirName.value()) << std::endl; } + // Image gen with LoRA adapters case - resolve filenames and download safetensors files + status = this->pullLoraAdapters(graphDirectory); + if (!status.ok()) { + return status; + } + GraphExport graphExporter; status = graphExporter.createServableConfig(graphDirectory, this->hfSettings); if (!status.ok()) { diff --git a/src/pull_module/hf_pull_model_module.hpp b/src/pull_module/hf_pull_model_module.hpp index be42887b39..296300b474 100644 --- a/src/pull_module/hf_pull_model_module.hpp +++ b/src/pull_module/hf_pull_model_module.hpp @@ -35,10 +35,14 @@ class HfPullModelModule : public Module { Status start(const ovms::Config& config) override; void shutdown() override; - Status clone() const; + Status clone(); static const std::string GIT_SERVER_CONNECT_TIMEOUT_ENV; static const std::string GIT_SERVER_TIMEOUT_ENV; static const std::string GIT_SSL_CERT_LOCATIONS_ENV; + +protected: + Status resolveHfLoraFilenames(); + Status pullLoraAdapters(const std::string& graphDirectory); }; std::variant> createLibGitGuard(); diff --git a/src/server.cpp b/src/server.cpp index ec0a7e4b10..54e1648cb9 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -370,7 +370,7 @@ Status Server::startModules(ovms::Config& config) { if (!status.ok()) { return status; } - auto hfModule = dynamic_cast(it->second.get()); + auto hfModule = dynamic_cast(it->second.get()); status = hfModule->clone(); // Return from modules only in --pull mode or error, otherwise start the rest of modules if (config.getServerSettings().serverMode == HF_PULL_MODE || !status.ok()) diff --git a/src/stringutils.cpp b/src/stringutils.cpp index 6d9f98fe5f..60c92f951c 100644 --- a/src/stringutils.cpp +++ b/src/stringutils.cpp @@ -301,4 +301,20 @@ void escapeSpecialCharacters(std::string& text) { text = std::move(escaped); } +bool isLocalFilePath(const std::string& path) { + if (path.empty()) { + return false; + } + if (path[0] == '/') { + return true; + } + if (path.size() >= 2 && (path.substr(0, 2) == "./" || path.substr(0, 2) == ".\\")) { + return true; + } + if (path.size() >= 3 && std::isalpha(static_cast(path[0])) && path[1] == ':' && (path[2] == '\\' || path[2] == '/')) { + return true; + } + return false; +} + } // namespace ovms diff --git a/src/stringutils.hpp b/src/stringutils.hpp index 9990ce38e4..4eda3efff2 100644 --- a/src/stringutils.hpp +++ b/src/stringutils.hpp @@ -123,4 +123,6 @@ bool stringsOverlap(const std::string& lhs, const std::string& rhs); void escapeSpecialCharacters(std::string& text); +bool isLocalFilePath(const std::string& path); + } // namespace ovms diff --git a/src/test/graph_export_test.cpp b/src/test/graph_export_test.cpp index 777792e7d3..cf6ed55253 100644 --- a/src/test/graph_export_test.cpp +++ b/src/test/graph_export_test.cpp @@ -13,6 +13,8 @@ // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** +#include +#include #include #include @@ -23,6 +25,7 @@ #include "light_test_utils.hpp" #include "../capi_frontend/server_settings.hpp" #include "../graph_export/graph_export.hpp" +#include "../graph_export/image_generation_graph_cli_parser.hpp" #include "../filesystem.hpp" #include "../status.hpp" #include "../version.hpp" @@ -1103,3 +1106,482 @@ TEST_F(GraphCreationTest, pluginConfigNegative) { ASSERT_TRUE(std::holds_alternative(res)); ASSERT_EQ(std::get(res), ovms::StatusCode::PLUGIN_CONFIG_CONFLICTING_PARAMETERS); } + +// ===================== LoRA Graph Export Tests ===================== + +const std::string expectedImageGenWithOneLora = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "CPU" + lora_adapters { alias: "pokemon" path: "loras/juliensimon/sd-pokemon-lora/pytorch_lora_weights.safetensors" } + } + } +} + +)"; + +const std::string expectedImageGenWithTwoLoras = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "GPU" + max_resolution: "1024x1024" + lora_adapters { alias: "pokemon" path: "loras/juliensimon/sd-pokemon-lora/model.safetensors" } + lora_adapters { alias: "anime-style" path: "loras/org2/anime-lora/weights.safetensors" } + } + } +} + +)"; + +TEST_F(GraphCreationTest, imageGenerationWithOneLora) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + imageGenerationGraphSettings.loraAdapters.push_back({"pokemon", "juliensimon/sd-pokemon-lora", "pytorch_lora_weights.safetensors"}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenWithOneLora, removeVersionString(graphContents)) << graphContents; +} + +TEST_F(GraphCreationTest, imageGenerationWithTwoLoras) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + hfSettings.exportSettings.targetDevice = "GPU"; + imageGenerationGraphSettings.maxResolution = "1024x1024"; + imageGenerationGraphSettings.loraAdapters.push_back({"pokemon", "juliensimon/sd-pokemon-lora", "model.safetensors"}); + imageGenerationGraphSettings.loraAdapters.push_back({"anime-style", "org2/anime-lora", "weights.safetensors"}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenWithTwoLoras, removeVersionString(graphContents)) << graphContents; +} + +TEST_F(GraphCreationTest, imageGenerationNoLorasRemainsUnchanged) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenerationGraphContentsDefault, removeVersionString(graphContents)) << graphContents; +} + +// ===================== LoRA CLI-to-Settings Tests ===================== + +TEST(ImageGenCLILoraParsingTest, SingleLoraWithAlias) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=juliensimon/sd-pokemon-lora"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, "juliensimon/sd-pokemon-lora"); + EXPECT_TRUE(graphSettings.loraAdapters[0].safetensorsFile.empty()); +} + +TEST(ImageGenCLILoraParsingTest, MissingAliasThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "juliensimon/sd-pokemon-lora"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, SingleLoraWithAliasAndFilename) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=juliensimon/sd-pokemon-lora@custom_lora.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, "juliensimon/sd-pokemon-lora"); + EXPECT_EQ(graphSettings.loraAdapters[0].safetensorsFile, "custom_lora.safetensors"); +} + +TEST(ImageGenCLILoraParsingTest, MultipleLoras) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org1/repo1,anime=org2/repo2@weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 2); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, "org1/repo1"); + EXPECT_TRUE(graphSettings.loraAdapters[0].safetensorsFile.empty()); + EXPECT_EQ(graphSettings.loraAdapters[1].alias, "anime"); + EXPECT_EQ(graphSettings.loraAdapters[1].sourceLora, "org2/repo2"); + EXPECT_EQ(graphSettings.loraAdapters[1].safetensorsFile, "weights.safetensors"); +} + +TEST(ImageGenCLILoraParsingTest, EmptySourceLorasProducesNoAdapters) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = ""; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 0); +} + +TEST(ImageGenCLILoraParsingTest, InvalidEmptyAlias) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "=org/repo"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, InvalidEmptyFilenameAfterAt) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/repo@"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, MissingAliasWithFilenameThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "org1/repo1@special.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +// ===================== LoRA Source Type Tests ===================== + +TEST(ImageGenCLILoraParsingTest, DirectUrlWithAlias) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=https://huggingface.co/juliensimon/sd-pokemon-lora/resolve/main/pytorch_lora_weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, "https://huggingface.co/juliensimon/sd-pokemon-lora/resolve/main/pytorch_lora_weights.safetensors"); + EXPECT_EQ(graphSettings.loraAdapters[0].safetensorsFile, "pytorch_lora_weights.safetensors"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::DIRECT_URL); +} + +TEST(ImageGenCLILoraParsingTest, DirectUrlHttpWithAlias) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=http://example.com/weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::DIRECT_URL); + EXPECT_EQ(graphSettings.loraAdapters[0].safetensorsFile, "weights.safetensors"); +} + +TEST(ImageGenCLILoraParsingTest, DirectUrlMissingAliasThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "https://example.com/weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, DirectUrlNotSafetensorsThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=https://example.com/model.bin"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +class ImageGenCLILoraParsingWithTempDir : public TestWithTempDir {}; + +TEST_F(ImageGenCLILoraParsingWithTempDir, LocalFileWithAlias) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + std::string tmpFile = ovms::FileSystem::joinPath({this->directoryPath, "test_weights.safetensors"}); + { + std::ofstream f(tmpFile); + f << "test"; + } + hfSettings.sourceLoras = "pokemon=" + tmpFile; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, tmpFile); + EXPECT_EQ(graphSettings.loraAdapters[0].safetensorsFile, "test_weights.safetensors"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::LOCAL_FILE); +} + +TEST_F(ImageGenCLILoraParsingWithTempDir, MixedSourceTypes) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + std::string tmpFile = ovms::FileSystem::joinPath({this->directoryPath, "local.safetensors"}); + { + std::ofstream f(tmpFile); + f << "test"; + } + hfSettings.sourceLoras = "hf=org/repo,url=https://example.com/remote.safetensors,local=" + tmpFile; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 3); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::HF_REPO); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "hf"); + EXPECT_EQ(graphSettings.loraAdapters[1].sourceType, ovms::LoraSourceType::DIRECT_URL); + EXPECT_EQ(graphSettings.loraAdapters[1].alias, "url"); + EXPECT_EQ(graphSettings.loraAdapters[1].safetensorsFile, "remote.safetensors"); + EXPECT_EQ(graphSettings.loraAdapters[2].sourceType, ovms::LoraSourceType::LOCAL_FILE); + EXPECT_EQ(graphSettings.loraAdapters[2].alias, "local"); + EXPECT_EQ(graphSettings.loraAdapters[2].safetensorsFile, "local.safetensors"); +} + +TEST(ImageGenCLILoraParsingTest, LocalFileMissingAliasThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "/tmp/some_weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, LocalFileNotSafetensorsThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=/tmp/model.bin"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, LocalFileDoesNotExistThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=/nonexistent/path/weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +// ===================== Graph Export with Different Source Types ===================== + +const std::string expectedImageGenWithUrlLora = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "CPU" + lora_adapters { alias: "pokemon" path: "loras/pokemon/pytorch_lora_weights.safetensors" } + } + } +} + +)"; + +TEST_F(GraphCreationTest, imageGenerationWithUrlLora) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + imageGenerationGraphSettings.loraAdapters.push_back({"pokemon", "https://huggingface.co/juliensimon/sd-pokemon-lora/resolve/main/pytorch_lora_weights.safetensors", "pytorch_lora_weights.safetensors", ovms::LoraSourceType::DIRECT_URL}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenWithUrlLora, removeVersionString(graphContents)) << graphContents; +} + +const std::string expectedImageGenWithLocalLora = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "CPU" + lora_adapters { alias: "pokemon" path: "/path/to/weights.safetensors" } + } + } +} + +)"; + +TEST_F(GraphCreationTest, imageGenerationWithLocalLora) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + imageGenerationGraphSettings.loraAdapters.push_back({"pokemon", "/path/to/weights.safetensors", "weights.safetensors", ovms::LoraSourceType::LOCAL_FILE}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenWithLocalLora, removeVersionString(graphContents)) << graphContents; +} + +// ===================== Composite LoRA Tests ===================== + +TEST(ImageGenCLILoraParsingTest, CompositeLoraBasic) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora,anime=org/anime-lora,pokemon_anime=@pokemon+@anime"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 2); + ASSERT_EQ(graphSettings.compositeLoraAdapters.size(), 1); + EXPECT_EQ(graphSettings.compositeLoraAdapters[0].alias, "pokemon_anime"); + ASSERT_EQ(graphSettings.compositeLoraAdapters[0].components.size(), 2); + EXPECT_EQ(graphSettings.compositeLoraAdapters[0].components[0].adapterAlias, "pokemon"); + EXPECT_FLOAT_EQ(graphSettings.compositeLoraAdapters[0].components[0].weight, 1.0f); + EXPECT_EQ(graphSettings.compositeLoraAdapters[0].components[1].adapterAlias, "anime"); + EXPECT_FLOAT_EQ(graphSettings.compositeLoraAdapters[0].components[1].weight, 1.0f); +} + +TEST(ImageGenCLILoraParsingTest, CompositeLoraWithWeights) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora,anime=org/anime-lora,blend=@pokemon:0.7+@anime:0.5"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.compositeLoraAdapters.size(), 1); + EXPECT_EQ(graphSettings.compositeLoraAdapters[0].alias, "blend"); + ASSERT_EQ(graphSettings.compositeLoraAdapters[0].components.size(), 2); + EXPECT_FLOAT_EQ(graphSettings.compositeLoraAdapters[0].components[0].weight, 0.7f); + EXPECT_FLOAT_EQ(graphSettings.compositeLoraAdapters[0].components[1].weight, 0.5f); +} + +TEST(ImageGenCLILoraParsingTest, CompositeLoraUnknownRefThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora,blend=@pokemon+@nonexistent"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, CompositeLoraInvalidComponentThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora,blend=@pokemon+noatsign"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +const std::string expectedImageGenWithCompositeLora = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "CPU" + lora_adapters { alias: "pokemon" path: "loras/org/pokemon-lora/weights.safetensors" } + lora_adapters { alias: "anime" path: "loras/org/anime-lora/weights.safetensors" } + composite_lora_adapters { + alias: "blend" + components { adapter_alias: "pokemon" weight: 0.7 } + components { adapter_alias: "anime" weight: 0.5 } + } + } + } +} + +)"; + +TEST_F(GraphCreationTest, imageGenerationWithCompositeLora) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + imageGenerationGraphSettings.loraAdapters.push_back({"pokemon", "org/pokemon-lora", "weights.safetensors", ovms::LoraSourceType::HF_REPO}); + imageGenerationGraphSettings.loraAdapters.push_back({"anime", "org/anime-lora", "weights.safetensors", ovms::LoraSourceType::HF_REPO}); + imageGenerationGraphSettings.compositeLoraAdapters.push_back({"blend", {{"pokemon", 0.7f}, {"anime", 0.5f}}}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenWithCompositeLora, removeVersionString(graphContents)) << graphContents; +} diff --git a/src/test/ovmsconfig_test.cpp b/src/test/ovmsconfig_test.cpp index 5e6f694f56..c8fffad803 100644 --- a/src/test/ovmsconfig_test.cpp +++ b/src/test/ovmsconfig_test.cpp @@ -528,6 +528,60 @@ TEST_F(OvmsConfigDeathTest, negativeImageGenerationGraph_MaxNumInferenceStepsZer EXPECT_THROW(ovms::Config::instance().parse(arg_count, n_argv), std::invalid_argument); } +TEST(OvmsGraphConfigTest, negativeImageGenerationGraph_SourceLorasEmptyAlias) { + char* n_argv[] = { + (char*)"ovms", + (char*)"--pull", + (char*)"--source_model", + (char*)"some/model", + (char*)"--model_repository_path", + (char*)"/some/path", + (char*)"--task", + (char*)"image_generation", + (char*)"--source_loras", + (char*)"=org/repo", + }; + int arg_count = 10; + ConstructorEnabledConfig config; + EXPECT_THROW(config.parse(arg_count, n_argv), std::invalid_argument); +} + +TEST(OvmsGraphConfigTest, negativeImageGenerationGraph_SourceLorasEmptyRepo) { + char* n_argv[] = { + (char*)"ovms", + (char*)"--pull", + (char*)"--source_model", + (char*)"some/model", + (char*)"--model_repository_path", + (char*)"/some/path", + (char*)"--task", + (char*)"image_generation", + (char*)"--source_loras", + (char*)"alias=", + }; + int arg_count = 10; + ConstructorEnabledConfig config; + EXPECT_THROW(config.parse(arg_count, n_argv), std::invalid_argument); +} + +TEST(OvmsGraphConfigTest, negativeImageGenerationGraph_SourceLorasEmptyFilenameAfterAt) { + char* n_argv[] = { + (char*)"ovms", + (char*)"--pull", + (char*)"--source_model", + (char*)"some/model", + (char*)"--model_repository_path", + (char*)"/some/path", + (char*)"--task", + (char*)"image_generation", + (char*)"--source_loras", + (char*)"pokemon=org/repo@", + }; + int arg_count = 10; + ConstructorEnabledConfig config; + EXPECT_THROW(config.parse(arg_count, n_argv), std::invalid_argument); +} + TEST_F(OvmsConfigDeathTest, hfBadEmbeddingsGraphParameter) { char* n_argv[] = { "ovms", @@ -1711,6 +1765,39 @@ TEST(OvmsGraphConfigTest, positiveAllChangedImageGeneration) { ASSERT_EQ(exportSettings.pluginConfig.manualString.value(), "{\"SOME_KEY\":\"SOME_VALUE\"}"); } +TEST(OvmsGraphConfigTest, positiveImageGenerationWithSourceLoras) { + std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov"; + std::string downloadPath = "test/repository"; + char* n_argv[] = { + (char*)"ovms", + (char*)"--pull", + (char*)"--source_model", + (char*)modelName.c_str(), + (char*)"--model_repository_path", + (char*)downloadPath.c_str(), + (char*)"--task", + (char*)"image_generation", + (char*)"--source_loras=pokemon=juliensimon/sd-pokemon-lora@weights.safetensors,anime=org/anime-lora", + }; + + int arg_count = 9; + ConstructorEnabledConfig config; + config.parse(arg_count, n_argv); + + auto& hfSettings = config.getServerSettings().hfSettings; + ASSERT_EQ(hfSettings.task, ovms::IMAGE_GENERATION_GRAPH); + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters.size(), 2); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[0].alias, "pokemon"); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[0].sourceLora, "juliensimon/sd-pokemon-lora"); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[0].safetensorsFile, "weights.safetensors"); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::HF_REPO); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[1].alias, "anime"); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[1].sourceLora, "org/anime-lora"); + ASSERT_TRUE(imageGenerationGraphSettings.loraAdapters[1].safetensorsFile.empty()); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[1].sourceType, ovms::LoraSourceType::HF_REPO); +} + TEST(OvmsGraphConfigTest, positiveDefaultImageGeneration) { std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov"; std::string downloadPath = "test/repository"; diff --git a/src/test/pull_hf_model_test.cpp b/src/test/pull_hf_model_test.cpp index 1fbd0798f6..3dc3eea8b8 100644 --- a/src/test/pull_hf_model_test.cpp +++ b/src/test/pull_hf_model_test.cpp @@ -37,6 +37,7 @@ #include "src/pull_module/optimum_export.hpp" #include "src/servables_config_manager_module/listmodels.hpp" #include "src/modelextensions.hpp" +#include "src/capi_frontend/server_settings.hpp" #include "../module.hpp" #include "../server.hpp" @@ -301,9 +302,7 @@ TEST_F(HfDownloaderPullHfModel, Resume) { // Fails because we want clean and it has the graph.pbtxt after download ASSERT_EQ(hfDownloader->CheckRepositoryStatus(true).getCode(), ovms::StatusCode::HF_GIT_STATUS_UNCLEAN); - exit(0); - }, - ::testing::ExitedWithCode(0), ""); + exit(0); }, ::testing::ExitedWithCode(0), ""); std::error_code ec; ec.clear(); @@ -503,9 +502,7 @@ TEST(HfDownloaderClassTest, RepositoryStatusCheckErrors) { // Fails without libgit init ASSERT_EQ(hfDownloader->CheckRepositoryStatus(true).getCode(), ovms::StatusCode::HF_GIT_LIBGIT2_NOT_INITIALIZED); ASSERT_EQ(hfDownloader->CheckRepositoryStatus(false).getCode(), ovms::StatusCode::HF_GIT_LIBGIT2_NOT_INITIALIZED); - exit(0); - }, - ::testing::ExitedWithCode(0), ""); + exit(0); }, ::testing::ExitedWithCode(0), ""); EXPECT_EXIT({ std::unique_ptr hfDownloader = std::make_unique(modelName, ovms::IModelDownloader::getGraphDirectory(downloadPath, modelName), hfEndpoint, hfToken, httpProxy, false); @@ -523,9 +520,7 @@ TEST(HfDownloaderClassTest, RepositoryStatusCheckErrors) { std::unique_ptr existingHfDownloader = std::make_unique(modelName, downloadPath, hfEndpoint, hfToken, httpProxy, false); ASSERT_EQ(existingHfDownloader->CheckRepositoryStatus(true).getCode(), ovms::StatusCode::HF_GIT_STATUS_FAILED); ASSERT_EQ(existingHfDownloader->CheckRepositoryStatus(false).getCode(), ovms::StatusCode::HF_GIT_STATUS_FAILED); - exit(0); - }, - ::testing::ExitedWithCode(0), ""); + exit(0); }, ::testing::ExitedWithCode(0), ""); } class TestOptimumDownloaderSetup : public ::testing::Test { @@ -809,9 +804,7 @@ TEST(Libgt2InitGuardTest, LfsFilterCaptureDefaultResumeOptions) { } EXPECT_THAT(output, ::testing::HasSubstr("[INFO] LFS resume: attempts=5 interval=10 s")); - exit(0); - }, - ::testing::ExitedWithCode(0), ""); + exit(0); }, ::testing::ExitedWithCode(0), ""); } TEST(Libgt2InitGuardTest, LfsFilterCaptureNonDefaultResumeOptions) { @@ -834,9 +827,7 @@ TEST(Libgt2InitGuardTest, LfsFilterCaptureNonDefaultResumeOptions) { } EXPECT_THAT(output, ::testing::HasSubstr("[INFO] LFS resume: attempts=3 interval=20 s")); - exit(0); - }, - ::testing::ExitedWithCode(0), ""); + exit(0); }, ::testing::ExitedWithCode(0), ""); } TEST_F(HfDownloaderHfEnvTest, Methods) { @@ -1036,3 +1027,130 @@ TEST(ServerModulesBehaviorTests, PullAndStartModeErrorAndExpectFailAndNoOtherMod ASSERT_EQ(server.getModule(ovms::SERVABLE_MANAGER_MODULE_NAME), nullptr); ASSERT_EQ(server.getModule(ovms::SERVABLES_CONFIG_MANAGER_MODULE_NAME), nullptr); } + +// ===================== LoRA Pull Module Tests ===================== + +class TestHfPullModelModuleForLora : public ovms::HfPullModelModule { +public: + ovms::HFSettingsImpl& getHfSettings() { return this->hfSettings; } + ovms::Status testResolveHfLoraFilenames() { return this->resolveHfLoraFilenames(); } + ovms::Status testPullLoraAdapters(const std::string& graphDirectory) { return this->pullLoraAdapters(graphDirectory); } +}; + +class HfPullModelModuleLoraTest : public TestWithTempDir {}; + +TEST_F(HfPullModelModuleLoraTest, ResolveHfLoraFilenames) { + SKIP_AND_EXIT_IF_NOT_RUNNING_UNSTABLE(); + const char* hfToken = std::getenv("HF_TOKEN"); + if (!hfToken || std::string(hfToken).empty()) { + GTEST_SKIP() << "Skipping: HF_TOKEN not set (required for HF API resolution)"; + } + TestHfPullModelModuleForLora module; + auto& settings = module.getHfSettings(); + settings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl graphSettings; + ovms::LoraAdapterSettings adapter; + adapter.alias = "pokemon"; + adapter.sourceLora = "juliensimon/sd-pokemon-lora"; + adapter.safetensorsFile = ""; + adapter.sourceType = ovms::LoraSourceType::HF_REPO; + graphSettings.loraAdapters.push_back(adapter); + settings.graphSettings = graphSettings; + + auto status = module.testResolveHfLoraFilenames(); + ASSERT_TRUE(status.ok()) << status.string(); + + const auto& resolved = std::get(settings.graphSettings); + ASSERT_EQ(resolved.loraAdapters.size(), 1); + EXPECT_EQ(resolved.loraAdapters[0].safetensorsFile, "pytorch_lora_weights.safetensors"); +} + +TEST_F(HfPullModelModuleLoraTest, PullLoraAdaptersFromHfRepo) { + SKIP_AND_EXIT_IF_NOT_RUNNING_UNSTABLE(); + const char* hfToken = std::getenv("HF_TOKEN"); + if (!hfToken || std::string(hfToken).empty()) { + GTEST_SKIP() << "Skipping: HF_TOKEN not set (required for HF download)"; + } + TestHfPullModelModuleForLora module; + auto& settings = module.getHfSettings(); + settings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl graphSettings; + ovms::LoraAdapterSettings adapter; + adapter.alias = "pokemon"; + adapter.sourceLora = "juliensimon/sd-pokemon-lora"; + adapter.safetensorsFile = "pytorch_lora_weights.safetensors"; // explicit filename — skips HF API resolve + adapter.sourceType = ovms::LoraSourceType::HF_REPO; + graphSettings.loraAdapters.push_back(adapter); + settings.graphSettings = graphSettings; + + auto status = module.testPullLoraAdapters(this->directoryPath); + ASSERT_TRUE(status.ok()) << status.string(); + + auto loraFilePath = ovms::FileSystem::joinPath({this->directoryPath, "loras", "juliensimon/sd-pokemon-lora", "pytorch_lora_weights.safetensors"}); + ASSERT_TRUE(std::filesystem::exists(loraFilePath)) << loraFilePath; + EXPECT_GT(std::filesystem::file_size(loraFilePath), 0); +} + +TEST_F(HfPullModelModuleLoraTest, PullLoraAdaptersSkipsLocalFile) { + TestHfPullModelModuleForLora module; + auto& settings = module.getHfSettings(); + settings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl graphSettings; + ovms::LoraAdapterSettings adapter; + adapter.alias = "local_lora"; + adapter.sourceLora = "/some/local/path/model.safetensors"; + adapter.safetensorsFile = "model.safetensors"; + adapter.sourceType = ovms::LoraSourceType::LOCAL_FILE; + graphSettings.loraAdapters.push_back(adapter); + settings.graphSettings = graphSettings; + + auto status = module.testPullLoraAdapters(this->directoryPath); + ASSERT_TRUE(status.ok()) << status.string(); + // No files should have been downloaded to the temp directory + EXPECT_TRUE(std::filesystem::is_empty(this->directoryPath)); +} + +TEST_F(HfPullModelModuleLoraTest, PullLoraAdaptersNonImageGenGraphIsNoOp) { + TestHfPullModelModuleForLora module; + auto& settings = module.getHfSettings(); + settings.task = ovms::TEXT_GENERATION_GRAPH; + settings.graphSettings = ovms::TextGenGraphSettingsImpl{}; + + auto status = module.testPullLoraAdapters(this->directoryPath); + ASSERT_TRUE(status.ok()) << status.string(); +} + +// Full-flow test: download SD model + LoRA via --pull mode, verify files and graph.pbtxt. +// This exercises: CLI parsing -> source_loras -> HF resolution -> LoRA download -> graph.pbtxt generation. +// Runtime clone()+LoRA behavior is guaranteed by the GenAI API: clone() "reuses underlying models" +// which share the AdapterController. Adapters are selected per-request via generate() properties. +TEST_F(HfDownloaderPullHfModel, DownloadImageGenModelWithLoRA) { + SKIP_AND_EXIT_IF_NOT_RUNNING_UNSTABLE(); + const char* hfToken = std::getenv("HF_TOKEN"); + if (!hfToken || std::string(hfToken).empty()) { + GTEST_SKIP() << "Skipping: HF_TOKEN not set (required for HF LoRA download)"; + } + this->filesToPrintInCaseOfFailure.emplace_back("graph.pbtxt"); + std::string modelName = "OpenVINO/stable-diffusion-v1-5-int8-ov"; + std::string downloadPath = ovms::FileSystem::joinPath({this->directoryPath, "repository"}); + std::string task = "image_generation"; + std::string sourceLoras = "pokemon=juliensimon/sd-pokemon-lora@pytorch_lora_weights.safetensors"; + ::SetUpServerForDownloadWithLoras(this->t, this->server, modelName, downloadPath, task, sourceLoras); + + std::string basePath = ovms::FileSystem::joinPath({downloadPath, "OpenVINO", "stable-diffusion-v1-5-int8-ov"}); + std::string graphPath = ovms::FileSystem::appendSlash(basePath) + "graph.pbtxt"; + + // Verify model was downloaded + ASSERT_TRUE(std::filesystem::exists(basePath)) << basePath; + ASSERT_TRUE(std::filesystem::exists(graphPath)) << graphPath; + + // Verify LoRA adapter was downloaded + std::string loraDir = ovms::FileSystem::joinPath({basePath, "loras", "juliensimon", "sd-pokemon-lora"}); + auto loraFiles = searchFilesRecursively(loraDir, {"pytorch_lora_weights.safetensors"}); + ASSERT_FALSE(loraFiles.empty()) << "LoRA .safetensors not found in: " << loraDir; + + // Verify graph.pbtxt contains the LoRA adapter entry + std::string graphContents = GetFileContents(graphPath); + EXPECT_NE(graphContents.find("lora_adapters"), std::string::npos) << "graph.pbtxt should contain lora_adapters"; + EXPECT_NE(graphContents.find("pokemon"), std::string::npos) << "graph.pbtxt should reference pokemon alias"; +} diff --git a/src/test/test_utils.cpp b/src/test/test_utils.cpp index a80f924f9f..bafa8bafd8 100644 --- a/src/test/test_utils.cpp +++ b/src/test/test_utils.cpp @@ -793,6 +793,27 @@ void SetUpServerForDownloadWithDraft(std::unique_ptr& t, ovms::Serv EnsureServerModelDownloadFinishedWithTimeout(server, timeoutSeconds); } +void SetUpServerForDownloadWithLoras(std::unique_ptr& t, ovms::Server& server, std::string& source_model, std::string& download_path, std::string& task, std::string& source_loras, int expected_code, int timeoutSeconds) { + server.setShutdownRequest(0); + char* argv[] = {(char*)"ovms", + (char*)"--pull", + (char*)"--source_model", + (char*)source_model.c_str(), + (char*)"--model_repository_path", + (char*)download_path.c_str(), + (char*)"--task", + (char*)task.c_str(), + (char*)"--source_loras", + (char*)source_loras.c_str()}; + + int argc = 10; + t.reset(new std::thread([&argc, &argv, &server, expected_code]() { + EXPECT_EQ(expected_code, server.start(argc, argv)); + })); + + EnsureServerModelDownloadFinishedWithTimeout(server, timeoutSeconds); +} + void SetUpServerForDownloadAndStart(std::unique_ptr& t, ovms::Server& server, std::string& source_model, std::string& download_path, std::string& task, int timeoutSeconds) { server.setShutdownRequest(0); std::string port = "9133"; diff --git a/src/test/test_utils.hpp b/src/test/test_utils.hpp index 879ab1313e..f9e6cd970f 100644 --- a/src/test/test_utils.hpp +++ b/src/test/test_utils.hpp @@ -786,6 +786,11 @@ void SetUpServerForDownloadWithDraft(std::unique_ptr& t, ovms::Serv * --source_model Qwen/Qwen3-8B-GGUF --model_repository_path /models --gguf_filename Qwen3-8B-Q4_K_M.gguf */ void SetUpServerForDownloadAndStartGGUF(std::unique_ptr& t, ovms::Server& server, std::string& ggufFilename, std::string& sourceModel, std::string& downloadPath, std::string& task, int timeoutSeconds = 4 * SERVER_START_FROM_CONFIG_TIMEOUT_SECONDS); +/* + * starts loading OVMS on separate thread but waits until it is shutdowned or model is downloaded and check if model is downloaded in ovms + * --pull --source_model org/model --model_repository_path /models --task image_generation --source_loras alias=org/repo + */ +void SetUpServerForDownloadWithLoras(std::unique_ptr& t, ovms::Server& server, std::string& source_model, std::string& download_path, std::string& task, std::string& source_loras, int expected_code = EXIT_SUCCESS, int timeoutSeconds = 4 * SERVER_START_FROM_CONFIG_TIMEOUT_SECONDS); /* * starts loading OVMS on separate thread but waits until it is shutdowned or model is downloaded and check if model is started in ovms * --source_model OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov --model_repository_path /models diff --git a/src/test/text2image_test.cpp b/src/test/text2image_test.cpp index 3d61d3e3cd..ff616fecf1 100644 --- a/src/test/text2image_test.cpp +++ b/src/test/text2image_test.cpp @@ -1496,5 +1496,115 @@ TEST(Text2ImageTest, ResponseFromOvTensorBatch3) { uint16_t n = 3; testResponseFromOvTensor(n); } +// ===================== LoRA Proto Parsing Tests ===================== + +TEST(ImageGenCalculatorOptionsTest, LoraAdaptersAbsolutePath) { +#ifdef _WIN32 + const std::string dummyLocation = dummy_model_location; +#else + const std::string dummyLocation = "/ovms/src/test/dummy"; +#endif + std::ostringstream oss; + oss << R"pb( + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")pb" + << dummyLocation; + oss << R"(")"; + oss << R"pb( + lora_adapters { alias: "pokemon" path: "/absolute/path/to/lora.safetensors" } # Shariar00/stable-diffusion-v1-5_Finetune_Custom_Fashion_dataset_v1.0 + lora_adapters { alias: "anime" path: "/another/path/weights.safetensors" alpha: 0.5 } + } + } +)pb"; + auto nodePbtxt = oss.str(); + auto node = mediapipe::ParseTextProtoOrDie(nodePbtxt); + const std::string graphPath = ""; + auto nodeOptions = node.node_options(0); + auto imageGenArgsOrStatus = prepareImageGenPipelineArgs(nodeOptions, graphPath); + ASSERT_TRUE(std::holds_alternative(imageGenArgsOrStatus)); + auto imageGenArgs = std::get(imageGenArgsOrStatus); + ASSERT_EQ(imageGenArgs.loraAdapters.size(), 2); + EXPECT_EQ(imageGenArgs.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(imageGenArgs.loraAdapters[0].path, "/absolute/path/to/lora.safetensors"); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[0].alpha, 1.0f); + EXPECT_EQ(imageGenArgs.loraAdapters[1].alias, "anime"); + EXPECT_EQ(imageGenArgs.loraAdapters[1].path, "/another/path/weights.safetensors"); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[1].alpha, 0.5f); +} + +TEST(ImageGenCalculatorOptionsTest, LoraAdaptersRelativePath) { +#ifdef _WIN32 + const std::string dummyLocation = dummy_model_location; +#else + const std::string dummyLocation = "/ovms/src/test/dummy"; +#endif + std::ostringstream oss; + oss << R"pb( + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")pb" + << dummyLocation; + oss << R"(")"; + oss << R"pb( + lora_adapters { alias: "pokemon" path: "loras/org/repo/model.safetensors" } + } + } +)pb"; + auto nodePbtxt = oss.str(); + auto node = mediapipe::ParseTextProtoOrDie(nodePbtxt); + const std::string graphPath = "/ovms/graph_dir"; + auto nodeOptions = node.node_options(0); + auto imageGenArgsOrStatus = prepareImageGenPipelineArgs(nodeOptions, graphPath); + ASSERT_TRUE(std::holds_alternative(imageGenArgsOrStatus)); + auto imageGenArgs = std::get(imageGenArgsOrStatus); + ASSERT_EQ(imageGenArgs.loraAdapters.size(), 1); + EXPECT_EQ(imageGenArgs.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(imageGenArgs.loraAdapters[0].path, "/ovms/graph_dir/loras/org/repo/model.safetensors"); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[0].alpha, 1.0f); +} + +TEST(ImageGenCalculatorOptionsTest, NoLoraAdaptersProducesEmptyVector) { +#ifdef _WIN32 + const std::string dummyLocation = dummy_model_location; +#else + const std::string dummyLocation = "/ovms/src/test/dummy"; +#endif + std::ostringstream oss; + oss << R"pb( + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")pb" + << dummyLocation; + oss << R"(")"; + oss << R"pb( + } + } + )pb"; + auto nodePbtxt = oss.str(); + auto node = mediapipe::ParseTextProtoOrDie(nodePbtxt); + const std::string graphPath = ""; + auto nodeOptions = node.node_options(0); + auto imageGenArgsOrStatus = prepareImageGenPipelineArgs(nodeOptions, graphPath); + ASSERT_TRUE(std::holds_alternative(imageGenArgsOrStatus)); + auto imageGenArgs = std::get(imageGenArgsOrStatus); + ASSERT_TRUE(imageGenArgs.loraAdapters.empty()); +} + // TODO: // -> test for all unhandled OpenAI fields define what to do - ignore/error imageVariation