Skip to content

New transform: Push indexing to only materialized nodes.#634

Open
kaushikcfd wants to merge 3 commits intomainfrom
indirection_pusher
Open

New transform: Push indexing to only materialized nodes.#634
kaushikcfd wants to merge 3 commits intomainfrom
indirection_pusher

Conversation

@kaushikcfd
Copy link
Collaborator

Generalizes the logic in #425.

@kaushikcfd kaushikcfd force-pushed the indirection_pusher branch 5 times, most recently from 35e4a63 to a15bd58 Compare February 20, 2026 19:43
@kaushikcfd kaushikcfd force-pushed the indirection_pusher branch 8 times, most recently from 558a1f0 to f6a6fa8 Compare March 1, 2026 07:07
@kaushikcfd kaushikcfd force-pushed the indirection_pusher branch 2 times, most recently from 8d1f516 to c42a786 Compare March 2, 2026 01:21
@kaushikcfd kaushikcfd changed the base branch from main to some_typing_fixes_0 March 2, 2026 01:22
@kaushikcfd kaushikcfd marked this pull request as ready for review March 2, 2026 01:22
@kaushikcfd kaushikcfd requested a review from inducer March 2, 2026 01:31
Base automatically changed from some_typing_fixes_0 to main March 2, 2026 01:36
@kaushikcfd kaushikcfd force-pushed the indirection_pusher branch 2 times, most recently from 72aa2be to a1f0f4b Compare March 2, 2026 02:03
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new transform to push indexing operations down the expression DAG so that indexing happens only at (or is forced onto) materialized nodes, generalizing prior work in #425. This is accompanied by a substantial new test suite and a small testing helper to compare transformed vs. reference evaluation.

Changes:

  • Introduce pt.push_index_to_materialized_nodes transform and export it from pytato.__init__.
  • Add extensive tests covering many indexing composition/broadcasting scenarios.
  • Add assert_allclose_to_ref test helper for comparing two Pytato expressions by evaluation.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
pytato/transform/push_index_to_materialized_nodes.py New transform implementation that composes/pushes indices through various ops.
pytato/__init__.py Exposes push_index_to_materialized_nodes as part of the public API.
pytato/raising.py Adjusts ZerosLikeOp typing/raising for the new transform’s HLO lowering path.
pytato/array.py Tweaks index-expression typing (ConvertibleToIndexExpr).
test/testlib.py Adds assert_allclose_to_ref helper to compare two Pytato expressions via execution.
test/test_transform.py New large test module for the indexing pusher transform.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +351 to +359
if kind == "basic":
tgt_axis = 0
for idx in indices:
if isinstance(idx, INT_CLASSES):
accesses.append(PointAccess(idx))
else:
assert isinstance(idx, NormalizedSlice)
accesses.append(SliceAccess(tgt_axis, idx))
tgt_axis += 1
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_indexing_kind() treats 0-d Array indices as "basic", but _get_axis_accesses()'s kind == "basic" branch assumes non-slice indices are only INT_CLASSES and asserts everything else is a NormalizedSlice. This will fail for valid scalar-array indices (shape ()). Handle Array indices with ndim == 0 here (e.g., represent them as an ArrayIndexAccess((), idx)), or adjust the kind classification so scalar arrays don't enter the basic branch.

Copilot uses AI. Check for mistakes.
@@ -90,7 +90,7 @@ class BinaryOp(HighLevelOp):
@dataclass(frozen=True, eq=True, repr=True)
class ZerosLikeOp(HighLevelOp):
function: str
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ZerosLikeOp still declares a required function: str field, but the code constructing it no longer passes that argument. This will raise TypeError when index_lambda_to_high_level_op encounters pytato.zero(...). Either remove the unused function field from ZerosLikeOp or pass the expected function name when constructing the op.

Suggested change
function: str

Copilot uses AI. Check for mistakes.
inner_expr.parameters, expr.bindings, expr.shape
)
assert isinstance(ary, Array)
return ZerosLikeOp(ary)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_lambda_to_high_level_op returns ZerosLikeOp(ary) but ZerosLikeOp currently requires both function and x. This call will fail at runtime. Update the constructor call to match the dataclass fields (and keep it consistent with downstream users like push_index_to_materialized_nodes expecting .x).

Suggested change
return ZerosLikeOp(ary)
return ZerosLikeOp(inner_expr.function, ary)

Copilot uses AI. Check for mistakes.
Comment on lines +1063 to +1076
elif isinstance(hlo, C99CallOp):
new_args = tuple(
(
self.rec_w_broadcast(ary_arg, expr.shape, indices)
if isinstance(ary_arg, Array)
else ary_arg
)
for ary_arg in hlo.args
)
return _lower_call_op_hlo(
replace( # pyright: ignore[reportUnboundVariable,reportUnknownArgumentType]
hlo, args=new_args
)
)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace is only imported inside the BinaryOp branch, but it is also used in the C99CallOp branch. If hlo is a C99CallOp, this will raise NameError at runtime. Import dataclasses.replace at function scope (or re-import inside the C99CallOp branch) before calling it.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants