Skip to content

[MLAS][KleidiAI] Add Arm64 fp16 MLAS and KleidiAI support#28487

Open
Laan33 wants to merge 10 commits into
microsoft:mainfrom
Laan33:catlaw01/add-fp16-path
Open

[MLAS][KleidiAI] Add Arm64 fp16 MLAS and KleidiAI support#28487
Laan33 wants to merge 10 commits into
microsoft:mainfrom
Laan33:catlaw01/add-fp16-path

Conversation

@Laan33
Copy link
Copy Markdown
Contributor

@Laan33 Laan33 commented May 13, 2026

Description

Enable selected native CPU fp16 execution paths on Arm64 through MLAS and KleidiAI.

This adds Arm64 fp16 HGEMM/MatMul/Gemm coverage, KleidiAI halfconv IMATMUL support, native prepack handling, backend selector routing, and the CPU EP plumbing needed to preserve profitable fp16 nodes when session.enable_cpu_fp16 is enabled. Unsupported or unprofitable cases continue to use the existing fp32 fallback path via the default heuristic.

This also includes Apple Arm64 fp16 build enablement, mixed-EP insert-cast coverage, and fixes for fp16 routing edge cases such as zero-K halfgemm handling, backend-native packed-B contracts, and mlas.disable_kleidiai propagation.

Motivation and Context

ORT CPU EP previously converted most fp16 CPU work back through fp32 fallback paths, which prevented Arm64 fp16 MLAS/KleidiAI kernels from being used by MatMul/Gemm and eligible convolution shapes.

The goal is to expose native fp16 execution where it is supported and profitable, while keeping the default policy conservative so existing fp32 fallback behaviour is preserved for cast-heavy, unsupported, or slower shapes.

Usage

Native CPU fp16 execution is opt-in.

Set the session config entry below to allow the CPU EP to preserve supported fp16 nodes instead of always falling back through fp32 casts:

session.enable_cpu_fp16|1

By default, the fp16 fallback heuristic remains enabled. This means only supported/profitable fp16 MatMul/Gemm and eligible convolution paths are preserved natively; other fp16 CPU nodes continue to use the existing fp32 fallback path.

For diagnostic or benchmarking use only, the fallback heuristic can be disabled:

session.enable_cpu_fp16|1 session.cpu_fp16_use_fp32_fallback_heuristic|0

KleidiAI can be disabled independently through the MLAS backend selector config:

mlas.disable_kleidiai|1

Validation

Validation performed locally on Arm64:

  • Built onnxruntime_test_all and onnxruntime_provider_test.
  • Ran focused CPU fp16 transformer/cast tests.
  • Ran focused fp16 MatMul/Gemm/prepack tests.
  • Ran focused fp16 Conv/KleidiAI halfconv tests.
  • Ran generated-input fp32-vs-fp16 parity checks across fp16 model pairs.
  • Ran model-level CPU fp16 performance sweeps comparing disabled, default, native-forced, and KleidiAI-disabled configurations.

@Laan33 Laan33 force-pushed the catlaw01/add-fp16-path branch 2 times, most recently from b443302 to 97afce3 Compare May 19, 2026 09:25
@hariharans29 hariharans29 requested a review from Copilot May 19, 2026 16:04
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Enable opt-in native CPU fp16 execution on Arm64 by wiring MLAS and KleidiAI fp16 GEMM/Conv paths through the CPU EP, adding backend-selection config plumbing, and expanding coverage with targeted unit tests.

Changes:

  • Add CPU EP fp16 MatMul/Gemm kernels, native packed-B handling, and fp16 conv “HalfConv” (KleidiAI) routing/prepack.
  • Extend InsertCastTransformer and session config to preserve eligible fp16 CPU nodes when session.enable_cpu_fp16 is enabled (with an fp32 fallback heuristic by default).
  • Add/extend tests for fp16 MatMul/Gemm/HalfGemm/HalfConv behavior, backend selector propagation, zero-K behavior, and prepack-sharing rules.

