From 8e40b2c9e1a80faa4edb9b0f673bf5cfc8a904bc Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Sat, 28 Mar 2026 03:28:07 -0700 Subject: [PATCH] feat: Add ChatCompletionsRequest object This is part of a larger chain of commits for adding chat completion API support to the Apigee model. PiperOrigin-RevId: 890876818 --- .../adk/models/ChatCompletionsCommon.java | 57 ++++ .../adk/models/ChatCompletionsRequest.java | 316 ++++++++++++++++++ .../adk/models/ChatCompletionsResponse.java | 33 +- .../models/ChatCompletionsRequestTest.java | 192 +++++++++++ .../models/ChatCompletionsResponseTest.java | 4 +- 5 files changed, 570 insertions(+), 32 deletions(-) create mode 100644 core/src/main/java/com/google/adk/models/ChatCompletionsCommon.java create mode 100644 core/src/main/java/com/google/adk/models/ChatCompletionsRequest.java create mode 100644 core/src/test/java/com/google/adk/models/ChatCompletionsRequestTest.java diff --git a/core/src/main/java/com/google/adk/models/ChatCompletionsCommon.java b/core/src/main/java/com/google/adk/models/ChatCompletionsCommon.java new file mode 100644 index 000000000..cf21c06de --- /dev/null +++ b/core/src/main/java/com/google/adk/models/ChatCompletionsCommon.java @@ -0,0 +1,57 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Map; + +/** Shared models for Chat Completions Request and Response. */ +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +final class ChatCompletionsCommon { + + private ChatCompletionsCommon() {} + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class ToolCall { + public Integer index; // Applicable for response streaming, null in requests + public String id; + public String type; + public Function function; + public Custom custom; + + @JsonProperty("extra_content") + public Map extraContent; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class Function { + public String name; + public String arguments; // JSON string + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class Custom { + public String input; + public String name; + } +} diff --git a/core/src/main/java/com/google/adk/models/ChatCompletionsRequest.java b/core/src/main/java/com/google/adk/models/ChatCompletionsRequest.java new file mode 100644 index 000000000..1dbe6fdc3 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/ChatCompletionsRequest.java @@ -0,0 +1,316 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonValue; +import java.util.List; +import java.util.Map; + +/** Data Transfer Objects for Chat Completion API requests. */ +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +final class ChatCompletionsRequest { + + public List messages; + public String model; + public AudioParam audio; + + @JsonProperty("frequency_penalty") + public Double frequencyPenalty; + + @JsonProperty("logit_bias") + public Map logitBias; + + public Boolean logprobs; + + @JsonProperty("max_completion_tokens") + public Integer maxCompletionTokens; + + @JsonProperty("max_tokens") + public Integer maxTokens; + + public Map metadata; + public List modalities; + public Integer n; + + @JsonProperty("parallel_tool_calls") + public Boolean parallelToolCalls; + + public Prediction prediction; + + @JsonProperty("presence_penalty") + public Double presencePenalty; + + @JsonProperty("prompt_cache_key") + public String promptCacheKey; + + @JsonProperty("prompt_cache_retention") + public String promptCacheRetention; + + @JsonProperty("reasoning_effort") + public String reasoningEffort; + + @JsonProperty("response_format") + public ResponseFormat responseFormat; + + @JsonProperty("safety_identifier") + public String safetyIdentifier; + + public Long seed; + + @JsonProperty("service_tier") + public String serviceTier; + + public Object stop; // Can be String or List + public Boolean store; + public Boolean stream; + + @JsonProperty("stream_options") + public StreamOptions streamOptions; + + public Double temperature; + + @JsonProperty("tool_choice") + public ToolChoice toolChoice; + + public List tools; + + @JsonProperty("top_logprobs") + public Integer topLogprobs; + + @JsonProperty("top_p") + public Double topP; + + public String user; + public String verbosity; + + @JsonProperty("web_search_options") + public WebSearchOptions webSearchOptions; + + @JsonProperty("extra_body") + public Map extraBody; + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class Message { + public String role; + public Object content; // Can be String or List + public String name; + + @JsonProperty("tool_calls") + public List toolCalls; + + @JsonProperty("function_call") + public FunctionCall functionCall; // Deprecated + + @JsonProperty("tool_call_id") + public String toolCallId; + + public Audio audio; // For assistant messages with audio + public String refusal; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class ContentPart { + public String type; + public String text; + public String refusal; + + @JsonProperty("image_url") + public ImageUrl imageUrl; + + @JsonProperty("input_audio") + public InputAudio inputAudio; + + public File file; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class ImageUrl { + public String url; + public String detail; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class InputAudio { + public String data; + public String format; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class File { + @JsonProperty("file_data") + public String fileData; + + @JsonProperty("file_id") + public String fileId; + + public String filename; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class FunctionCall { + public String name; + public String arguments; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class AudioParam { + public String format; + public Object voice; // Can be String or Map + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class Audio { + public String id; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class Prediction { + public String type; + public Object content; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class StreamOptions { + @JsonProperty("include_obfuscation") + public Boolean includeObfuscation; + + @JsonProperty("include_usage") + public Boolean includeUsage; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class Tool { + public String type; + public FunctionDefinition function; + public CustomTool custom; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class FunctionDefinition { + public String name; + public String description; + public Object parameters; // JSON Schema + public Boolean strict; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class CustomTool { + public String name; + public String description; + public Object format; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class WebSearchOptions { + @JsonProperty("search_context_size") + public String searchContextSize; + + @JsonProperty("user_location") + public UserLocation userLocation; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class UserLocation { + public String type; + public ApproximateLocation approximate; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class ApproximateLocation { + public String city; + public String country; + public String region; + public String timezone; + } + + interface ResponseFormat {} + + static class ResponseFormatText implements ResponseFormat { + public String type = "text"; + } + + static class ResponseFormatJsonObject implements ResponseFormat { + public String type = "json_object"; + } + + static class ResponseFormatJsonSchema implements ResponseFormat { + public String type = "json_schema"; + + @JsonProperty("json_schema") + public JsonSchema jsonSchema; + + static class JsonSchema { + public String name; + public String description; + public Object schema; + public Boolean strict; + } + } + + interface ToolChoice {} + + static class ToolChoiceMode implements ToolChoice { + private final String mode; + + public ToolChoiceMode(String mode) { + this.mode = mode; + } + + @JsonValue + public String getMode() { + return mode; + } + } + + static class NamedToolChoice implements ToolChoice { + public String type = "function"; + public FunctionName function; + + static class FunctionName { + public String name; + } + } + + static class NamedToolChoiceCustom implements ToolChoice { + public String type = "custom"; + public CustomName custom; + + static class CustomName { + public String name; + } + } +} diff --git a/core/src/main/java/com/google/adk/models/ChatCompletionsResponse.java b/core/src/main/java/com/google/adk/models/ChatCompletionsResponse.java index fe5cdd116..eed726069 100644 --- a/core/src/main/java/com/google/adk/models/ChatCompletionsResponse.java +++ b/core/src/main/java/com/google/adk/models/ChatCompletionsResponse.java @@ -20,7 +20,6 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; -import java.util.Map; /** * Data Transfer Objects for Chat Completion and Chat Completion Chunk API responses. @@ -93,42 +92,16 @@ static class Message { public String role; @JsonProperty("tool_calls") - public List toolCalls; + public List toolCalls; - // function_call is not supported in ChatCompletionChunk and ChatCompletion support is - // deprecated. + // function_call is deprecated. @JsonProperty("function_call") - public Function functionCall; // Fallback for deprecated top-level function calls + public ChatCompletionsCommon.Function functionCall; public List annotations; public Audio audio; } - @JsonIgnoreProperties(ignoreUnknown = true) - static class ToolCall { - // Index is only used in ChatCompletionChunk. - public Integer index; - public String id; - public String type; - public Function function; - public Custom custom; - - @JsonProperty("extra_content") - public Map extraContent; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class Function { - public String name; - public String arguments; // JSON string - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class Custom { - public String input; - public String name; - } - @JsonIgnoreProperties(ignoreUnknown = true) static class Logprobs { public List content; diff --git a/core/src/test/java/com/google/adk/models/ChatCompletionsRequestTest.java b/core/src/test/java/com/google/adk/models/ChatCompletionsRequestTest.java new file mode 100644 index 000000000..6a1a418a9 --- /dev/null +++ b/core/src/test/java/com/google/adk/models/ChatCompletionsRequestTest.java @@ -0,0 +1,192 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models; + +import static com.google.common.truth.Truth.assertThat; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import java.util.HashMap; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class ChatCompletionsRequestTest { + + private ObjectMapper objectMapper; + + @Before + public void setUp() { + objectMapper = new ObjectMapper(); + } + + @Test + public void testSerializeChatCompletionRequest_standard() throws Exception { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + + ChatCompletionsRequest.Message message = new ChatCompletionsRequest.Message(); + message.role = "user"; + message.content = "Hello"; + request.messages = ImmutableList.of(message); + + String json = objectMapper.writeValueAsString(request); + + assertThat(json).contains("\"model\":\"gemini-3-flash-preview\""); + assertThat(json).contains("\"role\":\"user\""); + assertThat(json).contains("\"content\":\"Hello\""); + } + + @Test + public void testSerializeChatCompletionRequest_withExtraBody() throws Exception { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + + ChatCompletionsRequest.Message message = new ChatCompletionsRequest.Message(); + message.role = "user"; + message.content = "Explain to me how AI works"; + request.messages = ImmutableList.of(message); + + Map thinkingConfig = new HashMap<>(); + thinkingConfig.put("thinking_level", "low"); + thinkingConfig.put("include_thoughts", true); + + Map google = new HashMap<>(); + google.put("thinking_config", thinkingConfig); + + Map extraBody = new HashMap<>(); + extraBody.put("google", google); + + request.extraBody = extraBody; + + String json = objectMapper.writeValueAsString(request); + + assertThat(json).contains("\"extra_body\":{"); + assertThat(json).contains("\"thinking_level\":\"low\""); + assertThat(json).contains("\"include_thoughts\":true"); + } + + @Test + public void testSerializeChatCompletionRequest_withToolCallsAndExtraContent() throws Exception { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + + ChatCompletionsRequest.Message userMessage = new ChatCompletionsRequest.Message(); + userMessage.role = "user"; + userMessage.content = "Check flight status"; + + ChatCompletionsRequest.Message modelMessage = new ChatCompletionsRequest.Message(); + modelMessage.role = "model"; + + ChatCompletionsCommon.ToolCall toolCall = new ChatCompletionsCommon.ToolCall(); + toolCall.id = "function-call-1"; + toolCall.type = "function"; + + ChatCompletionsCommon.Function function = new ChatCompletionsCommon.Function(); + function.name = "check_flight"; + function.arguments = "{\"flight\":\"AA100\"}"; + toolCall.function = function; + + Map google = new HashMap<>(); + google.put("thought_signature", ""); + + Map extraContent = new HashMap<>(); + extraContent.put("google", google); + + toolCall.extraContent = extraContent; + + modelMessage.toolCalls = ImmutableList.of(toolCall); + + ChatCompletionsRequest.Message toolMessage = new ChatCompletionsRequest.Message(); + toolMessage.role = "tool"; + toolMessage.name = "check_flight"; + toolMessage.toolCallId = "function-call-1"; + toolMessage.content = "{\"status\":\"delayed\"}"; + + request.messages = ImmutableList.of(userMessage, modelMessage, toolMessage); + + String json = objectMapper.writeValueAsString(request); + + assertThat(json).contains("\"role\":\"user\""); + assertThat(json).contains("\"role\":\"model\""); + assertThat(json).contains("\"role\":\"tool\""); + assertThat(json).contains("\"extra_content\":{"); + assertThat(json).contains("\"thought_signature\":\"\""); + assertThat(json).contains("\"tool_call_id\":\"function-call-1\""); + } + + @Test + public void testSerializeChatCompletionRequest_comprehensive() throws Exception { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + + // Developer message with name + ChatCompletionsRequest.Message devMsg = new ChatCompletionsRequest.Message(); + devMsg.role = "developer"; + devMsg.content = "System instruction"; + devMsg.name = "system-bot"; + + request.messages = ImmutableList.of(devMsg); + + // Response Format JSON Schema + ChatCompletionsRequest.ResponseFormatJsonSchema format = + new ChatCompletionsRequest.ResponseFormatJsonSchema(); + format.jsonSchema = new ChatCompletionsRequest.ResponseFormatJsonSchema.JsonSchema(); + format.jsonSchema.name = "MySchema"; + format.jsonSchema.strict = true; + request.responseFormat = format; + + // Tool Choice Named + ChatCompletionsRequest.NamedToolChoice choice = new ChatCompletionsRequest.NamedToolChoice(); + choice.function = new ChatCompletionsRequest.NamedToolChoice.FunctionName(); + choice.function.name = "my_function"; + request.toolChoice = choice; + + String json = objectMapper.writeValueAsString(request); + + // Assert Developer Message + assertThat(json).contains("\"role\":\"developer\""); + assertThat(json).contains("\"name\":\"system-bot\""); + assertThat(json).contains("\"content\":\"System instruction\""); + + // Assert Response Format + assertThat(json).contains("\"response_format\":{"); + assertThat(json).contains("\"type\":\"json_schema\""); + assertThat(json).contains("\"name\":\"MySchema\""); + assertThat(json).contains("\"strict\":true"); + + // Assert Tool Choice + assertThat(json).contains("\"tool_choice\":{"); + assertThat(json).contains("\"type\":\"function\""); + assertThat(json).contains("\"name\":\"my_function\""); + } + + @Test + public void testSerializeChatCompletionRequest_withToolChoiceMode() throws Exception { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + + request.toolChoice = new ChatCompletionsRequest.ToolChoiceMode("none"); + + String json = objectMapper.writeValueAsString(request); + + assertThat(json).contains("\"tool_choice\":\"none\""); + } +} diff --git a/core/src/test/java/com/google/adk/models/ChatCompletionsResponseTest.java b/core/src/test/java/com/google/adk/models/ChatCompletionsResponseTest.java index 53fcdfbdf..1faaf4446 100644 --- a/core/src/test/java/com/google/adk/models/ChatCompletionsResponseTest.java +++ b/core/src/test/java/com/google/adk/models/ChatCompletionsResponseTest.java @@ -245,7 +245,7 @@ public void testDeserializeChatCompletion_withCustomToolCall() throws Exception objectMapper.readValue(json, ChatCompletionsResponse.ChatCompletion.class); assertThat(completion.choices.get(0).message.toolCalls).hasSize(1); - ChatCompletionsResponse.ToolCall toolCall = completion.choices.get(0).message.toolCalls.get(0); + ChatCompletionsCommon.ToolCall toolCall = completion.choices.get(0).message.toolCalls.get(0); assertThat(toolCall.type).isEqualTo("custom"); assertThat(toolCall.custom.name).isEqualTo("custom_tool"); assertThat(toolCall.custom.input).isEqualTo("{\"arg\":\"val\"}"); @@ -310,7 +310,7 @@ public void testDeserializeChatCompletionChunk_withToolCallDelta() throws Except ChatCompletionChunk chunk = objectMapper.readValue(json, ChatCompletionChunk.class); assertThat(chunk.choices.get(0).delta.toolCalls).hasSize(1); - ChatCompletionsResponse.ToolCall toolCall = chunk.choices.get(0).delta.toolCalls.get(0); + ChatCompletionsCommon.ToolCall toolCall = chunk.choices.get(0).delta.toolCalls.get(0); assertThat(toolCall.index).isEqualTo(1); assertThat(toolCall.id).isEqualTo("call_abc"); assertThat(toolCall.type).isEqualTo("function");