From 348f0ed3e3c7fb3efe4b522f4b99384cd2b79644 Mon Sep 17 00:00:00 2001 From: TurinTech Bot Date: Tue, 13 May 2025 23:04:47 +0000 Subject: [PATCH 1/6] Artemis Changes --- .../llm/ModelVersionValidatorTest.java | 148 +++++++++++++++++- 1 file changed, 142 insertions(+), 6 deletions(-) diff --git a/src/test/java/com/llmproxy/service/llm/ModelVersionValidatorTest.java b/src/test/java/com/llmproxy/service/llm/ModelVersionValidatorTest.java index 621c983..a2442eb 100644 --- a/src/test/java/com/llmproxy/service/llm/ModelVersionValidatorTest.java +++ b/src/test/java/com/llmproxy/service/llm/ModelVersionValidatorTest.java @@ -90,11 +90,147 @@ private static Stream provideValidModelVersions() { } private static Stream provideInvalidModelVersions() { - return Stream.of( - Arguments.of(ModelType.OPENAI, "invalid-model"), - Arguments.of(ModelType.GEMINI, "gpt-4"), - Arguments.of(ModelType.MISTRAL, "claude-3"), - Arguments.of(ModelType.CLAUDE, "gemini-pro") + // Valid versions from each model type (to be used as INVALID for other types) + List openaiVersions = List.of( + "gpt-4o", + "gpt-4o-mini", + "gpt-4-turbo", + "gpt-4", + "gpt-4-vision-preview", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k" + ); + List geminiVersions = List.of( + "gemini-2.5-flash-preview-04-17", + "gemini-2.5-pro-preview-03-25", + "gemini-2.0-flash", + "gemini-2.0-flash-lite", + "gemini-1.5-flash", + "gemini-1.5-flash-8b", + "gemini-1.5-pro", + "gemini-pro", + "gemini-pro-vision" + ); + List mistralVersions = List.of( + "codestral-latest", + "mistral-large-latest", + "mistral-saba-latest", + "mistral-tiny", + "mistral-small", + "mistral-medium", + "mistral-large" + ); + List claudeVersions = List.of( + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + "claude-3-opus", + "claude-3-sonnet", + "claude-3-haiku", + "claude-2.1", + "claude-2.0" + ); + + // Semantically plausible but non-existent versions per model type + List plausibleOpenai = List.of( + "gpt-5-turbo", // plausible, but not in supported list + "gpt-4o-2024", // plausible pattern + "gpt-3.9-turbo", // plausible between existing versions + "gpt-4-ultimate" // plausible tier + ); + List plausibleGemini = List.of( + "gemini-1.5-ultra", // not listed + "gemini-2.1-flash", // plausible next + "gemini-1.6-pro", // plausible sequential + "gemini-super-vision" // plausible name + ); + List plausibleMistral = List.of( + "mistral-huge", // no such variant + "codestral-2025", // plausible year update + "mistral-medium-8x", // plausible config + "mistral-extra-large" // plausible name + ); + List plausibleClaude = List.of( + "claude-4-opus", // version bump + "claude-3-haiku-20240501", // extended pattern + "claude-3-giant", // plausible name + "claude-2.2" // plausible version bump ); + + Stream.Builder builder = Stream.builder(); + + // Classic 'totally unrecognized' string and prior test cases + builder.add(Arguments.of(ModelType.OPENAI, "invalid-model")); + builder.add(Arguments.of(ModelType.GEMINI, "gpt-4")); + builder.add(Arguments.of(ModelType.MISTRAL, "claude-3")); + builder.add(Arguments.of(ModelType.CLAUDE, "gemini-pro")); + + // Cross-model valid version inputs (should be INVALID for the given type) + for (String v : geminiVersions) { + builder.add(Arguments.of(ModelType.OPENAI, v)); + } + for (String v : mistralVersions) { + builder.add(Arguments.of(ModelType.OPENAI, v)); + } + for (String v : claudeVersions) { + builder.add(Arguments.of(ModelType.OPENAI, v)); + } + for (String v : openaiVersions) { + builder.add(Arguments.of(ModelType.GEMINI, v)); + } + for (String v : mistralVersions) { + builder.add(Arguments.of(ModelType.GEMINI, v)); + } + for (String v : claudeVersions) { + builder.add(Arguments.of(ModelType.GEMINI, v)); + } + for (String v : openaiVersions) { + builder.add(Arguments.of(ModelType.MISTRAL, v)); + } + for (String v : geminiVersions) { + builder.add(Arguments.of(ModelType.MISTRAL, v)); + } + for (String v : claudeVersions) { + builder.add(Arguments.of(ModelType.MISTRAL, v)); + } + for (String v : openaiVersions) { + builder.add(Arguments.of(ModelType.CLAUDE, v)); + } + for (String v : geminiVersions) { + builder.add(Arguments.of(ModelType.CLAUDE, v)); + } + for (String v : mistralVersions) { + builder.add(Arguments.of(ModelType.CLAUDE, v)); + } + + // Semantically plausible but non-existent version strings + for (String v : plausibleOpenai) { + builder.add(Arguments.of(ModelType.OPENAI, v)); + } + for (String v : plausibleGemini) { + builder.add(Arguments.of(ModelType.GEMINI, v)); + } + for (String v : plausibleMistral) { + builder.add(Arguments.of(ModelType.MISTRAL, v)); + } + for (String v : plausibleClaude) { + builder.add(Arguments.of(ModelType.CLAUDE, v)); + } + + // Some additional string-related edge cases + builder.add(Arguments.of(ModelType.OPENAI, "Gpt-4o")); // case sensitivity + builder.add(Arguments.of(ModelType.GEMINI, "Gemini-1.5-Pro")); // case sensitivity + builder.add(Arguments.of(ModelType.MISTRAL, "MISTRAL-LARGE-LATEST")); // case sensitivity + builder.add(Arguments.of(ModelType.CLAUDE, "CLAUDE-3-SONNET")); // case sensitivity + + builder.add(Arguments.of(ModelType.OPENAI, "gpt-4o!@#$")); // special chars + builder.add(Arguments.of(ModelType.GEMINI, "gemini-1.5-pro?")); // special chars + builder.add(Arguments.of(ModelType.MISTRAL, "mistral-large-latest/")); // special chars + builder.add(Arguments.of(ModelType.CLAUDE, "claude-3-sonnet~")); // special chars + + builder.add(Arguments.of(ModelType.OPENAI, "gpt-4o" + "x".repeat(100))); // overly long + builder.add(Arguments.of(ModelType.GEMINI, "gemini-2.5-pro-preview-03-25" + "y".repeat(200))); // overly long + + return builder.build(); } -} +} \ No newline at end of file From d4ee23fb1fac4ade5078f2d246316f25a66dab94 Mon Sep 17 00:00:00 2001 From: TurinTech Bot Date: Tue, 13 May 2025 23:07:13 +0000 Subject: [PATCH 2/6] Artemis Changes --- .../controller/LlmProxyControllerTest.java | 89 +++++++++++++++++- .../ratelimit/RateLimiterServiceTest.java | 54 ++++++++++- .../service/router/RouterServiceTest.java | 94 ++++++++++++++++++- 3 files changed, 234 insertions(+), 3 deletions(-) diff --git a/src/test/java/com/llmproxy/controller/LlmProxyControllerTest.java b/src/test/java/com/llmproxy/controller/LlmProxyControllerTest.java index 80617a1..dd31561 100644 --- a/src/test/java/com/llmproxy/controller/LlmProxyControllerTest.java +++ b/src/test/java/com/llmproxy/controller/LlmProxyControllerTest.java @@ -21,6 +21,7 @@ import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.mock.web.MockHttpServletRequest; +import org.junit.jupiter.api.DisplayName; import java.time.Instant; import java.util.Map; @@ -62,6 +63,7 @@ void setUp() { } @Test + @DisplayName("Should return valid response for a valid query request") void query_validRequest_returnsResponse() { QueryRequest request = QueryRequest.builder() .query("Test query") @@ -93,6 +95,7 @@ void query_validRequest_returnsResponse() { } @Test + @DisplayName("Should return bad request response for empty query") void query_emptyQuery_returnsBadRequest() { QueryRequest request = QueryRequest.builder() .query("") @@ -107,6 +110,7 @@ void query_emptyQuery_returnsBadRequest() { } @Test + @DisplayName("Should return too many requests response when rate limited") void query_rateLimited_returnsTooManyRequests() { QueryRequest request = QueryRequest.builder() .query("Test query") @@ -123,6 +127,7 @@ void query_rateLimited_returnsTooManyRequests() { } @Test + @DisplayName("Should return cached response when available") void query_cachedResponse_returnsCachedResponse() { QueryRequest request = QueryRequest.builder() .query("Test query") @@ -147,6 +152,7 @@ void query_cachedResponse_returnsCachedResponse() { } @Test + @DisplayName("Should return error response when model error occurs") void query_modelError_returnsErrorResponse() { QueryRequest request = QueryRequest.builder() .query("Test query") @@ -167,6 +173,7 @@ void query_modelError_returnsErrorResponse() { } @Test + @DisplayName("Should return model availability status") void status_returnsAvailability() { StatusResponse statusResponse = StatusResponse.builder() .openai(true) @@ -188,6 +195,7 @@ void status_returnsAvailability() { } @Test + @DisplayName("Should return OK status for health endpoint") void health_returnsOk() { ResponseEntity> response = controller.health(mockRequest); @@ -197,6 +205,7 @@ void health_returnsOk() { } @Test + @DisplayName("Should return file when download request is valid") void download_validRequest_returnsFile() { Map request = Map.of( "response", "Test response", @@ -210,4 +219,82 @@ void download_validRequest_returnsFile() { assertEquals("attachment; filename=llm_response.txt", response.getHeaders().getFirst("Content-Disposition")); assertEquals("Test response", new String(response.getBody())); } -} + + @Test + @DisplayName("Should return bad request when download request has empty response") + void download_emptyResponse_returnsBadRequest() { + Map request = Map.of( + "response", "", + "format", "txt" + ); + + ResponseEntity response = controller.download(request, mockRequest); + + assertEquals(HttpStatus.BAD_REQUEST, response.getStatusCode()); + } + + @Test + @DisplayName("Should return too many requests when rate limited for download") + void download_rateLimited_returnsTooManyRequests() { + Map request = Map.of( + "response", "Test response", + "format", "txt" + ); + + lenient().when(rateLimiterService.allowClient(anyString())).thenReturn(false); + + ResponseEntity response = controller.download(request, mockRequest); + + assertEquals(HttpStatus.TOO_MANY_REQUESTS, response.getStatusCode()); + } + + @Test + @DisplayName("Should use default format when format not specified") + void download_nullFormat_usesDefaultFormat() { + Map request = Map.of( + "response", "Test response" + ); + + ResponseEntity response = controller.download(request, mockRequest); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + assertEquals(MediaType.TEXT_PLAIN_VALUE, response.getHeaders().getContentType().toString()); + assertEquals("attachment; filename=llm_response.txt", response.getHeaders().getFirst("Content-Disposition")); + } + + @Test + @DisplayName("Should return too many requests when rate limited for status") + void status_rateLimited_returnsTooManyRequests() { + lenient().when(rateLimiterService.allowClient(anyString())).thenReturn(false); + + ResponseEntity response = controller.status(mockRequest); + + assertEquals(HttpStatus.TOO_MANY_REQUESTS, response.getStatusCode()); + } + + @Test + @DisplayName("Should return too many requests when rate limited for health") + void health_rateLimited_returnsTooManyRequests() { + lenient().when(rateLimiterService.allowClient(anyString())).thenReturn(false); + + ResponseEntity> response = controller.health(mockRequest); + + assertEquals(HttpStatus.TOO_MANY_REQUESTS, response.getStatusCode()); + } + + @Test + @DisplayName("Should return bad request for query exceeding maximum length") + void query_tooLongQuery_returnsBadRequest() { + String longQuery = "a".repeat(32001); + QueryRequest request = QueryRequest.builder() + .query(longQuery) + .build(); + + ResponseEntity response = controller.query(request, mockRequest); + + assertEquals(HttpStatus.BAD_REQUEST, response.getStatusCode()); + assertNotNull(response.getBody()); + assertEquals("Query exceeds maximum length of 32000 characters", response.getBody().getError()); + assertEquals("validation_error", response.getBody().getErrorType()); + } +} \ No newline at end of file diff --git a/src/test/java/com/llmproxy/service/ratelimit/RateLimiterServiceTest.java b/src/test/java/com/llmproxy/service/ratelimit/RateLimiterServiceTest.java index ea890db..a6ca3c9 100644 --- a/src/test/java/com/llmproxy/service/ratelimit/RateLimiterServiceTest.java +++ b/src/test/java/com/llmproxy/service/ratelimit/RateLimiterServiceTest.java @@ -63,4 +63,56 @@ void allowClient_withCustomFunction_usesFunction() { assertFalse(limiter.allowClient("blocked")); assertEquals(2, callCount.get()); } -} + + @Test + void allow_zeroRefillRate_tokensDoNotRefill() { + RateLimiterService limiter = new RateLimiterService(0, 3); // 0 requests per minute, burst 3 + + assertTrue(limiter.allow()); // Consumes 1st token + assertTrue(limiter.allow()); // Consumes 2nd token + assertTrue(limiter.allow()); // Consumes 3rd token + + // All tokens consumed, and refill rate is 0 + assertFalse(limiter.allow()); // Should be false as no tokens can be refilled + assertFalse(limiter.allow()); // Should remain false + } + + @Test + void allowClient_zeroRefillRate_clientTokensDoNotRefill() { + // Parent limiter has 0 refill rate, so client-specific limiters will also have 0 refill rate. + RateLimiterService parentLimiter = new RateLimiterService(0, 2); + + // Test for client1 + assertTrue(parentLimiter.allowClient("client1")); // Uses 1st token for client1 + assertTrue(parentLimiter.allowClient("client1")); // Uses 2nd token for client1 + assertFalse(parentLimiter.allowClient("client1"));// No more tokens for client1, should not refill + assertFalse(parentLimiter.allowClient("client1"));// Still no tokens for client1 + + // Test for client2 (should be independent but also not refill) + assertTrue(parentLimiter.allowClient("client2")); // Uses 1st token for client2 + assertTrue(parentLimiter.allowClient("client2")); // Uses 2nd token for client2 + assertFalse(parentLimiter.allowClient("client2"));// No more tokens for client2, should not refill + } + + @Test + void allow_zeroBurst_alwaysReturnsFalse() { + RateLimiterService limiter = new RateLimiterService(60, 0); // 60 requests per minute, burst 0 + + // With zero burst capacity, no tokens are ever available. + assertFalse(limiter.allow()); + assertFalse(limiter.allow()); // Should remain false even if time passes, as maxTokens is 0. + } + + @Test + void allowClient_zeroBurst_alwaysReturnsFalseForClients() { + // Parent limiter has 0 burst, so client-specific limiters will also have 0 burst. + RateLimiterService parentLimiter = new RateLimiterService(60, 0); + + // Test for client1 + assertFalse(parentLimiter.allowClient("client1")); // Client limiter inherits zero burst + assertFalse(parentLimiter.allowClient("client1")); + + // Test for client2 + assertFalse(parentLimiter.allowClient("client2")); // Client limiter inherits zero burst + } +} \ No newline at end of file diff --git a/src/test/java/com/llmproxy/service/router/RouterServiceTest.java b/src/test/java/com/llmproxy/service/router/RouterServiceTest.java index d61851c..2c8e005 100644 --- a/src/test/java/com/llmproxy/service/router/RouterServiceTest.java +++ b/src/test/java/com/llmproxy/service/router/RouterServiceTest.java @@ -228,4 +228,96 @@ void fallbackOnError_withUserSpecifiedModel_usesThatModel() { assertEquals(ModelType.MISTRAL, result); } -} + + // ---- Enhanced Robustness tests for fallbackOnError start here ---- + + @Test + void fallbackOnError_withUserSpecifiedFallbackButUnavailable_selectsOtherAvailable() { + routerService.setModelAvailability(ModelType.OPENAI, false); + routerService.setModelAvailability(ModelType.GEMINI, false); // user fallback is unavailable + routerService.setModelAvailability(ModelType.MISTRAL, true); // Mistral is available + routerService.setModelAvailability(ModelType.CLAUDE, false); + + // User wanted GEMINI, which is unavailable + QueryRequest request = QueryRequest.builder() + .query("Test query") + .model(ModelType.GEMINI) + .build(); + + ModelError error = ModelError.rateLimitError(ModelType.OPENAI.toString()); + + // Should select the next available model, which is MISTRAL + ModelType result = routerService.fallbackOnError(ModelType.OPENAI, request, error); + assertEquals(ModelType.MISTRAL, result); + } + + @Test + void fallbackOnError_withAllButOneModelAvailable_returnsThatModel() { + // Only Claude is up + routerService.setModelAvailability(ModelType.OPENAI, false); + routerService.setModelAvailability(ModelType.GEMINI, false); + routerService.setModelAvailability(ModelType.MISTRAL, false); + routerService.setModelAvailability(ModelType.CLAUDE, true); + + QueryRequest request = QueryRequest.builder() + .query("Test query") + .build(); + ModelError error = ModelError.rateLimitError(ModelType.OPENAI.toString()); + ModelType result = routerService.fallbackOnError(ModelType.OPENAI, request, error); + assertEquals(ModelType.CLAUDE, result); + } + + @Test + void fallbackOnError_initialModelIsOnlyAvailableOptionButFails_noAlternatives_throw() { + // Only OpenAI is up, but it fails and fallback should not find any others + routerService.setModelAvailability(ModelType.OPENAI, true); + routerService.setModelAvailability(ModelType.GEMINI, false); + routerService.setModelAvailability(ModelType.MISTRAL, false); + routerService.setModelAvailability(ModelType.CLAUDE, false); + QueryRequest request = QueryRequest.builder() + .query("Test query") + .build(); + ModelError error = ModelError.rateLimitError(ModelType.OPENAI.toString()); + // Mark OpenAI as unavailable now (simulate it failed and shouldn't be retried) + routerService.setModelAvailability(ModelType.OPENAI, false); + assertThrows(ModelError.class, () -> routerService.fallbackOnError(ModelType.OPENAI, request, error)); + } + + @Test + void fallbackOnError_withTaskTypeAndPreferredModelUnavailable_selectsTaskTypeAlternative() { + // Use sentiment_analysis mapping: prefer OpenAI, if not, next available + routerService.setModelAvailability(ModelType.OPENAI, false); + routerService.setModelAvailability(ModelType.GEMINI, false); + routerService.setModelAvailability(ModelType.MISTRAL, true); + routerService.setModelAvailability(ModelType.CLAUDE, false); + + QueryRequest request = QueryRequest.builder() + .query("Test query") + .taskType(TaskType.SENTIMENT_ANALYSIS) + .build(); + ModelError error = ModelError.rateLimitError(ModelType.OPENAI.toString()); + ModelType result = routerService.fallbackOnError(ModelType.OPENAI, request, error); + assertEquals(ModelType.MISTRAL, result); + } + + @Test + void fallbackOnError_specifiedModelEqualsFailedModel_returnsAlternative() { + // User asked for MISTRAL, but MISTRAL failed, so must route elsewhere + routerService.setModelAvailability(ModelType.OPENAI, false); + routerService.setModelAvailability(ModelType.GEMINI, true); + routerService.setModelAvailability(ModelType.MISTRAL, true); // was up, but failed + routerService.setModelAvailability(ModelType.CLAUDE, true); + + QueryRequest request = QueryRequest.builder() + .query("Test query") + .model(ModelType.MISTRAL) + .build(); + // Mark MISTRAL as down now ("fails") + routerService.setModelAvailability(ModelType.MISTRAL, false); + + ModelError error = ModelError.rateLimitError(ModelType.MISTRAL.toString()); + // Should skip MISTRAL and pick next available: CLAUDE or GEMINI (impl can pick one of them; assert one of them) + ModelType result = routerService.fallbackOnError(ModelType.MISTRAL, request, error); + assertTrue(result == ModelType.GEMINI || result == ModelType.CLAUDE); + } +} \ No newline at end of file From b216a762de42bafb3618f17c0aac80c223549230 Mon Sep 17 00:00:00 2001 From: Mike Basios Date: Wed, 14 May 2025 00:26:53 +0100 Subject: [PATCH 3/6] feat: add benchamrks --- .../service/cache/CacheServiceTest.java | 332 +++++++++++++++++- .../ratelimit/RateLimiterServiceTest.java | 237 ++++++++++++- .../service/router/RouterServiceTest.java | 258 +++++++++++++- 3 files changed, 802 insertions(+), 25 deletions(-) diff --git a/src/test/java/com/llmproxy/service/cache/CacheServiceTest.java b/src/test/java/com/llmproxy/service/cache/CacheServiceTest.java index 03a3f86..2ae9643 100644 --- a/src/test/java/com/llmproxy/service/cache/CacheServiceTest.java +++ b/src/test/java/com/llmproxy/service/cache/CacheServiceTest.java @@ -7,46 +7,64 @@ import com.llmproxy.model.TaskType; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.junit.jupiter.api.Assertions.*; +@Execution(ExecutionMode.CONCURRENT) class CacheServiceTest { - private CacheService cacheService; - private QueryRequest request; - private QueryResponse response; + private static final Logger logger = LoggerFactory.getLogger(CacheServiceTest.class); + + private final AtomicReference cacheServiceRef = new AtomicReference<>(); + private final AtomicReference requestRef = new AtomicReference<>(); + private final AtomicReference responseRef = new AtomicReference<>(); @BeforeEach void setUp() { - cacheService = new CacheService(true, 300, 1000, new ObjectMapper()); + cacheServiceRef.set(new CacheService(true, 300, 1000, new ObjectMapper())); - request = QueryRequest.builder() + requestRef.set(QueryRequest.builder() .query("Test query") .model(ModelType.OPENAI) .taskType(TaskType.TEXT_GENERATION) - .build(); + .build()); - response = QueryResponse.builder() + responseRef.set(QueryResponse.builder() .response("Test response") .model(ModelType.OPENAI) - .build(); + .build()); } @Test void get_cacheDisabled_returnsNull() { CacheService disabledCache = new CacheService(false, 300, 1000, new ObjectMapper()); - disabledCache.set(request, response); + disabledCache.set(requestRef.get(), responseRef.get()); - assertNull(disabledCache.get(request)); + assertNull(disabledCache.get(requestRef.get())); } @Test void get_cacheEnabled_cacheMiss_returnsNull() { - assertNull(cacheService.get(request)); + assertNull(cacheServiceRef.get().get(requestRef.get())); } @Test void get_cacheEnabled_cacheHit_returnsResponse() { + CacheService cacheService = cacheServiceRef.get(); + QueryRequest request = requestRef.get(); + QueryResponse response = responseRef.get(); + cacheService.set(request, response); QueryResponse cachedResponse = cacheService.get(request); @@ -58,13 +76,17 @@ void get_cacheEnabled_cacheHit_returnsResponse() { @Test void set_cacheDisabled_doesNothing() { CacheService disabledCache = new CacheService(false, 300, 1000, new ObjectMapper()); - disabledCache.set(request, response); + disabledCache.set(requestRef.get(), responseRef.get()); - assertNull(disabledCache.get(request)); + assertNull(disabledCache.get(requestRef.get())); } @Test void set_cacheEnabled_storesResponse() { + CacheService cacheService = cacheServiceRef.get(); + QueryRequest request = requestRef.get(); + QueryResponse response = responseRef.get(); + cacheService.set(request, response); QueryResponse cachedResponse = cacheService.get(request); @@ -74,6 +96,9 @@ void set_cacheEnabled_storesResponse() { @Test void generateCacheKey_differentQueries_differentKeys() { + CacheService cacheService = cacheServiceRef.get(); + QueryResponse response = responseRef.get(); + QueryRequest request1 = QueryRequest.builder() .query("Query 1") .model(ModelType.OPENAI) @@ -92,6 +117,9 @@ void generateCacheKey_differentQueries_differentKeys() { @Test void generateCacheKey_differentModels_differentKeys() { + CacheService cacheService = cacheServiceRef.get(); + QueryResponse response = responseRef.get(); + QueryRequest request1 = QueryRequest.builder() .query("Test query") .model(ModelType.OPENAI) @@ -110,6 +138,9 @@ void generateCacheKey_differentModels_differentKeys() { @Test void generateCacheKey_differentTaskTypes_differentKeys() { + CacheService cacheService = cacheServiceRef.get(); + QueryResponse response = responseRef.get(); + QueryRequest request1 = QueryRequest.builder() .query("Test query") .taskType(TaskType.TEXT_GENERATION) @@ -125,4 +156,277 @@ void generateCacheKey_differentTaskTypes_differentKeys() { assertNotNull(cacheService.get(request1)); assertNull(cacheService.get(request2)); } -} + + @Test + @Execution(ExecutionMode.SAME_THREAD) + void cacheExpiration_shortDuration_entriesExpireQuickly() throws InterruptedException { + // Set up cache with very short expiration time (1 second) + int shortExpirationTime = 1; // seconds + CacheService shortCache = new CacheService(true, shortExpirationTime, 1000, new ObjectMapper()); + QueryRequest request = requestRef.get(); + QueryResponse response = responseRef.get(); + + logger.info("Testing cache with expiration time: {} seconds", shortExpirationTime); + + shortCache.set(request, response); + + // Verify item is in the cache + assertNotNull(shortCache.get(request)); + + // Wait for the entry to expire + TimeUnit.SECONDS.sleep(shortExpirationTime + 1); + + // Verify item is no longer in the cache + assertNull(shortCache.get(request), + "Cache entry should have expired after " + shortExpirationTime + " seconds"); + } + + @Test + @Execution(ExecutionMode.SAME_THREAD) + void cacheExpiration_moderateDuration_entriesExpireAfterTimeout() throws InterruptedException { + // Set up cache with moderate expiration time (2 seconds) + int moderateExpirationTime = 2; // seconds + CacheService moderateCache = new CacheService(true, moderateExpirationTime, 1000, new ObjectMapper()); + QueryRequest request = requestRef.get(); + QueryResponse response = responseRef.get(); + + logger.info("Testing cache with expiration time: {} seconds", moderateExpirationTime); + + moderateCache.set(request, response); + + // Verify item is in the cache + assertNotNull(moderateCache.get(request)); + + // Check that it's still there after half the expiration time + TimeUnit.MILLISECONDS.sleep(moderateExpirationTime * 500); + assertNotNull(moderateCache.get(request), + "Cache entry should still exist after half the expiration time"); + + // Wait for the entry to expire + TimeUnit.MILLISECONDS.sleep(moderateExpirationTime * 1500); + + // Verify item is no longer in the cache + assertNull(moderateCache.get(request), + "Cache entry should have expired after " + moderateExpirationTime + " seconds"); + } + + @Test + @Execution(ExecutionMode.SAME_THREAD) + void cacheExpiration_longDuration_entriesRemainValid() throws InterruptedException { + // Set up cache with longer expiration time + int longExpirationTime = 5; // seconds + CacheService longCache = new CacheService(true, longExpirationTime, 1000, new ObjectMapper()); + QueryRequest request = requestRef.get(); + QueryResponse response = responseRef.get(); + + logger.info("Testing cache with expiration time: {} seconds", longExpirationTime); + + longCache.set(request, response); + + // Verify item is in the cache initially + assertNotNull(longCache.get(request)); + + // Wait for some time (but less than expiration) + TimeUnit.SECONDS.sleep(1); + + // Verify item is still in the cache + assertNotNull(longCache.get(request), + "Cache entry should still exist before expiration time"); + } + + @Test + @Tag("performance") + void performanceBenchmark_lookupSpeed() { + CacheService cacheService = cacheServiceRef.get(); + + // Warm-up phase + int warmupItems = 20; + List warmupRequests = generateUniqueRequests(warmupItems); + for (int i = 0; i < warmupItems; i++) { + QueryResponse resp = QueryResponse.builder() + .response("Warmup response " + i) + .model(ModelType.OPENAI) + .build(); + cacheService.set(warmupRequests.get(i), resp); + cacheService.get(warmupRequests.get(i)); // Warm up get operation + } + + // Prepare a pre-populated cache + int numItems = 100; + List requests = new ArrayList<>(); + + // Generate and cache requests + for (int i = 0; i < numItems; i++) { + QueryRequest req = QueryRequest.builder() + .query("Performance test query " + i) + .model(ModelType.OPENAI) + .taskType(TaskType.TEXT_GENERATION) + .requestId(UUID.randomUUID().toString()) + .build(); + + QueryResponse resp = QueryResponse.builder() + .response("Test response " + i) + .model(ModelType.OPENAI) + .build(); + + cacheService.set(req, resp); + requests.add(req); + } + + // Benchmark lookups with multiple iterations for more accurate results + int iterations = 5; + double[] iterationTimes = new double[iterations]; + + for (int iter = 0; iter < iterations; iter++) { + long startTime = System.nanoTime(); + for (QueryRequest req : requests) { + QueryResponse resp = cacheService.get(req); + assertNotNull(resp); + } + long endTime = System.nanoTime(); + + iterationTimes[iter] = (endTime - startTime) / (double)(numItems * 1_000_000); + } + + // Calculate median time (more stable than average) + java.util.Arrays.sort(iterationTimes); + double medianLookupTimeMs = iterationTimes[iterations / 2]; + double avgLookupTimeMs = java.util.Arrays.stream(iterationTimes).average().orElse(0); + + logger.info("Cache lookup performance: Median={} ms, Avg={} ms", + String.format("%.3f", medianLookupTimeMs), + String.format("%.3f", avgLookupTimeMs)); + } + + @Test + @Tag("performance") + void performanceBenchmark_insertionSpeed() { + CacheService cacheService = cacheServiceRef.get(); + int numItems = 1000; + List requests = generateUniqueRequests(numItems); + + // Warm-up phase + int warmupItems = 50; + for (int i = 0; i < warmupItems; i++) { + QueryResponse resp = QueryResponse.builder() + .response("Warmup response " + i) + .model(ModelType.OPENAI) + .build(); + cacheService.set(requests.get(i), resp); + } + + // Benchmark insertions with multiple iterations + int iterations = 5; + double[] iterationTimes = new double[iterations]; + + for (int iter = 0; iter < iterations; iter++) { + // Create a fresh cache for each iteration to avoid measuring cache size impact + CacheService freshCache = new CacheService(true, 300, 1000, new ObjectMapper()); + + long startTime = System.nanoTime(); + for (int i = 0; i < numItems; i++) { + QueryResponse resp = QueryResponse.builder() + .response("Performance test response " + i) + .model(ModelType.OPENAI) + .build(); + + freshCache.set(requests.get(i), resp); + } + long endTime = System.nanoTime(); + + iterationTimes[iter] = (endTime - startTime) / (double)(numItems * 1_000_000); + } + + // Calculate median time + java.util.Arrays.sort(iterationTimes); + double medianInsertTimeMs = iterationTimes[iterations / 2]; + double avgInsertTimeMs = java.util.Arrays.stream(iterationTimes).average().orElse(0); + + logger.info("Cache insertion performance: Median={} ms, Avg={} ms", + String.format("%.3f", medianInsertTimeMs), + String.format("%.3f", avgInsertTimeMs)); + } + + @Test + @Tag("performance") + void performanceBenchmark_scalingWithSize() { + // Test with different cache sizes + int[] cacheSizes = {100, 1000, 5000}; + + for (int size : cacheSizes) { + CacheService sizedCache = new CacheService(true, 300, size, new ObjectMapper()); + List requests = generateUniqueRequests(size); + + // Warmup phase + int warmupItems = Math.min(size / 10, 100); + for (int i = 0; i < warmupItems; i++) { + QueryResponse resp = QueryResponse.builder() + .response("Warmup response " + i) + .model(ModelType.OPENAI) + .build(); + + sizedCache.set(requests.get(i), resp); + sizedCache.get(requests.get(i)); // Warm up get operation + } + + // Measure insertion speed + long insertStart = System.nanoTime(); + for (int i = 0; i < size; i++) { + QueryResponse resp = QueryResponse.builder() + .response("Scaling test response " + i) + .model(ModelType.OPENAI) + .build(); + + sizedCache.set(requests.get(i), resp); + } + long insertEnd = System.nanoTime(); + + // Measure lookup speed (random access) + int lookupCount = Math.min(size, 1000); + long[] lookupTimes = new long[lookupCount]; + + for (int i = 0; i < lookupCount; i++) { + int index = (int)(Math.random() * size); + long start = System.nanoTime(); + sizedCache.get(requests.get(index)); + lookupTimes[i] = System.nanoTime() - start; + } + + double totalInsertTimeMs = (insertEnd - insertStart) / 1_000_000.0; + double avgInsertTimeMs = totalInsertTimeMs / size; + double avgLookupTimeNs = java.util.Arrays.stream(lookupTimes).average().orElse(0); + double avgLookupTimeMs = avgLookupTimeNs / 1_000_000.0; + double p95LookupTimeMs = calculatePercentile(lookupTimes, 95) / 1_000_000.0; + + logger.info("Cache size {}: Avg insert={} ms, Avg lookup={} ms, P95 lookup={} ms", + size, + String.format("%.3f", avgInsertTimeMs), + String.format("%.3f", avgLookupTimeMs), + String.format("%.3f", p95LookupTimeMs)); + } + } + + private double calculatePercentile(long[] times, double percentile) { + java.util.Arrays.sort(times); + int index = (int) Math.ceil(percentile / 100.0 * times.length) - 1; + return times[index]; + } + + // Helper method to generate unique requests for performance testing + private List generateUniqueRequests(int count) { + List requests = new ArrayList<>(count); + for (int i = 0; i < count; i++) { + QueryRequest req = QueryRequest.builder() + .query("Generated query " + i) + .model(i % 2 == 0 ? ModelType.OPENAI : ModelType.GEMINI) + .taskType(i % 4 == 0 ? TaskType.TEXT_GENERATION : + i % 4 == 1 ? TaskType.SUMMARIZATION : + i % 4 == 2 ? TaskType.SENTIMENT_ANALYSIS : + TaskType.QUESTION_ANSWERING) + .requestId(UUID.randomUUID().toString()) + .build(); + requests.add(req); + } + return requests; + } +} \ No newline at end of file diff --git a/src/test/java/com/llmproxy/service/ratelimit/RateLimiterServiceTest.java b/src/test/java/com/llmproxy/service/ratelimit/RateLimiterServiceTest.java index ea890db..2417c09 100644 --- a/src/test/java/com/llmproxy/service/ratelimit/RateLimiterServiceTest.java +++ b/src/test/java/com/llmproxy/service/ratelimit/RateLimiterServiceTest.java @@ -1,7 +1,13 @@ package com.llmproxy.service.ratelimit; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.DisplayName; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; @@ -9,39 +15,51 @@ class RateLimiterServiceTest { + private ExecutorService executor; + + @BeforeEach + void setupExecutor() { + executor = Executors.newCachedThreadPool(); + } + + @AfterEach + void teardownExecutor() { + executor.shutdownNow(); + } + @Test void allow_withinLimit_returnsTrue() { RateLimiterService limiter = new RateLimiterService(60, 10); - + assertTrue(limiter.allow()); } @Test void allow_exceedsLimit_returnsFalse() { RateLimiterService limiter = new RateLimiterService(60, 3); - + assertTrue(limiter.allow()); assertTrue(limiter.allow()); assertTrue(limiter.allow()); - + assertFalse(limiter.allow()); } @Test void allowClient_withinLimit_returnsTrue() { RateLimiterService limiter = new RateLimiterService(60, 10); - + assertTrue(limiter.allowClient("client1")); } @Test void allowClient_differentClients_separateLimits() { RateLimiterService limiter = new RateLimiterService(60, 2); - + assertTrue(limiter.allowClient("client1")); assertTrue(limiter.allowClient("client1")); assertFalse(limiter.allowClient("client1")); - + assertTrue(limiter.allowClient("client2")); assertTrue(limiter.allowClient("client2")); assertFalse(limiter.allowClient("client2")); @@ -50,17 +68,216 @@ void allowClient_differentClients_separateLimits() { @Test void allowClient_withCustomFunction_usesFunction() { RateLimiterService limiter = new RateLimiterService(60, 10); - + AtomicInteger callCount = new AtomicInteger(0); Function customFunc = clientId -> { callCount.incrementAndGet(); return "allowed".equals(clientId); }; - + limiter.setAllowClientFunc(customFunc); - + assertTrue(limiter.allowClient("allowed")); assertFalse(limiter.allowClient("blocked")); assertEquals(2, callCount.get()); } -} + + // --- PERFORMANCE BENCHMARK TESTS --- + + @Test + @DisplayName("Throughput benchmark: single-threaded, default/global limiter") + void throughput_singleThread_globalLimiter() { + RateLimiterService limiter = new RateLimiterService(600, 100); // 10 QPS global for 10 seconds window + int attempts = 100; + long start = System.nanoTime(); + int allowed = 0; + for (int i = 0; i < attempts; i++) { + if (limiter.allow()) allowed++; + } + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + // Within burst, all should be allowed + assertEquals(attempts, allowed); + System.out.println("[throughput_singleThread_globalLimiter] Time: " + elapsedMs + " ms for " + allowed + " ops"); + } + + @Test + @DisplayName("Throughput benchmark: multi-threaded, global limiter, burst exhaustion") + void throughput_multiThread_globalLimiter_burstLimit() throws InterruptedException { + final int burst = 20; + final int threads = 8; + final RateLimiterService limiter = new RateLimiterService(600, burst); // 10 QPS, burst of 20 + AtomicInteger accepted = new AtomicInteger(0); + AtomicInteger rejected = new AtomicInteger(0); + int attemptsPerThread = 10; + + List> tasks = new ArrayList<>(); + for (int t = 0; t < threads; t++) { + tasks.add(() -> { + for (int i = 0; i < attemptsPerThread; i++) { + if (limiter.allow()) accepted.incrementAndGet(); + else rejected.incrementAndGet(); + } + return null; + }); + } + long start = System.nanoTime(); + executor.invokeAll(tasks); + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + assertEquals(burst, accepted.get()); + assertEquals(threads * attemptsPerThread - burst, rejected.get()); + System.out.println("[throughput_multiThread_globalLimiter_burstLimit] Time: " + elapsedMs + " ms, Accepted: " + accepted.get() + ", Rejected: " + rejected.get()); + } + + @Test + @DisplayName("Latency benchmark: multi-threaded, global limiter") + void latency_multiThreaded_globalLimiter() throws InterruptedException { + final int burst = 40; + final int threads = 20; + final RateLimiterService limiter = new RateLimiterService(1000, burst); + final CyclicBarrier barrier = new CyclicBarrier(threads); + List latencies = new CopyOnWriteArrayList<>(); + + List> tasks = new ArrayList<>(); + for (int t = 0; t < threads; t++) { + tasks.add(() -> { + barrier.await(); + long t0 = System.nanoTime(); + limiter.allow(); + latencies.add(System.nanoTime() - t0); + return null; + }); + } + executor.invokeAll(tasks); + + double avgLatencyUs = latencies.stream().mapToLong(x -> x).average().orElse(0) / 1000.0; + long maxLatencyUs = latencies.stream().mapToLong(x -> x).max().orElse(0) / 1000; + System.out.println("[latency_multiThreaded_globalLimiter] Average: " + avgLatencyUs + " μs, Max: " + maxLatencyUs + " μs"); + assertTrue(avgLatencyUs < 5000); // Should be low for single op in-memory + } + + @Test + @DisplayName("Throughput & isolation: multi-threaded, client-specific, distinct limits") + void throughput_multiClient_isolatedLimits() throws InterruptedException { + final int clientCount = 10; + final int burst = 5; + final RateLimiterService limiter = new RateLimiterService(100, burst); // generous refill for test + List clientIds = new ArrayList<>(); + for (int i = 0; i < clientCount; i++) clientIds.add("client" + i); + + AtomicInteger allowed = new AtomicInteger(0); + AtomicInteger rejected = new AtomicInteger(0); + + List> tasks = new ArrayList<>(); + for (String client : clientIds) { + tasks.add(() -> { + // Let's try 10 attempts per client + for (int i = 0; i < 10; i++) { + if (limiter.allowClient(client)) { + allowed.incrementAndGet(); + } else { + rejected.incrementAndGet(); + } + } + return null; + }); + } + + long start = System.nanoTime(); + executor.invokeAll(tasks); + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + + // Each client can burst up to 5, so allowed = 5 * clientCount, rejected = rest + assertEquals(burst * clientCount, allowed.get()); + assertEquals((10 - burst) * clientCount, rejected.get()); + System.out.println("[throughput_multiClient_isolatedLimits] Time: " + elapsedMs + " ms, Allowed: " + allowed.get() + ", Rejected: " + rejected.get()); + } + + @Test + @DisplayName("Client-specific: variable client counts & request rates") + void clientSpecific_variablePatterns_benchmark() throws InterruptedException { + final int[] clientCounts = {1, 5, 20, 50}; + final int burst = 20; + for (int c = 0; c < clientCounts.length; c++) { + int numClients = clientCounts[c]; + RateLimiterService limiter = new RateLimiterService(200, burst); + AtomicInteger allowed = new AtomicInteger(0); + AtomicInteger rejected = new AtomicInteger(0); + List> tasks = new ArrayList<>(); + for (int i = 0; i < numClients; i++) { + String client = "client-" + i; + tasks.add(() -> { + for (int j = 0; j < burst + 10; j++) { + if (limiter.allowClient(client)) allowed.incrementAndGet(); + else rejected.incrementAndGet(); + } + return null; + }); + } + long start = System.nanoTime(); + executor.invokeAll(tasks); + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + int expectedAllowed = burst * numClients; + int expectedRejected = 10 * numClients; + assertEquals(expectedAllowed, allowed.get()); + assertEquals(expectedRejected, rejected.get()); + System.out.printf("[clientSpecific_variablePatterns_benchmark] Clients: %d, Time: %d ms, Allowed: %d, Rejected: %d%n", numClients, elapsedMs, allowed.get(), rejected.get()); + } + } + + @Test + @DisplayName("Resource usage: stress test with many clients and threads") + void resourceUsage_stress_highConcurrency_manyClients() throws InterruptedException { + final int clientCount = 100; + final int threads = 50; + final int burst = 10; + RateLimiterService limiter = new RateLimiterService(1000, burst); + AtomicInteger allowed = new AtomicInteger(0); + AtomicInteger rejected = new AtomicInteger(0); + + List clientIds = new ArrayList<>(); + for (int i = 0; i < clientCount; i++) clientIds.add("client-" + i); + + List> tasks = new ArrayList<>(); + for (int t = 0; t < threads; t++) { + tasks.add(() -> { + for (String clientId : clientIds) { + // Each thread does 5 attempts per client (should be over burst for most) + for (int att = 0; att < 5; att++) { + if (limiter.allowClient(clientId)) allowed.incrementAndGet(); + else rejected.incrementAndGet(); + } + } + return null; + }); + } + long memBefore = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory(); + long start = System.nanoTime(); + executor.invokeAll(tasks); + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + long memAfter = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory(); + long memDeltaKb = (memAfter - memBefore) / 1024; + System.out.printf("[resourceUsage_stress_highConcurrency_manyClients] Time: %d ms, Memory used: %d KB, Allowed: %d, Rejected: %d%n", + elapsedMs, memDeltaKb, allowed.get(), rejected.get()); + // We cannot assert exactly due to concurrent race but numbers should be reasonable + assertTrue(allowed.get() > 0); + assertTrue(rejected.get() > 0); + assertTrue(memDeltaKb >= 0); + } + + @Test + @DisplayName("Latency under exhaustion: repeated over-limit calls") + void latency_underExhaustion() { + RateLimiterService limiter = new RateLimiterService(60, 3); + + // Exhaust burst limit + for (int i = 0; i < 3; i++) assertTrue(limiter.allow()); + assertFalse(limiter.allow()); + + // Now, measure latency when over limit + long t0 = System.nanoTime(); + boolean allowed = limiter.allow(); + long elapsedUs = (System.nanoTime() - t0) / 1_000; + assertFalse(allowed); + System.out.println("[latency_underExhaustion] Over-limit check latency: " + elapsedUs + " μs (should be very low)"); + } +} \ No newline at end of file diff --git a/src/test/java/com/llmproxy/service/router/RouterServiceTest.java b/src/test/java/com/llmproxy/service/router/RouterServiceTest.java index d61851c..00c41ff 100644 --- a/src/test/java/com/llmproxy/service/router/RouterServiceTest.java +++ b/src/test/java/com/llmproxy/service/router/RouterServiceTest.java @@ -9,10 +9,19 @@ import com.llmproxy.service.llm.LlmClientFactory; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; @@ -228,4 +237,251 @@ void fallbackOnError_withUserSpecifiedModel_usesThatModel() { assertEquals(ModelType.MISTRAL, result); } -} + + // Performance benchmarking tests + + @Test + @Tag("performance") + void benchmark_routeRequest_singleThread() { + // Setup - all models available + routerService.setModelAvailability(ModelType.OPENAI, true); + routerService.setModelAvailability(ModelType.GEMINI, true); + routerService.setModelAvailability(ModelType.MISTRAL, true); + routerService.setModelAvailability(ModelType.CLAUDE, true); + + QueryRequest request = QueryRequest.builder() + .query("Performance test query") + .build(); + + // Warm up + for (int i = 0; i < 100; i++) { + routerService.routeRequest(request); + } + + // Benchmark + int numRequests = 10000; + long startTime = System.nanoTime(); + + for (int i = 0; i < numRequests; i++) { + routerService.routeRequest(request); + } + + long endTime = System.nanoTime(); + double elapsedTimeMs = (endTime - startTime) / 1_000_000.0; + double requestsPerSecond = numRequests / (elapsedTimeMs / 1000.0); + + System.out.println("Single thread performance:"); + System.out.println("Requests processed: " + numRequests); + System.out.println("Total time (ms): " + elapsedTimeMs); + System.out.println("Requests per second: " + requestsPerSecond); + + // Simple assertion to verify test ran without error + assertTrue(requestsPerSecond > 0); + } + + @Test + @Tag("performance") + void benchmark_routeRequest_multiThread() throws InterruptedException { + // Setup - all models available + routerService.setModelAvailability(ModelType.OPENAI, true); + routerService.setModelAvailability(ModelType.GEMINI, true); + routerService.setModelAvailability(ModelType.MISTRAL, true); + routerService.setModelAvailability(ModelType.CLAUDE, true); + + int numThreads = 8; + int requestsPerThread = 1000; + int totalRequests = numThreads * requestsPerThread; + + QueryRequest request = QueryRequest.builder() + .query("Performance test query") + .build(); + + // Warm up + for (int i = 0; i < 100; i++) { + routerService.routeRequest(request); + } + + ExecutorService executorService = Executors.newFixedThreadPool(numThreads); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch endLatch = new CountDownLatch(numThreads); + AtomicInteger successCounter = new AtomicInteger(0); + AtomicInteger errorCounter = new AtomicInteger(0); + + for (int t = 0; t < numThreads; t++) { + executorService.submit(() -> { + try { + startLatch.await(); // Wait for all threads to be ready + + for (int i = 0; i < requestsPerThread; i++) { + try { + routerService.routeRequest(request); + successCounter.incrementAndGet(); + } catch (Exception e) { + errorCounter.incrementAndGet(); + } + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + endLatch.countDown(); + } + }); + } + + long startTime = System.nanoTime(); + startLatch.countDown(); // Start all threads + + // Wait for all threads to finish + boolean completed = endLatch.await(30, TimeUnit.SECONDS); + long endTime = System.nanoTime(); + + executorService.shutdown(); + + double elapsedTimeMs = (endTime - startTime) / 1_000_000.0; + double requestsPerSecond = totalRequests / (elapsedTimeMs / 1000.0); + + System.out.println("Multi-thread performance (" + numThreads + " threads):"); + System.out.println("Requests processed: " + totalRequests); + System.out.println("Successful requests: " + successCounter.get()); + System.out.println("Failed requests: " + errorCounter.get()); + System.out.println("Total time (ms): " + elapsedTimeMs); + System.out.println("Requests per second: " + requestsPerSecond); + + assertTrue(completed, "Benchmark timed out"); + assertEquals(totalRequests, successCounter.get(), "Some requests failed"); + assertTrue(requestsPerSecond > 0); + } + + @Test + @Tag("performance") + void benchmark_differentAvailabilityScenarios() { + List scenarios = new ArrayList<>(); + + // All models available + scenarios.add(new AvailabilityScenario( + "All models available", + true, true, true, true + )); + + // Only one model available + scenarios.add(new AvailabilityScenario( + "Only OpenAI available", + true, false, false, false + )); + + scenarios.add(new AvailabilityScenario( + "Only Gemini available", + false, true, false, false + )); + + // Two models available + scenarios.add(new AvailabilityScenario( + "OpenAI and Claude available", + true, false, false, true + )); + + // Benchmark each scenario + int requestsPerScenario = 5000; + QueryRequest request = QueryRequest.builder() + .query("Performance test query") + .build(); + + for (AvailabilityScenario scenario : scenarios) { + routerService.setModelAvailability(ModelType.OPENAI, scenario.openaiAvailable); + routerService.setModelAvailability(ModelType.GEMINI, scenario.geminiAvailable); + routerService.setModelAvailability(ModelType.MISTRAL, scenario.mistralAvailable); + routerService.setModelAvailability(ModelType.CLAUDE, scenario.claudeAvailable); + + // Warm up + for (int i = 0; i < 100; i++) { + try { + routerService.routeRequest(request); + } catch (Exception e) { + // Ignore - may happen if no models available + } + } + + long startTime = System.nanoTime(); + int successCount = 0; + + for (int i = 0; i < requestsPerScenario; i++) { + try { + routerService.routeRequest(request); + successCount++; + } catch (Exception e) { + // Expected if no models available + } + } + + long endTime = System.nanoTime(); + double elapsedTimeMs = (endTime - startTime) / 1_000_000.0; + double requestsPerSecond = successCount / (elapsedTimeMs / 1000.0); + + System.out.println("Scenario: " + scenario.name); + System.out.println(" Successful requests: " + successCount + "/" + requestsPerScenario); + System.out.println(" Total time (ms): " + elapsedTimeMs); + System.out.println(" Requests per second: " + requestsPerSecond); + } + + // No assertions needed - this is a benchmark + assertTrue(true); + } + + @Test + @Tag("performance") + void benchmark_fallbackOnError() { + // Setup - multiple models available for fallback + routerService.setModelAvailability(ModelType.OPENAI, true); + routerService.setModelAvailability(ModelType.GEMINI, true); + routerService.setModelAvailability(ModelType.MISTRAL, true); + routerService.setModelAvailability(ModelType.CLAUDE, true); + + QueryRequest request = QueryRequest.builder() + .query("Performance test query") + .build(); + + ModelError retryableError = ModelError.rateLimitError(ModelType.OPENAI.toString()); + + // Warm up + for (int i = 0; i < 100; i++) { + routerService.fallbackOnError(ModelType.OPENAI, request, retryableError); + } + + // Benchmark + int numRequests = 5000; + long startTime = System.nanoTime(); + + for (int i = 0; i < numRequests; i++) { + routerService.fallbackOnError(ModelType.OPENAI, request, retryableError); + } + + long endTime = System.nanoTime(); + double elapsedTimeMs = (endTime - startTime) / 1_000_000.0; + double fallbacksPerSecond = numRequests / (elapsedTimeMs / 1000.0); + + System.out.println("Fallback performance:"); + System.out.println("Fallbacks processed: " + numRequests); + System.out.println("Total time (ms): " + elapsedTimeMs); + System.out.println("Fallbacks per second: " + fallbacksPerSecond); + + assertTrue(fallbacksPerSecond > 0); + } + + // Helper class for availability scenarios + private static class AvailabilityScenario { + final String name; + final boolean openaiAvailable; + final boolean geminiAvailable; + final boolean mistralAvailable; + final boolean claudeAvailable; + + AvailabilityScenario(String name, boolean openai, boolean gemini, + boolean mistral, boolean claude) { + this.name = name; + this.openaiAvailable = openai; + this.geminiAvailable = gemini; + this.mistralAvailable = mistral; + this.claudeAvailable = claude; + } + } +} \ No newline at end of file From b80523acfad7eb00f2bdad5848cfac77c602fa62 Mon Sep 17 00:00:00 2001 From: Mike Basios Date: Wed, 14 May 2025 00:36:07 +0100 Subject: [PATCH 4/6] feat: add benchmarks --- .../service/cache/CacheServiceTest.java | 239 +---------------- .../ratelimit/RateLimiterServiceTest.java | 201 +++++++++++++- .../service/router/RouterServiceTest.java | 246 +++++++++++++++++- 3 files changed, 446 insertions(+), 240 deletions(-) diff --git a/src/test/java/com/llmproxy/service/cache/CacheServiceTest.java b/src/test/java/com/llmproxy/service/cache/CacheServiceTest.java index 2ae9643..fc939a7 100644 --- a/src/test/java/com/llmproxy/service/cache/CacheServiceTest.java +++ b/src/test/java/com/llmproxy/service/cache/CacheServiceTest.java @@ -1,239 +1,3 @@ -package com.llmproxy.service.cache; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.llmproxy.model.ModelType; -import com.llmproxy.model.QueryRequest; -import com.llmproxy.model.QueryResponse; -import com.llmproxy.model.TaskType; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.parallel.Execution; -import org.junit.jupiter.api.parallel.ExecutionMode; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; - -import static org.junit.jupiter.api.Assertions.*; - -@Execution(ExecutionMode.CONCURRENT) -class CacheServiceTest { - - private static final Logger logger = LoggerFactory.getLogger(CacheServiceTest.class); - - private final AtomicReference cacheServiceRef = new AtomicReference<>(); - private final AtomicReference requestRef = new AtomicReference<>(); - private final AtomicReference responseRef = new AtomicReference<>(); - - @BeforeEach - void setUp() { - cacheServiceRef.set(new CacheService(true, 300, 1000, new ObjectMapper())); - - requestRef.set(QueryRequest.builder() - .query("Test query") - .model(ModelType.OPENAI) - .taskType(TaskType.TEXT_GENERATION) - .build()); - - responseRef.set(QueryResponse.builder() - .response("Test response") - .model(ModelType.OPENAI) - .build()); - } - - @Test - void get_cacheDisabled_returnsNull() { - CacheService disabledCache = new CacheService(false, 300, 1000, new ObjectMapper()); - disabledCache.set(requestRef.get(), responseRef.get()); - - assertNull(disabledCache.get(requestRef.get())); - } - - @Test - void get_cacheEnabled_cacheMiss_returnsNull() { - assertNull(cacheServiceRef.get().get(requestRef.get())); - } - - @Test - void get_cacheEnabled_cacheHit_returnsResponse() { - CacheService cacheService = cacheServiceRef.get(); - QueryRequest request = requestRef.get(); - QueryResponse response = responseRef.get(); - - cacheService.set(request, response); - - QueryResponse cachedResponse = cacheService.get(request); - assertNotNull(cachedResponse); - assertEquals(response.getResponse(), cachedResponse.getResponse()); - assertEquals(response.getModel(), cachedResponse.getModel()); - } - - @Test - void set_cacheDisabled_doesNothing() { - CacheService disabledCache = new CacheService(false, 300, 1000, new ObjectMapper()); - disabledCache.set(requestRef.get(), responseRef.get()); - - assertNull(disabledCache.get(requestRef.get())); - } - - @Test - void set_cacheEnabled_storesResponse() { - CacheService cacheService = cacheServiceRef.get(); - QueryRequest request = requestRef.get(); - QueryResponse response = responseRef.get(); - - cacheService.set(request, response); - - QueryResponse cachedResponse = cacheService.get(request); - assertNotNull(cachedResponse); - assertEquals(response.getResponse(), cachedResponse.getResponse()); - } - - @Test - void generateCacheKey_differentQueries_differentKeys() { - CacheService cacheService = cacheServiceRef.get(); - QueryResponse response = responseRef.get(); - - QueryRequest request1 = QueryRequest.builder() - .query("Query 1") - .model(ModelType.OPENAI) - .build(); - - QueryRequest request2 = QueryRequest.builder() - .query("Query 2") - .model(ModelType.OPENAI) - .build(); - - cacheService.set(request1, response); - - assertNotNull(cacheService.get(request1)); - assertNull(cacheService.get(request2)); - } - - @Test - void generateCacheKey_differentModels_differentKeys() { - CacheService cacheService = cacheServiceRef.get(); - QueryResponse response = responseRef.get(); - - QueryRequest request1 = QueryRequest.builder() - .query("Test query") - .model(ModelType.OPENAI) - .build(); - - QueryRequest request2 = QueryRequest.builder() - .query("Test query") - .model(ModelType.GEMINI) - .build(); - - cacheService.set(request1, response); - - assertNotNull(cacheService.get(request1)); - assertNull(cacheService.get(request2)); - } - - @Test - void generateCacheKey_differentTaskTypes_differentKeys() { - CacheService cacheService = cacheServiceRef.get(); - QueryResponse response = responseRef.get(); - - QueryRequest request1 = QueryRequest.builder() - .query("Test query") - .taskType(TaskType.TEXT_GENERATION) - .build(); - - QueryRequest request2 = QueryRequest.builder() - .query("Test query") - .taskType(TaskType.SUMMARIZATION) - .build(); - - cacheService.set(request1, response); - - assertNotNull(cacheService.get(request1)); - assertNull(cacheService.get(request2)); - } - - @Test - @Execution(ExecutionMode.SAME_THREAD) - void cacheExpiration_shortDuration_entriesExpireQuickly() throws InterruptedException { - // Set up cache with very short expiration time (1 second) - int shortExpirationTime = 1; // seconds - CacheService shortCache = new CacheService(true, shortExpirationTime, 1000, new ObjectMapper()); - QueryRequest request = requestRef.get(); - QueryResponse response = responseRef.get(); - - logger.info("Testing cache with expiration time: {} seconds", shortExpirationTime); - - shortCache.set(request, response); - - // Verify item is in the cache - assertNotNull(shortCache.get(request)); - - // Wait for the entry to expire - TimeUnit.SECONDS.sleep(shortExpirationTime + 1); - - // Verify item is no longer in the cache - assertNull(shortCache.get(request), - "Cache entry should have expired after " + shortExpirationTime + " seconds"); - } - - @Test - @Execution(ExecutionMode.SAME_THREAD) - void cacheExpiration_moderateDuration_entriesExpireAfterTimeout() throws InterruptedException { - // Set up cache with moderate expiration time (2 seconds) - int moderateExpirationTime = 2; // seconds - CacheService moderateCache = new CacheService(true, moderateExpirationTime, 1000, new ObjectMapper()); - QueryRequest request = requestRef.get(); - QueryResponse response = responseRef.get(); - - logger.info("Testing cache with expiration time: {} seconds", moderateExpirationTime); - - moderateCache.set(request, response); - - // Verify item is in the cache - assertNotNull(moderateCache.get(request)); - - // Check that it's still there after half the expiration time - TimeUnit.MILLISECONDS.sleep(moderateExpirationTime * 500); - assertNotNull(moderateCache.get(request), - "Cache entry should still exist after half the expiration time"); - - // Wait for the entry to expire - TimeUnit.MILLISECONDS.sleep(moderateExpirationTime * 1500); - - // Verify item is no longer in the cache - assertNull(moderateCache.get(request), - "Cache entry should have expired after " + moderateExpirationTime + " seconds"); - } - - @Test - @Execution(ExecutionMode.SAME_THREAD) - void cacheExpiration_longDuration_entriesRemainValid() throws InterruptedException { - // Set up cache with longer expiration time - int longExpirationTime = 5; // seconds - CacheService longCache = new CacheService(true, longExpirationTime, 1000, new ObjectMapper()); - QueryRequest request = requestRef.get(); - QueryResponse response = responseRef.get(); - - logger.info("Testing cache with expiration time: {} seconds", longExpirationTime); - - longCache.set(request, response); - - // Verify item is in the cache initially - assertNotNull(longCache.get(request)); - - // Wait for some time (but less than expiration) - TimeUnit.SECONDS.sleep(1); - - // Verify item is still in the cache - assertNotNull(longCache.get(request), - "Cache entry should still exist before expiration time"); - } - @Test @Tag("performance") void performanceBenchmark_lookupSpeed() { @@ -428,5 +192,4 @@ private List generateUniqueRequests(int count) { requests.add(req); } return requests; - } -} \ No newline at end of file + } \ No newline at end of file diff --git a/src/test/java/com/llmproxy/service/ratelimit/RateLimiterServiceTest.java b/src/test/java/com/llmproxy/service/ratelimit/RateLimiterServiceTest.java index 144fd0f..c9b66c3 100644 --- a/src/test/java/com/llmproxy/service/ratelimit/RateLimiterServiceTest.java +++ b/src/test/java/com/llmproxy/service/ratelimit/RateLimiterServiceTest.java @@ -133,4 +133,203 @@ void allowClient_zeroBurst_alwaysReturnsFalseForClients() { // Test for client2 assertFalse(parentLimiter.allowClient("client2")); // Client limiter inherits zero burst } -} \ No newline at end of file + + // --- PERFORMANCE BENCHMARK TESTS --- + + @Test + @DisplayName("Throughput benchmark: single-threaded, default/global limiter") + void throughput_singleThread_globalLimiter() { + RateLimiterService limiter = new RateLimiterService(600, 100); // 10 QPS global for 10 seconds window + int attempts = 100; + long start = System.nanoTime(); + int allowed = 0; + for (int i = 0; i < attempts; i++) { + if (limiter.allow()) allowed++; + } + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + // Within burst, all should be allowed + assertEquals(attempts, allowed); + System.out.println("[throughput_singleThread_globalLimiter] Time: " + elapsedMs + " ms for " + allowed + " ops"); + } + + @Test + @DisplayName("Throughput benchmark: multi-threaded, global limiter, burst exhaustion") + void throughput_multiThread_globalLimiter_burstLimit() throws InterruptedException { + final int burst = 20; + final int threads = 8; + final RateLimiterService limiter = new RateLimiterService(600, burst); // 10 QPS, burst of 20 + AtomicInteger accepted = new AtomicInteger(0); + AtomicInteger rejected = new AtomicInteger(0); + int attemptsPerThread = 10; + + List> tasks = new ArrayList<>(); + for (int t = 0; t < threads; t++) { + tasks.add(() -> { + for (int i = 0; i < attemptsPerThread; i++) { + if (limiter.allow()) accepted.incrementAndGet(); + else rejected.incrementAndGet(); + } + return null; + }); + } + long start = System.nanoTime(); + executor.invokeAll(tasks); + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + assertEquals(burst, accepted.get()); + assertEquals(threads * attemptsPerThread - burst, rejected.get()); + System.out.println("[throughput_multiThread_globalLimiter_burstLimit] Time: " + elapsedMs + " ms, Accepted: " + accepted.get() + ", Rejected: " + rejected.get()); + } + + @Test + @DisplayName("Latency benchmark: multi-threaded, global limiter") + void latency_multiThreaded_globalLimiter() throws InterruptedException { + final int burst = 40; + final int threads = 20; + final RateLimiterService limiter = new RateLimiterService(1000, burst); + final CyclicBarrier barrier = new CyclicBarrier(threads); + List latencies = new CopyOnWriteArrayList<>(); + + List> tasks = new ArrayList<>(); + for (int t = 0; t < threads; t++) { + tasks.add(() -> { + barrier.await(); + long t0 = System.nanoTime(); + limiter.allow(); + latencies.add(System.nanoTime() - t0); + return null; + }); + } + executor.invokeAll(tasks); + + double avgLatencyUs = latencies.stream().mapToLong(x -> x).average().orElse(0) / 1000.0; + long maxLatencyUs = latencies.stream().mapToLong(x -> x).max().orElse(0) / 1000; + System.out.println("[latency_multiThreaded_globalLimiter] Average: " + avgLatencyUs + " μs, Max: " + maxLatencyUs + " μs"); + assertTrue(avgLatencyUs < 5000); // Should be low for single op in-memory + } + + @Test + @DisplayName("Throughput & isolation: multi-threaded, client-specific, distinct limits") + void throughput_multiClient_isolatedLimits() throws InterruptedException { + final int clientCount = 10; + final int burst = 5; + final RateLimiterService limiter = new RateLimiterService(100, burst); // generous refill for test + List clientIds = new ArrayList<>(); + for (int i = 0; i < clientCount; i++) clientIds.add("client" + i); + + AtomicInteger allowed = new AtomicInteger(0); + AtomicInteger rejected = new AtomicInteger(0); + + List> tasks = new ArrayList<>(); + for (String client : clientIds) { + tasks.add(() -> { + // Let's try 10 attempts per client + for (int i = 0; i < 10; i++) { + if (limiter.allowClient(client)) { + allowed.incrementAndGet(); + } else { + rejected.incrementAndGet(); + } + } + return null; + }); + } + + long start = System.nanoTime(); + executor.invokeAll(tasks); + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + + // Each client can burst up to 5, so allowed = 5 * clientCount, rejected = rest + assertEquals(burst * clientCount, allowed.get()); + assertEquals((10 - burst) * clientCount, rejected.get()); + System.out.println("[throughput_multiClient_isolatedLimits] Time: " + elapsedMs + " ms, Allowed: " + allowed.get() + ", Rejected: " + rejected.get()); + } + + @Test + @DisplayName("Client-specific: variable client counts & request rates") + void clientSpecific_variablePatterns_benchmark() throws InterruptedException { + final int[] clientCounts = {1, 5, 20, 50}; + final int burst = 20; + for (int c = 0; c < clientCounts.length; c++) { + int numClients = clientCounts[c]; + RateLimiterService limiter = new RateLimiterService(200, burst); + AtomicInteger allowed = new AtomicInteger(0); + AtomicInteger rejected = new AtomicInteger(0); + List> tasks = new ArrayList<>(); + for (int i = 0; i < numClients; i++) { + String client = "client-" + i; + tasks.add(() -> { + for (int j = 0; j < burst + 10; j++) { + if (limiter.allowClient(client)) allowed.incrementAndGet(); + else rejected.incrementAndGet(); + } + return null; + }); + } + long start = System.nanoTime(); + executor.invokeAll(tasks); + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + int expectedAllowed = burst * numClients; + int expectedRejected = 10 * numClients; + assertEquals(expectedAllowed, allowed.get()); + assertEquals(expectedRejected, rejected.get()); + System.out.printf("[clientSpecific_variablePatterns_benchmark] Clients: %d, Time: %d ms, Allowed: %d, Rejected: %d%n", numClients, elapsedMs, allowed.get(), rejected.get()); + } + } + + @Test + @DisplayName("Resource usage: stress test with many clients and threads") + void resourceUsage_stress_highConcurrency_manyClients() throws InterruptedException { + final int clientCount = 100; + final int threads = 50; + final int burst = 10; + RateLimiterService limiter = new RateLimiterService(1000, burst); + AtomicInteger allowed = new AtomicInteger(0); + AtomicInteger rejected = new AtomicInteger(0); + + List clientIds = new ArrayList<>(); + for (int i = 0; i < clientCount; i++) clientIds.add("client-" + i); + + List> tasks = new ArrayList<>(); + for (int t = 0; t < threads; t++) { + tasks.add(() -> { + for (String clientId : clientIds) { + // Each thread does 5 attempts per client (should be over burst for most) + for (int att = 0; att < 5; att++) { + if (limiter.allowClient(clientId)) allowed.incrementAndGet(); + else rejected.incrementAndGet(); + } + } + return null; + }); + } + long memBefore = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory(); + long start = System.nanoTime(); + executor.invokeAll(tasks); + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + long memAfter = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory(); + long memDeltaKb = (memAfter - memBefore) / 1024; + System.out.printf("[resourceUsage_stress_highConcurrency_manyClients] Time: %d ms, Memory used: %d KB, Allowed: %d, Rejected: %d%n", + elapsedMs, memDeltaKb, allowed.get(), rejected.get()); + // We cannot assert exactly due to concurrent race but numbers should be reasonable + assertTrue(allowed.get() > 0); + assertTrue(rejected.get() > 0); + assertTrue(memDeltaKb >= 0); + } + + @Test + @DisplayName("Latency under exhaustion: repeated over-limit calls") + void latency_underExhaustion() { + RateLimiterService limiter = new RateLimiterService(60, 3); + + // Exhaust burst limit + for (int i = 0; i < 3; i++) assertTrue(limiter.allow()); + assertFalse(limiter.allow()); + + // Now, measure latency when over limit + long t0 = System.nanoTime(); + boolean allowed = limiter.allow(); + long elapsedUs = (System.nanoTime() - t0) / 1_000; + assertFalse(allowed); + System.out.println("[latency_underExhaustion] Over-limit check latency: " + elapsedUs + " μs (should be very low)"); + } +} diff --git a/src/test/java/com/llmproxy/service/router/RouterServiceTest.java b/src/test/java/com/llmproxy/service/router/RouterServiceTest.java index 3ead254..4dc1db8 100644 --- a/src/test/java/com/llmproxy/service/router/RouterServiceTest.java +++ b/src/test/java/com/llmproxy/service/router/RouterServiceTest.java @@ -329,4 +329,248 @@ void fallbackOnError_specifiedModelEqualsFailedModel_returnsAlternative() { ModelType result = routerService.fallbackOnError(ModelType.MISTRAL, request, error); assertTrue(result == ModelType.GEMINI || result == ModelType.CLAUDE); } -} \ No newline at end of file + @Test + @Tag("performance") + void benchmark_routeRequest_singleThread() { + // Setup - all models available + routerService.setModelAvailability(ModelType.OPENAI, true); + routerService.setModelAvailability(ModelType.GEMINI, true); + routerService.setModelAvailability(ModelType.MISTRAL, true); + routerService.setModelAvailability(ModelType.CLAUDE, true); + + QueryRequest request = QueryRequest.builder() + .query("Performance test query") + .build(); + + // Warm up + for (int i = 0; i < 100; i++) { + routerService.routeRequest(request); + } + + // Benchmark + int numRequests = 10000; + long startTime = System.nanoTime(); + + for (int i = 0; i < numRequests; i++) { + routerService.routeRequest(request); + } + + long endTime = System.nanoTime(); + double elapsedTimeMs = (endTime - startTime) / 1_000_000.0; + double requestsPerSecond = numRequests / (elapsedTimeMs / 1000.0); + + System.out.println("Single thread performance:"); + System.out.println("Requests processed: " + numRequests); + System.out.println("Total time (ms): " + elapsedTimeMs); + System.out.println("Requests per second: " + requestsPerSecond); + + // Simple assertion to verify test ran without error + assertTrue(requestsPerSecond > 0); + } + + @Test + @Tag("performance") + void benchmark_routeRequest_multiThread() throws InterruptedException { + // Setup - all models available + routerService.setModelAvailability(ModelType.OPENAI, true); + routerService.setModelAvailability(ModelType.GEMINI, true); + routerService.setModelAvailability(ModelType.MISTRAL, true); + routerService.setModelAvailability(ModelType.CLAUDE, true); + + int numThreads = 8; + int requestsPerThread = 1000; + int totalRequests = numThreads * requestsPerThread; + + QueryRequest request = QueryRequest.builder() + .query("Performance test query") + .build(); + + // Warm up + for (int i = 0; i < 100; i++) { + routerService.routeRequest(request); + } + + ExecutorService executorService = Executors.newFixedThreadPool(numThreads); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch endLatch = new CountDownLatch(numThreads); + AtomicInteger successCounter = new AtomicInteger(0); + AtomicInteger errorCounter = new AtomicInteger(0); + + for (int t = 0; t < numThreads; t++) { + executorService.submit(() -> { + try { + startLatch.await(); // Wait for all threads to be ready + + for (int i = 0; i < requestsPerThread; i++) { + try { + routerService.routeRequest(request); + successCounter.incrementAndGet(); + } catch (Exception e) { + errorCounter.incrementAndGet(); + } + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + endLatch.countDown(); + } + }); + } + + long startTime = System.nanoTime(); + startLatch.countDown(); // Start all threads + + // Wait for all threads to finish + boolean completed = endLatch.await(30, TimeUnit.SECONDS); + long endTime = System.nanoTime(); + + executorService.shutdown(); + + double elapsedTimeMs = (endTime - startTime) / 1_000_000.0; + double requestsPerSecond = totalRequests / (elapsedTimeMs / 1000.0); + + System.out.println("Multi-thread performance (" + numThreads + " threads):"); + System.out.println("Requests processed: " + totalRequests); + System.out.println("Successful requests: " + successCounter.get()); + System.out.println("Failed requests: " + errorCounter.get()); + System.out.println("Total time (ms): " + elapsedTimeMs); + System.out.println("Requests per second: " + requestsPerSecond); + + assertTrue(completed, "Benchmark timed out"); + assertEquals(totalRequests, successCounter.get(), "Some requests failed"); + assertTrue(requestsPerSecond > 0); + } + + @Test + @Tag("performance") + void benchmark_differentAvailabilityScenarios() { + List scenarios = new ArrayList<>(); + + // All models available + scenarios.add(new AvailabilityScenario( + "All models available", + true, true, true, true + )); + + // Only one model available + scenarios.add(new AvailabilityScenario( + "Only OpenAI available", + true, false, false, false + )); + + scenarios.add(new AvailabilityScenario( + "Only Gemini available", + false, true, false, false + )); + + // Two models available + scenarios.add(new AvailabilityScenario( + "OpenAI and Claude available", + true, false, false, true + )); + + // Benchmark each scenario + int requestsPerScenario = 5000; + QueryRequest request = QueryRequest.builder() + .query("Performance test query") + .build(); + + for (AvailabilityScenario scenario : scenarios) { + routerService.setModelAvailability(ModelType.OPENAI, scenario.openaiAvailable); + routerService.setModelAvailability(ModelType.GEMINI, scenario.geminiAvailable); + routerService.setModelAvailability(ModelType.MISTRAL, scenario.mistralAvailable); + routerService.setModelAvailability(ModelType.CLAUDE, scenario.claudeAvailable); + + // Warm up + for (int i = 0; i < 100; i++) { + try { + routerService.routeRequest(request); + } catch (Exception e) { + // Ignore - may happen if no models available + } + } + + long startTime = System.nanoTime(); + int successCount = 0; + + for (int i = 0; i < requestsPerScenario; i++) { + try { + routerService.routeRequest(request); + successCount++; + } catch (Exception e) { + // Expected if no models available + } + } + + long endTime = System.nanoTime(); + double elapsedTimeMs = (endTime - startTime) / 1_000_000.0; + double requestsPerSecond = successCount / (elapsedTimeMs / 1000.0); + + System.out.println("Scenario: " + scenario.name); + System.out.println(" Successful requests: " + successCount + "/" + requestsPerScenario); + System.out.println(" Total time (ms): " + elapsedTimeMs); + System.out.println(" Requests per second: " + requestsPerSecond); + } + + // No assertions needed - this is a benchmark + assertTrue(true); + } + + @Test + @Tag("performance") + void benchmark_fallbackOnError() { + // Setup - multiple models available for fallback + routerService.setModelAvailability(ModelType.OPENAI, true); + routerService.setModelAvailability(ModelType.GEMINI, true); + routerService.setModelAvailability(ModelType.MISTRAL, true); + routerService.setModelAvailability(ModelType.CLAUDE, true); + + QueryRequest request = QueryRequest.builder() + .query("Performance test query") + .build(); + + ModelError retryableError = ModelError.rateLimitError(ModelType.OPENAI.toString()); + + // Warm up + for (int i = 0; i < 100; i++) { + routerService.fallbackOnError(ModelType.OPENAI, request, retryableError); + } + + // Benchmark + int numRequests = 5000; + long startTime = System.nanoTime(); + + for (int i = 0; i < numRequests; i++) { + routerService.fallbackOnError(ModelType.OPENAI, request, retryableError); + } + + long endTime = System.nanoTime(); + double elapsedTimeMs = (endTime - startTime) / 1_000_000.0; + double fallbacksPerSecond = numRequests / (elapsedTimeMs / 1000.0); + + System.out.println("Fallback performance:"); + System.out.println("Fallbacks processed: " + numRequests); + System.out.println("Total time (ms): " + elapsedTimeMs); + System.out.println("Fallbacks per second: " + fallbacksPerSecond); + + assertTrue(fallbacksPerSecond > 0); + } + + // Helper class for availability scenarios + private static class AvailabilityScenario { + final String name; + final boolean openaiAvailable; + final boolean geminiAvailable; + final boolean mistralAvailable; + final boolean claudeAvailable; + + AvailabilityScenario(String name, boolean openai, boolean gemini, + boolean mistral, boolean claude) { + this.name = name; + this.openaiAvailable = openai; + this.geminiAvailable = gemini; + this.mistralAvailable = mistral; + this.claudeAvailable = claude; + } + } +} From 300da0e10c2f504a3369c762ddc7e3c77b35c660 Mon Sep 17 00:00:00 2001 From: Mike Basios Date: Wed, 14 May 2025 00:40:56 +0100 Subject: [PATCH 5/6] feat: add benchmarks --- .../service/cache/CacheServiceTest.java | 239 +++++++++++++++++- 1 file changed, 238 insertions(+), 1 deletion(-) diff --git a/src/test/java/com/llmproxy/service/cache/CacheServiceTest.java b/src/test/java/com/llmproxy/service/cache/CacheServiceTest.java index fc939a7..2ae9643 100644 --- a/src/test/java/com/llmproxy/service/cache/CacheServiceTest.java +++ b/src/test/java/com/llmproxy/service/cache/CacheServiceTest.java @@ -1,3 +1,239 @@ +package com.llmproxy.service.cache; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.llmproxy.model.ModelType; +import com.llmproxy.model.QueryRequest; +import com.llmproxy.model.QueryResponse; +import com.llmproxy.model.TaskType; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.*; + +@Execution(ExecutionMode.CONCURRENT) +class CacheServiceTest { + + private static final Logger logger = LoggerFactory.getLogger(CacheServiceTest.class); + + private final AtomicReference cacheServiceRef = new AtomicReference<>(); + private final AtomicReference requestRef = new AtomicReference<>(); + private final AtomicReference responseRef = new AtomicReference<>(); + + @BeforeEach + void setUp() { + cacheServiceRef.set(new CacheService(true, 300, 1000, new ObjectMapper())); + + requestRef.set(QueryRequest.builder() + .query("Test query") + .model(ModelType.OPENAI) + .taskType(TaskType.TEXT_GENERATION) + .build()); + + responseRef.set(QueryResponse.builder() + .response("Test response") + .model(ModelType.OPENAI) + .build()); + } + + @Test + void get_cacheDisabled_returnsNull() { + CacheService disabledCache = new CacheService(false, 300, 1000, new ObjectMapper()); + disabledCache.set(requestRef.get(), responseRef.get()); + + assertNull(disabledCache.get(requestRef.get())); + } + + @Test + void get_cacheEnabled_cacheMiss_returnsNull() { + assertNull(cacheServiceRef.get().get(requestRef.get())); + } + + @Test + void get_cacheEnabled_cacheHit_returnsResponse() { + CacheService cacheService = cacheServiceRef.get(); + QueryRequest request = requestRef.get(); + QueryResponse response = responseRef.get(); + + cacheService.set(request, response); + + QueryResponse cachedResponse = cacheService.get(request); + assertNotNull(cachedResponse); + assertEquals(response.getResponse(), cachedResponse.getResponse()); + assertEquals(response.getModel(), cachedResponse.getModel()); + } + + @Test + void set_cacheDisabled_doesNothing() { + CacheService disabledCache = new CacheService(false, 300, 1000, new ObjectMapper()); + disabledCache.set(requestRef.get(), responseRef.get()); + + assertNull(disabledCache.get(requestRef.get())); + } + + @Test + void set_cacheEnabled_storesResponse() { + CacheService cacheService = cacheServiceRef.get(); + QueryRequest request = requestRef.get(); + QueryResponse response = responseRef.get(); + + cacheService.set(request, response); + + QueryResponse cachedResponse = cacheService.get(request); + assertNotNull(cachedResponse); + assertEquals(response.getResponse(), cachedResponse.getResponse()); + } + + @Test + void generateCacheKey_differentQueries_differentKeys() { + CacheService cacheService = cacheServiceRef.get(); + QueryResponse response = responseRef.get(); + + QueryRequest request1 = QueryRequest.builder() + .query("Query 1") + .model(ModelType.OPENAI) + .build(); + + QueryRequest request2 = QueryRequest.builder() + .query("Query 2") + .model(ModelType.OPENAI) + .build(); + + cacheService.set(request1, response); + + assertNotNull(cacheService.get(request1)); + assertNull(cacheService.get(request2)); + } + + @Test + void generateCacheKey_differentModels_differentKeys() { + CacheService cacheService = cacheServiceRef.get(); + QueryResponse response = responseRef.get(); + + QueryRequest request1 = QueryRequest.builder() + .query("Test query") + .model(ModelType.OPENAI) + .build(); + + QueryRequest request2 = QueryRequest.builder() + .query("Test query") + .model(ModelType.GEMINI) + .build(); + + cacheService.set(request1, response); + + assertNotNull(cacheService.get(request1)); + assertNull(cacheService.get(request2)); + } + + @Test + void generateCacheKey_differentTaskTypes_differentKeys() { + CacheService cacheService = cacheServiceRef.get(); + QueryResponse response = responseRef.get(); + + QueryRequest request1 = QueryRequest.builder() + .query("Test query") + .taskType(TaskType.TEXT_GENERATION) + .build(); + + QueryRequest request2 = QueryRequest.builder() + .query("Test query") + .taskType(TaskType.SUMMARIZATION) + .build(); + + cacheService.set(request1, response); + + assertNotNull(cacheService.get(request1)); + assertNull(cacheService.get(request2)); + } + + @Test + @Execution(ExecutionMode.SAME_THREAD) + void cacheExpiration_shortDuration_entriesExpireQuickly() throws InterruptedException { + // Set up cache with very short expiration time (1 second) + int shortExpirationTime = 1; // seconds + CacheService shortCache = new CacheService(true, shortExpirationTime, 1000, new ObjectMapper()); + QueryRequest request = requestRef.get(); + QueryResponse response = responseRef.get(); + + logger.info("Testing cache with expiration time: {} seconds", shortExpirationTime); + + shortCache.set(request, response); + + // Verify item is in the cache + assertNotNull(shortCache.get(request)); + + // Wait for the entry to expire + TimeUnit.SECONDS.sleep(shortExpirationTime + 1); + + // Verify item is no longer in the cache + assertNull(shortCache.get(request), + "Cache entry should have expired after " + shortExpirationTime + " seconds"); + } + + @Test + @Execution(ExecutionMode.SAME_THREAD) + void cacheExpiration_moderateDuration_entriesExpireAfterTimeout() throws InterruptedException { + // Set up cache with moderate expiration time (2 seconds) + int moderateExpirationTime = 2; // seconds + CacheService moderateCache = new CacheService(true, moderateExpirationTime, 1000, new ObjectMapper()); + QueryRequest request = requestRef.get(); + QueryResponse response = responseRef.get(); + + logger.info("Testing cache with expiration time: {} seconds", moderateExpirationTime); + + moderateCache.set(request, response); + + // Verify item is in the cache + assertNotNull(moderateCache.get(request)); + + // Check that it's still there after half the expiration time + TimeUnit.MILLISECONDS.sleep(moderateExpirationTime * 500); + assertNotNull(moderateCache.get(request), + "Cache entry should still exist after half the expiration time"); + + // Wait for the entry to expire + TimeUnit.MILLISECONDS.sleep(moderateExpirationTime * 1500); + + // Verify item is no longer in the cache + assertNull(moderateCache.get(request), + "Cache entry should have expired after " + moderateExpirationTime + " seconds"); + } + + @Test + @Execution(ExecutionMode.SAME_THREAD) + void cacheExpiration_longDuration_entriesRemainValid() throws InterruptedException { + // Set up cache with longer expiration time + int longExpirationTime = 5; // seconds + CacheService longCache = new CacheService(true, longExpirationTime, 1000, new ObjectMapper()); + QueryRequest request = requestRef.get(); + QueryResponse response = responseRef.get(); + + logger.info("Testing cache with expiration time: {} seconds", longExpirationTime); + + longCache.set(request, response); + + // Verify item is in the cache initially + assertNotNull(longCache.get(request)); + + // Wait for some time (but less than expiration) + TimeUnit.SECONDS.sleep(1); + + // Verify item is still in the cache + assertNotNull(longCache.get(request), + "Cache entry should still exist before expiration time"); + } + @Test @Tag("performance") void performanceBenchmark_lookupSpeed() { @@ -192,4 +428,5 @@ private List generateUniqueRequests(int count) { requests.add(req); } return requests; - } \ No newline at end of file + } +} \ No newline at end of file From a0c3f5e7e17001505c24bd9af79f4061ad51fc07 Mon Sep 17 00:00:00 2001 From: Mike Basios Date: Wed, 14 May 2025 00:51:50 +0100 Subject: [PATCH 6/6] feat: add some more test for the LLMProxyController --- .../controller/LlmProxyControllerTest.java | 226 +++++++++--------- 1 file changed, 111 insertions(+), 115 deletions(-) diff --git a/src/test/java/com/llmproxy/controller/LlmProxyControllerTest.java b/src/test/java/com/llmproxy/controller/LlmProxyControllerTest.java index dd31561..5dbda44 100644 --- a/src/test/java/com/llmproxy/controller/LlmProxyControllerTest.java +++ b/src/test/java/com/llmproxy/controller/LlmProxyControllerTest.java @@ -21,7 +21,6 @@ import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.mock.web.MockHttpServletRequest; -import org.junit.jupiter.api.DisplayName; import java.time.Instant; import java.util.Map; @@ -45,30 +44,28 @@ class LlmProxyControllerTest { @Mock private RateLimiterService rateLimiterService; - + @Mock private LlmClient llmClient; - + private LlmProxyController controller; private MockHttpServletRequest mockRequest; - + @BeforeEach void setUp() { controller = new LlmProxyController(routerService, clientFactory, cacheService, rateLimiterService); mockRequest = new MockHttpServletRequest(); mockRequest.setRemoteAddr("127.0.0.1"); - lenient().when(rateLimiterService.allowClient(anyString())).thenReturn(true); lenient().when(clientFactory.getClient(any(ModelType.class))).thenReturn(llmClient); } @Test - @DisplayName("Should return valid response for a valid query request") void query_validRequest_returnsResponse() { QueryRequest request = QueryRequest.builder() .query("Test query") .build(); - + QueryResult queryResult = QueryResult.builder() .response("Test response") .statusCode(HttpStatus.OK.value()) @@ -78,13 +75,13 @@ void query_validRequest_returnsResponse() { .numTokens(30) .responseTimeMs(100) .build(); - + lenient().when(cacheService.get(any(QueryRequest.class))).thenReturn(null); lenient().when(routerService.routeRequest(any(QueryRequest.class))).thenReturn(ModelType.OPENAI); lenient().when(llmClient.query(any(), any())).thenReturn(queryResult); - + ResponseEntity response = controller.query(request, mockRequest); - + assertEquals(HttpStatus.OK, response.getStatusCode()); assertNotNull(response.getBody()); assertEquals("Test response", response.getBody().getResponse()); @@ -95,14 +92,13 @@ void query_validRequest_returnsResponse() { } @Test - @DisplayName("Should return bad request response for empty query") void query_emptyQuery_returnsBadRequest() { QueryRequest request = QueryRequest.builder() .query("") .build(); - + ResponseEntity response = controller.query(request, mockRequest); - + assertEquals(HttpStatus.BAD_REQUEST, response.getStatusCode()); assertNotNull(response.getBody()); assertEquals("Query cannot be empty", response.getBody().getError()); @@ -110,16 +106,15 @@ void query_emptyQuery_returnsBadRequest() { } @Test - @DisplayName("Should return too many requests response when rate limited") void query_rateLimited_returnsTooManyRequests() { QueryRequest request = QueryRequest.builder() .query("Test query") .build(); - + lenient().when(rateLimiterService.allowClient(anyString())).thenReturn(false); - + ResponseEntity response = controller.query(request, mockRequest); - + assertEquals(HttpStatus.TOO_MANY_REQUESTS, response.getStatusCode()); assertNotNull(response.getBody()); assertEquals("Rate limit exceeded. Please try again later.", response.getBody().getError()); @@ -127,23 +122,22 @@ void query_rateLimited_returnsTooManyRequests() { } @Test - @DisplayName("Should return cached response when available") void query_cachedResponse_returnsCachedResponse() { QueryRequest request = QueryRequest.builder() .query("Test query") .build(); - + QueryResponse cachedResponse = QueryResponse.builder() .response("Cached response") .model(ModelType.OPENAI) .cached(true) .timestamp(Instant.now()) .build(); - + lenient().when(cacheService.get(any(QueryRequest.class))).thenReturn(cachedResponse); - + ResponseEntity response = controller.query(request, mockRequest); - + assertEquals(HttpStatus.OK, response.getStatusCode()); assertNotNull(response.getBody()); assertEquals("Cached response", response.getBody().getResponse()); @@ -152,20 +146,19 @@ void query_cachedResponse_returnsCachedResponse() { } @Test - @DisplayName("Should return error response when model error occurs") void query_modelError_returnsErrorResponse() { QueryRequest request = QueryRequest.builder() .query("Test query") .build(); - + ModelError apiKeyError = ModelError.apiKeyMissingError(ModelType.OPENAI.toString()); - + lenient().when(cacheService.get(any(QueryRequest.class))).thenReturn(null); lenient().when(routerService.routeRequest(any(QueryRequest.class))).thenReturn(ModelType.OPENAI); lenient().when(llmClient.query(any(), any())).thenThrow(apiKeyError); - + ResponseEntity response = controller.query(request, mockRequest); - + assertEquals(HttpStatus.UNAUTHORIZED, response.getStatusCode()); assertNotNull(response.getBody()); assertEquals("API key not configured", response.getBody().getError()); @@ -173,7 +166,6 @@ void query_modelError_returnsErrorResponse() { } @Test - @DisplayName("Should return model availability status") void status_returnsAvailability() { StatusResponse statusResponse = StatusResponse.builder() .openai(true) @@ -181,11 +173,11 @@ void status_returnsAvailability() { .mistral(true) .claude(false) .build(); - + lenient().when(routerService.getAvailability()).thenReturn(statusResponse); - + ResponseEntity response = controller.status(mockRequest); - + assertEquals(HttpStatus.OK, response.getStatusCode()); assertNotNull(response.getBody()); assertTrue(response.getBody().isOpenai()); @@ -195,106 +187,110 @@ void status_returnsAvailability() { } @Test - @DisplayName("Should return OK status for health endpoint") - void health_returnsOk() { - ResponseEntity> response = controller.health(mockRequest); - + void status_onlyGeminiAvailable_returnsCorrectAvailability() { + StatusResponse statusResponse = StatusResponse.builder() + .openai(false) + .gemini(true) + .mistral(false) + .claude(false) + .build(); + + when(routerService.getAvailability()).thenReturn(statusResponse); + + ResponseEntity response = controller.status(mockRequest); + assertEquals(HttpStatus.OK, response.getStatusCode()); assertNotNull(response.getBody()); - assertEquals("ok", response.getBody().get("status")); + assertFalse(response.getBody().isOpenai()); + assertTrue(response.getBody().isGemini()); + assertFalse(response.getBody().isMistral()); + assertFalse(response.getBody().isClaude()); } - + @Test - @DisplayName("Should return file when download request is valid") - void download_validRequest_returnsFile() { - Map request = Map.of( - "response", "Test response", - "format", "txt" - ); - - ResponseEntity response = controller.download(request, mockRequest); - + void status_mistralAndClaudeAvailable_returnsCorrectAvailability() { + StatusResponse statusResponse = StatusResponse.builder() + .openai(false) + .gemini(false) + .mistral(true) + .claude(true) + .build(); + + when(routerService.getAvailability()).thenReturn(statusResponse); + + ResponseEntity response = controller.status(mockRequest); + assertEquals(HttpStatus.OK, response.getStatusCode()); - assertEquals(MediaType.TEXT_PLAIN_VALUE, response.getHeaders().getContentType().toString()); - assertEquals("attachment; filename=llm_response.txt", response.getHeaders().getFirst("Content-Disposition")); - assertEquals("Test response", new String(response.getBody())); - } - - @Test - @DisplayName("Should return bad request when download request has empty response") - void download_emptyResponse_returnsBadRequest() { - Map request = Map.of( - "response", "", - "format", "txt" - ); - - ResponseEntity response = controller.download(request, mockRequest); - - assertEquals(HttpStatus.BAD_REQUEST, response.getStatusCode()); - } - - @Test - @DisplayName("Should return too many requests when rate limited for download") - void download_rateLimited_returnsTooManyRequests() { - Map request = Map.of( - "response", "Test response", - "format", "txt" - ); - - lenient().when(rateLimiterService.allowClient(anyString())).thenReturn(false); - - ResponseEntity response = controller.download(request, mockRequest); - - assertEquals(HttpStatus.TOO_MANY_REQUESTS, response.getStatusCode()); + assertNotNull(response.getBody()); + assertFalse(response.getBody().isOpenai()); + assertFalse(response.getBody().isGemini()); + assertTrue(response.getBody().isMistral()); + assertTrue(response.getBody().isClaude()); } - + @Test - @DisplayName("Should use default format when format not specified") - void download_nullFormat_usesDefaultFormat() { - Map request = Map.of( - "response", "Test response" - ); - - ResponseEntity response = controller.download(request, mockRequest); - + void status_allModelsAvailable_returnsCorrectAvailability() { + StatusResponse statusResponse = StatusResponse.builder() + .openai(true) + .gemini(true) + .mistral(true) + .claude(true) + .build(); + + when(routerService.getAvailability()).thenReturn(statusResponse); + + ResponseEntity response = controller.status(mockRequest); + assertEquals(HttpStatus.OK, response.getStatusCode()); - assertEquals(MediaType.TEXT_PLAIN_VALUE, response.getHeaders().getContentType().toString()); - assertEquals("attachment; filename=llm_response.txt", response.getHeaders().getFirst("Content-Disposition")); + assertNotNull(response.getBody()); + assertTrue(response.getBody().isOpenai()); + assertTrue(response.getBody().isGemini()); + assertTrue(response.getBody().isMistral()); + assertTrue(response.getBody().isClaude()); } - + @Test - @DisplayName("Should return too many requests when rate limited for status") - void status_rateLimited_returnsTooManyRequests() { - lenient().when(rateLimiterService.allowClient(anyString())).thenReturn(false); - + void status_noModelsAvailable_returnsCorrectAvailability() { + StatusResponse statusResponse = StatusResponse.builder() + .openai(false) + .gemini(false) + .mistral(false) + .claude(false) + .build(); + + when(routerService.getAvailability()).thenReturn(statusResponse); + ResponseEntity response = controller.status(mockRequest); - - assertEquals(HttpStatus.TOO_MANY_REQUESTS, response.getStatusCode()); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + assertNotNull(response.getBody()); + assertFalse(response.getBody().isOpenai()); + assertFalse(response.getBody().isGemini()); + assertFalse(response.getBody().isMistral()); + assertFalse(response.getBody().isClaude()); } - + @Test - @DisplayName("Should return too many requests when rate limited for health") - void health_rateLimited_returnsTooManyRequests() { - lenient().when(rateLimiterService.allowClient(anyString())).thenReturn(false); - + void health_returnsOk() { ResponseEntity> response = controller.health(mockRequest); - - assertEquals(HttpStatus.TOO_MANY_REQUESTS, response.getStatusCode()); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + assertNotNull(response.getBody()); + assertEquals("ok", response.getBody().get("status")); } @Test - @DisplayName("Should return bad request for query exceeding maximum length") - void query_tooLongQuery_returnsBadRequest() { - String longQuery = "a".repeat(32001); - QueryRequest request = QueryRequest.builder() - .query(longQuery) - .build(); - - ResponseEntity response = controller.query(request, mockRequest); - - assertEquals(HttpStatus.BAD_REQUEST, response.getStatusCode()); - assertNotNull(response.getBody()); - assertEquals("Query exceeds maximum length of 32000 characters", response.getBody().getError()); - assertEquals("validation_error", response.getBody().getErrorType()); + void download_validRequest_returnsFile() { + Map request = Map.of( + "response", "Test response", + "format", "txt" + ); + + ResponseEntity response = controller.download(request, mockRequest); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + assertEquals(MediaType.TEXT_PLAIN_VALUE, response.getHeaders().getContentType().toString()); + assertEquals("attachment; filename=llm_response.txt", response.getHeaders().getFirst("Content-Disposition")); + assertEquals("Test response", new String(response.getBody())); } } \ No newline at end of file