Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions src/main/java/com/llmproxy/service/llm/ModelVersionValidator.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import com.llmproxy.model.ModelType;
import org.springframework.stereotype.Component;

import java.util.HashMap;
import java.util.Collections;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.HashSet;

@Component
public class ModelVersionValidator {
Expand All @@ -15,19 +18,25 @@ public class ModelVersionValidator {
public static final String DEFAULT_CLAUDE_VERSION = "claude-3-sonnet-20240229";

private final Map<ModelType, List<String>> supportedModelVersions;
private final Map<ModelType, Set<String>> validVersionSets;

public ModelVersionValidator() {
supportedModelVersions = new HashMap<>();
supportedModelVersions.put(ModelType.OPENAI, List.of(
// Use EnumMap for better performance with enum keys
supportedModelVersions = new EnumMap<>(ModelType.class);
validVersionSets = new EnumMap<>(ModelType.class);

// Define lists once and store immutable references
List<String> openaiVersions = List.of(
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4",
"gpt-4-vision-preview",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k"
));
supportedModelVersions.put(ModelType.GEMINI, List.of(
);

List<String> geminiVersions = List.of(
"gemini-2.5-flash-preview-04-17",
"gemini-2.5-pro-preview-03-25",
"gemini-2.0-flash",
Expand All @@ -37,17 +46,19 @@ public ModelVersionValidator() {
"gemini-1.5-pro",
"gemini-pro",
"gemini-pro-vision"
));
supportedModelVersions.put(ModelType.MISTRAL, List.of(
);

List<String> mistralVersions = List.of(
"codestral-latest",
"mistral-large-latest",
"mistral-saba-latest",
"mistral-tiny",
"mistral-small",
"mistral-medium",
"mistral-large"
));
supportedModelVersions.put(ModelType.CLAUDE, List.of(
);

List<String> claudeVersions = List.of(
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
Expand All @@ -56,19 +67,27 @@ public ModelVersionValidator() {
"claude-3-haiku",
"claude-2.1",
"claude-2.0"
));
);

supportedModelVersions.put(ModelType.OPENAI, openaiVersions);
supportedModelVersions.put(ModelType.GEMINI, geminiVersions);
supportedModelVersions.put(ModelType.MISTRAL, mistralVersions);
supportedModelVersions.put(ModelType.CLAUDE, claudeVersions);

// Create sets for O(1) lookups instead of stream filtering
for (Map.Entry<ModelType, List<String>> entry : supportedModelVersions.entrySet()) {
validVersionSets.put(entry.getKey(), new HashSet<>(entry.getValue()));
}
}

public String validateModelVersion(ModelType modelType, String version) {
if (version == null || version.isBlank()) {
return getDefaultVersionForModel(modelType);
}

List<String> validVersions = supportedModelVersions.get(modelType);
return validVersions.stream()
.filter(v -> v.equals(version))
.findFirst()
.orElse(getDefaultVersionForModel(modelType));
Set<String> validVersions = validVersionSets.get(modelType);
return validVersions != null && validVersions.contains(version) ?
version : getDefaultVersionForModel(modelType);
}

private String getDefaultVersionForModel(ModelType modelType) {
Expand All @@ -81,6 +100,6 @@ private String getDefaultVersionForModel(ModelType modelType) {
}

public List<String> getSupportedVersionsForModel(ModelType modelType) {
return supportedModelVersions.getOrDefault(modelType, List.of());
return supportedModelVersions.getOrDefault(modelType, Collections.emptyList());
}
}
}