Skip to content

Commit 2e1af21

Browse files
tchardonnensclaude
andcommitted
feat: add Human-in-the-Loop (HITL) deferred tool calls support
Add DeferredToolCallsException and supporting classes to enable human confirmation flows in run_async and run_stream_async. When a tool is registered with requires_confirmation=True or the server returns a function call with confirmation_status="pending", run_async raises DeferredToolCallsException. The user catches it, calls .confirm() or .reject() on each deferred call, and passes the responses back to run_async to resume the conversation. Changes: - extra/exceptions.py: Add DeferralReason, DeferredToolCallEntry, DeferredToolCallConfirmation, DeferredToolCallRejection, DeferredToolCallsException with serialization support - extra/run/deferred.py: New module with helpers for processing deferred responses (client-side execution, server-side ToolCallConfirmation) - extra/run/context.py: Add requires_confirmation() method, _tool_configurations dict, requires_confirmation param on register_func() and register_mcp_client(), tool include/exclude filtering, FunctionResultEntry pass-through in _validate_run - client/conversations.py: Update run_async and run_stream_async to partition function calls into deferred vs executable, process DeferredToolCallResponse inputs, and pass tool_confirmations to append_async/append_stream_async Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f0969ad commit 2e1af21

4 files changed

Lines changed: 590 additions & 36 deletions

File tree

src/mistralai/client/conversations.py

Lines changed: 167 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@
2828
)
2929
from mistralai.extra.run.utils import run_requirements
3030
from mistralai.extra.observability.otel import GenAISpanEnum, get_or_create_otel_tracer
31+
from mistralai.extra.exceptions import (
32+
DeferralReason,
33+
DeferredToolCallsException,
34+
DeferredToolCallEntry,
35+
DeferredToolCallResponse,
36+
)
37+
from mistralai.extra.run.deferred import (
38+
_is_deferred_response,
39+
_is_server_deferred,
40+
_process_deferred_responses,
41+
)
3142

