diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index 7ae18db235ccb..7edfdbf504732 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -24,6 +24,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sgemm.cpp
${MLAS_SRC_DIR}/halfgemm.cpp
${MLAS_SRC_DIR}/qgemm.cpp
+ ${MLAS_SRC_DIR}/qgemm_fp8.cpp
${MLAS_SRC_DIR}/qdwconv.cpp
${MLAS_SRC_DIR}/convolve.cpp
${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_greater_than_1.cpp
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 3608f1246450f..454051fe896fa 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -26,6 +26,7 @@ Do not modify directly.*
* com.microsoft.DequantizeBFP
* com.microsoft.DequantizeLinear
* com.microsoft.DequantizeWithOrder
+ * com.microsoft.DynamicQuantMatMulFp8
* com.microsoft.DynamicQuantizeLSTM
* com.microsoft.DynamicQuantizeMatMul
* com.microsoft.DynamicTimeWarping
@@ -1493,6 +1494,67 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.DynamicQuantMatMulFp8**
+
+ Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from float16/bfloat16/float) and dynamic runtime quantization of activations to fp8 using internally computed block-wise scales. All zero-point inputs, when provided, must encode 0.0.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- block_size_k : int
+- Block size along K for A and B block-wise scales.
+- block_size_m : int
+- Block size along M for A block-wise scales. Must be 1.
+- block_size_n : int
+- Block size along N for B block-wise scales.
+- fp8_type : int
+- FP8 TensorProto data type used when non-FP8 constant B is dynamically quantized during prepack. Defaults to FLOAT8E4M3FN.
+
+
+#### Inputs (2 - 6)
+
+
+- A : TA
+- Input tensor A.
+- B : TB
+- Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only supported when B is a constant initializer that can be quantized during prepack.
+- B_scale (optional) : TS
+- Scale of FP8 input 'B'. Must be a block-wise tensor with shape (N / block_size_n, K / block_size_k). Required when B is already FP8. Ignored for non-FP8 constant B, where scales are computed during prepack.
+- B_zero_point (optional) : TZ
+- Zero point tensor for input 'B'. Must have the same shape as B_scale and all values must encode 0.0.
+- Y_scale (optional) : TS
+- Scale of output 'Y'. Must be a scalar when provided.
+- Y_zero_point (optional) : TZ
+- Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided.
+
+
+#### Outputs
+
+
+- Y : TY
+- Output tensor of shape (..., M, N).
+
+
+#### Type Constraints
+
+
+- TA : tensor(float16), tensor(bfloat16), tensor(float)
+- Constrain input A type to float16, bfloat16, or float.
+- TB : tensor(float16), tensor(bfloat16), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
+- Constrain input B type to fp8, or to float16, bfloat16, or float for constant initializers.
+- TZ : tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
+- Constrain zero point types to fp8. Only zero-valued zero points are supported.
+- TS : tensor(float), tensor(float16), tensor(bfloat16)
+- Constrain scale types to float, float16, or bfloat16.
+- TY : tensor(float16), tensor(bfloat16), tensor(float)
+- Constrain output type to float16, bfloat16, or float.
+
+
+
### **com.microsoft.DynamicQuantizeLSTM**
#### Version
@@ -6690,5 +6752,3 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
-
-
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index d0d8e750285d4..969083aba3aa0 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -577,6 +577,7 @@ The **OpSet Version** column uses the following notation:
|CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)|
|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**T** = tensor(float)|
|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float)|
+|DynamicQuantMatMulFp8|*in* A:**TA**
*in* B:**TB**
*in* B_scale:**TS**
*in* B_zero_point:**TZ**
*in* Y_scale:**TS**
*in* Y_zero_point:**TZ**
*out* Y:**TY**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**TS** = tensor(bfloat16), tensor(float), tensor(float16)
**TY** = tensor(bfloat16), tensor(float), tensor(float16)
**TZ** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)|
|DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
|DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
|DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)|
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 0749457f5a182..b581064b180a2 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -116,6 +116,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, NhwcMaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, NhwcMaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization);
+#if !defined(DISABLE_FLOAT8_TYPES)
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicQuantMatMulFp8);
+#endif
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE);
@@ -284,6 +287,9 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+#if !defined(DISABLE_FLOAT8_TYPES)
+ BuildKernelCreateInfo,
+#endif
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc
new file mode 100644
index 0000000000000..350de4ac9cfa3
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc
@@ -0,0 +1,802 @@
+// Copyright (c) 2026 Arm Limited. All rights reserved.
+// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates
+// SPDX-License-Identifier: MIT
+
+#include "dynamic_quant_matmul_fp8.h"
+
+#include "core/common/common.h"
+#include "core/common/fp8_common.h"
+#include "core/framework/op_kernel.h"
+#include "core/graph/onnx_protobuf.h"
+#include "core/mlas/inc/mlas.h"
+#include "core/common/float16.h"
+#include "core/common/float8.h"
+#include "core/common/safeint.h"
+#include "core/platform/threadpool.h"
+#include "core/providers/cpu/math/matmul_helper.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+namespace {
+
+constexpr int64_t kDefaultBlockSize = 128;
+constexpr int64_t kDefaultBlockSizeM = 1;
+constexpr int64_t kPackedBMetadataVersion = 1;
+constexpr size_t kPackedBMetadataElementCount = 7;
+constexpr size_t kPackedBMetadataSize = kPackedBMetadataElementCount * sizeof(int64_t);
+
+enum PackedBMetadataIndex : size_t {
+ kPackedBMetadataVersionIndex = 0,
+ kPackedBMetadataRowsIndex,
+ kPackedBMetadataColsIndex,
+ kPackedBMetadataSizeIndex,
+ kPackedBMetadataScaleCountIndex,
+ kPackedBMetadataFp8ModeIndex,
+ kPackedBMetadataHasFp8ModeIndex,
+};
+
+bool IsFp8DataType(ONNX_NAMESPACE::TensorProto_DataType elem_type) {
+ return elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN ||
+ elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ ||
+ elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2 ||
+ elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
+}
+
+bool IsValidFp8Mode(int64_t mode) {
+ return mode >= static_cast(MLAS_FP8_MODE_E4M3_INF) &&
+ mode < static_cast(MLAS_FP8_MODE_END);
+}
+
+Status RestorePackedBMetadata(const void* metadata_buffer,
+ size_t metadata_size,
+ size_t quantized_b_buffer_size,
+ size_t b_scale_buffer_size,
+ TensorShape& b_shape,
+ size_t& quantized_b_size,
+ size_t& b_scale_count,
+ mlas_fp8_mode& b_type,
+ bool& has_b_type) {
+ ORT_RETURN_IF(metadata_buffer == nullptr,
+ "DynamicQuantMatMulFp8 requires shared prepacked B metadata.");
+ ORT_RETURN_IF(metadata_size != kPackedBMetadataSize,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata has an unexpected size.");
+
+ const auto* metadata = static_cast(metadata_buffer);
+ ORT_RETURN_IF(metadata[kPackedBMetadataVersionIndex] != kPackedBMetadataVersion,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata has an unsupported version.");
+ ORT_RETURN_IF(metadata[kPackedBMetadataRowsIndex] <= 0 || metadata[kPackedBMetadataColsIndex] <= 0,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B shape.");
+ ORT_RETURN_IF(metadata[kPackedBMetadataSizeIndex] <= 0,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B buffer size.");
+ ORT_RETURN_IF(metadata[kPackedBMetadataScaleCountIndex] <= 0,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B scale count.");
+ ORT_RETURN_IF(metadata[kPackedBMetadataHasFp8ModeIndex] != 1 ||
+ !IsValidFp8Mode(metadata[kPackedBMetadataFp8ModeIndex]),
+ "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid FP8 type.");
+
+ const size_t rows = static_cast(metadata[kPackedBMetadataRowsIndex]);
+ const size_t cols = static_cast(metadata[kPackedBMetadataColsIndex]);
+ const size_t expected_quantized_b_size = SafeMul(rows, cols);
+ const size_t restored_quantized_b_size = static_cast(metadata[kPackedBMetadataSizeIndex]);
+ ORT_RETURN_IF(restored_quantized_b_size != expected_quantized_b_size ||
+ restored_quantized_b_size != quantized_b_buffer_size,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata does not match the B buffer size.");
+ const size_t restored_b_scale_count = static_cast(metadata[kPackedBMetadataScaleCountIndex]);
+ ORT_RETURN_IF(restored_b_scale_count > std::numeric_limits::max() / sizeof(float) ||
+ restored_b_scale_count * sizeof(float) != b_scale_buffer_size,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata does not match the B scale buffer size.");
+
+ b_shape = TensorShape({metadata[kPackedBMetadataRowsIndex], metadata[kPackedBMetadataColsIndex]});
+ quantized_b_size = restored_quantized_b_size;
+ b_scale_count = restored_b_scale_count;
+ b_type = static_cast(metadata[kPackedBMetadataFp8ModeIndex]);
+ has_b_type = true;
+ return Status::OK();
+}
+
+// Reject invalid scales before quantization divides by them or MLAS dequantizes with them.
+Status ValidatePositiveFiniteScales(const float* scales, size_t count, const char* scale_name) {
+ for (size_t i = 0; i < count; ++i) {
+ ORT_RETURN_IF(!std::isfinite(scales[i]) || scales[i] <= 0.0f,
+ "DynamicQuantMatMulFp8 requires ", scale_name, " values to be finite and positive.");
+ }
+ return Status::OK();
+}
+
+Status GetFp8MaxAbs(mlas_fp8_mode mode, float& max_abs) {
+ switch (mode) {
+ case MLAS_FP8_MODE_E4M3_INF:
+ case MLAS_FP8_MODE_E4M3_SAT:
+ max_abs = 448.0f;
+ return Status::OK();
+ case MLAS_FP8_MODE_E5M2_INF:
+ case MLAS_FP8_MODE_E5M2_SAT:
+ max_abs = 57344.0f;
+ return Status::OK();
+ default:
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Unsupported fp8 mode for DynamicQuantMatMulFp8.");
+ }
+}
+
+Status ValidateZeroPointValuesAreZero(const Tensor& zero_point, size_t expected_count,
+ const char* zero_point_name) {
+ const size_t actual_count = static_cast(zero_point.Shape().Size());
+ ORT_RETURN_IF(actual_count != expected_count,
+ "DynamicQuantMatMulFp8 requires ", zero_point_name, " to have the expected number of elements.");
+
+ const auto reject_non_zero = [zero_point_name](float value) {
+ ORT_RETURN_IF(value != 0.0f,
+ "DynamicQuantMatMulFp8 supports symmetric quantization only; ",
+ zero_point_name, " values must be zero.");
+ return Status::OK();
+ };
+
+ if (zero_point.IsDataType()) {
+ const auto* zp = static_cast(zero_point.DataRaw());
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], MLAS_FP8_MODE_E4M3_INF)));
+ }
+ } else if (zero_point.IsDataType()) {
+ const auto* zp = static_cast(zero_point.DataRaw());
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], MLAS_FP8_MODE_E4M3_SAT)));
+ }
+ } else if (zero_point.IsDataType()) {
+ const auto* zp = static_cast(zero_point.DataRaw());
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], MLAS_FP8_MODE_E5M2_INF)));
+ }
+ } else if (zero_point.IsDataType()) {
+ const auto* zp = static_cast(zero_point.DataRaw());
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], MLAS_FP8_MODE_E5M2_SAT)));
+ }
+ } else if (zero_point.IsDataType()) {
+ const auto* zp = zero_point.Data();
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(zp[i]));
+ }
+ } else if (zero_point.IsDataType()) {
+ const auto* zp = zero_point.Data();
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(static_cast(zp[i])));
+ }
+ } else if (zero_point.IsDataType()) {
+ const auto* zp = zero_point.Data();
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(static_cast(zp[i])));
+ }
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Unsupported zero point type for DynamicQuantMatMulFp8.");
+ }
+ return Status::OK();
+}
+
+template
+void QuantizeBlockwiseFp8ABlockDynamic(const SrcT* src,
+ size_t K,
+ size_t block_size_k,
+ size_t blocks_k,
+ size_t row,
+ size_t block_k,
+ float fp8_max_abs,
+ mlas_fp8_mode mode,
+ uint8_t* dst,
+ float* scales) {
+ const size_t k_begin = block_k * block_size_k;
+ const size_t k_end = std::min(K, k_begin + block_size_k);
+ const size_t idx = row * blocks_k + block_k;
+
+ // Match the KleidiAI LHS packer: one dynamic quantization scale per A row and K block.
+ float max_abs = 0.0f;
+ const size_t row_offset = row * K;
+ for (size_t k = k_begin; k < k_end; ++k) {
+ max_abs = std::max(max_abs, std::fabs(static_cast(src[row_offset + k])));
+ }
+
+ const float scale = max_abs == 0.0f ? 1.0f : max_abs / fp8_max_abs;
+ scales[idx] = scale;
+
+ for (size_t k = k_begin; k < k_end; ++k) {
+ const float value = static_cast(src[row_offset + k]);
+ const float quantized = value / scale;
+ dst[row_offset + k] = FloatToFp8Byte(quantized, mode);
+ }
+}
+
+template
+void QuantizeBlockwiseFp8WithScales(const SrcT* src,
+ size_t K,
+ size_t N,
+ size_t block_size_k,
+ size_t block_size_n,
+ const float* scales,
+ uint8_t* dst) {
+ // Block sizes come from op attributes; scale shapes only provide the number of blocks.
+ const size_t blocks_k = K / block_size_k;
+ for (size_t k = 0; k < K; ++k) {
+ const size_t block_k = k / block_size_k;
+ const size_t row_offset = k * N;
+ for (size_t n = 0; n < N; ++n) {
+ const size_t block_n = n / block_size_n;
+ const size_t scale_idx = block_n * blocks_k + block_k;
+ const float scale = scales[scale_idx];
+ const float value = static_cast(src[row_offset + n]);
+ const float quantized = value / scale;
+ const Fp8T fp8_value(quantized, true);
+ dst[row_offset + n] = fp8_value.val;
+ }
+ }
+}
+
+template
+void ComputeBlockwiseScalesFromInput(const SrcT* src,
+ size_t K,
+ size_t N,
+ size_t block_size_k,
+ size_t block_size_n,
+ float fp8_max_abs,
+ float* scales) {
+ // Reference-style dynamic quantization: derive one positive scale from each source block.
+ const size_t blocks_k = K / block_size_k;
+ const size_t blocks_n = N / block_size_n;
+ for (size_t block_k = 0; block_k < blocks_k; ++block_k) {
+ const size_t k_begin = block_k * block_size_k;
+ const size_t k_end = k_begin + block_size_k;
+ for (size_t block_n = 0; block_n < blocks_n; ++block_n) {
+ const size_t n_begin = block_n * block_size_n;
+ const size_t n_end = n_begin + block_size_n;
+ float max_abs = 0.0f;
+ for (size_t k = k_begin; k < k_end; ++k) {
+ const size_t row_offset = k * N;
+ for (size_t n = n_begin; n < n_end; ++n) {
+ max_abs = std::max(max_abs, std::fabs(static_cast(src[row_offset + n])));
+ }
+ }
+ // Match KleidiAI RHS scale layout: one scale per N block and K block.
+ scales[block_n * blocks_k + block_k] = max_abs == 0.0f ? 1.0f : max_abs / fp8_max_abs;
+ }
+ }
+}
+
+template
+Status QuantizeToFp8ByModeWithScales(mlas_fp8_mode fp8_mode,
+ const SrcT* src,
+ size_t K,
+ size_t N,
+ size_t block_size_k,
+ size_t block_size_n,
+ const float* scales,
+ uint8_t* dst) {
+ // Dispatch quantization using the requested FP8 mode and runtime block sizes.
+ switch (fp8_mode) {
+ case MLAS_FP8_MODE_E4M3_INF:
+ QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst);
+ return Status::OK();
+ case MLAS_FP8_MODE_E4M3_SAT:
+ QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst);
+ return Status::OK();
+ case MLAS_FP8_MODE_E5M2_INF:
+ QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst);
+ return Status::OK();
+ case MLAS_FP8_MODE_E5M2_SAT:
+ QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst);
+ return Status::OK();
+ default:
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Unsupported fp8 mode for DynamicQuantMatMulFp8.");
+ }
+}
+
+} // namespace
+
+ONNX_OPERATOR_KERNEL_EX(
+ DynamicQuantMatMulFp8,
+ kMSDomain,
+ 1,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("TA", std::vector{
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()})
+ .TypeConstraint("TB", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()})
+ .TypeConstraint("TZ", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()})
+ .TypeConstraint("TS", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()})
+ .TypeConstraint("TY", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}),
+ DynamicQuantMatMulFp8);
+
+DynamicQuantMatMulFp8::DynamicQuantMatMulFp8(const OpKernelInfo& info) : OpKernel(info) {
+ const int64_t block_size_m = info.GetAttrOrDefault("block_size_m", kDefaultBlockSizeM);
+ const int64_t block_size_k = info.GetAttrOrDefault("block_size_k", kDefaultBlockSize);
+ const int64_t block_size_n = info.GetAttrOrDefault("block_size_n", kDefaultBlockSize);
+ const int64_t fp8_type =
+ info.GetAttrOrDefault("fp8_type", ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN);
+ ORT_ENFORCE(block_size_m == 1,
+ "DynamicQuantMatMulFp8 requires block_size_m to be 1 to match the row-wise A scale layout.");
+ ORT_ENFORCE(block_size_k > 0,
+ "DynamicQuantMatMulFp8 requires block_size_k to be greater than zero.");
+ ORT_ENFORCE(block_size_n > 0,
+ "DynamicQuantMatMulFp8 requires block_size_n to be greater than zero.");
+ block_size_k_ = static_cast(block_size_k);
+ block_size_n_ = static_cast(block_size_n);
+ ORT_THROW_IF_ERROR(GetFp8Type(static_cast(fp8_type), fp8_type_));
+}
+
+Status DynamicQuantMatMulFp8::GetFp8Type(const Tensor& tensor, mlas_fp8_mode& out_type) {
+ return GetFp8Type(static_cast(tensor.GetElementType()), out_type);
+}
+
+Status DynamicQuantMatMulFp8::GetFp8Type(ONNX_NAMESPACE::TensorProto_DataType elem_type,
+ mlas_fp8_mode& out_type) {
+ switch (elem_type) {
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
+ out_type = MLAS_FP8_MODE_E4M3_INF;
+ return Status::OK();
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ:
+ out_type = MLAS_FP8_MODE_E4M3_SAT;
+ return Status::OK();
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
+ out_type = MLAS_FP8_MODE_E5M2_INF;
+ return Status::OK();
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ:
+ out_type = MLAS_FP8_MODE_E5M2_SAT;
+ return Status::OK();
+ default:
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported fp8 type for DynamicQuantMatMulFp8.");
+ }
+}
+
+Status DynamicQuantMatMulFp8::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ /*out*/ bool& is_packed,
+ /*out*/ PrePackedWeights* prepacked_weights) {
+ is_packed = false;
+ if (input_idx != GetBIdx()) {
+ return Status::OK();
+ }
+
+ b_shape_ = tensor.Shape();
+ if (b_shape_.NumDimensions() != 2) {
+ return Status::OK();
+ }
+
+ const size_t K = static_cast(b_shape_[0]);
+ const size_t N = static_cast(b_shape_[1]);
+ const auto b_elem_type = static_cast(tensor.GetElementType());
+ const bool b_is_fp8 = IsFp8DataType(b_elem_type);
+ if (b_is_fp8) {
+ ORT_RETURN_IF_ERROR(GetFp8Type(tensor, b_type_));
+ has_b_type_ = true;
+ return Status::OK();
+ }
+
+ b_type_ = fp8_type_;
+ has_b_type_ = true;
+ if (K == 0) {
+ return Status::OK();
+ }
+ if (N == 0) {
+ return Status::OK();
+ }
+
+ ORT_RETURN_IF_NOT(K % block_size_k_ == 0,
+ "DynamicQuantMatMulFp8 requires K to be divisible by block_size_k.");
+ ORT_RETURN_IF_NOT(N % block_size_n_ == 0,
+ "DynamicQuantMatMulFp8 requires N to be divisible by block_size_n.");
+ const size_t blocks_k = K / block_size_k_;
+ const size_t blocks_n = N / block_size_n_;
+ b_scale_count_ = SafeMul(blocks_k, blocks_n);
+ b_scales_ = IAllocator::MakeUniquePtr(alloc, b_scale_count_, true);
+ float fp8_max_abs = 0.0f;
+ ORT_RETURN_IF_ERROR(GetFp8MaxAbs(b_type_, fp8_max_abs));
+
+ const size_t quantized_b_size = SafeMul(K, N);
+ quantized_b_ = IAllocator::MakeUniquePtr(alloc, quantized_b_size, true);
+ quantized_b_size_ = quantized_b_size;
+ auto* quantized_b_bytes = static_cast(quantized_b_.get());
+ if (tensor.IsDataType()) {
+ ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_,
+ fp8_max_abs, b_scales_.get());
+ ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_,
+ b_scales_.get(), quantized_b_bytes));
+ } else if (tensor.IsDataType()) {
+ ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_,
+ fp8_max_abs, b_scales_.get());
+ ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_,
+ b_scales_.get(), quantized_b_bytes));
+ } else if (tensor.IsDataType()) {
+ ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_,
+ fp8_max_abs, b_scales_.get());
+ ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_,
+ b_scales_.get(), quantized_b_bytes));
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Unsupported B type for DynamicQuantMatMulFp8 prepack.");
+ }
+
+ if (prepacked_weights != nullptr) {
+ const std::array metadata_values = {
+ kPackedBMetadataVersion,
+ b_shape_[0],
+ b_shape_[1],
+ static_cast(quantized_b_size_),
+ static_cast(b_scale_count_),
+ static_cast(b_type_),
+ 1,
+ };
+ auto metadata = IAllocator::MakeUniquePtr(alloc, kPackedBMetadataSize, true);
+ std::memcpy(metadata.get(), metadata_values.data(), kPackedBMetadataSize);
+ auto b_scales_deleter = std::move(b_scales_.get_deleter());
+ IAllocatorUniquePtr b_scales_buffer(
+ b_scales_.release(),
+ [b_scales_deleter = std::move(b_scales_deleter)](void* p) mutable {
+ b_scales_deleter(static_cast(p));
+ });
+ prepacked_weights->buffers_.push_back(std::move(quantized_b_));
+ prepacked_weights->buffer_sizes_.push_back(quantized_b_size_);
+ prepacked_weights->buffers_.push_back(std::move(b_scales_buffer));
+ prepacked_weights->buffer_sizes_.push_back(b_scale_count_ * sizeof(float));
+ prepacked_weights->buffers_.push_back(std::move(metadata));
+ prepacked_weights->buffer_sizes_.push_back(kPackedBMetadataSize);
+ }
+ is_packed = true;
+ return Status::OK();
+}
+
+Status DynamicQuantMatMulFp8::UseSharedPrePackedBuffers(std::vector& prepacked_buffers,
+ gsl::span prepacked_buffer_sizes,
+ int input_idx,
+ /*out*/ bool& used_shared_buffers) {
+ used_shared_buffers = false;
+ if (input_idx != GetBIdx()) {
+ return Status::OK();
+ }
+
+ ORT_RETURN_IF(prepacked_buffers.size() != 3 || prepacked_buffer_sizes.size() != 3,
+ "DynamicQuantMatMulFp8 requires shared prepacked B data, scale, and metadata buffers.");
+ ORT_RETURN_IF(prepacked_buffers[0].get() == nullptr,
+ "DynamicQuantMatMulFp8 requires shared prepacked B data.");
+ ORT_RETURN_IF(prepacked_buffers[1].get() == nullptr,
+ "DynamicQuantMatMulFp8 requires shared prepacked B scales.");
+
+ // Buffer 0 owns quantized B bytes; buffer 1 owns computed B scales; buffer 2 restores kernel state.
+ ORT_RETURN_IF_ERROR(RestorePackedBMetadata(prepacked_buffers[2].get(),
+ prepacked_buffer_sizes[2],
+ prepacked_buffer_sizes[0],
+ prepacked_buffer_sizes[1],
+ b_shape_,
+ quantized_b_size_,
+ b_scale_count_,
+ b_type_,
+ has_b_type_));
+ quantized_b_ = std::move(prepacked_buffers[0]);
+ auto b_scales_deleter = prepacked_buffers[1].get_deleter();
+ b_scales_ = IAllocatorUniquePtr(
+ static_cast(prepacked_buffers[1].release()),
+ [b_scales_deleter = std::move(b_scales_deleter)](float* p) mutable {
+ b_scales_deleter(static_cast(p));
+ });
+ used_shared_buffers = true;
+ return Status::OK();
+}
+
+Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const {
+ const Tensor* a = context->Input(IN_A);
+ const Tensor* b = quantized_b_ ? nullptr : context->Input(IN_B);
+ const Tensor* b_scale = context->Input(IN_B_SCALE);
+ const Tensor* b_zero_point = context->Input(IN_B_ZERO_POINT);
+ const Tensor* y_scale = context->Input(IN_Y_SCALE);
+ const Tensor* y_zero_point = context->Input(IN_Y_ZERO_POINT);
+
+ // Runtime B uses one 2D B scale/zero-point layout, so reject batched B before MatMul broadcasts it.
+ ORT_RETURN_IF(!quantized_b_ && b->Shape().NumDimensions() != 2,
+ "DynamicQuantMatMulFp8 requires runtime B to be a 2D tensor.");
+
+ MatMulComputeHelper helper;
+ ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(),
+ quantized_b_ ? b_shape_ : b->Shape(),
+ nullptr,
+ nullptr));
+
+ const size_t M = static_cast(helper.M());
+ const size_t N = static_cast(helper.N());
+ const size_t K = static_cast(helper.K());
+ Tensor* y = context->Output(OUT_Y, helper.OutputShape());
+ const size_t y_size = static_cast(y->Shape().Size());
+ ORT_RETURN_IF(!y->IsDataType() && !y->IsDataType() && !y->IsDataType(),
+ "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16.");
+
+ if (y_zero_point != nullptr) {
+ // Runtime tensors must match the schema scalar contract before reading element 0.
+ ORT_RETURN_IF(y_zero_point->Shape().NumDimensions() != 0 || y_zero_point->Shape().Size() != 1,
+ "DynamicQuantMatMulFp8 requires Y zero point input to be a scalar.");
+ ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*y_zero_point, 1, "Y zero point"));
+ }
+
+ float y_scale_storage = 0.0f;
+ const float* y_scale_data = nullptr;
+ if (y_scale != nullptr) {
+ // Runtime tensors must match the schema scalar contract before reading element 0.
+ ORT_RETURN_IF(y_scale->Shape().NumDimensions() != 0 || y_scale->Shape().Size() != 1,
+ "DynamicQuantMatMulFp8 requires Y scale input to be a scalar.");
+ if (y_scale->IsDataType()) {
+ y_scale_data = y_scale->Data();
+ } else if (y_scale->IsDataType()) {
+ y_scale_storage = static_cast(y_scale->Data()[0]);
+ y_scale_data = &y_scale_storage;
+ } else if (y_scale->IsDataType()) {
+ y_scale_storage = static_cast(y_scale->Data()[0]);
+ y_scale_data = &y_scale_storage;
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires Y scale input to be float, float16, or bfloat16.");
+ }
+ ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(y_scale_data, 1, "Y scale"));
+ }
+
+ // Empty reduction does not need B data, so fill zeros before enforcing runtime FP8 B.
+ if (K == 0) {
+ if (y_size == 0) {
+ return Status::OK();
+ }
+ if (y->IsDataType()) {
+ std::fill_n(y->MutableData(), y_size, 0.0f);
+ } else if (y->IsDataType()) {
+ std::fill_n(y->MutableData(), y_size, MLFloat16::FromBits(0));
+ } else if (y->IsDataType()) {
+ std::fill_n(y->MutableData(), y_size, BFloat16::FromBits(0));
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16.");
+ }
+ return Status::OK();
+ }
+
+ const bool a_is_supported =
+ a->IsDataType() || a->IsDataType() || a->IsDataType();
+ ORT_RETURN_IF(!a_is_supported, "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16.");
+
+ const auto b_elem_type = b ? static_cast(b->GetElementType())
+ : static_cast(0);
+ const bool b_is_fp8 = IsFp8DataType(b_elem_type);
+
+ mlas_fp8_mode b_type{};
+ if (has_b_type_) {
+ b_type = b_type_;
+ } else if (b_is_fp8) {
+ ORT_RETURN_IF_ERROR(GetFp8Type(b_elem_type, b_type));
+ } else {
+ b_type = fp8_type_;
+ }
+
+ if (b_zero_point != nullptr) {
+ const auto b_zp_elem_type =
+ static_cast(b_zero_point->GetElementType());
+ mlas_fp8_mode b_zp_type{};
+ ORT_RETURN_IF_ERROR(GetFp8Type(b_zp_elem_type, b_zp_type));
+ ORT_RETURN_IF(b_type != b_zp_type,
+ "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match.");
+ }
+
+ if (y_size == 0) {
+ return Status::OK();
+ }
+
+ // Select the FP8 B buffer: prefer pre-quantized B from PrePack, otherwise accept FP8-typed B input.
+ const uint8_t* b_fp8 = nullptr;
+ if (quantized_b_) {
+ b_fp8 = static_cast(quantized_b_.get());
+ } else if (b_is_fp8) {
+ b_fp8 = static_cast(b->DataRaw());
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires runtime B input to be FP8. Non-FP8 B is only supported "
+ "when B is a constant initializer that can be quantized during prepack.");
+ }
+
+ const size_t num_gemms = helper.OutputOffsets().size();
+ ORT_RETURN_IF(K % block_size_k_ != 0,
+ "DynamicQuantMatMulFp8 requires K to be divisible by block_size_k.");
+ const size_t expected_blocks_k = K / block_size_k_;
+ const size_t blocks_m = M;
+ const size_t blocks_k = expected_blocks_k;
+
+ const size_t blocks_n = N / block_size_n_;
+ ORT_RETURN_IF(!b_scales_ && b_scale == nullptr,
+ "DynamicQuantMatMulFp8 requires B scale when B is already FP8.");
+ ORT_RETURN_IF(b_scale != nullptr && b_scale->Shape().NumDimensions() != 2,
+ "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor.");
+ ORT_RETURN_IF(blocks_n == 0, "DynamicQuantMatMulFp8 requires non-zero B scale N dimension.");
+ ORT_RETURN_IF(N % block_size_n_ != 0,
+ "DynamicQuantMatMulFp8 requires N to be divisible by block_size_n.");
+ ORT_RETURN_IF(b_scale != nullptr && static_cast(b_scale->Shape()[0]) != blocks_n,
+ "DynamicQuantMatMulFp8 requires B scale N dimension to be N / block_size_n.");
+ ORT_RETURN_IF(b_scale != nullptr && static_cast(b_scale->Shape()[1]) != blocks_k,
+ "DynamicQuantMatMulFp8 requires B scale K dimension to be K / block_size_k.");
+
+ const size_t a_scale_batch_stride = SafeMul(blocks_m, blocks_k);
+ const size_t b_zp_count = SafeMul(blocks_k, blocks_n);
+
+ if (b_zero_point != nullptr) {
+ ORT_RETURN_IF(b_zero_point->Shape().NumDimensions() != 2,
+ "DynamicQuantMatMulFp8 requires B zero point to be a 2D tensor.");
+ ORT_RETURN_IF(b_zero_point->Shape()[0] != static_cast(blocks_n) ||
+ b_zero_point->Shape()[1] != static_cast(blocks_k),
+ "DynamicQuantMatMulFp8 requires B zero point to have shape [N / block_size_n, K / block_size_k].");
+ ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*b_zero_point, b_zp_count, "B zero point"));
+ }
+
+ AllocatorPtr temp_allocator;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&temp_allocator));
+
+ const float* b_scales = nullptr;
+ IAllocatorUniquePtr b_scale_float;
+ size_t b_scale_elems = 0;
+ if (b_scales_) {
+ b_scales = b_scales_.get();
+ b_scale_elems = b_scale_count_;
+ } else if (b_scale->IsDataType()) {
+ b_scales = b_scale->Data();
+ b_scale_elems = static_cast(b_scale->Shape().Size());
+ } else {
+ AllocatorPtr allocator;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+ b_scale_elems = static_cast(b_scale->Shape().Size());
+ b_scale_float = IAllocator::MakeUniquePtr(allocator, b_scale_elems, true);
+ if (b_scale->IsDataType()) {
+ for (size_t i = 0; i < b_scale_elems; ++i) {
+ b_scale_float.get()[i] = static_cast(b_scale->Data()[i]);
+ }
+ } else if (b_scale->IsDataType()) {
+ for (size_t i = 0; i < b_scale_elems; ++i) {
+ b_scale_float.get()[i] = static_cast(b_scale->Data()[i]);
+ }
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires B scale input to be float, float16, or bfloat16.");
+ }
+ b_scales = b_scale_float.get();
+ }
+
+ // MLAS FP8 GEMM accumulates and stores float output. Use scratch for lower-precision Y,
+ // then convert once after all batched GEMMs complete.
+ IAllocatorUniquePtr y_float_buffer;
+ float* y_float_data = nullptr;
+ if (y->IsDataType()) {
+ y_float_data = y->MutableData();
+ } else if (y->IsDataType() || y->IsDataType()) {
+ y_float_buffer = IAllocator::MakeUniquePtr(temp_allocator, y_size, true);
+ y_float_data = y_float_buffer.get();
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16.");
+ }
+
+ MLAS_FP8_GEMM_SHAPE_PARAMS gemm_shape;
+ gemm_shape.M = M;
+ gemm_shape.N = N;
+ gemm_shape.K = K;
+
+ ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(b_scales, b_scale_elems, "B scale"));
+
+ const size_t a_fp8_size = SafeMul(M, K);
+ const size_t a_num_elements = static_cast(a->Shape().Size());
+ ORT_RETURN_IF(a_num_elements % a_fp8_size != 0,
+ "DynamicQuantMatMulFp8 requires A to contain complete MxK matrices.");
+ const size_t a_batch_count = a_num_elements / a_fp8_size;
+
+ // Quantize the physical A tensor once. Broadcasted output GEMMs then reuse the same FP8 A slice.
+ auto a_fp8_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_num_elements, true);
+ const size_t a_scale_count = SafeMul(a_batch_count, a_scale_batch_stride);
+ auto a_scale_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_scale_count, true);
+ const size_t a_quant_work_items = SafeMul(a_batch_count, a_scale_batch_stride);
+ ORT_RETURN_IF(a_quant_work_items > static_cast(std::numeric_limits::max()),
+ "DynamicQuantMatMulFp8 A quantization work item count exceeds ptrdiff_t range.");
+ const auto a_quant_work_items_i = static_cast(a_quant_work_items);
+ const size_t a_quant_block_elems = block_size_k_;
+ const TensorOpCost a_quant_unit_cost{
+ static_cast(SafeMul(a_quant_block_elems, sizeof(float))),
+ static_cast(SafeMul(a_quant_block_elems, sizeof(uint8_t))),
+ static_cast(a_quant_block_elems) * 2.0};
+ float fp8_max_abs = 0.0f;
+ ORT_RETURN_IF_ERROR(GetFp8MaxAbs(b_type, fp8_max_abs));
+ const auto quantize_a_batches = [&](const auto* a_data) {
+ concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), a_quant_work_items_i,
+ a_quant_unit_cost,
+ [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ for (std::ptrdiff_t tid = begin; tid < end; ++tid) {
+ const size_t work_idx = static_cast(tid);
+ const size_t a_batch_idx = work_idx / a_scale_batch_stride;
+ const size_t scale_block_idx = work_idx % a_scale_batch_stride;
+ const size_t row = scale_block_idx / blocks_k;
+ const size_t block_k = scale_block_idx % blocks_k;
+ const size_t a_batch_offset = a_batch_idx * a_fp8_size;
+ const size_t a_scale_batch_offset = a_batch_idx * a_scale_batch_stride;
+ QuantizeBlockwiseFp8ABlockDynamic(
+ a_data + a_batch_offset,
+ K, block_size_k_, blocks_k,
+ row, block_k, fp8_max_abs, b_type,
+ a_fp8_buffer.get() + a_batch_offset,
+ a_scale_buffer.get() + a_scale_batch_offset);
+ }
+ });
+ };
+ if (a->IsDataType()) {
+ quantize_a_batches(a->Data());
+ } else if (a->IsDataType()) {
+ quantize_a_batches(a->Data());
+ } else if (a->IsDataType()) {
+ quantize_a_batches(a->Data());
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16.");
+ }
+
+ std::vector gemm_data_vec(num_gemms);
+ for (size_t gemm_idx = 0; gemm_idx < num_gemms; ++gemm_idx) {
+ const size_t a_offset = helper.LeftOffsets()[gemm_idx];
+ ORT_RETURN_IF(a_offset >= a_num_elements || (a_offset % a_fp8_size) != 0,
+ "DynamicQuantMatMulFp8 requires A offsets to reference complete MxK matrices.");
+ const size_t scale_batch_index = a_offset / a_fp8_size;
+ ORT_RETURN_IF(scale_batch_index >= a_batch_count,
+ "DynamicQuantMatMulFp8 requires A offsets to reference complete MxK matrices.");
+ const size_t a_scale_batch_offset = SafeMul(scale_batch_index, a_scale_batch_stride);
+ const float* a_scales_batch = a_scale_buffer.get() + a_scale_batch_offset;
+ auto& gemm_data = gemm_data_vec[gemm_idx];
+ gemm_data.A = a_fp8_buffer.get() + a_offset;
+ gemm_data.lda = K;
+ gemm_data.B = b_fp8 + helper.RightOffsets()[gemm_idx];
+ gemm_data.ldb = N;
+ gemm_data.C = y_float_data + helper.OutputOffsets()[gemm_idx];
+ gemm_data.ldc = N;
+ gemm_data.ScaleA = a_scales_batch;
+ gemm_data.ScaleB = b_scales;
+ gemm_data.ScaleY = y_scale_data;
+ gemm_data.Fp8Type = b_type;
+ gemm_data.BlockSizeM = 1;
+ gemm_data.BlockSizeK = block_size_k_;
+ gemm_data.BlockSizeN = block_size_n_;
+ gemm_data.BlocksM = blocks_m;
+ gemm_data.BlocksK = blocks_k;
+ gemm_data.BlocksN = blocks_n;
+ gemm_data.ScaleAStrideK = 1;
+ gemm_data.ScaleAStrideM = blocks_k;
+ gemm_data.ScaleBStrideN = blocks_k;
+ gemm_data.ScaleBStrideK = 1;
+ }
+
+ MlasFp8GemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, context->GetOperatorThreadPool());
+
+ if (y_float_buffer != nullptr) {
+ if (y->IsDataType()) {
+ auto* y_data = y->MutableData();
+ for (size_t i = 0; i < y_size; ++i) {
+ y_data[i] = static_cast(y_float_data[i]);
+ }
+ } else {
+ auto* y_data = y->MutableData();
+ for (size_t i = 0; i < y_size; ++i) {
+ y_data[i] = BFloat16(y_float_data[i]);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+#endif // !defined(DISABLE_FLOAT8_TYPES)
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h
new file mode 100644
index 0000000000000..cef88def0ec77
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h
@@ -0,0 +1,59 @@
+// Copyright (c) 2026 Arm Limited. All rights reserved.
+// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include "core/framework/op_kernel.h"
+#include "core/framework/prepacked_weights.h"
+#include "core/graph/onnx_protobuf.h"
+#include "core/mlas/inc/mlas.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+class DynamicQuantMatMulFp8 final : public OpKernel {
+ public:
+ DynamicQuantMatMulFp8(const OpKernelInfo& info);
+
+ Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ /*out*/ bool& is_packed,
+ /*out*/ PrePackedWeights* prepacked_weights) override;
+
+ Status Compute(OpKernelContext* context) const override;
+
+ enum InputTensors : int {
+ IN_A = 0,
+ IN_B = 1,
+ IN_B_SCALE = 2,
+ IN_B_ZERO_POINT = 3,
+ IN_Y_SCALE = 4,
+ IN_Y_ZERO_POINT = 5
+ };
+
+ enum OutputTensors : int { OUT_Y = 0 };
+
+ static Status GetFp8Type(const Tensor& tensor, mlas_fp8_mode& out_type);
+ static Status GetFp8Type(ONNX_NAMESPACE::TensorProto_DataType elem_type, mlas_fp8_mode& out_type);
+
+ Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers,
+ gsl::span prepacked_buffer_sizes,
+ int input_idx,
+ /*out*/ bool& used_shared_buffers) override;
+
+ private:
+ static constexpr int GetBIdx() { return IN_B; }
+ IAllocatorUniquePtr quantized_b_;
+ size_t quantized_b_size_{0};
+ IAllocatorUniquePtr b_scales_;
+ size_t b_scale_count_{0};
+ TensorShape b_shape_;
+ mlas_fp8_mode b_type_{static_cast(0)};
+ bool has_b_type_{false};
+ mlas_fp8_mode fp8_type_{MLAS_FP8_MODE_E4M3_INF};
+ size_t block_size_k_{128};
+ size_t block_size_n_{128};
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/core/common/fp8_common.h b/onnxruntime/core/common/fp8_common.h
new file mode 100644
index 0000000000000..e4cb73df9cd80
--- /dev/null
+++ b/onnxruntime/core/common/fp8_common.h
@@ -0,0 +1,68 @@
+// Copyright (c) 2026 Arm Limited. All rights reserved.
+// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+#include "core/common/float8.h"
+#include "core/mlas/inc/mlas.h"
+
+#include
+
+namespace onnxruntime {
+
+inline float Fp8ByteToFloat(uint8_t value, mlas_fp8_mode mode) {
+ switch (mode) {
+ case MLAS_FP8_MODE_E4M3_INF:
+ return Float8E4M3FN(value, Float8E4M3FN::FromBits()).ToFloat();
+ case MLAS_FP8_MODE_E4M3_SAT:
+ return Float8E4M3FNUZ(value, Float8E4M3FNUZ::FromBits()).ToFloat();
+ case MLAS_FP8_MODE_E5M2_INF:
+ return Float8E5M2(value, Float8E5M2::FromBits()).ToFloat();
+ case MLAS_FP8_MODE_E5M2_SAT:
+ return Float8E5M2FNUZ(value, Float8E5M2FNUZ::FromBits()).ToFloat();
+ default:
+ return 0.0f;
+ }
+}
+
+inline uint8_t FloatToFp8Byte(float value, mlas_fp8_mode mode) {
+ switch (mode) {
+ case MLAS_FP8_MODE_E4M3_INF: {
+ const Float8E4M3FN fp8(value, true);
+ return fp8.val;
+ }
+ case MLAS_FP8_MODE_E4M3_SAT: {
+ const Float8E4M3FNUZ fp8(value, true);
+ return fp8.val;
+ }
+ case MLAS_FP8_MODE_E5M2_INF: {
+ const Float8E5M2 fp8(value, true);
+ return fp8.val;
+ }
+ case MLAS_FP8_MODE_E5M2_SAT: {
+ const Float8E5M2FNUZ fp8(value, true);
+ return fp8.val;
+ }
+ default:
+ return 0;
+ }
+}
+
+inline bool IsValidFp8Mode(mlas_fp8_mode mode) {
+ switch (mode) {
+ case MLAS_FP8_MODE_E4M3_INF:
+ case MLAS_FP8_MODE_E4M3_SAT:
+ case MLAS_FP8_MODE_E5M2_INF:
+ case MLAS_FP8_MODE_E5M2_SAT:
+ return true;
+ default:
+ return false;
+ }
+}
+
+} // namespace onnxruntime
+
+#endif // !defined(DISABLE_FLOAT8_TYPES)
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index 59f97c222ceb2..6894b4e7eed40 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -25,6 +25,9 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MulInteger);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QGemm);
+#if !defined(DISABLE_FLOAT8_TYPES)
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicQuantMatMulFp8);
+#endif
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAdd);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConcat);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearWhere);
@@ -139,6 +142,9 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+#if !defined(DISABLE_FLOAT8_TYPES)
+ fn(GetOpSchema());
+#endif
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
index 5a3cd86b04492..86800e4d1ec09 100644
--- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
@@ -946,6 +946,130 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
updateOutputShape(ctx, 0, {first_input_shape.dim(transA ? 1 : 0), second_input_shape.dim(transB ? 0 : 1)});
}
}));
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ DynamicQuantMatMulFp8, 1,
+ OpSchema()
+ .SetDoc("Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from "
+ "float16/bfloat16/float) and dynamic runtime quantization of activations to fp8 using "
+ "internally computed block-wise scales. All zero-point inputs, when provided, must encode 0.0.")
+ .Input(0, "A", "Input tensor A.", "TA")
+ .Input(1, "B",
+ "Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only "
+ "supported when B is a constant initializer that can be quantized during prepack.",
+ "TB")
+ .Input(2, "B_scale",
+ "Scale of FP8 input 'B'. Must be a block-wise tensor with shape "
+ "(N / block_size_n, K / block_size_k). Required when B is already FP8. Ignored for non-FP8 "
+ "constant B, where scales are computed during prepack.",
+ "TS", OpSchema::Optional)
+ .Input(3, "B_zero_point",
+ "Zero point tensor for input 'B'. Must have the same shape as B_scale and all values must encode 0.0.",
+ "TZ", OpSchema::Optional)
+ .Input(4, "Y_scale", "Scale of output 'Y'. Must be a scalar when provided.", "TS",
+ OpSchema::Optional)
+ .Input(5, "Y_zero_point",
+ "Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided.", "TZ",
+ OpSchema::Optional)
+ .Output(0, "Y", "Output tensor of shape (..., M, N).", "TY")
+ .Attr("block_size_m", "Block size along M for A block-wise scales. Must be 1.",
+ AttributeProto::INT, static_cast(1))
+ .Attr("block_size_k", "Block size along K for A and B block-wise scales.", AttributeProto::INT,
+ static_cast(128))
+ .Attr("block_size_n", "Block size along N for B block-wise scales.", AttributeProto::INT,
+ static_cast(128))
+ .Attr("fp8_type",
+ "FP8 TensorProto data type used when non-FP8 constant B is dynamically quantized during prepack. "
+ "Defaults to FLOAT8E4M3FN.",
+ AttributeProto::INT,
+ static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN))
+ .TypeConstraint("TA", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"},
+ "Constrain input A type to float16, bfloat16, or float.")
+ .TypeConstraint("TB",
+ {"tensor(float16)", "tensor(bfloat16)", "tensor(float)",
+ "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
+ "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"},
+ "Constrain input B type to fp8, or to float16, bfloat16, or float for constant initializers.")
+ .TypeConstraint("TZ", {"tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"},
+ "Constrain zero point types to fp8. Only zero-valued zero points are supported.")
+ // Scale tensors are upcast to float by the CPU kernel before calling MLAS.
+ .TypeConstraint("TS", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"},
+ "Constrain scale types to float, float16, or bfloat16.")
+ .TypeConstraint("TY", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"},
+ "Constrain output type to float16, bfloat16, or float.")
+ .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
+ const int64_t block_size_m = getAttribute(ctx, "block_size_m", static_cast(1));
+ const int64_t block_size_k = getAttribute(ctx, "block_size_k", static_cast(128));
+ const int64_t block_size_n = getAttribute(ctx, "block_size_n", static_cast(128));
+ if (block_size_m != 1) {
+ fail_type_inference("block_size_m must be 1.");
+ }
+ if (block_size_k <= 0 || block_size_n <= 0) {
+ fail_type_inference("block_size_k and block_size_n must be greater than zero.");
+ }
+ if (hasInputShape(ctx, 1)) {
+ auto& b_shape = getInputShape(ctx, 1);
+ if (b_shape.dim_size() != 2) {
+ fail_type_inference("B must be 2D.");
+ }
+ }
+ if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) {
+ ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1);
+ }
+ if (hasInputShape(ctx, 4)) {
+ auto shape = ctx.getInputType(4)->tensor_type().shape();
+ if (shape.dim_size() != 0) {
+ fail_type_inference("Y scale input must be a scalar.");
+ }
+ }
+ if (hasInputShape(ctx, 5)) {
+ auto shape = ctx.getInputType(5)->tensor_type().shape();
+ if (shape.dim_size() != 0) {
+ fail_type_inference("Y zero point input must be a scalar.");
+ }
+ }
+ if (hasInputShape(ctx, 1) && hasInputShape(ctx, 2)) {
+ auto& b_shape = getInputShape(ctx, 1);
+ auto& b_scale_shape = getInputShape(ctx, 2);
+ if (b_scale_shape.dim_size() != 2) {
+ fail_type_inference("B scale must be 2D.");
+ }
+ if (b_shape.dim(1).has_dim_value() && b_scale_shape.dim(0).has_dim_value()) {
+ const auto n = b_shape.dim(1).dim_value();
+ if ((n % block_size_n) != 0 || b_scale_shape.dim(0).dim_value() != (n / block_size_n)) {
+ fail_type_inference("B scale first dimension must be N / block_size_n.");
+ }
+ }
+ if (b_shape.dim(0).has_dim_value() && b_scale_shape.dim(1).has_dim_value()) {
+ const auto k = b_shape.dim(0).dim_value();
+ if ((k % block_size_k) != 0 || b_scale_shape.dim(1).dim_value() != (k / block_size_k)) {
+ fail_type_inference("B scale last dimension must be K / block_size_k.");
+ }
+ }
+ }
+ if (hasInputShape(ctx, 1) && hasInputShape(ctx, 3)) {
+ auto& b_shape = getInputShape(ctx, 1);
+ auto& b_zp_shape = getInputShape(ctx, 3);
+ if (b_zp_shape.dim_size() != 2) {
+ fail_type_inference("B zero point must be 2D.");
+ }
+ if (b_shape.dim(1).has_dim_value() && b_zp_shape.dim(0).has_dim_value()) {
+ const auto n = b_shape.dim(1).dim_value();
+ if ((n % block_size_n) != 0 || b_zp_shape.dim(0).dim_value() != (n / block_size_n)) {
+ fail_type_inference("B zero point first dimension must be N / block_size_n.");
+ }
+ }
+ if (b_shape.dim(0).has_dim_value() && b_zp_shape.dim(1).has_dim_value()) {
+ const auto k = b_shape.dim(0).dim_value();
+ if ((k % block_size_k) != 0 || b_zp_shape.dim(1).dim_value() != (k / block_size_k)) {
+ fail_type_inference("B zero point last dimension must be K / block_size_k.");
+ }
+ }
+ }
+ }));
+
+#endif
ONNX_MS_OPERATOR_SET_SCHEMA(
QAttention, 1,
OpSchema()
diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
index ddb9daa5e244b..866b4430a9ce7 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -726,6 +726,76 @@ bool
MLASCALL
MlasIsDynamicQGemmAvailable(const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig);
+/**
+ * @brief Parameters that define the shape of an FP8 GEMM operation.
+ */
+struct MLAS_FP8_GEMM_SHAPE_PARAMS {
+ size_t M = 0; /**< Row size of matrix A */
+ size_t N = 0; /**< Column size of matrix B */
+ size_t K = 0; /**< Column size of matrix A and row size of matrix B */
+};
+
+// FP8 mode is aligned with Arm KleidiAI format/overflow modes.
+// Defined here to keep MLAS FP8 APIs platform-agnostic.
+#ifndef MLAS_FP8_MODE_DEFINED
+#define MLAS_FP8_MODE_DEFINED
+enum mlas_fp8_mode : uint8_t {
+ MLAS_FP8_MODE_E4M3_INF = 0, ///< E4M3 with NaN/Inf on overflow.
+ MLAS_FP8_MODE_E4M3_SAT = 1, ///< E4M3 with saturation on overflow.
+ MLAS_FP8_MODE_E5M2_INF = 2, ///< E5M2 with NaN/Inf on overflow.
+ MLAS_FP8_MODE_E5M2_SAT = 3, ///< E5M2 with saturation on overflow.
+ MLAS_FP8_MODE_END = 4, ///< End marker. Do not use.
+};
+#endif // MLAS_FP8_MODE_DEFINED
+
+/**
+ * @brief Parameters that define the data buffers and layout for an FP8 GEMM.
+ */
+struct MLAS_FP8_GEMM_DATA_PARAMS {
+ const void* A = nullptr;
+ size_t lda = 0;
+ const void* B = nullptr;
+ size_t ldb = 0;
+ void* C = nullptr;
+ size_t ldc = 0;
+ // Block-wise scales for A, indexed as block_m * ScaleAStrideM + block_k * ScaleAStrideK.
+ const float* ScaleA = nullptr;
+ // Block-wise scales for B, indexed as block_k * ScaleBStrideK + block_n * ScaleBStrideN.
+ const float* ScaleB = nullptr;
+ const float* ScaleY = nullptr; // Scalar scale for Y.
+ mlas_fp8_mode Fp8Type = static_cast(0);
+ size_t BlockSizeM = 128; // Block size along M for A quantization.
+ size_t BlockSizeK = 128; // Block size along K for A/B quantization.
+ size_t BlockSizeN = 128; // Block size along N for B quantization.
+ size_t BlocksM = 0; // Number of blocks along M (ceil(M / BlockSizeM)).
+ size_t BlocksK = 0; // Number of blocks along K (ceil(K / BlockSizeK)).
+ size_t BlocksN = 0; // Number of blocks along N (ceil(N / BlockSizeN)).
+ size_t ScaleAStrideK = 0; // ScaleA stride between K blocks (elements).
+ size_t ScaleAStrideM = 0; // ScaleA stride between M blocks (elements).
+ size_t ScaleBStrideN = 0; // ScaleB stride between N blocks (elements).
+ size_t ScaleBStrideK = 0; // ScaleB stride between K blocks (elements).
+};
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+void
+MLASCALL
+MlasFp8GemmBatch(
+ const MLAS_FP8_GEMM_SHAPE_PARAMS& Shape,
+ const MLAS_FP8_GEMM_DATA_PARAMS* DataParams,
+ const size_t BatchN,
+ MLAS_THREADPOOL* ThreadPool
+);
+
+inline void
+MlasFp8Gemm(
+ const MLAS_FP8_GEMM_SHAPE_PARAMS& Shape,
+ const MLAS_FP8_GEMM_DATA_PARAMS* DataParams,
+ MLAS_THREADPOOL* ThreadPool
+) {
+ MlasFp8GemmBatch(Shape, DataParams, 1, ThreadPool);
+}
+#endif // !defined(DISABLE_FLOAT8_TYPES)
+
//
// Symmetric QGEMM has limited buffer overrun.
// Currently only supported in ARM64
diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
index 3c9f398ece887..79561d246bac9 100644
--- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
+++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
@@ -165,7 +165,7 @@ MLASCALL
MlasDynamicQGemmBatch(
const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape,
const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams,
- const size_t BatchN,
+ const size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
);
@@ -225,37 +225,3 @@ MlasConvSymmetricChannelsLast2DFloatPackW(
MLAS_THREADPOOL* ThreadPool
);
}
-
-/*++
-
-Routine Description:
-
- This routine determines if a wraparound will occur when multiplying two size_t variables
- Uses __builtin_mul_overflow if available on the current system and if not falls back
- to a default implementation to check this wraparound.
-
-Arguments:
-
- a - Supplies the first number to be muliplied.
-
- b - Supplies the second number to be muliplied.
-
- out - pointer to a size_t which acts as the return value in success cases.
-
-Return Value:
-
- Returns false if the operation was successful
- Returns true if wraparound of size_t was detected
-
---*/
-inline bool mul_overflow_size_t_builtin(size_t a, size_t b, size_t* out) {
-#if defined(__has_builtin)
-# if __has_builtin(__builtin_mul_overflow)
- return __builtin_mul_overflow(a, b, out);
-# endif
-#endif
- // Fallback to manual check if builtin not available
- if (b != 0 && a > SIZE_MAX / b) return true;
- if (out) *out = a * b;
- return false;
-}
diff --git a/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp
index f88af056aa156..99816649ac7d0 100644
--- a/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp
+++ b/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp
@@ -279,7 +279,7 @@ Return Value:
LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr);
size_t lhs_resize = 0;
- if (mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize))
+ if (MlasMultiplyOverflowsSizeT(LhsPackedStride, BatchSize, &lhs_resize))
{
// size_t wraparound detected for LhsPackedStride, fallback to MLAS
return false;
@@ -304,7 +304,7 @@ Return Value:
// Multithread pack lhs and rhs
RhsPackedStride = ArmKleidiAI::MlasSBGemmPackBSize(TransA, TransB, N, K);
size_t rhs_resize = 0;
- if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize))
+ if (MlasMultiplyOverflowsSizeT(RhsPackedStride, BatchSize, &rhs_resize))
{
// size_t wraparound detected for RhsPackedStride, fallback to MLAS
return false;
@@ -354,7 +354,7 @@ Return Value:
// Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop.
// Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively.
size_t max_tile_elems = 0;
- if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) {
+ if (MlasMultiplyOverflowsSizeT(m_step, n_step, &max_tile_elems)) {
// size_t wraparound detected for tile size, fallback to MLAS
return false;
}
diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp
index 69b83e1e5b49c..ef599cf1346a8 100644
--- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp
+++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp
@@ -542,7 +542,7 @@ Return Value:
LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr);
size_t lhs_resize = 0;
- if(mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize))
+ if(MlasMultiplyOverflowsSizeT(LhsPackedStride, BatchSize, &lhs_resize))
{
// size_t wraparound detected for LhsPackedStride, fallback to MLAS
return false;
@@ -568,7 +568,7 @@ Return Value:
// Multithread pack lhs and rhs
RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K);
size_t rhs_resize = 0;
- if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize))
+ if (MlasMultiplyOverflowsSizeT(RhsPackedStride, BatchSize, &rhs_resize))
{
// size_t wraparound detected for RhsPackedStride, fallback to MLAS
return false;
@@ -616,7 +616,7 @@ Return Value:
// Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop.
// Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively.
size_t max_tile_elems = 0;
- if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) {
+ if (MlasMultiplyOverflowsSizeT(m_step, n_step, &max_tile_elems)) {
// size_t wraparound detected for tile size, fallback to MLAS
return false;
}
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
index dbb414505ff38..b4ef99b1f17b2 100644
--- a/onnxruntime/core/mlas/lib/mlasi.h
+++ b/onnxruntime/core/mlas/lib/mlasi.h
@@ -119,6 +119,33 @@ Module Name:
#define MLAS_UNREFERENCED_PARAMETER(parameter) ((void)(parameter))
+//
+// Reports whether multiplying two size_t values would overflow.
+//
+
+MLAS_FORCEINLINE
+bool
+MlasMultiplyOverflowsSizeT(
+ size_t a,
+ size_t b,
+ size_t* out
+ )
+{
+#if defined(__has_builtin)
+#if __has_builtin(__builtin_mul_overflow)
+ size_t result;
+ return __builtin_mul_overflow(a, b, out != nullptr ? out : &result);
+#endif
+#endif
+ if (b != 0 && a > std::numeric_limits::max() / b) {
+ return true;
+ }
+ if (out != nullptr) {
+ *out = a * b;
+ }
+ return false;
+}
+
#ifdef MLAS_NO_EXCEPTION
MLAS_FORCEINLINE void
diff --git a/onnxruntime/core/mlas/lib/qgemm_fp8.cpp b/onnxruntime/core/mlas/lib/qgemm_fp8.cpp
new file mode 100644
index 0000000000000..fa0115db89d5b
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/qgemm_fp8.cpp
@@ -0,0 +1,164 @@
+// Copyright (c) 2026 Arm Limited. All rights reserved.
+// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates
+// SPDX-License-Identifier: MIT
+
+#include "mlasi.h"
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+#include "core/common/common.h"
+#include "core/common/fp8_common.h"
+
+namespace {
+
+// Computes the parallel work item count without wrapping before ptrdiff_t narrowing.
+inline ptrdiff_t CheckedWorkItems(size_t batch_count, size_t m) {
+ size_t work_items = 0;
+ ORT_ENFORCE(!MlasMultiplyOverflowsSizeT(batch_count, m, &work_items), "FP8 GEMM work item count overflow.");
+ ORT_ENFORCE(work_items <= static_cast(std::numeric_limits::max()),
+ "FP8 GEMM work item count exceeds ptrdiff_t range.");
+ return static_cast(work_items);
+}
+
+// Validates the largest row-major strided matrix offset used by the GEMM loops.
+inline void CheckStridedMatrixOffset(size_t rows, size_t cols, size_t row_stride) {
+ if (rows == 0 || cols == 0) {
+ return;
+ }
+
+ size_t offset = 0;
+ ORT_ENFORCE(!MlasMultiplyOverflowsSizeT(rows - 1, row_stride, &offset),
+ "FP8 GEMM matrix offset overflow.");
+ ORT_ENFORCE(cols - 1 <= std::numeric_limits::max() - offset,
+ "FP8 GEMM matrix offset overflow.");
+}
+
+// Validates the largest block-scale or zero-point offset used by the GEMM loops.
+inline void CheckBlockedMatrixOffset(size_t blocks0, size_t stride0, size_t blocks1, size_t stride1) {
+ if (blocks0 == 0 || blocks1 == 0) {
+ return;
+ }
+
+ size_t offset0 = 0;
+ ORT_ENFORCE(!MlasMultiplyOverflowsSizeT(blocks0 - 1, stride0, &offset0),
+ "FP8 GEMM block offset overflow.");
+ size_t offset1 = 0;
+ ORT_ENFORCE(!MlasMultiplyOverflowsSizeT(blocks1 - 1, stride1, &offset1),
+ "FP8 GEMM block offset overflow.");
+ ORT_ENFORCE(offset1 <= std::numeric_limits::max() - offset0,
+ "FP8 GEMM block offset overflow.");
+}
+
+// Validate caller-provided buffers, strides, and block metadata before parallel workers dereference them.
+inline void CheckFp8GemmBatchParams(
+ const MLAS_FP8_GEMM_SHAPE_PARAMS& shape,
+ const MLAS_FP8_GEMM_DATA_PARAMS& params) {
+ ORT_ENFORCE(onnxruntime::IsValidFp8Mode(params.Fp8Type), "FP8 GEMM mode must be valid.");
+ ORT_ENFORCE(params.BlockSizeM != 0 && params.BlockSizeK != 0 && params.BlockSizeN != 0,
+ "FP8 GEMM block sizes must be non-zero.");
+
+ const bool writes_output = shape.M != 0 && shape.N != 0;
+ const bool reads_reduction_data = writes_output && shape.K != 0;
+
+ // Empty-output GEMMs do not dereference C, and empty reductions do not read A/B.
+ if (reads_reduction_data) {
+ ORT_ENFORCE(params.A != nullptr, "FP8 GEMM A buffer must not be null.");
+ ORT_ENFORCE(params.B != nullptr, "FP8 GEMM B buffer must not be null.");
+ ORT_ENFORCE(params.lda >= shape.K, "FP8 GEMM lda must be greater than or equal to K.");
+ ORT_ENFORCE(params.ldb >= shape.N, "FP8 GEMM ldb must be greater than or equal to N.");
+
+ CheckStridedMatrixOffset(shape.M, shape.K, params.lda);
+ CheckStridedMatrixOffset(shape.K, shape.N, params.ldb);
+ }
+
+ if (writes_output) {
+ ORT_ENFORCE(params.C != nullptr, "FP8 GEMM C buffer must not be null.");
+ ORT_ENFORCE(params.ldc >= shape.N, "FP8 GEMM ldc must be greater than or equal to N.");
+
+ CheckStridedMatrixOffset(shape.M, shape.N, params.ldc);
+ }
+
+ const size_t blocks_m = shape.M == 0 ? 0 : ((shape.M - 1) / params.BlockSizeM) + 1;
+ const size_t blocks_k = shape.K == 0 ? 0 : ((shape.K - 1) / params.BlockSizeK) + 1;
+ const size_t blocks_n = shape.N == 0 ? 0 : ((shape.N - 1) / params.BlockSizeN) + 1;
+
+ if (reads_reduction_data && params.ScaleA != nullptr) {
+ ORT_ENFORCE(blocks_m == params.BlocksM, "FP8 GEMM M block count must match shape and block size.");
+ ORT_ENFORCE(blocks_k == params.BlocksK, "FP8 GEMM K block count must match shape and block size.");
+ CheckBlockedMatrixOffset(blocks_m, params.ScaleAStrideM, blocks_k, params.ScaleAStrideK);
+ }
+ if (reads_reduction_data && params.ScaleB != nullptr) {
+ ORT_ENFORCE(blocks_k == params.BlocksK, "FP8 GEMM K block count must match shape and block size.");
+ ORT_ENFORCE(blocks_n == params.BlocksN, "FP8 GEMM N block count must match shape and block size.");
+ CheckBlockedMatrixOffset(blocks_k, params.ScaleBStrideK, blocks_n, params.ScaleBStrideN);
+ }
+}
+
+} // namespace
+
+void
+MLASCALL
+MlasFp8GemmBatch(
+ const MLAS_FP8_GEMM_SHAPE_PARAMS& Shape,
+ const MLAS_FP8_GEMM_DATA_PARAMS* DataParams,
+ const size_t BatchN,
+ MLAS_THREADPOOL* ThreadPool
+ )
+{
+ const size_t M = Shape.M;
+ const size_t N = Shape.N;
+ const size_t K = Shape.K;
+
+ if (BatchN == 0 || M == 0 || N == 0) {
+ return;
+ }
+
+ const ptrdiff_t WorkItems = CheckedWorkItems(BatchN, M);
+
+ ORT_ENFORCE(DataParams != nullptr, "FP8 GEMM data parameters must not be null.");
+
+ for (size_t batch = 0; batch < BatchN; ++batch) {
+ CheckFp8GemmBatchParams(Shape, DataParams[batch]);
+ }
+
+ MlasTrySimpleParallel(ThreadPool, WorkItems, [&](ptrdiff_t tid) {
+ const size_t batch = static_cast(tid) / M;
+ const size_t m = static_cast(tid) % M;
+ const auto& params = DataParams[batch];
+ ORT_ENFORCE(onnxruntime::IsValidFp8Mode(params.Fp8Type), "FP8 GEMM mode must be valid.");
+
+ const auto* a_fp8 = static_cast(params.A);
+ const auto* b_fp8 = static_cast(params.B);
+ auto* c = static_cast(params.C);
+ const auto* scale_a = params.ScaleA;
+ const auto* scale_b = params.ScaleB;
+
+ const size_t block_m = m / params.BlockSizeM;
+ for (size_t n = 0; n < N; ++n) {
+ const size_t block_n = n / params.BlockSizeN;
+ float acc = 0.0f;
+ for (size_t k = 0; k < K; ++k) {
+ const size_t block_k = k / params.BlockSizeK;
+
+ const size_t a_scale_idx = block_m * params.ScaleAStrideM + block_k * params.ScaleAStrideK;
+ const size_t b_scale_idx = block_k * params.ScaleBStrideK + block_n * params.ScaleBStrideN;
+ const float scale_a_val = scale_a ? scale_a[a_scale_idx] : 1.0f;
+ const float scale_b_val = scale_b ? scale_b[b_scale_idx] : 1.0f;
+
+ const float a_val = onnxruntime::Fp8ByteToFloat(a_fp8[m * params.lda + k], params.Fp8Type);
+ const float b_val = onnxruntime::Fp8ByteToFloat(b_fp8[k * params.ldb + n], params.Fp8Type);
+
+ const float a_deq = a_val * scale_a_val;
+ const float b_deq = b_val * scale_b_val;
+ acc += a_deq * b_deq;
+ }
+
+ if (params.ScaleY != nullptr) {
+ acc *= params.ScaleY[0];
+ }
+ c[m * params.ldc + n] = acc;
+ }
+ });
+}
+
+#endif // !defined(DISABLE_FLOAT8_TYPES)
diff --git a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc
new file mode 100644
index 0000000000000..763660460d855
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc
@@ -0,0 +1,866 @@
+// Copyright (c) 2026 Arm Limited. All rights reserved.
+// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates
+// SPDX-License-Identifier: MIT
+
+#include
+#include
+#include
+#include
+
+#include "gtest/gtest.h"
+#include "core/session/inference_session.h"
+#include "test/providers/provider_test_utils.h"
+#include "test/util/include/test_environment.h"
+#include "test/unittest_util/conversion.h"
+#include "default_providers.h"
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+#include "core/common/float8.h"
+
+namespace onnxruntime {
+namespace test {
+
+class DynamicQuantMatMulFp8SessionTester : public OpTester {
+ public:
+ using BaseTester::ExecuteModel;
+ using BaseTester::FillFeedsAndOutputNames;
+ using BaseTester::SetTestFunctionCalled;
+ using OpTester::BuildModel;
+ using OpTester::OpTester;
+};
+
+template
+struct Fp8TensorProtoType;
+
+template <>
+struct Fp8TensorProtoType {
+ static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN;
+};
+
+template <>
+struct Fp8TensorProtoType {
+ static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ;
+};
+
+template <>
+struct Fp8TensorProtoType {
+ static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
+};
+
+template <>
+struct Fp8TensorProtoType {
+ static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
+};
+
+template
+float Fp8MaxAbs();
+
+template <>
+float Fp8MaxAbs() {
+ return 448.0f;
+}
+
+template <>
+float Fp8MaxAbs() {
+ return 448.0f;
+}
+
+template <>
+float Fp8MaxAbs() {
+ return 57344.0f;
+}
+
+template <>
+float Fp8MaxAbs() {
+ return 57344.0f;
+}
+
+template
+float QuantizeDequantize(float value, float scale) {
+ return Fp8T(value / scale, true).ToFloat() * scale;
+}
+
+template
+std::vector