From 4cb522c9d94ba2bb058a694ccfb76edb3b25f2b9 Mon Sep 17 00:00:00 2001 From: melkap01 Date: Wed, 27 May 2026 15:44:09 +0100 Subject: [PATCH 1/5] Add DynamicQuantMatMulFp8 contrib op with internal FP8 GEMM helper Signed-off-by: melkap01 --- docs/ContribOperators.md | 62 +- docs/OperatorKernels.md | 1 + .../contrib_ops/cpu/cpu_contrib_kernels.cc | 6 + .../quantization/dynamic_quant_matmul_fp8.cc | 961 ++++++++++++++++++ onnxruntime/core/graph/contrib_ops/ms_opset.h | 6 + .../graph/contrib_ops/quantization_defs.cc | 121 +++ .../dynamic_quant_matmul_fp8_test.cc | 844 +++++++++++++++ 7 files changed, 1999 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc create mode 100644 onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 6df316097e719..251735009ba41 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -26,6 +26,7 @@ Do not modify directly.* * com.microsoft.DequantizeBFP * com.microsoft.DequantizeLinear * com.microsoft.DequantizeWithOrder + * com.microsoft.DynamicQuantMatMulFp8 * com.microsoft.DynamicQuantizeLSTM * com.microsoft.DynamicQuantizeMatMul * com.microsoft.DynamicTimeWarping @@ -1493,6 +1494,65 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.DynamicQuantMatMulFp8** + + Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from float16/bfloat16/float) and dynamic runtime quantization of activations to fp8 using internally computed block-wise scales. All zero-point inputs, when provided, must encode 0.0. Optional trailing inputs may be omitted, but intermediate optional inputs must use an empty input name to keep later input positions. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
block_size_k : int
+
Block size along K for A and B block-wise scales.
+
block_size_n : int
+
Block size along N for B block-wise scales.
+
fp8_type : int
+
FP8 TensorProto data type used when non-FP8 constant B is dynamically quantized during prepack. Defaults to FLOAT8E4M3FN.
+
+ +#### Inputs (2 - 6) + +
+
A : TA
+
Input tensor A.
+
B : TB
+
Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only supported when B is a constant initializer that can be quantized during prepack.
+
B_scale (optional) : TS
+
Scale of FP8 input 'B'. Must be a block-wise tensor with shape (N / block_size_n, K / block_size_k). Required when B is already FP8. Ignored for non-FP8 constant B, where scales are computed during prepack.
+
B_zero_point (optional) : TZ
+
Zero point tensor for input 'B'. Must have the same shape as B_scale and all values must encode 0.0.
+
Y_scale (optional) : TS
+
Scale of output 'Y'. Must be a scalar when provided.
+
Y_zero_point (optional) : TZ
+
Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided. May be provided without Y_scale; only Y_scale changes the floating-point output values.
+
+ +#### Outputs + +
+
Y : TY
+
Output tensor of shape (..., M, N).
+
+ +#### Type Constraints + +
+
TA : tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain input A type to float16, bfloat16, or float.
+
TB : tensor(float16), tensor(bfloat16), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
+
Constrain input B type to fp8, or to float16, bfloat16, or float for constant initializers.
+
TZ : tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
+
Constrain zero point types to fp8. Only zero-valued zero points are supported.
+
TS : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain scale types to float, float16, or bfloat16.
+
TY : tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain output type to float16, bfloat16, or float.
+
+ + ### **com.microsoft.DynamicQuantizeLSTM** #### Version @@ -6691,5 +6751,3 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
- - diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d0d8e750285d4..969083aba3aa0 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -577,6 +577,7 @@ The **OpSet Version** column uses the following notation: |CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)| |DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**T** = tensor(float)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float)| +|DynamicQuantMatMulFp8|*in* A:**TA**
*in* B:**TB**
*in* B_scale:**TS**
*in* B_zero_point:**TZ**
*in* Y_scale:**TS**
*in* Y_zero_point:**TZ**
*out* Y:**TY**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**TS** = tensor(bfloat16), tensor(float), tensor(float16)
**TY** = tensor(bfloat16), tensor(float), tensor(float16)
**TZ** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)| |DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)| diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 0749457f5a182..b581064b180a2 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -116,6 +116,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, NhwcMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, NhwcMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicQuantMatMulFp8); +#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE); @@ -284,6 +287,9 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, +#if !defined(DISABLE_FLOAT8_TYPES) + BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc new file mode 100644 index 0000000000000..4be1a7e73ecd2 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -0,0 +1,961 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "dynamic_quant_matmul_fp8.h" + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/graph/onnx_protobuf.h" +#include "core/common/float16.h" +#include "core/common/float8.h" +#include "core/common/safeint.h" +#include "core/platform/threadpool.h" +#include "core/providers/cpu/math/matmul_helper.h" + +#include +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +#if !defined(DISABLE_FLOAT8_TYPES) + +namespace { + +constexpr int64_t kDefaultBlockSize = 128; +constexpr int64_t kPackedBMetadataVersion = 1; +constexpr size_t kPackedBMetadataElementCount = 6; +constexpr size_t kPackedBMetadataSize = kPackedBMetadataElementCount * sizeof(int64_t); + +enum PackedBMetadataIndex : size_t { + kPackedBMetadataVersionIndex = 0, + kPackedBMetadataRowsIndex, + kPackedBMetadataColsIndex, + kPackedBMetadataSizeIndex, + kPackedBMetadataScaleCountIndex, + kPackedBMetadataFp8ModeIndex, +}; + +bool IsFp8DataType(ONNX_NAMESPACE::TensorProto_DataType elem_type) { + return elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || + elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ || + elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2 || + elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ; +} + +Status RestorePackedBMetadata(const void* metadata_buffer, + size_t metadata_size, + size_t quantized_b_buffer_size, + size_t b_scale_buffer_size, + TensorShape& b_shape, + size_t& quantized_b_size, + size_t& b_scale_count, + Fp8Mode& b_type, + bool& has_b_type) { + ORT_RETURN_IF(metadata_buffer == nullptr, + "DynamicQuantMatMulFp8 requires shared prepacked B metadata."); + ORT_RETURN_IF(metadata_size != kPackedBMetadataSize, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an unexpected size."); + + const auto* metadata = static_cast(metadata_buffer); + ORT_RETURN_IF(metadata[kPackedBMetadataVersionIndex] != kPackedBMetadataVersion, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an unsupported version."); + ORT_RETURN_IF(metadata[kPackedBMetadataRowsIndex] <= 0 || metadata[kPackedBMetadataColsIndex] <= 0, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B shape."); + ORT_RETURN_IF(metadata[kPackedBMetadataSizeIndex] <= 0, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B buffer size."); + ORT_RETURN_IF(metadata[kPackedBMetadataScaleCountIndex] <= 0, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B scale count."); + const int64_t restored_fp8_mode = metadata[kPackedBMetadataFp8ModeIndex]; + ORT_RETURN_IF(restored_fp8_mode < static_cast(Fp8Mode::E4M3Inf) || + restored_fp8_mode >= static_cast(Fp8Mode::End), + "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid FP8 type."); + + const size_t rows = static_cast(metadata[kPackedBMetadataRowsIndex]); + const size_t cols = static_cast(metadata[kPackedBMetadataColsIndex]); + const size_t expected_quantized_b_size = SafeMul(rows, cols); + const size_t restored_quantized_b_size = static_cast(metadata[kPackedBMetadataSizeIndex]); + ORT_RETURN_IF(restored_quantized_b_size != expected_quantized_b_size || + restored_quantized_b_size != quantized_b_buffer_size, + "DynamicQuantMatMulFp8 shared prepacked B metadata does not match the B buffer size."); + const size_t restored_b_scale_count = static_cast(metadata[kPackedBMetadataScaleCountIndex]); + ORT_RETURN_IF(restored_b_scale_count > std::numeric_limits::max() / sizeof(float) || + restored_b_scale_count * sizeof(float) != b_scale_buffer_size, + "DynamicQuantMatMulFp8 shared prepacked B metadata does not match the B scale buffer size."); + + b_shape = TensorShape({metadata[kPackedBMetadataRowsIndex], metadata[kPackedBMetadataColsIndex]}); + quantized_b_size = restored_quantized_b_size; + b_scale_count = restored_b_scale_count; + b_type = static_cast(metadata[kPackedBMetadataFp8ModeIndex]); + has_b_type = true; + return Status::OK(); +} + +// Reject invalid scales before quantization divides by them or the GEMM dequantizes with them. +Status ValidatePositiveFiniteScaleTensor(const Tensor& scale, const char* scale_name) { + const size_t count = static_cast(scale.Shape().Size()); + if (scale.IsDataType()) { + const auto* data = scale.Data(); + for (size_t i = 0; i < count; ++i) { + ORT_RETURN_IF(!std::isfinite(data[i]) || data[i] <= 0.0f, + "DynamicQuantMatMulFp8 requires ", scale_name, " values to be finite and positive."); + } + return Status::OK(); + } + + if (scale.IsDataType()) { + const auto* data = scale.Data(); + for (size_t i = 0; i < count; ++i) { + const float value = static_cast(data[i]); + ORT_RETURN_IF(!std::isfinite(value) || value <= 0.0f, + "DynamicQuantMatMulFp8 requires ", scale_name, " values to be finite and positive."); + } + return Status::OK(); + } + + if (scale.IsDataType()) { + const auto* data = scale.Data(); + for (size_t i = 0; i < count; ++i) { + const float value = static_cast(data[i]); + ORT_RETURN_IF(!std::isfinite(value) || value <= 0.0f, + "DynamicQuantMatMulFp8 requires ", scale_name, " values to be finite and positive."); + } + return Status::OK(); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires ", scale_name, + " input to be float, float16, or bfloat16."); +} + +uint8_t FloatToFp8Byte(float value, Fp8Mode mode) { + switch (mode) { + case Fp8Mode::E4M3Inf: + return Float8E4M3FN(value, true).val; + case Fp8Mode::E4M3Sat: + return Float8E4M3FNUZ(value, true).val; + case Fp8Mode::E5M2Inf: + return Float8E5M2(value, true).val; + case Fp8Mode::E5M2Sat: + return Float8E5M2FNUZ(value, true).val; + default: + ORT_THROW("Unsupported FP8 mode."); + } +} + +float Fp8ByteToFloat(uint8_t value, Fp8Mode mode) { + switch (mode) { + case Fp8Mode::E4M3Inf: + return static_cast(Float8E4M3FN(value, Float8E4M3FN::FromBits())); + case Fp8Mode::E4M3Sat: + return static_cast(Float8E4M3FNUZ(value, Float8E4M3FNUZ::FromBits())); + case Fp8Mode::E5M2Inf: + return static_cast(Float8E5M2(value, Float8E5M2::FromBits())); + case Fp8Mode::E5M2Sat: + return static_cast(Float8E5M2FNUZ(value, Float8E5M2FNUZ::FromBits())); + default: + ORT_THROW("Unsupported FP8 mode."); + } +} + +Status GetFp8MaxAbs(Fp8Mode mode, float& max_abs) { + switch (mode) { + case Fp8Mode::E4M3Inf: + case Fp8Mode::E4M3Sat: + max_abs = 448.0f; + return Status::OK(); + case Fp8Mode::E5M2Inf: + case Fp8Mode::E5M2Sat: + max_abs = 57344.0f; + return Status::OK(); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported fp8 mode for DynamicQuantMatMulFp8."); + } +} + +Status ValidateZeroPointValuesAreZero(const Tensor& zero_point, size_t expected_count, + const char* zero_point_name) { + const size_t actual_count = static_cast(zero_point.Shape().Size()); + ORT_RETURN_IF(actual_count != expected_count, + "DynamicQuantMatMulFp8 requires ", zero_point_name, " to have the expected number of elements."); + + const auto reject_non_zero = [zero_point_name](float value) { + ORT_RETURN_IF(value != 0.0f, + "DynamicQuantMatMulFp8 supports symmetric quantization only; ", + zero_point_name, " values must be zero."); + return Status::OK(); + }; + + if (zero_point.IsDataType()) { + const auto* zp = static_cast(zero_point.DataRaw()); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], Fp8Mode::E4M3Inf))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = static_cast(zero_point.DataRaw()); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], Fp8Mode::E4M3Sat))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = static_cast(zero_point.DataRaw()); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], Fp8Mode::E5M2Inf))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = static_cast(zero_point.DataRaw()); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], Fp8Mode::E5M2Sat))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = zero_point.Data(); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(zp[i])); + } + } else if (zero_point.IsDataType()) { + const auto* zp = zero_point.Data(); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(static_cast(zp[i]))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = zero_point.Data(); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(static_cast(zp[i]))); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported zero point type for DynamicQuantMatMulFp8."); + } + return Status::OK(); +} + +template +void QuantizeBlockwiseFp8ABlockDynamic(const SrcT* src, + size_t K, + size_t block_size_k, + size_t blocks_k, + size_t row, + size_t block_k, + float fp8_max_abs, + Fp8Mode mode, + uint8_t* dst, + float* scales) { + const size_t k_begin = block_k * block_size_k; + const size_t k_end = std::min(K, k_begin + block_size_k); + const size_t idx = row * blocks_k + block_k; + + // Use one dynamic quantization scale per A row and K block. The scale depends on + // the block max_abs, so this reference path scans first and quantizes in a second pass. + float max_abs = 0.0f; + const size_t row_offset = row * K; + for (size_t k = k_begin; k < k_end; ++k) { + max_abs = std::max(max_abs, std::fabs(static_cast(src[row_offset + k]))); + } + + const float scale = max_abs == 0.0f ? 1.0f : max_abs / fp8_max_abs; + scales[idx] = scale; + + for (size_t k = k_begin; k < k_end; ++k) { + const float value = static_cast(src[row_offset + k]); + const float quantized = value / scale; + dst[row_offset + k] = FloatToFp8Byte(quantized, mode); + } +} + +template +void QuantizeBlockwiseFp8WithScales(const SrcT* src, + size_t K, + size_t N, + size_t block_size_k, + size_t block_size_n, + const float* scales, + uint8_t* dst) { + // Block sizes come from op attributes; scale shapes only provide the number of blocks. + const size_t blocks_k = K / block_size_k; + for (size_t k = 0; k < K; ++k) { + const size_t block_k = k / block_size_k; + const size_t row_offset = k * N; + for (size_t n = 0; n < N; ++n) { + const size_t block_n = n / block_size_n; + const size_t scale_idx = block_n * blocks_k + block_k; + const float scale = scales[scale_idx]; + const float value = static_cast(src[row_offset + n]); + const float quantized = value / scale; + const Fp8T fp8_value(quantized, true); + dst[row_offset + n] = fp8_value.val; + } + } +} + +template +void ComputeBlockwiseScalesFromInput(const SrcT* src, + size_t K, + size_t N, + size_t block_size_k, + size_t block_size_n, + float fp8_max_abs, + float* scales) { + // Reference-style dynamic quantization: derive one positive scale from each source block. + const size_t blocks_k = K / block_size_k; + const size_t blocks_n = N / block_size_n; + for (size_t block_k = 0; block_k < blocks_k; ++block_k) { + const size_t k_begin = block_k * block_size_k; + const size_t k_end = k_begin + block_size_k; + for (size_t block_n = 0; block_n < blocks_n; ++block_n) { + const size_t n_begin = block_n * block_size_n; + const size_t n_end = n_begin + block_size_n; + float max_abs = 0.0f; + for (size_t k = k_begin; k < k_end; ++k) { + const size_t row_offset = k * N; + for (size_t n = n_begin; n < n_end; ++n) { + max_abs = std::max(max_abs, std::fabs(static_cast(src[row_offset + n]))); + } + } + // Use one scale per N block and K block. Quantization runs after all block + // scales are known, so this reference path intentionally reads B in two phases. + scales[block_n * blocks_k + block_k] = max_abs == 0.0f ? 1.0f : max_abs / fp8_max_abs; + } + } +} + +template +Status QuantizeToFp8ByModeWithScales(Fp8Mode fp8_mode, + const SrcT* src, + size_t K, + size_t N, + size_t block_size_k, + size_t block_size_n, + const float* scales, + uint8_t* dst) { + // Dispatch quantization using the requested FP8 mode and runtime block sizes. + switch (fp8_mode) { + case Fp8Mode::E4M3Inf: + QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst); + return Status::OK(); + case Fp8Mode::E4M3Sat: + QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst); + return Status::OK(); + case Fp8Mode::E5M2Inf: + QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst); + return Status::OK(); + case Fp8Mode::E5M2Sat: + QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst); + return Status::OK(); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported fp8 mode for DynamicQuantMatMulFp8."); + } +} + +struct Fp8GemmShapeParams { + size_t M; + size_t N; + size_t K; + Fp8Mode Fp8Type; +}; + +struct Fp8GemmDataParams { + const void* A = nullptr; + size_t lda = 0; + const void* B = nullptr; + size_t ldb = 0; + void* C = nullptr; + size_t ldc = 0; + const float* ScaleA = nullptr; + const float* ScaleB = nullptr; + const float* ScaleY = nullptr; + size_t BlockSizeM = 0; + size_t BlockSizeK = 0; + size_t BlockSizeN = 0; + size_t ScaleAStrideK = 0; + size_t ScaleAStrideM = 0; + size_t ScaleBStrideK = 0; + size_t ScaleBStrideN = 0; +}; + +Status ReferenceFp8GemmBatch(const Fp8GemmShapeParams& shape, + const Fp8GemmDataParams* data_params, + size_t batch_count, + concurrency::ThreadPool* thread_pool) { + const size_t M = shape.M; + const size_t N = shape.N; + const size_t K = shape.K; + + if (batch_count == 0 || M == 0 || N == 0) { + return Status::OK(); + } + + size_t work_items_size = 0; + ORT_RETURN_IF(batch_count > std::numeric_limits::max() / M, + "DynamicQuantMatMulFp8 scalar GEMM work item count overflow."); + work_items_size = batch_count * M; + ORT_RETURN_IF(work_items_size > static_cast(std::numeric_limits::max()), + "DynamicQuantMatMulFp8 scalar GEMM work item count exceeds ptrdiff_t range."); + const auto work_items = static_cast(work_items_size); + + const TensorOpCost unit_cost{ + static_cast(SafeMul(K, sizeof(uint8_t)) * 2), + static_cast(N * sizeof(float)), + static_cast(SafeMul(K, N) * 2)}; + concurrency::ThreadPool::TryParallelFor(thread_pool, work_items, unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t tid = begin; tid < end; ++tid) { + const size_t batch = static_cast(tid) / M; + const size_t m = static_cast(tid) % M; + const auto& params = data_params[batch]; + + const auto* a_fp8 = static_cast(params.A); + const auto* b_fp8 = static_cast(params.B); + auto* c = static_cast(params.C); + const auto* scale_a = params.ScaleA; + const auto* scale_b = params.ScaleB; + + const size_t block_m = m / params.BlockSizeM; + for (size_t n = 0; n < N; ++n) { + const size_t block_n = n / params.BlockSizeN; + float acc = 0.0f; + for (size_t k = 0; k < K; ++k) { + const size_t block_k = k / params.BlockSizeK; + const size_t a_scale_idx = + block_m * params.ScaleAStrideM + block_k * params.ScaleAStrideK; + const size_t b_scale_idx = + block_k * params.ScaleBStrideK + block_n * params.ScaleBStrideN; + const float scale_a_val = scale_a ? scale_a[a_scale_idx] : 1.0f; + const float scale_b_val = scale_b ? scale_b[b_scale_idx] : 1.0f; + const float a_val = + Fp8ByteToFloat(a_fp8[m * params.lda + k], shape.Fp8Type); + const float b_val = + Fp8ByteToFloat(b_fp8[k * params.ldb + n], shape.Fp8Type); + acc += (a_val * scale_a_val) * (b_val * scale_b_val); + } + + if (params.ScaleY != nullptr) { + acc *= params.ScaleY[0]; + } + c[m * params.ldc + n] = acc; + } + } + }); + return Status::OK(); +} + +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + DynamicQuantMatMulFp8, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("TA", std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("TB", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("TZ", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("TS", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("TY", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + DynamicQuantMatMulFp8); + +DynamicQuantMatMulFp8::DynamicQuantMatMulFp8(const OpKernelInfo& info) : OpKernel(info) { + const int64_t block_size_k = info.GetAttrOrDefault("block_size_k", kDefaultBlockSize); + const int64_t block_size_n = info.GetAttrOrDefault("block_size_n", kDefaultBlockSize); + const int64_t fp8_type = + info.GetAttrOrDefault("fp8_type", ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN); + ORT_ENFORCE(block_size_k > 0, + "DynamicQuantMatMulFp8 requires block_size_k to be greater than zero."); + ORT_ENFORCE(block_size_n > 0, + "DynamicQuantMatMulFp8 requires block_size_n to be greater than zero."); + block_size_k_ = static_cast(block_size_k); + block_size_n_ = static_cast(block_size_n); + ORT_THROW_IF_ERROR(GetFp8Type(static_cast(fp8_type), fp8_type_)); +} + +Status DynamicQuantMatMulFp8::GetFp8Type(const Tensor& tensor, Fp8Mode& out_type) { + return GetFp8Type(static_cast(tensor.GetElementType()), out_type); +} + +Status DynamicQuantMatMulFp8::GetFp8Type(ONNX_NAMESPACE::TensorProto_DataType elem_type, + Fp8Mode& out_type) { + switch (elem_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + out_type = Fp8Mode::E4M3Inf; + return Status::OK(); + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: + out_type = Fp8Mode::E4M3Sat; + return Status::OK(); + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + out_type = Fp8Mode::E5M2Inf; + return Status::OK(); + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: + out_type = Fp8Mode::E5M2Sat; + return Status::OK(); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported fp8 type for DynamicQuantMatMulFp8."); + } +} + +Status DynamicQuantMatMulFp8::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + + const Tensor* constant_b = nullptr; + const bool model_b_scale_is_ignored = + OpKernel::Info().TryGetConstantInput(GetBIdx(), &constant_b) && + !IsFp8DataType(static_cast(constant_b->GetElementType())); + + if (input_idx == IN_B_SCALE) { + // Non-FP8 constant B computes its own prepacked scales, so model B_scale values are not consumed. + if (!model_b_scale_is_ignored) { + ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScaleTensor(tensor, "B scale")); + constant_b_scale_values_validated_ = true; + } + return Status::OK(); + } + + if (input_idx == IN_B_ZERO_POINT) { + ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero( + tensor, static_cast(tensor.Shape().Size()), "B zero point")); + constant_b_zero_point_values_validated_ = true; + return Status::OK(); + } + + if (input_idx != GetBIdx()) { + return Status::OK(); + } + + b_shape_ = tensor.Shape(); + if (b_shape_.NumDimensions() != 2) { + return Status::OK(); + } + + const size_t K = static_cast(b_shape_[0]); + const size_t N = static_cast(b_shape_[1]); + const auto b_elem_type = static_cast(tensor.GetElementType()); + const bool b_is_fp8 = IsFp8DataType(b_elem_type); + if (b_is_fp8) { + ORT_RETURN_IF_ERROR(GetFp8Type(tensor, b_type_)); + has_b_type_ = true; + return Status::OK(); + } + + b_type_ = fp8_type_; + has_b_type_ = true; + if (K == 0) { + return Status::OK(); + } + if (N == 0) { + return Status::OK(); + } + + ORT_RETURN_IF_NOT(K % block_size_k_ == 0, + "DynamicQuantMatMulFp8 requires K to be divisible by block_size_k."); + ORT_RETURN_IF_NOT(N % block_size_n_ == 0, + "DynamicQuantMatMulFp8 requires N to be divisible by block_size_n."); + const size_t blocks_k = K / block_size_k_; + const size_t blocks_n = N / block_size_n_; + b_scale_count_ = SafeMul(blocks_k, blocks_n); + b_scales_ = IAllocator::MakeUniquePtr(alloc, b_scale_count_ * sizeof(float), true); + auto* prepacked_b_scales = static_cast(b_scales_.get()); + float fp8_max_abs = 0.0f; + ORT_RETURN_IF_ERROR(GetFp8MaxAbs(b_type_, fp8_max_abs)); + + const size_t quantized_b_size = SafeMul(K, N); + quantized_b_ = IAllocator::MakeUniquePtr(alloc, quantized_b_size, true); + quantized_b_size_ = quantized_b_size; + auto* quantized_b_bytes = static_cast(quantized_b_.get()); + if (tensor.IsDataType()) { + ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_, + fp8_max_abs, prepacked_b_scales); + ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_, + prepacked_b_scales, quantized_b_bytes)); + } else if (tensor.IsDataType()) { + ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_, + fp8_max_abs, prepacked_b_scales); + ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_, + prepacked_b_scales, quantized_b_bytes)); + } else if (tensor.IsDataType()) { + ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_, + fp8_max_abs, prepacked_b_scales); + ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_, + prepacked_b_scales, quantized_b_bytes)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported B type for DynamicQuantMatMulFp8 prepack."); + } + + if (prepacked_weights != nullptr) { + const std::array metadata_values = { + kPackedBMetadataVersion, + b_shape_[0], + b_shape_[1], + static_cast(quantized_b_size_), + static_cast(b_scale_count_), + static_cast(b_type_), + }; + auto metadata = IAllocator::MakeUniquePtr(alloc, kPackedBMetadataSize, true); + std::memcpy(metadata.get(), metadata_values.data(), kPackedBMetadataSize); + prepacked_weights->buffers_.push_back(std::move(quantized_b_)); + prepacked_weights->buffer_sizes_.push_back(quantized_b_size_); + prepacked_weights->buffers_.push_back(std::move(b_scales_)); + prepacked_weights->buffer_sizes_.push_back(b_scale_count_ * sizeof(float)); + prepacked_weights->buffers_.push_back(std::move(metadata)); + prepacked_weights->buffer_sizes_.push_back(kPackedBMetadataSize); + } + is_packed = true; + return Status::OK(); +} + +Status DynamicQuantMatMulFp8::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + if (input_idx != GetBIdx()) { + return Status::OK(); + } + + ORT_RETURN_IF(prepacked_buffers.size() != 3 || prepacked_buffer_sizes.size() != 3, + "DynamicQuantMatMulFp8 requires shared prepacked B data, scale, and metadata buffers."); + ORT_RETURN_IF(prepacked_buffers[0].get() == nullptr, + "DynamicQuantMatMulFp8 requires shared prepacked B data."); + ORT_RETURN_IF(prepacked_buffers[1].get() == nullptr, + "DynamicQuantMatMulFp8 requires shared prepacked B scales."); + + // Buffer 0 owns quantized B bytes; buffer 1 owns computed B scales; buffer 2 restores kernel state. + ORT_RETURN_IF_ERROR(RestorePackedBMetadata(prepacked_buffers[2].get(), + prepacked_buffer_sizes[2], + prepacked_buffer_sizes[0], + prepacked_buffer_sizes[1], + b_shape_, + quantized_b_size_, + b_scale_count_, + b_type_, + has_b_type_)); + quantized_b_ = std::move(prepacked_buffers[0]); + b_scales_ = std::move(prepacked_buffers[1]); + used_shared_buffers = true; + return Status::OK(); +} + +Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { + const Tensor* a = context->Input(IN_A); + const Tensor* b = quantized_b_ ? nullptr : context->Input(IN_B); + const Tensor* b_scale = context->Input(IN_B_SCALE); + const Tensor* b_zero_point = context->Input(IN_B_ZERO_POINT); + const Tensor* y_scale = context->Input(IN_Y_SCALE); + const Tensor* y_zero_point = context->Input(IN_Y_ZERO_POINT); + + // Runtime B uses one 2D B scale/zero-point layout, so reject batched B before MatMul broadcasts it. + ORT_RETURN_IF(!quantized_b_ && b->Shape().NumDimensions() != 2, + "DynamicQuantMatMulFp8 requires runtime B to be a 2D tensor."); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), + quantized_b_ ? b_shape_ : b->Shape(), + nullptr, + nullptr)); + + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + Tensor* y = context->Output(OUT_Y, helper.OutputShape()); + const size_t y_size = static_cast(y->Shape().Size()); + ORT_RETURN_IF(!y->IsDataType() && !y->IsDataType() && !y->IsDataType(), + "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16."); + + if (y_zero_point != nullptr) { + // Runtime tensors must match the schema scalar contract before reading element 0. + ORT_RETURN_IF(y_zero_point->Shape().NumDimensions() != 0 || y_zero_point->Shape().Size() != 1, + "DynamicQuantMatMulFp8 requires Y zero point input to be a scalar."); + ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*y_zero_point, 1, "Y zero point")); + } + + float y_scale_storage = 0.0f; + const float* y_scale_data = nullptr; + if (y_scale != nullptr) { + // Runtime tensors must match the schema scalar contract before reading element 0. + ORT_RETURN_IF(y_scale->Shape().NumDimensions() != 0 || y_scale->Shape().Size() != 1, + "DynamicQuantMatMulFp8 requires Y scale input to be a scalar."); + if (y_scale->IsDataType()) { + y_scale_data = y_scale->Data(); + } else if (y_scale->IsDataType()) { + y_scale_storage = static_cast(y_scale->Data()[0]); + y_scale_data = &y_scale_storage; + } else if (y_scale->IsDataType()) { + y_scale_storage = static_cast(y_scale->Data()[0]); + y_scale_data = &y_scale_storage; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires Y scale input to be float, float16, or bfloat16."); + } + ORT_RETURN_IF(!std::isfinite(y_scale_data[0]) || y_scale_data[0] <= 0.0f, + "DynamicQuantMatMulFp8 requires Y scale values to be finite and positive."); + } + + // Empty reduction does not need B data, so fill zeros before enforcing runtime FP8 B. + if (K == 0) { + if (y_size == 0) { + return Status::OK(); + } + if (y->IsDataType()) { + std::fill_n(y->MutableData(), y_size, 0.0f); + } else if (y->IsDataType()) { + std::fill_n(y->MutableData(), y_size, MLFloat16::FromBits(0)); + } else if (y->IsDataType()) { + std::fill_n(y->MutableData(), y_size, BFloat16::FromBits(0)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16."); + } + return Status::OK(); + } + + const bool a_is_supported = + a->IsDataType() || a->IsDataType() || a->IsDataType(); + ORT_RETURN_IF(!a_is_supported, "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16."); + + const auto b_elem_type = b ? static_cast(b->GetElementType()) + : static_cast(0); + const bool b_is_fp8 = IsFp8DataType(b_elem_type); + + Fp8Mode b_type{}; + if (has_b_type_) { + b_type = b_type_; + } else if (b_is_fp8) { + ORT_RETURN_IF_ERROR(GetFp8Type(b_elem_type, b_type)); + } else { + b_type = fp8_type_; + } + + if (b_zero_point != nullptr) { + const auto b_zp_elem_type = + static_cast(b_zero_point->GetElementType()); + Fp8Mode b_zp_type{}; + ORT_RETURN_IF_ERROR(GetFp8Type(b_zp_elem_type, b_zp_type)); + ORT_RETURN_IF(b_type != b_zp_type, + "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); + } + + if (y_size == 0) { + return Status::OK(); + } + + // Select the FP8 B buffer: prefer pre-quantized B from PrePack, otherwise accept FP8-typed B input. + const uint8_t* b_fp8 = nullptr; + if (quantized_b_) { + b_fp8 = static_cast(quantized_b_.get()); + } else if (b_is_fp8) { + b_fp8 = static_cast(b->DataRaw()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires runtime B input to be FP8. Non-FP8 B is only supported " + "when B is a constant initializer that can be quantized during prepack."); + } + + const size_t num_gemms = helper.OutputOffsets().size(); + ORT_RETURN_IF(K % block_size_k_ != 0, + "DynamicQuantMatMulFp8 requires K to be divisible by block_size_k."); + const size_t expected_blocks_k = K / block_size_k_; + const size_t blocks_m = M; + const size_t blocks_k = expected_blocks_k; + + const bool uses_model_b_scale = b_scales_ == nullptr; + ORT_RETURN_IF(uses_model_b_scale && b_scale == nullptr, + "DynamicQuantMatMulFp8 requires B scale when B is already FP8."); + ORT_RETURN_IF(uses_model_b_scale && b_scale->Shape().NumDimensions() != 2, + "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor."); + ORT_RETURN_IF(N % block_size_n_ != 0, + "DynamicQuantMatMulFp8 requires N to be divisible by block_size_n."); + const size_t blocks_n = N / block_size_n_; + ORT_RETURN_IF(blocks_n == 0, "DynamicQuantMatMulFp8 requires non-zero B scale N dimension."); + ORT_RETURN_IF(uses_model_b_scale && static_cast(b_scale->Shape()[0]) != blocks_n, + "DynamicQuantMatMulFp8 requires B scale N dimension to be N / block_size_n."); + ORT_RETURN_IF(uses_model_b_scale && static_cast(b_scale->Shape()[1]) != blocks_k, + "DynamicQuantMatMulFp8 requires B scale K dimension to be K / block_size_k."); + + const size_t a_scale_batch_stride = SafeMul(blocks_m, blocks_k); + const size_t b_zp_count = SafeMul(blocks_k, blocks_n); + + if (uses_model_b_scale && b_zero_point != nullptr) { + ORT_RETURN_IF(b_zero_point->Shape().NumDimensions() != 2, + "DynamicQuantMatMulFp8 requires B zero point to be a 2D tensor."); + ORT_RETURN_IF(b_zero_point->Shape()[0] != static_cast(blocks_n) || + b_zero_point->Shape()[1] != static_cast(blocks_k), + "DynamicQuantMatMulFp8 requires B zero point to have shape [N / block_size_n, K / block_size_k]."); + if (!constant_b_zero_point_values_validated_) { + ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*b_zero_point, b_zp_count, "B zero point")); + } + } + + AllocatorPtr temp_allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&temp_allocator)); + + const float* b_scales = nullptr; + IAllocatorUniquePtr b_scale_float; + size_t b_scale_elems = 0; + if (b_scales_) { + b_scales = static_cast(b_scales_.get()); + b_scale_elems = b_scale_count_; + } else if (b_scale->IsDataType()) { + b_scales = b_scale->Data(); + b_scale_elems = static_cast(b_scale->Shape().Size()); + } else { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + b_scale_elems = static_cast(b_scale->Shape().Size()); + b_scale_float = IAllocator::MakeUniquePtr(allocator, b_scale_elems, true); + if (b_scale->IsDataType()) { + for (size_t i = 0; i < b_scale_elems; ++i) { + b_scale_float.get()[i] = static_cast(b_scale->Data()[i]); + } + } else if (b_scale->IsDataType()) { + for (size_t i = 0; i < b_scale_elems; ++i) { + b_scale_float.get()[i] = static_cast(b_scale->Data()[i]); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires B scale input to be float, float16, or bfloat16."); + } + b_scales = b_scale_float.get(); + } + + // MLAS FP8 GEMM accumulates and stores float output. Use scratch for lower-precision Y, + // then convert once after all batched GEMMs complete. + IAllocatorUniquePtr y_float_buffer; + float* y_float_data = nullptr; + if (y->IsDataType()) { + y_float_data = y->MutableData(); + } else if (y->IsDataType() || y->IsDataType()) { + y_float_buffer = IAllocator::MakeUniquePtr(temp_allocator, y_size, true); + y_float_data = y_float_buffer.get(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16."); + } + + Fp8GemmShapeParams gemm_shape; + gemm_shape.M = M; + gemm_shape.N = N; + gemm_shape.K = K; + gemm_shape.Fp8Type = b_type; + + if (uses_model_b_scale && !constant_b_scale_values_validated_) { + for (size_t i = 0; i < b_scale_elems; ++i) { + ORT_RETURN_IF(!std::isfinite(b_scales[i]) || b_scales[i] <= 0.0f, + "DynamicQuantMatMulFp8 requires B scale values to be finite and positive."); + } + } + + const size_t a_fp8_size = SafeMul(M, K); + const size_t a_num_elements = static_cast(a->Shape().Size()); + ORT_RETURN_IF(a_num_elements % a_fp8_size != 0, + "DynamicQuantMatMulFp8 requires A to contain complete MxK matrices."); + const size_t a_batch_count = a_num_elements / a_fp8_size; + + // Quantize the physical A tensor once. Broadcasted output GEMMs then reuse the same FP8 A slice. + auto a_fp8_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_num_elements, true); + const size_t a_scale_count = SafeMul(a_batch_count, a_scale_batch_stride); + auto a_scale_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_scale_count, true); + const size_t a_quant_work_items = SafeMul(a_batch_count, a_scale_batch_stride); + ORT_RETURN_IF(a_quant_work_items > static_cast(std::numeric_limits::max()), + "DynamicQuantMatMulFp8 A quantization work item count exceeds ptrdiff_t range."); + const auto a_quant_work_items_i = static_cast(a_quant_work_items); + const size_t a_quant_block_elems = block_size_k_; + const TensorOpCost a_quant_unit_cost{ + static_cast(SafeMul(a_quant_block_elems, sizeof(float))), + static_cast(SafeMul(a_quant_block_elems, sizeof(uint8_t))), + static_cast(a_quant_block_elems) * 2.0}; + float fp8_max_abs = 0.0f; + ORT_RETURN_IF_ERROR(GetFp8MaxAbs(b_type, fp8_max_abs)); + const auto quantize_a_batches = [&](const auto* a_data) { + concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), a_quant_work_items_i, + a_quant_unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t tid = begin; tid < end; ++tid) { + const size_t work_idx = static_cast(tid); + const size_t a_batch_idx = work_idx / a_scale_batch_stride; + const size_t scale_block_idx = work_idx % a_scale_batch_stride; + const size_t row = scale_block_idx / blocks_k; + const size_t block_k = scale_block_idx % blocks_k; + const size_t a_batch_offset = a_batch_idx * a_fp8_size; + const size_t a_scale_batch_offset = a_batch_idx * a_scale_batch_stride; + QuantizeBlockwiseFp8ABlockDynamic( + a_data + a_batch_offset, + K, block_size_k_, blocks_k, + row, block_k, fp8_max_abs, b_type, + a_fp8_buffer.get() + a_batch_offset, + a_scale_buffer.get() + a_scale_batch_offset); + } + }); + }; + if (a->IsDataType()) { + quantize_a_batches(a->Data()); + } else if (a->IsDataType()) { + quantize_a_batches(a->Data()); + } else if (a->IsDataType()) { + quantize_a_batches(a->Data()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16."); + } + + std::vector gemm_data_vec(num_gemms); + for (size_t gemm_idx = 0; gemm_idx < num_gemms; ++gemm_idx) { + const size_t a_offset = helper.LeftOffsets()[gemm_idx]; + ORT_RETURN_IF(a_offset >= a_num_elements || (a_offset % a_fp8_size) != 0, + "DynamicQuantMatMulFp8 requires A offsets to reference complete MxK matrices."); + const size_t scale_batch_index = a_offset / a_fp8_size; + ORT_RETURN_IF(scale_batch_index >= a_batch_count, + "DynamicQuantMatMulFp8 requires A offsets to reference complete MxK matrices."); + const size_t a_scale_batch_offset = SafeMul(scale_batch_index, a_scale_batch_stride); + const float* a_scales_batch = a_scale_buffer.get() + a_scale_batch_offset; + auto& gemm_data = gemm_data_vec[gemm_idx]; + gemm_data.A = a_fp8_buffer.get() + a_offset; + gemm_data.lda = K; + gemm_data.B = b_fp8 + helper.RightOffsets()[gemm_idx]; + gemm_data.ldb = N; + gemm_data.C = y_float_data + helper.OutputOffsets()[gemm_idx]; + gemm_data.ldc = N; + gemm_data.ScaleA = a_scales_batch; + gemm_data.ScaleB = b_scales; + gemm_data.ScaleY = y_scale_data; + gemm_data.BlockSizeM = 1; + gemm_data.BlockSizeK = block_size_k_; + gemm_data.BlockSizeN = block_size_n_; + gemm_data.ScaleAStrideK = 1; + gemm_data.ScaleAStrideM = blocks_k; + gemm_data.ScaleBStrideN = blocks_k; + gemm_data.ScaleBStrideK = 1; + } + + ORT_RETURN_IF_ERROR(ReferenceFp8GemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, + context->GetOperatorThreadPool())); + + if (y_float_buffer != nullptr) { + if (y->IsDataType()) { + auto* y_data = y->MutableData(); + for (size_t i = 0; i < y_size; ++i) { + y_data[i] = static_cast(y_float_data[i]); + } + } else { + auto* y_data = y->MutableData(); + for (size_t i = 0; i < y_size; ++i) { + y_data[i] = BFloat16(y_float_data[i]); + } + } + } + + return Status::OK(); +} + +#endif // !defined(DISABLE_FLOAT8_TYPES) + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 59f97c222ceb2..6894b4e7eed40 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -25,6 +25,9 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MulInteger); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QGemm); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicQuantMatMulFp8); +#endif class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAdd); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConcat); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearWhere); @@ -139,6 +142,9 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); +#if !defined(DISABLE_FLOAT8_TYPES) + fn(GetOpSchema()); +#endif fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 5a3cd86b04492..69f3ca8837291 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -946,6 +946,127 @@ ONNX_MS_OPERATOR_SET_SCHEMA( updateOutputShape(ctx, 0, {first_input_shape.dim(transA ? 1 : 0), second_input_shape.dim(transB ? 0 : 1)}); } })); + +#if !defined(DISABLE_FLOAT8_TYPES) +ONNX_MS_OPERATOR_SET_SCHEMA( + DynamicQuantMatMulFp8, 1, + OpSchema() + .SetDoc("Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from " + "float16/bfloat16/float) and dynamic runtime quantization of activations to fp8 using " + "internally computed block-wise scales. All zero-point inputs, when provided, must encode 0.0. " + "Optional trailing inputs may be omitted, but intermediate optional inputs must use an empty " + "input name to keep later input positions.") + .Input(0, "A", "Input tensor A.", "TA") + .Input(1, "B", + "Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only " + "supported when B is a constant initializer that can be quantized during prepack.", + "TB") + .Input(2, "B_scale", + "Scale of FP8 input 'B'. Must be a block-wise tensor with shape " + "(N / block_size_n, K / block_size_k). Required when B is already FP8. Ignored for non-FP8 " + "constant B, where scales are computed during prepack.", + "TS", OpSchema::Optional) + .Input(3, "B_zero_point", + "Zero point tensor for input 'B'. Must have the same shape as B_scale and all values must encode 0.0.", + "TZ", OpSchema::Optional) + .Input(4, "Y_scale", "Scale of output 'Y'. Must be a scalar when provided.", "TS", + OpSchema::Optional) + .Input(5, "Y_zero_point", + "Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided. " + "May be provided without Y_scale; only Y_scale changes the floating-point output values.", + "TZ", OpSchema::Optional) + .Output(0, "Y", "Output tensor of shape (..., M, N).", "TY") + .Attr("block_size_k", "Block size along K for A and B block-wise scales.", AttributeProto::INT, + static_cast(128)) + .Attr("block_size_n", "Block size along N for B block-wise scales.", AttributeProto::INT, + static_cast(128)) + .Attr("fp8_type", + "FP8 TensorProto data type used when non-FP8 constant B is dynamically quantized during prepack. " + "Defaults to FLOAT8E4M3FN.", + AttributeProto::INT, + static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN)) + .TypeConstraint("TA", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, + "Constrain input A type to float16, bfloat16, or float.") + .TypeConstraint("TB", + {"tensor(float16)", "tensor(bfloat16)", "tensor(float)", + "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", + "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"}, + "Constrain input B type to fp8, or to float16, bfloat16, or float for constant initializers.") + .TypeConstraint("TZ", {"tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"}, + "Constrain zero point types to fp8. Only zero-valued zero points are supported.") + // Scale tensors are upcast to float by the CPU kernel before calling MLAS. + .TypeConstraint("TS", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain scale types to float, float16, or bfloat16.") + .TypeConstraint("TY", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, + "Constrain output type to float16, bfloat16, or float.") + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + const int64_t block_size_k = getAttribute(ctx, "block_size_k", static_cast(128)); + const int64_t block_size_n = getAttribute(ctx, "block_size_n", static_cast(128)); + if (block_size_k <= 0 || block_size_n <= 0) { + fail_type_inference("block_size_k and block_size_n must be greater than zero."); + } + if (hasInputShape(ctx, 1)) { + auto& b_shape = getInputShape(ctx, 1); + if (b_shape.dim_size() != 2) { + fail_type_inference("B must be 2D."); + } + } + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { + ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1); + } + if (hasInputShape(ctx, 4)) { + auto shape = ctx.getInputType(4)->tensor_type().shape(); + if (shape.dim_size() != 0) { + fail_type_inference("Y scale input must be a scalar."); + } + } + if (hasInputShape(ctx, 5)) { + auto shape = ctx.getInputType(5)->tensor_type().shape(); + if (shape.dim_size() != 0) { + fail_type_inference("Y zero point input must be a scalar."); + } + } + if (hasInputShape(ctx, 1) && hasInputShape(ctx, 2)) { + auto& b_shape = getInputShape(ctx, 1); + auto& b_scale_shape = getInputShape(ctx, 2); + if (b_scale_shape.dim_size() != 2) { + fail_type_inference("B scale must be 2D."); + } + if (b_shape.dim(1).has_dim_value() && b_scale_shape.dim(0).has_dim_value()) { + const auto n = b_shape.dim(1).dim_value(); + if ((n % block_size_n) != 0 || b_scale_shape.dim(0).dim_value() != (n / block_size_n)) { + fail_type_inference("B scale first dimension must be N / block_size_n."); + } + } + if (b_shape.dim(0).has_dim_value() && b_scale_shape.dim(1).has_dim_value()) { + const auto k = b_shape.dim(0).dim_value(); + if ((k % block_size_k) != 0 || b_scale_shape.dim(1).dim_value() != (k / block_size_k)) { + fail_type_inference("B scale last dimension must be K / block_size_k."); + } + } + } + if (hasInputShape(ctx, 1) && hasInputShape(ctx, 3)) { + auto& b_shape = getInputShape(ctx, 1); + auto& b_zp_shape = getInputShape(ctx, 3); + if (b_zp_shape.dim_size() != 2) { + fail_type_inference("B zero point must be 2D."); + } + if (b_shape.dim(1).has_dim_value() && b_zp_shape.dim(0).has_dim_value()) { + const auto n = b_shape.dim(1).dim_value(); + if ((n % block_size_n) != 0 || b_zp_shape.dim(0).dim_value() != (n / block_size_n)) { + fail_type_inference("B zero point first dimension must be N / block_size_n."); + } + } + if (b_shape.dim(0).has_dim_value() && b_zp_shape.dim(1).has_dim_value()) { + const auto k = b_shape.dim(0).dim_value(); + if ((k % block_size_k) != 0 || b_zp_shape.dim(1).dim_value() != (k / block_size_k)) { + fail_type_inference("B zero point last dimension must be K / block_size_k."); + } + } + } + })); + +#endif ONNX_MS_OPERATOR_SET_SCHEMA( QAttention, 1, OpSchema() diff --git a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc new file mode 100644 index 0000000000000..b532d39ac8384 --- /dev/null +++ b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc @@ -0,0 +1,844 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "core/session/inference_session.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/test_environment.h" +#include "test/unittest_util/conversion.h" +#include "default_providers.h" + +#if !defined(DISABLE_FLOAT8_TYPES) +#include "core/common/float8.h" + +namespace onnxruntime { +namespace test { + +class DynamicQuantMatMulFp8SessionTester : public OpTester { + public: + using BaseTester::ExecuteModel; + using BaseTester::FillFeedsAndOutputNames; + using BaseTester::SetTestFunctionCalled; + using OpTester::BuildModel; + using OpTester::OpTester; +}; + +template +struct Fp8TensorProtoType; + +template <> +struct Fp8TensorProtoType { + static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN; +}; + +template <> +struct Fp8TensorProtoType { + static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ; +}; + +template <> +struct Fp8TensorProtoType { + static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; +}; + +template <> +struct Fp8TensorProtoType { + static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ; +}; + +template +float Fp8MaxAbs(); + +template <> +float Fp8MaxAbs() { + return 448.0f; +} + +template <> +float Fp8MaxAbs() { + return 448.0f; +} + +template <> +float Fp8MaxAbs() { + return 57344.0f; +} + +template <> +float Fp8MaxAbs() { + return 57344.0f; +} + +template +float QuantizeDequantize(float value, float scale) { + return Fp8T(value / scale, true).ToFloat() * scale; +} + +template +std::vector ComputeRowwiseAQuantized(const std::vector& a_data, + int64_t m, + int64_t k, + int64_t block_size_k) { + const int64_t blocks_k = k / block_size_k; + std::vector result(a_data.size()); + for (int64_t row = 0; row < m; ++row) { + for (int64_t block_k = 0; block_k < blocks_k; ++block_k) { + const int64_t k_begin = block_k * block_size_k; + const int64_t k_end = k_begin + block_size_k; + float max_abs = 0.0f; + for (int64_t kk = k_begin; kk < k_end; ++kk) { + max_abs = std::max(max_abs, std::fabs(a_data[static_cast(row * k + kk)])); + } + const float scale = max_abs == 0.0f ? 1.0f : max_abs / Fp8MaxAbs(); + for (int64_t kk = k_begin; kk < k_end; ++kk) { + const size_t index = static_cast(row * k + kk); + result[index] = QuantizeDequantize(a_data[index], scale); + } + } + } + return result; +} + +template +std::vector ComputeConstantBQuantized(const std::vector& b_data, + int64_t k, + int64_t n, + int64_t block_size_k, + int64_t block_size_n) { + const int64_t blocks_k = k / block_size_k; + const int64_t blocks_n = n / block_size_n; + std::vector result(b_data.size()); + for (int64_t block_n = 0; block_n < blocks_n; ++block_n) { + const int64_t n_begin = block_n * block_size_n; + const int64_t n_end = n_begin + block_size_n; + for (int64_t block_k = 0; block_k < blocks_k; ++block_k) { + const int64_t k_begin = block_k * block_size_k; + const int64_t k_end = k_begin + block_size_k; + float max_abs = 0.0f; + for (int64_t kk = k_begin; kk < k_end; ++kk) { + for (int64_t nn = n_begin; nn < n_end; ++nn) { + max_abs = std::max(max_abs, std::fabs(b_data[static_cast(kk * n + nn)])); + } + } + const float scale = max_abs == 0.0f ? 1.0f : max_abs / Fp8MaxAbs(); + for (int64_t kk = k_begin; kk < k_end; ++kk) { + for (int64_t nn = n_begin; nn < n_end; ++nn) { + const size_t index = static_cast(kk * n + nn); + result[index] = QuantizeDequantize(b_data[index], scale); + } + } + } + } + return result; +} + +template +std::vector ComputeRuntimeBQuantized(const std::vector& b_data, + const std::vector& b_scale, + int64_t k, + int64_t n, + int64_t block_size_k, + int64_t block_size_n) { + const int64_t blocks_k = k / block_size_k; + std::vector result(b_data.size()); + for (int64_t kk = 0; kk < k; ++kk) { + const int64_t block_k = kk / block_size_k; + for (int64_t nn = 0; nn < n; ++nn) { + const int64_t block_n = nn / block_size_n; + const size_t index = static_cast(kk * n + nn); + result[index] = b_data[index].ToFloat() * b_scale[static_cast(block_n * blocks_k + block_k)]; + } + } + return result; +} + +std::vector ComputeMatMul(const std::vector& a_data, + const std::vector& b_data, + int64_t m, + int64_t n, + int64_t k, + float y_scale = 1.0f) { + std::vector y_data(static_cast(m * n), 0.0f); + for (int64_t row = 0; row < m; ++row) { + for (int64_t col = 0; col < n; ++col) { + float sum = 0.0f; + for (int64_t kk = 0; kk < k; ++kk) { + sum += a_data[static_cast(row * k + kk)] * b_data[static_cast(kk * n + col)]; + } + y_data[static_cast(row * n + col)] = sum * y_scale; + } + } + return y_data; +} + +template +void AddZeroPoint(OpTester& test, const char* name, const std::vector& shape, size_t count, bool initializer) { + test.AddInput(name, shape, std::vector(count, Fp8T(0.0f)), initializer); +} + +template +void RunConstantBInputs() { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + const auto a_quantized = ComputeRowwiseAQuantized(a_data, M, K, 128); + const auto b_quantized = ComputeConstantBQuantized(b_data, K, N, 128, 128); + const auto y_data = ComputeMatMul(a_quantized, b_quantized, M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("fp8_type", Fp8TensorProtoType::value); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 1e-5f); + test.Run(); +} + +template +void RunRuntimeFp8BInput() { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Fp8T(-0.5f)); + std::vector b_scale{1.0f}; + const auto a_quantized = ComputeRowwiseAQuantized(a_data, M, K, 128); + const auto b_quantized = ComputeRuntimeBQuantized(b_data, b_scale, K, N, 128, 128); + const auto y_data = ComputeMatMul(a_quantized, b_quantized, M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 1e-5f); + test.Run(); +} + +void RunDynamicQuantMatMulFp8WithSharedPrepack(DynamicQuantMatMulFp8SessionTester& test, + OrtValue& shared_b, + PrepackedWeightsContainer& prepacked_weights_container, + size_t& shared_prepack_count) { + test.SetTestFunctionCalled(); + + auto& model = test.BuildModel(); + Status status = model.MainGraph().Resolve(); + ASSERT_TRUE(status.IsOK()) << status; + + std::unordered_map feeds; + std::vector output_names; + test.FillFeedsAndOutputNames(feeds, output_names); + + SessionOptions so; + status = so.AddInitializer("B", &shared_b); + ASSERT_TRUE(status.IsOK()) << status; + + InferenceSession session{so, GetEnvironment()}; + status = session.AddPrePackedWeightsContainer(&prepacked_weights_container); + ASSERT_TRUE(status.IsOK()) << status; + + status = session.RegisterExecutionProvider(DefaultCpuExecutionProvider()); + ASSERT_TRUE(status.IsOK()) << status; + + test.ExecuteModel(model, + session, + OpTester::ExpectResult::kExpectSuccess, + "", + nullptr, + feeds, + output_names, + kCpuExecutionProvider); + shared_prepack_count = session.GetSessionState().GetUsedSharedPrePackedWeightCounter(); +} + +TEST(DynamicQuantMatMulFp8, WithConstantBInputs) { + RunConstantBInputs(); +} + +TEST(DynamicQuantMatMulFp8, WithConstantBInputsE4M3FNUZ) { + RunConstantBInputs(); +} + +TEST(DynamicQuantMatMulFp8, WithConstantBInputsE5M2) { + RunConstantBInputs(); +} + +TEST(DynamicQuantMatMulFp8, WithConstantBInputsE5M2FNUZ) { + RunConstantBInputs(); +} + +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInput) { + RunRuntimeFp8BInput(); +} + +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE4M3FNUZ) { + RunRuntimeFp8BInput(); +} + +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE5M2) { + RunRuntimeFp8BInput(); +} + +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE5M2FNUZ) { + RunRuntimeFp8BInput(); +} + +TEST(DynamicQuantMatMulFp8, WithOmittedOutputQuantizationInputs) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeRuntimeBQuantized(b_data, b_scale, K, N, 128, 128), M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 1e-5f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, WithOnlyYScale) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + constexpr float YScale = 0.5f; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + std::vector y_scale{YScale}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeRuntimeBQuantized(b_data, b_scale, K, N, 128, 128), M, N, K, YScale); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); + test.AddInput("Y_scale", {}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 1e-5f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, RejectsNonZeroYZeroPoint) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + std::vector y_data(static_cast(M * N), 0.0f); + std::vector y_zp{Float8E4M3FN(1.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); + test.AddOptionalInputEdge(); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 supports symmetric quantization only; Y zero point values must be zero."); +} + +TEST(DynamicQuantMatMulFp8, WithRuntimeBInputsBf16Scales) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale_float{1.0f}; + std::vector b_scale = MakeBFloat16({1.0f}); + std::vector y_scale = MakeBFloat16({1.0f}); + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeRuntimeBQuantized(b_data, b_scale_float, K, N, 128, 128), M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); + test.AddInput("Y_scale", {}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 1e-5f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, Float16Output) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeConstantBQuantized(b_data, K, N, 128, 128), M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, FloatsToMLFloat16s(a_data)); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, FloatsToMLFloat16s(y_data)); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, BFloat16Output) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeConstantBQuantized(b_data, K, N, 128, 128), M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, FloatsToBFloat16s(a_data)); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, FloatsToBFloat16s(y_data)); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, RejectsNonConstantB) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires runtime B input to be FP8."); +} + +TEST(DynamicQuantMatMulFp8, RejectsRuntimeFp8BZeroPointTypeMismatch) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + std::vector b_zp{Float8E5M2(0.0f)}; + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddInput("B_zero_point", {N / 128, K / 128}, b_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); +} + +TEST(DynamicQuantMatMulFp8, RejectsConstantFp8BZeroPointTypeMismatch) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + std::vector b_zp{Float8E5M2(0.0f)}; + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddInput("B_zero_point", {N / 128, K / 128}, b_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); +} + +TEST(DynamicQuantMatMulFp8, RejectsNonZeroBZeroPoint) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + std::vector b_zp{Float8E4M3FN(1.0f)}; + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddInput("B_zero_point", {N / 128, K / 128}, b_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 supports symmetric quantization only; B zero point values must be zero."); +} + +TEST(DynamicQuantMatMulFp8, NonDefaultBlockSizesWithPartialM) { + constexpr int64_t M = 9; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + constexpr int64_t BlockK = 2; + constexpr int64_t BlockN = 2; + + std::vector a_data(static_cast(M * K), 0.0f); + for (int64_t m = 0; m < M; ++m) { + for (int64_t k = 0; k < K; ++k) { + a_data[static_cast(m * K + k)] = static_cast((m + 1) * (k + 1)) / 16.0f; + } + } + + std::vector b_data(static_cast(K * N), Float8E4M3FN(0.0f)); + for (int64_t k = 0; k < K; ++k) { + for (int64_t n = 0; n < N; ++n) { + b_data[static_cast(k * N + n)] = Float8E4M3FN(k == n ? 1.0f : 0.0f); + } + } + std::vector b_scale{1.0f, 2.0f, + 3.0f, 4.0f}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, BlockK), + ComputeRuntimeBQuantized(b_data, b_scale, K, N, BlockK, BlockN), M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("block_size_k", BlockK); + test.AddAttribute("block_size_n", BlockN); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / BlockN, K / BlockK}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / BlockN, K / BlockK}, b_scale.size(), false); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 1e-5f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsRestoresPackedBMetadata) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + constexpr int64_t BlockSize = 2; + + std::vector a_data{ + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 0.5f, 1.0f, -0.5f, -1.0f, + 4.0f, 3.0f, 2.0f, 1.0f}; + std::vector b_data{ + 2.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 2.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 2.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 2.0f}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, BlockSize), + ComputeConstantBQuantized(b_data, K, N, BlockSize, BlockSize), + M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("fp8_type", Fp8TensorProtoType::value); + test.AddAttribute("block_size_k", BlockSize); + test.AddAttribute("block_size_n", BlockSize); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 1e-5f); + + OrtValue b; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({K, N}), b_data.data(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + SessionOptions so; + ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); + + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t prepack_count_session_1 = 0; + size_t shared_prepack_count = 0; + test.Config(so).ConfigEps(cpu_ep()).RunWithConfig(&prepack_count_session_1, &shared_prepack_count); + ASSERT_EQ(shared_prepack_count, static_cast(0)); + ASSERT_EQ(test.GetNumPrePackedWeightsShared(), prepack_count_session_1); + ASSERT_GT(prepack_count_session_1, static_cast(0)); + + size_t prepack_count_session_2 = 0; + test.Config(so).ConfigEps(cpu_ep()).RunWithConfig(&prepack_count_session_2, &shared_prepack_count); + ASSERT_EQ(prepack_count_session_2, prepack_count_session_1); + ASSERT_EQ(shared_prepack_count, prepack_count_session_2); +} + +TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsWithComputedBScalesReuseCorrectly) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + constexpr int64_t BlockSize = 2; + + const std::vector a_data{ + 1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}; + std::vector b_data{ + 0.31f, -0.37f, 0.83f, -1.70f, + 1.20f, -2.30f, 3.40f, -4.50f, + 5.50f, -6.25f, 7.75f, -8.50f, + 9.00f, -10.50f, 11.25f, -12.75f}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, BlockSize), + ComputeConstantBQuantized(b_data, K, N, BlockSize, BlockSize), + M, N, K); + + OrtValue b; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({K, N}), b_data.data(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + PrepackedWeightsContainer prepacked_weights_container; + + DynamicQuantMatMulFp8SessionTester test_1("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test_1.AddAttribute("block_size_k", BlockSize); + test_1.AddAttribute("block_size_n", BlockSize); + test_1.AddInput("A", {M, K}, a_data); + test_1.AddInput("B", {K, N}, b_data, true /*initializer*/); + test_1.AddOutput("Y", {M, N}, y_data); + test_1.SetOutputAbsErr("Y", 1e-5f); + + size_t shared_prepack_count = 0; + RunDynamicQuantMatMulFp8WithSharedPrepack(test_1, b, prepacked_weights_container, shared_prepack_count); + ASSERT_EQ(shared_prepack_count, static_cast(0)); + ASSERT_EQ(prepacked_weights_container.GetNumberOfElements(), static_cast(1)); + + DynamicQuantMatMulFp8SessionTester test_2("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test_2.AddAttribute("block_size_k", BlockSize); + test_2.AddAttribute("block_size_n", BlockSize); + test_2.AddInput("A", {M, K}, a_data); + test_2.AddInput("B", {K, N}, b_data, true /*initializer*/); + test_2.AddOutput("Y", {M, N}, y_data); + test_2.SetOutputAbsErr("Y", 1e-5f); + + RunDynamicQuantMatMulFp8WithSharedPrepack(test_2, b, prepacked_weights_container, shared_prepack_count); + ASSERT_GT(shared_prepack_count, static_cast(0)); + ASSERT_EQ(prepacked_weights_container.GetNumberOfElements(), static_cast(1)); +} + +TEST(DynamicQuantMatMulFp8, RejectsMalformedBScaleShapeBeforeReadingScaleData) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(0.5f)); + std::vector b_scale = FloatsToMLFloat16s({1.0f}); + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("block_size_k", 4); + test.AddAttribute("block_size_n", 4); + test.AddShapeToTensorData(false); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {1}, b_scale); + AddZeroPoint(test, "B_zero_point", {1, 1}, 1, false); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor."); +} + +TEST(DynamicQuantMatMulFp8, RejectsRuntimeFp8BWithoutBScale) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(0.5f)); + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("block_size_k", 4); + test.AddAttribute("block_size_n", 4); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires B scale when B is already FP8."); +} + +TEST(DynamicQuantMatMulFp8, ZeroMInput) { + constexpr int64_t M = 0; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data{}; + std::vector b_data(static_cast(K * N), 0.0f); + std::vector y_data{}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, y_data); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, ZeroKInput) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 0; + + std::vector a_data{}; + std::vector b_data{}; + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, y_data); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, ZeroKInputRejectsInvalidYScaleShape) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 0; + + std::vector a_data{}; + std::vector b_data{}; + std::vector y_data(static_cast(M * N), 0.0f); + std::vector y_scale{1.0f}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddShapeToTensorData(false); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddInput("Y_scale", {1}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires Y scale input to be a scalar."); +} + +TEST(DynamicQuantMatMulFp8, ZeroKInputRejectsInvalidYScaleValue) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 0; + + std::vector a_data{}; + std::vector b_data{}; + std::vector y_data(static_cast(M * N), 0.0f); + std::vector y_scale{0.0f}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddInput("Y_scale", {}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Y scale values to be finite and positive."); +} + +TEST(DynamicQuantMatMulFp8, ZeroKInputRejectsInvalidYScaleType) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 0; + + std::vector a_data{}; + std::vector b_data{}; + std::vector y_data(static_cast(M * N), 0.0f); + std::vector y_scale{1}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddInput("Y_scale", {}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, "Type Error"); +} + +TEST(DynamicQuantMatMulFp8, ZeroNInput) { + constexpr int64_t M = 128; + constexpr int64_t N = 0; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.0f); + std::vector b_data{}; + std::vector y_data{}; + std::vector b_scale{}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddOutput("Y", {M, N}, y_data); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, ZeroNInputRejectsInvalidYScaleValue) { + constexpr int64_t M = 128; + constexpr int64_t N = 0; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.0f); + std::vector b_data{}; + std::vector y_data{}; + std::vector b_scale{}; + std::vector y_scale{0.0f}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddOptionalInputEdge(); + test.AddInput("Y_scale", {}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Y scale values to be finite and positive."); +} + +TEST(DynamicQuantMatMulFp8, ZeroNInputWithConstantNonFp8B) { + constexpr int64_t M = 128; + constexpr int64_t N = 0; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.0f); + std::vector b_data{}; + std::vector y_data{}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, y_data); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime +#endif // !defined(DISABLE_FLOAT8_TYPES) From deebf48d7a6c78c183a4a16c09e96bc3ad5f7a0f Mon Sep 17 00:00:00 2001 From: melkap01 Date: Wed, 27 May 2026 17:23:22 +0100 Subject: [PATCH 2/5] removing 2 mlas references from internal implementation Signed-off-by: melkap01 --- .../contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc | 2 +- onnxruntime/core/graph/contrib_ops/quantization_defs.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc index 4be1a7e73ecd2..93cd2543aa0d4 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -826,7 +826,7 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { b_scales = b_scale_float.get(); } - // MLAS FP8 GEMM accumulates and stores float output. Use scratch for lower-precision Y, + // The internal FP8 GEMM helper accumulates and stores float output. Use scratch for lower-precision Y, // then convert once after all batched GEMMs complete. IAllocatorUniquePtr y_float_buffer; float* y_float_data = nullptr; diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 69f3ca8837291..b3abcb800eb93 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -994,7 +994,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Constrain input B type to fp8, or to float16, bfloat16, or float for constant initializers.") .TypeConstraint("TZ", {"tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"}, "Constrain zero point types to fp8. Only zero-valued zero points are supported.") - // Scale tensors are upcast to float by the CPU kernel before calling MLAS. + // Scale tensors are upcast to float by the CPU kernel before compute. .TypeConstraint("TS", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain scale types to float, float16, or bfloat16.") .TypeConstraint("TY", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, From 1e7951ff51ab65fda45c7f5b054755a8179216f3 Mon Sep 17 00:00:00 2001 From: melkap01 Date: Wed, 27 May 2026 17:37:38 +0100 Subject: [PATCH 3/5] missing header added Signed-off-by: melkap01 --- .../quantization/dynamic_quant_matmul_fp8.h | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h new file mode 100644 index 0000000000000..02fbf142fd9c8 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/prepacked_weights.h" +#include "core/graph/onnx_protobuf.h" + +namespace onnxruntime { +namespace contrib { + +enum class Fp8Mode : int { + E4M3Inf = 0, + E4M3Sat, + E5M2Inf, + E5M2Sat, + End, +}; + +class DynamicQuantMatMulFp8 final : public OpKernel { + public: + DynamicQuantMatMulFp8(const OpKernelInfo& info); + + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status Compute(OpKernelContext* context) const override; + + enum InputTensors : int { + IN_A = 0, + IN_B = 1, + IN_B_SCALE = 2, + IN_B_ZERO_POINT = 3, + IN_Y_SCALE = 4, + IN_Y_ZERO_POINT = 5 + }; + + enum OutputTensors : int { OUT_Y = 0 }; + + static Status GetFp8Type(const Tensor& tensor, Fp8Mode& out_type); + static Status GetFp8Type(ONNX_NAMESPACE::TensorProto_DataType elem_type, Fp8Mode& out_type); + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + + private: + static constexpr int GetBIdx() { return IN_B; } + IAllocatorUniquePtr quantized_b_; + size_t quantized_b_size_{0}; + IAllocatorUniquePtr b_scales_; + size_t b_scale_count_{0}; + bool constant_b_scale_values_validated_{false}; + bool constant_b_zero_point_values_validated_{false}; + TensorShape b_shape_; + Fp8Mode b_type_{Fp8Mode::E4M3Inf}; + bool has_b_type_{false}; + Fp8Mode fp8_type_{Fp8Mode::E4M3Inf}; + size_t block_size_k_{128}; + size_t block_size_n_{128}; +}; + +} // namespace contrib +} // namespace onnxruntime From 8dabaa2947d00fdd9a9b54f7d0bd6d7b305ac1ca Mon Sep 17 00:00:00 2001 From: melkap01 Date: Wed, 27 May 2026 21:31:51 +0100 Subject: [PATCH 4/5] github review comments addessed Signed-off-by: melkap01 --- .../quantization/dynamic_quant_matmul_fp8.cc | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc index 93cd2543aa0d4..57e4e209011db 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -2,24 +2,25 @@ // SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates // SPDX-License-Identifier: MIT -#include "dynamic_quant_matmul_fp8.h" - -#include "core/common/common.h" -#include "core/framework/op_kernel.h" -#include "core/graph/onnx_protobuf.h" -#include "core/common/float16.h" -#include "core/common/float8.h" -#include "core/common/safeint.h" -#include "core/platform/threadpool.h" -#include "core/providers/cpu/math/matmul_helper.h" +#include "contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h" #include #include #include #include #include +#include #include +#include "core/common/common.h" +#include "core/common/float16.h" +#include "core/common/float8.h" +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "core/graph/onnx_protobuf.h" +#include "core/platform/threadpool.h" +#include "core/providers/cpu/math/matmul_helper.h" + namespace onnxruntime { namespace contrib { @@ -936,7 +937,7 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { } ORT_RETURN_IF_ERROR(ReferenceFp8GemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, - context->GetOperatorThreadPool())); + context->GetOperatorThreadPool())); if (y_float_buffer != nullptr) { if (y->IsDataType()) { From bc6f9b3483e5df124952b494a1cf5dcd95551c26 Mon Sep 17 00:00:00 2001 From: melkap01 Date: Fri, 29 May 2026 17:09:29 +0100 Subject: [PATCH 5/5] review comments addressed Signed-off-by: melkap01 --- .../quantization/dynamic_quant_matmul_fp8.cc | 43 ++++++++----------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc index 57e4e209011db..f6de3c24bb76a 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -98,35 +98,31 @@ Status RestorePackedBMetadata(const void* metadata_buffer, } // Reject invalid scales before quantization divides by them or the GEMM dequantizes with them. -Status ValidatePositiveFiniteScaleTensor(const Tensor& scale, const char* scale_name) { +template +Status ValidatePositiveFiniteScaleTensorImpl(const Tensor& scale, const char* scale_name) { + const auto* data = scale.Data(); const size_t count = static_cast(scale.Shape().Size()); + + for (size_t i = 0; i < count; ++i) { + const float value = static_cast(data[i]); + ORT_RETURN_IF(!std::isfinite(value) || value <= 0.0f, + "DynamicQuantMatMulFp8 requires ", scale_name, " values to be finite and positive."); + } + + return Status::OK(); +} + +Status ValidatePositiveFiniteScaleTensor(const Tensor& scale, const char* scale_name) { if (scale.IsDataType()) { - const auto* data = scale.Data(); - for (size_t i = 0; i < count; ++i) { - ORT_RETURN_IF(!std::isfinite(data[i]) || data[i] <= 0.0f, - "DynamicQuantMatMulFp8 requires ", scale_name, " values to be finite and positive."); - } - return Status::OK(); + return ValidatePositiveFiniteScaleTensorImpl(scale, scale_name); } if (scale.IsDataType()) { - const auto* data = scale.Data(); - for (size_t i = 0; i < count; ++i) { - const float value = static_cast(data[i]); - ORT_RETURN_IF(!std::isfinite(value) || value <= 0.0f, - "DynamicQuantMatMulFp8 requires ", scale_name, " values to be finite and positive."); - } - return Status::OK(); + return ValidatePositiveFiniteScaleTensorImpl(scale, scale_name); } if (scale.IsDataType()) { - const auto* data = scale.Data(); - for (size_t i = 0; i < count; ++i) { - const float value = static_cast(data[i]); - ORT_RETURN_IF(!std::isfinite(value) || value <= 0.0f, - "DynamicQuantMatMulFp8 requires ", scale_name, " values to be finite and positive."); - } - return Status::OK(); + return ValidatePositiveFiniteScaleTensorImpl(scale, scale_name); } return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -548,10 +544,7 @@ Status DynamicQuantMatMulFp8::PrePack(const Tensor& tensor, int input_idx, Alloc b_type_ = fp8_type_; has_b_type_ = true; - if (K == 0) { - return Status::OK(); - } - if (N == 0) { + if (K == 0 || N == 0) { return Status::OK(); }