Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,40 @@ output, so every function is a `TableBufferingFunction` (Sink+Source):

- `process(batch)` — sink each input batch to execution-scoped `BoundStorage`.
- `combine(state_ids)` — collapse to a single finalize key (one bucket).
- `finalize(...)` — reassemble the full table (`buffered_frame()` → pandas),
run the estimator once, emit one result batch, then `out.finish()`.

`SinkBuffer` in `buffering.py` implements `process`/`combine`/`buffered_frame`;
each function only writes `on_bind` (its output schema) + `finalize`. A
`DrainState(done: bool)` cursor makes finalize emit exactly once.
- `finalize(...)` — on the first tick reassemble the full table
(`buffered_frame()` → pandas) and run the estimator once into the cursor; each
tick then emits a bounded slice and `out.finish()` once drained.

`SinkBuffer` in `buffering.py` implements `process`/`combine`/`buffered_frame`
plus `drain_result(...)` (the cursor loop); each function only writes `on_bind`
(its output schema) + a one-line `finalize` that hands `drain_result` the
estimator call.

### Why finalize streams an OFFSET cursor (HTTP continuation)

Over the **stateless http transport** the framework round-trips a producer's
per-finalize-stream state through a continuation token: after each `finalize()`
tick it wire-serializes the state (`ArrowSerializableDataclass.serialize_to_bytes()`),
the client returns it, and the worker resumes by deserializing it — emitting at
most one (the producer batch limit) data batch per response. subprocess/unix keep
the live state in-process so they hide the bug; only http (and the
`run_buffering(..., serialize_state=True)` unit harness) expose it.

A position-less `DrainState{done: bool}` that emits ALL rows in one `out.emit`
then sets `done` restarts from row 0 on every http resume and **loops forever**
once the output exceeds one producer batch — which `propensity_scores` (one row
per input subject, unbounded) routinely does. So `DrainState` carries an explicit
**offset cursor**: the already-computed result batch as IPC bytes (`result_ipc`,
fully serializable), a `started` flag, and an integer `offset`. The first tick
computes + packs the result; every tick emits at most `ROWS_PER_TICK` (64) rows
from `offset`, advances `offset`, and finishes when `offset >= total`. Because
`offset` survives the wire round-trip, a resumed tick emits the NEXT slice — never
re-runs the estimator, never restarts from row 0. `ate` (3 rows) and `att` (1 row)
are bounded but use the identical cursor for uniformity. Results are byte-identical
to the old emit-all path. The regression test is
`TestCursorSurvivesContinuation` in `test_tables.py` (re-serializes finalize state
between every tick, 10 000-tick overrun guard) plus the big-cohort paging asserts
in `causal.test` (which only terminate over http if the cursor works).

## Estimators (the math)

Expand Down
54 changes: 54 additions & 0 deletions test/sql/causal.test
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,59 @@ SELECT * FROM causal.ate((SELECT id AS t, y, x FROM cohort), treatment := 't', o
----
binary 0/1

# ---------------------------------------------------------------------------
# HTTP continuation paging: a cohort large enough that propensity_scores emits
# MORE than one producer batch worth of rows. propensity_scores returns one row
# per input subject, so a 200-row cohort yields 200 output rows -- well past
# ROWS_PER_TICK (64). Over the stateless http transport the finalize cursor must
# page across the limit-1 continuation boundary; a position-less cursor would
# loop forever (the test would hang/time out) instead of returning these counts.
statement ok
CREATE TEMP TABLE big_cohort AS
SELECT
g AS id,
(g - 100)::DOUBLE / 50.0 AS x,
CASE WHEN ((g * 7 + 3) % 10) < (5 + 3 * sign((g - 100)::DOUBLE)) THEN 1 ELSE 0 END AS t,
5.0 * (CASE WHEN ((g * 7 + 3) % 10) < (5 + 3 * sign((g - 100)::DOUBLE)) THEN 1 ELSE 0 END)
+ 2.0 * ((g - 100)::DOUBLE / 50.0) AS y
FROM generate_series(0, 199) AS s(g);

# Every input subject gets exactly one propensity row -> 200 rows. If the cursor
# fails to page over http this query never terminates.
query I
SELECT count(*)
FROM causal.propensity_scores((SELECT id, t, x FROM big_cohort), treatment := 't', id := 'id');
----
200

# All 200 propensity scores are valid probabilities in (0,1) -- so every paged
# slice carried real rows, not re-emitted duplicates of row 0.
query I
SELECT count(*)
FROM causal.propensity_scores((SELECT id, t, x FROM big_cohort), treatment := 't', id := 'id')
WHERE propensity > 0.0 AND propensity < 1.0;
----
200

# Each id appears exactly once across all paged slices (no dupes from a restart).
query I
SELECT count(DISTINCT id)
FROM causal.propensity_scores((SELECT id, t, x FROM big_cohort), treatment := 't', id := 'id');
----
200

# Ordered head of the paged result is stable and correct (ids 0..4 in order).
query I
SELECT id
FROM causal.propensity_scores((SELECT id, t, x FROM big_cohort), treatment := 't', id := 'id')
ORDER BY id
LIMIT 5;
----
0
1
2
3
4

statement ok
DETACH causal;
74 changes: 69 additions & 5 deletions tests/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,37 @@
from vgi.table_buffering_function import TableBufferingParams


class _BatchTooLarge(Exception):
"""A finalize tick emitted more rows than one continuation response can carry."""


class _Collector:
"""Captures emitted batches from a finalize stream."""
"""Captures emitted batches from a finalize stream.

When ``max_rows_per_tick`` is set (the ``serialize_state`` / HTTP-continuation
model), an ``emit`` whose batch exceeds that cap raises :class:`_BatchTooLarge`
*before* recording the batch -- mirroring the stateless transport, which rejects
an over-cap response and resumes the tick from the pre-tick continuation token.
A position-less finalize re-emits the whole result every tick and never fits;
the offset cursor emits a bounded slice that does.
"""

def __init__(self) -> None:
def __init__(self, max_rows_per_tick: int | None = None) -> None:
self.batches: list[pa.RecordBatch] = []
self.finished = False
self._max_rows_per_tick = max_rows_per_tick
self._tick_rows = 0

def begin_tick(self) -> None:
self._tick_rows = 0

def emit(self, batch: pa.RecordBatch, *_a: Any, **_kw: Any) -> None:
self._tick_rows += batch.num_rows
if self._max_rows_per_tick is not None and self._tick_rows > self._max_rows_per_tick:
raise _BatchTooLarge(
f"finalize emitted {self._tick_rows} rows in one tick, exceeding the "
f"{self._max_rows_per_tick}-row continuation cap"
)
self.batches.append(batch)

def finish(self) -> None:
Expand All @@ -40,16 +63,26 @@ def run_buffering(
table: pa.Table,
*,
named: dict[str, str] | None = None,
serialize_state: bool = False,
) -> pa.Table:
"""Drive a causal buffering function over a whole input ``table``.

Args:
func_cls: The ``TableBufferingFunction`` subclass to run.
table: The input relation (the ``(SELECT ...)`` data) as an Arrow table.
named: Named string column-role args (e.g. ``{"treatment": "t"}``).
serialize_state: When ``True``, wire-serialize the finalize state between
every ``finalize()`` tick (``deserialize_from_bytes(serialize_to_bytes())``),
emulating the stateless HTTP continuation token. A position-less cursor
(the old ``DrainState{done}``) re-emits row 0 forever under this mode and
trips the overrun guard; the offset cursor pages and terminates.

Returns:
The emitted result as a single Arrow table (the function's output).

Raises:
AssertionError: If a finalize stream runs past the ~10000-tick guard
(i.e. it never terminates -- the latent HTTP-continuation bug).
"""
input_schema = table.schema
args = Arguments(
Expand Down Expand Up @@ -95,12 +128,43 @@ def make_params() -> TableBufferingParams:
# Combine phase.
finalize_ids = func_cls.combine(state_ids, make_params())

# Source phase: drain each finalize stream.
out = _Collector()
# Source phase: drain each finalize stream. Under ``serialize_state`` we model
# the stateless HTTP continuation: re-serialize the finalize state between every
# tick, and cap each tick at one response worth of rows (``ROWS_PER_TICK``). An
# over-cap emit is rejected and the tick resumes from the PRE-tick token (the
# mutation is discarded) -- exactly what makes a position-less finalize loop
# forever and a position cursor page.
from vgi_causal.buffering import ROWS_PER_TICK

cap = ROWS_PER_TICK if serialize_state else None
out = _Collector(max_rows_per_tick=cap)
max_ticks = 10_000
for fid in finalize_ids:
params = make_params()
state = func_cls.initial_finalize_state(fid, params)
ticks = 0
while not out.finished:
func_cls.finalize(params, fid, state, out)
ticks += 1
assert ticks < max_ticks, (
f"{func_cls.Meta.name}.finalize did not terminate within {max_ticks} ticks "
f"(serialize_state={serialize_state}): the finalize cursor never advances "
"across the continuation boundary -- the HTTP-continuation bug."
)
if not serialize_state:
func_cls.finalize(params, fid, state, out)
continue
pre_tick_blob = state.serialize_to_bytes()
saved_batches = len(out.batches)
out.begin_tick()
try:
func_cls.finalize(params, fid, state, out)
except _BatchTooLarge:
# Continuation: drop the over-cap batches and the un-committed state
# advance; the next attempt resumes from the pre-tick token.
del out.batches[saved_batches:]
state = type(state).deserialize_from_bytes(pre_tick_blob)
continue
state = type(state).deserialize_from_bytes(state.serialize_to_bytes())
out.finished = False # reset for the next finalize stream

return pa.Table.from_batches(out.batches, schema=bind_resp.output_schema)
60 changes: 60 additions & 0 deletions tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,63 @@ def test_non_binary_treatment_raises() -> None:
tbl = pa.table({"t": [0, 1, 2], "y": [1.0, 2.0, 3.0], "x1": [0.1, 0.2, 0.3]})
with pytest.raises(Exception, match="binary 0/1"):
run_buffering(Ate, tbl, named={"treatment": "t", "outcome": "y"})


class TestCursorSurvivesContinuation:
"""The finalize cursor must survive a wire round-trip between every tick.

Over the stateless HTTP transport the framework serializes the finalize state
after each tick, returns at most one data batch, then resumes by deserializing
the token. ``run_buffering(..., serialize_state=True)`` emulates that. The old
position-less ``DrainState{done}`` re-emitted row 0 forever (overrun guard);
the offset cursor pages through the result and terminates.

``propensity_scores`` emits one row per input subject -- genuinely unbounded --
so an 800-row cohort produces 800 output rows, far exceeding ``ROWS_PER_TICK``
(64); the cursor must page across many continuation boundaries.
"""

def test_propensity_pages_identically_under_serialization(self) -> None:
from vgi_causal import buffering

df = make_confounded(n=800)
tbl = _arrow(df[["id", "t", "x1", "x2"]])
named = {"treatment": "t", "id": "id"}

# Sanity: the result genuinely spans several ROWS_PER_TICK ticks.
assert len(df) > buffering.ROWS_PER_TICK

baseline = run_buffering(PropensityScores, tbl, named=named).to_pydict()
paged = run_buffering(PropensityScores, tbl, named=named, serialize_state=True).to_pydict()

# (1) Same number of rows -- no truncation, no infinite re-emit.
assert len(paged["id"]) == len(baseline["id"]) == len(df)
# (2) Byte-identical rows in identical order (sort by id to be safe).
b_order = np.argsort(np.asarray(baseline["id"]))
p_order = np.argsort(np.asarray(paged["id"]))
for col in ("id", "propensity", "treatment"):
np.testing.assert_array_equal(np.asarray(baseline[col])[b_order], np.asarray(paged[col])[p_order])
# (3) Each id emitted exactly once (no dupes).
assert len(set(paged["id"])) == len(paged["id"]) == len(df)

def test_small_rows_per_tick_pages_bounded_result(self, monkeypatch: pytest.MonkeyPatch) -> None:
# Force the bounded estimators (ate=3 rows, att=1 row) to also page, so the
# cursor is exercised across the continuation boundary for every function.
from vgi_causal import buffering

monkeypatch.setattr(buffering, "ROWS_PER_TICK", 2)
df = make_confounded(n=400)

ate_tbl = _arrow(df[["t", "y", "x1", "x2"]])
baseline = run_buffering(Ate, ate_tbl, named={"treatment": "t", "outcome": "y"}).to_pydict()
paged = run_buffering(Ate, ate_tbl, named={"treatment": "t", "outcome": "y"}, serialize_state=True).to_pydict()
assert paged == baseline
assert set(paged["method"]) == {"ipw", "regression_adjustment", "aipw"}

att_tbl = _arrow(df[["t", "y", "x1", "x2"]])
att_base = run_buffering(Att, att_tbl, named={"treatment": "t", "outcome": "y"}).to_pydict()
att_paged = run_buffering(
Att, att_tbl, named={"treatment": "t", "outcome": "y"}, serialize_state=True
).to_pydict()
assert att_paged == att_base
assert len(att_paged["estimate"]) == 1
83 changes: 81 additions & 2 deletions vgi_causal/buffering.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,59 @@

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, cast

import pandas as pd
import pyarrow as pa
from vgi.table_buffering_function import TableBufferingFunction, TableBufferingParams
from vgi_rpc import ArrowSerializableDataclass
from vgi_rpc.rpc import OutputCollector

_DATA_KEY = b"input_batches"

# Rows emitted per finalize tick. Bounded so the cursor (``offset``) is observable
# across the HTTP limit-1 continuation boundary: the stateless HTTP transport
# wire-serializes the finalize state after every tick, returns at most one data
# batch per response, then resumes by deserializing the token. Correctness no
# longer depends on the whole result fitting in one producer batch.
ROWS_PER_TICK = 64


@dataclass(kw_only=True)
class DrainState(ArrowSerializableDataclass):
"""Per-finalize-stream cursor: emit the single result batch once, then finish."""
"""Externalized finalize cursor: result batch (IPC bytes) plus next-row offset.

Both fields wire-serialize through the HTTP continuation token, so a resumed
finalize tick sees the advanced ``offset`` and emits the next bounded slice
(or finishes) -- it never re-runs the estimator or restarts from row 0. This
is what lets ``propensity_scores`` (one row per input subject, unbounded)
page correctly over the stateless HTTP transport.

``result_ipc`` is empty until the first tick computes the estimate;
``started`` distinguishes "not yet computed" from "computed an empty result".
"""

started: bool = False
offset: int = 0
result_ipc: bytes = b""


def result_to_ipc(batch: pa.RecordBatch) -> bytes:
"""Serialize the full computed result batch to Arrow IPC bytes for the cursor."""
sink = pa.BufferOutputStream()
# pyarrow.ipc ships no type stubs, so mypy sees these as untyped calls.
with pa.ipc.new_stream(sink, batch.schema) as writer: # type: ignore[no-untyped-call]
writer.write_batch(batch)
return cast(bytes, sink.getvalue().to_pybytes())


done: bool = False
def ipc_to_table(value: bytes) -> pa.Table:
"""Inverse of :func:`result_to_ipc`: read the cursor's result back as a table."""
# pyarrow.ipc ships no type stubs, so mypy sees this as an untyped call.
reader = pa.ipc.open_stream(pa.BufferReader(value)) # type: ignore[no-untyped-call]
return cast(pa.Table, reader.read_all())


def serialize_batch(batch: pa.RecordBatch) -> bytes:
Expand Down Expand Up @@ -105,3 +142,45 @@ def buffered_frame(cls, params: TableBufferingParams[TArgs]) -> pd.DataFrame:
if not batches:
return cast(pd.DataFrame, pa.Table.from_batches([], schema=input_schema).to_pandas())
return cast(pd.DataFrame, pa.Table.from_batches(batches, schema=input_schema).to_pandas())

@classmethod
def drain_result(
cls,
params: TableBufferingParams[TArgs],
state: DrainState,
out: OutputCollector,
compute: Callable[[pd.DataFrame], dict[str, list[Any]]],
) -> None:
"""Compute the result once into the cursor, then stream a bounded slice.

The first finalize tick runs ``compute`` over the buffered frame, packs the
result batch into ``state.result_ipc`` (Arrow IPC bytes) and flips
``state.started``. Every tick then emits at most ``ROWS_PER_TICK`` rows from
``state.offset``, advances ``state.offset``, and calls ``out.finish()`` once
the result is drained. Because ``state`` round-trips through the HTTP
continuation token, a resumed tick sees the advanced offset and never
re-runs ``compute`` or restarts from row 0.

Args:
params: The buffering params (args, output schema, storage).
state: The finalize cursor (result bytes + offset).
out: The collector to emit the result slice into.
compute: Maps the buffered frame to the estimator's ``dict[str, list]``.
"""
if not state.started:
df = cls.buffered_frame(params)
result = compute(df)
batch = pa.RecordBatch.from_pydict(result, schema=params.output_schema)
state.result_ipc = result_to_ipc(batch)
state.started = True
state.offset = 0

table = ipc_to_table(state.result_ipc)
total = table.num_rows
if state.offset >= total:
out.finish()
return
end = min(state.offset + ROWS_PER_TICK, total)
chunk = table.slice(state.offset, end - state.offset).combine_chunks()
out.emit(chunk.to_batches()[0])
state.offset = end
Loading
Loading