From d35bf1b99bce4fa0b16ac2766c20a0c0ba7d3ed7 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Tue, 17 Mar 2026 15:19:39 +0100 Subject: [PATCH] descript: validate low-level schedules --- src/xtc/backends/mlir/MlirScheduler.py | 27 +++--- .../schedules/test_low_level_validation.py | 87 +++++++++++++++++++ 2 files changed, 103 insertions(+), 11 deletions(-) create mode 100644 tests/filecheck/schedules/test_low_level_validation.py diff --git a/src/xtc/backends/mlir/MlirScheduler.py b/src/xtc/backends/mlir/MlirScheduler.py index cdf382b5..ba2e9b29 100644 --- a/src/xtc/backends/mlir/MlirScheduler.py +++ b/src/xtc/backends/mlir/MlirScheduler.py @@ -9,7 +9,7 @@ import xtc.itf as itf import xtc.backends.mlir as backend -from .MlirNodeScheduler import MlirNodeScheduler, MlirNodeSchedule +from .MlirNodeScheduler import MlirNodeScheduler, MlirNodeSchedule, basename __all__ = [ "MlirScheduler", @@ -212,36 +212,41 @@ def get_loop_nest(self) -> LoopNest: loop_nest = LoopNest(abstract_dims=dims) root_node = loop_nest.build_root_node(node_sched.node_name) - # Assign splits to root_node first + # Assign splits to root_node first, stripping the root prefix from names for axis, axis_splits in node_sched.splits.items(): - root_node.splits[axis] = dict(axis_splits) + root_node.splits[axis] = {basename(k): v for k, v in axis_splits.items()} # Build mapper to get splits_info mapper = LoopInfo.build_from_node(root_node) def populate_node(node: LoopNestNode, perm: list[str]) -> None: """Populate node with data for loops in its permutation.""" - node.interchange = list(perm) perm_set = set(perm) + node.interchange = [basename(n) for n in perm] for axis, axis_tiles in node_sched.tiles.items(): for tile_name, size in axis_tiles.items(): if tile_name in perm_set: if axis not in node.tiles: node.tiles[axis] = {} - node.tiles[axis][tile_name] = size - node.vectorize = [v for v in node_sched.vectorization if v in perm_set] - node.parallelize = [p for p in node_sched.parallelization if p in perm_set] + node.tiles[axis][basename(tile_name)] = size + node.vectorize = [ + basename(v) for v in node_sched.vectorization if v in perm_set + ] + node.parallelize = [ + basename(p) for p in node_sched.parallelization if p in perm_set + ] node.unroll = { - k: v for k, v in node_sched.unrolling.items() if k in perm_set + basename(k): v for k, v in node_sched.unrolling.items() if k in perm_set } # Process each root in permutation for root, perm in node_sched.permutation.items(): - if root in mapper.splits_info: + root_name = basename(root) + if root_name in mapper.splits_info: # This root is a split - create child node - axis, start, end = mapper.splits_info[root] + axis, start, end = mapper.splits_info[root_name] child = LoopNestNode( - root=root, + root=root_name, tiles={d: {} for d in dims}, split_origin=SplitOrigin(axis=axis, start=start, end=end), ) diff --git a/tests/filecheck/schedules/test_low_level_validation.py b/tests/filecheck/schedules/test_low_level_validation.py new file mode 100644 index 00000000..add10b79 --- /dev/null +++ b/tests/filecheck/schedules/test_low_level_validation.py @@ -0,0 +1,87 @@ +# RUN: python %s 2>&1 | filecheck %s --check-prefix=CHECK-VALID +# RUN: not python %s --unused-dim 2>&1 | filecheck %s --check-prefix=CHECK-UNUSED-DIM +# RUN: not python %s --vect-inconsistency 2>&1 | filecheck %s --check-prefix=CHECK-VECT +# RUN: not python %s --tile-before-axis 2>&1 | filecheck %s --check-prefix=CHECK-ORDER +# RUN: not python %s --tile-too-large 2>&1 | filecheck %s --check-prefix=CHECK-SIZE + +import sys +import xtc.graphs.xtc.op as O +from xtc.backends.mlir import Backend + +I, J, K, dtype = 4, 32, 512, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph + + +def make_scheduler(): + impl = Backend(graph) + return impl.get_scheduler() + + +if len(sys.argv) == 1: + sch = make_scheduler() + sch.set_dims(["I", "J", "K"]) + sch.tile("I", {"I0": 2}) + sch.tile("J", {"J0": 16}) + sch.interchange(["K", "I", "J", "I0", "J0"]) + sch.vectorize(["J0"]) + + loop_nest = sch.get_loop_nest() + loop_nest.check() + print("ok") + +# CHECK-VALID: ok + +elif "--unused-dim" in sys.argv: + sch = make_scheduler() + sch.set_dims(["I", "J", "K"]) + sch.tile("I", {"I0": 2}) + sch.interchange(["I", "J", "I0"]) + + loop_nest = sch.get_loop_nest() + loop_nest.check() + +# CHECK-UNUSED-DIM: K defined but never used + +elif "--vect-inconsistency" in sys.argv: + sch = make_scheduler() + sch.set_dims(["I", "J", "K"]) + sch.tile("I", {"I0": 2}) + sch.tile("J", {"J0": 16}) + sch.interchange(["K", "I", "J", "J0", "I0"]) + sch.vectorize(["J0"]) + + loop_nest = sch.get_loop_nest() + loop_nest.check() + +# CHECK-VECT: Inner loop I0 isn't vectorized but an outer one is. + +elif "--tile-before-axis" in sys.argv: + sch = make_scheduler() + sch.set_dims(["I", "J", "K"]) + sch.tile("I", {"I0": 2}) + sch.tile("J", {"J0": 16}) + sch.interchange(["K", "I0", "I", "J", "J0"]) + + loop_nest = sch.get_loop_nest() + loop_nest.check() + +# CHECK-ORDER: `I#2`: I has not been materialized yet. + +elif "--tile-too-large" in sys.argv: + sch = make_scheduler() + sch.set_dims(["I", "J", "K"]) + sch.tile("I", {"I0": 4}) + sch.tile("I", {"I00": 8}) + sch.tile("J", {"J0": 16}) + sch.interchange(["K", "I", "J", "I0", "I00", "J0"]) + + loop_nest = sch.get_loop_nest() + loop_nest.check() + +# CHECK-SIZE: Inner loop I00 on axis I must be smaller than outer loop.