New transform: Push indexing to only materialized nodes.#634
New transform: Push indexing to only materialized nodes.#634kaushikcfd wants to merge 3 commits intomainfrom
Conversation
35e4a63 to
a15bd58
Compare
558a1f0 to
f6a6fa8
Compare
8d1f516 to
c42a786
Compare
72aa2be to
a1f0f4b
Compare
a1f0f4b to
bf08e13
Compare
There was a problem hiding this comment.
Pull request overview
Adds a new transform to push indexing operations down the expression DAG so that indexing happens only at (or is forced onto) materialized nodes, generalizing prior work in #425. This is accompanied by a substantial new test suite and a small testing helper to compare transformed vs. reference evaluation.
Changes:
- Introduce
pt.push_index_to_materialized_nodestransform and export it frompytato.__init__. - Add extensive tests covering many indexing composition/broadcasting scenarios.
- Add
assert_allclose_to_reftest helper for comparing two Pytato expressions by evaluation.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
pytato/transform/push_index_to_materialized_nodes.py |
New transform implementation that composes/pushes indices through various ops. |
pytato/__init__.py |
Exposes push_index_to_materialized_nodes as part of the public API. |
pytato/raising.py |
Adjusts ZerosLikeOp typing/raising for the new transform’s HLO lowering path. |
pytato/array.py |
Tweaks index-expression typing (ConvertibleToIndexExpr). |
test/testlib.py |
Adds assert_allclose_to_ref helper to compare two Pytato expressions via execution. |
test/test_transform.py |
New large test module for the indexing pusher transform. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if kind == "basic": | ||
| tgt_axis = 0 | ||
| for idx in indices: | ||
| if isinstance(idx, INT_CLASSES): | ||
| accesses.append(PointAccess(idx)) | ||
| else: | ||
| assert isinstance(idx, NormalizedSlice) | ||
| accesses.append(SliceAccess(tgt_axis, idx)) | ||
| tgt_axis += 1 |
There was a problem hiding this comment.
get_indexing_kind() treats 0-d Array indices as "basic", but _get_axis_accesses()'s kind == "basic" branch assumes non-slice indices are only INT_CLASSES and asserts everything else is a NormalizedSlice. This will fail for valid scalar-array indices (shape ()). Handle Array indices with ndim == 0 here (e.g., represent them as an ArrayIndexAccess((), idx)), or adjust the kind classification so scalar arrays don't enter the basic branch.
| @@ -90,7 +90,7 @@ class BinaryOp(HighLevelOp): | |||
| @dataclass(frozen=True, eq=True, repr=True) | |||
| class ZerosLikeOp(HighLevelOp): | |||
| function: str | |||
There was a problem hiding this comment.
ZerosLikeOp still declares a required function: str field, but the code constructing it no longer passes that argument. This will raise TypeError when index_lambda_to_high_level_op encounters pytato.zero(...). Either remove the unused function field from ZerosLikeOp or pass the expected function name when constructing the op.
| function: str |
| inner_expr.parameters, expr.bindings, expr.shape | ||
| ) | ||
| assert isinstance(ary, Array) | ||
| return ZerosLikeOp(ary) |
There was a problem hiding this comment.
index_lambda_to_high_level_op returns ZerosLikeOp(ary) but ZerosLikeOp currently requires both function and x. This call will fail at runtime. Update the constructor call to match the dataclass fields (and keep it consistent with downstream users like push_index_to_materialized_nodes expecting .x).
| return ZerosLikeOp(ary) | |
| return ZerosLikeOp(inner_expr.function, ary) |
| elif isinstance(hlo, C99CallOp): | ||
| new_args = tuple( | ||
| ( | ||
| self.rec_w_broadcast(ary_arg, expr.shape, indices) | ||
| if isinstance(ary_arg, Array) | ||
| else ary_arg | ||
| ) | ||
| for ary_arg in hlo.args | ||
| ) | ||
| return _lower_call_op_hlo( | ||
| replace( # pyright: ignore[reportUnboundVariable,reportUnknownArgumentType] | ||
| hlo, args=new_args | ||
| ) | ||
| ) |
There was a problem hiding this comment.
replace is only imported inside the BinaryOp branch, but it is also used in the C99CallOp branch. If hlo is a C99CallOp, this will raise NameError at runtime. Import dataclasses.replace at function scope (or re-import inside the C99CallOp branch) before calling it.
Generalizes the logic in #425.