From bcceaca46f874de7e91f961fb5c3cda29345f4b3 Mon Sep 17 00:00:00 2001 From: Eric Lee Date: Fri, 15 May 2026 10:32:44 -0700 Subject: [PATCH] refactor(providers): extract StreamAbortGuard helper for streaming cancel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PRs #144 and #145 added abort-signal-aware streaming for three providers (Anthropic, Minimax, OpenAI-compatible). The pattern across all three is structurally identical — pre-call fast-path, register- then-recheck listener closing the stream's response on abort, signal- state-authoritative exception translation, listener detachment, post- stream recheck — but the bookkeeping was inlined and triplicated. Adding a fourth provider would have meant a fourth copy. Extract the pattern into ``src/providers/_stream_abort.py``: * ``StreamAbortGuard(abort_signal)`` — one per call. ``abort_signal=None`` makes every method a no-op so providers can use it unconditionally. * ``raise_if_pre_aborted()`` — pre-call fast-path raising ``AbortError``. * ``raise_if_post_aborted()`` — same check at a different boundary. * ``aborted`` property — cheap in-loop check (OpenAI-compat path). * ``reraise_if_aborted(original_exc)`` — translate SDK exceptions to ``AbortError`` via ``raise ... from``; no-op if signal didn't fire. * ``with guard.attach(stream):`` — register a listener that closes ``stream.response`` on abort; detach in ``__exit__``. ``__exit__`` ALSO closes the response if ``signal.aborted`` is True at exit, as a safety net for the race window where ``AbortSignal._fire`` snapshots the listener list, the consumer thread observes ``aborted=True``, breaks out, detaches the listener, and the snapshot's iteration then runs against an empty list — that path would otherwise leak the underlying httpx response open. The close is idempotent on httpx (``if not self.is_closed`` guard) so the double-fire path is safe. Provider line counts: anthropic_provider.py -88 lines minimax_provider.py -73 lines openai_compatible.py -212 lines _stream_abort.py +211 (mostly docstring) Net: providers shrink from 374 lines of abort bookkeeping to 110 (SDK-specific iteration shape only). The provider-level tests from #144 and #145 (15 tests across three files) pass unmodified, proving behavior preservation end-to-end. Seventeen new unit tests in ``tests/test_stream_abort_guard.py`` pin the helper's contract directly: pre/post fast-paths, ``aborted`` property semantics, ``attach`` lifecycle (register, detach, race- recovery via register-then-recheck), close-failure tolerance, no- response-attribute graceful degradation, ``reraise_if_aborted`` translation + cause chaining + no-op, and — critically — the ``__exit__`` close-on-abort safety net. The safety-net test mutation-verified: removing the ``__exit__`` close branch makes the test fail with the exact expected assertion. The OpenAI-compat in-loop test gains a ``stream.response.close.called`` assertion to pin the close at the provider level too. Co-Authored-By: Claude Opus 4.7 --- src/providers/_stream_abort.py | 245 ++++++++++++++++++++ src/providers/anthropic_provider.py | 88 ++----- src/providers/minimax_provider.py | 73 ++---- src/providers/openai_compatible.py | 212 ++++++----------- tests/test_openai_compat_abort_signal.py | 7 + tests/test_stream_abort_guard.py | 280 +++++++++++++++++++++++ 6 files changed, 642 insertions(+), 263 deletions(-) create mode 100644 src/providers/_stream_abort.py create mode 100644 tests/test_stream_abort_guard.py diff --git a/src/providers/_stream_abort.py b/src/providers/_stream_abort.py new file mode 100644 index 0000000..3104176 --- /dev/null +++ b/src/providers/_stream_abort.py @@ -0,0 +1,245 @@ +"""Shared abort-signal helper for provider streaming paths. + +Three providers (Anthropic, Minimax, OpenAI-compatible) all need the +same pattern to make ESC interrupt a streaming HTTP read mid-flight: + +1. Pre-call fast-path so a tripped signal at a turn boundary skips + the API round-trip entirely. +2. Register a listener on the abort signal that calls + ``stream.response.close()``. The close interrupts the SDK's + blocking next-chunk read, which raises in the consumer thread. +3. Race-safe ordering: register-then-recheck closes the + sub-microsecond window where ``AbortSignal._fire`` could snapshot + the listener list and silently drop a freshly-appended listener. +4. Signal-state-authoritative exception translation: the SDK / httpx + layer can raise several different exception classes depending on + which syscall was in flight when the response closed, so + ``signal.aborted`` is the only stable abort indicator. +5. Cleanup: detach the listener in ``finally`` so long-lived + controllers (the REPL engine's, reused across many turns) don't + accumulate dead listeners against gone streams. +6. Post-stream recheck: catch a signal that fires after the iterator + exits naturally but before we return. + +This module factors the bookkeeping into ``StreamAbortGuard`` so each +provider only owns the SDK-specific iteration shape (Anthropic's +``stream.text_stream``, OpenAI's bare ``for chunk in stream``, +Minimax's ``with``-block + ``get_final_message``). Adding a new +provider becomes: build a ``StreamAbortGuard(abort_signal)``, call +``raise_if_pre_aborted()`` before the API request, wrap the SDK's +stream object in ``with guard.attach(stream):``, and translate +exceptions via ``guard.reraise_if_aborted(exc)`` in the ``except`` +block. The provider keeps full control over fallbacks (e.g. +Anthropic's ``StreamWatchdog`` non-streaming recovery) — the guard +just owns the listener lifecycle. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from src.utils.abort_controller import AbortError + +if TYPE_CHECKING: + from src.utils.abort_controller import AbortSignal + + +__all__ = ["StreamAbortGuard"] + + +def _close_response_safely(stream: Any) -> None: + """Best-effort close of ``stream.response`` — never raises. + + Both the Anthropic SDK (``client.messages.stream``'s + ``MessageStream``) and the OpenAI SDK (``Stream`` from + ``client.chat.completions.create(stream=True)``) expose the + underlying httpx ``Response`` as ``stream.response``. Close is + idempotent on httpx (``if not self.is_closed`` guard), so a + double-close (e.g., listener fires AND the post-loop path also + closes) is harmless. + + Failures in the listener thread must not propagate — the close is + purely defensive; the next-chunk read will eventually fail by + other means (timeout, server-side disconnect) even if the close + is a no-op. + """ + try: + response = getattr(stream, "response", None) + if response is not None: + close = getattr(response, "close", None) + if callable(close): + close() + except Exception: + pass + + +class StreamAbortGuard: + """Provider-side coordinator for abort-aware streaming. + + A single instance is built per ``chat_stream_response`` call; the + same instance handles the pre-call fast-path, listener lifecycle + around the SDK iteration, and post-stream recheck. When + ``abort_signal`` is ``None``, every method is a no-op — providers + can use the guard unconditionally without branching on the + presence of an abort signal. + + Usage pattern:: + + guard = StreamAbortGuard(abort_signal) + guard.raise_if_pre_aborted() # before API request + + with client.messages.stream(...) as stream: # SDK-specific + with guard.attach(stream): + try: + for chunk in stream.text_stream: # SDK-specific + if guard.aborted: # optional in-loop check + break + ... + except Exception as exc: + guard.reraise_if_aborted(exc) + raise # non-abort exception + + guard.raise_if_post_aborted() # signal may have fired after + # the with-block exited normally + + The guard does NOT own the SDK's stream lifecycle — the provider + keeps its own ``with`` / ``try`` / ``finally`` around the stream + object, so provider-specific recovery (e.g., the Anthropic + watchdog's non-streaming fallback) and provider-specific cleanup + (e.g., ``stream.get_final_message()``) stay where they belong. + """ + + __slots__ = ("_signal",) + + def __init__(self, abort_signal: "AbortSignal | None") -> None: + self._signal = abort_signal + + @property + def aborted(self) -> bool: + """True when the signal has fired. ``False`` when no signal was provided. + + Cheaper than calling ``raise_if_aborted`` in a hot loop — the + in-loop check inside ``for chunk in stream:`` uses this so it + doesn't pay the exception construction cost on every chunk. + """ + return self._signal is not None and self._signal.aborted + + def raise_if_pre_aborted(self) -> None: + """Raise ``AbortError`` if the signal was already tripped at call entry. + + Called BEFORE the API round-trip so a signal that tripped at a + turn boundary doesn't pay the request cost. Identical shape + across every provider. + """ + if self._signal is not None and self._signal.aborted: + raise AbortError(self._signal.reason or "user_interrupt") + + def raise_if_post_aborted(self) -> None: + """Raise ``AbortError`` if the signal tripped after stream exit. + + Catches the window between ``stream.__exit__`` (or the + iterator's natural exhaustion) and the provider's return. + Same shape as the pre-aborted check, called from a different + boundary. + """ + self.raise_if_pre_aborted() + + def reraise_if_aborted(self, original_exc: BaseException) -> None: + """If the signal aborted, translate the SDK exception to ``AbortError``. + + The provider catches ``Exception`` around the streaming + iterator and asks the guard whether the exception was caused + by the user pressing ESC. We check the signal state (not the + exception class) because different SDK versions raise + different classes when the underlying response is closed + mid-read — ``httpx.ReadError``, ``httpx.RemoteProtocolError``, + ``OSError``, ``BrokenPipeError``, or wrapped variants + depending on which syscall was in flight. + + If the signal isn't aborted (genuine network error, auth + failure, etc.), this is a no-op so the provider's ``raise`` + statement runs and the real error propagates with its + original class intact. + """ + if self._signal is not None and self._signal.aborted: + raise AbortError(self._signal.reason or "user_interrupt") from original_exc + + def attach(self, stream: Any) -> "_StreamAbortContext": + """Register a close-on-abort listener for ``stream``'s lifetime. + + Returns a context manager. While active, a tripped signal + synchronously calls ``stream.response.close()`` from whichever + thread fires the abort (TUI keypress thread, headless SIGINT + handler, etc.). On context exit the listener is detached so a + long-lived ``AbortController`` doesn't accumulate dead + listeners pointing at gone streams. + + When ``abort_signal`` is ``None`` this is a no-op context. + """ + return _StreamAbortContext(self._signal, stream) + + +class _StreamAbortContext: + """Context manager that owns one close-on-abort listener. + + Not part of the public API — callers construct via + ``StreamAbortGuard.attach(stream)``. + + Registration ordering: register-then-recheck. The naive + "if aborted: close else: add_listener" sequence has a + sub-microsecond race where another thread can call ``_fire`` + between the ``aborted`` read and the ``add_listener`` append; + ``_fire`` snapshots the listener list before iterating, so a + listener appended after the snapshot is silently dropped. + Register-then-recheck closes the gap: ``aborted`` is sticky-True + after ``_fire`` runs, so the post-add read catches any concurrent + fire, and the close callback is idempotent. + """ + + __slots__ = ("_signal", "_stream", "_listener") + + def __init__(self, signal: "AbortSignal | None", stream: Any) -> None: + self._signal = signal + self._stream = stream + self._listener: Any = None + + def __enter__(self) -> "_StreamAbortContext": + if self._signal is None: + return self + + stream = self._stream + + def _close() -> None: + _close_response_safely(stream) + + # Register, then re-check. See the docstring above for the + # race analysis. + self._listener = self._signal.add_listener(_close, once=True) + if self._signal.aborted: + _close() + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + # Close-on-abort guarantee. The listener-firing path closes + # the response synchronously from whichever thread tripped + # the abort — but ``AbortSignal._fire`` snapshots the + # listener list BEFORE iterating, so a narrow race exists: + # if the consumer thread observes ``guard.aborted == True``, + # breaks out, and runs this ``__exit__`` (which detaches the + # listener) before the abort thread reaches the snapshot's + # firing iteration, the listener is silently dropped and the + # underlying httpx response leaks open. To close the gap we + # do one more idempotent close here whenever ``aborted`` is + # True at exit — covers both the in-loop-break path (OpenAI + # provider) and any future provider that exits the attach + # context after observing the abort without raising. + if self._signal is not None and self._signal.aborted: + _close_response_safely(self._stream) + if self._listener is not None and self._signal is not None: + try: + self._signal.remove_listener(self._listener) + except Exception: + pass + # Never suppress exceptions — the provider's surrounding + # try/except is where exception translation happens via + # ``StreamAbortGuard.reraise_if_aborted``. + return False diff --git a/src/providers/anthropic_provider.py b/src/providers/anthropic_provider.py index 128732a..a54d221 100644 --- a/src/providers/anthropic_provider.py +++ b/src/providers/anthropic_provider.py @@ -272,14 +272,15 @@ def chat_stream_response( translate the raise into ``AbortError`` so callers can distinguish a user-initiated cancel from the watchdog's idle-timeout fallback. """ - from src.utils.abort_controller import AbortError from src.utils.stream_watchdog import StreamWatchdog - # Fast-path: if abort fired before we even build the request, don't - # spend the round-trip — raise directly so the caller's cancel - # boundary unwinds at the same place the mid-stream path lands. - if abort_signal is not None and abort_signal.aborted: - raise AbortError(abort_signal.reason or "user_interrupt") + from ._stream_abort import StreamAbortGuard + + guard = StreamAbortGuard(abort_signal) + # Fast-path: if abort fired before we even build the request, + # raise directly so the caller's cancel boundary unwinds at + # the same place the mid-stream path lands. + guard.raise_if_pre_aborted() model = self._get_model(**kwargs) max_tokens = kwargs.get("max_tokens", 4096) @@ -315,7 +316,6 @@ def _fallback_to_chat() -> ChatResponse: streamed_text = "" watchdog_fired = False final_message = None - abort_listener: Any = None try: with client.messages.stream( model=model, @@ -324,45 +324,12 @@ def _fallback_to_chat() -> ChatResponse: **({"system": system} if system else {}), **extra_kwargs, **{k: v for k, v in kwargs.items() if k not in ["model", "max_tokens", "tools"]}, - ) as stream: - # Register the abort listener BEFORE the iterator pulls - # its first chunk, so a signal that fires between context - # entry and the first ``text_stream.__next__`` still wins - # the race. Mirrors ``StreamWatchdog``'s close pattern: - # close the underlying HTTP response from another thread, - # which raises in the consumer thread on the next pull. - if abort_signal is not None: - def _close_stream_on_abort() -> None: - try: - response = getattr(stream, "response", None) - if response is not None: - close = getattr(response, "close", None) - if callable(close): - close() - except Exception: - # Best-effort — never let the close - # propagate out of the listener thread. - pass - - # Register-then-recheck (NOT check-then-register): - # the naive ordering has a sub-microsecond race - # where another thread can call ``_fire`` between - # our ``aborted`` read and the ``add_listener`` - # append. ``_fire`` snapshots the listener list - # before iterating, so any listener appended after - # that snapshot is silently dropped. - # Register-then-recheck closes the gap: ``aborted`` - # is sticky-True after ``_fire`` runs, so the - # post-add read catches any concurrent fire, and - # ``_close_stream_on_abort`` is idempotent so a - # double-call (listener fires AND we call directly) - # is harmless. - abort_listener = abort_signal.add_listener( - _close_stream_on_abort, once=True, - ) - if abort_signal.aborted: - _close_stream_on_abort() - + ) as stream, guard.attach(stream): + # ``guard.attach`` registered the close-on-abort listener + # (see ``_stream_abort.py`` for the race-safe ordering + # and the close-via-stream.response.close mechanism). + # The provider keeps the watchdog and fallback logic + # local: they aren't abort-related. watchdog = StreamWatchdog(stream) watchdog.arm() try: @@ -389,17 +356,12 @@ def _close_stream_on_abort() -> None: watchdog_fired = watchdog.fired watchdog.disarm() except Exception as streaming_exc: - # Abort path: the abort listener closed the stream's response, - # which raised in the consumer thread. Translate to - # ``AbortError`` so the query loop's - # ``except AbortError: raise`` cancel boundary unwinds - # cleanly. We check the signal AFTER the catch (not the - # exception type) because the SDK can raise several different - # exception classes depending on which socket operation was - # in flight when we closed; the abort_signal state is the - # authoritative source of truth. - if abort_signal is not None and abort_signal.aborted: - raise AbortError(abort_signal.reason or "user_interrupt") from streaming_exc + # Abort path FIRST: a user cancel must win over the + # watchdog fallback (the abort listener may also have + # tripped the watchdog's race, so we'd otherwise route a + # user cancel through non-streaming recovery and burn + # another round-trip). + guard.reraise_if_aborted(streaming_exc) # WI-5.2 fallback path: stream interrupted by the idle # watchdog. Fall back to non-streaming so the user still @@ -416,17 +378,11 @@ def _close_stream_on_abort() -> None: # error and re-raised only the streaming one. raise fallback_exc from streaming_exc raise - finally: - # Always detach the abort listener so it doesn't pin the - # provider alive past one call. - if abort_listener is not None and abort_signal is not None: - abort_signal.remove_listener(abort_listener) # Stream completed normally but abort may have fired between - # ``stream.__exit__`` and here. Surface it now so the caller - # bails at the same place every other path does. - if abort_signal is not None and abort_signal.aborted: - raise AbortError(abort_signal.reason or "user_interrupt") + # ``stream.__exit__`` and here. Surface it at the same boundary + # the mid-stream path uses. + guard.raise_if_post_aborted() if watchdog_fired: # Stream got interrupted but no exception escaped the diff --git a/src/providers/minimax_provider.py b/src/providers/minimax_provider.py index def28e5..03e7d47 100644 --- a/src/providers/minimax_provider.py +++ b/src/providers/minimax_provider.py @@ -170,20 +170,18 @@ def chat_stream_response( ) -> ChatResponse: """Stream Minimax response with abort-signal-aware cancellation. - Minimax wraps the anthropic SDK against its compatible endpoint, - so the response-close listener pattern AnthropicProvider uses - works here too. Same contract: pre-call fast-path, register- - then-recheck listener that closes the underlying HTTP response, - signal-state-authoritative abort detection in the exception - handler, post-with-block recheck, ``finally`` detaches the - listener. + Minimax wraps the anthropic SDK against its compatible + endpoint, so the same response-close listener pattern + AnthropicProvider uses works here too. The bookkeeping lives + in ``StreamAbortGuard``; this provider only owns the + SDK-specific iteration shape (``with client.messages.stream`` + + ``stream.text_stream`` + ``get_final_message``). """ - from src.utils.abort_controller import AbortError + from ._stream_abort import StreamAbortGuard + + guard = StreamAbortGuard(abort_signal) + guard.raise_if_pre_aborted() - # Pre-call fast-path: matches AnthropicProvider. A signal that - # tripped at a turn boundary skips the API round-trip entirely. - if abort_signal is not None and getattr(abort_signal, "aborted", False): - raise AbortError(getattr(abort_signal, "reason", None) or "user_interrupt") model = self._get_model(**kwargs) max_tokens = kwargs.get("max_tokens", 4096) system = kwargs.pop("system", None) @@ -196,7 +194,6 @@ def chat_stream_response( streamed_text = "" final_message: Any = None - abort_listener: Any = None try: with client.messages.stream( model=model, @@ -205,29 +202,7 @@ def chat_stream_response( **({"system": system} if system else {}), **extra_kwargs, **{k: v for k, v in kwargs.items() if k not in ["model", "max_tokens", "tools"]}, - ) as stream: - if abort_signal is not None: - def _close_stream_on_abort() -> None: - try: - response = getattr(stream, "response", None) - if response is not None: - close = getattr(response, "close", None) - if callable(close): - close() - except Exception: - pass - - # Register-then-recheck: see AnthropicProvider for the - # full race analysis. ``_fire`` snapshots the listener - # list before iterating, so a listener appended after - # the snapshot is silently dropped; the post-add - # ``aborted`` read closes the gap. - abort_listener = abort_signal.add_listener( - _close_stream_on_abort, once=True, - ) - if abort_signal.aborted: - _close_stream_on_abort() - + ) as stream, guard.attach(stream): for text in stream.text_stream: if not text: continue @@ -239,28 +214,12 @@ def _close_stream_on_abort() -> None: except Exception: final_message = None except Exception as streaming_exc: - # Abort path: signal state is authoritative — different SDK - # versions raise different exception types when the response - # is closed mid-read. - if abort_signal is not None and getattr(abort_signal, "aborted", False): - raise AbortError( - getattr(abort_signal, "reason", None) or "user_interrupt" - ) from streaming_exc + guard.reraise_if_aborted(streaming_exc) raise - finally: - if abort_listener is not None and abort_signal is not None: - try: - abort_signal.remove_listener(abort_listener) - except Exception: - pass - - # Stream completed normally but abort may have fired between - # ``__exit__`` and here. Surface it at the same boundary every - # other path uses. - if abort_signal is not None and getattr(abort_signal, "aborted", False): - raise AbortError( - getattr(abort_signal, "reason", None) or "user_interrupt" - ) + + # Stream exited normally but abort may have fired between + # ``__exit__`` and here. + guard.raise_if_post_aborted() if final_message is not None: return self._build_chat_response(final_message) diff --git a/src/providers/openai_compatible.py b/src/providers/openai_compatible.py index 208392a..fcfe8ed 100644 --- a/src/providers/openai_compatible.py +++ b/src/providers/openai_compatible.py @@ -330,33 +330,19 @@ def chat_stream_response( ) -> ChatResponse: """Stream OpenAI-compatible chunks while rebuilding the final response. - ESC-cancellation: when ``abort_signal`` is provided, two defenses - cooperate so ESC unwinds the stream promptly regardless of the - provider's chunk cadence: - - * **Response-close listener** registered on the abort signal — - calls ``stream.response.close()``. Closes the underlying HTTP - socket so the SDK's blocking next-chunk read raises - immediately, even when the model is in a long gap between - chunks (extended thinking, tool_use generation). - * **In-loop abort check** at the top of each ``for chunk in - stream`` iteration — catches the case where chunks arrive - back-to-back and the listener's close lands one iteration - late, so we stop iterating before the next read. - - Mirrors the contract ``AnthropicProvider.chat_stream_response`` - established for the Anthropic SDK; same correctness arguments - apply (signal state is authoritative for abort detection; - register-then-recheck closes the registration race; listener - is detached in a ``finally`` so long-lived controllers don't - accumulate dead listeners). + ESC-cancellation lives in ``StreamAbortGuard`` (see + ``_stream_abort.py``). This provider keeps the SDK-specific + iteration shape — bare ``for chunk in stream`` plus an + in-loop ``guard.aborted`` check that catches the case where + chunks arrive back-to-back fast enough that the listener's + close lands one iteration late (or where the SDK has already + prefetched chunks past the close point). """ - from src.utils.abort_controller import AbortError + from ._stream_abort import StreamAbortGuard + + guard = StreamAbortGuard(abort_signal) + guard.raise_if_pre_aborted() - # Pre-call fast-path: matches AnthropicProvider. A signal that - # tripped at a turn boundary skips the API round-trip entirely. - if abort_signal is not None and getattr(abort_signal, "aborted", False): - raise AbortError(getattr(abort_signal, "reason", None) or "user_interrupt") model = self._get_model(**kwargs) provider_messages = self._prepare_messages(messages) @@ -392,122 +378,68 @@ def chat_stream_response( usage_obj: Any = None tool_calls_by_index: dict[int, dict[str, str]] = {} - # --- Abort-listener wiring --- - # Close the underlying HTTP response when the signal trips so a - # blocking next-chunk read raises immediately. The OpenAI Python - # SDK 1.x and 2.x both expose the underlying httpx Response as - # ``stream.response`` (see ``openai/_streaming.py``). - # ``httpx.Response.close()`` is idempotent (guarded by - # ``if not self.is_closed``), so a double-close — e.g., the - # listener fires AND the post-loop path explicitly closes — is - # harmless. - def _close_stream_on_abort() -> None: + with guard.attach(stream): try: - response = getattr(stream, "response", None) - if response is not None: - close = getattr(response, "close", None) - if callable(close): - close() - except Exception: - # Best-effort — never let close() propagate out of the - # listener thread. - pass - - abort_listener: Any = None - if abort_signal is not None: - # Register-then-recheck: see the Anthropic provider for the - # full race analysis. The TL;DR is that ``_fire`` snapshots - # the listener list before iterating, so a listener appended - # after that snapshot is silently dropped; the post-add - # ``aborted`` read closes the gap (signal state is sticky). - abort_listener = abort_signal.add_listener( - _close_stream_on_abort, once=True, - ) - if abort_signal.aborted: - _close_stream_on_abort() - - try: - for chunk in stream: - # In-loop abort check: even when the listener fires - # mid-stream, chunks already buffered by the SDK can - # still get yielded before the closed-socket raise lands. - # The in-loop check makes the abort observable on the - # very next chunk boundary regardless of buffering. - if abort_signal is not None and abort_signal.aborted: - break - - response_model = getattr(chunk, "model", response_model) - usage_candidate = getattr(chunk, "usage", None) - if usage_candidate is not None: - usage_obj = usage_candidate - - choices = getattr(chunk, "choices", None) or [] - if not choices: - continue - choice = choices[0] - if getattr(choice, "finish_reason", None): - finish_reason = choice.finish_reason - - delta = getattr(choice, "delta", None) - if delta is None: - continue - - content_piece = getattr(delta, "content", None) - if content_piece: - piece = str(content_piece) - content_parts.append(piece) - if on_text_chunk is not None: - on_text_chunk(piece) - - reasoning_piece = getattr(delta, "reasoning_content", None) - if reasoning_piece: - reasoning_parts.append(str(reasoning_piece)) - - tool_call_deltas = getattr(delta, "tool_calls", None) or [] - for tc in tool_call_deltas: - idx = getattr(tc, "index", 0) - entry = tool_calls_by_index.setdefault(idx, {"id": "", "name": "", "arguments": ""}) - - tc_id = getattr(tc, "id", None) - if tc_id: - entry["id"] = str(tc_id) - - function = getattr(tc, "function", None) - if function is not None: - fn_name = getattr(function, "name", None) - if fn_name: - entry["name"] += str(fn_name) - fn_args = getattr(function, "arguments", None) - if fn_args: - entry["arguments"] += str(fn_args) - except Exception as streaming_exc: - # Abort path: the listener closed the underlying HTTP - # response, which raised on the SDK's next read in the - # consumer thread. Detect via signal state (not exception - # class — the OpenAI/httpx layer can raise several different - # exception types depending on which syscall was in flight). - if abort_signal is not None and getattr(abort_signal, "aborted", False): - raise AbortError( - getattr(abort_signal, "reason", None) or "user_interrupt" - ) from streaming_exc - raise - finally: - if abort_listener is not None and abort_signal is not None: - try: - abort_signal.remove_listener(abort_listener) - except Exception: - pass - - # The stream may have completed naturally OR we broke out of - # the loop because the in-loop abort check fired. Surface the - # abort here so the caller bails at the same place every other - # path does. ``stream.close()`` after a clean exit is a no-op - # on httpx, so this stays safe. - if abort_signal is not None and getattr(abort_signal, "aborted", False): - _close_stream_on_abort() - raise AbortError( - getattr(abort_signal, "reason", None) or "user_interrupt" - ) + for chunk in stream: + # In-loop check catches the SDK-prefetched-chunks + # case: the listener's close lands but the SDK has + # already buffered several chunks ahead. We break + # before consuming the next one. + if guard.aborted: + break + + response_model = getattr(chunk, "model", response_model) + usage_candidate = getattr(chunk, "usage", None) + if usage_candidate is not None: + usage_obj = usage_candidate + + choices = getattr(chunk, "choices", None) or [] + if not choices: + continue + choice = choices[0] + if getattr(choice, "finish_reason", None): + finish_reason = choice.finish_reason + + delta = getattr(choice, "delta", None) + if delta is None: + continue + + content_piece = getattr(delta, "content", None) + if content_piece: + piece = str(content_piece) + content_parts.append(piece) + if on_text_chunk is not None: + on_text_chunk(piece) + + reasoning_piece = getattr(delta, "reasoning_content", None) + if reasoning_piece: + reasoning_parts.append(str(reasoning_piece)) + + tool_call_deltas = getattr(delta, "tool_calls", None) or [] + for tc in tool_call_deltas: + idx = getattr(tc, "index", 0) + entry = tool_calls_by_index.setdefault(idx, {"id": "", "name": "", "arguments": ""}) + + tc_id = getattr(tc, "id", None) + if tc_id: + entry["id"] = str(tc_id) + + function = getattr(tc, "function", None) + if function is not None: + fn_name = getattr(function, "name", None) + if fn_name: + entry["name"] += str(fn_name) + fn_args = getattr(function, "arguments", None) + if fn_args: + entry["arguments"] += str(fn_args) + except Exception as streaming_exc: + guard.reraise_if_aborted(streaming_exc) + raise + + # Stream completed naturally OR the in-loop check broke out. + # In the latter case the signal is already tripped; raise so + # the caller bails at the same place every other path does. + guard.raise_if_post_aborted() tool_uses: list[dict[str, Any]] = [] for idx in sorted(tool_calls_by_index.keys()): diff --git a/tests/test_openai_compat_abort_signal.py b/tests/test_openai_compat_abort_signal.py index 4a5ac68..aed120d 100644 --- a/tests/test_openai_compat_abort_signal.py +++ b/tests/test_openai_compat_abort_signal.py @@ -209,6 +209,13 @@ def __iter__(self): # back to the original "ESC waits for the model to finish # generating" behaviour. assert seen == ["first"], f"in-loop check leaked second chunk: {seen}" + # And on the in-loop-break path, the underlying httpx response + # must still be closed (otherwise the socket leaks). The + # listener fired during the synchronous ``controller.abort()`` + # call inside the iterator, so close() was already invoked once + # there; the helper's ``__exit__`` close-on-abort guarantee adds + # a second idempotent call. We just assert at-least-once. + assert stream.response.close.called, "stream.response was not closed on in-loop break" def test_uncancelled_stream_returns_normally() -> None: diff --git a/tests/test_stream_abort_guard.py b/tests/test_stream_abort_guard.py new file mode 100644 index 0000000..eb1f5a1 --- /dev/null +++ b/tests/test_stream_abort_guard.py @@ -0,0 +1,280 @@ +"""Unit tests for ``src/providers/_stream_abort.py``. + +The provider-level tests in ``test_provider_abort_signal.py``, +``test_openai_compat_abort_signal.py``, and +``test_minimax_abort_signal.py`` already cover the end-to-end behavior +through each provider's ``chat_stream_response`` path. This file pins +the helper's contract directly so a future refactor that changes one +provider but forgets to update the helper (or vice versa) fails fast +at the unit level. +""" +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from src.providers._stream_abort import StreamAbortGuard +from src.utils.abort_controller import AbortController, AbortError + + +def _make_stream() -> MagicMock: + """Build a stub stream with a ``response.close`` we can assert on.""" + stream = MagicMock() + stream.response = MagicMock() + return stream + + +# --------------------------------------------------------------------------- +# Pre-call / post-call fast-paths + + +def test_raise_if_pre_aborted_no_signal_is_noop() -> None: + """A guard with ``abort_signal=None`` does not raise on the pre-call check. + + Providers can build a guard unconditionally — callers that don't + pass an abort signal just get a guard that does nothing. + """ + StreamAbortGuard(None).raise_if_pre_aborted() # no exception + + +def test_raise_if_pre_aborted_signal_clear_is_noop() -> None: + controller = AbortController() + StreamAbortGuard(controller.signal).raise_if_pre_aborted() # no exception + + +def test_raise_if_pre_aborted_signal_set_raises_abort_error() -> None: + controller = AbortController() + controller.abort("user_interrupt") + guard = StreamAbortGuard(controller.signal) + with pytest.raises(AbortError) as exc_info: + guard.raise_if_pre_aborted() + assert exc_info.value.reason == "user_interrupt" + + +def test_raise_if_post_aborted_mirrors_pre_check() -> None: + """Post-stream recheck has the same shape as the pre-call check. + + Providers call this after the SDK's ``with``-block exits to catch + a signal that fired between ``__exit__`` and the function return. + """ + controller = AbortController() + guard = StreamAbortGuard(controller.signal) + guard.raise_if_post_aborted() # clean — no exception + + controller.abort("user_interrupt") + with pytest.raises(AbortError): + guard.raise_if_post_aborted() + + +# --------------------------------------------------------------------------- +# Aborted property — used by the in-loop check inside OpenAI-compat + + +def test_aborted_property_reflects_signal_state() -> None: + controller = AbortController() + guard = StreamAbortGuard(controller.signal) + assert guard.aborted is False + controller.abort("user_interrupt") + assert guard.aborted is True + + +def test_aborted_property_is_false_when_no_signal() -> None: + """``aborted`` is False when ``abort_signal=None`` — never tripped.""" + assert StreamAbortGuard(None).aborted is False + + +# --------------------------------------------------------------------------- +# attach() — the listener-lifecycle context manager + + +def test_attach_no_signal_is_noop_context() -> None: + """With no signal, ``attach`` returns a context that does nothing. + + Lets providers wrap the iteration unconditionally without branching + on whether the caller passed an abort_signal. + """ + stream = _make_stream() + guard = StreamAbortGuard(None) + with guard.attach(stream): + pass # no listener registered, nothing to clean up + stream.response.close.assert_not_called() + + +def test_attach_registers_listener_and_detaches_on_exit() -> None: + """The listener exists while attached and is gone after exit. + + Pins the long-running-controller invariant: a single + AbortController reused across many turns must not accumulate + listeners pointing at gone streams. + """ + controller = AbortController() + stream = _make_stream() + guard = StreamAbortGuard(controller.signal) + + assert controller.signal._listeners == [] + with guard.attach(stream): + assert len(controller.signal._listeners) == 1 + assert controller.signal._listeners == [] + + +def test_attach_fires_close_when_signal_trips_after_enter() -> None: + """A signal that fires mid-attach calls ``stream.response.close()``.""" + controller = AbortController() + stream = _make_stream() + guard = StreamAbortGuard(controller.signal) + + with guard.attach(stream): + stream.response.close.assert_not_called() + controller.abort("user_interrupt") + stream.response.close.assert_called_once() + + +def test_attach_fires_close_when_signal_already_tripped_at_enter() -> None: + """Race-recovery: signal fired before ``attach`` calls ``__enter__``. + + The naive "check then register" sequence has a sub-microsecond + race where ``_fire`` can snapshot the listener list before our + ``add_listener`` append; the listener would be silently dropped. + The helper's ``register-then-recheck`` ordering closes the gap: + after ``add_listener`` we re-check ``aborted`` and call the close + callback directly if the signal is already tripped. + """ + controller = AbortController() + controller.abort("user_interrupt") + stream = _make_stream() + guard = StreamAbortGuard(controller.signal) + + with guard.attach(stream): + # The recheck after add_listener fired close() directly. + stream.response.close.assert_called() + + +def test_attach_close_failures_do_not_propagate() -> None: + """A raising ``stream.response.close()`` is swallowed. + + The listener fires from whichever thread tripped the abort (UI + thread, SIGINT handler) — letting close() raise there would crash + that thread without delivering the cancel. + """ + controller = AbortController() + stream = _make_stream() + stream.response.close.side_effect = RuntimeError("simulated close failure") + guard = StreamAbortGuard(controller.signal) + + with guard.attach(stream): + # Should not raise even though close() throws. + controller.abort("user_interrupt") + + # And the listener detach in __exit__ also tolerates the + # already-fired state (the once=True wrapper has already + # self-detached). + + +def test_attach_no_response_attribute_is_safe() -> None: + """A stream without a ``response`` attribute is silently skipped. + + Future SDKs may name the response differently; the helper should + degrade to "no close happens" rather than raising AttributeError + inside the listener thread. + """ + controller = AbortController() + stream = MagicMock(spec=[]) # no response attribute + guard = StreamAbortGuard(controller.signal) + + with guard.attach(stream): + controller.abort("user_interrupt") # no AttributeError raised + + # Stream lacks a response attribute; helper just no-ops. + + +def test_exit_closes_stream_when_signal_aborted_no_listener_fire() -> None: + """``__exit__`` closes the response if the signal aborted but the listener never fired. + + Pins the race-recovery guarantee. ``AbortSignal._fire`` snapshots + the listener list before iterating, so a narrow window exists + where the consumer thread can: + 1. Observe ``aborted == True`` (set by ``_fire`` BEFORE the + listener iteration starts), + 2. Break out of the iteration, + 3. Exit the ``with`` block — ``__exit__`` runs, detaches the + listener, + 4. The original ``_fire`` thread resumes, ``list(self._listeners)`` + is now empty, the close never fires. + + Without the ``__exit__`` close fallback, the underlying httpx + response would leak open. We simulate the race by setting + ``_aborted = True`` directly (bypasses ``_fire``'s listener + iteration) so the listener is guaranteed to have NOT fired. + """ + controller = AbortController() + stream = _make_stream() + guard = StreamAbortGuard(controller.signal) + + with guard.attach(stream): + # Trip the signal WITHOUT going through _fire, so no listener + # is invoked — mimics the race window above where the abort + # thread set ``_aborted=True`` but the listener iteration + # races with our ``__exit__``. + controller.signal._aborted = True + controller.signal._reason = "user_interrupt" + stream.response.close.assert_not_called() + # ``__exit__`` must close even when the listener never fired. + stream.response.close.assert_called() + + +def test_exit_does_not_close_when_signal_not_aborted() -> None: + """The fallback close only fires on abort — clean exits don't trigger it. + + Regression guard: a stream that exits the attach context after + natural iterator exhaustion (no abort) must not get a redundant + ``close()`` call. The SDK's own ``__exit__`` is responsible for + cleanup on the happy path. + """ + controller = AbortController() + stream = _make_stream() + guard = StreamAbortGuard(controller.signal) + + with guard.attach(stream): + pass # no abort, no break — clean exit + + stream.response.close.assert_not_called() + + +# --------------------------------------------------------------------------- +# reraise_if_aborted — exception translation + + +def test_reraise_if_aborted_no_abort_is_noop() -> None: + """If the signal didn't fire, leave the original exception alone.""" + controller = AbortController() + guard = StreamAbortGuard(controller.signal) + orig = RuntimeError("genuine network error") + # No raise — the caller's subsequent ``raise`` re-raises ``orig``. + guard.reraise_if_aborted(orig) + + +def test_reraise_if_aborted_translates_to_abort_error_with_cause() -> None: + """When the signal fired, translate to AbortError preserving the cause. + + The SDK / httpx layer can raise several different exception + classes when the underlying response is closed mid-read; the + guard uses the signal state (not the exception class) as the + authoritative abort indicator. The original exception is + chained via ``raise ... from`` so observers can still see what + the SDK reported. + """ + controller = AbortController() + controller.abort("user_interrupt") + guard = StreamAbortGuard(controller.signal) + orig = ConnectionError("socket closed mid-read") + + with pytest.raises(AbortError) as exc_info: + guard.reraise_if_aborted(orig) + assert exc_info.value.reason == "user_interrupt" + assert exc_info.value.__cause__ is orig + + +def test_reraise_if_aborted_no_signal_is_noop() -> None: + """``abort_signal=None`` guards always treat the exception as non-abort.""" + StreamAbortGuard(None).reraise_if_aborted(RuntimeError("anything"))