diff --git a/.gitignore b/.gitignore index 4fe8b88..1cb204e 100644 --- a/.gitignore +++ b/.gitignore @@ -67,6 +67,7 @@ src/main/resources/static/assets/ ### Local scratchpad ### tmp/ bin/ +dogfood-output/ ### Build artifacts (root level) ### BOOT-INF/ diff --git a/src/main/java/com/williamcallahan/javachat/config/RetrievalAugmentationConfig.java b/src/main/java/com/williamcallahan/javachat/config/RetrievalAugmentationConfig.java index 4dea61c..4a09620 100644 --- a/src/main/java/com/williamcallahan/javachat/config/RetrievalAugmentationConfig.java +++ b/src/main/java/com/williamcallahan/javachat/config/RetrievalAugmentationConfig.java @@ -14,7 +14,7 @@ public class RetrievalAugmentationConfig { private static final int OVERLAP_DEF = 150; private static final int CITE_DEF = 3; private static final double MMR_LAMBDA_DEF = 0.5d; - private static final Duration RERANK_TIMEOUT_DEF = Duration.ofSeconds(12); + private static final Duration RERANK_TIMEOUT_DEF = Duration.ofSeconds(30); private static final int MIN_POSITIVE = 1; private static final int MIN_NON_NEG = 0; private static final double MMR_MIN = 0.0d; diff --git a/src/main/java/com/williamcallahan/javachat/service/EmbeddingBatchEmbedder.java b/src/main/java/com/williamcallahan/javachat/service/EmbeddingBatchEmbedder.java index c233e22..d72c67e 100644 --- a/src/main/java/com/williamcallahan/javachat/service/EmbeddingBatchEmbedder.java +++ b/src/main/java/com/williamcallahan/javachat/service/EmbeddingBatchEmbedder.java @@ -56,7 +56,7 @@ private static List embedSingleBatch( List batchEmbeddings; try { - batchEmbeddings = embeddingClient.embed(textBatch); + batchEmbeddings = embeddingClient.embed(textBatch, LlmGatewayTier.BATCH); } catch (EmbeddingServiceUnavailableException embeddingFailure) { String firstBatchUrl = extractDocumentUrl(documentBatch.getFirst(), batchStartIndex); String lastBatchUrl = extractDocumentUrl(documentBatch.getLast(), batchEndIndex - 1); diff --git a/src/main/java/com/williamcallahan/javachat/service/EmbeddingClient.java b/src/main/java/com/williamcallahan/javachat/service/EmbeddingClient.java index 7ced0e8..614e24c 100644 --- a/src/main/java/com/williamcallahan/javachat/service/EmbeddingClient.java +++ b/src/main/java/com/williamcallahan/javachat/service/EmbeddingClient.java @@ -14,23 +14,26 @@ public interface EmbeddingClient { * Produces one dense embedding vector per input text, preserving input order. * * @param texts input texts + * @param requestTier gateway capacity tier for this embedding request * @return embedding vectors in the same order as {@code texts} */ - List embed(List texts); + List embed(List texts, LlmGatewayTier requestTier); /** * Produces a dense embedding vector for a single text. * * @param text input text + * @param requestTier gateway capacity tier for this embedding request * @return embedding vector */ - default float[] embed(String text) { + default float[] embed(String text, LlmGatewayTier requestTier) { + Objects.requireNonNull(requestTier, "requestTier"); String safeText = Objects.requireNonNullElse(text, ""); - List vectors = embed(List.of(safeText)); - if (vectors.isEmpty()) { + List embeddingVectors = embed(List.of(safeText), requestTier); + if (embeddingVectors.isEmpty()) { throw new EmbeddingServiceUnavailableException("Embedding response was empty"); } - return vectors.get(0); + return embeddingVectors.get(0); } /** @@ -44,7 +47,7 @@ default float[] embed(String text) { * Issues a minimal embedding request so the provider keeps its model resident. * *

Implementations must call their provider-specific request path directly instead - * of delegating to {@link #embed(List)}. The RAG pipeline logging aspect advises + * of delegating to {@link #embed(List, LlmGatewayTier)}. The RAG pipeline logging aspect advises * public {@code embed} executions, so routing scheduled probes around that method * keeps "STEP 1" pipeline logs scoped to real requests.

* diff --git a/src/main/java/com/williamcallahan/javachat/service/HybridSearchService.java b/src/main/java/com/williamcallahan/javachat/service/HybridSearchService.java index 11cd718..b61b426 100644 --- a/src/main/java/com/williamcallahan/javachat/service/HybridSearchService.java +++ b/src/main/java/com/williamcallahan/javachat/service/HybridSearchService.java @@ -128,7 +128,7 @@ public SearchOutcome searchOutcome(String query, int topK, RetrievalConstraint r return new SearchOutcome(List.of(), List.of()); } - float[] denseVector = queryEncoding.embeddingClient().embed(query); + float[] denseVector = queryEncoding.embeddingClient().embed(query, LlmGatewayTier.LIVE); LexicalSparseVectorEncoder.SparseVector sparseVector = queryEncoding.sparseVectorEncoder().encode(query); Optional retrievalFilter = queryEncoding.constraintBuilder().buildFilter(retrievalConstraint); diff --git a/src/main/java/com/williamcallahan/javachat/service/LlmGatewayTier.java b/src/main/java/com/williamcallahan/javachat/service/LlmGatewayTier.java new file mode 100644 index 0000000..086ed83 --- /dev/null +++ b/src/main/java/com/williamcallahan/javachat/service/LlmGatewayTier.java @@ -0,0 +1,33 @@ +package com.williamcallahan.javachat.service; + +/** + * Defines the request tiers understood by the LLM gateway. + * + *

Live user-facing work uses {@link #LIVE}. Background ingestion and scheduled + * embedding probes use {@link #BATCH} so live requests keep reserved capacity.

+ */ +public enum LlmGatewayTier { + /** User-facing request tier with production reserved capacity. */ + LIVE("production-z"), + + /** Background request tier for ingestion, backfills, and scheduled probes. */ + BATCH("batch"); + + /** HTTP header used by the gateway to classify request capacity. */ + public static final String REQUEST_TIER_HEADER = "X-Tier"; + + private final String requestHeader; + + LlmGatewayTier(String requestHeader) { + this.requestHeader = requestHeader; + } + + /** + * Returns the gateway header payload for this request tier. + * + * @return header payload sent in {@link #REQUEST_TIER_HEADER} + */ + public String requestHeader() { + return requestHeader; + } +} diff --git a/src/main/java/com/williamcallahan/javachat/service/LocalEmbeddingClient.java b/src/main/java/com/williamcallahan/javachat/service/LocalEmbeddingClient.java index b99f834..1313e36 100644 --- a/src/main/java/com/williamcallahan/javachat/service/LocalEmbeddingClient.java +++ b/src/main/java/com/williamcallahan/javachat/service/LocalEmbeddingClient.java @@ -66,7 +66,8 @@ public LocalEmbeddingClient( } @Override - public List embed(List texts) { + public List embed(List texts, LlmGatewayTier requestTier) { + Objects.requireNonNull(requestTier, "requestTier"); if (texts == null || texts.isEmpty()) { return List.of(); } diff --git a/src/main/java/com/williamcallahan/javachat/service/OpenAIStreamingService.java b/src/main/java/com/williamcallahan/javachat/service/OpenAIStreamingService.java index 35f2c40..3387d43 100644 --- a/src/main/java/com/williamcallahan/javachat/service/OpenAIStreamingService.java +++ b/src/main/java/com/williamcallahan/javachat/service/OpenAIStreamingService.java @@ -42,7 +42,6 @@ public class OpenAIStreamingService { private static final Logger log = LoggerFactory.getLogger(OpenAIStreamingService.class); private static final int COMPLETE_REQUEST_TIMEOUT_SECONDS = 30; - private static final String LLM_GATEWAY_TIER_LIVE = "production-z"; private static final String STREAM_STATUS_CODE_PROVIDER_FALLBACK = SseConstants.STATUS_CODE_STREAM_PROVIDER_FALLBACK; private static final String STREAM_STAGE_STREAM = SseConstants.STATUS_STAGE_STREAM; @@ -180,6 +179,25 @@ public Mono streamResponse(StructuredPrompt structuredPrompt, d * @return completion text from the first successful provider attempt */ public Mono complete(String prompt, double temperature) { + return complete(prompt, temperature, null); + } + + /** + * Sends a non-streaming completion request with an explicit output budget. + * + * @param prompt completion prompt + * @param temperature response temperature + * @param maximumOutputTokens maximum output tokens needed by this caller + * @return completion text from the first successful provider attempt + */ + public Mono complete(String prompt, double temperature, int maximumOutputTokens) { + if (maximumOutputTokens <= 0) { + return Mono.error(new IllegalArgumentException("maximumOutputTokens must be positive")); + } + return complete(prompt, temperature, Integer.valueOf(maximumOutputTokens)); + } + + private Mono complete(String prompt, double temperature, Integer maximumOutputTokens) { return Mono.defer(() -> { List availableProviders = providerRoutingService.selectAvailableProviderCandidates(clientPrimary, clientSecondary); @@ -196,7 +214,7 @@ public Mono complete(String prompt, double temperature) { RateLimitService.ApiProvider activeProvider = providerCandidate.provider(); ResponseCreateParams requestParameters = - requestFactory.buildCompletionRequest(prompt, temperature, activeProvider); + buildCompletionRequest(prompt, temperature, activeProvider, maximumOutputTokens); try { log.info("[LLM] Complete started (providerId={})", activeProvider.ordinal()); RequestOptions requestOptions = RequestOptions.builder() @@ -237,6 +255,17 @@ public Mono complete(String prompt, double temperature) { .subscribeOn(Schedulers.boundedElastic()); } + private ResponseCreateParams buildCompletionRequest( + String prompt, + double temperature, + RateLimitService.ApiProvider activeProvider, + Integer maximumOutputTokens) { + if (maximumOutputTokens == null) { + return requestFactory.buildCompletionRequest(prompt, temperature, activeProvider); + } + return requestFactory.buildCompletionRequest(prompt, temperature, activeProvider, maximumOutputTokens); + } + /** * Returns whether a streaming failure is likely recoverable with a retry. * @@ -338,7 +367,7 @@ private OpenAIClient createClient(String apiKey, String baseUrl) { return OpenAIOkHttpClient.builder() .apiKey(apiKey) .baseUrl(OpenAiSdkUrlNormalizer.normalize(baseUrl)) - .putHeader("X-Tier", resolvedLlmGatewayTier()) + .putHeader(LlmGatewayTier.REQUEST_TIER_HEADER, resolvedLlmGatewayTier()) // Disable SDK-level retries: Reactor timeout and onErrorResume handle failures. // Retries cause InterruptedException when Reactor cancels a sleeping retry. .maxRetries(0) @@ -346,7 +375,9 @@ private OpenAIClient createClient(String apiKey, String baseUrl) { } private String resolvedLlmGatewayTier() { - return llmGatewayTier == null || llmGatewayTier.isBlank() ? LLM_GATEWAY_TIER_LIVE : llmGatewayTier.trim(); + return llmGatewayTier == null || llmGatewayTier.isBlank() + ? LlmGatewayTier.LIVE.requestHeader() + : llmGatewayTier.trim(); } private void closeClientSafely(OpenAIClient client, String clientName) { diff --git a/src/main/java/com/williamcallahan/javachat/service/OpenAiCompatibleEmbeddingClient.java b/src/main/java/com/williamcallahan/javachat/service/OpenAiCompatibleEmbeddingClient.java index afb7f71..4b263fe 100644 --- a/src/main/java/com/williamcallahan/javachat/service/OpenAiCompatibleEmbeddingClient.java +++ b/src/main/java/com/williamcallahan/javachat/service/OpenAiCompatibleEmbeddingClient.java @@ -40,9 +40,11 @@ public class OpenAiCompatibleEmbeddingClient implements EmbeddingClient, AutoClo private static final int HTTP_TOO_MANY_REQUESTS = 429; private static final int HTTP_INTERNAL_SERVER_ERROR = 500; - private final OpenAIClient client; + private final OpenAIClient liveEmbeddingClient; + private final OpenAIClient batchEmbeddingClient; private final String modelName; private final int dimensionsHint; + private final boolean closeBatchEmbeddingClient; /** * Creates an OpenAI-compatible embedding client backed by a remote REST API endpoint. @@ -56,61 +58,76 @@ public class OpenAiCompatibleEmbeddingClient implements EmbeddingClient, AutoClo public static OpenAiCompatibleEmbeddingClient create( String baseUrl, String apiKey, String modelName, int dimensionsHint) { validateDimensions(dimensionsHint); - OpenAIClient client = OpenAIOkHttpClient.builder() - .apiKey(requireConfiguredApiKey(apiKey)) - .baseUrl(normalizeSdkBaseUrl(baseUrl)) - // Embedding traffic is ingestion/backfill-dominated, so it is classed - // as the LLM gateway's "batch" tier. The current sf7-direct endpoint - // ignores the header; it becomes load-bearing if this client is ever - // pointed at the gateway queue (api.llm-gateway.iocloudhost.net). - .putHeader("X-Tier", "batch") - .build(); - return new OpenAiCompatibleEmbeddingClient(client, requireConfiguredModel(modelName), dimensionsHint); + String configuredApiKey = requireConfiguredApiKey(apiKey); + String normalizedBaseUrl = normalizeSdkBaseUrl(baseUrl); + OpenAIClient liveEmbeddingClient = createTieredClient(configuredApiKey, normalizedBaseUrl, LlmGatewayTier.LIVE); + OpenAIClient batchEmbeddingClient = + createTieredClient(configuredApiKey, normalizedBaseUrl, LlmGatewayTier.BATCH); + return new OpenAiCompatibleEmbeddingClient( + liveEmbeddingClient, batchEmbeddingClient, requireConfiguredModel(modelName), dimensionsHint); } static OpenAiCompatibleEmbeddingClient create(OpenAIClient client, String modelName, int dimensionsHint) { validateDimensions(dimensionsHint); + OpenAIClient embeddingClient = Objects.requireNonNull(client, "client"); return new OpenAiCompatibleEmbeddingClient( - Objects.requireNonNull(client, "client"), requireConfiguredModel(modelName), dimensionsHint); + embeddingClient, embeddingClient, requireConfiguredModel(modelName), dimensionsHint, false); } - OpenAiCompatibleEmbeddingClient(OpenAIClient client, String modelName, int dimensionsHint) { - this.client = client; + OpenAiCompatibleEmbeddingClient( + OpenAIClient liveEmbeddingClient, OpenAIClient batchEmbeddingClient, String modelName, int dimensionsHint) { + this(liveEmbeddingClient, batchEmbeddingClient, modelName, dimensionsHint, true); + } + + private OpenAiCompatibleEmbeddingClient( + OpenAIClient liveEmbeddingClient, + OpenAIClient batchEmbeddingClient, + String modelName, + int dimensionsHint, + boolean closeBatchEmbeddingClient) { + this.liveEmbeddingClient = Objects.requireNonNull(liveEmbeddingClient, "liveEmbeddingClient"); + this.batchEmbeddingClient = Objects.requireNonNull(batchEmbeddingClient, "batchEmbeddingClient"); this.modelName = modelName; this.dimensionsHint = dimensionsHint; + this.closeBatchEmbeddingClient = closeBatchEmbeddingClient; } @Override - public List embed(List texts) { + public List embed(List texts, LlmGatewayTier requestTier) { + Objects.requireNonNull(requestTier, "requestTier"); if (texts == null || texts.isEmpty()) { return List.of(); } - return createEmbeddings(texts); + return createEmbeddings(texts, requestTier); } @Override public void warmUp() { - createEmbeddings(List.of(EMBEDDING_WARM_UP_PROBE_TEXT)); + createEmbeddings(List.of(EMBEDDING_WARM_UP_PROBE_TEXT), LlmGatewayTier.BATCH); } - private List createEmbeddings(List texts) { + private List createEmbeddings(List texts, LlmGatewayTier requestTier) { EmbeddingCreateParams.Builder embeddingRequestBuilder = EmbeddingCreateParams.builder().model(modelName).inputOfArrayOfStrings(texts); if (supportsDimensionOverride(modelName)) { embeddingRequestBuilder.dimensions((long) dimensionsHint); } - EmbeddingCreateParams params = embeddingRequestBuilder.build(); + EmbeddingCreateParams embeddingRequest = embeddingRequestBuilder.build(); RequestOptions requestOptions = RequestOptions.builder().timeout(embeddingTimeout()).build(); - return executeWithRetry(params, requestOptions, texts.size()); + return executeWithRetry(clientFor(requestTier), embeddingRequest, requestOptions, texts.size()); } private List executeWithRetry( - EmbeddingCreateParams params, RequestOptions requestOptions, int expectedCount) { + OpenAIClient requestClient, + EmbeddingCreateParams embeddingRequest, + RequestOptions requestOptions, + int expectedCount) { long retryBackoffMillis = INITIAL_RETRY_BACKOFF_MILLIS; for (int attemptNumber = 1; attemptNumber <= MAX_EMBED_ATTEMPTS; attemptNumber++) { try { - CreateEmbeddingResponse embeddingResponse = client.embeddings().create(params, requestOptions); + CreateEmbeddingResponse embeddingResponse = + requestClient.embeddings().create(embeddingRequest, requestOptions); return parseResponse(embeddingResponse, expectedCount); } catch (OpenAIServiceException exception) { retryBackoffMillis = handleServiceError(exception, attemptNumber, retryBackoffMillis); @@ -372,7 +389,10 @@ private static String sanitizeMessage(String message) { */ @Override public void close() { - client.close(); + liveEmbeddingClient.close(); + if (closeBatchEmbeddingClient) { + batchEmbeddingClient.close(); + } } private static String requireConfiguredApiKey(String apiKey) { @@ -393,6 +413,21 @@ private static String normalizeSdkBaseUrl(String baseUrl) { return OpenAiSdkUrlNormalizer.normalize(baseUrl); } + private static OpenAIClient createTieredClient(String apiKey, String baseUrl, LlmGatewayTier requestTier) { + return OpenAIOkHttpClient.builder() + .apiKey(apiKey) + .baseUrl(baseUrl) + .putHeader(LlmGatewayTier.REQUEST_TIER_HEADER, requestTier.requestHeader()) + .build(); + } + + private OpenAIClient clientFor(LlmGatewayTier requestTier) { + return switch (requestTier) { + case LIVE -> liveEmbeddingClient; + case BATCH -> batchEmbeddingClient; + }; + } + private static void validateDimensions(int dimensionsHint) { if (dimensionsHint <= 0) { throw new IllegalArgumentException("Embedding dimensions must be positive"); diff --git a/src/main/java/com/williamcallahan/javachat/service/OpenAiProviderRoutingService.java b/src/main/java/com/williamcallahan/javachat/service/OpenAiProviderRoutingService.java index bade83d..bced68f 100644 --- a/src/main/java/com/williamcallahan/javachat/service/OpenAiProviderRoutingService.java +++ b/src/main/java/com/williamcallahan/javachat/service/OpenAiProviderRoutingService.java @@ -10,6 +10,8 @@ import com.openai.errors.SseException; import com.openai.errors.UnauthorizedException; import com.williamcallahan.javachat.support.AsciiTextNormalizer; +import java.io.InterruptedIOException; +import java.net.SocketTimeoutException; import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; @@ -176,30 +178,16 @@ public boolean isRecoverableStreamingFailure(Throwable throwable) { } boolean shouldBackoffPrimary(Throwable throwable) { - if (isRateLimit(throwable)) { - return true; - } - if (throwable instanceof OpenAIIoException) { - return true; - } - if (throwable instanceof InterruptedException) { - Thread.currentThread().interrupt(); - return true; - } - if (throwable instanceof UnauthorizedException || throwable instanceof PermissionDeniedException) { - return true; - } - if (throwable instanceof InternalServerException) { - return true; - } - if (throwable instanceof NotFoundException) { - return true; - } - if (throwable instanceof OpenAIServiceException serviceException) { - return serviceException.statusCode() >= HTTP_INTERNAL_SERVER_ERROR; + if (isCallerCancellation(throwable)) { + return false; } - String message = throwable.getMessage(); - return message != null && AsciiTextNormalizer.toLowerAscii(message).contains("sleep interrupted"); + return isRateLimit(throwable) + || throwable instanceof OpenAIIoException + || throwable instanceof UnauthorizedException + || throwable instanceof PermissionDeniedException + || throwable instanceof InternalServerException + || throwable instanceof NotFoundException + || isServerError(throwable); } private List orderedProviderCandidates( @@ -275,6 +263,32 @@ private boolean isRateLimit(Throwable throwable) { && serviceException.statusCode() == HTTP_TOO_MANY_REQUESTS); } + private boolean isServerError(Throwable throwable) { + return throwable instanceof OpenAIServiceException serviceException + && serviceException.statusCode() >= HTTP_INTERNAL_SERVER_ERROR; + } + + private boolean isCallerCancellation(Throwable throwable) { + Throwable cancellationCandidate = throwable; + while (cancellationCandidate != null) { + if (cancellationCandidate instanceof InterruptedException) { + Thread.currentThread().interrupt(); + return true; + } + if (cancellationCandidate instanceof InterruptedIOException + && !(cancellationCandidate instanceof SocketTimeoutException)) { + return true; + } + String cancellationMessage = cancellationCandidate.getMessage(); + if (cancellationMessage != null + && AsciiTextNormalizer.toLowerAscii(cancellationMessage).contains("sleep interrupted")) { + return true; + } + cancellationCandidate = cancellationCandidate.getCause(); + } + return false; + } + private boolean isPrimaryInBackoff() { return System.currentTimeMillis() < primaryBackoffUntilEpochMs; } diff --git a/src/main/java/com/williamcallahan/javachat/service/OpenAiRequestFactory.java b/src/main/java/com/williamcallahan/javachat/service/OpenAiRequestFactory.java index a9665fd..c6a5793 100644 --- a/src/main/java/com/williamcallahan/javachat/service/OpenAiRequestFactory.java +++ b/src/main/java/com/williamcallahan/javachat/service/OpenAiRequestFactory.java @@ -23,7 +23,7 @@ public class OpenAiRequestFactory { private static final Logger log = LoggerFactory.getLogger(OpenAiRequestFactory.class); - private static final int MAX_COMPLETION_TOKENS = 4000; + private static final int GPT5_COMPLETION_OUTPUT_TOKEN_BUDGET = 4000; /** Prefix matching gpt-5, gpt-5.2, gpt-5.2-pro, etc. */ private static final String GPT_5_MODEL_PREFIX = "gpt-5"; @@ -117,10 +117,32 @@ public OpenAiPreparedRequest prepareStreamingRequest( */ public ResponseCreateParams buildCompletionRequest( String prompt, double temperature, RateLimitService.ApiProvider provider) { + return buildCompletionRequest(prompt, temperature, provider, null); + } + + /** + * Builds completion request parameters with an explicit output budget. + * + * @param prompt completion prompt + * @param temperature response temperature + * @param provider provider chosen for this request attempt + * @param maximumOutputTokens maximum output tokens needed by this caller + * @return request payload ready for SDK execution + */ + public ResponseCreateParams buildCompletionRequest( + String prompt, double temperature, RateLimitService.ApiProvider provider, int maximumOutputTokens) { + if (maximumOutputTokens <= 0) { + throw new IllegalArgumentException("maximumOutputTokens must be positive"); + } + return buildCompletionRequest(prompt, temperature, provider, Integer.valueOf(maximumOutputTokens)); + } + + private ResponseCreateParams buildCompletionRequest( + String prompt, double temperature, RateLimitService.ApiProvider provider, Integer maximumOutputTokens) { boolean useGitHubModels = provider == RateLimitService.ApiProvider.GITHUB_MODELS; String modelId = normalizedModelId(useGitHubModels); String truncatedPrompt = truncatePromptForCompletion(prompt, modelId, useGitHubModels); - return buildResponseParams(truncatedPrompt, temperature, modelId); + return buildResponseParams(truncatedPrompt, temperature, modelId, maximumOutputTokens); } private String truncatePromptForCompletion(String prompt, String modelId, boolean useGitHubModels) { @@ -145,6 +167,11 @@ private String truncatePromptForCompletion(String prompt, String modelId, boolea } private ResponseCreateParams buildResponseParams(String prompt, double temperature, String normalizedModelId) { + return buildResponseParams(prompt, temperature, normalizedModelId, null); + } + + private ResponseCreateParams buildResponseParams( + String prompt, double temperature, String normalizedModelId, Integer maximumOutputTokens) { boolean gpt5Family = isGpt5Family(normalizedModelId); boolean reasoningModel = gpt5Family || canonicalModelName(normalizedModelId).startsWith("o"); @@ -152,8 +179,13 @@ private ResponseCreateParams buildResponseParams(String prompt, double temperatu ResponseCreateParams.Builder builder = ResponseCreateParams.builder().input(prompt).model(ResponsesModel.ofString(normalizedModelId)); + if (maximumOutputTokens != null) { + builder.maxOutputTokens(maximumOutputTokens.longValue()); + } else if (gpt5Family) { + builder.maxOutputTokens((long) GPT5_COMPLETION_OUTPUT_TOKEN_BUDGET); + } + if (gpt5Family) { - builder.maxOutputTokens((long) MAX_COMPLETION_TOKENS); log.debug("Using GPT-5 family configuration for model: {}", normalizedModelId); resolveReasoningEffort() diff --git a/src/main/java/com/williamcallahan/javachat/service/RerankerService.java b/src/main/java/com/williamcallahan/javachat/service/RerankerService.java index c9cf816..1e10c46 100644 --- a/src/main/java/com/williamcallahan/javachat/service/RerankerService.java +++ b/src/main/java/com/williamcallahan/javachat/service/RerankerService.java @@ -32,6 +32,9 @@ public class RerankerService { /** Maximum character length of document text included in the rerank prompt. */ private static final int RERANK_PROMPT_TEXT_MAX_LENGTH = 500; + /** Output budget for the small JSON ordering the reranker requires. */ + private static final int RERANKER_OUTPUT_TOKEN_BUDGET = 128; + private final OpenAIStreamingService openAIStreamingService; private final ObjectMapper mapper; private final Duration rerankerTimeout; @@ -100,7 +103,7 @@ private Optional callLlmForReranking(String query, List docume try { return openAIStreamingService - .complete(prompt, 0.0) + .complete(prompt, 0.0, RERANKER_OUTPUT_TOKEN_BUDGET) .timeout(rerankerTimeout) .doOnError( timeoutOrApiError -> log.debug("Reranker LLM call timed out or failed", timeoutOrApiError)) @@ -121,7 +124,8 @@ private String buildRerankPrompt(String query, List documents) { prompt.append("Consider Java-specific context, version relevance, and learning value.\n"); prompt.append("Prefer official documentation over blogs or third-party sources.\n"); prompt.append("Prefer stable release documentation over early-access or preview content.\n"); - prompt.append("Return JSON: {\"order\":[indices...]} with 0-based indices.\n\n"); + prompt.append("Return only JSON: {\"order\":[indices...]} with 0-based indices.\n"); + prompt.append("Do not include markdown, prose, or explanations.\n\n"); prompt.append("Query: ").append(query).append("\n\n"); for (int docIndex = 0; docIndex < documents.size(); docIndex++) { diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 8b9fad4..060377a 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -111,7 +111,7 @@ app.rag.search-top-k=${RAG_TOP_K:12} app.rag.search-return-k=${RAG_RETURN_K:6} app.rag.search-citations=${RAG_CITATIONS_K:3} app.rag.search-mmr-lambda=${RAG_MMR_LAMBDA:0.5} -app.rag.reranker-timeout=${RAG_RERANKER_TIMEOUT:12s} +app.rag.reranker-timeout=${RAG_RERANKER_TIMEOUT:30s} # LLM defaults used by openai-java streaming service app.llm.temperature=${APP_LLM_TEMPERATURE:0.7} diff --git a/src/test/java/com/williamcallahan/javachat/config/RetrievalAugmentationConfigTest.java b/src/test/java/com/williamcallahan/javachat/config/RetrievalAugmentationConfigTest.java index 630700f..08fd763 100644 --- a/src/test/java/com/williamcallahan/javachat/config/RetrievalAugmentationConfigTest.java +++ b/src/test/java/com/williamcallahan/javachat/config/RetrievalAugmentationConfigTest.java @@ -1,6 +1,7 @@ package com.williamcallahan.javachat.config; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import java.time.Duration; @@ -16,6 +17,7 @@ void validateConfigurationAcceptsDefaultRerankerTimeout() { RetrievalAugmentationConfig config = new RetrievalAugmentationConfig(); assertDoesNotThrow(config::validateConfiguration); + assertEquals(Duration.ofSeconds(30), config.getRerankerTimeout()); } @Test diff --git a/src/test/java/com/williamcallahan/javachat/service/EmbeddingModelKeepAliveTest.java b/src/test/java/com/williamcallahan/javachat/service/EmbeddingModelKeepAliveTest.java index 9b1ab86..57c9e8d 100644 --- a/src/test/java/com/williamcallahan/javachat/service/EmbeddingModelKeepAliveTest.java +++ b/src/test/java/com/williamcallahan/javachat/service/EmbeddingModelKeepAliveTest.java @@ -28,12 +28,15 @@ void keepEmbeddingModelWarmDoesNotPropagateProviderUnavailability() { assertDoesNotThrow(keepAlive::keepEmbeddingModelWarm); } + /** + * Records scheduled warm-up calls without allowing pipeline embedding calls. + */ private static final class RecordingEmbeddingClient implements EmbeddingClient { private int warmUpInvocationCount; @Override - public List embed(List texts) { - throw new AssertionError("keep-alive probes must not call embed(List)"); + public List embed(List texts, LlmGatewayTier requestTier) { + throw new AssertionError("keep-alive probes must not call embed(List, LlmGatewayTier)"); } @Override @@ -47,9 +50,12 @@ public int dimensions() { } } + /** + * Simulates a provider that is unavailable during scheduled warm-up. + */ private static final class UnavailableEmbeddingClient implements EmbeddingClient { @Override - public List embed(List texts) { + public List embed(List texts, LlmGatewayTier requestTier) { throw new EmbeddingServiceUnavailableException("provider offline for test"); } diff --git a/src/test/java/com/williamcallahan/javachat/service/EmbeddingProviderFailureTest.java b/src/test/java/com/williamcallahan/javachat/service/EmbeddingProviderFailureTest.java index 8e19ca5..a5935e3 100644 --- a/src/test/java/com/williamcallahan/javachat/service/EmbeddingProviderFailureTest.java +++ b/src/test/java/com/williamcallahan/javachat/service/EmbeddingProviderFailureTest.java @@ -30,8 +30,9 @@ void localEmbeddingSurfacesHttpErrors() throws IOException { try { LocalEmbeddingClient localClient = new LocalEmbeddingClient(baseUrl, "local-test-model", 12, 8, new RestTemplateBuilder()); - EmbeddingServiceUnavailableException thrown = - assertThrows(EmbeddingServiceUnavailableException.class, () -> localClient.embed(List.of("hello"))); + EmbeddingServiceUnavailableException thrown = assertThrows( + EmbeddingServiceUnavailableException.class, + () -> localClient.embed(List.of("hello"), LlmGatewayTier.LIVE)); assertTrue(thrown.getMessage().contains("HTTP 500")); } finally { @@ -49,7 +50,8 @@ void remoteEmbeddingHttpErrorsSurfaceStatusCodes() throws IOException { try (OpenAiCompatibleEmbeddingClient remoteClient = OpenAiCompatibleEmbeddingClient.create(baseUrl, "test-key", "text-embedding-qwen3-embedding-8b", 8)) { EmbeddingServiceUnavailableException thrown = assertThrows( - EmbeddingServiceUnavailableException.class, () -> remoteClient.embed(List.of("hello"))); + EmbeddingServiceUnavailableException.class, + () -> remoteClient.embed(List.of("hello"), LlmGatewayTier.LIVE)); assertTrue(thrown.getMessage().contains("HTTP 401")); assertNotNull(thrown.getCause()); diff --git a/src/test/java/com/williamcallahan/javachat/service/HybridSearchServiceTest.java b/src/test/java/com/williamcallahan/javachat/service/HybridSearchServiceTest.java index 3e55edd..08bdeb1 100644 --- a/src/test/java/com/williamcallahan/javachat/service/HybridSearchServiceTest.java +++ b/src/test/java/com/williamcallahan/javachat/service/HybridSearchServiceTest.java @@ -47,7 +47,7 @@ void setUp() { void appliesServerFilterToQueryAndPrefetchWithConfiguredRrfK() { appProperties.getQdrant().setRrfK(77); - when(embeddingClient.embed("Java 25 streams")).thenReturn(new float[] {0.1f, 0.2f, 0.3f}); + when(embeddingClient.embed("Java 25 streams", LlmGatewayTier.LIVE)).thenReturn(new float[] {0.1f, 0.2f, 0.3f}); when(sparseEncoder.encode("Java 25 streams")) .thenReturn(new LexicalSparseVectorEncoder.SparseVector(List.of(1L, 3L), List.of(2.0f, 1.0f))); @@ -112,7 +112,7 @@ private HybridSearchService buildSearchService() { } private void stubPartialFailureQueryResponses(String queryText) { - when(embeddingClient.embed(queryText)).thenReturn(new float[] {0.5f, 0.1f, 0.4f}); + when(embeddingClient.embed(queryText, LlmGatewayTier.LIVE)).thenReturn(new float[] {0.5f, 0.1f, 0.4f}); when(sparseEncoder.encode(queryText)) .thenReturn(new LexicalSparseVectorEncoder.SparseVector(List.of(2L), List.of(1.0f))); diff --git a/src/test/java/com/williamcallahan/javachat/service/LocalEmbeddingClientTest.java b/src/test/java/com/williamcallahan/javachat/service/LocalEmbeddingClientTest.java index 3c9cb59..ff16e1f 100644 --- a/src/test/java/com/williamcallahan/javachat/service/LocalEmbeddingClientTest.java +++ b/src/test/java/com/williamcallahan/javachat/service/LocalEmbeddingClientTest.java @@ -57,7 +57,8 @@ void batchesRequestsAndPreservesEmbeddingOrderByIndex() throws IOException { try { LocalEmbeddingClient localEmbeddingClient = new LocalEmbeddingClient(baseUrl, "local-model", 3, 2, new RestTemplateBuilder()); - List embeddingVectors = localEmbeddingClient.embed(List.of("alpha", "beta", "gamma")); + List embeddingVectors = + localEmbeddingClient.embed(List.of("alpha", "beta", "gamma"), LlmGatewayTier.LIVE); assertEquals(2, requestCounter.get()); assertEquals(List.of(2, 1), observedBatchSizes); @@ -89,7 +90,8 @@ void failsWhenLocalEmbeddingDimensionsDoNotMatchConfiguredDimensions() throws IO LocalEmbeddingClient localEmbeddingClient = new LocalEmbeddingClient(baseUrl, "local-model", 3, 8, new RestTemplateBuilder()); EmbeddingServiceUnavailableException thrownException = assertThrows( - EmbeddingServiceUnavailableException.class, () -> localEmbeddingClient.embed(List.of("alpha"))); + EmbeddingServiceUnavailableException.class, + () -> localEmbeddingClient.embed(List.of("alpha"), LlmGatewayTier.LIVE)); assertTrue(thrownException.getMessage().contains("dimension mismatch")); } finally { httpServer.stop(0); diff --git a/src/test/java/com/williamcallahan/javachat/service/OpenAIStreamingServiceTest.java b/src/test/java/com/williamcallahan/javachat/service/OpenAIStreamingServiceTest.java index e523a56..1be575b 100644 --- a/src/test/java/com/williamcallahan/javachat/service/OpenAIStreamingServiceTest.java +++ b/src/test/java/com/williamcallahan/javachat/service/OpenAIStreamingServiceTest.java @@ -13,6 +13,7 @@ import com.openai.errors.RateLimitException; import com.openai.errors.UnauthorizedException; import com.williamcallahan.javachat.application.prompt.PromptTruncator; +import java.io.InterruptedIOException; import java.util.List; import org.junit.jupiter.api.Test; import reactor.core.Exceptions; @@ -47,6 +48,15 @@ void shouldBackoffPrimaryTreatsSdkIoAsBackoffEligible() { assertTrue(routingService.shouldBackoffPrimary(new OpenAIIoException("io"))); } + @Test + void shouldBackoffPrimaryIgnoresCallerCancellationWrappedBySdkIo() { + OpenAiProviderRoutingService routingService = createRoutingService(); + InterruptedIOException interruptedRequest = new InterruptedIOException("request interrupted by caller timeout"); + OpenAIIoException cancelledCompletion = new OpenAIIoException("Request failed", interruptedRequest); + + assertFalse(routingService.shouldBackoffPrimary(cancelledCompletion)); + } + @Test void shouldBackoffPrimaryTreats401AsBackoffEligible() { OpenAiProviderRoutingService routingService = createRoutingService(); diff --git a/src/test/java/com/williamcallahan/javachat/service/OpenAiCompatibleEmbeddingClientTest.java b/src/test/java/com/williamcallahan/javachat/service/OpenAiCompatibleEmbeddingClientTest.java index c6123f8..18c7c3f 100644 --- a/src/test/java/com/williamcallahan/javachat/service/OpenAiCompatibleEmbeddingClientTest.java +++ b/src/test/java/com/williamcallahan/javachat/service/OpenAiCompatibleEmbeddingClientTest.java @@ -7,6 +7,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import com.openai.client.OpenAIClient; @@ -53,7 +54,7 @@ void callUsesSdkAndPreservesIndexOrdering() { try (OpenAiCompatibleEmbeddingClient clientAdapter = OpenAiCompatibleEmbeddingClient.create( client, "text-embedding-qwen3-embedding-8b", EXPECTED_EMBEDDING_DIMENSION)) { - List vectors = clientAdapter.embed(List.of("a", "b")); + List vectors = clientAdapter.embed(List.of("a", "b"), LlmGatewayTier.LIVE); assertEquals(2, vectors.size()); assertEquals(0.25f, vectors.get(0)[0]); @@ -88,8 +89,9 @@ void throwsWhenEmbeddingDimensionDoesNotMatchConfiguration() { try (OpenAiCompatibleEmbeddingClient clientAdapter = OpenAiCompatibleEmbeddingClient.create( client, "text-embedding-qwen3-embedding-8b", EXPECTED_EMBEDDING_DIMENSION)) { - EmbeddingServiceUnavailableException thrownException = - assertThrows(EmbeddingServiceUnavailableException.class, () -> clientAdapter.embed(List.of("a"))); + EmbeddingServiceUnavailableException thrownException = assertThrows( + EmbeddingServiceUnavailableException.class, + () -> clientAdapter.embed(List.of("a"), LlmGatewayTier.LIVE)); assertTrue(thrownException.getMessage().contains("dimension mismatch")); verify(embeddingService, times(1)).create(any(), any(RequestOptions.class)); } @@ -131,7 +133,7 @@ void retriesTransientResponseValidationFailuresAndRecovers() { try (OpenAiCompatibleEmbeddingClient clientAdapter = OpenAiCompatibleEmbeddingClient.create( client, "text-embedding-qwen3-embedding-8b", EXPECTED_EMBEDDING_DIMENSION)) { - List vectors = clientAdapter.embed(List.of("single")); + List vectors = clientAdapter.embed(List.of("single"), LlmGatewayTier.LIVE); assertEquals(1, vectors.size()); assertEquals(0.7f, vectors.get(0)[0]); @@ -161,7 +163,7 @@ void embed_includesDimensionsForTextEmbedding3Models() { try (OpenAiCompatibleEmbeddingClient clientAdapter = OpenAiCompatibleEmbeddingClient.create( client, "text-embedding-3-small", EXPECTED_EMBEDDING_DIMENSION)) { - clientAdapter.embed(List.of("dimension check")); + clientAdapter.embed(List.of("dimension check"), LlmGatewayTier.LIVE); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(EmbeddingCreateParams.class); verify(embeddingService).create(requestCaptor.capture(), any(RequestOptions.class)); @@ -193,11 +195,46 @@ void embed_omitsDimensionsForNonTextEmbedding3Models() { try (OpenAiCompatibleEmbeddingClient clientAdapter = OpenAiCompatibleEmbeddingClient.create( client, "text-embedding-qwen3-embedding-8b", EXPECTED_EMBEDDING_DIMENSION)) { - clientAdapter.embed(List.of("dimension check")); + clientAdapter.embed(List.of("dimension check"), LlmGatewayTier.LIVE); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(EmbeddingCreateParams.class); verify(embeddingService).create(requestCaptor.capture(), any(RequestOptions.class)); assertTrue(requestCaptor.getValue().dimensions().isEmpty()); } } + + @Test + void routesEmbeddingRequestsToTierSpecificSdkClients() { + OpenAIClient liveClient = mock(OpenAIClient.class); + OpenAIClient batchClient = mock(OpenAIClient.class); + EmbeddingService liveEmbeddingService = mock(EmbeddingService.class); + EmbeddingService batchEmbeddingService = mock(EmbeddingService.class); + + when(liveClient.embeddings()).thenReturn(liveEmbeddingService); + when(batchClient.embeddings()).thenReturn(batchEmbeddingService); + + CreateEmbeddingResponse response = CreateEmbeddingResponse.builder() + .model("text-embedding-qwen3-embedding-8b") + .usage(CreateEmbeddingResponse.Usage.builder() + .promptTokens(1L) + .totalTokens(1L) + .build()) + .data(List.of(com.openai.models.embeddings.Embedding.builder() + .index(0L) + .embedding(List.of(0.4f, 0.6f)) + .build())) + .build(); + when(liveEmbeddingService.create(any(), any(RequestOptions.class))).thenReturn(response); + when(batchEmbeddingService.create(any(), any(RequestOptions.class))).thenReturn(response); + + try (OpenAiCompatibleEmbeddingClient clientAdapter = new OpenAiCompatibleEmbeddingClient( + liveClient, batchClient, "text-embedding-qwen3-embedding-8b", EXPECTED_EMBEDDING_DIMENSION)) { + clientAdapter.embed(List.of("live query"), LlmGatewayTier.LIVE); + verify(liveEmbeddingService).create(any(), any(RequestOptions.class)); + verifyNoInteractions(batchEmbeddingService); + + clientAdapter.embed(List.of("batch document"), LlmGatewayTier.BATCH); + verify(batchEmbeddingService).create(any(), any(RequestOptions.class)); + } + } } diff --git a/src/test/java/com/williamcallahan/javachat/service/OpenAiRequestFactoryTest.java b/src/test/java/com/williamcallahan/javachat/service/OpenAiRequestFactoryTest.java index 850aaac..033c371 100644 --- a/src/test/java/com/williamcallahan/javachat/service/OpenAiRequestFactoryTest.java +++ b/src/test/java/com/williamcallahan/javachat/service/OpenAiRequestFactoryTest.java @@ -49,6 +49,17 @@ void buildCompletionRequestRetainsQualifiedGitHubModelIdentifier() { assertEquals(0.25, responseCreateParams.temperature().orElseThrow(), 0.000_001); } + @Test + void buildCompletionRequestAppliesCallerOutputBudget() { + OpenAiRequestFactory requestFactory = + new OpenAiRequestFactory(new Chunker(), new PromptTruncator(), "qwen3.6:onprem", "openai/gpt-5", ""); + + ResponseCreateParams responseCreateParams = requestFactory.buildCompletionRequest( + "Rank these documents", 0.0, RateLimitService.ApiProvider.OPENAI, 128); + + assertEquals(128L, responseCreateParams.maxOutputTokens().orElseThrow()); + } + @Test void buildCompletionRequestKeepsPromptWithinSelectedOpenAiModelLimit() { OpenAiRequestFactory requestFactory = diff --git a/src/test/java/com/williamcallahan/javachat/service/RerankerServiceTest.java b/src/test/java/com/williamcallahan/javachat/service/RerankerServiceTest.java index a29a93a..75c2413 100644 --- a/src/test/java/com/williamcallahan/javachat/service/RerankerServiceTest.java +++ b/src/test/java/com/williamcallahan/javachat/service/RerankerServiceTest.java @@ -1,14 +1,22 @@ package com.williamcallahan.javachat.service; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; import com.williamcallahan.javachat.config.AppProperties; import java.util.List; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.ai.document.Document; +import reactor.core.publisher.Mono; /** * Ensures reranker surfaces failures instead of silently falling back. @@ -22,8 +30,27 @@ void rerankThrowsWhenServiceUnavailable() { RerankerService rerankerService = new RerankerService(streamingService, new ObjectMapper(), new AppProperties()); - List docs = List.of(new Document("first"), new Document("second")); + List sourceDocuments = List.of(new Document("first"), new Document("second")); - assertThrows(RerankingFailureException.class, () -> rerankerService.rerank("query", docs, 2)); + assertThrows(RerankingFailureException.class, () -> rerankerService.rerank("query", sourceDocuments, 2)); + } + + @Test + void rerankUsesBoundedCompletionBudget() { + OpenAIStreamingService streamingService = mock(OpenAIStreamingService.class); + when(streamingService.isAvailable()).thenReturn(true); + when(streamingService.complete(anyString(), eq(0.0), anyInt())).thenReturn(Mono.just("{\"order\":[1,0]}")); + + RerankerService rerankerService = + new RerankerService(streamingService, new ObjectMapper(), new AppProperties()); + List sourceDocuments = List.of(new Document("first"), new Document("second")); + + List rankedDocuments = rerankerService.rerank("query", sourceDocuments, 2); + + ArgumentCaptor outputBudgetCaptor = ArgumentCaptor.forClass(Integer.class); + verify(streamingService).complete(anyString(), eq(0.0), outputBudgetCaptor.capture()); + verify(streamingService, never()).complete(anyString(), eq(0.0)); + assertEquals(128, outputBudgetCaptor.getValue()); + assertEquals(sourceDocuments.get(1), rankedDocuments.get(0)); } }