Reviewed changes

Copilot reviewed 37 out of 37 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc Adds fp16 Conv tests covering KleidiAI eligibility, disabling KleidiAI via session config, and prepack sharing scenarios.
onnxruntime/test/providers/cpu/math/matmul_test.cc Adds fp16 MatMul test validating native packed weights are not shared across sessions.
onnxruntime/test/providers/cpu/math/gemm_test.cc Adds fp16 Gemm tests for beta==0 semantics when C is missing/present.
onnxruntime/test/mlas/unittest/test_util.h Fixes guard-buffer state reset by clearing _GuardAddress.
onnxruntime/test/mlas/unittest/test_halfgemm.h Extends HalfGemm tests to validate native-fp16 vs fp32 reference, output-processor bypass, and override enforcement.
onnxruntime/test/mlas/unittest/test_halfgemm.cpp Adds extensive HalfGemm regression tests: selector config behavior, zero-K behavior, pack-B overflow/padding, packed-B fallback behavior, and KleidiAI packed-B contract tests.
onnxruntime/test/mlas/unittest/test_conv2d.cpp Adds test ensuring HalfConv prepare honors backend selector config.
onnxruntime/test/framework/insert_cast_transformer_test.cc Adds tests for CPU fp16 opt-in behavior, heuristic routing, mixed-EP behavior, and runtime execution checks.
onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc Excludes non-owning EPs for internal-domain NHWC fp16 pool tests.
onnxruntime/core/util/math_cpu.cc Adds MatMul<MLFloat16> specialization using MLAS halfgemm when fp16 intrinsics are available.
onnxruntime/core/session/inference_session.cc Wires session config into InsertCastTransformer and aligns transformer heuristics with MLAS backend selector config.
onnxruntime/core/providers/cpu/math/matmul.h Adds CPU EP fp16 MatMul kernel class with native packed-B prepack support and backend selector config.
onnxruntime/core/providers/cpu/math/matmul.cc Registers fp16 MatMul kernels across opsets; implements native packed-B prepack and fp16 execution via halfgemm/hgemm.
onnxruntime/core/providers/cpu/math/gemm.cc Routes fp16 Gemm through MLAS halfgemm for supported cases, including no-bias beta==0, and passes backend selector config.
onnxruntime/core/providers/cpu/fp16/fp16_conv.cc Adds KleidiAI HalfConv fast path, packs weights+bias for eligible convs, and propagates backend selector config to MLAS calls.
onnxruntime/core/providers/cpu/cpu_execution_provider.cc Registers fp16 Gemm/MatMul kernels in the CPU EP kernel registry.
onnxruntime/core/optimizer/insert_cast_transformer.h Extends InsertCastTransformer API to accept cpu registry list, cpu-fp16 enablement, heuristic control, and MLAS backend selector config.
onnxruntime/core/optimizer/insert_cast_transformer.cc Implements CPU fp16 preservation policy + fallback heuristic for MatMul/Gemm; improves kernel availability checks using resolved type constraints.
onnxruntime/core/mlas/lib/platform.cpp Enables KleidiAI fp16 halfgemm/halfconv overrides on SME/SME2 capable systems.
onnxruntime/core/mlas/lib/mlasi.h Adds MLAS platform override function pointer types for halfgemm and halfconv.
onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp Renames parameter to BatchN and updates corresponding logic for clarity/consistency.
onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h Declares KleidiAI fp16 halfgemm/halfconv entry points and adds checked TryAddSize helper.
onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp Adds KleidiAI fp16 halfgemm implementation with RHS packing, fp32 input conversion, and contract validation.
onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp Adds KleidiAI fp16 halfconv implementation with IMATMUL packing/execution and locality heuristics.
onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp Fixes cache access to avoid stale reference and uses per-pad cache directly.
onnxruntime/core/mlas/lib/kai_ukernel_interface.h Adds FP16 IMATMUL and FP16 HGEMM ukernel wrapper types and accessors.
onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp Wires SME/SME2 FP16 IMATMUL and FP16 HGEMM ukernel selection.
onnxruntime/core/mlas/lib/halfgemm.h Adds overflow-safe helpers and implements CopyPackB with aligned-K zero padding.
onnxruntime/core/mlas/lib/halfgemm.cpp Extends halfgemm batch API to accept backend selector config, adds zero-K handling, native pack-B APIs, and backend-native packed-B guard.
onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp Enables CopyPackB routine in the NEON halfgemm dispatch table.
onnxruntime/core/mlas/lib/halfconv.cpp Adds public MLAS dispatch wrappers for optional halfconv backend support.
onnxruntime/core/mlas/inc/mlas.h Enables fp16 intrinsics on Apple Arm64, extends halfgemm params with packed-B flags, and adds public halfconv + native pack-B APIs.
onnxruntime/core/framework/session_state.cc Allows kernels to prepack session-local weights without requiring shareable prepack buffers.
onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc Passes MLAS backend selector config into fp16 MoE GEMM dispatch.
include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h Adds session config keys for opt-in CPU fp16 and the fp32 fallback heuristic control.
docs/OperatorKernels.md Documents fp16 CPU kernel coverage for Gemm and MatMul.
cmake/onnxruntime_mlas.cmake Adds halfconv sources and enables Arm64 fp16 MLAS sources on Apple; wires new KleidiAI sources.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +193 to +205
template <>
void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, ThreadPool* threadpool,
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) {
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
MLAS_HALF_GEMM_DATA_PARAMS data;
data.A = A;
data.lda = static_cast<size_t>(K);
data.B = B;
data.ldb = static_cast<size_t>(N);
data.C = C;
data.ldc = static_cast<size_t>(N);
MlasHalfGemmBatch(static_cast<size_t>(M), static_cast<size_t>(N), static_cast<size_t>(K), 1, &data, threadpool,
mlas_backend_kernel_selector_config);
Comment on lines +220 to 233
if (trans_a == CblasNoTrans && trans_b == CblasNoTrans &&
alpha == onnxruntime::MLFloat16::One && (use_mlas_no_bias || use_mlas_bias)) {
MLAS_HALF_GEMM_DATA_PARAMS data;
data.A = a_data;
data.lda = K;
data.B = b_data;
data.ldb = N;
data.C = y_data;
data.ldc = N;
if (c_shape != nullptr) {
if (use_mlas_bias && c_shape != nullptr) {
data.Bias = c_data;
}
MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool);
MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool, mlas_backend_kernel_selector_config);
return;
Comment on lines 536 to 556
template <>
Status MoE<MLFloat16>::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFloat16* C,
int64_t M, int64_t K, int64_t N, bool transpose_B) const {
MLAS_HALF_GEMM_DATA_PARAMS params;
params.A = A;
params.lda = static_cast<size_t>(K);
params.C = C;
params.ldc = static_cast<size_t>(N);
params.AIsfp32 = false;
params.BIsfp32 = false;
params.B = B;

if (transpose_B) {
params.ldb = static_cast<size_t>(K);
} else {
params.ldb = static_cast<size_t>(N);
}

MlasHalfGemmBatch(static_cast<size_t>(M), static_cast<size_t>(N), static_cast<size_t>(K), 1, &params, nullptr);
MlasHalfGemmBatch(static_cast<size_t>(M), static_cast<size_t>(N), static_cast<size_t>(K), 1, &params, nullptr,
&mlas_backend_kernel_selector_config_);
return Status::OK();
Comment on lines +116 to +123
MLAS_HALF_GEMM_DATA_PARAMS data;
data.A = A.data();
data.B = B.data();
data.C = reinterpret_cast<MLAS_FP16*>(C.data());
data.lda = K;
data.ldb = N;
data.ldc = N;

Comment on lines 681 to 686
MlasHalfGemmBatch(
static_cast<size_t>(output_count),
static_cast<size_t>(group_output_channels),
static_cast<size_t>(kernel_dim),
1, &gemm_params, nullptr);
1, &gemm_params, nullptr, &mlas_backend_kernel_selector_config_);
}
Laan33 added 5 commits May 20, 2026 16:12
Add native fp16 HGEMM and halfconv IMATMUL support for Arm64 MLAS/KleidiAI, including CPU fp16 MatMul/Gemm exposure, prepack handling, backend selector routing, and focused test coverage.

