Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/xtc/backends/mlir/MlirCompilerPasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ def _generate_node_scheduling(
schedule=schedule,
sched_state=sched_state,
)
if loop_name in schedule.write_buffers:
self._write_buffer(
loop_name=loop_name,
schedule=schedule,
sched_state=sched_state,
)

# Manage the strip-mining
if loop_name in schedule.vectorization:
Expand Down Expand Up @@ -537,6 +543,30 @@ def _pack_buffer(
input_idx=input_idx,
)

def _write_buffer(
self,
loop_name: str,
schedule: MlirNodeSchedule,
sched_state: SchedulingState,
):
from .MlirGraphBackend import MlirGraphBackend
from .MlirNodeBackend import MlirNodeBackend

assert self._mlir_schedule is not None
graph_backend = self._mlir_schedule.scheduler.backend
assert isinstance(graph_backend, MlirGraphBackend)
node_backend = graph_backend.nodes[schedule.node_name]
assert isinstance(node_backend, MlirNodeBackend)
output_idx = len(node_backend.np_inputs_spec())
with InsertionPoint(transform.ApplyPatternsOp(sched_state.handle).patterns):
memref.ApplyFoldMemrefAliasOpsPatternsOp()
if "sdist" in self._mlir_program.mlir_extensions:
assert sdist_transform is not None
sdist_transform.SDistLocalBufferAtOp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it do under the hood ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It uses the same SDist primitive as pack_at, which automatically creates a local buffer at a particular loop level. The transformation automatically infer which buffer is a read and/or write.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok could you test the pipeline at different levels, aka transform, then transformed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not, but SDist is not tested in CI. Maybe We should ask @guillon how to test XTC also with SDist

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if we use it for all targets we should test it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do it when SDist will be open-sourced

target=sched_state.handle,
input_idx=output_idx,
)


class MlirProgramApplyTransformPass:
def __init__(
Expand Down
7 changes: 7 additions & 0 deletions src/xtc/backends/mlir/MlirNodeScheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class MlirNodeSchedule:
parallelization: list[str]
unrolling: dict[str, int]
packed_buffers: dict[str, list[int]]
write_buffers: list[str]
memory_mesh: dict[str, int]
processor_mesh: dict[str, int]
distribution: dict[str, str]
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
self.parallelization: list[str] = []
self.unrolling: dict[str, int] = {}
self.packed_buffers: dict[str, list[int]] = {}
self.write_buffers: list[str] = []
self.memory_mesh: dict[str, int] = {}
self.processor_mesh: dict[str, int] = {}
self.distribution: dict[str, str] = {}
Expand All @@ -112,6 +114,7 @@ def mlir_node_schedule(self) -> MlirNodeSchedule:
unrolling=self.unrolling,
memory_mesh=self.memory_mesh,
packed_buffers=self.packed_buffers,
write_buffers=self.write_buffers,
processor_mesh=self.processor_mesh,
distribution=self.distribution,
distributed_buffers=self.distributed_buffers,
Expand Down Expand Up @@ -178,6 +181,10 @@ def pack_at(
else:
self.packed_buffers[axis_key].append(input_idx)

def buffer_at(self, axis: str, mtype: str | None = None, root: str = DEFAULT_ROOT):
axis_key = f"{root}{ROOT_SEP}{axis}"
self.write_buffers.append(axis_key)

def define_memory_mesh(self, axes: dict[str, int]):
assert len(self.memory_mesh) == 0, "Memory mesh has already been defined"
self.memory_mesh = axes
Expand Down
13 changes: 9 additions & 4 deletions src/xtc/backends/mlir/MlirScheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,14 @@ def interchange(self, permutation: list[str], root: str = DEFAULT_ROOT) -> None:
def buffer_at(
self, axis: str, mtype: str | None = None, root: str = DEFAULT_ROOT
) -> None:
assert mtype is None or mtype == "global"
# TODO: not implemented for now
pass
# The current implementation exclusively rely on SDist, but upstream
# transform dialect may be used for some cases.
assert mtype is None or mtype == "global" or mtype == "local"
if mtype is None or mtype == "global":
self._require_extension("sdist", weak=True)
else:
self._require_extension("sdist")
self._current_scheduler.buffer_at(axis, mtype, root=root)

@override
def pack_at(
Expand All @@ -144,7 +149,7 @@ def pack_at(
pad: bool = False,
root: str = DEFAULT_ROOT,
) -> None:
# The current implemntation exclusively rely on SDist, but upstream
# The current implementation exclusively rely on SDist, but upstream
# transform dialect may be used for some cases.
assert mtype is None or mtype == "global" or mtype == "local"
if pad:
Expand Down
10 changes: 5 additions & 5 deletions tests/filecheck/search/test_conv_oo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
utils.print_exhaustive_samples(backend, strategy, 100)

# CHECK: schedule O0: [1, 1, 1, 1, 1, 1, 1]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 1}, 'f': {'./f1': 1}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 1, './w1': 1, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 1}, 'f': {'./f1': 1}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 1, './w1': 1, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: schedule O1: [1, 1, 1, 1, 1, 1, 1]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 1}, 'f': {'./f1': 1}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 1, './w1': 1, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 1}, 'f': {'./f1': 1}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 1, './w1': 1, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: schedule O2: [1, 1, 2, 16, 1, 1, 1]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: schedule O3: [1, 1, 2, 16, 1, 1, 3]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 3}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 1, './c1': 3, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 3}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 1, './c1': 3, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: sample 0: [1, 1, 1, 1, 1, 1, 1]
# CHECK-NEXT: sample 1: [1, 1, 1, 1, 1, 1, 3]
# CHECK-NEXT: sample 2: [1, 1, 1, 1, 1, 7, 1]
Expand Down Expand Up @@ -99,4 +99,4 @@
# CHECK-NEXT: sample 76: [2, 2, 2, 8, 1, 1, 1]
# CHECK-NEXT: sample 77: [2, 2, 2, 16, 1, 1, 1]
# CHECK-NEXT: stats {'filtered': 78, 'all': 384}
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 2}, 'h': {'./h1': 2}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 2, './c1': 1, './s1': 1, './r1': 1, './b1': 2}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 2}, 'h': {'./h1': 2}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 2, './c1': 1, './s1': 1, './r1': 1, './b1': 2}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
Loading