Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 34 additions & 33 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/threading.cpp
${MLAS_SRC_DIR}/sgemm.cpp
${MLAS_SRC_DIR}/halfgemm.cpp
${MLAS_SRC_DIR}/halfconv.cpp
${MLAS_SRC_DIR}/qgemm.cpp
${MLAS_SRC_DIR}/qdwconv.cpp
${MLAS_SRC_DIR}/convolve.cpp
Expand Down Expand Up @@ -309,6 +310,8 @@ function(setup_kleidiai)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/kai_ukernel_interface.cpp
${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp
${MLAS_SRC_DIR}/kleidiai/halfgemm_kleidiai.cpp
${MLAS_SRC_DIR}/kleidiai/halfconv_kleidiai.cpp
${MLAS_SRC_DIR}/kleidiai/sbgemm_kleidiai.cpp
${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp
${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp
Expand Down Expand Up @@ -363,7 +366,7 @@ function (setup_arm_neon_nchwc)
${MLAS_SRC_DIR}/aarch64/SconvNchwcKernelNeon.S
${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S
${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S
)
)
endif()
list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC)
set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE)
Expand Down Expand Up @@ -546,8 +549,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")

if (NOT APPLE)
set(mlas_platform_srcs
set(mlas_platform_srcs
${mlas_platform_srcs}
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
Expand All @@ -573,37 +575,36 @@ else()
${MLAS_SRC_DIR}/gelu_neon_fp16.h
${MLAS_SRC_DIR}/gelu_neon_fp16.cpp
)
if (onnxruntime_USE_ARM_NEON_NCHWC)
list(APPEND mlas_platform_srcs
${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S
${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S
${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S
)
endif()
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16_8bit.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
if (onnxruntime_USE_ARM_NEON_NCHWC)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
endif()
set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
if (onnxruntime_USE_ARM_NEON_NCHWC AND NOT APPLE)
list(APPEND mlas_platform_srcs
${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S
${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S
${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S
)
endif()
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16_8bit.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
if (onnxruntime_USE_ARM_NEON_NCHWC AND NOT APPLE)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
endif()
set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs})
Expand Down
14 changes: 7 additions & 7 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ The **OpSet Version** column uses the following notation:
|||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|20+|**T** = tensor(float)|
|Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[9, 10]|**T** = tensor(double), tensor(float)|
|||[7, 8]|**T** = tensor(double), tensor(float)|
|Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|GlobalAveragePool|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(float)|
|||[1, 21]|**T** = tensor(float)|
|GlobalLpPool|*in* X:**T**<br> *out* Y:**T**|2+|**T** = tensor(float)|
Expand Down Expand Up @@ -258,9 +258,9 @@ The **OpSet Version** column uses the following notation:
|||[18, 21]|**T** = tensor(float)|
|||[11, 17]|**T** = tensor(float)|
|||[2, 10]|**T** = tensor(float)|
|MatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[1, 8]|**T** = tensor(double), tensor(float)|
|MatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[1, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulInteger|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *out* Y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int32)|
|Max|*in* data_0:**T**<br> *out* max:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,24 @@ static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFil
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";

// Enables opt-in CPU EP float16 execution for supported ops.
// This allows InsertCastTransformer to preserve eligible fp16 CPU nodes instead of inserting fp16<->fp32 casts.
// The fp32 fallback heuristic below is still enabled by default and may keep selected shapes on the fp32 path when
// native fp16 is not expected to be profitable for the active MLAS backend.
// Option values:
// - "0": CPU fp16 is not explicitly enabled, and default cast/fallback behavior is used. [DEFAULT]
// - "1": Enable CPU fp16 preservation for the current opt-in scope.
static const char* const kOrtSessionOptionsEnableCpuFp16 = "session.enable_cpu_fp16";

// Controls the CPU fp16 -> fp32 fallback heuristic used when session.enable_cpu_fp16 is "1".
// The heuristic keeps native CPU fp16 only for cases expected to be profitable, such as supported GEMV-like shapes or
// constant-RHS MatMul when the active MLAS backend reports a native packed-B path. Other eligible nodes are assigned to
// CPU through fp32 casts when a valid CPU fp32 kernel exists. This is the recommended/default CPU fp16 mode.
// Option values:
// - "0": Do not use the heuristic; preserve native CPU fp16 more broadly for eligible CPU fp16 kernels.
// - "1": Use the fp32 fallback heuristic. [DEFAULT]
static const char* const kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic = "session.cpu_fp16_use_fp32_fallback_heuristic";

// Use LUT (Lookup Table) based GEMM for quantized models when available.
// Option values:
// - "0": Do not use LUT based GEMM. [DEFAULT]
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ Status MoE<float>::ComputeGEMM(const float* A, const float* B, float* C,
template <>
Status MoE<MLFloat16>::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFloat16* C,
int64_t M, int64_t K, int64_t N, bool transpose_B) const {
MLAS_HALF_GEMM_DATA_PARAMS params;
MLAS_HALF_GEMM_DATA_PARAMS params{};
params.A = A;
params.lda = static_cast<size_t>(K);
params.C = C;
Expand All @@ -551,7 +551,8 @@ Status MoE<MLFloat16>::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFlo
params.ldb = static_cast<size_t>(N);
}

MlasHalfGemmBatch(static_cast<size_t>(M), static_cast<size_t>(N), static_cast<size_t>(K), 1, &params, nullptr);
MlasHalfGemmBatch(static_cast<size_t>(M), static_cast<size_t>(N), static_cast<size_t>(K), 1, &params, nullptr,
&mlas_backend_kernel_selector_config_);
return Status::OK();
}

Expand Down
11 changes: 4 additions & 7 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,13 +517,10 @@ Status SessionState::PrepackConstantInitializedTensors(
is_packed,
&weights_to_be_filled_in));

if (is_packed) {
// BUG CHECK: Ensure that the kernel has filled in the pre-packed weight
// to be cached if the weight was pre-packed
ORT_ENFORCE(weights_to_be_filled_in.buffers_.size() > 0,
"The kernel corresponding to the node ", node.Name(),
" doesn't have an implementation that can cache computed pre-packed weights");

// Some kernels pre-pack for their own session but intentionally do not
// expose those buffers for sharing, for example when the packed layout is
// backend-specific and the shared container has no layout metadata.
if (is_packed && !weights_to_be_filled_in.buffers_.empty()) {
const auto& op_type = node.OpType();

// Sanity check
Expand Down
Loading
Loading