2828)
2929from mistralai .extra .run .utils import run_requirements
3030from mistralai .extra .observability .otel import GenAISpanEnum , get_or_create_otel_tracer
31+ from mistralai .extra .exceptions import (
32+ DeferralReason ,
33+ DeferredToolCallsException ,
34+ DeferredToolCallEntry ,
35+ DeferredToolCallResponse ,
36+ )
37+ from mistralai .extra .run .deferred import (
38+ _is_deferred_response ,
39+ _is_server_deferred ,
40+ _process_deferred_responses ,
41+ )
3142
3243logger = logging .getLogger (__name__ )
3344tracing_enabled , tracer = get_or_create_otel_tracer ()
@@ -48,7 +59,11 @@ class Conversations(BaseSDK):
4859 async def run_async (
4960 self ,
5061 run_ctx : "RunContext" ,
51- inputs : Union [models .ConversationInputs , models .ConversationInputsTypedDict ],
62+ inputs : Union [
63+ models .ConversationInputs ,
64+ models .ConversationInputsTypedDict ,
65+ List [DeferredToolCallResponse ],
66+ ],
5267 instructions : OptionalNullable [str ] = UNSET ,
5368 tools : OptionalNullable [
5469 Union [
@@ -68,16 +83,44 @@ async def run_async(
6883 ) -> RunResult :
6984 """Run a conversation with the given inputs and context.
7085
71- The execution of a run will only stop when no required local execution can be done."""
86+ The execution of a run will only stop when no required local execution can be done.
87+
88+ Inputs can be:
89+ - Regular conversation inputs (messages, function results, etc.)
90+ - DeferredToolResponse objects (from deferred.confirm(), reject())
91+
92+ When passing DeferredToolResponse objects, the SDK will:
93+ - Execute confirmed tools automatically
94+ - Convert rejections to function results with the rejection message
95+ """
7296 from mistralai .client .beta import Beta # pylint: disable=import-outside-toplevel
7397 from mistralai .extra .run .context import _validate_run # pylint: disable=import-outside-toplevel
7498 from mistralai .extra .run .tools import get_function_calls # pylint: disable=import-outside-toplevel
7599
100+ # Check if inputs contain deferred responses - process them
101+ pending_tool_confirmations : Optional [List [models .ToolCallConfirmation ]] = None
102+ if inputs and isinstance (inputs , list ):
103+ deferred_inputs = typing .cast (
104+ List [DeferredToolCallResponse ],
105+ [i for i in inputs if _is_deferred_response (i )],
106+ )
107+ other_inputs = typing .cast (
108+ List [InputEntries ], [i for i in inputs if not _is_deferred_response (i )]
109+ )
110+ if deferred_inputs :
111+ (
112+ processed ,
113+ pending_tool_confirmations ,
114+ ) = await _process_deferred_responses (run_ctx , deferred_inputs )
115+ inputs = other_inputs + processed
116+ if not pending_tool_confirmations :
117+ pending_tool_confirmations = None
118+
76119 with tracer .start_as_current_span (GenAISpanEnum .VALIDATE_RUN .value ):
77120 req , run_result , input_entries = await _validate_run (
78121 beta_client = Beta (self .sdk_configuration ),
79122 run_ctx = run_ctx ,
80- inputs = inputs ,
123+ inputs = typing . cast ( List [ InputEntries ], inputs ) ,
81124 instructions = instructions ,
82125 tools = tools ,
83126 completion_args = completion_args ,
@@ -105,26 +148,68 @@ async def run_async(
105148 res = await self .append_async (
106149 conversation_id = run_ctx .conversation_id ,
107150 inputs = input_entries ,
151+ tool_confirmations = pending_tool_confirmations ,
108152 retries = retries ,
109153 server_url = server_url ,
110154 timeout_ms = timeout_ms ,
155+ http_headers = http_headers ,
111156 )
157+ # Clear after first use
158+ pending_tool_confirmations = None
112159 run_ctx .request_count += 1
113160 run_result .output_entries .extend (res .outputs )
114161 fcalls = get_function_calls (res .outputs )
115162 if not fcalls :
116163 logger .debug ("No more function calls to execute" )
117164 break
118- fresults = await run_ctx .execute_function_calls (fcalls )
119- run_result .output_entries .extend (fresults )
120- input_entries = typing .cast (list [InputEntries ], fresults )
165+
166+ # Partition by permission: include server-side deferred calls
167+ to_defer = [
168+ fc
169+ for fc in fcalls
170+ if run_ctx .requires_confirmation (fc .name ) or _is_server_deferred (fc )
171+ ]
172+ to_execute = [fc for fc in fcalls if fc not in to_defer ]
173+
174+ # Execute approved
175+ fresults = []
176+ if to_execute :
177+ fresults = await run_ctx .execute_function_calls (to_execute )
178+ run_result .output_entries .extend (fresults )
179+ input_entries = typing .cast (list [InputEntries ], fresults )
180+
181+ # Defer the rest - include executed_results so user can pass them back
182+ if to_defer :
183+ deferred_objects = [
184+ DeferredToolCallEntry (
185+ fc ,
186+ reason = DeferralReason .SERVER_SIDE_CONFIRMATION_REQUIRED
187+ if _is_server_deferred (fc )
188+ else DeferralReason .CONFIRMATION_REQUIRED ,
189+ )
190+ for fc in to_defer
191+ ]
192+ raise DeferredToolCallsException (
193+ run_ctx .conversation_id ,
194+ deferred_objects ,
195+ run_result .output_entries ,
196+ executed_results = fresults ,
197+ )
198+
199+ # If we only executed tools (none deferred), continue the loop
200+ if not to_execute :
201+ break
121202 return run_result
122203
123204 @run_requirements
124205 async def run_stream_async (
125206 self ,
126207 run_ctx : "RunContext" ,
127- inputs : Union [models .ConversationInputs , models .ConversationInputsTypedDict ],
208+ inputs : Union [
209+ models .ConversationInputs ,
210+ models .ConversationInputsTypedDict ,
211+ List [DeferredToolCallResponse ],
212+ ],
128213 instructions : OptionalNullable [str ] = UNSET ,
129214 tools : OptionalNullable [
130215 Union [
@@ -144,15 +229,39 @@ async def run_stream_async(
144229 ) -> AsyncGenerator [Union [RunResultEvents , RunResult ], None ]:
145230 """Similar to `run_async` but returns a generator which streams events.
146231
147- The last streamed object is the RunResult object which summarises what happened in the run."""
232+ The last streamed object is the RunResult object which summarises what happened in the run.
233+
234+ Inputs can be:
235+ - Regular conversation inputs (messages, function results, etc.)
236+ - DeferredToolResponse objects (from deferred.confirm(), reject())
237+ """
148238 from mistralai .client .beta import Beta # pylint: disable=import-outside-toplevel
149239 from mistralai .extra .run .context import _validate_run # pylint: disable=import-outside-toplevel
150240 from mistralai .extra .run .tools import get_function_calls # pylint: disable=import-outside-toplevel
151241
242+ # Check if inputs contain deferred responses - process them
243+ pending_tool_confirmations : Optional [List [models .ToolCallConfirmation ]] = None
244+ if inputs and isinstance (inputs , list ):
245+ deferred_inputs = typing .cast (
246+ List [DeferredToolCallResponse ],
247+ [i for i in inputs if _is_deferred_response (i )],
248+ )
249+ other_inputs = typing .cast (
250+ List [InputEntries ], [i for i in inputs if not _is_deferred_response (i )]
251+ )
252+ if deferred_inputs :
253+ (
254+ processed ,
255+ pending_tool_confirmations ,
256+ ) = await _process_deferred_responses (run_ctx , deferred_inputs )
257+ inputs = other_inputs + processed
258+ if not pending_tool_confirmations :
259+ pending_tool_confirmations = None
260+
152261 req , run_result , input_entries = await _validate_run (
153262 beta_client = Beta (self .sdk_configuration ),
154263 run_ctx = run_ctx ,
155- inputs = inputs ,
264+ inputs = typing . cast ( List [ InputEntries ], inputs ) ,
156265 instructions = instructions ,
157266 tools = tools ,
158267 completion_args = completion_args ,
@@ -161,6 +270,7 @@ async def run_stream_async(
161270 async def run_generator () -> (
162271 AsyncGenerator [Union [RunResultEvents , RunResult ], None ]
163272 ):
273+ nonlocal pending_tool_confirmations
164274 current_entries = input_entries
165275 while True :
166276 received_event_tracker : defaultdict [
@@ -181,10 +291,13 @@ async def run_generator() -> (
181291 res = await self .append_stream_async (
182292 conversation_id = run_ctx .conversation_id ,
183293 inputs = current_entries ,
294+ tool_confirmations = pending_tool_confirmations ,
184295 retries = retries ,
185296 server_url = server_url ,
186297 timeout_ms = timeout_ms ,
187298 )
299+ # Clear after first use
300+ pending_tool_confirmations = None
188301 async for event in res :
189302 if (
190303 isinstance (event .data , ResponseStartedEvent )
@@ -207,18 +320,52 @@ async def run_generator() -> (
207320 if not fcalls :
208321 logger .debug ("No more function calls to execute" )
209322 break
210- fresults = await run_ctx .execute_function_calls (fcalls )
211- run_result .output_entries .extend (fresults )
212- for fresult in fresults :
213- yield RunResultEvents (
214- event = "function.result" ,
215- data = FunctionResultEvent (
216- type = "function.result" ,
217- result = fresult .result ,
218- tool_call_id = fresult .tool_call_id ,
219- ),
323+
324+ # Partition by permission: include server-side deferred calls
325+ to_defer = [
326+ fc
327+ for fc in fcalls
328+ if run_ctx .requires_confirmation (fc .name ) or _is_server_deferred (fc )
329+ ]
330+ to_execute = [fc for fc in fcalls if fc not in to_defer ]
331+
332+ # Execute approved
333+ fresults = []
334+ if to_execute :
335+ fresults = await run_ctx .execute_function_calls (to_execute )
336+ run_result .output_entries .extend (fresults )
337+ for fresult in fresults :
338+ yield RunResultEvents (
339+ event = "function.result" ,
340+ data = FunctionResultEvent (
341+ type = "function.result" ,
342+ result = fresult .result ,
343+ tool_call_id = fresult .tool_call_id ,
344+ ),
345+ )
346+ current_entries = typing .cast (list [InputEntries ], fresults )
347+
348+ # Defer the rest - include executed_results so user can pass them back
349+ if to_defer :
350+ deferred_objects = [
351+ DeferredToolCallEntry (
352+ fc ,
353+ reason = DeferralReason .SERVER_SIDE_CONFIRMATION_REQUIRED
354+ if _is_server_deferred (fc )
355+ else DeferralReason .CONFIRMATION_REQUIRED ,
356+ )
357+ for fc in to_defer
358+ ]
359+ raise DeferredToolCallsException (
360+ run_ctx .conversation_id ,
361+ deferred_objects ,
362+ run_result .output_entries ,
363+ executed_results = fresults ,
220364 )
221- current_entries = typing .cast (list [InputEntries ], fresults )
365+
366+ # If we only executed tools (none deferred), continue the loop
367+ if not to_execute :
368+ break
222369 yield run_result
223370
224371 return run_generator ()
0 commit comments