diff --git a/src/main/java/com/llmproxy/service/llm/ModelVersionValidator.java b/src/main/java/com/llmproxy/service/llm/ModelVersionValidator.java index e778a4f..4a1d1ac 100644 --- a/src/main/java/com/llmproxy/service/llm/ModelVersionValidator.java +++ b/src/main/java/com/llmproxy/service/llm/ModelVersionValidator.java @@ -3,9 +3,12 @@ import com.llmproxy.model.ModelType; import org.springframework.stereotype.Component; -import java.util.HashMap; +import java.util.Collections; +import java.util.EnumMap; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.HashSet; @Component public class ModelVersionValidator { @@ -15,10 +18,15 @@ public class ModelVersionValidator { public static final String DEFAULT_CLAUDE_VERSION = "claude-3-sonnet-20240229"; private final Map> supportedModelVersions; + private final Map> validVersionSets; public ModelVersionValidator() { - supportedModelVersions = new HashMap<>(); - supportedModelVersions.put(ModelType.OPENAI, List.of( + // Use EnumMap for better performance with enum keys + supportedModelVersions = new EnumMap<>(ModelType.class); + validVersionSets = new EnumMap<>(ModelType.class); + + // Define lists once and store immutable references + List openaiVersions = List.of( "gpt-4o", "gpt-4o-mini", "gpt-4-turbo", @@ -26,8 +34,9 @@ public ModelVersionValidator() { "gpt-4-vision-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k" - )); - supportedModelVersions.put(ModelType.GEMINI, List.of( + ); + + List geminiVersions = List.of( "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25", "gemini-2.0-flash", @@ -37,8 +46,9 @@ public ModelVersionValidator() { "gemini-1.5-pro", "gemini-pro", "gemini-pro-vision" - )); - supportedModelVersions.put(ModelType.MISTRAL, List.of( + ); + + List mistralVersions = List.of( "codestral-latest", "mistral-large-latest", "mistral-saba-latest", @@ -46,8 +56,9 @@ public ModelVersionValidator() { "mistral-small", "mistral-medium", "mistral-large" - )); - supportedModelVersions.put(ModelType.CLAUDE, List.of( + ); + + List claudeVersions = List.of( "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", @@ -56,7 +67,17 @@ public ModelVersionValidator() { "claude-3-haiku", "claude-2.1", "claude-2.0" - )); + ); + + supportedModelVersions.put(ModelType.OPENAI, openaiVersions); + supportedModelVersions.put(ModelType.GEMINI, geminiVersions); + supportedModelVersions.put(ModelType.MISTRAL, mistralVersions); + supportedModelVersions.put(ModelType.CLAUDE, claudeVersions); + + // Create sets for O(1) lookups instead of stream filtering + for (Map.Entry> entry : supportedModelVersions.entrySet()) { + validVersionSets.put(entry.getKey(), new HashSet<>(entry.getValue())); + } } public String validateModelVersion(ModelType modelType, String version) { @@ -64,11 +85,9 @@ public String validateModelVersion(ModelType modelType, String version) { return getDefaultVersionForModel(modelType); } - List validVersions = supportedModelVersions.get(modelType); - return validVersions.stream() - .filter(v -> v.equals(version)) - .findFirst() - .orElse(getDefaultVersionForModel(modelType)); + Set validVersions = validVersionSets.get(modelType); + return validVersions != null && validVersions.contains(version) ? + version : getDefaultVersionForModel(modelType); } private String getDefaultVersionForModel(ModelType modelType) { @@ -81,6 +100,6 @@ private String getDefaultVersionForModel(ModelType modelType) { } public List getSupportedVersionsForModel(ModelType modelType) { - return supportedModelVersions.getOrDefault(modelType, List.of()); + return supportedModelVersions.getOrDefault(modelType, Collections.emptyList()); } -} +} \ No newline at end of file