Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions src/xtc/backends/mlir/MlirScheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import xtc.itf as itf
import xtc.backends.mlir as backend

from .MlirNodeScheduler import MlirNodeScheduler, MlirNodeSchedule
from .MlirNodeScheduler import MlirNodeScheduler, MlirNodeSchedule, basename

__all__ = [
"MlirScheduler",
Expand Down Expand Up @@ -212,36 +212,41 @@ def get_loop_nest(self) -> LoopNest:
loop_nest = LoopNest(abstract_dims=dims)
root_node = loop_nest.build_root_node(node_sched.node_name)

# Assign splits to root_node first
# Assign splits to root_node first, stripping the root prefix from names
for axis, axis_splits in node_sched.splits.items():
root_node.splits[axis] = dict(axis_splits)
root_node.splits[axis] = {basename(k): v for k, v in axis_splits.items()}

# Build mapper to get splits_info
mapper = LoopInfo.build_from_node(root_node)

def populate_node(node: LoopNestNode, perm: list[str]) -> None:
"""Populate node with data for loops in its permutation."""
node.interchange = list(perm)
perm_set = set(perm)
node.interchange = [basename(n) for n in perm]
for axis, axis_tiles in node_sched.tiles.items():
for tile_name, size in axis_tiles.items():
if tile_name in perm_set:
if axis not in node.tiles:
node.tiles[axis] = {}
node.tiles[axis][tile_name] = size
node.vectorize = [v for v in node_sched.vectorization if v in perm_set]
node.parallelize = [p for p in node_sched.parallelization if p in perm_set]
node.tiles[axis][basename(tile_name)] = size
node.vectorize = [
basename(v) for v in node_sched.vectorization if v in perm_set
]
node.parallelize = [
basename(p) for p in node_sched.parallelization if p in perm_set
]
node.unroll = {
k: v for k, v in node_sched.unrolling.items() if k in perm_set
basename(k): v for k, v in node_sched.unrolling.items() if k in perm_set
}

# Process each root in permutation
for root, perm in node_sched.permutation.items():
if root in mapper.splits_info:
root_name = basename(root)
if root_name in mapper.splits_info:
# This root is a split - create child node
axis, start, end = mapper.splits_info[root]
axis, start, end = mapper.splits_info[root_name]
child = LoopNestNode(
root=root,
root=root_name,
tiles={d: {} for d in dims},
split_origin=SplitOrigin(axis=axis, start=start, end=end),
)
Expand Down
87 changes: 87 additions & 0 deletions tests/filecheck/schedules/test_low_level_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# RUN: python %s 2>&1 | filecheck %s --check-prefix=CHECK-VALID
# RUN: not python %s --unused-dim 2>&1 | filecheck %s --check-prefix=CHECK-UNUSED-DIM
# RUN: not python %s --vect-inconsistency 2>&1 | filecheck %s --check-prefix=CHECK-VECT
# RUN: not python %s --tile-before-axis 2>&1 | filecheck %s --check-prefix=CHECK-ORDER
# RUN: not python %s --tile-too-large 2>&1 | filecheck %s --check-prefix=CHECK-SIZE

import sys
import xtc.graphs.xtc.op as O
from xtc.backends.mlir import Backend

I, J, K, dtype = 4, 32, 512, "float32"
a = O.tensor((I, K), dtype, name="A")
b = O.tensor((K, J), dtype, name="B")

with O.graph(name="matmul") as gb:
O.matmul(a, b, name="C")

graph = gb.graph


def make_scheduler():
impl = Backend(graph)
return impl.get_scheduler()


if len(sys.argv) == 1:
sch = make_scheduler()
sch.set_dims(["I", "J", "K"])
sch.tile("I", {"I0": 2})
sch.tile("J", {"J0": 16})
sch.interchange(["K", "I", "J", "I0", "J0"])
sch.vectorize(["J0"])

loop_nest = sch.get_loop_nest()
loop_nest.check()
print("ok")

# CHECK-VALID: ok

elif "--unused-dim" in sys.argv:
sch = make_scheduler()
sch.set_dims(["I", "J", "K"])
sch.tile("I", {"I0": 2})
sch.interchange(["I", "J", "I0"])

loop_nest = sch.get_loop_nest()
loop_nest.check()

# CHECK-UNUSED-DIM: K defined but never used

elif "--vect-inconsistency" in sys.argv:
sch = make_scheduler()
sch.set_dims(["I", "J", "K"])
sch.tile("I", {"I0": 2})
sch.tile("J", {"J0": 16})
sch.interchange(["K", "I", "J", "J0", "I0"])
sch.vectorize(["J0"])

loop_nest = sch.get_loop_nest()
loop_nest.check()

# CHECK-VECT: Inner loop I0 isn't vectorized but an outer one is.

elif "--tile-before-axis" in sys.argv:
sch = make_scheduler()
sch.set_dims(["I", "J", "K"])
sch.tile("I", {"I0": 2})
sch.tile("J", {"J0": 16})
sch.interchange(["K", "I0", "I", "J", "J0"])

loop_nest = sch.get_loop_nest()
loop_nest.check()

# CHECK-ORDER: `I#2`: I has not been materialized yet.

elif "--tile-too-large" in sys.argv:
sch = make_scheduler()
sch.set_dims(["I", "J", "K"])
sch.tile("I", {"I0": 4})
sch.tile("I", {"I00": 8})
sch.tile("J", {"J0": 16})
sch.interchange(["K", "I", "J", "I0", "I00", "J0"])

loop_nest = sch.get_loop_nest()
loop_nest.check()

# CHECK-SIZE: Inner loop I00 on axis I must be smaller than outer loop.
Loading