diff --git a/src/providers/openai_compatible.py b/src/providers/openai_compatible.py index fcfe8ed..4288449 100644 --- a/src/providers/openai_compatible.py +++ b/src/providers/openai_compatible.py @@ -330,13 +330,28 @@ def chat_stream_response( ) -> ChatResponse: """Stream OpenAI-compatible chunks while rebuilding the final response. - ESC-cancellation lives in ``StreamAbortGuard`` (see - ``_stream_abort.py``). This provider keeps the SDK-specific - iteration shape — bare ``for chunk in stream`` plus an - in-loop ``guard.aborted`` check that catches the case where - chunks arrive back-to-back fast enough that the listener's - close lands one iteration late (or where the SDK has already - prefetched chunks past the close point). + ESC-cancellation runs the SDK iteration on a daemon worker + thread that pushes chunks into a ``queue.Queue``. The main + thread polls the queue with a 100 ms timeout and re-checks + ``guard.aborted`` between ticks. On abort the main thread + raises ``AbortError`` immediately and orphans the worker — + the worker dies when the underlying connection eventually + closes. + + Why the worker indirection (vs. the simpler in-loop check + used in earlier revisions): the OpenAI Python SDK uses sync + ``httpx`` for streaming, and ``response.close()`` from + another thread is purely advisory. For LiteLLM-proxied + connections (and certain other httpx + chunked-transfer + configurations) the SDK's blocking socket read doesn't + actually return when the response is "closed" — it keeps + consuming bytes. Unlike JavaScript's native ``fetch + + AbortSignal`` integration (which the TypeScript reference at + ``typescript/src/services/api/openaiShim.ts`` uses), Python + has no portable way to make a sync blocking read honor an + abort from another thread, so the worker exists to keep the + main thread's response time independent of the SDK's + cooperation. """ from ._stream_abort import StreamAbortGuard @@ -378,63 +393,128 @@ def chat_stream_response( usage_obj: Any = None tool_calls_by_index: dict[int, dict[str, str]] = {} - with guard.attach(stream): + # Worker-thread iteration. The OpenAI Python SDK uses sync + # ``httpx`` for streaming, and ``response.close()`` from another + # thread is best-effort — for LiteLLM-proxied connections (and + # some other httpx configurations) the SDK's blocking socket + # read doesn't actually return when the response is closed. + # Unlike JavaScript's native ``fetch + AbortSignal`` integration + # (which the TypeScript reference uses), Python has no portable + # way to make a sync blocking read honor an abort from another + # thread. + # + # Workaround: hoist the iteration onto a daemon worker thread + # that pushes chunks into a queue. The main thread polls the + # queue with a short timeout and re-checks ``guard.aborted`` + # each tick. On abort we raise ``AbortError`` immediately and + # orphan the worker — it'll die when the underlying connection + # eventually closes (server-side, idle timeout, or the SDK's + # natural exhaustion). The cost is some wasted bandwidth on + # the orphaned read; the benefit is that the user's prompt + # comes back in ~100 ms regardless of LiteLLM/httpx behavior. + import queue as _queue + import threading as _threading + + _DONE = object() + chunk_queue: _queue.Queue = _queue.Queue() + + def _drain_stream() -> None: try: - for chunk in stream: - # In-loop check catches the SDK-prefetched-chunks - # case: the listener's close lands but the SDK has - # already buffered several chunks ahead. We break - # before consuming the next one. - if guard.aborted: - break + for c in stream: + chunk_queue.put(c) + except BaseException as exc: # noqa: BLE001 — surface to consumer + chunk_queue.put(exc) + finally: + chunk_queue.put(_DONE) + + worker = _threading.Thread( + target=_drain_stream, + daemon=True, + name=f"openai-stream-{id(stream)}", + ) - response_model = getattr(chunk, "model", response_model) - usage_candidate = getattr(chunk, "usage", None) - if usage_candidate is not None: - usage_obj = usage_candidate + with guard.attach(stream): + worker.start() + while True: + try: + item = chunk_queue.get(timeout=0.1) + except _queue.Empty: + # No chunk available right now — check abort and + # loop. The 100 ms tick bounds how long the user + # waits between pressing ESC and the prompt + # returning, regardless of how slow / blocked the + # underlying SDK iteration is. + if guard.aborted: + # Use ``raise_if_post_aborted`` so the abort + # reason from the controller is preserved + # (rather than hardcoding ``"user_interrupt"``, + # which would silently downgrade a non-default + # reason like a future ``"rate_limit_backoff"``). + guard.raise_if_post_aborted() + continue - choices = getattr(chunk, "choices", None) or [] - if not choices: - continue + if item is _DONE: + break + if isinstance(item, BaseException): + if isinstance(item, Exception): + guard.reraise_if_aborted(item) + raise item + # KeyboardInterrupt/SystemExit from the worker + # path — re-raise as-is so the outer signal- + # handling story stays intact. + raise item + + chunk = item + response_model = getattr(chunk, "model", response_model) + usage_candidate = getattr(chunk, "usage", None) + if usage_candidate is not None: + usage_obj = usage_candidate + + choices = getattr(chunk, "choices", None) or [] + if choices: choice = choices[0] if getattr(choice, "finish_reason", None): finish_reason = choice.finish_reason delta = getattr(choice, "delta", None) - if delta is None: - continue - - content_piece = getattr(delta, "content", None) - if content_piece: - piece = str(content_piece) - content_parts.append(piece) - if on_text_chunk is not None: - on_text_chunk(piece) - - reasoning_piece = getattr(delta, "reasoning_content", None) - if reasoning_piece: - reasoning_parts.append(str(reasoning_piece)) - - tool_call_deltas = getattr(delta, "tool_calls", None) or [] - for tc in tool_call_deltas: - idx = getattr(tc, "index", 0) - entry = tool_calls_by_index.setdefault(idx, {"id": "", "name": "", "arguments": ""}) - - tc_id = getattr(tc, "id", None) - if tc_id: - entry["id"] = str(tc_id) - - function = getattr(tc, "function", None) - if function is not None: - fn_name = getattr(function, "name", None) - if fn_name: - entry["name"] += str(fn_name) - fn_args = getattr(function, "arguments", None) - if fn_args: - entry["arguments"] += str(fn_args) - except Exception as streaming_exc: - guard.reraise_if_aborted(streaming_exc) - raise + if delta is not None: + content_piece = getattr(delta, "content", None) + if content_piece: + piece = str(content_piece) + content_parts.append(piece) + if on_text_chunk is not None: + on_text_chunk(piece) + + reasoning_piece = getattr(delta, "reasoning_content", None) + if reasoning_piece: + reasoning_parts.append(str(reasoning_piece)) + + tool_call_deltas = getattr(delta, "tool_calls", None) or [] + for tc in tool_call_deltas: + idx = getattr(tc, "index", 0) + entry = tool_calls_by_index.setdefault(idx, {"id": "", "name": "", "arguments": ""}) + + tc_id = getattr(tc, "id", None) + if tc_id: + entry["id"] = str(tc_id) + + function = getattr(tc, "function", None) + if function is not None: + fn_name = getattr(function, "name", None) + if fn_name: + entry["name"] += str(fn_name) + fn_args = getattr(function, "arguments", None) + if fn_args: + entry["arguments"] += str(fn_args) + + # Check abort AFTER processing this chunk so any + # already-delivered content is preserved (matches the + # in-loop-check semantics from the old implementation: + # the chunk-list test pins that the chunk we received + # before the abort gets processed; we just don't take + # the next one). + if guard.aborted: + guard.raise_if_post_aborted() # Stream completed naturally OR the in-loop check broke out. # In the latter case the signal is already tripped; raise so diff --git a/tests/test_openai_compat_abort_signal.py b/tests/test_openai_compat_abort_signal.py index aed120d..617d55a 100644 --- a/tests/test_openai_compat_abort_signal.py +++ b/tests/test_openai_compat_abort_signal.py @@ -266,3 +266,129 @@ def test_listener_detached_after_normal_completion() -> None: ) assert controller.signal._listeners == [] + + +class _StuckStream: + """Mimic an OpenAI Stream whose iterator never honors ``response.close()``. + + Models the LiteLLM/proxy scenario reported by the user: the + underlying socket is not interrupted when ``stream.response.close()`` + is called from another thread, so the SDK iterator stays blocked + on the next chunk indefinitely. The worker-thread iteration in + ``OpenAICompatibleProvider.chat_stream_response`` must NOT rely on + the iterator unblocking — the main thread polls a queue with + timeout and bails on abort. + + ``__iter__`` blocks on an ``Event`` that the test never sets, so + iteration would hang forever without the worker+queue decoupling. + """ + + def __init__(self) -> None: + self.response = MagicMock() + self._never_set = threading.Event() + self._iter_entered = threading.Event() + + def __iter__(self): + self._iter_entered.set() + # Block forever — even if response.close() is called. + # ``_never_set`` is never set in this test. + self._never_set.wait() + # Unreachable. If we somehow get here, yield nothing so the + # iterator ends and the test doesn't go on forever. + return + yield # pragma: no cover + + +def test_abort_unwinds_promptly_even_when_iterator_never_returns() -> None: + """The user's bug: ESC must unwind in <1s even when the SDK never honors close(). + + Pre-fix (single-threaded ``for chunk in stream``): the main thread + was blocked on ``next(stream)`` waiting for a chunk the LiteLLM + proxy never delivered, ``response.close()`` from the listener + thread didn't propagate to the kernel socket read, and ESC waited + indefinitely. + + Post-fix (worker thread + queue): the SDK iteration runs on a + daemon worker that gets orphaned on abort. The main thread polls + the queue with a 100 ms timeout and bails on ``guard.aborted``. + Total ESC-to-AbortError budget is one poll tick plus listener + cascade — well under 1 second on any reasonable machine. + + Failure mode this regression-tests against: someone reverting the + worker+queue would make the main thread block on ``next(stream)`` + again. With ``_StuckStream``'s never-set Event, the test would + hang forever (the assertion-failure form is a CI timeout, not a + fast fail — but a CI timeout is still loud). + """ + controller = AbortController() + stream = _StuckStream() + provider = _provider_with_stream(stream) + + def _trip_after_worker_starts() -> None: + # Wait for the worker thread to actually enter the iterator, + # so the test pins "abort during a stuck iteration" rather + # than "abort before the worker started". + assert stream._iter_entered.wait(timeout=2.0), "worker never entered iterator" + controller.abort("user_interrupt") + + threading.Thread(target=_trip_after_worker_starts, daemon=True).start() + + start = time.monotonic() + with pytest.raises(AbortError): + provider.chat_stream_response( + messages=[{"role": "user", "content": "hi"}], + abort_signal=controller.signal, + ) + elapsed = time.monotonic() - start + + # 100 ms poll tick + listener cascade + abort propagation. 1.5 s + # is comfortable headroom on slow CI; on a healthy laptop this is + # well under 300 ms. + assert elapsed < 1.5, f"abort took {elapsed:.2f}s — expected <1.5s" + + +class _ContentThenUsageStream: + """Stream that yields one content chunk then a final usage-only chunk. + + Mirrors OpenAI's streaming wire format when + ``stream_options.include_usage=True``: content/delta chunks first, + then a final chunk with empty ``choices`` and populated ``usage``. + """ + + def __init__(self) -> None: + self.response = MagicMock() + + def __iter__(self): + # Regular content chunk. + yield _FakeChunk(content="hello") + # Final usage-only chunk: empty choices, populated usage. + final = MagicMock() + final.model = "test-model" + final.choices = [] + final.usage = MagicMock( + prompt_tokens=10, completion_tokens=5, total_tokens=15, + ) + yield final + + +def test_normal_completion_still_captures_final_usage() -> None: + """The worker+queue path must not drop the final usage chunk. + + OpenAI emits usage stats only in the last chunk (with empty + ``choices``). The main thread must drain every queued chunk + before breaking on ``_DONE`` — otherwise token counting would + silently regress for non-aborted streams. + """ + controller = AbortController() + stream = _ContentThenUsageStream() + provider = _provider_with_stream(stream) + + response = provider.chat_stream_response( + messages=[{"role": "user", "content": "hi"}], + abort_signal=controller.signal, + ) + assert response.content == "hello" + # The final usage chunk made it through the queue; otherwise + # ``response.usage`` would be the default empty dict, and the + # ``↓ N tokens`` REPL spinner would silently lose count. + assert response.usage.get("total_tokens") == 15