Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ id_ed25519.pub
*.safetensors
*.model
.cline_storage
*.egg-info
5 changes: 3 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@


@pytest.fixture
def aie_context():
def aie_context(request):
"""Create a fresh AIEContext for each test"""
return AIEContext()
verbose_mlir = request.config.option.verbose > 0
return AIEContext(mlir_verbose=verbose_mlir)


def pytest_addoption(parser):
Expand Down
3 changes: 2 additions & 1 deletion iron/common/aie_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class AIEContext:
"""Context for managing AIE operator compilation and runtime state"""

def __init__(self, use_runlist=True):
def __init__(self, use_runlist=True, mlir_verbose=None):
self.operators = []
self.static_data_pool = {}
self.device_manager = AIEDeviceManager()
Expand All @@ -24,6 +24,7 @@ def __init__(self, use_runlist=True):
self.peano_dir = Path(aie.utils.config.peano_install_dir())
# Disable the XRT runlist sacrifices performance by executing kernels individually as separate xclbin invocations for easier debugging (can tell which part of runlist execution failed)
self.use_runlist = use_runlist
self.mlir_verbose = bool(mlir_verbose)
self._runtime_prepared = False

def register_operator(self, operator):
Expand Down
8 changes: 7 additions & 1 deletion iron/operators/gemv/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@
"""


def my_matvec(dev, cols, M, K, m_input, m_output=None):
def my_matvec(dev, cols, M, K, m_input, m_output=None, verbose=False):
if m_output is None:
m_output = m_input

if verbose:
print(f"Device: {dev}")
print(f"Matrix dimensions: M={M}, K={K}")
print(f"Tiling: m_input={m_input}, m_output={m_output}")
print(f"Columns: {cols}")

# The reason for the following requirement is because we first acquire output rows from the C FIFO, then fill those acquiring rows of the A input.
assert (
m_output % m_input == 0 and m_output >= m_input
Expand Down
2 changes: 2 additions & 0 deletions iron/operators/gemv/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def get_artifacts(self, prefix="gemv_"):
# The underlying MLIR design is a matrix-vector multiplication. We support vector-matrix multiplication by transposing the matrix beforehand (AB = C <=> B^T A^T = C^T).
operator_dir = Path(__file__).parent
file_name_base = f"{prefix}{self.M}x{self.K}_{self.tile_size_input}tsi_{self.tile_size_output}tso_{self.num_aie_columns}col"
mlir_verbose = getattr(self.context, "mlir_verbose", False)

mlir_artifact = PythonGeneratedMLIRArtifact.new(
f"{file_name_base}.mlir",
Expand All @@ -74,6 +75,7 @@ def get_artifacts(self, prefix="gemv_"):
self.K,
self.tile_size_input,
self.tile_size_output,
mlir_verbose,
],
)

Expand Down
20 changes: 13 additions & 7 deletions iron/operators/mha/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,9 +712,13 @@ def batched_matmul_pv(
(heads * S_q_pad, d), (number_of_pipelines_join_distribute * B_q, d), (1, 1)
)

K_tiles = TensorTiler2D.group_tiler((heads * S_kv_pad, d), (S_kv_pad, d), (1, 1))
K_tiles = TensorTiler2D.group_tiler(
(num_KV_heads * S_kv_pad, d), (S_kv_pad, d), (1, 1)
)

V_tiles = TensorTiler2D.group_tiler((heads * S_kv_pad, d), (S_kv_pad, d), (1, 1))
V_tiles = TensorTiler2D.group_tiler(
(num_KV_heads * S_kv_pad, d), (S_kv_pad, d), (1, 1)
)

O_tiles = TensorTiler2D.group_tiler(
(heads * S_q_pad, d), (number_of_pipelines_join_distribute * B_q, d), (1, 1)
Expand Down Expand Up @@ -759,9 +763,9 @@ def legalize_tas(tas: TensorAccessSequence):
if verbose:
print(f"DMA Transfer Configuration: DRAM <-> Mem tile")
# print_tap_seq_info(Q_tiles, "Q")
# print_tap_seq_info(K_tiles, "K")
# print_tap_seq_info(V_tiles, "V")
print_tap_seq_info(O_tiles, "O")
print_tap_seq_info(K_tiles, "K")
print_tap_seq_info(V_tiles, "V")
# print_tap_seq_info(O_tiles, "O")

# Runtime operations to move data to/from the AIE-array
rt = Runtime()
Expand All @@ -788,6 +792,8 @@ def set_mha_rtps():

for head_idx in range(heads):

kv_head_idx = head_idx // (heads // num_KV_heads)

for q_block_idx in range(num_q_block_per_pipeline):

# Initialize a group for parallel drain tasks, with fill resources free'd when drains complete.
Expand Down Expand Up @@ -827,14 +833,14 @@ def set_mha_rtps():
rt.fill(
inK.prod(),
K,
tap=K_tiles[head_idx],
tap=K_tiles[kv_head_idx],
placement=Tile(col=5, row=0),
task_group=tg,
)
rt.fill(
inV.prod(),
V,
tap=V_tiles[head_idx],
tap=V_tiles[kv_head_idx],
placement=Tile(col=6, row=0),
task_group=tg,
)
Expand Down
9 changes: 6 additions & 3 deletions iron/operators/mha/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def set_up_artifacts(self):

kv_heads = self.num_KV_heads if self.num_KV_heads > 0 else self.num_heads
file_name_base = f"mha_{self.num_heads}h_{kv_heads}kv_{self.seq_len}s_{self.d}d"
mlir_verbose = getattr(self.context, "mlir_verbose", False)

# Define source files
mm_source = str(self.context.base_dir / "aie_kernels" / "aie2p" / "mm.cc")
Expand Down Expand Up @@ -98,12 +99,12 @@ def set_up_artifacts(self):
"number_of_pipelines": self.num_of_pipelines,
"emulate_bf16_mmul_with_bfp16": True,
"trace_size": 0,
"verbose": False,
"verbose": mlir_verbose,
},
)

xclbin_artifact = XclbinArtifact.new(
f"mha.xclbin",
f"{file_name_base}.xclbin",
depends=[
mlir_artifact,
KernelArchiveArtifact.new(
Expand Down Expand Up @@ -139,7 +140,9 @@ def set_up_artifacts(self):
)

insts_artifact = InstsBinArtifact.new(
f"mha.bin", depends=[mlir_artifact], extra_flags=["--dynamic-objFifos"]
f"{file_name_base}.bin",
depends=[mlir_artifact],
extra_flags=["--dynamic-objFifos"],
)

self.xclbin_artifact = xclbin_artifact
Expand Down
31 changes: 19 additions & 12 deletions iron/operators/mha/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.nn.attention import SDPBackend, sdpa_kernel

import numpy as np
from ml_dtypes import bfloat16

Expand Down Expand Up @@ -59,29 +61,34 @@ def generate_golden_reference(
K = torch.rand(num_kv_heads, S_kv, d, dtype=torch.bfloat16) * val_range
V = torch.rand(num_kv_heads, S_kv, d, dtype=torch.bfloat16) * val_range

K_original = K.clone()
V_original = V.clone()

K = K.repeat_interleave(number_of_groups, dim=0)
V = V.repeat_interleave(number_of_groups, dim=0)

# MHA from PyTorch
inv_scale = 1 / np.sqrt(K.shape[-1])
O = torch.nn.functional.scaled_dot_product_attention(
Q.to(torch.bfloat16),
K.to(torch.bfloat16),
V.to(torch.bfloat16),
dropout_p=0.0,
is_causal=True,
scale=inv_scale,
)

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
O = torch.nn.functional.scaled_dot_product_attention(
Q.to(torch.bfloat16).unsqueeze(0),
K.to(torch.bfloat16).unsqueeze(0),
V.to(torch.bfloat16).unsqueeze(0),
dropout_p=0.0,
is_causal=True,
scale=inv_scale,
).squeeze(0)

# Pad all tensors to multiple of 64
Q = pad_to_multiple_of_64(Q, seq_dim=1, num_pipeline=num_pipeline)
K = pad_to_multiple_of_64(K, seq_dim=1, num_pipeline=num_pipeline)
V = pad_to_multiple_of_64(V, seq_dim=1, num_pipeline=num_pipeline)
K_original = pad_to_multiple_of_64(K_original, seq_dim=1, num_pipeline=num_pipeline)
V_original = pad_to_multiple_of_64(V_original, seq_dim=1, num_pipeline=num_pipeline)
O = pad_to_multiple_of_64(O, seq_dim=1, num_pipeline=num_pipeline)

return {
"Q": Q,
"K": K,
"V": V,
"K": K_original,
"V": V_original,
"O": O,
}
40 changes: 34 additions & 6 deletions iron/operators/mha/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,24 @@


def generate_test_params(extensive=False):
params = [(16384, 64, 1, 8)]
names = ["mha"]
# (seq_len, head_dim, heads, number_of_pipeline, num_kv_heads)

names = []

params = [(16384, 64, 1, 8, 0)]

if extensive:
params += [
(4096, 64, 8, 8, 4),
(4096, 64, 8, 8, 2),
(4096, 64, 8, 8, 0),
]

for seq_len, head_dim, heads, number_of_pipeline, num_kv_heads in params:
names += [
f"mha_{seq_len}_{head_dim}_{heads}_{number_of_pipeline}_{num_kv_heads}"
]

return params, names


Expand All @@ -35,22 +51,34 @@ def generate_test_params(extensive=False):
Latency=r"Latency \(us\): (?P<value>[\d\.]+)",
Bandwidth=r"Effective Bandwidth: (?P<value>[\d\.e\+-]+) GB/s",
)
@pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines", all_params)
def test_mha(seq_len, dim, num_heads, num_pipelines, aie_context):
@pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines,num_kv_heads", all_params)
def test_mha(
seq_len: int,
dim: int,
num_heads: int,
num_pipelines: int,
num_kv_heads: int,
aie_context,
):

print(
f"\nTest configuration: seq_len={seq_len}, dim={dim}, num_heads={num_heads}, num_pipelines={num_pipelines}, num_kv_heads={num_kv_heads}"
)

golden_ref = generate_golden_reference(
S_q=seq_len,
S_kv=seq_len,
d=dim,
heads=num_heads,
num_kv_heads=num_heads,
num_kv_heads=num_kv_heads,
num_pipeline=num_pipelines,
)

operator = AIEMHA(
num_heads=num_heads,
seq_len=seq_len,
d=dim,
num_KV_heads=num_heads,
num_KV_heads=num_kv_heads,
num_of_pipelines=num_pipelines,
context=aie_context,
)
Expand Down
2 changes: 1 addition & 1 deletion scripts/hooks/pre-push
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ echo "Checking licenses with reuse..."
if command -v reuse &> /dev/null; then
if ! reuse lint; then
echo "❌ License check failed"
echo ' Run: reuse annotate --template ApacheAMD --copyright-prefix spdx-string-c --copyright "Advanced Micro Devices, Inc. All rights reserved." --license="Apache-2.0" --recursive --skip-unrecognised ./
echo ' Run: reuse annotate --template ApacheAMD --copyright-prefix spdx-string-c --copyright "Advanced Micro Devices, Inc. All rights reserved." --license="Apache-2.0" --recursive --skip-unrecognised ./'
FAILED=1
else
echo "✅ License check passed"
Expand Down
Loading