diff --git a/CLAUDE.md b/CLAUDE.md index d1aea2c..44093d2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -50,12 +50,40 @@ output, so every function is a `TableBufferingFunction` (Sink+Source): - `process(batch)` — sink each input batch to execution-scoped `BoundStorage`. - `combine(state_ids)` — collapse to a single finalize key (one bucket). -- `finalize(...)` — reassemble the full table (`buffered_frame()` → pandas), - run the estimator once, emit one result batch, then `out.finish()`. - -`SinkBuffer` in `buffering.py` implements `process`/`combine`/`buffered_frame`; -each function only writes `on_bind` (its output schema) + `finalize`. A -`DrainState(done: bool)` cursor makes finalize emit exactly once. +- `finalize(...)` — on the first tick reassemble the full table + (`buffered_frame()` → pandas) and run the estimator once into the cursor; each + tick then emits a bounded slice and `out.finish()` once drained. + +`SinkBuffer` in `buffering.py` implements `process`/`combine`/`buffered_frame` +plus `drain_result(...)` (the cursor loop); each function only writes `on_bind` +(its output schema) + a one-line `finalize` that hands `drain_result` the +estimator call. + +### Why finalize streams an OFFSET cursor (HTTP continuation) + +Over the **stateless http transport** the framework round-trips a producer's +per-finalize-stream state through a continuation token: after each `finalize()` +tick it wire-serializes the state (`ArrowSerializableDataclass.serialize_to_bytes()`), +the client returns it, and the worker resumes by deserializing it — emitting at +most one (the producer batch limit) data batch per response. subprocess/unix keep +the live state in-process so they hide the bug; only http (and the +`run_buffering(..., serialize_state=True)` unit harness) expose it. + +A position-less `DrainState{done: bool}` that emits ALL rows in one `out.emit` +then sets `done` restarts from row 0 on every http resume and **loops forever** +once the output exceeds one producer batch — which `propensity_scores` (one row +per input subject, unbounded) routinely does. So `DrainState` carries an explicit +**offset cursor**: the already-computed result batch as IPC bytes (`result_ipc`, +fully serializable), a `started` flag, and an integer `offset`. The first tick +computes + packs the result; every tick emits at most `ROWS_PER_TICK` (64) rows +from `offset`, advances `offset`, and finishes when `offset >= total`. Because +`offset` survives the wire round-trip, a resumed tick emits the NEXT slice — never +re-runs the estimator, never restarts from row 0. `ate` (3 rows) and `att` (1 row) +are bounded but use the identical cursor for uniformity. Results are byte-identical +to the old emit-all path. The regression test is +`TestCursorSurvivesContinuation` in `test_tables.py` (re-serializes finalize state +between every tick, 10 000-tick overrun guard) plus the big-cohort paging asserts +in `causal.test` (which only terminate over http if the cursor works). ## Estimators (the math) diff --git a/test/sql/causal.test b/test/sql/causal.test index 3e39d16..668b2ed 100644 --- a/test/sql/causal.test +++ b/test/sql/causal.test @@ -110,5 +110,59 @@ SELECT * FROM causal.ate((SELECT id AS t, y, x FROM cohort), treatment := 't', o ---- binary 0/1 +# --------------------------------------------------------------------------- +# HTTP continuation paging: a cohort large enough that propensity_scores emits +# MORE than one producer batch worth of rows. propensity_scores returns one row +# per input subject, so a 200-row cohort yields 200 output rows -- well past +# ROWS_PER_TICK (64). Over the stateless http transport the finalize cursor must +# page across the limit-1 continuation boundary; a position-less cursor would +# loop forever (the test would hang/time out) instead of returning these counts. +statement ok +CREATE TEMP TABLE big_cohort AS +SELECT + g AS id, + (g - 100)::DOUBLE / 50.0 AS x, + CASE WHEN ((g * 7 + 3) % 10) < (5 + 3 * sign((g - 100)::DOUBLE)) THEN 1 ELSE 0 END AS t, + 5.0 * (CASE WHEN ((g * 7 + 3) % 10) < (5 + 3 * sign((g - 100)::DOUBLE)) THEN 1 ELSE 0 END) + + 2.0 * ((g - 100)::DOUBLE / 50.0) AS y +FROM generate_series(0, 199) AS s(g); + +# Every input subject gets exactly one propensity row -> 200 rows. If the cursor +# fails to page over http this query never terminates. +query I +SELECT count(*) +FROM causal.propensity_scores((SELECT id, t, x FROM big_cohort), treatment := 't', id := 'id'); +---- +200 + +# All 200 propensity scores are valid probabilities in (0,1) -- so every paged +# slice carried real rows, not re-emitted duplicates of row 0. +query I +SELECT count(*) +FROM causal.propensity_scores((SELECT id, t, x FROM big_cohort), treatment := 't', id := 'id') +WHERE propensity > 0.0 AND propensity < 1.0; +---- +200 + +# Each id appears exactly once across all paged slices (no dupes from a restart). +query I +SELECT count(DISTINCT id) +FROM causal.propensity_scores((SELECT id, t, x FROM big_cohort), treatment := 't', id := 'id'); +---- +200 + +# Ordered head of the paged result is stable and correct (ids 0..4 in order). +query I +SELECT id +FROM causal.propensity_scores((SELECT id, t, x FROM big_cohort), treatment := 't', id := 'id') +ORDER BY id +LIMIT 5; +---- +0 +1 +2 +3 +4 + statement ok DETACH causal; diff --git a/tests/harness.py b/tests/harness.py index 583b4eb..ceabd5b 100644 --- a/tests/harness.py +++ b/tests/harness.py @@ -18,14 +18,37 @@ from vgi.table_buffering_function import TableBufferingParams +class _BatchTooLarge(Exception): + """A finalize tick emitted more rows than one continuation response can carry.""" + + class _Collector: - """Captures emitted batches from a finalize stream.""" + """Captures emitted batches from a finalize stream. + + When ``max_rows_per_tick`` is set (the ``serialize_state`` / HTTP-continuation + model), an ``emit`` whose batch exceeds that cap raises :class:`_BatchTooLarge` + *before* recording the batch -- mirroring the stateless transport, which rejects + an over-cap response and resumes the tick from the pre-tick continuation token. + A position-less finalize re-emits the whole result every tick and never fits; + the offset cursor emits a bounded slice that does. + """ - def __init__(self) -> None: + def __init__(self, max_rows_per_tick: int | None = None) -> None: self.batches: list[pa.RecordBatch] = [] self.finished = False + self._max_rows_per_tick = max_rows_per_tick + self._tick_rows = 0 + + def begin_tick(self) -> None: + self._tick_rows = 0 def emit(self, batch: pa.RecordBatch, *_a: Any, **_kw: Any) -> None: + self._tick_rows += batch.num_rows + if self._max_rows_per_tick is not None and self._tick_rows > self._max_rows_per_tick: + raise _BatchTooLarge( + f"finalize emitted {self._tick_rows} rows in one tick, exceeding the " + f"{self._max_rows_per_tick}-row continuation cap" + ) self.batches.append(batch) def finish(self) -> None: @@ -40,6 +63,7 @@ def run_buffering( table: pa.Table, *, named: dict[str, str] | None = None, + serialize_state: bool = False, ) -> pa.Table: """Drive a causal buffering function over a whole input ``table``. @@ -47,9 +71,18 @@ def run_buffering( func_cls: The ``TableBufferingFunction`` subclass to run. table: The input relation (the ``(SELECT ...)`` data) as an Arrow table. named: Named string column-role args (e.g. ``{"treatment": "t"}``). + serialize_state: When ``True``, wire-serialize the finalize state between + every ``finalize()`` tick (``deserialize_from_bytes(serialize_to_bytes())``), + emulating the stateless HTTP continuation token. A position-less cursor + (the old ``DrainState{done}``) re-emits row 0 forever under this mode and + trips the overrun guard; the offset cursor pages and terminates. Returns: The emitted result as a single Arrow table (the function's output). + + Raises: + AssertionError: If a finalize stream runs past the ~10000-tick guard + (i.e. it never terminates -- the latent HTTP-continuation bug). """ input_schema = table.schema args = Arguments( @@ -95,12 +128,43 @@ def make_params() -> TableBufferingParams: # Combine phase. finalize_ids = func_cls.combine(state_ids, make_params()) - # Source phase: drain each finalize stream. - out = _Collector() + # Source phase: drain each finalize stream. Under ``serialize_state`` we model + # the stateless HTTP continuation: re-serialize the finalize state between every + # tick, and cap each tick at one response worth of rows (``ROWS_PER_TICK``). An + # over-cap emit is rejected and the tick resumes from the PRE-tick token (the + # mutation is discarded) -- exactly what makes a position-less finalize loop + # forever and a position cursor page. + from vgi_causal.buffering import ROWS_PER_TICK + + cap = ROWS_PER_TICK if serialize_state else None + out = _Collector(max_rows_per_tick=cap) + max_ticks = 10_000 for fid in finalize_ids: params = make_params() state = func_cls.initial_finalize_state(fid, params) + ticks = 0 while not out.finished: - func_cls.finalize(params, fid, state, out) + ticks += 1 + assert ticks < max_ticks, ( + f"{func_cls.Meta.name}.finalize did not terminate within {max_ticks} ticks " + f"(serialize_state={serialize_state}): the finalize cursor never advances " + "across the continuation boundary -- the HTTP-continuation bug." + ) + if not serialize_state: + func_cls.finalize(params, fid, state, out) + continue + pre_tick_blob = state.serialize_to_bytes() + saved_batches = len(out.batches) + out.begin_tick() + try: + func_cls.finalize(params, fid, state, out) + except _BatchTooLarge: + # Continuation: drop the over-cap batches and the un-committed state + # advance; the next attempt resumes from the pre-tick token. + del out.batches[saved_batches:] + state = type(state).deserialize_from_bytes(pre_tick_blob) + continue + state = type(state).deserialize_from_bytes(state.serialize_to_bytes()) + out.finished = False # reset for the next finalize stream return pa.Table.from_batches(out.batches, schema=bind_resp.output_schema) diff --git a/tests/test_tables.py b/tests/test_tables.py index 85ce18f..9d70750 100644 --- a/tests/test_tables.py +++ b/tests/test_tables.py @@ -60,3 +60,63 @@ def test_non_binary_treatment_raises() -> None: tbl = pa.table({"t": [0, 1, 2], "y": [1.0, 2.0, 3.0], "x1": [0.1, 0.2, 0.3]}) with pytest.raises(Exception, match="binary 0/1"): run_buffering(Ate, tbl, named={"treatment": "t", "outcome": "y"}) + + +class TestCursorSurvivesContinuation: + """The finalize cursor must survive a wire round-trip between every tick. + + Over the stateless HTTP transport the framework serializes the finalize state + after each tick, returns at most one data batch, then resumes by deserializing + the token. ``run_buffering(..., serialize_state=True)`` emulates that. The old + position-less ``DrainState{done}`` re-emitted row 0 forever (overrun guard); + the offset cursor pages through the result and terminates. + + ``propensity_scores`` emits one row per input subject -- genuinely unbounded -- + so an 800-row cohort produces 800 output rows, far exceeding ``ROWS_PER_TICK`` + (64); the cursor must page across many continuation boundaries. + """ + + def test_propensity_pages_identically_under_serialization(self) -> None: + from vgi_causal import buffering + + df = make_confounded(n=800) + tbl = _arrow(df[["id", "t", "x1", "x2"]]) + named = {"treatment": "t", "id": "id"} + + # Sanity: the result genuinely spans several ROWS_PER_TICK ticks. + assert len(df) > buffering.ROWS_PER_TICK + + baseline = run_buffering(PropensityScores, tbl, named=named).to_pydict() + paged = run_buffering(PropensityScores, tbl, named=named, serialize_state=True).to_pydict() + + # (1) Same number of rows -- no truncation, no infinite re-emit. + assert len(paged["id"]) == len(baseline["id"]) == len(df) + # (2) Byte-identical rows in identical order (sort by id to be safe). + b_order = np.argsort(np.asarray(baseline["id"])) + p_order = np.argsort(np.asarray(paged["id"])) + for col in ("id", "propensity", "treatment"): + np.testing.assert_array_equal(np.asarray(baseline[col])[b_order], np.asarray(paged[col])[p_order]) + # (3) Each id emitted exactly once (no dupes). + assert len(set(paged["id"])) == len(paged["id"]) == len(df) + + def test_small_rows_per_tick_pages_bounded_result(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Force the bounded estimators (ate=3 rows, att=1 row) to also page, so the + # cursor is exercised across the continuation boundary for every function. + from vgi_causal import buffering + + monkeypatch.setattr(buffering, "ROWS_PER_TICK", 2) + df = make_confounded(n=400) + + ate_tbl = _arrow(df[["t", "y", "x1", "x2"]]) + baseline = run_buffering(Ate, ate_tbl, named={"treatment": "t", "outcome": "y"}).to_pydict() + paged = run_buffering(Ate, ate_tbl, named={"treatment": "t", "outcome": "y"}, serialize_state=True).to_pydict() + assert paged == baseline + assert set(paged["method"]) == {"ipw", "regression_adjustment", "aipw"} + + att_tbl = _arrow(df[["t", "y", "x1", "x2"]]) + att_base = run_buffering(Att, att_tbl, named={"treatment": "t", "outcome": "y"}).to_pydict() + att_paged = run_buffering( + Att, att_tbl, named={"treatment": "t", "outcome": "y"}, serialize_state=True + ).to_pydict() + assert att_paged == att_base + assert len(att_paged["estimate"]) == 1 diff --git a/vgi_causal/buffering.py b/vgi_causal/buffering.py index cea0808..3ddcc87 100644 --- a/vgi_causal/buffering.py +++ b/vgi_causal/buffering.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass from typing import Any, cast @@ -21,15 +22,51 @@ import pyarrow as pa from vgi.table_buffering_function import TableBufferingFunction, TableBufferingParams from vgi_rpc import ArrowSerializableDataclass +from vgi_rpc.rpc import OutputCollector _DATA_KEY = b"input_batches" +# Rows emitted per finalize tick. Bounded so the cursor (``offset``) is observable +# across the HTTP limit-1 continuation boundary: the stateless HTTP transport +# wire-serializes the finalize state after every tick, returns at most one data +# batch per response, then resumes by deserializing the token. Correctness no +# longer depends on the whole result fitting in one producer batch. +ROWS_PER_TICK = 64 + @dataclass(kw_only=True) class DrainState(ArrowSerializableDataclass): - """Per-finalize-stream cursor: emit the single result batch once, then finish.""" + """Externalized finalize cursor: result batch (IPC bytes) plus next-row offset. + + Both fields wire-serialize through the HTTP continuation token, so a resumed + finalize tick sees the advanced ``offset`` and emits the next bounded slice + (or finishes) -- it never re-runs the estimator or restarts from row 0. This + is what lets ``propensity_scores`` (one row per input subject, unbounded) + page correctly over the stateless HTTP transport. + + ``result_ipc`` is empty until the first tick computes the estimate; + ``started`` distinguishes "not yet computed" from "computed an empty result". + """ + + started: bool = False + offset: int = 0 + result_ipc: bytes = b"" + + +def result_to_ipc(batch: pa.RecordBatch) -> bytes: + """Serialize the full computed result batch to Arrow IPC bytes for the cursor.""" + sink = pa.BufferOutputStream() + # pyarrow.ipc ships no type stubs, so mypy sees these as untyped calls. + with pa.ipc.new_stream(sink, batch.schema) as writer: # type: ignore[no-untyped-call] + writer.write_batch(batch) + return cast(bytes, sink.getvalue().to_pybytes()) + - done: bool = False +def ipc_to_table(value: bytes) -> pa.Table: + """Inverse of :func:`result_to_ipc`: read the cursor's result back as a table.""" + # pyarrow.ipc ships no type stubs, so mypy sees this as an untyped call. + reader = pa.ipc.open_stream(pa.BufferReader(value)) # type: ignore[no-untyped-call] + return cast(pa.Table, reader.read_all()) def serialize_batch(batch: pa.RecordBatch) -> bytes: @@ -105,3 +142,45 @@ def buffered_frame(cls, params: TableBufferingParams[TArgs]) -> pd.DataFrame: if not batches: return cast(pd.DataFrame, pa.Table.from_batches([], schema=input_schema).to_pandas()) return cast(pd.DataFrame, pa.Table.from_batches(batches, schema=input_schema).to_pandas()) + + @classmethod + def drain_result( + cls, + params: TableBufferingParams[TArgs], + state: DrainState, + out: OutputCollector, + compute: Callable[[pd.DataFrame], dict[str, list[Any]]], + ) -> None: + """Compute the result once into the cursor, then stream a bounded slice. + + The first finalize tick runs ``compute`` over the buffered frame, packs the + result batch into ``state.result_ipc`` (Arrow IPC bytes) and flips + ``state.started``. Every tick then emits at most ``ROWS_PER_TICK`` rows from + ``state.offset``, advances ``state.offset``, and calls ``out.finish()`` once + the result is drained. Because ``state`` round-trips through the HTTP + continuation token, a resumed tick sees the advanced offset and never + re-runs ``compute`` or restarts from row 0. + + Args: + params: The buffering params (args, output schema, storage). + state: The finalize cursor (result bytes + offset). + out: The collector to emit the result slice into. + compute: Maps the buffered frame to the estimator's ``dict[str, list]``. + """ + if not state.started: + df = cls.buffered_frame(params) + result = compute(df) + batch = pa.RecordBatch.from_pydict(result, schema=params.output_schema) + state.result_ipc = result_to_ipc(batch) + state.started = True + state.offset = 0 + + table = ipc_to_table(state.result_ipc) + total = table.num_rows + if state.offset >= total: + out.finish() + return + end = min(state.offset + ROWS_PER_TICK, total) + chunk = table.slice(state.offset, end - state.offset).combine_chunks() + out.emit(chunk.to_batches()[0]) + state.offset = end diff --git a/vgi_causal/tables.py b/vgi_causal/tables.py index dabb0ad..51a4f6a 100644 --- a/vgi_causal/tables.py +++ b/vgi_causal/tables.py @@ -146,7 +146,7 @@ def initial_finalize_state(cls, finalize_state_id: bytes, params: TableBuffering params: The buffering params for this execution. Returns: - A fresh drain cursor that emits the result batch once. + A fresh finalize cursor (result bytes + offset at 0). """ return DrainState() @@ -163,17 +163,11 @@ def finalize( Args: params: The buffering params (args, output schema, storage). finalize_state_id: The finalize stream identifier. - state: The drain cursor tracking single emission. - out: The collector to emit the result batch into. + state: The finalize cursor (result bytes + offset). + out: The collector to emit the result slice into. """ - if state.done: - out.finish() - return - state.done = True a = params.args - df = cls.buffered_frame(params) - result = causal.ate(df, treatment=a.treatment, outcome=a.outcome) - out.emit(pa.RecordBatch.from_pydict(result, schema=params.output_schema)) + cls.drain_result(params, state, out, lambda df: causal.ate(df, treatment=a.treatment, outcome=a.outcome)) class PropensityScores(SinkBuffer[PropensityArgs, DrainState]): @@ -224,7 +218,7 @@ def initial_finalize_state( params: The buffering params for this execution. Returns: - A fresh drain cursor that emits the result batch once. + A fresh finalize cursor (result bytes + offset at 0). """ return DrainState() @@ -241,17 +235,11 @@ def finalize( Args: params: The buffering params (args, output schema, storage). finalize_state_id: The finalize stream identifier. - state: The drain cursor tracking single emission. - out: The collector to emit the result batch into. + state: The finalize cursor (result bytes + offset). + out: The collector to emit the result slice into. """ - if state.done: - out.finish() - return - state.done = True a = params.args - df = cls.buffered_frame(params) - result = causal.propensity_scores(df, treatment=a.treatment, id=a.id) - out.emit(pa.RecordBatch.from_pydict(result, schema=params.output_schema)) + cls.drain_result(params, state, out, lambda df: causal.propensity_scores(df, treatment=a.treatment, id=a.id)) class Att(SinkBuffer[AttArgs, DrainState]): @@ -297,7 +285,7 @@ def initial_finalize_state(cls, finalize_state_id: bytes, params: TableBuffering params: The buffering params for this execution. Returns: - A fresh drain cursor that emits the result batch once. + A fresh finalize cursor (result bytes + offset at 0). """ return DrainState() @@ -314,17 +302,11 @@ def finalize( Args: params: The buffering params (args, output schema, storage). finalize_state_id: The finalize stream identifier. - state: The drain cursor tracking single emission. - out: The collector to emit the result batch into. + state: The finalize cursor (result bytes + offset). + out: The collector to emit the result slice into. """ - if state.done: - out.finish() - return - state.done = True a = params.args - df = cls.buffered_frame(params) - result = causal.att(df, treatment=a.treatment, outcome=a.outcome) - out.emit(pa.RecordBatch.from_pydict(result, schema=params.output_schema)) + cls.drain_result(params, state, out, lambda df: causal.att(df, treatment=a.treatment, outcome=a.outcome)) TABLE_FUNCTIONS: list[type] = [Ate, PropensityScores, Att]