3243
logger = logging.getLogger(__name__)
3344
tracing_enabled, tracer = get_or_create_otel_tracer()
@@ -48,7 +59,11 @@ class Conversations(BaseSDK):
4859
async def run_async(
4960
self,
5061
run_ctx: "RunContext",
51-
inputs: Union[models.ConversationInputs, models.ConversationInputsTypedDict],
62+
inputs: Union[
63+
models.ConversationInputs,
64+
models.ConversationInputsTypedDict,
65+
List[DeferredToolCallResponse],
66+
],
5267
instructions: OptionalNullable[str] = UNSET,
5368
tools: OptionalNullable[
5469
Union[
@@ -68,16 +83,44 @@ async def run_async(
6883
) -> RunResult:
6984
"""Run a conversation with the given inputs and context.
7085
71-
The execution of a run will only stop when no required local execution can be done."""
86+
The execution of a run will only stop when no required local execution can be done.
87+
88+
Inputs can be:
89+
- Regular conversation inputs (messages, function results, etc.)
90+
- DeferredToolResponse objects (from deferred.confirm(), reject())
91+
92+
When passing DeferredToolResponse objects, the SDK will:
93+
- Execute confirmed tools automatically
94+
- Convert rejections to function results with the rejection message
95+
"""
7296
from mistralai.client.beta import Beta # pylint: disable=import-outside-toplevel
7397
from mistralai.extra.run.context import _validate_run # pylint: disable=import-outside-toplevel
7498
from mistralai.extra.run.tools import get_function_calls # pylint: disable=import-outside-toplevel
7599

100+
# Check if inputs contain deferred responses - process them
101+
pending_tool_confirmations: Optional[List[models.ToolCallConfirmation]] = None
102+
if inputs and isinstance(inputs, list):
103+
deferred_inputs = typing.cast(
104+
List[DeferredToolCallResponse],
105+
[i for i in inputs if _is_deferred_response(i)],
106+
)
107+
other_inputs = typing.cast(
108+
List[InputEntries], [i for i in inputs if not _is_deferred_response(i)]
109+
)
110+
if deferred_inputs:
111+
(
112+
processed,
113+
pending_tool_confirmations,
114+
) = await _process_deferred_responses(run_ctx, deferred_inputs)
115+
inputs = other_inputs + processed
116+
if not pending_tool_confirmations:
117+
pending_tool_confirmations = None
118+
76119
with tracer.start_as_current_span(GenAISpanEnum.VALIDATE_RUN.value):
77120
req, run_result, input_entries = await _validate_run(
78121
beta_client=Beta(self.sdk_configuration),
79122
run_ctx=run_ctx,
80-
inputs=inputs,
123+
inputs=typing.cast(List[InputEntries], inputs),
81124
instructions=instructions,
82125
tools=tools,
83126
completion_args=completion_args,
@@ -105,26 +148,68 @@ async def run_async(
105148
res = await self.append_async(
106149
conversation_id=run_ctx.conversation_id,
107150
inputs=input_entries,
151+
tool_confirmations=pending_tool_confirmations,
108152
retries=retries,
109153
server_url=server_url,
110154
timeout_ms=timeout_ms,
155+
http_headers=http_headers,
111156
)
157+
# Clear after first use
158+
pending_tool_confirmations = None
112159
run_ctx.request_count += 1
113160
run_result.output_entries.extend(res.outputs)
114161
fcalls = get_function_calls(res.outputs)
115162
if not fcalls:
116163
logger.debug("No more function calls to execute")
117164
break
118-
fresults = await run_ctx.execute_function_calls(fcalls)
119-
run_result.output_entries.extend(fresults)
120-
input_entries = typing.cast(list[InputEntries], fresults)
165+
166+
# Partition by permission: include server-side deferred calls
167+
to_defer = [
168+
fc
169+
for fc in fcalls
170+
if run_ctx.requires_confirmation(fc.name) or _is_server_deferred(fc)
171+
]
172+
to_execute = [fc for fc in fcalls if fc not in to_defer]
173+
174+
# Execute approved
175+
fresults = []
176+
if to_execute:
177+
fresults = await run_ctx.execute_function_calls(to_execute)
178+
run_result.output_entries.extend(fresults)
179+
input_entries = typing.cast(list[InputEntries], fresults)
180+
181+
# Defer the rest - include executed_results so user can pass them back
182+
if to_defer:
183+
deferred_objects = [
184+
DeferredToolCallEntry(
185+
fc,
186+
reason=DeferralReason.SERVER_SIDE_CONFIRMATION_REQUIRED
187+
if _is_server_deferred(fc)
188+
else DeferralReason.CONFIRMATION_REQUIRED,
189+
)
190+
for fc in to_defer
191+
]
192+
raise DeferredToolCallsException(
193+
run_ctx.conversation_id,
194+
deferred_objects,
195+
run_result.output_entries,
196+
executed_results=fresults,
197+
)
198+
199+
# If we only executed tools (none deferred), continue the loop
200+
if not to_execute:
201+
break
121202
return run_result
122203

123204
@run_requirements
124205
async def run_stream_async(
125206
self,
126207
run_ctx: "RunContext",
127-
inputs: Union[models.ConversationInputs, models.ConversationInputsTypedDict],
208+
inputs: Union[
209+
models.ConversationInputs,
210+
models.ConversationInputsTypedDict,
211+
List[DeferredToolCallResponse],
212+
],
128213
instructions: OptionalNullable[str] = UNSET,
129214
tools: OptionalNullable[
130215
Union[
@@ -144,15 +229,39 @@ async def run_stream_async(
144229
) -> AsyncGenerator[Union[RunResultEvents, RunResult], None]:
145230
"""Similar to `run_async` but returns a generator which streams events.
146231
147-
The last streamed object is the RunResult object which summarises what happened in the run."""
232+
The last streamed object is the RunResult object which summarises what happened in the run.
233+
234+
Inputs can be:
235+
- Regular conversation inputs (messages, function results, etc.)
236+
- DeferredToolResponse objects (from deferred.confirm(), reject())
237+
"""
148238
from mistralai.client.beta import Beta # pylint: disable=import-outside-toplevel
149239
from mistralai.extra.run.context import _validate_run # pylint: disable=import-outside-toplevel
150240
from mistralai.extra.run.tools import get_function_calls # pylint: disable=import-outside-toplevel
151241

242+
# Check if inputs contain deferred responses - process them
243+
pending_tool_confirmations: Optional[List[models.ToolCallConfirmation]] = None
244+
if inputs and isinstance(inputs, list):
245+
deferred_inputs = typing.cast(
246+
List[DeferredToolCallResponse],
247+
[i for i in inputs if _is_deferred_response(i)],
248+
)
249+
other_inputs = typing.cast(
250+
List[InputEntries], [i for i in inputs if not _is_deferred_response(i)]
251+
)
252+
if deferred_inputs:
253+
(
254+
processed,
255+
pending_tool_confirmations,
256+
) = await _process_deferred_responses(run_ctx, deferred_inputs)
257+
inputs = other_inputs + processed
258+
if not pending_tool_confirmations:
259+
pending_tool_confirmations = None
260+
152261
req, run_result, input_entries = await _validate_run(
153262
beta_client=Beta(self.sdk_configuration),
154263
run_ctx=run_ctx,
155-
inputs=inputs,
264+
inputs=typing.cast(List[InputEntries], inputs),
156265
instructions=instructions,
157266
tools=tools,
158267
completion_args=completion_args,
@@ -161,6 +270,7 @@ async def run_stream_async(
161270
async def run_generator() -> (
162271
AsyncGenerator[Union[RunResultEvents, RunResult], None]
163272
):
273+
nonlocal pending_tool_confirmations
164274
current_entries = input_entries
165275
while True:
166276
received_event_tracker: defaultdict[
@@ -181,10 +291,13 @@ async def run_generator() -> (
181291
res = await self.append_stream_async(
182292
conversation_id=run_ctx.conversation_id,
183293
inputs=current_entries,
294+
tool_confirmations=pending_tool_confirmations,
184295
retries=retries,
185296
server_url=server_url,
186297
timeout_ms=timeout_ms,
187298
)
299+
# Clear after first use
300+
pending_tool_confirmations = None
188301
async for event in res:
189302
if (
190303
isinstance(event.data, ResponseStartedEvent)
@@ -207,18 +320,52 @@ async def run_generator() -> (
207320
if not fcalls:
208321
logger.debug("No more function calls to execute")
209322
break
210-
fresults = await run_ctx.execute_function_calls(fcalls)
211-
run_result.output_entries.extend(fresults)
212-
for fresult in fresults:
213-
yield RunResultEvents(
214-
event="function.result",
215-
data=FunctionResultEvent(
216-
type="function.result",
217-
result=fresult.result,
218-
tool_call_id=fresult.tool_call_id,
219-
),
323+
324+
# Partition by permission: include server-side deferred calls
325+
to_defer = [
326+
fc
327+
for fc in fcalls
328+
if run_ctx.requires_confirmation(fc.name) or _is_server_deferred(fc)
329+
]
330+
to_execute = [fc for fc in fcalls if fc not in to_defer]
331+
332+
# Execute approved
333+
fresults = []
334+
if to_execute:
335+
fresults = await run_ctx.execute_function_calls(to_execute)
336+
run_result.output_entries.extend(fresults)
337+
for fresult in fresults:
338+
yield RunResultEvents(
339+
event="function.result",
340+
data=FunctionResultEvent(
341+
type="function.result",
342+
result=fresult.result,
343+
tool_call_id=fresult.tool_call_id,
344+
),
345+
)
346+
current_entries = typing.cast(list[InputEntries], fresults)
347+
348+
# Defer the rest - include executed_results so user can pass them back
349+
if to_defer:
350+
deferred_objects = [
351+
DeferredToolCallEntry(
352+
fc,
353+
reason=DeferralReason.SERVER_SIDE_CONFIRMATION_REQUIRED
354+
if _is_server_deferred(fc)
355+
else DeferralReason.CONFIRMATION_REQUIRED,
356+
)
357+
for fc in to_defer
358+
]
359+
raise DeferredToolCallsException(
360+
run_ctx.conversation_id,
361+
deferred_objects,
362+
run_result.output_entries,
363+
executed_results=fresults,
220364
)
221-
current_entries = typing.cast(list[InputEntries], fresults)
365+
366+
# If we only executed tools (none deferred), continue the loop
367+
if not to_execute:
368+
break
222369
yield run_result
223370

224371
return run_generator()

0 commit comments

Comments
 (0)