Also include Apple Arm64 fp16 build enablement, insert-cast mixed-EP coverage, and fixes for fp16 routing edge cases such as zero-K halfgemm handling and backend-native packed-B contracts.

Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Ensure CPU-assigned fp16 nodes without a matching CPU fp16 kernel are routed through the existing fp32 cast fallback when a float CPU kernel exists. This preserves main-compatible behavior for fused nodes such as BiasGelu.

Also split the graph-output cast coverage so the fallback behavior is easier to review.

Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Native ARM64 fp16 execution can now exercise provider paths that were
previously hidden by fp32 fallback behaviour. Update the affected tests to
reflect supported provider coverage and the expected numerical drift from
HQNBIT_CompFp16 fp16 accumulation on large-K MatMulNBits cases.

Signed-off-by: Cathal Lawlor cathal.lawlor@arm.com
@Laan33 Laan33 force-pushed the catlaw01/add-fp16-path branch 2 times, most recently from e0d814e to 2483990 Compare May 20, 2026 22:21
Laan33 added 3 commits May 20, 2026 23:46
Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>

# Conflicts:
#	onnxruntime/core/optimizer/insert_cast_transformer.cc
#	onnxruntime/core/optimizer/insert_cast_transformer.h
#	onnxruntime/core/session/inference_session.cc
Value-initialize MLAS_HALF_GEMM_DATA_PARAMS at call sites so optional
flags and pointers default to null/false unless explicitly set.

