Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 88 additions & 5 deletions src/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from __future__ import annotations

import sys
from typing import Generator, Optional, Any
from typing import Generator, Optional, Any, TYPE_CHECKING

from .base import BaseProvider, ChatResponse, MessageInput, TextChunkCallback

if TYPE_CHECKING:
from src.utils.abort_controller import AbortSignal


# WI-4.4 (ch17 Phase 4): defer the ``import anthropic`` call. The SDK
# alone is ~150-200ms to import (verified by ``my-docs/profiler-baseline.md``:
Expand Down Expand Up @@ -248,6 +251,7 @@ def chat_stream_response(
messages: list[MessageInput],
tools: Optional[list[dict[str, Any]]] = None,
on_text_chunk: TextChunkCallback | None = None,
abort_signal: "AbortSignal | None" = None,
**kwargs
) -> ChatResponse:
"""Stream Anthropic text chunks and return the final structured response.
Expand All @@ -257,9 +261,26 @@ def chat_stream_response(
``CLAUDE_STREAM_IDLE_TIMEOUT_MS`` (default 90 s). On timeout the
iterator raises; we catch it and fall back to the non-streaming
``chat()`` path so the user gets an answer rather than a hung session.

ESC-cancellation: when ``abort_signal`` is provided, a listener is
registered that calls ``stream.response.close()`` when the signal
fires. The close interrupts the SDK's blocking socket read so the
``for text in stream.text_stream`` iterator raises immediately —
without it, ESC during a tool-use-only response (no intervening
text chunks for ``on_text_chunk`` to observe) waits for the model
to finish generating before the outer query loop can bail. We
translate the raise into ``AbortError`` so callers can distinguish
a user-initiated cancel from the watchdog's idle-timeout fallback.
"""
from src.utils.abort_controller import AbortError
from src.utils.stream_watchdog import StreamWatchdog

# Fast-path: if abort fired before we even build the request, don't
# spend the round-trip — raise directly so the caller's cancel
# boundary unwinds at the same place the mid-stream path lands.
if abort_signal is not None and abort_signal.aborted:
raise AbortError(abort_signal.reason or "user_interrupt")

model = self._get_model(**kwargs)
max_tokens = kwargs.get("max_tokens", 4096)
system = kwargs.pop("system", None)
Expand Down Expand Up @@ -294,6 +315,7 @@ def _fallback_to_chat() -> ChatResponse:
streamed_text = ""
watchdog_fired = False
final_message = None
abort_listener: Any = None
try:
with client.messages.stream(
model=model,
Expand All @@ -303,6 +325,44 @@ def _fallback_to_chat() -> ChatResponse:
**extra_kwargs,
**{k: v for k, v in kwargs.items() if k not in ["model", "max_tokens", "tools"]},
) as stream:
# Register the abort listener BEFORE the iterator pulls
# its first chunk, so a signal that fires between context
# entry and the first ``text_stream.__next__`` still wins
# the race. Mirrors ``StreamWatchdog``'s close pattern:
# close the underlying HTTP response from another thread,
# which raises in the consumer thread on the next pull.
if abort_signal is not None:
def _close_stream_on_abort() -> None:
try:
response = getattr(stream, "response", None)
if response is not None:
close = getattr(response, "close", None)
if callable(close):
close()
except Exception:
# Best-effort — never let the close
# propagate out of the listener thread.
pass

# Register-then-recheck (NOT check-then-register):
# the naive ordering has a sub-microsecond race
# where another thread can call ``_fire`` between
# our ``aborted`` read and the ``add_listener``
# append. ``_fire`` snapshots the listener list
# before iterating, so any listener appended after
# that snapshot is silently dropped.
# Register-then-recheck closes the gap: ``aborted``
# is sticky-True after ``_fire`` runs, so the
# post-add read catches any concurrent fire, and
# ``_close_stream_on_abort`` is idempotent so a
# double-call (listener fires AND we call directly)
# is harmless.
abort_listener = abort_signal.add_listener(
_close_stream_on_abort, once=True,
)
if abort_signal.aborted:
_close_stream_on_abort()

watchdog = StreamWatchdog(stream)
watchdog.arm()
try:
Expand All @@ -329,10 +389,22 @@ def _fallback_to_chat() -> ChatResponse:
watchdog_fired = watchdog.fired
watchdog.disarm()
except Exception as streaming_exc:
# WI-5.2 fallback path: stream interrupted. If our watchdog
# triggered the interruption, fall back to non-streaming so
# the user still gets an answer. If the failure is something
# else (network/auth/etc.), re-raise the original.
# Abort path: the abort listener closed the stream's response,
# which raised in the consumer thread. Translate to
# ``AbortError`` so the query loop's
# ``except AbortError: raise`` cancel boundary unwinds
# cleanly. We check the signal AFTER the catch (not the
# exception type) because the SDK can raise several different
# exception classes depending on which socket operation was
# in flight when we closed; the abort_signal state is the
# authoritative source of truth.
if abort_signal is not None and abort_signal.aborted:
raise AbortError(abort_signal.reason or "user_interrupt") from streaming_exc

# WI-5.2 fallback path: stream interrupted by the idle
# watchdog. Fall back to non-streaming so the user still
# gets an answer. If the failure is something else
# (network/auth/etc.), re-raise the original.
if watchdog_fired:
try:
return _fallback_to_chat()
Expand All @@ -344,6 +416,17 @@ def _fallback_to_chat() -> ChatResponse:
# error and re-raised only the streaming one.
raise fallback_exc from streaming_exc
raise
finally:
# Always detach the abort listener so it doesn't pin the
# provider alive past one call.
if abort_listener is not None and abort_signal is not None:
abort_signal.remove_listener(abort_listener)

# Stream completed normally but abort may have fired between
# ``stream.__exit__`` and here. Surface it now so the caller
# bails at the same place every other path does.
if abort_signal is not None and abort_signal.aborted:
raise AbortError(abort_signal.reason or "user_interrupt")

if watchdog_fired:
# Stream got interrupted but no exception escaped the
Expand Down
14 changes: 13 additions & 1 deletion src/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Generator, Optional, TypeAlias
from typing import Any, Callable, Generator, Optional, TYPE_CHECKING, TypeAlias

if TYPE_CHECKING:
from src.utils.abort_controller import AbortSignal


@dataclass
Expand Down Expand Up @@ -95,10 +98,19 @@ def chat_stream_response(
messages: list[MessageInput],
tools: Optional[list[dict[str, Any]]] = None,
on_text_chunk: TextChunkCallback | None = None,
abort_signal: "AbortSignal | None" = None,
**kwargs
) -> ChatResponse:
"""Stream a response while also returning the final structured ChatResponse.

When ``abort_signal`` is provided, a provider implementation should
register a listener on it that forcibly closes the underlying HTTP
stream when the signal fires. Without this, a tripped abort can only
be observed between chunks via ``on_text_chunk`` — which never fires
for a turn that emits tool_use blocks without intervening text, so
ESC ends up waiting for the model to finish generating before the
outer query loop can bail.

Providers may override this to support tool-aware streaming. The default
implementation signals that rich streamed responses are unavailable.
"""
Expand Down
10 changes: 10 additions & 0 deletions src/providers/minimax_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,18 @@ def chat_stream_response(
messages: list[MessageInput],
tools: Optional[list[dict[str, Any]]] = None,
on_text_chunk: TextChunkCallback | None = None,
abort_signal: Any = None,
**kwargs
) -> ChatResponse:
# Pre-call fast-path: matches AnthropicProvider. A signal that
# tripped at a turn boundary skips the API round-trip entirely.
# Mid-stream cancellation isn't implemented yet — that needs the
# same response-close listener pattern AnthropicProvider uses,
# which the Minimax/anthropic-compatible SDK should support
# (it's the same underlying ``anthropic`` package) — future PR.
if abort_signal is not None and getattr(abort_signal, "aborted", False):
from src.utils.abort_controller import AbortError
raise AbortError(getattr(abort_signal, "reason", None) or "user_interrupt")
model = self._get_model(**kwargs)
max_tokens = kwargs.get("max_tokens", 4096)
system = kwargs.pop("system", None)
Expand Down
9 changes: 9 additions & 0 deletions src/providers/openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,18 @@ def chat_stream_response(
messages: list[MessageInput],
tools: Optional[list[dict[str, Any]]] = None,
on_text_chunk: TextChunkCallback | None = None,
abort_signal: Any = None,
**kwargs
) -> ChatResponse:
"""Stream OpenAI-compatible chunks while rebuilding the final response."""
# Pre-call fast-path: matches AnthropicProvider. A signal that
# tripped at a turn boundary skips the API round-trip entirely.
# Mid-stream cancellation isn't implemented yet — that needs a
# response-close listener around the OpenAI SDK's stream
# iterator — future PR.
if abort_signal is not None and getattr(abort_signal, "aborted", False):
from src.utils.abort_controller import AbortError
raise AbortError(getattr(abort_signal, "reason", None) or "user_interrupt")
model = self._get_model(**kwargs)
provider_messages = self._prepare_messages(messages)

