From 55282c197af3b7f2427f02b73c687b0019da0c46 Mon Sep 17 00:00:00 2001 From: Rusty Conover Date: Tue, 23 Jun 2026 23:11:43 -0400 Subject: [PATCH] Fix HTTP-continuation loop: externalize table-function scan state as a cursor The five vision table functions were TableFunctionGenerator[Args] with `state: None`, emitting all rows in one `out.emit()` then `out.finish()`. Over the stateless HTTP transport the framework round-trips per-scan state through a continuation token and emits at most one producer batch per response; a position-less state restarts from row 0 on every resume and loops forever once the result exceeds one batch. `image_classes()` (1000 rows) is the smoking gun. subprocess/unix keep live state in-process and hide the bug. Fix (mirrors vgi-search's ScanState): - Add `ROWS_PER_TICK = 64` and `ScanState(ArrowSerializableDataclass)` with `started`/`offset`/`rows_ipc` (+ `result_to_ipc`/`ipc_to_table` helpers). - Convert all five functions to TableFunctionGenerator[Args, ScanState] with `initial_state`. First tick materializes the full output batch into `state.rows_ipc`; each tick emits one bounded ROWS_PER_TICK slice from `offset`, advances `offset` BEFORE emit (so a tick suspended at the limit-1 boundary serializes the advanced offset), and finishes when drained. The NULL/empty-image early-out (zero-row batch + finish) is preserved. Rows and schema are byte-identical. Validation: - harness `invoke_table_function(..., serialize_state=True)` models http by capping each tick to one producer batch and wire-serializing ScanState between every tick, with a 1000-tick guard. On the old emit-all/None code it overruns the guard (re-emits row 0); on the cursor code it terminates once per row. - New TestScanStateRoundTrip (model-gated, 1000-row image_classes) and TestCursorSurvivesContinuation (offline, monkeypatched classify_image -> 200 synthetic preds) tests, plus small-chunk and no-preds cases. - classify.test asserts image_classes() pages correctly over http (count, dense idx, ordered head at offset 500). Verified locally green on subprocess, http, and unix transports. Co-Authored-By: Claude Opus 4.8 (1M context) --- CLAUDE.md | 40 +++++++++ test/sql/classify.test | 18 +++- tests/harness.py | 60 +++++++++++-- tests/test_tables.py | 91 +++++++++++++++++++ vgi_vision/tables.py | 195 ++++++++++++++++++++++++++++++++--------- 5 files changed, 358 insertions(+), 46 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index b530fc3..c5dd1a9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -59,6 +59,46 @@ only), and overloads resolve by **arity + type**: LIST/STRUCT-style fixed schemas use `bind_fixed_schema` + `FIXED_SCHEMA` and the `field()` helper for column comments. +## Scan state is an explicit cursor — the HTTP-continuation contract + +All five table functions are `TableFunctionGenerator[Args, ScanState]` (in +`tables.py`), **not** `state: None`. `ScanState(ArrowSerializableDataclass)` carries +`started: bool`, `offset: int`, `rows_ipc: bytes` (the full result as Arrow IPC +bytes via `result_to_ipc`/`ipc_to_table`), and a module constant +`ROWS_PER_TICK = 64`. + +**Why a cursor and not emit-all.** Over the stateless **http** transport the +framework round-trips the producer's per-scan state through a *continuation token*: +after each `process()` tick it wire-serializes the state +(`ArrowSerializableDataclass.serialize_to_bytes()`), the client returns it, and the +worker resumes by deserializing it — emitting **at most one producer batch per +response**. A position-less `state: None` function that did +`out.emit(...all rows...); out.finish()` restarts from row 0 on every http resume +and **loops forever** once the result exceeds one batch. subprocess/unix keep the +live state in-process so they hide the bug; only http (and the +serialize-between-ticks unit test) expose it. `image_classes()` = 1000 rows is the +smoking gun — if the http leg times out, the cursor is wrong. + +**The contract.** Each `process()` tick: on the first tick (`not state.started`) +read the source and materialize the full output batch into `state.rows_ipc` (set +`started`); then emit one bounded `ROWS_PER_TICK` slice from `state.offset`, +**advance `state.offset` before `out.emit()`** (so a tick suspended at the limit-1 +boundary serializes the already-advanced offset), and `out.finish()` once +`offset >= total`. The NULL/empty-image early-out is preserved: a `None`/empty image +materializes a zero-row batch and finishes immediately (`0 >= 0`). The shared +`_emit_cursor` helper does the slicing; `_classify_batch` builds the full classify +batch. Rows/schema are byte-identical to the old emit-all behaviour. + +The regression test lives in `tests/test_tables.py` +(`TestScanStateRoundTrip`/`TestCursorSurvivesContinuation`): the test harness +`invoke_table_function(..., serialize_state=True)` models http by capping each tick +to one producer batch and wire-serializing `ScanState` between every tick, with a +1000-tick guard. On the old emit-all/`None` code it overruns the guard (re-emits +row 0); on the cursor code it terminates with each row emitted once. The +model-gated `image_classes` round-trips the 1000-row label set; an offline classify +variant monkeypatches `model.classify_image` to return ~200 synthetic preds so the +regression is covered without the ONNX weights. + ## Sharp edges (learned the hard way) 1. **`haybarn-unittest` skips `require vgi`** — under haybarn the extension is diff --git a/test/sql/classify.test b/test/sql/classify.test index d4c1866..2cfdef8 100644 --- a/test/sql/classify.test +++ b/test/sql/classify.test @@ -11,19 +11,33 @@ statement ok ATTACH 'vision' AS vision (TYPE vgi, LOCATION '${VGI_VISION_WORKER}'); # ---- image_classes(): the model's full ImageNet label set ---- +# 1000 rows is well over one producer batch, so over the http transport this +# result is paged across continuation tokens. If the scan cursor (ScanState.offset) +# did not survive each continuation boundary, the http leg would loop forever +# (re-emitting from row 0) or truncate -- so this is a live regression test for the +# HTTP-continuation cursor, exercised on every transport. query I SELECT count(*) FROM vision.image_classes(); ---- 1000 -# Indices are 0..999, dense. +# Indices are 0..999, dense -- every row emitted exactly once, no dupes, no gaps. +# (A broken cursor that re-paged from 0 would inflate the count / break DISTINCT.) query I -SELECT count(DISTINCT idx) = 1000 AND min(idx) = 0 AND max(idx) = 999 +SELECT count(DISTINCT idx) = 1000 AND min(idx) = 0 AND max(idx) = 999 AND count(*) = 1000 FROM vision.image_classes(); ---- true +# Ordered head well past the first producer batch: idx=500 is reachable only if +# the cursor paged correctly across the http continuation boundary. +query II +SELECT idx, (label IS NOT NULL AND length(label) > 0) FROM vision.image_classes() +ORDER BY idx LIMIT 1 OFFSET 500; +---- +500 true + # ---- top_label(image): #1 prediction for the committed sample.png ---- # The prediction is deterministic for this exact fixture + model (MobileNetV2-12). diff --git a/tests/harness.py b/tests/harness.py index ee2c678..d781fbd 100644 --- a/tests/harness.py +++ b/tests/harness.py @@ -10,6 +10,7 @@ from __future__ import annotations +import contextlib import io from typing import Any @@ -71,12 +72,24 @@ def test_storage() -> FunctionStorage: class MockOutputCollector: - """Captures emitted batches for assertions.""" + """Captures emitted batches for assertions. - def __init__(self, output_schema: pa.Schema) -> None: + ``batch_limit`` models the HTTP transport's per-response producer batch limit: + once that many batches have been emitted in a single ``process()`` tick, further + emits raise ``_BatchLimitReached`` so the driver can suspend, wire-serialize the + state, and resume (exactly as the http server does across a continuation token). + ``None`` (the default) means unbounded -- the in-process behaviour. + """ + + def __init__(self, output_schema: pa.Schema, batch_limit: int | None = None) -> None: self.output_schema = output_schema self.batches: list[pa.RecordBatch] = [] self._finished = False + self._batch_limit = batch_limit + self._emitted_this_tick = 0 + + def begin_tick(self) -> None: + self._emitted_this_tick = 0 def emit( self, @@ -85,6 +98,9 @@ def emit( metadata: dict[str, str] | None = None, ) -> None: self.batches.append(batch) + self._emitted_this_tick += 1 + if self._batch_limit is not None and self._emitted_this_tick >= self._batch_limit: + raise _BatchLimitReached def finish(self) -> None: self._finished = True @@ -97,13 +113,30 @@ def emit_client_log_message(self, msg: Any) -> None: pass +class _BatchLimitReached(Exception): + """Internal: the per-tick producer batch limit was hit; suspend + resume.""" + + def invoke_table_function( func_cls: type, *, named: dict[str, pa.Scalar] | None = None, positional: tuple[pa.Scalar, ...] = (), + serialize_state: bool = False, + max_ticks: int = 1000, ) -> pa.Table: - """Run a (source) table function through bind -> init -> process -> table.""" + """Run a (source) table function through bind -> init -> process -> table. + + When ``serialize_state=True`` the driver faithfully models the stateless HTTP + transport: each ``process()`` tick may emit at most ONE producer batch (the + limit-1 continuation boundary), after which the per-scan state is wire- + serialized and deserialized + (``type(state).deserialize_from_bytes(state.serialize_to_bytes())``) before the + next tick resumes. A position-less state that re-emits row 0 on every resume + loops forever; the ``max_ticks`` guard turns that into a loud failure instead + of an infinite hang. With a cursor state the offset survives each round-trip + and the scan terminates after ~ceil(rows / ROWS_PER_TICK) ticks. + """ args = Arguments(positional=positional, named=named or {}) bind_req = BindRequest( @@ -128,9 +161,26 @@ def invoke_table_function( ) state = func_cls.initial_state(params) - out = MockOutputCollector(bind_resp.output_schema) + # Over http, each response carries at most one producer batch; model that with + # batch_limit=1 so the cursor must be observable across the boundary. + out = MockOutputCollector(bind_resp.output_schema, batch_limit=1 if serialize_state else None) + ticks = 0 while not out.finished: - func_cls.process(params, state, out) + if serialize_state and state is not None: + state = type(state).deserialize_from_bytes(state.serialize_to_bytes()) + out.begin_tick() + # _BatchLimitReached suspends the tick mid-flight exactly as the http server + # does once it has filled a response with one batch; the loop then resumes + # with the (serialized) state. + with contextlib.suppress(_BatchLimitReached): + func_cls.process(params, state, out) + ticks += 1 + if ticks > max_ticks: + raise AssertionError( + f"{func_cls.__name__} did not finish after {max_ticks} ticks " + f"(serialize_state={serialize_state}): the scan cursor is not " + f"surviving the continuation boundary (likely re-emitting from row 0)." + ) return pa.Table.from_batches(out.batches, schema=bind_resp.output_schema) diff --git a/tests/test_tables.py b/tests/test_tables.py index 36dd486..b2698ad 100644 --- a/tests/test_tables.py +++ b/tests/test_tables.py @@ -7,8 +7,12 @@ from __future__ import annotations +import math + import pyarrow as pa +import pytest +from vgi_vision import tables from vgi_vision.tables import ( ClassifyFunction, ClassifyTopKFunction, @@ -67,3 +71,90 @@ def test_garbage_image_no_rows(self) -> None: def test_empty_image_no_rows(self) -> None: table = invoke_table_function(ClassifyFunction, positional=(pa.scalar(b"", type=pa.binary()),)) assert table.num_rows == 0 + + +# --------------------------------------------------------------------------- +# HTTP-continuation regression: the scan cursor must survive being wire- +# serialized/deserialized between every process() tick (as the stateless HTTP +# transport does across its limit-1 continuation boundary). On the OLD +# emit-all-then-finish code with `state: None`, a serialized resume re-emits from +# row 0 forever -- these tests would overrun the tick guard (hang -> AssertionError) +# or see duplicated rows. On the cursor code the offset survives and the scan +# terminates with each row emitted exactly once. +# --------------------------------------------------------------------------- + + +class TestScanStateRoundTrip: + """image_classes() (~1000 rows >> ROWS_PER_TICK) drives the real cursor path.""" + + @needs_model + def test_serialize_between_ticks_matches_single_shot(self) -> None: + plain = invoke_table_function(ImageClassesFunction) + rt = invoke_table_function(ImageClassesFunction, serialize_state=True) + + assert plain.num_rows == 1000 + # Identical rows AND order, despite a wire round-trip on every tick. + assert rt.to_pylist() == plain.to_pylist() + # No duplicates: idx is the dense 0..999 range exactly once. + idx = rt.column("idx").to_pylist() + assert idx == list(range(1000)) + assert len(set(idx)) == 1000 + + @needs_model + def test_terminates_in_expected_tick_count(self) -> None: + # ceil(1000 / ROWS_PER_TICK) emitting ticks + 1 finishing tick at most. + expected = math.ceil(1000 / tables.ROWS_PER_TICK) + # A tight guard: if the cursor regressed, this overruns and raises. + table = invoke_table_function(ImageClassesFunction, serialize_state=True, max_ticks=expected + 2) + assert table.num_rows == 1000 + + @needs_model + def test_small_chunk(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Force many tiny slices so the cursor must advance through ~500 ticks, + # each across a serialize boundary -- a stress test of offset survival. + monkeypatch.setattr(tables, "ROWS_PER_TICK", 2) + rt = invoke_table_function(ImageClassesFunction, serialize_state=True) + assert rt.column("idx").to_pylist() == list(range(1000)) + + +class TestCursorSurvivesContinuation: + """Offline classify round-trip: monkeypatch classify_image to span many ticks. + + Runs without the ONNX weights so the regression is covered even on a bare + checkout (the image_classes tests above are model-gated). + """ + + def test_many_synthetic_preds_round_trip(self, monkeypatch: pytest.MonkeyPatch) -> None: + n = 200 # >> ROWS_PER_TICK (64): forces multiple continuation slices. + synthetic = [(f"label-{i:04d}", 1.0 - i / n) for i in range(n)] + monkeypatch.setattr(tables.model, "classify_image", lambda *a, **k: synthetic) + + img = pa.scalar(png_bytes((1, 2, 3)), type=pa.binary()) + plain = invoke_table_function(ClassifyFunction, positional=(img,)) + rt = invoke_table_function(ClassifyFunction, positional=(img,), serialize_state=True) + + assert plain.num_rows == n + assert rt.num_rows == n + # Byte-identical rows AND order across the wire round-trip. + assert rt.to_pylist() == plain.to_pylist() + labels = rt.column("label").to_pylist() + assert labels == [p[0] for p in synthetic] + assert len(set(labels)) == n # no dupes + + def test_small_chunk(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(tables, "ROWS_PER_TICK", 2) + synthetic = [(f"c{i}", 0.5) for i in range(50)] + monkeypatch.setattr(tables.model, "classify_image", lambda *a, **k: synthetic) + img = pa.scalar(png_bytes((9, 9, 9)), type=pa.binary()) + rt = invoke_table_function( + ClassifyTopKFunction, positional=(img, pa.scalar(50, type=pa.int64())), serialize_state=True + ) + assert rt.column("label").to_pylist() == [p[0] for p in synthetic] + + def test_no_preds_still_terminates(self, monkeypatch: pytest.MonkeyPatch) -> None: + # The NULL/empty-image early-out path: zero rows, finishes immediately, + # even with serialize-between-ticks (0 >= 0). + monkeypatch.setattr(tables.model, "classify_image", lambda *a, **k: None) + img = pa.scalar(png_bytes((0, 0, 0)), type=pa.binary()) + rt = invoke_table_function(ClassifyFunction, positional=(img,), serialize_state=True, max_ticks=5) + assert rt.num_rows == 0 diff --git a/vgi_vision/tables.py b/vgi_vision/tables.py index a4e4241..7e4b85b 100644 --- a/vgi_vision/tables.py +++ b/vgi_vision/tables.py @@ -36,6 +36,7 @@ bind_fixed_schema, init_single_worker, ) +from vgi_rpc import ArrowSerializableDataclass from vgi_rpc.rpc import OutputCollector from . import model @@ -44,6 +45,78 @@ _DEFAULT_TOP_K = 5 +# Rows emitted per process() tick. Bounded so the scan cursor (``offset``) is +# observable across the HTTP limit-1 continuation boundary: correctness no longer +# depends on the whole result fitting inside a single producer batch. See +# ScanState below and CLAUDE.md "HTTP continuation" for the why. +ROWS_PER_TICK = 64 + + +@dataclass(kw_only=True) +class ScanState(ArrowSerializableDataclass): + """Externalized scan cursor for the vision table functions. + + Over the stateless HTTP transport the framework wire-serializes a producer's + per-scan state through a continuation token after each ``process()`` tick (the + client returns it; the worker resumes by deserializing it). A position-less + state that emits *all* rows in one ``out.emit()`` and finishes therefore + restarts from row 0 on every HTTP resume and loops forever once the output + exceeds one producer batch. Carrying an explicit cursor here fixes that. + + Fields (all wire-serialize through the continuation token): + + * ``started`` -- flips once the (possibly heavy) source has been read and the + full result batch materialized into ``rows_ipc``. Distinguishes + "not yet computed" from "computed an empty result". + * ``offset`` -- index of the next unemitted row; advanced at each emit. + * ``rows_ipc`` -- the full materialized result as Arrow IPC stream bytes. + """ + + started: bool = False + offset: int = 0 + rows_ipc: bytes = b"" + + +def result_to_ipc(batch: pa.RecordBatch) -> bytes: + """Serialize a single RecordBatch to Arrow IPC stream bytes (for ScanState).""" + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, batch.schema) as writer: # type: ignore[no-untyped-call] + writer.write_batch(batch) + result: bytes = sink.getvalue().to_pybytes() + return result + + +def ipc_to_table(value: bytes) -> pa.Table: + """Deserialize Arrow IPC stream bytes (from ScanState) back to a Table.""" + reader = pa.ipc.open_stream(pa.BufferReader(value)) # type: ignore[no-untyped-call] + return reader.read_all() + + +def _emit_cursor(state: ScanState, out: OutputCollector, schema: pa.Schema) -> None: + """Emit one bounded ``ROWS_PER_TICK`` slice from ``state.offset``; finish when drained. + + ``state.started`` must already be set (``rows_ipc`` materialized). Advances + ``state.offset`` past the emitted slice so a resumed tick (post wire round-trip) + sees the new position and never re-emits row 0. An empty/zero-row result + finishes immediately (``0 >= 0``). + """ + table = ipc_to_table(state.rows_ipc) + total = table.num_rows + if state.offset >= total: + out.finish() + return + end = min(state.offset + ROWS_PER_TICK, total) + chunk = table.slice(state.offset, end - state.offset) + # Advance the cursor BEFORE emitting: over http, emit() may suspend the tick + # (limit-1 continuation boundary) and the framework wire-serializes the state + # as it stands -- the advanced offset must already be recorded so the resumed + # tick continues past this slice instead of re-emitting it. + state.offset = end + out.emit(chunk.combine_chunks().to_batches()[0]) + if state.offset >= total: + out.finish() + + _CLASSIFY_SCHEMA = pa.schema( [ field("label", pa.string(), "Predicted ImageNet class label.", nullable=False), @@ -59,19 +132,30 @@ ) -def _emit_classify(preds: list[tuple[str, float]] | None, out: OutputCollector, schema: pa.Schema) -> None: - """Emit one row per prediction (or nothing for a NULL/unclassifiable image).""" +def _classify_batch(preds: list[tuple[str, float]] | None, schema: pa.Schema) -> pa.RecordBatch: + """Build the full ``(label, confidence)`` batch for a set of predictions. + + A NULL/unclassifiable image (``preds`` is ``None``/empty) yields a zero-row + batch -- the cursor then finishes with no rows, preserving the early-out + contract. The cursor (not this helper) does the emit/finish. + """ if not preds: - out.emit(pa.RecordBatch.from_pydict({"label": [], "confidence": []}, schema=schema)) - out.finish() - return - out.emit( - pa.RecordBatch.from_pydict( - {"label": [p[0] for p in preds], "confidence": [p[1] for p in preds]}, - schema=schema, - ) + return pa.RecordBatch.from_pydict({"label": [], "confidence": []}, schema=schema) + return pa.RecordBatch.from_pydict( + {"label": [p[0] for p in preds], "confidence": [p[1] for p in preds]}, + schema=schema, ) - out.finish() + + +def _process_classify( + state: ScanState, out: OutputCollector, schema: pa.Schema, preds: list[tuple[str, float]] | None +) -> None: + """Cursor-driven classify tick: materialize predictions on the first tick, then stream slices.""" + if not state.started: + state.rows_ipc = result_to_ipc(_classify_batch(preds, schema)) + state.started = True + state.offset = 0 + _emit_cursor(state, out, schema) # --------------------------------------------------------------------------- @@ -92,7 +176,7 @@ class _ClassifyBlobTopKArgs: @init_single_worker @bind_fixed_schema -class ClassifyFunction(TableFunctionGenerator[_ClassifyBlobArgs]): +class ClassifyFunction(TableFunctionGenerator[_ClassifyBlobArgs, ScanState]): """``classify(image)`` -- top-5 ImageNet predictions, confidence descending.""" FunctionArguments = _ClassifyBlobArgs @@ -117,15 +201,20 @@ def cardinality(cls, params: BindParams[_ClassifyBlobArgs]) -> TableCardinality: return TableCardinality(estimate=_DEFAULT_TOP_K, max=_DEFAULT_TOP_K) @classmethod - def process(cls, params: ProcessParams[_ClassifyBlobArgs], state: None, out: OutputCollector) -> None: - """Classify the image BLOB and emit the top-5 predictions.""" - preds = model.classify_image(params.args.image, top_k=_DEFAULT_TOP_K) - _emit_classify(preds, out, params.output_schema) + def initial_state(cls, params: ProcessParams[_ClassifyBlobArgs]) -> ScanState: + """Fresh scan cursor for this image's predictions.""" + return ScanState() + + @classmethod + def process(cls, params: ProcessParams[_ClassifyBlobArgs], state: ScanState, out: OutputCollector) -> None: + """Classify the image BLOB and stream the top-5 predictions via the cursor.""" + preds = None if state.started else model.classify_image(params.args.image, top_k=_DEFAULT_TOP_K) + _process_classify(state, out, params.output_schema, preds) @init_single_worker @bind_fixed_schema -class ClassifyTopKFunction(TableFunctionGenerator[_ClassifyBlobTopKArgs]): +class ClassifyTopKFunction(TableFunctionGenerator[_ClassifyBlobTopKArgs, ScanState]): """``classify(image, top_k)`` -- top-k ImageNet predictions, confidence desc.""" FunctionArguments = _ClassifyBlobTopKArgs @@ -151,10 +240,15 @@ def cardinality(cls, params: BindParams[_ClassifyBlobTopKArgs]) -> TableCardinal return TableCardinality(estimate=k, max=k) @classmethod - def process(cls, params: ProcessParams[_ClassifyBlobTopKArgs], state: None, out: OutputCollector) -> None: - """Classify the image BLOB and emit the top-k predictions.""" - preds = model.classify_image(params.args.image, top_k=params.args.top_k) - _emit_classify(preds, out, params.output_schema) + def initial_state(cls, params: ProcessParams[_ClassifyBlobTopKArgs]) -> ScanState: + """Fresh scan cursor for this image's predictions.""" + return ScanState() + + @classmethod + def process(cls, params: ProcessParams[_ClassifyBlobTopKArgs], state: ScanState, out: OutputCollector) -> None: + """Classify the image BLOB and stream the top-k predictions via the cursor.""" + preds = None if state.started else model.classify_image(params.args.image, top_k=params.args.top_k) + _process_classify(state, out, params.output_schema, preds) # --------------------------------------------------------------------------- @@ -175,7 +269,7 @@ class _ClassifyPathTopKArgs: @init_single_worker @bind_fixed_schema -class ClassifyPathFunction(TableFunctionGenerator[_ClassifyPathArgs]): +class ClassifyPathFunction(TableFunctionGenerator[_ClassifyPathArgs, ScanState]): """``classify(path)`` -- top-5 predictions for an image read off disk.""" FunctionArguments = _ClassifyPathArgs @@ -200,15 +294,20 @@ def cardinality(cls, params: BindParams[_ClassifyPathArgs]) -> TableCardinality: return TableCardinality(estimate=_DEFAULT_TOP_K, max=_DEFAULT_TOP_K) @classmethod - def process(cls, params: ProcessParams[_ClassifyPathArgs], state: None, out: OutputCollector) -> None: - """Read the image off disk, classify it, and emit the top-5 predictions.""" - preds = model.classify_image(_read_path(params.args.path), top_k=_DEFAULT_TOP_K) - _emit_classify(preds, out, params.output_schema) + def initial_state(cls, params: ProcessParams[_ClassifyPathArgs]) -> ScanState: + """Fresh scan cursor for this image's predictions.""" + return ScanState() + + @classmethod + def process(cls, params: ProcessParams[_ClassifyPathArgs], state: ScanState, out: OutputCollector) -> None: + """Read the image off disk, classify it, and stream the top-5 predictions via the cursor.""" + preds = None if state.started else model.classify_image(_read_path(params.args.path), top_k=_DEFAULT_TOP_K) + _process_classify(state, out, params.output_schema, preds) @init_single_worker @bind_fixed_schema -class ClassifyPathTopKFunction(TableFunctionGenerator[_ClassifyPathTopKArgs]): +class ClassifyPathTopKFunction(TableFunctionGenerator[_ClassifyPathTopKArgs, ScanState]): """``classify(path, top_k)`` -- top-k predictions for an image read off disk.""" FunctionArguments = _ClassifyPathTopKArgs @@ -234,10 +333,15 @@ def cardinality(cls, params: BindParams[_ClassifyPathTopKArgs]) -> TableCardinal return TableCardinality(estimate=k, max=k) @classmethod - def process(cls, params: ProcessParams[_ClassifyPathTopKArgs], state: None, out: OutputCollector) -> None: - """Read the image off disk, classify it, and emit the top-k predictions.""" - preds = model.classify_image(_read_path(params.args.path), top_k=params.args.top_k) - _emit_classify(preds, out, params.output_schema) + def initial_state(cls, params: ProcessParams[_ClassifyPathTopKArgs]) -> ScanState: + """Fresh scan cursor for this image's predictions.""" + return ScanState() + + @classmethod + def process(cls, params: ProcessParams[_ClassifyPathTopKArgs], state: ScanState, out: OutputCollector) -> None: + """Read the image off disk, classify it, and stream the top-k predictions via the cursor.""" + preds = None if state.started else model.classify_image(_read_path(params.args.path), top_k=params.args.top_k) + _process_classify(state, out, params.output_schema, preds) # --------------------------------------------------------------------------- @@ -252,7 +356,7 @@ class _NoArgs: @init_single_worker @bind_fixed_schema -class ImageClassesFunction(TableFunctionGenerator[_NoArgs]): +class ImageClassesFunction(TableFunctionGenerator[_NoArgs, ScanState]): """``image_classes()`` -- every ``(idx, label)`` the classifier can predict.""" FunctionArguments = _NoArgs @@ -277,16 +381,29 @@ def cardinality(cls, params: BindParams[_NoArgs]) -> TableCardinality: return TableCardinality(estimate=model.NUM_CLASSES, max=model.NUM_CLASSES) @classmethod - def process(cls, params: ProcessParams[_NoArgs], state: None, out: OutputCollector) -> None: - """Emit every ``(idx, label)`` the classifier can predict.""" - rows = model.class_table() - out.emit( - pa.RecordBatch.from_pydict( + def initial_state(cls, params: ProcessParams[_NoArgs]) -> ScanState: + """Fresh scan cursor for the label-set enumeration.""" + return ScanState() + + @classmethod + def process(cls, params: ProcessParams[_NoArgs], state: ScanState, out: OutputCollector) -> None: + """Stream every ``(idx, label)`` the classifier can predict, ``ROWS_PER_TICK`` at a time. + + The label set is ~1000 rows -- well over one producer batch -- so the cursor + is *required*: on the first tick we materialize the full table into + ``state.rows_ipc``, then emit bounded slices so the offset survives each HTTP + continuation round-trip. + """ + if not state.started: + rows = model.class_table() + batch = pa.RecordBatch.from_pydict( {"idx": [r[0] for r in rows], "label": [r[1] for r in rows]}, schema=params.output_schema, ) - ) - out.finish() + state.rows_ipc = result_to_ipc(batch) + state.started = True + state.offset = 0 + _emit_cursor(state, out, params.output_schema) TABLE_FUNCTIONS: list[type] = [