[MLAS][KleidiAI] Add Arm64 fp16 MLAS and KleidiAI support#28487
[MLAS][KleidiAI] Add Arm64 fp16 MLAS and KleidiAI support#28487Laan33 wants to merge 10 commits into
Conversation
b443302 to
97afce3
Compare
There was a problem hiding this comment.
Pull request overview
Enable opt-in native CPU fp16 execution on Arm64 by wiring MLAS and KleidiAI fp16 GEMM/Conv paths through the CPU EP, adding backend-selection config plumbing, and expanding coverage with targeted unit tests.
Changes:
- Add CPU EP fp16 MatMul/Gemm kernels, native packed-B handling, and fp16 conv “HalfConv” (KleidiAI) routing/prepack.
- Extend InsertCastTransformer and session config to preserve eligible fp16 CPU nodes when
session.enable_cpu_fp16is enabled (with an fp32 fallback heuristic by default). - Add/extend tests for fp16 MatMul/Gemm/HalfGemm/HalfConv behavior, backend selector propagation, zero-K behavior, and prepack-sharing rules.
Reviewed changes
Copilot reviewed 37 out of 37 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | Adds fp16 Conv tests covering KleidiAI eligibility, disabling KleidiAI via session config, and prepack sharing scenarios. |
| onnxruntime/test/providers/cpu/math/matmul_test.cc | Adds fp16 MatMul test validating native packed weights are not shared across sessions. |
| onnxruntime/test/providers/cpu/math/gemm_test.cc | Adds fp16 Gemm tests for beta==0 semantics when C is missing/present. |
| onnxruntime/test/mlas/unittest/test_util.h | Fixes guard-buffer state reset by clearing _GuardAddress. |
| onnxruntime/test/mlas/unittest/test_halfgemm.h | Extends HalfGemm tests to validate native-fp16 vs fp32 reference, output-processor bypass, and override enforcement. |
| onnxruntime/test/mlas/unittest/test_halfgemm.cpp | Adds extensive HalfGemm regression tests: selector config behavior, zero-K behavior, pack-B overflow/padding, packed-B fallback behavior, and KleidiAI packed-B contract tests. |
| onnxruntime/test/mlas/unittest/test_conv2d.cpp | Adds test ensuring HalfConv prepare honors backend selector config. |
| onnxruntime/test/framework/insert_cast_transformer_test.cc | Adds tests for CPU fp16 opt-in behavior, heuristic routing, mixed-EP behavior, and runtime execution checks. |
| onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc | Excludes non-owning EPs for internal-domain NHWC fp16 pool tests. |
| onnxruntime/core/util/math_cpu.cc | Adds MatMul<MLFloat16> specialization using MLAS halfgemm when fp16 intrinsics are available. |
| onnxruntime/core/session/inference_session.cc | Wires session config into InsertCastTransformer and aligns transformer heuristics with MLAS backend selector config. |
| onnxruntime/core/providers/cpu/math/matmul.h | Adds CPU EP fp16 MatMul kernel class with native packed-B prepack support and backend selector config. |
| onnxruntime/core/providers/cpu/math/matmul.cc | Registers fp16 MatMul kernels across opsets; implements native packed-B prepack and fp16 execution via halfgemm/hgemm. |
| onnxruntime/core/providers/cpu/math/gemm.cc | Routes fp16 Gemm through MLAS halfgemm for supported cases, including no-bias beta==0, and passes backend selector config. |
| onnxruntime/core/providers/cpu/fp16/fp16_conv.cc | Adds KleidiAI HalfConv fast path, packs weights+bias for eligible convs, and propagates backend selector config to MLAS calls. |
| onnxruntime/core/providers/cpu/cpu_execution_provider.cc | Registers fp16 Gemm/MatMul kernels in the CPU EP kernel registry. |
| onnxruntime/core/optimizer/insert_cast_transformer.h | Extends InsertCastTransformer API to accept cpu registry list, cpu-fp16 enablement, heuristic control, and MLAS backend selector config. |
| onnxruntime/core/optimizer/insert_cast_transformer.cc | Implements CPU fp16 preservation policy + fallback heuristic for MatMul/Gemm; improves kernel availability checks using resolved type constraints. |
| onnxruntime/core/mlas/lib/platform.cpp | Enables KleidiAI fp16 halfgemm/halfconv overrides on SME/SME2 capable systems. |
| onnxruntime/core/mlas/lib/mlasi.h | Adds MLAS platform override function pointer types for halfgemm and halfconv. |
| onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp | Renames parameter to BatchN and updates corresponding logic for clarity/consistency. |
| onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h | Declares KleidiAI fp16 halfgemm/halfconv entry points and adds checked TryAddSize helper. |
| onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp | Adds KleidiAI fp16 halfgemm implementation with RHS packing, fp32 input conversion, and contract validation. |
| onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp | Adds KleidiAI fp16 halfconv implementation with IMATMUL packing/execution and locality heuristics. |
| onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp | Fixes cache access to avoid stale reference and uses per-pad cache directly. |
| onnxruntime/core/mlas/lib/kai_ukernel_interface.h | Adds FP16 IMATMUL and FP16 HGEMM ukernel wrapper types and accessors. |
| onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp | Wires SME/SME2 FP16 IMATMUL and FP16 HGEMM ukernel selection. |
| onnxruntime/core/mlas/lib/halfgemm.h | Adds overflow-safe helpers and implements CopyPackB with aligned-K zero padding. |
| onnxruntime/core/mlas/lib/halfgemm.cpp | Extends halfgemm batch API to accept backend selector config, adds zero-K handling, native pack-B APIs, and backend-native packed-B guard. |
| onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp | Enables CopyPackB routine in the NEON halfgemm dispatch table. |
| onnxruntime/core/mlas/lib/halfconv.cpp | Adds public MLAS dispatch wrappers for optional halfconv backend support. |
| onnxruntime/core/mlas/inc/mlas.h | Enables fp16 intrinsics on Apple Arm64, extends halfgemm params with packed-B flags, and adds public halfconv + native pack-B APIs. |
| onnxruntime/core/framework/session_state.cc | Allows kernels to prepack session-local weights without requiring shareable prepack buffers. |
| onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc | Passes MLAS backend selector config into fp16 MoE GEMM dispatch. |
| include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h | Adds session config keys for opt-in CPU fp16 and the fp32 fallback heuristic control. |
| docs/OperatorKernels.md | Documents fp16 CPU kernel coverage for Gemm and MatMul. |
| cmake/onnxruntime_mlas.cmake | Adds halfconv sources and enables Arm64 fp16 MLAS sources on Apple; wires new KleidiAI sources. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| template <> | ||
| void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, ThreadPool* threadpool, | ||
| const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) { | ||
| #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED | ||
| MLAS_HALF_GEMM_DATA_PARAMS data; | ||
| data.A = A; | ||
| data.lda = static_cast<size_t>(K); | ||
| data.B = B; | ||
| data.ldb = static_cast<size_t>(N); | ||
| data.C = C; | ||
| data.ldc = static_cast<size_t>(N); | ||
| MlasHalfGemmBatch(static_cast<size_t>(M), static_cast<size_t>(N), static_cast<size_t>(K), 1, &data, threadpool, | ||
| mlas_backend_kernel_selector_config); |
| if (trans_a == CblasNoTrans && trans_b == CblasNoTrans && | ||
| alpha == onnxruntime::MLFloat16::One && (use_mlas_no_bias || use_mlas_bias)) { | ||
| MLAS_HALF_GEMM_DATA_PARAMS data; | ||
| data.A = a_data; | ||
| data.lda = K; | ||
| data.B = b_data; | ||
| data.ldb = N; | ||
| data.C = y_data; | ||
| data.ldc = N; | ||
| if (c_shape != nullptr) { | ||
| if (use_mlas_bias && c_shape != nullptr) { | ||
| data.Bias = c_data; | ||
| } | ||
| MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool); | ||
| MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool, mlas_backend_kernel_selector_config); | ||
| return; |
| template <> | ||
| Status MoE<MLFloat16>::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFloat16* C, | ||
| int64_t M, int64_t K, int64_t N, bool transpose_B) const { | ||
| MLAS_HALF_GEMM_DATA_PARAMS params; | ||
| params.A = A; | ||
| params.lda = static_cast<size_t>(K); | ||
| params.C = C; | ||
| params.ldc = static_cast<size_t>(N); | ||
| params.AIsfp32 = false; | ||
| params.BIsfp32 = false; | ||
| params.B = B; | ||
|
|
||
| if (transpose_B) { | ||
| params.ldb = static_cast<size_t>(K); | ||
| } else { | ||
| params.ldb = static_cast<size_t>(N); | ||
| } | ||
|
|
||
| MlasHalfGemmBatch(static_cast<size_t>(M), static_cast<size_t>(N), static_cast<size_t>(K), 1, ¶ms, nullptr); | ||
| MlasHalfGemmBatch(static_cast<size_t>(M), static_cast<size_t>(N), static_cast<size_t>(K), 1, ¶ms, nullptr, | ||
| &mlas_backend_kernel_selector_config_); | ||
| return Status::OK(); |
| MLAS_HALF_GEMM_DATA_PARAMS data; | ||
| data.A = A.data(); | ||
| data.B = B.data(); | ||
| data.C = reinterpret_cast<MLAS_FP16*>(C.data()); | ||
| data.lda = K; | ||
| data.ldb = N; | ||
| data.ldc = N; | ||
|
|
| MlasHalfGemmBatch( | ||
| static_cast<size_t>(output_count), | ||
| static_cast<size_t>(group_output_channels), | ||
| static_cast<size_t>(kernel_dim), | ||
| 1, &gemm_params, nullptr); | ||
| 1, &gemm_params, nullptr, &mlas_backend_kernel_selector_config_); | ||
| } |
Add native fp16 HGEMM and halfconv IMATMUL support for Arm64 MLAS/KleidiAI, including CPU fp16 MatMul/Gemm exposure, prepack handling, backend selector routing, and focused test coverage. Also include Apple Arm64 fp16 build enablement, insert-cast mixed-EP coverage, and fixes for fp16 routing edge cases such as zero-K halfgemm handling and backend-native packed-B contracts. Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Ensure CPU-assigned fp16 nodes without a matching CPU fp16 kernel are routed through the existing fp32 cast fallback when a float CPU kernel exists. This preserves main-compatible behavior for fused nodes such as BiasGelu. Also split the graph-output cast coverage so the fallback behavior is easier to review. Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Native ARM64 fp16 execution can now exercise provider paths that were previously hidden by fp32 fallback behaviour. Update the affected tests to reflect supported provider coverage and the expected numerical drift from HQNBIT_CompFp16 fp16 accumulation on large-K MatMulNBits cases. Signed-off-by: Cathal Lawlor cathal.lawlor@arm.com
e0d814e to
2483990
Compare
Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com> # Conflicts: # onnxruntime/core/optimizer/insert_cast_transformer.cc # onnxruntime/core/optimizer/insert_cast_transformer.h # onnxruntime/core/session/inference_session.cc
Value-initialize MLAS_HALF_GEMM_DATA_PARAMS at call sites so optional flags and pointers default to null/false unless explicitly set. Include <cstring> directly in matmul.cc for the native fp16 prepack memset calls. Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
| #include "test/common/random_generator.h" | ||
| #include "test/providers/provider_test_utils.h" | ||
| #include "test/util/include/scoped_env_vars.h" | ||
| #include "default_providers.h" | ||
| #include "core/session/onnxruntime_session_options_config_keys.h" | ||
|
|
| SessionOptions so; | ||
| ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); | ||
| ASSERT_EQ(so.config_options.AddConfigEntry("session.enable_cpu_fp16", "1"), Status::OK()); | ||
| ASSERT_EQ(so.config_options.AddConfigEntry("session.cpu_fp16_use_fp32_fallback_heuristic", "0"), Status::OK()); | ||
|
|
| const auto SetTypeConstraint = [&](const std::string& type_str, const NodeArg* def) { | ||
| if (!def || !def->Exists()) { | ||
| return; | ||
| } | ||
|
|
||
| TypeConstraintMap::const_iterator it = type_schema.find(type_str); | ||
| if (it == type_schema.end()) { | ||
| return; | ||
| } | ||
|
|
||
| auto type = DataTypeImpl::TypeFromProto(*(def->TypeAsProto())); | ||
| if (replace_fp16_with_float && type == DataTypeImpl::GetTensorType<MLFloat16>()) { | ||
| type = DataTypeImpl::GetTensorType<float>(); | ||
| } | ||
|
|
||
| type_constraint_map[type_str] = type; |
…handling and compatibility checks; add CPU FP16 resize model test Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
…GPU CI Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Review — PR #28487
|
|
Thank you for the review, we will see what we will go forwards with and keep you up to date @hariharans29 |
Description
Enable selected native CPU fp16 execution paths on Arm64 through MLAS and KleidiAI.
This adds Arm64 fp16 HGEMM/MatMul/Gemm coverage, KleidiAI halfconv IMATMUL support, native prepack handling, backend selector routing, and the CPU EP plumbing needed to preserve profitable fp16 nodes when
session.enable_cpu_fp16is enabled. Unsupported or unprofitable cases continue to use the existing fp32 fallback path via the default heuristic.This also includes Apple Arm64 fp16 build enablement, mixed-EP insert-cast coverage, and fixes for fp16 routing edge cases such as zero-K halfgemm handling, backend-native packed-B contracts, and
mlas.disable_kleidiaipropagation.Motivation and Context
ORT CPU EP previously converted most fp16 CPU work back through fp32 fallback paths, which prevented Arm64 fp16 MLAS/KleidiAI kernels from being used by MatMul/Gemm and eligible convolution shapes.
The goal is to expose native fp16 execution where it is supported and profitable, while keeping the default policy conservative so existing fp32 fallback behaviour is preserved for cast-heavy, unsupported, or slower shapes.
Usage
Native CPU fp16 execution is opt-in.
Set the session config entry below to allow the CPU EP to preserve supported fp16 nodes instead of always falling back through fp32 casts:
By default, the fp16 fallback heuristic remains enabled. This means only supported/profitable fp16 MatMul/Gemm and eligible convolution paths are preserved natively; other fp16 CPU nodes continue to use the existing fp32 fallback path.
For diagnostic or benchmarking use only, the fallback heuristic can be disabled:
KleidiAI can be disabled independently through the MLAS backend selector config:
Validation
Validation performed locally on Arm64:
onnxruntime_test_allandonnxruntime_provider_test.