Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,46 @@ only), and overloads resolve by **arity + type**:
LIST/STRUCT-style fixed schemas use `bind_fixed_schema` + `FIXED_SCHEMA` and the
`field()` helper for column comments.

## Scan state is an explicit cursor — the HTTP-continuation contract

All five table functions are `TableFunctionGenerator[Args, ScanState]` (in
`tables.py`), **not** `state: None`. `ScanState(ArrowSerializableDataclass)` carries
`started: bool`, `offset: int`, `rows_ipc: bytes` (the full result as Arrow IPC
bytes via `result_to_ipc`/`ipc_to_table`), and a module constant
`ROWS_PER_TICK = 64`.

**Why a cursor and not emit-all.** Over the stateless **http** transport the
framework round-trips the producer's per-scan state through a *continuation token*:
after each `process()` tick it wire-serializes the state
(`ArrowSerializableDataclass.serialize_to_bytes()`), the client returns it, and the
worker resumes by deserializing it — emitting **at most one producer batch per
response**. A position-less `state: None` function that did
`out.emit(...all rows...); out.finish()` restarts from row 0 on every http resume
and **loops forever** once the result exceeds one batch. subprocess/unix keep the
live state in-process so they hide the bug; only http (and the
serialize-between-ticks unit test) expose it. `image_classes()` = 1000 rows is the
smoking gun — if the http leg times out, the cursor is wrong.

**The contract.** Each `process()` tick: on the first tick (`not state.started`)
read the source and materialize the full output batch into `state.rows_ipc` (set
`started`); then emit one bounded `ROWS_PER_TICK` slice from `state.offset`,
**advance `state.offset` before `out.emit()`** (so a tick suspended at the limit-1
boundary serializes the already-advanced offset), and `out.finish()` once
`offset >= total`. The NULL/empty-image early-out is preserved: a `None`/empty image
materializes a zero-row batch and finishes immediately (`0 >= 0`). The shared
`_emit_cursor` helper does the slicing; `_classify_batch` builds the full classify
batch. Rows/schema are byte-identical to the old emit-all behaviour.

The regression test lives in `tests/test_tables.py`
(`TestScanStateRoundTrip`/`TestCursorSurvivesContinuation`): the test harness
`invoke_table_function(..., serialize_state=True)` models http by capping each tick
to one producer batch and wire-serializing `ScanState` between every tick, with a
1000-tick guard. On the old emit-all/`None` code it overruns the guard (re-emits
row 0); on the cursor code it terminates with each row emitted once. The
model-gated `image_classes` round-trips the 1000-row label set; an offline classify
variant monkeypatches `model.classify_image` to return ~200 synthetic preds so the
regression is covered without the ONNX weights.

## Sharp edges (learned the hard way)

1. **`haybarn-unittest` skips `require vgi`** — under haybarn the extension is
Expand Down
18 changes: 16 additions & 2 deletions test/sql/classify.test
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,33 @@ statement ok
ATTACH 'vision' AS vision (TYPE vgi, LOCATION '${VGI_VISION_WORKER}');

# ---- image_classes(): the model's full ImageNet label set ----
# 1000 rows is well over one producer batch, so over the http transport this
# result is paged across continuation tokens. If the scan cursor (ScanState.offset)
# did not survive each continuation boundary, the http leg would loop forever
# (re-emitting from row 0) or truncate -- so this is a live regression test for the
# HTTP-continuation cursor, exercised on every transport.

query I
SELECT count(*) FROM vision.image_classes();
----
1000

# Indices are 0..999, dense.
# Indices are 0..999, dense -- every row emitted exactly once, no dupes, no gaps.
# (A broken cursor that re-paged from 0 would inflate the count / break DISTINCT.)
query I
SELECT count(DISTINCT idx) = 1000 AND min(idx) = 0 AND max(idx) = 999
SELECT count(DISTINCT idx) = 1000 AND min(idx) = 0 AND max(idx) = 999 AND count(*) = 1000
FROM vision.image_classes();
----
true

# Ordered head well past the first producer batch: idx=500 is reachable only if
# the cursor paged correctly across the http continuation boundary.
query II
SELECT idx, (label IS NOT NULL AND length(label) > 0) FROM vision.image_classes()
ORDER BY idx LIMIT 1 OFFSET 500;
----
500 true

# ---- top_label(image): #1 prediction for the committed sample.png ----

# The prediction is deterministic for this exact fixture + model (MobileNetV2-12).
Expand Down
60 changes: 55 additions & 5 deletions tests/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from __future__ import annotations

import contextlib
import io
from typing import Any

Expand Down Expand Up @@ -71,12 +72,24 @@ def test_storage() -> FunctionStorage:


class MockOutputCollector:
"""Captures emitted batches for assertions."""
"""Captures emitted batches for assertions.

def __init__(self, output_schema: pa.Schema) -> None:
``batch_limit`` models the HTTP transport's per-response producer batch limit:
once that many batches have been emitted in a single ``process()`` tick, further
emits raise ``_BatchLimitReached`` so the driver can suspend, wire-serialize the
state, and resume (exactly as the http server does across a continuation token).
``None`` (the default) means unbounded -- the in-process behaviour.
"""

def __init__(self, output_schema: pa.Schema, batch_limit: int | None = None) -> None:
self.output_schema = output_schema
self.batches: list[pa.RecordBatch] = []
self._finished = False
self._batch_limit = batch_limit
self._emitted_this_tick = 0

def begin_tick(self) -> None:
self._emitted_this_tick = 0

def emit(
self,
Expand All @@ -85,6 +98,9 @@ def emit(
metadata: dict[str, str] | None = None,
) -> None:
self.batches.append(batch)
self._emitted_this_tick += 1
if self._batch_limit is not None and self._emitted_this_tick >= self._batch_limit:
raise _BatchLimitReached

def finish(self) -> None:
self._finished = True
Expand All @@ -97,13 +113,30 @@ def emit_client_log_message(self, msg: Any) -> None:
pass


class _BatchLimitReached(Exception):
"""Internal: the per-tick producer batch limit was hit; suspend + resume."""


def invoke_table_function(
func_cls: type,
*,
named: dict[str, pa.Scalar] | None = None,
positional: tuple[pa.Scalar, ...] = (),
serialize_state: bool = False,
max_ticks: int = 1000,
) -> pa.Table:
"""Run a (source) table function through bind -> init -> process -> table."""
"""Run a (source) table function through bind -> init -> process -> table.

When ``serialize_state=True`` the driver faithfully models the stateless HTTP
transport: each ``process()`` tick may emit at most ONE producer batch (the
limit-1 continuation boundary), after which the per-scan state is wire-
serialized and deserialized
(``type(state).deserialize_from_bytes(state.serialize_to_bytes())``) before the
next tick resumes. A position-less state that re-emits row 0 on every resume
loops forever; the ``max_ticks`` guard turns that into a loud failure instead
of an infinite hang. With a cursor state the offset survives each round-trip
and the scan terminates after ~ceil(rows / ROWS_PER_TICK) ticks.
"""
args = Arguments(positional=positional, named=named or {})

bind_req = BindRequest(
Expand All @@ -128,9 +161,26 @@ def invoke_table_function(
)

state = func_cls.initial_state(params)
out = MockOutputCollector(bind_resp.output_schema)
# Over http, each response carries at most one producer batch; model that with
# batch_limit=1 so the cursor must be observable across the boundary.
out = MockOutputCollector(bind_resp.output_schema, batch_limit=1 if serialize_state else None)

ticks = 0
while not out.finished:
func_cls.process(params, state, out)
if serialize_state and state is not None:
state = type(state).deserialize_from_bytes(state.serialize_to_bytes())
out.begin_tick()
# _BatchLimitReached suspends the tick mid-flight exactly as the http server
# does once it has filled a response with one batch; the loop then resumes
# with the (serialized) state.
with contextlib.suppress(_BatchLimitReached):
func_cls.process(params, state, out)
ticks += 1
if ticks > max_ticks:
raise AssertionError(
f"{func_cls.__name__} did not finish after {max_ticks} ticks "
f"(serialize_state={serialize_state}): the scan cursor is not "
f"surviving the continuation boundary (likely re-emitting from row 0)."
)

return pa.Table.from_batches(out.batches, schema=bind_resp.output_schema)
91 changes: 91 additions & 0 deletions tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@

from __future__ import annotations

import math

import pyarrow as pa
import pytest

from vgi_vision import tables
from vgi_vision.tables import (
ClassifyFunction,
ClassifyTopKFunction,
Expand Down Expand Up @@ -67,3 +71,90 @@ def test_garbage_image_no_rows(self) -> None:
def test_empty_image_no_rows(self) -> None:
table = invoke_table_function(ClassifyFunction, positional=(pa.scalar(b"", type=pa.binary()),))
assert table.num_rows == 0


# ---------------------------------------------------------------------------
# HTTP-continuation regression: the scan cursor must survive being wire-
# serialized/deserialized between every process() tick (as the stateless HTTP
# transport does across its limit-1 continuation boundary). On the OLD
# emit-all-then-finish code with `state: None`, a serialized resume re-emits from
# row 0 forever -- these tests would overrun the tick guard (hang -> AssertionError)
# or see duplicated rows. On the cursor code the offset survives and the scan
# terminates with each row emitted exactly once.
# ---------------------------------------------------------------------------


class TestScanStateRoundTrip:
"""image_classes() (~1000 rows >> ROWS_PER_TICK) drives the real cursor path."""

@needs_model
def test_serialize_between_ticks_matches_single_shot(self) -> None:
plain = invoke_table_function(ImageClassesFunction)
rt = invoke_table_function(ImageClassesFunction, serialize_state=True)

assert plain.num_rows == 1000
# Identical rows AND order, despite a wire round-trip on every tick.
assert rt.to_pylist() == plain.to_pylist()
# No duplicates: idx is the dense 0..999 range exactly once.
idx = rt.column("idx").to_pylist()
assert idx == list(range(1000))
assert len(set(idx)) == 1000

@needs_model
def test_terminates_in_expected_tick_count(self) -> None:
# ceil(1000 / ROWS_PER_TICK) emitting ticks + 1 finishing tick at most.
expected = math.ceil(1000 / tables.ROWS_PER_TICK)
# A tight guard: if the cursor regressed, this overruns and raises.
table = invoke_table_function(ImageClassesFunction, serialize_state=True, max_ticks=expected + 2)
assert table.num_rows == 1000

@needs_model
def test_small_chunk(self, monkeypatch: pytest.MonkeyPatch) -> None:
# Force many tiny slices so the cursor must advance through ~500 ticks,
# each across a serialize boundary -- a stress test of offset survival.
monkeypatch.setattr(tables, "ROWS_PER_TICK", 2)
rt = invoke_table_function(ImageClassesFunction, serialize_state=True)
assert rt.column("idx").to_pylist() == list(range(1000))


class TestCursorSurvivesContinuation:
"""Offline classify round-trip: monkeypatch classify_image to span many ticks.

Runs without the ONNX weights so the regression is covered even on a bare
checkout (the image_classes tests above are model-gated).
"""

def test_many_synthetic_preds_round_trip(self, monkeypatch: pytest.MonkeyPatch) -> None:
n = 200 # >> ROWS_PER_TICK (64): forces multiple continuation slices.
synthetic = [(f"label-{i:04d}", 1.0 - i / n) for i in range(n)]
monkeypatch.setattr(tables.model, "classify_image", lambda *a, **k: synthetic)

img = pa.scalar(png_bytes((1, 2, 3)), type=pa.binary())
plain = invoke_table_function(ClassifyFunction, positional=(img,))
rt = invoke_table_function(ClassifyFunction, positional=(img,), serialize_state=True)

assert plain.num_rows == n
assert rt.num_rows == n
# Byte-identical rows AND order across the wire round-trip.
assert rt.to_pylist() == plain.to_pylist()
labels = rt.column("label").to_pylist()
assert labels == [p[0] for p in synthetic]
assert len(set(labels)) == n # no dupes

def test_small_chunk(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(tables, "ROWS_PER_TICK", 2)
synthetic = [(f"c{i}", 0.5) for i in range(50)]
monkeypatch.setattr(tables.model, "classify_image", lambda *a, **k: synthetic)
img = pa.scalar(png_bytes((9, 9, 9)), type=pa.binary())
rt = invoke_table_function(
ClassifyTopKFunction, positional=(img, pa.scalar(50, type=pa.int64())), serialize_state=True
)
assert rt.column("label").to_pylist() == [p[0] for p in synthetic]

def test_no_preds_still_terminates(self, monkeypatch: pytest.MonkeyPatch) -> None:
# The NULL/empty-image early-out path: zero rows, finishes immediately,
# even with serialize-between-ticks (0 >= 0).
monkeypatch.setattr(tables.model, "classify_image", lambda *a, **k: None)
img = pa.scalar(png_bytes((0, 0, 0)), type=pa.binary())
rt = invoke_table_function(ClassifyFunction, positional=(img,), serialize_state=True, max_ticks=5)
assert rt.num_rows == 0
Loading
Loading