Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions py/packages/genkit/src/genkit/blocks/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,11 @@ async def _resolve_tool_request(tool: Action, tool_request_part: ToolRequestPart
Part(
tool_request=tool_request_part.tool_request,
metadata={
**(tool_request_part.metadata if tool_request_part.metadata else {}),
**(
tool_request_part.metadata.root
if isinstance(tool_request_part.metadata, Metadata)
else (tool_request_part.metadata or {})
),
'interrupt': (interrupt_error.metadata if interrupt_error.metadata else True),
},
),
Expand Down Expand Up @@ -815,7 +819,7 @@ def _find_corresponding_tool_response(responses: list[ToolResponsePart], request
"""
for p in responses:
if p.tool_response.name == request.tool_request.name and p.tool_response.ref == request.tool_request.ref:
return p
return Part(root=p)
return None


Expand Down
11 changes: 9 additions & 2 deletions py/packages/genkit/src/genkit/blocks/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any

from genkit.core.action import ActionRunContext
from genkit.core.typing import Part, ToolRequestPart, ToolResponse
from genkit.core.typing import Metadata, Part, ToolRequestPart, ToolResponse


class ToolRunContext(ActionRunContext):
Expand Down Expand Up @@ -96,13 +96,20 @@ def tool_response(
"""
# TODO: validate against tool schema
tool_request = interrupt.root.tool_request if isinstance(interrupt, Part) else interrupt.tool_request

interrupt_metadata = True
if isinstance(metadata, Metadata):
interrupt_metadata = metadata.root
elif metadata:
interrupt_metadata = metadata

return Part(
tool_response=ToolResponse(
name=tool_request.name,
ref=tool_request.ref,
output=response_data,
),
metadata={
'interruptResponse': metadata if metadata else True,
'interruptResponse': interrupt_metadata,
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
GeminiEmbeddingModels,
VertexEmbeddingModels,
)
from genkit.plugins.google_genai.models.gemini import GeminiConfigSchema, GoogleAIGeminiVersion, VertexAIGeminiVersion
from genkit.plugins.google_genai.models.gemini import (
GeminiConfigSchema,
GeminiImageConfigSchema,
GoogleAIGeminiVersion,
VertexAIGeminiVersion,
)
from genkit.plugins.google_genai.models.imagen import ImagenVersion


Expand All @@ -45,5 +50,6 @@ def package_name() -> str:
VertexAIGeminiVersion.__name__,
EmbeddingTaskType.__name__,
GeminiConfigSchema.__name__,
GeminiImageConfigSchema.__name__,
ImagenVersion.__name__,
]
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ class GeminiConfigSchema(genai_types.GenerateContentConfig):

code_execution: bool | None = None
response_modalities: list[str] | None = None
thinking_config: dict[str, Any] | None = None
file_search: dict[str, Any] | None = None
url_context: dict[str, Any] | None = None
api_version: str | None = None


class GeminiTtsConfigSchema(GeminiConfigSchema):
Expand Down Expand Up @@ -678,6 +682,11 @@ def _create_tool(self, tool: ToolDefinition) -> genai_types.Tool:
Genai tool compatible with Gemini API.
"""
params = self._convert_schema_property(tool.input_schema)
# Fix for no-arg tools: parameters cannot be None if we want the tool to be callable?
# Actually Google GenAI expects type=OBJECT for params usually.
if not params:
params = genai_types.Schema(type=genai_types.Type.OBJECT, properties={})

function = genai_types.FunctionDeclaration(
name=tool.name,
description=tool.description,
Expand Down Expand Up @@ -741,7 +750,7 @@ def _convert_schema_property(

if schema_type == genai_types.Type.OBJECT:
schema.properties = {}
properties = input_schema['properties']
properties = input_schema.get('properties', {})
for key in properties:
nested_schema = self._convert_schema_property(properties[key], defs)
schema.properties[key] = nested_schema
Expand Down Expand Up @@ -844,13 +853,59 @@ async def generate(self, request: GenerateRequest, ctx: ActionRunContext) -> Gen
if cached_content:
request_cfg.cached_content = cached_content.name

client = self._client
# If config specifies an api_version different from default (e.g. 'v1alpha'),
# Create a temporary client with that version, since api_version is a client-level setting.
api_version = None
if request.config:
api_version = getattr(request.config, 'api_version', None)
if not api_version and isinstance(request.config, dict):
api_version = request.config.get('api_version')

if api_version:
# TODO: Request public API from google-genai maintainers.
# Currently, there is no public way to access the configured api_key, project, or location
# from an existing Client instance. We need to access the private _api_client to
# clone the configuration when overriding the api_version.
# This is brittle and relies on internal implementation details of the google-genai library.
# If the library changes its internal structure (e.g. renames _api_client or _credentials),
# this code WILL BREAK.
api_client = self._client._api_client
kwargs = {
'vertexai': api_client.vertexai,
'http_options': {'api_version': api_version},
}
if api_client.vertexai:
# Vertex AI mode: requires project/location (api_key is optional/unlikely)
if api_client.project:
kwargs['project'] = api_client.project
if api_client.location:
kwargs['location'] = api_client.location
if api_client._credentials:
kwargs['credentials'] = api_client._credentials
# Don't pass api_key if we are in Vertex AI mode with credentials/project
else:
# Google AI mode: primarily uses api_key
if api_client.api_key:
kwargs['api_key'] = api_client.api_key
# Do NOT pass project/location/credentials if in Google AI mode to be safe
if api_client._credentials and not kwargs.get('api_key'):
# Fallback if no api_key but credentials present (unlikely for pure Google AI but possible)
kwargs['credentials'] = api_client._credentials

client = genai.Client(**kwargs)

if ctx.is_streaming:
response = await self._streaming_generate(
request_contents=request_contents, request_cfg=request_cfg, ctx=ctx, model_name=model_name
request_contents=request_contents,
request_cfg=request_cfg,
ctx=ctx,
model_name=model_name,
client=client,
)
else:
response = await self._generate(
request_contents=request_contents, request_cfg=request_cfg, model_name=model_name
request_contents=request_contents, request_cfg=request_cfg, model_name=model_name, client=client
)

response.usage = self._create_usage_stats(request=request, response=response)
Expand All @@ -862,6 +917,7 @@ async def _generate(
request_contents: list[genai_types.Content],
request_cfg: genai_types.GenerateContentConfig,
model_name: str,
client: genai.Client | None = None,
) -> GenerateResponse:
"""Call google-genai generate.

Expand All @@ -885,7 +941,8 @@ async def _generate(
fallback=lambda _: '[!! failed to serialize !!]',
),
)
response = await self._client.aio.models.generate_content(
client = client or self._client
response = await client.aio.models.generate_content(
model=model_name, contents=request_contents, config=request_cfg
)
span.set_attribute('genkit:output', dump_json(response))
Expand All @@ -905,6 +962,7 @@ async def _streaming_generate(
request_cfg: genai_types.GenerateContentConfig | None,
ctx: ActionRunContext,
model_name: str,
client: genai.Client | None = None,
) -> GenerateResponse:
"""Call google-genai generate for streaming.

Expand All @@ -926,7 +984,8 @@ async def _streaming_generate(
'model': model_name,
}),
)
generator = self._client.aio.models.generate_content_stream(
client = client or self._client
generator = client.aio.models.generate_content_stream(
model=model_name, contents=request_contents, config=request_cfg
)
accumulated_content = []
Expand Down Expand Up @@ -989,7 +1048,11 @@ async def _build_messages(
continue
content_parts: list[genai_types.Part] = []
for p in msg.content:
content_parts.append(PartConverter.to_gemini(p))
converted = PartConverter.to_gemini(p)
if isinstance(converted, list):
content_parts.extend(converted)
else:
content_parts.append(converted)
request_contents.append(genai_types.Content(parts=content_parts, role=msg.role))

if msg.metadata and msg.metadata.get('cache'):
Expand Down Expand Up @@ -1050,7 +1113,19 @@ def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.Gener
if request_config.code_execution:
tools.extend([genai_types.Tool(code_execution=genai_types.ToolCodeExecution())])
elif isinstance(request_config, dict):
cfg = genai_types.GenerateContentConfig(**request_config)
if 'image_config' in request_config:
cfg = GeminiImageConfigSchema(**request_config)
elif 'speech_config' in request_config:
cfg = GeminiTtsConfigSchema(**request_config)
else:
cfg = GeminiConfigSchema(**request_config)

if isinstance(cfg, GeminiConfigSchema):
dumped_config = cfg.model_dump(exclude_none=True)
for key in ['code_execution', 'file_search', 'url_context', 'api_version']:
if key in dumped_config:
del dumped_config[key]
cfg = genai_types.GenerateContentConfig(**dumped_config)

if request.output:
if not cfg:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,49 @@ def to_gemini(cls, part: Part) -> genai.types.Part:
thought_signature=cls._extract_thought_signature(part.root.metadata),
)
if isinstance(part.root, ToolResponsePart):
tool_output = part.root.tool_response.output
parts_to_return = []

# Check for multimodal content structure {content: [{media: ...}]}
if isinstance(tool_output, dict) and 'content' in tool_output:
content_list = tool_output['content']
if isinstance(content_list, list):
# Create a copy to avoid mutating original if that matters,
# but here we just want to separate content from other fields.
clean_output = tool_output.copy()
clean_output.pop('content')

# Heuristic: if media found, extract it to separate parts.
has_media = False
for item in content_list:
if isinstance(item, dict) and 'media' in item:
has_media = True
media_info = item['media']
url = media_info.get('url')
content_type = media_info.get('contentType') or media_info.get('content_type')

if url and url.startswith(cls.DATA):
_, data_str = url.split(',', 1)
data = base64.b64decode(data_str)
parts_to_return.append(
genai.types.Part(inline_data=genai.types.Blob(mime_type=content_type, data=data))
)

if has_media:
# Append the function response part FIRST (contextually correct)
parts_to_return.insert(
0,
genai.types.Part(
function_response=genai.types.FunctionResponse(
id=part.root.tool_response.ref,
name=part.root.tool_response.name.replace('/', '__'),
response=clean_output,
)
),
)
return parts_to_return

# Default behavior for standard tool responses
return genai.types.Part(
function_response=genai.types.FunctionResponse(
id=part.root.tool_response.ref,
Expand Down Expand Up @@ -167,7 +210,7 @@ def from_gemini(cls, part: genai.types.Part, ref: str | None = None) -> Part:
metadata=cls._encode_thought_signature(part.thought_signature),
)
)
if part.text:
if part.text is not None:
return Part(root=TextPart(text=part.text))
if part.function_call:
return Part(
Expand Down
Binary file added py/samples/google-genai-hello/my_room.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added py/samples/google-genai-hello/palm_tree.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added py/samples/google-genai-hello/photo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading