Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 112 additions & 29 deletions src/test/java/com/llmproxy/controller/LlmProxyControllerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,18 @@ class LlmProxyControllerTest {

@Mock
private RateLimiterService rateLimiterService;

@Mock
private LlmClient llmClient;

private LlmProxyController controller;
private MockHttpServletRequest mockRequest;

@BeforeEach
void setUp() {
controller = new LlmProxyController(routerService, clientFactory, cacheService, rateLimiterService);
mockRequest = new MockHttpServletRequest();
mockRequest.setRemoteAddr("127.0.0.1");

lenient().when(rateLimiterService.allowClient(anyString())).thenReturn(true);
lenient().when(clientFactory.getClient(any(ModelType.class))).thenReturn(llmClient);
}
Expand All @@ -66,7 +65,7 @@ void query_validRequest_returnsResponse() {
QueryRequest request = QueryRequest.builder()
.query("Test query")
.build();

QueryResult queryResult = QueryResult.builder()
.response("Test response")
.statusCode(HttpStatus.OK.value())
Expand All @@ -76,13 +75,13 @@ void query_validRequest_returnsResponse() {
.numTokens(30)
.responseTimeMs(100)
.build();

lenient().when(cacheService.get(any(QueryRequest.class))).thenReturn(null);
lenient().when(routerService.routeRequest(any(QueryRequest.class))).thenReturn(ModelType.OPENAI);
lenient().when(llmClient.query(any(), any())).thenReturn(queryResult);

ResponseEntity<QueryResponse> response = controller.query(request, mockRequest);

assertEquals(HttpStatus.OK, response.getStatusCode());
assertNotNull(response.getBody());
assertEquals("Test response", response.getBody().getResponse());
Expand All @@ -97,9 +96,9 @@ void query_emptyQuery_returnsBadRequest() {
QueryRequest request = QueryRequest.builder()
.query("")
.build();

ResponseEntity<QueryResponse> response = controller.query(request, mockRequest);

assertEquals(HttpStatus.BAD_REQUEST, response.getStatusCode());
assertNotNull(response.getBody());
assertEquals("Query cannot be empty", response.getBody().getError());
Expand All @@ -111,11 +110,11 @@ void query_rateLimited_returnsTooManyRequests() {
QueryRequest request = QueryRequest.builder()
.query("Test query")
.build();

lenient().when(rateLimiterService.allowClient(anyString())).thenReturn(false);

ResponseEntity<QueryResponse> response = controller.query(request, mockRequest);

assertEquals(HttpStatus.TOO_MANY_REQUESTS, response.getStatusCode());
assertNotNull(response.getBody());
assertEquals("Rate limit exceeded. Please try again later.", response.getBody().getError());
Expand All @@ -127,18 +126,18 @@ void query_cachedResponse_returnsCachedResponse() {
QueryRequest request = QueryRequest.builder()
.query("Test query")
.build();

QueryResponse cachedResponse = QueryResponse.builder()
.response("Cached response")
.model(ModelType.OPENAI)
.cached(true)
.timestamp(Instant.now())
.build();

lenient().when(cacheService.get(any(QueryRequest.class))).thenReturn(cachedResponse);

ResponseEntity<QueryResponse> response = controller.query(request, mockRequest);

assertEquals(HttpStatus.OK, response.getStatusCode());
assertNotNull(response.getBody());
assertEquals("Cached response", response.getBody().getResponse());
Expand All @@ -151,15 +150,15 @@ void query_modelError_returnsErrorResponse() {
QueryRequest request = QueryRequest.builder()
.query("Test query")
.build();

ModelError apiKeyError = ModelError.apiKeyMissingError(ModelType.OPENAI.toString());

lenient().when(cacheService.get(any(QueryRequest.class))).thenReturn(null);
lenient().when(routerService.routeRequest(any(QueryRequest.class))).thenReturn(ModelType.OPENAI);
lenient().when(llmClient.query(any(), any())).thenThrow(apiKeyError);

ResponseEntity<QueryResponse> response = controller.query(request, mockRequest);

assertEquals(HttpStatus.UNAUTHORIZED, response.getStatusCode());
assertNotNull(response.getBody());
assertEquals("API key not configured", response.getBody().getError());
Expand All @@ -174,11 +173,11 @@ void status_returnsAvailability() {
.mistral(true)
.claude(false)
.build();

lenient().when(routerService.getAvailability()).thenReturn(statusResponse);

ResponseEntity<StatusResponse> response = controller.status(mockRequest);

assertEquals(HttpStatus.OK, response.getStatusCode());
assertNotNull(response.getBody());
assertTrue(response.getBody().isOpenai());
Expand All @@ -187,27 +186,111 @@ void status_returnsAvailability() {
assertFalse(response.getBody().isClaude());
}

@Test
void status_onlyGeminiAvailable_returnsCorrectAvailability() {
StatusResponse statusResponse = StatusResponse.builder()
.openai(false)
.gemini(true)
.mistral(false)
.claude(false)
.build();

when(routerService.getAvailability()).thenReturn(statusResponse);

ResponseEntity<StatusResponse> response = controller.status(mockRequest);

assertEquals(HttpStatus.OK, response.getStatusCode());
assertNotNull(response.getBody());
assertFalse(response.getBody().isOpenai());
assertTrue(response.getBody().isGemini());
assertFalse(response.getBody().isMistral());
assertFalse(response.getBody().isClaude());
}

@Test
void status_mistralAndClaudeAvailable_returnsCorrectAvailability() {
StatusResponse statusResponse = StatusResponse.builder()
.openai(false)
.gemini(false)
.mistral(true)
.claude(true)
.build();

when(routerService.getAvailability()).thenReturn(statusResponse);

ResponseEntity<StatusResponse> response = controller.status(mockRequest);

assertEquals(HttpStatus.OK, response.getStatusCode());
assertNotNull(response.getBody());
assertFalse(response.getBody().isOpenai());
assertFalse(response.getBody().isGemini());
assertTrue(response.getBody().isMistral());
assertTrue(response.getBody().isClaude());
}

@Test
void status_allModelsAvailable_returnsCorrectAvailability() {
StatusResponse statusResponse = StatusResponse.builder()
.openai(true)
.gemini(true)
.mistral(true)
.claude(true)
.build();

when(routerService.getAvailability()).thenReturn(statusResponse);

ResponseEntity<StatusResponse> response = controller.status(mockRequest);

assertEquals(HttpStatus.OK, response.getStatusCode());
assertNotNull(response.getBody());
assertTrue(response.getBody().isOpenai());
assertTrue(response.getBody().isGemini());
assertTrue(response.getBody().isMistral());
assertTrue(response.getBody().isClaude());
}

@Test
void status_noModelsAvailable_returnsCorrectAvailability() {
StatusResponse statusResponse = StatusResponse.builder()
.openai(false)
.gemini(false)
.mistral(false)
.claude(false)
.build();

when(routerService.getAvailability()).thenReturn(statusResponse);

ResponseEntity<StatusResponse> response = controller.status(mockRequest);

assertEquals(HttpStatus.OK, response.getStatusCode());
assertNotNull(response.getBody());
assertFalse(response.getBody().isOpenai());
assertFalse(response.getBody().isGemini());
assertFalse(response.getBody().isMistral());
assertFalse(response.getBody().isClaude());
}

@Test
void health_returnsOk() {
ResponseEntity<Map<String, Object>> response = controller.health(mockRequest);

assertEquals(HttpStatus.OK, response.getStatusCode());
assertNotNull(response.getBody());
assertEquals("ok", response.getBody().get("status"));
}

@Test
void download_validRequest_returnsFile() {
Map<String, String> request = Map.of(
"response", "Test response",
"format", "txt"
);

ResponseEntity<byte[]> response = controller.download(request, mockRequest);

assertEquals(HttpStatus.OK, response.getStatusCode());
assertEquals(MediaType.TEXT_PLAIN_VALUE, response.getHeaders().getContentType().toString());
assertEquals("attachment; filename=llm_response.txt", response.getHeaders().getFirst("Content-Disposition"));
assertEquals("Test response", new String(response.getBody()));
}
}
}
Loading