diff --git a/src/mistralai/client/conversations.py b/src/mistralai/client/conversations.py index 49810eb6..f33c557b 100644 --- a/src/mistralai/client/conversations.py +++ b/src/mistralai/client/conversations.py @@ -28,6 +28,17 @@ ) from mistralai.extra.run.utils import run_requirements from mistralai.extra.observability.otel import GenAISpanEnum, get_or_create_otel_tracer +from mistralai.extra.exceptions import ( + DeferralReason, + DeferredToolCallsException, + DeferredToolCallEntry, + DeferredToolCallResponse, +) +from mistralai.extra.run.deferred import ( + _is_deferred_response, + _is_server_deferred, + _process_deferred_responses, +) logger = logging.getLogger(__name__) tracing_enabled, tracer = get_or_create_otel_tracer() @@ -48,7 +59,11 @@ class Conversations(BaseSDK): async def run_async( self, run_ctx: "RunContext", - inputs: Union[models.ConversationInputs, models.ConversationInputsTypedDict], + inputs: Union[ + models.ConversationInputs, + models.ConversationInputsTypedDict, + List[DeferredToolCallResponse], + ], instructions: OptionalNullable[str] = UNSET, tools: OptionalNullable[ Union[ @@ -68,16 +83,44 @@ async def run_async( ) -> RunResult: """Run a conversation with the given inputs and context. - The execution of a run will only stop when no required local execution can be done.""" + The execution of a run will only stop when no required local execution can be done. + + Inputs can be: + - Regular conversation inputs (messages, function results, etc.) + - DeferredToolResponse objects (from deferred.confirm(), reject()) + + When passing DeferredToolResponse objects, the SDK will: + - Execute confirmed tools automatically + - Convert rejections to function results with the rejection message + """ from mistralai.client.beta import Beta # pylint: disable=import-outside-toplevel from mistralai.extra.run.context import _validate_run # pylint: disable=import-outside-toplevel from mistralai.extra.run.tools import get_function_calls # pylint: disable=import-outside-toplevel + # Check if inputs contain deferred responses - process them + pending_tool_confirmations: Optional[List[models.ToolCallConfirmation]] = None + if inputs and isinstance(inputs, list): + deferred_inputs = typing.cast( + List[DeferredToolCallResponse], + [i for i in inputs if _is_deferred_response(i)], + ) + other_inputs = typing.cast( + List[InputEntries], [i for i in inputs if not _is_deferred_response(i)] + ) + if deferred_inputs: + ( + processed, + pending_tool_confirmations, + ) = await _process_deferred_responses(run_ctx, deferred_inputs) + inputs = other_inputs + processed + if not pending_tool_confirmations: + pending_tool_confirmations = None + with tracer.start_as_current_span(GenAISpanEnum.VALIDATE_RUN.value): req, run_result, input_entries = await _validate_run( beta_client=Beta(self.sdk_configuration), run_ctx=run_ctx, - inputs=inputs, + inputs=typing.cast(List[InputEntries], inputs), instructions=instructions, tools=tools, completion_args=completion_args, @@ -105,26 +148,68 @@ async def run_async( res = await self.append_async( conversation_id=run_ctx.conversation_id, inputs=input_entries, + tool_confirmations=pending_tool_confirmations, retries=retries, server_url=server_url, timeout_ms=timeout_ms, + http_headers=http_headers, ) + # Clear after first use + pending_tool_confirmations = None run_ctx.request_count += 1 run_result.output_entries.extend(res.outputs) fcalls = get_function_calls(res.outputs) if not fcalls: logger.debug("No more function calls to execute") break - fresults = await run_ctx.execute_function_calls(fcalls) - run_result.output_entries.extend(fresults) - input_entries = typing.cast(list[InputEntries], fresults) + + # Partition by permission: include server-side deferred calls + to_defer = [ + fc + for fc in fcalls + if run_ctx.requires_confirmation(fc.name) or _is_server_deferred(fc) + ] + to_execute = [fc for fc in fcalls if fc not in to_defer] + + # Execute approved + fresults = [] + if to_execute: + fresults = await run_ctx.execute_function_calls(to_execute) + run_result.output_entries.extend(fresults) + input_entries = typing.cast(list[InputEntries], fresults) + + # Defer the rest - include executed_results so user can pass them back + if to_defer: + deferred_objects = [ + DeferredToolCallEntry( + fc, + reason=DeferralReason.SERVER_SIDE_CONFIRMATION_REQUIRED + if _is_server_deferred(fc) + else DeferralReason.CONFIRMATION_REQUIRED, + ) + for fc in to_defer + ] + raise DeferredToolCallsException( + run_ctx.conversation_id, + deferred_objects, + run_result.output_entries, + executed_results=fresults, + ) + + # If we only executed tools (none deferred), continue the loop + if not to_execute: + break return run_result @run_requirements async def run_stream_async( self, run_ctx: "RunContext", - inputs: Union[models.ConversationInputs, models.ConversationInputsTypedDict], + inputs: Union[ + models.ConversationInputs, + models.ConversationInputsTypedDict, + List[DeferredToolCallResponse], + ], instructions: OptionalNullable[str] = UNSET, tools: OptionalNullable[ Union[ @@ -144,23 +229,48 @@ async def run_stream_async( ) -> AsyncGenerator[Union[RunResultEvents, RunResult], None]: """Similar to `run_async` but returns a generator which streams events. - The last streamed object is the RunResult object which summarises what happened in the run.""" + The last streamed object is the RunResult object which summarises what happened in the run. + + Inputs can be: + - Regular conversation inputs (messages, function results, etc.) + - DeferredToolResponse objects (from deferred.confirm(), reject()) + """ from mistralai.client.beta import Beta # pylint: disable=import-outside-toplevel from mistralai.extra.run.context import _validate_run # pylint: disable=import-outside-toplevel from mistralai.extra.run.tools import get_function_calls # pylint: disable=import-outside-toplevel + # Check if inputs contain deferred responses - process them + pending_tool_confirmations: Optional[List[models.ToolCallConfirmation]] = None + if inputs and isinstance(inputs, list): + deferred_inputs = typing.cast( + List[DeferredToolCallResponse], + [i for i in inputs if _is_deferred_response(i)], + ) + other_inputs = typing.cast( + List[InputEntries], [i for i in inputs if not _is_deferred_response(i)] + ) + if deferred_inputs: + ( + processed, + pending_tool_confirmations, + ) = await _process_deferred_responses(run_ctx, deferred_inputs) + inputs = other_inputs + processed + if not pending_tool_confirmations: + pending_tool_confirmations = None + req, run_result, input_entries = await _validate_run( beta_client=Beta(self.sdk_configuration), run_ctx=run_ctx, - inputs=inputs, + inputs=typing.cast(List[InputEntries], inputs), instructions=instructions, tools=tools, completion_args=completion_args, ) - async def run_generator() -> ( - AsyncGenerator[Union[RunResultEvents, RunResult], None] - ): + async def run_generator() -> AsyncGenerator[ + Union[RunResultEvents, RunResult], None + ]: + nonlocal pending_tool_confirmations current_entries = input_entries while True: received_event_tracker: defaultdict[ @@ -181,10 +291,13 @@ async def run_generator() -> ( res = await self.append_stream_async( conversation_id=run_ctx.conversation_id, inputs=current_entries, + tool_confirmations=pending_tool_confirmations, retries=retries, server_url=server_url, timeout_ms=timeout_ms, ) + # Clear after first use + pending_tool_confirmations = None async for event in res: if ( isinstance(event.data, ResponseStartedEvent) @@ -207,18 +320,52 @@ async def run_generator() -> ( if not fcalls: logger.debug("No more function calls to execute") break - fresults = await run_ctx.execute_function_calls(fcalls) - run_result.output_entries.extend(fresults) - for fresult in fresults: - yield RunResultEvents( - event="function.result", - data=FunctionResultEvent( - type="function.result", - result=fresult.result, - tool_call_id=fresult.tool_call_id, - ), + + # Partition by permission: include server-side deferred calls + to_defer = [ + fc + for fc in fcalls + if run_ctx.requires_confirmation(fc.name) or _is_server_deferred(fc) + ] + to_execute = [fc for fc in fcalls if fc not in to_defer] + + # Execute approved + fresults = [] + if to_execute: + fresults = await run_ctx.execute_function_calls(to_execute) + run_result.output_entries.extend(fresults) + for fresult in fresults: + yield RunResultEvents( + event="function.result", + data=FunctionResultEvent( + type="function.result", + result=fresult.result, + tool_call_id=fresult.tool_call_id, + ), + ) + current_entries = typing.cast(list[InputEntries], fresults) + + # Defer the rest - include executed_results so user can pass them back + if to_defer: + deferred_objects = [ + DeferredToolCallEntry( + fc, + reason=DeferralReason.SERVER_SIDE_CONFIRMATION_REQUIRED + if _is_server_deferred(fc) + else DeferralReason.CONFIRMATION_REQUIRED, + ) + for fc in to_defer + ] + raise DeferredToolCallsException( + run_ctx.conversation_id, + deferred_objects, + run_result.output_entries, + executed_results=fresults, ) - current_entries = typing.cast(list[InputEntries], fresults) + + # If we only executed tools (none deferred), continue the loop + if not to_execute: + break yield run_result return run_generator() diff --git a/src/mistralai/extra/exceptions.py b/src/mistralai/extra/exceptions.py index d2cd3e79..86a39917 100644 --- a/src/mistralai/extra/exceptions.py +++ b/src/mistralai/extra/exceptions.py @@ -1,7 +1,19 @@ -from typing import Optional, TYPE_CHECKING +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Optional, Union, TYPE_CHECKING +import typing + +from mistralai.client.models import ( + FunctionCallEntryArguments, + FunctionResultEntry, + FunctionCallEntry, +) if TYPE_CHECKING: from mistralai.client.models import RealtimeTranscriptionError + from mistralai.extra.run.result import RunOutputEntries class MistralClientException(Exception): @@ -20,6 +32,219 @@ class MCPAuthException(MCPException): """MCP authentication errors.""" +class DeferralReason(str, Enum): + """Reason why a tool call was deferred.""" + + CONFIRMATION_REQUIRED = "confirmation_required" + SERVER_SIDE_CONFIRMATION_REQUIRED = "server_side_confirmation_required" + + +@dataclass +class DeferredToolCallConfirmation: + """Response indicating the tool call is approved for execution.""" + + tool_call_id: str + tool_name: str + function_call: FunctionCallEntry + override_args: Optional[dict[str, Any]] = None + deferral_reason: Optional[DeferralReason] = None + + +@dataclass +class DeferredToolCallRejection: + """Response indicating tool should not be executed.""" + + tool_call_id: str + message: str = "Rejected by user" + deferral_reason: Optional[DeferralReason] = None + + +DeferredToolCallResponse = Union[ + DeferredToolCallConfirmation, DeferredToolCallRejection +] + + +class FunctionCallSchema(typing.TypedDict): + id: str | None + tool_call_id: str + name: str + arguments: FunctionCallEntryArguments + + +class DeferredToolCallEntrySchema(typing.TypedDict): + tool_call_id: str + tool_name: str + arguments: FunctionCallEntryArguments + reason: str + metadata: dict[str, Any] + function_call: FunctionCallSchema + + +class DeferredToolCallEntry: + """Represents a tool call that requires confirmation.""" + + def __init__( + self, + function_call: FunctionCallEntry, + reason: DeferralReason = DeferralReason.CONFIRMATION_REQUIRED, + metadata: Optional[dict[str, Any]] = None, + ): + self.function_call = function_call + self.tool_call_id = function_call.tool_call_id + self.tool_name = function_call.name + self.arguments = function_call.arguments + self.reason = reason + self.metadata = metadata or {} + + def to_function_result(self, result: str) -> dict[str, str]: + """Convert to function result dict for use as input.""" + return { + "tool_call_id": self.tool_call_id, + "result": result, + } + + def confirm( + self, override_args: Optional[dict[str, str]] = None + ) -> DeferredToolCallConfirmation: + """Create a confirmation response for this tool call.""" + return DeferredToolCallConfirmation( + tool_call_id=self.tool_call_id, + tool_name=self.tool_name, + function_call=self.function_call, + override_args=override_args, + deferral_reason=self.reason, + ) + + def reject(self, message: str = "Rejected by user") -> DeferredToolCallRejection: + """Create a rejection response for this tool call.""" + return DeferredToolCallRejection( + tool_call_id=self.tool_call_id, + message=message, + deferral_reason=self.reason, + ) + + def to_dict(self) -> DeferredToolCallEntrySchema: + """Serialize to a JSON-serializable dictionary for stateless scenarios.""" + return { + "tool_call_id": self.tool_call_id, + "tool_name": self.tool_name, + "arguments": self.arguments, + "reason": self.reason.value, + "metadata": self.metadata, + "function_call": { + "id": self.function_call.id, + "tool_call_id": self.function_call.tool_call_id, + "name": self.function_call.name, + "arguments": self.function_call.arguments, + }, + } + + @classmethod + def from_dict(cls, data: DeferredToolCallEntrySchema) -> DeferredToolCallEntry: + """Deserialize from a dictionary.""" + function_call = FunctionCallEntry( + id=data["function_call"].get("id"), + tool_call_id=data["function_call"]["tool_call_id"], + name=data["function_call"]["name"], + arguments=data["function_call"]["arguments"], + ) + return cls( + function_call=function_call, + reason=DeferralReason( + data.get("reason", DeferralReason.CONFIRMATION_REQUIRED.value) + ), + metadata=data.get("metadata", {}), + ) + + +class DeferredToolCallsExceptionSchema(typing.TypedDict): + conversation_id: str | None + deferred_calls: list[DeferredToolCallEntrySchema] + outputs: list[dict[str, Any]] + executed_results: list[dict[str, Any]] + + +class DeferredToolCallsException(RunException): + """Exception raised when tool calls require human confirmation.""" + + def __init__( + self, + conversation_id: str | None, + deferred_calls: list[DeferredToolCallEntry], + outputs: list[RunOutputEntries] | None = None, + executed_results: list[FunctionResultEntry] | None = None, + ): + self.conversation_id = conversation_id + self.deferred_calls = deferred_calls + self.outputs = outputs or [] + self.executed_results = executed_results or [] + super().__init__( + f"Deferred tool calls requiring confirmation: {[dc.tool_name for dc in deferred_calls]}" + ) + + def to_dict(self) -> DeferredToolCallsExceptionSchema: + """Serialize to a JSON-serializable dictionary for stateless scenarios.""" + return { + "conversation_id": self.conversation_id, + "deferred_calls": [dc.to_dict() for dc in self.deferred_calls], + "outputs": [entry.model_dump(mode="json") for entry in self.outputs], + "executed_results": [ + entry.model_dump(mode="json") for entry in self.executed_results + ], + } + + @classmethod + def from_dict( + cls, data: DeferredToolCallsExceptionSchema + ) -> DeferredToolCallsException: + """Deserialize from a dictionary.""" + from pydantic import BaseModel + from mistralai.client.models import ( + MessageOutputEntry, + FunctionCallEntry, + FunctionResultEntry, + AgentHandoffEntry, + ToolExecutionEntry, + ) + + output_entry_types: dict[str, type[BaseModel]] = { + "message.output": MessageOutputEntry, + "function.call": FunctionCallEntry, + "function.result": FunctionResultEntry, + "agent.handoff": AgentHandoffEntry, + "tool.execution": ToolExecutionEntry, + } + + deferred_calls = [ + DeferredToolCallEntry.from_dict(dc_data) + for dc_data in data["deferred_calls"] + ] + + outputs: list[RunOutputEntries] = [] + for entry_data in data.get("outputs", []): + entry_type = entry_data.get("type") + if isinstance(entry_type, str): + model_cls = output_entry_types.get(entry_type) + if model_cls is not None: + outputs.append( + typing.cast( + "RunOutputEntries", model_cls.model_validate(entry_data) + ) + ) + + executed_results = [ + FunctionResultEntry.model_validate(r) + for r in data.get("executed_results", []) + ] + + return cls( + conversation_id=data["conversation_id"], + deferred_calls=deferred_calls, + outputs=outputs, + executed_results=executed_results, + ) + + class RealtimeTranscriptionException(MistralClientException): """Base realtime transcription exception.""" diff --git a/src/mistralai/extra/run/context.py b/src/mistralai/extra/run/context.py index d253edd5..f33b393d 100644 --- a/src/mistralai/extra/run/context.py +++ b/src/mistralai/extra/run/context.py @@ -52,7 +52,9 @@ class AgentRequestKwargs(typing.TypedDict): class ModelRequestKwargs(typing.TypedDict): model: str instructions: OptionalNullable[str] - tools: OptionalNullable[list[ConversationRequestTool] | list[ConversationRequestToolTypedDict]] + tools: OptionalNullable[ + list[ConversationRequestTool] | list[ConversationRequestToolTypedDict] + ] completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict] @@ -78,6 +80,9 @@ class RunContext: _exit_stack: AsyncExitStack = field(init=False) _callable_tools: dict[str, RunTool] = field(init=False, default_factory=dict) _mcp_clients: list[MCPClientProtocol] = field(init=False, default_factory=list) + _tool_configurations: dict[str, dict[str, bool]] = field( + init=False, default_factory=dict + ) conversation_id: str | None = field(default=None) model: str | None = field(default=None) @@ -99,7 +104,14 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): for mcp_client in self._mcp_clients: await mcp_client.aclose() - def register_func(self, func: Callable): + def requires_confirmation(self, tool_name: str) -> bool: + """Check if tool requires confirmation. Default: False.""" + config = self._tool_configurations.get(tool_name) + if config is None: + return False + return config.get("requires_confirmation", False) + + def register_func(self, func: Callable, requires_confirmation: bool = False): """Add a function to the context.""" if not inspect.isfunction(func): raise RunException( @@ -119,6 +131,10 @@ def register_func(self, func: Callable): tool=create_tool_call(func), ) + self._tool_configurations[func.__name__] = { + "requires_confirmation": requires_confirmation, + } + @wraps(func) def wrapper(*args, **kwargs): logger.info(f"Executing {func.__name__}") @@ -126,24 +142,63 @@ def wrapper(*args, **kwargs): return wrapper - async def register_mcp_clients(self, mcp_clients: list[MCPClientProtocol]) -> None: + async def register_mcp_clients( + self, + mcp_clients: list[MCPClientProtocol], + tool_configurations: list[dict[str, list[str]] | None] | None = None, + ) -> None: """Registering multiple MCP clients at the same time in the same asyncio.Task.""" - for mcp_client in mcp_clients: - await self.register_mcp_client(mcp_client) + for i, mcp_client in enumerate(mcp_clients): + tool_configuration = tool_configurations[i] if tool_configurations else None + await self.register_mcp_client( + mcp_client, tool_configuration=tool_configuration + ) - async def register_mcp_client(self, mcp_client: MCPClientProtocol) -> None: + async def register_mcp_client( + self, + mcp_client: MCPClientProtocol, + tool_configuration: dict[str, list[str]] | None = None, + ) -> None: """Add a MCP client to the context.""" await mcp_client.initialize(exit_stack=self._exit_stack) tools = await mcp_client.get_tools() + + include = ( + set(tool_configuration.get("include", [])) if tool_configuration else set() + ) + exclude = ( + set(tool_configuration.get("exclude", [])) if tool_configuration else set() + ) + requires_confirmation_list = ( + set(tool_configuration.get("requires_confirmation", [])) + if tool_configuration + else set() + ) + for tool in tools: + tool_name = tool.function.name + + if include: + if tool_name not in include: + continue + elif exclude: + if tool_name in exclude: + continue + logger.info( - f"Adding tool {tool.function.name} from {mcp_client._name or 'mcp client'}" + f"Adding tool {tool_name} from {mcp_client._name or 'mcp client'}" ) - self._callable_tools[tool.function.name] = RunMCPTool( - name=tool.function.name, + self._callable_tools[tool_name] = RunMCPTool( + name=tool_name, tool=tool, mcp_client=mcp_client, ) + + if tool_configuration is not None: + self._tool_configurations[tool_name] = { + "requires_confirmation": tool_name in requires_confirmation_list, + } + self._mcp_clients.append(mcp_client) async def execute_function_calls( @@ -213,8 +268,12 @@ async def prepare_agent_request(self, beta_client: "Beta") -> AgentRequestKwargs async def prepare_model_request( self, - tools: OptionalNullable[list[ConversationRequestTool] | list[ConversationRequestToolTypedDict]] = UNSET, - completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict] = UNSET, + tools: OptionalNullable[ + list[ConversationRequestTool] | list[ConversationRequestToolTypedDict] + ] = UNSET, + completion_args: OptionalNullable[ + CompletionArgs | CompletionArgsTypedDict + ] = UNSET, instructions: OptionalNullable[str] = None, ) -> ModelRequestKwargs: if self.model is None: @@ -254,11 +313,11 @@ async def _validate_run( run_ctx: RunContext, inputs: ConversationInputs | ConversationInputsTypedDict, instructions: OptionalNullable[str] = UNSET, - tools: OptionalNullable[list[ConversationRequestTool] | list[ConversationRequestToolTypedDict]] = UNSET, + tools: OptionalNullable[ + list[ConversationRequestTool] | list[ConversationRequestToolTypedDict] + ] = UNSET, completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict] = UNSET, -) -> tuple[ - AgentRequestKwargs | ModelRequestKwargs, RunResult, list[InputEntries] -]: +) -> tuple[AgentRequestKwargs | ModelRequestKwargs, RunResult, list[InputEntries]]: input_entries: list[InputEntries] = [] if isinstance(inputs, str): input_entries.append(MessageInputEntry(role="user", content=inputs)) @@ -268,6 +327,8 @@ async def _validate_run( input_entries.append( pydantic.TypeAdapter(InputEntries).validate_python(input) ) + elif isinstance(input, FunctionResultEntry): + input_entries.append(input) run_result = RunResult( input_entries=input_entries, output_model=run_ctx.output_format, diff --git a/src/mistralai/extra/run/deferred.py b/src/mistralai/extra/run/deferred.py new file mode 100644 index 00000000..5aa463b2 --- /dev/null +++ b/src/mistralai/extra/run/deferred.py @@ -0,0 +1,121 @@ +"""Helper functions for processing deferred tool call responses. + +Moved out of conversations.py to avoid conflicts with speakeasy code generation, +which overwrites everything outside custom regions. +""" + +from __future__ import annotations + +import asyncio +import json +from typing import TYPE_CHECKING + +from mistralai.client import models +from mistralai.extra.exceptions import ( + DeferralReason, + DeferredToolCallConfirmation, + DeferredToolCallRejection, + DeferredToolCallResponse, + RunException, +) + +if TYPE_CHECKING: + from mistralai.extra.run.context import RunContext + + +def _is_deferred_response(obj) -> bool: + """Check if object is a DeferredToolResponse.""" + return isinstance(obj, (DeferredToolCallConfirmation, DeferredToolCallRejection)) + + +def _is_server_deferred(fc: models.FunctionCallEntry) -> bool: + """Check if a function call was deferred server-side (pending confirmation).""" + return getattr(fc, "confirmation_status", None) == "pending" + + +async def _process_deferred_responses( + run_ctx: "RunContext", + responses: list[DeferredToolCallResponse], +) -> tuple[list[models.InputEntries], list[models.ToolCallConfirmation]]: + """Process deferred tool responses and return function results and server-side confirmations. + + For client-side deferrals (CONFIRMATION_REQUIRED): + - Confirmations: executes the tool using run_ctx -> FunctionResultEntry + - Rejections: creates a result with the rejection message -> FunctionResultEntry + For server-side deferrals (SERVER_SIDE_CONFIRMATION_REQUIRED): + - Confirmations: returns ToolCallConfirmation(confirmation="allow") + - Rejections: returns ToolCallConfirmation(confirmation="deny") + """ + results: list[models.InputEntries] = [] + tool_confirmations: list[models.ToolCallConfirmation] = [] + confirmation_tasks: list[tuple[str, str, asyncio.Task]] = [] + + for response in responses: + if isinstance(response, DeferredToolCallConfirmation): + reason = getattr( + response, "deferral_reason", DeferralReason.CONFIRMATION_REQUIRED + ) + + if reason == DeferralReason.SERVER_SIDE_CONFIRMATION_REQUIRED: + tool_confirmations.append( + models.ToolCallConfirmation( + tool_call_id=response.tool_call_id, + confirmation="allow", + ) + ) + else: + if response.override_args is not None: + original_args = ( + json.loads(response.function_call.arguments) + if isinstance(response.function_call.arguments, str) + else response.function_call.arguments + ) + merged_args = {**original_args, **response.override_args} + function_call = models.FunctionCallEntry( + id=response.function_call.id, + tool_call_id=response.tool_call_id, + name=response.tool_name, + arguments=json.dumps(merged_args), + ) + else: + function_call = response.function_call + + task = asyncio.create_task( + run_ctx.execute_function_calls([function_call]) + ) + confirmation_tasks.append( + (response.tool_call_id, response.tool_name, task) + ) + + elif isinstance(response, DeferredToolCallRejection): + reason = getattr( + response, "deferral_reason", DeferralReason.CONFIRMATION_REQUIRED + ) + + if reason == DeferralReason.SERVER_SIDE_CONFIRMATION_REQUIRED: + tool_confirmations.append( + models.ToolCallConfirmation( + tool_call_id=response.tool_call_id, + confirmation="deny", + ) + ) + else: + results.append( + models.FunctionResultEntry( + tool_call_id=response.tool_call_id, + result=response.message, + ) + ) + + if confirmation_tasks: + await asyncio.gather(*[task for _, _, task in confirmation_tasks]) + for tool_call_id, tool_name, task in confirmation_tasks: + task_results = task.result() + if task_results: + results.append(task_results[0]) + else: + raise RunException( + f"Tool '{tool_name}' is not registered in the RunContext" + ) + + return results, tool_confirmations