diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index 7ae18db235ccb..7b623fe4cbaad 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -23,6 +23,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/threading.cpp
${MLAS_SRC_DIR}/sgemm.cpp
${MLAS_SRC_DIR}/halfgemm.cpp
+ ${MLAS_SRC_DIR}/halfconv.cpp
${MLAS_SRC_DIR}/qgemm.cpp
${MLAS_SRC_DIR}/qdwconv.cpp
${MLAS_SRC_DIR}/convolve.cpp
@@ -309,6 +310,8 @@ function(setup_kleidiai)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/kai_ukernel_interface.cpp
${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp
+ ${MLAS_SRC_DIR}/kleidiai/halfgemm_kleidiai.cpp
+ ${MLAS_SRC_DIR}/kleidiai/halfconv_kleidiai.cpp
${MLAS_SRC_DIR}/kleidiai/sbgemm_kleidiai.cpp
${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp
${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp
@@ -363,7 +366,7 @@ function (setup_arm_neon_nchwc)
${MLAS_SRC_DIR}/aarch64/SconvNchwcKernelNeon.S
${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S
${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S
- )
+ )
endif()
list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC)
set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE)
@@ -546,8 +549,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
- if (NOT APPLE)
- set(mlas_platform_srcs
+ set(mlas_platform_srcs
${mlas_platform_srcs}
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
@@ -573,37 +575,36 @@ else()
${MLAS_SRC_DIR}/gelu_neon_fp16.h
${MLAS_SRC_DIR}/gelu_neon_fp16.cpp
)
- if (onnxruntime_USE_ARM_NEON_NCHWC)
- list(APPEND mlas_platform_srcs
- ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S
- ${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S
- ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S
- )
- endif()
- set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
- set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
- set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16_8bit.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
- if (onnxruntime_USE_ARM_NEON_NCHWC)
- set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
- endif()
- set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ if (onnxruntime_USE_ARM_NEON_NCHWC AND NOT APPLE)
+ list(APPEND mlas_platform_srcs
+ ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S
+ ${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S
+ ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S
+ )
+ endif()
+ set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
+ set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
+ set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16_8bit.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ if (onnxruntime_USE_ARM_NEON_NCHWC AND NOT APPLE)
+ set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
endif()
+ set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs})
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index d0d8e750285d4..a0c8ae3dfe4ca 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -174,10 +174,10 @@ The **OpSet Version** column uses the following notation:
|||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)|
|||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)|
|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(float)|
-|Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
-|||[11, 12]|**T** = tensor(double), tensor(float)|
-|||[9, 10]|**T** = tensor(double), tensor(float)|
-|||[7, 8]|**T** = tensor(double), tensor(float)|
+|Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|GlobalAveragePool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)|
|||[1, 21]|**T** = tensor(float)|
|GlobalLpPool|*in* X:**T**
*out* Y:**T**|2+|**T** = tensor(float)|
@@ -258,9 +258,9 @@ The **OpSet Version** column uses the following notation:
|||[18, 21]|**T** = tensor(float)|
|||[11, 17]|**T** = tensor(float)|
|||[2, 10]|**T** = tensor(float)|
-|MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
-|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
-|||[1, 8]|**T** = tensor(double), tensor(float)|
+|MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||[1, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulInteger|*in* A:**T1**
*in* B:**T2**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*out* Y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int32)|
|Max|*in* data_0:**T**
*out* max:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
index 9d61165927d8c..2b1d4d8701563 100644
--- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
@@ -462,6 +462,24 @@ static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFil
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
+// Enables opt-in CPU EP float16 execution for supported ops.
+// This allows InsertCastTransformer to preserve eligible fp16 CPU nodes instead of inserting fp16<->fp32 casts.
+// The fp32 fallback heuristic below is still enabled by default and may keep selected shapes on the fp32 path when
+// native fp16 is not expected to be profitable for the active MLAS backend.
+// Option values:
+// - "0": CPU fp16 is not explicitly enabled, and default cast/fallback behavior is used. [DEFAULT]
+// - "1": Enable CPU fp16 preservation for the current opt-in scope.
+static const char* const kOrtSessionOptionsEnableCpuFp16 = "session.enable_cpu_fp16";
+
+// Controls the CPU fp16 -> fp32 fallback heuristic used when session.enable_cpu_fp16 is "1".
+// The heuristic keeps native CPU fp16 only for cases expected to be profitable, such as supported GEMV-like shapes or
+// constant-RHS MatMul when the active MLAS backend reports a native packed-B path. Other eligible nodes are assigned to
+// CPU through fp32 casts when a valid CPU fp32 kernel exists. This is the recommended/default CPU fp16 mode.
+// Option values:
+// - "0": Do not use the heuristic; preserve native CPU fp16 more broadly for eligible CPU fp16 kernels.
+// - "1": Use the fp32 fallback heuristic. [DEFAULT]
+static const char* const kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic = "session.cpu_fp16_use_fp32_fallback_heuristic";
+
// Use LUT (Lookup Table) based GEMM for quantized models when available.
// Option values:
// - "0": Do not use LUT based GEMM. [DEFAULT]
diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc
index c3194f1c57ba7..3cc1bdca98490 100644
--- a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc
+++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc
@@ -536,7 +536,7 @@ Status MoE::ComputeGEMM(const float* A, const float* B, float* C,
template <>
Status MoE::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFloat16* C,
int64_t M, int64_t K, int64_t N, bool transpose_B) const {
- MLAS_HALF_GEMM_DATA_PARAMS params;
+ MLAS_HALF_GEMM_DATA_PARAMS params{};
params.A = A;
params.lda = static_cast(K);
params.C = C;
@@ -551,7 +551,8 @@ Status MoE::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFlo
params.ldb = static_cast(N);
}
- MlasHalfGemmBatch(static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, nullptr);
+ MlasHalfGemmBatch(static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, nullptr,
+ &mlas_backend_kernel_selector_config_);
return Status::OK();
}
diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc
index 6ef2319c1d3f4..69cfe16c20ff5 100644
--- a/onnxruntime/core/framework/session_state.cc
+++ b/onnxruntime/core/framework/session_state.cc
@@ -517,13 +517,10 @@ Status SessionState::PrepackConstantInitializedTensors(
is_packed,
&weights_to_be_filled_in));
- if (is_packed) {
- // BUG CHECK: Ensure that the kernel has filled in the pre-packed weight
- // to be cached if the weight was pre-packed
- ORT_ENFORCE(weights_to_be_filled_in.buffers_.size() > 0,
- "The kernel corresponding to the node ", node.Name(),
- " doesn't have an implementation that can cache computed pre-packed weights");
-
+ // Some kernels pre-pack for their own session but intentionally do not
+ // expose those buffers for sharing, for example when the packed layout is
+ // backend-specific and the shared container has no layout metadata.
+ if (is_packed && !weights_to_be_filled_in.buffers_.empty()) {
const auto& op_type = node.OpType();
// Sanity check
diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
index ddb9daa5e244b..1552c15db2e6d 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -92,15 +92,7 @@ Module Name:
#if (!defined(_MSC_VER)) || (_MSC_VER >= 1930)
#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC)
-#if !defined(__APPLE__)
-// Had to temporary disable fp16 under APPLE ARM64, as compiling
-// the source files require a hardware specific compilation flag.
-// When building an universial binary for APPLE, this flag would
-// cause trouble for x64 target.
-
#define MLAS_F16VEC_INTRINSICS_SUPPORTED
-
-#endif //
#endif // ARM64
#endif // Visual Studio 16 or earlier does not support fp16 intrinsic
@@ -906,6 +898,7 @@ struct MLAS_CONV_PARAMETERS {
float Beta;
MLAS_CONV_ALGORITHM Algorithm;
ptrdiff_t ThreadCount;
+ bool InputOutputChannelsLast;
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr;
const void* PackedFilter = nullptr;
size_t PackedFilterGroupStride = 0;
@@ -1898,11 +1891,20 @@ struct MLAS_HALF_GEMM_DATA_PARAMS {
const MLAS_HALF_GEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be casted into fp16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be casted into fp16*/
+ bool BIsPacked = false; /**< matrix B is pre-packed by MlasHalfGemmPackB/MlasHalfGemmConvertPackB */
+ /**
+ * Matrix B uses a backend-specific direct-consumption packed layout.
+ * When true, B must be produced by MlasHalfGemmNativePackB, ldb must be 0,
+ * Bias must be nullptr, and OutputProcessor must be nullptr.
+ */
+ bool BIsBackendNativePacked = false;
};
/**
* @brief Half precision Batched GEMM: C = A * B + Bias
- * Either A or B can be fp32 or fp16
+ * Either A or B can be fp32 or fp16.
+ * Backend-native packed B is a constrained direct-consumption layout
+ * and does not support runtime Bias or OutputProcessor.
*
* Note: We only support uniform batching, so shapes and types of the
* input must be same across all parameter blocks.
@@ -1913,6 +1915,7 @@ struct MLAS_HALF_GEMM_DATA_PARAMS {
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool
+ * @param[in] BackendKernelSelectorConfig Optional backend selector config
* @return
*/
void
@@ -1923,7 +1926,8 @@ MlasHalfGemmBatch(
const size_t K,
const size_t BatchN,
const MLAS_HALF_GEMM_DATA_PARAMS* DataParams,
- MLAS_THREADPOOL* ThreadPool
+ MLAS_THREADPOOL* ThreadPool,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr
);
/**
@@ -1964,6 +1968,106 @@ MlasHalfGemmPackB(
void* PackedB
);
+/**
+ * @brief For half precision GEMM, returns the size of a backend-native
+ * direct-consumption packing buffer for right hand side B.
+ *
+ * The returned layout is only valid when MlasHalfGemmBatch consumes
+ * MLAS_HALF_GEMM_DATA_PARAMS with BIsBackendNativePacked set to true, ldb set
+ * to 0, Bias set to nullptr, and OutputProcessor set to nullptr. Returns 0
+ * when no backend-native format is available for the given transpose/shape/config.
+ */
+size_t
+MLASCALL
+MlasHalfGemmNativePackBSize(
+ CBLAS_TRANSPOSE TransA,
+ CBLAS_TRANSPOSE TransB,
+ size_t N,
+ size_t K,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr
+ );
+
+/**
+ * @brief For half precision GEMM, pack right hand side B into the
+ * backend-native direct-consumption format.
+ *
+ * Returns false when the backend-native format is unavailable for the given
+ * transpose/shape/config. The resulting buffer must be passed to
+ * MlasHalfGemmBatch with ldb set to 0, BIsBackendNativePacked set to true,
+ * Bias set to nullptr, and OutputProcessor set to nullptr.
+ */
+bool
+MLASCALL
+MlasHalfGemmNativePackB(
+ CBLAS_TRANSPOSE TransA,
+ CBLAS_TRANSPOSE TransB,
+ size_t N,
+ size_t K,
+ const MLAS_FP16* B,
+ size_t ldb,
+ void* PackedB,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr
+ );
+
+bool
+MLASCALL
+MlasHalfConvPrepare(
+ MLAS_CONV_PARAMETERS* Parameters,
+ size_t Dimensions,
+ size_t BatchCount,
+ size_t GroupCount,
+ size_t InputChannels,
+ const int64_t* InputShape,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape,
+ const int64_t* Padding,
+ const int64_t* StrideShape,
+ const int64_t* OutputShape,
+ size_t FilterCount,
+ const MLAS_ACTIVATION* Activation,
+ size_t* WorkingBufferSize,
+ float Beta,
+ bool InputOutputChannelsLast,
+ MLAS_THREADPOOL* ThreadPool,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr);
+
+bool
+MLASCALL
+MlasHalfConv(
+ const MLAS_CONV_PARAMETERS* Parameters,
+ const MLAS_FP16* Input,
+ const MLAS_FP16* Filter,
+ // When true, Filter must point to a packed weights+bias buffer produced by
+ // MlasHalfConvPackWeightsAndBias and Bias must be nullptr.
+ bool FilterAndBiasArePacked,
+ const MLAS_FP16* Bias,
+ MLAS_FP16* WorkingBuffer,
+ MLAS_FP16* Output,
+ MLAS_THREADPOOL* ThreadPool);
+
+size_t
+MLASCALL
+MlasHalfConvPackWeightsAndBiasSize(
+ size_t FilterCount,
+ size_t InputChannels,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr);
+
+bool
+MLASCALL
+MlasHalfConvPackWeightsAndBias(
+ size_t FilterCount,
+ size_t InputChannels,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape,
+ const MLAS_FP16* Filter,
+ // Optional bias to bake into the packed direct-consumption buffer.
+ const MLAS_FP16* Bias,
+ void* PackedWeightsAndBias,
+ MLAS_THREADPOOL* ThreadPool,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr);
+
/**
* @brief For half precision GEMM, convert the float matrix B
* to half precision and pack it into a packing buffer
diff --git a/onnxruntime/core/mlas/lib/halfconv.cpp b/onnxruntime/core/mlas/lib/halfconv.cpp
new file mode 100644
index 0000000000000..22fa48dbc2874
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/halfconv.cpp
@@ -0,0 +1,158 @@
+//
+// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: MIT
+//
+
+/*++
+
+Module Name:
+
+ halfconv.cpp
+
+Abstract:
+
+ This module implements public dispatch wrappers for optional half precision
+ convolution backends.
+
+--*/
+
+#include "mlasi.h"
+
+bool
+MLASCALL
+MlasHalfConvPrepare(
+ MLAS_CONV_PARAMETERS* Parameters,
+ size_t Dimensions,
+ size_t BatchCount,
+ size_t GroupCount,
+ size_t InputChannels,
+ const int64_t* InputShape,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape,
+ const int64_t* Padding,
+ const int64_t* StrideShape,
+ const int64_t* OutputShape,
+ size_t FilterCount,
+ const MLAS_ACTIVATION* Activation,
+ size_t* WorkingBufferSize,
+ float Beta,
+ bool InputOutputChannelsLast,
+ MLAS_THREADPOOL* ThreadPool,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
+ )
+{
+ if (GetMlasPlatform().MlasHalfConvPrepareOverride == nullptr) {
+ return false;
+ }
+
+ return GetMlasPlatform().MlasHalfConvPrepareOverride(
+ Parameters,
+ Dimensions,
+ BatchCount,
+ GroupCount,
+ InputChannels,
+ InputShape,
+ KernelShape,
+ DilationShape,
+ Padding,
+ StrideShape,
+ OutputShape,
+ FilterCount,
+ Activation,
+ WorkingBufferSize,
+ Beta,
+ InputOutputChannelsLast,
+ ThreadPool,
+ BackendKernelSelectorConfig);
+}
+
+bool
+MLASCALL
+MlasHalfConv(
+ const MLAS_CONV_PARAMETERS* Parameters,
+ const MLAS_FP16* Input,
+ const MLAS_FP16* Filter,
+ bool FilterAndBiasArePacked,
+ const MLAS_FP16* Bias,
+ MLAS_FP16* WorkingBuffer,
+ MLAS_FP16* Output,
+ MLAS_THREADPOOL* ThreadPool
+ )
+{
+ if (GetMlasPlatform().MlasHalfConvOverride == nullptr) {
+ return false;
+ }
+
+ if (FilterAndBiasArePacked && Bias != nullptr) {
+ return false;
+ }
+
+ return GetMlasPlatform().MlasHalfConvOverride(
+ Parameters,
+ Input,
+ Filter,
+ FilterAndBiasArePacked,
+ Bias,
+ WorkingBuffer,
+ Output,
+ ThreadPool);
+}
+
+size_t
+MLASCALL
+MlasHalfConvPackWeightsAndBiasSize(
+ size_t FilterCount,
+ size_t InputChannels,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
+ )
+{
+ if (BackendKernelSelectorConfig != nullptr && !BackendKernelSelectorConfig->use_kleidiai) {
+ return 0;
+ }
+
+ if (GetMlasPlatform().MlasHalfConvPackWeightsAndBiasSizeOverride == nullptr) {
+ return 0;
+ }
+
+ return GetMlasPlatform().MlasHalfConvPackWeightsAndBiasSizeOverride(
+ FilterCount,
+ InputChannels,
+ KernelShape,
+ DilationShape);
+}
+
+bool
+MLASCALL
+MlasHalfConvPackWeightsAndBias(
+ size_t FilterCount,
+ size_t InputChannels,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape,
+ const MLAS_FP16* Filter,
+ const MLAS_FP16* Bias,
+ void* PackedWeightsAndBias,
+ MLAS_THREADPOOL* ThreadPool,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
+ )
+{
+ if (BackendKernelSelectorConfig != nullptr && !BackendKernelSelectorConfig->use_kleidiai) {
+ return false;
+ }
+
+ if (GetMlasPlatform().MlasHalfConvPackWeightsAndBiasOverride == nullptr) {
+ return false;
+ }
+
+ return GetMlasPlatform().MlasHalfConvPackWeightsAndBiasOverride(
+ FilterCount,
+ InputChannels,
+ KernelShape,
+ DilationShape,
+ Filter,
+ Bias,
+ PackedWeightsAndBias,
+ ThreadPool);
+}
diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp
index 66a335665d024..e7fe787b7f7dd 100644
--- a/onnxruntime/core/mlas/lib/halfgemm.cpp
+++ b/onnxruntime/core/mlas/lib/halfgemm.cpp
@@ -19,9 +19,38 @@ Module Name:
#include "mlas_float16.h"
#include "halfgemm.h"
+#if defined(USE_KLEIDIAI)
+#include "kleidiai/mlasi_kleidiai.h"
+#endif
#include
+static void
+MlasHalfGemmZeroKBatch(
+ const size_t M,
+ const size_t N,
+ const size_t BatchN,
+ const MLAS_HALF_GEMM_DATA_PARAMS* DataParams
+ )
+{
+ for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
+ const auto* Data = &DataParams[gemm_i];
+ auto* C = reinterpret_cast(Data->C);
+ const auto* Bias = reinterpret_cast(Data->Bias);
+ const size_t ldc = Data->ldc;
+
+ for (size_t m = 0; m < M; m++) {
+ for (size_t n = 0; n < N; n++) {
+ C[m * ldc + n] = Bias == nullptr ? MLAS_FP16::FromBits(0) : Bias[n];
+ }
+ }
+
+ if (Data->OutputProcessor != nullptr) {
+ Data->OutputProcessor->Process(Data->C, 0, 0, M, N, ldc);
+ }
+ }
+}
+
bool MLASCALL
MlasFp16AccelerationSupported()
{
@@ -41,9 +70,41 @@ MlasHalfGemmBatch(
const size_t K,
const size_t BatchN,
const MLAS_HALF_GEMM_DATA_PARAMS* DataParams,
- MLAS_THREADPOOL* ThreadPool
+ MLAS_THREADPOOL* ThreadPool,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
)
{
+ MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig);
+
+ if (BatchN == 0 || M == 0 || N == 0) {
+ return;
+ }
+
+ if (K == 0) {
+ MlasHalfGemmZeroKBatch(M, N, BatchN, DataParams);
+ return;
+ }
+
+#if defined(USE_KLEIDIAI)
+ if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) &&
+ GetMlasPlatform().MlasHalfGemmBatchOverride != nullptr &&
+ GetMlasPlatform().MlasHalfGemmBatchOverride(
+ M, N, K, BatchN, DataParams, ThreadPool, BackendKernelSelectorConfig)) {
+ return;
+ }
+#endif
+
+ // BIsPacked denotes the generic MLAS halfgemm packed-B layout and can be
+ // consumed here. Backend-native packed layouts are separate and must not
+ // silently fall through to the generic kernels.
+ for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
+ if (DataParams[gemm_i].BIsBackendNativePacked) {
+ MLAS_THROW_EX(
+ std::runtime_error,
+ "backend-native halfgemm packed B is not supported by generic MLAS halfgemm");
+ }
+ }
+
const MLAS_HALFGEMM_DISPATCH* dispatch = MlasHalfGemmGetDispatch();
MLAS_HALFGEMM_OPERATION* operation = dispatch->Operation;
@@ -128,11 +189,24 @@ MlasHalfGemmPackBSize(
// No packing routine provided
return 0;
}
- const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1);
- const size_t BytesRequired = N * AlignedK * FP16_SIZE + padding;
+ size_t aligned_k_input = 0;
+ if (MlasTryAddSizeT(K, PackedK - 1, &aligned_k_input)) {
+ return 0;
+ }
+ const size_t AlignedK = aligned_k_input & ~(PackedK - 1);
+
+ size_t BytesRequired = 0;
+ if (MlasTryMultiplySizeT(N, AlignedK, &BytesRequired) ||
+ MlasTryMultiplySizeT(BytesRequired, FP16_SIZE, &BytesRequired) ||
+ MlasTryAddSizeT(BytesRequired, padding, &BytesRequired)) {
+ return 0;
+ }
const size_t BufferAlignment = MlasGetPreferredBufferAlignment();
- const size_t AlignedBytesRequired =
- (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1);
+ size_t aligned_bytes_input = 0;
+ if (MlasTryAddSizeT(BytesRequired, BufferAlignment - 1, &aligned_bytes_input)) {
+ return 0;
+ }
+ const size_t AlignedBytesRequired = aligned_bytes_input & ~(BufferAlignment - 1);
return AlignedBytesRequired;
}
@@ -150,6 +224,62 @@ MlasHalfGemmPackB(
dispatch->CopyPackBRoutine((_mlas_fp16_*)PackedB, (const _mlas_fp16_*)B, ldb, N, K);
}
+size_t
+MLASCALL
+MlasHalfGemmNativePackBSize(
+ CBLAS_TRANSPOSE TransA,
+ CBLAS_TRANSPOSE TransB,
+ size_t N,
+ size_t K,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
+ )
+{
+#if defined(USE_KLEIDIAI)
+ if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) &&
+ GetMlasPlatform().MlasHalfGemmPackBSizeOverride != nullptr) {
+ return GetMlasPlatform().MlasHalfGemmPackBSizeOverride(TransA, TransB, N, K);
+ }
+#else
+ MLAS_UNREFERENCED_PARAMETER(TransA);
+ MLAS_UNREFERENCED_PARAMETER(TransB);
+ MLAS_UNREFERENCED_PARAMETER(N);
+ MLAS_UNREFERENCED_PARAMETER(K);
+ MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig);
+#endif
+ return 0;
+}
+
+bool
+MLASCALL
+MlasHalfGemmNativePackB(
+ CBLAS_TRANSPOSE TransA,
+ CBLAS_TRANSPOSE TransB,
+ size_t N,
+ size_t K,
+ const MLAS_FP16* B,
+ size_t ldb,
+ void* PackedB,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
+ )
+{
+#if defined(USE_KLEIDIAI)
+ if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) &&
+ GetMlasPlatform().MlasHalfGemmPackBOverride != nullptr) {
+ return GetMlasPlatform().MlasHalfGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB);
+ }
+#else
+ MLAS_UNREFERENCED_PARAMETER(TransA);
+ MLAS_UNREFERENCED_PARAMETER(TransB);
+ MLAS_UNREFERENCED_PARAMETER(N);
+ MLAS_UNREFERENCED_PARAMETER(K);
+ MLAS_UNREFERENCED_PARAMETER(B);
+ MLAS_UNREFERENCED_PARAMETER(ldb);
+ MLAS_UNREFERENCED_PARAMETER(PackedB);
+ MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig);
+#endif
+ return false;
+}
+
void
MLASCALL
MlasHalfGemmConvertPackB(
@@ -574,7 +704,7 @@ MlasGemmBatch(
const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = {
MlasHalfGemmOperation,
- nullptr,
+ MlasHalfGemmCopyPackB,
MlasHalfGemmConvertPackB,
MLAS_HALF_GEMM_KERNEL_DEFAULT::PackedK,
MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM,
diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h
index 529db48f58e6f..4d0a9b9888b3a 100644
--- a/onnxruntime/core/mlas/lib/halfgemm.h
+++ b/onnxruntime/core/mlas/lib/halfgemm.h
@@ -32,14 +32,57 @@ Module Name:
#pragma once
-#include
#include
+#include
+#include
+#include
+#include
#include
#include "mlasi.h"
#include "mlas_float16.h"
+MLAS_FORCEINLINE
+bool
+MlasTryMultiplySizeT(
+ size_t a,
+ size_t b,
+ size_t* out
+ )
+{
+#if defined(__has_builtin)
+#if __has_builtin(__builtin_mul_overflow)
+ return __builtin_mul_overflow(a, b, out);
+#endif
+#endif
+ if (b != 0 && a > (std::numeric_limits::max)() / b) {
+ return true;
+ }
+ *out = a * b;
+ return false;
+}
+
+MLAS_FORCEINLINE
+bool
+MlasTryAddSizeT(
+ size_t a,
+ size_t b,
+ size_t* out
+ )
+{
+#if defined(__has_builtin)
+#if __has_builtin(__builtin_add_overflow)
+ return __builtin_add_overflow(a, b, out);
+#endif
+#endif
+ if (a > (std::numeric_limits::max)() - b) {
+ return true;
+ }
+ *out = a + b;
+ return false;
+}
+
/**
* @brief Define the default striding parameters for
* the half precision gemm operation
@@ -71,12 +114,46 @@ MlasHalfGemmCopyPackB(
size_t CountK
)
{
- MLAS_UNREFERENCED_PARAMETER(D);
- MLAS_UNREFERENCED_PARAMETER(B);
- MLAS_UNREFERENCED_PARAMETER(ldb);
- MLAS_UNREFERENCED_PARAMETER(CountN);
- MLAS_UNREFERENCED_PARAMETER(CountK);
- // No packing needed by default
+ size_t aligned_count_k_input = 0;
+ if (MlasTryAddSizeT(CountK, KernelType::PackedK - 1, &aligned_count_k_input)) {
+ MLAS_THROW_EX(std::runtime_error, "MlasHalfGemmCopyPackB aligned K overflow");
+ }
+ const size_t AlignedCountK = aligned_count_k_input & ~(KernelType::PackedK - 1);
+ size_t PaddingCountK = AlignedCountK - CountK;
+
+ if (ldb == CountN) {
+ size_t bytes_to_copy = 0;
+ if (MlasTryMultiplySizeT(CountK, CountN, &bytes_to_copy) ||
+ MlasTryMultiplySizeT(bytes_to_copy, sizeof(_mlas_fp16_), &bytes_to_copy)) {
+ MLAS_THROW_EX(std::runtime_error, "MlasHalfGemmCopyPackB size overflow");
+ }
+ std::memcpy(D, B, bytes_to_copy);
+ if (PaddingCountK > 0) {
+ size_t padding_bytes = 0;
+ if (MlasTryMultiplySizeT(PaddingCountK, CountN, &padding_bytes) ||
+ MlasTryMultiplySizeT(padding_bytes, sizeof(_mlas_fp16_), &padding_bytes)) {
+ MLAS_THROW_EX(std::runtime_error, "MlasHalfGemmCopyPackB padding size overflow");
+ }
+ std::memset(D + CountK * CountN, 0, padding_bytes);
+ }
+ return;
+ }
+
+ size_t row_bytes = 0;
+ if (MlasTryMultiplySizeT(CountN, sizeof(_mlas_fp16_), &row_bytes)) {
+ MLAS_THROW_EX(std::runtime_error, "MlasHalfGemmCopyPackB row size overflow");
+ }
+ while (CountK > 0) {
+ std::memcpy(D, B, row_bytes);
+ B += ldb;
+ D += CountN;
+ CountK--;
+ }
+ while (PaddingCountK > 0) {
+ std::memset(D, 0, row_bytes);
+ D += CountN;
+ PaddingCountK--;
+ }
}
/**
diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp
index d7f5a90b00589..f9bb3da1c5b57 100644
--- a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp
+++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp
@@ -179,7 +179,7 @@ MlasHalfGemmKernel(
const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon = {
MlasHalfGemmOperation,
- nullptr,
+ MlasHalfGemmCopyPackB,
MlasHalfGemmConvertPackB,
MLAS_HALF_GEMM_KERNEL_NEON::PackedK,
MLAS_HALF_GEMM_KERNEL_NEON::KernelMaxM,
diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp
index 6ee80594c6b49..a352fc85abbd1 100644
--- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp
+++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp
@@ -28,6 +28,7 @@
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h"
// IMATMUL
#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
+#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h"
// SME2 kernels
// GEMM/QGEMM/SBGEMM
@@ -40,6 +41,11 @@
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h"
// IMATMUL
#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h"
+#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h"
+
+// FP16 HGEMM kernels
+#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla.h"
+#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h"
#if defined(ENABLE_QMX_KERNELS)
// QMX kernels (optional)
@@ -229,6 +235,12 @@ const KaiF32IMatmulKernel imatmul_conv_sme =
const KaiF32IMatmulKernel imatmul_conv_sme2 =
KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa);
+const KaiF16IMatmulKernel imatmul_f16_conv_sme =
+ KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa);
+
+const KaiF16IMatmulKernel imatmul_f16_conv_sme2 =
+ KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa);
+
const KaiBF16SBgemmKernel sbgemm_gemm_sme2 =
KAI_WRAP_UKERNEL_RUN_MATMUL_11(matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa);
@@ -246,6 +258,12 @@ const KaiDynamicQGemmKernel qgemm_gemm_sme =
const KaiDynamicQGemmKernel qgemm_gemm_sme2 =
KAI_WRAP_UKERNEL_RUN_MATMUL_11(matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa);
+const KaiF16HgemmKernel hgemm_sme =
+ KAI_WRAP_UKERNEL_RUN_MATMUL_10_LHS_OFFSET(matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla);
+
+const KaiF16HgemmKernel hgemm_sme2 =
+ KAI_WRAP_UKERNEL_RUN_MATMUL_10_LHS_PACKED_OFFSET(matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot);
+
#if defined(ENABLE_QMX_KERNELS)
@@ -350,6 +368,14 @@ const KaiF32IMatmulKernel& GetKleidiAIF32IMatmulUKernel() {
}
}
+const KaiF16IMatmulKernel& GetKleidiAIF16IMatmulUKernel() {
+ if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) {
+ return imatmul_f16_conv_sme2;
+ } else {
+ return imatmul_f16_conv_sme;
+ }
+}
+
const KaiDynamicQGemmKernel& GetKleidiAIQGemmUKernel() {
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) {
return qgemm_gemm_sme2;
@@ -372,3 +398,11 @@ const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel() {
// Currently only SME2 variant exists for bfloat16/SBGEMM kernel
return sbgemm_gemm_sme2;
}
+
+const KaiF16HgemmKernel& GetKleidiAIHgemmUKernel() {
+ if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) {
+ return hgemm_sme2;
+ } else {
+ return hgemm_sme;
+ }
+}
diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h
index 155ecf1762b3b..9b42e30a969d0 100644
--- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h
+++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h
@@ -18,6 +18,10 @@
#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h"
+#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h"
+
+#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p_interface.h"
+
// Wrapper type that carries a stable "name" alongside the KAI ukernel interface.
// This avoids needing to infer which underlying microkernel was selected from a function pointer.
template
@@ -41,8 +45,14 @@ using KaiDynamicQGemmKernel = KaiMatmulKernel;
+// Wrapper for FP16 IMATMUL kernels used by the KleidiAI convolution implementation.
+using KaiF16IMatmulKernel = KaiMatmulKernel;
+
using KaiBF16SBgemmKernel = KaiMatmulKernel;
+// Wrapper for FP16 HGEMM kernels producing FP16 output.
+using KaiF16HgemmKernel = KaiMatmulKernel;
+
// Returns the selected Qnbit GEMM ukernel based on runtime CPU capabilities.
const KaiQnbitGemmKernel& GetKleidiAIGemmUKernel();
@@ -61,5 +71,11 @@ const KaiF32SgemvKernel& GetKleidiAISGemvUKernel();
// Returns the selected FP32 IMATMUL ukernel used by the KleidiAI convolution implementation.
const KaiF32IMatmulKernel& GetKleidiAIF32IMatmulUKernel();
+// Returns the selected FP16 IMATMUL ukernel used by the KleidiAI convolution implementation.
+const KaiF16IMatmulKernel& GetKleidiAIF16IMatmulUKernel();
+
// Returns the selected BF16 SBGEMM ukernel used by the KleidiAI based on runtime CPU capabilities.
-const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel();
\ No newline at end of file
+const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel();
+
+// Returns the selected FP16 HGEMM ukernel based on runtime CPU capabilities.
+const KaiF16HgemmKernel& GetKleidiAIHgemmUKernel();
diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
index cca4f5a19c417..1b8b80ed496d9 100644
--- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
+++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
@@ -480,14 +480,12 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s
1, 1
};
- auto& lhs_ptrs_cache = lhs_ptrs_cache_by_pad[cur_pad_ptr];
-
std::shared_ptr lhs_ptrs;
- if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) {
+ if (auto found = lhs_ptrs_cache_by_pad[cur_pad_ptr].find(key); found != lhs_ptrs_cache_by_pad[cur_pad_ptr].end()) {
lhs_ptrs = found->second;
} else {
lhs_ptrs = LhsPtrFill(ci, ih, iw, kh, kw, sh, sw, padding, &pad_ptr[0]);
- lhs_ptrs_cache[key] = lhs_ptrs;
+ lhs_ptrs_cache_by_pad[cur_pad_ptr][key] = lhs_ptrs;
}
MultiThreadedLHSPackSme(ThreadPool, ci, m, kh, kw, &lhs_ptrs[0], &lhs[0], activation_src, &pad_ptr[0]);
diff --git a/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp
new file mode 100644
index 0000000000000..f1261b399e380
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp
@@ -0,0 +1,1002 @@
+//
+// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: MIT
+//
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h"
+#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h"
+#include "kai_ukernel_interface.h"
+#include "mlasi_kleidiai.h"
+
+namespace
+{
+
+// Cache-oriented heuristics for the chunked IMATMUL path. These bound the
+// packed LHS working set so we only split when the single-pass pack/compute
+// footprint is large enough that chunking can improve locality.
+constexpr size_t AutomaticMaximumLhsChunkBytes = 2 * 1024 * 1024;
+constexpr size_t MinimumLhsPackedBytesForAutomaticChunking = 8 * 1024 * 1024;
+constexpr size_t MinimumEffectiveKForAutomaticChunking = 1024;
+// Small filter counts typically do not amortize the extra chunking overhead.
+constexpr size_t MaximumFilterCountForAutomaticChunking = 32;
+
+bool
+TryComputeKernelSize(
+ size_t dilation,
+ size_t kernel,
+ size_t& dilated_kernel
+)
+{
+ if (dilation == 0 || kernel == 0) {
+ return false;
+ }
+
+ size_t scaled_kernel = 0;
+ if (mul_overflow_size_t_builtin(dilation, kernel, &scaled_kernel) || scaled_kernel < (dilation - 1)) {
+ return false;
+ }
+
+ dilated_kernel = scaled_kernel - (dilation - 1);
+ return true;
+}
+
+bool
+TryComputeConvOutSize(
+ size_t input,
+ size_t kernel,
+ size_t padding,
+ size_t stride,
+ size_t& output
+)
+{
+ output = 0;
+ if (stride == 0) {
+ return false;
+ }
+
+ size_t total_padding = 0;
+ size_t padded_input = 0;
+ if (mul_overflow_size_t_builtin(padding, 2, &total_padding) ||
+ !TryAddSize(input, total_padding, padded_input)) {
+ return false;
+ }
+
+ if (padded_input < kernel) {
+ return true;
+ }
+
+ output = ((padded_input - kernel) / stride) + 1;
+ return true;
+}
+
+bool
+TryComputeOutputSize(
+ size_t input_height,
+ size_t input_width,
+ size_t kernel_height,
+ size_t kernel_width,
+ size_t padding_height,
+ size_t padding_width,
+ size_t stride_height,
+ size_t stride_width,
+ size_t& output_height,
+ size_t& output_width,
+ size_t& output_size
+)
+{
+ if (!TryComputeConvOutSize(input_height, kernel_height, padding_height, stride_height, output_height) ||
+ !TryComputeConvOutSize(input_width, kernel_width, padding_width, stride_width, output_width) ||
+ mul_overflow_size_t_builtin(output_height, output_width, &output_size)) {
+ return false;
+ }
+
+ return true;
+}
+
+bool
+TryComputeOutputSize(
+ size_t input_height,
+ size_t input_width,
+ size_t kernel_height,
+ size_t kernel_width,
+ size_t padding_height,
+ size_t padding_width,
+ size_t stride_height,
+ size_t stride_width,
+ size_t& output_size
+)
+{
+ size_t output_height = 0;
+ size_t output_width = 0;
+ return TryComputeOutputSize(
+ input_height,
+ input_width,
+ kernel_height,
+ kernel_width,
+ padding_height,
+ padding_width,
+ stride_height,
+ stride_width,
+ output_height,
+ output_width,
+ output_size);
+}
+
+size_t
+SelectMaximumLhsChunkBytes(
+ size_t full_lhs_size,
+ size_t filter_count,
+ size_t effective_k
+)
+{
+ // The bounded-LHS path is most useful when LHS packing is cache-hostile and
+ // there are few output channels, so full-LHS reuse across N tiles is limited.
+ if (full_lhs_size >= MinimumLhsPackedBytesForAutomaticChunking &&
+ filter_count <= MaximumFilterCountForAutomaticChunking &&
+ effective_k >= MinimumEffectiveKForAutomaticChunking) {
+ return AutomaticMaximumLhsChunkBytes;
+ }
+
+ return 0;
+}
+
+bool
+IsPaddingSymmetric2D(const MLAS_CONV_PARAMETERS* parameters)
+{
+ return parameters->Padding[0] == parameters->Padding[1] &&
+ parameters->Padding[0] == parameters->Padding[2] &&
+ parameters->Padding[0] == parameters->Padding[3];
+}
+
+bool
+CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* parameters)
+{
+ if (parameters == nullptr) {
+ return false;
+ }
+
+ if (parameters->BackendKernelSelectorConfig != nullptr &&
+ !parameters->BackendKernelSelectorConfig->use_kleidiai) {
+ KLEIDIAI_DEBUG_LOG("User explicitly disabled KleidiAI, returning false from MlasHalfConv.");
+ return false;
+ }
+
+ if ((parameters->Dimensions != 2) ||
+ (parameters->BatchCount != 1) ||
+ (parameters->GroupCount != 1) ||
+ (parameters->Beta != 0.0f) ||
+ !(parameters->Activation == nullptr ||
+ parameters->Activation->ActivationKind == MlasIdentityActivation) ||
+ !IsPaddingSymmetric2D(parameters)) {
+ KLEIDIAI_DEBUG_LOG("MlasHalfConv capability check failed: unsupported configuration.");
+ return false;
+ }
+
+ size_t d_kh = 0;
+ size_t d_kw = 0;
+ size_t output_size = 0;
+ if (!TryComputeKernelSize(parameters->DilationShape[0], parameters->KernelShape[0], d_kh) ||
+ !TryComputeKernelSize(parameters->DilationShape[1], parameters->KernelShape[1], d_kw) ||
+ !TryComputeOutputSize(
+ parameters->InputShape[0],
+ parameters->InputShape[1],
+ d_kh,
+ d_kw,
+ parameters->Padding[0],
+ parameters->Padding[1],
+ parameters->StrideShape[0],
+ parameters->StrideShape[1],
+ output_size)) {
+ return false;
+ }
+
+ if (output_size == 0 ||
+ output_size != parameters->OutputSize ||
+ parameters->InputChannels == 0 ||
+ parameters->FilterCount == 0 ||
+ parameters->FilterCount == 1 ||
+ parameters->KernelShape[0] < 3 ||
+ parameters->KernelShape[1] < 3) {
+ KLEIDIAI_DEBUG_LOG("MlasHalfConv capability check failed: shape/heuristic gating.");
+ return false;
+ }
+
+ return true;
+}
+
+bool
+GetPackedFilterSize(
+ size_t filter_count,
+ size_t input_channels,
+ const int64_t* kernel_shape,
+ const int64_t* dilation_shape,
+ size_t* packed_size
+)
+{
+ if (packed_size == nullptr ||
+ kernel_shape == nullptr ||
+ dilation_shape == nullptr ||
+ filter_count <= 1 ||
+ input_channels == 0 ||
+ kernel_shape[0] < 3 ||
+ kernel_shape[1] < 3 ||
+ dilation_shape[0] <= 0 ||
+ dilation_shape[1] <= 0) {
+ return false;
+ }
+
+ size_t d_kh = 0;
+ size_t d_kw = 0;
+ size_t k_chunk_count = 0;
+ if (!TryComputeKernelSize(static_cast(dilation_shape[0]), static_cast(kernel_shape[0]), d_kh) ||
+ !TryComputeKernelSize(static_cast(dilation_shape[1]), static_cast(kernel_shape[1]), d_kw) ||
+ mul_overflow_size_t_builtin(d_kh, d_kw, &k_chunk_count)) {
+ return false;
+ }
+
+ *packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
+ filter_count, k_chunk_count, input_channels
+ );
+ return *packed_size != 0;
+}
+
+bool
+NchwToNhwc(
+ const MLAS_FP16* input,
+ size_t channels,
+ size_t height,
+ size_t width,
+ std::vector& output
+)
+{
+ size_t element_count = 0;
+ if (mul_overflow_size_t_builtin(channels, height, &element_count) ||
+ mul_overflow_size_t_builtin(element_count, width, &element_count)) {
+ return false;
+ }
+ output.resize(element_count);
+
+ if (input == nullptr) {
+ return false;
+ }
+
+ for (size_t c = 0; c < channels; ++c) {
+ for (size_t h = 0; h < height; ++h) {
+ for (size_t w = 0; w < width; ++w) {
+ output[(h * width + w) * channels + c] = input[(c * height + h) * width + w];
+ }
+ }
+ }
+
+ return true;
+}
+
+bool
+FillIndirectionTable(
+ const MLAS_FP16* input_nhwc,
+ const MLAS_FP16* pad,
+ size_t input_channels,
+ size_t input_height,
+ size_t input_width,
+ size_t kernel_height,
+ size_t kernel_width,
+ size_t stride_height,
+ size_t stride_width,
+ size_t padding,
+ size_t output_size,
+ std::vector& indirection
+)
+{
+ const auto& imatmul = GetKleidiAIF16IMatmulUKernel();
+ const size_t m_step = imatmul.ukernel.get_m_step();
+ size_t lhs_ptrs_k = 0;
+ if (mul_overflow_size_t_builtin(kernel_height, kernel_width, &lhs_ptrs_k)) {
+ return false;
+ }
+
+ size_t lhs_ptrs_m = 0;
+ if (mul_overflow_size_t_builtin(m_step, MlasDivRoundup(output_size, m_step), &lhs_ptrs_m)) {
+ return false;
+ }
+
+ size_t table_size = 0;
+ if (mul_overflow_size_t_builtin(lhs_ptrs_k, lhs_ptrs_m, &table_size)) {
+ return false;
+ }
+ indirection.resize(table_size);
+
+ std::fill(indirection.begin(), indirection.end(), pad);
+
+ auto ptr_offset = [lhs_ptrs_k, m_step](size_t k, size_t m) {
+ return ((m / m_step) * lhs_ptrs_k * m_step) + (k * m_step) + (m % m_step);
+ };
+
+ auto pixel_ptr = [=](size_t h, size_t w) -> const void* {
+ if (h < padding || w < padding) {
+ return pad;
+ }
+
+ h -= padding;
+ w -= padding;
+ if (h >= input_height || w >= input_width) {
+ return pad;
+ }
+
+ return input_nhwc + (h * input_width + w) * input_channels;
+ };
+
+ size_t output_height = 0;
+ size_t output_width = 0;
+ size_t computed_output_size = 0;
+ if (!TryComputeOutputSize(
+ input_height,
+ input_width,
+ kernel_height,
+ kernel_width,
+ padding,
+ padding,
+ stride_height,
+ stride_width,
+ output_height,
+ output_width,
+ computed_output_size) ||
+ computed_output_size != output_size) {
+ return false;
+ }
+
+ size_t m = 0;
+ for (size_t oh = 0; oh < output_height; ++oh) {
+ for (size_t ow = 0; ow < output_width; ++ow, ++m) {
+ size_t k = 0;
+ const size_t input_base_h = oh * stride_height;
+ const size_t input_base_w = ow * stride_width;
+ for (size_t kh = 0; kh < kernel_height; ++kh) {
+ for (size_t kw = 0; kw < kernel_width; ++kw, ++k) {
+ indirection[ptr_offset(k, m)] = pixel_ptr(input_base_h + kh, input_base_w + kw);
+ }
+ }
+ }
+ }
+
+ return m == output_size;
+}
+
+bool
+PrepareLhsInput(
+ const MLAS_CONV_PARAMETERS* parameters,
+ const MLAS_FP16* input,
+ size_t output_size,
+ std::vector& input_nhwc,
+ std::vector& pad,
+ std::vector& indirection
+)
+{
+ size_t d_kh = 0;
+ size_t d_kw = 0;
+ if (!TryComputeKernelSize(parameters->DilationShape[0], parameters->KernelShape[0], d_kh) ||
+ !TryComputeKernelSize(parameters->DilationShape[1], parameters->KernelShape[1], d_kw)) {
+ return false;
+ }
+
+ const size_t input_channels = parameters->InputChannels;
+
+ const MLAS_FP16* input_nhwc_data = input;
+ if (!parameters->InputOutputChannelsLast) {
+ if (!NchwToNhwc(
+ input,
+ input_channels,
+ parameters->InputShape[0],
+ parameters->InputShape[1],
+ input_nhwc
+ )) {
+ return false;
+ }
+ input_nhwc_data = input_nhwc.data();
+ }
+
+ pad.resize(input_channels);
+ std::fill(pad.begin(), pad.end(), MLAS_FP16::FromBits(0));
+
+ return FillIndirectionTable(
+ input_nhwc_data,
+ pad.data(),
+ input_channels,
+ parameters->InputShape[0],
+ parameters->InputShape[1],
+ d_kh,
+ d_kw,
+ parameters->StrideShape[0],
+ parameters->StrideShape[1],
+ parameters->Padding[0],
+ output_size,
+ indirection
+ );
+}
+
+bool
+PackFilter(
+ size_t filter_count,
+ size_t input_channels,
+ const int64_t* kernel_shape,
+ const int64_t* dilation_shape,
+ const MLAS_FP16* filter,
+ const MLAS_FP16* bias,
+ void* packed_filter
+)
+{
+ if (filter == nullptr || packed_filter == nullptr ||
+ kernel_shape == nullptr || dilation_shape == nullptr ||
+ filter_count <= 1 || input_channels == 0 ||
+ kernel_shape[0] < 3 || kernel_shape[1] < 3 ||
+ dilation_shape[0] <= 0 || dilation_shape[1] <= 0) {
+ return false;
+ }
+
+ const size_t kernel_height = static_cast(kernel_shape[0]);
+ const size_t kernel_width = static_cast(kernel_shape[1]);
+ const size_t dilation_height = static_cast(dilation_shape[0]);
+ const size_t dilation_width = static_cast(dilation_shape[1]);
+ size_t d_kh = 0;
+ size_t d_kw = 0;
+ size_t k_chunk_count = 0;
+ if (!TryComputeKernelSize(dilation_height, kernel_height, d_kh) ||
+ !TryComputeKernelSize(dilation_width, kernel_width, d_kw) ||
+ mul_overflow_size_t_builtin(d_kh, d_kw, &k_chunk_count)) {
+ return false;
+ }
+
+ size_t reordered_size = 0;
+ if (mul_overflow_size_t_builtin(k_chunk_count, input_channels, &reordered_size) ||
+ mul_overflow_size_t_builtin(reordered_size, filter_count, &reordered_size)) {
+ return false;
+ }
+
+ std::vector reordered_filter;
+ reordered_filter.resize(reordered_size);
+ std::fill(reordered_filter.begin(), reordered_filter.end(), MLAS_FP16::FromBits(0));
+
+ for (size_t oc = 0; oc < filter_count; ++oc) {
+ for (size_t ic = 0; ic < input_channels; ++ic) {
+ for (size_t kh = 0; kh < kernel_height; ++kh) {
+ for (size_t kw = 0; kw < kernel_width; ++kw) {
+ const size_t src = ((oc * input_channels + ic) * kernel_height + kh) * kernel_width + kw;
+ const size_t dk = ((kh * dilation_height) * d_kw + (kw * dilation_width)) * input_channels + ic;
+ reordered_filter[dk * filter_count + oc] = filter[src];
+ }
+ }
+ }
+ }
+
+ std::vector zero_bias;
+ const MLAS_FP16* bias_data = bias;
+ if (bias_data == nullptr) {
+ zero_bias.resize(filter_count);
+ std::fill(zero_bias.begin(), zero_bias.end(), MLAS_FP16::FromBits(0));
+ bias_data = zero_bias.data();
+ }
+
+ KLEIDIAI_KERNEL_LOG("kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme" << " N=" << filter_count << " k_chunk_count=" << k_chunk_count << " k_chunk_length=" << input_channels << " rhs_stride_row=" << (filter_count * sizeof(MLAS_FP16)));
+ kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
+ filter_count,
+ k_chunk_count,
+ input_channels,
+ filter_count * sizeof(MLAS_FP16),
+ reordered_filter.data(),
+ bias_data,
+ packed_filter
+ );
+
+ return true;
+}
+
+bool
+ConvolveSme(
+ const MLAS_CONV_PARAMETERS* parameters,
+ const MLAS_FP16* input,
+ const MLAS_FP16* filter,
+ bool filter_and_bias_are_packed,
+ const MLAS_FP16* bias,
+ MLAS_FP16* working_buffer,
+ MLAS_FP16* output,
+ MLAS_THREADPOOL* thread_pool
+)
+{
+ if (input == nullptr || filter == nullptr || output == nullptr ||
+ (!parameters->InputOutputChannelsLast && working_buffer == nullptr)) {
+ return false;
+ }
+
+ size_t d_kh = 0;
+ size_t d_kw = 0;
+ size_t output_size = 0;
+ if (!TryComputeKernelSize(parameters->DilationShape[0], parameters->KernelShape[0], d_kh) ||
+ !TryComputeKernelSize(parameters->DilationShape[1], parameters->KernelShape[1], d_kw) ||
+ !TryComputeOutputSize(
+ parameters->InputShape[0],
+ parameters->InputShape[1],
+ d_kh,
+ d_kw,
+ parameters->Padding[0],
+ parameters->Padding[1],
+ parameters->StrideShape[0],
+ parameters->StrideShape[1],
+ output_size)) {
+ return false;
+ }
+
+ std::vector input_nhwc;
+ std::vector pad;
+ std::vector indirection;
+ if (!PrepareLhsInput(parameters, input, output_size, input_nhwc, pad, indirection)) {
+ return false;
+ }
+
+ std::vector packed_filter_buffer;
+ const std::byte* packed_filter = reinterpret_cast(filter);
+ if (!filter_and_bias_are_packed) {
+ const std::array kernel_shape{
+ static_cast(parameters->KernelShape[0]),
+ static_cast(parameters->KernelShape[1])
+ };
+ const std::array dilation_shape{
+ static_cast(parameters->DilationShape[0]),
+ static_cast(parameters->DilationShape[1])
+ };
+ size_t packed_filter_size = 0;
+ if (!GetPackedFilterSize(
+ parameters->FilterCount,
+ parameters->InputChannels,
+ kernel_shape.data(),
+ dilation_shape.data(),
+ &packed_filter_size
+ )) {
+ return false;
+ }
+ packed_filter_buffer.resize(packed_filter_size);
+ if (!PackFilter(
+ parameters->FilterCount,
+ parameters->InputChannels,
+ kernel_shape.data(),
+ dilation_shape.data(),
+ filter,
+ bias,
+ packed_filter_buffer.data()
+ )) {
+ return false;
+ }
+ packed_filter = packed_filter_buffer.data();
+ }
+
+ const auto& imatmul = GetKleidiAIF16IMatmulUKernel();
+ const size_t base_n_step = imatmul.ukernel.get_n_step();
+ const size_t base_m_step = imatmul.ukernel.get_m_step();
+ size_t n_step = base_n_step;
+ size_t m_step = base_m_step;
+ const size_t filter_count = parameters->FilterCount;
+ const size_t input_channels = parameters->InputChannels;
+
+ std::array dim{
+ MlasDivRoundup(output_size, m_step),
+ MlasDivRoundup(filter_count, n_step)
+ };
+
+ size_t tile_count = 0;
+ if (mul_overflow_size_t_builtin(dim[0], dim[1], &tile_count)) {
+ return false;
+ }
+
+ const size_t required_tiles = std::min(
+ static_cast(MlasGetMaximumThreadCount(thread_pool)),
+ tile_count
+ );
+
+ if (required_tiles == 0) {
+ return false;
+ }
+
+ const size_t original_dim0 = dim[0];
+ const size_t original_dim1 = dim[1];
+ size_t scaled_dim0 = 0;
+ size_t scaled_dim1 = 0;
+ if (mul_overflow_size_t_builtin(required_tiles, original_dim0, &scaled_dim0) ||
+ mul_overflow_size_t_builtin(required_tiles, original_dim1, &scaled_dim1)) {
+ return false;
+ }
+
+ dim[0] = MlasDivRoundup(scaled_dim0, tile_count);
+ dim[1] = MlasDivRoundup(scaled_dim1, tile_count);
+
+ size_t new_m_step = 0;
+ size_t new_n_step = 0;
+ if (mul_overflow_size_t_builtin(m_step, MlasDivRoundup(MlasDivRoundup(output_size, dim[0]), m_step), &new_m_step) ||
+ mul_overflow_size_t_builtin(n_step, MlasDivRoundup(MlasDivRoundup(filter_count, dim[1]), n_step), &new_n_step)) {
+ return false;
+ }
+ m_step = new_m_step;
+ n_step = new_n_step;
+
+ dim[0] = MlasDivRoundup(output_size, m_step);
+ dim[1] = MlasDivRoundup(filter_count, n_step);
+ size_t finalized_tile_count = 0;
+ if (mul_overflow_size_t_builtin(dim[0], dim[1], &finalized_tile_count)) {
+ return false;
+ }
+
+ const float clamp_min = -std::numeric_limits::infinity();
+ const float clamp_max = std::numeric_limits::infinity();
+ size_t dst_stride = 0;
+ if (mul_overflow_size_t_builtin(filter_count, sizeof(MLAS_FP16), &dst_stride)) {
+ return false;
+ }
+
+ MLAS_FP16* destination = parameters->InputOutputChannelsLast ? output : working_buffer;
+
+ size_t kernel_chunk_count = 0;
+ size_t effective_k = 0;
+ if (mul_overflow_size_t_builtin(d_kh, d_kw, &kernel_chunk_count) ||
+ mul_overflow_size_t_builtin(kernel_chunk_count, input_channels, &effective_k)) {
+ return false;
+ }
+
+ const size_t full_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme(
+ output_size, kernel_chunk_count, input_channels
+ );
+ if (full_lhs_size == 0) {
+ return false;
+ }
+
+ const size_t maximum_lhs_chunk_bytes = SelectMaximumLhsChunkBytes(
+ full_lhs_size,
+ filter_count,
+ effective_k
+ );
+
+ if (maximum_lhs_chunk_bytes == 0 || full_lhs_size <= maximum_lhs_chunk_bytes) {
+ std::vector packed_lhs;
+ packed_lhs.resize(full_lhs_size);
+
+ KLEIDIAI_KERNEL_LOG("kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme"
+ << " M=" << output_size
+ << " k_chunk_count=" << kernel_chunk_count
+ << " k_chunk_length=" << input_channels);
+ kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme(
+ output_size,
+ kernel_chunk_count,
+ input_channels,
+ indirection.data(),
+ 0,
+ pad.data(),
+ packed_lhs.data()
+ );
+
+ std::atomic ok{true};
+ MlasTrySimpleParallel(thread_pool, static_cast(finalized_tile_count), [&](ptrdiff_t tid) {
+ if (!ok.load(std::memory_order_relaxed)) {
+ return;
+ }
+
+ const size_t m_idx = (static_cast(tid) / dim[1]) * m_step;
+ const size_t n_idx = (static_cast(tid) % dim[1]) * n_step;
+ const size_t tile_m = std::min(m_step, output_size - m_idx);
+ const size_t tile_n = std::min(n_step, filter_count - n_idx);
+
+ const std::byte* lhs_tile =
+ packed_lhs.data() + imatmul.ukernel.get_lhs_packed_offset(m_idx, kernel_chunk_count, input_channels);
+ const std::byte* rhs_tile =
+ packed_filter + imatmul.ukernel.get_rhs_packed_offset(n_idx, kernel_chunk_count, input_channels);
+
+ size_t dst_elements = 0;
+ size_t dst_bytes = 0;
+ if (mul_overflow_size_t_builtin(m_idx, filter_count, &dst_elements) ||
+ !TryAddSize(dst_elements, n_idx, dst_elements) ||
+ mul_overflow_size_t_builtin(dst_elements, sizeof(MLAS_FP16), &dst_bytes)) {
+ ok.store(false, std::memory_order_relaxed);
+ return;
+ }
+
+ std::byte* dst_tile = reinterpret_cast(destination) + dst_bytes;
+
+ KLEIDIAI_KERNEL_LOG(imatmul.name << " M=" << tile_m
+ << " N=" << tile_n
+ << " k_chunk_count=" << kernel_chunk_count
+ << " k_chunk_length=" << input_channels);
+ imatmul.ukernel.run_imatmul(
+ tile_m,
+ tile_n,
+ kernel_chunk_count,
+ input_channels,
+ lhs_tile,
+ rhs_tile,
+ dst_tile,
+ dst_stride,
+ clamp_min,
+ clamp_max
+ );
+ });
+
+ if (!ok.load(std::memory_order_relaxed)) {
+ return false;
+ }
+ } else {
+ const size_t bytes_per_m_step = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme(
+ base_m_step, kernel_chunk_count, input_channels
+ );
+ if (bytes_per_m_step == 0) {
+ return false;
+ }
+
+ const size_t max_m_steps_per_chunk = std::min(
+ MlasDivRoundup(output_size, base_m_step),
+ std::max(1, maximum_lhs_chunk_bytes / bytes_per_m_step)
+ );
+ size_t lhs_chunk_m = 0;
+ if (mul_overflow_size_t_builtin(base_m_step, max_m_steps_per_chunk, &lhs_chunk_m)) {
+ return false;
+ }
+
+ const size_t lhs_chunk_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme(
+ lhs_chunk_m, kernel_chunk_count, input_channels
+ );
+ if (lhs_chunk_size == 0 || lhs_chunk_size > std::vector().max_size()) {
+ return false;
+ }
+
+ const size_t m_chunk_count = MlasDivRoundup(output_size, lhs_chunk_m);
+ std::atomic ok{true};
+ MlasTrySimpleParallel(thread_pool, static_cast(m_chunk_count), [&](ptrdiff_t tid) {
+ if (!ok.load(std::memory_order_relaxed)) {
+ return;
+ }
+
+ const size_t global_m_idx = static_cast(tid) * lhs_chunk_m;
+ const size_t chunk_m = std::min(lhs_chunk_m, output_size - global_m_idx);
+ size_t indirection_offset = 0;
+ size_t indirection_tiles = 0;
+ if (mul_overflow_size_t_builtin(global_m_idx / base_m_step, kernel_chunk_count, &indirection_tiles) ||
+ mul_overflow_size_t_builtin(indirection_tiles, base_m_step, &indirection_offset)) {
+ ok.store(false, std::memory_order_relaxed);
+ return;
+ }
+
+ std::vector packed_lhs;
+ packed_lhs.resize(lhs_chunk_size);
+
+ KLEIDIAI_KERNEL_LOG("kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme"
+ << " M=" << chunk_m
+ << " k_chunk_count=" << kernel_chunk_count
+ << " k_chunk_length=" << input_channels);
+ kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme(
+ chunk_m,
+ kernel_chunk_count,
+ input_channels,
+ indirection.data() + indirection_offset,
+ 0,
+ pad.data(),
+ packed_lhs.data()
+ );
+
+ for (size_t n_tile_idx = 0; n_tile_idx < dim[1]; ++n_tile_idx) {
+ const size_t n_idx = n_tile_idx * n_step;
+ const size_t tile_n = std::min(n_step, filter_count - n_idx);
+ const std::byte* rhs_tile =
+ packed_filter + imatmul.ukernel.get_rhs_packed_offset(n_idx, kernel_chunk_count, input_channels);
+
+ size_t dst_elements = 0;
+ size_t dst_bytes = 0;
+ if (mul_overflow_size_t_builtin(global_m_idx, filter_count, &dst_elements) ||
+ !TryAddSize(dst_elements, n_idx, dst_elements) ||
+ mul_overflow_size_t_builtin(dst_elements, sizeof(MLAS_FP16), &dst_bytes)) {
+ ok.store(false, std::memory_order_relaxed);
+ return;
+ }
+ std::byte* dst_tile = reinterpret_cast(destination) + dst_bytes;
+
+ KLEIDIAI_KERNEL_LOG(imatmul.name << " M=" << chunk_m
+ << " N=" << tile_n
+ << " k_chunk_count=" << kernel_chunk_count
+ << " k_chunk_length=" << input_channels);
+ imatmul.ukernel.run_imatmul(
+ chunk_m,
+ tile_n,
+ kernel_chunk_count,
+ input_channels,
+ packed_lhs.data(),
+ rhs_tile,
+ dst_tile,
+ dst_stride,
+ clamp_min,
+ clamp_max
+ );
+ }
+ });
+
+ if (!ok.load(std::memory_order_relaxed)) {
+ return false;
+ }
+ }
+
+ if (parameters->InputOutputChannelsLast) {
+ return true;
+ }
+
+ MlasTranspose(working_buffer, output, output_size, filter_count, thread_pool);
+ return true;
+}
+
+} // namespace
+
+bool
+ MLASCALL
+ ArmKleidiAI::MlasHalfConvPrepare(
+ MLAS_CONV_PARAMETERS* Parameters,
+ size_t Dimensions,
+ size_t BatchCount,
+ size_t GroupCount,
+ size_t InputChannels,
+ const int64_t* InputShape,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape,
+ const int64_t* Padding,
+ const int64_t* StrideShape,
+ const int64_t* OutputShape,
+ size_t FilterCount,
+ const MLAS_ACTIVATION* Activation,
+ size_t* WorkingBufferSize,
+ float Beta,
+ bool InputOutputChannelsLast,
+ MLAS_THREADPOOL* ThreadPool,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
+ )
+{
+ MLAS_UNREFERENCED_PARAMETER(ThreadPool);
+
+ if (Parameters == nullptr ||
+ InputShape == nullptr ||
+ KernelShape == nullptr ||
+ DilationShape == nullptr ||
+ Padding == nullptr ||
+ StrideShape == nullptr ||
+ OutputShape == nullptr ||
+ WorkingBufferSize == nullptr ||
+ Dimensions != 2) {
+ return false;
+ }
+
+ if (BackendKernelSelectorConfig != nullptr &&
+ !BackendKernelSelectorConfig->use_kleidiai) {
+ KLEIDIAI_DEBUG_LOG("User explicitly disabled KleidiAI, returning false from MlasHalfConvPrepare.");
+ return false;
+ }
+
+ Parameters->BackendKernelSelectorConfig = BackendKernelSelectorConfig;
+ Parameters->Activation = Activation;
+ Parameters->Dimensions = Dimensions;
+ Parameters->BatchCount = BatchCount;
+ Parameters->GroupCount = GroupCount;
+ Parameters->InputChannels = InputChannels;
+ Parameters->FilterCount = FilterCount;
+ Parameters->Beta = Beta;
+ Parameters->InputOutputChannelsLast = InputOutputChannelsLast;
+
+ size_t input_size = 1;
+ size_t output_size = 1;
+ size_t k = InputChannels;
+ size_t dilated_kernel_size = InputChannels;
+ for (size_t dim = 0; dim < Dimensions; ++dim) {
+ if (InputShape[dim] <= 0 ||
+ OutputShape[dim] <= 0 ||
+ KernelShape[dim] <= 0 ||
+ DilationShape[dim] <= 0 ||
+ StrideShape[dim] <= 0 ||
+ Padding[dim] < 0 ||
+ Padding[dim + Dimensions] < 0) {
+ return false;
+ }
+
+ Parameters->InputShape[dim] = static_cast(InputShape[dim]);
+ Parameters->OutputShape[dim] = static_cast(OutputShape[dim]);
+ Parameters->KernelShape[dim] = static_cast(KernelShape[dim]);
+ Parameters->DilationShape[dim] = static_cast(DilationShape[dim]);
+ Parameters->Padding[dim] = static_cast(Padding[dim]);
+ Parameters->Padding[dim + Dimensions] = static_cast(Padding[dim + Dimensions]);
+ Parameters->StrideShape[dim] = static_cast(StrideShape[dim]);
+
+ size_t dilated_kernel_dim = 0;
+ if (!TryComputeKernelSize(Parameters->DilationShape[dim], Parameters->KernelShape[dim], dilated_kernel_dim)) {
+ return false;
+ }
+
+ if (mul_overflow_size_t_builtin(input_size, Parameters->InputShape[dim], &input_size) ||
+ mul_overflow_size_t_builtin(output_size, Parameters->OutputShape[dim], &output_size) ||
+ mul_overflow_size_t_builtin(k, Parameters->KernelShape[dim], &k) ||
+ mul_overflow_size_t_builtin(dilated_kernel_size, dilated_kernel_dim, &dilated_kernel_size)) {
+ return false;
+ }
+ }
+
+ Parameters->InputSize = input_size;
+ Parameters->OutputSize = output_size;
+ Parameters->K = k;
+ Parameters->ThreadCount = MlasGetMaximumThreadCount(ThreadPool);
+
+ if (!CheckCapabilitiesSme(Parameters)) {
+ return false;
+ }
+
+ size_t working_elements = 0;
+ if (mul_overflow_size_t_builtin(Parameters->OutputSize, Parameters->FilterCount, &working_elements)) {
+ return false;
+ }
+ *WorkingBufferSize = Parameters->InputOutputChannelsLast ? 0 : working_elements;
+ return true;
+}
+
+bool
+ MLASCALL
+ ArmKleidiAI::MlasHalfConv(
+ const MLAS_CONV_PARAMETERS* Parameters,
+ const MLAS_FP16* Input,
+ const MLAS_FP16* Filter,
+ bool FilterAndBiasArePacked,
+ const MLAS_FP16* Bias,
+ MLAS_FP16* WorkingBuffer,
+ MLAS_FP16* Output,
+ MLAS_THREADPOOL* ThreadPool
+ )
+{
+ if (!CheckCapabilitiesSme(Parameters)) {
+ return false;
+ }
+
+ return ConvolveSme(Parameters, Input, Filter, FilterAndBiasArePacked, Bias, WorkingBuffer, Output, ThreadPool);
+}
+
+size_t
+ MLASCALL
+ ArmKleidiAI::MlasHalfConvPackWeightsAndBiasSize(
+ size_t FilterCount,
+ size_t InputChannels,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape
+ )
+{
+ size_t packed_size = 0;
+ if (!GetPackedFilterSize(FilterCount, InputChannels, KernelShape, DilationShape, &packed_size)) {
+ return 0;
+ }
+ return packed_size;
+}
+
+bool
+ MLASCALL
+ ArmKleidiAI::MlasHalfConvPackWeightsAndBias(
+ size_t FilterCount,
+ size_t InputChannels,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape,
+ const MLAS_FP16* Filter,
+ const MLAS_FP16* Bias,
+ void* PackedWeightsAndBias,
+ MLAS_THREADPOOL* ThreadPool
+ )
+{
+ MLAS_UNREFERENCED_PARAMETER(ThreadPool);
+
+ return PackFilter(
+ FilterCount,
+ InputChannels,
+ KernelShape,
+ DilationShape,
+ Filter,
+ Bias,
+ PackedWeightsAndBias
+ );
+}
diff --git a/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp
new file mode 100644
index 0000000000000..bc10db249466d
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp
@@ -0,0 +1,283 @@
+//
+// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: MIT
+//
+
+#include
+#include
+#include
+#include
+#include "mlas.h"
+
+#include "mlasi_kleidiai.h"
+
+#include "kai_ukernel_interface.h"
+
+#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h"
+
+namespace {
+struct KaiHalfTlsBuffers {
+ std::vector lhs_converted;
+ std::vector rhs_converted;
+ std::vector bias_zero;
+ std::vector rhs_packed;
+};
+
+thread_local KaiHalfTlsBuffers g_kai_half_tls;
+
+template
+bool TryResizeVector(std::vector& buffer, size_t size) {
+ if (size > buffer.max_size()) {
+ return false;
+ }
+ buffer.resize(size);
+ return true;
+}
+
+static inline void ConvertFloatMatrixToHalf(
+ const float* src,
+ MLAS_FP16* dst,
+ size_t rows,
+ size_t cols,
+ size_t src_ld) {
+ for (size_t r = 0; r < rows; ++r) {
+ MlasConvertFloatToHalfBuffer(src + r * src_ld, dst + r * cols, cols);
+ }
+}
+} // namespace
+
+size_t
+MLASCALL
+ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(
+ CBLAS_TRANSPOSE TransA,
+ CBLAS_TRANSPOSE TransB,
+ size_t N,
+ size_t K
+) {
+ if (TransA != CblasNoTrans || TransB != CblasNoTrans || N == 0 || K == 0) {
+ return 0;
+ }
+
+ return kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(N, K);
+}
+
+bool
+MLASCALL
+ArmKleidiAI::MlasHalfGemmKleidiAIPackB(
+ CBLAS_TRANSPOSE TransA,
+ CBLAS_TRANSPOSE TransB,
+ size_t N,
+ size_t K,
+ const MLAS_FP16* B,
+ size_t ldb,
+ void* PackedB
+) {
+ if (TransA != CblasNoTrans || TransB != CblasNoTrans) {
+ return false;
+ }
+
+ if (PackedB == nullptr || B == nullptr || N == 0 || K == 0) {
+ return false;
+ }
+
+ const size_t packed_rhs_size = ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(TransA, TransB, N, K);
+ if (packed_rhs_size == 0) {
+ return false;
+ }
+
+ std::vector zero_bias(N, MLAS_FP16::FromBits(0));
+
+ size_t ldb_bytes = 0;
+ if (mul_overflow_size_t_builtin(ldb, sizeof(MLAS_FP16), &ldb_bytes)) {
+ return false;
+ }
+
+ const auto& hgemm = GetKleidiAIHgemmUKernel();
+ kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(
+ 1, N, K, hgemm.ukernel.get_nr(), hgemm.ukernel.get_kr(), hgemm.ukernel.get_sr(), ldb_bytes,
+ B,
+ zero_bias.data(),
+ nullptr,
+ PackedB,
+ 0,
+ nullptr);
+
+ return true;
+}
+
+bool
+MLASCALL
+ArmKleidiAI::MlasHalfGemmBatch(
+ size_t M,
+ size_t N,
+ size_t K,
+ size_t BatchN,
+ const MLAS_HALF_GEMM_DATA_PARAMS* DataParams,
+ MLAS_THREADPOOL* ThreadPool,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
+) {
+ if (BatchN == 0 || M == 0 || N == 0) {
+ return true;
+ }
+ if (K == 0) {
+ return false;
+ }
+ if (DataParams == nullptr) {
+ return false;
+ }
+
+ MLAS_UNREFERENCED_PARAMETER(ThreadPool);
+ MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig);
+ // Validate all batch entries up front so we never partially execute and then
+ // fall back (which would corrupt results for the already-written outputs).
+ bool needs_rhs_packing = false;
+ for (size_t b = 0; b < BatchN; ++b) {
+ const auto& data = DataParams[b];
+ if (data.OutputProcessor != nullptr) {
+ return false;
+ }
+ if (data.BIsBackendNativePacked && data.Bias != nullptr) {
+ return false;
+ }
+ if (data.BIsBackendNativePacked && data.ldb != 0) {
+ return false;
+ }
+ // Native-packed RHS is consumed directly below. Only allocate the
+ // runtime RHS packing scratch when at least one batch entry needs it.
+ needs_rhs_packing = needs_rhs_packing || !data.BIsBackendNativePacked;
+ }
+
+ const auto& hgemm = GetKleidiAIHgemmUKernel();
+ const size_t n_step = hgemm.ukernel.get_n_step();
+ const size_t nr = hgemm.ukernel.get_nr();
+ const size_t kr = hgemm.ukernel.get_kr();
+ const size_t sr = hgemm.ukernel.get_sr();
+ KLEIDIAI_KERNEL_LOG(hgemm.name);
+
+ const size_t packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(N, K);
+ if (packed_rhs_size == 0) {
+ return false;
+ }
+
+ const float clamp_min = -std::numeric_limits::infinity();
+ const float clamp_max = std::numeric_limits::infinity();
+
+ if (needs_rhs_packing && !TryResizeVector(g_kai_half_tls.rhs_packed, packed_rhs_size)) {
+ return false;
+ }
+
+ for (size_t b = 0; b < BatchN; ++b) {
+ const auto& data = DataParams[b];
+
+ const MLAS_FP16* lhs_base = reinterpret_cast(data.A);
+ const MLAS_FP16* rhs_base = reinterpret_cast(data.B);
+ const std::byte* rhs_packed = nullptr;
+ size_t lhs_ld = data.lda;
+ size_t rhs_ld = data.ldb;
+
+ if (data.AIsfp32) {
+ size_t lhs_elements = 0;
+ if (mul_overflow_size_t_builtin(M, K, &lhs_elements) ||
+ !TryResizeVector(g_kai_half_tls.lhs_converted, lhs_elements)) {
+ return false;
+ }
+ ConvertFloatMatrixToHalf(
+ reinterpret_cast(data.A),
+ g_kai_half_tls.lhs_converted.data(),
+ M, K, data.lda);
+ lhs_base = g_kai_half_tls.lhs_converted.data();
+ lhs_ld = K;
+ }
+
+ if (data.BIsBackendNativePacked) {
+ rhs_packed = reinterpret_cast(data.B);
+ } else if (data.ldb == 0) {
+ // Prepacked B from MlasHalfGemmPackB/MlasHalfGemmConvertPackB.
+ // For the current default halfgemm dispatch this is a row-major
+ // fp16 KxN buffer with leading dimension N. It is not the native
+ // KleidiAI RHS-packed layout, so this path falls back to packing
+ // it into KleidiAI format before execution.
+ rhs_ld = N;
+ } else if (data.BIsfp32) {
+ size_t rhs_elements = 0;
+ if (mul_overflow_size_t_builtin(K, N, &rhs_elements) ||
+ !TryResizeVector(g_kai_half_tls.rhs_converted, rhs_elements)) {
+ return false;
+ }
+ ConvertFloatMatrixToHalf(
+ reinterpret_cast(data.B),
+ g_kai_half_tls.rhs_converted.data(),
+ K, N, data.ldb);
+ rhs_base = g_kai_half_tls.rhs_converted.data();
+ rhs_ld = N;
+ }
+
+ if (rhs_packed == nullptr) {
+ auto* rhs_packed_buffer = g_kai_half_tls.rhs_packed.data();
+
+ size_t ldb_bytes = 0;
+ if (mul_overflow_size_t_builtin(rhs_ld, sizeof(MLAS_FP16), &ldb_bytes)) {
+ return false;
+ }
+ if (data.Bias == nullptr) {
+ if (!TryResizeVector(g_kai_half_tls.bias_zero, N)) {
+ return false;
+ }
+ std::fill(g_kai_half_tls.bias_zero.begin(), g_kai_half_tls.bias_zero.end(), MLAS_FP16::FromBits(0));
+ }
+
+ kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(
+ 1, N, K, nr, kr, sr, ldb_bytes,
+ rhs_base,
+ data.Bias != nullptr ? data.Bias : g_kai_half_tls.bias_zero.data(),
+ nullptr,
+ rhs_packed_buffer,
+ 0,
+ nullptr);
+ rhs_packed = rhs_packed_buffer;
+ }
+
+ size_t lda_bytes = 0;
+ if (mul_overflow_size_t_builtin(lhs_ld, sizeof(MLAS_FP16), &lda_bytes)) {
+ return false;
+ }
+ size_t dst_stride_bytes = 0;
+ if (mul_overflow_size_t_builtin(data.ldc, sizeof(MLAS_FP16), &dst_stride_bytes)) {
+ return false;
+ }
+
+ MlasTrySimpleParallel(ThreadPool, static_cast(M), [&](ptrdiff_t m_idx) {
+ const size_t m = static_cast(m_idx);
+ const auto* lhs = lhs_base + m * lhs_ld;
+ auto* dst_row = data.C + m * data.ldc;
+ const auto* rhs_packed_base = rhs_packed;
+ // The selected KleidiAI HGEMM micro-kernel is 1xN by design.
+ // We execute one output row per call and parallelize over rows.
+ constexpr size_t kernel_m = 1;
+ for (size_t n_idx = 0; n_idx < N; n_idx += n_step) {
+ const size_t tile_n = std::min(n_step, N - n_idx);
+ const auto* rhs_tile = rhs_packed_base + hgemm.ukernel.get_rhs_packed_offset(n_idx, K);
+ auto* dst_tile = reinterpret_cast(
+ reinterpret_cast(dst_row) +
+ hgemm.ukernel.get_dst_offset(0, n_idx, dst_stride_bytes));
+
+ hgemm.ukernel.run_matmul(
+ kernel_m,
+ tile_n,
+ K,
+ lhs,
+ lda_bytes,
+ rhs_tile,
+ dst_tile,
+ dst_stride_bytes,
+ sizeof(MLAS_FP16),
+ clamp_min,
+ clamp_max);
+ }
+ });
+
+ }
+
+ return true;
+}
diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
index 3c9f398ece887..207935b8a8bd8 100644
--- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
+++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
@@ -7,6 +7,7 @@
#pragma once
#include "../mlasi.h"
+#include
#include
// Fix to ensure compatibility with MSVC build
@@ -224,6 +225,102 @@ MlasConvSymmetricChannelsLast2DFloatPackW(
size_t PackedFilterGroupStride,
MLAS_THREADPOOL* ThreadPool
);
+
+bool
+MLASCALL
+MlasHalfGemmBatch(
+ size_t M,
+ size_t N,
+ size_t K,
+ size_t BatchN,
+ const MLAS_HALF_GEMM_DATA_PARAMS* DataParams,
+ MLAS_THREADPOOL* ThreadPool,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
+ );
+
+size_t
+MLASCALL
+MlasHalfGemmKleidiAIPackBSize(
+ CBLAS_TRANSPOSE TransA,
+ CBLAS_TRANSPOSE TransB,
+ size_t N,
+ size_t K
+ );
+
+// Packs B into the native KleidiAI RHS-packed layout for the supported
+// halfgemm configuration only. This differs from the generic MLAS halfgemm
+// prepacked-B format produced by MlasHalfGemmPackB and
+// MlasHalfGemmConvertPackB, so generic MLAS prepacked weights may need to be
+// repacked into this layout before running the KleidiAI halfgemm path.
+// Unsupported transpose combinations return false/0 so the caller can fall
+// back to the generic MLAS path.
+bool
+MLASCALL
+MlasHalfGemmKleidiAIPackB(
+ CBLAS_TRANSPOSE TransA,
+ CBLAS_TRANSPOSE TransB,
+ size_t N,
+ size_t K,
+ const MLAS_FP16* B,
+ size_t ldb,
+ void* PackedB
+ );
+
+bool
+MLASCALL
+MlasHalfConvPrepare(MLAS_CONV_PARAMETERS* Parameters,
+ size_t Dimensions,
+ size_t BatchCount,
+ size_t GroupCount,
+ size_t InputChannels,
+ const int64_t* InputShape,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape,
+ const int64_t* Padding,
+ const int64_t* StrideShape,
+ const int64_t* OutputShape,
+ size_t FilterCount,
+ const MLAS_ACTIVATION* Activation,
+ size_t* WorkingBufferSize,
+ float Beta,
+ bool InputOutputChannelsLast,
+ MLAS_THREADPOOL* ThreadPool,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig);
+
+bool
+MLASCALL
+MlasHalfConv(
+ const MLAS_CONV_PARAMETERS* Parameters,
+ const MLAS_FP16* Input,
+ const MLAS_FP16* Filter,
+ bool FilterAndBiasArePacked,
+ const MLAS_FP16* Bias,
+ MLAS_FP16* WorkingBuffer,
+ MLAS_FP16* Output,
+ MLAS_THREADPOOL* ThreadPool
+ );
+
+size_t
+MLASCALL
+MlasHalfConvPackWeightsAndBiasSize(
+ size_t FilterCount,
+ size_t InputChannels,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape
+ );
+
+bool
+MLASCALL
+MlasHalfConvPackWeightsAndBias(
+ size_t FilterCount,
+ size_t InputChannels,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape,
+ const MLAS_FP16* Filter,
+ const MLAS_FP16* Bias,
+ void* PackedWeightsAndBias,
+ MLAS_THREADPOOL* ThreadPool
+ );
}
/*++
@@ -259,3 +356,42 @@ inline bool mul_overflow_size_t_builtin(size_t a, size_t b, size_t* out) {
if (out) *out = a * b;
return false;
}
+
+/*++
+
+Routine Description:
+
+ This routine adds two size_t values and returns the sum when no wraparound
+ occurs. Uses __builtin_add_overflow if available on the current system and
+ falls back to a default implementation otherwise.
+
+Arguments:
+
+ a - Supplies the first number to be added.
+
+ b - Supplies the second number to be added.
+
+ out - Supplies a size_t reference which receives the result in success
+ cases.
+
+Return Value:
+
+ Returns true if the operation was successful
+ Returns false if wraparound of size_t was detected
+
+--*/
+inline bool TryAddSize(size_t a, size_t b, size_t& out)
+{
+#if defined(__has_builtin)
+# if __has_builtin(__builtin_add_overflow)
+ return !__builtin_add_overflow(a, b, &out);
+# endif
+#endif
+
+ if (a > (std::numeric_limits::max)() - b) {
+ return false;
+ }
+
+ out = a + b;
+ return true;
+}
diff --git a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp
index 9e70b0742217a..2e5306d442476 100644
--- a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp
+++ b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp
@@ -101,7 +101,7 @@ MLASCALL
ArmKleidiAI::MlasDynamicQGemmBatch(
const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape,
const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams,
- const size_t BatchSize,
+ const size_t BatchN,
MLAS_THREADPOOL* ThreadPool
) {
@@ -112,7 +112,7 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
size_t m_step = qgemm_gemm.ukernel.get_m_step();
size_t n_step = qgemm_gemm.ukernel.get_n_step();
- if (BatchSize == 0 || Shape.M == 0 || Shape.N == 0 || Shape.K == 0) {
+ if (BatchN == 0 || Shape.M == 0 || Shape.N == 0 || Shape.K == 0) {
return;
}
@@ -123,7 +123,7 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
MLAS_THROW_EX(std::runtime_error, "Dynamic QGEMM requires valid DataParams.");
}
- for (size_t batch_idx = 0; batch_idx < BatchSize; ++batch_idx) {
+ for (size_t batch_idx = 0; batch_idx < BatchN; ++batch_idx) {
const auto& params = DataParams[batch_idx];
if (params.A == nullptr) {
@@ -151,24 +151,24 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
const size_t LhsPackedStride = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr);
std::byte* LhsPackedData = nullptr;
- if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchSize) {
+ if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchN) {
- g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize);
+ g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchN);
}
- g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize);
+ g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchN);
LhsPackedData = g_kai_tls_qgemm.lhs_packed.data();
// Per-batch table of LHS base pointers.
- if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchSize) {
+ if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchN) {
- g_kai_tls_qgemm.lhs_base_table.reserve(BatchSize);
+ g_kai_tls_qgemm.lhs_base_table.reserve(BatchN);
}
- g_kai_tls_qgemm.lhs_base_table.resize(BatchSize);
+ g_kai_tls_qgemm.lhs_base_table.resize(BatchN);
// Capture the shared batch table pointer so worker threads use the same backing storage.
const std::byte** tls_lhs_base = g_kai_tls_qgemm.lhs_base_table.data();
// B batches require no packing.
// We have already decided the matmul variant we are using before having values for M, N, and K.
- MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) {
+ MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t batch_idx) {
std::byte* lhs = nullptr;
if (DataParams[batch_idx].Workspace && DataParams[batch_idx].WorkspaceSize >= LhsPackedStride) {
@@ -184,7 +184,7 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
// Tile iteration dimensions.
std::array dim;
- dim[0] = BatchSize; // B
+ dim[0] = BatchN; // B
dim[1] = MlasDivRoundup(Shape.M, m_step); // M
dim[2] = MlasDivRoundup(Shape.N, n_step); // N
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
index dbb414505ff38..b06761befd48a 100644
--- a/onnxruntime/core/mlas/lib/mlasi.h
+++ b/onnxruntime/core/mlas/lib/mlasi.h
@@ -967,6 +967,90 @@ bool
void* PackedB);
#endif
+typedef
+bool
+(MLASCALL MLAS_HALF_GEMM_BATCH_OVERRIDE)(
+ size_t M,
+ size_t N,
+ size_t K,
+ size_t BatchN,
+ const MLAS_HALF_GEMM_DATA_PARAMS* DataParams,
+ MLAS_THREADPOOL* ThreadPool,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig);
+
+typedef
+size_t
+(MLASCALL MLAS_HALF_GEMM_PACK_B_SIZE_OVERRIDE)(
+ CBLAS_TRANSPOSE TransA,
+ CBLAS_TRANSPOSE TransB,
+ size_t N,
+ size_t K);
+
+typedef
+bool
+(MLASCALL MLAS_HALF_GEMM_PACK_B_OVERRIDE)(
+ CBLAS_TRANSPOSE TransA,
+ CBLAS_TRANSPOSE TransB,
+ size_t N,
+ size_t K,
+ const MLAS_FP16* B,
+ size_t ldb,
+ void* PackedB);
+
+typedef
+bool
+(MLASCALL MLAS_HALF_CONV_PREPARE_OVERRIDE)(
+ MLAS_CONV_PARAMETERS* Parameters,
+ size_t Dimensions,
+ size_t BatchCount,
+ size_t GroupCount,
+ size_t InputChannels,
+ const int64_t* InputShape,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape,
+ const int64_t* Padding,
+ const int64_t* StrideShape,
+ const int64_t* OutputShape,
+ size_t FilterCount,
+ const MLAS_ACTIVATION* Activation,
+ size_t* WorkingBufferSize,
+ float Beta,
+ bool InputOutputChannelsLast,
+ MLAS_THREADPOOL* ThreadPool,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig);
+
+typedef
+bool
+(MLASCALL MLAS_HALF_CONV_OVERRIDE)(
+ const MLAS_CONV_PARAMETERS* Parameters,
+ const MLAS_FP16* Input,
+ const MLAS_FP16* Filter,
+ bool FilterAndBiasArePacked,
+ const MLAS_FP16* Bias,
+ MLAS_FP16* WorkingBuffer,
+ MLAS_FP16* Output,
+ MLAS_THREADPOOL* ThreadPool);
+
+typedef
+size_t
+(MLASCALL MLAS_HALF_CONV_PACK_WEIGHTS_AND_BIAS_SIZE_OVERRIDE)(
+ size_t FilterCount,
+ size_t InputChannels,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape);
+
+typedef
+bool
+(MLASCALL MLAS_HALF_CONV_PACK_WEIGHTS_AND_BIAS_OVERRIDE)(
+ size_t FilterCount,
+ size_t InputChannels,
+ const int64_t* KernelShape,
+ const int64_t* DilationShape,
+ const MLAS_FP16* Filter,
+ const MLAS_FP16* Bias,
+ void* PackedWeightsAndBias,
+ MLAS_THREADPOOL* ThreadPool);
+
extern "C" {
#if defined(MLAS_TARGET_AMD64_IX86)
@@ -1485,6 +1569,15 @@ struct MLAS_PLATFORM {
MLAS_DYNAMIC_QGEMM_BATCH_OVERRIDE* MlasDynamicQGemmBatchOverride = nullptr;
MLAS_DYNAMIC_QGEMM_PACK_B_SIZE_OVERRIDE* MlasDynamicQGemmPackBSizeOverride = nullptr;
MLAS_DYNAMIC_QGEMM_PACK_B_OVERRIDE* MlasDynamicQGemmPackBOverride = nullptr;
+ // MLAS HalfGemm overrides
+ MLAS_HALF_GEMM_BATCH_OVERRIDE* MlasHalfGemmBatchOverride = nullptr;
+ MLAS_HALF_GEMM_PACK_B_SIZE_OVERRIDE* MlasHalfGemmPackBSizeOverride = nullptr;
+ MLAS_HALF_GEMM_PACK_B_OVERRIDE* MlasHalfGemmPackBOverride = nullptr;
+ // MLAS HalfConv overrides
+ MLAS_HALF_CONV_PREPARE_OVERRIDE* MlasHalfConvPrepareOverride = nullptr;
+ MLAS_HALF_CONV_OVERRIDE* MlasHalfConvOverride = nullptr;
+ MLAS_HALF_CONV_PACK_WEIGHTS_AND_BIAS_SIZE_OVERRIDE* MlasHalfConvPackWeightsAndBiasSizeOverride = nullptr;
+ MLAS_HALF_CONV_PACK_WEIGHTS_AND_BIAS_OVERRIDE* MlasHalfConvPackWeightsAndBiasOverride = nullptr;
// MLAS Conv overrides
MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr;
MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr;
diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp
index 466fa9a3e9497..e975787877462 100644
--- a/onnxruntime/core/mlas/lib/platform.cpp
+++ b/onnxruntime/core/mlas/lib/platform.cpp
@@ -677,13 +677,20 @@ Return Value:
}
#if defined(USE_KLEIDIAI)
- if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){
+ if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME() || MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()){
this->MlasSGemmBatchOverride = ArmKleidiAI::MlasGemmBatch;
this->MlasSGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize;
this->MlasSGemmPackBOverride = ArmKleidiAI::MlasGemmPackB;
this->MlasDynamicQGemmBatchOverride = ArmKleidiAI::MlasDynamicQGemmBatch;
this->MlasDynamicQGemmPackBSizeOverride = ArmKleidiAI::MlasDynamicQGemmPackBSize;
this->MlasDynamicQGemmPackBOverride = ArmKleidiAI::MlasDynamicQGemmPackB;
+ this->MlasHalfGemmBatchOverride = ArmKleidiAI::MlasHalfGemmBatch;
+ this->MlasHalfGemmPackBSizeOverride = ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize;
+ this->MlasHalfGemmPackBOverride = ArmKleidiAI::MlasHalfGemmKleidiAIPackB;
+ this->MlasHalfConvPrepareOverride = ArmKleidiAI::MlasHalfConvPrepare;
+ this->MlasHalfConvOverride = ArmKleidiAI::MlasHalfConv;
+ this->MlasHalfConvPackWeightsAndBiasSizeOverride = ArmKleidiAI::MlasHalfConvPackWeightsAndBiasSize;
+ this->MlasHalfConvPackWeightsAndBiasOverride = ArmKleidiAI::MlasHalfConvPackWeightsAndBias;
this->MlasConvPrepareOverride = ArmKleidiAI::MlasConvPrepare;
this->MlasConvOverride = ArmKleidiAI::MlasConv;
#if defined(__aarch64__) && defined(__linux__)
diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc
index c8a0bbaba9df5..38f045f913124 100644
--- a/onnxruntime/core/optimizer/insert_cast_transformer.cc
+++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc
@@ -3,28 +3,422 @@
#include "core/optimizer/insert_cast_transformer.h"
#include "core/framework/data_types.h"
-#include "core/graph/graph_utils.h"
#include "core/framework/compute_capability.h"
+#include "core/graph/graph_utils.h"
#include "core/graph/indexed_sub_graph.h"
+#include "core/mlas/inc/mlas.h"
+
+#include
+#include
+#include
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
+void InsertCastTransformer::RecordPartitionAssignment(const onnxruntime::Graph& graph,
+ const onnxruntime::Node& node,
+ const std::string& ep_type) const {
+ if (!on_partition_assignment_fn_) {
+ return;
+ }
+
+ auto indexed_subgraph = std::make_unique();
+ indexed_subgraph->nodes.push_back(node.Index());
+ const ComputeCapability compute_capability{std::move(indexed_subgraph)};
+ on_partition_assignment_fn_(graph, compute_capability, ep_type);
+}
+
+template
+static bool IsTensorOfType(const NodeArg& node_arg) {
+ const auto* type_proto = node_arg.TypeAsProto();
+ return node_arg.Exists() &&
+ type_proto != nullptr &&
+ DataTypeImpl::TypeFromProto(*type_proto) == DataTypeImpl::GetTensorType();
+}
+
+template
+static bool HasTensorArgOfType(const NodeArgs& node_args) {
+ return std::any_of(node_args.cbegin(), node_args.cend(),
+ [](const NodeArg* node_arg) {
+ return node_arg != nullptr && IsTensorOfType(*node_arg);
+ });
+}
+
static bool IsMLFloat16Tensor(const NodeArg& node_arg) {
- // Type() will return nullptr if node_arg.Exists() is true so don't need an additional check for that
- return node_arg.Type() != nullptr &&
- DataTypeImpl::TypeFromProto(*node_arg.TypeAsProto()) == DataTypeImpl::GetTensorType();
+ return IsTensorOfType(node_arg);
}
-bool InsertCastTransformer::NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const {
- // If the node's input is float16 and currently the node is not assigned to any EP
- // we need to insert a cast to float, and put the node on CPU for default behavior.
- // We don't cast a node with a subgraph as we'd need to do a lot more checking of the subgraph inputs
- // (both explicit and implicit) and contents to determine if it was safe to do so.
- // TODO: a better check is to check does the CPU kernel with float exist or not.
+static bool HasCpuFloat32FallbackKernel(
+ const onnxruntime::Node& node,
+ const InlinedVector>& cpu_kernel_registries,
+ const logging::Logger& logger);
+
+bool InsertCastTransformer::NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input,
+ const logging::Logger& logger) const {
+ // Returns true when this input is an fp16 input to an unassigned node that is eligible
+ // for the cast-to-fp32 fallback path.
+ //
+ // Nodes with subgraphs are excluded because rewriting explicit and implicit subgraph
+ // inputs safely requires additional checks of the subgraph boundaries and contents.
return node->GetExecutionProviderType().empty() &&
!node->ContainsSubgraph() &&
- IsMLFloat16Tensor(*input);
+ IsMLFloat16Tensor(*input) &&
+ HasCpuFloat32FallbackKernel(*node, cpu_kernel_registries_, logger);
+}
+
+static bool HasFp16IO(const onnxruntime::Node& node) {
+ return HasTensorArgOfType(node.InputDefs()) ||
+ HasTensorArgOfType(node.OutputDefs());
+}
+
+static bool HasFp16Input(const onnxruntime::Node& node) {
+ return HasTensorArgOfType(node.InputDefs());
+}
+
+static bool IsCpuFp16OptInPolicyOp(const onnxruntime::Node& node) {
+ // Standard-domain ops whose CPU fp16 kernels are governed by session.enable_cpu_fp16
+ // and the fp32 fallback heuristic. Existing CPU fp16 kernels are intentionally not
+ // included in this opt-in/fallback policy.
+ return node.Domain().empty() &&
+ (node.OpType() == "MatMul" || node.OpType() == "Gemm");
+}
+
+static std::optional DimValue(const ONNX_NAMESPACE::TensorShapeProto* shape, int dim_idx) {
+ if (!shape || dim_idx < 0 || dim_idx >= shape->dim_size()) {
+ return std::nullopt;
+ }
+
+ const auto& dim = shape->dim(dim_idx);
+ if (!dim.has_dim_value()) {
+ return std::nullopt;
+ }
+
+ return dim.dim_value();
+}
+
+static std::optional ProductOfDimsExceptLast(const ONNX_NAMESPACE::TensorShapeProto* shape) {
+ if (!shape || shape->dim_size() < 2) {
+ return std::nullopt;
+ }
+
+ int64_t product = 1;
+ for (int i = 0; i < shape->dim_size() - 1; ++i) {
+ const auto dim = DimValue(shape, i);
+ if (!dim || *dim < 0) {
+ return std::nullopt;
+ }
+
+ if (*dim != 0 && product > std::numeric_limits::max() / *dim) {
+ return std::nullopt;
+ }
+ product *= *dim;
+ }
+
+ return product;
+}
+
+static std::optional GetAttributeInt(const onnxruntime::Node& node,
+ const std::string& attr_name,
+ int64_t default_value) {
+ const auto& attrs = node.GetAttributes();
+ const auto attr = attrs.find(attr_name);
+ if (attr == attrs.end()) {
+ return default_value;
+ }
+
+ if (attr->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_INT) {
+ return std::nullopt;
+ }
+
+ return attr->second.i();
+}
+
+struct CpuFp16GemmShape {
+ int64_t M;
+ int64_t N;
+ int64_t K;
+};
+
+static std::optional GetMatMulShapeForCpuFp16Heuristic(const onnxruntime::Node& node) {
+ const auto& inputs = node.InputDefs();
+ if (inputs.size() < 2 || !inputs[0] || !inputs[1]) {
+ return std::nullopt;
+ }
+
+ const auto* a_shape = inputs[0]->Shape();
+ const auto* b_shape = inputs[1]->Shape();
+ if (!a_shape || !b_shape || a_shape->dim_size() < 2 || b_shape->dim_size() < 2) {
+ return std::nullopt;
+ }
+
+ const auto b_inner_dim = DimValue(b_shape, b_shape->dim_size() - 2);
+ const auto b_size_to_inner_dim = ProductOfDimsExceptLast(b_shape);
+ // Match MatMulComputeHelper: RHS shapes like [1, ..., 1, K, N] also flatten
+ // the left input because the leading RHS dims are only padding.
+ const bool flattens_left = a_shape->dim_size() >= b_shape->dim_size() &&
+ b_inner_dim && b_size_to_inner_dim &&
+ *b_size_to_inner_dim == *b_inner_dim;
+
+ // Match MatMulComputeHelper: effectively 2D RHS flattens the left input, while
+ // genuinely batched RHS uses the per-GEMM row count.
+ auto M = flattens_left ? ProductOfDimsExceptLast(a_shape)
+ : DimValue(a_shape, a_shape->dim_size() - 2);
+ auto K = DimValue(a_shape, a_shape->dim_size() - 1);
+ auto N = DimValue(b_shape, b_shape->dim_size() - 1);
+ if (!M || !N || !K) {
+ return std::nullopt;
+ }
+
+ return CpuFp16GemmShape{*M, *N, *K};
+}
+
+struct CpuFp16MatMulRhsShape {
+ int64_t N;
+ int64_t K;
+};
+
+static std::optional GetMatMulRhsShapeForCpuFp16Heuristic(const onnxruntime::Node& node) {
+ const auto& inputs = node.InputDefs();
+ if (inputs.size() < 2 || !inputs[1]) {
+ return std::nullopt;
+ }
+
+ const auto* b_shape = inputs[1]->Shape();
+ if (!b_shape || b_shape->dim_size() != 2) {
+ return std::nullopt;
+ }
+
+ auto K = DimValue(b_shape, b_shape->dim_size() - 2);
+ auto N = DimValue(b_shape, b_shape->dim_size() - 1);
+ if (!N || !K || *N < 0 || *K < 0) {
+ return std::nullopt;
+ }
+
+ if (*K != 0 && *N > std::numeric_limits::max() / *K) {
+ return std::nullopt;
+ }
+
+ return CpuFp16MatMulRhsShape{*N, *K};
+}
+
+static std::optional GetGemmShapeForCpuFp16Heuristic(const onnxruntime::Node& node) {
+ const auto& inputs = node.InputDefs();
+ if (inputs.size() < 2 || !inputs[0] || !inputs[1]) {
+ return std::nullopt;
+ }
+
+ const auto* a_shape = inputs[0]->Shape();
+ const auto* b_shape = inputs[1]->Shape();
+ if (!a_shape || !b_shape || a_shape->dim_size() != 2 || b_shape->dim_size() != 2) {
+ return std::nullopt;
+ }
+
+ const auto trans_a = GetAttributeInt(node, "transA", 0);
+ const auto trans_b = GetAttributeInt(node, "transB", 0);
+ if (!trans_a || !trans_b) {
+ return std::nullopt;
+ }
+
+ const auto M = DimValue(a_shape, *trans_a ? 1 : 0);
+ const auto K = DimValue(a_shape, *trans_a ? 0 : 1);
+ const auto N = DimValue(b_shape, *trans_b ? 0 : 1);
+ if (!M || !N || !K) {
+ return std::nullopt;
+ }
+
+ return CpuFp16GemmShape{*M, *N, *K};
+}
+
+static bool ShouldKeepNativeCpuFp16ForMatMulOrGemm(
+ const Graph& graph,
+ const onnxruntime::Node& node,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG& mlas_backend_kernel_selector_config) {
+ constexpr int64_t kMaxNativeFp16GemvM = 4;
+ constexpr int64_t kMinNativeFp16ConstantMatMulNK = 512 * 1024;
+ constexpr int64_t kMinNativeFp16NK = 1024 * 1024;
+
+ std::optional shape;
+ if (node.OpType() == "MatMul") {
+ const auto& inputs = node.InputDefs();
+ const bool has_constant_rhs =
+ inputs.size() > 1 && inputs[1] && graph_utils::IsInitializer(graph, inputs[1]->Name(), true);
+ if (has_constant_rhs) {
+ const auto rhs_shape = GetMatMulRhsShapeForCpuFp16Heuristic(node);
+ if (!rhs_shape) {
+ return false;
+ }
+
+ const auto nk = rhs_shape->N * rhs_shape->K;
+ if (nk < kMinNativeFp16ConstantMatMulNK) {
+ return false;
+ }
+
+ return MlasHalfGemmNativePackBSize(CblasNoTrans, CblasNoTrans,
+ static_cast(rhs_shape->N),
+ static_cast(rhs_shape->K),
+ &mlas_backend_kernel_selector_config) != 0;
+ }
+
+ shape = GetMatMulShapeForCpuFp16Heuristic(node);
+ } else if (node.OpType() == "Gemm") {
+ shape = GetGemmShapeForCpuFp16Heuristic(node);
+ }
+
+ if (!shape || shape->M < 0 || shape->N < 0 || shape->K < 0) {
+ return false;
+ }
+
+ if (shape->K != 0 && shape->N > std::numeric_limits::max() / shape->K) {
+ return false;
+ }
+
+ const auto nk = shape->N * shape->K;
+
+ if (node.OpType() == "MatMul") {
+ return shape->M <= kMaxNativeFp16GemvM &&
+ nk >= kMinNativeFp16NK &&
+ MlasHGemmSupported(CblasNoTrans, CblasNoTrans);
+ }
+
+ if (shape->M > kMaxNativeFp16GemvM) {
+ return false;
+ }
+
+ return nk >= kMinNativeFp16NK;
+}
+
+static bool KernelDomainMatchesNode(const KernelDef& kernel_def, const onnxruntime::Node& node) {
+ const auto& kernel_domain = kernel_def.Domain();
+ const auto& node_domain = node.Domain();
+ return kernel_domain == node_domain ||
+ (kernel_domain == kOnnxDomainAlias && node_domain.empty());
+}
+
+static bool KernelVersionMatchesNode(const KernelDef& kernel_def, const onnxruntime::Node& node) {
+ const auto [kernel_start_version, kernel_end_version] = kernel_def.SinceVersion();
+ return kernel_start_version <= node.SinceVersion() &&
+ kernel_end_version >= node.SinceVersion();
+}
+
+static bool IsKernelTypeCompatible(gsl::span enabled_types,
+ const ONNX_NAMESPACE::TypeProto& actual_type,
+ bool replace_fp16_with_float) {
+ const auto* type_to_check = &actual_type;
+ TypeProto float_tensor_type;
+ if (replace_fp16_with_float &&
+ DataTypeImpl::TypeFromProto(actual_type) == DataTypeImpl::GetTensorType()) {
+ float_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
+ type_to_check = &float_tensor_type;
+ }
+
+ return std::any_of(enabled_types.begin(), enabled_types.end(),
+ [type_to_check](const DataTypeImpl* enabled_type) {
+ return enabled_type->IsCompatible(*type_to_check);
+ });
+}
+
+static bool KernelTypeConstraintsMatchAllNodeArgs(const onnxruntime::Node& node,
+ const KernelDef& kernel_def,
+ bool replace_fp16_with_float) {
+#if !defined(ORT_MINIMAL_BUILD)
+ static const OpSchemaKernelTypeStrResolver resolver;
+#else
+ static const KernelTypeStrResolver resolver;
+#endif
+
+ const auto actual_inputs = node.InputDefs();
+ const auto actual_outputs = node.OutputDefs();
+ const auto& actual_input_arg_counts = node.InputArgCount();
+ InlinedVector actual_input_arg_offsets;
+ actual_input_arg_offsets.reserve(actual_input_arg_counts.size());
+ int current_offset = 0;
+ for (const auto arg_count : actual_input_arg_counts) {
+ actual_input_arg_offsets.push_back(current_offset);
+ current_offset += arg_count;
+ }
+
+ const auto CheckArg = [&](const NodeArg* arg,
+ gsl::span enabled_types) {
+ if (!arg || !arg->Exists()) {
+ return true;
+ }
+
+ const auto* type_proto = arg->TypeAsProto();
+ if (type_proto == nullptr) {
+ return false;
+ }
+
+ return IsKernelTypeCompatible(enabled_types, *type_proto, replace_fp16_with_float);
+ };
+
+ for (const auto& [kernel_type_str, enabled_types] : kernel_def.TypeConstraints()) {
+ gsl::span constraint_args;
+ if (!resolver.ResolveKernelTypeStr(node, kernel_type_str, constraint_args).IsOK()) {
+ return false;
+ }
+
+ for (const auto& [arg_type, formal_arg_idx] : constraint_args) {
+ if (arg_type == ArgType::kInput) {
+ if (formal_arg_idx >= actual_input_arg_counts.size() ||
+ actual_input_arg_counts[formal_arg_idx] == 0) {
+ continue;
+ }
+
+ const auto first_arg_idx = actual_input_arg_offsets[formal_arg_idx];
+ for (int arg_idx = 0; arg_idx < actual_input_arg_counts[formal_arg_idx]; arg_idx++) {
+ const auto actual_arg_idx = static_cast(first_arg_idx + arg_idx);
+ ORT_ENFORCE(actual_arg_idx < actual_inputs.size());
+ if (!CheckArg(actual_inputs[actual_arg_idx], enabled_types)) {
+ return false;
+ }
+ }
+ } else {
+ if (formal_arg_idx < actual_outputs.size() &&
+ !CheckArg(actual_outputs[formal_arg_idx], enabled_types)) {
+ return false;
+ }
+ }
+ }
+ }
+
+ return true;
+}
+
+static bool HasCpuKernelWithTypeSupport(
+ const onnxruntime::Node& node,
+ const InlinedVector>& cpu_kernel_registries,
+ bool replace_fp16_with_float) {
+ for (const KernelRegistry* cpu_kernel_registry : cpu_kernel_registries) {
+ for (const auto& [_, kernel_create_info] : cpu_kernel_registry->GetKernelCreateMap()) {
+ const auto* kernel_def = kernel_create_info.kernel_def.get();
+ if (kernel_def != nullptr &&
+ kernel_def->Provider() == kCpuExecutionProvider &&
+ kernel_def->OpName() == node.OpType() &&
+ KernelDomainMatchesNode(*kernel_def, node) &&
+ KernelVersionMatchesNode(*kernel_def, node) &&
+ KernelTypeConstraintsMatchAllNodeArgs(node, *kernel_def, replace_fp16_with_float)) {
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+static bool HasCpuKernelForCurrentTypes(
+ const onnxruntime::Node& node,
+ const InlinedVector>& cpu_kernel_registries,
+ const logging::Logger& logger) {
+ ORT_UNUSED_PARAMETER(logger);
+ return HasCpuKernelWithTypeSupport(node, cpu_kernel_registries, false);
+}
+
+static bool HasCpuFloat32FallbackKernel(
+ const onnxruntime::Node& node,
+ const InlinedVector>& cpu_kernel_registries,
+ const logging::Logger& logger) {
+ ORT_UNUSED_PARAMETER(logger);
+ return HasCpuKernelWithTypeSupport(node, cpu_kernel_registries, true);
}
onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph,
@@ -50,21 +444,29 @@ onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph,
// check if the node has an fp16 input but was not able to be assigned an execution provider.
// we will need to add casts to/from fp32 around the node for it to be executed using the CPU EP.
-static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) {
+static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node,
+ const InlinedVector>& cpu_kernel_registries,
+ const logging::Logger& logger) {
bool not_assigned = node.GetExecutionProviderType().empty();
if (not_assigned) {
- const auto& input_defs = node.InputDefs();
- bool has_fp16_input = std::any_of(input_defs.cbegin(), input_defs.cend(),
- [](const NodeArg* input_def) {
- return IsMLFloat16Tensor(*input_def);
- });
- return has_fp16_input;
+ return HasFp16Input(node) && HasCpuFloat32FallbackKernel(node, cpu_kernel_registries, logger);
}
return false;
}
+static bool CpuAssignedFp16NodeNeedsFallbackCast(
+ const onnxruntime::Node& node,
+ const InlinedVector>& cpu_kernel_registries,
+ const logging::Logger& logger) {
+ return node.GetExecutionProviderType() == kCpuExecutionProvider &&
+ !node.ContainsSubgraph() &&
+ HasFp16IO(node) &&
+ !HasCpuKernelForCurrentTypes(node, cpu_kernel_registries, logger) &&
+ HasCpuFloat32FallbackKernel(node, cpu_kernel_registries, logger);
+}
+
// Detect an isolated node that is able to process fp16 data but is between other nodes that have fp16 inputs
// but will need a Cast inserted to enable them to run.
//
@@ -87,7 +489,7 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) {
//
// Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32.
static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph,
- const KernelRegistry& cpu_kernel_registry,
+ const InlinedVector>& cpu_kernel_registries,
const logging::Logger& logger) {
// we can check if it's an isolated fp16 node
// if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't),
@@ -150,7 +552,13 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
type_constraint_map[type_str] = DataTypeImpl::GetTensorType();
break; // we don't have multiple tensors feeding into one input
}
- type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*(input_def->TypeAsProto()));
+
+ const auto* type_proto = input_def->TypeAsProto();
+ if (type_proto == nullptr) {
+ return false;
+ }
+
+ type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*type_proto);
break; // we don't have multiple tensors feeding into one input
}
}
@@ -164,7 +572,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
const int arg_idx = input_edge->GetDstArgIndex();
if (fp16_args.find(arg_idx) != fp16_args.end()) {
// if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16
- if (!NodeNeedsInputCastToFp32(input_edge->GetNode())) {
+ if (!NodeNeedsInputCastToFp32(input_edge->GetNode(), cpu_kernel_registries, logger)) {
return false;
}
}
@@ -192,7 +600,12 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
fp16_args.emplace((int)idx);
type_constraint_map[type_str] = DataTypeImpl::GetTensorType();
} else {
- type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*(output_def->TypeAsProto()));
+ const auto* type_proto = output_def->TypeAsProto();
+ if (type_proto == nullptr) {
+ return false;
+ }
+
+ type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*type_proto);
}
}
@@ -204,7 +617,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
const int arg_idx = output_edge->GetSrcArgIndex();
if (fp16_args.find(arg_idx) != fp16_args.end()) {
// if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16
- if (!NodeNeedsInputCastToFp32(output_edge->GetNode())) {
+ if (!NodeNeedsInputCastToFp32(output_edge->GetNode(), cpu_kernel_registries, logger)) {
return false;
}
}
@@ -212,32 +625,36 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
// now all fp16 inputs and outputs would have a cast
// make sure fp32 version of the kernel is available.
- const KernelCreateInfo* kernel_create_info{};
- const auto lookup_status = cpu_kernel_registry.TryFindKernel(
- kCpuExecutionProvider, node.OpType(), node.Domain(),
- node.SinceVersion(), type_constraint_map, logger, &kernel_create_info);
- if (lookup_status.IsOK() && kernel_create_info != nullptr) {
- return true;
+ for (const KernelRegistry* cpu_kernel_registry : cpu_kernel_registries) {
+ const KernelCreateInfo* kernel_create_info{};
+ const auto lookup_status = cpu_kernel_registry->TryFindKernel(
+ kCpuExecutionProvider, node.OpType(), node.Domain(),
+ node.SinceVersion(), type_constraint_map, logger, &kernel_create_info);
+ if (lookup_status.IsOK() && kernel_create_info != nullptr) {
+ return true;
+ }
}
}
return false;
}
-// These nodes have an fp16 CPU kernel and were therefore assigned to CPU EP by the partitioner;
-// that assignment has already been recorded. Clearing the EP is the mechanism that makes
-// NeedInsertCast return true so they get wrapped with fp32 casts like any other unassigned node.
-// Collect their indices so ApplyImpl can skip the on_partition_assignment_fn_ callback for them:
-// the callback is only for nodes that are newly receiving a CPU EP fallback from this transformer,
-// not for nodes whose partitioner assignment is already on record.
-static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry,
- const logging::Logger& logger,
- InlinedHashSet& nodes_with_recorded_cpu_assignment) {
+static Status ForceSingleNodeCPUFloat16ToFloat32(
+ onnxruntime::Graph& graph,
+ const InlinedVector>& cpu_kernel_registries,
+ const logging::Logger& logger,
+ InlinedHashSet& forced_fp32_nodes,
+ InlinedHashSet& nodes_with_recorded_cpu_assignment) {
for (auto& node : graph.Nodes()) {
- if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry, logger)) {
- // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32
+ if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registries, logger)) {
+ // These nodes have an fp16 CPU kernel and were therefore assigned to CPU EP by the partitioner;
+ // that assignment has already been recorded. Clearing the EP is the mechanism that makes
+ // NeedInsertCast return true so they get wrapped with fp32 casts like any other unassigned node.
+ // Track the node index as well so the later broad fp16-preservation logic does not
+ // immediately assign it back to CPU and undo the heuristic.
nodes_with_recorded_cpu_assignment.insert(node.Index());
node.SetExecutionProviderType("");
+ forced_fp32_nodes.insert(node.Index());
}
}
@@ -518,14 +935,16 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
// is rewritten to consume fp32 (via inserted casts) and given a CPU EP assignment.
// on_partition_assignment_fn_ is fired to record each such new CPU EP assignment.
//
- // Exception: ForceSingleNodeCPUFloat16ToFloat32 may clear the EP of nodes that were already
- // assigned to CPU EP by the partitioner (they have an fp16 CPU kernel and got cast-wrapped to
- // avoid isolated fp16 islands). Those nodes are tracked here so we can skip the callback —
- // their assignment was already recorded by the partitioner.
+ // Exception: some CPU EP assignments may be cleared by this transformer to route isolated or
+ // unprofitable fp16 islands through fp32 casts. Those assignments were already recorded by the
+ // partitioner, so track them and skip the callback when they are assigned back to CPU.
+ InlinedHashSet forced_fp32_nodes;
InlinedHashSet nodes_with_recorded_cpu_assignment;
- if (force_cpu_fp32_)
- ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_, logger,
- nodes_with_recorded_cpu_assignment));
+ if (force_cpu_fp32_ && !cpu_kernel_registries_.empty()) {
+ ORT_RETURN_IF_ERROR(
+ ForceSingleNodeCPUFloat16ToFloat32(graph, cpu_kernel_registries_, logger,
+ forced_fp32_nodes, nodes_with_recorded_cpu_assignment));
+ }
GraphViewer graph_viewer(graph);
auto& order = graph_viewer.GetNodesInTopologicalOrder();
@@ -540,11 +959,56 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
if (!node)
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
+ if (CpuAssignedFp16NodeNeedsFallbackCast(*node, cpu_kernel_registries_, logger)) {
+ node->SetExecutionProviderType("");
+ }
+
+ if (!enable_cpu_fp16_ &&
+ node->GetExecutionProviderType() == kCpuExecutionProvider &&
+ IsCpuFp16OptInPolicyOp(*node) &&
+ HasFp16Input(*node) &&
+ !node->ContainsSubgraph()) {
+ node->SetExecutionProviderType("");
+ }
+
+ if (enable_cpu_fp16_ &&
+ force_cpu_fp32_ &&
+ IsCpuFp16OptInPolicyOp(*node) &&
+ HasFp16Input(*node) &&
+ !node->ContainsSubgraph() &&
+ (node->GetExecutionProviderType().empty() ||
+ node->GetExecutionProviderType() == kCpuExecutionProvider) &&
+ !ShouldKeepNativeCpuFp16ForMatMulOrGemm(graph, *node, mlas_backend_kernel_selector_config_) &&
+ HasCpuFloat32FallbackKernel(*node, cpu_kernel_registries_, logger)) {
+ // Current Arm fp16 paths are profitable for constant-RHS MatMul once native
+ // packed-B is available, and for large GEMV-like shapes. Keep Gemm conservative
+ // until MLAS native fp16 is consistently faster across its common shapes.
+ if (node->GetExecutionProviderType() == kCpuExecutionProvider) {
+ nodes_with_recorded_cpu_assignment.insert(node->Index());
+ }
+ node->SetExecutionProviderType("");
+ forced_fp32_nodes.insert(node->Index());
+ }
+
+ const bool has_fp16_io = !node->ContainsSubgraph() && HasFp16IO(*node);
+ const bool has_cpu_fp16_kernel =
+ has_fp16_io && HasCpuKernelForCurrentTypes(*node, cpu_kernel_registries_, logger);
+
+ if (enable_cpu_fp16_ &&
+ node->GetExecutionProviderType().empty() &&
+ has_cpu_fp16_kernel &&
+ forced_fp32_nodes.find(node->Index()) == forced_fp32_nodes.end()) {
+ // When CPU fp16 is enabled, assign any currently-unassigned fp16-capable node to CPU
+ // so it is preserved in fp16 instead of being routed through the fp32 cast fallback.
+ node->SetExecutionProviderType(kCpuExecutionProvider);
+ RecordPartitionAssignment(graph, *node, kCpuExecutionProvider);
+ }
+
auto& inputs = node->MutableInputDefs();
std::map replacement_defs;
bool casted = false;
for (auto input : inputs) {
- if (NeedInsertCast(node, input)) {
+ if (NeedInsertCast(node, input, logger)) {
auto src_arg = input;
if (input_def_updates.count(src_arg)) {
replacement_defs[src_arg] = input_def_updates[src_arg];
@@ -566,17 +1030,10 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
if (casted) {
// Set current node to run on the CPU execution provider
- // Keep in mind that the EP will be empty because NeedInsertCast() already insures that
+ // Keep in mind that the EP will be empty because NeedInsertCast() already ensures that
node->SetExecutionProviderType(kCpuExecutionProvider);
-
- // Record the new CPU EP assignment via the partition assignment callback if provided.
- // Skip nodes in nodes_with_recorded_cpu_assignment: their CPU EP assignment was made by the partitioner
- // and is already on record; the callback must not fire again for them.
- if (on_partition_assignment_fn_ && nodes_with_recorded_cpu_assignment.find(node->Index()) == nodes_with_recorded_cpu_assignment.end()) {
- auto sub_graph = std::make_unique();
- sub_graph->nodes = {node->Index()};
- ComputeCapability capability(std::move(sub_graph));
- on_partition_assignment_fn_(graph, capability, kCpuExecutionProvider);
+ if (nodes_with_recorded_cpu_assignment.find(node->Index()) == nodes_with_recorded_cpu_assignment.end()) {
+ RecordPartitionAssignment(graph, *node, kCpuExecutionProvider);
}
// Some ONNX operators have an attribute `dtype` which define the output type for these operators
diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.h b/onnxruntime/core/optimizer/insert_cast_transformer.h
index ca968fdfa1545..f5b1822d9fed0 100644
--- a/onnxruntime/core/optimizer/insert_cast_transformer.h
+++ b/onnxruntime/core/optimizer/insert_cast_transformer.h
@@ -3,12 +3,14 @@
#pragma once
#include "core/common/common.h"
+#include "core/common/inlined_containers.h"
#include "core/graph/graph_viewer.h"
#include "core/framework/op_kernel.h"
+#include "core/framework/graph_partitioner.h"
#include "core/optimizer/graph_transformer.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/kernel_registry.h"
-#include "core/framework/graph_partitioner.h"
+#include "core/mlas/inc/mlas.h"
namespace onnxruntime {
@@ -23,28 +25,73 @@ class InsertCastTransformer : public onnxruntime::GraphTransformer {
* @brief Initializer
* @param name for logging purpose
* @param cpu_kernel_registry used to query whether an op node can be safely created
+ * @param enable_cpu_fp16 if true, allows CPU fp16 kernels to run without forcing fp32 casts
+ * @param force_cpu_fp32 if true, applies the CPU fp16 fallback heuristic, which may keep selected
+ * fp16 CPU nodes on the existing fp32 cast fallback path when native fp16 is
+ * not expected to be profitable for the active MLAS backend
+ * @param mlas_backend_kernel_selector_config
+ * active MLAS backend selector config. Used by the CPU fp16 heuristic to avoid
+ * preserving native fp16 paths that rely on backend-specific support, such as
+ * native packed-B, when that backend is disabled or unavailable.
+ * @param on_partition_assignment_fn
+ * optional callback for recording nodes manually assigned by this transformer.
*/
InsertCastTransformer(const std::string& name, const KernelRegistry* cpu_kernel_registry,
+ bool enable_cpu_fp16 = false, bool force_cpu_fp32 = true,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config = nullptr,
+ OnPartitionAssignmentFunction on_partition_assignment_fn = {})
+ : onnxruntime::GraphTransformer(name),
+ cpu_kernel_registries_(cpu_kernel_registry != nullptr ? InlinedVector>{cpu_kernel_registry}
+ : InlinedVector>{}),
+ enable_cpu_fp16_(enable_cpu_fp16),
+ force_cpu_fp32_(!cpu_kernel_registries_.empty() && force_cpu_fp32),
+ mlas_backend_kernel_selector_config_(mlas_backend_kernel_selector_config != nullptr
+ ? *mlas_backend_kernel_selector_config
+ : MLAS_BACKEND_KERNEL_SELECTOR_CONFIG{}),
+ on_partition_assignment_fn_(std::move(on_partition_assignment_fn)) {}
+
+ InsertCastTransformer(const std::string& name, const KernelRegistry* cpu_kernel_registry,
+ OnPartitionAssignmentFunction on_partition_assignment_fn)
+ : InsertCastTransformer(name, cpu_kernel_registry,
+ /*enable_cpu_fp16*/ false,
+ /*force_cpu_fp32*/ true,
+ /*mlas_backend_kernel_selector_config*/ nullptr,
+ std::move(on_partition_assignment_fn)) {}
+
+ InsertCastTransformer(const std::string& name,
+ InlinedVector> cpu_kernel_registries,
+ bool enable_cpu_fp16 = false, bool force_cpu_fp32 = true,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config = nullptr,
OnPartitionAssignmentFunction on_partition_assignment_fn = {})
: onnxruntime::GraphTransformer(name),
- cpu_kernel_registries_(cpu_kernel_registry),
- force_cpu_fp32_(cpu_kernel_registry != nullptr),
+ cpu_kernel_registries_(std::move(cpu_kernel_registries)),
+ enable_cpu_fp16_(enable_cpu_fp16),
+ force_cpu_fp32_(!cpu_kernel_registries_.empty() && force_cpu_fp32),
+ mlas_backend_kernel_selector_config_(mlas_backend_kernel_selector_config != nullptr
+ ? *mlas_backend_kernel_selector_config
+ : MLAS_BACKEND_KERNEL_SELECTOR_CONFIG{}),
on_partition_assignment_fn_(std::move(on_partition_assignment_fn)) {}
private:
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
- bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const;
+ bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input,
+ const logging::Logger& logger) const;
+ void RecordPartitionAssignment(const onnxruntime::Graph& graph, const onnxruntime::Node& node,
+ const std::string& ep_type) const;
- const KernelRegistry* cpu_kernel_registries_;
+ const InlinedVector> cpu_kernel_registries_;
- // Currently because we only have very few cpu kernels support float16, place those nodes on float16
- // will introduce many cast between fp32 and fp16, which will slow the execution.
- // A better solution is to have a cost model to evaluate does it works to place the node on float16.
- // Here for simplify, we only force the single-node-float16 sub-graph to float32
+ const bool enable_cpu_fp16_;
+
+ // Some CPU fp16 kernels are only profitable for specific shapes and backend capabilities. A broader cost model would
+ // be better; for now we use conservative checks for known slower cases and native packed-B availability.
const bool force_cpu_fp32_;
+ // Copied from session options so graph optimization makes the same backend-capability decisions that execution will.
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_;
+
// Optional callback to record when nodes are assigned to CPU EP by this transformer.
// Reuses the same callback type as GraphPartitioner to maintain consistent EP assignment tracking.
- OnPartitionAssignmentFunction on_partition_assignment_fn_;
+ const OnPartitionAssignmentFunction on_partition_assignment_fn_;
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index fb854b19accfa..5476a9b210f65 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -177,11 +177,13 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 21, Atan);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, float, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, double, MatMul);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MLFloat16, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, TopK);
@@ -397,8 +399,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float,
@@ -575,6 +579,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift);
@@ -709,8 +714,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
#endif
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, MatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Min);
@@ -1745,6 +1752,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
float, Gemm)>,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
index 08dbc46213f65..222dcc4b3c6dc 100644
--- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
+++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
@@ -9,9 +9,13 @@
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
+#include
+
+#include "core/common/narrow.h"
#include "core/common/safeint.h"
#include "core/common/float16.h"
#include "core/framework/op_kernel.h"
+#include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h"
#include "core/providers/cpu/nn/conv_attributes.h"
#include "contrib_ops/cpu/fused_activation.h"
@@ -46,6 +50,13 @@ class FusedConvFp16 final : public OpKernel {
FusedConvFp16(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
channels_last_ = (info.GetKernelDef().OpName() == "NhwcFusedConv");
+ SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions());
+ const auto& input_defs = info.node().InputDefs();
+ has_bias_input_ = input_defs.size() >= 3 && input_defs[2]->Exists();
+ const Tensor* bias = nullptr;
+ if (has_bias_input_ && info.TryGetConstantInput(2, &bias)) {
+ constant_B_ = bias;
+ }
}
Status Compute(OpKernelContext* context) const override;
@@ -93,8 +104,13 @@ class FusedConvFp16 final : public OpKernel {
}
MLAS_ACTIVATION activation_;
+ MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_;
ConvAttributes conv_attrs_;
+ bool has_bias_input_{false};
+ const Tensor* constant_B_{nullptr};
TensorShape W_shape_;
+ BufferUniquePtr packed_halfconv_weights_and_bias_buffer_;
+ size_t packed_halfconv_weights_and_bias_size_{0};
BufferUniquePtr packed_W_buffer_;
size_t packed_W_size_{0};
bool is_W_packed_{false};
@@ -139,8 +155,53 @@ Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
const size_t kernel_dim = group_input_channels * kernel_size;
bool share_prepacked_weights = (prepacked_weights != nullptr);
+ const bool has_valid_constant_halfconv_bias =
+ constant_B_ != nullptr &&
+ !share_prepacked_weights &&
+ constant_B_->Shape().NumDimensions() == 1 &&
+ constant_B_->Shape()[0] == static_cast(output_channels);
const bool is_depthwise_conv = (group_input_channels == 1 && group_output_channels == 1);
+
+ if (group_count == 1 &&
+ rank == 4 &&
+ output_channels > 1 &&
+ activation_.ActivationKind == MlasIdentityActivation &&
+ (!has_bias_input_ || has_valid_constant_halfconv_bias)) {
+ std::array halfconv_kernel_shape{shape[2], shape[3]};
+ std::array halfconv_dilations{1, 1};
+ if (!conv_attrs_.dilations.empty()) {
+ halfconv_dilations = {conv_attrs_.dilations[0], conv_attrs_.dilations[1]};
+ }
+
+ packed_halfconv_weights_and_bias_size_ = MlasHalfConvPackWeightsAndBiasSize(
+ output_channels,
+ group_input_channels,
+ halfconv_kernel_shape.data(),
+ halfconv_dilations.data(),
+ &mlas_backend_kernel_selector_config_);
+ if (packed_halfconv_weights_and_bias_size_ != 0) {
+ auto* packed_halfconv_weights_and_bias = alloc->Alloc(packed_halfconv_weights_and_bias_size_);
+ BufferUniquePtr packed_halfconv_weights_and_bias_buffer(
+ packed_halfconv_weights_and_bias, BufferDeleter(alloc));
+ const auto* bias_data = constant_B_ != nullptr ? constant_B_->Data() : nullptr;
+ if (MlasHalfConvPackWeightsAndBias(
+ output_channels,
+ group_input_channels,
+ halfconv_kernel_shape.data(),
+ halfconv_dilations.data(),
+ Wdata,
+ bias_data,
+ packed_halfconv_weights_and_bias,
+ nullptr,
+ &mlas_backend_kernel_selector_config_)) {
+ packed_halfconv_weights_and_bias_buffer_ = std::move(packed_halfconv_weights_and_bias_buffer);
+ } else {
+ packed_halfconv_weights_and_bias_size_ = 0;
+ }
+ }
+ }
+
// Don't pack the filter buffer if the MlasConvDepthwise path is used.
if (!is_depthwise_conv) {
packed_W_size_ = MlasHalfGemmPackBSize(group_output_channels, kernel_dim, false);
@@ -185,6 +246,9 @@ Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
}
if (share_prepacked_weights) {
+ // `packed_halfconv_weights_and_bias_buffer_` is session-local only. It encodes a bias
+ // choice even when the bias is null, so it must not be shared under the
+ // W-only prepack cache key.
prepacked_weights->buffers_.push_back(nullptr); // packed_W_buffer_ is nullptr
prepacked_weights->buffer_sizes_.push_back(0);
}
@@ -222,12 +286,13 @@ Status FusedConvFp16::UseSharedPrePackedBuffers(std::vector& pr
used_shared_buffers = true;
- if (prepacked_buffers.size() == 1) { // This means that only packed_W_ exists
+ if (prepacked_buffers.size() == 1) { // only packed_W_ exists
packed_W_buffer_ = std::move(prepacked_buffers[0]);
- } else if (prepacked_buffers.size() == 2) { // This means that only reordered_W_ exists
- // Enforce that the first "placeholder" buffer is nullptr
+ } else if (prepacked_buffers.size() == 2) { // placeholder + reordered_W_
ORT_ENFORCE(prepacked_buffers[0].get() == nullptr);
reordered_W_buffer_ = std::move(prepacked_buffers[1]);
+ } else {
+ ORT_ENFORCE(false, "Unexpected number of shared prepacked fp16 conv buffers.");
}
return Status::OK();
@@ -296,6 +361,63 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
+ concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
+
+ const auto* Bdata = B != nullptr ? B->Data() : nullptr;
+
+ if (Sum == nullptr && kernel_rank == 2) {
+ MLAS_CONV_PARAMETERS parameters;
+
+ size_t working_buffer_size = 0;
+ if (MlasHalfConvPrepare(¶meters,
+ kernel_rank,
+ narrow(N),
+ narrow(conv_attrs_.group),
+ narrow(C / conv_attrs_.group),
+ input_shape.GetDims().data(),
+ kernel_shape.data(),
+ dilations.data(),
+ pads.data(),
+ strides.data(),
+ output_shape.GetDims().data(),
+ narrow(M / conv_attrs_.group),
+ &activation_,
+ &working_buffer_size,
+ 0.0f,
+ channels_last_,
+ thread_pool,
+ &mlas_backend_kernel_selector_config_)) {
+ const MLFloat16* halfconv_filter = nullptr;
+ const MLFloat16* halfconv_bias = Bdata;
+ bool halfconv_filter_and_bias_are_packed = false;
+
+ if (packed_halfconv_weights_and_bias_buffer_ != nullptr) {
+ halfconv_filter = static_cast(packed_halfconv_weights_and_bias_buffer_.get());
+ halfconv_filter_and_bias_are_packed = true;
+ halfconv_bias = nullptr;
+ } else if (W != nullptr) {
+ halfconv_filter = W->Data();
+ }
+
+ if (halfconv_filter != nullptr) {
+ auto* working_data = working_buffer_size > 0
+ ? alloc->Alloc(sizeof(MLFloat16) * SafeInt(working_buffer_size))
+ : nullptr;
+ BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc));
+
+ if (MlasHalfConv(¶meters,
+ X->Data(),
+ halfconv_filter,
+ halfconv_filter_and_bias_are_packed,
+ halfconv_bias,
+ static_cast(working_buffer.get()),
+ Y->MutableData(),
+ thread_pool)) {
+ return Status::OK();
+ }
+ }
+ }
+ }
// Handle the case of a dynamic weight filter.
BufferUniquePtr reordered_W_buffer;
@@ -337,7 +459,6 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
const int64_t col_buffer_size = kernel_dim * output_image_size;
const auto* Xdata = X->Data();
- const auto* Bdata = B != nullptr ? B->Data() : nullptr;
auto* Ydata = Y->MutableData();
const auto* sum_data = Sum != nullptr ? Sum->Data() : nullptr;
@@ -387,8 +508,6 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
padding_data.resize(static_cast(C), MLFloat16());
}
- concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
-
/*************************************
* Thread partition idea: we are essentially partition a GEMM A[M,K] x B[K,N].
* Here B contains the conv filters, which are usually not big, so we assume
@@ -544,7 +663,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
const auto* gemm_add = add_src == nullptr ? nullptr : worker_addsrc + group_id * group_output_channels;
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_, gemm_add);
- MLAS_HALF_GEMM_DATA_PARAMS gemm_params;
+ MLAS_HALF_GEMM_DATA_PARAMS gemm_params{};
gemm_params.A = AData;
gemm_params.lda = lda;
if (packed_W_buffer_) {
@@ -563,7 +682,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
static_cast(output_count),
static_cast(group_output_channels),
static_cast(kernel_dim),
- 1, &gemm_params, nullptr);
+ 1, &gemm_params, nullptr, &mlas_backend_kernel_selector_config_);
}
}
};
diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc
index c0da9aec1e1b1..a86e5a36aadf9 100644
--- a/onnxruntime/core/providers/cpu/math/gemm.cc
+++ b/onnxruntime/core/providers/cpu/math/gemm.cc
@@ -206,26 +206,30 @@ void Gemm_MLFloat16(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b,
if (c_data == nullptr)
beta = onnxruntime::MLFloat16::Zero;
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
- bool support_mlas = false;
+ bool support_mlas_bias = false;
if (c_shape == nullptr) {
- support_mlas = true;
+ support_mlas_bias = true;
} else if (c_shape->NumDimensions() == 1 && (*c_shape)[0] == N) {
- support_mlas = true;
- } else if (c_shape->NumDimensions() == 2 && (((*c_shape)[0] == 1 && (*c_shape)[1] == N) || ((*c_shape)[0] == N && (*c_shape)[1] == 1))) {
- support_mlas = true;
+ support_mlas_bias = true;
+ } else if (c_shape->NumDimensions() == 2 &&
+ (((*c_shape)[0] == 1 && (*c_shape)[1] == N) || ((*c_shape)[0] == N && (*c_shape)[1] == 1))) {
+ support_mlas_bias = true;
}
- if (trans_a == CblasNoTrans && trans_b == CblasNoTrans && support_mlas && alpha.ToFloat() == 1.0 && beta.ToFloat() == 1.0) {
- MLAS_HALF_GEMM_DATA_PARAMS data;
+ const bool use_mlas_no_bias = beta == onnxruntime::MLFloat16::Zero;
+ const bool use_mlas_bias = beta == onnxruntime::MLFloat16::One && support_mlas_bias;
+ if (trans_a == CblasNoTrans && trans_b == CblasNoTrans &&
+ alpha == onnxruntime::MLFloat16::One && (use_mlas_no_bias || use_mlas_bias)) {
+ MLAS_HALF_GEMM_DATA_PARAMS data{};
data.A = a_data;
data.lda = K;
data.B = b_data;
data.ldb = N;
data.C = y_data;
data.ldc = N;
- if (c_shape != nullptr) {
+ if (use_mlas_bias && c_shape != nullptr) {
data.Bias = c_data;
}
- MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool);
+ MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool, mlas_backend_kernel_selector_config);
return;
}
#endif
diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc
index ca91f46db93da..e19f8ec370d26 100644
--- a/onnxruntime/core/providers/cpu/math/matmul.cc
+++ b/onnxruntime/core/providers/cpu/math/matmul.cc
@@ -6,6 +6,9 @@
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
+#include
+#include
+#include
namespace onnxruntime {
@@ -23,6 +26,13 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
MatMul);
+ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
+ MatMul,
+ 1, 8,
+ MLFloat16,
+ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ MatMul);
+
// opset 9 supports more types
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
MatMul,
@@ -40,6 +50,14 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
MatMul);
+ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
+ MatMul,
+ 9,
+ 12,
+ MLFloat16,
+ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ MatMul);
+
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
MatMul,
9,
@@ -72,6 +90,13 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL(
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
MatMul);
+ONNX_CPU_OPERATOR_TYPED_KERNEL(
+ MatMul,
+ 13,
+ MLFloat16,
+ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ MatMul);
+
ONNX_CPU_OPERATOR_TYPED_KERNEL(
MatMul,
13,
@@ -106,9 +131,8 @@ Status MatMul::Compute(OpKernelContext* ctx) const {
if (helper.K() == 0) {
// When we have (M, 0, N) then the inputs are empty, but the output should
// be filled out with zeros.
- EigenMatrixMapRowMajor dest(y->MutableData(),
- narrow(helper.M()), narrow(helper.N()));
- dest.setZero();
+ auto output_span = gsl::make_span(y->MutableData(), narrow(y->Shape().Size()));
+ std::fill(output_span.begin(), output_span.end(), T{});
return Status::OK();
}
@@ -235,6 +259,45 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc,
}
#endif
+bool GemmPackBHalfNative(AllocatorPtr& alloc,
+ const Tensor& tensor_b,
+ IAllocatorUniquePtr& packed_b,
+ size_t& packed_b_size,
+ TensorShape& b_shape,
+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) {
+ // Only handle the common case of a 2D weight matrix. Additional matrices
+ // could be handled by stacking the packed buffers.
+ if (tensor_b.Shape().NumDimensions() != 2) {
+ return false;
+ }
+
+ b_shape = tensor_b.Shape();
+
+ const size_t K = static_cast(b_shape[0]);
+ const size_t N = static_cast(b_shape[1]);
+
+ packed_b_size = MlasHalfGemmNativePackBSize(CblasNoTrans, CblasNoTrans, N, K,
+ mlas_backend_kernel_selector_config);
+ if (packed_b_size == 0) {
+ return false;
+ }
+
+ packed_b = IAllocator::MakeUniquePtr(alloc, packed_b_size, true);
+ auto* packed_b_data = packed_b.get();
+ memset(packed_b_data, 0, packed_b_size);
+
+ if (!MlasHalfGemmNativePackB(CblasNoTrans, CblasNoTrans, N, K,
+ reinterpret_cast