Include <cstring> directly in matmul.cc for the native fp16 prepack
memset calls.

Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 41 out of 41 changed files in this pull request and generated 3 comments.

Comment on lines 10 to 15
#include "test/common/random_generator.h"
#include "test/providers/provider_test_utils.h"
#include "test/util/include/scoped_env_vars.h"
#include "default_providers.h"
#include "core/session/onnxruntime_session_options_config_keys.h"

Comment on lines +718 to +722
SessionOptions so;
ASSERT_EQ(so.AddInitializer("B", &b), Status::OK());
ASSERT_EQ(so.config_options.AddConfigEntry("session.enable_cpu_fp16", "1"), Status::OK());
ASSERT_EQ(so.config_options.AddConfigEntry("session.cpu_fp16_use_fp32_fallback_heuristic", "0"), Status::OK());

Comment on lines +321 to +336
const auto SetTypeConstraint = [&](const std::string& type_str, const NodeArg* def) {
if (!def || !def->Exists()) {
return;
}

TypeConstraintMap::const_iterator it = type_schema.find(type_str);
if (it == type_schema.end()) {
return;
}

auto type = DataTypeImpl::TypeFromProto(*(def->TypeAsProto()));
if (replace_fp16_with_float && type == DataTypeImpl::GetTensorType<MLFloat16>()) {
type = DataTypeImpl::GetTensorType<float>();
}

type_constraint_map[type_str] = type;
Laan33 added 2 commits May 25, 2026 15:21
…handling and compatibility checks; add CPU FP16 resize model test

Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
…GPU CI

Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
@hariharans29
Copy link
Copy Markdown
Member

Review — PR #28487 [MLAS][KleidiAI] Add Arm64 fp16 MLAS and KleidiAI support

Verdict: needs significant rework before merge. Direction is good and the kernel work looks solid, but the PR is far too large (~5260 LOC across 41 files), mixes unrelated changes, takes a couple of public ABI/API liberties, and relaxes a long-standing prepack invariant in a way that affects every kernel, not just the new ones.

Scope / hygiene

This PR is doing at least five things that each deserve their own MR:

  1. New MLAS public surface for native fp16 GEMM dispatch + KleidiAI half-GEMM / half-Conv backends.
  2. New CPU EP fp16 MatMul/Gemm kernel registrations (opset 1/9/11/13 for both MatMul and Gemm) and the fp16 conv prepack path in fp16_conv.cc.
  3. InsertCastTransformer rework with new opt-in policy and shape heuristics, plus a substantial constructor signature change and a new resolver-based type-constraint check.
  4. Apple ARM64 fp16 build enablement (CMake guards removed, universal-binary implications).
  5. A pile of unrelated test changes — attention_op_test.cc (disable_webgpu = true), matmul_4bits_test.cc WebGPU tolerance, nhwc_pool_in_op_test.cc provider excludes, nhwc_transformer_test.cc tolerance loosening, resize_op_test.cc excludes, the BatchSizeBatchN rename in MlasDynamicQGemmBatch, the lhs_ptrs_cache_by_pad bug fix in qgemm_kleidiai.cpp.

The qgemm lhs_ptrs_cache_by_pad fix is a real correctness fix (the previous code keyed the cache only on shape and ignored the pad pointer); it should ship on its own with its own test. The BatchSizeBatchN rename is pure churn here. The test-only EP excludes and tolerance changes need standalone justification — FusedConvWithSumFp16 loosening from default to 0.02 / 0.02 is not explained.

Substantive concerns

1. MlasHalfGemmBatch signature change. This is a public MLAS API. Appending a parameter with a C++ default argument preserves source compat for in-tree callers but does not preserve binary compat or function-pointer compat for external consumers, and the MlasHalfGemmBatchOverride typedef in mlasi.h now diverges from the public signature in a subtle way (the override takes the selector config, the public API forwards it). Either commit to the change explicitly and version it, or keep the public API stable and route the selector via a side channel (e.g., MLAS_HALF_GEMM_DATA_PARAMS::BackendSelector, mirroring how MLAS_CONV_PARAMETERS already carries it).

2. MLAS_CONV_PARAMETERS grows a field. Inserting bool InputOutputChannelsLast; mid-struct is an ABI change for anyone who composes this struct out of tree. Append at the end and document the contract, or wrap in a versioned secondary parameter block.

3. New public API by default-arg. MlasHalfGemmNativePackBSize, MlasHalfGemmNativePackB, MlasHalfConvPrepare, MlasHalfConv, MlasHalfConvPackWeightsAndBiasSize, MlasHalfConvPackWeightsAndBias all take BackendKernelSelectorConfig = nullptr as a default. Default arguments in a public C-style header are fragile (no defaults for C callers, hard to evolve). Prefer overload-free explicit-parameter forms with documented defaults in the header text.

4. Session option UX. Two flags with subtle interaction:

session.enable_cpu_fp16 = 0/1                       (default 0)
session.cpu_fp16_use_fp32_fallback_heuristic = 0/1  (default 1)

The second flag only applies when the first is 1, and uses double-negative semantics ("set to 0 to disable the fallback that the first flag's heuristic is using"). Most users will misconfigure this. Collapse into a single tri-state, e.g. session.cpu_fp16_mode = off | heuristic | forced (default off). Also: this is CPU-EP-specific behavior; it belongs in the CPU EP provider options, not the global session option namespace.

5. InsertCastTransformer ↔ MLAS coupling. The transformer now #includes core/mlas/inc/mlas.h and calls MlasHalfGemmNativePackBSize to make graph-rewrite decisions, and it stores MLAS_BACKEND_KERNEL_SELECTOR_CONFIG on the transformer. The hard-coded kMaxNativeFp16GemvM, kMinNativeFp16ConstantMatMulNK, kMinNativeFp16NK constants encode Arm cost-model thresholds inside a generic optimizer. A graph-level transformer should not be making backend-cost decisions for one specific EP. The right place for this policy is the CPU EP's GetCapability or kernel selection, not Level1-applied graph rewriting.

6. Layering: optimizer asks MLAS what its packed-B size would be. That coupling is the wrong direction. If the transformer needs that signal, expose a small predicate on the EP (e.g. CpuExecutionProvider::PrefersNativeFp16(node, shape)) and keep MLAS internals private to MLAS.

7. session_state.cc invariant relaxation. The change replaces

ORT_ENFORCE(weights_to_be_filled_in.buffers_.size() > 0,
            "The kernel ... doesn't have an implementation that can cache computed pre-packed weights");

with a silent if (is_packed && !weights_to_be_filled_in.buffers_.empty()). This change applies globally, not just to fp16 MatMul. Any kernel that returns is_packed = true without populating buffers will now silently lose its packed weights to the shared-prepack cache. This is a real semantics change that needs:

  • A separate flag in the kernel API (PrePackResult::BackendOwnedBuffers or similar) so that "I packed it but I'm keeping it" is a distinct, intentional state, not the same as "I forgot to fill in the buffers."
  • A test that exercises the legacy ORT_ENFORCE failure mode to make sure we still catch the genuinely-broken-kernel case.

8. MatMul<MLFloat16>::PrePack declines shared prepacking. Combined with #7, this means every session pays full packing cost for fp16 MatMul weights, with no path forward. The comment ("Decline shared buffers until the shared prepack cache can validate that layout") is honest but this is exactly the use case shared prepacking exists for. A backend-layout-tag on the shared buffer entry would solve it; that should be the plan rather than a permanent opt-out.

9. Hot-path allocations. halfconv_kleidiai.cpp allocates several std::vector<std::byte> / std::vector<MLAS_FP16> per MlasHalfConv call (and per worker thread in the chunked branch). The neighboring qgemm path uses thread_local KaiHalfTlsBuffers-style TLS scratch — halfconv_kleidiai.cpp should do the same.

10. TLS scratch never shrinks. g_kai_half_tls in halfgemm_kleidiai.cpp grows monotonically across calls (matches the existing g_kai_tls_qgemm pattern). For long-running inference servers, a single large-K outlier permanently inflates per-thread RSS. At minimum apply a soft cap or periodic shrink_to_fit; ideally route through an allocator the host can introspect.

11. Throws from kernel inner loop. MlasHalfGemmCopyPackB<> now MLAS_THROW_EX(std::runtime_error, ...) on size-overflow. These run on worker threads inside MlasTrySimpleParallel. Confirm the exception propagates correctly (it doesn't — MlasTrySimpleParallel doesn't catch); the test KleidiAIPackedBWithBiasThrows only exercises the public-API throw, not the threaded path. Use the existing TryAddSize/TryMultiplySizeT boolean returns and fail the dispatch cleanly.

12. Apple ARM64 fp16 enablement. The original if (NOT APPLE) guard around the fp16 source list and MLAS_F16VEC_INTRINSICS_SUPPORTED was put in for a documented reason ("compiling source files requires a hardware-specific compilation flag; when building a universal binary for APPLE this flag would cause trouble for x64 target"). The PR removes the guard with no replacement. Either prove the universal-binary path still works (Apple lipo'd x86_64+arm64 build in CI) or restore the guard for the universal-binary configuration.

13. KernelTypeConstraintsMatchAllNodeArgs cost. Iterates the entire CPU kernel-create map for every node in every transformer apply. On large graphs this is O(nodes × kernels × constraints) per Level-N pass. Cache the answer per (op_type, version, domain) tuple.

14. OpSchemaKernelTypeStrResolver as static const in a function. Magic-statics are fine, but this resolver pulls in ONNX schemas; in minimal-build configurations the schemas may not be available. Verify the #if !defined(ORT_MINIMAL_BUILD) split actually works at link time in the size-stripped builds.

15. RecordPartitionAssignment only fires on the "force CPU" branch. The "preserve fp16 by clearing EP" branch (node->SetExecutionProviderType("")) and the implicit-CPU branch don't record. If the callback is intended to track every assignment this transformer makes, it's currently lying about half of them.

16. Gemm_MLFloat16 c_data == nullptr semantics. The refactor renames support_mlas to support_mlas_bias and adds use_mlas_bias/use_mlas_no_bias paths, but I can't see the full new structure in the diff snippets. Worth a careful re-read against the test GemmNoTrans_f16 which now covers the missing-C and explicit-zero-C cases — make sure the alpha/beta combinations still hit the right branch.

Nits

  • MlasTryAddSizeT / MlasTryMultiplySizeT in halfgemm.h duplicate mul_overflow_size_t_builtin and the new TryAddSize in mlasi_kleidiai.h. Pick one home (mlasi.h already has MlasMultiplyOverflowsSizeT from [MLAS] Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback #28416) and converge.
  • KaiF16HgemmKernel defined alongside KaiF16IMatmulKernel but only used in one TU — fine, just confirm no ODR issues with the existing F32 variants.
  • using namespace and a stray <iostream> include in mlasi_kleidiai.h — please drop the iostream.
  • MatMul<MLFloat16>::Compute does std::vector<MLAS_HGEMM_DATA_PARAMS> data(max_len); and std::vector<MLAS_HALF_GEMM_DATA_PARAMS> data(max_len); per call; prefer InlinedVector per AGENTS guidance.
  • MakeCpuFp16Model / MakeCpuFp16MatMulModelWithShapes etc. write .onnx files to CWD then std::filesystem::remove them — these will leak under crash. Use TemporaryDirectory or in-memory loaders.
  • New tests use raw string literals for session option keys in places — the Copilot bot already flagged this; tests that already exist in this PR use the kOrt... constants, just be consistent.
  • Conv2d_HalfConv.PrepareRespectsBackendSelectorConfig — make this MLAS_CONV_PARAMETERS parameters{}; (zero-init) for the same reason every other call site is being zero-inited in this PR.

Bottom line

The kernels and KleidiAI plumbing look like good work and I want them in tree. But this PR as posted is asking maintainers to approve simultaneously:

  • a public MLAS API/ABI extension via default arguments,
  • a long-standing SessionState invariant relaxation,
  • a new policy layer in the generic graph optimizer that reaches into MLAS,
  • two new session-scoped config keys with awkward UX,
  • Apple universal-binary build enablement,
  • and several unrelated test changes.

Suggested split:

  1. Independent fixes firstlhs_ptrs_cache_by_pad correctness fix (with regression test), the BatchSizeBatchN rename, the attention_op_test/matmul_4bits_test/resize_op_test exclusions, and the nhwc_transformer_test tolerance change each justified separately.
  2. MLAS half-gemm/half-conv API additions with BackendSelector plumbed through the data-params struct (not as default args), no session_state.cc change yet.
  3. CPU EP fp16 MatMul/Gemm kernels + the fp16_conv.cc prepack path, using the API from MR Remove vsts test runner in cmake file #2 with MatMul<MLFloat16>::PrePack returning a sentinel that the prepack cache understands (modeling fix for Set up CI with Azure Pipelines #7/Fix build #8).
  4. InsertCastTransformer policy redesigned as an EP-side capability decision (CPU EP GetCapability enhancement) rather than a generic transformer pass, with a single tri-state session option.
  5. Apple ARM64 fp16 enablement with documented universal-binary CI evidence.

@Laan33
Copy link
Copy Markdown
Contributor Author

Laan33 commented May 27, 2026

Thank you for the review, we will see what we will go forwards with and keep you up to date @hariharans29

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants