diff --git a/Examples/Examples/Chat/ChatExampleVertex.cs b/Examples/Examples/Chat/ChatExampleVertex.cs new file mode 100644 index 00000000..0618d52b --- /dev/null +++ b/Examples/Examples/Chat/ChatExampleVertex.cs @@ -0,0 +1,24 @@ +using Examples.Utils; +using MaIN.Core.Hub; +using MaIN.Domain.Configuration.BackendInferenceParams; +using MaIN.Domain.Models; + +namespace Examples.Chat; + +public class ChatExampleVertex : IExample +{ + public async Task Start() + { + VertexExample.Setup(); //We need to provide Google service account config + Console.WriteLine("(Vertex AI) ChatExample is running!"); + + await AIHub.Chat() + .WithModel(Models.Vertex.Gemini2_5Pro) + .WithMessage("Is the killer whale the smartest animal?") + .WithInferenceParams(new VertexInferenceParams + { + Location = "europe-central2" + }) + .CompleteAsync(interactive: true); + } +} diff --git a/Examples/Examples/Program.cs b/Examples/Examples/Program.cs index 7283210b..ba0c8523 100644 --- a/Examples/Examples/Program.cs +++ b/Examples/Examples/Program.cs @@ -73,6 +73,7 @@ static void RegisterExamples(IServiceCollection services) services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); services.AddTransient(); services.AddTransient(); services.AddTransient(); @@ -186,6 +187,7 @@ public class ExampleRegistry(IServiceProvider serviceProvider) ("\u25a0 Gemini Chat with grammar", serviceProvider.GetRequiredService()), ("\u25a0 Gemini Chat with image", serviceProvider.GetRequiredService()), ("\u25a0 Gemini Chat with files", serviceProvider.GetRequiredService()), + ("\u25a0 Vertex Chat", serviceProvider.GetRequiredService()), ("\u25a0 DeepSeek Chat with reasoning", serviceProvider.GetRequiredService()), ("\u25a0 GroqCloud Chat", serviceProvider.GetRequiredService()), ("\u25a0 Anthropic Chat", serviceProvider.GetRequiredService()), diff --git a/Examples/Examples/Utils/VertexExample.cs b/Examples/Examples/Utils/VertexExample.cs new file mode 100644 index 00000000..1fa68fa8 --- /dev/null +++ b/Examples/Examples/Utils/VertexExample.cs @@ -0,0 +1,22 @@ +using MaIN.Core; +using MaIN.Domain.Configuration; +using MaIN.Domain.Configuration.Vertex; + +namespace Examples.Utils; + +public class VertexExample +{ + public static void Setup() + { + MaINBootstrapper.Initialize(configureSettings: options => + { + options.BackendType = BackendType.Vertex; + options.GoogleServiceAccountAuth = new GoogleServiceAccountConfig + { + ProjectId = "", + ClientEmail = "", + PrivateKey = @"" + }; + }); + } +} \ No newline at end of file diff --git a/Releases/0.10.4.md b/Releases/0.10.4.md new file mode 100644 index 00000000..54102b7a --- /dev/null +++ b/Releases/0.10.4.md @@ -0,0 +1,3 @@ +# 0.10.4 release + +Adds Google Vertex AI as a backend with authentication, MCP support, and new models including image generation, along with UI configuration and example usage. \ No newline at end of file diff --git a/src/MaIN.Core/.nuspec b/src/MaIN.Core/.nuspec index 48635786..7f341170 100644 --- a/src/MaIN.Core/.nuspec +++ b/src/MaIN.Core/.nuspec @@ -2,7 +2,7 @@ MaIN.NET - 0.10.3 + 0.10.4 Wisedev Wisedev favicon.png diff --git a/src/MaIN.Domain/Configuration/BackendInferenceParams/BackendParamsFactory.cs b/src/MaIN.Domain/Configuration/BackendInferenceParams/BackendParamsFactory.cs index fa8d729b..70ce5077 100644 --- a/src/MaIN.Domain/Configuration/BackendInferenceParams/BackendParamsFactory.cs +++ b/src/MaIN.Domain/Configuration/BackendInferenceParams/BackendParamsFactory.cs @@ -14,6 +14,7 @@ public static class BackendParamsFactory BackendType.Gemini => new GeminiInferenceParams(), BackendType.Anthropic => new AnthropicInferenceParams(), BackendType.Ollama => new OllamaInferenceParams(), + BackendType.Vertex => new VertexInferenceParams(), _ => new LocalInferenceParams() }; } diff --git a/src/MaIN.Domain/Configuration/BackendInferenceParams/VertexInferenceParams.cs b/src/MaIN.Domain/Configuration/BackendInferenceParams/VertexInferenceParams.cs new file mode 100644 index 00000000..0e5f2273 --- /dev/null +++ b/src/MaIN.Domain/Configuration/BackendInferenceParams/VertexInferenceParams.cs @@ -0,0 +1,18 @@ +using MaIN.Domain.Entities; +using Grammar = MaIN.Domain.Models.Grammar; + +namespace MaIN.Domain.Configuration.BackendInferenceParams; + +public class VertexInferenceParams : IBackendInferenceParams +{ + public BackendType Backend => BackendType.Vertex; + + public string Location { get; init; } = "us-central1"; + + public float? Temperature { get; init; } + public int? MaxTokens { get; init; } + public float? TopP { get; init; } + public string[]? StopSequences { get; init; } + public Grammar? Grammar { get; set; } + public Dictionary? AdditionalParams { get; init; } +} diff --git a/src/MaIN.Domain/Configuration/MaINSettings.cs b/src/MaIN.Domain/Configuration/MaINSettings.cs index d763f088..81018957 100644 --- a/src/MaIN.Domain/Configuration/MaINSettings.cs +++ b/src/MaIN.Domain/Configuration/MaINSettings.cs @@ -1,3 +1,4 @@ +using MaIN.Domain.Configuration.Vertex; namespace MaIN.Domain.Configuration; @@ -18,6 +19,7 @@ public class MaINSettings public SqliteSettings? SqliteSettings { get; set; } public SqlSettings? SqlSettings { get; set; } public string? VoicesPath { get; set; } + public GoogleServiceAccountConfig? GoogleServiceAccountAuth { get; set; } } public enum BackendType @@ -30,4 +32,5 @@ public enum BackendType Anthropic = 5, Xai = 6, Ollama = 7, + Vertex = 8, } \ No newline at end of file diff --git a/src/MaIN.Domain/Configuration/Vertex/GoogleServiceAccountConfig.cs b/src/MaIN.Domain/Configuration/Vertex/GoogleServiceAccountConfig.cs new file mode 100644 index 00000000..444829ef --- /dev/null +++ b/src/MaIN.Domain/Configuration/Vertex/GoogleServiceAccountConfig.cs @@ -0,0 +1,9 @@ +namespace MaIN.Domain.Configuration.Vertex; + +public class GoogleServiceAccountConfig +{ + public required string ProjectId { get; init; } + public required string ClientEmail { get; init; } + public required string PrivateKey { get; init; } + public string TokenUri { get; init; } = "https://oauth2.googleapis.com/token"; +} diff --git a/src/MaIN.Domain/Entities/Mcp.cs b/src/MaIN.Domain/Entities/Mcp.cs index 11816f33..32fea8de 100644 --- a/src/MaIN.Domain/Entities/Mcp.cs +++ b/src/MaIN.Domain/Entities/Mcp.cs @@ -8,6 +8,7 @@ public class Mcp public required List Arguments { get; init; } public required string Command { get; init; } public required string Model { get; init; } + public string Location { get; set; } = "us-central1"; public Dictionary Properties { get; set; } = []; public BackendType? Backend { get; set; } public Dictionary EnvironmentVariables { get; set; } = []; diff --git a/src/MaIN.Domain/Models/Concrete/CloudModels.cs b/src/MaIN.Domain/Models/Concrete/CloudModels.cs index 002e8d2f..1697e8b2 100644 --- a/src/MaIN.Domain/Models/Concrete/CloudModels.cs +++ b/src/MaIN.Domain/Models/Concrete/CloudModels.cs @@ -93,6 +93,59 @@ public sealed record Gemini2_0Flash() : CloudModel( public string? MMProjectName => null; } +public sealed record Gemini2_5Pro() : CloudModel( + Models.Gemini.Gemini2_5Pro, + BackendType.Gemini, + "Gemini 2.5 Pro", + 1000000, + "Google's most capable Gemini model"), IVisionModel +{ + public string? MMProjectName => null; +} + +public sealed record GeminiImagen4_0FastGenerate() : CloudModel( + Models.Gemini.Imagen4_0_FastGenerate, + BackendType.Gemini, + "Imagen 4.0 Fast (Gemini)", + 4000, + "Google's fast image generation model via Gemini API"), IImageGenerationModel; + +// ===== Vertex AI Models ===== + +public sealed record VertexGemini2_5Pro() : CloudModel( + Models.Vertex.Gemini2_5Pro, + BackendType.Vertex, + "Gemini 2.5 Pro (Vertex)", + 1000000, + "Fast and efficient Gemini model served via Vertex AI"), IVisionModel +{ + public string? MMProjectName => null; +} + +public sealed record VertexGemini2_5Flash() : CloudModel( + Models.Vertex.Gemini2_5Flash, + BackendType.Vertex, + "Gemini 2.5 Flash (Vertex)", + 1000000, + "Fast and efficient Gemini model served via Vertex AI"), IVisionModel +{ + public string? MMProjectName => null; +} + +public sealed record VertexVeo2_0Generate() : CloudModel( + Models.Vertex.Veo2_0_Generate, + BackendType.Vertex, + "Veo 2.0 Generate", + 4000, + "Google's video generation model available through Vertex AI"), IImageGenerationModel; + +public sealed record VertexImagen4_0Generate() : CloudModel( + Models.Vertex.Imagen4_0_Generate, + BackendType.Vertex, + "Imagen 4.0 (Vertex)", + 4000, + "Google's latest image generation model available through Vertex AI"), IImageGenerationModel; + // ===== xAI Models ===== public sealed record Grok3Beta() : CloudModel( diff --git a/src/MaIN.Domain/Models/Concrete/LLMApiRegistry.cs b/src/MaIN.Domain/Models/Concrete/LLMApiRegistry.cs index 726aa434..32a8b03a 100644 --- a/src/MaIN.Domain/Models/Concrete/LLMApiRegistry.cs +++ b/src/MaIN.Domain/Models/Concrete/LLMApiRegistry.cs @@ -11,6 +11,7 @@ public static class LLMApiRegistry public static readonly LLMApiRegistryEntry Anthropic = new("Anthropic", "ANTHROPIC_API_KEY"); public static readonly LLMApiRegistryEntry Xai = new("Xai", "XAI_API_KEY"); public static readonly LLMApiRegistryEntry Ollama = new("Ollama", "OLLAMA_API_KEY"); + public static readonly LLMApiRegistryEntry Vertex = new("Vertex", "GOOGLE_APPLICATION_CREDENTIALS"); public static LLMApiRegistryEntry? GetEntry(BackendType backendType) => backendType switch { @@ -21,6 +22,7 @@ public static class LLMApiRegistry BackendType.Anthropic => Anthropic, BackendType.Xai => Xai, BackendType.Ollama => Ollama, + BackendType.Vertex => Vertex, _ => null }; } diff --git a/src/MaIN.Domain/Models/Concrete/LocalModels.cs b/src/MaIN.Domain/Models/Concrete/LocalModels.cs index 7132e3b5..42e89206 100644 --- a/src/MaIN.Domain/Models/Concrete/LocalModels.cs +++ b/src/MaIN.Domain/Models/Concrete/LocalModels.cs @@ -293,6 +293,16 @@ public sealed record Olmo2_7b() : LocalModel( 8192, "Open-source 7B model for research, benchmarking, and academic studies"); +// ===== Image Generation ===== + +public sealed record Flux1Shnell() : LocalModel( + Models.Local.Flux1Shnell, + "FLUX.1_Shnell", + null, + "FLUX.1 Schnell", + 4096, + "Fast local image generation model"), IImageGenerationModel; + // ===== Embedding Model ===== public sealed record Mxbai_Embedding() : LocalModel( diff --git a/src/MaIN.Domain/Models/Models.cs b/src/MaIN.Domain/Models/Models.cs index bd79619a..c5d68cc7 100644 --- a/src/MaIN.Domain/Models/Models.cs +++ b/src/MaIN.Domain/Models/Models.cs @@ -23,8 +23,10 @@ public static class Anthropic public static class Gemini { + public const string Gemini2_5Pro = "gemini-2.5-pro"; public const string Gemini2_5Flash = "gemini-2.5-flash"; public const string Gemini2_0Flash = "gemini-2.0-flash"; + public const string Imagen4_0_FastGenerate = "imagen-4.0-fast-generate-001"; } public static class Xai @@ -49,6 +51,14 @@ public static class Ollama public const string Gemma3_4b = "gemma3:4b"; } + public static class Vertex + { + public const string Gemini2_5Pro = "google/gemini-2.5-pro"; + public const string Gemini2_5Flash = "google/gemini-2.5-flash"; + public const string Veo2_0_Generate = "google/veo-2.0-generate-001"; + public const string Imagen4_0_Generate = "google/imagen-4.0-generate-001"; + } + public static class Local { // Gemma diff --git a/src/MaIN.InferPage/Components/Pages/Home.razor b/src/MaIN.InferPage/Components/Pages/Home.razor index 36ec3e99..729acc2c 100644 --- a/src/MaIN.InferPage/Components/Pages/Home.razor +++ b/src/MaIN.InferPage/Components/Pages/Home.razor @@ -3,7 +3,7 @@ @inject IJSRuntime JS @inject SettingsService SettingsStorage @inject SettingsStateService SettingsState -@inject MaIN.Domain.Configuration.MaINSettings MaINSettings +@inject MaINSettings MaINSettings @implements IDisposable @using MaIN.Core.Hub @using MaIN.Core.Hub.Contexts.Interfaces.ChatContext @@ -427,12 +427,32 @@ if (backendType != BackendType.Self) { var backendKey = backendType == BackendType.Ollama - ? (settings.IsOllamaCloud ? "OllamaCloud" : "OllamaLocal") - : backendType.ToString(); + ? (settings.IsOllamaCloud + ? "OllamaCloud" + : "OllamaLocal") + : backendType == BackendType.Vertex + ? "Vertex" + : backendType.ToString(); apiKey = await SettingsStorage.GetApiKeyForBackendAsync(backendKey); } + // Load Vertex auth from localStorage if applicable + Domain.Configuration.Vertex.GoogleServiceAccountConfig? vertexAuth = null; + if (backendType == BackendType.Vertex) + { + var stored = await SettingsStorage.GetVertexAuthAsync(); + if (stored != null) + { + vertexAuth = new Domain.Configuration.Vertex.GoogleServiceAccountConfig + { + ProjectId = stored.ProjectId, + ClientEmail = stored.ClientEmail, + PrivateKey = stored.PrivateKey + }; + } + } + Utils.ApplySettings( backendType, settings.Model!, @@ -442,7 +462,8 @@ settings.HasImageGen, settings.MmProjName, MaINSettings, - apiKey); + apiKey, + vertexAuth); } private void ShowSettingsFromGear() diff --git a/src/MaIN.InferPage/Components/Pages/Settings.razor b/src/MaIN.InferPage/Components/Pages/Settings.razor index c6d5456f..d8086e80 100644 --- a/src/MaIN.InferPage/Components/Pages/Settings.razor +++ b/src/MaIN.InferPage/Components/Pages/Settings.razor @@ -1,4 +1,5 @@ @using MaIN.Domain.Configuration +@using MaIN.Domain.Configuration.Vertex @using MaIN.Domain.Models.Abstract @inject SettingsService SettingsStorage @inject MaINSettings MaINSettings @@ -64,6 +65,49 @@ } + @if (_selectedBackend?.BackendType == BackendType.Vertex) + { +
+ + +
+ +
+ + +
+ +
+ +
+ + + +
+ private_key field from the service account JSON file +
+ +
+ + + Optional — defaults to us-central1 +
+ } + @if (_selectedBackend?.BackendType == BackendType.Self) {
@@ -155,6 +199,13 @@ private string? _savedKeyPreview; private bool _showApiKey; + // Vertex AI auth fields + private string? _vertexProjectId; + private string? _vertexClientEmail; + private string? _vertexPrivateKey; + private string? _vertexLocation; + private bool _showVertexKey; + // "Will load:" path preview shown below the model path field (Self backend only) private string? ResolvedModelPathPreview { @@ -236,9 +287,14 @@ private string? _mmProjName; private bool RequiresApiKey => _selectedBackend?.RequiresApiKey == true; + private bool IsVertexBackend => _selectedBackend?.BackendType == BackendType.Vertex; + private bool HasVertexRequiredFields => !string.IsNullOrWhiteSpace(_vertexProjectId) + && !string.IsNullOrWhiteSpace(_vertexClientEmail) + && !string.IsNullOrWhiteSpace(_vertexPrivateKey); private bool CanSave => !string.IsNullOrWhiteSpace(_modelName) && _selectedBackend != null - && (!RequiresApiKey || !string.IsNullOrEmpty(_apiKeyInput) || !string.IsNullOrEmpty(_savedKeyPreview)); + && (!RequiresApiKey || !string.IsNullOrEmpty(_apiKeyInput) || !string.IsNullOrEmpty(_savedKeyPreview)) + && (!IsVertexBackend || HasVertexRequiredFields); protected override async Task OnInitializedAsync() { @@ -251,7 +307,8 @@ new(5, "Anthropic", BackendType.Anthropic, true), new(6, "xAI", BackendType.Xai, true), new(7, "Ollama (Local)", BackendType.Ollama, false), - new(8, "Ollama (Cloud)", BackendType.Ollama, true) + new(8, "Ollama (Cloud)", BackendType.Ollama, true), + new(9, "Vertex AI", BackendType.Vertex, false) }.OrderBy(x => x.DisplayName) .Prepend(new BackendOption(0, "Local", BackendType.Self, false)) .ToList(); @@ -281,6 +338,10 @@ ? _backendOptions.First(o => o.Id == 8) : _backendOptions.First(o => o.Id == 7); } + else if (backendType == BackendType.Vertex) + { + _selectedBackend = _backendOptions.First(o => o.Id == 9); + } else { _selectedBackend = _backendOptions.FirstOrDefault(o => o.BackendType == backendType && o.Id < 7); @@ -293,6 +354,18 @@ _manualImageGen = settings.HasImageGen; _mmProjName = settings.MmProjName; + if (backendType == BackendType.Vertex) + { + var vertexAuth = await SettingsStorage.GetVertexAuthAsync(); + if (vertexAuth != null) + { + _vertexProjectId = vertexAuth.ProjectId; + _vertexClientEmail = vertexAuth.ClientEmail; + _vertexPrivateKey = vertexAuth.PrivateKey; + } + _vertexLocation = settings.VertexLocation; + } + OnModelNameChanged(); } else if (!Utils.NeedsConfiguration) @@ -302,7 +375,9 @@ ? (Utils.HasApiKey ? _backendOptions.First(o => o.Id == 8) : _backendOptions.First(o => o.Id == 7)) - : _backendOptions.FirstOrDefault(o => o.BackendType == Utils.BackendType && o.Id < 7); + : Utils.BackendType == BackendType.Vertex + ? _backendOptions.First(o => o.Id == 9) + : _backendOptions.FirstOrDefault(o => o.BackendType == Utils.BackendType && o.Id < 7); _modelName = Utils.Model; _modelPath = Utils.Path; @@ -343,6 +418,26 @@ } } + if (IsVertexBackend) + { + var vertexAuth = await SettingsStorage.GetVertexAuthAsync(); + if (vertexAuth != null) + { + _vertexProjectId = vertexAuth.ProjectId; + _vertexClientEmail = vertexAuth.ClientEmail; + _vertexPrivateKey = vertexAuth.PrivateKey; + } + else + { + _vertexProjectId = null; + _vertexClientEmail = null; + _vertexPrivateKey = null; + } + + var settings = await SettingsStorage.LoadSettingsAsync(); + _vertexLocation = settings?.VertexLocation; + } + await LoadApiKeyPreview(); _apiKeyInput = null; } @@ -425,7 +520,8 @@ HasReasoning = hasReasoning, HasImageGen = hasImageGen, ModelPath = _modelPath, - MmProjName = _mmProjName + MmProjName = _mmProjName, + VertexLocation = IsVertexBackend ? _vertexLocation : null }; await SettingsStorage.SaveSettingsAsync(settings); @@ -447,6 +543,19 @@ } } + // Vertex AI: persist auth and build config + GoogleServiceAccountConfig? vertexAuth = null; + if (IsVertexBackend && HasVertexRequiredFields) + { + await SettingsStorage.SaveVertexAuthAsync(_vertexProjectId!, _vertexClientEmail!, _vertexPrivateKey!); + vertexAuth = new GoogleServiceAccountConfig + { + ProjectId = _vertexProjectId!, + ClientEmail = _vertexClientEmail!, + PrivateKey = _vertexPrivateKey! + }; + } + Utils.ApplySettings( _selectedBackend.BackendType, _modelName, @@ -456,7 +565,8 @@ hasImageGen, _mmProjName, MaINSettings, - apiKey); + apiKey, + vertexAuth); await OnSettingsApplied.InvokeAsync(); } @@ -466,6 +576,7 @@ if (_selectedBackend == null) return "Self"; if (_selectedBackend.Id == 7) return "OllamaLocal"; if (_selectedBackend.Id == 8) return "OllamaCloud"; + if (_selectedBackend.Id == 9) return "Vertex"; return _selectedBackend.BackendType.ToString(); } diff --git a/src/MaIN.InferPage/Program.cs b/src/MaIN.InferPage/Program.cs index 86728b75..ef07369d 100644 --- a/src/MaIN.InferPage/Program.cs +++ b/src/MaIN.InferPage/Program.cs @@ -39,10 +39,15 @@ "anthropic" => BackendType.Anthropic, "xai" => BackendType.Xai, "ollama" => BackendType.Ollama, + "vertex" => BackendType.Vertex, _ => BackendType.Self }; - if (Utils.BackendType != BackendType.Self) + if (Utils.BackendType == BackendType.Vertex) + { + Console.WriteLine("Vertex AI requires service account credentials. Configure them via the Settings page."); + } + else if (Utils.BackendType != BackendType.Self) { var apiKeyVariable = LLMApiRegistry.GetEntry(Utils.BackendType)?.ApiKeyEnvName ?? string.Empty; var key = Environment.GetEnvironmentVariable(apiKeyVariable); diff --git a/src/MaIN.InferPage/Services/InferPageSettings.cs b/src/MaIN.InferPage/Services/InferPageSettings.cs index 48bddd4c..55378858 100644 --- a/src/MaIN.InferPage/Services/InferPageSettings.cs +++ b/src/MaIN.InferPage/Services/InferPageSettings.cs @@ -10,4 +10,5 @@ public class InferPageSettings public bool HasImageGen { get; set; } public string? ModelPath { get; set; } public string? MmProjName { get; set; } + public string? VertexLocation { get; set; } } \ No newline at end of file diff --git a/src/MaIN.InferPage/Services/SettingsService.cs b/src/MaIN.InferPage/Services/SettingsService.cs index db52934c..ff697851 100644 --- a/src/MaIN.InferPage/Services/SettingsService.cs +++ b/src/MaIN.InferPage/Services/SettingsService.cs @@ -41,6 +41,18 @@ public async Task SaveProfileForBackendAsync(string backend, string model, return profiles?.GetValueOrDefault(backend); } + // Vertex AI auth (stored separately — PrivateKey should not be in general settings) + private const string VertexAuthKey = "inferpage-vertex-auth"; + + public async Task SaveVertexAuthAsync(string projectId, string clientEmail, string privateKey) + { + var auth = new VertexAuthStorage(projectId, clientEmail, privateKey); + await js.InvokeVoidAsync("settingsManager.save", VertexAuthKey, auth); + } + + public async Task GetVertexAuthAsync() + => await js.InvokeAsync("settingsManager.load", VertexAuthKey); + private async Task SetInDictAsync(string storageKey, string key, string value) { var dict = await LoadDictAsync(storageKey); @@ -56,3 +68,5 @@ private async Task> LoadDictAsync(string storageKey) } public record BackendProfile(string Model, bool Vision, bool Reasoning, bool ImageGen, string? MmProjName = null); + +public record VertexAuthStorage(string ProjectId, string ClientEmail, string PrivateKey); diff --git a/src/MaIN.InferPage/Utils.cs b/src/MaIN.InferPage/Utils.cs index 49c297b5..582fe267 100644 --- a/src/MaIN.InferPage/Utils.cs +++ b/src/MaIN.InferPage/Utils.cs @@ -1,4 +1,5 @@ using MaIN.Domain.Configuration; +using MaIN.Domain.Configuration.Vertex; using MaIN.Domain.Entities; using MaIN.Domain.Models.Abstract; using MaIN.Domain.Models.Concrete; @@ -45,7 +46,8 @@ public static void ApplySettings( bool hasImageGen, string? mmProjName, MaINSettings mainSettings, - string? apiKey) + string? apiKey, + GoogleServiceAccountConfig? vertexAuth = null) { BackendType = backendType; Model = model; @@ -89,6 +91,10 @@ public static void ApplySettings( case BackendType.GroqCloud: mainSettings.GroqCloudKey = apiKey; break; case BackendType.Ollama: mainSettings.OllamaKey = apiKey; break; case BackendType.Xai: mainSettings.XaiKey = apiKey; break; + case BackendType.Vertex: + if (vertexAuth != null) + mainSettings.GoogleServiceAccountAuth = vertexAuth; + break; } } diff --git a/src/MaIN.Services/Bootstrapper.cs b/src/MaIN.Services/Bootstrapper.cs index 6e218c2c..09ad2022 100644 --- a/src/MaIN.Services/Bootstrapper.cs +++ b/src/MaIN.Services/Bootstrapper.cs @@ -111,6 +111,7 @@ private static IServiceCollection AddHttpClients(this IServiceCollection service services.AddHttpClient(ServiceConstants.HttpClients.GroqCloudClient); services.AddHttpClient(ServiceConstants.HttpClients.OllamaClient); services.AddHttpClient(ServiceConstants.HttpClients.OllamaLocalClient); + services.AddHttpClient(ServiceConstants.HttpClients.VertexClient); services.AddHttpClient(ServiceConstants.HttpClients.ImageDownloadClient); services.AddHttpClient(ServiceConstants.HttpClients.ModelContextDownloadClient, client => { diff --git a/src/MaIN.Services/Constants/ServiceConstants.cs b/src/MaIN.Services/Constants/ServiceConstants.cs index ead2d5ba..bdd0d0a8 100644 --- a/src/MaIN.Services/Constants/ServiceConstants.cs +++ b/src/MaIN.Services/Constants/ServiceConstants.cs @@ -13,6 +13,7 @@ public static class HttpClients public const string XaiClient = "XaiClient"; public const string OllamaClient = "OllamaClient"; public const string OllamaLocalClient = "OllamaLocalClient"; + public const string VertexClient = "VertexClient"; public const string ImageDownloadClient = "ImageDownloadClient"; public const string ModelContextDownloadClient = "ModelContextDownloadClient"; } diff --git a/src/MaIN.Services/Mappers/ChatMapper.cs b/src/MaIN.Services/Mappers/ChatMapper.cs index 3955beca..c051eed4 100644 --- a/src/MaIN.Services/Mappers/ChatMapper.cs +++ b/src/MaIN.Services/Mappers/ChatMapper.cs @@ -1,7 +1,7 @@ using MaIN.Domain.Entities; using MaIN.Domain.Models; +using MaIN.Domain.Models.Abstract; using MaIN.Services.Dtos; -using MaIN.Services.Services.ImageGenServices; using FileInfo = MaIN.Domain.Entities.FileInfo; namespace MaIN.Services.Mappers; @@ -44,7 +44,7 @@ public static Chat ToDomain(this ChatDto chat) Name = chat.Name!, ModelId = chat.Model!, Messages = chat.Messages?.Select(m => m.ToDomain()).ToList()!, - ImageGen = chat.Model == ImageGenService.LocalImageModels.FLUX, + ImageGen = ModelRegistry.TryGetById(chat.Model!, out var m) && m!.HasImageGeneration, Type = Enum.Parse(chat.Type.ToString()), Properties = chat.Properties }; diff --git a/src/MaIN.Services/Services/AgentService.cs b/src/MaIN.Services/Services/AgentService.cs index 02895507..7536bdf7 100644 --- a/src/MaIN.Services/Services/AgentService.cs +++ b/src/MaIN.Services/Services/AgentService.cs @@ -9,7 +9,6 @@ using MaIN.Domain.Repositories; using MaIN.Services.Constants; using MaIN.Services.Services.Abstract; -using MaIN.Services.Services.ImageGenServices; using MaIN.Services.Services.LLMService.Factory; using MaIN.Services.Services.Models.Commands; using MaIN.Services.Services.Steps.Commands.Abstract; @@ -101,7 +100,7 @@ public async Task CreateAgent(Agent agent, bool flow = false, bool intera Id = Guid.NewGuid().ToString(), ModelId = agent.Model, Name = agent.Name, - ImageGen = agent.Model == ImageGenService.LocalImageModels.FLUX, + ImageGen = ModelRegistry.TryGetById(agent.Model, out var agentModel) && agentModel!.HasImageGeneration, ToolsConfiguration = agent.ToolsConfiguration, BackendParams = inferenceParams, MemoryParams = memoryParams ?? new MemoryParams(), diff --git a/src/MaIN.Services/Services/ChatService.cs b/src/MaIN.Services/Services/ChatService.cs index 21805d21..7c017172 100644 --- a/src/MaIN.Services/Services/ChatService.cs +++ b/src/MaIN.Services/Services/ChatService.cs @@ -6,7 +6,6 @@ using MaIN.Domain.Models.Abstract; using MaIN.Domain.Repositories; using MaIN.Services.Services.Abstract; -using MaIN.Services.Services.ImageGenServices; using MaIN.Services.Services.LLMService; using MaIN.Services.Services.LLMService.Factory; using MaIN.Services.Services.Models; @@ -34,14 +33,14 @@ public async Task Completions( Func? changeOfValue = null, CancellationToken cancellationToken = default) { - if (chat.ModelId == ImageGenService.LocalImageModels.FLUX) + if (!ModelRegistry.TryGetById(chat.ModelId, out var model)) { - chat.ImageGen = true; + throw new ChatModelNotAvailableException(chat.Id, chat.ModelId); } - if (!ModelRegistry.TryGetById(chat.ModelId, out var model)) + if (model!.HasImageGeneration) { - throw new ChatModelNotAvailableException(chat.Id, chat.ModelId); + chat.ImageGen = true; } var backend = model!.Backend; diff --git a/src/MaIN.Services/Services/ImageGenServices/GeminiImageGenService.cs b/src/MaIN.Services/Services/ImageGenServices/GeminiImageGenService.cs index 1cd58b92..4a4f8cc3 100644 --- a/src/MaIN.Services/Services/ImageGenServices/GeminiImageGenService.cs +++ b/src/MaIN.Services/Services/ImageGenServices/GeminiImageGenService.cs @@ -3,6 +3,7 @@ using MaIN.Services.Constants; using MaIN.Services.Services.Abstract; using MaIN.Services.Services.Models; +using ModelIds = MaIN.Domain.Models.Models; using System.Net.Http.Headers; using System.Net.Http.Json; using System.Text.Json.Serialization; @@ -22,14 +23,11 @@ internal class GeminiImageGenService(IHttpClientFactory httpClientFactory, MaINS string apiKey = _settings.GeminiKey ?? Environment.GetEnvironmentVariable(LLMApiRegistry.Gemini.ApiKeyEnvName) ?? throw new APIKeyNotConfiguredException(LLMApiRegistry.Gemini.ApiName); - if (string.IsNullOrEmpty(chat.ModelId)) - { - chat.ModelId = Models.IMAGEN_GENERATE; - } + var model = string.IsNullOrEmpty(chat.ModelId) ? ModelIds.Gemini.Imagen4_0_FastGenerate : chat.ModelId; client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); var requestBody = new { - model = chat.ModelId, + model, prompt = BuildPromptFromChat(chat), response_format = "b64_json", // necessary for gemini api size = ServiceConstants.Defaults.ImageSize, @@ -37,7 +35,7 @@ internal class GeminiImageGenService(IHttpClientFactory httpClientFactory, MaINS using var response = await client.PostAsJsonAsync(ServiceConstants.ApiUrls.GeminiImageGenerations, requestBody); var imageBytes = await ProcessGeminiResponse(response); - return CreateChatResult(imageBytes); + return CreateChatResult(imageBytes, model); } private static string BuildPromptFromChat(Chat chat) @@ -61,7 +59,7 @@ private async Task ProcessGeminiResponse(HttpResponseMessage response) return Convert.FromBase64String(base64Image); } - private static ChatResult CreateChatResult(byte[] imageBytes) + private static ChatResult CreateChatResult(byte[] imageBytes, string model) { return new ChatResult { @@ -73,15 +71,10 @@ private static ChatResult CreateChatResult(byte[] imageBytes) Image = imageBytes, Type = MessageType.Image }, - Model = Models.IMAGEN_GENERATE, + Model = model, CreatedAt = DateTime.UtcNow }; } - - private struct Models - { - public const string IMAGEN_GENERATE = "imagen-4.0-fast-generate-001"; - } } file class GeminiImageResponse diff --git a/src/MaIN.Services/Services/ImageGenServices/ImageGenService.cs b/src/MaIN.Services/Services/ImageGenServices/ImageGenService.cs index 56e67055..6143de9d 100644 --- a/src/MaIN.Services/Services/ImageGenServices/ImageGenService.cs +++ b/src/MaIN.Services/Services/ImageGenServices/ImageGenService.cs @@ -3,6 +3,7 @@ using MaIN.Services.Constants; using MaIN.Services.Services.Abstract; using MaIN.Services.Services.Models; +using ModelIds = MaIN.Domain.Models.Models; namespace MaIN.Services.Services.ImageGenServices; @@ -48,13 +49,8 @@ private static ChatResult CreateChatResult(byte[] imageBytes) Image = imageBytes, Type = MessageType.Image }, - Model = LocalImageModels.FLUX, + Model = ModelIds.Local.Flux1Shnell, CreatedAt = DateTime.UtcNow }; } - - internal struct LocalImageModels - { - public const string FLUX = "FLUX.1_Shnell"; - } } \ No newline at end of file diff --git a/src/MaIN.Services/Services/ImageGenServices/OpenAiImageGenService.cs b/src/MaIN.Services/Services/ImageGenServices/OpenAiImageGenService.cs index da42e658..ef87d4ee 100644 --- a/src/MaIN.Services/Services/ImageGenServices/OpenAiImageGenService.cs +++ b/src/MaIN.Services/Services/ImageGenServices/OpenAiImageGenService.cs @@ -5,6 +5,7 @@ using MaIN.Services.Constants; using MaIN.Services.Services.Abstract; using MaIN.Services.Services.Models; +using ModelIds = MaIN.Domain.Models.Models; using System.Net.Http.Headers; using System.Net.Http.Json; using System.Text.Json.Serialization; @@ -22,13 +23,14 @@ public class OpenAiImageGenService( public async Task Send(Chat chat) { var client = _httpClientFactory.CreateClient(ServiceConstants.HttpClients.OpenAiClient); - string apiKey = _settings.OpenAiKey ?? Environment.GetEnvironmentVariable(LLMApiRegistry.OpenAi.ApiKeyEnvName) + string apiKey = _settings.OpenAiKey ?? Environment.GetEnvironmentVariable(LLMApiRegistry.OpenAi.ApiKeyEnvName) ?? throw new APIKeyNotConfiguredException(LLMApiRegistry.OpenAi.ApiName); - + + var model = string.IsNullOrEmpty(chat.ModelId) ? ModelIds.OpenAi.DallE3 : chat.ModelId; client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); var requestBody = new { - model = chat.ModelId, + model, prompt = BuildPromptFromChat(chat), size = ServiceConstants.Defaults.ImageSize }; @@ -36,7 +38,7 @@ public class OpenAiImageGenService( using var response = await client.PostAsJsonAsync(ServiceConstants.ApiUrls.OpenAiImageGenerations, requestBody); byte[] imageBytes = await ProcessOpenAiResponse(response); - return CreateChatResult(imageBytes); + return CreateChatResult(imageBytes, model); } private static string BuildPromptFromChat(Chat chat) @@ -73,7 +75,7 @@ private async Task ProcessOpenAiResponse(HttpResponseMessage response) throw new InvalidOperationException("No image URL or base64 data returned from OpenAI"); } - private static ChatResult CreateChatResult(byte[] imageBytes) + private static ChatResult CreateChatResult(byte[] imageBytes, string model) { return new ChatResult { @@ -85,15 +87,10 @@ private static ChatResult CreateChatResult(byte[] imageBytes) Image = imageBytes, Type = MessageType.Image }, - Model = Models.DALLE, + Model = model, CreatedAt = DateTime.UtcNow }; } - - private struct Models - { - public const string DALLE = "dall-e-3"; - } } file class OpenAiImageResponse diff --git a/src/MaIN.Services/Services/ImageGenServices/VertexImageGenService.cs b/src/MaIN.Services/Services/ImageGenServices/VertexImageGenService.cs new file mode 100644 index 00000000..6495bb8c --- /dev/null +++ b/src/MaIN.Services/Services/ImageGenServices/VertexImageGenService.cs @@ -0,0 +1,117 @@ +using MaIN.Domain.Configuration; +using MaIN.Domain.Configuration.BackendInferenceParams; +using MaIN.Domain.Entities; +using MaIN.Services.Constants; +using MaIN.Services.Services.Abstract; +using MaIN.Services.Services.LLMService.Auth; +using MaIN.Services.Services.Models; +using ModelIds = MaIN.Domain.Models.Models; +using System.Net.Http.Headers; +using System.Net.Http.Json; +using System.Text.Json.Serialization; + +namespace MaIN.Services.Services.ImageGenServices; + +internal class VertexImageGenService(IHttpClientFactory httpClientFactory, MaINSettings settings) : IImageGenService +{ + private const string DefaultLocation = "us-central1"; + + public async Task Send(Chat chat) + { + var auth = settings.GoogleServiceAccountAuth + ?? throw new InvalidOperationException("Vertex AI service account is not configured."); + + var location = chat.BackendParams is VertexInferenceParams vp + ? vp.Location + : DefaultLocation; + + using var tokenProvider = new GoogleServiceAccountTokenProvider(auth); + var httpClient = httpClientFactory.CreateClient(ServiceConstants.HttpClients.VertexClient); + var accessToken = await tokenProvider.GetAccessTokenAsync(httpClient); + + var model = ExtractModelName(chat.ModelId); + var endpoint = $"https://{location}-aiplatform.googleapis.com/v1/projects/{auth.ProjectId}/locations/{location}/publishers/google/models/{model}:predict"; + + var requestBody = new + { + instances = new[] + { + new { prompt = BuildPromptFromChat(chat) } + }, + parameters = new + { + sampleCount = 1, + aspectRatio = "1:1" + } + }; + + using var request = new HttpRequestMessage(HttpMethod.Post, endpoint); + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken); + request.Content = JsonContent.Create(requestBody); + + using var response = await httpClient.SendAsync(request); + + if (!response.IsSuccessStatusCode) + { + var error = await response.Content.ReadAsStringAsync(); + throw new InvalidOperationException( + $"Vertex AI Imagen request failed ({response.StatusCode}): {error}"); + } + + var result = await response.Content.ReadFromJsonAsync(); + var base64Image = result?.Predictions?.FirstOrDefault()?.BytesBase64Encoded; + + if (string.IsNullOrEmpty(base64Image)) + throw new InvalidOperationException("No image returned from Vertex AI Imagen."); + + var imageBytes = Convert.FromBase64String(base64Image); + + return new ChatResult + { + Done = true, + Message = new Message + { + Content = ServiceConstants.Messages.GeneratedImageContent, + Role = ServiceConstants.Roles.Assistant, + Image = imageBytes, + Type = MessageType.Image + }, + Model = string.IsNullOrEmpty(chat.ModelId) ? ModelIds.Vertex.Imagen4_0_Generate : chat.ModelId, + CreatedAt = DateTime.UtcNow + }; + } + + private static string BuildPromptFromChat(Chat chat) + { + return chat.Messages + .Select((msg, index) => index == 0 ? msg.Content : $"&& {msg.Content}") + .Aggregate((current, next) => $"{current} {next}"); + } + + /// + /// Strips the "google/" publisher prefix if present (Vertex predict endpoint doesn't use it). + /// + private static string ExtractModelName(string? modelId) + { + var resolved = string.IsNullOrEmpty(modelId) ? ModelIds.Vertex.Imagen4_0_Generate : modelId; + + return resolved.StartsWith("google/", StringComparison.OrdinalIgnoreCase) + ? resolved["google/".Length..] + : resolved; + } +} + +file class ImagenResponse +{ + [JsonPropertyName("predictions")] + public ImagenPrediction[]? Predictions { get; set; } +} + +file class ImagenPrediction +{ + [JsonPropertyName("bytesBase64Encoded")] + public string? BytesBase64Encoded { get; set; } + + [JsonPropertyName("mimeType")] + public string? MimeType { get; set; } +} diff --git a/src/MaIN.Services/Services/ImageGenServices/XaiImageGenService.cs b/src/MaIN.Services/Services/ImageGenServices/XaiImageGenService.cs index 92e0a199..a513a9b9 100644 --- a/src/MaIN.Services/Services/ImageGenServices/XaiImageGenService.cs +++ b/src/MaIN.Services/Services/ImageGenServices/XaiImageGenService.cs @@ -3,6 +3,7 @@ using MaIN.Services.Constants; using MaIN.Services.Services.Abstract; using MaIN.Services.Services.Models; +using ModelIds = MaIN.Domain.Models.Models; using System.Net.Http.Headers; using System.Net.Http.Json; using System.Text.Json; @@ -25,10 +26,11 @@ public class XaiImageGenService( string apiKey = _settings.XaiKey ?? Environment.GetEnvironmentVariable(LLMApiRegistry.Xai.ApiKeyEnvName) ?? throw new APIKeyNotConfiguredException(LLMApiRegistry.Xai.ApiName); + var model = string.IsNullOrWhiteSpace(chat.ModelId) ? ModelIds.Xai.GrokImage : chat.ModelId; client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); var requestBody = new { - model = string.IsNullOrWhiteSpace(chat.ModelId) ? Models.GROK_IMAGE : chat.ModelId, + model, prompt = BuildPromptFromChat(chat), n = 1, response_format = "b64_json" //or "url" @@ -36,7 +38,7 @@ public class XaiImageGenService( using var response = await client.PostAsJsonAsync(ServiceConstants.ApiUrls.XaiImageGenerations, requestBody); var imageBytes = await ProcessXaiResponse(response); - return CreateChatResult(imageBytes); + return CreateChatResult(imageBytes, model); } private static string BuildPromptFromChat(Chat chat) @@ -77,7 +79,7 @@ private async Task DownloadImageAsync(string imageUrl) return await imageResponse.Content.ReadAsByteArrayAsync(); } - private static ChatResult CreateChatResult(byte[] imageBytes) + private static ChatResult CreateChatResult(byte[] imageBytes, string model) { return new ChatResult { @@ -89,15 +91,10 @@ private static ChatResult CreateChatResult(byte[] imageBytes) Image = imageBytes, Type = MessageType.Image }, - Model = Models.GROK_IMAGE, + Model = model, CreatedAt = DateTime.UtcNow }; } - - private struct Models - { - public const string GROK_IMAGE = "grok-2-image"; - } } diff --git a/src/MaIN.Services/Services/LLMService/Auth/GoogleServiceAccountTokenProvider.cs b/src/MaIN.Services/Services/LLMService/Auth/GoogleServiceAccountTokenProvider.cs new file mode 100644 index 00000000..664143a4 --- /dev/null +++ b/src/MaIN.Services/Services/LLMService/Auth/GoogleServiceAccountTokenProvider.cs @@ -0,0 +1,121 @@ +using System.Collections.Concurrent; +using System.Net.Http.Json; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json.Serialization; +using MaIN.Domain.Configuration.Vertex; + +namespace MaIN.Services.Services.LLMService.Auth; + +internal sealed class GoogleServiceAccountTokenProvider : IDisposable +{ + private const string Scope = "https://www.googleapis.com/auth/cloud-platform"; + private const int TokenLifetimeSeconds = 3600; + private const int RefreshBufferMinutes = 5; + + private readonly GoogleServiceAccountConfig _config; + private readonly RSA _rsa; + + private static readonly ConcurrentDictionary TokenCache = new(); + private static readonly ConcurrentDictionary RefreshLocks = new(); + + public GoogleServiceAccountTokenProvider(GoogleServiceAccountConfig config) + { + _config = config; + _rsa = RSA.Create(); + _rsa.ImportFromPem(config.PrivateKey.Replace("\\n", "\n")); + } + + public async Task GetAccessTokenAsync(HttpClient httpClient) + { + var email = _config.ClientEmail; + + if (TokenCache.TryGetValue(email, out var cached) && !cached.IsExpired) + return cached.Token; + + var refreshLock = RefreshLocks.GetOrAdd(email, _ => new SemaphoreSlim(1, 1)); + await refreshLock.WaitAsync(); + try + { + // Double-check after acquiring lock + if (TokenCache.TryGetValue(email, out cached) && !cached.IsExpired) + return cached.Token; + + var jwt = BuildSignedJwt(); + var token = await ExchangeJwtForTokenAsync(httpClient, jwt); + + var accessToken = token.AccessToken + ?? throw new InvalidOperationException("Vertex AI token response missing access_token."); + var expiry = DateTime.UtcNow.AddSeconds(token.ExpiresIn).AddMinutes(-RefreshBufferMinutes); + + TokenCache[email] = new CachedToken(accessToken, expiry); + return accessToken; + } + finally + { + refreshLock.Release(); + } + } + + private string BuildSignedJwt() + { + var now = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); + + var header = Base64UrlEncode(System.Text.Json.JsonSerializer.SerializeToUtf8Bytes(new + { + alg = "RS256", + typ = "JWT" + })); + + var payload = Base64UrlEncode(System.Text.Json.JsonSerializer.SerializeToUtf8Bytes(new + { + iss = _config.ClientEmail, + scope = Scope, + aud = _config.TokenUri, + iat = now, + exp = now + TokenLifetimeSeconds + })); + + var dataToSign = Encoding.ASCII.GetBytes($"{header}.{payload}"); + var signature = _rsa.SignData(dataToSign, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + + return $"{header}.{payload}.{Base64UrlEncode(signature)}"; + } + + private static async Task ExchangeJwtForTokenAsync(HttpClient httpClient, string jwt) + { + using var content = new FormUrlEncodedContent(new Dictionary + { + ["grant_type"] = "urn:ietf:params:oauth:grant-type:jwt-bearer", + ["assertion"] = jwt + }); + + using var response = await httpClient.PostAsync("https://oauth2.googleapis.com/token", content); + + if (!response.IsSuccessStatusCode) + { + var error = await response.Content.ReadAsStringAsync(); + throw new InvalidOperationException( + $"Vertex AI token exchange failed ({response.StatusCode}): {error}"); + } + + return await response.Content.ReadFromJsonAsync() + ?? throw new InvalidOperationException("Failed to parse Vertex AI token response."); + } + + public void Dispose() => _rsa.Dispose(); + + private static string Base64UrlEncode(byte[] data) + => Convert.ToBase64String(data).TrimEnd('=').Replace('+', '-').Replace('/', '_'); + + private sealed record CachedToken(string Token, DateTime Expiry) + { + public bool IsExpired => DateTime.UtcNow >= Expiry; + } + + private sealed class TokenResponse + { + [JsonPropertyName("access_token")] public string? AccessToken { get; set; } + [JsonPropertyName("expires_in")] public int ExpiresIn { get; set; } + } +} diff --git a/src/MaIN.Services/Services/LLMService/Factory/ImageGenServiceFactory.cs b/src/MaIN.Services/Services/LLMService/Factory/ImageGenServiceFactory.cs index 0854b665..f8fa4290 100644 --- a/src/MaIN.Services/Services/LLMService/Factory/ImageGenServiceFactory.cs +++ b/src/MaIN.Services/Services/LLMService/Factory/ImageGenServiceFactory.cs @@ -20,11 +20,11 @@ public class ImageGenServiceFactory(IServiceProvider serviceProvider) : IImageGe BackendType.Anthropic => null, BackendType.Xai => new XaiImageGenService(serviceProvider.GetRequiredService(), serviceProvider.GetRequiredService()), + BackendType.Vertex => new VertexImageGenService(serviceProvider.GetRequiredService(), + serviceProvider.GetRequiredService()), BackendType.Ollama => null, BackendType.Self => new ImageGenService(serviceProvider.GetRequiredService(), serviceProvider.GetRequiredService()), - - // Add other backends as needed _ => throw new NotSupportedException("Not support image generation."), }; } diff --git a/src/MaIN.Services/Services/LLMService/Factory/LLMServiceFactory.cs b/src/MaIN.Services/Services/LLMService/Factory/LLMServiceFactory.cs index d404c89d..844ad387 100644 --- a/src/MaIN.Services/Services/LLMService/Factory/LLMServiceFactory.cs +++ b/src/MaIN.Services/Services/LLMService/Factory/LLMServiceFactory.cs @@ -58,6 +58,13 @@ public ILLMService CreateService(BackendType backendType) serviceProvider.GetRequiredService(), serviceProvider.GetRequiredService()), + BackendType.Vertex => new VertexService( + serviceProvider.GetRequiredService(), + serviceProvider.GetRequiredService(), + serviceProvider.GetRequiredService(), + serviceProvider.GetRequiredService(), + serviceProvider.GetRequiredService()), + BackendType.Self => new LLMService( serviceProvider.GetRequiredService(), serviceProvider.GetRequiredService(), diff --git a/src/MaIN.Services/Services/LLMService/Memory/IMemoryFactory.cs b/src/MaIN.Services/Services/LLMService/Memory/IMemoryFactory.cs index c411ae63..955ddf32 100644 --- a/src/MaIN.Services/Services/LLMService/Memory/IMemoryFactory.cs +++ b/src/MaIN.Services/Services/LLMService/Memory/IMemoryFactory.cs @@ -14,4 +14,5 @@ public interface IMemoryFactory MemoryParams memoryParams); IKernelMemory CreateMemoryWithOpenAi(string openAiKey, MemoryParams memoryParams); IKernelMemory CreateMemoryWithGemini(string geminiKey, MemoryParams memoryParams); + IKernelMemory CreateMemoryWithVertex(Func> bearerTokenProvider, string location, string projectId, MemoryParams memoryParams); } \ No newline at end of file diff --git a/src/MaIN.Services/Services/LLMService/Memory/MemoryFactory.cs b/src/MaIN.Services/Services/LLMService/Memory/MemoryFactory.cs index e8b45ec4..046c2293 100644 --- a/src/MaIN.Services/Services/LLMService/Memory/MemoryFactory.cs +++ b/src/MaIN.Services/Services/LLMService/Memory/MemoryFactory.cs @@ -74,6 +74,29 @@ public IKernelMemory CreateMemoryWithGemini(string geminiKey, MemoryParams memor .WithSemanticKernelTextEmbeddingGenerationService( new GoogleAITextEmbeddingGenerationService("gemini-embedding-001", geminiKey), new SemanticKernelConfig()) #pragma warning restore SKEXP0070 + .WithCustomImageOcr(new OcrWrapper()) + .WithSimpleVectorDb() + .Build(); + + return kernelMemory; + } + + public IKernelMemory CreateMemoryWithVertex(Func> bearerTokenProvider, string location, string projectId, MemoryParams memoryParams) + { + var searchOptions = ConfigureSearchOptions(memoryParams); + + var kernelMemory = new KernelMemoryBuilder() + .WithSearchClientConfig(searchOptions) +#pragma warning disable SKEXP0070 + .WithSemanticKernelTextGenerationService( + new GeminiTextGeneratorAdapter( + new VertexAIGeminiChatCompletionService("gemini-2.5-flash", bearerTokenProvider, location, projectId)), + new SemanticKernelConfig()) + .WithSemanticKernelTextEmbeddingGenerationService( + new VertexAITextEmbeddingGenerationService("text-embedding-005", bearerTokenProvider, location, projectId), + new SemanticKernelConfig()) +#pragma warning restore SKEXP0070 + .WithCustomImageOcr(new OcrWrapper()) .WithSimpleVectorDb() .Build(); diff --git a/src/MaIN.Services/Services/LLMService/Utils/ChatHelper.cs b/src/MaIN.Services/Services/LLMService/Utils/ChatHelper.cs index d9c08bd9..fc2b34c5 100644 --- a/src/MaIN.Services/Services/LLMService/Utils/ChatHelper.cs +++ b/src/MaIN.Services/Services/LLMService/Utils/ChatHelper.cs @@ -259,6 +259,11 @@ private static string DetectImageMimeType(byte[] imageBytes) if (imageBytes.Length < 4) return "image/jpeg"; + // PDF: %PDF (0x25 0x50 0x44 0x46) + if (imageBytes[0] == 0x25 && imageBytes[1] == 0x50 && + imageBytes[2] == 0x44 && imageBytes[3] == 0x46) + return "application/pdf"; + if (imageBytes[0] == 0xFF && imageBytes[1] == 0xD8) return "image/jpeg"; diff --git a/src/MaIN.Services/Services/LLMService/VertexService.cs b/src/MaIN.Services/Services/LLMService/VertexService.cs new file mode 100644 index 00000000..c5cd4571 --- /dev/null +++ b/src/MaIN.Services/Services/LLMService/VertexService.cs @@ -0,0 +1,243 @@ +using System.Text; +using MaIN.Domain.Configuration; +using MaIN.Domain.Configuration.BackendInferenceParams; +using MaIN.Domain.Configuration.Vertex; +using MaIN.Domain.Entities; +using MaIN.Domain.Models; +using MaIN.Domain.Models.Concrete; +using MaIN.Services.Constants; +using MaIN.Services.Services.Abstract; +using MaIN.Services.Services.LLMService.Auth; +using MaIN.Services.Services.LLMService.Memory; +using MaIN.Services.Services.Models; +using MaIN.Services.Utils; +using Microsoft.Extensions.Logging; + +namespace MaIN.Services.Services.LLMService; + +public sealed class VertexService( + MaINSettings settings, + INotificationService notificationService, + IHttpClientFactory httpClientFactory, + IMemoryFactory memoryFactory, + IMemoryService memoryService, + ILogger? logger = null) + : OpenAiCompatibleService(notificationService, httpClientFactory, memoryFactory, memoryService, logger), ILLMService +{ + private GoogleServiceAccountTokenProvider? _tokenProvider; + private string _location = "us-central1"; + + private GoogleServiceAccountConfig Auth + => settings.GoogleServiceAccountAuth + ?? throw new InvalidOperationException("Vertex AI service account is not configured."); + + protected override string HttpClientName => ServiceConstants.HttpClients.VertexClient; + + protected override string ChatCompletionsUrl + => $"https://{_location}-aiplatform.googleapis.com/v1beta1/projects/{Auth.ProjectId}/locations/{_location}/endpoints/openapi/chat/completions"; + + protected override string ModelsUrl + => $"https://{_location}-aiplatform.googleapis.com/v1beta1/projects/{Auth.ProjectId}/locations/{_location}/endpoints/openapi/models"; + + protected override Type ExpectedParamsType => typeof(VertexInferenceParams); + + protected override string GetApiKey() + { + var auth = Auth; + _tokenProvider ??= new GoogleServiceAccountTokenProvider(auth); + + var httpClient = httpClientFactory.CreateClient(HttpClientName); + // Task.Run avoids deadlocking on Blazor Server's single-threaded SynchronizationContext + return Task.Run(() => _tokenProvider.GetAccessTokenAsync(httpClient)).GetAwaiter().GetResult(); + } + + protected override string GetApiName() => LLMApiRegistry.Vertex.ApiName; + + protected override void ValidateApiKey() + { + var auth = Auth; + if (string.IsNullOrEmpty(auth.ProjectId)) + throw new InvalidOperationException("GoogleServiceAccountConfig.ProjectId is required."); + if (string.IsNullOrEmpty(auth.ClientEmail)) + throw new InvalidOperationException("GoogleServiceAccountConfig.ClientEmail is required."); + if (string.IsNullOrEmpty(auth.PrivateKey)) + throw new InvalidOperationException("GoogleServiceAccountConfig.PrivateKey is required."); + } + + protected override void ApplyBackendParams(Dictionary requestBody, Chat chat) + { + if (chat.BackendParams is not VertexInferenceParams p) return; + if (p.Temperature.HasValue) requestBody["temperature"] = p.Temperature.Value; + if (p.MaxTokens.HasValue) requestBody["max_tokens"] = p.MaxTokens.Value; + if (p.TopP.HasValue) requestBody["top_p"] = p.TopP.Value; + if (p.StopSequences is { Length: > 0 }) requestBody["stop"] = p.StopSequences; + } + + public new async Task Send( + Chat chat, + ChatRequestOptions options, + CancellationToken cancellationToken = default) + { + ExtractLocation(chat); + return await base.Send(chat, options, cancellationToken); + } + + /// + /// Sends files directly to Gemini via multimodal API (bypasses KernelMemory). + /// PDFs and images are sent inline (Gemini handles OCR natively), + /// other formats are pre-processed to text via DocumentProcessor. + /// + public override async Task AskMemory( + Chat chat, + ChatMemoryOptions memoryOptions, + ChatRequestOptions requestOptions, + CancellationToken cancellationToken = default) + { + ExtractLocation(chat); + + if (!chat.Messages.Any()) + return null; + + var lastMessage = chat.Messages.Last(); + var originalContent = lastMessage.Content; + var originalFiles = lastMessage.Files; + var originalImages = lastMessage.Images; + + try + { + var inlineBytes = new List(); + var textContext = new StringBuilder(); + + CollectTextData(memoryOptions, textContext); + await CollectFilesData(memoryOptions, inlineBytes, textContext, cancellationToken); + await CollectStreamData(memoryOptions, inlineBytes, textContext, cancellationToken); + CollectMemoryItems(memoryOptions, textContext); + + lastMessage.Content = BuildQuery(originalContent, textContext, chat.MemoryParams.Grammar); + lastMessage.Files = null; + lastMessage.Images = MergeInlineContent(originalImages, inlineBytes); + + return await Send(chat, requestOptions, cancellationToken); + } + finally + { + lastMessage.Content = originalContent; + lastMessage.Files = originalFiles; + lastMessage.Images = originalImages; + } + } + + #region Multimodal File Processing + + private static readonly HashSet NativeMultimodalExtensions = + [".pdf", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif", ".heic", ".heif", ".avif"]; + + private static bool IsNativeMultimodalFile(string fileName) + => NativeMultimodalExtensions.Contains(Path.GetExtension(fileName).ToLowerInvariant()); + + private static void CollectTextData(ChatMemoryOptions options, StringBuilder textContext) + { + foreach (var (name, content) in options.TextData) + AppendDocument(textContext, name, content); + } + + private static async Task CollectFilesData( + ChatMemoryOptions options, List inlineBytes, StringBuilder textContext, + CancellationToken cancellationToken) + { + foreach (var (name, path) in options.FilesData) + { + if (IsNativeMultimodalFile(name)) + inlineBytes.Add(await File.ReadAllBytesAsync(path, cancellationToken)); + else + AppendDocument(textContext, name, DocumentProcessor.ProcessDocument(path)); + } + } + + private static async Task CollectStreamData( + ChatMemoryOptions options, List inlineBytes, StringBuilder textContext, + CancellationToken cancellationToken) + { + foreach (var (name, stream) in options.StreamData) + { + if (stream.CanSeek) stream.Position = 0; + using var ms = new MemoryStream(); + await stream.CopyToAsync(ms, cancellationToken); + var bytes = ms.ToArray(); + + if (IsNativeMultimodalFile(name)) + { + inlineBytes.Add(bytes); + } + else + { + var tempPath = Path.Combine(Path.GetTempPath(), $"vertex_{Guid.NewGuid()}{Path.GetExtension(name)}"); + try + { + await File.WriteAllBytesAsync(tempPath, bytes, cancellationToken); + AppendDocument(textContext, name, DocumentProcessor.ProcessDocument(tempPath)); + } + finally + { + if (File.Exists(tempPath)) File.Delete(tempPath); + } + } + } + } + + private static void CollectMemoryItems(ChatMemoryOptions options, StringBuilder textContext) + { + if (options.Memory is not { Count: > 0 }) return; + foreach (var item in options.Memory) + { + textContext.AppendLine(item); + textContext.AppendLine(); + } + } + + private static void AppendDocument(StringBuilder sb, string name, string content) + { + sb.AppendLine($"[Document: {name}]"); + sb.AppendLine(content); + sb.AppendLine(); + } + + private static string BuildQuery(string userQuestion, StringBuilder documentContext, Grammar? grammar) + { + var query = new StringBuilder(); + if (documentContext.Length > 0) + { + query.AppendLine("Use the following document content to answer the question:\n"); + query.Append(documentContext); + query.AppendLine(); + } + query.Append(userQuestion); + + if (grammar != null) + { + var jsonGrammar = new GrammarToJsonConverter().ConvertToJson(grammar); + query.Append( + $" | For your next response only, please respond using exactly the following JSON format: \n{jsonGrammar}\n. Do not include any explanations, code blocks, or additional content. After this single JSON response, resume your normal conversational style."); + } + + return query.ToString(); + } + + private static List? MergeInlineContent(List? existingImages, List newBytes) + { + if ((existingImages == null || existingImages.Count == 0) && newBytes.Count == 0) + return null; + + var merged = new List(existingImages ?? []); + merged.AddRange(newBytes); + return merged; + } + + #endregion + + private void ExtractLocation(Chat chat) + { + if (chat.BackendParams is VertexInferenceParams vp) + _location = vp.Location; + } +} diff --git a/src/MaIN.Services/Services/McpService.cs b/src/MaIN.Services/Services/McpService.cs index 9572b40b..8bd1b7f6 100644 --- a/src/MaIN.Services/Services/McpService.cs +++ b/src/MaIN.Services/Services/McpService.cs @@ -2,6 +2,7 @@ using MaIN.Domain.Entities; using MaIN.Domain.Models.Concrete; using MaIN.Services.Services.Abstract; +using MaIN.Services.Services.LLMService.Auth; using MaIN.Services.Services.LLMService.Utils; using MaIN.Services.Services.Models; using Microsoft.SemanticKernel; @@ -30,7 +31,7 @@ public async Task Prompt(Mcp config, List messageHistory) ); var builder = Kernel.CreateBuilder(); - var promptSettings = InitializeChatCompletions(builder, config.Backend ?? settings.BackendType, config.Model); + var promptSettings = InitializeChatCompletions(builder, config); var kernel = builder.Build(); var tools = await mcpClient.ListToolsAsync(); kernel.Plugins.AddFromFunctions("Tools", tools.Select(x => x.AsKernelFunction())); @@ -49,10 +50,10 @@ public async Task Prompt(Mcp config, List messageHistory) } var chatService = kernel.GetRequiredService(); - + var result = await chatService.GetChatMessageContentsAsync( - chatHistory, - promptSettings, + chatHistory, + promptSettings, kernel); return new McpResult @@ -68,8 +69,11 @@ public async Task Prompt(Mcp config, List messageHistory) }; } - private PromptExecutionSettings InitializeChatCompletions(IKernelBuilder kernelBuilder, BackendType backendType, string model) + private PromptExecutionSettings InitializeChatCompletions(IKernelBuilder kernelBuilder, Mcp config) { + var backendType = config.Backend ?? settings.BackendType; + var model = config.Model; + switch (backendType) { case BackendType.OpenAi: @@ -118,6 +122,24 @@ private PromptExecutionSettings InitializeChatCompletions(IKernelBuilder kernelB FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: new() { RetainArgumentTypes = true }) }; + case BackendType.Vertex: + var auth = settings.GoogleServiceAccountAuth + ?? throw new InvalidOperationException("Vertex AI service account is not configured."); + var tokenProvider = new GoogleServiceAccountTokenProvider(auth); + var httpClient = new HttpClient(); + Func> bearerTokenProvider = async () + => await tokenProvider.GetAccessTokenAsync(httpClient); + + var modelName = model.StartsWith("google/", StringComparison.OrdinalIgnoreCase) + ? model["google/".Length..] + : model; + + kernelBuilder.Services.AddVertexAIGeminiChatCompletion(modelName, bearerTokenProvider, config.Location, auth.ProjectId); + return new GeminiPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: new() { RetainArgumentTypes = true }) + }; + case BackendType.Ollama: throw new NotSupportedException("Ollama models does not support MCP integration."); diff --git a/src/MaIN.Services/Utils/AgentStateManager.cs b/src/MaIN.Services/Utils/AgentStateManager.cs index 3effa2a5..62d223d3 100644 --- a/src/MaIN.Services/Utils/AgentStateManager.cs +++ b/src/MaIN.Services/Utils/AgentStateManager.cs @@ -1,6 +1,6 @@ using MaIN.Domain.Entities; using MaIN.Domain.Entities.Agents; -using MaIN.Services.Services.ImageGenServices; +using MaIN.Domain.Models.Abstract; namespace MaIN.Services.Utils; @@ -11,7 +11,7 @@ public static void ClearState(Agent agent, Chat chat) agent.CurrentBehaviour = "Default"; chat.Properties.Clear(); - if (chat.ModelId == ImageGenService.LocalImageModels.FLUX) + if (ModelRegistry.TryGetById(chat.ModelId, out var model) && model!.HasImageGeneration) { chat.Messages = []; }