diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 554ff5a1bf863..76b81686161e4 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -26,6 +26,7 @@ Do not modify directly.*
* com.microsoft.DequantizeBFP
* com.microsoft.DequantizeLinear
* com.microsoft.DequantizeWithOrder
+ * com.microsoft.DynamicQuantMatMulFp8
* com.microsoft.DynamicQuantizeLSTM
* com.microsoft.DynamicQuantizeMatMul
* com.microsoft.DynamicTimeWarping
@@ -1495,6 +1496,65 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.DynamicQuantMatMulFp8**
+
+ Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from float16/bfloat16/float) and dynamic runtime quantization of activations to fp8 using internally computed block-wise scales. All zero-point inputs, when provided, must encode 0.0. Optional trailing inputs may be omitted, but intermediate optional inputs must use an empty input name to keep later input positions.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- block_size_k : int
+- Block size along K for A and B block-wise scales.
+- block_size_n : int
+- Block size along N for B block-wise scales.
+- fp8_type : int
+- FP8 TensorProto data type used when non-FP8 constant B is dynamically quantized during prepack. Defaults to FLOAT8E4M3FN.
+
+
+#### Inputs (2 - 6)
+
+
+- A : TA
+- Input tensor A.
+- B : TB
+- Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only supported when B is a constant initializer that can be quantized during prepack.
+- B_scale (optional) : TS
+- Scale of FP8 input 'B'. Must be a block-wise tensor with shape (N / block_size_n, K / block_size_k). Required when B is already FP8. Ignored for non-FP8 constant B, where scales are computed during prepack.
+- B_zero_point (optional) : TZ
+- Zero point tensor for input 'B'. Must have the same shape as B_scale and all values must encode 0.0.
+- Y_scale (optional) : TS
+- Scale of output 'Y'. Must be a scalar when provided.
+- Y_zero_point (optional) : TZ
+- Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided. May be provided without Y_scale; only Y_scale changes the floating-point output values.
+
+
+#### Outputs
+
+
+- Y : TY
+- Output tensor of shape (..., M, N).
+
+
+#### Type Constraints
+
+
+- TA : tensor(float16), tensor(bfloat16), tensor(float)
+- Constrain input A type to float16, bfloat16, or float.
+- TB : tensor(float16), tensor(bfloat16), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
+- Constrain input B type to fp8, or to float16, bfloat16, or float for constant initializers.
+- TZ : tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
+- Constrain zero point types to fp8. Only zero-valued zero points are supported.
+- TS : tensor(float), tensor(float16), tensor(bfloat16)
+- Constrain scale types to float, float16, or bfloat16.
+- TY : tensor(float16), tensor(bfloat16), tensor(float)
+- Constrain output type to float16, bfloat16, or float.
+
+
+
### **com.microsoft.DynamicQuantizeLSTM**
#### Version
@@ -6883,5 +6943,3 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
-
-
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index d0d8e750285d4..969083aba3aa0 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -577,6 +577,7 @@ The **OpSet Version** column uses the following notation:
|CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)|
|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**T** = tensor(float)|
|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float)|
+|DynamicQuantMatMulFp8|*in* A:**TA**
*in* B:**TB**
*in* B_scale:**TS**
*in* B_zero_point:**TZ**
*in* Y_scale:**TS**
*in* Y_zero_point:**TZ**
*out* Y:**TY**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**TS** = tensor(bfloat16), tensor(float), tensor(float16)
**TY** = tensor(bfloat16), tensor(float), tensor(float16)
**TZ** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)|
|DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
|DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
|DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)|
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 0749457f5a182..b581064b180a2 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -116,6 +116,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, NhwcMaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, NhwcMaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization);
+#if !defined(DISABLE_FLOAT8_TYPES)
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicQuantMatMulFp8);
+#endif
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE);
@@ -284,6 +287,9 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+#if !defined(DISABLE_FLOAT8_TYPES)
+ BuildKernelCreateInfo,
+#endif
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc
new file mode 100644
index 0000000000000..f6de3c24bb76a
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc
@@ -0,0 +1,955 @@
+// Copyright (c) 2026 Arm Limited. All rights reserved.
+// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates
+// SPDX-License-Identifier: MIT
+
+#include "contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "core/common/common.h"
+#include "core/common/float16.h"
+#include "core/common/float8.h"
+#include "core/common/safeint.h"
+#include "core/framework/op_kernel.h"
+#include "core/graph/onnx_protobuf.h"
+#include "core/platform/threadpool.h"
+#include "core/providers/cpu/math/matmul_helper.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+namespace {
+
+constexpr int64_t kDefaultBlockSize = 128;
+constexpr int64_t kPackedBMetadataVersion = 1;
+constexpr size_t kPackedBMetadataElementCount = 6;
+constexpr size_t kPackedBMetadataSize = kPackedBMetadataElementCount * sizeof(int64_t);
+
+enum PackedBMetadataIndex : size_t {
+ kPackedBMetadataVersionIndex = 0,
+ kPackedBMetadataRowsIndex,
+ kPackedBMetadataColsIndex,
+ kPackedBMetadataSizeIndex,
+ kPackedBMetadataScaleCountIndex,
+ kPackedBMetadataFp8ModeIndex,
+};
+
+bool IsFp8DataType(ONNX_NAMESPACE::TensorProto_DataType elem_type) {
+ return elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN ||
+ elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ ||
+ elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2 ||
+ elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
+}
+
+Status RestorePackedBMetadata(const void* metadata_buffer,
+ size_t metadata_size,
+ size_t quantized_b_buffer_size,
+ size_t b_scale_buffer_size,
+ TensorShape& b_shape,
+ size_t& quantized_b_size,
+ size_t& b_scale_count,
+ Fp8Mode& b_type,
+ bool& has_b_type) {
+ ORT_RETURN_IF(metadata_buffer == nullptr,
+ "DynamicQuantMatMulFp8 requires shared prepacked B metadata.");
+ ORT_RETURN_IF(metadata_size != kPackedBMetadataSize,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata has an unexpected size.");
+
+ const auto* metadata = static_cast(metadata_buffer);
+ ORT_RETURN_IF(metadata[kPackedBMetadataVersionIndex] != kPackedBMetadataVersion,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata has an unsupported version.");
+ ORT_RETURN_IF(metadata[kPackedBMetadataRowsIndex] <= 0 || metadata[kPackedBMetadataColsIndex] <= 0,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B shape.");
+ ORT_RETURN_IF(metadata[kPackedBMetadataSizeIndex] <= 0,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B buffer size.");
+ ORT_RETURN_IF(metadata[kPackedBMetadataScaleCountIndex] <= 0,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B scale count.");
+ const int64_t restored_fp8_mode = metadata[kPackedBMetadataFp8ModeIndex];
+ ORT_RETURN_IF(restored_fp8_mode < static_cast(Fp8Mode::E4M3Inf) ||
+ restored_fp8_mode >= static_cast(Fp8Mode::End),
+ "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid FP8 type.");
+
+ const size_t rows = static_cast(metadata[kPackedBMetadataRowsIndex]);
+ const size_t cols = static_cast(metadata[kPackedBMetadataColsIndex]);
+ const size_t expected_quantized_b_size = SafeMul(rows, cols);
+ const size_t restored_quantized_b_size = static_cast(metadata[kPackedBMetadataSizeIndex]);
+ ORT_RETURN_IF(restored_quantized_b_size != expected_quantized_b_size ||
+ restored_quantized_b_size != quantized_b_buffer_size,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata does not match the B buffer size.");
+ const size_t restored_b_scale_count = static_cast(metadata[kPackedBMetadataScaleCountIndex]);
+ ORT_RETURN_IF(restored_b_scale_count > std::numeric_limits::max() / sizeof(float) ||
+ restored_b_scale_count * sizeof(float) != b_scale_buffer_size,
+ "DynamicQuantMatMulFp8 shared prepacked B metadata does not match the B scale buffer size.");
+
+ b_shape = TensorShape({metadata[kPackedBMetadataRowsIndex], metadata[kPackedBMetadataColsIndex]});
+ quantized_b_size = restored_quantized_b_size;
+ b_scale_count = restored_b_scale_count;
+ b_type = static_cast(metadata[kPackedBMetadataFp8ModeIndex]);
+ has_b_type = true;
+ return Status::OK();
+}
+
+// Reject invalid scales before quantization divides by them or the GEMM dequantizes with them.
+template
+Status ValidatePositiveFiniteScaleTensorImpl(const Tensor& scale, const char* scale_name) {
+ const auto* data = scale.Data();
+ const size_t count = static_cast(scale.Shape().Size());
+
+ for (size_t i = 0; i < count; ++i) {
+ const float value = static_cast(data[i]);
+ ORT_RETURN_IF(!std::isfinite(value) || value <= 0.0f,
+ "DynamicQuantMatMulFp8 requires ", scale_name, " values to be finite and positive.");
+ }
+
+ return Status::OK();
+}
+
+Status ValidatePositiveFiniteScaleTensor(const Tensor& scale, const char* scale_name) {
+ if (scale.IsDataType()) {
+ return ValidatePositiveFiniteScaleTensorImpl(scale, scale_name);
+ }
+
+ if (scale.IsDataType()) {
+ return ValidatePositiveFiniteScaleTensorImpl(scale, scale_name);
+ }
+
+ if (scale.IsDataType()) {
+ return ValidatePositiveFiniteScaleTensorImpl(scale, scale_name);
+ }
+
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires ", scale_name,
+ " input to be float, float16, or bfloat16.");
+}
+
+uint8_t FloatToFp8Byte(float value, Fp8Mode mode) {
+ switch (mode) {
+ case Fp8Mode::E4M3Inf:
+ return Float8E4M3FN(value, true).val;
+ case Fp8Mode::E4M3Sat:
+ return Float8E4M3FNUZ(value, true).val;
+ case Fp8Mode::E5M2Inf:
+ return Float8E5M2(value, true).val;
+ case Fp8Mode::E5M2Sat:
+ return Float8E5M2FNUZ(value, true).val;
+ default:
+ ORT_THROW("Unsupported FP8 mode.");
+ }
+}
+
+float Fp8ByteToFloat(uint8_t value, Fp8Mode mode) {
+ switch (mode) {
+ case Fp8Mode::E4M3Inf:
+ return static_cast(Float8E4M3FN(value, Float8E4M3FN::FromBits()));
+ case Fp8Mode::E4M3Sat:
+ return static_cast(Float8E4M3FNUZ(value, Float8E4M3FNUZ::FromBits()));
+ case Fp8Mode::E5M2Inf:
+ return static_cast(Float8E5M2(value, Float8E5M2::FromBits()));
+ case Fp8Mode::E5M2Sat:
+ return static_cast(Float8E5M2FNUZ(value, Float8E5M2FNUZ::FromBits()));
+ default:
+ ORT_THROW("Unsupported FP8 mode.");
+ }
+}
+
+Status GetFp8MaxAbs(Fp8Mode mode, float& max_abs) {
+ switch (mode) {
+ case Fp8Mode::E4M3Inf:
+ case Fp8Mode::E4M3Sat:
+ max_abs = 448.0f;
+ return Status::OK();
+ case Fp8Mode::E5M2Inf:
+ case Fp8Mode::E5M2Sat:
+ max_abs = 57344.0f;
+ return Status::OK();
+ default:
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Unsupported fp8 mode for DynamicQuantMatMulFp8.");
+ }
+}
+
+Status ValidateZeroPointValuesAreZero(const Tensor& zero_point, size_t expected_count,
+ const char* zero_point_name) {
+ const size_t actual_count = static_cast(zero_point.Shape().Size());
+ ORT_RETURN_IF(actual_count != expected_count,
+ "DynamicQuantMatMulFp8 requires ", zero_point_name, " to have the expected number of elements.");
+
+ const auto reject_non_zero = [zero_point_name](float value) {
+ ORT_RETURN_IF(value != 0.0f,
+ "DynamicQuantMatMulFp8 supports symmetric quantization only; ",
+ zero_point_name, " values must be zero.");
+ return Status::OK();
+ };
+
+ if (zero_point.IsDataType()) {
+ const auto* zp = static_cast(zero_point.DataRaw());
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], Fp8Mode::E4M3Inf)));
+ }
+ } else if (zero_point.IsDataType()) {
+ const auto* zp = static_cast(zero_point.DataRaw());
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], Fp8Mode::E4M3Sat)));
+ }
+ } else if (zero_point.IsDataType()) {
+ const auto* zp = static_cast(zero_point.DataRaw());
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], Fp8Mode::E5M2Inf)));
+ }
+ } else if (zero_point.IsDataType()) {
+ const auto* zp = static_cast(zero_point.DataRaw());
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], Fp8Mode::E5M2Sat)));
+ }
+ } else if (zero_point.IsDataType()) {
+ const auto* zp = zero_point.Data();
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(zp[i]));
+ }
+ } else if (zero_point.IsDataType()) {
+ const auto* zp = zero_point.Data();
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(static_cast(zp[i])));
+ }
+ } else if (zero_point.IsDataType()) {
+ const auto* zp = zero_point.Data();
+ for (size_t i = 0; i < actual_count; ++i) {
+ ORT_RETURN_IF_ERROR(reject_non_zero(static_cast(zp[i])));
+ }
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Unsupported zero point type for DynamicQuantMatMulFp8.");
+ }
+ return Status::OK();
+}
+
+template
+void QuantizeBlockwiseFp8ABlockDynamic(const SrcT* src,
+ size_t K,
+ size_t block_size_k,
+ size_t blocks_k,
+ size_t row,
+ size_t block_k,
+ float fp8_max_abs,
+ Fp8Mode mode,
+ uint8_t* dst,
+ float* scales) {
+ const size_t k_begin = block_k * block_size_k;
+ const size_t k_end = std::min(K, k_begin + block_size_k);
+ const size_t idx = row * blocks_k + block_k;
+
+ // Use one dynamic quantization scale per A row and K block. The scale depends on
+ // the block max_abs, so this reference path scans first and quantizes in a second pass.
+ float max_abs = 0.0f;
+ const size_t row_offset = row * K;
+ for (size_t k = k_begin; k < k_end; ++k) {
+ max_abs = std::max(max_abs, std::fabs(static_cast(src[row_offset + k])));
+ }
+
+ const float scale = max_abs == 0.0f ? 1.0f : max_abs / fp8_max_abs;
+ scales[idx] = scale;
+
+ for (size_t k = k_begin; k < k_end; ++k) {
+ const float value = static_cast(src[row_offset + k]);
+ const float quantized = value / scale;
+ dst[row_offset + k] = FloatToFp8Byte(quantized, mode);
+ }
+}
+
+template
+void QuantizeBlockwiseFp8WithScales(const SrcT* src,
+ size_t K,
+ size_t N,
+ size_t block_size_k,
+ size_t block_size_n,
+ const float* scales,
+ uint8_t* dst) {
+ // Block sizes come from op attributes; scale shapes only provide the number of blocks.
+ const size_t blocks_k = K / block_size_k;
+ for (size_t k = 0; k < K; ++k) {
+ const size_t block_k = k / block_size_k;
+ const size_t row_offset = k * N;
+ for (size_t n = 0; n < N; ++n) {
+ const size_t block_n = n / block_size_n;
+ const size_t scale_idx = block_n * blocks_k + block_k;
+ const float scale = scales[scale_idx];
+ const float value = static_cast(src[row_offset + n]);
+ const float quantized = value / scale;
+ const Fp8T fp8_value(quantized, true);
+ dst[row_offset + n] = fp8_value.val;
+ }
+ }
+}
+
+template
+void ComputeBlockwiseScalesFromInput(const SrcT* src,
+ size_t K,
+ size_t N,
+ size_t block_size_k,
+ size_t block_size_n,
+ float fp8_max_abs,
+ float* scales) {
+ // Reference-style dynamic quantization: derive one positive scale from each source block.
+ const size_t blocks_k = K / block_size_k;
+ const size_t blocks_n = N / block_size_n;
+ for (size_t block_k = 0; block_k < blocks_k; ++block_k) {
+ const size_t k_begin = block_k * block_size_k;
+ const size_t k_end = k_begin + block_size_k;
+ for (size_t block_n = 0; block_n < blocks_n; ++block_n) {
+ const size_t n_begin = block_n * block_size_n;
+ const size_t n_end = n_begin + block_size_n;
+ float max_abs = 0.0f;
+ for (size_t k = k_begin; k < k_end; ++k) {
+ const size_t row_offset = k * N;
+ for (size_t n = n_begin; n < n_end; ++n) {
+ max_abs = std::max(max_abs, std::fabs(static_cast(src[row_offset + n])));
+ }
+ }
+ // Use one scale per N block and K block. Quantization runs after all block
+ // scales are known, so this reference path intentionally reads B in two phases.
+ scales[block_n * blocks_k + block_k] = max_abs == 0.0f ? 1.0f : max_abs / fp8_max_abs;
+ }
+ }
+}
+
+template
+Status QuantizeToFp8ByModeWithScales(Fp8Mode fp8_mode,
+ const SrcT* src,
+ size_t K,
+ size_t N,
+ size_t block_size_k,
+ size_t block_size_n,
+ const float* scales,
+ uint8_t* dst) {
+ // Dispatch quantization using the requested FP8 mode and runtime block sizes.
+ switch (fp8_mode) {
+ case Fp8Mode::E4M3Inf:
+ QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst);
+ return Status::OK();
+ case Fp8Mode::E4M3Sat:
+ QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst);
+ return Status::OK();
+ case Fp8Mode::E5M2Inf:
+ QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst);
+ return Status::OK();
+ case Fp8Mode::E5M2Sat:
+ QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst);
+ return Status::OK();
+ default:
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Unsupported fp8 mode for DynamicQuantMatMulFp8.");
+ }
+}
+
+struct Fp8GemmShapeParams {
+ size_t M;
+ size_t N;
+ size_t K;
+ Fp8Mode Fp8Type;
+};
+
+struct Fp8GemmDataParams {
+ const void* A = nullptr;
+ size_t lda = 0;
+ const void* B = nullptr;
+ size_t ldb = 0;
+ void* C = nullptr;
+ size_t ldc = 0;
+ const float* ScaleA = nullptr;
+ const float* ScaleB = nullptr;
+ const float* ScaleY = nullptr;
+ size_t BlockSizeM = 0;
+ size_t BlockSizeK = 0;
+ size_t BlockSizeN = 0;
+ size_t ScaleAStrideK = 0;
+ size_t ScaleAStrideM = 0;
+ size_t ScaleBStrideK = 0;
+ size_t ScaleBStrideN = 0;
+};
+
+Status ReferenceFp8GemmBatch(const Fp8GemmShapeParams& shape,
+ const Fp8GemmDataParams* data_params,
+ size_t batch_count,
+ concurrency::ThreadPool* thread_pool) {
+ const size_t M = shape.M;
+ const size_t N = shape.N;
+ const size_t K = shape.K;
+
+ if (batch_count == 0 || M == 0 || N == 0) {
+ return Status::OK();
+ }
+
+ size_t work_items_size = 0;
+ ORT_RETURN_IF(batch_count > std::numeric_limits::max() / M,
+ "DynamicQuantMatMulFp8 scalar GEMM work item count overflow.");
+ work_items_size = batch_count * M;
+ ORT_RETURN_IF(work_items_size > static_cast(std::numeric_limits::max()),
+ "DynamicQuantMatMulFp8 scalar GEMM work item count exceeds ptrdiff_t range.");
+ const auto work_items = static_cast(work_items_size);
+
+ const TensorOpCost unit_cost{
+ static_cast(SafeMul(K, sizeof(uint8_t)) * 2),
+ static_cast(N * sizeof(float)),
+ static_cast(SafeMul(K, N) * 2)};
+ concurrency::ThreadPool::TryParallelFor(thread_pool, work_items, unit_cost,
+ [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ for (std::ptrdiff_t tid = begin; tid < end; ++tid) {
+ const size_t batch = static_cast(tid) / M;
+ const size_t m = static_cast(tid) % M;
+ const auto& params = data_params[batch];
+
+ const auto* a_fp8 = static_cast(params.A);
+ const auto* b_fp8 = static_cast(params.B);
+ auto* c = static_cast(params.C);
+ const auto* scale_a = params.ScaleA;
+ const auto* scale_b = params.ScaleB;
+
+ const size_t block_m = m / params.BlockSizeM;
+ for (size_t n = 0; n < N; ++n) {
+ const size_t block_n = n / params.BlockSizeN;
+ float acc = 0.0f;
+ for (size_t k = 0; k < K; ++k) {
+ const size_t block_k = k / params.BlockSizeK;
+ const size_t a_scale_idx =
+ block_m * params.ScaleAStrideM + block_k * params.ScaleAStrideK;
+ const size_t b_scale_idx =
+ block_k * params.ScaleBStrideK + block_n * params.ScaleBStrideN;
+ const float scale_a_val = scale_a ? scale_a[a_scale_idx] : 1.0f;
+ const float scale_b_val = scale_b ? scale_b[b_scale_idx] : 1.0f;
+ const float a_val =
+ Fp8ByteToFloat(a_fp8[m * params.lda + k], shape.Fp8Type);
+ const float b_val =
+ Fp8ByteToFloat(b_fp8[k * params.ldb + n], shape.Fp8Type);
+ acc += (a_val * scale_a_val) * (b_val * scale_b_val);
+ }
+
+ if (params.ScaleY != nullptr) {
+ acc *= params.ScaleY[0];
+ }
+ c[m * params.ldc + n] = acc;
+ }
+ }
+ });
+ return Status::OK();
+}
+
+} // namespace
+
+ONNX_OPERATOR_KERNEL_EX(
+ DynamicQuantMatMulFp8,
+ kMSDomain,
+ 1,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("TA", std::vector{
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()})
+ .TypeConstraint("TB", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()})
+ .TypeConstraint("TZ", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()})
+ .TypeConstraint("TS", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()})
+ .TypeConstraint("TY", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}),
+ DynamicQuantMatMulFp8);
+
+DynamicQuantMatMulFp8::DynamicQuantMatMulFp8(const OpKernelInfo& info) : OpKernel(info) {
+ const int64_t block_size_k = info.GetAttrOrDefault("block_size_k", kDefaultBlockSize);
+ const int64_t block_size_n = info.GetAttrOrDefault("block_size_n", kDefaultBlockSize);
+ const int64_t fp8_type =
+ info.GetAttrOrDefault("fp8_type", ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN);
+ ORT_ENFORCE(block_size_k > 0,
+ "DynamicQuantMatMulFp8 requires block_size_k to be greater than zero.");
+ ORT_ENFORCE(block_size_n > 0,
+ "DynamicQuantMatMulFp8 requires block_size_n to be greater than zero.");
+ block_size_k_ = static_cast(block_size_k);
+ block_size_n_ = static_cast(block_size_n);
+ ORT_THROW_IF_ERROR(GetFp8Type(static_cast(fp8_type), fp8_type_));
+}
+
+Status DynamicQuantMatMulFp8::GetFp8Type(const Tensor& tensor, Fp8Mode& out_type) {
+ return GetFp8Type(static_cast(tensor.GetElementType()), out_type);
+}
+
+Status DynamicQuantMatMulFp8::GetFp8Type(ONNX_NAMESPACE::TensorProto_DataType elem_type,
+ Fp8Mode& out_type) {
+ switch (elem_type) {
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
+ out_type = Fp8Mode::E4M3Inf;
+ return Status::OK();
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ:
+ out_type = Fp8Mode::E4M3Sat;
+ return Status::OK();
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
+ out_type = Fp8Mode::E5M2Inf;
+ return Status::OK();
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ:
+ out_type = Fp8Mode::E5M2Sat;
+ return Status::OK();
+ default:
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported fp8 type for DynamicQuantMatMulFp8.");
+ }
+}
+
+Status DynamicQuantMatMulFp8::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ /*out*/ bool& is_packed,
+ /*out*/ PrePackedWeights* prepacked_weights) {
+ is_packed = false;
+
+ const Tensor* constant_b = nullptr;
+ const bool model_b_scale_is_ignored =
+ OpKernel::Info().TryGetConstantInput(GetBIdx(), &constant_b) &&
+ !IsFp8DataType(static_cast(constant_b->GetElementType()));
+
+ if (input_idx == IN_B_SCALE) {
+ // Non-FP8 constant B computes its own prepacked scales, so model B_scale values are not consumed.
+ if (!model_b_scale_is_ignored) {
+ ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScaleTensor(tensor, "B scale"));
+ constant_b_scale_values_validated_ = true;
+ }
+ return Status::OK();
+ }
+
+ if (input_idx == IN_B_ZERO_POINT) {
+ ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(
+ tensor, static_cast(tensor.Shape().Size()), "B zero point"));
+ constant_b_zero_point_values_validated_ = true;
+ return Status::OK();
+ }
+
+ if (input_idx != GetBIdx()) {
+ return Status::OK();
+ }
+
+ b_shape_ = tensor.Shape();
+ if (b_shape_.NumDimensions() != 2) {
+ return Status::OK();
+ }
+
+ const size_t K = static_cast(b_shape_[0]);
+ const size_t N = static_cast(b_shape_[1]);
+ const auto b_elem_type = static_cast(tensor.GetElementType());
+ const bool b_is_fp8 = IsFp8DataType(b_elem_type);
+ if (b_is_fp8) {
+ ORT_RETURN_IF_ERROR(GetFp8Type(tensor, b_type_));
+ has_b_type_ = true;
+ return Status::OK();
+ }
+
+ b_type_ = fp8_type_;
+ has_b_type_ = true;
+ if (K == 0 || N == 0) {
+ return Status::OK();
+ }
+
+ ORT_RETURN_IF_NOT(K % block_size_k_ == 0,
+ "DynamicQuantMatMulFp8 requires K to be divisible by block_size_k.");
+ ORT_RETURN_IF_NOT(N % block_size_n_ == 0,
+ "DynamicQuantMatMulFp8 requires N to be divisible by block_size_n.");
+ const size_t blocks_k = K / block_size_k_;
+ const size_t blocks_n = N / block_size_n_;
+ b_scale_count_ = SafeMul(blocks_k, blocks_n);
+ b_scales_ = IAllocator::MakeUniquePtr(alloc, b_scale_count_ * sizeof(float), true);
+ auto* prepacked_b_scales = static_cast(b_scales_.get());
+ float fp8_max_abs = 0.0f;
+ ORT_RETURN_IF_ERROR(GetFp8MaxAbs(b_type_, fp8_max_abs));
+
+ const size_t quantized_b_size = SafeMul(K, N);
+ quantized_b_ = IAllocator::MakeUniquePtr(alloc, quantized_b_size, true);
+ quantized_b_size_ = quantized_b_size;
+ auto* quantized_b_bytes = static_cast(quantized_b_.get());
+ if (tensor.IsDataType()) {
+ ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_,
+ fp8_max_abs, prepacked_b_scales);
+ ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_,
+ prepacked_b_scales, quantized_b_bytes));
+ } else if (tensor.IsDataType()) {
+ ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_,
+ fp8_max_abs, prepacked_b_scales);
+ ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_,
+ prepacked_b_scales, quantized_b_bytes));
+ } else if (tensor.IsDataType()) {
+ ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_,
+ fp8_max_abs, prepacked_b_scales);
+ ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_,
+ prepacked_b_scales, quantized_b_bytes));
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Unsupported B type for DynamicQuantMatMulFp8 prepack.");
+ }
+
+ if (prepacked_weights != nullptr) {
+ const std::array metadata_values = {
+ kPackedBMetadataVersion,
+ b_shape_[0],
+ b_shape_[1],
+ static_cast(quantized_b_size_),
+ static_cast(b_scale_count_),
+ static_cast(b_type_),
+ };
+ auto metadata = IAllocator::MakeUniquePtr(alloc, kPackedBMetadataSize, true);
+ std::memcpy(metadata.get(), metadata_values.data(), kPackedBMetadataSize);
+ prepacked_weights->buffers_.push_back(std::move(quantized_b_));
+ prepacked_weights->buffer_sizes_.push_back(quantized_b_size_);
+ prepacked_weights->buffers_.push_back(std::move(b_scales_));
+ prepacked_weights->buffer_sizes_.push_back(b_scale_count_ * sizeof(float));
+ prepacked_weights->buffers_.push_back(std::move(metadata));
+ prepacked_weights->buffer_sizes_.push_back(kPackedBMetadataSize);
+ }
+ is_packed = true;
+ return Status::OK();
+}
+
+Status DynamicQuantMatMulFp8::UseSharedPrePackedBuffers(std::vector& prepacked_buffers,
+ gsl::span prepacked_buffer_sizes,
+ int input_idx,
+ /*out*/ bool& used_shared_buffers) {
+ used_shared_buffers = false;
+ if (input_idx != GetBIdx()) {
+ return Status::OK();
+ }
+
+ ORT_RETURN_IF(prepacked_buffers.size() != 3 || prepacked_buffer_sizes.size() != 3,
+ "DynamicQuantMatMulFp8 requires shared prepacked B data, scale, and metadata buffers.");
+ ORT_RETURN_IF(prepacked_buffers[0].get() == nullptr,
+ "DynamicQuantMatMulFp8 requires shared prepacked B data.");
+ ORT_RETURN_IF(prepacked_buffers[1].get() == nullptr,
+ "DynamicQuantMatMulFp8 requires shared prepacked B scales.");
+
+ // Buffer 0 owns quantized B bytes; buffer 1 owns computed B scales; buffer 2 restores kernel state.
+ ORT_RETURN_IF_ERROR(RestorePackedBMetadata(prepacked_buffers[2].get(),
+ prepacked_buffer_sizes[2],
+ prepacked_buffer_sizes[0],
+ prepacked_buffer_sizes[1],
+ b_shape_,
+ quantized_b_size_,
+ b_scale_count_,
+ b_type_,
+ has_b_type_));
+ quantized_b_ = std::move(prepacked_buffers[0]);
+ b_scales_ = std::move(prepacked_buffers[1]);
+ used_shared_buffers = true;
+ return Status::OK();
+}
+
+Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const {
+ const Tensor* a = context->Input(IN_A);
+ const Tensor* b = quantized_b_ ? nullptr : context->Input(IN_B);
+ const Tensor* b_scale = context->Input(IN_B_SCALE);
+ const Tensor* b_zero_point = context->Input(IN_B_ZERO_POINT);
+ const Tensor* y_scale = context->Input(IN_Y_SCALE);
+ const Tensor* y_zero_point = context->Input(IN_Y_ZERO_POINT);
+
+ // Runtime B uses one 2D B scale/zero-point layout, so reject batched B before MatMul broadcasts it.
+ ORT_RETURN_IF(!quantized_b_ && b->Shape().NumDimensions() != 2,
+ "DynamicQuantMatMulFp8 requires runtime B to be a 2D tensor.");
+
+ MatMulComputeHelper helper;
+ ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(),
+ quantized_b_ ? b_shape_ : b->Shape(),
+ nullptr,
+ nullptr));
+
+ const size_t M = static_cast(helper.M());
+ const size_t N = static_cast(helper.N());
+ const size_t K = static_cast(helper.K());
+ Tensor* y = context->Output(OUT_Y, helper.OutputShape());
+ const size_t y_size = static_cast(y->Shape().Size());
+ ORT_RETURN_IF(!y->IsDataType() && !y->IsDataType() && !y->IsDataType(),
+ "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16.");
+
+ if (y_zero_point != nullptr) {
+ // Runtime tensors must match the schema scalar contract before reading element 0.
+ ORT_RETURN_IF(y_zero_point->Shape().NumDimensions() != 0 || y_zero_point->Shape().Size() != 1,
+ "DynamicQuantMatMulFp8 requires Y zero point input to be a scalar.");
+ ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*y_zero_point, 1, "Y zero point"));
+ }
+
+ float y_scale_storage = 0.0f;
+ const float* y_scale_data = nullptr;
+ if (y_scale != nullptr) {
+ // Runtime tensors must match the schema scalar contract before reading element 0.
+ ORT_RETURN_IF(y_scale->Shape().NumDimensions() != 0 || y_scale->Shape().Size() != 1,
+ "DynamicQuantMatMulFp8 requires Y scale input to be a scalar.");
+ if (y_scale->IsDataType()) {
+ y_scale_data = y_scale->Data();
+ } else if (y_scale->IsDataType()) {
+ y_scale_storage = static_cast(y_scale->Data()[0]);
+ y_scale_data = &y_scale_storage;
+ } else if (y_scale->IsDataType()) {
+ y_scale_storage = static_cast(y_scale->Data()[0]);
+ y_scale_data = &y_scale_storage;
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires Y scale input to be float, float16, or bfloat16.");
+ }
+ ORT_RETURN_IF(!std::isfinite(y_scale_data[0]) || y_scale_data[0] <= 0.0f,
+ "DynamicQuantMatMulFp8 requires Y scale values to be finite and positive.");
+ }
+
+ // Empty reduction does not need B data, so fill zeros before enforcing runtime FP8 B.
+ if (K == 0) {
+ if (y_size == 0) {
+ return Status::OK();
+ }
+ if (y->IsDataType()) {
+ std::fill_n(y->MutableData(), y_size, 0.0f);
+ } else if (y->IsDataType()) {
+ std::fill_n(y->MutableData(), y_size, MLFloat16::FromBits(0));
+ } else if (y->IsDataType()) {
+ std::fill_n(y->MutableData(), y_size, BFloat16::FromBits(0));
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16.");
+ }
+ return Status::OK();
+ }
+
+ const bool a_is_supported =
+ a->IsDataType() || a->IsDataType() || a->IsDataType();
+ ORT_RETURN_IF(!a_is_supported, "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16.");
+
+ const auto b_elem_type = b ? static_cast(b->GetElementType())
+ : static_cast(0);
+ const bool b_is_fp8 = IsFp8DataType(b_elem_type);
+
+ Fp8Mode b_type{};
+ if (has_b_type_) {
+ b_type = b_type_;
+ } else if (b_is_fp8) {
+ ORT_RETURN_IF_ERROR(GetFp8Type(b_elem_type, b_type));
+ } else {
+ b_type = fp8_type_;
+ }
+
+ if (b_zero_point != nullptr) {
+ const auto b_zp_elem_type =
+ static_cast(b_zero_point->GetElementType());
+ Fp8Mode b_zp_type{};
+ ORT_RETURN_IF_ERROR(GetFp8Type(b_zp_elem_type, b_zp_type));
+ ORT_RETURN_IF(b_type != b_zp_type,
+ "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match.");
+ }
+
+ if (y_size == 0) {
+ return Status::OK();
+ }
+
+ // Select the FP8 B buffer: prefer pre-quantized B from PrePack, otherwise accept FP8-typed B input.
+ const uint8_t* b_fp8 = nullptr;
+ if (quantized_b_) {
+ b_fp8 = static_cast(quantized_b_.get());
+ } else if (b_is_fp8) {
+ b_fp8 = static_cast(b->DataRaw());
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires runtime B input to be FP8. Non-FP8 B is only supported "
+ "when B is a constant initializer that can be quantized during prepack.");
+ }
+
+ const size_t num_gemms = helper.OutputOffsets().size();
+ ORT_RETURN_IF(K % block_size_k_ != 0,
+ "DynamicQuantMatMulFp8 requires K to be divisible by block_size_k.");
+ const size_t expected_blocks_k = K / block_size_k_;
+ const size_t blocks_m = M;
+ const size_t blocks_k = expected_blocks_k;
+
+ const bool uses_model_b_scale = b_scales_ == nullptr;
+ ORT_RETURN_IF(uses_model_b_scale && b_scale == nullptr,
+ "DynamicQuantMatMulFp8 requires B scale when B is already FP8.");
+ ORT_RETURN_IF(uses_model_b_scale && b_scale->Shape().NumDimensions() != 2,
+ "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor.");
+ ORT_RETURN_IF(N % block_size_n_ != 0,
+ "DynamicQuantMatMulFp8 requires N to be divisible by block_size_n.");
+ const size_t blocks_n = N / block_size_n_;
+ ORT_RETURN_IF(blocks_n == 0, "DynamicQuantMatMulFp8 requires non-zero B scale N dimension.");
+ ORT_RETURN_IF(uses_model_b_scale && static_cast(b_scale->Shape()[0]) != blocks_n,
+ "DynamicQuantMatMulFp8 requires B scale N dimension to be N / block_size_n.");
+ ORT_RETURN_IF(uses_model_b_scale && static_cast(b_scale->Shape()[1]) != blocks_k,
+ "DynamicQuantMatMulFp8 requires B scale K dimension to be K / block_size_k.");
+
+ const size_t a_scale_batch_stride = SafeMul(blocks_m, blocks_k);
+ const size_t b_zp_count = SafeMul(blocks_k, blocks_n);
+
+ if (uses_model_b_scale && b_zero_point != nullptr) {
+ ORT_RETURN_IF(b_zero_point->Shape().NumDimensions() != 2,
+ "DynamicQuantMatMulFp8 requires B zero point to be a 2D tensor.");
+ ORT_RETURN_IF(b_zero_point->Shape()[0] != static_cast(blocks_n) ||
+ b_zero_point->Shape()[1] != static_cast(blocks_k),
+ "DynamicQuantMatMulFp8 requires B zero point to have shape [N / block_size_n, K / block_size_k].");
+ if (!constant_b_zero_point_values_validated_) {
+ ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*b_zero_point, b_zp_count, "B zero point"));
+ }
+ }
+
+ AllocatorPtr temp_allocator;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&temp_allocator));
+
+ const float* b_scales = nullptr;
+ IAllocatorUniquePtr b_scale_float;
+ size_t b_scale_elems = 0;
+ if (b_scales_) {
+ b_scales = static_cast(b_scales_.get());
+ b_scale_elems = b_scale_count_;
+ } else if (b_scale->IsDataType()) {
+ b_scales = b_scale->Data();
+ b_scale_elems = static_cast(b_scale->Shape().Size());
+ } else {
+ AllocatorPtr allocator;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+ b_scale_elems = static_cast(b_scale->Shape().Size());
+ b_scale_float = IAllocator::MakeUniquePtr(allocator, b_scale_elems, true);
+ if (b_scale->IsDataType()) {
+ for (size_t i = 0; i < b_scale_elems; ++i) {
+ b_scale_float.get()[i] = static_cast(b_scale->Data()[i]);
+ }
+ } else if (b_scale->IsDataType()) {
+ for (size_t i = 0; i < b_scale_elems; ++i) {
+ b_scale_float.get()[i] = static_cast(b_scale->Data()[i]);
+ }
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires B scale input to be float, float16, or bfloat16.");
+ }
+ b_scales = b_scale_float.get();
+ }
+
+ // The internal FP8 GEMM helper accumulates and stores float output. Use scratch for lower-precision Y,
+ // then convert once after all batched GEMMs complete.
+ IAllocatorUniquePtr y_float_buffer;
+ float* y_float_data = nullptr;
+ if (y->IsDataType()) {
+ y_float_data = y->MutableData();
+ } else if (y->IsDataType() || y->IsDataType()) {
+ y_float_buffer = IAllocator::MakeUniquePtr(temp_allocator, y_size, true);
+ y_float_data = y_float_buffer.get();
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16.");
+ }
+
+ Fp8GemmShapeParams gemm_shape;
+ gemm_shape.M = M;
+ gemm_shape.N = N;
+ gemm_shape.K = K;
+ gemm_shape.Fp8Type = b_type;
+
+ if (uses_model_b_scale && !constant_b_scale_values_validated_) {
+ for (size_t i = 0; i < b_scale_elems; ++i) {
+ ORT_RETURN_IF(!std::isfinite(b_scales[i]) || b_scales[i] <= 0.0f,
+ "DynamicQuantMatMulFp8 requires B scale values to be finite and positive.");
+ }
+ }
+
+ const size_t a_fp8_size = SafeMul(M, K);
+ const size_t a_num_elements = static_cast(a->Shape().Size());
+ ORT_RETURN_IF(a_num_elements % a_fp8_size != 0,
+ "DynamicQuantMatMulFp8 requires A to contain complete MxK matrices.");
+ const size_t a_batch_count = a_num_elements / a_fp8_size;
+
+ // Quantize the physical A tensor once. Broadcasted output GEMMs then reuse the same FP8 A slice.
+ auto a_fp8_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_num_elements, true);
+ const size_t a_scale_count = SafeMul(a_batch_count, a_scale_batch_stride);
+ auto a_scale_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_scale_count, true);
+ const size_t a_quant_work_items = SafeMul(a_batch_count, a_scale_batch_stride);
+ ORT_RETURN_IF(a_quant_work_items > static_cast(std::numeric_limits::max()),
+ "DynamicQuantMatMulFp8 A quantization work item count exceeds ptrdiff_t range.");
+ const auto a_quant_work_items_i = static_cast(a_quant_work_items);
+ const size_t a_quant_block_elems = block_size_k_;
+ const TensorOpCost a_quant_unit_cost{
+ static_cast(SafeMul(a_quant_block_elems, sizeof(float))),
+ static_cast(SafeMul(a_quant_block_elems, sizeof(uint8_t))),
+ static_cast(a_quant_block_elems) * 2.0};
+ float fp8_max_abs = 0.0f;
+ ORT_RETURN_IF_ERROR(GetFp8MaxAbs(b_type, fp8_max_abs));
+ const auto quantize_a_batches = [&](const auto* a_data) {
+ concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), a_quant_work_items_i,
+ a_quant_unit_cost,
+ [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ for (std::ptrdiff_t tid = begin; tid < end; ++tid) {
+ const size_t work_idx = static_cast(tid);
+ const size_t a_batch_idx = work_idx / a_scale_batch_stride;
+ const size_t scale_block_idx = work_idx % a_scale_batch_stride;
+ const size_t row = scale_block_idx / blocks_k;
+ const size_t block_k = scale_block_idx % blocks_k;
+ const size_t a_batch_offset = a_batch_idx * a_fp8_size;
+ const size_t a_scale_batch_offset = a_batch_idx * a_scale_batch_stride;
+ QuantizeBlockwiseFp8ABlockDynamic(
+ a_data + a_batch_offset,
+ K, block_size_k_, blocks_k,
+ row, block_k, fp8_max_abs, b_type,
+ a_fp8_buffer.get() + a_batch_offset,
+ a_scale_buffer.get() + a_scale_batch_offset);
+ }
+ });
+ };
+ if (a->IsDataType()) {
+ quantize_a_batches(a->Data());
+ } else if (a->IsDataType()) {
+ quantize_a_batches(a->Data());
+ } else if (a->IsDataType()) {
+ quantize_a_batches(a->Data());
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16.");
+ }
+
+ std::vector gemm_data_vec(num_gemms);
+ for (size_t gemm_idx = 0; gemm_idx < num_gemms; ++gemm_idx) {
+ const size_t a_offset = helper.LeftOffsets()[gemm_idx];
+ ORT_RETURN_IF(a_offset >= a_num_elements || (a_offset % a_fp8_size) != 0,
+ "DynamicQuantMatMulFp8 requires A offsets to reference complete MxK matrices.");
+ const size_t scale_batch_index = a_offset / a_fp8_size;
+ ORT_RETURN_IF(scale_batch_index >= a_batch_count,
+ "DynamicQuantMatMulFp8 requires A offsets to reference complete MxK matrices.");
+ const size_t a_scale_batch_offset = SafeMul(scale_batch_index, a_scale_batch_stride);
+ const float* a_scales_batch = a_scale_buffer.get() + a_scale_batch_offset;
+ auto& gemm_data = gemm_data_vec[gemm_idx];
+ gemm_data.A = a_fp8_buffer.get() + a_offset;
+ gemm_data.lda = K;
+ gemm_data.B = b_fp8 + helper.RightOffsets()[gemm_idx];
+ gemm_data.ldb = N;
+ gemm_data.C = y_float_data + helper.OutputOffsets()[gemm_idx];
+ gemm_data.ldc = N;
+ gemm_data.ScaleA = a_scales_batch;
+ gemm_data.ScaleB = b_scales;
+ gemm_data.ScaleY = y_scale_data;
+ gemm_data.BlockSizeM = 1;
+ gemm_data.BlockSizeK = block_size_k_;
+ gemm_data.BlockSizeN = block_size_n_;
+ gemm_data.ScaleAStrideK = 1;
+ gemm_data.ScaleAStrideM = blocks_k;
+ gemm_data.ScaleBStrideN = blocks_k;
+ gemm_data.ScaleBStrideK = 1;
+ }
+
+ ORT_RETURN_IF_ERROR(ReferenceFp8GemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms,
+ context->GetOperatorThreadPool()));
+
+ if (y_float_buffer != nullptr) {
+ if (y->IsDataType()) {
+ auto* y_data = y->MutableData();
+ for (size_t i = 0; i < y_size; ++i) {
+ y_data[i] = static_cast(y_float_data[i]);
+ }
+ } else {
+ auto* y_data = y->MutableData();
+ for (size_t i = 0; i < y_size; ++i) {
+ y_data[i] = BFloat16(y_float_data[i]);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+#endif // !defined(DISABLE_FLOAT8_TYPES)
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h
new file mode 100644
index 0000000000000..02fbf142fd9c8
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h
@@ -0,0 +1,68 @@
+// Copyright (c) 2026 Arm Limited. All rights reserved.
+// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include "core/framework/op_kernel.h"
+#include "core/framework/prepacked_weights.h"
+#include "core/graph/onnx_protobuf.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+enum class Fp8Mode : int {
+ E4M3Inf = 0,
+ E4M3Sat,
+ E5M2Inf,
+ E5M2Sat,
+ End,
+};
+
+class DynamicQuantMatMulFp8 final : public OpKernel {
+ public:
+ DynamicQuantMatMulFp8(const OpKernelInfo& info);
+
+ Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ /*out*/ bool& is_packed,
+ /*out*/ PrePackedWeights* prepacked_weights) override;
+
+ Status Compute(OpKernelContext* context) const override;
+
+ enum InputTensors : int {
+ IN_A = 0,
+ IN_B = 1,
+ IN_B_SCALE = 2,
+ IN_B_ZERO_POINT = 3,
+ IN_Y_SCALE = 4,
+ IN_Y_ZERO_POINT = 5
+ };
+
+ enum OutputTensors : int { OUT_Y = 0 };
+
+ static Status GetFp8Type(const Tensor& tensor, Fp8Mode& out_type);
+ static Status GetFp8Type(ONNX_NAMESPACE::TensorProto_DataType elem_type, Fp8Mode& out_type);
+
+ Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers,
+ gsl::span prepacked_buffer_sizes,
+ int input_idx,
+ /*out*/ bool& used_shared_buffers) override;
+
+ private:
+ static constexpr int GetBIdx() { return IN_B; }
+ IAllocatorUniquePtr quantized_b_;
+ size_t quantized_b_size_{0};
+ IAllocatorUniquePtr b_scales_;
+ size_t b_scale_count_{0};
+ bool constant_b_scale_values_validated_{false};
+ bool constant_b_zero_point_values_validated_{false};
+ TensorShape b_shape_;
+ Fp8Mode b_type_{Fp8Mode::E4M3Inf};
+ bool has_b_type_{false};
+ Fp8Mode fp8_type_{Fp8Mode::E4M3Inf};
+ size_t block_size_k_{128};
+ size_t block_size_n_{128};
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index 59f97c222ceb2..6894b4e7eed40 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -25,6 +25,9 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MulInteger);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QGemm);
+#if !defined(DISABLE_FLOAT8_TYPES)
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicQuantMatMulFp8);
+#endif
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAdd);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConcat);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearWhere);
@@ -139,6 +142,9 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+#if !defined(DISABLE_FLOAT8_TYPES)
+ fn(GetOpSchema());
+#endif
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
index 5a3cd86b04492..b3abcb800eb93 100644
--- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
@@ -946,6 +946,127 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
updateOutputShape(ctx, 0, {first_input_shape.dim(transA ? 1 : 0), second_input_shape.dim(transB ? 0 : 1)});
}
}));
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ DynamicQuantMatMulFp8, 1,
+ OpSchema()
+ .SetDoc("Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from "
+ "float16/bfloat16/float) and dynamic runtime quantization of activations to fp8 using "
+ "internally computed block-wise scales. All zero-point inputs, when provided, must encode 0.0. "
+ "Optional trailing inputs may be omitted, but intermediate optional inputs must use an empty "
+ "input name to keep later input positions.")
+ .Input(0, "A", "Input tensor A.", "TA")
+ .Input(1, "B",
+ "Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only "
+ "supported when B is a constant initializer that can be quantized during prepack.",
+ "TB")
+ .Input(2, "B_scale",
+ "Scale of FP8 input 'B'. Must be a block-wise tensor with shape "
+ "(N / block_size_n, K / block_size_k). Required when B is already FP8. Ignored for non-FP8 "
+ "constant B, where scales are computed during prepack.",
+ "TS", OpSchema::Optional)
+ .Input(3, "B_zero_point",
+ "Zero point tensor for input 'B'. Must have the same shape as B_scale and all values must encode 0.0.",
+ "TZ", OpSchema::Optional)
+ .Input(4, "Y_scale", "Scale of output 'Y'. Must be a scalar when provided.", "TS",
+ OpSchema::Optional)
+ .Input(5, "Y_zero_point",
+ "Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided. "
+ "May be provided without Y_scale; only Y_scale changes the floating-point output values.",
+ "TZ", OpSchema::Optional)
+ .Output(0, "Y", "Output tensor of shape (..., M, N).", "TY")
+ .Attr("block_size_k", "Block size along K for A and B block-wise scales.", AttributeProto::INT,
+ static_cast(128))
+ .Attr("block_size_n", "Block size along N for B block-wise scales.", AttributeProto::INT,
+ static_cast(128))
+ .Attr("fp8_type",
+ "FP8 TensorProto data type used when non-FP8 constant B is dynamically quantized during prepack. "
+ "Defaults to FLOAT8E4M3FN.",
+ AttributeProto::INT,
+ static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN))
+ .TypeConstraint("TA", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"},
+ "Constrain input A type to float16, bfloat16, or float.")
+ .TypeConstraint("TB",
+ {"tensor(float16)", "tensor(bfloat16)", "tensor(float)",
+ "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
+ "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"},
+ "Constrain input B type to fp8, or to float16, bfloat16, or float for constant initializers.")
+ .TypeConstraint("TZ", {"tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"},
+ "Constrain zero point types to fp8. Only zero-valued zero points are supported.")
+ // Scale tensors are upcast to float by the CPU kernel before compute.
+ .TypeConstraint("TS", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"},
+ "Constrain scale types to float, float16, or bfloat16.")
+ .TypeConstraint("TY", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"},
+ "Constrain output type to float16, bfloat16, or float.")
+ .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
+ const int64_t block_size_k = getAttribute(ctx, "block_size_k", static_cast(128));
+ const int64_t block_size_n = getAttribute(ctx, "block_size_n", static_cast(128));
+ if (block_size_k <= 0 || block_size_n <= 0) {
+ fail_type_inference("block_size_k and block_size_n must be greater than zero.");
+ }
+ if (hasInputShape(ctx, 1)) {
+ auto& b_shape = getInputShape(ctx, 1);
+ if (b_shape.dim_size() != 2) {
+ fail_type_inference("B must be 2D.");
+ }
+ }
+ if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) {
+ ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1);
+ }
+ if (hasInputShape(ctx, 4)) {
+ auto shape = ctx.getInputType(4)->tensor_type().shape();
+ if (shape.dim_size() != 0) {
+ fail_type_inference("Y scale input must be a scalar.");
+ }
+ }
+ if (hasInputShape(ctx, 5)) {
+ auto shape = ctx.getInputType(5)->tensor_type().shape();
+ if (shape.dim_size() != 0) {
+ fail_type_inference("Y zero point input must be a scalar.");
+ }
+ }
+ if (hasInputShape(ctx, 1) && hasInputShape(ctx, 2)) {
+ auto& b_shape = getInputShape(ctx, 1);
+ auto& b_scale_shape = getInputShape(ctx, 2);
+ if (b_scale_shape.dim_size() != 2) {
+ fail_type_inference("B scale must be 2D.");
+ }
+ if (b_shape.dim(1).has_dim_value() && b_scale_shape.dim(0).has_dim_value()) {
+ const auto n = b_shape.dim(1).dim_value();
+ if ((n % block_size_n) != 0 || b_scale_shape.dim(0).dim_value() != (n / block_size_n)) {
+ fail_type_inference("B scale first dimension must be N / block_size_n.");
+ }
+ }
+ if (b_shape.dim(0).has_dim_value() && b_scale_shape.dim(1).has_dim_value()) {
+ const auto k = b_shape.dim(0).dim_value();
+ if ((k % block_size_k) != 0 || b_scale_shape.dim(1).dim_value() != (k / block_size_k)) {
+ fail_type_inference("B scale last dimension must be K / block_size_k.");
+ }
+ }
+ }
+ if (hasInputShape(ctx, 1) && hasInputShape(ctx, 3)) {
+ auto& b_shape = getInputShape(ctx, 1);
+ auto& b_zp_shape = getInputShape(ctx, 3);
+ if (b_zp_shape.dim_size() != 2) {
+ fail_type_inference("B zero point must be 2D.");
+ }
+ if (b_shape.dim(1).has_dim_value() && b_zp_shape.dim(0).has_dim_value()) {
+ const auto n = b_shape.dim(1).dim_value();
+ if ((n % block_size_n) != 0 || b_zp_shape.dim(0).dim_value() != (n / block_size_n)) {
+ fail_type_inference("B zero point first dimension must be N / block_size_n.");
+ }
+ }
+ if (b_shape.dim(0).has_dim_value() && b_zp_shape.dim(1).has_dim_value()) {
+ const auto k = b_shape.dim(0).dim_value();
+ if ((k % block_size_k) != 0 || b_zp_shape.dim(1).dim_value() != (k / block_size_k)) {
+ fail_type_inference("B zero point last dimension must be K / block_size_k.");
+ }
+ }
+ }
+ }));
+
+#endif
ONNX_MS_OPERATOR_SET_SCHEMA(
QAttention, 1,
OpSchema()
diff --git a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc
new file mode 100644
index 0000000000000..b532d39ac8384
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc
@@ -0,0 +1,844 @@
+// Copyright (c) 2026 Arm Limited. All rights reserved.
+// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates
+// SPDX-License-Identifier: MIT
+
+#include
+#include
+#include
+#include
+
+#include "gtest/gtest.h"
+#include "core/session/inference_session.h"
+#include "test/providers/provider_test_utils.h"
+#include "test/util/include/test_environment.h"
+#include "test/unittest_util/conversion.h"
+#include "default_providers.h"
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+#include "core/common/float8.h"
+
+namespace onnxruntime {
+namespace test {
+
+class DynamicQuantMatMulFp8SessionTester : public OpTester {
+ public:
+ using BaseTester::ExecuteModel;
+ using BaseTester::FillFeedsAndOutputNames;
+ using BaseTester::SetTestFunctionCalled;
+ using OpTester::BuildModel;
+ using OpTester::OpTester;
+};
+
+template
+struct Fp8TensorProtoType;
+
+template <>
+struct Fp8TensorProtoType {
+ static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN;
+};
+
+template <>
+struct Fp8TensorProtoType {
+ static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ;
+};
+
+template <>
+struct Fp8TensorProtoType