From a4f4199f64059656e63d30371338b85bf0547262 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20K=C3=A1konyi?= Date: Wed, 11 Mar 2026 09:16:53 +0100 Subject: [PATCH 01/13] Fix Vertex AI listSessions null handling --- .../adk/sessions/VertexAiSessionService.java | 10 +++--- .../sessions/VertexAiSessionServiceTest.java | 34 +++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 4336f96c9..b62add27a 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -128,15 +128,17 @@ public Single listSessions(String appName, String userId) .map( listSessionsResponseMap -> parseListSessionsResponse(listSessionsResponseMap, appName, userId)) - .defaultIfEmpty(ListSessionsResponse.builder().build()); + .defaultIfEmpty(ListSessionsResponse.builder().sessions(new ArrayList<>()).build()); } private ListSessionsResponse parseListSessionsResponse( JsonNode listSessionsResponseMap, String appName, String userId) { + JsonNode sessionsNode = listSessionsResponseMap.get("sessions"); + if (sessionsNode == null || sessionsNode.isNull() || sessionsNode.isEmpty()) { + return ListSessionsResponse.builder().sessions(new ArrayList<>()).build(); + } List> apiSessions = - objectMapper.convertValue( - listSessionsResponseMap.get("sessions"), - new TypeReference>>() {}); + objectMapper.convertValue(sessionsNode, new TypeReference>>() {}); List sessions = new ArrayList<>(); for (Map apiSession : apiSessions) { diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index def4faf4c..3dab94b46 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -25,6 +25,8 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import okhttp3.MediaType; +import okhttp3.ResponseBody; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -37,6 +39,20 @@ public class VertexAiSessionServiceTest { private static final ObjectMapper mapper = JsonBaseModel.getMapper(); + private static final MediaType JSON_MEDIA_TYPE = + MediaType.parse("application/json; charset=utf-8"); + + private static ApiResponse apiResponseJson(String json) { + return new ApiResponse() { + @Override + public ResponseBody getResponseBody() { + return ResponseBody.create(JSON_MEDIA_TYPE, json); + } + + @Override + public void close() {} + }; + } private static final String MOCK_SESSION_STRING_1 = """ @@ -319,6 +335,24 @@ public void listSessions_empty() { .isEmpty(); } + @Test + public void listSessions_missingSessionsField_returnsEmpty() { + when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=userX", "")) + .thenReturn(apiResponseJson("{}")); + + assertThat(vertexAiSessionService.listSessions("123", "userX").blockingGet().sessions()) + .isEmpty(); + } + + @Test + public void listSessions_nullSessionsField_returnsEmpty() { + when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=userY", "")) + .thenReturn(apiResponseJson("{\"sessions\": null}")); + + assertThat(vertexAiSessionService.listSessions("123", "userY").blockingGet().sessions()) + .isEmpty(); + } + @Test public void listEvents_empty() { assertThat(vertexAiSessionService.listEvents("789", "user1", "3").blockingGet().events()) From 70056707f42281772bd737e2c7fd5878181c7c37 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Fri, 20 Mar 2026 16:47:40 +0100 Subject: [PATCH 02/13] refactor: migrate LangChain4j to builder pattern, enhance token usage, and use JSpecify Nullable - Migrate LangChain4j to a builder pattern - Enhance token usage handling with TokenCountEstimator (from PR #623) - Upgrade to latest version of LangChain4j - Replace javax.annotation.Nullable with org.jspecify.annotations.Nullable --- .../adk/models/langchain4j/LangChain4j.java | 230 ++++++++++++------ .../LangChain4jIntegrationTest.java | 24 +- .../models/langchain4j/LangChain4jTest.java | 162 +++++++++++- pom.xml | 2 +- 4 files changed, 327 insertions(+), 91 deletions(-) diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 3ccb1e029..8279dc21a 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -23,6 +23,7 @@ import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.auto.value.AutoValue; import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; @@ -30,11 +31,11 @@ import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import com.google.genai.types.Schema; import com.google.genai.types.ToolConfig; import com.google.genai.types.Type; -import dev.langchain4j.Experimental; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.audio.Audio; @@ -52,6 +53,7 @@ import dev.langchain4j.data.pdf.PdfFile; import dev.langchain4j.data.video.Video; import dev.langchain4j.exception.UnsupportedFeatureException; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -65,6 +67,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; @@ -72,66 +75,101 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.UUID; +import org.jspecify.annotations.Nullable; -@Experimental -public class LangChain4j extends BaseLlm { +@AutoValue +public abstract class LangChain4j extends BaseLlm { private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference<>() {}; - private final ChatModel chatModel; - private final StreamingChatModel streamingChatModel; - private final ObjectMapper objectMapper; + LangChain4j() { + super(""); + } + + @Nullable + public abstract ChatModel chatModel(); + + @Nullable + public abstract StreamingChatModel streamingChatModel(); + + public abstract ObjectMapper objectMapper(); + + public abstract String modelName(); + + @Nullable + public abstract TokenCountEstimator tokenCountEstimator(); + + @Override + public String model() { + return modelName(); + } + + public static Builder builder() { + return new AutoValue_LangChain4j.Builder().objectMapper(new ObjectMapper()); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder chatModel(ChatModel chatModel); + + public abstract Builder streamingChatModel(StreamingChatModel streamingChatModel); + + public abstract Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator); + + public abstract Builder objectMapper(ObjectMapper objectMapper); + + public abstract Builder modelName(String modelName); + + public abstract LangChain4j build(); + } public LangChain4j(ChatModel chatModel) { - super( - Objects.requireNonNull( - chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = null; - this.objectMapper = new ObjectMapper(); + this(chatModel, null, null, chatModel.defaultRequestParameters().modelName(), null); } public LangChain4j(ChatModel chatModel, String modelName) { - super(Objects.requireNonNull(modelName, "chat model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = null; - this.objectMapper = new ObjectMapper(); + this(chatModel, null, null, modelName, null); } public LangChain4j(StreamingChatModel streamingChatModel) { - super( - Objects.requireNonNull( - streamingChatModel.defaultRequestParameters().modelName(), - "streaming chat model name cannot be null")); - this.chatModel = null; - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this( + null, + streamingChatModel, + null, + streamingChatModel.defaultRequestParameters().modelName(), + null); } public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); - this.chatModel = null; - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this(null, streamingChatModel, null, modelName, null); } public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this(chatModel, streamingChatModel, null, modelName, null); + } + + private LangChain4j( + ChatModel chatModel, + StreamingChatModel streamingChatModel, + ObjectMapper objectMapper, + String modelName, + TokenCountEstimator tokenCountEstimator) { + this(); + LangChain4j.builder() + .chatModel(chatModel) + .streamingChatModel(streamingChatModel) + .objectMapper(objectMapper) + .modelName(modelName) + .tokenCountEstimator(tokenCountEstimator) + .build(); } @Override public Flowable generateContent(LlmRequest llmRequest, boolean stream) { if (stream) { - if (this.streamingChatModel == null) { + if (this.streamingChatModel() == null) { return Flowable.error(new IllegalStateException("StreamingChatModel is not configured")); } @@ -139,54 +177,57 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre return Flowable.create( emitter -> { - streamingChatModel.chat( - chatRequest, - new StreamingChatResponseHandler() { - @Override - public void onPartialResponse(String s) { - emitter.onNext( - LlmResponse.builder().content(Content.fromParts(Part.fromText(s))).build()); - } - - @Override - public void onCompleteResponse(ChatResponse chatResponse) { - if (chatResponse.aiMessage().hasToolExecutionRequests()) { - AiMessage aiMessage = chatResponse.aiMessage(); - toParts(aiMessage).stream() - .map(Part::functionCall) - .forEach( - functionCall -> { - functionCall.ifPresent( - function -> { - emitter.onNext( - LlmResponse.builder() - .content( - Content.fromParts( - Part.fromFunctionCall( - function.name().orElse(""), - function.args().orElse(Map.of())))) - .build()); - }); - }); - } - emitter.onComplete(); - } - - @Override - public void onError(Throwable throwable) { - emitter.onError(throwable); - } - }); + streamingChatModel() + .chat( + chatRequest, + new StreamingChatResponseHandler() { + @Override + public void onPartialResponse(String s) { + emitter.onNext( + LlmResponse.builder() + .content(Content.fromParts(Part.fromText(s))) + .build()); + } + + @Override + public void onCompleteResponse(ChatResponse chatResponse) { + if (chatResponse.aiMessage().hasToolExecutionRequests()) { + AiMessage aiMessage = chatResponse.aiMessage(); + toParts(aiMessage).stream() + .map(Part::functionCall) + .forEach( + functionCall -> { + functionCall.ifPresent( + function -> { + emitter.onNext( + LlmResponse.builder() + .content( + Content.fromParts( + Part.fromFunctionCall( + function.name().orElse(""), + function.args().orElse(Map.of())))) + .build()); + }); + }); + } + emitter.onComplete(); + } + + @Override + public void onError(Throwable throwable) { + emitter.onError(throwable); + } + }); }, BackpressureStrategy.BUFFER); } else { - if (this.chatModel == null) { + if (this.chatModel() == null) { return Flowable.error(new IllegalStateException("ChatModel is not configured")); } ChatRequest chatRequest = toChatRequest(llmRequest); - ChatResponse chatResponse = chatModel.chat(chatRequest); - LlmResponse llmResponse = toLlmResponse(chatResponse); + ChatResponse chatResponse = chatModel().chat(chatRequest); + LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest); return Flowable.just(llmResponse); } @@ -413,7 +454,7 @@ private AiMessage toAiMessage(Content content) { private String toJson(Object object) { try { - return objectMapper.writeValueAsString(object); + return objectMapper().writeValueAsString(object); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -511,11 +552,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) { } } - private LlmResponse toLlmResponse(ChatResponse chatResponse) { + private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) { Content content = Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build(); - return LlmResponse.builder().content(content).build(); + LlmResponse.Builder builder = LlmResponse.builder().content(content); + TokenUsage tokenUsage = chatResponse.tokenUsage(); + if (tokenCountEstimator() != null) { + try { + int estimatedInput = + tokenCountEstimator().estimateTokenCountInMessages(chatRequest.messages()); + int estimatedOutput = + tokenCountEstimator().estimateTokenCountInText(chatResponse.aiMessage().text()); + int estimatedTotal = estimatedInput + estimatedOutput; + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(estimatedInput) + .candidatesTokenCount(estimatedOutput) + .totalTokenCount(estimatedTotal) + .build()); + } catch (Exception e) { + e.printStackTrace(); + } + } else if (tokenUsage != null) { + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(tokenUsage.inputTokenCount()) + .candidatesTokenCount(tokenUsage.outputTokenCount()) + .totalTokenCount(tokenUsage.totalTokenCount()) + .build()); + } + + return builder.build(); } private List toParts(AiMessage aiMessage) { @@ -546,7 +614,7 @@ private List toParts(AiMessage aiMessage) { private Map toArgs(ToolExecutionRequest toolExecutionRequest) { try { - return objectMapper.readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); + return objectMapper().readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); } catch (JsonProcessingException e) { throw new RuntimeException(e); } diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java index 191e48017..5b6d3f3ad 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -62,7 +62,8 @@ void testSimpleAgent() { LlmAgent.builder() .name("science-app") .description("Science teacher agent") - .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) + .model( + LangChain4j.builder().chatModel(claudeModel).modelName(CLAUDE_4_6_SONNET).build()) .instruction( """ You are a helpful science teacher that explains science concepts @@ -98,7 +99,8 @@ void testSingleAgentWithTools() { LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) + .model( + LangChain4j.builder().chatModel(claudeModel).modelName(CLAUDE_4_6_SONNET).build()) .instruction( """ You are a friendly assistant. @@ -183,7 +185,7 @@ void testAgentTool() { LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly assistant. @@ -246,7 +248,7 @@ void testSubAgent() { LlmAgent.builder() .name("greeterAgent") .description("Friendly agent that greets users") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly that greets users. @@ -257,7 +259,7 @@ void testSubAgent() { LlmAgent.builder() .name("farewellAgent") .description("Friendly agent that says goodbye to users") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly that says goodbye to users. @@ -355,7 +357,11 @@ void testSimpleStreamingResponse() { .modelName(CLAUDE_4_6_SONNET) .build(); - LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_4_6_SONNET); + LangChain4j lc4jClaude = + LangChain4j.builder() + .streamingChatModel(claudeStreamingModel) + .modelName(CLAUDE_4_6_SONNET) + .build(); // when Flowable responses = @@ -413,7 +419,11 @@ void testStreamingRunConfig() { When someone greets you, respond with "Hello". If someone asks about the weather, call the `getWeather` function. """) - .model(new LangChain4j(streamingModel, "GPT_4_O_MINI")) + .model( + LangChain4j.builder() + .streamingChatModel(streamingModel) + .modelName("GPT_4_O_MINI") + .build()) // .model(new LangChain4j(streamingModel, // CLAUDE_3_7_SONNET_20250219)) .tools(FunctionTool.create(ToolExample.class, "getWeather")) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 076bb79a3..f88237ff1 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -19,6 +19,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.tools.FunctionTool; @@ -26,6 +27,7 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -33,6 +35,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; import java.util.List; @@ -57,8 +60,26 @@ void setUp() { chatModel = mock(ChatModel.class); streamingChatModel = mock(StreamingChatModel.class); - langChain4j = new LangChain4j(chatModel, MODEL_NAME); - streamingLangChain4j = new LangChain4j(streamingChatModel, MODEL_NAME); + langChain4j = LangChain4j.builder().chatModel(chatModel).modelName(MODEL_NAME).build(); + streamingLangChain4j = + LangChain4j.builder().streamingChatModel(streamingChatModel).modelName(MODEL_NAME).build(); + } + + @Test + void testBuilder() { + ObjectMapper customMapper = new ObjectMapper(); + LangChain4j customLc4j = + LangChain4j.builder() + .chatModel(chatModel) + .streamingChatModel(streamingChatModel) + .objectMapper(customMapper) + .modelName("custom-model") + .build(); + + assertThat(customLc4j.chatModel()).isEqualTo(chatModel); + assertThat(customLc4j.streamingChatModel()).isEqualTo(streamingChatModel); + assertThat(customLc4j.objectMapper()).isEqualTo(customMapper); + assertThat(customLc4j.modelName()).isEqualTo("custom-model"); } @Test @@ -812,4 +833,141 @@ void testGenerateContentWithMcpToolParametersJsonSchemaAsSchema() { assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool"); assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool"); } + + @Test + @DisplayName( + "Should use TokenCountEstimator to estimate token usage when TokenUsage is not available") + void testTokenCountEstimatorFallback() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(50); // Input tokens + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(20); // Output tokens + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage (simulating when LLM doesn't provide token counts) + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response has usage metadata estimated by TokenCountEstimator + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("The weather is sunny today."); + + // IMPORTANT: Verify that token usage was estimated via the TokenCountEstimator + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(20)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(70)); // 50 + 20 + + // Verify the estimator was actually called + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should prioritize TokenCountEstimator over TokenUsage when estimator is provided") + void testTokenCountEstimatorPriority() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(100); // From estimator + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(50); // From estimator + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITH actual TokenUsage from the LLM + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + final TokenUsage actualTokenUsage = new TokenUsage(30, 15, 45); // Actual token counts from LLM + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(actualTokenUsage); // LLM provides token usage + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // IMPORTANT: When TokenCountEstimator is present, it takes priority over TokenUsage + assertThat(response).isNotNull(); + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(100)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(150)); // 100 + 50 + + // Verify the estimator was called (it takes priority) + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should not include usageMetadata when TokenUsage is null and no estimator provided") + void testNoUsageMetadataWithoutEstimator() { + // Given + // Create LangChain4j WITHOUT TokenCountEstimator (default behavior) + final LangChain4j langChain4jNoEstimator = + LangChain4j.builder().chatModel(chatModel).modelName(MODEL_NAME).build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Hello, world!")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("Hello! How can I help you?"); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jNoEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response does NOT have usage metadata + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Hello! How can I help you?"); + + // IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator + assertThat(response.usageMetadata()).isEmpty(); + } } diff --git a/pom.xml b/pom.xml index cbeca1b72..40332472f 100644 --- a/pom.xml +++ b/pom.xml @@ -62,7 +62,7 @@ 0.18.1 3.41.0 3.9.0 - 1.11.0 + 1.12.2 2.0.17 1.4.5 1.0.0 From 3633a7dd071265087ea2ff148d419969b0c888ef Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Mar 2026 13:15:38 -0700 Subject: [PATCH 03/13] fix: Removing deprecated methods from Runner PiperOrigin-RevId: 886942637 --- .../java/com/google/adk/runner/Runner.java | 42 ------------------- 1 file changed, 42 deletions(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 1f7d924ab..849a3cd04 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -425,36 +425,6 @@ public Flowable runAsync(String userId, String sessionId, Content newMess return runAsync(userId, sessionId, newMessage, RunConfig.builder().build()); } - /** - * See {@link #runAsync(Session, Content, RunConfig, Map)}. - * - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync(Session session, Content newMessage, RunConfig runConfig) { - return runAsync(session, newMessage, runConfig, /* stateDelta= */ null); - } - - /** - * Runs the agent asynchronously using a provided Session object. - * - * @param session The session to run the agent in. - * @param newMessage The new message from the user to process. - * @param runConfig Configuration for the agent run. - * @param stateDelta Optional map of state updates to merge into the session for this run. - * @return A Flowable stream of {@link Event} objects generated by the agent during execution. - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync( - Session session, - Content newMessage, - RunConfig runConfig, - @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta) - .compose(Tracing.trace("invocation")); - } - /** * Runs the agent asynchronously using a provided Session object. * @@ -735,18 +705,6 @@ protected Flowable runLiveImpl( }); } - /** - * Runs the agent asynchronously with a default user ID. - * - * @return stream of generated events. - */ - @Deprecated(since = "0.5.0", forRemoval = true) - public Flowable runWithSessionId( - String sessionId, Content newMessage, RunConfig runConfig) { - // TODO(b/410859954): Add user_id to getter or method signature. Assuming "tmp-user" for now. - return this.runAsync("tmp-user", sessionId, newMessage, runConfig); - } - /** * Checks if the agent and its parent chain allow transfer up the tree. * From 8e9fb085354f8148e00cbd236e8f29e82de56d6e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Mar 2026 13:56:07 -0700 Subject: [PATCH 04/13] refactor: Use concatMap for sequential event persistence in Runner Ensure sequential event processing and persistence in ADK Runner. This ensures that events are appended in order and returned from runAsync in order. This aligns better with the Python implementation. PiperOrigin-RevId: 886961696 --- .../java/com/google/adk/runner/Runner.java | 2 +- .../com/google/adk/runner/RunnerTest.java | 42 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 849a3cd04..2bfbca881 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -529,7 +529,7 @@ private Flowable runAgentWithFreshSession( contextWithUpdatedSession .agent() .runAsync(contextWithUpdatedSession) - .flatMap( + .concatMap( agentEvent -> this.sessionService .appendEvent(updatedSession, agentEvent) diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index a3e21cb73..efd565c16 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -26,6 +26,7 @@ import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Arrays.stream; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.mock; @@ -33,6 +34,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; @@ -43,6 +45,7 @@ import com.google.adk.flows.llmflows.Functions; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; +import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; import com.google.adk.summarizer.EventsCompactionConfig; @@ -851,6 +854,45 @@ public void beforeRunCallback_withStateDelta_seesMergedState() { assertThat(sessionInCallback.state()).containsEntry("number", 123); } + @Test + public void runAsync_ensureEventsAreAppendedInOrder() throws Exception { + Event event1 = TestUtils.createEvent("1"); + Event event2 = TestUtils.createEvent("2"); + BaseAgent mockAgent = TestUtils.createSubAgent("test agent", event1, event2); + + BaseSessionService mockSessionService = mock(BaseSessionService.class); + + when(mockSessionService.getSession(any(), any(), any(), any())).thenReturn(Maybe.just(session)); + when(mockSessionService.appendEvent(any(), any())) + .thenAnswer( + invocation -> { + Event eventArg = invocation.getArgument(1); + Single result = Single.just(eventArg); + if (eventArg.id().equals("1")) { + // Artificially delay the first event to ensure it is appended first. + return result.delay(100, MILLISECONDS); + } + return result; + }); + + Runner mockRunner = + Runner.builder() + .agent(mockAgent) + .appName("test") + .sessionService(mockSessionService) + .build(); + + List results = + mockRunner + .runAsync("user", session.id(), createContent("user message")) + .toList() + .blockingGet(); + + assertThat(simplifyEvents(results)) + .containsExactly("author: content for event 1", "author: content for event 2") + .inOrder(); + } + private Content createContent(String text) { return Content.builder().parts(Part.builder().text(text).build()).build(); } From 3e21e7ac46b634341819b3543388a38caef85516 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Sat, 21 Mar 2026 20:11:12 +0100 Subject: [PATCH 05/13] fix: handle null `AiMessage.text()` to prevent NPE and add unit test (PR #1035) --- .../adk/models/langchain4j/LangChain4j.java | 7 ++++-- .../models/langchain4j/LangChain4jTest.java | 23 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 8279dc21a..97331e7b4 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -607,8 +607,11 @@ private List toParts(AiMessage aiMessage) { }); return parts; } else { - Part part = Part.builder().text(aiMessage.text()).build(); - return List.of(part); + String text = aiMessage.text(); + if (text == null) { + return List.of(); + } + return List.of(Part.builder().text(text).build()); } } diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index f88237ff1..a1ec7a3c2 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -970,4 +970,27 @@ void testNoUsageMetadataWithoutEstimator() { // IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator assertThat(response.usageMetadata()).isEmpty(); } + + @Test + @DisplayName("Should handle null AiMessage text without throwing NPE") + void testGenerateContentWithNullAiMessageText() { + // Given + final LlmRequest llmRequest = + LlmRequest.builder().contents(List.of(Content.fromParts(Part.fromText("Hello")))).build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = mock(AiMessage.class); + when(aiMessage.text()).thenReturn(null); + when(aiMessage.hasToolExecutionRequests()).thenReturn(false); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final Flowable responseFlowable = langChain4j.generateContent(llmRequest, false); + final LlmResponse response = responseFlowable.blockingFirst(); + // Then - no NPE thrown, and content has no text parts + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts().orElse(List.of())).isEmpty(); + } } From cdc5199eb0f92cb95db2ee7ff139d67317968457 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Mon, 23 Mar 2026 13:43:34 +0100 Subject: [PATCH 06/13] fix: add schema validation to SetModelResponseTool (issue #587 already implemented, but adding tests from PR #603) --- .../adk/tools/SetModelResponseTool.java | 7 +- .../adk/tools/SetModelResponseToolTest.java | 123 ++++++++++++++++++ 2 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java diff --git a/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java index e23d6414a..3b0e411b4 100644 --- a/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java +++ b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java @@ -16,6 +16,7 @@ package com.google.adk.tools; +import com.google.adk.SchemaUtils; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Single; @@ -58,6 +59,10 @@ public Optional declaration() { public Single> runAsync(Map args, ToolContext toolContext) { // This tool is a marker for the final response, it doesn't do anything but return its arguments // which will be captured as the final result. - return Single.just(args); + return Single.fromCallable( + () -> { + SchemaUtils.validateMapOnSchema(args, outputSchema, /* isInput= */ false); + return args; + }); } } diff --git a/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java b/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java new file mode 100644 index 000000000..64b600af9 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SetModelResponseToolTest { + + @Test + public void declaration_returnsCorrectFunctionDeclaration() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + FunctionDeclaration declaration = tool.declaration().get(); + + assertThat(declaration.name()).hasValue("set_model_response"); + assertThat(declaration.description()).isPresent(); + assertThat(declaration.description().get()).contains("Set your final response"); + assertThat(declaration.parameters()).hasValue(outputSchema); + } + + @Test + public void runAsync_returnsArgs() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + Map args = ImmutableMap.of("field1", "value1"); + + Map result = tool.runAsync(args, null).blockingGet(); + + assertThat(result).isEqualTo(args); + } + + @Test + public void runAsync_validatesArgs() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + Map invalidArgs = ImmutableMap.of("field2", "value2"); + + // Should throw validation error + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> tool.runAsync(invalidArgs, null).blockingGet()); + + assertThat(exception).hasMessageThat().contains("does not match agent output schema"); + } + + @Test + public void runAsync_validatesComplexArgs() { + Schema complexSchema = + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "id", + Schema.builder().type("INTEGER").build(), + "tags", + Schema.builder() + .type("ARRAY") + .items(Schema.builder().type("STRING").build()) + .build(), + "metadata", + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("key", Schema.builder().type("STRING").build())) + .build())) + .required(ImmutableList.of("id", "tags", "metadata")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(complexSchema); + Map complexArgs = + ImmutableMap.of( + "id", 123, + "tags", ImmutableList.of("tag1", "tag2"), + "metadata", ImmutableMap.of("key", "value")); + + Map result = tool.runAsync(complexArgs, null).blockingGet(); + + assertThat(result).containsEntry("id", 123); + assertThat(result).containsEntry("tags", ImmutableList.of("tag1", "tag2")); + assertThat(result).containsEntry("metadata", ImmutableMap.of("key", "value")); + } +} From e9df447f1445044552e8710713ab5a76c2ae5093 Mon Sep 17 00:00:00 2001 From: "Michael Vorburger.ch" Date: Mon, 23 Mar 2026 08:42:56 -0700 Subject: [PATCH 07/13] Remove explicit SLF4J binding from city-time-weather ADK tutorial. The `slf4j-simple` dependency and the exclusion of `logback-classic` are removed, allowing the default logging implementation provided by `google-adk-dev` to be used. PiperOrigin-RevId: 888114465 --- tutorials/city-time-weather/pom.xml | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tutorials/city-time-weather/pom.xml b/tutorials/city-time-weather/pom.xml index aeb110cf6..19ef08a2d 100644 --- a/tutorials/city-time-weather/pom.xml +++ b/tutorials/city-time-weather/pom.xml @@ -36,16 +36,6 @@ com.google.adk google-adk-dev ${project.version} - - - ch.qos.logback - logback-classic - - - - - org.slf4j - slf4j-simple From ce4b642220c785f48711d92657faccaa4eded4f1 Mon Sep 17 00:00:00 2001 From: ddobrin Date: Mon, 23 Mar 2026 13:33:26 -0400 Subject: [PATCH 08/13] Fixes #490 and #1064 ToolConverter issues in the spring-ai module --- contrib/spring-ai/DOCUMENT-GEMINI.md | 86 ------------------ contrib/spring-ai/README.md | 26 +++--- .../adk/models/springai/ToolConverter.java | 88 +++++++++++++------ .../ToolConverterArgumentProcessingTest.java | 84 ++++++++++++++++++ .../models/springai/ToolConverterTest.java | 34 +++++++ 5 files changed, 190 insertions(+), 128 deletions(-) delete mode 100644 contrib/spring-ai/DOCUMENT-GEMINI.md diff --git a/contrib/spring-ai/DOCUMENT-GEMINI.md b/contrib/spring-ai/DOCUMENT-GEMINI.md deleted file mode 100644 index 393562528..000000000 --- a/contrib/spring-ai/DOCUMENT-GEMINI.md +++ /dev/null @@ -1,86 +0,0 @@ -# Documentation for the ADK Spring AI Library - -## 📖 Overview -The `google-adk-spring-ai` library provides an integration layer between the Google Agent Development Kit (ADK) and the Spring AI project. It allows developers to use Spring AI's `ChatModel`, `StreamingChatModel`, and `EmbeddingModel` as `BaseLlm` and `Embedding` implementations within the ADK framework. - -The library handles the conversion between ADK's request/response formats and Spring AI's prompt/chat response formats. It also includes auto-configuration to automatically expose Spring AI models as ADK `SpringAI` and `SpringAIEmbedding` beans in a Spring Boot application. - -## 🛠️ Building -To include this library in your project, use the following Maven coordinates: - -```xml - - com.google.adk - google-adk-spring-ai - 0.3.1-SNAPSHOT - -``` - -You will also need to include a dependency for the specific Spring AI model you want to use, for example: -```xml - - org.springframework.ai - spring-ai-openai - -``` - -## 🚀 Usage -The primary way to use this library is through Spring Boot auto-configuration. By including the `google-adk-spring-ai` dependency and a Spring AI model dependency (e.g., `spring-ai-openai`), the library will automatically create a `SpringAI` bean. This bean can then be injected and used as a `BaseLlm` in the ADK. - -**Example `application.properties`:** -```properties -# OpenAI configuration -spring.ai.openai.api-key=${OPENAI_API_KEY} -spring.ai.openai.chat.options.model=gpt-4o-mini -spring.ai.openai.chat.options.temperature=0.7 - -# ADK Spring AI configuration -adk.spring-ai.model=gpt-4o-mini -adk.spring-ai.validation.enabled=true -``` - -**Example usage in a Spring service:** -```java -import com.google.adk.models.BaseLlm; -import com.google.adk.models.LlmRequest; -import com.google.adk.models.LlmResponse; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; -import reactor.core.publisher.Mono; - -@Service -public class MyAgentService { - - private final BaseLlm llm; - - @Autowired - public MyAgentService(BaseLlm llm) { - this.llm = llm; - } - - public Mono generateResponse(String prompt) { - LlmRequest request = LlmRequest.builder() - .addText(prompt) - .build(); - return Mono.from(llm.generateContent(request)) - .map(llmResponse -> llmResponse.content().get().parts().get(0).text().get()); - } -} -``` - -## 📚 API Reference -### Key Classes -- **`SpringAI`**: The main class that wraps a Spring AI `ChatModel` and/or `StreamingChatModel` and implements the ADK `BaseLlm` interface. - - **Methods**: - - `generateContent(LlmRequest llmRequest, boolean stream)`: Generates content, either streaming or non-streaming, by calling the underlying Spring AI model. It converts the ADK `LlmRequest` to a Spring AI `Prompt` and the `ChatResponse` back to an ADK `LlmResponse`. - -- **`SpringAIEmbedding`**: Wraps a Spring AI `EmbeddingModel` to be used for generating embeddings within the ADK framework. - -- **`SpringAIAutoConfiguration`**: The Spring Boot auto-configuration class that automatically discovers and configures `SpringAI` and `SpringAIEmbedding` beans based on the `ChatModel`, `StreamingChatModel`, and `EmbeddingModel` beans present in the application context. - -- **`SpringAIProperties`**: A configuration properties class (`@ConfigurationProperties("adk.spring-ai")`) that allows for customization of the Spring AI integration. - - **Properties**: - - `model`: The model name to use. - - `validation.enabled`: Whether to enable configuration validation. - - `validation.fail-fast`: Whether to fail fast on validation errors. - - `observability.enabled`: Whether to enable observability features. diff --git a/contrib/spring-ai/README.md b/contrib/spring-ai/README.md index c45f0e033..0ce7de4fe 100644 --- a/contrib/spring-ai/README.md +++ b/contrib/spring-ai/README.md @@ -18,21 +18,21 @@ To use ADK Java with the Spring AI integration in your application, add the foll com.google.adk google-adk - 0.3.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT com.google.adk google-adk-spring-ai - 0.3.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT org.springframework.ai spring-ai-bom - 1.1.0-M3 + 2.0.0-M3 pom import @@ -109,14 +109,14 @@ Add the Spring AI provider dependencies for the AI services you want to use: org.springframework.boot spring-boot-starter-parent - 3.2.0 + 4.0.2 17 - 1.1.0-M3 - 0.3.1-SNAPSHOT + 2.0.0-M3 + 1.0.1-rc.1-SNAPSHOT @@ -271,7 +271,7 @@ public class MyAdkSpringAiApplication { .anthropicApi(anthropicApi) .build(); - return new SpringAI(chatModel, "claude-3-5-sonnet-20241022"); + return new SpringAI(chatModel, "claude-sonnet-4-6"); } @Bean @@ -312,7 +312,7 @@ spring: api-key: ${ANTHROPIC_API_KEY} chat: options: - model: claude-3-5-sonnet-20241022 + model: claude-sonnet-4-6 temperature: 0.7 # ADK Spring AI Configuration @@ -365,13 +365,13 @@ The main adapter class that implements `BaseLlm` and wraps Spring AI `ChatModel` **Usage:** ```java // With ChatModel only -SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-20250514"); +SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-6"); // With both ChatModel and StreamingChatModel -SpringAI springAI = new SpringAI(chatModel, streamingChatModel, "claude-sonnet-4-20250514"); +SpringAI springAI = new SpringAI(chatModel, streamingChatModel, "claude-sonnet-4-6"); // With observability configuration -SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-20250514", observabilityConfig); +SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-6", observabilityConfig); ``` #### 2. MessageConverter (MessageConverter.java) @@ -533,7 +533,7 @@ The library works with any Spring AI provider: - Features: Chat, streaming, function calling, embeddings 2. **Anthropic** (`spring-ai-anthropic`) - - Models: Claude 3.5 Sonnet, Claude 3 Haiku + - Models: Claude 4.x Sonnet, Claude 4.x Haiku - Features: Chat, streaming, function calling - **Note:** Requires proper function schema registration @@ -563,7 +563,7 @@ The library works with any Spring AI provider: #### Anthropic - **Function Calling:** Requires explicit schema registration using `inputSchema()` method -- **Model Names:** Use full model names like `claude-3-5-sonnet-20241022` +- **Model Names:** Use full model names like `claude-sonnet-4-6` - **API Key:** Requires `ANTHROPIC_API_KEY` environment variable #### OpenAI diff --git a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java index 95dafadb4..4012ee5d6 100644 --- a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java +++ b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.tool.ToolCallback; @@ -172,6 +173,17 @@ public List convertToSpringAiTools(Map tools) { } catch (Exception e) { logger.error("Error serializing schema to JSON: {}", e.getMessage(), e); } + } else if (declaration.parametersJsonSchema().isPresent()) { + callbackBuilder.inputType(Map.class); + try { + String schemaJson = + new com.fasterxml.jackson.databind.ObjectMapper() + .writeValueAsString(declaration.parametersJsonSchema().get()); + callbackBuilder.inputSchema(schemaJson); + logger.debug("Set input schema JSON from parametersJsonSchema: {}", schemaJson); + } catch (Exception e) { + logger.error("Error serializing parametersJsonSchema to JSON: {}", e.getMessage(), e); + } } toolCallbacks.add(callbackBuilder.build()); @@ -187,45 +199,63 @@ public List convertToSpringAiTools(Map tools) { */ private Map processArguments( Map args, FunctionDeclaration declaration) { - // If the arguments already match the expected format, return as-is if (declaration.parameters().isPresent()) { var schema = declaration.parameters().get(); if (schema.properties().isPresent()) { - var expectedParams = schema.properties().get().keySet(); - - // Check if all expected parameters are present at the top level - boolean allParamsPresent = expectedParams.stream().allMatch(args::containsKey); - if (allParamsPresent) { - return args; + return normalizeArguments(args, schema.properties().get().keySet()); + } + } else if (declaration.parametersJsonSchema().isPresent()) { + try { + @SuppressWarnings("unchecked") + Map schemaMap = + new com.fasterxml.jackson.databind.ObjectMapper() + .convertValue(declaration.parametersJsonSchema().get(), Map.class); + Object propertiesObj = schemaMap.get("properties"); + if (propertiesObj instanceof Map) { + @SuppressWarnings("unchecked") + Set expectedParams = ((Map) propertiesObj).keySet(); + return normalizeArguments(args, expectedParams); } + } catch (Exception e) { + logger.warn( + "Error processing parametersJsonSchema for argument mapping: {}", e.getMessage()); + } + } - // Check if arguments are nested under a single key (common pattern) - if (args.size() == 1) { - var singleValue = args.values().iterator().next(); - if (singleValue instanceof Map) { - @SuppressWarnings("unchecked") - Map nestedArgs = (Map) singleValue; - boolean allNestedParamsPresent = - expectedParams.stream().allMatch(nestedArgs::containsKey); - if (allNestedParamsPresent) { - return nestedArgs; - } - } - } + // If no processing worked, return original args and let ADK handle the error + return args; + } - // Check if we have a single parameter function and got a direct value - if (expectedParams.size() == 1) { - String expectedParam = expectedParams.iterator().next(); - if (args.size() == 1 && !args.containsKey(expectedParam)) { - // Try to map the single value to the expected parameter name - Object singleValue = args.values().iterator().next(); - return Map.of(expectedParam, singleValue); - } + private Map normalizeArguments( + Map args, Set expectedParams) { + // Check if all expected parameters are present at the top level + boolean allParamsPresent = expectedParams.stream().allMatch(args::containsKey); + if (allParamsPresent) { + return args; + } + + // Check if arguments are nested under a single key (common pattern) + if (args.size() == 1) { + var singleValue = args.values().iterator().next(); + if (singleValue instanceof Map) { + @SuppressWarnings("unchecked") + Map nestedArgs = (Map) singleValue; + boolean allNestedParamsPresent = expectedParams.stream().allMatch(nestedArgs::containsKey); + if (allNestedParamsPresent) { + return nestedArgs; } } } - // If no processing worked, return original args and let ADK handle the error + // Check if we have a single parameter function and got a direct value + if (expectedParams.size() == 1) { + String expectedParam = expectedParams.iterator().next(); + if (args.size() == 1 && !args.containsKey(expectedParam)) { + Object singleValue = args.values().iterator().next(); + return Map.of(expectedParam, singleValue); + } + } + return args; } diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java index 301a145e0..77b988837 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java @@ -115,6 +115,90 @@ private Map invokeProcessArguments( return (Map) method.invoke(converter, args, declaration); } + @Test + void testArgumentProcessingWithParametersJsonSchema_correctFormat() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map correctArgs = Map.of("location", "San Francisco"); + Map processedArgs = + invokeProcessArguments(processArguments, converter, correctArgs, declaration); + + assertThat(processedArgs).isEqualTo(correctArgs); + } + + @Test + void testArgumentProcessingWithParametersJsonSchema_nestedFormat() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map nestedArgs = Map.of("args", Map.of("location", "San Francisco")); + Map processedArgs = + invokeProcessArguments(processArguments, converter, nestedArgs, declaration); + + assertThat(processedArgs).containsEntry("location", "San Francisco"); + } + + @Test + void testArgumentProcessingWithParametersJsonSchema_directValue() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map directValueArgs = Map.of("value", "San Francisco"); + Map processedArgs = + invokeProcessArguments(processArguments, converter, directValueArgs, declaration); + + assertThat(processedArgs).containsEntry("location", "San Francisco"); + } + + @Test + void testArgumentProcessingWithParametersJsonSchema_noMatch() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map wrongArgs = Map.of("city", "San Francisco", "country", "USA"); + Map processedArgs = + invokeProcessArguments(processArguments, converter, wrongArgs, declaration); + + assertThat(processedArgs).isEqualTo(wrongArgs); + } + public static class WeatherTools { public static Map getWeatherInfo(String location) { return Map.of( diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java index 231c8e1fe..1f3044159 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java @@ -26,6 +26,7 @@ import java.util.Optional; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.ToolCallback; class ToolConverterTest { @@ -178,4 +179,37 @@ void testToolMetadata() { assertThat(metadata.getDescription()).isEqualTo("Test description"); assertThat(metadata.getDeclaration()).isEqualTo(function); } + + @Test + void testConvertToSpringAiToolsWithParametersJsonSchema() { + Map jsonSchema = + Map.of( + "type", + "object", + "properties", + Map.of("location", Map.of("type", "string", "description", "City name")), + "required", + List.of("location")); + + FunctionDeclaration function = + FunctionDeclaration.builder() + .name("get_weather") + .description("Get weather for a location") + .parametersJsonSchema(jsonSchema) + .build(); + + BaseTool testTool = + new BaseTool("get_weather", "Get weather for a location") { + @Override + public Optional declaration() { + return Optional.of(function); + } + }; + + Map tools = Map.of("get_weather", testTool); + List toolCallbacks = toolConverter.convertToSpringAiTools(tools); + + assertThat(toolCallbacks).hasSize(1); + assertThat(toolCallbacks.get(0).getToolDefinition().name()).isEqualTo("get_weather"); + } } From 8a7f816ffeb80d58b7e8e2a32d7c70ba8ad89d73 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Mar 2026 08:00:58 -0700 Subject: [PATCH 09/13] refactor: use mock api answers for tests PiperOrigin-RevId: 888667558 --- .../adk/sessions/VertexAiSessionService.java | 4 ++-- .../google/adk/sessions/MockApiAnswer.java | 11 ++++++++++ .../sessions/VertexAiSessionServiceTest.java | 21 ++----------------- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index b62add27a..99e7e3479 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -135,7 +135,7 @@ private ListSessionsResponse parseListSessionsResponse( JsonNode listSessionsResponseMap, String appName, String userId) { JsonNode sessionsNode = listSessionsResponseMap.get("sessions"); if (sessionsNode == null || sessionsNode.isNull() || sessionsNode.isEmpty()) { - return ListSessionsResponse.builder().sessions(new ArrayList<>()).build(); + return ListSessionsResponse.builder().build(); } List> apiSessions = objectMapper.convertValue(sessionsNode, new TypeReference>>() {}); @@ -174,7 +174,7 @@ public Single listEvents(String appName, String userId, Stri private ListEventsResponse parseListEventsResponse(JsonNode listEventsResponse) { JsonNode sessionEventsNode = listEventsResponse.get("sessionEvents"); if (sessionEventsNode == null || sessionEventsNode.isEmpty()) { - return ListEventsResponse.builder().events(new ArrayList<>()).build(); + return ListEventsResponse.builder().build(); } return ListEventsResponse.builder() .events( diff --git a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java index 111b1dce3..743bcee8d 100644 --- a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java +++ b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java @@ -36,14 +36,25 @@ class MockApiAnswer implements Answer { private final Map sessionMap; private final Map eventMap; + private final String rawApiResponse; MockApiAnswer(Map sessionMap, Map eventMap) { this.sessionMap = sessionMap; this.eventMap = eventMap; + this.rawApiResponse = null; + } + + MockApiAnswer(String rawApiResponse) { + this.sessionMap = null; + this.eventMap = null; + this.rawApiResponse = rawApiResponse; } @Override public ApiResponse answer(InvocationOnMock invocation) throws Throwable { + if (rawApiResponse != null) { + return responseWithBody(rawApiResponse); + } String httpMethod = invocation.getArgument(0); String path = invocation.getArgument(1); if (httpMethod.equals("POST") && SESSIONS_REGEX.matcher(path).matches()) { diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index 3dab94b46..dd62263d7 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -25,8 +25,6 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import okhttp3.MediaType; -import okhttp3.ResponseBody; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -39,21 +37,6 @@ public class VertexAiSessionServiceTest { private static final ObjectMapper mapper = JsonBaseModel.getMapper(); - private static final MediaType JSON_MEDIA_TYPE = - MediaType.parse("application/json; charset=utf-8"); - - private static ApiResponse apiResponseJson(String json) { - return new ApiResponse() { - @Override - public ResponseBody getResponseBody() { - return ResponseBody.create(JSON_MEDIA_TYPE, json); - } - - @Override - public void close() {} - }; - } - private static final String MOCK_SESSION_STRING_1 = """ { @@ -338,7 +321,7 @@ public void listSessions_empty() { @Test public void listSessions_missingSessionsField_returnsEmpty() { when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=userX", "")) - .thenReturn(apiResponseJson("{}")); + .thenAnswer(new MockApiAnswer("{}")); assertThat(vertexAiSessionService.listSessions("123", "userX").blockingGet().sessions()) .isEmpty(); @@ -347,7 +330,7 @@ public void listSessions_missingSessionsField_returnsEmpty() { @Test public void listSessions_nullSessionsField_returnsEmpty() { when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=userY", "")) - .thenReturn(apiResponseJson("{\"sessions\": null}")); + .thenAnswer(new MockApiAnswer("{\"sessions\": null}")); assertThat(vertexAiSessionService.listSessions("123", "userY").blockingGet().sessions()) .isEmpty(); From 677b6d7452aa28fab42d554d18c150d59ca88eec Mon Sep 17 00:00:00 2001 From: Mateusz Krawiec Date: Wed, 25 Mar 2026 03:24:35 -0700 Subject: [PATCH 10/13] fix: parallel agent execution PiperOrigin-RevId: 889140710 --- .../com/google/adk/agents/ParallelAgent.java | 30 +++++-- .../google/adk/agents/ParallelAgentTest.java | 86 ++++++++++++++++++- 2 files changed, 108 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/ParallelAgent.java b/core/src/main/java/com/google/adk/agents/ParallelAgent.java index f30d951aa..2593ec13a 100644 --- a/core/src/main/java/com/google/adk/agents/ParallelAgent.java +++ b/core/src/main/java/com/google/adk/agents/ParallelAgent.java @@ -16,11 +16,14 @@ package com.google.adk.agents; import static com.google.common.base.Strings.isNullOrEmpty; -import static com.google.common.collect.ImmutableList.toImmutableList; import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.events.Event; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Scheduler; +import io.reactivex.rxjava3.schedulers.Schedulers; +import java.util.ArrayList; import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +38,7 @@ public class ParallelAgent extends BaseAgent { private static final Logger logger = LoggerFactory.getLogger(ParallelAgent.class); + private final Scheduler scheduler; /** * Constructor for ParallelAgent. @@ -44,24 +48,35 @@ public class ParallelAgent extends BaseAgent { * @param subAgents The list of sub-agents to run in parallel. * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. + * @param scheduler The scheduler to use for parallel execution. */ private ParallelAgent( String name, String description, List subAgents, List beforeAgentCallback, - List afterAgentCallback) { + List afterAgentCallback, + Scheduler scheduler) { super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); + this.scheduler = scheduler; } /** Builder for {@link ParallelAgent}. */ public static class Builder extends BaseAgent.Builder { + private Scheduler scheduler = Schedulers.io(); + + @CanIgnoreReturnValue + public Builder scheduler(Scheduler scheduler) { + this.scheduler = scheduler; + return this; + } + @Override public ParallelAgent build() { return new ParallelAgent( - name, description, subAgents, beforeAgentCallback, afterAgentCallback); + name, description, subAgents, beforeAgentCallback, afterAgentCallback, scheduler); } } @@ -129,10 +144,11 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { } var updatedInvocationContext = setBranchForCurrentAgent(this, invocationContext); - return Flowable.merge( - currentSubAgents.stream() - .map(subAgent -> subAgent.runAsync(updatedInvocationContext)) - .collect(toImmutableList())); + List> agentFlowables = new ArrayList<>(); + for (BaseAgent subAgent : currentSubAgents) { + agentFlowables.add(subAgent.runAsync(updatedInvocationContext).subscribeOn(scheduler)); + } + return Flowable.merge(agentFlowables); } /** diff --git a/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java b/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java index a6afb5793..e51240c45 100644 --- a/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java @@ -25,7 +25,10 @@ import com.google.genai.types.Content; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Scheduler; import io.reactivex.rxjava3.schedulers.Schedulers; +import io.reactivex.rxjava3.schedulers.TestScheduler; +import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,10 +39,16 @@ public final class ParallelAgentTest { static class TestingAgent extends BaseAgent { private final long delayMillis; + private final Scheduler scheduler; private TestingAgent(String name, String description, long delayMillis) { + this(name, description, delayMillis, Schedulers.computation()); + } + + private TestingAgent(String name, String description, long delayMillis, Scheduler scheduler) { super(name, description, ImmutableList.of(), null, null); this.delayMillis = delayMillis; + this.scheduler = scheduler; } @Override @@ -55,7 +64,7 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { .build()); if (delayMillis > 0) { - return event.delay(delayMillis, MILLISECONDS, Schedulers.computation()); + return event.delay(delayMillis, MILLISECONDS, scheduler); } return event; } @@ -110,4 +119,79 @@ public void runAsync_noSubAgents_returnsEmptyFlowable() { assertThat(events).isEmpty(); } + + static class BlockingAgent extends BaseAgent { + private final long sleepMillis; + + private BlockingAgent(String name, long sleepMillis) { + super(name, "Blocking Agent", ImmutableList.of(), null, null); + this.sleepMillis = sleepMillis; + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.fromCallable( + () -> { + Thread.sleep(sleepMillis); + return Event.builder() + .author(name()) + .branch(invocationContext.branch().orElse(null)) + .invocationId(invocationContext.invocationId()) + .content(Content.fromParts(Part.fromText("Done"))) + .build(); + }); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + throw new UnsupportedOperationException("Not implemented"); + } + } + + @Test + public void runAsync_blockingSubAgents_shouldExecuteInParallel() { + long sleepTime = 1000; + BlockingAgent agent1 = new BlockingAgent("agent1", sleepTime); + BlockingAgent agent2 = new BlockingAgent("agent2", sleepTime); + + ParallelAgent parallelAgent = + ParallelAgent.builder().name("parallel_agent").subAgents(agent1, agent2).build(); + + InvocationContext invocationContext = createInvocationContext(parallelAgent); + + long startTime = System.currentTimeMillis(); + List events = parallelAgent.runAsync(invocationContext).toList().blockingGet(); + long duration = System.currentTimeMillis() - startTime; + + assertThat(events).hasSize(2); + // If parallel, duration should be less than 1.5 * sleepTime (1500ms). + assertThat(duration).isAtLeast(sleepTime); + assertThat(duration).isLessThan((long) (1.5 * sleepTime)); + } + + @Test + public void runAsync_withTestScheduler_usesVirtualTime() { + TestScheduler testScheduler = new TestScheduler(); + long delayMillis = 1000; + TestingAgent agent = + new TestingAgent("delayed_agent", "Delayed Agent", delayMillis, testScheduler); + + ParallelAgent parallelAgent = + ParallelAgent.builder() + .name("parallel_agent") + .subAgents(agent) + .scheduler(testScheduler) + .build(); + + InvocationContext invocationContext = createInvocationContext(parallelAgent); + + TestSubscriber testSubscriber = parallelAgent.runAsync(invocationContext).test(); + + testScheduler.advanceTimeBy(delayMillis - 100, MILLISECONDS); + testSubscriber.assertNoValues(); + testSubscriber.assertNotComplete(); + testScheduler.advanceTimeBy(200, MILLISECONDS); + testSubscriber.assertValueCount(1); + testSubscriber.assertComplete(); + } } From 5a2abbfe6f9e4e1ebdd5b918e34fcdb144603b5a Mon Sep 17 00:00:00 2001 From: Mateusz Krawiec Date: Wed, 25 Mar 2026 03:25:54 -0700 Subject: [PATCH 11/13] fix: resolve MCP tool parsing errors in Claude integration The Claude model integration parsing logic failed when processing MCP tool responses because it only extracted output from the legacy `result` field. Extended extraction logic to: - Support native MCP `content` arrays. - Support legacy `result` structures natively. - Fallback to generic JSON serialization of the entire map. Additionally, updated AbstractMcpTool.wrapCallResult() format to match Python ADK. PiperOrigin-RevId: 889141233 --- .../java/com/google/adk/models/Claude.java | 51 ++++++++-- .../google/adk/tools/mcp/AbstractMcpTool.java | 52 +--------- .../com/google/adk/models/ClaudeTest.java | 97 +++++++++++++++++++ .../adk/tools/mcp/AbstractMcpToolTest.java | 62 ++++++++++++ 4 files changed, 203 insertions(+), 59 deletions(-) create mode 100644 core/src/test/java/com/google/adk/models/ClaudeTest.java create mode 100644 core/src/test/java/com/google/adk/tools/mcp/AbstractMcpToolTest.java diff --git a/core/src/main/java/com/google/adk/models/Claude.java b/core/src/main/java/com/google/adk/models/Claude.java index ebb786e35..01feda1d4 100644 --- a/core/src/main/java/com/google/adk/models/Claude.java +++ b/core/src/main/java/com/google/adk/models/Claude.java @@ -31,8 +31,7 @@ import com.anthropic.models.messages.ToolUnion; import com.anthropic.models.messages.ToolUseBlockParam; import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import com.google.adk.JsonBaseModel; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -170,9 +169,22 @@ private ContentBlockParam partToAnthropicMessageBlock(Part part) { .build()); } else if (part.functionResponse().isPresent()) { String content = ""; - if (part.functionResponse().get().response().isPresent() - && part.functionResponse().get().response().get().getOrDefault("result", null) != null) { - content = part.functionResponse().get().response().get().get("result").toString(); + if (part.functionResponse().get().response().isPresent()) { + Map responseData = part.functionResponse().get().response().get(); + + Object contentObj = responseData.get("content"); + Object resultObj = responseData.get("result"); + + if (contentObj instanceof List list && !list.isEmpty()) { + // Native MCP format: list of content blocks + content = extractMcpContentBlocks(list); + } else if (resultObj != null) { + // ADK tool result object + content = resultObj instanceof String s ? s : serializeToJson(resultObj); + } else if (!responseData.isEmpty()) { + // Fallback: arbitrary JSON structure + content = serializeToJson(responseData); + } } return ContentBlockParam.ofToolResult( ToolResultBlockParam.builder() @@ -184,6 +196,30 @@ private ContentBlockParam partToAnthropicMessageBlock(Part part) { throw new UnsupportedOperationException("Not supported yet."); } + private String extractMcpContentBlocks(List list) { + List textBlocks = new ArrayList<>(); + for (Object item : list) { + if (item instanceof Map m && "text".equals(m.get("type"))) { + Object textObj = m.get("text"); + textBlocks.add(textObj != null ? String.valueOf(textObj) : ""); + } else if (item instanceof String s) { + textBlocks.add(s); + } else { + textBlocks.add(serializeToJson(item)); + } + } + return String.join("\n", textBlocks); + } + + private String serializeToJson(Object obj) { + try { + return JsonBaseModel.getMapper().writeValueAsString(obj); + } catch (Exception e) { + logger.warn("Failed to serialize object to JSON", e); + return String.valueOf(obj); + } + } + private void updateTypeString(Map valueDict) { if (valueDict == null) { return; @@ -221,10 +257,9 @@ private Tool functionDeclarationToAnthropicTool(FunctionDeclaration functionDecl .get() .forEach( (key, schema) -> { - ObjectMapper objectMapper = new ObjectMapper(); - objectMapper.registerModule(new Jdk8Module()); Map schemaMap = - objectMapper.convertValue(schema, new TypeReference>() {}); + JsonBaseModel.getMapper() + .convertValue(schema, new TypeReference>() {}); updateTypeString(schemaMap); properties.put(key, schemaMap); }); diff --git a/core/src/main/java/com/google/adk/tools/mcp/AbstractMcpTool.java b/core/src/main/java/com/google/adk/tools/mcp/AbstractMcpTool.java index d9c28e501..3b0c3d70a 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/AbstractMcpTool.java +++ b/core/src/main/java/com/google/adk/tools/mcp/AbstractMcpTool.java @@ -16,7 +16,6 @@ package com.google.adk.tools.mcp; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.tools.BaseTool; @@ -24,13 +23,9 @@ import com.google.common.collect.ImmutableMap; import com.google.genai.types.FunctionDeclaration; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.Content; import io.modelcontextprotocol.spec.McpSchema.JsonSchema; -import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpSchema.ToolAnnotations; -import java.util.ArrayList; -import java.util.List; import java.util.Map; import java.util.Optional; @@ -116,51 +111,6 @@ protected static Map wrapCallResult( return ImmutableMap.of("error", "MCP framework error: CallToolResult was null"); } - List contents = callResult.content(); - Boolean isToolError = callResult.isError(); - - if (isToolError != null && isToolError) { - String errorMessage = "Tool execution failed."; - if (contents != null - && !contents.isEmpty() - && contents.get(0) instanceof TextContent textContent) { - if (textContent.text() != null && !textContent.text().isEmpty()) { - errorMessage += " Details: " + textContent.text(); - } - } - return ImmutableMap.of("error", errorMessage); - } - - if (contents == null || contents.isEmpty()) { - return ImmutableMap.of(); - } - - List textOutputs = new ArrayList<>(); - for (Content content : contents) { - if (content instanceof TextContent textContent) { - if (textContent.text() != null) { - textOutputs.add(textContent.text()); - } - } - } - - if (textOutputs.isEmpty()) { - return ImmutableMap.of( - "error", - "Tool '" + mcpToolName + "' returned content that is not TextContent.", - "content_details", - contents.toString()); - } - - List> resultMaps = new ArrayList<>(); - for (String textOutput : textOutputs) { - try { - resultMaps.add( - objectMapper.readValue(textOutput, new TypeReference>() {})); - } catch (JsonProcessingException e) { - resultMaps.add(ImmutableMap.of("text", textOutput)); - } - } - return ImmutableMap.of("text_output", resultMaps); + return objectMapper.convertValue(callResult, new TypeReference>() {}); } } diff --git a/core/src/test/java/com/google/adk/models/ClaudeTest.java b/core/src/test/java/com/google/adk/models/ClaudeTest.java new file mode 100644 index 000000000..677d40627 --- /dev/null +++ b/core/src/test/java/com/google/adk/models/ClaudeTest.java @@ -0,0 +1,97 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models; + +import static com.google.common.truth.Truth.assertThat; + +import com.anthropic.client.AnthropicClient; +import com.anthropic.models.messages.ContentBlockParam; +import com.anthropic.models.messages.ToolResultBlockParam; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import java.lang.reflect.Method; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public final class ClaudeTest { + + private Claude claude; + private Method partToAnthropicMessageBlockMethod; + + @Before + public void setUp() throws Exception { + AnthropicClient mockClient = Mockito.mock(AnthropicClient.class); + claude = new Claude("claude-3-opus", mockClient); + + // Access private method for testing the extraction logic + partToAnthropicMessageBlockMethod = + Claude.class.getDeclaredMethod("partToAnthropicMessageBlock", Part.class); + partToAnthropicMessageBlockMethod.setAccessible(true); + } + + @Test + public void testPartToAnthropicMessageBlock_mcpNativeFormat() throws Exception { + Map responseData = + ImmutableMap.of( + "content", + ImmutableList.of(ImmutableMap.of("type", "text", "text", "Extracted native MCP text"))); + FunctionResponse funcParam = + FunctionResponse.builder().name("test_tool").response(responseData).id("call_123").build(); + Part part = Part.builder().functionResponse(funcParam).build(); + + ContentBlockParam result = + (ContentBlockParam) partToAnthropicMessageBlockMethod.invoke(claude, part); + + ToolResultBlockParam toolResult = result.asToolResult(); + assertThat(toolResult.content().get().asString()).isEqualTo("Extracted native MCP text"); + } + + @Test + public void testPartToAnthropicMessageBlock_legacyResultKey() throws Exception { + Map responseData = ImmutableMap.of("result", "Legacy result text"); + FunctionResponse funcParam = + FunctionResponse.builder().name("test_tool").response(responseData).id("call_123").build(); + Part part = Part.builder().functionResponse(funcParam).build(); + + ContentBlockParam result = + (ContentBlockParam) partToAnthropicMessageBlockMethod.invoke(claude, part); + + ToolResultBlockParam toolResult = result.asToolResult(); + assertThat(toolResult.content().get().asString()).isEqualTo("Legacy result text"); + } + + @Test + public void testPartToAnthropicMessageBlock_jsonFallback() throws Exception { + Map responseData = ImmutableMap.of("custom_key", "custom_value"); + FunctionResponse funcParam = + FunctionResponse.builder().name("test_tool").response(responseData).id("call_123").build(); + Part part = Part.builder().functionResponse(funcParam).build(); + + ContentBlockParam result = + (ContentBlockParam) partToAnthropicMessageBlockMethod.invoke(claude, part); + + ToolResultBlockParam toolResult = result.asToolResult(); + assertThat(toolResult.content().get().asString()).contains("\"custom_key\":\"custom_value\""); + } +} diff --git a/core/src/test/java/com/google/adk/tools/mcp/AbstractMcpToolTest.java b/core/src/test/java/com/google/adk/tools/mcp/AbstractMcpToolTest.java new file mode 100644 index 000000000..e8d9ea631 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/mcp/AbstractMcpToolTest.java @@ -0,0 +1,62 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.mcp; + +import static com.google.common.truth.Truth.assertThat; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import java.util.List; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class AbstractMcpToolTest { + + private ObjectMapper objectMapper; + + @Before + public void setUp() { + objectMapper = new ObjectMapper(); + } + + @Test + public void testWrapCallResult_success() { + CallToolResult result = + CallToolResult.builder() + .content(ImmutableList.of(new TextContent("success"))) + .isError(false) + .build(); + + Map map = AbstractMcpTool.wrapCallResult(objectMapper, "my_tool", result); + + assertThat(map).containsKey("content"); + List content = (List) map.get("content"); + assertThat(content).hasSize(1); + + Map contentItem = (Map) content.get(0); + assertThat(contentItem).containsEntry("type", "text"); + assertThat(contentItem).containsEntry("text", "success"); + + assertThat(map).containsEntry("isError", false); + } +} From 6a5a55eb3e531c6f8a7083712308c4800f680ca5 Mon Sep 17 00:00:00 2001 From: "Ganesh, Mohan" Date: Thu, 26 Feb 2026 18:15:32 -0500 Subject: [PATCH 12/13] fix(firestore): Remove hardcoded dependency version --- contrib/firestore-session-service/pom.xml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/contrib/firestore-session-service/pom.xml b/contrib/firestore-session-service/pom.xml index ed1ecd09b..34b577984 100644 --- a/contrib/firestore-session-service/pom.xml +++ b/contrib/firestore-session-service/pom.xml @@ -14,7 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. --> - + 4.0.0 @@ -49,7 +51,6 @@ com.google.cloud google-cloud-firestore - 3.30.3 com.google.truth From 8ab7f072cdaa363e07b7a786044376c021c4c009 Mon Sep 17 00:00:00 2001 From: pkarmarkar Date: Tue, 6 Jan 2026 13:24:22 -0800 Subject: [PATCH 13/13] fix: add media/image support in Spring AI MessageConverter Previously, MessageConverter only transferred text content from ADK to Spring AI, ignoring image and media attachments. This caused vision model requests to fail even though Spring AI's underlying models (like GPT-4o) support image inputs. Updated MessageConverter to properly handle image/media parts by constructing UserMessage with Media attachments. Fixes #705 --- .../adk/models/springai/MessageConverter.java | 12 +- .../models/springai/MessageConverterTest.java | 181 +++++++++++++++++- 2 files changed, 183 insertions(+), 10 deletions(-) diff --git a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java index 036a898bb..3983b08a5 100644 --- a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java +++ b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java @@ -221,8 +221,7 @@ private List handleUserContent(Content content) { } catch (Exception e) { // Log warning but continue processing other parts // In production, consider proper logging framework - System.err.println( - "Warning: Failed to parse media mime type: " + blob.mimeType().get()); + System.err.println("Warning: Failed to process media part: " + e.getMessage()); } } } else if (part.fileData().isPresent()) { @@ -235,19 +234,14 @@ private List handleUserContent(Content content) { URI uri = URI.create(fileData.fileUri().get()); mediaList.add(new Media(mimeType, uri)); } catch (Exception e) { - System.err.println( - "Warning: Failed to parse media mime type: " + fileData.mimeType().get()); + System.err.println("Warning: Failed to process media part: " + e.getMessage()); } } } } List messages = new ArrayList<>(); - // Create UserMessage with text - // TODO: Media attachments support - UserMessage constructors with media are private in Spring - // AI 1.1.0 - // For now, only text content is supported - messages.add(new UserMessage(textBuilder.toString())); + messages.add(UserMessage.builder().text(textBuilder.toString()).media(mediaList).build()); messages.addAll(toolResponseMessages); return messages; diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java index a57644b5d..b861a71f2 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java @@ -60,7 +60,9 @@ void testToLlmPromptWithUserMessage() { assertThat(prompt.getInstructions()).hasSize(1); Message message = prompt.getInstructions().get(0); assertThat(message).isInstanceOf(UserMessage.class); - assertThat(((UserMessage) message).getText()).isEqualTo("Hello, how are you?"); + UserMessage userMessage = (UserMessage) message; + assertThat(userMessage.getText()).isEqualTo("Hello, how are you?"); + assertThat(userMessage.getMedia()).isEmpty(); } @Test @@ -444,4 +446,181 @@ void testCombineMultipleSystemMessagesForGeminiCompatibility() { assertThat(secondMessage).isInstanceOf(UserMessage.class); assertThat(((UserMessage) secondMessage).getText()).isEqualTo("Hello world"); } + + @Test + void testUserMessageWithInlineMediaData() { + // Test conversion of ADK Content with inline media (image bytes) to Spring AI UserMessage + byte[] imageData = "fake-image-data".getBytes(); + String mimeType = "image/png"; + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.fromText("What's in this image?"), + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType(mimeType) + .data(imageData) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + Message message = prompt.getInstructions().get(0); + assertThat(message).isInstanceOf(UserMessage.class); + + UserMessage userMessage = (UserMessage) message; + assertThat(userMessage.getText()).isEqualTo("What's in this image?"); + assertThat(userMessage.getMedia()).hasSize(1); + org.springframework.ai.content.Media media = userMessage.getMedia().get(0); + assertThat(media.getMimeType().toString()).isEqualTo(mimeType); + assertThat(media.getData()).isInstanceOf(byte[].class); + byte[] actualData = (byte[]) media.getData(); + assertThat(actualData).isEqualTo(imageData); + } + + @Test + void testUserMessageWithFileMediaData() { + // Test conversion of ADK Content with file-based media (URI) to Spring AI UserMessage + String fileUri = "gs://bucket/image.jpg"; + String mimeType = "image/jpeg"; + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.fromText("Analyze this image"), + Part.builder() + .fileData( + com.google.genai.types.FileData.builder() + .mimeType(mimeType) + .fileUri(fileUri) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + Message message = prompt.getInstructions().get(0); + assertThat(message).isInstanceOf(UserMessage.class); + + UserMessage userMessage = (UserMessage) message; + assertThat(userMessage.getText()).isEqualTo("Analyze this image"); + assertThat(userMessage.getMedia()).hasSize(1); + org.springframework.ai.content.Media media = userMessage.getMedia().get(0); + assertThat(media.getMimeType().toString()).isEqualTo(mimeType); + assertThat(media.getData()).isInstanceOf(String.class); + String actualUri = (String) media.getData(); + assertThat(actualUri).isEqualTo(fileUri); + } + + @Test + void testUserMessageWithMultipleMediaAttachments() { + // Test conversion with multiple media attachments + byte[] image1 = "image1-data".getBytes(); + byte[] image2 = "image2-data".getBytes(); + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.fromText("Compare these images"), + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType("image/png") + .data(image1) + .build()) + .build(), + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType("image/jpeg") + .data(image2) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + UserMessage userMessage = (UserMessage) prompt.getInstructions().get(0); + assertThat(userMessage.getText()).isEqualTo("Compare these images"); + assertThat(userMessage.getMedia()).hasSize(2); + } + + @Test + void testUserMessageWithInvalidMimeTypeGracefullySkipsMediaPart() { + // Test that an invalid MIME type string causes the media part to be skipped gracefully + byte[] imageData = "fake-image-data".getBytes(); + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.fromText("What's in this image?"), + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType("invalid/mime/type!!!") // invalid MIME type + .data(imageData) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + // Should not throw — invalid MIME type is silently skipped + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + UserMessage userMessage = (UserMessage) prompt.getInstructions().get(0); + assertThat(userMessage.getText()).isEqualTo("What's in this image?"); + // Media part is skipped due to invalid MIME type + assertThat(userMessage.getMedia()).isEmpty(); + } + + @Test + void testUserMessageWithMediaOnly() { + // Test conversion with media but no text + byte[] imageData = "image-only".getBytes(); + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType("image/png") + .data(imageData) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + UserMessage userMessage = (UserMessage) prompt.getInstructions().get(0); + assertThat(userMessage.getText()).isEmpty(); + assertThat(userMessage.getMedia()).hasSize(1); + } }