diff --git a/aie_kernels/aie2p/softmax.cc b/aie_kernels/aie2p/softmax.cc index 7d480354..64cca202 100644 --- a/aie_kernels/aie2p/softmax.cc +++ b/aie_kernels/aie2p/softmax.cc @@ -177,4 +177,12 @@ void partial_softmax_bf16(bfloat16 *restrict input, partial_softmax_alias_bf16(input, output, scale_buffer, input_size, row_idx, num_rows, scale); } +void mask_bf16(bfloat16 *inout, const int32 unmasked_size, const int32 total_size) +{ + // TODO: Optimize this to use vector code + for (int32 i = unmasked_size; i < total_size; i++) { + inout[i] = (bfloat16)(-INFINITY); + } +} + } // extern "C" \ No newline at end of file diff --git a/aie_kernels/generic/mv.cc b/aie_kernels/generic/mv.cc index 34da4550..f632e8f0 100644 --- a/aie_kernels/generic/mv.cc +++ b/aie_kernels/generic/mv.cc @@ -15,6 +15,10 @@ #include +#ifndef VEC_SIZE +#define VEC_SIZE 64 +#endif + void matvec_scalar(uint32_t m, uint32_t k, const bfloat16 *__restrict a, @@ -40,22 +44,17 @@ Matrix-vector multiplication kernel - c: Pointer to the output vector - r: Vector size; data from the matrix and vector will be loaded in and processed in chunks of this size */ -template -void matvec_vectorized(uint32_t m, - uint32_t k, - const bfloat16 *__restrict a, - const bfloat16 *__restrict b, - bfloat16 *__restrict c) +template +void matvec_vectorized(uint32_t m, const bfloat16 *__restrict a, const bfloat16 *__restrict b, bfloat16 *__restrict c) { ::aie::set_rounding(aie::rounding_mode::conv_even); bfloat16 *c_end = c + m; const bfloat16 *b_end = b + k; for (; c < c_end; c++) { aie::accum acc = aie::zeros(); - // The following two pragmas enable pipelining the zero-overhead loop, but they do assume that k is at least - // two. This assumption should hold for any useful use of this function; if k were one, this would be a simple - // scalar multiplication of a vector. - AIE_LOOP_MIN_ITERATION_COUNT(2) + // The following two pragmas enable pipelining the zero-overhead loop, but they do assume that there are at + // least two iterations of the loop, i.e. k >= 2*r. This pragma will break the code if that is not the case! + AIE_LOOP_MIN_ITERATION_COUNT(k / VEC_SIZE) for (const bfloat16 *__restrict b_cur = b; b_cur < b_end; b_cur += r, a += r) { aie::vector a_vec = aie::load_v(a); aie::vector b_vec = aie::load_v(b_cur); @@ -72,25 +71,23 @@ extern "C" { * `c`. */ void matvec_scalar_bf16_bf16(uint32_t m, - uint32_t k, uint32_t row_offset, const bfloat16 *__restrict a_in, const bfloat16 *__restrict b_in, bfloat16 *__restrict c_out) { c_out += row_offset; - matvec_scalar(m, k, a_in, b_in, c_out); + matvec_scalar(m, DIM_K, a_in, b_in, c_out); } void matvec_vectorized_bf16_bf16(uint32_t m, - uint32_t k, uint32_t row_offset, const bfloat16 *__restrict a_in, const bfloat16 *__restrict b_in, bfloat16 *__restrict c_out) { c_out += row_offset; - matvec_vectorized<64>(m, k, a_in, b_in, c_out); + matvec_vectorized(m, a_in, b_in, c_out); } } // extern "C" \ No newline at end of file diff --git a/conftest.py b/conftest.py index 2f4ab726..1a3c0e89 100644 --- a/conftest.py +++ b/conftest.py @@ -16,7 +16,9 @@ @pytest.fixture def aie_context(): """Create a fresh AIEContext for each test""" - return AIEContext() + ctx = AIEContext() + yield ctx + ctx.device_manager.reset() def pytest_addoption(parser): diff --git a/iron/applications/llama_3.2_1b/analyze_profile.py b/iron/applications/llama_3.2_1b/analyze_profile.py deleted file mode 100644 index 7e2c76f1..00000000 --- a/iron/applications/llama_3.2_1b/analyze_profile.py +++ /dev/null @@ -1,358 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Analyze profiling logs generated by inference.py - -This script parses the profile logs and provides statistics about function execution times. -The total times reported in the analysis results are the cumulative times of the functions, including subcalls. -""" - -import argparse -import re -from collections import defaultdict -from pathlib import Path -import sys -import csv -import statistics -from collections import deque - - -class FunctionStats: - def __init__(self, name): - self.name = name - self.call_count = 0 - self.total_time = 0.0 - self.min_time = float("inf") - self.max_time = 0.0 - self.durations = [] - - def add_duration(self, duration): - self.call_count += 1 - self.total_time += duration - self.min_time = min(self.min_time, duration) - self.max_time = max(self.max_time, duration) - self.durations.append(duration) - - @property - def avg_time(self): - if not self.durations: - return 0.0 - return statistics.mean(self.durations) - - @property - def median_time(self): - if not self.durations: - return 0.0 - return statistics.median(self.durations) - - -def parse_profile_log(log_file): - """ - Parse a profile log file and extract function timing information. - - Args: - log_file: Path to the profile log file - - Returns: - dict: Dictionary mapping function names to FunctionStats objects - """ - stats = defaultdict(lambda: FunctionStats("")) - function_stack = deque() # Track ongoing calls by function identifier - - # Regex patterns for parsing log lines - call_pattern = re.compile(r"\[CALL\] (.+?) started at ([\d.]+)") - return_pattern = re.compile(r"\[RETURN\] (.+?) ended at ([\d.]+)") - - with open(log_file, "r") as f: - for line in f: - # Try to match CALL pattern - call_match = call_pattern.search(line) - if call_match: - func_id = call_match.group(1) - timestamp = float(call_match.group(2)) - function_stack.append((func_id, timestamp)) - continue - - # Try to match RETURN pattern - return_match = return_pattern.search(line) - if return_match: - func_id = return_match.group(1) - timestamp = float(return_match.group(2)) - - # Use the full function identifier (filepath:function_name:line_no) - if func_id not in stats: - stats[func_id].name = func_id - - if function_stack: - stats[func_id].add_duration(timestamp - function_stack.pop()[1]) - else: - raise RuntimeError( - f"Stack empty, found a log for the return but missing a log for the call of {func_id}" - ) - - return dict(stats) - - -def print_summary(stats, sort_by="total", top_n=20, min_calls=1): - """ - Print a summary of function statistics. - - Args: - stats: Dictionary of function statistics - sort_by: Sort criterion ('total', 'avg', 'max', 'calls') - top_n: Number of top functions to display - min_calls: Minimum number of calls to include in results - """ - # Filter by minimum calls - filtered_stats = {name: s for name, s in stats.items() if s.call_count >= min_calls} - - if not filtered_stats: - print("No functions found matching the criteria.") - return - - # Sort functions - sort_keys = { - "total": lambda x: x[1].total_time, - "avg": lambda x: x[1].avg_time, - "max": lambda x: x[1].max_time, - "calls": lambda x: x[1].call_count, - } - - if sort_by not in sort_keys: - print(f"Invalid sort criterion: {sort_by}. Using 'total'.") - sort_by = "total" - - sorted_stats = sorted(filtered_stats.items(), key=sort_keys[sort_by], reverse=True) - - # Print header - print("\n" + "=" * 160) - print(f"FUNCTION PROFILING SUMMARY (sorted by {sort_by}, top {top_n})") - print("=" * 160) - print( - f"{'Function Identifier':<80} {'Calls':>8} {'Total (s)':>12} {'Avg (s)':>12} {'Min (s)':>12} {'Max (s)':>12} {'Median (s)':>12}" - ) - print("-" * 160) - - # Print top N functions - for func_name, func_stats in sorted_stats[:top_n]: - # Truncate long function identifiers for display - display_name = func_name if len(func_name) <= 80 else func_name[:77] + "..." - print( - f"{display_name:<80} {func_stats.call_count:>8} " - f"{func_stats.total_time:>12.6f} {func_stats.avg_time:>12.6f} " - f"{func_stats.min_time:>12.6f} {func_stats.max_time:>12.6f} " - f"{func_stats.median_time:>12.6f}" - ) - - print("-" * 160) - - -def print_function_details(stats, function_name): - """ - Print detailed statistics for functions matching the given name. - - Args: - stats: Dictionary of function statistics - function_name: Name or substring to search for in function identifiers - """ - # Find all function identifiers containing the function_name string - matching_funcs = { - func_id: func_stats - for func_id, func_stats in stats.items() - if function_name in func_id - } - - if not matching_funcs: - print(f"No functions found containing '{function_name}' in profile data.") - print(f"\nAvailable functions (showing first 20):") - for i, name in enumerate(sorted(stats.keys())[:20]): - print(f" - {name}") - if len(stats) > 20: - print(f" ... and {len(stats) - 20} more") - return - - print("\n" + "=" * 120) - print(f"DETAILED STATISTICS FOR FUNCTIONS CONTAINING: '{function_name}'") - print(f"Found {len(matching_funcs)} matching function(s)") - print("=" * 120) - - for func_id, func_stats in sorted( - matching_funcs.items(), key=lambda x: x[1].total_time, reverse=True - ): - print(f"\nFunction: {func_id}") - print("-" * 120) - print(f" Total calls: {func_stats.call_count:,}") - print(f" Total time: {func_stats.total_time:.6f} seconds") - print(f" Average time: {func_stats.avg_time:.6f} seconds") - print(f" Median time: {func_stats.median_time:.6f} seconds") - print(f" Min time: {func_stats.min_time:.6f} seconds") - print(f" Max time: {func_stats.max_time:.6f} seconds") - - if func_stats.call_count > 1: - std_dev = statistics.stdev(func_stats.durations) - print(f" Std deviation: {std_dev:.6f} seconds") - - print("=" * 120 + "\n") - - -def export_to_csv(stats, output_file, sort_by="total", min_calls=1): - """ - Export function statistics to a CSV file. - - Args: - stats: Dictionary of function statistics - output_file: Path to output CSV file - sort_by: Sort criterion ('total', 'avg', 'max', 'calls') - min_calls: Minimum number of calls to include in results - """ - # Filter by minimum calls - filtered_stats = {name: s for name, s in stats.items() if s.call_count >= min_calls} - - if not filtered_stats: - print("No functions found matching the criteria.") - return - - # Sort functions - sort_keys = { - "total": lambda x: x[1].total_time, - "avg": lambda x: x[1].avg_time, - "max": lambda x: x[1].max_time, - "calls": lambda x: x[1].call_count, - } - - if sort_by not in sort_keys: - print(f"Invalid sort criterion: {sort_by}. Using 'total'.") - sort_by = "total" - - sorted_stats = sorted(filtered_stats.items(), key=sort_keys[sort_by], reverse=True) - - # Write to CSV - with open(output_file, "w", newline="") as csvfile: - fieldnames = [ - "function_name", - "call_count", - "total_time_seconds", - "avg_time_seconds", - "median_time_seconds", - "min_time_seconds", - "max_time_seconds", - "std_dev_seconds", - ] - - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - writer.writeheader() - - for func_name, func_stats in sorted_stats: - std_dev = ( - statistics.stdev(func_stats.durations) - if func_stats.call_count > 1 - else 0.0 - ) - - writer.writerow( - { - "function_name": func_name, - "call_count": func_stats.call_count, - "total_time_seconds": f"{func_stats.total_time:.9f}", - "avg_time_seconds": f"{func_stats.avg_time:.9f}", - "median_time_seconds": f"{func_stats.median_time:.9f}", - "min_time_seconds": f"{func_stats.min_time:.9f}", - "max_time_seconds": f"{func_stats.max_time:.9f}", - "std_dev_seconds": f"{std_dev:.9f}", - } - ) - - print(f"\nCSV file saved to: {output_file}") - print(f"Total functions exported: {len(sorted_stats)}") - - -def main(): - parser = argparse.ArgumentParser( - description="Analyze profiling logs from inference.py", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Analyze the most recent profile log - python analyze_profile.py - - # Analyze a specific log file - python analyze_profile.py logs/profile_20250110_160000.log - - # Sort by average time and show top 30 - python analyze_profile.py --sort avg --top 30 - - # Show details for a specific function - python analyze_profile.py --function inference - - # Filter functions with at least 10 calls - python analyze_profile.py --min-calls 10 - - # Export to CSV file - python analyze_profile.py --csv profile_stats.csv - - # Export to CSV with custom sorting and filtering - python analyze_profile.py --csv results.csv --sort avg --min-calls 5 - """, - ) - - parser.add_argument( - "log_file", - type=str, - help="Path to profile log file", - ) - parser.add_argument( - "--sort", - choices=["total", "avg", "max", "calls"], - default="total", - help="Sort criterion (default: total)", - ) - parser.add_argument( - "--top", - type=int, - default=20, - help="Number of top functions to display (default: 20)", - ) - parser.add_argument( - "--min-calls", - type=int, - default=1, - help="Minimum number of calls to include (default: 1)", - ) - parser.add_argument( - "--function", type=str, help="Show detailed statistics for a specific function" - ) - parser.add_argument( - "--csv", - type=str, - help="Export results to CSV file instead of printing to console", - ) - - args = parser.parse_args() - - # Parse the log file - log_file = Path(args.log_file) - print(f"Parsing {log_file}...") - stats = parse_profile_log(log_file) - - if not stats: - print("No profiling data found in log file.") - else: - print(f"Found {len(stats)} unique functions") - - # Show results - if args.csv: - # Export to CSV - export_to_csv(stats, args.csv, sort_by=args.sort, min_calls=args.min_calls) - elif args.function: - # Show detailed function statistics - print_function_details(stats, args.function) - else: - # Print summary to console - print_summary( - stats, sort_by=args.sort, top_n=args.top, min_calls=args.min_calls - ) - - -if __name__ == "__main__": - main() diff --git a/iron/applications/llama_3.2_1b/configs/llama32_1b.json b/iron/applications/llama_3.2_1b/configs/llama32_1b.json deleted file mode 100644 index ed6bc4bf..00000000 --- a/iron/applications/llama_3.2_1b/configs/llama32_1b.json +++ /dev/null @@ -1,39 +0,0 @@ -{ - "model_config": { - "vocab_size": 128256, - "context_length": 131072, - "emb_dim": 2048, - "n_heads": 32, - "n_layers": 16, - "hidden_dim": 8192, - "n_kv_groups": 8, - "use_kv_cache": true, - "rope_base": 500000.0, - "dtype": "bfloat16", - "use_aie_final_norm": true, - "use_aie_ffn_gemm": false, - "use_aie_ffn_silu": false, - "use_aie_ffn_mul": false, - "use_aie_ffn_swiglu": true, - "use_aie_ffn_gemv": true, - "use_aie_attn_projection_gemm": true, - "use_aie_gqa_gemv": true, - "use_aie_rope": true, - "use_aie_norm1": true, - "use_aie_norm2": true, - "use_aie_residual": true, - "use_aie_regular_mha": false, - "use_aie_fused_mha": true, - "use_aie_final_gemm": true, - "use_aie_final_gemv": true, - "rope_freq": { - "factor": 32.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_context_length": 8192 - } - }, - "aie_config": { - "device": "npu2" - } -} \ No newline at end of file diff --git a/iron/applications/llama_3.2_1b/configs/llama32_1b.json.license b/iron/applications/llama_3.2_1b/configs/llama32_1b.json.license deleted file mode 100644 index 50daea92..00000000 --- a/iron/applications/llama_3.2_1b/configs/llama32_1b.json.license +++ /dev/null @@ -1,7 +0,0 @@ -Copyright (c) Sebastian Raschka under Apache License 2.0. -Source for "Build a Large Language Model From Scratch" - - https://www.manning.com/books/build-a-large-language-model-from-scratch -Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb - -SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -SPDX-License-Identifier: Apache-2.0 diff --git a/iron/applications/llama_3.2_1b/inference.py b/iron/applications/llama_3.2_1b/inference.py deleted file mode 100755 index 8109c543..00000000 --- a/iron/applications/llama_3.2_1b/inference.py +++ /dev/null @@ -1,445 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import sys -from pathlib import Path - -import argparse -import time -import torch -from src.model_with_json import Llama3ModelWithJSONConfig - -# from src.model import Llama3Model -from src.tokenizer import Tokenizer, ChatFormat -from safetensors.torch import load_file -import os -import shutil -import logging -from collections import deque - -from iron.common import AIEOperatorBase -from src.utils import ( - model_memory_size, - load_weights_into_llama, - text_to_token_ids, - token_ids_to_text, - clean_text, - generate, -) - -# Global logger for profiling -_profile_logger = None - - -def profile_function_calls(frame, event, arg): - """ - Profile function that logs start and end times of every function call. - - Args: - frame: The current stack frame - event: The event type ('call', 'return', 'c_call', 'c_return', 'c_exception') - arg: Event-specific argument - """ - global _profile_logger - - if _profile_logger is None: - return - - func_name = frame.f_code.co_name - filename = frame.f_code.co_filename - line_no = frame.f_lineno - - # Create a readable function identifier - func_identifier = f"{filename}:{func_name}:{line_no}" - - if event == "call": - # Function is being called - timestamp = time.perf_counter() - _profile_logger.debug(f"[CALL] {func_identifier} started at {timestamp:.9f}") - - elif event == "return": - # Function is returning - timestamp = time.perf_counter() - _profile_logger.debug(f"[RETURN] {func_identifier} ended at {timestamp:.9f}") - - return profile_function_calls - - -def enable_profiling(logs_dir_name): - """Enable function call profiling using sys.setprofile.""" - global _profile_logger - - # Create a dedicated logger for profiling - _profile_logger = logging.getLogger("function_profiler") - _profile_logger.setLevel(logging.DEBUG) - # Prevent propagation to root logger to avoid console output - _profile_logger.propagate = False - - # Create log file for profiling data - timestamp = time.strftime("%Y%m%d_%H%M%S") - log_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - logs_dir_name, - f"profile_{timestamp}.log", - ) - - # Add file handler for profiling (only file, no console output) - profile_handler = logging.FileHandler(log_path) - profile_handler.setLevel(logging.DEBUG) - profile_formatter = logging.Formatter("%(asctime)s - %(message)s") - profile_handler.setFormatter(profile_formatter) - _profile_logger.addHandler(profile_handler) - - # Set the profile function - sys.setprofile(profile_function_calls) - _profile_logger.info("Function profiling enabled") - - # Explicitly call profile_function_calls to log this function's call - import inspect - - frame = inspect.currentframe() - profile_function_calls(frame, "call", None) - - -def disable_profiling(): - """Disable function call profiling.""" - global _profile_logger - - sys.setprofile(None) - if _profile_logger: - _profile_logger.info("Function profiling disabled") - # Close all handlers - for handler in _profile_logger.handlers[:]: - handler.close() - _profile_logger.removeHandler(handler) - - -_iron_chat = r""" - /$$$$$$ /$$$$$$$ /$$$$$$ /$$ /$$ - |_ $$_/| $$__ $$ /$$__ $$| $$$ | $$ - | $$ | $$ \ $$| $$ \ $$| $$$$| $$ - | $$ | $$$$$$$/| $$ | $$| $$ $$ $$ - | $$ | $$__ $$| $$ | $$| $$ $$$$ - | $$ | $$ \ $$| $$ | $$| $$\ $$$ - /$$$$$$| $$ | $$| $$$$$$/| $$ \ $$ - |______/|__/ |__/ \______/ |__/ \__/ - - - /$$ /$$ /$$$$$$ /$$ /$$ /$$$$$$ -| $$ | $$ /$$__ $$| $$$ /$$$ /$$__ $$ -| $$ | $$ | $$ \ $$| $$$$ /$$$$| $$ \ $$ -| $$ | $$ | $$$$$$$$| $$ $$/$$ $$| $$$$$$$$ -| $$ | $$ | $$__ $$| $$ $$$| $$| $$__ $$ -| $$ | $$ | $$ | $$| $$\ $ | $$| $$ | $$ -| $$$$$$$$| $$$$$$$$| $$ | $$| $$ \/ | $$| $$ | $$ -|________/|________/|__/ |__/|__/ |__/|__/ |__/ -""" - - -def setup_logging(verbosity): - """Set up logging based on verbosity level.""" - - # Ensure the logs directory is created in case of profiling - logs_dir_name = "logs" - if not os.path.exists(logs_dir_name): - os.makedirs(logs_dir_name) - - if verbosity != 0: - levels = { - 4: logging.DEBUG, - 3: logging.INFO, - 2: logging.WARNING, - # 1: log everything (DEBUG) to a file - } - - # Create log file - timestamp = time.strftime("%Y%m%d_%H%M%S") - log_file = f"logs/inference_{timestamp}.log" - - handlers = [logging.FileHandler(log_file)] - if verbosity > 0: - handlers.append(logging.StreamHandler(sys.stderr)) - handlers[-1].setLevel(levels[verbosity]) - - # Configure root logger - logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=handlers, - force=True, # Override any existing configuration - ) - - return logs_dir_name - - -def save_layer_data(module, input, output, name, input_data_path, output_data_path): - for count, i in enumerate(input): - torch.save( - i.detach(), - f"{input_data_path}/{name}_input_{count}_{input[0].size()[1]}_toks.pt", - ) - torch.save( - output.detach(), f"{output_data_path}/{name}_output_{output.size()[1]}_toks.pt" - ) - - -def inference( - weights_file_path, - tokenizer_file_path, - num_tokens, - prompt, - use_prompt_template, - save_outputs, - chat: bool, - prompt_len: int = 64, -): - """ - Main function to load a Llama3 model, process input, and generate output text. - """ - logging.info("Weights file path: %s", weights_file_path) - logging.info("Tokenizer file path: %s", tokenizer_file_path) - logging.info("Number of tokens: %d", num_tokens) - logging.debug("Prompt: %s", prompt) - logging.info("Use prompt template: %s", use_prompt_template) - logging.info("Save outputs: %s", save_outputs) - torch.manual_seed(1608560892) - input_data_path = "results/inputs" - output_data_path = "results/outputs" - - tokenizer = Tokenizer(tokenizer_file_path) - - print(_iron_chat) - if chat: - prompt = input("Enter your prompt: ").strip() - print("") - - logging.info(f"Loading model and tokenizer...") - token_ids = text_to_token_ids(prompt, tokenizer)[:, :prompt_len] - truncated_prompt = token_ids_to_text(token_ids, tokenizer) - - script_dir = os.path.dirname(os.path.abspath(__file__)) - config_path = os.path.join(script_dir, "configs", "llama32_1b.json") - model = Llama3ModelWithJSONConfig( - config_path=config_path, - prompt_length=prompt_len, - num_tokens=num_tokens, - ) - logging.info("Model and tokenizer loaded.") - - # Important: Set the seed again after initialization of the model. Each - # call that initializes an nn.Linear layer updates the RNG state, because - # weights are initialized with random values. For different JSON - # configurations, we initialize a different number of linear layers, - # so different configurations result in a different RNG state here. Since - # we use random numbers to sample from the token distribution during - # inference, it is important to have the same RNG state between runs so we - # can have reproducible results across configurations. - torch.manual_seed(1608560892) - - hook_handles = [] - if save_outputs: - if os.path.exists(output_data_path): - shutil.rmtree(output_data_path) - os.makedirs(output_data_path) - if os.path.exists(input_data_path): - shutil.rmtree(input_data_path) - os.makedirs(input_data_path) - for name, module in model.named_modules(): - handle = module.register_forward_hook( - lambda module, input, output, name=name, input_data_path=input_data_path, output_data_path=output_data_path: ( - save_layer_data( - module, input, output, name, input_data_path, output_data_path - ) - ) - ) - hook_handles.append(handle) - - device = torch.device("cpu") - model.to(device) - chat_tokenizer = ChatFormat(tokenizer) - - total_params = sum(p.numel() for p in model.parameters()) - total_params_normalized = total_params - model.tok_emb.weight.numel() - logging.info(f"Total number of parameters: {total_params:,}") - logging.info(f"Total number of unique parameters: {total_params_normalized:,}") - logging.info( - f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB" - ) - logging.info( - f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB" - ) - - combined_weights = load_file(weights_file_path) - # Get parameters from model config - model_config = { - "n_layers": model.cfg["n_layers"], - "emb_dim": model.cfg["emb_dim"], - "n_heads": model.cfg["n_heads"], - "n_kv_groups": model.cfg["n_kv_groups"], - "vocab_size": model.cfg["vocab_size"], - "context_length": model.cfg["context_length"], - "hidden_dim": model.cfg["hidden_dim"], - "rope_base": model.cfg["rope_base"], - "dtype": model.cfg["dtype"], - "rope_freq": model.cfg["rope_freq"], - } - load_weights_into_llama(model, model_config, combined_weights) - model.to(device) - del combined_weights - - logging.info("Preparing AIE operators...") - # At this point the model is fully described (operators and their dimensions and how to compile them) - AIEOperatorBase.get_default_context().compile_all() - AIEOperatorBase.get_default_context().prepare_runtime() - logging.info("AIE operator preparation completed.") - print(f"Starting text generation...") - print(f"Generating {num_tokens} tokens...") - print("=" * 55) - - prefill_end_time = None - - def set_prefill_time(): - nonlocal prefill_end_time - prefill_end_time = time.time() - - # Start total wall clock timing - start = time.time() - token_ids = generate( - model=model, - idx=token_ids.to(device), - max_new_tokens=num_tokens, - context_size=model.cfg["context_length"], - eos_id=tokenizer.special["<|end_of_text|>"], - hook_handles=hook_handles, - temperature=0.7, - top_k=50, - tokenizer=tokenizer, - prompt=truncated_prompt, - prefill_done_callback=set_prefill_time, - ) - end = time.time() - prefill_time = prefill_end_time - start - total_time = end - start - post_prefill_time = end - prefill_end_time if num_tokens > 0 else 0 - - tokens_per_second = (num_tokens - 1) / post_prefill_time if num_tokens > 1 else 0 - time_per_token = total_time / (num_tokens - 1) if num_tokens > 1 else prefill_time - - print("=" * 55) - print(" TIMING RESULTS:") - print(f" Total time: {total_time:.4f} seconds") - print(f" Prefill time: {prefill_time:.4f} seconds") - print(f" Tokens generated: {num_tokens}") - print(f" Tokens per second: {tokens_per_second:.2f}") - print( - f" Time per token: {time_per_token:.4f} seconds" - if num_tokens > 0 - else " Time per token: N/A" - ) - print("=" * 55) - - logging.info(f"Generation time: {total_time:.4f} sec") - logging.info(f"Total wall clock time: {total_time:.4f} sec") - logging.info(f"Tokens per second: {tokens_per_second:.2f}") - logging.info( - f"Time per token: {time_per_token:.4f} sec" - if num_tokens > 0 - else "Time per token: N/A" - ) - - output_text = token_ids_to_text(token_ids, tokenizer) - logging.info("Output text:\n %s", clean_text(output_text)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run Llama3 model inference.") - parser.add_argument( - "weights_file_path", - type=str, - help="Path to the weights file: model.safetensors", - ) - parser.add_argument( - "tokenizer_file_path", - type=str, - help="Path to the tokenizer file: tokenizer.model", - ) - parser.add_argument( - "--num_tokens", type=int, default=1, help="Number of tokens to predict." - ) - parser.add_argument( - "--prompt", - type=str, - default="", - help="Prompt for the model to generate text from.", - ) - parser.add_argument( - "--use_prompt_template", - action="store_true", - help="Use a prompt template for the model.", - ) - parser.add_argument( - "--save_outputs", - action="store_true", - help="Enable hooks to save outputs of the layers in the model", - ) - parser.add_argument( - "--chat", - action="store_true", - help="Enable interactive mode to enter your own prompt.", - ) - parser.add_argument( - "--prompt_len", - type=int, - default=2048, - help="Truncate prompt to this many tokens.", - ) - parser.add_argument( - "--profile", - action="store_true", - help="Use a custom profiler for performance measurements", - ) - parser.add_argument( - "-v", - action="count", - default=0, - help="Increase verbosity level (use -v (logs to file), -vv, -vvv, or -vvvv)", - ) - args = parser.parse_args() - - # Set up logging - logs_dir_name = setup_logging(args.v) - - # Enable function profiling - if args.profile: - enable_profiling(logs_dir_name) - - try: - prompt = args.prompt - if not prompt: - # Default prompt is text from Shakespeare's King Lear: https://shakespeare.mit.edu/lear/lear.1.1.html - prompt_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "prompt.txt" - ) - with open(prompt_path, "r", encoding="utf-8") as file: - prompt = file.read().strip() - - inference( - args.weights_file_path, - args.tokenizer_file_path, - args.num_tokens, - prompt, - args.use_prompt_template, - args.save_outputs, - args.chat, - args.prompt_len, - ) - finally: - if args.profile: - # Disable profiling when done - disable_profiling() diff --git a/iron/applications/llama_3.2_1b/llama_cpu.py b/iron/applications/llama_3.2_1b/llama_cpu.py new file mode 100755 index 00000000..0def104a --- /dev/null +++ b/iron/applications/llama_3.2_1b/llama_cpu.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import math +import llama_inference_harness as harness + +# Operators +# ########################################################################## + + +def rope_forward(x, angles): + """Rotary positional embedding using precomputed angles""" + # x: (batch, seq_len, num_heads, head_dim) after view and before transpose + # angles: (context_length, head_dim) + _, seq_len, _, head_dim = x.shape + angles_slice = angles[:seq_len] # (seq_len, head_dim) + + # Split into even and odd dimensions + x1 = x[..., : head_dim // 2] # (batch, seq_len, num_heads, head_dim//2) + x2 = x[..., head_dim // 2 :] # (batch, seq_len, num_heads, head_dim//2) + + # Get cos and sin from angles + cos = angles_slice[:, ::2] # (seq_len, head_dim//2) + sin = angles_slice[:, 1::2] # (seq_len, head_dim//2) + + # Reshape for broadcasting: (1, seq_len, 1, head_dim//2) + # (The same cosine and sine values are used across batch and heads.) + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + # Rotate: [x1*cos - x2*sin, x1*sin + x2*cos] + rotated = torch.empty_like(x) + rotated[..., : head_dim // 2] = x1 * cos - x2 * sin + rotated[..., head_dim // 2 :] = x1 * sin + x2 * cos + + return rotated + + +def rms_norm_forward(x, weight, eps=1e-5): + """Root Mean Square Layer Normalization""" + # x: (batch, seq_len, dim) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + return weight * x + + +def grouped_query_attention_forward( + x, + keys_cache, + values_cache, + W_query, + W_key, + W_value, + W_out, + angles, + mask=None, + num_heads=32, + num_kv_groups=8, +): + batch, seq_len, d_in = x.shape + assert W_query.shape[0] >= num_heads and W_query.shape[0] % num_heads == 0 + head_dim = W_query.shape[0] // num_heads + assert W_key.shape[0] == num_kv_groups * head_dim + assert W_value.shape[0] == num_kv_groups * head_dim + num_preceding_tokens = keys_cache.shape[2] + assert keys_cache.shape == (batch, num_kv_groups, num_preceding_tokens, head_dim) + assert values_cache.shape == (batch, num_kv_groups, num_preceding_tokens, head_dim) + + # Step 1: Linear projections + # This multiplication produces queries, keys and values for all tokens in the sequence. + # The weight matrix is such that multiple queries, keys and values are generated for each token. + # For each token, each head corresponds to one query. + # In particular, each token gets `num_heads` queries and `num_kv_groups` keys/values (keys/values shared for multiple queries). + # Due to the structure of the matmul, all queries, keys and values are contiguous for each token. + # Note that during the decode phase, seq_len=1, and we are only calculating the projections for the most recent token -- the keys and values of previous tokens will be concatenated in step 4. + queries = torch.nn.functional.linear( + x, W_query + ) # (batch, seq_len, num_heads * head_dim) + keys = torch.nn.functional.linear( + x, W_key + ) # (batch, seq_len, num_kv_groups * head_dim) + values = torch.nn.functional.linear( + x, W_value + ) # (batch, seq_len, num_kv_groups * head_dim) + queries = queries.view( + batch, seq_len, num_heads, head_dim + ) # (batch, seq_len, num_heads, head_dim) + keys = keys.view( + batch, seq_len, num_kv_groups, head_dim + ) # (batch, seq_len, num_kv_groups, head_dim) + values = values.view( + batch, seq_len, num_kv_groups, head_dim + ) # (batch, seq_len, num_kv_groups, head_dim) + + # Step 2: Apply RoPE + queries = rope_forward( + queries, angles[num_preceding_tokens : num_preceding_tokens + seq_len] + ) + keys = rope_forward( + keys, angles[num_preceding_tokens : num_preceding_tokens + seq_len] + ) + + # Step 3: Transpose for attention computation + # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. + # Transpose so that heads are consecutive for attention computation: (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) + queries = queries.transpose(1, 2) # (batch, num_heads, seq_len, head_dim) + keys = keys.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) + values = values.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) + + # Step 4: Combine newly computed keys/values for most recent token with cache; these values are used as the updated cache and will be returned to use in the next iteration. + keys_cache = torch.cat([keys_cache, keys], dim=2) + values_cache = torch.cat([values_cache, values], dim=2) + keys = keys_cache + values = values_cache + + # Step 5: Repeat keys and values for grouped attention -- multiple queries get the same key/value + group_size = num_heads // num_kv_groups + keys = keys.repeat_interleave(group_size, dim=1) + values = values.repeat_interleave(group_size, dim=1) + + # Step 6: Compute attention scores + # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, seq_len) + # -> (batch, num_heads, seq_len, seq_len) + # Entry at row i, column j, indicates how much token i's query attends to token j's key. + scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(head_dim) + + # Step 7: Apply mask + # This ensures causality, so that tokens in the future cannot attend to tokens in the past. + if mask is not None: + scores = scores.masked_fill(mask, float("-inf")) + + # Step 8: Apply softmax to squeeze scores into probabilities (0, 1) + attention_weights = torch.nn.functional.softmax(scores, dim=-1) + + # Step 9: Compute attention output + # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim) + # -> (batch, num_heads, seq_len, head_dim) + context = torch.matmul(attention_weights, values) + + # Step 10: Concatenate heads and project + # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads * head_dim) + context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) + + output = torch.nn.functional.linear(context, W_out) + + return output, keys_cache, values_cache + + +def swiglu_ffn_forward(x, fc1_weight, fc2_weight, fc3_weight): + # Step 1: Parallel projections: (batch, seq_len, embedding_dim) -> (batch, seq_len, swiglu_hidden_dim) + gate = torch.nn.functional.linear(x, fc1_weight) # gate projection + up = torch.nn.functional.linear(x, fc2_weight) # up projection + + # Step 2: Apply SiLU activation + gate_activated = torch.nn.functional.silu( + gate + ) # (batch, seq_len, swiglu_hidden_dim) + + # Step 3: Element-wise multiplication (apply the 'gating') + hidden = gate_activated * up # (batch, seq_len, swiglu_hidden_dim) + + # Step 4: Down projection: (batch, seq_len, swiglu_hidden_dim) -> (batch, seq_len, embedding_dim) + output = torch.nn.functional.linear(hidden, fc3_weight) + + return output + + +def transformer_block_forward( + x, + attn_keys_cache, + attn_values_cache, + num_heads, + num_kv_groups, + W_norm1, + W_attn_query, + W_attn_key, + W_attn_value, + W_attn_out, + W_norm2, + W_ffn_fc1, + W_ffn_fc2, + W_ffn_fc3, + rope_angles, + attn_mask, +): + # Step 1: RMS normalization + x_norm = rms_norm_forward(x, W_norm1) + + # Step 2: Attention + attn_output, attn_keys, attn_values = grouped_query_attention_forward( + x_norm, + attn_keys_cache, + attn_values_cache, + W_attn_query, + W_attn_key, + W_attn_value, + W_attn_out, + rope_angles, + attn_mask, + num_heads, + num_kv_groups, + ) + + # Step 3: Residual + x = x + attn_output + + # Step 4: Post-norm + x_norm = rms_norm_forward(x, W_norm2) + + # Step 5: fully-connected feed-forward network + ffn_output = swiglu_ffn_forward(x_norm, W_ffn_fc1, W_ffn_fc2, W_ffn_fc3) + + # Step 6: Residual + x = x + ffn_output + + return x, attn_keys, attn_values + + +def llama_forward_pass(config, state): + batch, seq_len = state.token_ids.shape + + # Step 1: Token embedding + tok_emb_weight = config.weights["model.embed_tokens.weight"] + x = torch.nn.functional.embedding( + state.token_ids, tok_emb_weight + ) # (batch, seq_len, emb_dim) + + # Step 2: Create causal mask + attn_mask = torch.triu( + torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1 + ) + + # Step 3: Apply transformer blocks + for layer_idx in range(config.n_layers): + x, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = ( + transformer_block_forward( + x, + state.attn_keys_caches[layer_idx], + state.attn_values_caches[layer_idx], + config.n_heads, + config.n_kv_groups, + W_norm1=config.weights[ + f"model.layers.{layer_idx}.input_layernorm.weight" + ], + W_attn_query=config.weights[ + f"model.layers.{layer_idx}.self_attn.q_proj.weight" + ], + W_attn_key=config.weights[ + f"model.layers.{layer_idx}.self_attn.k_proj.weight" + ], + W_attn_value=config.weights[ + f"model.layers.{layer_idx}.self_attn.v_proj.weight" + ], + W_attn_out=config.weights[ + f"model.layers.{layer_idx}.self_attn.o_proj.weight" + ], + W_ffn_fc1=config.weights[ + f"model.layers.{layer_idx}.mlp.gate_proj.weight" + ], + W_ffn_fc2=config.weights[ + f"model.layers.{layer_idx}.mlp.up_proj.weight" + ], + W_ffn_fc3=config.weights[ + f"model.layers.{layer_idx}.mlp.down_proj.weight" + ], + W_norm2=config.weights[ + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ], + rope_angles=config.angles, + attn_mask=attn_mask, + ) + ) + + # Step 4: Final normalization + final_norm_weight = config.weights["model.norm.weight"] + x = rms_norm_forward(x, final_norm_weight) + + # Step 5: Output projection + logits = torch.nn.functional.linear( + x, config.weights["model.embed_tokens.weight"] + ) # (batch, seq_len, vocab_size) + + return logits, state + + +# Main +# ########################################################################## + + +def main(): + args = harness.parse_args() + prompt = harness.get_prompt(args.prompt_len) + config, state = harness.init(args.weights_path, args.tokenizer_path, prompt=prompt) + print(prompt, end="", flush=True) + harness.generate(config, state, llama_forward_pass, num_tokens=args.num_tokens) + + +if __name__ == "__main__": + main() diff --git a/iron/applications/llama_3.2_1b/llama_inference_harness.py b/iron/applications/llama_3.2_1b/llama_inference_harness.py new file mode 100644 index 00000000..2e7b7f7a --- /dev/null +++ b/iron/applications/llama_3.2_1b/llama_inference_harness.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Inference harness -- all the necessary code _other_ than the actual model (forward pass). +Exposes a 'harness' function that can be called with a 'forward_pass' function that implements the model. +The 'harness' function does the following: +1. Load and set up model weights, tokenizer, and RoPE angle look-up table. +2. Tokenize the provided input prompt. +3. Run the generation loop to produce new tokens; this calls the provided forward_pass function. Decode and print each generated token. +""" + +import torch +import math +import sys +import time +import argparse + +import safetensors.torch +import tiktoken, tiktoken.load + +# Configuration +# ########################################################################## + + +class LlamaConfig: + def __init__(self, weights_path, tokenizer_path): + # Model architecture + self.vocab_size = 128256 + self.emb_dim = 2048 + self.n_layers = 16 + self.n_heads = 32 + self.n_kv_groups = 8 + self.head_dim = self.emb_dim // self.n_heads # 64 + self.hidden_dim = 8192 + + # RoPE + self.rope_base = 500000.0 + self.context_length = 131072 + + # Generation + self.temperature = 0.7 + self.top_k = 50 + + # Tokenization + self.special_tokens = { + "<|begin_of_text|>": 128000, + "<|end_of_text|>": 128001, + "<|start_header_id|>": 128006, + "<|end_header_id|>": 128007, + "<|eot_id|>": 128009, + } + self.special_tokens.update( + { + f"<|reserved_{i}|>": i + for i in list(range(128002, 128006)) + list(range(128009, 128256)) + } + ) + + # Load model weights and tokenizer + self.weights = safetensors.torch.load_file(weights_path) + self.tokenizer = get_tokenizer(tokenizer_path, self.special_tokens) + # TODO: Assert that weight dimensions match config + + # Compute RoPE angle look-up table + self.angles = compute_rope_angles( + self.head_dim, self.context_length, self.rope_base + ) + + +class LlamaModelState: + def __init__(self, config): + # Current IDs of tokens being processed (most recent token for decode; all prompt tokens for prefill) + self.token_ids = torch.empty(0, dtype=torch.long) + self.reset_kv_cache(config) + + def reset_kv_cache(self, config): + # Set up KV cache -- initially empty + # This is what passes information from previous tokens to the current token during generation + self.attn_keys_caches = [ + torch.empty( + 1, + config.n_kv_groups, + 0, + config.head_dim, + dtype=config.weights["model.layers.0.self_attn.k_proj.weight"].dtype, + ) # (batch_size, n_kv_groups, seq_len, head_dim) + for _ in range(config.n_layers) + ] + self.attn_values_caches = [ + torch.empty( + 1, + config.n_kv_groups, + 0, + config.head_dim, + dtype=config.weights["model.layers.0.self_attn.v_proj.weight"].dtype, + ) # (batch_size, n_kv_groups, seq_len, head_dim) + for _ in range(config.n_layers) + ] + + +# Utilities +# ########################################################################## + + +def compute_rope_angles(head_dim, context_length, rope_base=500000.0): + """Compute RoPE (Rotary Position Embedding) angles.""" + # Precompute the frequency tensor + inv_freq = 1.0 / (rope_base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + position = torch.arange(context_length).float() + freqs = torch.outer(position, inv_freq) + + cos = torch.cos(freqs) + sin = torch.sin(freqs) + + # Interleave cos and sin - create angles buffer + angles = torch.empty(context_length, head_dim) + angles[:, ::2] = cos + angles[:, 1::2] = sin + return angles + + +def get_tokenizer(tokenizer_path, special_tokens): + mergeable = tiktoken.load.load_tiktoken_bpe(tokenizer_path) + return tiktoken.Encoding( + name="llama3.2-1b", + pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)" + r"|[^\r\n\p{L}\p{N}]?\p{L}+" + r"|\p{N}{1,3}" + r"| ?[^\s\p{L}\p{N}]+[\r\n]*" + r"|\s*[\r\n]+" + r"|\s+(?!\S)" + r"|\s+", + mergeable_ranks=mergeable, + special_tokens=special_tokens, + ) + + +# Generation loop +# ########################################################################## + + +def generate_token(config, forward_pass, state): + generated_tokens = [] + + # Step 1: Forward pass + logits, state = forward_pass(config, state) + + # Step 2: Get logits for last token + last_token_logits = logits[:, -1, :] # (batch, vocab_size) + + # Step 3: Temperature scaling + if config.temperature > 0: + last_token_logits = last_token_logits / config.temperature + + # Step 4: Top-k filtering + if config.top_k is not None: + top_logits, top_indices = torch.topk(last_token_logits, config.top_k) + min_val = top_logits[:, -1:] + last_token_logits = torch.where( + last_token_logits < min_val, torch.tensor(float("-inf")), last_token_logits + ) + + # Step 5: Sample + probs = torch.nn.functional.softmax(last_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + + return next_token.item(), state + + +def parse_args(): + parser = argparse.ArgumentParser(description="LLaMA 3.2 1B Inference Harness") + parser.add_argument( + "weights_path", type=str, help="Path to the model weights (safetensors file)" + ) + parser.add_argument( + "tokenizer_path", type=str, help="Path to the tokenizer model (tiktoken file)" + ) + parser.add_argument( + "--prompt-len", + type=int, + default=2048, + help="Length of the input prompt in tokens (default: 2048)", + ) + parser.add_argument( + "--num-tokens", + type=int, + default=40, + help="Number of tokens to generate (default: 40)", + ) + return parser.parse_args() + + +def get_prompt(prompt_len): + with open("prompt.txt", "r") as f: + prompt = f.read() + prompt = prompt[:prompt_len] + return prompt + + +def init( + weights_path, + tokenizer_path, + prompt="The capital of France is ", +): + config = LlamaConfig(weights_path, tokenizer_path) + state = LlamaModelState(config) + + seed = 1608560892 + torch.manual_seed(seed) + + # Tokenize prompt + prompt_token_ids = [config.special_tokens["<|begin_of_text|>"]] + prompt_token_ids += config.tokenizer.encode(prompt) + assert ( + len(prompt_token_ids) <= config.context_length + ), "Prompt + new tokens to generate too long (exceed context)" + prompt_token_ids = torch.tensor([prompt_token_ids], dtype=torch.long) + + state.token_ids = prompt_token_ids + + return config, state + + +def generate(config, state, forward_pass, num_tokens=100, use_kv_cache=True): + # Generate tokens + # First token (prefill) + n_tokens_generated = 0 + t_prefill_start = time.perf_counter() + first_token, state = generate_token(config, forward_pass, state) + token_text = config.tokenizer.decode([first_token]) + n_tokens_generated += 1 + print(token_text, end="", flush=True) + t_prefill_stop = time.perf_counter() + + # Remaining tokens (decode) + if use_kv_cache: + state.token_ids = torch.tensor([[first_token]], dtype=torch.long) + else: + state.reset_kv_cache(config) + state.token_ids = torch.cat( + [state.token_ids, torch.tensor([[first_token]], dtype=torch.long)], dim=1 + ) + t_decode_start = time.perf_counter() + for _ in range(num_tokens - 1): + next_token, state = generate_token(config, forward_pass, state) + token_text = config.tokenizer.decode([next_token]) + n_tokens_generated += 1 + print(token_text, end="", flush=True) + if use_kv_cache: + state.token_ids = torch.tensor([[next_token]], dtype=torch.long) + else: + state.reset_kv_cache(config) + state.token_ids = torch.cat( + [state.token_ids, torch.tensor([[next_token]], dtype=torch.long)], dim=1 + ) + t_decode_end = time.perf_counter() + + t_prefill = t_prefill_stop - t_prefill_start + t_decode = t_decode_end - t_decode_start + sys.stderr.write("\n\n=== Performance Statistics ===\n") + sys.stderr.write(f"[Prefill] Time to first token: {t_prefill:7.3f} s\n") + if n_tokens_generated > 1: + sys.stderr.write( + f"[Decode] Time per token (mean): {t_decode / (n_tokens_generated - 1):7.3f} s\n" + ) + sys.stderr.write( + f"[Decode] Tokens per second: {(n_tokens_generated - 1) / t_decode:7.3f}\n" + ) + sys.stderr.write( + f"[Total] Time per token (mean): {(t_prefill + t_decode) / n_tokens_generated:7.3f} s\n" + ) + sys.stderr.write( + f"[Total] Tokens per second: {n_tokens_generated / (t_prefill + t_decode):7.3f}\n" + ) + + +if __name__ == "__main__": + main() diff --git a/iron/applications/llama_3.2_1b/llama_npu.py b/iron/applications/llama_3.2_1b/llama_npu.py new file mode 100755 index 00000000..9ca130cb --- /dev/null +++ b/iron/applications/llama_3.2_1b/llama_npu.py @@ -0,0 +1,1350 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Next steps for decode performance: +# [ ] All decode operators operate on 2048-padded buffers; instead, should bin into shorter sequence lengths and call smaller operators +# [ ] Opportunity to fuse data layout transformations (e.g., transpose ops) onto end of other operations (e.g., transpose after RoPE) +# [ ] Some kernels are not optimized; e.g., softmax masking is using scalar cores +# [ ] Fine-tune parameters of operators (e.g., num AIE columns, tile sizes) +# [ ] Patching of operators (instantiating new xrt::elf for each token) is slow; find quicker way of patching instruction sequence in-memory +# [ ] Spatial fusion of operators + +import torch +import math +from pathlib import Path +import sys +import numpy as np +import ml_dtypes +import llama_inference_harness as harness +import logging +import time + +repo_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(repo_root)) + +from iron.common.context import AIEContext +from iron.common import AIEBuffer +from iron.common.utils import torch_to_numpy +from iron.common.base import PatchableSingleXclbinCallable +from iron.common.fusion import ( + FusedMLIROperator, + FusedFullELFCallable, + load_elf, + patch_elf, +) +from iron.operators import ( + AIERMSNorm, + AIEGEMM, + AIEGEMV, + AIEElementwiseAdd, + AIEElementwiseMul, + AIESiLU, + AIERope, + AIEStridedCopy, + AIERepeat, + AIESoftmax, + AIETranspose, +) + +logging.basicConfig(level=logging.DEBUG) + +max_seq_len = 2048 + + +# AIE Operator Configuration +# ########################################################################## + + +aie_ops = None + + +class AIEPrefillOperations: + pass + + +class AIEDecodeOperations: + pass + + +class AIELlamaOperators: + + def __init__(self, config, prompt_len): + self.context = AIEContext() + self.context.build_dir.mkdir(parents=True, exist_ok=True) + + self.prefill = AIEPrefillOperations() + self.decode = AIEDecodeOperations() + + # ################################################################## + # Prefill operators + + self.prefill.rms_norm = ( + AIERMSNorm( + size=prompt_len * config.emb_dim, + eps=1e-5, + num_aie_columns=8, + num_channels=2, + tile_size=config.emb_dim, + weighted=True, + context=self.context, + ) + .compile() + .get_callable() + ) + + self.prefill.residual_add = ( + AIEElementwiseAdd( + size=prompt_len * config.emb_dim, tile_size=config.emb_dim + ) + .compile() + .get_callable() + ) + + min_N = 64 * 8 * 4 # tile_n * num_aie_columns * partition_N + config.padded_vocab_size = (config.vocab_size + min_N - 1) // min_N * min_N + config.vocab_partitions = 4 + self.prefill.gemv_out_head_compilable = AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.padded_vocab_size // config.vocab_partitions, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=True, + separate_c_tiles=True, + context=self.context, + ).compile() + self.prefill.out_head = self.prefill.gemv_out_head_compilable.get_callable() + + # SwiGLU FFN operators + # Prefill: M=prompt_len, K=emb_dim, N=hidden_dim + self.prefill.ffn_up_gate = ( + AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.hidden_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, # exceeds stride dimensions otherwise; just transpose weights + context=self.context, + ) + .compile() + .get_callable() + ) + + self.prefill.ffn_down = ( + AIEGEMM( + M=prompt_len, + K=config.hidden_dim, + N=config.emb_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, # exceeds stride dimensions otherwise; just transpose weights + context=self.context, + ) + .compile() + .get_callable() + ) + + self.prefill.ffn_silu = ( + AIESiLU( + size=prompt_len * config.hidden_dim, + tile_size=config.hidden_dim, + num_aie_columns=8, + context=self.context, + ) + .compile() + .get_callable() + ) + + self.prefill.eltwise_mul_ffn = ( + AIEElementwiseMul( + size=prompt_len * config.hidden_dim, + tile_size=config.hidden_dim, + num_aie_columns=8, + context=self.context, + ) + .compile() + .get_callable() + ) + + # Attention score scaling operators + # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector; maybe use AXPY + self.prefill.attn_scale = ( + AIEElementwiseMul( + size=config.n_heads * prompt_len * prompt_len, + tile_size=prompt_len, + num_aie_columns=8, + context=self.context, + ) + .compile() + .get_callable() + ) + + # RoPE operators + # For queries: (seq_len, num_heads * head_dim) = (seq_len, 2048) + # For keys: (seq_len, num_kv_groups * head_dim) = (seq_len, 512) + # angle_rows=1 because all rows use the same angle row (angles are per position) + self.prefill.rope_queries = ( + AIERope( + rows=prompt_len * config.n_heads, + cols=config.head_dim, + angle_rows=prompt_len, + context=self.context, + ) + .compile() + .get_callable() + ) + + self.prefill.rope_keys = ( + AIERope( + rows=prompt_len * config.n_kv_groups, + cols=config.head_dim, + angle_rows=prompt_len, + context=self.context, + ) + .compile() + .get_callable() + ) + + # Attention projection operators + # Query projection: (seq_len, emb_dim) -> (seq_len, n_heads * head_dim) + self.prefill.attn_query = ( + AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.n_heads * config.head_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + context=self.context, + ) + .compile() + .get_callable() + ) + + # Key projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) + self.prefill.attn_key = ( + AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.n_kv_groups * config.head_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + context=self.context, + ) + .compile() + .get_callable() + ) + + # Value projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) + self.prefill.attn_value = ( + AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.n_kv_groups * config.head_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + context=self.context, + ) + .compile() + .get_callable() + ) + + # Attention score computation: Q @ K^T per head + # For prefill: (seq_len, head_dim) @ (head_dim, seq_len) = (seq_len, seq_len) per head + self.prefill.attn_scores = ( + AIEGEMM( + M=prompt_len, + K=config.head_dim, + N=prompt_len, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + context=self.context, + ) + .compile() + .get_callable() + ) + + # Decode operator (everything temporally fused) + # ################################################################## + + elf_ctx = AIEContext(build_dir="build_elf") + + gemv_attn_query_op = AIEGEMV( + M=config.n_heads * config.head_dim, + K=config.emb_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.head_dim // 2, + context=elf_ctx, + ) + + gemv_attn_key_value_op = AIEGEMV( + M=config.n_kv_groups * config.head_dim, + K=config.emb_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.head_dim // 2, + context=elf_ctx, + ) + + rope_queries_op = AIERope( + rows=1 * config.n_heads, cols=config.head_dim, angle_rows=1, context=elf_ctx + ) + + rope_keys_op = AIERope( + rows=1 * config.n_kv_groups, + cols=config.head_dim, + angle_rows=1, + context=elf_ctx, + ) + + strided_copy_cache_magic = 0xDEADBEE0 + strided_copy_cache_op = AIEStridedCopy( + input_sizes=(config.n_kv_groups, config.head_dim), + input_strides=(config.head_dim, 1), + input_offset=0, + output_sizes=(1, config.n_kv_groups, config.head_dim), + output_strides=(0, prompt_len * config.head_dim, 1), + output_offset=7 * config.head_dim * 2, # Will be patched at runtime + input_buffer_size=1 * config.n_kv_groups * config.head_dim, + output_buffer_size=config.n_kv_groups * prompt_len * config.head_dim, + num_aie_channels=1, + output_offset_patch_marker=strided_copy_cache_magic, + context=elf_ctx, + ) + + # For decode: per head, (1, head_dim) @ (head_dim, max_context_len) + # Use GEMV: (max_context_len, head_dim) @ (head_dim,) = (max_context_len,) + gemv_attn_scores_op = AIEGEMV( + M=prompt_len, # max possible context length + K=config.head_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=prompt_len // 8, + num_batches=config.n_heads, + context=elf_ctx, + ) + + attn_scale_op = AIEElementwiseMul( + size=config.n_heads * prompt_len, + tile_size=prompt_len // 8, + num_aie_columns=8, + context=elf_ctx, + ) + + # Softmax operators for attention weights + softmax_magic = 0xBA5EBA11 + softmax_op = AIESoftmax( + rows=config.n_heads, + cols=prompt_len, + num_aie_columns=1, + num_channels=1, + rtp_vector_size=prompt_len, # Compile with max size + mask_patch_value=softmax_magic, # Magic value for patching + context=elf_ctx, + ) + + # Fused transpose for all attention heads (decode) + transpose_values_op = AIETranspose( + M=prompt_len, + N=config.head_dim, + num_aie_columns=2, + num_channels=1, + m=256, + n=32, + s=8, + context=elf_ctx, + ) + + # GEMV for attention context: (head_dim, max_context_len) @ (max_context_len,) = (head_dim,) per head + gemv_attn_context_op = AIEGEMV( + M=config.head_dim, + K=prompt_len, # max possible context length + num_aie_columns=8, + tile_size_input=4, + tile_size_output=4, + num_batches=config.n_heads, + context=elf_ctx, + ) + + gemv_attn_output_op = AIEGEMV( + M=config.emb_dim, + K=config.n_heads * config.head_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.emb_dim // 8, + context=elf_ctx, + ) + + rms_norm_op = AIERMSNorm( + size=config.emb_dim, + eps=1e-5, + num_aie_columns=1, + num_channels=2, + tile_size=config.emb_dim, + weighted=True, + context=elf_ctx, + ) + + gemv_ffn_up_gate_op = AIEGEMV( + M=config.hidden_dim, + K=config.emb_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.hidden_dim // 8, + context=elf_ctx, + ) + + gemv_ffn_down_op = AIEGEMV( + M=config.emb_dim, + K=config.hidden_dim, + num_aie_columns=8, + tile_size_input=1, + tile_size_output=config.emb_dim // 8, + context=elf_ctx, + ) + + silu_ffn_op = AIESiLU( + size=config.hidden_dim, + tile_size=config.hidden_dim // 8, + num_aie_columns=8, + context=elf_ctx, + ) + + eltwise_mul_ffn_op = AIEElementwiseMul( + size=config.hidden_dim, + tile_size=config.hidden_dim // 8, + num_aie_columns=8, + context=elf_ctx, + ) + + residual_add_op = AIEElementwiseAdd( + size=config.emb_dim, tile_size=config.emb_dim // 8, context=elf_ctx + ) + + repeat_interleave_op = AIERepeat( + rows=config.n_kv_groups, + cols=prompt_len * config.head_dim, # Max context length + repeat=config.n_heads // config.n_kv_groups, + transfer_size=config.head_dim, + context=elf_ctx, + ) + + gemv_out_head_op = AIEGEMV( + M=config.vocab_size, + K=config.emb_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=32, + context=self.context, + ) + + # Create fused operator + + cache_buffer_size = ( + config.n_kv_groups * prompt_len * config.head_dim * 2 + ) # * 2 for bfloat16 + values_per_head_buffer_size = ( + prompt_len * config.head_dim * 2 + ) # * 2 for bfloat16 + values_buffer_size = config.n_heads * values_per_head_buffer_size + + runlist = [] + for layer_idx in range(config.n_layers): + # + runlist.extend( + [ + ( + rms_norm_op, + "x", + f"W_norm1_{layer_idx}", + "x_norm", + ) # Step 1: RMS normalization + ] + + [ + # + ( + gemv_attn_query_op, + f"W_attn_query_{layer_idx}", + "x_norm", + "queries", + ), + ( + gemv_attn_key_value_op, + f"W_attn_key_{layer_idx}", + "x_norm", + "keys", + ), + ( + gemv_attn_key_value_op, + f"W_attn_value_{layer_idx}", + "x_norm", + "values", + ), + (rope_queries_op, "queries", "rope_angles", "queries"), + (rope_keys_op, "keys", "rope_angles", "keys"), + (strided_copy_cache_op, "keys", f"keys_cache_{layer_idx}"), + (strided_copy_cache_op, "values", f"values_cache_{layer_idx}"), + ( + repeat_interleave_op, + f"keys_cache_{layer_idx}", + "attn_scores_keys", + ), + ( + repeat_interleave_op, + f"values_cache_{layer_idx}", + "attn_scores_values", + ), + (gemv_attn_scores_op, "attn_scores_keys", "queries", "attn_scores"), + (attn_scale_op, "attn_scores", "attn_scale_factor", "attn_scores"), + (softmax_op, "attn_scores", "attn_weights"), + ] + + [ + ( + transpose_values_op, + f"attn_scores_values[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]", + f"attn_scores_values_transposed[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]", + ) + for h in range(config.n_heads) + ] + + [ + ( + gemv_attn_context_op, + "attn_scores_values_transposed", + "attn_weights", + "attn_context", + ), + ( + gemv_attn_output_op, + f"W_attn_output_decode_{layer_idx}", + "attn_context", + "attn_output", + ), + # + ] + + [ + (residual_add_op, "x", "attn_output", "x"), + (rms_norm_op, "x", f"W_norm2_{layer_idx}", "x_norm"), + ( + gemv_ffn_up_gate_op, + f"W_ffn_gate_{layer_idx}", + "x_norm", + "ffn_gate", + ), + (gemv_ffn_up_gate_op, f"W_ffn_up_{layer_idx}", "x_norm", "ffn_up"), + (silu_ffn_op, "ffn_gate", "ffn_gate"), + (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), + ( + gemv_ffn_down_op, + f"W_ffn_down_{layer_idx}", + "ffn_hidden", + "ffn_output", + ), + (residual_add_op, "x", "ffn_output", "x"), + ] + ) + # + runlist += [ + (rms_norm_op, "x", "W_final_norm", "x"), + (gemv_out_head_op, "W_out_head", "x", "logits"), + ] + + self.decode.fused_op = FusedMLIROperator( + "fused_op", + runlist, + input_args=[ # arguments that change between invocations of the fused kernel and therefore need to be synced on each token + "x", + "rope_angles", + ], + output_args=["logits"], + buffer_sizes={ + **{ + f"keys_cache_{layer_idx}": cache_buffer_size + for layer_idx in range(config.n_layers) + }, + **{ + f"values_cache_{layer_idx}": cache_buffer_size + for layer_idx in range(config.n_layers) + }, + **{ + "attn_scores_values": values_buffer_size, + "attn_scores_values_transposed": values_buffer_size, + }, + }, + context=elf_ctx, + ).compile() + + # Operator patching + + self.decode.fused_elf_data = load_elf(self.decode.fused_op) + + def get_patch_locs(elf_data, magic): + magic = magic & 0xFFFFFFFF + return np.where(elf_data == magic)[0] + + keys_patches = {} + values_patches = {} + for layer_idx in range(config.n_layers): + _, keys_cache_offs, _ = self.decode.fused_op.get_layout_for_buffer( + f"keys_cache_{layer_idx}" + ) + _, values_cache_offs, _ = self.decode.fused_op.get_layout_for_buffer( + f"values_cache_{layer_idx}" + ) + keys_patches.update( + { + int(l): keys_cache_offs + for l in get_patch_locs( + self.decode.fused_elf_data, + (keys_cache_offs + strided_copy_cache_magic * 2), + ) + } + ) + values_patches.update( + { + int(l): values_cache_offs + for l in get_patch_locs( + self.decode.fused_elf_data, + (values_cache_offs + strided_copy_cache_magic * 2), + ) + } + ) + no_offset_patches = { + int(l): 0 + for l in get_patch_locs( + self.decode.fused_elf_data, (strided_copy_cache_magic * 2) + ) + } + self.decode.fused_patch_locations = { + **keys_patches, + **values_patches, + **no_offset_patches, + } + assert len(self.decode.fused_patch_locations) == 4 * config.n_layers + 2 + + self.decode.softmax_patch_offsets = get_patch_locs( + self.decode.fused_elf_data, softmax_magic + ) + assert len(self.decode.softmax_patch_offsets) == config.n_layers + 1 + + self.decode.fused = FusedFullELFCallable( + self.decode.fused_op, elf_data=self.decode.fused_elf_data + ) + + # Operator static buffers (weights, LUTs) + + for layer_idx in range(config.n_layers): + self.decode.fused.get_buffer(f"W_norm1_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.input_layernorm.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_attn_query_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.self_attn.q_proj.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_attn_key_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.self_attn.k_proj.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_attn_value_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.self_attn.v_proj.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_attn_output_decode_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.self_attn.o_proj.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_norm2_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_ffn_gate_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.mlp.gate_proj.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_ffn_up_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.mlp.up_proj.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_ffn_down_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.mlp.down_proj.weight" + ].flatten() + scale_factor = 1.0 / math.sqrt(config.head_dim) + self.decode.fused.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[ + : + ] = scale_factor + self.decode.fused.get_buffer("W_final_norm").to("cpu").view_as_torch()[:] = ( + config.weights["model.norm.weight"].flatten() + ) + self.decode.fused.get_buffer("W_out_head").to("cpu").view_as_torch()[:] = ( + config.weights["model.embed_tokens.weight"].flatten() + ) + self.decode.fused.input_buffer.to("npu") + self.decode.fused.scratch_buffer.to("npu") + self.decode.fused.output_buffer.to("npu") + + +# Allocate buffers shared with NPU +# ########################################################################## + +aie_buffers = None + + +class AIEPrefillBuffers: + def __init__(self, prompt_len, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): + self.x = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) + self.x_norm = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) + self.attn_output = AIEBuffer( + shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16 + ) + self.ffn_output = AIEBuffer( + shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16 + ) + # SwiGLU intermediate buffers + self.ffn_gate = AIEBuffer( + shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16 + ) + self.ffn_up = AIEBuffer( + shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16 + ) + self.ffn_hidden = AIEBuffer( + shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16 + ) + # Attention buffers: queries and keys serve as both projection output and RoPE input/output + self.queries = AIEBuffer( + shape=(prompt_len * n_heads, head_dim), dtype=ml_dtypes.bfloat16 + ) + self.keys = AIEBuffer( + shape=(prompt_len * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16 + ) + self.values = AIEBuffer( + shape=(prompt_len, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16 + ) + self.rope_angles = AIEBuffer( + shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16 + ) + # Attention score computation buffers (per-head) - parent buffers with subbuffers + # Parent buffer for all heads' queries: (n_heads, prompt_len, head_dim) stored contiguously + self.attn_scores_queries_all = AIEBuffer( + shape=(n_heads * prompt_len, head_dim), dtype=ml_dtypes.bfloat16 + ) + self.attn_scores_queries_per_head = [ + self.attn_scores_queries_all.subbuffer( + length=prompt_len * head_dim, + offset=h * prompt_len * head_dim, + shape=(prompt_len, head_dim), + ) + for h in range(n_heads) + ] + # Parent buffer for all KV groups' keys: (n_kv_groups, head_dim, prompt_len) stored contiguously + self.attn_scores_keys_all = AIEBuffer( + shape=(n_kv_groups * head_dim, prompt_len), dtype=ml_dtypes.bfloat16 + ) + self.attn_scores_keys_per_kv_group = [ + self.attn_scores_keys_all.subbuffer( + length=head_dim * prompt_len, + offset=g * head_dim * prompt_len, + shape=(head_dim, prompt_len), + ) + for g in range(n_kv_groups) + ] + # Parent buffer for all heads' scores: (n_heads * prompt_len, prompt_len) + self.attn_scores = AIEBuffer( + shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16 + ) + self.attn_scores_per_head = [ + self.attn_scores.subbuffer( + length=prompt_len * prompt_len, + offset=h * prompt_len * prompt_len, + shape=(prompt_len, prompt_len), + ) + for h in range(n_heads) + ] + # Attention score scaling buffer (pre-initialized with 1/sqrt(head_dim)) + scale_factor = 1.0 / math.sqrt(head_dim) + self.attn_scale_factor = AIEBuffer( + shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16 + ) + self.attn_scale_factor.view_as_torch()[:] = scale_factor + self.attn_scale_factor.to("npu") + # Attention weights buffer (output of softmax) + self.attn_weights = AIEBuffer( + shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16 + ) + + +class AIELlamaBuffers: + def __init__(self, config, prompt_len): + # Vector of the current token(s) being processed through the pipeline + self.prefill = AIEPrefillBuffers( + prompt_len, + config.emb_dim, + config.hidden_dim, + config.n_heads, + config.n_kv_groups, + config.head_dim, + ) + + # Per-layer KV cache buffers on NPU (used by strided copy for transpose and concatenate) + self.keys_cache = [ + AIEBuffer( + shape=(config.n_kv_groups, prompt_len, config.head_dim), + dtype=ml_dtypes.bfloat16, + ) + for _ in range(config.n_layers) + ] + self.values_cache = [ + AIEBuffer( + shape=(config.n_kv_groups, prompt_len, config.head_dim), + dtype=ml_dtypes.bfloat16, + ) + for _ in range(config.n_layers) + ] + + # Transformer block layer-wise RMS norm + self.W_norm1 = [] + self.W_norm2 = [] + # Attention projection weights + self.W_attn_query_prefill = [] + self.W_attn_query_decode = [] + self.W_attn_key_prefill = [] + self.W_attn_key_decode = [] + self.W_attn_value_prefill = [] + self.W_attn_value_decode = [] + self.W_attn_output_decode = [] + # SwiGLU FFN weights + self.W_ffn_gate_prefill = [] + self.W_ffn_up_prefill = [] + self.W_ffn_down_prefill = [] + self.W_ffn_gate_decode = [] + self.W_ffn_up_decode = [] + self.W_ffn_down_decode = [] + for layer_idx in range(config.n_layers): + self.W_norm1.append( + AIEBuffer.from_torch( + config.weights[f"model.layers.{layer_idx}.input_layernorm.weight"] + ).to("npu") + ) + self.W_norm2.append( + AIEBuffer.from_torch( + config.weights[ + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ] + ).to("npu") + ) + self.W_attn_query_prefill.append( + AIEBuffer.from_torch( + config.weights[ + f"model.layers.{layer_idx}.self_attn.q_proj.weight" + ].T + ).to("npu") + ) + self.W_attn_key_prefill.append( + AIEBuffer.from_torch( + config.weights[ + f"model.layers.{layer_idx}.self_attn.k_proj.weight" + ].T + ).to("npu") + ) + self.W_attn_value_prefill.append( + AIEBuffer.from_torch( + config.weights[ + f"model.layers.{layer_idx}.self_attn.v_proj.weight" + ].T + ).to("npu") + ) + self.W_ffn_gate_prefill.append( + AIEBuffer.from_torch( + config.weights[f"model.layers.{layer_idx}.mlp.gate_proj.weight"].T + ).to("npu") + ) + self.W_ffn_up_prefill.append( + AIEBuffer.from_torch( + config.weights[f"model.layers.{layer_idx}.mlp.up_proj.weight"].T + ).to("npu") + ) + self.W_ffn_down_prefill.append( + AIEBuffer.from_torch( + config.weights[f"model.layers.{layer_idx}.mlp.down_proj.weight"].T + ).to("npu") + ) + + # Final RMS norm weights + self.W_final_norm = AIEBuffer.from_torch( + config.weights["model.norm.weight"] + ).to("npu") + # Final linear layer + self.W_out_head = AIEBuffer.from_torch( + config.weights["model.embed_tokens.weight"] + ).to( + "npu" + ) # unpadded/unpartitioned, used by GEMV + W_out_head_parts = aie_ops.prefill.gemv_out_head_compilable.partition_B( + torch_to_numpy(config.weights["model.embed_tokens.weight"]), + config.vocab_partitions, + ) + self.W_out_head_parts = [ + AIEBuffer.from_np(W_out_head_part).to("npu") + for W_out_head_part in W_out_head_parts + ] # partitioned, padded parts of weight, used by GEMM + self.prefill.logits = AIEBuffer( + shape=( + config.vocab_partitions, + prompt_len, + config.padded_vocab_size // config.vocab_partitions, + ) + ).to("npu") + self.prefill.logits_parts = [ + self.prefill.logits.subbuffer( + length=prompt_len + * (config.padded_vocab_size // config.vocab_partitions), + offset=i + * prompt_len + * (config.padded_vocab_size // config.vocab_partitions), + shape=(prompt_len, config.padded_vocab_size // config.vocab_partitions), + ) + for i in range(config.vocab_partitions) + ] + + +# Prefill +# ########################################################################## + + +def grouped_query_attention_forward_prefill( + config, + x, + keys_cache, + values_cache, + layer_idx, + mask=None, +): + batch, seq_len, emb_dim = x.shape + num_preceding_tokens = keys_cache.shape[2] + + # Step 1: Linear projections + aie_ops.prefill.attn_query( + aie_buffers.prefill.x_norm, + aie_buffers.W_attn_query_prefill[layer_idx], + aie_buffers.prefill.queries, + ) + aie_ops.prefill.attn_key( + aie_buffers.prefill.x_norm, + aie_buffers.W_attn_key_prefill[layer_idx], + aie_buffers.prefill.keys, + ) + aie_ops.prefill.attn_value( + aie_buffers.prefill.x_norm, + aie_buffers.W_attn_value_prefill[layer_idx], + aie_buffers.prefill.values, + ) + + # Step 2: Apply RoPE to queries and keys + aie_ops.prefill.rope_queries( + aie_buffers.prefill.queries, + aie_buffers.prefill.rope_angles, + aie_buffers.prefill.queries, + ) + aie_ops.prefill.rope_keys( + aie_buffers.prefill.keys, + aie_buffers.prefill.rope_angles, + aie_buffers.prefill.keys, + ) + + # Read results from NPU + queries = aie_buffers.prefill.queries.to("cpu").view_as_torch()[ + : seq_len * config.n_heads, : + ] + keys = aie_buffers.prefill.keys.to("cpu").view_as_torch()[ + : seq_len * config.n_kv_groups, : + ] + values = aie_buffers.prefill.values.to("cpu").view_as_torch()[ + :seq_len, : + ] # (seq_len, n_kv_groups * head_dim) + queries = queries.view(batch, seq_len, config.n_heads, config.head_dim) + keys = keys.unsqueeze(0).view(batch, seq_len, config.n_kv_groups, config.head_dim) + values = values.unsqueeze(0).view( + batch, seq_len, config.n_kv_groups, config.head_dim + ) # (batch, seq_len, num_kv_groups, head_dim) + + # Step 3: Transpose for attention computation + # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. + # Transpose so that heads are consecutive for attention computation: + # (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) + queries = queries.transpose(1, 2) # (batch, num_heads, seq_len, head_dim) + keys = keys.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) + values = values.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) + + # Step 4: Combine newly computed keys/values for most recent token with cache; these values are used as the updated cache and will be returned to use in the next iteration. + keys_cache = torch.cat([keys_cache, keys], dim=2) + values_cache = torch.cat([values_cache, values], dim=2) + keys = keys_cache + values = values_cache + + # Step 5: Repeat keys and values for grouped attention -- multiple queries get the same key/value + group_size = config.n_heads // config.n_kv_groups + values = values.repeat_interleave(group_size, dim=1) + context_len = keys.shape[2] + + # Step 6: Compute attention scores using NPU (per-head) + # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, context_len) + # -> (batch, num_heads, seq_len, context_len) + + queries_buf = aie_buffers.prefill.attn_scores_queries_all.view_as_torch().view( + config.n_heads, -1, config.head_dim + ) + queries_buf[:, :seq_len, :] = queries.squeeze(0)[ + :, :seq_len, : + ] # (num_heads, seq_len, head_dim) + keys_buf = aie_buffers.prefill.attn_scores_keys_all.view_as_torch().view( + config.n_kv_groups, config.head_dim, -1 + ) + keys_buf[:, :, :context_len] = keys.squeeze(0).transpose( + -2, -1 + ) # (num_kv_groups, head_dim, context_len) + + # Transfer parent buffers to NPU once + aie_buffers.prefill.attn_scores_queries_all.to("npu") + aie_buffers.prefill.attn_scores_keys_all.to("npu") + aie_buffers.prefill.attn_scores.to("npu") + + # Execute GEMM for each head using sub-buffers + for h in range(config.n_heads): + kv_group = h // group_size + aie_ops.prefill.attn_scores( + aie_buffers.prefill.attn_scores_queries_per_head[h], + aie_buffers.prefill.attn_scores_keys_per_kv_group[kv_group], + aie_buffers.prefill.attn_scores_per_head[h], + ) + + # Read back all results at once from parent buffer and apply scaling on NPU + aie_ops.prefill.attn_scale( + aie_buffers.prefill.attn_scores, + aie_buffers.prefill.attn_scale_factor, + aie_buffers.prefill.attn_scores, + ) + aie_buffers.prefill.attn_scores.to("cpu") + # Buffer is (n_heads * max_seq_len, max_seq_len), view as (n_heads, max_seq_len, max_seq_len) then slice + max_seq_len = aie_buffers.prefill.attn_scores.shape[0] // config.n_heads + scores = ( + aie_buffers.prefill.attn_scores.view_as_torch() + .view(config.n_heads, max_seq_len, max_seq_len) + .unsqueeze(0)[:, :, :seq_len, :context_len] + ) + + # Step 7: Apply mask + # This ensures causality, so that tokens in the future cannot attend to tokens in the past. + if mask is not None: + scores = scores.masked_fill(mask, float("-inf")) + + # Step 8: Apply softmax on CPU + scores = torch.softmax(scores.to(torch.float32), dim=-1).to(torch.bfloat16) + attention_weights = scores + + # Step 9: Compute attention output + # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim) + # -> (batch, num_heads, seq_len, head_dim) + context = torch.matmul(attention_weights, values) + + # Step 10: Concatenate heads and project + # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads * head_dim) + context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) + + output = torch.nn.functional.linear( + context, config.weights[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] + ) + + return output, keys_cache, values_cache + + +def swiglu_ffn_forward_prefill(layer_idx): + # Step 1: Gate projection + aie_ops.prefill.ffn_up_gate( + aie_buffers.prefill.x_norm, + aie_buffers.W_ffn_gate_prefill[layer_idx], + aie_buffers.prefill.ffn_gate, + ) + + # Step 2: Up projection + aie_ops.prefill.ffn_up_gate( + aie_buffers.prefill.x_norm, + aie_buffers.W_ffn_up_prefill[layer_idx], + aie_buffers.prefill.ffn_up, + ) + + # Step 3: Apply SiLU activation + aie_ops.prefill.ffn_silu(aie_buffers.prefill.ffn_gate, aie_buffers.prefill.ffn_gate) + + # Step 4: Element-wise multiplication + aie_ops.prefill.eltwise_mul_ffn( + aie_buffers.prefill.ffn_gate, + aie_buffers.prefill.ffn_up, + aie_buffers.prefill.ffn_hidden, + ) + + # Step 5: Down projection + aie_ops.prefill.ffn_down( + aie_buffers.prefill.ffn_hidden, + aie_buffers.W_ffn_down_prefill[layer_idx], + aie_buffers.prefill.ffn_output, + ) + + +def transformer_block_forward_prefill( + config, seq_len, layer_idx, attn_keys_cache, attn_values_cache, attn_mask +): + # Step 1: RMS normalization + aie_ops.prefill.rms_norm( + aie_buffers.prefill.x, + aie_buffers.W_norm1[layer_idx], + aie_buffers.prefill.x_norm, + ) + aie_buffers.prefill.x_norm.to("cpu") + x_norm = aie_buffers.prefill.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] + + # Step 2: Attention + attn_output, attn_keys, attn_values = grouped_query_attention_forward_prefill( + config, + x_norm, + attn_keys_cache, + attn_values_cache, + layer_idx, + attn_mask, + ) + + # Step 3: Residual + aie_buffers.prefill.attn_output.view_as_torch().unsqueeze(0)[ + 0, :seq_len, : + ] = attn_output + aie_ops.prefill.residual_add( + aie_buffers.prefill.x, aie_buffers.prefill.attn_output, aie_buffers.prefill.x + ) + x = aie_buffers.prefill.x.to("cpu").view_as_torch().unsqueeze(0)[:, :seq_len, :] + + # Step 4: Post-norm + aie_buffers.prefill.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + aie_ops.prefill.rms_norm( + aie_buffers.prefill.x, + aie_buffers.W_norm2[layer_idx], + aie_buffers.prefill.x_norm, + ) + aie_buffers.prefill.x_norm.to("cpu") + x_norm = aie_buffers.prefill.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] + + # Step 5: Feed-forward network + swiglu_ffn_forward_prefill(layer_idx) + + # Step 6: Residual + aie_ops.prefill.residual_add( + aie_buffers.prefill.x, aie_buffers.prefill.ffn_output, aie_buffers.prefill.x + ) + + return attn_keys, attn_values + + +def llama_forward_pass_prefill(config, state): + batch, seq_len = state.token_ids.shape + + # Step 1: RoPE angles + num_preceding_tokens = state.attn_keys_caches[0].shape[2] + angles_slice = config.angles[num_preceding_tokens : num_preceding_tokens + seq_len] + aie_buffers.prefill.rope_angles.view_as_torch()[:seq_len, :] = angles_slice + + # Step 2: Token embedding + tok_emb_weight = config.weights["model.embed_tokens.weight"] + x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) + attn_mask = torch.triu( + torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1 + ) + aie_buffers.prefill.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + + # Step 3: Transformer blocks + for layer_idx in range(config.n_layers): + state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = ( + transformer_block_forward_prefill( + config, + seq_len, + layer_idx, + state.attn_keys_caches[layer_idx], + state.attn_values_caches[layer_idx], + attn_mask=attn_mask, + ) + ) + + # Step 4: Final normalization + aie_ops.prefill.rms_norm( + aie_buffers.prefill.x, aie_buffers.W_final_norm, aie_buffers.prefill.x + ) + + # Step 5: Output projection + for i in range(config.vocab_partitions): + aie_ops.prefill.out_head( + aie_buffers.prefill.x, + aie_buffers.W_out_head_parts[i], + aie_buffers.prefill.logits_parts[i], + ) + aie_buffers.prefill.logits.to("cpu") + logits_padded_partitioned = aie_buffers.prefill.logits.view_as_torch() + logits_padded = ( + logits_padded_partitioned.transpose(0, 1) + .contiguous() + .view(-1, config.padded_vocab_size) + ) + logits = logits_padded.unsqueeze(0)[:, :seq_len, : config.vocab_size] + + # Step 6: Initialize per-layer NPU cache buffers with current cache state for decode phase + for layer_idx in range(config.n_layers): + cache_len = state.attn_keys_caches[layer_idx].shape[2] + aie_buffers.keys_cache[layer_idx].view_as_torch()[:, :cache_len, :] = ( + state.attn_keys_caches[layer_idx].squeeze(0) + ) + aie_buffers.values_cache[layer_idx].view_as_torch()[:, :cache_len, :] = ( + state.attn_values_caches[layer_idx].squeeze(0) + ) + aie_buffers.keys_cache[layer_idx].to("npu") + aie_buffers.values_cache[layer_idx].to("npu") + + return logits, state + + +# Decode +# ########################################################################## + + +def patch_fused_decode_operator(ops, config, num_preceding_tokens): + context_len = num_preceding_tokens + 1 + + # Patch fused operator for strided copy cache offset + output_offset = num_preceding_tokens * config.head_dim + offset_val = output_offset * 2 # Multiply by 2 for bfloat16 byte offset + strided_copy_patches = { + i: (base + offset_val, 0xFFFFFFFF) + for i, base in ops.fused_patch_locations.items() + } + softmax_patches = {i: (context_len, 0xFFFFFFFF) for i in ops.softmax_patch_offsets} + patches = {**strided_copy_patches, **softmax_patches} + patched_elf_data = ops.fused_elf_data.copy() + patch_elf(patched_elf_data, patches) + + ops.fused.reload_elf(patched_elf_data) + + +def llama_forward_pass_decode(config, state): + batch, seq_len = state.token_ids.shape + assert seq_len == 1 + assert state.num_preceding_tokens < max_seq_len + + patch_fused_decode_operator(aie_ops.decode, config, state.num_preceding_tokens) + + # Prefill RoPE angle look-up tables + angles_slice = config.angles[ + state.num_preceding_tokens : state.num_preceding_tokens + seq_len + ] + aie_ops.decode.fused.get_buffer("rope_angles").to("cpu").view_as_torch()[ + : + ] = angles_slice + + # Token embedding (on CPU) + tok_emb_weight = config.weights["model.embed_tokens.weight"] + x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) + aie_ops.decode.fused.get_buffer("x").view_as_torch().view(-1, config.emb_dim)[ + :seq_len, : + ] = x + + # Fused NPU operator for all of decode (16 transformer blocks + final norm + final linear layer) + aie_ops.decode.fused.input_buffer.to("cpu") + aie_ops.decode.fused() + aie_ops.decode.fused.output_buffer.to("cpu") + logits = ( + aie_ops.decode.fused.get_buffer("logits") + .view_as_torch() + .view(1, 1, config.vocab_size) + ) + + return logits, state + + +# Main +# ########################################################################## + + +def llama_forward_pass(config, state): + global aie_ops, aie_buffers + + batch, seq_len = state.token_ids.shape + if seq_len > 1: + ret = llama_forward_pass_prefill(config, state) + state.num_preceding_tokens = state.token_ids.shape[1] + # Pass KV cache data onto fused decode operator + for layer_idx in range(config.n_layers): + aie_ops.decode.fused.get_buffer(f"keys_cache_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = ( + aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten() + ) + aie_ops.decode.fused.get_buffer(f"values_cache_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = ( + aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() + ) + aie_ops.decode.fused.scratch_buffer.to("cpu") + return ret + else: + ret = llama_forward_pass_decode(config, state) + state.num_preceding_tokens += 1 + return ret + + +def main(): + global aie_ops, aie_buffers, max_seq_len + args = harness.parse_args() + + assert ( + max_seq_len >= args.prompt_len + args.num_tokens + ), "max_seq_len must be at least prompt_len + num_tokens" + + prompt = harness.get_prompt(args.prompt_len) + + config, state = harness.init(args.weights_path, args.tokenizer_path, prompt=prompt) + + aie_ops = AIELlamaOperators(config, max_seq_len) + aie_buffers = AIELlamaBuffers(config, max_seq_len) + + print(prompt, end="", flush=True) + harness.generate( + config, state, llama_forward_pass, use_kv_cache=True, num_tokens=args.num_tokens + ) + + +if __name__ == "__main__": + main() diff --git a/iron/applications/llama_3.2_1b/src/block/feed_forward.py b/iron/applications/llama_3.2_1b/src/block/feed_forward.py deleted file mode 100644 index 8bae36ec..00000000 --- a/iron/applications/llama_3.2_1b/src/block/feed_forward.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.nn as nn -from ..utils import assign -from iron.operators import ( - AIEElementwiseMul, - AIEGEMM, - AIEGEMV, - AIESiLU, - AIESwiGLUPrefill, - AIESwiGLUDecode, -) -from ml_dtypes import bfloat16 - - -class FeedForward(nn.Module): - def __init__( - self, - cfg, - prompt_length=0, - num_tokens=1, - ): - super().__init__() - self.cfg = cfg.copy() - - assert ( - cfg["use_aie_ffn_swiglu"] - and not ( - cfg["use_aie_ffn_silu"] - or cfg["use_aie_ffn_gemm"] - or cfg["use_aie_ffn_mul"] - ) - or not cfg["use_aie_ffn_swiglu"] - ), "Cannot mix fused SwiGLU with individual AIE operators." - - self.emb_dim = cfg["emb_dim"] - self.hidden_dim = cfg["hidden_dim"] - - # Initialize SiLU activation - if self.cfg["use_aie_ffn_silu"]: - if self.cfg["use_kv_cache"]: - max_prefill_size = prompt_length * self.hidden_dim - else: - max_prefill_size = (prompt_length + num_tokens) * self.hidden_dim - self.aie_silu_prefill = AIESiLU( - size=max_prefill_size, - num_aie_columns=8, - num_channels=2, - tile_size=self.hidden_dim, - ) - # For decode phase - single token (only when using KV cache) - if self.cfg["use_kv_cache"]: - decode_size = self.hidden_dim # 1 token * emb_dim - self.aie_silu_decode = AIESiLU( - size=decode_size, - num_aie_columns=1, - num_channels=1, - tile_size=self.hidden_dim, - ) - else: - # When not using KV cache, use same operator for both phases - self.aie_silu_decode = self.silu_prefill - else: - self.silu = nn.SiLU() - - if self.cfg["use_aie_ffn_swiglu"]: - self.aie_swiglu_prefill = AIESwiGLUPrefill( - seq_len=prompt_length, - embedding_dim=self.emb_dim, - hidden_dim=self.hidden_dim, - ) - if self.cfg["use_kv_cache"]: - self.aie_swiglu_decode = AIESwiGLUDecode( - embedding_dim=self.emb_dim, hidden_dim=self.hidden_dim - ) - - if self.cfg["use_aie_ffn_gemm"]: - if self.cfg["use_kv_cache"]: - M_prefill = prompt_length - else: - M_prefill = prompt_length + num_tokens - - aie_config_prefill = { - "num_aie_columns": 8, - "tile_m": 64, - "tile_k": 64, - "tile_n": 64, - "use_static_weight": True, - } - - self.fc1 = AIEGEMM( - M=M_prefill, K=self.emb_dim, N=self.hidden_dim, **aie_config_prefill - ) - self.fc2 = AIEGEMM( - M=M_prefill, K=self.emb_dim, N=self.hidden_dim, **aie_config_prefill - ) - self.fc3 = AIEGEMM( - M=M_prefill, K=self.hidden_dim, N=self.emb_dim, **aie_config_prefill - ) - else: - self.fc1 = nn.Linear( - cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False - ) - self.fc2 = nn.Linear( - cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False - ) - self.fc3 = nn.Linear( - cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False - ) - - if self.cfg["use_kv_cache"] and self.cfg["use_aie_ffn_gemv"]: - aie_gemv_config = {"num_aie_columns": 8, "is_mv": False} - # FC1 and FC2: emb_dim -> hidden_dim - self.aie_fc1_gemv = AIEGEMV( - M=self.hidden_dim, - K=self.emb_dim, - tile_size_input=1, - tile_size_output=self.hidden_dim // 16, - **aie_gemv_config, - ) - self.aie_fc2_gemv = AIEGEMV( - M=self.hidden_dim, - K=self.emb_dim, - tile_size_input=1, - tile_size_output=self.hidden_dim // 16, - **aie_gemv_config, - ) - # FC3: hidden_dim -> emb_dim - self.aie_fc3_gemv = AIEGEMV( - M=self.emb_dim, - K=self.hidden_dim, - tile_size_input=1, - tile_size_output=self.emb_dim // 16, - **aie_gemv_config, - ) - - # Initialize AIE elementwise multiply - if self.cfg["use_aie_ffn_mul"]: - if self.cfg["use_kv_cache"]: - max_prefill_size = prompt_length * self.hidden_dim - else: - max_prefill_size = (prompt_length + num_tokens) * self.hidden_dim - - self.aie_mul_prefill = AIEElementwiseMul( - size=max_prefill_size, - num_aie_columns=8, - num_channels=2, - tile_size=self.hidden_dim, - ) - - # For decode phase - single token (only when using KV cache) - if self.cfg["use_kv_cache"]: - decode_size = self.hidden_dim # 1 token * emb_dim - self.aie_mul_decode = AIEElementwiseMul( - size=decode_size, - num_aie_columns=1, - num_channels=2, - tile_size=self.hidden_dim, - ) - else: - # When not using KV cache, use same operator for both phases - self.aie_mul_decode = self.aie_mul_prefill - - def forward(self, x): - original_shape = x.shape - - # Check if input is a vector (decode phase) or matrix (prefill phase) - # Handle 1D: (emb_dim,), 2D: (1, emb_dim), or 3D: (1, 1, emb_dim) - is_vector = ( - len(x.shape) == 1 - or (len(x.shape) == 2 and x.shape[0] == 1) - or (len(x.shape) == 3 and x.shape[0] == 1 and x.shape[1] == 1) - ) - - is_prefill = not is_vector or not self.cfg["use_kv_cache"] - is_decode_with_kv = is_vector and self.cfg["use_kv_cache"] - - if self.cfg["use_aie_ffn_swiglu"]: - if is_prefill: - return self.aie_swiglu_prefill(x) - else: - return self.aie_swiglu_decode(x) - - if is_decode_with_kv and self.cfg["use_aie_ffn_gemv"]: - x_fc1 = self.aie_fc1_gemv(x) - x_fc2 = self.aie_fc2_gemv(x) - else: - x_fc1 = self.fc1(x) - x_fc2 = self.fc2(x) - - if self.cfg["use_aie_ffn_silu"]: - if is_decode_with_kv: - x_fc1_silu = self.aie_silu_decode(x_fc1) - else: - x_fc1_silu = self.aie_silu_prefill(x_fc1) - else: - x_fc1_silu = self.silu(x_fc1) - - if self.cfg["use_aie_ffn_mul"]: - if is_decode_with_kv: - x = self.aie_mul_decode(x_fc1_silu, x_fc2) - else: - x = self.aie_mul_prefill(x_fc1_silu, x_fc2) - else: - x = x_fc1_silu * x_fc2 - - if is_decode_with_kv and self.cfg["use_aie_ffn_gemv"]: - result = self.aie_fc3_gemv(x) - return result.view(original_shape) - else: - return self.fc3(x).view(original_shape) - - def assign_weights(self, l, fc1, fc2, fc3): - if self.cfg["use_kv_cache"] and self.cfg["use_aie_ffn_gemv"]: - self.aie_fc1_gemv.weight = fc1 - self.aie_fc2_gemv.weight = fc2 - self.aie_fc3_gemv.weight = fc3 - - if self.cfg["use_aie_ffn_swiglu"]: - self.aie_swiglu_prefill.weights_1 = fc1 - self.aie_swiglu_prefill.weights_2 = fc2 - self.aie_swiglu_prefill.weights_3 = fc3 - if self.cfg["use_kv_cache"]: - self.aie_swiglu_decode.weights_1 = fc1 - self.aie_swiglu_decode.weights_2 = fc2 - self.aie_swiglu_decode.weights_3 = fc3 - return - - self.fc1.weight = assign( - self.fc1.weight, - fc1, - f"model.layers.{l}.mlp.gate_proj.weight", - ) - self.fc2.weight = assign( - self.fc2.weight, - fc2, - f"model.layers.{l}.mlp.up_proj.weight", - ) - self.fc3.weight = assign( - self.fc3.weight, - fc3, - f"model.layers.{l}.mlp.down_proj.weight", - ) diff --git a/iron/applications/llama_3.2_1b/src/block/gqa.py b/iron/applications/llama_3.2_1b/src/block/gqa.py deleted file mode 100644 index 1a712ff9..00000000 --- a/iron/applications/llama_3.2_1b/src/block/gqa.py +++ /dev/null @@ -1,505 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch -import torch.nn as nn - -from iron.operators import AIERope, AIESoftmax, AIEMHA, AIEGEMM, AIEGEMV -from iron.operators.rope.rope_utils import apply_rope - -from torchtune.modules import KVCache - -from ..utils import assign - - -class GroupedQueryAttention(nn.Module): - def __init__( - self, - d_in, - d_out, - num_heads, - num_kv_groups, - prompt_length=0, - num_tokens=1, - dtype=None, - max_batch_size=1, - max_seq_len=8192, - cfg=None, - ): - super().__init__() - assert d_out % num_heads == 0, "d_out must be divisible by num_heads" - assert ( - num_heads % num_kv_groups == 0 - ), "num_heads must be divisible by num_kv_groups" - - self.cfg = cfg.copy() if cfg is not None else {} - - self.d_out = d_out - self.num_heads = num_heads - self.head_dim = d_out // num_heads - - self.num_tokens = num_tokens - - # Weights for Attention layer - self.W_key = nn.Linear( - d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype - ) - self.W_value = nn.Linear( - d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype - ) - self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) - self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) - - self.num_kv_groups = num_kv_groups - self.group_size = num_heads // num_kv_groups - - self.prompt_length = prompt_length - - aie_gemm_config = { - "num_aie_columns": 8, - "tile_m": 64, - "tile_k": 64, - "tile_n": 64, - "use_static_weight": False, - } - - # Initialize KV Cache - if self.cfg["use_kv_cache"]: - self.kv_cache = KVCache( - batch_size=max_batch_size, - max_seq_len=max_seq_len, - num_kv_heads=self.num_kv_groups, - head_dim=self.head_dim, - dtype=torch.bfloat16, - ) - - # Initialize AIE Regular MHA operator - if self.cfg["use_aie_regular_mha"]: - self.aie_softmax = AIESoftmax( - num_aie_columns=1, - num_channels=1, - rows=prompt_length, - cols=prompt_length, - ) - M_for_gemm = prompt_length + num_tokens - self.aie_mha_gemm_qk = AIEGEMM( - M=M_for_gemm, K=self.head_dim, N=M_for_gemm, **aie_gemm_config - ) - self.aie_mha_gemm_pv = AIEGEMM( - M=M_for_gemm, K=M_for_gemm, N=self.head_dim, **aie_gemm_config - ) - - # Initialize AIE RoPE operator - if self.cfg["use_aie_rope"]: - self.aie_rope_prefill_k = AIERope( - rows=self.prompt_length * self.num_kv_groups, - cols=self.head_dim, - angle_rows=self.prompt_length, - ) - self.aie_rope_prefill_q = AIERope( - rows=self.prompt_length * self.num_heads, - cols=self.head_dim, - angle_rows=self.prompt_length, - ) - self.aie_rope_decode_k = AIERope( - rows=self.num_kv_groups, - cols=self.head_dim, - angle_rows=1, - ) - self.aie_rope_decode_q = AIERope( - rows=self.num_heads, - cols=self.head_dim, - angle_rows=1, - ) - - # Initialize fused AIE MHA operator - if self.cfg["use_aie_fused_mha"]: - self.aie_mha = AIEMHA( - num_heads=num_heads, - seq_len=prompt_length, - d=self.head_dim, - num_KV_heads=0, # Regular MHA since we feed repeated K/V - num_of_pipelines=8, - ) - - # Initialize AIE GEMV operators for decode phase (when using KV cache) - if self.cfg["use_kv_cache"] and self.cfg["use_aie_gqa_gemv"]: - - aie_gemv_config = { - "num_aie_columns": 8, - "is_mv": False, - "use_static_weight": True, - } - self.aie_query_gemv = AIEGEMV( - M=d_out, - K=d_in, - tile_size_input=1, - tile_size_output=d_out // 16, - **aie_gemv_config, - ) - kv_out_dim = num_kv_groups * self.head_dim - self.aie_key_gemv = AIEGEMV( - M=kv_out_dim, - K=d_in, - tile_size_input=1, - tile_size_output=kv_out_dim // 16, - **aie_gemv_config, - ) - self.aie_value_gemv = AIEGEMV( - M=kv_out_dim, - K=d_in, - tile_size_input=1, - tile_size_output=kv_out_dim // 16, - **aie_gemv_config, - ) - self.aie_out_proj_gemv = AIEGEMV( - M=d_out, - K=d_out, - tile_size_input=1, - tile_size_output=d_out // 16, - **aie_gemv_config, - ) - - # Initialize AIE GEMM operators - if self.cfg["use_aie_attn_projection_gemm"]: - if self.cfg["use_kv_cache"]: - M_for_gemm = self.prompt_length - else: - M_for_gemm = self.prompt_length + self.num_tokens - - # GEMMs for projection use weights - aie_gemm_config["use_static_weight"] = True - # Query: (batch_size, d_in) @ (d_in, d_out) -> (batch_size, d_out) - self.aie_query = AIEGEMM(M=M_for_gemm, K=d_in, N=d_out, **aie_gemm_config) - # Key: (batch_size, d_in) @ (d_in, num_kv_groups * head_dim) -> (batch_size, num_kv_groups * head_dim) - kv_out_dim = num_kv_groups * self.head_dim - self.aie_key = AIEGEMM( - M=M_for_gemm, K=d_in, N=kv_out_dim, **aie_gemm_config - ) - # Value: same dimensions as key - self.aie_value = AIEGEMM( - M=M_for_gemm, K=d_in, N=kv_out_dim, **aie_gemm_config - ) - # Output projection: (batch_size, d_out) @ (d_out, d_out) -> (batch_size, d_out) - self.aie_out_proj = AIEGEMM( - M=M_for_gemm, K=d_out, N=d_out, **aie_gemm_config - ) - - def forward(self, x, mask, angles, input_pos=None): - b, num_tokens, d_in = x.shape - is_prefill = input_pos is None - is_decode = input_pos is not None - - # Step 1. - # --- - # Linear projections -- calculate quries, keys and values by multiplying embedding vector (in decode) or matrix (in prefill) with weight matrices - - # Choose between GEMM (prefill) and GEMV (decode) based on KV cache usage - if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gqa_gemv"]: - # Decode phase with KV cache - use GEMV for single token - # weight.T @ input, which is vector-matrix multiplication (So, is_mv=False) - x_flat = x.reshape(1, -1) # Shape: (1, d_in) - - queries_flat = self.aie_query_gemv(x_flat) - queries = queries_flat.reshape(b, num_tokens, self.d_out) - - keys_flat = self.aie_key_gemv(x_flat) - keys = keys_flat.reshape(b, num_tokens, self.num_kv_groups * self.head_dim) - - values_flat = self.aie_value_gemv(x_flat) - values = values_flat.reshape( - b, num_tokens, self.num_kv_groups * self.head_dim - ) - - elif self.cfg["use_aie_attn_projection_gemm"]: - # Prefill phase - use GEMM for multiple tokens - x_flat = x.reshape(-1, d_in) - input_dtype = x.dtype - - queries_flat = self.aie_query(x_flat) - queries = queries_flat.reshape(b, num_tokens, self.d_out) - - keys_flat = self.aie_key(x_flat) - keys = keys_flat.reshape(b, num_tokens, self.num_kv_groups * self.head_dim) - - values_flat = self.aie_value(x_flat) - values = values_flat.reshape( - b, num_tokens, self.num_kv_groups * self.head_dim - ) - else: - queries = self.W_query(x) - keys = self.W_key(x) - values = self.W_value(x) - - # Each attention head gets its own slice of the embedding dimension. - # For each head, we have query, key and value. - # In grouped-query attention, the keys and values are shared across groups of heads. - # Therefore, we have self.num_heads queries, and self.num_kv_groups (== self.num_heads in case of regular attention) keys and values. - # Each head can be applied independently to its subslice of the embedding dimension. - keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim) - values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) - queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) - - # Step 2. - # --- - # Apply positional encoding to keys and queries. - # The positional embedding is applied independently to each head. - # It modifies the embedding vectors to encode where in the sequence each token is located. - - # Determine angle slice based on KV cache usage and phase - if self.cfg["use_kv_cache"] and is_decode: - # Decode phase with KV cache: use single position - current_pos = input_pos.item() - angle_slice = angles[current_pos : current_pos + 1, :] - else: - # Prefill phase or no KV cache: use all tokens - angle_slice = angles[:num_tokens, :] - - # Apply RoPE with AIE - def apply_rope_and_transpose(aie_op, tensor, num_heads_dim, angle_slice): - angle_slice = angle_slice.to(dtype=tensor.dtype) - if self.cfg["use_aie_rope"]: - result = aie_op( - tensor.view(num_tokens * num_heads_dim, self.head_dim), angle_slice - ) - result = result.view( - b, num_tokens, num_heads_dim, self.head_dim - ).transpose(1, 2) - else: - transposed = ( - tensor.view(num_tokens, num_heads_dim, self.head_dim) - .transpose(0, 1) - .contiguous() - ) - result = apply_rope( - transposed.view(1, num_heads_dim, num_tokens, self.head_dim), - angle_slice, - ) - # ref = apply_rope(transposed.view(1, num_heads_dim, num_tokens, self.head_dim), angle_slice) - # assert torch.allclose(ref, result, atol=0.7, rtol=0.07), "AIE RoPE result does not match reference" - return result - - keys = apply_rope_and_transpose( - ( - (self.aie_rope_prefill_k if is_prefill else self.aie_rope_decode_k) - if self.cfg["use_aie_rope"] - else None - ), - keys, - self.num_kv_groups, - angle_slice, - ) - queries = apply_rope_and_transpose( - ( - (self.aie_rope_prefill_q if is_prefill else self.aie_rope_decode_q) - if self.cfg["use_aie_rope"] - else None - ), - queries, - self.num_heads, - angle_slice, - ) - values = values.transpose(1, 2) - - if self.cfg["use_kv_cache"]: - if is_prefill: - self.kv_cache.reset() - self.kv_cache.update(keys, values) - cached_keys, cached_values = keys, values - else: - self.kv_cache.update(keys, values) - current_seq_len = input_pos.item() + 1 - cached_keys = self.kv_cache.k_cache[:, :, :current_seq_len, :] - cached_values = self.kv_cache.v_cache[:, :, :current_seq_len, :] - - keys = cached_keys - values = cached_values - - # Step 3. - # --- - # Since the keys and values are shared across groups of heads in grouped-query attention, - # we now expand (repeat) the same keys and values so that each head has its own keys and values. - keys = keys.repeat_interleave(self.group_size, dim=1) - values = values.repeat_interleave(self.group_size, dim=1) - - # Step 4. - # --- - # Compute attention scores (indepdentently for each head), apply softmax to get attention weights, then apply those weights to the attention values to get output. - # Attention scores are the dot-product of queries and keys. - - # Use fused AIE MHA if enabled and conditions are met - if is_prefill or not self.cfg["use_kv_cache"]: - if ( - self.cfg["use_aie_fused_mha"] - and b == 1 - and num_tokens == self.prompt_length - and self.head_dim == 64 - ): - # TODO: Doesn't give good output ven with num_kv_groups set to 8 with kv_cache - # TODO: Doesn't match the output of CPU only when used without kv_cache - context_vec = self.aie_mha( - queries, keys, values - ) # Shape: (num_heads, num_tokens, head_dim) - - # Reshape context_vec to prepare for output projection - context_vec = context_vec.transpose(0, 1) - context_vec = context_vec.reshape(b, num_tokens, self.d_out) - - elif self.cfg["use_aie_regular_mha"]: - # attn_scores = queries @ keys.transpose(2, 3) - # Compute attention scores for each head separately since AIE GEMM doesn't support batched operations - attn_scores_list = [] - for head in range(self.num_heads): - q_head = queries[:, head, :, :] # Shape: (b, num_tokens, head_dim) - k_head = keys[:, head, :, :] # Shape: (b, num_tokens, head_dim) - - # Use 2D tensors directly (remove batch dimension if b=1) - q_2d = q_head.squeeze(0) # Shape: (num_tokens, head_dim) - k_2d = k_head.squeeze(0) # Shape: (num_tokens, head_dim) - - # Compute Q @ K^T for this head - attn_head = self.aie_mha_gemm_qk( - q_2d, k_2d.T - ) # Shape: (num_tokens, num_tokens) - attn_head = attn_head.unsqueeze(0).unsqueeze( - 0 - ) # Add batch and head dimensions - attn_scores_list.append( - attn_head - ) # Shape: (1, 1, num_tokens, num_tokens) - - attn_scores = torch.cat( - attn_scores_list, dim=1 - ) # Shape: (b, num_heads, num_tokens, num_tokens) - attn_scores = attn_scores.masked_fill(mask, -torch.inf) - scaled_scores = attn_scores / (self.head_dim**0.5) - - # TODO: Make softmax more configurable to run in any scenario - if ( - scaled_scores.shape[-1] == self.prompt_length - and scaled_scores.shape[-1] % 16 == 0 - ): - attn_weights = self.aie_softmax(scaled_scores) - else: - attn_weights = torch.nn.functional.softmax(scaled_scores, dim=-1) - - # Compute context vector for each head separately using AIE GEMM - context_vec_list = [] - for head in range(self.num_heads): - attn_head = attn_weights[ - :, head, :, : - ] # Shape: (b, num_tokens, num_tokens) - v_head = values[:, head, :, :] # Shape: (b, num_tokens, head_dim) - - # Use 2D tensors directly (remove batch dimension if b=1) - attn_2d = attn_head.squeeze(0) # Shape: (num_tokens, num_tokens) - v_2d = v_head.squeeze(0) # Shape: (num_tokens, head_dim) - - # Compute attn @ V for this head - context_head = self.aie_mha_gemm_pv( - attn_2d, v_2d - ) # Shape: (num_tokens, head_dim) - context_head = context_head.unsqueeze(0).unsqueeze( - 1 - ) # Add batch and head dimensions - context_vec_list.append( - context_head - ) # Shape: (1, 1, num_tokens, head_dim) - - context_vec = torch.cat( - context_vec_list, dim=1 - ) # Shape: (b, num_heads, num_tokens, head_dim) - context_vec = context_vec.transpose( - 1, 2 - ) # Shape: (b, num_tokens, num_heads, head_dim) - context_vec = context_vec.reshape(b, num_tokens, self.d_out) - else: - - def my_mha(queries, keys, values): - inv_scale = 1 / np.sqrt(values.shape[-1]) - context_vec = torch.nn.functional.scaled_dot_product_attention( - queries, - keys, - values, - dropout_p=0.0, - is_causal=True, - scale=inv_scale, - ) - return context_vec - - context_vec = my_mha(queries, keys, values) - context_vec = context_vec.transpose(1, 2) - context_vec = context_vec.reshape(b, num_tokens, self.d_out) - else: - attn_scores = queries @ keys.transpose(2, 3) - - if mask is not None: - attn_scores = attn_scores.masked_fill(mask, -torch.inf) - - scaled_scores = attn_scores / (self.head_dim**0.5) - - if ( - scaled_scores.shape[-1] == self.prompt_length - and self.cfg["use_aie_softmax"] - and scaled_scores.shape[-1] % 16 == 0 - ): - attn_weights = self.aie_softmax(scaled_scores) - else: - attn_weights = torch.nn.functional.softmax(scaled_scores, dim=-1) - - context_vec = (attn_weights @ values).transpose(1, 2) - context_vec = context_vec.reshape(b, num_tokens, self.d_out) - - # Choose output projection based on phase - if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gqa_gemv"]: - context_vec_flat = context_vec.reshape(1, -1) - output_flat = self.aie_out_proj_gemv(context_vec_flat) - context_vec = output_flat.reshape(b, num_tokens, self.d_out) - elif self.cfg["use_aie_attn_projection_gemm"]: - context_vec_flat = context_vec.reshape(-1, self.d_out) - output_flat = self.aie_out_proj(context_vec_flat) - context_vec = output_flat.reshape(b, num_tokens, self.d_out) - else: - context_vec = self.out_proj(context_vec) - - return context_vec - - def assign_weights(self, l, w_query, w_key, w_value, w_out_proj): - if self.cfg["use_kv_cache"] and self.cfg["use_aie_gqa_gemv"]: - self.aie_query_gemv.weight = w_query - self.aie_key_gemv.weight = w_key - self.aie_value_gemv.weight = w_value - self.aie_out_proj_gemv.weight = w_out_proj - - if self.cfg["use_aie_attn_projection_gemm"]: - self.aie_query.weight = w_query - self.aie_key.weight = w_key - self.aie_value.weight = w_value - self.aie_out_proj.weight = w_out_proj - - self.W_query.weight = assign( - self.W_query.weight, - w_query, - f"model.layers.{l}.self_attn.q_proj.weight", - ) - self.W_key.weight = assign( - self.W_key.weight, - w_key, - f"model.layers.{l}.self_attn.k_proj.weight", - ) - self.W_value.weight = assign( - self.W_value.weight, - w_value, - f"model.layers.{l}.self_attn.v_proj.weight", - ) - self.out_proj.weight = assign( - self.out_proj.weight, - w_out_proj, - f"model.layers.{l}.self_attn.o_proj.weight", - ) diff --git a/iron/applications/llama_3.2_1b/src/block/transformer.py b/iron/applications/llama_3.2_1b/src/block/transformer.py deleted file mode 100644 index f2b46cdf..00000000 --- a/iron/applications/llama_3.2_1b/src/block/transformer.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.nn as nn -from ..utils import assign -from src.block.gqa import GroupedQueryAttention -from src.block.feed_forward import FeedForward -from iron.operators import AIERMSNorm, AIEElementwiseAdd - - -class TransformerBlock(nn.Module): - def __init__( - self, - cfg, - prompt_length=42, - num_tokens=1, - ): - super().__init__() - self.cfg = cfg.copy() - - self.att = GroupedQueryAttention( - d_in=cfg["emb_dim"], - d_out=cfg["emb_dim"], - num_heads=cfg["n_heads"], - num_kv_groups=cfg["n_kv_groups"], - dtype=cfg["dtype"], - prompt_length=prompt_length, - cfg=cfg, - ) - self.ff = FeedForward( - cfg, - prompt_length=prompt_length, - num_tokens=num_tokens, - ) - - if self.cfg["use_aie_norm1"]: - if self.cfg["use_kv_cache"]: - max_prefill_size = prompt_length * self.cfg["emb_dim"] - else: - max_prefill_size = (prompt_length + num_tokens) * self.cfg["emb_dim"] - self.aie_norm1_prefill = AIERMSNorm( - size=max_prefill_size, - eps=1e-5, - num_aie_columns=8, - num_channels=2, - tile_size=self.cfg["emb_dim"], - ) - # For decode phase - single token (only when using KV cache) - if self.cfg["use_kv_cache"]: - decode_size = self.cfg["emb_dim"] # 1 token * emb_dim - self.aie_norm1_decode = AIERMSNorm( - size=decode_size, - eps=1e-5, - num_aie_columns=1, - num_channels=2, - tile_size=self.cfg["emb_dim"], - ) - else: - # When not using KV cache, use same operator for both phases - self.aie_norm1_decode = self.aie_norm1_prefill - else: - self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) - - if self.cfg["use_aie_norm2"]: - if self.cfg["use_kv_cache"]: - max_prefill_size = prompt_length * self.cfg["emb_dim"] - else: - max_prefill_size = (prompt_length + num_tokens) * self.cfg["emb_dim"] - self.aie_norm2_prefill = AIERMSNorm( - size=max_prefill_size, - eps=1e-5, - num_aie_columns=8, - num_channels=2, - tile_size=self.cfg["emb_dim"], - ) - # For decode phase - single token (only when using KV cache) - if self.cfg["use_kv_cache"]: - decode_size = self.cfg["emb_dim"] # 1 token * emb_dim - self.aie_norm2_decode = AIERMSNorm( - size=decode_size, - eps=1e-5, - num_aie_columns=1, - num_channels=2, - tile_size=self.cfg["emb_dim"], - ) - else: - # When not using KV cache, use same operator for both phases - self.aie_norm2_decode = self.aie_norm2_prefill - else: - self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) - - if self.cfg["use_aie_residual"]: - if self.cfg["use_kv_cache"]: - max_prefill_size = prompt_length * cfg["emb_dim"] - else: - max_prefill_size = (prompt_length + num_tokens) * cfg["emb_dim"] - - self.aie_residual_add_prefill = AIEElementwiseAdd( - size=max_prefill_size, - num_aie_columns=8, - num_channels=2, - tile_size=cfg["emb_dim"], - ) - - # For decode phase - single token (only when using KV cache) - if self.cfg["use_kv_cache"]: - decode_size = cfg["emb_dim"] # 1 token * emb_dim - self.aie_residual_add_decode = AIEElementwiseAdd( - size=decode_size, - num_aie_columns=1, - num_channels=2, - tile_size=cfg["emb_dim"], - ) - else: - # When not using KV cache, use same operator for both phases - self.aie_residual_add_decode = self.aie_residual_add_prefill - - def forward(self, x, mask, angles, input_pos): - original_shape = x.shape - - # (batch, sequence, embedding) where sequence=1 indicates decode - if len(x.shape) == 3: - is_decode_with_kv = (x.shape[1] == 1) and self.cfg["use_kv_cache"] - elif len(x.shape) == 2: - is_decode_with_kv = (x.shape[0] == 1) and self.cfg["use_kv_cache"] - else: - is_decode_with_kv = False - - shortcut = x - if self.cfg["use_aie_norm1"]: - if is_decode_with_kv: - x = self.aie_norm1_decode(x) - else: - x = self.aie_norm1_prefill(x) - else: - x = self.norm1(x) - - x = self.att(x, mask, angles, input_pos) - - if self.cfg["use_aie_residual"]: - if is_decode_with_kv: - x = self.aie_residual_add_decode(x, shortcut) - else: - x = self.aie_residual_add_prefill(x, shortcut) - else: - x = x + shortcut - - # Shortcut connection for feed-forward block - shortcut = x - if self.cfg["use_aie_norm2"]: - if is_decode_with_kv: - x = self.aie_norm2_decode(x) - else: - x = self.aie_norm2_prefill(x) - else: - x = self.norm2(x) - x = self.ff(x) - - if self.cfg["use_aie_residual"]: - if is_decode_with_kv: - x = self.aie_residual_add_decode(x, shortcut) - else: - x = self.aie_residual_add_prefill(x, shortcut) - else: - x = x + shortcut - - return x - - def assign_weights(self, l, norm1, norm2): - if self.cfg["use_aie_norm1"]: - self.aie_norm1_prefill.weight = norm1 - if self.cfg["use_kv_cache"]: - self.aie_norm1_decode.weight = norm1 - if self.cfg["use_aie_norm2"]: - self.aie_norm2_prefill.weight = norm2 - if self.cfg["use_kv_cache"]: - self.aie_norm2_decode.weight = norm2 - return - - self.norm1.weight = assign( - self.norm1.weight, - norm1, - f"model.layers.{l}.input_layernorm.weight", - ) - self.norm2.weight = assign( - self.norm2.weight, - norm2, - f"model.layers.{l}.post_attention_layernorm.weight", - ) diff --git a/iron/applications/llama_3.2_1b/src/model_with_json.py b/iron/applications/llama_3.2_1b/src/model_with_json.py deleted file mode 100644 index 856fb048..00000000 --- a/iron/applications/llama_3.2_1b/src/model_with_json.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.nn as nn -import json -from pathlib import Path -from src.block.transformer import TransformerBlock -from iron.operators.rope.rope_utils import compute_rope_params -from iron.operators import AIERMSNorm, AIEGEMM, AIEGEMV -from rich.console import Console -from rich.text import Text - -from .utils import assign - - -def dtype_from_string(inp): - if isinstance(inp, torch.dtype): - return inp - return {"bfloat16": torch.bfloat16, "float16": torch.float16}.get( - inp, torch.float32 - ) - - -# fmt: off -# Configuration flag key -> (type function, default value, description) -config_options = { - "dtype": (dtype_from_string, torch.float32, "Data type"), - "use_kv_cache": (bool, False, "[Model] KV Cache"), - "use_aie_rope": (bool, False, "[Attention] Rope"), - "use_aie_attn_projection_gemm": (bool, False, "[Attention] QKV GEMM"), - "use_aie_regular_mha": (bool, False, "[Attention] Regular MHA"), - "use_aie_fused_mha": (bool, False, "[Attention] Fused MHA"), - "use_aie_gqa_gemv": (bool, False, "[Attention] GEMV (Decode)"), - "use_aie_ffn_gemm": (bool, False, "[FFN] GEMM"), - "use_aie_ffn_mul": (bool, False, "[FFN] Elementwise Mul"), - "use_aie_ffn_silu": (bool, False, "[FFN] SiLU"), - "use_aie_ffn_swiglu": (bool, False, "[FFN] Runlist-based SwiGLU"), - "use_aie_ffn_gemv": (bool, False, "[FFN] GEMV (Decode)"), - "use_aie_residual": (bool, False, "[Transformer] Residual Addition"), - "use_aie_norm1": (bool, False, "[Transformer] Pre Norm"), - "use_aie_norm2": (bool, False, "[Transformer] Post Norm"), - "use_aie_final_norm": (bool, False, "[Transformer] Final Norm"), - "use_aie_final_gemm": (bool, False, "[Transformer] Final GEMM"), - "use_aie_final_gemv": (bool, False, "[Transformer] Final GEMV"), -} -# fmt: on - - -def load_llama_config(config_path=None): - """Load Llama configuration from JSON file""" - if config_path is None: - # Default to config.json in the llama directory - config_path = Path(__file__).parent.parent / "llama32_1b.json" - - with open(config_path, "r") as f: - config = json.load(f) - - model_config = config["model_config"].copy() - for key, (type_fn, default_value, description) in config_options.items(): - if key in model_config: - model_config[key] = type_fn(model_config[key]) - else: - model_config[key] = default_value - - return model_config - - -def print_config(cfg, console=Console()): - def format_option(name, value): - if isinstance(value, bool): - checkmark = "[green]✔[/green]" if value else "[red]✘[/red]" - return f"{name} {checkmark}" - return f"{name}: {value}" - - dont_print = {"dtype"} - # The following options are mutually exclusive, e.g. regular and fused MHA - # cannot be enabled at the same time. But it looks bad to have red Xs, - # indicating things are running on the CPU when they are not. So, we only - # print one of these mutually exclusive options. - if cfg["use_aie_fused_mha"]: - dont_print |= {"use_aie_regular_mha"} - else: - dont_print |= {"use_aie_fused_mha"} - if cfg["use_aie_ffn_swiglu"]: - dont_print |= { - "use_aie_ffn_gemm", - "use_aie_ffn_mul", - "use_aie_ffn_silu", - } - else: - dont_print |= {"use_aie_ffn_swiglu"} - - console.print( - "AIE Configuration ([green]✔[/green] = AIE NPU / [red]✘[/red] = CPU):", - style="bold underline", - ) - for option_key, (option_ty, option_default, option_name) in config_options.items(): - if option_key in dont_print: - continue - console.print(format_option(option_name, cfg.get(option_key, option_default))) - console.print("") - - -class Llama3ModelWithJSONConfig(nn.Module): - """Llama3 model that loads configuration from JSON file""" - - def __init__( - self, - config_path=None, - prompt_length=0, - num_tokens=1, - ): - super().__init__() - - # Load configuration from JSON - self.cfg = load_llama_config(config_path) - self.prompt_length = prompt_length - self.num_tokens = num_tokens - print_config(self.cfg) - - # Main model parameters - self.tok_emb = nn.Embedding( - self.cfg["vocab_size"], self.cfg["emb_dim"], dtype=self.cfg["dtype"] - ) - - self.trf_blocks = nn.ModuleList( - [ - TransformerBlock( - self.cfg, - prompt_length=prompt_length, - num_tokens=num_tokens, - ) - for i in range(self.cfg["n_layers"]) - ] - ) - - # Create final norm - either AIE or PyTorch - if self.cfg.get("use_aie_final_norm", False): - if self.cfg["use_kv_cache"]: - max_prefill_size = prompt_length * self.cfg["emb_dim"] - else: - max_prefill_size = (prompt_length + num_tokens) * self.cfg["emb_dim"] - self.aie_final_norm_prefill = AIERMSNorm( - size=max_prefill_size, - eps=1e-5, - num_aie_columns=8, - num_channels=2, - tile_size=self.cfg["emb_dim"], - ) - # For decode phase - single token (only when using KV cache) - if self.cfg["use_kv_cache"]: - decode_size = self.cfg["emb_dim"] # 1 token * emb_dim - self.aie_final_norm_decode = AIERMSNorm( - size=decode_size, - eps=1e-5, - num_aie_columns=1, - num_channels=2, - tile_size=self.cfg["emb_dim"], - ) - else: - # When not using KV cache, use same operator for both phases - self.aie_final_norm_decode = self.aie_final_norm_prefill - else: - self.final_norm = nn.RMSNorm( - self.cfg["emb_dim"], eps=1e-5, dtype=self.cfg["dtype"] - ) - - # Offload final linear layer if enabled - if self.cfg.get("use_aie_final_gemm", False): - # Since this GEMM has such a large N dimension, partition the N dimension by 4, - # and GEMM will execute for a workload of that smaller N dimension across different buffers of B and C - aie_config_prefill = { - "num_aie_columns": 8, - "tile_m": 64, - "tile_k": 64, - "tile_n": 64, - "b_col_maj": True, - "use_static_weight": True, - "separate_c_tiles": True, - "partition_N": 4, - } - if self.cfg["use_kv_cache"]: - M_for_gemm = self.prompt_length - else: - M_for_gemm = self.prompt_length + self.num_tokens - self.out_head_prefill = AIEGEMM( - M=M_for_gemm, - K=self.cfg["emb_dim"], - N=self.cfg["vocab_size"], - **aie_config_prefill, - ) - aie_gemv_config = { - "num_aie_columns": 8, - "is_mv": True, - "use_static_weight": True, - "num_aie_columns": 8, - "tile_size_input": 4, - "tile_size_output": 32, - } - # FC1 and FC2: emb_dim -> hidden_dim - if self.cfg["use_aie_final_gemv"]: - self.out_head_decode = AIEGEMV( - M=self.cfg["vocab_size"], K=self.cfg["emb_dim"], **aie_gemv_config - ) - else: - self.out_head = nn.Linear( - self.cfg["emb_dim"], - self.cfg["vocab_size"], - bias=False, - dtype=self.cfg["dtype"], - ) - - # Reusable utilities - cos, sin = compute_rope_params( - head_dim=self.cfg["emb_dim"] // self.cfg["n_heads"], - theta_base=self.cfg["rope_base"], - context_length=self.cfg["context_length"], - freq_config=self.cfg["rope_freq"], - ) - angles = torch.cat([torch.empty_like(cos), torch.empty_like(cos)], dim=1) - angles[:, ::2] = cos - angles[:, 1::2] = sin - self.register_buffer("angles", angles, persistent=False) - - def forward(self, in_idx, input_pos=None, use_kv_cache=False): - # Forward pass - tok_embeds = self.tok_emb(in_idx) - x = tok_embeds - - # Check if input is a vector (decode phase) or matrix (prefill phase) - # Handle 1D: (emb_dim,), 2D: (1, emb_dim), or 3D: (1, 1, emb_dim) - is_vector = ( - len(x.shape) == 1 - or (len(x.shape) == 2 and x.shape[0] == 1) - or (len(x.shape) == 3 and x.shape[0] == 1 and x.shape[1] == 1) - ) - - # (batch, sequence, embedding) where sequence=1 indicates decode - if len(x.shape) == 3: - is_decode_with_kv = (x.shape[1] == 1) and self.cfg["use_kv_cache"] - elif len(x.shape) == 2: - is_decode_with_kv = (x.shape[0] == 1) and self.cfg["use_kv_cache"] - else: - is_decode_with_kv = False - - num_tokens = x.shape[1] - - # During generation phase with KV cache, don't create a mask - # The attention layer will handle masking based on position - if use_kv_cache and input_pos is not None: - mask = None - else: - # During prefill, create standard causal mask - mask = torch.triu( - torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), - diagonal=1, - ) - - for block in self.trf_blocks: - x = block(x, mask, self.angles, input_pos) - - # Sequence length of 1 from input shape means we're in the decode stage, which can use KV cache - if self.cfg.get("use_aie_final_norm", False): - if (x.shape[-2] == 1) and self.cfg.get("use_kv_cache", False): - x = self.aie_final_norm_decode(x) - else: - x = self.aie_final_norm_prefill(x) - else: - x = self.final_norm(x) - - if self.cfg["use_aie_final_gemm"]: - if is_decode_with_kv and self.cfg["use_aie_final_gemv"]: - logits = self.out_head_decode(x) - else: - logits = self.out_head_prefill(x) - else: - logits = self.out_head(x) - - return logits - - def assign_weights(self, final_norm, out_head, out_head_name): - if self.cfg.get("use_aie_final_norm", False): - self.aie_final_norm_prefill.weight = final_norm - if self.cfg["use_kv_cache"]: - self.aie_final_norm_decode.weight = final_norm - else: - self.final_norm.weight = assign( - self.final_norm.weight, - final_norm, - f"model.norm.weight", - ) - - if self.cfg["use_aie_final_gemm"]: - # Want column-major for B - self.out_head_prefill.weight = out_head.T - if self.cfg["use_aie_final_gemv"]: - self.out_head_decode.weight = out_head.T - else: - self.out_head.weight = assign( - self.out_head.weight, - out_head, - out_head_name, - ) diff --git a/iron/applications/llama_3.2_1b/src/tokenizer.py b/iron/applications/llama_3.2_1b/src/tokenizer.py deleted file mode 100644 index 1a16cf57..00000000 --- a/iron/applications/llama_3.2_1b/src/tokenizer.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import os -from pathlib import Path - -import tiktoken -from tiktoken.load import load_tiktoken_bpe - - -class Tokenizer: - """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.""" - - def __init__(self, model_path): - if not os.path.isfile(model_path): - raise FileNotFoundError(model_path) - - mergeable = load_tiktoken_bpe(model_path) - - # hard-coded from Meta's tokenizer.json - self.special = { - "<|begin_of_text|>": 128000, - "<|end_of_text|>": 128001, - "<|start_header_id|>": 128006, - "<|end_header_id|>": 128007, - "<|eot_id|>": 128009, - } - self.special.update( - { - f"<|reserved_{i}|>": 128002 + i - for i in range(256) - if 128002 + i not in self.special.values() - } - ) - - self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)" - r"|[^\r\n\p{L}\p{N}]?\p{L}+" - r"|\p{N}{1,3}" - r"| ?[^\s\p{L}\p{N}]+[\r\n]*" - r"|\s*[\r\n]+" - r"|\s+(?!\S)" - r"|\s+", - mergeable_ranks=mergeable, - special_tokens=self.special, - ) - - def encode(self, text, bos=False, eos=False): - ids = ([self.special["<|begin_of_text|>"]] if bos else []) + self.model.encode( - text - ) - if eos: - ids.append(self.special["<|end_of_text|>"]) - return ids - - def decode(self, ids): - return self.model.decode(ids) - - -class ChatFormat: - - def __init__( - self, tokenizer: Tokenizer, *, default_system="You are a helpful assistant." - ): - self.tok = tokenizer - self.default_system = default_system - - def _header(self, role): - """Encode <|start_header_id|>role<|end_header_id|>\n\n""" - return ( - [self.tok.special["<|start_header_id|>"]] - + self.tok.encode(role) - + [self.tok.special["<|end_header_id|>"]] - + self.tok.encode("\n\n") - ) - - def encode(self, user_message, system_message=None): - sys_msg = system_message if system_message is not None else self.default_system - - ids = [self.tok.special["<|begin_of_text|>"]] - - # system - ids += self._header("system") - ids += self.tok.encode(sys_msg) - ids += [self.tok.special["<|eot_id|>"]] - - # user - ids += self._header("user") - ids += self.tok.encode(user_message) - ids += [self.tok.special["<|eot_id|>"]] - - # assistant header (no content yet) - ids += self._header("assistant") - - return ids diff --git a/iron/applications/llama_3.2_1b/src/utils.py b/iron/applications/llama_3.2_1b/src/utils.py deleted file mode 100644 index 158b59df..00000000 --- a/iron/applications/llama_3.2_1b/src/utils.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import time -import torch -import numpy as np -from ml_dtypes import bfloat16 - - -def model_memory_size(model, input_dtype=torch.float32): - """ - Calculate the estimated memory size of a PyTorch model in gigabytes. - - This function computes the total memory required for the model's parameters, - gradients, and buffers based on the input data type. - - Args: - model (torch.nn.Module): The PyTorch model for which to calculate memory size. - input_dtype (torch.dtype, optional): The data type of the model's input. - Defaults to torch.float32. - - Returns: - float: The estimated memory size of the model in gigabytes. - """ - - total_params = 0 - total_grads = 0 - for param in model.parameters(): - # Calculate total number of elements per parameter - param_size = param.numel() - total_params += param_size - # Check if gradients are stored for this parameter - if param.requires_grad: - total_grads += param_size - - # Calculate buffer size (non-parameters that require memory) - total_buffers = sum(buf.numel() for buf in model.buffers()) - - # Size in bytes = (Number of elements) * (Size of each element in bytes) - # We assume parameters and gradients are stored in the same type as input dtype - element_size = torch.tensor(0, dtype=input_dtype).element_size() - total_memory_bytes = (total_params + total_grads + total_buffers) * element_size - - # Convert bytes to gigabytes - total_memory_gb = total_memory_bytes / (1024**3) - - return total_memory_gb - - -def assign(left, right, tensor_name="unknown"): - """ - Assigns the value of the right tensor to a new torch.nn.Parameter after validating shape compatibility. - - Parameters: - left (torch.Tensor or any): The tensor to compare shape with. - right (torch.Tensor or any): The tensor or value to be assigned. - tensor_name (str): The name of the tensor for error reporting (default is "unknown"). - - Returns: - torch.nn.Parameter: A new parameter containing the value of right. - - Raises: - ValueError: If the shapes of left and right do not match. - """ - - if left.shape != right.shape: - raise ValueError( - f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}" - ) - - if isinstance(right, torch.Tensor): - return torch.nn.Parameter(right.clone().detach()) - else: - return torch.nn.Parameter(torch.tensor(right)) - - -def load_weights_into_llama(model, param_config, params): - """ - Load weights into the LLaMA model from the provided parameters. - - This function assigns weights from the given parameters to the corresponding - layers of the LLaMA model. It handles the embedding layer, attention layers, - feedforward layers, and the output layer. The function also checks for weight - tying between the output head and the embedding layer. - - Args: - model: The LLaMA model instance into which weights will be loaded. - param_config (dict): A configuration dictionary containing model parameters, - including the number of layers (`n_layers`). - params (dict): A dictionary containing the weights to be loaded, with keys - corresponding to the model's architecture. - """ - model.tok_emb.weight = assign( - model.tok_emb.weight, - params["model.embed_tokens.weight"], - "model.embed_tokens.weight", - ) - - for l in range(param_config["n_layers"]): - - # Load attention weights - model.trf_blocks[l].att.assign_weights( - l, - params[f"model.layers.{l}.self_attn.q_proj.weight"], - params[f"model.layers.{l}.self_attn.k_proj.weight"], - params[f"model.layers.{l}.self_attn.v_proj.weight"], - params[f"model.layers.{l}.self_attn.o_proj.weight"], - ) - # Load FeedForward weights - model.trf_blocks[l].ff.assign_weights( - l, - fc1=params[f"model.layers.{l}.mlp.gate_proj.weight"], - fc2=params[f"model.layers.{l}.mlp.up_proj.weight"], - fc3=params[f"model.layers.{l}.mlp.down_proj.weight"], - ) - # Load RMS norm weights - model.trf_blocks[l].assign_weights( - l, - params[f"model.layers.{l}.input_layernorm.weight"], - params[f"model.layers.{l}.post_attention_layernorm.weight"], - ) - - # Load output layer weights - if "lm_head.weight" in params.keys(): - model.assign_weights( - params["model.norm.weight"], params["lm_head.weight"], "lm_head.weight" - ) - else: - model.assign_weights( - params["model.norm.weight"], - params["model.embed_tokens.weight"], - "model.embed_tokens.weight", - ) - - -def text_to_token_ids(text, tokenizer): - """ - Convert a given text into token IDs using the specified tokenizer. - - Args: - text (str): The input text to be tokenized. - tokenizer: An instance of a tokenizer that has an `encode` method. - - Returns: - torch.Tensor: A tensor containing the token IDs of the input text, - with an added batch dimension. - """ - encoded = tokenizer.encode(text) - encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension - return encoded_tensor - - -def token_ids_to_text(token_ids, tokenizer): - """ - Convert a tensor of token IDs to a human-readable text string. - - Args: - token_ids (torch.Tensor): A tensor containing token IDs, - typically with a batch dimension. - tokenizer (Tokenizer): An instance of a tokenizer that has a - decode method to convert token IDs to text. - - Returns: - str: The decoded text string corresponding to the input token IDs. - """ - flat = token_ids.squeeze(0) # remove batch dimension - return tokenizer.decode(flat.tolist()) - - -def generate( - model, - idx, - max_new_tokens, - context_size, - eos_id, - hook_handles, - temperature=0.0, - top_k=None, - tokenizer=None, - prompt=None, - do_print=True, - prefill_done_callback=None, -): - """ - Generate new tokens using the provided model based on the input sequence. - - Args: - model: The model used for generating tokens. It should accept input sequences and return logits. - idx (torch.Tensor): The input sequence of token indices (shape: (batch_size, sequence_length)). - max_new_tokens (int): The maximum number of new tokens to generate. - context_size (int): The number of tokens from the input sequence to consider for generation. - temperature (float, optional): The temperature for scaling logits. Higher values result in more random outputs. Default is 0.0 (no scaling). - top_k (int, optional): The number of top logits to consider for sampling. If None, all logits are used. Default is None. - eos_id (int, optional): The end-of-sequence token ID. If specified, generation will stop when this token is produced. Default is None. - - Returns: - torch.Tensor: The updated sequence of token indices after generation (shape: (batch_size, new_sequence_length)). - """ - # For-loop is the same as before: Get logits, and only focus on last time step - finished_prefill = False - - print(f"Starting prefill inference...") - - for i in range(max_new_tokens): - use_kv_cache = model.cfg["use_kv_cache"] - - if use_kv_cache: - if i == 0: - # Prefill phase - process entire sequence - idx_cond = idx[:, -context_size:] - input_pos = None - else: - # Generation phase with KV cache - single token, need to track position - # Extract only the last token - idx_cond = idx[:, -1:] - input_pos = torch.tensor([idx.shape[1] - 1], device=idx.device) - else: - # No KV cache - always process entire sequence (GEMM every time) - idx_cond = idx[:, -context_size:] - input_pos = None - with torch.no_grad(): - logits = model(idx_cond, input_pos=input_pos, use_kv_cache=use_kv_cache) - logits = logits[:, -1, :] - - # New: Filter logits with top_k sampling - if top_k is not None: - # Keep only top_k values - top_logits, _ = torch.topk(logits, top_k) - min_val = top_logits[:, -1] - logits = torch.where( - logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits - ) - - # New: Apply temperature scaling - if temperature > 0.0: - logits = logits / temperature - - # Apply softmax to get probabilities - probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) - - # Sample from the distribution - idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) - - # Otherwise same as before: get idx of the vocab entry with the highest logits value - else: - idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) - - # Only run the forward hook for the prefill stage, remove it afterwards to speed up inference - if not finished_prefill: - if hook_handles: - for handle in hook_handles: - handle.remove() - finished_prefill = True - - if ( - idx_next == eos_id - ): # Stop generating early if end-of-sequence token is encountered and eos_id is specified - break - - # Same as before: append sampled index to the running sequence - idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) - - # End timing the first iteration - if i == 0: - if prefill_done_callback is not None: - prefill_done_callback() - if do_print: - print(prompt) - - # print(f"\rGenerating token {i + 1}/{max_new_tokens}...", end="") - generated_text = token_ids_to_text(idx_next, tokenizer) - if do_print: - print(f"{generated_text}", end="", flush=True) - - print("\n\n") - return idx - - -def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): - """ - Cleans the input text by removing the header portion defined by the header_end token. - - Parameters: - text (str): The input text to be cleaned. - header_end (str): The token that marks the end of the header. Defaults to "assistant<|end_header_id|>\n\n". - - Returns: - str: The cleaned text, which is the substring after the header_end token. - If the token is not found, the original text is returned. - """ - - # Find the index of the first occurrence of "<|end_header_id|>" - index = text.find(header_end) - - if index != -1: - # Return the substring starting after "<|end_header_id|>" - return text[ - index + len(header_end) : - ].strip() # Strip removes leading/trailing whitespace - else: - # If the token is not found, return the original text - return text diff --git a/iron/applications/llama_3.2_1b/test.py b/iron/applications/llama_3.2_1b/test.py index 933b7d5e..03a7639e 100644 --- a/iron/applications/llama_3.2_1b/test.py +++ b/iron/applications/llama_3.2_1b/test.py @@ -5,13 +5,14 @@ import subprocess import pytest from pathlib import Path +import os test_dir = Path(__file__).parent -weights_dir = Path("/srv") +weights_dir = Path(os.environ.get("IRON_EXAMPLE_WEIGHTS_DIR", "/srv")) def generate_test_params(): - prompt_lengths = [2048, 13] + prompt_lengths = [1024, 13] num_tokens_list = [40, 1] params = [] @@ -27,13 +28,12 @@ def generate_test_params(): @pytest.mark.metrics( - TTFT=r"Prefill time: (?P[\d\.e\+-]+) seconds", - TPS=r"Tokens per second: (?P[\d\.e\+-]+)", - Num_Tokens=r"Tokens generated: (?P[\d\.e\+-]+)", + TTFT=r"\[Prefill\]\s*Time to first token:\s*(?P[\d\.e\+-]+) s", + TPS=r"\[Decode\]\s*Tokens per second: (?P[\d\.e\+-]+)", ) @pytest.mark.parametrize("prompt_len,num_tokens", params, ids=names) def test_llama_3_2_1b(prompt_len, num_tokens): - command = f"python3 {test_dir}/inference.py {weights_dir}/llama3.2-1b/model.safetensors {weights_dir}/llama3.2-1b/tokenizer.model --prompt_len {prompt_len} --num_tokens {num_tokens}" + command = f"python3 {test_dir}/llama_npu.py {weights_dir}/llama3.2-1b/model.safetensors {weights_dir}/llama3.2-1b/tokenizer.model --num-tokens {num_tokens} --prompt-len {prompt_len}" result = subprocess.run( command, diff --git a/iron/applications/llama_3.2_1b/torch_to_npy.py b/iron/applications/llama_3.2_1b/torch_to_npy.py deleted file mode 100644 index e7d06be0..00000000 --- a/iron/applications/llama_3.2_1b/torch_to_npy.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import torch -import argparse -import numpy as np -import os -import shutil - - -def torch_to_npy(inp_file_path, outp_file_path): - # Load the torch file - data = torch.load(inp_file_path) - # Convert the tensor to a numpy array of floats - data_np = data.to(torch.float32).numpy() - # Compare the values between data and data_np - if not torch.equal(data, torch.from_numpy(data_np)): - raise ValueError("Mismatch between original data and converted numpy array.") - - # Save the array to a npy file - np.save(outp_file_path, data_np) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert torch files to npy files.") - parser.add_argument( - "file_path", - type=str, - help="Path to the torch file or directory containing torch files", - ) - args = parser.parse_args() - file_path = args.file_path - - output_dir = os.path.join("results", f"{os.path.basename(file_path)}_npy") - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - os.makedirs(output_dir) - - # Check if the file path is a directory - if os.path.isdir(file_path): - for file_name in os.listdir(file_path): - if file_name.endswith(".pt") or file_name.endswith(".pth"): - full_path = os.path.join(file_path, file_name) - output_file_path = os.path.join( - output_dir, file_name.replace(".pt", ".npy").replace(".pth", ".npy") - ) - torch_to_npy(full_path, output_file_path) - else: - torch_to_npy(file_path) diff --git a/iron/common/__init__.py b/iron/common/__init__.py index 4fa9ae3b..68cafac6 100644 --- a/iron/common/__init__.py +++ b/iron/common/__init__.py @@ -3,8 +3,16 @@ """Common utilities and base classes for IRON operators.""" -from .aie_base import AIEOperatorBase, AIEOperatorConstraintError -from .aie_context import AIEContext +from .base import ( + AIEOperatorBase, + MLIROperator, + CompositeOperator, + CompositeCallable, + AIEBuffer, + SingleXclbinCallable, + AIERuntimeArgSpec, +) +from .context import AIEContext from .compilation import ( XclbinArtifact, InstsBinArtifact, @@ -13,4 +21,4 @@ SourceArtifact, PythonGeneratedMLIRArtifact, ) -from .aie_device_manager import AIEDeviceManager +from .device_manager import AIEDeviceManager diff --git a/iron/common/aie_base.py b/iron/common/aie_base.py deleted file mode 100644 index 5238f6f5..00000000 --- a/iron/common/aie_base.py +++ /dev/null @@ -1,229 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import os -from pathlib import Path -from abc import ABC, abstractmethod -import logging -import time -import torch -from ml_dtypes import bfloat16 - -import aie.utils.config -from . import compilation as comp -from .aie_context import AIEContext -from .aie_device_manager import AIEDeviceManager, pyxrt -from .utils import numpy_to_torch, torch_to_numpy - - -class AIEOperatorBase(ABC): - """Base class for AIE-accelerated operations""" - - @classmethod - def get_default_context(cls): - """One global 'default' context if none is specified""" - if not hasattr(AIEOperatorBase, "_default_context"): - AIEOperatorBase._default_context = AIEContext() - return AIEOperatorBase._default_context - - def __init__(self, context=None): - self.artifacts = ( - [] - ) # CompilationArtifact objects are uniqued within the context - self.kernels = {} # Name -> (xclbin_path, xclbin_kernel_name, insts_path) - self.buffers = {} # Name -> required buffer size in bytes - self.buffer_static_data = {} - self.runlist = ( - [] - ) # List of (kernel_name, buffers_name, buffer_name...), will be executed in sequence - - # AIE runtime state - self.buffer_bos = {} # Buffer name -> buffer object - self.xrt_kernels = ( - {} - ) # Kernel name -> (XRT context, XRT kernel object, instruction buffer object, instruction length) - self.xrt_runlist = None - - if context is None: - context = self.get_default_context() - context.register_operator(self) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def add_kernel( - self, - name: str, - xclbin_artifact: comp.XclbinArtifact, - xclbin_kernel_name: str, - insts_artifact: comp.InstsBinArtifact, - ): - assert name not in self.kernels - self.kernels[name] = (xclbin_artifact, xclbin_kernel_name, insts_artifact) - - def add_buffer(self, name, count, dtype=bfloat16, static_data=None): - assert name not in self.buffers - self.buffers[name] = count * np.dtype(dtype).itemsize - if static_data is not None: - assert ( - static_data.nbytes <= self.buffers[name] - ), f"Static data for buffer {name} exceeds allocated size: expected {self.buffers[name]} bytes, got {static_data.nbytes} bytes." - static_data_bytes = static_data.flatten().view(np.uint8).tobytes() - if static_data_bytes not in self.context.static_data_pool: - self.context.static_data_pool[static_data_bytes] = None - self.buffer_static_data[name] = next( - k - for k, v in self.context.static_data_pool.items() - if k == static_data_bytes - ) - - def add_to_runlist(self, kernel_name, *args): - if kernel_name not in self.kernels: - raise RuntimeError(f"No such kernel: {kernel_name}") - for arg in args: - if arg not in self.buffers: - raise RuntimeError(f"No such buffer: {arg}") - self.runlist.append((kernel_name, *args)) - - def get_bo(self, buffer_name): - return self.buffer_bos[buffer_name] - - def read_buffer(self, buffer_name, shape, copy=False, dtype=bfloat16): - """Read buffer and return values as a numpy array""" - # Create a byte accessible memory view of the buffer object - mv = self.get_bo(buffer_name).map() - - # Interpret the buffer as a 1-dimensional array then change its view to the expected shape - arr = np.frombuffer(mv, dtype=dtype, count=np.prod(shape)).reshape(shape) - - # Return an independent copy of the array if needed - return arr.copy() if copy else arr - - def read_buffer_as_torch(self, buffer_name, shape, dtype=bfloat16): - return numpy_to_torch(self.read_buffer(buffer_name, shape, dtype)) - - def write_buffer(self, buffer_name, array): - """Write buffer from a numpy array into a XRT buffer object""" - if buffer_name in self.buffer_static_data: - raise RuntimeError(f"Cannot write to static buffer: {buffer_name}") - - # Normalize the source - if isinstance(array, torch.Tensor): - src = torch_to_numpy(array) - else: - src = np.asarray(array) - - # Create a flattened 1D byte view of the source - src_bytes = src.ravel().view(np.uint8) - - bo = self.get_bo(buffer_name) - mv = bo.map() # byte accessible memory view - # Interpret the buffer as a 1-dimensional array - dst_bytes = np.frombuffer(mv, dtype=np.uint8, count=bo.size()) - - # The BO is an existing array, so copyto() can be called, which doesn't create a new array - np.copyto(dst_bytes[: src_bytes.size], src_bytes, casting="no") - - @abstractmethod - def set_up_artifacts(self): - """ - Subclasses should overwrite this method to set up their required dependenices and runtime runlist, kernels and buffers with calls to add_artifacts(), add_kernel(), add_buffer(), and add_to_runlist(). - Note: This method should only *describe* the required artifacts and runtime buffers, and not yet do any computation or compilation. - Compilation will be handled automatically based on the provided description. - """ - pass - - @abstractmethod - def set_up_runtime(self): - pass - - def compile(self, dry_run=None): - """ - Set up the operator and compile any necessary artifacts. - Subclasses are expected to overwrite set_up(); they may register any artifacts that they need to be compiled there. - """ - context = self.context - self.set_up_artifacts() - self._move_artifact_paths() - work_list = comp.get_work_list(self.artifacts) - compilation_rules = [ - comp.GenerateMLIRFromPythonCompilationRule(dry_run=dry_run), - comp.PeanoCompilationRule( - context.peano_dir, context.mlir_aie_dir, dry_run=dry_run - ), - comp.ArchiveCompilationRule(context.peano_dir, dry_run=dry_run), - comp.AieccCompilationRule( - context.build_dir, - context.peano_dir, - context.mlir_aie_dir, - dry_run=dry_run, - ), - ] - if work_list: - logging.info( - f"Compiling {len(work_list)} new artifacts for AIE operator {self.__class__.__name__}: {', '.join(str(artifact.path.name) for artifact in work_list)}" - ) - comp.compile(compilation_rules, work_list) - - def add_artifacts(self, artifacts): - self.artifacts.extend(artifacts) - - def _move_artifact_paths(self): - """Make all artifacts paths point into the build directory (source artifacts into the ironclad source directory). This doesn't phyisically move files; this function is called before artifact generation.""" - context = self.context - todo = self.artifacts.copy() - while todo: - artifact = todo[0] - todo.pop(0) - if isinstance(artifact, comp.SourceArtifact): - artifact.set_path(context.base_dir / artifact.path) - else: - artifact.set_path(context.build_dir / artifact.path) - todo.extend(artifact.depends) - - def run_runlist(self): - elapsed = 0.0 - if self.xrt_runlist is None: - # Execute as separate xclbin kernel invocations - for i, (kernel_name, *buffer_args) in enumerate(self.runlist): - context, xrt_kernel, insts_bo, insts_len = self.xrt_kernels[kernel_name] - insts_bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) - bos = [self.buffer_bos[buffer_arg] for buffer_arg in buffer_args] - for bo in bos: - bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) - opcode = 3 - start = time.perf_counter() - run = xrt_kernel(opcode, insts_bo, insts_len, *bos) - result = run.wait() - stop = time.perf_counter() - elapsed += stop - start - if result != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: - raise RuntimeError( - f"Kernel {kernel_name} did not complete correctly: {result}" - ) - for bo in bos: - bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE) - else: - bos = set( - self.buffer_bos[buffer_arg] - for _, *buffer_args in self.runlist - for buffer_arg in buffer_args - ) - insts_bos = set( - self.xrt_kernels[kernel_name][2] for (kernel_name, *_) in self.runlist - ) - for bo in bos | insts_bos: - bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) - start = time.perf_counter() - self.xrt_runlist.execute() - self.xrt_runlist.wait() - stop = time.perf_counter() - for bo in bos: - bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE) - elapsed = stop - start - return elapsed - - -class AIEOperatorConstraintError(RuntimeError): - pass diff --git a/iron/common/aie_context.py b/iron/common/aie_context.py deleted file mode 100644 index 804499f6..00000000 --- a/iron/common/aie_context.py +++ /dev/null @@ -1,211 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import logging -from pathlib import Path -import os - -from .aie_device_manager import AIEDeviceManager, pyxrt -from . import compilation as comp -import aie.utils.config - - -class AIEContext: - """Context for managing AIE operator compilation and runtime state""" - - def __init__(self, use_runlist=True): - self.operators = [] - self.static_data_pool = {} - self.device_manager = AIEDeviceManager() - self.base_dir = Path(__file__).parent.parent.parent - self.build_dir = Path(os.getcwd()) / "build" - self.mlir_aie_dir = Path(aie.utils.config.root_path()) - self.peano_dir = Path(aie.utils.config.peano_install_dir()) - # Disable the XRT runlist sacrifices performance by executing kernels individually as separate xclbin invocations for easier debugging (can tell which part of runlist execution failed) - self.use_runlist = use_runlist - self._runtime_prepared = False - - def register_operator(self, operator): - """Register an operator with this context""" - if self._runtime_prepared: - raise RuntimeError("Cannot register operators after runtime is prepared") - operator.context = self - self.operators.append(operator) - - def compile_all(self): - """Compile all registered operators""" - self.build_dir.mkdir(parents=True, exist_ok=True) - for op in self.operators: - op.compile() - - def prepare_runtime(self): - """Setup XRT runtime for all registered operators""" - if self._runtime_prepared: - return - - for op in self.operators: - op.set_up_runtime() - - # Pools of preallocated buffer objects; each buffer object is allocated - # once at program start and then reused across operators where possible. - bo_pools = {} - page_sz = 4096 - get_pool_sz = lambda x: (x + page_sz - 1) // page_sz * page_sz - - # Allocate static buffers first - for buffer_data in self.static_data_pool: - logging.debug( - f"Allocating static buffer with size {len(buffer_data)} bytes." - ) - bo = pyxrt.bo( - self.device_manager.device, - len(buffer_data), - pyxrt.bo.host_only, - 0x10000, - ) - bo.write(np.frombuffer(buffer_data, dtype=np.uint8), 0) - self.static_data_pool[buffer_data] = bo - - for op in self.operators: - if len(op.kernels) == 0: - continue - - logging.info(f"Preparing runtime for AIE operator: {op.__class__.__name__}") - - # Set up kernels - for kernel_name, (xclbin, xclbin_kernel_name, insts) in op.kernels.items(): - handle = self.device_manager.get_kernel_handle( - str(xclbin.path), xclbin_kernel_name, str(insts.path) - ) - op.xrt_kernels[kernel_name] = ( - handle.context, - handle.kernel, - handle.insts_bo, - len(handle.insts), - ) - - # If multiple buffers (of the same binned size) are used in the - # same kernel invocation OR across different invocations with shared - # buffers, they require separate allocations. - conflicting_buffers = {} # map buffer -> {set of conflicting buffers} - buffer_to_runlist_entries = {} # map buffer -> set of runlist entry indices - - # First pass: track which buffers appear in which runlist entries - for idx, (kernel, *args) in enumerate(op.runlist): - for arg in args: - buffer_to_runlist_entries.setdefault(arg, set()).add(idx) - - # Second pass: determine conflicts - for idx, (kernel, *args) in enumerate(op.runlist): - for arg in args: - if arg in op.buffer_static_data: - # Static buffers never conflict - continue - pool_sz = get_pool_sz(op.buffers[arg]) - - # Buffers conflict if they're in the same runlist entry - conflicting_args = { - a for a in args if get_pool_sz(op.buffers[a]) == pool_sz - } - {arg} - - # Also conflict with buffers in other runlist entries that share - # a buffer with this entry - for other_arg in args: - if other_arg == arg: - continue - for other_idx in buffer_to_runlist_entries.get( - other_arg, set() - ): - if other_idx != idx: - _, *other_args = op.runlist[other_idx] - conflicting_args.update( - { - a - for a in other_args - if get_pool_sz(op.buffers[a]) == pool_sz - and a != arg - } - ) - - conflicting_buffers[arg] = conflicting_buffers.get( - arg, set() - ).union(conflicting_args) - - # Allocate buffers - buffer_allocations = {} - for buffer_name, buffer_min_size in op.buffers.items(): - if buffer_name in op.buffer_static_data: - static_data = op.buffer_static_data[buffer_name] - op.buffer_bos[buffer_name] = self.static_data_pool[static_data] - continue - - alloc_pool = get_pool_sz(buffer_min_size) - alloc_idx = 0 - for conflict in conflicting_buffers.get(buffer_name, set()): - if conflict not in buffer_allocations: - continue - conflict_pool, conflict_idx = buffer_allocations[conflict] - alloc_idx = max(alloc_idx, conflict_idx + 1) - - assert 0 <= alloc_idx < len(bo_pools.get(alloc_pool, [])) + 1 - if alloc_idx == len(bo_pools.get(alloc_pool, [])): - bo = pyxrt.bo( - self.device_manager.device, - alloc_pool, - pyxrt.bo.host_only, - 0x10000, - ) - bo_pools.setdefault(alloc_pool, []).append(bo) - - buffer_allocations[buffer_name] = (alloc_pool, alloc_idx) - op.buffer_bos[buffer_name] = bo_pools[alloc_pool][alloc_idx] - - # Setup runlist - _, (first_xclbin, first_xclbin_kernel_name, first_insts) = next( - iter(op.kernels.items()) - ) - handle = self.device_manager.get_kernel_handle( - str(first_xclbin.path), first_xclbin_kernel_name, str(first_insts.path) - ) - context = handle.context - if self.use_runlist: - op.xrt_runlist = pyxrt.runlist(context) - for i, (kernel_name, *buffer_args) in enumerate(op.runlist): - this_context, xrt_kernel, insts_bo, insts_len = op.xrt_kernels[ - kernel_name - ] - assert this_context == context - opcode = 3 - run = pyxrt.run(xrt_kernel) - run.set_arg(0, opcode) - run.set_arg(1, insts_bo) - run.set_arg(2, insts_len) - for j, buffer_arg in enumerate(buffer_args): - run.set_arg(j + 3, op.buffer_bos[buffer_arg]) - op.xrt_runlist.add(run) - else: - op.xrt_runlist = None - - # Log allocation info - bo_count = sum(len(pool) for pool in bo_pools.values()) - bo_footprint = sum(len(pool) * pool_sz for pool_sz, pool in bo_pools.items()) - logging.info( - f"Allocated {bo_count} total buffer objects with a total memory footprint of " - + ( - f"{bo_footprint//1024//1024} MiB." - if bo_footprint >= 1024 * 1024 - else f"{bo_footprint//1024} KiB." - ) - ) - static_data_footprint = sum(len(data) for data in self.static_data_pool) - logging.info( - f"Allocated {len(self.static_data_pool)} static buffers with a total memory footprint of " - + ( - f"{static_data_footprint//1024//1024} MiB." - if static_data_footprint >= 1024 * 1024 - else f"{static_data_footprint//1024} KiB." - ) - ) - - self._runtime_prepared = True diff --git a/iron/common/base.py b/iron/common/base.py new file mode 100644 index 00000000..641061bb --- /dev/null +++ b/iron/common/base.py @@ -0,0 +1,356 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import os +from pathlib import Path +from abc import ABC, abstractmethod +import logging +import time +import torch +from ml_dtypes import bfloat16 + +import aie.utils.config +from . import compilation as comp +from .context import AIEContext +from .device_manager import AIEDeviceManager, pyxrt +from .utils import numpy_to_torch, torch_to_numpy +from .compilation import ( + XclbinArtifact, + InstsBinArtifact, + KernelObjectArtifact, + KernelArchiveArtifact, + SourceArtifact, + PythonGeneratedMLIRArtifact, +) + + +class AIEOperatorBase(ABC): + """Base class for AIE-accelerated operations""" + + def __init__(self, context=None): + self.artifacts = comp.CompilationArtifactGraph( + [] + ) # CompilationArtifact objects are uniqued within the context + if context is None: + context = self.get_default_context() + context.register_operator(self) + self.context = context + + @abstractmethod + def set_up_artifacts(self): + """ + Subclasses should overwrite this method to set up their required dependenices and runtime runlist, kernels and buffers with calls to add_artifacts(), add_kernel(), add_buffer(), and add_to_runlist(). + Note: This method should only *describe* the required artifacts and runtime buffers, and not yet do any computation or compilation. + Compilation will be handled automatically based on the provided description. + """ + pass + + @abstractmethod + def get_arg_spec(self): + pass + + @abstractmethod + def get_callable(self): + pass + + @classmethod + def get_default_context(cls): + """One global 'default' context if none is specified""" + if not hasattr(AIEOperatorBase, "_default_context"): + AIEOperatorBase._default_context = AIEContext() + return AIEOperatorBase._default_context + + def compile(self, dry_run=False): + """ + Set up the operator and compile any necessary artifacts. + Subclasses are expected to overwrite set_up(); they may register any artifacts that they need to be compiled there. + """ + self.set_up_artifacts() + comp.compile( + self.context.compilation_rules, + self.artifacts, + self.context.build_dir, + dry_run=dry_run, + ) + return self + + def add_artifacts(self, artifacts): + for artifact in artifacts: + self.artifacts.add(artifact) + + +def sync_to_device(bos): + for bo in bos: + bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) + + +def sync_from_device(bos): + for bo in bos: + bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE) + + +def execute_runlist(runlist): + runlist.execute() + runlist.wait() + + +class MLIROperator(AIEOperatorBase, ABC): + """Base class for AIE-accelerated operations defined by a single MLIR source""" + + def __init__(self, *args, **kwargs): + self.kernel_archive = f"{self.get_operator_name()}_kernels.a" + AIEOperatorBase.__init__(self, *args, **kwargs) + + @abstractmethod + def get_operator_name(self): + pass + + @abstractmethod + def get_mlir_artifact(self): + pass + + @abstractmethod + def get_kernel_artifacts(self): + pass + + def get_artifacts(self, prefix=""): + operator_name = prefix + self.get_operator_name() + mlir_artifact = self.get_mlir_artifact() + kernel_deps_inputs = self.get_kernel_artifacts() + if len(kernel_deps_inputs) > 0: + # FIXME: currently hard-coding that the design will accept this argument as an input if it uses kernels + # Also not handling name collisions of kernels with the same name + mlir_artifact.callback_kwargs["kernel_archive"] = self.kernel_archive + kernel_deps = ( + [ + KernelArchiveArtifact( + self.kernel_archive, + dependencies=kernel_deps_inputs, + ) + ] + if kernel_deps_inputs + else [] + ) + xclbin_artifact = XclbinArtifact( + f"{operator_name}.xclbin", + mlir_input=mlir_artifact, + dependencies=[mlir_artifact] + kernel_deps, + ) + insts_artifact = InstsBinArtifact( + f"{operator_name}.bin", + mlir_input=mlir_artifact, + dependencies=[mlir_artifact], + ) + return xclbin_artifact, insts_artifact + + def set_up_artifacts(self): + xclbin_artifact, insts_artifact = self.get_artifacts() + self.xclbin_artifact = xclbin_artifact + self.insts_artifact = insts_artifact + self.add_artifacts([xclbin_artifact, insts_artifact]) + + def get_callable(self): + return SingleXclbinCallable( + xclbin_path=self.xclbin_artifact.filename, + kernel_name=self.xclbin_artifact.kernel_name, + insts_bin_path=self.insts_artifact.filename, + args_spec=self.get_arg_spec(), + ) + + +class CompositeOperator(AIEOperatorBase, ABC): + """Base class for composite operators that chain multiple sub-operators""" + + def __init__(self, context=None): + super().__init__(context) + + +class AIERuntimeArgSpec: + def __init__(self, direction, shape, dtype=bfloat16): + self.shape = shape + self.dtype = dtype + assert direction in {"in", "out", "inout"} + self.direction = direction + + def __repr__(self): + return f"AIERuntimeArgSpec(direction={self.direction}, shape={self.shape}, dtype={self.dtype})" + + +class AIEBuffer: + def __init__(self, shape, dtype=bfloat16, bo=None, device_manager=None): + size = np.prod(shape) * np.dtype(dtype).itemsize + self.shape = shape + self.dtype = dtype + self.bo = bo + self.on = "cpu" + self.device_manager = device_manager or AIEDeviceManager() + if not self.bo: + self.bo = pyxrt.bo( + self.device_manager.device, + size, + pyxrt.bo.host_only, + 0x10000, + ) + self.memory_view = self.bo.map() + self.subviews = [] + + def subbuffer(self, length, offset, shape, dtype=None): + if dtype is None: + dtype = self.dtype + assert np.prod(shape) == length + itemsize = np.dtype(dtype).itemsize + assert offset >= 0 + assert offset * itemsize <= np.prod(self.shape) * np.dtype(self.dtype).itemsize + assert ( + length * itemsize + offset * itemsize + <= np.prod(self.shape) * np.dtype(self.dtype).itemsize + ) + sub_bo = pyxrt.bo( + self.bo, # parent bo + length * itemsize, # size + offset * itemsize, # offset + ) + sub_buffer = AIEBuffer( + shape=shape, dtype=dtype, bo=sub_bo, device_manager=self.device_manager + ) + sub_buffer.on = self.on + self.subviews.append(sub_buffer) + return sub_buffer + + def view(self, shape): + assert np.prod(shape) == np.prod(self.shape) + sub_buffer = AIEBuffer( + shape=shape, + dtype=self.dtype, + bo=self.bo, + device_manager=self.device_manager, + ) + sub_buffer.on = self.on + self.subviews.append(sub_buffer) + return sub_buffer + + def view_as_np(self): + self.to("cpu") + # Interpret the buffer as a 1-dimensional array then change its view to the expected shape + return np.frombuffer( + self.memory_view, dtype=self.dtype, count=np.prod(self.shape) + ).reshape(self.shape) + + def view_as_torch(self): + return numpy_to_torch(self.view_as_np()) + + def to(self, dest): + direction = { + "npu": pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE, + "cpu": pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE, + } + if dest not in direction: + raise RuntimeError(f"Unknown destination for AIEBuffer.to(): {dest}") + if self.on == dest: + return self + direction = direction[dest] + self.bo.sync(direction) + self.on = dest + todo = self.subviews.copy() + while todo: + sub_buffer = todo.pop() + sub_buffer.on = self.on + todo.extend(sub_buffer.subviews) + return self + + @staticmethod + def from_np(buffer): + shape = buffer.shape + dtype = buffer.dtype + size = np.prod(shape) * np.dtype(dtype).itemsize + aie_buffer = AIEBuffer(shape=shape, dtype=dtype) + aie_buffer.view_as_np()[:] = buffer + aie_buffer.to("npu") + return aie_buffer + + @staticmethod + def from_torch(tensor): + return AIEBuffer.from_np(torch_to_numpy(tensor)) + + +class SingleXclbinCallable: + def __init__( + self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None + ): + self.device_manager = device_manager or AIEDeviceManager() + self.context, self.xrt_kernel = self.device_manager.get_context_and_kernel( + str(xclbin_path), kernel_name + ) + with open(str(insts_bin_path), "rb") as f: + instructions = np.frombuffer(f.read(), dtype=np.uint32) + insts_bo = pyxrt.bo( + self.device_manager.device, + instructions.nbytes, + pyxrt.bo.cacheable, + self.xrt_kernel.group_id(1), + ) + insts_bo.write(instructions.view(np.uint8), 0) + self.insts_buffer = AIEBuffer( + shape=(len(instructions),), dtype=np.uint32, bo=insts_bo + ) + self.insts_buffer.to("npu") + self.args_spec = args_spec + + def __call__(self, *buffers): + assert len(buffers) == len(self.args_spec) + # assert all( + # np.prod(buffers[i].shape) >= np.prod(self.args_spec[i].shape) and buffers[i].dtype == self.args_spec[i].dtype + # for i in range(len(buffers)) + # ), "Input buffer shapes or dtypes do not match expected argument specification." + self.insts_buffer.to("npu") + for buf in buffers: + buf.to("npu") + opcode = 3 + bos = [buffer.bo for buffer in buffers] + run = self.xrt_kernel( + opcode, self.insts_buffer.bo, self.insts_buffer.shape[0], *bos + ) + ret_code = run.wait() + if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: + raise RuntimeError(f"Kernel did not complete correctly: {ret_code}") + + +class PatchableSingleXclbinCallable(SingleXclbinCallable): + def __init__( + self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None + ): + super().__init__( + xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager + ) + self.baseline_instructions = self.insts_buffer.view_as_np().copy() + + def patch(self, patches): + """Apply patches with masking: dict of {position: (value, mask)}.""" + insts = self.insts_buffer.view_as_np() + insts[:] = self.baseline_instructions + for pos, (val, mask) in patches.items(): + insts[pos] = (np.int64(insts[pos]) & ~mask) | (val & mask) + self.insts_buffer.to("npu") + + +class CompositeCallable: + """Callable for executing a sequence of sub-operators""" + + def __init__(self, sequence, intermediate_buffers=None): + """ + Args: + sequence: List of (callable, args_indices) tuples. + args_indices is a list of indices into the combined list of [inputs, outputs, intermediates]. + intermediate_buffers: List of AIEBuffer objects for intermediate results. + """ + self.sequence = sequence + self.intermediate_buffers = intermediate_buffers or [] + + def __call__(self, *args): + # args contains inputs and outputs + all_buffers = list(args) + self.intermediate_buffers + + for op_callable, indices in self.sequence: + op_args = [all_buffers[i] for i in indices] + op_callable(*op_args) diff --git a/iron/common/compilation.py b/iron/common/compilation.py deleted file mode 100644 index 2cbaa916..00000000 --- a/iron/common/compilation.py +++ /dev/null @@ -1,630 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -This file implements a simple Python-based build system. You specify what you -want to compile (*artifacts*) through subclasses of `CompilationArtifact`. -Each artifact can have a list of depenencies of other artifacts that it relies -on. Each artifact corresponds to exactly one file. If a file with a matching -name already exists, and all its dependencies are built and older than the file, -then the existing file will be reused. - -For each file name, artifacts are singletons. You create artifacts by calling -the `new` class method of the appropriate class. This ensures that artifact -objects are uniqued, i.e., calling `new` twice with the same file name will -return the same object. - -There is a special artifact for source files that do not need to get generated, -`SourceArtifact`. It is likely that in your compilation dependency graph, -the leaf nodes will be `SourceArtifact`s. - -You specify how to generate (compile) an artifact through *rules*, which are -expressed as subclasses of `CompilationRule`. This class requires you to -implement two methods: `matches` and `compile`. During compilation, we will -call `matches` on the set of remaining artifacts to see if the given rule is -able to produce any of the artifacts not available yet. If this function -returns `True`, we will call `compile` on the rule to generate the artifact. -`compile` returns a new list of artifacts, which may be the same one as -before; however, if `matches()==True`, at least one of the artifacts in the -list must be made available after calling `compile()`. -""" - -from abc import ABC, abstractmethod -from pathlib import Path -import os.path -import zlib -import logging -import subprocess -import importlib.util -from contextlib import nullcontext -from aie.extras.context import mlir_mod_ctx - -# Compilation Artifacts -# -------------------------------------------------------------------------- - - -class CompilationArtifact(ABC): - _instances = {} - - @classmethod - def new(cls, path, *args, **kwargs): - """Uniques artifacts based on absolute file path; any two artifacts with the same absolute path will be represented by the same object.""" - path = Path(path) - abs_path = path.absolute() - if abs_path not in cls._instances: - cls._instances[abs_path] = None - instance = cls(path, *args, **kwargs) - cls._instances[abs_path] = instance - else: - assert ( - type(cls._instances[abs_path]) == cls - ), f"Artifact with path {abs_path} is already registered with a different type" - return cls._instances[abs_path] - - def __init__(self, path, depends=None): - abs_path = path.absolute() - assert ( - abs_path in self._instances - ), "do not construct artifact objects directly; call the get() class method instead for uniquing" - self.path: Path = path - self.depends: list[CompilationArtifact] = depends if depends is not None else [] - self.users: list[CompilationArtifact] = ( - [] - ) # List of ancestor artifacts that depend on this artifact - for dependency in self.depends: - dependency.users.append(self) - self.fake_available = False - - def __repr__(self): - return f"{self.__class__.__name__}(path={self.path}, depends={self.depends})" - - def set_path(self, new_path): - old_abs_path = self.path.absolute() - new_path = Path(new_path) - abs_path = new_path.absolute() - self.path = new_path - del CompilationArtifact._instances[old_abs_path] - CompilationArtifact._instances[abs_path] = self - - def is_available(self): - if self.fake_available: - return True - if not self.path.exists(): - return False - for dependency in self.depends: - # If any of our dependencies' dependencies are outdated, this artifact is also outdated - if not dependency.is_available(): - return False - # If any of our direct dependencies are newer than this artifact, this artifact is invalid - if dependency.is_newer_than(os.path.getmtime(str(self.path))): - return False - return True - - def is_newer_than(self, time): - if self.fake_available: - return True - return os.path.getmtime(str(self.path)) > time - - def delete(self): - for user in self.users: - user.depends.remove(self) - del self._instances[self.path.absolute()] - return self.users - - -class SourceArtifact(CompilationArtifact): - pass - - -class XclbinArtifact(CompilationArtifact): - def __init__( - self, path, depends, kernel_name="MLIR_AIE", extra_flags=None, xclbin_input=None - ): - super().__init__(path, depends) - self.kernel_name = kernel_name - self.extra_flags = extra_flags if extra_flags is not None else [] - self.xclbin_input = xclbin_input - - -class InstsBinArtifact(CompilationArtifact): - def __init__(self, path, depends, extra_flags=None): - super().__init__(path, depends) - self.extra_flags = extra_flags if extra_flags is not None else [] - - -class KernelObjectArtifact(CompilationArtifact): - def __init__(self, path, depends, extra_flags=None, rename_symbols=None): - super().__init__(path, depends) - self.extra_flags = extra_flags if extra_flags is not None else [] - self.rename_symbols = rename_symbols if rename_symbols is not None else {} - - -class KernelArchiveArtifact(CompilationArtifact): - pass - - -class PythonGeneratedMLIRArtifact(CompilationArtifact): - def __init__( - self, - path, - import_path, - callback_fn, - callback_args=None, - callback_kwargs=None, - requires_context=False, - ): - self.import_path = import_path - self.callback_fn = callback_fn - self.callback_args = callback_args if callback_args is not None else [] - self.callback_kwargs = callback_kwargs if callback_kwargs is not None else {} - self.requires_context = requires_context - super().__init__(path) - - def is_available(self): - if self.fake_available: - return True - is_available = super().is_available() - if is_available: - # Force regeneration if the Python source is changed - return os.path.getmtime(str(self.path)) >= os.path.getmtime( - self.import_path - ) - return is_available - - -# Compilation Rules -# -------------------------------------------------------------------------- - - -class CompilationRule(ABC): - def __init__(self, dry_run=None): - self.dry_run = dry_run - - @abstractmethod - def matches(self, artifact: list[CompilationArtifact]) -> bool: - pass - - @abstractmethod - def compile( - self, artifacts: list[CompilationArtifact] - ) -> list[CompilationArtifact]: - pass - - -class GenerateMLIRFromPythonCompilationRule(CompilationRule): - def matches(self, artifacts): - return any( - isinstance(artifact, PythonGeneratedMLIRArtifact) - and len(artifact.depends) == 0 - for artifact in artifacts - ) - - def compile(self, artifacts): - """Generate MLIR from a Python callback that uses the MLIR bindings""" - for i, artifact in enumerate(artifacts): - if not isinstance(artifact, PythonGeneratedMLIRArtifact): - continue - if not all(dependency.is_available() for dependency in artifact.depends): - continue - - if self.dry_run is None: - # Import the Python source file - spec = importlib.util.spec_from_file_location( - Path(artifact.import_path).name, artifact.import_path - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - # We only initiate an MLIR context if requested; otherwise, it is expected that the callback creates the context - ctx_callback = lambda: ( - mlir_mod_ctx() if artifact.requires_context else nullcontext() - ) - with ctx_callback() as ctx: - callback_function = getattr(module, artifact.callback_fn) - mlir_code = callback_function( - *artifact.callback_args, **artifact.callback_kwargs - ) - # Stringify the generated MLIR - if artifact.requires_context: - mlir_code = str(ctx.module) - else: - mlir_code = str(mlir_code) - - with open(artifact.path, "w") as f: - f.write(mlir_code) - - # Now that the artifact is generated, replace this artifact with the MLIR source code file - old_users = artifact.delete() - new_artifact = SourceArtifact.new(artifact.path) - for user in old_users: - user.depends.append(new_artifact) - if self.dry_run is not None: - python_cmd = "" - # Import the Python source file - python_cmd += ( - "import sys; sys.path.append(" - f'"{Path(artifact.import_path).parent}"' - "); " - ) - python_cmd += f"from {Path(artifact.import_path).stem} import {artifact.callback_fn}; " - - # Check if we need to import device classes - # Device classes have __module__ == 'abc' but need to be imported from aie.iron.device - device_classes = set() - for arg in artifact.callback_args: - obj_module = type(arg).__module__ - obj_class = type(arg).__name__ - if obj_module == "abc" and ( - obj_class.startswith("NPU") or obj_class.startswith("XCVC") - ): - device_classes.add(obj_class) - for v in artifact.callback_kwargs.values(): - obj_module = type(v).__module__ - obj_class = type(v).__name__ - if obj_module == "abc" and ( - obj_class.startswith("NPU") or obj_class.startswith("XCVC") - ): - device_classes.add(obj_class) - - if device_classes: - python_cmd += f"from aie.iron.device import {', '.join(sorted(device_classes))}; " - - if artifact.requires_context: - python_cmd += "from aie.extras.context import mlir_mod_ctx; " - python_cmd += "with mlir_mod_ctx() as ctx: " - python_cmd += f"mlir_code = {artifact.callback_fn}({', '.join(map(GenerateMLIRFromPythonCompilationRule._repr_for_codegen, artifact.callback_args))}, {', '.join(f'{k}={_repr_for_codegen(v)}' for k, v in artifact.callback_kwargs.items())}); " - if artifact.requires_context: - python_cmd += "print(str(ctx.module))" - else: - python_cmd += "print(str(mlir_code))" - self.dry_run.append(f"python3 -c '{python_cmd}' > {artifact.path}") - new_artifact.fake_available = True - artifacts[i] = new_artifact - logging.debug(f"Created MLIR source string for {artifact.path.name}") - - return artifacts - - @staticmethod - def _repr_for_codegen(obj): - """Convert an object to its string representation for code generation. - - Handles special cases like device classes that need to be instantiated - rather than using their default repr(). - """ - # Check if this is a device class from aie.iron.device - # These classes have __module__ == 'abc' but are imported from aie.iron.device - obj_module = type(obj).__module__ - obj_class = type(obj).__name__ - - # Check for known device class patterns (NPU1, NPU2, XCVC1902, etc.) - # These are imported from aie.iron.device but have __module__ == 'abc' - if obj_module == "abc" and ( - obj_class.startswith("NPU") or obj_class.startswith("XCVC") - ): - # For device classes, generate instantiation code - return f"{obj_class}()" - - # Default to repr() for other types - return repr(obj) - - -class AieccCompilationRule(CompilationRule): - def __init__(self, build_dir, peano_dir, mlir_aie_dir, *args, **kwargs): - self.build_dir = build_dir - self.aiecc_path = Path(mlir_aie_dir) / "bin" / "aiecc.py" - self.peano_dir = peano_dir - super().__init__(*args, **kwargs) - - def matches(self, artifacts): - return any( - isinstance(artifact, (XclbinArtifact, InstsBinArtifact)) - and all(dependency.is_available() for dependency in artifact.depends) - for artifact in artifacts - ) - - def compile(self, artifacts): - # If there are both xclbin and insts.bin targets based on the same source MLIR code, we can combine them into one single `aiecc.py` invocation. - mlir_sources = set() - mlir_sources_to_xclbins = {} - mlir_sources_to_insts_bins = {} - for artifact in artifacts: - if not isinstance(artifact, (XclbinArtifact, InstsBinArtifact)): - continue - if not all(dependency.is_available() for dependency in artifact.depends): - continue - mlir_dependencies = [ - d - for d in artifact.depends - if isinstance(d, (SourceArtifact, PythonGeneratedMLIRArtifact)) - ] - if len(mlir_dependencies) != 1: - raise RuntimeError( - f"Expected exactly one dependency of {artifact.path} to be SourceArtifact or PythonGeneratedMLIRArtifact, got: {', '.join(str(dep.path) for dep in artifact.depends)}" - ) - mlir_dependency = mlir_dependencies[0] - mlir_sources.add(mlir_dependency) - if isinstance(artifact, XclbinArtifact): - mlir_sources_to_xclbins.setdefault(mlir_dependency, []).append(artifact) - elif isinstance(artifact, InstsBinArtifact): - mlir_sources_to_insts_bins.setdefault(mlir_dependency, []).append( - artifact - ) - - # Now we know for each mlir source if we need to generate an xclbin, an insts.bin or both for it - for mlir_source in mlir_sources: - # Build aiecc command using Peano - compile_cmd = [ - "python", - str(self.aiecc_path), - "--no-compile-host", - "--no-xchesscc", - "--no-xbridge", - "--peano", - str(self.peano_dir), - "--dynamic-objFifos", - ] - do_compile_xclbin = mlir_source in mlir_sources_to_xclbins - do_compile_insts_bin = mlir_source in mlir_sources_to_insts_bins - if do_compile_xclbin: - first_xclbin = mlir_sources_to_xclbins[mlir_source][ - 0 - ] # FIXME: this does not handle the case of multiple xclbins with different kernel names or flags from the same MLIR - compile_cmd += first_xclbin.extra_flags + [ - "--aie-generate-xclbin", - "--xclbin-name=" + str(first_xclbin.path), - "--xclbin-kernel-name=" + first_xclbin.kernel_name, - ] - if first_xclbin.xclbin_input is not None: - compile_cmd += [ - "--xclbin-input=" + str(first_xclbin.xclbin_input.path) - ] - if do_compile_insts_bin: - first_insts_bin = mlir_sources_to_insts_bins[mlir_source][ - 0 - ] # FIXME: this does not handle the case of multiple insts.bins with different flags from the same MLIR - if not do_compile_xclbin: - compile_cmd += ["--no-compile"] - compile_cmd += first_insts_bin.extra_flags + [ - "--aie-generate-npu", - "--npu-insts-name=" + str(first_insts_bin.path), - ] - compile_cmd += [str(mlir_source.path)] - - env = os.environ.copy() - logging.debug(f"Compiling MLIR with command: {' '.join(compile_cmd)}") - if not self.dry_run: - result = subprocess.run( - compile_cmd, - cwd=str(self.build_dir), - capture_output=True, - text=True, - timeout=300, - env=env, - ) - if result.returncode == 0: - logging.debug( - f"Successfully compiled {mlir_source.path} to {', '.join([str(first_xclbin.path)] if do_compile_xclbin else [] + [str(first_insts_bin.path)] if do_compile_insts_bin else [])}" - ) - else: - raise RuntimeError( - f"MLIR compilation for {mlir_source.path} failed: {result.stderr}" - ) - - # There may be multiple targets that require an xclbin/insts.bin from the same MLIR with different names; copy them - for sources_to in [mlir_sources_to_xclbins, mlir_sources_to_insts_bins]: - if sources_to.get(mlir_source, [])[1:]: - copy_src = sources_to[mlir_source][0] - for copy_dest in sources_to[mlir_source][1:]: - shutil.copy(copy_src.path, copy_dest.path) - - else: - for sources_to in [mlir_sources_to_xclbins, mlir_sources_to_insts_bins]: - for artifact in sources_to.get(mlir_source, []): - self.dry_run.append( - f"pushd {str(self.build_dir)} && {' '.join(compile_cmd)} && popd" - ) - artifact.fake_available = True - - # With the newly generated files, is_available() should now return True on the Xclbin and InstsBin targets - return artifacts - - -class PeanoCompilationRule(CompilationRule): - def __init__(self, peano_dir, mlir_aie_dir, *args, **kwargs): - self.peano_dir = peano_dir - self.mlir_aie_dir = mlir_aie_dir - super().__init__(*args, **kwargs) - - def matches(self, artifacts): - return any( - isinstance(artifact, KernelObjectArtifact) - and all( - isinstance(dependency, SourceArtifact) and dependency.is_available() - for dependency in artifact.depends - ) - for artifact in artifacts - ) - - def compile(self, artifacts): - clang_path = Path(self.peano_dir) / "bin" / "clang++" - include_path = Path(self.mlir_aie_dir) / "include" - - for artifact in artifacts: - if not isinstance(artifact, KernelObjectArtifact): - continue - - if len(artifact.depends) != 1: - raise RuntimeError( - "Expected exactly one dependency (the C source code) for KernelObjectArtifact" - ) - source_file = artifact.depends[0] - if not isinstance(source_file, SourceArtifact): - raise RuntimeError( - "Expected KernelObject dependency to be a C source file" - ) - - cmd = ( - [ - str(clang_path), - "-O2", - "-std=c++20", - "--target=aie2p-none-unknown-elf", - "-Wno-parentheses", - "-Wno-attributes", - "-Wno-macro-redefined", - "-Wno-empty-body", - "-Wno-missing-template-arg-list-after-template-kw", - f"-I{str(include_path)}", - ] - + artifact.extra_flags - + ["-c", str(source_file.path), "-o", str(artifact.path)] - ) - logging.debug(f"Running compilation command: {' '.join(cmd)}") - - if self.dry_run is None: - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - raise RuntimeError(f"Compilation failed: {result.stderr}") - logging.debug(f"Successfully compiled: {artifact.path.name}") - else: - artifact.fake_available = True - self.dry_run.append(" ".join(cmd)) - - if artifact.rename_symbols: - self._rename_symbols(artifact) - - return artifacts - - def _rename_symbols(self, artifact): - objcopy_path = "llvm-objcopy-18" - cmd = [ - objcopy_path, - ] - for old_sym, new_sym in artifact.rename_symbols.items(): - cmd += [ - "--redefine-sym", - f"{old_sym}={new_sym}", - ] - cmd += [str(artifact.path)] - - logging.debug(f"Running renaming command: {' '.join(cmd)}") - if self.dry_run is None: - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode == 0: - logging.info(f"Successfully renamed symbols in: {artifact.path.name}") - else: - raise RuntimeError(f"Symbol renaming failed: {result.stderr}") - else: - artifact.fake_available = True - self.dry_run.append(" ".join(cmd)) - - -class ArchiveCompilationRule(CompilationRule): - def __init__(self, peano_dir, *args, **kwargs): - self.peano_dir = peano_dir - super().__init__(*args, **kwargs) - - def matches(self, artifacts): - return any( - isinstance(artifact, KernelArchiveArtifact) and len(artifact.depends) > 0 - for artifact in artifacts - ) - - def compile(self, artifacts): - """Create an archive (.a) from compiled object files""" - for artifact in artifacts: - if not isinstance(artifact, KernelArchiveArtifact): - continue - - # Get archive filename from method - archive_path = str(artifact.path) - object_files = [ - str(dep.path) - for dep in artifact.depends - if isinstance(dep, KernelObjectArtifact) - ] - - # Try to find ar tool from PEANO, then system - ar_path = None - - if self.peano_dir: - # Peano has llvm-ar for archiving - peano_ar = Path(self.peano_dir) / "bin" / "llvm-ar" - if os.path.exists(peano_ar): - ar_path = peano_ar - - if ar_path is None: - raise RuntimeError( - "Could not find 'ar' tool in PEANO installation or system PATH" - ) - - cmd = [str(ar_path), "rcs", archive_path] + object_files - - if self.dry_run is None: - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode == 0: - logging.debug( - f"Successfully created archive: {Path(archive_path).name}" - ) - else: - raise RuntimeError(f"Archive creation failed: {result.stderr}") - else: - artifact.fake_available = True - self.dry_run.append(" ".join(cmd)) - - return artifacts - - -# Global Functions -# -------------------------------------------------------------------------- - - -def apply_rules(rules, artifacts): - for rule in rules: - if rule.matches(artifacts): - logging.debug(f"Applying rule: {rule.__class__.__name__}") - artifacts = rule.compile(artifacts) - break - else: - # None of the rules matched - return False, artifacts - - return True, artifacts - - -def compile(rules, artifacts): - # While some artifacts remain to be compiled (not all are available) - while not all(artifact.is_available() for artifact in artifacts): - remaining = [artifact for artifact in artifacts if not artifact.is_available()] - success, artifacts = apply_rules(rules, remaining) - if not success: - raise RuntimeError( - f"No matching rule to compile target(s): {', '.join(str(artifact.path.name) for artifact in artifacts if not artifact.is_available())}" - ) - return artifacts - - -def get_work_list(artifacts): - """ - Return a flattened artifact creation worklist in reverse topological order from dependencies. - The returned list will start with leaf nodes (artifacts with no dependencies), and any following artifacts will only contain artifacts from earlier in the list. - """ - work_list = [] - todo = list(artifacts) - visited = set() - - def dfs_visit(artifact): - if artifact in visited: - # Thanks to uniquing of artifact objects, this avoids duplicate creation of the same artifacts - return - visited.add(artifact) - # First visit all dependencies, so put leaves first (post-order) ... - for dep in artifact.depends: - dfs_visit(dep) - # ... then put parent - if not artifact.is_available(): - work_list.append(artifact) - - for artifact in todo: - dfs_visit(artifact) - - return work_list diff --git a/iron/common/compilation/__init__.py b/iron/common/compilation/__init__.py new file mode 100644 index 00000000..405df6b0 --- /dev/null +++ b/iron/common/compilation/__init__.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from .base import * +from .fusion import * diff --git a/iron/common/compilation/base.py b/iron/common/compilation/base.py new file mode 100644 index 00000000..fb4b2c4d --- /dev/null +++ b/iron/common/compilation/base.py @@ -0,0 +1,682 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +This file implements a simple Python-based build system. You specify what you +want to compile (*artifacts*) through subclasses of `CompilationArtifact`. +Multiple `CompilationArtifacts` form a `CompilationArtifactGraph`. Each artifact +can have a list (subgraph) of depenencies of other artifacts that it relies on. +Each artifact corresponds to exactly one file. + +There is a special artifact for source files that do not need to get generated, +`SourceArtifact`. It is likely that in your compilation dependency graph, +the leaf nodes will be `SourceArtifact`s. + +You specify how to generate (compile) an artifact through *rules*, which are +expressed as subclasses of `CompilationRule`. Rules must implement two methods: +`matches` and `compile`. If a rule `matches` to an artifact graph, it can be +applied. Applying a rule is done by calling `compile`; this transforms the +artifact graph (in the simplest case, marks one of the artifacts as available) +and returns a list of compilation commands. + +At this point, we can print the compilation commands to the console (dry-run) +or actually run them to generate the artifacts. + +Before starting compilation, you may call +`populate_availability_from_filesystem()` -- this will check if any artifacts +are already available at the given file paths (and ensure that dependencies are +as old or older than the artifacts that depend on them). This way, you can avoid +recompiling artifacts that are already up-to-date on disk. If you wish to +regenerate everything, you can skip this step, but will at a minimum want to +mark the `SourceArtifact`s as available -- they cannot be generated. +""" + +from abc import ABC, abstractmethod +from pathlib import Path +import os.path +import zlib +import logging +import subprocess +import importlib.util +from contextlib import nullcontext +from aie.extras.context import mlir_mod_ctx +import sys + +# Global Functions +# ########################################################################## + + +def plan(rules, graph): + if all(artifact.is_available() for artifact in graph): + return [] # Everything has been compiled + for rule in rules: + if rule.matches(graph): + commands = rule.compile(graph) + break + else: + raise RuntimeError( + f"No matching rule to compile target(s): {', '.join(artifact.filename for artifact in graph)}" + ) + return [(rule, commands)] + plan(rules, graph) + + +def execute(plan_steps): + for rule, commands in plan_steps: + logging.debug(f"Applying rule: {rule.__class__.__name__}") + for command in commands: + logging.debug(f" Executing command: {command}") + success = command.run() + if not success: + raise RuntimeError(f"Command failed: {command}") + + +def compile(rules, artifacts, build_dir="build", dry_run=False): + if not os.path.exists(build_dir) and not dry_run: + os.makedirs(build_dir) + artifacts.move_artifacts(build_dir) + artifacts.populate_availability_from_filesystem() + plan_steps = plan(rules, artifacts) + if not dry_run: + execute(plan_steps) + else: + print("\n".join("\n".join(map(str, cmds)) for _, cmds in plan_steps)) + + +# Compilation Artifact Graph +# ########################################################################## + + +class CompilationArtifactGraph: + def __init__(self, artifacts=None): + self.artifacts = artifacts if artifacts is not None else [] + + def __repr__(self): + def format_artifact(artifact, indent=0): + prefix = " " * indent + avail = "[x] " if artifact.is_available() else "[ ] " + result = f"{prefix}{avail}{artifact.__class__.__name__}({Path(artifact.filename).name})\n" + for dep in artifact.dependencies: + result += format_artifact(dep, indent + 1) + return result + + result = "CompilationArtifactGraph(\n" + for artifact in self.artifacts: + result += format_artifact(artifact, indent=1) + result += ")" + return result + + def __iter__(self): + return iter(self.artifacts) + + def __len__(self): + return len(self.artifacts) + + def __getitem__(self, index): + return self.artifacts[index] + + def dfs(self): + return self._traverse(True) + + def bfs(self): + return self._traverse(False) + + def _traverse(self, dfs): + visited = set() + todo = self.artifacts.copy() + while todo: + artifact = todo.pop() if dfs else todo.pop(0) + if artifact in visited: + continue + visited.add(artifact) + todo.extend(artifact.dependencies) + yield artifact + + def replace(self, old_artifact, new_artifact): + for i, artifact in enumerate(self.artifacts): + if artifact == old_artifact: + self.artifacts[i] = new_artifact + else: + artifact.dependencies.replace(old_artifact, new_artifact) + return self + + def populate_availability_from_filesystem(self): + for artifact in self.artifacts: + artifact.dependencies.populate_availability_from_filesystem() + artifact.available = artifact.is_available_in_filesystem() + + def get_worklist(self, kind): + """Return a list of artifacts of the given kind that can be built in the next step (dependencies available).""" + return [ + artifact + for artifact in self.bfs() + if isinstance(artifact, kind) + and not artifact.is_available() + and artifact.dependencies_available() + ] + + def move_artifacts(self, new_root): + """Make all artifacts paths point into a build directory""" + for artifact in self.bfs(): + if not os.path.isabs(artifact.filename): + artifact.filename = str(Path(new_root) / Path(artifact.filename).name) + + def add(self, artifact): + self.artifacts.append(artifact) + + +# Compilation Artifacts +# ########################################################################## + + +class CompilationArtifact(ABC): + def __init__(self, filename, dependencies=None, available=False): + self.filename = str(filename) + self.dependencies: CompilationArtifactGraph = CompilationArtifactGraph( + artifacts=dependencies if dependencies is not None else [] + ) + self.available = available + + def __repr__(self): + return f"{self.__class__.__name__}({self.filename})" + + def is_available(self): + """'Conceptual' availability: during a dry-run or in the planning stage, available may be True even if the underlying file does not exist yet.""" + # If any of our dependencies' dependencies are outdated, this artifact is also outdated + return self.available and self.dependencies_available() + + def dependencies_available(self): + return all(d.is_available() for d in self.dependencies) + + def is_available_in_filesystem(self): + """'Real' availability: checks if the underlying file exists and is up-to-date with respect to dependencies.""" + if not os.path.exists(self.filename): + return False + file_mtime = os.path.getmtime(self.filename) + for dependency in self.dependencies: + if ( + not dependency.is_available_in_filesystem() + or os.path.getmtime(dependency.filename) > file_mtime + ): + return False + return True + + +class SourceArtifact(CompilationArtifact): + """Artifact representing a source file that does not need to be generated, is assumed to be there.""" + + pass + + +class FullElfArtifact(CompilationArtifact): + def __init__(self, filename, mlir_input, dependencies): + if mlir_input not in dependencies: + dependencies = dependencies + [mlir_input] + super().__init__(filename, dependencies) + self.mlir_input = mlir_input + + +class XclbinArtifact(CompilationArtifact): + def __init__( + self, + filename, + mlir_input, + dependencies, + kernel_name="MLIR_AIE", + extra_flags=None, + xclbin_input=None, + ): + if mlir_input not in dependencies: + dependencies = dependencies + [mlir_input] + super().__init__(filename, dependencies) + self.mlir_input = mlir_input + self.kernel_name = kernel_name + self.extra_flags = extra_flags if extra_flags is not None else [] + self.xclbin_input = xclbin_input + + +class InstsBinArtifact(CompilationArtifact): + def __init__(self, filename, mlir_input, dependencies, extra_flags=None): + self.mlir_input = mlir_input + if mlir_input not in dependencies: + dependencies = dependencies + [mlir_input] + super().__init__(filename, dependencies) + self.extra_flags = extra_flags if extra_flags is not None else [] + + +class KernelObjectArtifact(CompilationArtifact): + def __init__( + self, + filename, + dependencies, + extra_flags=None, + rename_symbols=None, + prefix_symbols=None, + ): + super().__init__(filename, dependencies) + self.extra_flags = extra_flags if extra_flags is not None else [] + self.rename_symbols = rename_symbols if rename_symbols is not None else {} + self.prefix_symbols = prefix_symbols + + +class KernelArchiveArtifact(CompilationArtifact): + pass + + +class PythonGeneratedMLIRArtifact(CompilationArtifact): + def __init__( + self, + filename, + import_path, + callback_fn, + callback_args=None, + callback_kwargs=None, + requires_context=False, + uses_kernel_archive=False, + kernel_archive=None, + ): + self.import_path = import_path + self.callback_fn = callback_fn + self.callback_args = callback_args if callback_args is not None else [] + self.callback_kwargs = callback_kwargs if callback_kwargs is not None else {} + self.requires_context = requires_context + dependencies = [SourceArtifact(import_path)] + super().__init__(filename, dependencies=dependencies) + + +# Compilation Command +# ########################################################################## + + +class CompilationCommand(ABC): + """An abstraction for anything that can be executed to physically produce artifacts.""" + + @abstractmethod + def run(self) -> bool: + pass + + @abstractmethod + def __repr__(self): + pass + + +class ShellCompilationCommand(CompilationCommand): + def __init__(self, command: list[str], cwd=None, env="copy"): + self.command = command + self.cwd = cwd + if env == "copy": + env = os.environ.copy() + self.env = env + + def run(self) -> bool: + result = subprocess.run( + self.command, + capture_output=True, + text=True, + cwd=self.cwd, + env=self.env, + ) + if 0 != result.returncode: + print(result.stdout) + print(result.stderr, file=sys.stderr) + return 0 == result.returncode + + def __repr__(self): + return f"Shell({' '.join(self.command)})" + + +class PythonCallbackCompilationCommand(CompilationCommand): + def __init__(self, callback): + self.callback = callback + + def run(self) -> bool: + result = self.callback() + return bool(result) if result is not None else True + + def __repr__(self): + return f"PythonCallback({self.callback})" + + +# Compilation Rules +# ########################################################################## + + +class CompilationRule(ABC): + """A compilation rule is applied to a artifact graph, producing compilation commands and a transformed artifact graph.""" + + @abstractmethod + def matches(self, artifact: CompilationArtifactGraph) -> bool: + """Return true if this rule can be applied to any artifact in the artifact graph.""" + pass + + @abstractmethod + def compile(self, artifacts: CompilationArtifactGraph) -> list[CompilationCommand]: + """Apply this rule to the artifact graph, returning compilation commands. This should modify the artifact graph in-place to reflect the newly generated artifacts.""" + pass + + +class GenerateMLIRFromPythonCompilationRule(CompilationRule): + def matches(self, graph): + return any(graph.get_worklist(PythonGeneratedMLIRArtifact)) + + def compile(self, graph): + """Generate MLIR from a Python callback that uses the MLIR bindings""" + commands = [] + worklist = graph.get_worklist(PythonGeneratedMLIRArtifact) + for artifact in worklist: + new_artifact = SourceArtifact(artifact.filename) + # To make Python capture variables in this closure by value, not by reference, use default arguments + callback = lambda new_artifact=new_artifact, import_path=artifact.import_path, callback_fn=artifact.callback_fn, callback_args=artifact.callback_args, callback_kwargs=artifact.callback_kwargs, requires_context=artifact.requires_context: self.generate_mlir( + new_artifact, + import_path, + callback_fn, + callback_args, + callback_kwargs, + requires_context, + ) + commands.append(PythonCallbackCompilationCommand(callback)) + new_artifact.available = True + graph.replace(artifact, new_artifact) + return commands + + @staticmethod + def generate_mlir( + output_artifact, + import_path, + callback_fn, + callback_args=None, + callback_kwargs=None, + requires_context=False, + ): + # Import the Python source file + spec = importlib.util.spec_from_file_location( + Path(import_path).name, import_path + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # We only initiate an MLIR context if requested; otherwise, it is expected that the callback creates the context + ctx_callback = lambda: (mlir_mod_ctx() if requires_context else nullcontext()) + with ctx_callback() as ctx: + callback_function = getattr(module, callback_fn) + mlir_code = callback_function(*callback_args, **callback_kwargs) + # Stringify the generated MLIR + if requires_context: + mlir_code = str(ctx.module) + else: + mlir_code = str(mlir_code) + + with open(output_artifact.filename, "w") as f: + f.write(mlir_code) + + +class AieccCompilationRule(CompilationRule, ABC): + def __init__(self, build_dir, peano_dir, mlir_aie_dir, *args, **kwargs): + self.build_dir = build_dir + self.aiecc_path = Path(mlir_aie_dir) / "bin" / "aiecc.py" + self.peano_dir = peano_dir + super().__init__(*args, **kwargs) + + +class AieccFullElfCompilationRule(AieccCompilationRule): + def matches(self, graph): + return any(graph.get_worklist(FullElfArtifact)) + + def compile(self, graph): + worklist = graph.get_worklist(FullElfArtifact) + commands = [] + + for artifact in worklist: + compile_cmd = [ + "python", + str(self.aiecc_path), + "--no-compile-host", + "--no-xchesscc", + "--no-xbridge", + "--peano", + str(self.peano_dir), + "--dynamic-objFifos", + "--expand-load-pdis", + "--generate-full-elf", + "--full-elf-name", + os.path.abspath(artifact.filename), + os.path.abspath(artifact.mlir_input.filename), + ] + commands.append( + ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir)) + ) + artifact.available = True + + return commands + + +class AieccXclbinInstsCompilationRule(AieccCompilationRule): + def matches(self, graph): + return any(graph.get_worklist((XclbinArtifact, InstsBinArtifact))) + + def compile(self, graph): + # If there are both xclbin and insts.bin targets based on the same source MLIR code, we can combine them into one single `aiecc.py` invocation. + mlir_sources = set() + mlir_sources_to_xclbins = {} + mlir_sources_to_insts = {} + worklist = graph.get_worklist((XclbinArtifact, InstsBinArtifact)) + for artifact in worklist: + mlir_dependency = artifact.mlir_input + mlir_sources.add(mlir_dependency) + if isinstance(artifact, XclbinArtifact): + mlir_sources_to_xclbins.setdefault(mlir_dependency, []).append(artifact) + elif isinstance(artifact, InstsBinArtifact): + mlir_sources_to_insts.setdefault(mlir_dependency, []).append(artifact) + + commands = [] + # Now we know for each mlir source if we need to generate an xclbin, an insts.bin or both for it + for mlir_source in mlir_sources: + compile_cmd = [ + "python", + str(self.aiecc_path), + "--no-compile-host", + "--no-xchesscc", + "--no-xbridge", + "--peano", + str(self.peano_dir), + "--dynamic-objFifos", + ] + do_compile_xclbin = mlir_source in mlir_sources_to_xclbins + do_compile_insts_bin = mlir_source in mlir_sources_to_insts + if do_compile_xclbin: + first_xclbin = mlir_sources_to_xclbins[mlir_source][ + 0 + ] # TODO: this does not handle the case of multiple xclbins with different kernel names or flags from the same MLIR + compile_cmd += first_xclbin.extra_flags + [ + "--aie-generate-xclbin", + "--xclbin-name=" + os.path.abspath(first_xclbin.filename), + "--xclbin-kernel-name=" + first_xclbin.kernel_name, + ] + if first_xclbin.xclbin_input is not None: + compile_cmd += [ + "--xclbin-input=" + + os.path.abspath(first_xclbin.xclbin_input.filename) + ] + if do_compile_insts_bin: + first_insts_bin = mlir_sources_to_insts[mlir_source][ + 0 + ] # TODO: this does not handle the case of multiple insts.bins with different flags from the same MLIR + if not do_compile_xclbin: + compile_cmd += ["--no-compile"] + compile_cmd += first_insts_bin.extra_flags + [ + "--aie-generate-npu", + "--npu-insts-name=" + os.path.abspath(first_insts_bin.filename), + ] + compile_cmd += [os.path.abspath(mlir_source.filename)] + + # If the MLIR source depends on a kernel archive, pass it to aiecc.py so it can be linked + if ( + isinstance(mlir_source, PythonGeneratedMLIRArtifact) + and "kernel_archive" in mlir_source.callback_kwargs + ): + compile_cmd.append( + os.path.abspath( + os.path.join( + self.build_dir, + mlir_source.callback_kwargs["kernel_archive"], + ) + ) + ) + + commands.append( + ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir)) + ) + + # There may be multiple targets that require an xclbin/insts.bin from the same MLIR with different names; copy them + for sources_to in [mlir_sources_to_xclbins, mlir_sources_to_insts]: + if sources_to.get(mlir_source, [])[1:]: + copy_src = sources_to[mlir_source][0] + for copy_dest in sources_to[mlir_source][1:]: + commands.append( + ShellCompilationCommand( + ["cp", copy_src.filename, copy_dest.filename] + ) + ) + + # Update graph + for artifact in worklist: + artifact.available = True + + return commands + + +class PeanoCompilationRule(CompilationRule): + def __init__(self, peano_dir, mlir_aie_dir, *args, **kwargs): + self.peano_dir = peano_dir + self.mlir_aie_dir = mlir_aie_dir + super().__init__(*args, **kwargs) + + def matches(self, artifacts): + return any(artifacts.get_worklist(KernelObjectArtifact)) + + def compile(self, artifacts): + clang_path = Path(self.peano_dir) / "bin" / "clang++" + include_path = Path(self.mlir_aie_dir) / "include" + worklist = artifacts.get_worklist(KernelObjectArtifact) + commands = [] + for artifact in worklist: + if len(artifact.dependencies) != 1: + raise RuntimeError( + "Expected exactly one dependency (the C source code) for KernelObjectArtifact" + ) + source_file = artifact.dependencies[0] + if not isinstance(source_file, SourceArtifact): + raise RuntimeError( + "Expected KernelObject dependency to be a C source file" + ) + + cmd = ( + [ + str(clang_path), + "-O2", + "-std=c++20", + "--target=aie2p-none-unknown-elf", + "-Wno-parentheses", + "-Wno-attributes", + "-Wno-macro-redefined", + "-Wno-empty-body", + "-Wno-missing-template-arg-list-after-template-kw", + f"-I{str(include_path)}", + ] + + artifact.extra_flags + + ["-c", source_file.filename, "-o", artifact.filename] + ) + + commands.append(ShellCompilationCommand(cmd)) + if artifact.rename_symbols: + commands.extend(self._rename_symbols(artifact)) + if artifact.prefix_symbols: + commands.extend(self._prefix_symbols(artifact, artifact.prefix_symbols)) + artifact.available = True + + return commands + + def _rename_symbols(self, artifact): + objcopy_path = "llvm-objcopy-18" + cmd = [ + objcopy_path, + ] + for old_sym, new_sym in artifact.rename_symbols.items(): + cmd += [ + "--redefine-sym", + f"{old_sym}={new_sym}", + ] + cmd += [artifact.filename] + return [ShellCompilationCommand(cmd)] + + def _prefix_symbols(self, artifact, prefix): + objcopy_path = "llvm-objcopy-18" + nm_path = "llvm-nm-18" + symbol_map_file = artifact.filename + ".symbol_map" + + # Extract defined symbols and create symbol map + nm_cmd = [ + "sh", + "-c", + f"{nm_path} --defined-only --extern-only {artifact.filename} | " + f"awk '{{print $3 \" {prefix}\" $3}}' > {symbol_map_file}", + ] + + # Apply the renaming using the symbol map + objcopy_cmd = [ + objcopy_path, + "--redefine-syms=" + symbol_map_file, + artifact.filename, + ] + + return [ShellCompilationCommand(nm_cmd), ShellCompilationCommand(objcopy_cmd)] + + +class ArchiveCompilationRule(CompilationRule): + def __init__(self, peano_dir, *args, **kwargs): + self.peano_dir = peano_dir + super().__init__(*args, **kwargs) + + def matches(self, artifacts): + return any(artifacts.get_worklist(KernelArchiveArtifact)) + + def compile(self, artifacts): + """Create an archive (.a) from compiled object files""" + worklist = artifacts.get_worklist(KernelArchiveArtifact) + commands = [] + for artifact in worklist: + # Get archive filename from method + archive_path = artifact.filename + object_files = [ + dep.filename + for dep in artifact.dependencies + if isinstance(dep, KernelObjectArtifact) + ] + + # Try to find ar tool from PEANO, then system + ar_path = None + + if self.peano_dir: + # Peano has llvm-ar for archiving + peano_ar = Path(self.peano_dir) / "bin" / "llvm-ar" + if os.path.exists(peano_ar): + ar_path = peano_ar + + if ar_path is None: + raise RuntimeError( + "Could not find 'ar' tool in PEANO installation or system PATH" + ) + + cmd = [str(ar_path), "rcs", archive_path] + object_files + commands.append(ShellCompilationCommand(cmd)) + + # Check for duplicate symbol definitions in the archive + check_cmd = [ + "sh", + "-c", + f"nm {archive_path} | grep ' [TDR] ' | awk '{{print $3}}' | sort | uniq -d | " + f'if read sym; then echo "Error: Duplicate symbol in archive: $sym" >&2; exit 1; fi', + ] + commands.append(ShellCompilationCommand(check_cmd)) + + artifact.available = True + + return commands diff --git a/iron/common/compilation/fusion.py b/iron/common/compilation/fusion.py new file mode 100644 index 00000000..ea1d47e2 --- /dev/null +++ b/iron/common/compilation/fusion.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Temporal fusion of multiple MLIR modules into one module with multiple devices and a main runtime sequence that calls into them. +""" + +import numpy as np +import importlib.util +from pathlib import Path +from aie import ir +from aie.dialects import aie, aiex, memref +from aie.extras.context import mlir_mod_ctx +import ml_dtypes + +from . import ( + CompilationArtifact, + CompilationRule, + CompilationCommand, + PythonCallbackCompilationCommand, + SourceArtifact, + PythonGeneratedMLIRArtifact, +) + +# Compilation Artifacts +# ########################################################################## + + +class FusedMLIRSource(CompilationArtifact): + def __init__( + self, + filename, + operator_mlir_map, + runlist, + subbuffer_layout, + buffer_sizes, + slice_info=None, + ): + dependencies = list(operator_mlir_map.values()) + super().__init__(filename, dependencies) + self.operator_mlir_map = operator_mlir_map + self.runlist = runlist + self.subbuffer_layout = subbuffer_layout + self.buffer_sizes = buffer_sizes + self.slice_info = slice_info or {} + + +# Helper Functions +# ########################################################################## + + +def extract_runtime_sequence_arg_types(dev_op): + """MLIR helper: Extract argument types from a device operation's runtime sequence.""" + for nested_op in dev_op.body_region.blocks[0].operations: + op_name = nested_op.operation.name + if op_name == "aie.runtime_sequence": + if hasattr(nested_op, "body") and hasattr(nested_op.body, "blocks"): + if len(nested_op.body.blocks) > 0: + entry_block = nested_op.body.blocks[0] + arg_types = [ + entry_block.arguments[i].type + for i in range(len(entry_block.arguments)) + ] + return arg_types + raise RuntimeError("Could not find runtime sequence in device operation") + + +def get_child_mlir_module(mlir_artifact): + """Extract MLIR module from a PythonGeneratedMLIRArtifact.""" + assert isinstance(mlir_artifact, PythonGeneratedMLIRArtifact) + spec = importlib.util.spec_from_file_location( + Path(mlir_artifact.import_path).name, mlir_artifact.import_path + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if mlir_artifact.requires_context: + raise NotImplementedError("Not handled, make your operator return a ctx.module") + + callback_function = getattr(module, mlir_artifact.callback_fn) + mlir_module = callback_function( + *mlir_artifact.callback_args, **mlir_artifact.callback_kwargs + ) + return mlir_module + + +def fuse_mlir(artifact): + """Fuse multiple MLIR modules by inlining their device operations and adding a new main device and runtime sequence that call into sequence of operations based on a runlist.""" + + input_buffer_size, output_buffer_size, scratch_buffer_size = artifact.buffer_sizes + + # Extract device operations from each operator's MLIR artifact + device_mlir_strings = {} + device_ty = None + sequence_arg_types = {} + for op_name, mlir_artifact in artifact.operator_mlir_map.items(): + mlir_module = get_child_mlir_module(mlir_artifact) + device_ops = [ + op for op in mlir_module.body.operations if isinstance(op, aie.DeviceOp) + ] + assert ( + len(device_ops) == 1 + ), f"Expected exactly one device operation in MLIR artifact for operator '{op_name}'" + device_op = device_ops[0] + if device_ty is None: + device_ty = device_op.device + device_mlir_strings[op_name] = str(device_op) + sequence_arg_types[op_name] = extract_runtime_sequence_arg_types(device_op) + + # Build fused MLIR module + with mlir_mod_ctx() as ctx: + + # Concatenate aie.device ops + for op_name, device_str in device_mlir_strings.items(): + dev_op = aie.DeviceOp.parse(device_str) + dev_op.sym_name = ir.StringAttr.get(op_name) + ctx.module.body.append(dev_op) + + # Create the main device -- this contains the runtime sequence calling into the other devices + @aie.device(device_ty) + def main(): + buf_dtype = np.dtype[ + ml_dtypes.bfloat16 + ] # TODO: support for other data types + itemsize = 2 + + # RuntimeSequenceOp + @aiex.runtime_sequence( + np.ndarray[(input_buffer_size // itemsize,), buf_dtype], + np.ndarray[(output_buffer_size // itemsize,), buf_dtype], + np.ndarray[(scratch_buffer_size // itemsize,), buf_dtype], + ) + def sequence(input_buf, output_buf, scratch_buf): + consolidated_buffers = { + "input": input_buf, + "output": output_buf, + "scratch": scratch_buf, + } + + # Execute operations in runlist order + configure_op = None + last_op_name = None + for op_name, *buffer_names in artifact.runlist: + expected_arg_types = sequence_arg_types[op_name] + + # Avoid reconfiguring altogether if the same op is called multiple times consecutively + if configure_op is None or op_name != last_op_name: + # Configure Op + configure_sym_ref_attr = ir.FlatSymbolRefAttr.get(op_name) + configure_op = aiex.ConfigureOp( + configure_sym_ref_attr + ) # TODO: optimization -- if previous op was in the same device, skip reconfiguration + configure_body = configure_op.body.blocks.append() + last_op_name = op_name + + with ir.InsertionPoint(configure_body): + + # For each buffer, add subview and reinterpret_cast ops + buffer_ssa_values = [] + for idx, buf_name in enumerate(buffer_names): + # Check if this is a sliced buffer + if buf_name in artifact.slice_info: + base_name, start, end = artifact.slice_info[buf_name] + # Get parent buffer info + buf_type, parent_offset, parent_length = ( + artifact.subbuffer_layout[base_name] + ) + # Calculate actual offset and length for slice + offset = parent_offset + start + length = end - start + else: + # Regular buffer + buf_type, offset, length = artifact.subbuffer_layout[ + buf_name + ] + + # Subview Op + consolidated_buf = consolidated_buffers[buf_type] + offset_elements = offset // itemsize + size_elements = length // itemsize + subview = memref.subview( + consolidated_buf, + [offset_elements], + [size_elements], + [1], + ) + + # Reinterpret_cast Op + target_type = expected_arg_types[idx] + expected_memref = ir.MemRefType(target_type) + target_shape = [ + expected_memref.shape[i] + for i in range(expected_memref.rank) + ] + expected_size = np.prod(target_shape) + assert ( + expected_size == size_elements + ), f"Size mismatch for buffer '{buf_name}': MLIR runtime sequence expected {expected_size}, Python fused operator provided {size_elements}" + strides = [] + stride = 1 + for dim in reversed(target_shape): + strides.insert(0, stride) + stride *= dim + result_type = ir.MemRefType.get( + target_shape, ir.BF16Type.get() + ) + reinterpreted = memref.reinterpret_cast( + result=result_type, + source=subview, + offsets=[], + sizes=[], + strides=[], + static_offsets=[0], + static_sizes=target_shape, + static_strides=strides, + ) + buffer_ssa_values.append(reinterpreted) + + # Run Op + sequence_sym_ref_attr = ir.FlatSymbolRefAttr.get("sequence") + run_op = aiex.RunOp(sequence_sym_ref_attr, buffer_ssa_values) + + # Write the fused MLIR to file + with open(artifact.filename, "w") as f: + f.write(str(ctx.module)) + + +# Compilation Rules +# ########################################################################## + + +class FusePythonGeneratedMLIRCompilationRule(CompilationRule): + """Compilation rule that fuses multiple MLIR modules into one.""" + + def matches(self, graph): + return any(graph.get_worklist(FusedMLIRSource)) + + def compile(self, graph): + commands = [] + worklist = graph.get_worklist(FusedMLIRSource) + for artifact in worklist: + callback = lambda artifact=artifact: fuse_mlir(artifact) + commands.append(PythonCallbackCompilationCommand(callback)) + new_artifact = SourceArtifact(artifact.filename) + new_artifact.available = True + graph.replace(artifact, new_artifact) + return commands diff --git a/iron/common/context.py b/iron/common/context.py new file mode 100644 index 00000000..1cde7087 --- /dev/null +++ b/iron/common/context.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import logging +from pathlib import Path +import os + +from .device_manager import AIEDeviceManager, pyxrt +from . import compilation as comp +import aie.utils.config + + +class AIEContext: + """Context for managing AIE operator compilation and runtime state""" + + def __init__(self, use_runlist=True, build_dir=None): + self.operators = [] + self.static_data_pool = {} + self.device_manager = AIEDeviceManager() + self.base_dir = Path(__file__).parent.parent.parent + self.build_dir = build_dir or Path(os.getcwd()) / "build" + self.mlir_aie_dir = Path(aie.utils.config.root_path()) + self.peano_dir = Path(aie.utils.config.peano_install_dir()) + # Disable the XRT runlist sacrifices performance by executing kernels individually as separate xclbin invocations for easier debugging (can tell which part of runlist execution failed) + self.use_runlist = use_runlist + self.compilation_rules = [ + comp.FusePythonGeneratedMLIRCompilationRule(), + comp.GenerateMLIRFromPythonCompilationRule(), + comp.PeanoCompilationRule(self.peano_dir, self.mlir_aie_dir), + comp.ArchiveCompilationRule(self.peano_dir), + comp.AieccXclbinInstsCompilationRule( + self.build_dir, self.peano_dir, self.mlir_aie_dir + ), + comp.AieccFullElfCompilationRule( + self.build_dir, self.peano_dir, self.mlir_aie_dir + ), + ] + + def register_operator(self, operator): + """Register an operator with this context""" + operator.context = self + self.operators.append(operator) + + def compile_all(self): + """Compile all registered operators""" + self.build_dir.mkdir(parents=True, exist_ok=True) + for op in self.operators: + op.compile() diff --git a/iron/common/device_manager.py b/iron/common/device_manager.py new file mode 100644 index 00000000..2ae18bfe --- /dev/null +++ b/iron/common/device_manager.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Global AIE Device Manager for resource sharing and cleanup +""" + +import logging +import os +import sys +from pathlib import Path +from typing import Dict, Optional, Any +import pyxrt +from aie.utils.hostruntime.xrtruntime.hostruntime import XRTHostRuntime +from aie.iron.device import NPU1, NPU2 + + +class AIEDeviceManager: + """Singleton manager for AIE XRT resources""" + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + # Only initialize once + if AIEDeviceManager._initialized: + return + AIEDeviceManager._initialized = True + + self.device = pyxrt.device(0) + self.device_type = XRTHostRuntime().device() + self.contexts = {} # xclbin_path -> (context, xclbin) + self.kernels = {} # (xclbin_path, kernel_name) -> kernel + + def get_context_and_kernel( + self, xclbin_path: str, kernel_name: str | None = None + ) -> (pyxrt.hw_context, pyxrt.kernel): + """Get or create hardware context and kernel for xclbin""" + # Check if we already have a context for this xclbin + + if xclbin_path not in self.contexts: + xclbin = pyxrt.xclbin(xclbin_path) + self.device.register_xclbin(xclbin) + xclbin_uuid = xclbin.get_uuid() + context = pyxrt.hw_context(self.device, xclbin_uuid) + self.contexts[xclbin_path] = (context, xclbin) + logging.debug(f"Created new context for {Path(xclbin_path).name}") + else: + context, xclbin = self.contexts[xclbin_path] + logging.debug(f"Reusing context for {Path(xclbin_path).name}") + + # Get kernel name if not provided + if kernel_name is None: + kernels = xclbin.get_kernels() + if not kernels: + raise RuntimeError("No kernels found in xclbin") + kernel_name = kernels[0].get_name() + + # Check if we already have the kernel + kernel_key = (xclbin_path, kernel_name) + if kernel_key not in self.kernels: + self.kernels[kernel_key] = pyxrt.kernel(context, kernel_name) + logging.debug( + f"Created new kernel {kernel_name} from xclbin {Path(xclbin_path).name}" + ) + else: + logging.debug( + f"Reusing kernel: {kernel_name} from xclbin {Path(xclbin_path).name}" + ) + + return context, self.kernels[kernel_key] + + def device_str(self) -> str: + return self.device_type.resolve().name + + def cleanup(self): + """Clean up all XRT resources""" + self.kernels.clear() + + # Clear contexts + for xclbin_path, (context, xclbin) in self.contexts.items(): + try: + del context + except: + pass + self.contexts.clear() + + # Clear device + if self.device is not None: + try: + del self.device + except: + pass + self.device = None + + logging.debug("Cleaned up AIE device manager") + + def reset(self): + """Reset the device manager (for debugging)""" + self.cleanup() + AIEDeviceManager._instance = None + AIEDeviceManager._initialized = False diff --git a/iron/common/fusion.py b/iron/common/fusion.py new file mode 100644 index 00000000..e2ced3c0 --- /dev/null +++ b/iron/common/fusion.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import ml_dtypes +import pyxrt +import ctypes +from . import compilation as comp +from .base import AIEOperatorBase, MLIROperator, AIEBuffer +from .device_manager import AIEDeviceManager + +# Fused Operator +# ########################################################################## + + +class FusedMLIROperator(AIEOperatorBase): + """Operator that fuses multiple MLIROperators into one.""" + + def __init__( + self, name, runlist, input_args, output_args, buffer_sizes=None, *args, **kwargs + ): + assert all( + isinstance(op, MLIROperator) and all(isinstance(buf, str) for buf in bufs) + for op, *bufs in runlist + ) + self.runlist = runlist + self.name = name + self.input_args = input_args + self.output_args = output_args + self.explicit_buffer_sizes = ( + buffer_sizes or {} + ) # Optional dict: buffer_name -> size_in_bytes + self.kernel_archive = "kernels.a" + super().__init__(*args, **kwargs) + + def get_operator_name(self): + return self.name + + def get_kernel_artifacts(self): + """Collect all kernel artifacts from child operators.""" + kernel_artifacts = [] + unique_operators = [] + for op, *_ in self.runlist: + if op not in unique_operators: + unique_operators.append(op) + for idx, op in enumerate(unique_operators): + objs = op.get_kernel_artifacts() + for obj in objs: + obj.filename = f"op{idx}_{obj.filename}" + obj.prefix_symbols = f"op{idx}_" + kernel_artifacts.extend(objs) + return kernel_artifacts + + def get_mlir_artifact(self): + # Build operator_mlir_map: {op_name -> PythonGeneratedMLIRArtifact} + operator_mlir_map = {} + mlir_dependencies = [] + comp_runlist = [] + op_names = {} # op -> op_name + + unique_operators = [] + for op, *_ in self.runlist: + if op not in unique_operators: + unique_operators.append(op) + for idx, op in enumerate(unique_operators): + mlir_artifact = op.get_mlir_artifact() + if len(op.get_kernel_artifacts()) > 0: + # FIXME: currently hard-coding that the design will accept this argument as an input if it uses kernels + # Also not handling name collisions of kernels with the same name + mlir_artifact.callback_kwargs["kernel_archive"] = self.kernel_archive + mlir_artifact.callback_kwargs["func_prefix"] = f"op{idx}_" + op_name = f"op{idx}_{op.__class__.__name__}" + op_names[op] = op_name + operator_mlir_map[op_name] = mlir_artifact + + for op, *bufs in self.runlist: + comp_runlist.append((op_names[op], *bufs)) + + # Calculate buffer layout: {buffer_name -> (type, offset, length)} + self.subbuffer_layout, self.buffer_sizes, self.slice_info = ( + self._calculate_buffer_layout() + ) + + filename = self.get_operator_name() + "_fused.mlir" + fused_artifact = comp.FusedMLIRSource( + filename, + operator_mlir_map=operator_mlir_map, + runlist=comp_runlist, + subbuffer_layout=self.subbuffer_layout, + buffer_sizes=self.buffer_sizes, + slice_info=self.slice_info, + ) + + return fused_artifact + + def _calculate_buffer_layout(self): + args = {} # base_buffer_name -> args_spec + sliced_buffers = ( + {} + ) # full_buffer_name (with slice) -> (base_name, start, end, args_spec) + + # Collect all buffer specs from operators + for op, *bufs in self.runlist: + args_specs = op.get_arg_spec() + assert len(args_specs) == len( + bufs + ), "Number of buffers must match operator argument specification" + for i, buf_name in enumerate(bufs): + args_spec = args_specs[i] + + # Parse slice notation: "buffer_name[start:end]" + if "[" in buf_name and buf_name.endswith("]"): + base_name = buf_name[: buf_name.index("[")] + slice_part = buf_name[buf_name.index("[") + 1 : -1] + start, end = map(int, slice_part.split(":")) + sliced_buffers[buf_name] = (base_name, start, end, args_spec) + # Track that base buffer exists (size will be set later) + if ( + base_name not in args + and base_name not in self.explicit_buffer_sizes + ): + raise ValueError( + f"Sliced buffer '{buf_name}' requires explicit size for base buffer '{base_name}' in buffer_sizes parameter" + ) + else: + # Regular buffer (no slice) + if buf_name not in args: + args[buf_name] = args_spec + else: + assert np.prod(args[buf_name].shape) == np.prod( + args_spec.shape + ), f"Buffer {buf_name} has conflicting sizes between operators" + + # Verify all input/output args are present (either as regular or sliced buffers) + all_buffer_names = set(args.keys()) | set(sliced_buffers.keys()) + for arg in self.input_args: + # Check if it's a base buffer name in explicit_buffer_sizes + if arg not in all_buffer_names and arg not in self.explicit_buffer_sizes: + raise AssertionError( + f"Input argument {arg} not found in runlist buffers" + ) + for arg in self.output_args: + if arg not in all_buffer_names and arg not in self.explicit_buffer_sizes: + raise AssertionError( + f"Output argument {arg} not found in runlist buffers" + ) + + # Determine buffer types and create layout + subbuffer_layout = {} + slice_info = {} # full_buffer_name -> (base_name, start, end) + + def add_buffers(buffer_type, args_list): + offset = 0 + for arg in args_list: + if arg in self.explicit_buffer_sizes: + # Explicit size specified - this is a parent buffer for slices + length = self.explicit_buffer_sizes[arg] + subbuffer_layout[arg] = (buffer_type, offset, length) + offset += length + elif arg in args: + # Regular buffer with inferred size + arg_spec = args[arg] + length = int( + np.prod(arg_spec.shape) * np.dtype(arg_spec.dtype).itemsize + ) + subbuffer_layout[arg] = (buffer_type, offset, length) + offset += length + # Note: sliced buffers are handled separately, not in args_list + return offset # == total length + + # Add sliced buffer entries to layout (they reference parent buffers) + for buf_name, (base_name, start, end, args_spec) in sliced_buffers.items(): + slice_info[buf_name] = (base_name, start, end) + + input_buffer_size = add_buffers("input", self.input_args) + output_buffer_size = add_buffers("output", self.output_args) + scratch_args = [ + arg + for arg in args + if arg not in self.input_args and arg not in self.output_args + ] + # Also include explicit buffers that are only used for slicing + for explicit_buf in self.explicit_buffer_sizes: + if ( + explicit_buf not in self.input_args + and explicit_buf not in self.output_args + and explicit_buf not in scratch_args + ): + scratch_args.append(explicit_buf) + scratch_buffer_size = add_buffers("scratch", scratch_args) + + buffer_sizes = (input_buffer_size, output_buffer_size, scratch_buffer_size) + return subbuffer_layout, buffer_sizes, slice_info + + def set_up_artifacts(self): + operator_name = self.get_operator_name() + mlir_artifact = self.get_mlir_artifact() + kernel_objects = self.get_kernel_artifacts() + kernel_dep = ( + [ + comp.KernelArchiveArtifact( + self.kernel_archive, + dependencies=kernel_objects, + ) + ] + if kernel_objects + else [] + ) + full_elf_artifact = comp.FullElfArtifact( + f"{operator_name}.elf", + mlir_input=mlir_artifact, + dependencies=[mlir_artifact] + kernel_dep, + ) + self.add_artifacts([full_elf_artifact]) + + def get_arg_spec(self): + pass + + def get_callable(self): + return FusedFullELFCallable(self) + + def get_layout_for_buffer(self, buffer_name): + if buffer_name in self.slice_info: + buf_name, start, end = self.slice_info[buffer_name] + buf_type, parent_start, parent_end = self.get_layout_for_buffer(buf_name) + return buf_type, parent_start + start, parent_start + end + + buf_type, offset, length = self.subbuffer_layout[buffer_name] + return buf_type, offset, length + + +def load_elf(op): + assert isinstance(op.artifacts[0], comp.FullElfArtifact) + elf_data = None + with open(op.artifacts[0].filename, "rb") as f: + elf_data = np.frombuffer(f.read(), dtype=np.uint32) + return elf_data + + +def patch_elf(elf_data, patches): + for i, patch in patches.items(): + val, mask = patch + val = np.uint64(val) + mask = np.uint64(mask) # avoid numpy overflow errors + elf_data[i] = np.uint32((elf_data[i] & ~mask) | (val & mask)) + return elf_data + + +class FullELFCallable: + def __init__( + self, + elf_data, + device_name="main", + sequence_name="sequence", + device_manager=None, + ): + self.device_name = device_name + self.sequence_name = sequence_name + self.device_manager = device_manager or AIEDeviceManager() + self.reload_elf(elf_data) + + def __call__(self, *args): + run = pyxrt.run(self.xrt_kernel) + for i, arg in enumerate(args): + assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo" + run.set_arg(i, arg) + run.start() + ret_code = run.wait() + if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: + raise RuntimeError(f"Kernel execution failed with return code {ret_code}") + + def reload_elf(self, elf_data): + # Create a PyCapsule from the numpy array pointer for pybind11 + elf_data_u8 = elf_data.view(dtype=np.uint8) + ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object + ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.c_void_p, + ] + capsule = ctypes.pythonapi.PyCapsule_New(elf_data_u8.ctypes.data, None, None) + xrt_elf = pyxrt.elf(capsule, elf_data.nbytes) + xrt_context = pyxrt.hw_context(self.device_manager.device, xrt_elf) + self.xrt_kernel = pyxrt.ext.kernel( + xrt_context, f"{self.device_name}:{self.sequence_name}" + ) + + +class FusedFullELFCallable(FullELFCallable): + def __init__(self, op, elf_data=None, device_manager=None): + if elf_data is None: + elf_data = load_elf(op) + super().__init__(elf_data, device_manager=device_manager) + + self.op = op + input_buffer_size, output_buffer_size, scratch_buffer_size = op.buffer_sizes + itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + + self.input_buffer = AIEBuffer( + shape=(max(input_buffer_size, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + + self.output_buffer = AIEBuffer( + shape=(max(output_buffer_size, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + + self.scratch_buffer = AIEBuffer( + shape=(max(scratch_buffer_size, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + + self._buffer_cache = {} + + def get_buffer(self, buffer_name): + # Return cached buffer if already allocated + if buffer_name in self._buffer_cache: + return self._buffer_cache[buffer_name] + + buf_type, offset, length = self.op.get_layout_for_buffer(buffer_name) + + # Select the appropriate main buffer + if buf_type == "input": + main_buffer = self.input_buffer + elif buf_type == "output": + main_buffer = self.output_buffer + elif buf_type == "scratch": + main_buffer = self.scratch_buffer + else: + raise ValueError( + f"Unknown buffer type '{buf_type}' for buffer '{buffer_name}'" + ) + + if main_buffer is None: + raise RuntimeError(f"Main buffer for type '{buf_type}' is not allocated") + + # Convert byte offset/length to element offset/length + itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + offset_elements = offset // itemsize + length_elements = length // itemsize + + # Create subbuffer with appropriate shape + sub_buffer = main_buffer.subbuffer( + length=length_elements, + offset=offset_elements, + shape=(length_elements,), + dtype=ml_dtypes.bfloat16, + ) + + # Cache and return + self._buffer_cache[buffer_name] = sub_buffer + return sub_buffer + + def __call__(self): + self.input_buffer.to("npu") + self.output_buffer.to("npu") + super().__call__( + self.input_buffer.bo if self.input_buffer else None, + self.output_buffer.bo if self.output_buffer else None, + self.scratch_buffer.bo if self.scratch_buffer else None, + ) diff --git a/iron/common/test_utils.py b/iron/common/test_utils.py index dc19df5d..53f40e9f 100644 --- a/iron/common/test_utils.py +++ b/iron/common/test_utils.py @@ -6,6 +6,7 @@ from ml_dtypes import bfloat16 from .utils import torch_to_numpy import logging +from .base import MLIROperator, CompositeOperator, AIEBuffer def nearly_equal( @@ -29,11 +30,11 @@ def nearly_equal( return diff < max(abs_tol, rel_tol * norm) -def verify_buffer(operator, buf_name, reference, rel_tol=0.04, abs_tol=1e-6): +def verify_buffer(output, buf_name, reference, rel_tol=0.04, abs_tol=1e-6): errors = [] expected_np = torch_to_numpy(reference).reshape((-1,)) - buf_size = operator.buffers[buf_name] // 2 - output = operator.read_buffer(buf_name, (buf_size,)) + output = output.reshape((-1,)) + if len(output) < len(expected_np): # Allow larger buffers - binning may have allocated more space than needed print( @@ -65,7 +66,7 @@ def run_test( Run operator test with specified input/output/intermediate buffers. Args: - operator: AIE operator instance with registered buffers + operator: AIE operator instance input_buffers: Dict mapping buffer names to input data arrays output_buffers: Dict mapping buffer names to reference output arrays intermediate_buffers: Optional dict mapping buffer names to reference arrays for validation @@ -83,45 +84,78 @@ def run_test( level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) - operator.context.compile_all() - operator.context.prepare_runtime() - # Run warmup iterations before writing to buffers (warmup iters might corrupt the buffers) + if not isinstance(operator, (MLIROperator, CompositeOperator)): + raise ValueError("run_test only supports MLIROperator or CompositeOperator") + + operator.compile() + op_func = operator.get_callable() + + args = [] + arg_spec = operator.get_arg_spec() + + input_iter = iter(input_buffers.items()) + output_iter = iter(output_buffers.items()) + output_map = {} + + total_bytes = 0 + + for spec in arg_spec: + if spec.direction == "in": + try: + name, data = next(input_iter) + except StopIteration: + raise ValueError("Not enough input buffers provided for arg spec") + data_np = torch_to_numpy(data) + buf = AIEBuffer.from_np(data_np) + args.append(buf) + total_bytes += buf.bo.size() + elif spec.direction == "out": + try: + name, expected = next(output_iter) + except StopIteration: + raise ValueError("Not enough output buffers provided for arg spec") + buf = AIEBuffer(shape=spec.shape, dtype=spec.dtype) + args.append(buf) + output_map[name] = buf + total_bytes += buf.bo.size() + else: + # Handle other directions if needed, or raise error + raise ValueError(f"Unsupported direction: {spec.direction}") + + # Run warmup iterations for _ in range(warmup_iters): - operator.run_runlist() # warmup run to configure - - # Write input buffers and zero outputs - for buf_name in output_buffers: - buf_size = operator.buffers[buf_name] - operator.write_buffer(buf_name, np.zeros(buf_size, dtype=np.uint8)) - # Operator may share the same buffer object for inputs and outputs; hence, write input after outputs - for buf_name, data in input_buffers.items(): - data_np = torch_to_numpy(data) - operator.write_buffer(buf_name, data_np) + op_func(*args) # Run operator - elapsed_total = 0 + start_time = time.time() for _ in range(timed_iters): - elapsed_total += operator.run_runlist() - elapsed = elapsed_total / timed_iters + op_func(*args) + end_time = time.time() + + elapsed = (end_time - start_time) / timed_iters latency_us = elapsed * 1e6 # Verify outputs errors = {} for buf_name, expected in output_buffers.items(): - buf_errors = verify_buffer(operator, buf_name, expected, rel_tol, abs_tol) - if buf_errors: - errors[buf_name] = buf_errors - - for buf_name, expected in intermediate_buffers.items(): - buf_errors = verify_buffer(operator, buf_name, expected, rel_tol, abs_tol) - if buf_errors: - errors[buf_name] = buf_errors + if expected is None: + continue + if buf_name in output_map: + buf = output_map[buf_name] + output_np = buf.view_as_np() + buf_errors = verify_buffer(output_np, buf_name, expected, rel_tol, abs_tol) + if buf_errors: + errors[buf_name] = buf_errors + else: + print(f"Warning: Output buffer {buf_name} not found in operator arguments") + + # Intermediate buffers are not supported in this generic run_test + # unless we expose them somehow. For now, ignore or warn. + if intermediate_buffers: + print("Warning: intermediate_buffers verification is not supported in run_test") # Calculate bandwidth - input_bytes = sum(operator.buffers[buf_name] for buf_name in input_buffers) - output_bytes = sum(operator.buffers[buf_name] for buf_name in output_buffers) - total_bytes = input_bytes + output_bytes bandwidth_gbps = total_bytes / (latency_us * 1e-6) / 1e9 return errors, latency_us, bandwidth_gbps diff --git a/iron/operators/__init__.py b/iron/operators/__init__.py index fc203892..f40db99d 100644 --- a/iron/operators/__init__.py +++ b/iron/operators/__init__.py @@ -1,24 +1,17 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from .axpy.op import AIEAXPY -from .dequant.op import AIEDequant from .elementwise_add.op import AIEElementwiseAdd from .elementwise_mul.op import AIEElementwiseMul -from .gelu.op import AIEGELU from .gemm.op import AIEGEMM from .gemv.op import AIEGEMV -from .layer_norm.op import AIELayerNorm -from .leaky_relu.op import AIELeakyReLU -from .mem_copy.op import AIEMemCopy from .mha.op import AIEMHA -from .relu.op import AIEReLU from .rms_norm.op import AIERMSNorm from .rope.op import AIERope -from .sigmoid.op import AIESigmoid from .silu.op import AIESiLU from .softmax.op import AIESoftmax from .swiglu_decode.op import AIESwiGLUDecode from .swiglu_prefill.op import AIESwiGLUPrefill -from .tanh.op import AIETanh from .transpose.op import AIETranspose +from .strided_copy.op import AIEStridedCopy +from .repeat.op import AIERepeat diff --git a/iron/operators/axpy/design.py b/iron/operators/axpy/design.py index 69468940..bfa676f8 100644 --- a/iron/operators/axpy/design.py +++ b/iron/operators/axpy/design.py @@ -16,7 +16,14 @@ def my_axpy( - dev, num_elements, num_columns, num_channels, tile_size, trace_size, scalar_factor + dev, + num_elements, + num_columns, + num_channels, + tile_size, + trace_size, + scalar_factor, + kernel_archive=None, ): factor = scalar_factor per_tile_elements = 4096 if tile_size > 4096 else tile_size diff --git a/iron/operators/axpy/op.py b/iron/operators/axpy/op.py index ce1702c6..37e66a33 100644 --- a/iron/operators/axpy/op.py +++ b/iron/operators/axpy/op.py @@ -7,8 +7,8 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -17,7 +17,7 @@ ) -class AIEAXPY(AIEOperatorBase): +class AIEAXPY(MLIROperator): """AIE-accelerated aX + Y operator""" def __init__( @@ -30,25 +30,26 @@ def __init__( context=None, ): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_aie_columns = num_aie_columns self.num_channels = num_channels self.scalar_factor = scalar_factor - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"axpy_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t_{self.scalar_factor}s" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"axpy_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t_{self.scalar_factor}s" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_axpy", callback_args=[ @@ -62,68 +63,21 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"axpy.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "generic" - / "axpy.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("x", self.size) - self.add_buffer("y", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "axpy", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("axpy", "x", "y", "output") - - def forward(self, x, y): - if x.numel() > self.size or y.numel() > self.size: - raise AIEOperatorConstraintError( - "AIEAXPY: input too large for configured size" - ) - if x.numel() != y.numel(): - raise AIEOperatorConstraintError("AIEAXPY: sizes of X and Y do not match") - - original_shape = x.shape - x_flat = x.reshape(-1) - y_flat = y.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - y_flat = torch.nn.functional.pad(y_flat, (0, pad_len)) - - self.write_buffer("x", x_flat) - self.write_buffer("y", y_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"axpy.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "generic" / "axpy.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # x + AIERuntimeArgSpec("in", (self.size,)), # y + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/iron/operators/axpy/test.py b/iron/operators/axpy/test.py index b91e802f..b37fabe2 100755 --- a/iron/operators/axpy/test.py +++ b/iron/operators/axpy/test.py @@ -12,40 +12,37 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 2 - input_lengths = [2048] if not extensive else [1024, 2048, 4096, 8192] - scalar_factors = [3.0] if not extensive else [3.0, 10.0] + input_lengths = [1024, 2048, 4096, 8192] + scalar_factors = [3.0, 10.0] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns if tile_size * num_aie_columns != input_length: continue for scalar in scalar_factors: - names.append( - f"axpy_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}_{scalar}" - ) - params.append( - (input_length, num_aie_columns, num_channels, tile_size, scalar) - ) - return params, names + name = f"axpy_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}_{scalar}" + # Determine if this is a regular test case + is_regular = input_length == 2048 and scalar == 3.0 + marks = [] if is_regular else [pytest.mark.extensive] -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + scalar, + id=name, + marks=marks, + ) + ) + return params @pytest.mark.metrics( @@ -54,7 +51,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size,scalar_factor", - all_params, + get_params(), ) def test_axpy( input_length, num_aie_columns, num_channels, tile_size, scalar_factor, aie_context diff --git a/iron/operators/dequant/design.py b/iron/operators/dequant/design.py index 05cf2ddd..07c3e3bf 100644 --- a/iron/operators/dequant/design.py +++ b/iron/operators/dequant/design.py @@ -16,7 +16,14 @@ def my_dequant_kernel( - dev, num_elements, num_columns, num_channels, trace_size, tile_size, group_size + dev, + num_elements, + num_columns, + num_channels, + trace_size, + tile_size, + group_size, + kernel_archive=None, ): per_tile_elements = ( 16384 if tile_size > 16384 else tile_size diff --git a/iron/operators/dequant/op.py b/iron/operators/dequant/op.py index d4aeab8a..8fd3e933 100644 --- a/iron/operators/dequant/op.py +++ b/iron/operators/dequant/op.py @@ -7,8 +7,8 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -17,7 +17,7 @@ ) -class AIEDequant(AIEOperatorBase): +class AIEDequant(MLIROperator): def __init__( self, @@ -46,17 +46,15 @@ def __init__( assert self.size % total_cores == 0, "Size must be divisible by total cores" assert total_cores <= 16, "Total cores (columns * channels) must be <= 16" - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"dequant_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"dequant_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_dequant_kernel", callback_args=[ @@ -70,68 +68,24 @@ def set_up_artifacts(self): ], ) - # Build the kernel object file with the appropriate tile size and group size - kernel_artifact = KernelObjectArtifact.new( - f"expand_aie2_{self.tile_size}.o", - depends=[ - SourceArtifact.new( - self.context.base_dir / "aie_kernels" / "generic" / "expand.cc" - ) - ], - extra_flags=[ - f"-DTILE_SIZE={self.tile_size}", - f"-DGROUP_SIZE={self.group_size}", - ], - ) - - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[mlir_artifact, kernel_artifact], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # Input buffer uses uint8 dtype, output uses bfloat16 - self.add_buffer("input", self.input_size, dtype=np.uint8) - self.add_buffer("output", self.output_size, dtype=bfloat16) - self.add_kernel( - "dequant", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("dequant", "input", "output") - - def forward(self, x_packed): - """ - Forward pass for dequantization. - - Args: - x_packed: Packed uint8 numpy array containing int4 data + scale factors - - Returns: - Dequantized bfloat16 torch tensor - """ - if x_packed.size != self.input_size: - raise AIEOperatorConstraintError( - f"AIEDequant: input size {x_packed.size} does not match expected size {self.input_size}" + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"expand_aie2_{self.tile_size}.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "generic" / "expand.cc" + ) + ], + extra_flags=[ + f"-DTILE_SIZE={self.tile_size}", + f"-DGROUP_SIZE={self.group_size}", + ], ) + ] - # Write input and execute - self.write_buffer("input", x_packed.flatten()) - self.write_buffer("output", np.zeros(self.output_size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch( - "output", shape=(self.output_size,), dtype=bfloat16 - ) - - return result + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.input_size,), dtype=np.uint8), # input + AIERuntimeArgSpec("out", (self.output_size,), dtype=bfloat16), # output + ] diff --git a/iron/operators/dequant/test.py b/iron/operators/dequant/test.py index 03b037f4..4ab904c0 100644 --- a/iron/operators/dequant/test.py +++ b/iron/operators/dequant/test.py @@ -12,12 +12,11 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): - input_lengths = [2048] if not extensive else [1024, 2048, 4096, 8192] +def get_params(): + input_lengths = [1024, 2048, 4096, 8192] group_size = 32 params = [] - names = [] for input_length in input_lengths: for num_columns in range(1, 9): # 1 to 8 columns for num_channels in range(1, 3): # 1 or 2 channels @@ -30,26 +29,23 @@ def generate_test_params(extensive=False): # Only proceed if tile_size * total_cores == input_length (exact division) if tile_size * total_cores == input_length: - names.append( - f"dequant_{num_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - ) - params.append( - (input_length, num_columns, num_channels, tile_size, group_size) - ) - return params, names + name = f"dequant_{num_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params.append( + pytest.param( + input_length, + num_columns, + num_channels, + tile_size, + group_size, + id=name, + marks=marks, + ) + ) + return params @pytest.mark.metrics( @@ -58,7 +54,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size,group_size", - all_params, + get_params(), ) def test_dequant( input_length, num_aie_columns, num_channels, tile_size, group_size, aie_context diff --git a/iron/operators/elementwise_add/design.py b/iron/operators/elementwise_add/design.py index d1eda376..246331b7 100644 --- a/iron/operators/elementwise_add/design.py +++ b/iron/operators/elementwise_add/design.py @@ -15,7 +15,15 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_eltwise_add(dev, num_elements, num_columns, num_channels, tile_size, trace_size): +def my_eltwise_add( + dev, + num_elements, + num_columns, + tile_size, + trace_size, + kernel_archive, + func_prefix="", +): per_tile_elements = 4096 if tile_size > 4096 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -37,7 +45,9 @@ def my_eltwise_add(dev, num_elements, num_columns, num_channels, tile_size, trac # AIE Core Function declaration eltwise_add_bf16_vector = Kernel( - "eltwise_add_bf16_vector", "add.o", [tile_ty, tile_ty, tile_ty, np.int32] + f"{func_prefix}eltwise_add_bf16_vector", + kernel_archive, + [tile_ty, tile_ty, tile_ty, np.int32], ) # Define a task that will run on a compute tile diff --git a/iron/operators/elementwise_add/op.py b/iron/operators/elementwise_add/op.py index d1963723..7d2dd7a7 100644 --- a/iron/operators/elementwise_add/op.py +++ b/iron/operators/elementwise_add/op.py @@ -8,8 +8,8 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -19,152 +19,62 @@ ) -class AIEElementwiseAdd(AIEOperatorBase): +class AIEElementwiseAdd(MLIROperator): """AIE-accelerated element-wise addition""" def __init__( self, size, - num_aie_columns=None, - num_channels=None, - tile_size=None, + tile_size, + num_aie_columns=8, context=None, ): - max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % (num_aie_columns * tile_size) == 0 + ), "size must be multiple of num_aie_columns * tile_size" + self.size = size self.tile_size = tile_size - self.num_aie_columns = num_aie_columns - self.num_channels = num_channels # Enforce ShimDMA limits for elementwise_add (uses 2 inputs per core) # Maximum safe configuration: 8 columns × 2 channels = 16 ShimDMA channels - total_shimdma_channels = self.num_aie_columns * self.num_channels + total_shimdma_channels = self.num_aie_columns * 2 assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" + MLIROperator.__init__(self, context=context) - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None - - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"add_{self.num_aie_columns}col_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): - # Compilation artifacts + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"add_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_eltwise_add", callback_args=[ self.context.device_manager.device_type, self.size, self.num_aie_columns, - self.num_channels, self.tile_size, 0, ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"add.o", - depends=[ - SourceArtifact.new( - self.context.base_dir / "aie_kernels" / "generic" / "add.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"add.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "generic" / "add.cc" + ) + ], + ), + ] + + def get_arg_spec(self): # Runtime setup - self.add_buffer("input1", self.size) - self.add_buffer("input2", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "eltwise_add", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("eltwise_add", "input1", "input2", "output") - - def forward(self, x, y): - """Forward pass for element-wise addition""" - applicable = ( - len(x.shape) >= 1 - and len(y.shape) >= 1 - and x.shape[-1] <= self.size - and y.shape[-1] <= self.size - and x.numel() <= self.size - and y.numel() <= self.size - and x.numel() == y.numel() - and x.shape == y.shape - ) - if not applicable: - raise AIEOperatorConstraintError( - "AIEElementwiseAdd: incompatible tensor shape(s)" - ) - - # Always flatten to [batch, orig_size] - original_shape = x.shape - batch = x.shape[0] if x.dim() > 1 else 1 - x_flat = x.reshape(batch, -1) - y_flat = y.reshape(batch, -1) - - pad_len = self.size - x_flat.shape[1] - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - y_flat = torch.nn.functional.pad(y_flat, (0, pad_len)) - - out = self._execute_aie_operation(x_flat, y_flat) - - # Remove padding if added - numel = np.prod(original_shape) - if pad_len > 0: - out = out.reshape(-1)[..., :numel] - # Restore original shape - out = out.reshape(*original_shape) - - return out - - def _execute_aie_operation(self, x, y): - """Execute element-wise addition operation on AIE hardware""" - # x, y are [batch, size] - batch = x.shape[0] if x.dim() > 1 else 1 - - # Flatten inputs for AIE processing - x_flat = x.view(-1) - y_flat = y.view(-1) - - # Verify size matches expected - if len(x_flat) != self.size or len(y_flat) != self.size: - raise AIEOperatorConstraintError( - f"Input size x={len(x_flat)}, y={len(y_flat)} doesn't match configured size {self.size}" - ) - - self.write_buffer("input1", x_flat) - self.write_buffer("input2", y_flat) - test_pattern = np.zeros(len(x_flat), dtype=bfloat16) - self.write_buffer("output", test_pattern) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=x_flat.shape, dtype=bfloat16) - - return result + return [ + AIERuntimeArgSpec("in", (self.size,)), # input1 + AIERuntimeArgSpec("in", (self.size,)), # input2 + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/iron/operators/elementwise_add/test.py b/iron/operators/elementwise_add/test.py index 781265f5..5794a2c4 100755 --- a/iron/operators/elementwise_add/test.py +++ b/iron/operators/elementwise_add/test.py @@ -12,36 +12,35 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 2 - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + # Combine all lengths + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns if tile_size * num_aie_columns != input_length: continue - names.append( - f"eltwise_add_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - ) - params.append((input_length, num_aie_columns, num_channels, tile_size)) - return params, names + name = f"eltwise_add_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) + ) + return params @pytest.mark.metrics( @@ -50,7 +49,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_elementwise_add( input_length, num_aie_columns, num_channels, tile_size, aie_context @@ -60,7 +59,6 @@ def test_elementwise_add( operator = AIEElementwiseAdd( size=input_length, num_aie_columns=num_aie_columns, - num_channels=num_channels, tile_size=tile_size, context=aie_context, ) diff --git a/iron/operators/elementwise_mul/design.py b/iron/operators/elementwise_mul/design.py index 88ae1e31..51319004 100644 --- a/iron/operators/elementwise_mul/design.py +++ b/iron/operators/elementwise_mul/design.py @@ -12,9 +12,18 @@ from aie.iron.device import NPU1, NPU2 from aie.helpers.taplib.tap import TensorAccessPattern from aie.iron.controlflow import range_ - - -def my_eltwise_mul(dev, num_elements, num_columns, num_channels, tile_size, trace_size): +from aie.helpers.util import np_ndarray_type_get_shape + + +def my_eltwise_mul( + dev, + num_elements, + num_columns, + tile_size, + trace_size, + kernel_archive, + func_prefix="", +): per_tile_elements = 4096 if tile_size > 4096 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -36,7 +45,9 @@ def my_eltwise_mul(dev, num_elements, num_columns, num_channels, tile_size, trac # AIE Core Function declaration eltwise_mul_bf16_vector = Kernel( - "eltwise_mul_bf16_vector", "mul.o", [tile_ty, tile_ty, tile_ty, np.int32] + f"{func_prefix}eltwise_mul_bf16_vector", + kernel_archive, + [tile_ty, tile_ty, tile_ty, np.int32], ) # Define a task that will run on a compute tile @@ -146,11 +157,6 @@ def str_to_device(device: str): p.add_argument( "-co", "--columns", required=True, dest="cols", help="Number of columns" ) - # Number of channels is required to define the number of channels to be used - # It must be 1 or 2 - p.add_argument( - "-ch", "--channels", required=True, dest="chans", help="Number of channels" - ) # Tile size (elements per tile) - defaults to 1024 for backward compatibility p.add_argument( "-ts", @@ -183,9 +189,6 @@ def str_to_device(device: str): elif isinstance(dev, NPU2) and columns > 8: raise ValueError("[ERROR] NPU2 device cannot allocate more than 8 columns") - channels = int(opts.chans) - if channels < 1 or channels > 2: - raise ValueError("Number of channels must be 1 or 2") tile_size = int(opts.tile_size) if length % (tile_size * columns) != 0: print( @@ -198,7 +201,7 @@ def str_to_device(device: str): raise ValueError trace_size = int(opts.trace_size) if opts.trace_size is not None else 0 - module = my_eltwise_mul(dev, length, columns, channels, tile_size, trace_size) + module = my_eltwise_mul(dev, length, columns, tile_size, trace_size, "mul.o") output_file_path = Path(opts.output_file_path) diff --git a/iron/operators/elementwise_mul/op.py b/iron/operators/elementwise_mul/op.py index 60113341..2304ca99 100644 --- a/iron/operators/elementwise_mul/op.py +++ b/iron/operators/elementwise_mul/op.py @@ -1,164 +1,75 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import torch -import numpy as np -from ml_dtypes import bfloat16 from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, - KernelArchiveArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIEElementwiseMul(AIEOperatorBase): +class AIEElementwiseMul(MLIROperator): """AIE-accelerated element-wise multiplication""" def __init__( - self, size, num_aie_columns, num_channels, tile_size, trace_size=0, context=None + self, + size, + tile_size, + num_aie_columns=8, + context=None, ): - max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % (num_aie_columns * tile_size) == 0 + ), "size must be multiple of num_aie_columns * tile_size" + self.size = size self.tile_size = tile_size self.num_aie_columns = num_aie_columns - self.num_channels = num_channels - self.trace_size = trace_size - - total_shimdma_channels = self.num_aie_columns * self.num_channels + # Enforce ShimDMA limits for elementwise_mul (uses 2 inputs per core) + # Maximum safe configuration: 8 columns × 2 channels = 16 ShimDMA channels + total_shimdma_channels = self.num_aie_columns * 2 assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" + MLIROperator.__init__(self, context=context) - self.xclbin_artifact = None - self.insts_artifact = None - - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"mul_{self.num_aie_columns}col_{self.size}_{self.tile_size}t" - def get_artifacts(self, prefix="eltwise_mul_"): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"{prefix}{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_eltwise_mul", callback_args=[ self.context.device_manager.device_type, self.size, self.num_aie_columns, - self.num_channels, self.tile_size, - self.trace_size, + 0, ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"mul.o", - depends=[ - SourceArtifact.new( - self.context.base_dir / "aie_kernels" / "generic" / "mul.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - return xclbin_artifact, insts_artifact - - def set_up_artifacts(self): - xclbin_artifact, insts_artifact = self.get_artifacts() - - mlir_artifact = xclbin_artifact.depends[0] - mlir_artifact.callback_args[0] = self.context.device_manager.device_type - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - self.add_buffer("input1", self.size) - self.add_buffer("input2", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "eltwise_mul", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("eltwise_mul", "input1", "input2", "output") - - def forward(self, x, y): - """Forward pass for element-wise multiplication""" - applicable = ( - len(x.shape) >= 1 - and len(y.shape) >= 1 - and x.shape[-1] <= self.size - and y.shape[-1] <= self.size - and x.numel() <= self.size - and y.numel() <= self.size - and x.numel() == y.numel() - and x.shape == y.shape - ) - - # Always flatten to [batch, orig_size] - original_shape = x.shape - batch = x.shape[0] if x.dim() > 1 else 1 - x_flat = x.reshape(batch, -1) - y_flat = y.reshape(batch, -1) - - pad_len = self.size - x_flat.shape[1] - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - y_flat = torch.nn.functional.pad(y_flat, (0, pad_len)) - - out = self._execute_aie_operation(x_flat, y_flat) - - # Remove padding if added - numel = np.prod(original_shape) - if pad_len > 0: - out = out.reshape(-1)[..., :numel] - # Restore original shape - out = out.reshape(*original_shape) - - return out - - def _execute_aie_operation(self, x, y): - """Execute element-wise multiplication operation on AIE hardware""" - # x, y are [batch, size] - batch = x.shape[0] if x.dim() > 1 else 1 - - # Flatten inputs for AIE processing - x_flat = x.view(-1) - y_flat = y.view(-1) - - # Verify size matches expected - if len(x_flat) != self.size or len(y_flat) != self.size: - raise AIEOperatorConstraintError( - f"Input size x={len(x_flat)}, y={len(y_flat)} doesn't match configured size {self.size}" - ) - - self.write_buffer("input1", x_flat) - self.write_buffer("input2", y_flat) - test_pattern = np.zeros(len(x_flat), dtype=bfloat16) - self.write_buffer("output", test_pattern) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=x_flat.shape, dtype=bfloat16) - - return result + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"mul.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "generic" / "mul.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + # Runtime setup + return [ + AIERuntimeArgSpec("in", (self.size,)), # input1 + AIERuntimeArgSpec("in", (self.size,)), # input2 + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/iron/operators/elementwise_mul/test.py b/iron/operators/elementwise_mul/test.py index 2c92d288..82b34a9b 100755 --- a/iron/operators/elementwise_mul/test.py +++ b/iron/operators/elementwise_mul/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 2 - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns @@ -26,24 +25,23 @@ def generate_test_params(extensive=False): tile_size = 4096 if tile_size * num_aie_columns != input_length: continue - names.append( - f"eltwise_mul_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - ) - params.append((input_length, num_aie_columns, num_channels, tile_size)) - return params, names + name = f"eltwise_mul_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) + ) + return params @pytest.mark.metrics( @@ -52,7 +50,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_elementwise_mul( input_length, num_aie_columns, num_channels, tile_size, aie_context @@ -61,9 +59,8 @@ def test_elementwise_mul( operator = AIEElementwiseMul( size=input_length, - num_aie_columns=num_aie_columns, - num_channels=num_channels, tile_size=tile_size, + num_aie_columns=num_aie_columns, context=aie_context, ) diff --git a/iron/operators/gelu/design.py b/iron/operators/gelu/design.py index 7a110286..3ecd85a5 100644 --- a/iron/operators/gelu/design.py +++ b/iron/operators/gelu/design.py @@ -15,7 +15,9 @@ from aie.iron.controlflow import range_ -def my_gelu(dev, size, num_columns, num_channels, tile_size, trace_size): +def my_gelu( + dev, size, num_columns, num_channels, tile_size, trace_size, kernel_archive=None +): xfr_dtype = bfloat16 line_size = 8192 if tile_size > 8192 else tile_size fifodepth = 1 if line_size > 4096 else 2 diff --git a/iron/operators/gelu/op.py b/iron/operators/gelu/op.py index 86fea435..8f8f8157 100644 --- a/iron/operators/gelu/op.py +++ b/iron/operators/gelu/op.py @@ -7,8 +7,8 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -17,14 +17,17 @@ ) -class AIEGELU(AIEOperatorBase): +class AIEGELU(MLIROperator): """AIE-accelerated GELU activation function""" def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_aie_columns = num_aie_columns self.num_channels = num_channels @@ -32,17 +35,15 @@ def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None) total_shimdma_channels = self.num_aie_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"gelu_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"gelu_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_gelu", callback_args=[ @@ -55,65 +56,20 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"gelu.o", - depends=[ - SourceArtifact.new( - self.context.base_dir / "aie_kernels" / "aie2p" / "gelu.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "gelu", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("gelu", "input", "output") - - def forward(self, x): - """Forward pass for GELU activation""" - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIEGELU: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - # Pad if necessary - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - # Execute on AIE - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - # Remove padding and restore shape - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"gelu.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "aie2p" / "gelu.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/iron/operators/gelu/test.py b/iron/operators/gelu/test.py index d91a9e7a..69b4519d 100755 --- a/iron/operators/gelu/test.py +++ b/iron/operators/gelu/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels_choices = [1, 2] - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): for num_channels in num_channels_choices: @@ -28,26 +27,22 @@ def generate_test_params(extensive=False): tile_size = 8192 check_length = tile_size * total_cores if check_length == input_length: - names.append( - f"gelu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - ) - params.append( - (input_length, num_aie_columns, num_channels, tile_size) - ) - return params, names + name = f"gelu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) + ) + return params @pytest.mark.metrics( @@ -56,7 +51,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_gelu(input_length, num_aie_columns, num_channels, tile_size, aie_context): golden_ref = generate_golden_reference(input_length=input_length) diff --git a/iron/operators/gemm/design.py b/iron/operators/gemm/design.py index 6ea439d5..e5b4d748 100644 --- a/iron/operators/gemm/design.py +++ b/iron/operators/gemm/design.py @@ -106,6 +106,7 @@ def main(): args.separate_c_tiles, args.trace_size, args.archive, + "", args.generate_taps, ) @@ -140,7 +141,8 @@ def my_matmul( prio_accuracy, separate_c_tiles, trace_size, - archive=None, + kernel_archive=None, + func_prefix="", generate_taps=False, ): n_aie_rows = 4 @@ -273,7 +275,11 @@ def my_matmul( # AIE Core Function declarations scalar_suffix = "_scalar" if use_scalar else "" - archive_name = f"gemm_{m}x{k}x{n}_archive.a" if archive is None else archive + kernel_archive = ( + f"{func_prefix}gemm_{m}x{k}x{n}_archive.a" + if kernel_archive is None + else kernel_archive + ) if use_larger_internal_buffer: # Fix fifo depth for C objfifo to 1 since 1 buffer will be used for accumulation # and another for transfer to L2 @@ -283,19 +289,19 @@ def my_matmul( # A kernel to convert from the internal f32 accumulation to bf16 for transfer to L2 is needed convert_copy_kernel = Kernel( f"convert_copy_f32_to_bf16", - archive_name, + kernel_archive, [C_l1_ty_internal, C_l1_ty, np.int32], ) # Fix the kernels to use f32 outputs zero_kernel = Kernel( f"zero{scalar_suffix}_f32", - archive_name, + kernel_archive, [C_l1_ty_internal], ) matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_f32" matmul_kernel = Kernel( matmul_func_name, - archive_name, + kernel_archive, [A_l1_ty, B_l1_ty, C_l1_ty_internal], ) else: @@ -304,13 +310,13 @@ def my_matmul( fifo_depth_out = fifo_depth zero_kernel = Kernel( f"zero{scalar_suffix}_{dtype_out_str}", - archive_name, + kernel_archive, [C_l1_ty], ) matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" matmul_kernel = Kernel( matmul_func_name, - archive_name, + kernel_archive, [A_l1_ty, B_l1_ty, C_l1_ty], ) diff --git a/iron/operators/gemm/op.py b/iron/operators/gemm/op.py index 007e46b3..841fef03 100644 --- a/iron/operators/gemm/op.py +++ b/iron/operators/gemm/op.py @@ -8,12 +8,11 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, - KernelArchiveArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) @@ -21,7 +20,7 @@ from iron.common.utils import torch_to_numpy, numpy_to_torch -class AIEGEMM(AIEOperatorBase): +class AIEGEMM(MLIROperator): """AIE-accelerated General Matrix Multiplication (GEMM) layer""" def __init__( @@ -36,64 +35,50 @@ def __init__( # TODO: Add support for partitioning M and/or K # partition_M=1, # partition_K=1, - partition_N=1, num_aie_columns=8, context=None, **gemm_kwargs, ): - + num_aie_rows = 4 + min_M = tile_m * num_aie_rows + min_K = tile_k + min_N = tile_n * num_aie_columns + assert M % min_M == 0, f"M ({M}) must be multiple of {min_M}" + assert K % min_K == 0, f"K ({K}) must be multiple of {min_K}" + assert N % min_N == 0, f"N ({N}) must be multiple of {min_N}" + self.M = M + self.K = K + self.N = N self.tile_m = tile_m self.tile_k = tile_k self.tile_n = tile_n + self.num_aie_columns = num_aie_columns self.gemm_args = gemm_kwargs - - # Set frequently accessed gemm_args self.b_col_maj = gemm_kwargs.get("b_col_maj", False) self.c_col_maj = gemm_kwargs.get("c_col_maj", False) - self.weight = ( - None - if not use_static_weight - else torch.zeros((K, N), dtype=torch.bfloat16).T - ) - self.static_weight_shape = (K, N) - - # The operator's M, K, N represent what the NPU operator supports. - # Calls to forward() may supply matrices of different sizes, and the - # Python code will perform necessary padding/repeated application of - # the NPU operator. - assert ( - N % partition_N == 0 - ), f"N ({N}) must be divisible by partition_N ({partition_N})" - M_padded, K_padded, N_padded = self._get_padded_dims( - M, K, N // partition_N, tile_m, tile_k, tile_n + + emulate_bf16_mmul_with_bfp16 = self.gemm_args.get( + "emulate_bf16_mmul_with_bfp16", True ) - self.M = M_padded - self.K = K_padded - self.N = N_padded - self.partition_N = partition_N + if emulate_bf16_mmul_with_bfp16: + min_tile_m, min_tile_k, min_tile_n = 8, 8, 8 + else: + min_tile_m, min_tile_k, min_tile_n = 4, 8, 8 + assert tile_m >= min_tile_m, f"tile_m ({tile_m}) must be >= {min_tile_m}" + assert tile_k >= min_tile_k, f"tile_k ({tile_k}) must be >= {min_tile_k}" + assert tile_n >= min_tile_n, f"tile_n ({tile_n}) must be >= {min_tile_n}" - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"gemm_{self.M}x{self.K}x{self.N}_{self.tile_m}x{self.tile_k}x{self.tile_n}_{int(self.b_col_maj)}_{int(self.c_col_maj)}" - def get_artifacts(self, prefix="gemm_"): - # Extract parameters from self + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - tile_m = self.tile_m - tile_k = self.tile_k - tile_n = self.tile_n - M = self.M - K = self.K - N = self.N - num_aie_columns = self.num_aie_columns + operator_name = self.get_operator_name() base_dir = self.context.base_dir device_str = self.context.device_manager.device_str() - - b_col_maj = self.b_col_maj - c_col_maj = self.c_col_maj dtype_in = self.gemm_args.get("dtype_in", "bf16") dtype_out = self.gemm_args.get("dtype_out", "bf16") emulate_bf16_mmul_with_bfp16 = self.gemm_args.get( @@ -102,245 +87,171 @@ def get_artifacts(self, prefix="gemm_"): prio_accuracy = self.gemm_args.get("prio_accuracy", False) use_scalar = self.gemm_args.get("use_scalar", False) round_conv_even = self.gemm_args.get("round_conv_even", True) - - if emulate_bf16_mmul_with_bfp16: - min_tile_m, min_tile_k, min_tile_n = 8, 8, 8 - else: - min_tile_m, min_tile_k, min_tile_n = 4, 8, 8 - assert tile_m >= min_tile_m, f"tile_m ({tile_m}) must be >= {min_tile_m}" - assert tile_k >= min_tile_k, f"tile_k ({tile_k}) must be >= {min_tile_k}" - assert tile_n >= min_tile_n, f"tile_n ({tile_n}) must be >= {min_tile_n}" - - file_name_tile_base = f"{prefix}{tile_m}x{tile_k}x{tile_n}" - file_name_total_base = f"{prefix}{M}x{K}x{N}_{tile_m}x{tile_k}x{tile_n}_{int(b_col_maj)}_{int(c_col_maj)}" - xclbin_kernel_name = f"gemm_{file_name_tile_base}" - kernel_flags = [ - f"-DDIM_M={tile_m}", - f"-DDIM_K={tile_k}", - f"-DDIM_N={tile_n}", - "-DROUND_CONV_EVEN", - ] - if prio_accuracy: - kernel_flags.append("-Dbf16_f32_ONLY") - else: - kernel_flags.append("-Dbf16_bf16_ONLY") - if round_conv_even: - kernel_flags.append("-DROUND_CONV_EVEN") - if emulate_bf16_mmul_with_bfp16: - kernel_flags.append("-DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16") - if b_col_maj: - kernel_flags.append("-DB_COL_MAJ") - if c_col_maj: - kernel_flags.append("-DC_COL_MAJ") - - kernel_archive = ( - f"gemm_{tile_m}x{tile_k}x{tile_n}_{int(b_col_maj)}_{int(c_col_maj)}.a" - ) - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_total_base}.mlir", + separate_c_tiles = self.gemm_args.get("separate_c_tiles", False) + return PythonGeneratedMLIRArtifact( + f"{operator_name}.mlir", import_path=operator_dir / "design.py", callback_fn="my_matmul", callback_kwargs={ "dev": device_str, - "M": M, - "K": K, - "N": N, - "m": tile_m, - "k": tile_k, - "n": tile_n, - "n_aie_cols": num_aie_columns, + "M": self.M, + "K": self.K, + "N": self.N, + "m": self.tile_m, + "k": self.tile_k, + "n": self.tile_n, + "n_aie_cols": self.num_aie_columns, "dtype_in_str": dtype_in, "dtype_out_str": dtype_out, - "b_col_maj": int(b_col_maj), - "c_col_maj": int(c_col_maj), + "b_col_maj": int(self.b_col_maj), + "c_col_maj": int(self.c_col_maj), "use_scalar": use_scalar, "emulate_bf16_mmul_with_bfp16": emulate_bf16_mmul_with_bfp16, "prio_accuracy": prio_accuracy, - "separate_c_tiles": int(self.partition_N > 1), + "separate_c_tiles": int(separate_c_tiles), "trace_size": 0, - "archive": kernel_archive, "generate_taps": False, }, requires_context=False, ) - # FIXME: We should be able to reuse the same xclbin for same tile - # sizes, only swapping out the instruction sequence for different - # problem sizes. However, there seem to be cases where this does - # not work and the GEMM appears to be misconfigured for the wrong - # size (resulting in a timeout when trying to run it). Perhaps - # XRT is caching something, or something is wrong with the run- - # time parameter (synchronization)? For now, create separate - # xclbins for each problem size. - xclbin_artifact = XclbinArtifact.new( - f"{file_name_total_base}.xclbin", - depends=[ - mlir_artifact, - KernelArchiveArtifact.new( - kernel_archive, - depends=[ - KernelObjectArtifact.new( - f"gemm_{tile_m}x{tile_k}x{tile_n}_{int(b_col_maj)}_{int(c_col_maj)}.o", - extra_flags=kernel_flags, - depends=[ - SourceArtifact.new( - base_dir / "aie_kernels" / "aie2p" / "mm.cc" - ) - ], - ), - KernelObjectArtifact.new( - "convert_copy.o", - [ - SourceArtifact.new( - base_dir - / "aie_kernels" - / "generic" - / "convert_copy.cc" - ) - ], - ), - ], - ), - ], - extra_flags=["--dynamic-objFifos"], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_total_base}.bin", - depends=[mlir_artifact], - extra_flags=["--dynamic-objFifos"], - ) - - return (xclbin_artifact, insts_artifact) - - def set_up_artifacts(self): - # Describe required artifacts (xclbin, insts.bin) - device_str = self.context.device_manager.device_str() - xclbin_artifact, insts_artifact = self.get_artifacts() - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - static_weights = None - if self.weight is not None: - static_weights = self.weight.T - if isinstance(static_weights, torch.Tensor): - static_weights = torch_to_numpy(static_weights) - self.add_kernel( - "gemm", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_buffer("A", self.M * self.K) - B_parts = self._partition_B(static_weights) - for i, B_part in enumerate(B_parts): - self.add_buffer( - f"B_{i}", - self.K * self.N, - static_data=B_part, - ) - self.add_buffer(f"C_{i}", self.M * self.N) - self.add_to_runlist("gemm", "A", f"B_{i}", f"C_{i}") - - def _get_B_dims(self, B_shape): - """Extract K and N dimensions from B matrix shape based on layout. - - Returns: - tuple: (K, N) dimensions regardless of B's layout - """ - if self.b_col_maj: - return B_shape[-1], B_shape[-2] # B is (N, K) -> return (K, N) - else: - return B_shape[-2], B_shape[-1] # B is (K, N) -> return (K, N) - - def forward(self, A, B=None): - """Forward pass through GEMM operation: C = A @ B""" - B_shape = B.shape if B is not None else self.static_weight_shape - - # Determine output dimensions based on matrix layout - K2, N = self._get_B_dims(B_shape) - N_part = N // self.partition_N - - # Build expected output shape based on C layout - expected_output_shape = ( - A.shape[:-2] + (N, A.shape[-1]) if self.c_col_maj else A.shape[:-1] + (N,) - ) - - # Remove batch dimension, if any - if len(A.shape) > 2: - A = A.view(-1, A.shape[-1]) - if B is not None and len(B.shape) > 2: - B = B.view(-1, B_shape[-1]) - - M, K = A.shape - - applicable = ( - K == K2 - and (M <= self.M or not self.c_col_maj) - and K <= self.K - and N <= self.N - ) - if not applicable: - raise AIEOperatorConstraintError("AIEGEMM: incompatible tensor shape(s)") - - A_padded = self._pad_A(torch_to_numpy(A)) - if B is not None: - B_parts = self._partition_B(torch_to_numpy(B)) - else: - B_parts = None - - logging.debug( - f"Executing GEMM for dimensions M={M}, K={K}, N={N} using NPU operator with M={self.M}, K={self.N}, N={self.N}" + def get_kernel_artifacts(self): + base_dir = self.context.base_dir + emulate_bf16_mmul_with_bfp16 = self.gemm_args.get( + "emulate_bf16_mmul_with_bfp16", True ) - - if self.c_col_maj: - result_padded = np.zeros((N, M), dtype=A_padded.dtype) + prio_accuracy = self.gemm_args.get("prio_accuracy", False) + round_conv_even = self.gemm_args.get("round_conv_even", True) + kernel_flags = [ + f"-DDIM_M={self.tile_m}", + f"-DDIM_K={self.tile_k}", + f"-DDIM_N={self.tile_n}", + "-DROUND_CONV_EVEN", + ] + if prio_accuracy: + kernel_flags.append("-Dbf16_f32_ONLY") else: - result_padded = np.zeros((M, N), dtype=A_padded.dtype) - for M_lo in range(0, M, self.M): - A_part = A_padded[M_lo : M_lo + self.M, :] - result_parts = self._execute_aie_operation(A_part, B_parts) - max_M = min(M_lo + self.M, M) - for part in range(self.partition_N): - if self.c_col_maj: - result_padded[part * N_part : (part + 1) * N_part, M_lo:max_M] = ( - result_parts[part][:N_part, :max_M] - ) - else: - result_padded[M_lo:max_M, part * N_part : (part + 1) * N_part] = ( - result_parts[part][:max_M, :N_part] - ) - - # GEMM produces 2D result, reshape to expected output shape + kernel_flags.append("-Dbf16_bf16_ONLY") + if round_conv_even: + kernel_flags.append("-DROUND_CONV_EVEN") + if emulate_bf16_mmul_with_bfp16: + kernel_flags.append("-DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16") + if self.b_col_maj: + kernel_flags.append("-DB_COL_MAJ") if self.c_col_maj: - result = numpy_to_torch(result_padded[:N, :M]) - else: - result = numpy_to_torch(result_padded[:M, :N]) - result = result.view(expected_output_shape) - - return result - - def _get_padded_dims(self, M, K, N, tile_m, tile_k, tile_n): - num_aie_columns = self.num_aie_columns - num_aie_rows = 4 - - min_M = tile_m * num_aie_rows - min_K = tile_k - min_N = tile_n * num_aie_columns + kernel_flags.append("-DC_COL_MAJ") - # Calculate padded dimensions - M_padded = ((M + min_M - 1) // min_M) * min_M - K_padded = ((K + min_K - 1) // min_K) * min_K - N_padded = ((N + min_N - 1) // min_N) * min_N + # Include flags in the filename to avoid stale builds when flags change + flags_suffix = f"_{int(prio_accuracy)}_{int(emulate_bf16_mmul_with_bfp16)}_{int(round_conv_even)}" + + return [ + KernelObjectArtifact( + f"gemm_{self.tile_m}x{self.tile_k}x{self.tile_n}_{int(self.b_col_maj)}_{int(self.c_col_maj)}{flags_suffix}.o", + extra_flags=kernel_flags, + dependencies=[ + SourceArtifact(base_dir / "aie_kernels" / "aie2p" / "mm.cc") + ], + ), + KernelObjectArtifact( + "convert_copy.o", + [ + SourceArtifact( + base_dir / "aie_kernels" / "generic" / "convert_copy.cc" + ) + ], + ), + ] - return M_padded, K_padded, N_padded + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.M, self.K)), # input A + AIERuntimeArgSpec( + "in", (self.K, self.N) if not self.b_col_maj else (self.N, self.K) + ), # input B (weights) + AIERuntimeArgSpec( + "out", (self.M, self.N) if not self.c_col_maj else (self.N, self.M) + ), # output C + ] - def _pad_A(self, A_np): + # def _get_B_dims(self, B_shape): + # """Extract K and N dimensions from B matrix shape based on layout. + + # Returns: + # tuple: (K, N) dimensions regardless of B's layout + # """ + # if self.b_col_maj: + # return B_shape[-1], B_shape[-2] # B is (N, K) -> return (K, N) + # else: + # return B_shape[-2], B_shape[-1] # B is (K, N) -> return (K, N) + + # def forward(self, A, B=None): + # """Forward pass through GEMM operation: C = A @ B""" + # B_shape = B.shape if B is not None else self.static_weight_shape + + # # Determine output dimensions based on matrix layout + # K2, N = self._get_B_dims(B_shape) + # N_part = N // self.partition_N + + # # Build expected output shape based on C layout + # expected_output_shape = ( + # A.shape[:-2] + (N, A.shape[-1]) if self.c_col_maj else A.shape[:-1] + (N,) + # ) + + # # Remove batch dimension, if any + # if len(A.shape) > 2: + # A = A.view(-1, A.shape[-1]) + # if B is not None and len(B.shape) > 2: + # B = B.view(-1, B_shape[-1]) + + # M, K = A.shape + + # applicable = ( + # K == K2 + # and (M <= self.M or not self.c_col_maj) + # and K <= self.K + # and N <= self.N + # ) + # if not applicable: + # raise AIEOperatorConstraintError("AIEGEMM: incompatible tensor shape(s)") + + # A_padded = self._pad_A(torch_to_numpy(A)) + # if B is not None: + # B_parts = self._partition_B(torch_to_numpy(B)) + # else: + # B_parts = None + + # logging.debug( + # f"Executing GEMM for dimensions M={M}, K={K}, N={N} using NPU operator with M={self.M}, K={self.N}, N={self.N}" + # ) + + # if self.c_col_maj: + # result_padded = np.zeros((N, M), dtype=A_padded.dtype) + # else: + # result_padded = np.zeros((M, N), dtype=A_padded.dtype) + # for M_lo in range(0, M, self.M): + # A_part = A_padded[M_lo : M_lo + self.M, :] + # result_parts = self._execute_aie_operation(A_part, B_parts) + # max_M = min(M_lo + self.M, M) + # for part in range(self.partition_N): + # if self.c_col_maj: + # result_padded[part * N_part : (part + 1) * N_part, M_lo:max_M] = ( + # result_parts[part][:N_part, :max_M] + # ) + # else: + # result_padded[M_lo:max_M, part * N_part : (part + 1) * N_part] = ( + # result_parts[part][:max_M, :N_part] + # ) + + # # GEMM produces 2D result, reshape to expected output shape + # if self.c_col_maj: + # result = numpy_to_torch(result_padded[:N, :M]) + # else: + # result = numpy_to_torch(result_padded[:M, :N]) + # result = result.view(expected_output_shape) + + # return result + + def pad_A(self, A_np): """Pad A matrix to match operator dimensions (M, K)""" M, K = A_np.shape if M % self.M == 0 and K == self.K: @@ -351,7 +262,7 @@ def _pad_A(self, A_np): A_padded[:M, :K] = A_np return A_padded - def _pad_B(self, B_np): + def pad_B(self, B_np): """Pad B matrix to match operator dimensions based on layout""" if self.b_col_maj: N, K = B_np.shape @@ -367,56 +278,16 @@ def _pad_B(self, B_np): B_padded[:K, :N] = B_np return B_padded - def _partition_B(self, B): - B_parts = [None] * self.partition_N + def partition_B(self, B, partition_N): + B_parts = [None] * partition_N if B is None: return B_parts - for i in range(self.partition_N): + for i in range(partition_N): col_start = i * self.N col_end = (i + 1) * self.N - # Just in case, pad the weights before adding the buffer if self.b_col_maj: - B_parts[i] = self._pad_B(B[col_start:col_end, :]) + B_parts[i] = self.pad_B(B[col_start:col_end, :]) else: - B_parts[i] = self._pad_B(B[:, col_start:col_end]) - self.static_weight_shape = B_parts[0].shape + B_parts[i] = self.pad_B(B[:, col_start:col_end]) return B_parts - - def _execute_aie_operation(self, A_np, B_nps=None): - """Execute GEMM operation on AIE hardware""" - M, K = A_np.shape - B_shape = B_nps[0].shape if B_nps is not None else self.static_weight_shape - K2, N = self._get_B_dims(B_shape) - C_shape = (N, M) if self.c_col_maj else (M, N) - - # Validate dimensions match operator configuration - assert M == self.M - assert K == K2 and K == self.K - assert N == self.N - - self.write_buffer("A", A_np) - if B_nps is not None: - for i, B_np in enumerate(B_nps): - self.add_buffer( - f"B_{i}", - self.M * self.N, - static_data=B_np, - ) - self.run_runlist() - result_nps = [ - self.read_buffer(f"C_{i}", shape=C_shape, dtype=bfloat16) - for i in range(self.partition_N) - ] - - # Check for NaN and fail hard - # for result_np in result_nps: - # if np.isnan(result_np).any(): - # nan_count = np.isnan(result_np).sum() - # total_count = result_np.size - # raise RuntimeError( - # f"AIE execution returned {nan_count}/{total_count} NaN values. " - # ) - - # Convert back to torch tensor - return result_nps diff --git a/iron/operators/gemm/test.py b/iron/operators/gemm/test.py index 6480aeff..b1dc8194 100755 --- a/iron/operators/gemm/test.py +++ b/iron/operators/gemm/test.py @@ -12,10 +12,10 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): # fmt: off - params = [ - # M, K, N, num_aie_columns, b_col_maj, c_col_maj, m, k, n, trace_size, partition_N + # M, K, N, num_aie_columns, b_col_maj, c_col_maj, m, k, n, trace_size, partition_N + regular_params = [ (2048, 2048, 2048, 1, False, False, 64, 64, 64, 0, 1), (2048, 2048, 2048, 2, True, False, 64, 64, 64, 0, 1), (2048, 2048, 2048, 8, True, True, 64, 64, 64, 0, 1), @@ -44,48 +44,43 @@ def generate_test_params(extensive=False): ] # fmt: on - if extensive: - params = extensive_params - - names = [] - for ( - M, - K, - N, - num_aie_columns, - b_col_maj, - c_col_maj, - m, - k, - n, - trace_size, - partition_N, - ) in params: - name = f"gemm_{M}x{K}x{N}_{m}x{k}x{n}_{num_aie_columns}cols" - if b_col_maj: - name += "_bcolmaj" - if c_col_maj: - name += "_ccolmaj" - if partition_N > 1: - name += f"_{partition_N}npart" - if trace_size > 0: - name += f"_{trace_size}trace" - names.append(name) - - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params = [] + + # Helper to generate name and append param + def add_params(param_list, is_extensive): + for p in param_list: + ( + M, + K, + N, + num_aie_columns, + b_col_maj, + c_col_maj, + m, + k, + n, + trace_size, + partition_N, + ) = p + + name = f"gemm_{M}x{K}x{N}_{m}x{k}x{n}_{num_aie_columns}cols" + if b_col_maj: + name += "_bcolmaj" + if c_col_maj: + name += "_ccolmaj" + if partition_N > 1: + name += f"_{partition_N}npart" + if trace_size > 0: + name += f"_{trace_size}trace" + + marks = [pytest.mark.extensive] if is_extensive else [] + + params.append(pytest.param(*p, id=name, marks=marks)) + + add_params(regular_params, is_extensive=False) + add_params(extensive_params, is_extensive=True) + + return params @pytest.mark.metrics( @@ -95,7 +90,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "M,K,N,num_aie_columns,b_col_maj,c_col_maj,m,k,n,trace_size,partition_N", - all_params, + get_params(), ) def test_gemm( M, diff --git a/iron/operators/gemv/design.py b/iron/operators/gemv/design.py index 0a153364..6d48aa6d 100644 --- a/iron/operators/gemv/design.py +++ b/iron/operators/gemv/design.py @@ -29,10 +29,21 @@ - K: number of columns in the matrix == number of rows in the vector - m_input: number of input rows stored on each AIE core == chunk size for data movement of input A - m_output: number of output rows stored on each AIE core == chunk size for data movement of output C + - num_batches: number of iterations of this mat-vec to perform on contiguous matrices and vectors in memory (results concatenated) """ -def my_matvec(dev, cols, M, K, m_input, m_output=None): +def my_matvec( + dev, + cols, + M, + K, + m_input, + m_output=None, + num_batches=1, + kernel_archive="mv.o", + func_prefix="", +): if m_output is None: m_output = m_input @@ -68,20 +79,17 @@ def my_matvec(dev, cols, M, K, m_input, m_output=None): L1_B_ty = np.ndarray[(K,), dtype_in] L1_C_ty = np.ndarray[(m_output,), dtype_out] L3_A_ty = np.ndarray[ - ( - M, - K, - ), + (num_batches * M * K,), dtype_in, ] - L3_B_ty = np.ndarray[(K,), dtype_in] - L3_C_ty = np.ndarray[(M,), dtype_out] + L3_B_ty = np.ndarray[(num_batches * K,), dtype_in] + L3_C_ty = np.ndarray[(num_batches * M,), dtype_out] func_type = "vectorized" if vectorized else "scalar" matvec = Kernel( - f"matvec_{func_type}_{dtype_in_str}_{dtype_out_str}", - "mv.o", - [np.int32, np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty], + f"{func_prefix}matvec_{func_type}_{dtype_in_str}_{dtype_out_str}", + kernel_archive, + [np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty], ) A_L3L1_fifos = [ @@ -96,7 +104,7 @@ def my_matvec(dev, cols, M, K, m_input, m_output=None): def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec): one_idx = index.constant(1) - for _ in range_(0xFFFFFFFF): + for _ in range_(0xFFFFFFFF): # batch dim handled as part of this loop b = B_L3L1_fifo.acquire(1) # The kernel function computes m output rows; each core is responsible for (M/cols) output rows, so we need to call the kernel (M/cols)/m times. for i_idx in range_(M // m_output // cols): @@ -106,7 +114,7 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec): j_i32 = index.casts(T.i32(), j_idx) output_row_offset = j_i32 * m_input a = A_L3L1_fifo.acquire(1) - matvec(m_input, K, output_row_offset, a, b, c) + matvec(m_input, output_row_offset, a, b, c) A_L3L1_fifo.release(1) C_L1L3_fifo.release(1) B_L3L1_fifo.release(1) @@ -128,66 +136,63 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec): # The input matrix in DDR is MxK-sized (row-major); each core processes (M/cols)xK-sized matrices in chunks of mxK-sized tiles. # The chunking into mxK-sized tiles happens in the ObjectFIFO; the shim puts all data on the stream in sequence. A_taps = [ - TensorAccessPattern( - tensor_dims=(M, K), - offset=col * (M // cols) * K, - sizes=[1, 1, 1, (M // cols) * K], - strides=[0, 0, 0, 1], - ) + [ + TensorAccessPattern( + tensor_dims=L3_A_ty.__args__[0], + offset=col * (M // cols) * K + batch * M * K, + sizes=[1, 1, 1, (M // cols) * K], + strides=[0, 0, 0, 1], + ) + for batch in range(num_batches) + ] for col in range(cols) ] # Every column gets the entirety of the vector B, no TAP needed. # This design assumes that all of B fits on the cores. + B_tap = TensorAccessPattern( + tensor_dims=L3_B_ty.__args__[0], + offset=0, + sizes=[1, 1, 1, num_batches * K], + strides=[0, 0, 0, 1], + ) # Collection pattern for the output vector C: each AIE core writes back its contiguous chunk of rows. C_taps = [ - TensorAccessPattern( - tensor_dims=(1, M), - offset=col * (M // cols), - sizes=[1, 1, 1, (M // cols)], - strides=[0, 0, 0, 1], - ) + [ + TensorAccessPattern( + tensor_dims=L3_C_ty.__args__[0], + offset=col * (M // cols) + batch * M, + sizes=[1, 1, 1, (M // cols)], + strides=[0, 0, 0, 1], + ) + for batch in range(num_batches) + ] for col in range(cols) ] rt = Runtime() with rt.sequence(L3_A_ty, L3_B_ty, L3_C_ty) as (A, B, C): rt.start(*workers) - tg = rt.task_group() - for i in range(cols): - rt.fill(A_L3L1_fifos[i].prod(), A, A_taps[i], task_group=tg) - rt.fill(B_L3L1_fifos[i].prod(), B, task_group=tg) - for i in range(cols): - rt.drain(C_L1L3_fifos[i].cons(), C, C_taps[i], task_group=tg, wait=True) - rt.finish_task_group(tg) + tg_b = rt.task_group() + for col in range(cols): + # Simple linear transfer of B, includes all batches in sequence + rt.fill(B_L3L1_fifos[col].prod(), B, B_tap, task_group=tg_b) + for batch in range(num_batches): + tg_ac = rt.task_group() + for col in range(cols): + rt.fill( + A_L3L1_fifos[col].prod(), A, A_taps[col][batch], task_group=tg_ac + ) + for col in range(cols): + rt.drain( + C_L1L3_fifos[col].cons(), + C, + C_taps[col][batch], + task_group=tg_ac, + wait=True, + ) + rt.finish_task_group(tg_ac) + rt.finish_task_group(tg_b) return Program(dev_ty, rt).resolve_program(SequentialPlacer()) - - -def main(): - argparser = argparse.ArgumentParser( - prog="AIE Matrix Vector Multiplication MLIR Design", - ) - argparser.add_argument("--dev", type=str, choices=["npu", "npu2"], default="npu") - argparser.add_argument("-M", type=int) - argparser.add_argument("-K", type=int) - argparser.add_argument("-m", type=int) - argparser.add_argument("--cols", type=int) - argparser.add_argument( - "--output-file-path", - "-o", - type=str, - help="Output file path for the generated MLIR module", - ) - args = argparser.parse_args() - module = my_matvec(args.dev, args.cols, args.M, args.K, args.m) - - output_file_path = Path(args.output_file_path) - - with open(output_file_path, "w") as f: - f.write(str(module)) - - -if __name__ == "__main__": - main() diff --git a/iron/operators/gemv/op.py b/iron/operators/gemv/op.py index 6ed5a9fe..45c691f0 100644 --- a/iron/operators/gemv/op.py +++ b/iron/operators/gemv/op.py @@ -7,8 +7,8 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -19,7 +19,7 @@ from iron.common.utils import torch_to_numpy -class AIEGEMV(AIEOperatorBase): +class AIEGEMV(MLIROperator): """AIE-accelerated General Matrix-Vector/Vector-Matrix Multiplication layer""" def __init__( @@ -29,8 +29,9 @@ def __init__( num_aie_columns=1, tile_size_input=2, tile_size_output=None, - is_mv=True, + num_batches=1, use_static_weight=False, + kernel_vector_size=64, context=None, ): if tile_size_output is None: @@ -40,31 +41,30 @@ def __init__( tile_size_output % tile_size_input == 0 and tile_size_output >= tile_size_input ), "tile_size_output must be a multiple of tile_size_input" - self.M = M # matrix rows (if is_mv=False, matrix columns) - self.K = K # matrix columns, vector rows (if is_mv=False, matrix rows, vector columns) + self.M = M # matrix rows + self.K = K # matrix columns, vector rows self.num_aie_columns = num_aie_columns self.tile_size_input = tile_size_input self.tile_size_output = tile_size_output - self.is_mv = is_mv - if use_static_weight: - self.weight = torch.zeros( - (M, K) if is_mv else (K, M), dtype=torch.bfloat16 - ).T # weights are stored col-major/transposed - else: - self.weight = None + self.num_batches = num_batches + self.kernel_vector_size = kernel_vector_size + assert ( + K >= kernel_vector_size and K % kernel_vector_size == 0 + ), "K must be multiple of kernel_vector_size" self.xclbin_artifact = None self.insts_artifact = None - AIEOperatorBase.__init__(self, context=context) + MLIROperator.__init__(self, context=context) + + def get_operator_name(self): + return f"gemv_{self.M}x{self.K}_{self.tile_size_input}tsi_{self.tile_size_output}tso_{self.num_batches}batch_{self.num_aie_columns}col" - def get_artifacts(self, prefix="gemv_"): - # The underlying MLIR design is a matrix-vector multiplication. We support vector-matrix multiplication by transposing the matrix beforehand (AB = C <=> B^T A^T = C^T). + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"{prefix}{self.M}x{self.K}_{self.tile_size_input}tsi_{self.tile_size_output}tso_{self.num_aie_columns}col" - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_matvec", callback_args=[ @@ -74,119 +74,30 @@ def get_artifacts(self, prefix="gemv_"): self.K, self.tile_size_input, self.tile_size_output, + self.num_batches, ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"mv.o", - depends=[ - SourceArtifact.new( - self.context.base_dir / "aie_kernels" / "generic" / "mv.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - return xclbin_artifact, insts_artifact - - def set_up_artifacts(self): - xclbin_artifact, insts_artifact = self.get_artifacts() - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # If this operator is only used as a sub-operator in another operator that sets it up, we should skip the setup here as those artifacts and buffers may not be needed. - # Runtime Setup - # --- - static_weights = None - if self.weight is not None: - # Kernel expects row-major weights, so might need to transpose (torch weights are stored in col-major); - # also might need to transpose if is_mv - if self.is_mv: - static_weights = self.weight.T - else: - # Double transpose cancels out - static_weights = self.weight - if isinstance(static_weights, torch.Tensor): - static_weights = torch_to_numpy(static_weights) - self.add_kernel( - "gemv", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_buffer("matrix", self.M * self.K, static_data=static_weights) - self.add_buffer("vector", self.K) - self.add_buffer("output", self.M) - self.add_to_runlist("gemv", "matrix", "vector", "output") - - def forward(self, vector, matrix=None): - """Forward pass through GEMV operation - - Args: - matrix: Input matrix of shape (..., M, K) - vector: Input vector of shape (..., K) for MV or (..., M) for VM - is_mv: True for matrix-vector multiplication, False for vector-matrix - - Returns: - Output vector of shape (..., M) for MV or (..., K) for VM - """ - - # Flatten batch dimensions if needed - if matrix is not None: - matrix = matrix.reshape(*matrix.shape[-2:]) - original_vector_dims = vector.ndim - vector = vector.reshape(*vector.shape[-1:]) - - # For vector-matrix, we'll transpose the matrix internally - if matrix is not None and not self.is_mv: - # Transpose the matrix for vector-matrix multiplication - # (if using static weights, the matrix is already transposed once at setup if needed) - matrix = matrix.transpose(-2, -1) - - if matrix is not None: - matrix_rows = matrix.shape[-2] - matrix_cols = matrix.shape[-1] - else: - matrix_rows = self.M - matrix_cols = self.K - - vector_size = vector.shape[-1] - - applicable = ( - matrix_cols == vector_size - and matrix_rows == self.M - and matrix_cols == self.K - and (matrix is None or matrix.dtype == torch.bfloat16) - and vector.dtype == torch.bfloat16 - ) - if not applicable: - raise AIEOperatorConstraintError( - "AIEElementwiseAdd: incompatible tensor shape(s)" - ) - - if matrix is not None: - # If matrix is none, we are using static weights that have already been written to the buffer - self.write_buffer("matrix", matrix) - self.write_buffer("vector", vector) - self.run_runlist() - result = self.read_buffer_as_torch("output", (self.M,)) - - # Add back batch dimensions if we removed them earlier. - if result.ndim < original_vector_dims: - result = result.reshape(*((1,) * (original_vector_dims - 1)), -1) - - return result + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"gemv_{self.K}k.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "generic" / "mv.cc" + ) + ], + extra_flags=[ + f"-DDIM_K={self.K}", + f"-DVEC_SIZE={self.kernel_vector_size}", + ], + ), + ] + + def get_arg_spec(self): + batch_dim = (self.num_batches,) if self.num_batches > 1 else () + return [ + AIERuntimeArgSpec("in", batch_dim + (self.M, self.K)), # matrix + AIERuntimeArgSpec("in", batch_dim + (self.K,)), # vector + AIERuntimeArgSpec("out", batch_dim + (self.M,)), # output + ] diff --git a/iron/operators/gemv/test.py b/iron/operators/gemv/test.py index 2dd4a8e6..c26fb1f4 100755 --- a/iron/operators/gemv/test.py +++ b/iron/operators/gemv/test.py @@ -12,8 +12,8 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): - params = [ +def get_params(): + params_list = [ (128, 128, 1, 32, 128), (2048, 8192, 1, 1, 2048), (8192, 2048, 1, 4, 1024), @@ -24,24 +24,16 @@ def generate_test_params(extensive=False): (2048, 8192, 8, 1, 256), (8192, 2048, 8, 4, 1024), ] - names = [ - f"matrix_vector_mul_{M}x{K}_{tile_size_input}tsi_{tile_size_output}tso_{num_aie_columns}col" - for M, K, num_aie_columns, tile_size_input, tile_size_output in params - ] - return params, names - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) + params = [] + for p in params_list: + M, K, num_aie_columns, tile_size_input, tile_size_output = p + name = f"matrix_vector_mul_{M}x{K}_{tile_size_input}tsi_{tile_size_output}tso_{num_aie_columns}col" -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + # All tests are considered regular here as per original code structure + # (original code returned same list for both regular and extensive) + params.append(pytest.param(*p, id=name)) + return params @pytest.mark.metrics( @@ -50,7 +42,7 @@ def generate_test_params(extensive=False): Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", ) @pytest.mark.parametrize( - "M,K,num_aie_columns,tile_size_input,tile_size_output", all_params + "M,K,num_aie_columns,tile_size_input,tile_size_output", get_params() ) def test_gemv(M, K, num_aie_columns, tile_size_input, tile_size_output, aie_context): golden_ref = generate_golden_reference(M=M, K=K) diff --git a/iron/operators/layer_norm/design.py b/iron/operators/layer_norm/design.py index f48bb2d2..c5f088a4 100644 --- a/iron/operators/layer_norm/design.py +++ b/iron/operators/layer_norm/design.py @@ -15,7 +15,15 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_layer_norm(dev, num_elements, num_columns, num_channels, trace_size, tile_size): +def my_layer_norm( + dev, + num_elements, + num_columns, + num_channels, + trace_size, + tile_size, + kernel_archive=None, +): per_tile_elements = 8192 if tile_size > 8192 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: diff --git a/iron/operators/layer_norm/op.py b/iron/operators/layer_norm/op.py index cc3c1aa2..b2b7a35e 100644 --- a/iron/operators/layer_norm/op.py +++ b/iron/operators/layer_norm/op.py @@ -7,8 +7,8 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -17,16 +17,19 @@ ) -class AIELayerNorm(AIEOperatorBase): +class AIELayerNorm(MLIROperator): """AIE-accelerated LAYER NORM operator""" def __init__( self, size, num_aie_columns, num_channels, tile_size, trace_size=0, context=None ): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.trace_size = trace_size self.num_aie_columns = num_aie_columns @@ -35,17 +38,15 @@ def __init__( total_shimdma_channels = self.num_aie_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"layer_norm_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"layer_norm_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_layer_norm", callback_args=[ @@ -58,62 +59,23 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"layer_norm.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "aie2p" - / "layer_norm.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "layer_norm", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("layer_norm", "input", "output") - - def forward(self, x): - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIELayerNorm: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"layer_norm.o", + dependencies=[ + SourceArtifact( + self.context.base_dir + / "aie_kernels" + / "aie2p" + / "layer_norm.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/iron/operators/layer_norm/test.py b/iron/operators/layer_norm/test.py index 2b14641c..360da0a1 100755 --- a/iron/operators/layer_norm/test.py +++ b/iron/operators/layer_norm/test.py @@ -12,11 +12,10 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): for num_channels_layer in range(1, 3): # 1 or 2 @@ -26,26 +25,22 @@ def generate_test_params(extensive=False): tile_size = 8192 check_length = tile_size * total_cores if check_length == input_length: - names.append( - f"layer_norm_{num_aie_columns}_cols_{num_channels_layer}_channels_{input_length}_tile_{tile_size}" - ) - params.append( - (input_length, num_aie_columns, num_channels_layer, tile_size) - ) - return params, names + name = f"layer_norm_{num_aie_columns}_cols_{num_channels_layer}_channels_{input_length}_tile_{tile_size}" + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels_layer, + tile_size, + id=name, + marks=marks, + ) + ) + return params @pytest.mark.metrics( @@ -54,7 +49,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_layer_norm( input_length, num_aie_columns, num_channels, tile_size, aie_context diff --git a/iron/operators/leaky_relu/design.py b/iron/operators/leaky_relu/design.py index 25cd580b..a5d5c534 100644 --- a/iron/operators/leaky_relu/design.py +++ b/iron/operators/leaky_relu/design.py @@ -14,7 +14,16 @@ from aie.iron.controlflow import range_ -def my_leaky_relu(dev, size, num_columns, num_channels, tile_size, trace_size, alpha): +def my_leaky_relu( + dev, + size, + num_columns, + num_channels, + tile_size, + trace_size, + alpha, + kernel_archive=None, +): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] diff --git a/iron/operators/leaky_relu/op.py b/iron/operators/leaky_relu/op.py index e26fc368..72fddeb7 100644 --- a/iron/operators/leaky_relu/op.py +++ b/iron/operators/leaky_relu/op.py @@ -7,8 +7,8 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -17,16 +17,19 @@ ) -class AIELeakyReLU(AIEOperatorBase): +class AIELeakyReLU(MLIROperator): """AIE-accelerated LEAKY RELU operator""" def __init__( self, size, num_aie_columns, num_channels, tile_size, alpha=0.01, context=None ): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_columns = num_aie_columns @@ -36,17 +39,15 @@ def __init__( total_shimdma_channels = self.num_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"leaky_relu_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"leaky_relu_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_leaky_relu", callback_args=[ @@ -60,62 +61,23 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"leaky_relu.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "aie2p" - / "leaky_relu.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "leaky_relu", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("leaky_relu", "input", "output") - - def forward(self, x): - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIELeakyReLU: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"leaky_relu.o", + dependencies=[ + SourceArtifact( + self.context.base_dir + / "aie_kernels" + / "aie2p" + / "leaky_relu.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/iron/operators/leaky_relu/test.py b/iron/operators/leaky_relu/test.py index cac577ad..6adb8d4d 100755 --- a/iron/operators/leaky_relu/test.py +++ b/iron/operators/leaky_relu/test.py @@ -12,24 +12,10 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): # Leaky ReLU is currently broken (#36); leave it untested params = [] - names = [] - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -38,7 +24,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size,alpha", - all_params, + get_params(), ) def test_leaky_relu( input_length, num_aie_columns, num_channels, tile_size, alpha, aie_context diff --git a/iron/operators/mem_copy/design.py b/iron/operators/mem_copy/design.py index ce807a48..73a0eca2 100644 --- a/iron/operators/mem_copy/design.py +++ b/iron/operators/mem_copy/design.py @@ -167,7 +167,16 @@ def create_partial_workload_config( # -def my_mem_copy(dev, size, num_cores, num_channels, bypass, tile_size, trace_size): +def my_mem_copy( + dev, + size, + num_cores, + num_channels, + bypass, + tile_size, + trace_size, + kernel_archive=None, +): # -------------------------------------------------------------------------- # Configuration # -------------------------------------------------------------------------- diff --git a/iron/operators/mem_copy/op.py b/iron/operators/mem_copy/op.py index c5c9f14e..08cd95c9 100644 --- a/iron/operators/mem_copy/op.py +++ b/iron/operators/mem_copy/op.py @@ -7,17 +7,18 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, + KernelArchiveArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIEMemCopy(AIEOperatorBase): +class AIEMemCopy(MLIROperator): def __init__(self, size, num_cores, num_channels, bypass, tile_size, context=None): self.size = size @@ -29,22 +30,16 @@ def __init__(self, size, num_cores, num_channels, bypass, tile_size, context=Non # For naming consistency with other operators self.bypass_str = "bypass" if bypass else "no_bypass" - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"mem_copy_{self.num_cores}_cores_{self.num_channels}_chans_tile_{self.tile_size}_{self.bypass_str}" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - size = self.tile_size * self.num_cores - - # Xclbin base name (shared) - xclbin_base_name = f"mem_copy_{self.num_cores}_cores_{self.num_channels}_chans_tile_{self.tile_size}_{self.bypass_str}" - - # Generate MLIR for xclbin (using dummy size) - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{xclbin_base_name}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_mem_copy", callback_args=[ @@ -58,67 +53,57 @@ def set_up_artifacts(self): ], ) - # Build kernel only if not bypass mode + def get_kernel_artifacts(self): if not self.bypass: - kernel_artifact = KernelObjectArtifact.new( - "mem_copy.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "generic" - / "passThrough.cc" - ) - ], - ) - xclbin_depends = [mlir_artifact, kernel_artifact] + return [ + KernelObjectArtifact( + "mem_copy.o", + dependencies=[ + SourceArtifact( + self.context.base_dir + / "aie_kernels" + / "generic" + / "passThrough.cc" + ) + ], + ) + ] else: - xclbin_depends = [mlir_artifact] - - xclbin_artifact = XclbinArtifact.new( - f"{xclbin_base_name}.xclbin", - depends=xclbin_depends, - extra_flags=["--dynamic-objFifos"], + return [] + + def get_artifacts(self): + # Override to add --dynamic-objFifos flag + operator_name = self.get_operator_name() + mlir_artifact = self.get_mlir_artifact() + kernel_deps_inputs = self.get_kernel_artifacts() + if len(kernel_deps_inputs) > 0: + mlir_artifact.callback_kwargs["kernel_archive"] = self.kernel_archive + kernel_deps = ( + [ + KernelArchiveArtifact( + self.kernel_archive, + dependencies=kernel_deps_inputs, + ) + ] + if kernel_deps_inputs + else [] ) - - insts_file_name = f"mem_copy_{self.num_cores}_cores_{self.num_channels}_chans_{self.size}_tile_{self.tile_size}_{self.bypass_str}" - insts_artifact = InstsBinArtifact.new( - f"{insts_file_name}.bin", - depends=[mlir_artifact], + xclbin_artifact = XclbinArtifact( + f"{operator_name}.xclbin", + mlir_input=mlir_artifact, + dependencies=[mlir_artifact] + kernel_deps, extra_flags=["--dynamic-objFifos"], ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "mem_copy", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, + insts_artifact = InstsBinArtifact( + f"{operator_name}.bin", + mlir_input=mlir_artifact, + dependencies=[mlir_artifact], + extra_flags=["--dynamic-objFifos"], ) - self.add_to_runlist("mem_copy", "input", "output") - - def forward(self, x): - """Forward pass for memory copy""" - if x.numel() != self.size: - raise AIEOperatorConstraintError( - f"AIEMemCopy: input size {x.numel()} does not match expected size {self.size}" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - # Execute on AIE - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) + return xclbin_artifact, insts_artifact - return result.reshape(*original_shape) + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/iron/operators/mem_copy/test.py b/iron/operators/mem_copy/test.py index afd7f540..f6314e5b 100644 --- a/iron/operators/mem_copy/test.py +++ b/iron/operators/mem_copy/test.py @@ -12,12 +12,11 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): - input_lengths = [2048] if not extensive else [1024, 2048, 4096, 8192] - bypass_modes = [False] if not extensive else [False, True] +def get_params(): + input_lengths = [1024, 2048, 4096, 8192] + bypass_modes = [False, True] params = [] - names = [] for input_length in input_lengths: for num_cores in range(1, 17): # 1 to 16 cores @@ -35,33 +34,24 @@ def generate_test_params(extensive=False): # Only proceed if tile_size * num_cores == input_length (exact division) if tile_size * num_cores == input_length: - names.append( - f"mem_copy_{num_cores}_cores_{num_channels}_chans_{input_length}_tile_{tile_size}_{str(bypass)}" - ) + name = f"mem_copy_{num_cores}_cores_{num_channels}_chans_{input_length}_tile_{tile_size}_{str(bypass)}" + + is_regular = input_length == 2048 and bypass == False + marks = [] if is_regular else [pytest.mark.extensive] + params.append( - ( + pytest.param( input_length, num_cores, num_channels, bypass, tile_size, + id=name, + marks=marks, ) ) - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -70,7 +60,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_cores,num_channels,bypass,tile_size", - all_params, + get_params(), ) def test_mem_copy( input_length, num_cores, num_channels, bypass, tile_size, aie_context diff --git a/iron/operators/mha/design.py b/iron/operators/mha/design.py index d11e4ed4..9dc33b92 100644 --- a/iron/operators/mha/design.py +++ b/iron/operators/mha/design.py @@ -115,6 +115,7 @@ def fused_mha( emulate_bf16_mmul_with_bfp16: bool, trace_size: int = 0, verbose: bool = False, + kernel_archive=None, ): of_depth = 2 @@ -205,7 +206,7 @@ def fused_mha( # AIE kernel declarations func_type = "" if vectorized else "_scalar" - bin_name = "mha_kernels.a" + bin_name = kernel_archive if kernel_archive else "mha_kernels.a" zero_kernel = Kernel(f"zero_{dtype_str}", bin_name, [qk_ty]) diff --git a/iron/operators/mha/op.py b/iron/operators/mha/op.py index 58864519..950614a8 100644 --- a/iron/operators/mha/op.py +++ b/iron/operators/mha/op.py @@ -8,8 +8,8 @@ from typing import Dict, List from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -20,7 +20,7 @@ from iron.common.utils import torch_to_numpy, numpy_to_torch -class AIEMHA(AIEOperatorBase): +class AIEMHA(MLIROperator): def __init__( self, @@ -40,20 +40,34 @@ def __init__( self.num_of_pipelines = num_of_pipelines assert d == 64, "Only d=64 is supported in this version" - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + kv_heads = self.num_KV_heads if self.num_KV_heads > 0 else self.num_heads + return f"mha_{self.num_heads}h_{kv_heads}kv_{self.seq_len}s_{self.d}d" - def set_up_artifacts(self): - # Set up compilation artifacts - # --- + def get_mlir_artifact(self): operator_dir = Path(__file__).parent + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", + import_path=operator_dir / "design.py", + callback_fn="fused_mha", + callback_kwargs={ + "heads": self.num_heads, + "S_q": self.seq_len, + "S_kv": self.seq_len, + "d": self.d, + "B_q": self.B_q, + "B_kv": self.B_kv, + "num_KV_heads": self.num_KV_heads, + "number_of_pipelines": self.num_of_pipelines, + "emulate_bf16_mmul_with_bfp16": True, + "trace_size": 0, + "verbose": False, + }, + ) - kv_heads = self.num_KV_heads if self.num_KV_heads > 0 else self.num_heads - file_name_base = f"mha_{self.num_heads}h_{kv_heads}kv_{self.seq_len}s_{self.d}d" - + def get_kernel_artifacts(self): # Define source files mm_source = str(self.context.base_dir / "aie_kernels" / "aie2p" / "mm.cc") softmax_source = str( @@ -83,105 +97,72 @@ def set_up_artifacts(self): "zero_scalar_bf16": "zero_scalar_bf16_rowmaj", } - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", - import_path=operator_dir / "design.py", - callback_fn="fused_mha", - callback_kwargs={ - "heads": self.num_heads, - "S_q": self.seq_len, - "S_kv": self.seq_len, - "d": self.d, - "B_q": self.B_q, - "B_kv": self.B_kv, - "num_KV_heads": self.num_KV_heads, - "number_of_pipelines": self.num_of_pipelines, - "emulate_bf16_mmul_with_bfp16": True, - "trace_size": 0, - "verbose": False, - }, - ) + return [ + KernelObjectArtifact( + f"mha_mm.o", + extra_flags=mm_defines_colmaj, + dependencies=[SourceArtifact(mm_source)], + ), + KernelObjectArtifact( + f"mha_mm_rowmaj.o", + extra_flags=mm_defines_rowmaj, + dependencies=[SourceArtifact(mm_source)], + rename_symbols=mm_rename_symbols, + ), + KernelObjectArtifact( + "mha_softmax.o", + dependencies=[SourceArtifact(softmax_source)], + ), + KernelObjectArtifact( + "mha_mha.o", dependencies=[SourceArtifact(mha_source)] + ), + KernelObjectArtifact( + "mha_passThrough.o", + extra_flags=["-DBIT_WIDTH=16"], + dependencies=[SourceArtifact(passthrough_source)], + ), + ] - xclbin_artifact = XclbinArtifact.new( - f"mha.xclbin", - depends=[ - mlir_artifact, - KernelArchiveArtifact.new( - f"mha_kernels.a", - depends=[ - KernelObjectArtifact.new( - f"mha_mm.o", - extra_flags=mm_defines_colmaj, - depends=[SourceArtifact.new(mm_source)], - ), - KernelObjectArtifact.new( - f"mha_mm_rowmaj.o", - extra_flags=mm_defines_rowmaj, - depends=[SourceArtifact.new(mm_source)], - rename_symbols=mm_rename_symbols, - ), - KernelObjectArtifact.new( - "mha_softmax.o", - depends=[SourceArtifact.new(softmax_source)], - ), - KernelObjectArtifact.new( - "mha_mha.o", depends=[SourceArtifact.new(mha_source)] - ), - KernelObjectArtifact.new( - "mha_passThrough.o", - extra_flags=["-DBIT_WIDTH=16"], - depends=[SourceArtifact.new(passthrough_source)], - ), - ], - ), - ], + def get_artifacts(self): + # Override to add --dynamic-objFifos flag + operator_name = self.get_operator_name() + mlir_artifact = self.get_mlir_artifact() + kernel_deps_inputs = self.get_kernel_artifacts() + if len(kernel_deps_inputs) > 0: + mlir_artifact.callback_kwargs["kernel_archive"] = self.kernel_archive + kernel_deps = ( + [ + KernelArchiveArtifact( + self.kernel_archive, + dependencies=kernel_deps_inputs, + ) + ] + if kernel_deps_inputs + else [] + ) + xclbin_artifact = XclbinArtifact( + f"{operator_name}.xclbin", + mlir_input=mlir_artifact, + dependencies=[mlir_artifact] + kernel_deps, extra_flags=["--dynamic-objFifos"], ) - - insts_artifact = InstsBinArtifact.new( - f"mha.bin", depends=[mlir_artifact], extra_flags=["--dynamic-objFifos"] + insts_artifact = InstsBinArtifact( + f"{operator_name}.bin", + mlir_input=mlir_artifact, + dependencies=[mlir_artifact], + extra_flags=["--dynamic-objFifos"], ) + return xclbin_artifact, insts_artifact - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # Set up runtime - # --- - self.add_kernel( - "mha", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_buffer( - "Q", - self.num_heads - * self.d - * self._calculate_seq_padding(self.seq_len, self.num_of_pipelines), - ) - self.add_buffer( - "K", - self.num_heads - * self.d - * self._calculate_seq_padding(self.seq_len, self.num_of_pipelines), - ) - self.add_buffer( - "V", - self.num_heads - * self.d - * self._calculate_seq_padding(self.seq_len, self.num_of_pipelines), - ) - self.add_buffer( - "O", - self.num_heads - * self.d - * self._calculate_seq_padding(self.seq_len, self.num_of_pipelines), - ) - self.add_to_runlist("mha", "Q", "K", "V", "O") + def get_arg_spec(self): + seq_padding = self._calculate_seq_padding(self.seq_len, self.num_of_pipelines) + buffer_size = self.num_heads * self.d * seq_padding + return [ + AIERuntimeArgSpec("in", (buffer_size,)), # Q + AIERuntimeArgSpec("in", (buffer_size,)), # K + AIERuntimeArgSpec("in", (buffer_size,)), # V + AIERuntimeArgSpec("out", (buffer_size,)), # O + ] def _calculate_seq_padding(self, seq_len, num_pipeline=1): return ((seq_len + 63 * num_pipeline) // (64 * num_pipeline)) * ( @@ -190,7 +171,7 @@ def _calculate_seq_padding(self, seq_len, num_pipeline=1): def _pad_to_multiple_of_64(self, tensor, seq_dim, num_pipeline=1): seq_len = tensor.shape[seq_dim] - padded_seq_len = _calculate_seq_padding(seq_len, num_pipeline) + padded_seq_len = self._calculate_seq_padding(seq_len, num_pipeline) if padded_seq_len == seq_len: return tensor @@ -219,63 +200,3 @@ def _unpack_padded_to_compact( dst = np.zeros((H, S, D), dtype=src.dtype) dst = src[:H, :S, :D] return dst - - def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): - applicable = ( - q.shape[-1] == self.d - and k.shape[-1] == self.d - and v.shape[-1] == self.d - and q.shape[-2] == self.seq_len - and k.shape[-2] == self.seq_len - and v.shape[-2] == self.seq_len - and self.seq_len % 64 == 0, # Sequence length must be multiple of 64 - ) - if not applicable: - raise AIEOperatorConstraintError( - "AIEElementwiseAdd: incompatible tensor shape(s)" - ) - - ret = self._execute_aie_operation(q, k, v) - return ret - - def _execute_aie_operation(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): - # Convert to numpy - q_np = torch_to_numpy(q) - k_np = torch_to_numpy(k) - v_np = torch_to_numpy(v) - - # Calculate padded sequence length - S_pad = self._calculate_seq_padding(self.seq_len, self.num_of_pipelines) - - # Pack compact inputs to padded format - q_padded = self._pack_compact_to_padded( - q_np, self.num_heads, self.seq_len, S_pad, self.d - ) - k_padded = self._pack_compact_to_padded( - k_np, self.num_heads, self.seq_len, S_pad, self.d - ) - v_padded = self._pack_compact_to_padded( - v_np, self.num_heads, self.seq_len, S_pad, self.d - ) - - # Write padded buffers - self.write_buffer("Q", q_padded) - self.write_buffer("K", k_padded) - self.write_buffer("V", v_padded) - - # Execute - self.run_runlist() - - # Read padded output - o_padded = self.read_buffer( - "O", shape=(self.num_heads, S_pad, self.d), dtype=bfloat16 - ) - - # Unpack padded output to compact format - o_compact = self._unpack_padded_to_compact( - o_padded, self.num_heads, self.seq_len, S_pad, self.d - ) - - # Convert back to torch with correct shape - result = numpy_to_torch(o_compact) - return result diff --git a/iron/operators/mha/test.py b/iron/operators/mha/test.py index 35c5087f..b1871e42 100755 --- a/iron/operators/mha/test.py +++ b/iron/operators/mha/test.py @@ -12,30 +12,21 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): - params = [(16384, 64, 1, 8)] +def get_params(): + params_list = [(16384, 64, 1, 8)] names = ["mha"] - return params, names - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params = [] + for p, name in zip(params_list, names): + params.append(pytest.param(*p, id=name)) + return params @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines", all_params) +@pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines", get_params()) def test_mha(seq_len, dim, num_heads, num_pipelines, aie_context): golden_ref = generate_golden_reference( S_q=seq_len, diff --git a/iron/operators/relu/design.py b/iron/operators/relu/design.py index 496bb443..5c46fbb9 100644 --- a/iron/operators/relu/design.py +++ b/iron/operators/relu/design.py @@ -14,7 +14,9 @@ from aie.iron.controlflow import range_ -def my_relu(dev, size, num_columns, num_channels, tile_size, trace_size): +def my_relu( + dev, size, num_columns, num_channels, tile_size, trace_size, kernel_archive=None +): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] diff --git a/iron/operators/relu/op.py b/iron/operators/relu/op.py index 8b1f54e8..24ad44dc 100644 --- a/iron/operators/relu/op.py +++ b/iron/operators/relu/op.py @@ -7,8 +7,8 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -17,14 +17,17 @@ ) -class AIEReLU(AIEOperatorBase): +class AIEReLU(MLIROperator): """AIE-accelerated ReLU activation function""" def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_aie_columns = num_aie_columns self.num_channels = num_channels @@ -32,17 +35,15 @@ def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None) total_shimdma_channels = self.num_aie_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"relu_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"relu_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_relu", callback_args=[ @@ -55,59 +56,20 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"relu.o", - depends=[ - SourceArtifact.new( - self.context.base_dir / "aie_kernels" / "aie2p" / "relu.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "relu", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("relu", "input", "output") - - def forward(self, x): - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIEReLU: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"relu.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "aie2p" / "relu.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/iron/operators/relu/test.py b/iron/operators/relu/test.py index 3194c8c0..4bea4584 100755 --- a/iron/operators/relu/test.py +++ b/iron/operators/relu/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 1 # 1 channel for 1 input - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns @@ -26,24 +25,22 @@ def generate_test_params(extensive=False): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - names.append( - f"relu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + name = f"relu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] + + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) ) - params.append((input_length, num_aie_columns, num_channels, tile_size)) - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -52,7 +49,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_relu(input_length, num_aie_columns, num_channels, tile_size, aie_context): golden_ref = generate_golden_reference(input_length=input_length) diff --git a/iron/operators/repeat/design.py b/iron/operators/repeat/design.py new file mode 100644 index 00000000..a3539caa --- /dev/null +++ b/iron/operators/repeat/design.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +# from aie.extras.context import mlir_mod_ctx +# from aie.ir import StridedLayoutAttr, ShapedType +# from aie.dialects.aie import * +# from aie.dialects.aiex import * +from aie.dialects.aiex import TensorAccessPattern +from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker +from aie.iron.placers import SequentialPlacer + +""" +Repeat interleave +""" + + +def repeat(dev, dtype, rows, cols, repeat, transfer_size=None): + dtype = np.dtype[dtype] + + # Try to work around hardware size limitations by breaking transfers into smaller chunks + cols_split = 1 + if cols > 1023: + for divisor in range(2, cols + 1): + if cols % divisor == 0 and cols // divisor <= 1023: + cols_split = divisor + break + else: + raise ValueError( + f"Cannot split cols={cols} into chunks <= 1023; hardware limits cols to not exceed 1023" + ) + assert cols_split <= 1023, "cols is too large, can't split into smaller transfers" + + if transfer_size is None: + transfer_size = cols + + inp_ty = np.ndarray[ + (rows, cols), + dtype, + ] + out_ty = np.ndarray[ + (rows * repeat, cols), + dtype, + ] + transfer_ty = np.ndarray[ + (transfer_size,), + dtype, + ] + + input_tap = TensorAccessPattern( + tensor_dims=(rows, cols), + offset=0, + sizes=[repeat, rows, cols // cols_split, cols_split], + strides=[0, cols, cols_split, 1], + ) + + output_tap = TensorAccessPattern( + tensor_dims=(rows * repeat, cols), + offset=0, + sizes=[repeat, rows, cols // cols_split, cols_split], + strides=[cols, cols * repeat, cols_split, 1], + ) + + # Use smaller FIFOs for the transfer amount + fifo_in = ObjectFifo(transfer_ty, name="fifo_in", depth=2) + fifo_out = fifo_in.cons().forward(name="fifo_out", depth=2) + + rt = Runtime() + with rt.sequence(inp_ty, out_ty) as (inp, out): + tg = rt.task_group() + rt.fill(fifo_in.prod(), inp, input_tap, task_group=tg) + rt.drain(fifo_out.cons(), out, output_tap, task_group=tg, wait=True) + rt.finish_task_group(tg) + + return Program(dev, rt).resolve_program(SequentialPlacer()) diff --git a/iron/operators/repeat/op.py b/iron/operators/repeat/op.py new file mode 100644 index 00000000..b056f591 --- /dev/null +++ b/iron/operators/repeat/op.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from ml_dtypes import bfloat16 +from pathlib import Path + +from iron.common import ( + MLIROperator, + AIERuntimeArgSpec, + KernelObjectArtifact, + SourceArtifact, + PythonGeneratedMLIRArtifact, + XclbinArtifact, + InstsBinArtifact, +) + + +class AIERepeat(MLIROperator): + """AIE-accelerated General Matrix-Vector/Vector-Matrix Multiplication layer""" + + def __init__( + self, + rows, + cols, + repeat, + transfer_size=None, + dtype=bfloat16, + context=None, + ): + self.rows = rows + self.cols = cols + self.repeat = repeat + self.transfer_size = transfer_size + self.dtype = dtype + MLIROperator.__init__(self, context=context) + + def get_operator_name(self): + name = f"repeat_{self.rows}x{self.cols}_by_{self.repeat}" + if self.transfer_size is not None: + name += f"_{self.transfer_size}ts" + return name + + def get_mlir_artifact(self): + operator_dir = Path(__file__).parent + + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", + import_path=operator_dir / "design.py", + callback_fn="repeat", + callback_args=[ + self.context.device_manager.device_type, + self.dtype, + self.rows, + self.cols, + self.repeat, + self.transfer_size, + ], + ) + + def get_kernel_artifacts(self): + return [] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.rows, self.cols)), + AIERuntimeArgSpec("out", (self.rows * self.repeat, self.cols)), + ] diff --git a/iron/operators/rms_norm/design.py b/iron/operators/rms_norm/design.py index 2bf09b43..583ca8f6 100644 --- a/iron/operators/rms_norm/design.py +++ b/iron/operators/rms_norm/design.py @@ -15,7 +15,15 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_rms_norm(dev, num_elements, num_columns, num_channels, trace_size, tile_size): +def my_rms_norm( + dev, + num_elements, + num_columns, + num_channels, + trace_size, + tile_size, + kernel_archive="rms_norm.a", +): per_tile_elements = 8192 if tile_size > 8192 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -46,7 +54,7 @@ def my_rms_norm(dev, num_elements, num_columns, num_channels, trace_size, tile_s # AIE Core Function declaration rms_norm_kernel = Kernel( - "rms_norm_bf16_vector", "rms_norm.o", [tile_ty, tile_ty, np.int32] + "rms_norm_bf16_vector", kernel_archive, [tile_ty, tile_ty, np.int32] ) # Define a task that will run on a compute tile @@ -120,93 +128,3 @@ def core_body(of_in1, of_out, rms_norm_kernel): # Place program components (assign them resources on the device) and generate an MLIR module return Program(dev, rt).resolve_program(SequentialPlacer()) - - -if __name__ == "__main__": - - def str_to_device(device: str): - if device == "npu": - return NPU1() - elif device == "npu2": - return NPU2() - else: - raise ValueError(f"Device name {device} is unknown.") - - p = argparse.ArgumentParser() - # Parse command line arguments - - # Device name is required to select the AIE device: npu or npu2 - p.add_argument( - "-d", - "--dev", - required=True, - dest="device", - help="AIE Device", - type=str_to_device, - ) - # Transfer size is required to define the size of the data to be transferred - # It must be a multiple of 1024 and divisible by the number of columns and 2 channels per column - p.add_argument("-l", "--length", required=True, dest="length", help="Transfer size") - # Number of columns is required to define the number of columns to be used - # It must be less than or equal to 4 for npu and 8 for npu2 - p.add_argument( - "-co", "--columns", required=True, dest="cols", help="Number of columns" - ) - # Number of channels is required to define the number of channels to be used - # It must be 1 or 2 - p.add_argument( - "-ch", "--channels", required=True, dest="chans", help="Number of channels" - ) - # Tile size (columns per tile) - defaults to 1024 for backward compatibility - p.add_argument( - "-ts", - "--tile-size", - required=False, - dest="tile_size", - default="1024", - help="Tile size (columns per tile)", - ) - # Trace Size - p.add_argument( - "-tr", "--trace-size", required=True, dest="trace_size", help="Trace size" - ) - p.add_argument( - "--output-file-path", - "-o", - type=str, - help="Output file path for the generated MLIR module", - ) - - opts = p.parse_args(sys.argv[1:]) - - length = int(opts.length) - columns = int(opts.cols) - dev = opts.device # Now this is already a device object! - - # Validate columns based on device type - if isinstance(dev, NPU1) and columns > 4: - raise ValueError("[ERROR] NPU device cannot allocate more than 4 columns") - elif isinstance(dev, NPU2) and columns > 8: - raise ValueError("[ERROR] NPU2 device cannot allocate more than 8 columns") - - channels = int(opts.chans) - if channels < 1 or channels > 2: - raise ValueError("Number of channels must be 1 or 2") - tile_size = int(opts.tile_size) - if ((length % tile_size) % columns % channels) != 0: - print( - "transfer size (" - + str(length) - + ") must be a multiple of " - + str(tile_size) - + " and divisible by the number of columns and 2 channels per column" - ) - raise ValueError - trace_size = int(opts.trace_size) if opts.trace_size is not None else 0 - - module = my_rms_norm(dev, length, columns, channels, trace_size, tile_size) - - output_file_path = Path(opts.output_file_path) - - with open(output_file_path, "w") as f: - f.write(str(module)) diff --git a/iron/operators/rms_norm/design_weighted.py b/iron/operators/rms_norm/design_weighted.py index 20c4fbbe..fab3caac 100644 --- a/iron/operators/rms_norm/design_weighted.py +++ b/iron/operators/rms_norm/design_weighted.py @@ -16,7 +16,14 @@ def my_weighted_rms_norm( - dev, num_elements, num_columns, num_channels, weight_length, trace_size + dev, + num_elements, + num_columns, + num_channels, + weight_length, + trace_size, + kernel_archive="rms_norm.a", + func_prefix="", ): per_tile_elements = weight_length total_cores = num_columns # For each core that does rms norm, another core will take its output to do eltwise mul @@ -53,11 +60,13 @@ def my_weighted_rms_norm( # AIE Core Function declaration rms_norm_kernel = Kernel( - "rms_norm_bf16_vector", "rms_norm_archive.a", [tile_ty, tile_ty, np.int32] + f"{func_prefix}rms_norm_bf16_vector", + kernel_archive, + [tile_ty, tile_ty, np.int32], ) eltwise_mul_kernel = Kernel( - "eltwise_mul_bf16_vector", - "rms_norm_archive.a", + f"{func_prefix}eltwise_mul_bf16_vector", + kernel_archive, [tile_ty, weights_ty, tile_ty, np.int32], ) @@ -157,96 +166,3 @@ def core_body_mul(of_in1, of_in2, of_out2, eltwise_mul): # Place program components (assign them resources on the device) and generate an MLIR module return Program(dev, rt).resolve_program(SequentialPlacer()) - - -if __name__ == "__main__": - - def str_to_device(device: str): - if device == "npu": - return NPU1() - elif device == "npu2": - return NPU2() - else: - raise ValueError(f"Device name {device} is unknown.") - - p = argparse.ArgumentParser() - # Parse command line arguments - - # Device name is required to select the AIE device: npu or npu2 - p.add_argument( - "-d", - "--dev", - required=True, - dest="device", - help="AIE Device", - type=str_to_device, - ) - # Transfer size is required to define the size of the data to be transferred - # It must be a multiple of 1024 and divisible by the number of columns and 2 channels per column - p.add_argument("-l", "--length", required=True, dest="length", help="Transfer size") - # Number of columns is required to define the number of columns to be used - # It must be less than or equal to 4 for npu and 8 for npu2 - p.add_argument( - "-co", "--columns", required=True, dest="cols", help="Number of columns" - ) - # Number of channels is required to define the number of channels to be used - # It must be 1 or 2 - p.add_argument( - "-ch", "--channels", required=True, dest="chans", help="Number of channels" - ) - # Weight length - p.add_argument( - "-wl", - "--weight-length", - required=True, - dest="weight_length", - help="Weight vector length", - ) - # Trace Size - p.add_argument( - "-ts", "--trace-size", required=True, dest="trace_size", help="Trace size" - ) - p.add_argument( - "--output-file-path", - "-o", - type=str, - help="Output file path for the generated MLIR module", - ) - - opts = p.parse_args(sys.argv[1:]) - - length = int(opts.length) - columns = int(opts.cols) - dev = opts.device # Now this is already a device object! - - # Validate columns based on device type - if isinstance(dev, NPU1) and columns > 4: - raise ValueError("[ERROR] NPU device cannot allocate more than 4 columns") - elif isinstance(dev, NPU2) and columns > 8: - raise ValueError("[ERROR] NPU2 device cannot allocate more than 8 columns") - - channels = int(opts.chans) - if channels < 1 or channels > 2: - raise ValueError("Number of channels must be 1 or 2") - weight_length = int(opts.weight_length) - # For weighted RMS norm: cores = columns (weights are broadcasted) - total_cores = columns - if (length % (weight_length * total_cores)) != 0: - print( - "transfer size (" - + str(length) - + ") must be a multiple of weight_length * total_cores (" - + str(weight_length * total_cores) - + ")" - ) - raise ValueError - trace_size = int(opts.trace_size) if opts.trace_size is not None else 0 - - module = my_weighted_rms_norm( - dev, length, columns, channels, weight_length, trace_size - ) - - output_file_path = Path(opts.output_file_path) - - with open(output_file_path, "w") as f: - f.write(str(module)) diff --git a/iron/operators/rms_norm/op.py b/iron/operators/rms_norm/op.py index 1ba38d92..5ca06e88 100644 --- a/iron/operators/rms_norm/op.py +++ b/iron/operators/rms_norm/op.py @@ -8,8 +8,8 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -20,7 +20,7 @@ from iron.common.utils import torch_to_numpy -class AIERMSNorm(AIEOperatorBase): +class AIERMSNorm(MLIROperator): """AIE-accelerated RMS Normalization layer""" def __init__( @@ -34,9 +34,12 @@ def __init__( context=None, ): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_columns = num_aie_columns @@ -44,158 +47,80 @@ def __init__( self.eps = eps self.weighted = weighted - # Initializes weights to 1. Weights have size embedding dim, which is assumed to be tile size - self.weight = nn.Parameter(torch.ones(tile_size, dtype=torch.bfloat16)) - # Enforce ShimDMA limits for weighted RMS Norm (uses 2 inputs per core) # Maximum safe configuration: 8 columns × 2 channels = 16 ShimDMA channels total_shimdma_channels = self.num_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"weighted_rms_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): - # Compilation artifacts + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"weighted_rms_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", - import_path=operator_dir / "design_weighted.py", - callback_fn="my_weighted_rms_norm", - callback_args=[ + if self.weighted: + import_path = operator_dir / "design_weighted.py" + callback_fn = "my_weighted_rms_norm" + callback_args = [ self.context.device_manager.device_type, self.size, self.num_columns, self.num_channels, self.tile_size, 0, - ], + ] + else: + import_path = operator_dir / "design.py" + callback_fn = "my_rms_norm" + callback_args = [ + self.context.device_manager.device_type, + self.size, + self.num_columns, + self.num_channels, + 0, # trace_size + self.tile_size, + ] + + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", + import_path=import_path, + callback_fn=callback_fn, + callback_args=callback_args, + callback_kwargs={ + "kernel_archive": self.kernel_archive, + }, ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelArchiveArtifact.new( - f"rms_norm_archive.a", - depends=[ - KernelObjectArtifact.new( - f"rms_norm.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "aie2p" - / "rms_norm.cc" - ) - ], - ), - KernelObjectArtifact.new( - "mul.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "generic" - / "mul.cc" - ) - ], - ), + def get_kernel_artifacts(self): + artifacts = [ + KernelObjectArtifact( + f"rms_norm.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "aie2p" / "rms_norm.cc" + ) + ], + ), + ] + if self.weighted: + artifacts.append( + KernelObjectArtifact( + "mul.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "generic" / "mul.cc" + ) ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # Runtime setup - static_weights = None - if self.weight is not None: - static_weights = torch_to_numpy(self.weight) - - self.add_buffer("input1", self.size) - self.add_buffer("input2", self.tile_size, static_data=static_weights) - self.add_buffer("output", self.size) - self.add_kernel( - "eltwise_mul", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("eltwise_mul", "input1", "input2", "output") - - def forward(self, x, y=None): - """Forward pass through RMS normalization""" - applicable = ( - len(x.shape) >= 1 and x.shape[-1] <= self.size and x.numel() <= self.size - ) - if not applicable: - raise AIEOperatorConstraintError("AIERMSNorm: incompatible tensor shape(s)") - - # Always flatten to [batch, orig_size] - original_shape = x.shape - batch = x.shape[0] if x.dim() > 1 else 1 - x_flat = x.reshape(batch, -1) - if y is not None: - y_flat = y.reshape(batch, -1) - else: - y_flat = None - - pad_len = self.size - x_flat.shape[1] - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - out = self._execute_aie_operation(x_flat, y_flat) - - # Remove padding if added - numel = np.prod(original_shape) - if pad_len > 0: - out = out.reshape(-1)[..., :numel] - # Restore original shape - out = out.reshape(*original_shape) - - return out - - def _execute_aie_operation(self, x, y=None): - """Execute RMS normalization on AIE hardware""" - # x, y are [batch, size] - batch = x.shape[0] if x.dim() > 1 else 1 - - # Flatten inputs for AIE processing - x_flat = x.view(-1) - if y is not None: - y_flat = y.view(-1) - - # Verify size matches expected - if len(x_flat) != self.size: - raise AIEOperatorConstraintError( - f"Input size x={len(x_flat)} doesn't match configured size {self.size}" + ) ) - - self.write_buffer("input1", x_flat) - if y is not None: - self.write_buffer("input2", y_flat) - else: - assert ( - self.weight is not None - ), "Weights must be provided either as input or during initialization." - test_pattern = np.zeros(len(x_flat), dtype=bfloat16) - self.write_buffer("output", test_pattern) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=x_flat.shape, dtype=bfloat16) - - return result + return artifacts + + def get_arg_spec(self): + specs = [AIERuntimeArgSpec("in", (self.size // self.tile_size, self.tile_size))] + if self.weighted: + specs.append(AIERuntimeArgSpec("in", (self.tile_size,))) + specs.append( + AIERuntimeArgSpec("out", (self.size // self.tile_size, self.tile_size)) + ) + return specs diff --git a/iron/operators/rms_norm/test.py b/iron/operators/rms_norm/test.py index e6dd012d..f7c183f9 100755 --- a/iron/operators/rms_norm/test.py +++ b/iron/operators/rms_norm/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 2 - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for weighted in [False, True]: for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): @@ -37,37 +36,26 @@ def generate_test_params(extensive=False): check_length = tile_size * num_aie_columns if check_length == input_length: if not weighted: - names.append( - f"rms_norm_{num_aie_columns}_cols_{num_channels_rms}_channels_{input_length}_tile_{tile_size}" - ) + name = f"rms_norm_{num_aie_columns}_cols_{num_channels_rms}_channels_{input_length}_tile_{tile_size}" else: - names.append( - f"weighted_rms_norm_{num_aie_columns}_cols_{num_channels_rms}_channels_{input_length}_weights_{tile_size}" - ) + name = f"weighted_rms_norm_{num_aie_columns}_cols_{num_channels_rms}_channels_{input_length}_weights_{tile_size}" + + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] + params.append( - ( + pytest.param( input_length, num_aie_columns, num_channels_rms, tile_size, weighted, + id=name, + marks=marks, ) ) - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -76,7 +64,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size,weighted", - all_params, + get_params(), ) def test_rms_norm( input_length, num_aie_columns, num_channels, tile_size, weighted, aie_context @@ -97,6 +85,7 @@ def test_rms_norm( input_buffers = {"input1": golden_ref["input"]} if weighted: operator.weight = golden_ref["weight"] + input_buffers["weight"] = golden_ref["weight"] output_buffers = {"output": golden_ref["output"]} errors, latency_us, bandwidth_gbps = run_test( diff --git a/iron/operators/rope/design.py b/iron/operators/rope/design.py index f1082bdd..f486071d 100644 --- a/iron/operators/rope/design.py +++ b/iron/operators/rope/design.py @@ -37,11 +37,17 @@ def rope( num_aie_columns=1, trace_size=0, method_type=None, + kernel_archive=None, + func_prefix="", ): dtype = bfloat16 if angle_rows is None: angle_rows = rows + if kernel_archive is None: + kernel_archive = ( + "rope" + (f"_{method_type}" if method_type is not None else "") + ".o" + ) assert cols % (16 * 2) == 0 and cols >= ( 16 * 2 @@ -73,8 +79,8 @@ def rope( # AIE Core Function declaration rope_kernel = Kernel( - "rope", - "rope" + (f"_{method_type}" if method_type is not None else "") + ".o", + f"{func_prefix}rope", + kernel_archive, [tensor_tile_ty, angle_tile_ty, tensor_tile_ty, np.int32], ) @@ -127,7 +133,7 @@ def core_body(of_in, of_lut, of_out, rope_kernel): # Runtime operations to move data to/from the AIE-array rt = Runtime() - with rt.sequence(tensor_ty, tensor_ty, tensor_ty) as (A, B, C): + with rt.sequence(tensor_ty, angle_ty, tensor_ty) as (A, B, C): rt.start(*my_workers) # Initialize a group for parallel drain tasks, with fill resources free'd when drains complete. diff --git a/iron/operators/rope/op.py b/iron/operators/rope/op.py index be8e7f95..fa6f1e6a 100644 --- a/iron/operators/rope/op.py +++ b/iron/operators/rope/op.py @@ -1,38 +1,43 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import torch -import numpy as np -from ml_dtypes import bfloat16 from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, - KernelArchiveArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIERope(AIEOperatorBase): +class AIERope(MLIROperator): def __init__( self, rows: int, cols: int, angle_rows=None, - num_aie_columns=None, + num_aie_columns=1, method_type=0, context=None, ): if angle_rows is None: angle_rows = rows - if num_aie_columns is None: - num_aie_columns = 1 + + assert cols % (16 * 2) == 0 and cols >= ( + 16 * 2 + ), "cols must be multiple of 32 and >= 32" + assert rows % num_aie_columns == 0, "rows must be divisible by num_aie_columns" + assert ( + angle_rows <= rows and rows % angle_rows == 0 + ), "angle_rows must divide rows" + assert ( + angle_rows >= num_aie_columns and angle_rows % num_aie_columns == 0 + ), "angle_rows must be divisible by num_aie_columns" self.rows = rows self.cols = cols @@ -41,19 +46,15 @@ def __init__( self.method_type = method_type assert method_type in {0, 1} - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"rope_{self.num_aie_columns}col_{self.rows}rows_{self.cols}cols_{self.angle_rows}arows_{self.method_type}m" - def set_up_artifacts(self): - # Compilation artifacts + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"rope_{self.num_aie_columns}c_{self.rows}rows_{self.cols}cols_{self.angle_rows}arows_{self.method_type}m" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="rope", callback_args=[ @@ -67,68 +68,42 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"rope_{self.method_type}.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "generic" - / "rope.cc" - ) - ], - extra_flags=[ - "-DTWO_HALVES" if 0 == self.method_type else "-DINTERLEAVED" - ], + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"rope_{self.method_type}.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "generic" / "rope.cc" + ) + ], + extra_flags=[ + "-DTWO_HALVES" if 0 == self.method_type else "-DINTERLEAVED" + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec( + "in", + ( + self.rows, + self.cols, ), - ], - ) - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # Runtime setup - self.add_buffer("in", self.rows * self.cols) - self.add_buffer("angles", self.angle_rows * self.cols) - self.add_buffer("output", self.rows * self.cols) - self.add_kernel( - "rope", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("rope", "in", "angles", "output") - - def forward(self, tensor, angles): - applicable = ( - tensor.shape[-2] == self.rows - and tensor.shape[-1] == self.cols - and tensor.shape[-1] % 16 == 0 - and angles.shape[-2] == self.angle_rows - and angles.shape[-1] == self.cols - ) - if not applicable: - raise AIEOperatorConstraintError("AIERope: incompatible tensor shape(s)") - - # Write data to buffers - self.write_buffer("in", tensor) - self.write_buffer("angles", angles) - - # Execute kernel - self.run_runlist() - - # Read output - result = self.read_buffer_as_torch("output", shape=tensor.shape, dtype=bfloat16) - - return result + ), # input tensor + AIERuntimeArgSpec( + "in", + ( + self.angle_rows, + self.cols, + ), + ), # angles + AIERuntimeArgSpec( + "out", + ( + self.rows, + self.cols, + ), + ), # output + ] diff --git a/iron/operators/rope/test.py b/iron/operators/rope/test.py index 095a8cc3..d7156a7b 100755 --- a/iron/operators/rope/test.py +++ b/iron/operators/rope/test.py @@ -12,55 +12,49 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): - params = [] - names = [] - +def get_params(): num_aie_columns_options = [1, 2, 8] - if not extensive: - input_rows = [32] - input_cols = [512] - input_angle_rows = [8, 32] - method_types = [0] # 0: Two-halves method - else: - input_rows = [32, 64] - input_cols = [128] - input_angle_rows = [8, 16, 32] - method_types = [0, 1] # 0: Two-halves method, 1: interleaved method + # Combine all options + input_rows = [32, 64] + input_cols = [128, 512] + input_angle_rows = [8, 16, 32] + method_types = [0, 1] # 0: Two-halves method, 1: interleaved method + params = [] for num_aie_columns in num_aie_columns_options: for n_rows in input_rows: for n_angle_rows in input_angle_rows: for n_cols in input_cols: for method_type in method_types: - names.append( - f"rope_{num_aie_columns}c_{n_rows}rows_{n_cols}cols_{n_angle_rows}arows_{method_type}m" + name = f"rope_{num_aie_columns}c_{n_rows}rows_{n_cols}cols_{n_angle_rows}arows_{method_type}m" + + is_regular = ( + n_rows == 32 + and n_cols == 512 + and n_angle_rows in [8, 32] + and method_type == 0 ) + + is_extensive_valid = n_cols == 128 + + if not is_regular and not is_extensive_valid: + continue + + marks = [] if is_regular else [pytest.mark.extensive] + params.append( - ( + pytest.param( n_rows, n_cols, n_angle_rows, num_aie_columns, method_type, + id=name, + marks=marks, ) ) - - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -69,7 +63,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "rows,cols,angle_rows,aie_columns,method_type", - all_params, + get_params(), ) def test_rope(rows, cols, angle_rows, aie_columns, method_type, aie_context): golden_ref = generate_golden_reference( @@ -97,12 +91,7 @@ def test_rope(rows, cols, angle_rows, aie_columns, method_type, aie_context): operator, input_buffers, output_buffers, rel_tol=0.05, abs_tol=0.5 ) - print(golden_ref["C"]) - print( - operator.read_buffer_as_torch("output", (rows // angle_rows, angle_rows, cols)) - ) - print(f"\nLatency (us): {latency_us:.1f}") print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") - # assert not errors, f"Test failed with errors: {errors}" + assert not errors, f"Test failed with errors: {errors}" diff --git a/iron/operators/sigmoid/design.py b/iron/operators/sigmoid/design.py index 49d33502..927f9432 100644 --- a/iron/operators/sigmoid/design.py +++ b/iron/operators/sigmoid/design.py @@ -14,7 +14,9 @@ from aie.iron.controlflow import range_ -def my_sigmoid(dev, size, num_columns, num_channels, tile_size, trace_size): +def my_sigmoid( + dev, size, num_columns, num_channels, tile_size, trace_size, kernel_archive=None +): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] diff --git a/iron/operators/sigmoid/op.py b/iron/operators/sigmoid/op.py index a24d051d..0135800b 100644 --- a/iron/operators/sigmoid/op.py +++ b/iron/operators/sigmoid/op.py @@ -7,8 +7,8 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -17,14 +17,17 @@ ) -class AIESigmoid(AIEOperatorBase): +class AIESigmoid(MLIROperator): """AIE-accelerated Sigmoid activation function""" def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_columns = num_aie_columns @@ -33,17 +36,15 @@ def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None) total_shimdma_channels = self.num_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"sigmoid_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"sigmoid_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_sigmoid", callback_args=[ @@ -56,62 +57,20 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"sigmoid.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "aie2p" - / "sigmoid.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "sigmoid", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("sigmoid", "input", "output") - - def forward(self, x): - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIESigmoid: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"sigmoid.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "aie2p" / "sigmoid.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/iron/operators/sigmoid/test.py b/iron/operators/sigmoid/test.py index 1dc5b99d..641fca96 100755 --- a/iron/operators/sigmoid/test.py +++ b/iron/operators/sigmoid/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 1 # 1 channel for 1 input - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns @@ -26,24 +25,22 @@ def generate_test_params(extensive=False): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - names.append( - f"sigmoid_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + name = f"sigmoid_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] + + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) ) - params.append((input_length, num_aie_columns, num_channels, tile_size)) - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -52,7 +49,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_sigmoid(input_length, num_aie_columns, num_channels, tile_size, aie_context): golden_ref = generate_golden_reference(input_length=input_length) diff --git a/iron/operators/silu/design.py b/iron/operators/silu/design.py index 5968943b..4c041afb 100644 --- a/iron/operators/silu/design.py +++ b/iron/operators/silu/design.py @@ -12,15 +12,19 @@ from aie.iron.device import Tile, NPU1, NPU2 from aie.helpers.taplib.tap import TensorAccessPattern from aie.iron.controlflow import range_ +from aie.helpers.util import np_ndarray_type_get_shape -def my_silu(dev, size, num_columns, num_channels, tile_size, trace_size): +def my_silu( + dev, size, num_columns, tile_size, trace_size, kernel_archive, func_prefix="" +): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] transfer_type = np.ndarray[(size,), np.dtype[xfr_dtype]] - # Calculate number of iterations per core + # Calculate number of iterations per core (using 1 channel per column) + num_channels = 1 total_cores = num_columns * num_channels per_core_elements = size // total_cores N_div_n = per_core_elements // line_size @@ -42,8 +46,8 @@ def my_silu(dev, size, num_columns, num_channels, tile_size, trace_size): # External, binary kernel definition silu_fcn = Kernel( - "silu_bf16", - "silu.o", + f"{func_prefix}silu_bf16", + kernel_archive, [line_type, line_type, np.int32], ) @@ -152,11 +156,6 @@ def str_to_device(device: str): p.add_argument( "-co", "--columns", required=True, dest="cols", help="Number of columns" ) - # Number of channels is required to define the number of channels to be used - # It must be 1 or 2 - p.add_argument( - "-ch", "--channels", required=True, dest="chans", help="Number of channels" - ) # Tile size (elements per tile) - defaults to 1024 for backward compatibility p.add_argument( "-ts", @@ -189,11 +188,10 @@ def str_to_device(device: str): elif isinstance(dev, NPU2) and columns > 8: raise ValueError("[ERROR] NPU2 device cannot allocate more than 8 columns") - channels = int(opts.chans) - if channels < 1 or channels > 2: - raise ValueError("Number of channels must be 1 or 2") tile_size = int(opts.tile_size) - if ((length % tile_size) % columns % channels) != 0: + # Using 1 channel per column for SiLU + num_channels = 1 + if ((length % tile_size) % columns % num_channels) != 0: print( "transfer size (" + str(length) @@ -204,7 +202,7 @@ def str_to_device(device: str): raise ValueError trace_size = opts.trace_size - module = my_silu(dev, length, columns, channels, tile_size, trace_size) + module = my_silu(dev, length, columns, tile_size, trace_size, "silu.o") output_file_path = Path(opts.output_file_path) diff --git a/iron/operators/silu/op.py b/iron/operators/silu/op.py index 3583868c..1fe853f2 100644 --- a/iron/operators/silu/op.py +++ b/iron/operators/silu/op.py @@ -1,155 +1,68 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import torch -import numpy as np -from ml_dtypes import bfloat16 from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, - KernelArchiveArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIESiLU(AIEOperatorBase): +class AIESiLU(MLIROperator): """AIE-accelerated SiLU activation function""" - def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): - max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + def __init__(self, size, tile_size, num_aie_columns=8, context=None): + assert ( + size % (num_aie_columns * tile_size) == 0 + ), "size must be multiple of num_aie_columns * tile_size" + self.size = size self.tile_size = tile_size - - self.num_columns = num_aie_columns - self.num_channels = num_channels + self.num_aie_columns = num_aie_columns # Enforce ShimDMA limits for SiLU (uses 1 input per core) - # Maximum safe configuration: 8 columns × 2 channels = 16 ShimDMA channels - total_shimdma_channels = self.num_columns * self.num_channels + # Maximum safe configuration: 8 columns × 1 channel = 8 ShimDMA channels + total_shimdma_channels = self.num_aie_columns * 1 assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" + MLIROperator.__init__(self, context=context) - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None - - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"silu_{self.num_aie_columns}col_{self.size}_{self.tile_size}t" - def get_artifacts(self, prefix="silu_"): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"{prefix}{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_silu", callback_args=[ self.context.device_manager.device_type, self.size, - self.num_columns, - self.num_channels, + self.num_aie_columns, self.tile_size, 0, ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"silu.o", - depends=[ - SourceArtifact.new( - self.context.base_dir / "aie_kernels" / "aie2p" / "silu.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - return xclbin_artifact, insts_artifact - - def set_up_artifacts(self): - # If this operator is only used as a sub-operator in another operator that sets it up, we should skip the setup here as those artifacts and buffers may not be needed. - # Compilation artifacts - xclbin_artifact, insts_artifact = self.get_artifacts() - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # If this operator is only used as a sub-operator in another operator that sets it up, we should skip the setup here as those artifacts and buffers may not be needed. + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"silu.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "aie2p" / "silu.cc" + ) + ], + ), + ] + + def get_arg_spec(self): # Runtime setup - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "silu", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("silu", "input", "output") - - def forward(self, x): - """Forward pass for SiLU activation""" - applicable = ( - len(x.shape) >= 1 and x.shape[-1] <= self.size and x.numel() <= self.size - ) - if not applicable: - raise AIEOperatorConstraintError("AIESiLU: incompatible tensor shape(s)") - - # Always flatten to [batch, orig_size] - original_shape = x.shape - batch = x.shape[0] if x.dim() > 1 else 1 - x_flat = x.reshape(batch, -1) - - pad_len = self.size - x_flat.shape[1] - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - out = self._execute_aie_operation(x_flat) - - # Remove padding if added - numel = np.prod(original_shape) - if pad_len > 0: - out = out.reshape(-1)[..., :numel] - # Restore original shape - out = out.reshape(*original_shape) - - return out - - def _execute_aie_operation(self, x, y=None): - """Execute SiLU operation on AIE hardware""" - # x is [batch, size] - batch = x.shape[0] if x.dim() > 1 else 1 - - # Flatten inputs for AIE processing - x_flat = x.view(-1) - - # Verify size matches expected - if len(x_flat) != self.size: - raise AIEOperatorConstraintError( - f"Input size x={len(x_flat)} doesn't match configured size {self.size}" - ) - - self.write_buffer("input", x_flat) - test_pattern = np.zeros(len(x_flat), dtype=bfloat16) - self.write_buffer("output", test_pattern) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=x_flat.shape, dtype=bfloat16) - - return result + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/iron/operators/silu/test.py b/iron/operators/silu/test.py index 4dc52ba0..6eb22f20 100755 --- a/iron/operators/silu/test.py +++ b/iron/operators/silu/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 1 # 1 channel for 1 input - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns @@ -26,24 +25,22 @@ def generate_test_params(extensive=False): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - names.append( - f"silu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + name = f"silu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] + + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) ) - params.append((input_length, num_aie_columns, num_channels, tile_size)) - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -52,7 +49,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_silu(input_length, num_aie_columns, num_channels, tile_size, aie_context): golden_ref = generate_golden_reference(input_length=input_length) @@ -60,7 +57,6 @@ def test_silu(input_length, num_aie_columns, num_channels, tile_size, aie_contex operator = AIESiLU( size=input_length, num_aie_columns=num_aie_columns, - num_channels=num_channels, tile_size=tile_size, context=aie_context, ) diff --git a/iron/operators/softmax/design.py b/iron/operators/softmax/design.py index 981312be..567dbbc6 100644 --- a/iron/operators/softmax/design.py +++ b/iron/operators/softmax/design.py @@ -7,7 +7,15 @@ import argparse import sys -from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker +from aie.iron import ( + Kernel, + ObjectFifo, + Program, + Runtime, + Worker, + Buffer, + WorkerRuntimeBarrier, +) from aie.iron.placers import SequentialPlacer from aie.iron.device import NPU1, NPU2 from aie.helpers.taplib.tap import TensorAccessPattern @@ -15,15 +23,28 @@ from ml_dtypes import bfloat16 -def softmax(dev, num_elements, num_columns, num_channels, trace_size, tile_size): +def softmax( + dev, + num_elements, + num_aie_columns, + num_channels, + trace_size, + tile_size, + rtp_vector_size=None, + mask_patch_value=0, + kernel_archive="softmax.a", + func_prefix="", +): per_tile_elements = tile_size - n = per_tile_elements * num_columns + if rtp_vector_size is None: + rtp_vector_size = per_tile_elements + n = per_tile_elements * num_aie_columns if num_elements % n != 0: raise ValueError( f"Number of elements ({num_elements}) must be a multiple of {n}." ) N_div_n = num_elements // n - chunk = num_elements // num_columns // num_channels # For offset calculation + chunk = num_elements // num_aie_columns // num_channels # For offset calculation dtype = bfloat16 # Define tensor types @@ -33,28 +54,52 @@ def softmax(dev, num_elements, num_columns, num_channels, trace_size, tile_size) # AIE-array data movement with object fifos of_in1s = [ ObjectFifo(tile_ty, name=f"in1_{i}_{j}") - for i in range(num_columns) + for i in range(num_aie_columns) for j in range(num_channels) ] of_outs = [ ObjectFifo(tile_ty, name=f"out_{i}_{j}") - for i in range(num_columns) + for i in range(num_aie_columns) for j in range(num_channels) ] # AIE Core Function declaration - softmax_kernel = Kernel("softmax_bf16", "softmax.o", [tile_ty, tile_ty, np.int32]) + softmax_kernel = Kernel( + f"{func_prefix}softmax_bf16", kernel_archive, [tile_ty, tile_ty, np.int32] + ) + mask_kernel = Kernel( + f"{func_prefix}mask_bf16", kernel_archive, [tile_ty, np.int32, np.int32] + ) # Define a task that will run on a compute tile - def core_body(of_in1, of_out, softmax_kernel): + def core_body(of_in1, of_out, softmax_kernel, mask_kernel, rtp, barrier): # Number of sub-vector "tile" iterations + barrier.wait_for_value(1) + vector_size = rtp[0] for _ in range_(N_div_n): elem_in1 = of_in1.acquire(1) elem_out = of_out.acquire(1) + mask_kernel(elem_in1, vector_size, per_tile_elements) softmax_kernel(elem_in1, elem_out, per_tile_elements) of_in1.release(1) of_out.release(1) + rtps = [ + Buffer( + np.ndarray[(1,), np.dtype[np.int32]], + name=f"rtp_{i}_{j}", + use_write_rtp=True, + ) + for i in range(num_aie_columns) + for j in range(num_channels) + ] + + barriers = [ + WorkerRuntimeBarrier() + for i in range(num_aie_columns) + for j in range(num_channels) + ] + # Create a worker to run the task on a compute tile my_workers = [ Worker( @@ -63,9 +108,12 @@ def core_body(of_in1, of_out, softmax_kernel): of_in1s[i * num_channels + j].cons(), of_outs[i * num_channels + j].prod(), softmax_kernel, + mask_kernel, + rtps[i * num_channels + j], + barriers[i * num_channels + j], ], ) - for i in range(num_columns) + for i in range(num_aie_columns) for j in range(num_channels) ] @@ -81,7 +129,7 @@ def core_body(of_in1, of_out, softmax_kernel): [1, 1, 1, chunk], [0, 0, 0, 1], ) - for i in range(num_columns) + for i in range(num_aie_columns) for j in range(num_channels) ] @@ -90,11 +138,21 @@ def core_body(of_in1, of_out, softmax_kernel): with rt.sequence(tensor_ty, tensor_ty) as (A, C): rt.start(*my_workers) + # Set run-time parameter for actual vector size (remainder is considered padding and ignored by the computation) + def set_rtps(*args): + for rtp in args: + rtp[0] = rtp_vector_size if not mask_patch_value else mask_patch_value + + rt.inline_ops(set_rtps, rtps) + + for i in range(num_aie_columns * num_channels): + rt.set_barrier(barriers[i], 1) + # Initialize a group for parallel drain tasks, with fill resources free'd when drains complete. tg = rt.task_group() # Fill the input objectFIFOs with data - for i in range(num_columns): + for i in range(num_aie_columns): for j in range(num_channels): rt.fill( of_in1s[i * num_channels + j].prod(), @@ -103,7 +161,7 @@ def core_body(of_in1, of_out, softmax_kernel): task_group=tg, ) # Drain the output objectFIFOs with data - for i in range(num_columns): + for i in range(num_aie_columns): for j in range(num_channels): rt.drain( of_outs[i * num_channels + j].cons(), diff --git a/iron/operators/softmax/op.py b/iron/operators/softmax/op.py index 106f0415..2beb0627 100644 --- a/iron/operators/softmax/op.py +++ b/iron/operators/softmax/op.py @@ -2,133 +2,86 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import numpy as np -from ml_dtypes import bfloat16 from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, - KernelArchiveArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIESoftmax(AIEOperatorBase): +class AIESoftmax(MLIROperator): + """AIE-accelerated Softmax operation""" def __init__( - self, rows: int, cols: int, num_aie_columns=1, num_channels=1, context=None + self, + rows: int, + cols: int, + num_aie_columns=1, + num_channels=1, + rtp_vector_size=None, + mask_patch_value=0, + context=None, ): - self.size = rows * cols + assert rows % 16 == 0, "rows must be multiple of 16" + assert cols % 16 == 0, "cols must be multiple of 16" + assert (rows * cols) % ( + num_aie_columns * cols + ) == 0, "size must be multiple of num_aie_columns * tile_size" + self.rows = rows self.cols = cols - + self.size = rows * cols + self.num_aie_columns = num_aie_columns self.num_channels = num_channels - self.num_columns = num_aie_columns + self.rtp_vector_size = rtp_vector_size + self.mask_patch_value = mask_patch_value - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + name = f"softmax_{self.num_aie_columns}col_{self.num_channels}ch_{self.size}_{self.cols}t" + if self.rtp_vector_size is not None: + name += f"_{self.rtp_vector_size}rtp" + return name - def set_up_artifacts(self): - # Compilation artifacts + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"softmax_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.cols}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="softmax", callback_args=[ self.context.device_manager.device_type, - self.rows * self.cols, - self.num_columns, + self.size, + self.num_aie_columns, self.num_channels, - 0, + 0, # trace_size self.cols, + self.rtp_vector_size, + self.mask_patch_value, ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"softmax.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "aie2p" - / "softmax.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"gemm_{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # Runlist setup - self.add_buffer("in", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "softmax", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("softmax", "in", "output") - - def forward(self, x): - applicable = ( - x.shape[-1] * x.shape[-2] == self.size - and x.shape[-1] == self.cols - and x.shape[-1] % 16 == 0 - and x.shape[-2] % 16 == 0 - ) - if not applicable: - raise AIEOperatorConstraintError("AIESoftmax: incompatible tensor shape(s)") - - return self._execute_aie_operation(x) - - def _execute_aie_operation(self, x): - original_shape = x.shape - - # Reshape for processing - # Split x into a list of H tensors of size [S_q, S_kv] - heads = x.shape[1] - x_list = [x[0, h, :, :] for h in range(heads)] - results = [] - for i in range(heads): - x_iter = x_list[i] - input_size = x_iter.nbytes - self.write_buffer("in", x_iter) - test_pattern = np.zeros(len(x_iter), dtype=bfloat16) - self.write_buffer("output", test_pattern) - self.run_runlist() - result = self.read_buffer_as_torch( - "output", shape=x_list[i].shape, dtype=bfloat16 - ) - results.append(result) - - result = torch.stack(results, dim=0).unsqueeze( - 0 - ) # Shape: (1, heads, S_q, S_kv) - - return result + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"softmax.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "aie2p" / "softmax.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), + AIERuntimeArgSpec("out", (self.size,)), + ] diff --git a/iron/operators/softmax/test.py b/iron/operators/softmax/test.py index 1ad613d9..093610a7 100755 --- a/iron/operators/softmax/test.py +++ b/iron/operators/softmax/test.py @@ -30,37 +30,27 @@ def get_optimal_columns_channels(input_length, tile_size): return 2, 2 # Default fallback -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 2 - input_lengths = [4096] if not extensive else [] + input_lengths = [32768] tile_sizes = [1024, 512, 2048] params = [] - names = [] for input_length in input_lengths: for tile_size in tile_sizes: optimal_columns, optimal_channels = get_optimal_columns_channels( input_length, tile_size ) - names.append( - f"softmax_{optimal_columns}_cols_{optimal_channels}_channels_{input_length}_tile_{tile_size}" - ) - params.append((input_length, optimal_columns, optimal_channels, tile_size)) - return params, names - + name = f"softmax_{optimal_columns}_cols_{optimal_channels}_channels_{input_length}_tile_{tile_size}" -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + # All tests are regular as extensive list was empty in original code + params.append( + pytest.param( + input_length, optimal_columns, optimal_channels, tile_size, id=name + ) + ) + return params @pytest.mark.metrics( @@ -69,7 +59,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_softmax(input_length, num_aie_columns, num_channels, tile_size, aie_context): diff --git a/iron/operators/strided_copy/design.py b/iron/operators/strided_copy/design.py new file mode 100644 index 00000000..63b97e33 --- /dev/null +++ b/iron/operators/strided_copy/design.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +# from aie.extras.context import mlir_mod_ctx +# from aie.ir import StridedLayoutAttr, ShapedType +# from aie.dialects.aie import * +# from aie.dialects.aiex import * +from aie.dialects.aiex import TensorAccessPattern +from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker +from aie.iron.placers import SequentialPlacer + +""" +Strided copy design + +This can be useful for data layout manipulation and data copying such as: +input[0, :, 0] -> output[:, 0, 0] +""" + + +def strided_copy( + dev, + dtype, + input_buffer_size, + input_sizes, + input_strides, + input_offset, + output_buffer_size, + output_sizes, + output_strides, + output_offset, + transfer_size=None, + num_aie_channels=1, + input_offset_patch_marker=0, + output_offset_patch_marker=0, +): + assert len(input_sizes) == len(input_strides) + assert len(output_sizes) == len(output_strides) + + # Pad out dimensions to 4D; dropping leading dimensions leads to compiler not initializing these registers, causing hard-to-debug errors + input_sizes = [1] * (4 - len(input_sizes)) + list(input_sizes) + input_strides = [0] * (4 - len(input_strides)) + list(input_strides) + output_sizes = [1] * (4 - len(output_sizes)) + list(output_sizes) + output_strides = [0] * (4 - len(output_strides)) + list(output_strides) + + input_highest_sz_idx = max(idx for idx, sz in enumerate(input_sizes) if sz >= 1) + output_highest_sz_idx = max(idx for idx, sz in enumerate(output_sizes) if sz >= 1) + assert ( + input_sizes[input_highest_sz_idx] % num_aie_channels == 0 + ), "Highest dimension of input_sizes must be divisible by num_aie_channels" + assert ( + output_sizes[output_highest_sz_idx] % num_aie_channels == 0 + ), "Highest dimension of output_sizes must be divisible by num_aie_channels" + + if transfer_size is None: + transfer_size = int(np.prod(input_sizes)) + assert np.prod(input_sizes) % transfer_size == 0 + transfer_ty = np.ndarray[ + (transfer_size,), + np.dtype[dtype], + ] + + inp_ty = np.ndarray[ + (int(input_buffer_size),), + np.dtype[dtype], + ] + out_ty = np.ndarray[ + (int(output_buffer_size),), + np.dtype[dtype], + ] + + input_taps = [ + TensorAccessPattern( + tensor_dims=(int(input_buffer_size + input_offset_patch_marker),), + offset=( + input_offset_patch_marker + if input_offset_patch_marker != 0 + else input_offset + + c + * (input_sizes[input_highest_sz_idx] // num_aie_channels) + * input_strides[input_highest_sz_idx] + ), + sizes=( + input_sizes[:input_highest_sz_idx] + + [input_sizes[input_highest_sz_idx] // num_aie_channels] + + input_sizes[input_highest_sz_idx + 1 :] + ), + strides=list(input_strides), + ) + for c in range(num_aie_channels) + ] + + output_taps = [ + TensorAccessPattern( + tensor_dims=(int(output_buffer_size + output_offset_patch_marker),), + offset=( + output_offset_patch_marker + if output_offset_patch_marker != 0 + else output_offset + + c + * (output_sizes[output_highest_sz_idx] // num_aie_channels) + * output_strides[output_highest_sz_idx] + ), + sizes=( + output_sizes[:output_highest_sz_idx] + + [output_sizes[output_highest_sz_idx] // num_aie_channels] + + output_sizes[output_highest_sz_idx + 1 :] + ), + strides=list(output_strides), + ) + for c in range(num_aie_channels) + ] + + # Use smaller FIFOs for the transfer amount + fifos_in = [ + ObjectFifo(transfer_ty, name=f"fifo_in_{c}", depth=1) + for c in range(num_aie_channels) + ] + fifos_out = [ + fifos_in[c].cons().forward(name=f"fifo_out_{c}", depth=1) + for c in range(num_aie_channels) + ] + + rt = Runtime() + with rt.sequence(inp_ty, out_ty) as (inp, out): + tg = rt.task_group() + for c in range(num_aie_channels): + rt.fill(fifos_in[c].prod(), inp, input_taps[c], task_group=tg) + rt.drain(fifos_out[c].cons(), out, output_taps[c], task_group=tg, wait=True) + rt.finish_task_group(tg) + + return Program(dev, rt).resolve_program(SequentialPlacer()) diff --git a/iron/operators/strided_copy/op.py b/iron/operators/strided_copy/op.py new file mode 100644 index 00000000..5996a90d --- /dev/null +++ b/iron/operators/strided_copy/op.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from ml_dtypes import bfloat16 +from pathlib import Path + +from iron.common import ( + MLIROperator, + AIERuntimeArgSpec, + KernelObjectArtifact, + SourceArtifact, + PythonGeneratedMLIRArtifact, + XclbinArtifact, + InstsBinArtifact, +) + + +class AIEStridedCopy(MLIROperator): + """AIE-accelerated General Matrix-Vector/Vector-Matrix Multiplication layer""" + + def __init__( + self, + input_sizes, + input_strides, + input_offset, + output_sizes, + output_strides, + output_offset, + input_buffer_size, + output_buffer_size, + dtype=bfloat16, + transfer_size=None, + num_aie_channels=1, + context=None, + **kwargs, + ): + assert len(input_sizes) == len(input_strides) + assert len(output_sizes) == len(output_strides) + self.input_sizes = input_sizes + self.input_strides = input_strides + self.input_offset = input_offset + self.output_sizes = output_sizes + self.output_strides = output_strides + self.output_offset = output_offset + self.input_buffer_size = input_buffer_size + self.output_buffer_size = output_buffer_size + self.dtype = dtype + self.transfer_size = transfer_size + self.num_aie_channels = num_aie_channels + self.kwargs = kwargs + MLIROperator.__init__(self, context=context) + + def get_operator_name(self): + return f"strided_copy_{'x'.join(map(str, self.input_sizes))}sz_{'x'.join(map(str, self.input_strides))}st_{self.input_offset}off_to_{'x'.join(map(str, self.output_sizes))}sz_{'x'.join(map(str, self.output_strides))}st_{self.output_offset}off_{self.transfer_size if self.transfer_size is not None else 'auto'}tr_{self.num_aie_channels}ch" + + def get_mlir_artifact(self): + operator_dir = Path(__file__).parent + + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", + import_path=operator_dir / "design.py", + callback_fn="strided_copy", + callback_args=[ + self.context.device_manager.device_type, + self.dtype, + self.input_buffer_size, + self.input_sizes, + self.input_strides, + self.input_offset, + self.output_buffer_size, + self.output_sizes, + self.output_strides, + self.output_offset, + self.transfer_size, + self.num_aie_channels, + ], + callback_kwargs=self.kwargs, + ) + + def get_kernel_artifacts(self): + return [] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", self.input_buffer_size), # matrix + AIERuntimeArgSpec("out", self.output_buffer_size), # output + ] diff --git a/iron/operators/swiglu_decode/op.py b/iron/operators/swiglu_decode/op.py index 869493c9..05496634 100644 --- a/iron/operators/swiglu_decode/op.py +++ b/iron/operators/swiglu_decode/op.py @@ -7,7 +7,10 @@ from ml_dtypes import bfloat16 from iron.common import ( - AIEOperatorBase, + CompositeOperator, + AIERuntimeArgSpec, + AIEBuffer, + SingleXclbinCallable, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -21,7 +24,91 @@ from iron.common.utils import torch_to_numpy -class AIESwiGLUDecode(AIEOperatorBase): +class SwiGLUDecodeCallable: + def __init__(self, op): + self.op = op + # Create callables for sub-operators + # We need to manually construct SingleXclbinCallable because sub-operators weren't "compiled" in the standard way + + # Helper to create callable from operator and artifacts + def create_callable(sub_op, xclbin_path, kernel_name, insts_artifact): + return SingleXclbinCallable( + xclbin_path=xclbin_path, + kernel_name=kernel_name, + insts_bin_path=insts_artifact.filename, + args_spec=sub_op.get_arg_spec(), + ) + + self.gemv_1_callable = create_callable( + op.gemv_1, + op.combined_xclbin.filename, + op.gemv_1_xclbin.kernel_name, + op.gemv_1_insts, + ) + self.silu_callable = create_callable( + op.silu, + op.combined_xclbin.filename, + op.silu_xclbin.kernel_name, + op.silu_insts, + ) + self.eltwise_mul_callable = create_callable( + op.eltwise_mul, + op.combined_xclbin.filename, + op.eltwise_mul_xclbin.kernel_name, + op.eltwise_mul_insts, + ) + self.gemv_2_callable = create_callable( + op.gemv_2, + op.combined_xclbin.filename, + op.gemv_2_xclbin.kernel_name, + op.gemv_2_insts, + ) + + # Allocate and upload weights + self.weights_1 = AIEBuffer.from_np(torch_to_numpy(op.weights_1)) + self.weights_2 = AIEBuffer.from_np(torch_to_numpy(op.weights_2)) + self.weights_3 = AIEBuffer.from_np(torch_to_numpy(op.weights_3)) + + # Allocate intermediate buffers + # left: output of gemv_1 (hidden_dim_padded) + self.left = AIEBuffer(shape=(op.hidden_dim_padded,), dtype=bfloat16) + # right: output of gemv_1 (hidden_dim_padded) + self.right = AIEBuffer(shape=(op.hidden_dim_padded,), dtype=bfloat16) + # left_swished: output of silu (hidden_dim_padded) + self.left_swished = AIEBuffer(shape=(op.hidden_dim_padded,), dtype=bfloat16) + # intermediate: output of eltwise_mul (hidden_dim_padded) + self.intermediate = AIEBuffer(shape=(op.hidden_dim_padded,), dtype=bfloat16) + + def __call__(self, input_buf, output_buf): + # Ensure inputs are on device + input_buf.to("npu") + output_buf.to("npu") + self.weights_1.to("npu") + self.weights_2.to("npu") + self.weights_3.to("npu") + self.left.to("npu") + self.right.to("npu") + self.left_swished.to("npu") + self.intermediate.to("npu") + + # Sequence: + # 1. GEMV(weights_1, input, left) + self.gemv_1_callable(self.weights_1, input_buf, self.left) + + # 2. GEMV(weights_2, input, right) + self.gemv_1_callable(self.weights_2, input_buf, self.right) + + # 3. SiLU(left, left_swished) + self.silu_callable(self.left, self.left_swished) + + # 4. EltwiseMul(left_swished, right, intermediate) + self.eltwise_mul_callable(self.left_swished, self.right, self.intermediate) + + # 5. GEMV(weights_3, intermediate, output) + self.gemv_2_callable(self.weights_3, self.intermediate, output_buf) + + +class AIESwiGLUDecode(CompositeOperator): def __init__(self, embedding_dim, hidden_dim, prio_accuracy=False, context=None): self.hidden_dim = hidden_dim @@ -57,9 +144,7 @@ def set_up_artifacts(self): tile_size_output=self.hidden_dim // 8, ) self.gemv_1 = gemv_1 - gemv_1_xclbin, gemv_1_insts = gemv_1.get_artifacts( - prefix="swiglu_decode_gemv_1_" - ) + gemv_1_xclbin, gemv_1_insts = gemv_1.get_artifacts(prefix="swiglu_gemv_1_") gemv_1_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_gemv_1", "--xclbin-kernel-id=0x901", @@ -72,31 +157,29 @@ def set_up_artifacts(self): silu = AIESiLU( size=self.hidden_dim, num_aie_columns=8, - num_channels=2, tile_size=self.hidden_dim // 16, ) self.silu = silu self.hidden_dim_padded = silu.size - silu_xclbin, silu_insts = silu.get_artifacts(prefix="swiglu_decode_silu_") + silu_xclbin, silu_insts = silu.get_artifacts(prefix="swiglu_silu_") silu_xclbin.xclbin_input = gemv_1_xclbin silu_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_silu", "--xclbin-kernel-id=0x902", ] silu_xclbin.kernel_name = "swiglu_silu" - silu_xclbin.depends += [gemv_1_xclbin] + silu_xclbin.dependencies.add(gemv_1_xclbin) artifacts.append(silu_insts) eltwise_mul = AIEElementwiseMul( size=self.hidden_dim, num_aie_columns=8, - num_channels=2, tile_size=self.hidden_dim // 8, ) self.eltwise_mul = eltwise_mul assert self.hidden_dim <= eltwise_mul.size <= self.hidden_dim_padded eltwise_mul_xclbin, eltwise_mul_insts = eltwise_mul.get_artifacts( - prefix="swiglu_decode_eltwise_mul_" + prefix="swiglu_eltwise_mul_" ) eltwise_mul_xclbin.xclbin_input = silu_xclbin eltwise_mul_xclbin.extra_flags += [ @@ -104,7 +187,7 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x903", ] eltwise_mul_xclbin.kernel_name = "swiglu_eltwise_mul" - eltwise_mul_xclbin.depends += [silu_xclbin] + eltwise_mul_xclbin.dependencies.add(silu_xclbin) artifacts.append(eltwise_mul_insts) gemv_2 = AIEGEMV( @@ -115,16 +198,14 @@ def set_up_artifacts(self): tile_size_output=self.embedding_dim // 8, ) self.gemv_2 = gemv_2 - gemv_2_xclbin, gemv_2_insts = gemv_2.get_artifacts( - prefix="swiglu_decode_gemv_2_" - ) + gemv_2_xclbin, gemv_2_insts = gemv_2.get_artifacts(prefix="swiglu_gemv_2_") gemv_2_xclbin.xclbin_input = eltwise_mul_xclbin gemv_2_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_gemv_2", "--xclbin-kernel-id=0x904", ] gemv_2_xclbin.kernel_name = "swiglu_gemv_2" - gemv_2_xclbin.depends += [eltwise_mul_xclbin] + gemv_2_xclbin.dependencies.add(eltwise_mul_xclbin) artifacts.append(gemv_2_xclbin) artifacts.append(gemv_2_insts) @@ -140,69 +221,11 @@ def set_up_artifacts(self): self.add_artifacts(artifacts) - def set_up_runtime(self): - self.add_buffer("input", self.embedding_dim) - self.add_buffer( - "weights_1", - self.embedding_dim * self.hidden_dim_padded, - static_data=torch_to_numpy(self.weights_1), - ) - self.add_buffer( - "weights_2", - self.embedding_dim * self.hidden_dim_padded, - static_data=torch_to_numpy(self.weights_2), - ) - self.add_buffer( - "weights_3", - self.hidden_dim_padded * self.embedding_dim, - static_data=torch_to_numpy(self.weights_3), - ) - self.add_buffer("left", self.hidden_dim_padded) - self.add_buffer("left_swished", self.hidden_dim_padded) - self.add_buffer("right", self.hidden_dim_padded) - self.add_buffer("intermediate", self.hidden_dim_padded) - self.add_buffer("output", self.embedding_dim) - self.add_kernel( - "swiglu_gemv_1", - self.combined_xclbin, - self.gemv_1_xclbin.kernel_name, - self.gemv_1_insts, - ) - self.add_kernel( - "swiglu_silu", - self.combined_xclbin, - self.silu_xclbin.kernel_name, - self.silu_insts, - ) - self.add_kernel( - "swiglu_eltwise_mul", - self.combined_xclbin, - self.eltwise_mul_xclbin.kernel_name, - self.eltwise_mul_insts, - ) - self.add_kernel( - "swiglu_gemv_2", - self.combined_xclbin, - self.gemv_2_xclbin.kernel_name, - self.gemv_2_insts, - ) - self.add_to_runlist("swiglu_gemv_1", "weights_1", "input", "left") - self.add_to_runlist("swiglu_gemv_1", "weights_2", "input", "right") - self.add_to_runlist("swiglu_silu", "left", "left_swished") - self.add_to_runlist( - "swiglu_eltwise_mul", "left_swished", "right", "intermediate" - ) - self.add_to_runlist("swiglu_gemv_2", "weights_3", "intermediate", "output") - - def forward(self, x): - x_flat = x.reshape(x.shape[-1]) - assert x_flat.shape[0] == self.embedding_dim - - self.write_buffer("input", x_flat) - self.run_runlist() - result = self.read_buffer_as_torch( - "output", - (self.embedding_dim,), - ).view_as(x) + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.embedding_dim,)), + AIERuntimeArgSpec("out", (self.embedding_dim,)), + ] - return result + def get_callable(self): + return SwiGLUDecodeCallable(self) diff --git a/iron/operators/swiglu_decode/test.py b/iron/operators/swiglu_decode/test.py index 11b35fa2..8d4a51d2 100755 --- a/iron/operators/swiglu_decode/test.py +++ b/iron/operators/swiglu_decode/test.py @@ -7,35 +7,30 @@ from pathlib import Path +from ml_dtypes import bfloat16 +from iron.common.base import AIEBuffer +from iron.common.utils import torch_to_numpy from iron.operators.swiglu_decode.op import AIESwiGLUDecode from iron.operators.swiglu_decode.reference import generate_golden_reference -from iron.common.test_utils import run_test, verify_buffer +from iron.common.test_utils import verify_buffer -def generate_test_params(extensive=False): - params = [(2048, 2048)] - names = [f"swiglu_decode_1x{emb}x{hid}" for emb, hid in params] - return params, names +def get_params(): + params_list = [(2048, 2048)] - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params = [] + for p in params_list: + emb, hid = p + name = f"swiglu_decode_1x{emb}x{hid}" + params.append(pytest.param(*p, id=name)) + return params @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("embedding_dim,hidden_dim", all_params) +@pytest.mark.parametrize("embedding_dim,hidden_dim", get_params()) def test_swiglu_decode(embedding_dim, hidden_dim, aie_context): golden_ref = generate_golden_reference(M=1, K=embedding_dim, N=hidden_dim) @@ -46,39 +41,32 @@ def test_swiglu_decode(embedding_dim, hidden_dim, aie_context): operator.weights_2 = golden_ref["w_up"].T operator.weights_3 = golden_ref["w_down"].T - # In the following, some buffers are commented out. - # Because this operator calls multiple kernels in sequence, rounding errors due to the smaller bf16 data type accumulate, which can cause it to fail verification. - # So, instead of verifying the final output buffers against the float32-calculated reference, we calculate another reference for the final output: - # This reference is based on the previous intermediate result read back from the AIE operator, "resetting" the accumulated error to zero. - # Note that the previous intermediate result _is_ still verified up to the given tolerance. + operator.compile() + op_func = operator.get_callable() + + input_buf = AIEBuffer.from_np(torch_to_numpy(golden_ref["input"])) + output_buf = AIEBuffer(shape=(1, embedding_dim), dtype=bfloat16) - input_buffers = {"input": golden_ref["input"]} - output_buffers = {} - intermediate_buffers = { - "left": golden_ref["left"], - "left_swished": golden_ref["left_swished"], - "right": golden_ref["right"], - "intermediate": golden_ref["intermediate"], - } + op_func(input_buf, output_buf) - errors, latency_us, bandwidth_gbps = run_test( - operator, - input_buffers, - output_buffers, - intermediate_buffers, + errors = {} + # Verify intermediate result + intermediate = op_func.intermediate.view_as_torch().reshape((1, hidden_dim)) + errors_intermediate = verify_buffer( + intermediate, + "intermediate", + golden_ref["intermediate"], rel_tol=0.07, abs_tol=0.7, ) - - ref_2 = ( - operator.read_buffer_as_torch("intermediate", (1, hidden_dim)) - @ golden_ref["w_down"] - ) - errors_2 = verify_buffer(operator, "output", ref_2, rel_tol=0.04, abs_tol=0.4) - if errors_2: - errors["output"] = errors_2 - - print(f"\nLatency (us): {latency_us:.1f}") - print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") + if errors_intermediate: + errors["intermediate"] = errors_intermediate + + # Verify output using intermediate result + ref_2 = intermediate @ golden_ref["w_down"] + output = output_buf.view_as_torch().reshape((1, embedding_dim)) + errors_output = verify_buffer(output, "output", ref_2, rel_tol=0.04, abs_tol=0.4) + if errors_output: + errors["output"] = errors_output assert not errors, f"Test failed with errors: {errors}" diff --git a/iron/operators/swiglu_prefill/op.py b/iron/operators/swiglu_prefill/op.py index 2b2aa341..d572c21f 100644 --- a/iron/operators/swiglu_prefill/op.py +++ b/iron/operators/swiglu_prefill/op.py @@ -7,7 +7,10 @@ from ml_dtypes import bfloat16 from iron.common import ( - AIEOperatorBase, + CompositeOperator, + AIERuntimeArgSpec, + AIEBuffer, + SingleXclbinCallable, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -21,7 +24,87 @@ from iron.common.utils import torch_to_numpy -class AIESwiGLUPrefill(AIEOperatorBase): +class SwiGLUPrefillCallable: + def __init__(self, op): + self.op = op + + def create_callable(sub_op, xclbin_path, kernel_name, insts_artifact): + return SingleXclbinCallable( + xclbin_path=xclbin_path, + kernel_name=kernel_name, + insts_bin_path=insts_artifact.filename, + args_spec=sub_op.get_arg_spec(), + ) + + self.gemm_1_callable = create_callable( + op.gemm_1, + op.combined_xclbin.filename, + op.gemm_1_xclbin.kernel_name, + op.gemm_1_insts, + ) + self.silu_callable = create_callable( + op.silu, + op.combined_xclbin.filename, + op.silu_xclbin.kernel_name, + op.silu_insts, + ) + self.eltwise_mul_callable = create_callable( + op.eltwise_mul, + op.combined_xclbin.filename, + op.eltwise_mul_xclbin.kernel_name, + op.eltwise_mul_insts, + ) + self.gemm_2_callable = create_callable( + op.gemm_2, + op.combined_xclbin.filename, + op.gemm_2_xclbin.kernel_name, + op.gemm_2_insts, + ) + + # Allocate and upload weights + self.weights_1 = AIEBuffer.from_np(torch_to_numpy(op.weights_1.T)) + self.weights_2 = AIEBuffer.from_np(torch_to_numpy(op.weights_2.T)) + self.weights_3 = AIEBuffer.from_np(torch_to_numpy(op.weights_3.T)) + + # Allocate intermediate buffers + # Sizes are padded + size_hidden = op.seq_len_padded * op.hidden_dim_padded + self.left = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) + self.right = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) + self.left_swished = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) + self.intermediate = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) + self.last_output_buf = None + + def __call__(self, input_buf, output_buf): + self.last_output_buf = output_buf + input_buf.to("npu") + output_buf.to("npu") + self.weights_1.to("npu") + self.weights_2.to("npu") + self.weights_3.to("npu") + self.left.to("npu") + self.right.to("npu") + self.left_swished.to("npu") + self.intermediate.to("npu") + + # Sequence: + # 1. GEMM(input, weights_1, left) + self.gemm_1_callable(input_buf, self.weights_1, self.left) + + # 2. GEMM(input, weights_2, right) + self.gemm_1_callable(input_buf, self.weights_2, self.right) + + # 3. SiLU(left, left_swished) + self.silu_callable(self.left, self.left_swished) + + # 4. EltwiseMul(left_swished, right, intermediate) + self.eltwise_mul_callable(self.left_swished, self.right, self.intermediate) + + # 5. GEMM(intermediate, weights_3, output) + self.gemm_2_callable(self.intermediate, self.weights_3, output_buf) + + +class AIESwiGLUPrefill(CompositeOperator): def __init__( self, seq_len, embedding_dim, hidden_dim, prio_accuracy=False, context=None @@ -85,7 +168,6 @@ def set_up_artifacts(self): silu = AIESiLU( size=self.seq_len_padded * self.hidden_dim_padded, num_aie_columns=8, - num_channels=2, tile_size=self.hidden_dim_padded // 8, ) self.silu = silu @@ -98,13 +180,12 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x902", ] silu_xclbin.kernel_name = "swiglu_silu" - silu_xclbin.depends += [gemm_1_xclbin] + silu_xclbin.dependencies.add(gemm_1_xclbin) artifacts.append(silu_insts) eltwise_mul = AIEElementwiseMul( size=self.seq_len_padded * self.hidden_dim_padded, num_aie_columns=8, - num_channels=2, tile_size=self.hidden_dim_padded // 8, ) self.eltwise_mul = eltwise_mul @@ -119,7 +200,7 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x903", ] eltwise_mul_xclbin.kernel_name = "swiglu_eltwise_mul" - eltwise_mul_xclbin.depends += [silu_xclbin] + eltwise_mul_xclbin.dependencies.add(silu_xclbin) artifacts.append(eltwise_mul_insts) gemm_2 = AIEGEMM( @@ -137,7 +218,7 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x904", ] gemm_2_xclbin.kernel_name = "swiglu_gemm_2" - gemm_2_xclbin.depends += [eltwise_mul_xclbin] + gemm_2_xclbin.dependencies.add(eltwise_mul_xclbin) artifacts.append(gemm_2_xclbin) artifacts.append(gemm_2_insts) @@ -153,109 +234,13 @@ def set_up_artifacts(self): self.add_artifacts(artifacts) - def set_up_runtime(self): - # Runtime setup - # --- - self.add_buffer("input", self.seq_len_padded * self.embedding_dim_padded) - self.add_buffer( - "weights_1", - self.embedding_dim_padded * self.hidden_dim_padded, - static_data=torch_to_numpy(self.weights_1.T), - ) - self.add_buffer( - "weights_2", - self.embedding_dim_padded * self.hidden_dim_padded, - static_data=torch_to_numpy(self.weights_2.T), - ) - self.add_buffer( - "weights_3", - self.hidden_dim_padded * self.embedding_dim_padded, - static_data=torch_to_numpy(self.weights_3.T), - ) - self.add_buffer("left", self.seq_len_padded * self.hidden_dim_padded) - self.add_buffer("left_swished", self.seq_len_padded * self.hidden_dim_padded) - self.add_buffer("right", self.seq_len_padded * self.hidden_dim_padded) - self.add_buffer("intermediate", self.seq_len_padded * self.hidden_dim_padded) - self.add_buffer("output", self.seq_len_padded * self.embedding_dim_padded) - self.add_kernel( - "swiglu_gemm_1", - self.combined_xclbin, - self.gemm_1_xclbin.kernel_name, - self.gemm_1_insts, - ) - self.add_kernel( - "swiglu_silu", - self.combined_xclbin, - self.silu_xclbin.kernel_name, - self.silu_insts, - ) - self.add_kernel( - "swiglu_eltwise_mul", - self.combined_xclbin, - self.eltwise_mul_xclbin.kernel_name, - self.eltwise_mul_insts, - ) - self.add_kernel( - "swiglu_gemm_2", - self.combined_xclbin, - self.gemm_2_xclbin.kernel_name, - self.gemm_2_insts, - ) - self.add_to_runlist("swiglu_gemm_1", "input", "weights_1", "left") - self.add_to_runlist("swiglu_gemm_1", "input", "weights_2", "right") - self.add_to_runlist("swiglu_silu", "left", "left_swished") - self.add_to_runlist( - "swiglu_eltwise_mul", "left_swished", "right", "intermediate" - ) - self.add_to_runlist("swiglu_gemm_2", "intermediate", "weights_3", "output") - - def forward(self, x): - """Forward pass for SwiGLU operation""" - - # Always flatten to [batch, orig_size] - original_shape = x.shape - batch = x.shape[0] if x.dim() > 1 else 1 - x_flat = x.reshape(batch, -1) - - out = self._execute_aie_operation(x_flat) - - # Restore original shape - out = out.reshape(*original_shape) - - return out - - def _execute_aie_operation(self, x): - # x is [batch, size] - batch = x.shape[0] if x.dim() > 1 else 1 - - # Flatten inputs for AIE processing - x_flat = x.view(-1) - - # Verify input size matches expected dimensions - expected_size = batch * self.seq_len * self.embedding_dim - assert x_flat.shape[0] == expected_size - - # Pad input if necessary to match GEMM requirements - if self.seq_len_padded * self.embedding_dim_padded > x_flat.shape[0]: - x_padded = torch.zeros( - self.seq_len_padded * self.embedding_dim_padded, - dtype=x_flat.dtype, - device=x_flat.device, - ) - x_padded[: x_flat.shape[0]] = x_flat - x_flat = x_padded - - self.write_buffer("input", x_flat) - self.run_runlist() - - # Read padded output buffer - result_padded = self.read_buffer_as_torch( - "output", - shape=(self.seq_len_padded * self.embedding_dim_padded,), - dtype=bfloat16, - ) - - # Extract only the unpadded portion - result = result_padded[:expected_size].view(batch, -1) + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.seq_len_padded * self.embedding_dim_padded,)), + AIERuntimeArgSpec( + "out", (self.seq_len_padded * self.embedding_dim_padded,) + ), + ] - return result + def get_callable(self): + return SwiGLUPrefillCallable(self) diff --git a/iron/operators/swiglu_prefill/test.py b/iron/operators/swiglu_prefill/test.py index 75510d63..10df9243 100755 --- a/iron/operators/swiglu_prefill/test.py +++ b/iron/operators/swiglu_prefill/test.py @@ -7,36 +7,31 @@ from pathlib import Path +from ml_dtypes import bfloat16 +from iron.common.base import AIEBuffer +from iron.common.utils import torch_to_numpy from iron.operators.swiglu_prefill.op import AIESwiGLUPrefill from iron.operators.swiglu_decode.reference import generate_golden_reference -from iron.common.test_utils import run_test, verify_buffer +from iron.common.test_utils import verify_buffer -def generate_test_params(extensive=False): +def get_params(): # This operation is currently untested except for the integrated llama application tests. - params = [] - names = [] - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) + params_list = [(256, 2048, 2048, False)] -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params = [] + for p in params_list: + _, emb, hid, _ = p + name = f"swiglu_prefill_256x{emb}x{hid}" + params.append(pytest.param(*p, id=name)) + return params @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("seq_len,embedding_dim,hidden_dim,prio_accuracy", all_params) +@pytest.mark.parametrize("seq_len,embedding_dim,hidden_dim,prio_accuracy", get_params()) def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_context): golden_ref = generate_golden_reference(M=seq_len, K=embedding_dim, N=hidden_dim) @@ -51,41 +46,36 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c operator.weights_2 = golden_ref["w_up"].T operator.weights_3 = golden_ref["w_down"].T - input_buffers = {"input": golden_ref["input"]} - # output_buffers = {'output': golden_ref['output']} - output_buffers = {} - intermediate_buffers = { - "left": golden_ref["left"], - "left_swished": golden_ref["left_swished"], - "right": golden_ref["right"], - # 'intermediate': golden_ref['intermediate'] - } - - errors, latency_us, bandwidth_gbps = run_test( - operator, - input_buffers, - output_buffers, - intermediate_buffers, - rel_tol=0.07, - abs_tol=0.7, - ) + operator.compile() + op_func = operator.get_callable() + + input_buf = AIEBuffer.from_np(torch_to_numpy(golden_ref["input"])) + output_buf = AIEBuffer( + shape=(seq_len * embedding_dim,), dtype=bfloat16 + ) # Output is flattened - ref_2 = operator.read_buffer_as_torch( - "left_swished", (seq_len, hidden_dim) - ) * operator.read_buffer_as_torch("right", (seq_len, hidden_dim)) - errors_2 = verify_buffer(operator, "intermediate", ref_2, rel_tol=0.04, abs_tol=0.4) + op_func(input_buf, output_buf) + + errors = {} + + # Verify intermediate result (left_swished * right) + left_swished = op_func.left_swished.view_as_torch().reshape((seq_len, hidden_dim)) + right = op_func.right.view_as_torch().reshape((seq_len, hidden_dim)) + ref_2 = left_swished * right + + # Note: intermediate buffer in op_func stores the result of eltwise_mul + intermediate = op_func.intermediate.view_as_torch().reshape((seq_len, hidden_dim)) + errors_2 = verify_buffer( + intermediate, "intermediate", ref_2, rel_tol=0.04, abs_tol=0.4 + ) if errors_2: errors["intermediate"] = errors_2 - ref_3 = ( - operator.read_buffer_as_torch("intermediate", (seq_len, hidden_dim)) - @ golden_ref["w_down"] - ) - errors_3 = verify_buffer(operator, "output", ref_3, rel_tol=0.04, abs_tol=0.4) + # Verify output using intermediate result + ref_3 = intermediate @ golden_ref["w_down"] + output = output_buf.view_as_torch().reshape((seq_len, embedding_dim)) + errors_3 = verify_buffer(output, "output", ref_3, rel_tol=0.04, abs_tol=0.4) if errors_3: - errors["output"] = errors_2 - - print(f"\nLatency (us): {latency_us:.1f}") - print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") + errors["output"] = errors_3 assert not errors, f"Test failed with errors: {errors}" diff --git a/iron/operators/tanh/design.py b/iron/operators/tanh/design.py index 0f78fc92..c3e0acad 100644 --- a/iron/operators/tanh/design.py +++ b/iron/operators/tanh/design.py @@ -14,7 +14,9 @@ from aie.iron.controlflow import range_ -def my_tanh(dev, size, num_columns, num_channels, tile_size, trace_size): +def my_tanh( + dev, size, num_columns, num_channels, tile_size, trace_size, kernel_archive=None +): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] diff --git a/iron/operators/tanh/op.py b/iron/operators/tanh/op.py index 5bccad5e..2a0233aa 100644 --- a/iron/operators/tanh/op.py +++ b/iron/operators/tanh/op.py @@ -7,8 +7,8 @@ from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -17,14 +17,17 @@ ) -class AIETanh(AIEOperatorBase): +class AIETanh(MLIROperator): """AIE-accelerated Tanh activation function""" def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_columns = num_aie_columns @@ -33,17 +36,15 @@ def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None) total_shimdma_channels = self.num_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"tanh_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"tanh_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_tanh", callback_args=[ @@ -56,59 +57,20 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"tanh.o", - depends=[ - SourceArtifact.new( - self.context.base_dir / "aie_kernels" / "aie2p" / "tanh.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "tanh", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("tanh", "input", "output") - - def forward(self, x): - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIETanh: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"tanh.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "aie2p" / "tanh.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/iron/operators/tanh/test.py b/iron/operators/tanh/test.py index f9986bb3..0a50b183 100755 --- a/iron/operators/tanh/test.py +++ b/iron/operators/tanh/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 1 # 1 channel for 1 input - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns @@ -26,24 +25,22 @@ def generate_test_params(extensive=False): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - names.append( - f"tanh_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + name = f"tanh_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] + + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) ) - params.append((input_length, num_aie_columns, num_channels, tile_size)) - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -52,7 +49,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_tanh(input_length, num_aie_columns, num_channels, tile_size, aie_context): golden_ref = generate_golden_reference(input_length=input_length) diff --git a/iron/operators/transpose/design.py b/iron/operators/transpose/design.py index 7a53365a..03fad5d3 100644 --- a/iron/operators/transpose/design.py +++ b/iron/operators/transpose/design.py @@ -2,20 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 from ml_dtypes import bfloat16 -from pathlib import Path import numpy as np -import argparse -import sys from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker from aie.iron.placers import SequentialPlacer -from aie.iron.device import NPU1, NPU2 from aie.helpers.taplib.tap import TensorAccessPattern from aie.iron.controlflow import range_ -from aie.helpers.util import np_ndarray_type_get_shape -def shuffle_transpose(dev, M, N, num_columns, num_channels, trace_size, m, n, s): +def shuffle_transpose( + dev, M, N, num_columns, num_channels, m, n, s, kernel_archive=None, func_prefix="" +): num_elements = M * N per_tile_elements = m * n dtype = bfloat16 @@ -103,8 +100,10 @@ def shuffle_transpose(dev, M, N, num_columns, num_channels, trace_size, m, n, s) ] # AIE Core Function declaration + if kernel_archive is None: + kernel_archive = f"transpose_{s}x{s}.a" transpose_kernel = Kernel( - f"transpose_{s}x{s}", f"transpose_{m}x{n}.o", [tile_ty, tile_ty] + f"{func_prefix}transpose_{s}x{s}", kernel_archive, [tile_ty, tile_ty] ) # Define a task that will run on a compute tile @@ -163,115 +162,3 @@ def core_body(of_in1, of_out, transpose_kernel): # Place program components (assign them resources on the device) and generate an MLIR module return Program(dev, rt).resolve_program(SequentialPlacer()) - - -if __name__ == "__main__": - - def str_to_device(device: str): - if device == "npu": - return NPU1() - elif device == "npu2": - return NPU2() - else: - raise ValueError(f"Device name {device} is unknown.") - - p = argparse.ArgumentParser() - # Parse command line arguments - - # Device name is required to select the AIE device: npu or npu2 - p.add_argument( - "-d", - "--dev", - required=True, - dest="device", - help="AIE Device", - type=str_to_device, - ) - # Transfer size is required to define the size of the data to be transferred - p.add_argument( - "-M", "--workload-rows", required=True, dest="work_rows", help="Number of rows" - ) - p.add_argument( - "-N", - "--workload-columns", - required=True, - dest="work_cols", - help="Number of columns", - ) - # Number of columns is required to define the number of columns to be used - # It must be less than or equal to 4 for npu and 8 for npu2 - p.add_argument( - "-co", "--columns", required=True, dest="cols", help="Number of columns" - ) - # Number of channels is required to define the number of channels to be used - # It must be 1 or 2 - p.add_argument( - "-ch", "--channels", required=True, dest="chans", help="Number of channels" - ) - # Tile size - p.add_argument( - "-m", "--tile-rows", required=True, dest="tile_rows", help="Outer tile rows" - ) - p.add_argument( - "-n", - "--tile-columns", - required=True, - dest="tile_cols", - help="Outer tile columns", - ) - p.add_argument( - "-s", - "--kernel-dim", - required=True, - choices=["4", "8"], - dest="kernel_dim", - help="Inner tile dimension (square)", - ) - # Trace Size - p.add_argument( - "-tr", "--trace-size", required=True, dest="trace_size", help="Trace size" - ) - p.add_argument( - "--output-file-path", - "-o", - type=str, - help="Output file path for the generated MLIR module", - ) - - opts = p.parse_args(sys.argv[1:]) - - M = int(opts.work_rows) - N = int(opts.work_cols) - columns = int(opts.cols) - - dev = opts.device # Already a device object from str_to_device - - # Validate columns based on device type - if isinstance(dev, NPU1) and columns > 4: - raise ValueError("[ERROR] Device NPU cannot allocate more than 4 columns") - elif isinstance(dev, NPU2) and columns > 8: - raise ValueError("[ERROR] Device NPU2 cannot allocate more than 8 columns") - - channels = int(opts.chans) - if channels < 1 or channels > 2: - raise ValueError("Number of channels must be 1 or 2") - m = int(opts.tile_rows) - n = int(opts.tile_cols) - s = int(opts.kernel_dim) - if (((M * N) % (m * n)) % columns % channels) != 0: - print( - "transfer size (" - + str(M * N) - + ") must be a multiple of " - + str(m * n) - + f" and divisible by the number of columns ({columns}) and {channels} channels per column" - ) - raise ValueError - trace_size = int(opts.trace_size) if opts.trace_size is not None else 0 - - module = shuffle_transpose(dev, M, N, columns, channels, trace_size, m, n, s) - - output_file_path = Path(opts.output_file_path) - - with open(output_file_path, "w") as f: - f.write(str(module)) diff --git a/iron/operators/transpose/op.py b/iron/operators/transpose/op.py index 7963fd06..83c7891e 100644 --- a/iron/operators/transpose/op.py +++ b/iron/operators/transpose/op.py @@ -1,14 +1,11 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import torch -import numpy as np -from ml_dtypes import bfloat16 from pathlib import Path from iron.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + MLIROperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -17,37 +14,35 @@ ) -class AIETranspose(AIEOperatorBase): +class AIETranspose(MLIROperator): """AIE-accelerated transpose operator""" def __init__(self, M, N, num_aie_columns, num_channels, m, n, s, context=None): + assert M % m == 0, f"Matrix rows ({M}) must be a multiple of {m}" + assert N % n == 0, f"Matrix columns ({N}) must be a multiple of {n}" + assert m % s == 0, f"AIE tile rows ({m}) must be a multiple of {s}" + assert n % s == 0, f"AIE tile columns ({n}) must be a multiple of {s}" + assert ( + M * N % (m * n * num_aie_columns * num_channels) == 0 + ), "Transfer size must be divisible by m*n*num_columns*num_channels" + self.M = M self.N = N self.m = m self.n = n self.s = s - self.size = M * N - self.tile_size = m * n - self.num_columns = num_aie_columns self.num_channels = num_channels - total_shimdma_channels = self.num_columns * self.num_channels - if 1 > 1: - total_shimdma_channels *= 1 - assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - - self.xclbin_artifact = None - self.insts_artifact = None + MLIROperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"transpose_{self.num_columns}c_{self.num_channels}ch_{self.M}x{self.N}_{self.m}x{self.n}_{self.s}s" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"transpose_{self.num_columns}c_{self.num_channels}ch_{self.M}x{self.N}_{self.m}x{self.n}_{self.s}s" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="shuffle_transpose", callback_args=[ @@ -56,73 +51,33 @@ def set_up_artifacts(self): self.N, self.num_columns, self.num_channels, - 0, self.m, self.n, self.s, ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"transpose_{self.m}x{self.n}.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "generic" - / "transpose.cc" - ) - ], - extra_flags=[ - f"-DDIM_m={self.m}", - f"-DDIM_n={self.n}", - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "transpose", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("transpose", "input", "output") - - def forward(self, x): - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIETranspose: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"transpose_{self.m}x{self.n}.o", + dependencies=[ + SourceArtifact( + self.context.base_dir + / "aie_kernels" + / "generic" + / "transpose.cc" + ) + ], + extra_flags=[ + f"-DDIM_m={self.m}", + f"-DDIM_n={self.n}", + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.M * self.N,)), # input + AIERuntimeArgSpec("out", (self.M * self.N,)), # output (transposed) + ] diff --git a/iron/operators/transpose/test.py b/iron/operators/transpose/test.py index 8f0d9981..00cf562b 100755 --- a/iron/operators/transpose/test.py +++ b/iron/operators/transpose/test.py @@ -12,16 +12,15 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): - params = [] - names = [] +def get_params(): max_aie_columns = 8 - input_lengths = [2048] if not extensive else [64, 2048] - n_list = [64] if not extensive else [64, 128, 256, 512] + input_lengths = [64, 2048] + n_list = [64, 128, 256, 512] s_list = [8] m = 64 n = 64 + params = [] for M in input_lengths: for N in n_list: for s in s_list: @@ -37,32 +36,33 @@ def generate_test_params(extensive=False): length = M * N if check_length != length: continue - names.append( - f"transpose_{M}_M_{N}_N_{num_aie_columns}_cols_{num_channels}_channels_{m}_m_{n}_n_{s}_s" + name = f"transpose_{M}_M_{N}_N_{num_aie_columns}_cols_{num_channels}_channels_{m}_m_{n}_n_{s}_s" + + is_regular = M == 2048 and N == 64 + marks = [] if is_regular else [pytest.mark.extensive] + + params.append( + pytest.param( + M, + N, + num_aie_columns, + num_channels, + m, + n, + s, + id=name, + marks=marks, + ) ) - params.append((M, N, num_aie_columns, num_channels, m, n, s)) - - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("M,N,aie_columns,channels,m,n,s", all_params) +@pytest.mark.parametrize("M,N,aie_columns,channels,m,n,s", get_params()) def test_transpose(M, N, aie_columns, channels, m, n, s, aie_context): golden_ref = generate_golden_reference(rows=M, cols=N)