Implement streaming window functions in cudf-polars#22191
Implement streaming window functions in cudf-polars#22191Matt711 wants to merge 65 commits intorapidsai:mainfrom
Conversation
Matt711
left a comment
There was a problem hiding this comment.
Review Guide: I recommend reading the PR description first, then looking over the tests. Then looking at the three execution strategies (and the corresponding tests). Finally, the full-order preservation logic.
If you think we should split up the PR, that's okay. Additionally if you think logic should be shared (especially in the scalar aggs - groupby case), we can discuss what specifically in your review. I abstracted some logic into a helper function like _make_hash_shuffle_metadata but in general I avoided it (in the groupby case) because it made it more difficult to understand IMO.
| [False, True], | ||
| ids=["same_rank", "cross_rank"], | ||
| ) | ||
| def test_over_multirank( |
There was a problem hiding this comment.
I tested this using rrun
Details
(rapids) coder ➜ ~/cudf $ rrun -n 2 python -m pytest python/cudf_polars/tests/experimental/test_spmd.py::test_over_multirank -x -v
[rrun] All ranks launched. Waiting for completion...
============================= test session starts ==============================
platform linux -- Python 3.14.4, pytest-9.0.3, pluggy-1.6.0 -- /home/coder/.conda/envs/rapids/bin/python
============================= test session starts ==============================
platform linux -- Python 3.14.4, pytest-9.0.3, pluggy-1.6.0 -- /home/coder/.conda/envs/rapids/bin/python
cachedir: .pytest_cache
hypothesis profile 'default'
benchmark: 5.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /home/coder/cudf/python/cudf_polars
configfile: pyproject.toml
plugins: cases-3.10.1, anyio-4.13.0, hypothesis-6.151.13, cov-7.1.0, xdist-3.8.0, benchmark-5.2.3, pytest_httpserver-1.1.5, rerunfailures-16.1
cachedir: .pytest_cache
hypothesis profile 'default'
benchmark: 5.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /home/coder/cudf/python/cudf_polars
configfile: pyproject.toml
plugins: cases-3.10.1, anyio-4.13.0, hypothesis-6.151.13, cov-7.1.0, xdist-3.8.0, benchmark-5.2.3, pytest_httpserver-1.1.5, rerunfailures-16.1
collecting ... collected 4 items
collecting ... collected 4 items
python/cudf_polars/tests/experimental/test_spmd.py::test_over_multirank[same_rank-scalar_sum] PASSED [ 25%]
python/cudf_polars/tests/experimental/test_spmd.py::test_over_multirank[same_rank-scalar_sum] PASSED [ 25%]
python/cudf_polars/tests/experimental/test_spmd.py::test_over_multirank[same_rank-nonscalar_rank] PASSED [ 50%]
python/cudf_polars/tests/experimental/test_spmd.py::test_over_multirank[same_rank-nonscalar_rank] PASSED [ 50%]
python/cudf_polars/tests/experimental/test_spmd.py::test_over_multirank[cross_rank-scalar_sum] PASSED [ 75%]
python/cudf_polars/tests/experimental/test_spmd.py::test_over_multirank[cross_rank-scalar_sum] PASSED [ 75%]
python/cudf_polars/tests/experimental/test_spmd.py::test_over_multirank[cross_rank-nonscalar_rank] XFAIL [100%]
python/cudf_polars/tests/experimental/test_spmd.py::test_over_multirank[cross_rank-nonscalar_rank] XFAIL [100%]
========================= 3 passed, 1 xfailed in 3.45s =========================
========================= 3 passed, 1 xfailed in 3.46s =========================
594810d to
696bbc9
Compare
|
Moved this to WIP while I work throught the rapidsmpf test failures: https://github.com/rapidsai/cudf/actions/runs/24570156132/job/71845610400?pr=22191 |
|
/ok to test 0d6d6b1 |
9d48123 to
bc584d7
Compare
|
/ok to test bc584d7 |
- This is a follow-up to rapidsai#21796 - This (hopefully) simplifies some code in rapidsai#22191 **Problem statement**: We currently translate `HStack` nodes with non-pointwise expressions to the equivalent `Select` node at lowering time. This is because all our non-pointwise `Expr`-decomposition logic is specific to `Select`. Before this PR, this translation was skipped whenever the underlying `HStack` was completely overwriting it's original columns. The problem with this case is that we loose "anchor" columns that tell the `Select` how to broadcast scalar-aggregation results. **Proposed solution**: We add a temporary "anchor" column to the translated `HStack` so that broadcasting works correctly in the `Select` node. **Motivation**: - We can handle all `over()` expression decomposition within `Select` if we know **all** non-pointwise HStack operations are lowered to `Select` anyway. - We don't "fall back" for other non-`over` `HStack` corner cases either. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Matthew Murray (https://github.com/Matt711) URL: rapidsai#22353
|
/ok to test 338757a |
|
/ok to test 5b25eea |
wence-
left a comment
There was a problem hiding this comment.
A first go. I think there is opportunity to find some more abstraction here, since it seems we are recreating many concepts sui generis for this particular implementation. I did not get to the shuffle-based implementation yet.
I think a useful signpost would be for a module-level docstring that describes the algorithmic aspects of what is going on, without reaching into the implementation details.
| @dataclass(frozen=True) | ||
| class OriginStamps: |
There was a problem hiding this comment.
I think all of this is sui generis sort_by_key with the key being a tuple of (rank, chunk_index, position). Can we reuse the infrastructure for sort to do that for us?
I guess that doesn't necessarily give you everything with a given rank as its key on the input rank?
There was a problem hiding this comment.
Yeah the sort infrastructure doesn't give us "everything with a given rank as its key on the input rank"
…ts, clean up Over.do_evaluate signature
9cb0834 to
397d50f
Compare
|
/ok to test f7687f1 |
Description
over()is, at its heart, a grouped aggregation followed by a broadcast back to the shape of the input. For each groupgdefined by the partition-by keys, evaluate the expression, then map the result back to every row that belongs tog.Polars represents this with a
WindowMappingenum. This PR adds support for thegroup_to_rowsmapping in the RapidsMPF streaming executor (the variant where the output has the same number of rows as the input and each row receives the value computed for its group). The entry point is a newover_actorthat selects one of three execution strategies at runtime based on the incoming channel metadata and expression shape.The
over_actor: three strategies1. Chunkwise (already partitioned)
If the incoming channel metadata shows the data is already hash-partitioned on the
over()keys (or any prefix of them; being partitioned on('a',)is sufficient forover('a', 'b'), since every group is contained within one rank), the window function is trivially correct on each chunk in isolation. We evaluate chunkwise with no coordination at all.2. Scalar aggregations: AllGather + broadcast
When every
GroupedWindowin the expression is a scalar aggregation (sum,mean,count, etc.), we exploit the fact that these are decomposable: each worker computes partial aggregates chunkwise, an AllGather collects all workers' partial results, a single reduction produces the global aggregate per group, and then each original chunk has those results broadcast back into its row positions via a hash join on the partition keys.3. Non-scalar aggregations: forward-shuffle + return-shuffle
Functions like
rankare not decomposable; they require every row in the group to be visible at once. We hash-shuffle by the partition keys so that all rows belonging to groupgland in the same rank for evaluation. The challenge is then twofold: putting rows back in the right order, and getting them back to the rank that owns the corresponding output chunk in the first place. Output channels are rank-local, so only the rank that received an input chunk is wired up to emit it, and the hash shuffle scatters rows by group with no regard for where they originated. We need an explicit return trip.Preserving full order
A lot of the implementation exists purely to put output chunks back in the same sequence-number order as the input. Getting this right across both strategies is where most of the complexity lives.
Scalar aggregation path. We can't produce any output until the global aggregate is known, so we buffer incoming chunks while simultaneously computing partial aggregates over them. Once the AllGather + final reduction completes, we iterate over the buffer and evaluate each chunk against the global aggregate, emitting results with their original sequence numbers. Order preservation falls out naturally: the buffer is in receive order and we never reorder it.
Non-scalar shuffle path. Each row is stamped with three pieces of origin metadata before it enters the forward shuffle: an
origin_rank(which rank ingested it), achunk_index(a rank-local 0-based counter, not the upstream message sequence number, which can collide when the input is the output of a prior shuffle), and apositionwithin that input chunk. After the forward shuffle, each rank holds a mix of rows from every origin, but each row knows where it came from. We evaluate the window function on each local forward partition (soranksees every row in the group), then route the results through a return shuffle keyed onorigin_rank. The return shuffle usesnum_partitions = nranksandPartitionAssignment.CONTIGUOUS, so partitionilives on ranki, and every row goes back to the rank that originally received it. Each rank then sorts the returned rows by(chunk_index, position), splits at chunk-index transitions, drops the stamp columns, and emits one output chunk per input chunk in input order.To avoid buffering every input chunk just to size the forward shuffle, the actor samples a small number of chunks up front (
_choose_modulus), AllGathers a size estimate, picks the modulus, and then replays the sampled chunks back through a fresh channel viareplay_buffered_channel. The forward-insert phase reads from that replay channel and streams rows into the shuffle as they arrive, never holding more than the shuffle's own internal buffering..over()in streaming cuDF-Polars #22047Checklist