Skip to content

Commit cca51d4

Browse files
author
Your Name
committed
refactor: Improve interrupt handling with interruptible wrapper
Co-authored-by: cecli (openai/gemini_cli_local/gemini-2.5-pro)
1 parent 4c436a3 commit cca51d4

3 files changed

Lines changed: 52 additions & 34 deletions

File tree

cecli/coders/agent_coder.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from cecli.helpers import nested, responses
1717
from cecli.helpers.background_commands import BackgroundCommandManager
1818
from cecli.helpers.conversation import ConversationService, MessageTag
19+
from cecli.helpers.coroutines import interruptible
1920
from cecli.helpers.similarity import (
2021
cosine_similarity,
2122
create_bigram_vector,
@@ -301,25 +302,15 @@ async def _execute_local_tool_calls(self, tool_calls_list):
301302
else:
302303
all_results_content.append(f"Error: Unknown tool name '{tool_name}'")
303304
if tasks:
304-
gather_future = asyncio.gather(*tasks, return_exceptions=True)
305-
interrupt_task = asyncio.create_task(self.interrupt_event.wait())
306-
307-
done, pending = await asyncio.wait(
308-
{gather_future, interrupt_task},
309-
return_when=asyncio.FIRST_COMPLETED,
305+
gather_coro = asyncio.gather(*tasks, return_exceptions=True)
306+
task_results, interrupted = await interruptible(
307+
gather_coro, self.interrupt_event
310308
)
311309

312-
if interrupt_task in done:
313-
gather_future.cancel()
314-
try:
315-
await gather_future
316-
except asyncio.CancelledError:
317-
pass
310+
if interrupted:
318311
self.io.tool_warning("Tool execution interrupted.")
319-
# Append a message indicating interruption
320312
all_results_content.append("Tool execution interrupted by user.")
321-
else:
322-
task_results = gather_future.result()
313+
elif task_results:
323314
for res in task_results:
324315
if isinstance(res, Exception):
325316
all_results_content.append(f"Error in tool execution: {res}")
@@ -415,24 +406,11 @@ async def _exec_async():
415406
""")
416407
return f"Error executing tool call {tool_name}: {e}"
417408

418-
exec_future = asyncio.create_task(_exec_async())
419-
interrupt_task = asyncio.create_task(self.interrupt_event.wait())
420-
421-
done, pending = await asyncio.wait(
422-
{exec_future, interrupt_task},
423-
return_when=asyncio.FIRST_COMPLETED,
424-
)
409+
result, interrupted = await interruptible(_exec_async(), self.interrupt_event)
425410

426-
if interrupt_task in done:
427-
exec_future.cancel()
428-
try:
429-
await exec_future
430-
except asyncio.CancelledError:
431-
pass
411+
if interrupted:
432412
return "Tool execution interrupted by user."
433-
else:
434-
interrupt_task.cancel()
435-
return await exec_future
413+
return result
436414

437415
def _calculate_context_block_tokens(self, force=False):
438416
"""

cecli/coders/base_coder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2722,6 +2722,9 @@ async def _execute_mcp_tools(self, server, tool_calls):
27222722
tool_responses.append(
27232723
{"role": "tool", "tool_call_id": tool_call.id, "content": connection_error}
27242724
)
2725+
except asyncio.CancelledError:
2726+
# Re-raise CancelledError to ensure the task cancellation propagates
2727+
raise
27252728
except Exception as e:
27262729
connection_error = f"Could not connect to server {server.name}\n{e}"
27272730
self.io.tool_warning(connection_error)

cecli/helpers/coroutines.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,45 @@
1-
import asyncio # noqa: F401
1+
import asyncio
22

33

4-
def is_active(coroutine):
5-
if not coroutine or coroutine.done() or coroutine.cancelled():
4+
def is_active(task):
5+
if not task or task.done() or task.cancelled():
66
return False
77

88
return True
9+
10+
11+
async def interruptible(coroutine, interrupt_event):
12+
"""
13+
Runs a coroutine and allows it to be interrupted by an asyncio.Event.
14+
15+
Args:
16+
coroutine: The coroutine to run.
17+
interrupt_event: The asyncio.Event that signals an interruption.
18+
19+
Returns:
20+
A tuple of (result, interrupted).
21+
- If not interrupted: (coroutine_result, False)
22+
- If interrupted: (None, True)
23+
"""
24+
main_task = asyncio.create_task(coroutine)
25+
interrupt_task = asyncio.create_task(interrupt_event.wait())
26+
27+
done, pending = await asyncio.wait(
28+
{main_task, interrupt_task},
29+
return_when=asyncio.FIRST_COMPLETED,
30+
)
31+
32+
for task in pending:
33+
task.cancel()
34+
try:
35+
await task
36+
except asyncio.CancelledError:
37+
pass # Expected
38+
39+
if interrupt_task in done:
40+
return None, True
41+
42+
try:
43+
return main_task.result(), False
44+
except asyncio.CancelledError:
45+
return None, True

0 commit comments

Comments
 (0)