Expand Down
31 changes: 29 additions & 2 deletions src/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..tool_system.context import ToolContext
from ..tool_system.protocol import ToolCall, ToolResult
from ..tool_system.registry import ToolRegistry
from ..utils.abort_controller import AbortController
from ..utils.abort_controller import AbortController, AbortError
from ..providers.base import BaseProvider, ChatResponse

from .config import QueryConfig, build_query_config
Expand Down Expand Up @@ -282,6 +282,7 @@ async def _call_model_sync(
system_prompt: str,
tools: Tools,
max_output_tokens_override: int | None = None,
abort_signal: Any = None,
) -> tuple[list[AssistantMessage], list[ToolUseBlock]]:
from ..types.messages import normalize_messages_for_api

Expand Down Expand Up @@ -383,11 +384,27 @@ async def _call_model_sync(
logger.warning("[DIAG] _call_model_sync: calling provider (streaming)...")
try:
try:
response = provider.chat_stream_response(api_messages, **call_kwargs)
# ``abort_signal`` reaches the provider so a tripped controller
# can close the streaming HTTP response immediately rather than
# waiting for the model to finish generating. Without this
# plumb, ESC during a tool-use-only response (no intermediate
# text chunks for an ``on_text_chunk`` to observe) waits the
# full model latency before the outer query loop bails.
response = provider.chat_stream_response(
api_messages, abort_signal=abort_signal, **call_kwargs,
)
except (NotImplementedError, AttributeError):
if _diag:
logger.warning("[DIAG] _call_model_sync: streaming not supported, falling back to chat()")
response = provider.chat(api_messages, **call_kwargs)
except AbortError:
# User-initiated cancel — propagate so the query loop's
# ``except AbortError: pass`` boundary unwinds to the
# post-API abort-check block. We do NOT route this through
# the error-message classification below: a future addition
# to those substring checks could accidentally match an abort
# reason and convert the cancel into a model-error reply.
raise
except Exception as e:
if _diag:
logger.warning("[DIAG] _call_model_sync: EXCEPTION after %.1fs: %s", time.monotonic() - _t0, e)
Expand Down Expand Up @@ -866,6 +883,7 @@ async def query(
system_prompt=params.system_prompt,
tools=params.tools,
max_output_tokens_override=max_output_tokens_override,
abort_signal=params.abort_controller.signal,
)
assistant_messages = returned_assistants
tool_use_blocks = returned_tool_blocks
Expand All @@ -886,6 +904,15 @@ async def query(
if not withheld:
yield msg

except AbortError:
# The provider's abort listener closed the streaming HTTP
# response mid-flight (ESC pressed while the model was still
# generating). The signal is already tripped, so let the
# ``if params.abort_controller.signal.aborted`` block right
# below us do the cancellation processing in exactly one
# place — anything we did here would duplicate that work.
pass

except Exception as e:
logger.error("Query error: %s", e)
error_message = str(e)
Expand Down
9 changes: 9 additions & 0 deletions src/tool_system/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,24 @@ def _call_provider_for_turn(
call_kwargs: dict[str, Any],
stream: bool,
on_text_chunk: TextChunkHandler | None,
cancel_signal: AbortSignal | None = None,
) -> tuple[Any, bool]:
"""Call the provider, preferring structured streaming when available.

Returns (response, streamed_live_text).

``cancel_signal`` is forwarded to the provider so a tripped signal can
close the streaming HTTP response immediately rather than waiting for
the model to finish generating. Without this plumb, ESC during a
tool-use-only turn (no intervening text chunks for ``on_text_chunk``
to observe) waits the full model latency before the agent loop bails.
"""
if stream:
try:
response = provider.chat_stream_response(
api_messages,
on_text_chunk=on_text_chunk,
abort_signal=cancel_signal,
**call_kwargs,
)
if not isinstance(response, ChatResponse):
Expand Down Expand Up @@ -344,6 +352,7 @@ def _check_cancel() -> None:
call_kwargs=call_kwargs,
stream=stream,
on_text_chunk=on_text_chunk,
cancel_signal=cancel_signal,
)
turn_count += 1

Expand Down
Loading