From 536fb3f4093e577530f4047bbafc3ffb1a706450 Mon Sep 17 00:00:00 2001 From: melkap01 Date: Fri, 8 May 2026 16:09:35 +0100 Subject: [PATCH 01/11] Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback Adds DynamicQuantMatMulFp8 schema under the Microsoft contrib opset. Registers the CPU contrib kernel when FP8 types are enabled. Adds dynamic_quant_matmul_fp8.{h,cc} CPU kernel implementation. Adds MLAS FP8 GEMM API surface and scalar fallback implementation in qgemm_fp8.cpp. Wires the MLAS FP8 source into the MLAS build. Adds provider tests for the FP8 op-kernel path. Signed-off-by: melkap01 --- cmake/onnxruntime_mlas.cmake | 1 + .../contrib_ops/cpu/cpu_contrib_kernels.cc | 6 + .../quantization/dynamic_quant_matmul_fp8.cc | 870 +++++++++++++++ .../quantization/dynamic_quant_matmul_fp8.h | 59 ++ onnxruntime/core/graph/contrib_ops/ms_opset.h | 6 + .../graph/contrib_ops/quantization_defs.cc | 210 ++++ onnxruntime/core/mlas/inc/mlas.h | 75 ++ .../core/mlas/lib/kleidiai/mlasi_kleidiai.h | 36 +- .../mlas/lib/kleidiai/sbgemm_kleidiai.cpp | 6 +- .../core/mlas/lib/kleidiai/sgemm_kleidiai.cpp | 6 +- onnxruntime/core/mlas/lib/mlasi.h | 27 + onnxruntime/core/mlas/lib/qgemm_fp8.cpp | 219 ++++ .../dynamic_quant_matmul_fp8_test.cc | 997 ++++++++++++++++++ .../test/mlas/unittest/test_qgemm_fp8.cpp | 193 ++++ 14 files changed, 2670 insertions(+), 41 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc create mode 100644 onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h create mode 100644 onnxruntime/core/mlas/lib/qgemm_fp8.cpp create mode 100644 onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc create mode 100644 onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 0233254ad50ad..bd4f34db880ca 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -21,6 +21,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sgemm.cpp ${MLAS_SRC_DIR}/halfgemm.cpp ${MLAS_SRC_DIR}/qgemm.cpp + ${MLAS_SRC_DIR}/qgemm_fp8.cpp ${MLAS_SRC_DIR}/qdwconv.cpp ${MLAS_SRC_DIR}/convolve.cpp ${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_greater_than_1.cpp diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index cc652ed52ee72..799b5bbddae51 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -113,6 +113,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, NhwcMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, NhwcMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicQuantMatMulFp8); +#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE); @@ -281,6 +284,9 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, +#if !defined(DISABLE_FLOAT8_TYPES) + BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc new file mode 100644 index 0000000000000..1f5ee0378ad16 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -0,0 +1,870 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "dynamic_quant_matmul_fp8.h" + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/graph/onnx_protobuf.h" +#include "core/mlas/inc/mlas.h" +#include "core/common/float16.h" +#include "core/common/float8.h" +#include "core/common/safeint.h" +#include "core/providers/cpu/math/matmul_helper.h" + +#include +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +#if !defined(DISABLE_FLOAT8_TYPES) + +namespace { + +constexpr int64_t kDefaultBlockSize = 128; +constexpr int64_t kPackedBMetadataVersion = 1; +constexpr size_t kPackedBMetadataElementCount = 6; +constexpr size_t kPackedBMetadataSize = kPackedBMetadataElementCount * sizeof(int64_t); + +enum PackedBMetadataIndex : size_t { + kPackedBMetadataVersionIndex = 0, + kPackedBMetadataRowsIndex, + kPackedBMetadataColsIndex, + kPackedBMetadataSizeIndex, + kPackedBMetadataFp8ModeIndex, + kPackedBMetadataHasFp8ModeIndex, +}; + +size_t CeilDiv(size_t value, size_t divisor) { + ORT_ENFORCE(divisor != 0, "CeilDiv divisor must be non-zero."); + return value == 0 ? 0 : ((value - 1) / divisor) + 1; +} + +inline float Fp8ByteToFloat(uint8_t value, mlas_fp8_mode mode) { + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: + return Float8E4M3FN(value, Float8E4M3FN::FromBits()).ToFloat(); + case MLAS_FP8_MODE_E4M3_SAT: + return Float8E4M3FNUZ(value, Float8E4M3FNUZ::FromBits()).ToFloat(); + case MLAS_FP8_MODE_E5M2_INF: + return Float8E5M2(value, Float8E5M2::FromBits()).ToFloat(); + case MLAS_FP8_MODE_E5M2_SAT: + return Float8E5M2FNUZ(value, Float8E5M2FNUZ::FromBits()).ToFloat(); + default: + ORT_THROW("Unsupported FP8 mode."); + } +} + +inline uint8_t FloatToFp8Byte(float value, mlas_fp8_mode mode) { + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: { + const Float8E4M3FN fp8(value, true); + return fp8.val; + } + case MLAS_FP8_MODE_E4M3_SAT: { + const Float8E4M3FNUZ fp8(value, true); + return fp8.val; + } + case MLAS_FP8_MODE_E5M2_INF: { + const Float8E5M2 fp8(value, true); + return fp8.val; + } + case MLAS_FP8_MODE_E5M2_SAT: { + const Float8E5M2FNUZ fp8(value, true); + return fp8.val; + } + default: + ORT_THROW("Unsupported fp8 mode for DynamicQuantMatMulFp8."); + } +} + +bool IsFp8DataType(ONNX_NAMESPACE::TensorProto_DataType elem_type) { + return elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || + elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ || + elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2 || + elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ; +} + +bool IsValidFp8Mode(int64_t mode) { + return mode >= static_cast(MLAS_FP8_MODE_E4M3_INF) && + mode < static_cast(MLAS_FP8_MODE_END); +} + +Status RestorePackedBMetadata(const void* metadata_buffer, + size_t metadata_size, + size_t quantized_b_buffer_size, + TensorShape& b_shape, + size_t& quantized_b_size, + mlas_fp8_mode& b_type, + bool& has_b_type) { + ORT_RETURN_IF(metadata_buffer == nullptr, + "DynamicQuantMatMulFp8 requires shared prepacked B metadata."); + ORT_RETURN_IF(metadata_size != kPackedBMetadataSize, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an unexpected size."); + + const auto* metadata = static_cast(metadata_buffer); + ORT_RETURN_IF(metadata[kPackedBMetadataVersionIndex] != kPackedBMetadataVersion, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an unsupported version."); + ORT_RETURN_IF(metadata[kPackedBMetadataRowsIndex] <= 0 || metadata[kPackedBMetadataColsIndex] <= 0, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B shape."); + ORT_RETURN_IF(metadata[kPackedBMetadataSizeIndex] <= 0, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B buffer size."); + ORT_RETURN_IF(metadata[kPackedBMetadataHasFp8ModeIndex] != 1 || + !IsValidFp8Mode(metadata[kPackedBMetadataFp8ModeIndex]), + "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid FP8 type."); + + const size_t rows = static_cast(metadata[kPackedBMetadataRowsIndex]); + const size_t cols = static_cast(metadata[kPackedBMetadataColsIndex]); + const size_t expected_quantized_b_size = SafeMul(rows, cols); + const size_t restored_quantized_b_size = static_cast(metadata[kPackedBMetadataSizeIndex]); + ORT_RETURN_IF(restored_quantized_b_size != expected_quantized_b_size || + restored_quantized_b_size != quantized_b_buffer_size, + "DynamicQuantMatMulFp8 shared prepacked B metadata does not match the B buffer size."); + + b_shape = TensorShape({metadata[kPackedBMetadataRowsIndex], metadata[kPackedBMetadataColsIndex]}); + quantized_b_size = restored_quantized_b_size; + b_type = static_cast(metadata[kPackedBMetadataFp8ModeIndex]); + has_b_type = true; + return Status::OK(); +} + +// Reject invalid scales before quantization divides by them or MLAS dequantizes with them. +Status ValidatePositiveFiniteScales(const float* scales, size_t count, const char* scale_name) { + for (size_t i = 0; i < count; ++i) { + ORT_RETURN_IF(!std::isfinite(scales[i]) || scales[i] <= 0.0f, + "DynamicQuantMatMulFp8 requires ", scale_name, " values to be finite and positive."); + } + return Status::OK(); +} + +Status ValidateZeroPointValuesAreZero(const Tensor& zero_point, size_t expected_count, + const char* zero_point_name) { + const size_t actual_count = static_cast(zero_point.Shape().Size()); + ORT_RETURN_IF(actual_count != expected_count, + "DynamicQuantMatMulFp8 requires ", zero_point_name, " to have the expected number of elements."); + + const auto reject_non_zero = [zero_point_name](float value) { + ORT_RETURN_IF(value != 0.0f, + "DynamicQuantMatMulFp8 supports symmetric quantization only; ", + zero_point_name, " values must be zero."); + return Status::OK(); + }; + + if (zero_point.IsDataType()) { + const auto* zp = static_cast(zero_point.DataRaw()); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], MLAS_FP8_MODE_E4M3_INF))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = static_cast(zero_point.DataRaw()); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], MLAS_FP8_MODE_E4M3_SAT))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = static_cast(zero_point.DataRaw()); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], MLAS_FP8_MODE_E5M2_INF))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = static_cast(zero_point.DataRaw()); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], MLAS_FP8_MODE_E5M2_SAT))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = zero_point.Data(); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(zp[i])); + } + } else if (zero_point.IsDataType()) { + const auto* zp = zero_point.Data(); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(static_cast(zp[i]))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = zero_point.Data(); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(static_cast(zp[i]))); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported zero point type for DynamicQuantMatMulFp8."); + } + return Status::OK(); +} + +template +void QuantizeBlockwiseFp8A(const SrcT* src, + size_t M, + size_t K, + size_t block_size_m, + size_t block_size_k, + size_t blocks_k, + const float* scales, + const uint8_t* zero_points, + mlas_fp8_mode mode, + uint8_t* dst) { + // Block sizes come from op attributes; scale shapes only provide the number of blocks. + for (size_t m = 0; m < M; ++m) { + const size_t block_m = m / block_size_m; + const size_t row_offset = m * K; + for (size_t k = 0; k < K; ++k) { + const size_t block_k = k / block_size_k; + const size_t idx = block_m * blocks_k + block_k; + const float scale = scales[idx]; + const float zp = zero_points ? Fp8ByteToFloat(zero_points[idx], mode) : 0.0f; + const float value = static_cast(src[row_offset + k]); + const float quantized = (value / scale) + zp; + dst[row_offset + k] = FloatToFp8Byte(quantized, mode); + } + } +} + +template +void QuantizeBlockwiseFp8(const SrcT* src, + size_t K, + size_t N, + size_t block_size_k, + size_t block_size_n, + const float* scales, + const uint8_t* zero_points, + mlas_fp8_mode mode, + uint8_t* dst) { + // Block sizes come from op attributes; scale shapes only provide the number of blocks. + const size_t blocks_n = N / block_size_n; + for (size_t k = 0; k < K; ++k) { + const size_t block_k = k / block_size_k; + const size_t row_offset = k * N; + for (size_t n = 0; n < N; ++n) { + const size_t block_n = n / block_size_n; + const size_t scale_idx = block_k * blocks_n + block_n; + const float scale = scales[scale_idx]; + const float zp = zero_points ? Fp8ByteToFloat(zero_points[scale_idx], mode) : 0.0f; + const float value = static_cast(src[row_offset + n]); + const float quantized = (value / scale) + zp; + const Fp8T fp8_value(quantized, true); + dst[row_offset + n] = fp8_value.val; + } + } +} + +template +Status QuantizeToFp8ByMode(mlas_fp8_mode fp8_mode, + const SrcT* src, + size_t K, + size_t N, + size_t block_size_k, + size_t block_size_n, + const float* scales, + const uint8_t* zero_points, + uint8_t* dst) { + // Dispatch quantization using the requested FP8 mode and runtime block sizes. + switch (fp8_mode) { + case MLAS_FP8_MODE_E4M3_INF: + QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, + zero_points, fp8_mode, dst); + return Status::OK(); + case MLAS_FP8_MODE_E4M3_SAT: + QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, + zero_points, fp8_mode, dst); + return Status::OK(); + case MLAS_FP8_MODE_E5M2_INF: + QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, zero_points, + fp8_mode, dst); + return Status::OK(); + case MLAS_FP8_MODE_E5M2_SAT: + QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, + zero_points, fp8_mode, dst); + return Status::OK(); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported fp8 mode for DynamicQuantMatMulFp8."); + } +} + +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + DynamicQuantMatMulFp8, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("TA", std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("TB", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("TZ", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("TS", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("TY", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + DynamicQuantMatMulFp8); + +DynamicQuantMatMulFp8::DynamicQuantMatMulFp8(const OpKernelInfo& info) : OpKernel(info) { + const int64_t block_size_m = info.GetAttrOrDefault("block_size_m", kDefaultBlockSize); + const int64_t block_size_k = info.GetAttrOrDefault("block_size_k", kDefaultBlockSize); + const int64_t block_size_n = info.GetAttrOrDefault("block_size_n", kDefaultBlockSize); + ORT_ENFORCE(block_size_m > 0, + "DynamicQuantMatMulFp8 requires block_size_m to be greater than zero."); + ORT_ENFORCE(block_size_k > 0, + "DynamicQuantMatMulFp8 requires block_size_k to be greater than zero."); + ORT_ENFORCE(block_size_n > 0, + "DynamicQuantMatMulFp8 requires block_size_n to be greater than zero."); + block_size_m_ = static_cast(block_size_m); + block_size_k_ = static_cast(block_size_k); + block_size_n_ = static_cast(block_size_n); +} + +Status DynamicQuantMatMulFp8::GetFp8Type(const Tensor& tensor, mlas_fp8_mode& out_type) { + return GetFp8Type(static_cast(tensor.GetElementType()), out_type); +} + +Status DynamicQuantMatMulFp8::GetFp8Type(ONNX_NAMESPACE::TensorProto_DataType elem_type, + mlas_fp8_mode& out_type) { + switch (elem_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + out_type = MLAS_FP8_MODE_E4M3_INF; + return Status::OK(); + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: + out_type = MLAS_FP8_MODE_E4M3_SAT; + return Status::OK(); + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + out_type = MLAS_FP8_MODE_E5M2_INF; + return Status::OK(); + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: + out_type = MLAS_FP8_MODE_E5M2_SAT; + return Status::OK(); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported fp8 type for DynamicQuantMatMulFp8."); + } +} + +Status DynamicQuantMatMulFp8::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + if (input_idx != GetBIdx()) { + return Status::OK(); + } + // Only prepack if B scale and zero point are constant initializers. + const OrtValue* b_scale_ort = nullptr; + const OrtValue* b_zp_ort = nullptr; + const bool has_b_scale = Info().TryGetConstantInput(IN_B_SCALE, &b_scale_ort); + const bool has_b_zp = Info().TryGetConstantInput(IN_B_ZERO_POINT, &b_zp_ort); + if (!has_b_scale || !has_b_zp) { + const auto b_elem_type = static_cast(tensor.GetElementType()); + ORT_RETURN_IF(!IsFp8DataType(b_elem_type), + "DynamicQuantMatMulFp8 requires B scale and B zero point to be constant initializers when B " + "is not FP8."); + return Status::OK(); + } + + b_shape_ = tensor.Shape(); + if (b_shape_.NumDimensions() != 2) { + return Status::OK(); + } + + const size_t K = static_cast(b_shape_[0]); + const size_t N = static_cast(b_shape_[1]); + if (K == 0 || N == 0) { + return Status::OK(); + } + + const auto& b_scale = b_scale_ort->Get(); + const auto& b_zp = b_zp_ort->Get(); + + const auto b_elem_type = static_cast(tensor.GetElementType()); + const auto b_zp_elem_type = static_cast(b_zp.GetElementType()); + const bool b_is_fp8 = IsFp8DataType(b_elem_type); + const bool zp_is_fp8 = IsFp8DataType(b_zp_elem_type); + mlas_fp8_mode b_type{}; + if (b_is_fp8) { + ORT_RETURN_IF_ERROR(GetFp8Type(tensor, b_type)); + if (zp_is_fp8) { + mlas_fp8_mode b_zp_type{}; + ORT_RETURN_IF_ERROR(GetFp8Type(b_zp, b_zp_type)); + ORT_RETURN_IF(b_type != b_zp_type, + "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); + } + } else if (zp_is_fp8) { + ORT_RETURN_IF_ERROR(GetFp8Type(b_zp, b_type)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires fp8 zero points when B is not fp8."); + } + b_type_ = b_type; + has_b_type_ = true; + + ORT_RETURN_IF_NOT(b_scale.IsDataType() || b_scale.IsDataType() || b_scale.IsDataType(), + "DynamicQuantMatMulFp8 requires B scale input to be float, float16, or bfloat16."); + ORT_RETURN_IF_NOT(b_scale.Shape().NumDimensions() == 2, + "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor."); + ORT_RETURN_IF_NOT(b_zp.Shape().NumDimensions() == 2, + "DynamicQuantMatMulFp8 requires B zero point to be a 2D tensor."); + ORT_RETURN_IF_NOT(b_zp.Shape()[0] == b_scale.Shape()[0] && + b_zp.Shape()[1] == b_scale.Shape()[1], + "DynamicQuantMatMulFp8 requires B scale and zero point to have the same shape."); + const size_t blocks_k = static_cast(b_scale.Shape()[0]); + const size_t blocks_n = static_cast(b_scale.Shape()[1]); + ORT_RETURN_IF_NOT(blocks_k != 0 && blocks_n != 0, + "DynamicQuantMatMulFp8 requires non-zero B scale dimensions."); + ORT_RETURN_IF_NOT(K % block_size_k_ == 0, + "DynamicQuantMatMulFp8 requires K to be divisible by block_size_k."); + ORT_RETURN_IF_NOT(N % block_size_n_ == 0, + "DynamicQuantMatMulFp8 requires N to be divisible by block_size_n."); + const size_t expected_blocks_k = K / block_size_k_; + const size_t expected_blocks_n = N / block_size_n_; + ORT_RETURN_IF_NOT(blocks_k == expected_blocks_k, + "DynamicQuantMatMulFp8 requires B scale first dimension to be K / block_size_k."); + ORT_RETURN_IF_NOT(blocks_n == expected_blocks_n, + "DynamicQuantMatMulFp8 requires B scale last dimension to be N / block_size_n."); + + const size_t b_scale_elems = static_cast(b_scale.Shape().Size()); + ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(b_zp, b_scale_elems, "B zero point")); + + const float* b_scales = nullptr; + IAllocatorUniquePtr b_scale_float; + if (b_scale.IsDataType()) { + b_scales = b_scale.Data(); + } else if (b_scale.IsDataType()) { + b_scale_float = IAllocator::MakeUniquePtr(alloc, b_scale_elems, true); + for (size_t i = 0; i < b_scale_elems; ++i) { + b_scale_float.get()[i] = static_cast(b_scale.Data()[i]); + } + b_scales = b_scale_float.get(); + } else { + b_scale_float = IAllocator::MakeUniquePtr(alloc, b_scale_elems, true); + for (size_t i = 0; i < b_scale_elems; ++i) { + b_scale_float.get()[i] = static_cast(b_scale.Data()[i]); + } + b_scales = b_scale_float.get(); + } + ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(b_scales, b_scale_elems, "B scale")); + // If B is not already FP8, quantize it once during prepack and reuse the cached FP8 buffer. + if (!b_is_fp8) { + const size_t quantized_b_size = SafeMul(K, N); + quantized_b_ = IAllocator::MakeUniquePtr(alloc, quantized_b_size, true); + quantized_b_size_ = quantized_b_size; + auto* quantized_b_bytes = static_cast(quantized_b_.get()); + if (tensor.IsDataType()) { + ORT_RETURN_IF_ERROR(QuantizeToFp8ByMode(b_type, tensor.Data(), K, N, block_size_k_, block_size_n_, + b_scales, nullptr, quantized_b_bytes)); + } else if (tensor.IsDataType()) { + ORT_RETURN_IF_ERROR(QuantizeToFp8ByMode(b_type, tensor.Data(), K, N, block_size_k_, block_size_n_, + b_scales, nullptr, quantized_b_bytes)); + } else if (tensor.IsDataType()) { + ORT_RETURN_IF_ERROR(QuantizeToFp8ByMode(b_type, tensor.Data(), K, N, block_size_k_, block_size_n_, + b_scales, nullptr, quantized_b_bytes)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported B type for DynamicQuantMatMulFp8 prepack."); + } + if (prepacked_weights != nullptr) { + const std::array metadata_values = { + kPackedBMetadataVersion, + b_shape_[0], + b_shape_[1], + static_cast(quantized_b_size_), + static_cast(b_type_), + 1, + }; + auto metadata = IAllocator::MakeUniquePtr(alloc, kPackedBMetadataSize, true); + std::memcpy(metadata.get(), metadata_values.data(), kPackedBMetadataSize); + prepacked_weights->buffers_.push_back(std::move(quantized_b_)); + prepacked_weights->buffer_sizes_.push_back(quantized_b_size_); + prepacked_weights->buffers_.push_back(std::move(metadata)); + prepacked_weights->buffer_sizes_.push_back(kPackedBMetadataSize); + } + is_packed = true; + return Status::OK(); + } + + return Status::OK(); +} + +Status DynamicQuantMatMulFp8::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + if (input_idx != GetBIdx()) { + return Status::OK(); + } + + ORT_RETURN_IF(prepacked_buffers.size() != 2 || prepacked_buffer_sizes.size() != 2, + "DynamicQuantMatMulFp8 requires shared prepacked B data and metadata buffers."); + ORT_RETURN_IF(prepacked_buffers[0].get() == nullptr, + "DynamicQuantMatMulFp8 requires shared prepacked B data."); + + // Buffer 0 owns quantized B bytes; buffer 1 is metadata used only to restore kernel state. + ORT_RETURN_IF_ERROR(RestorePackedBMetadata(prepacked_buffers[1].get(), + prepacked_buffer_sizes[1], + prepacked_buffer_sizes[0], + b_shape_, + quantized_b_size_, + b_type_, + has_b_type_)); + quantized_b_ = std::move(prepacked_buffers[0]); + used_shared_buffers = true; + return Status::OK(); +} + +Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { + const Tensor* a = context->Input(IN_A); + const Tensor* b = quantized_b_ ? nullptr : context->Input(IN_B); + const Tensor* a_scale = context->Input(IN_A_SCALE); + const Tensor* a_zero_point = context->Input(IN_A_ZERO_POINT); + const Tensor* b_scale = context->Input(IN_B_SCALE); + const Tensor* b_zero_point = context->Input(IN_B_ZERO_POINT); + const Tensor* y_scale = context->Input(IN_Y_SCALE); + const Tensor* y_zero_point = context->Input(IN_Y_ZERO_POINT); + + // Runtime B uses one 2D B scale/zero-point layout, so reject batched B before MatMul broadcasts it. + ORT_RETURN_IF(!quantized_b_ && b->Shape().NumDimensions() != 2, + "DynamicQuantMatMulFp8 requires runtime B to be a 2D tensor."); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), + quantized_b_ ? b_shape_ : b->Shape(), + nullptr, + nullptr)); + + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + Tensor* y = context->Output(OUT_Y, helper.OutputShape()); + const size_t y_size = static_cast(y->Shape().Size()); + ORT_RETURN_IF(!y->IsDataType() && !y->IsDataType() && !y->IsDataType(), + "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16."); + + if (y_zero_point != nullptr) { + // Runtime tensors must match the schema scalar contract before reading element 0. + ORT_RETURN_IF(y_zero_point->Shape().NumDimensions() != 0 || y_zero_point->Shape().Size() != 1, + "DynamicQuantMatMulFp8 requires Y zero point input to be a scalar."); + ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*y_zero_point, 1, "Y zero point")); + } + + float y_scale_storage = 0.0f; + const float* y_scale_data = nullptr; + if (y_scale != nullptr) { + // Runtime tensors must match the schema scalar contract before reading element 0. + ORT_RETURN_IF(y_scale->Shape().NumDimensions() != 0 || y_scale->Shape().Size() != 1, + "DynamicQuantMatMulFp8 requires Y scale input to be a scalar."); + if (y_scale->IsDataType()) { + y_scale_data = y_scale->Data(); + } else if (y_scale->IsDataType()) { + y_scale_storage = static_cast(y_scale->Data()[0]); + y_scale_data = &y_scale_storage; + } else if (y_scale->IsDataType()) { + y_scale_storage = static_cast(y_scale->Data()[0]); + y_scale_data = &y_scale_storage; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires Y scale input to be float, float16, or bfloat16."); + } + ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(y_scale_data, 1, "Y scale")); + } + + // Empty reduction does not need B data, so fill zeros before enforcing runtime FP8 B. + if (K == 0) { + if (y_size == 0) { + return Status::OK(); + } + if (y->IsDataType()) { + std::fill_n(y->MutableData(), y_size, 0.0f); + } else if (y->IsDataType()) { + std::fill_n(y->MutableData(), y_size, MLFloat16::FromBits(0)); + } else if (y->IsDataType()) { + std::fill_n(y->MutableData(), y_size, BFloat16::FromBits(0)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16."); + } + return Status::OK(); + } + + const bool a_is_supported = + a->IsDataType() || a->IsDataType() || a->IsDataType(); + ORT_RETURN_IF(!a_is_supported, "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16."); + + const auto b_elem_type = b ? static_cast(b->GetElementType()) + : static_cast(0); + const auto b_zp_elem_type = + static_cast(b_zero_point->GetElementType()); + const bool b_is_fp8 = IsFp8DataType(b_elem_type); + + // Select the FP8 B buffer: prefer pre-quantized B from PrePack, otherwise accept FP8-typed B input. + const uint8_t* b_fp8 = nullptr; + if (quantized_b_) { + b_fp8 = static_cast(quantized_b_.get()); + } else if (b_is_fp8) { + b_fp8 = static_cast(b->DataRaw()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires runtime B input to be FP8. Non-FP8 B is only supported " + "when B is a constant initializer that can be quantized during prepack."); + } + + mlas_fp8_mode a_type{}; + ORT_RETURN_IF(a_zero_point == nullptr, + "DynamicQuantMatMulFp8 requires FP8 zero point for A."); + ORT_RETURN_IF_ERROR(GetFp8Type(*a_zero_point, a_type)); + mlas_fp8_mode b_type{}; + if (has_b_type_) { + b_type = b_type_; + } else if (b_is_fp8) { + ORT_RETURN_IF_ERROR(GetFp8Type(b_elem_type, b_type)); + if (IsFp8DataType(b_zp_elem_type)) { + mlas_fp8_mode b_zp_type{}; + ORT_RETURN_IF_ERROR(GetFp8Type(b_zp_elem_type, b_zp_type)); + ORT_RETURN_IF(b_type != b_zp_type, + "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); + } + } else { + ORT_RETURN_IF_ERROR(GetFp8Type(b_zp_elem_type, b_type)); + } + ORT_RETURN_IF(a_type != b_type, + "DynamicQuantMatMulFp8 requires A/B FP8 types to match."); + + if (y_size == 0) { + return Status::OK(); + } + + const size_t num_gemms = helper.OutputOffsets().size(); + const size_t a_scale_rank = static_cast(a_scale->Shape().NumDimensions()); + const size_t a_zp_rank = static_cast(a_zero_point->Shape().NumDimensions()); + // The scale tensor layout carries the number of tiles because it physically stores one + // scale per tile. + // A scale/zero-point may be [blocks_m, blocks_k] or [prefix..., blocks_m, blocks_k]. + // The scale tensor shape provides the number of quantization blocks. It does not define + // the block size. Block sizes come from the block_size_m/block_size_k/block_size_n attributes + // so future models can choose different tile sizes without changing how scale tensors are + // interpreted. The validation below binds both pieces of information together by requiring: + // blocks_m == ceil(M / block_size_m) + // blocks_k == K / block_size_k + // blocks_n == N / block_size_n + // This prevents silently treating a malformed scale shape as a different runtime block size. + ORT_RETURN_IF(a_scale_rank < 2, + "DynamicQuantMatMulFp8 requires A scale to have rank >= 2."); + ORT_RETURN_IF(a_zp_rank < 2, + "DynamicQuantMatMulFp8 requires A zero point to have rank >= 2."); + ORT_RETURN_IF(a_scale_rank != a_zp_rank, + "DynamicQuantMatMulFp8 requires A scale and zero point to have the same rank."); + const size_t blocks_m = static_cast(a_scale->Shape()[a_scale_rank - 2]); + const size_t blocks_k = static_cast(a_scale->Shape()[a_scale_rank - 1]); + for (size_t dim = 0; dim < a_scale_rank; ++dim) { + ORT_RETURN_IF(a_scale->Shape()[dim] != a_zero_point->Shape()[dim], + "DynamicQuantMatMulFp8 requires A scale and zero point to have the same shape."); + } + if (a_scale_rank != 2) { + const size_t y_rank = y->Shape().NumDimensions(); + ORT_RETURN_IF(a_scale_rank != y_rank, + "DynamicQuantMatMulFp8 requires A scale rank to be 2 or match Y rank."); + for (size_t dim = 0; dim < y_rank - 2; ++dim) { + ORT_RETURN_IF(a_scale->Shape()[dim] != y->Shape()[dim], + "DynamicQuantMatMulFp8 requires A scale batch dimensions to match Y."); + } + } + // Scale tensor block counts must match the explicit block-size attributes before reading scale data. + ORT_RETURN_IF(blocks_m == 0, "DynamicQuantMatMulFp8 requires non-zero A scale M dimension."); + ORT_RETURN_IF(blocks_k == 0, "DynamicQuantMatMulFp8 requires non-zero A scale K dimension."); + ORT_RETURN_IF(K % block_size_k_ != 0, + "DynamicQuantMatMulFp8 requires K to be divisible by block_size_k."); + const size_t expected_blocks_m = CeilDiv(M, block_size_m_); + const size_t expected_blocks_k = K / block_size_k_; + // If the scale tensor says it has a different number of M blocks than ceil(M / block_size_m), + // return an error instead of running with wrong scale indexing. + ORT_RETURN_IF(blocks_m != expected_blocks_m, + "DynamicQuantMatMulFp8 requires A scale M dimension to be ceil(M / block_size_m)."); + ORT_RETURN_IF(blocks_k != expected_blocks_k, + "DynamicQuantMatMulFp8 requires A scale K dimension to be K / block_size_k."); + + ORT_RETURN_IF(b_scale->Shape().NumDimensions() != 2, + "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor."); + ORT_RETURN_IF(b_zero_point->Shape().NumDimensions() != 2, + "DynamicQuantMatMulFp8 requires B zero point to be a 2D tensor."); + const size_t blocks_n = static_cast(b_scale->Shape()[1]); + ORT_RETURN_IF(b_zero_point->Shape()[0] != b_scale->Shape()[0] || + b_zero_point->Shape()[1] != b_scale->Shape()[1], + "DynamicQuantMatMulFp8 requires B scale and zero point to have the same shape."); + ORT_RETURN_IF(static_cast(b_scale->Shape()[0]) != blocks_k, + "DynamicQuantMatMulFp8 requires B scale K dimension to match A scale K dimension."); + ORT_RETURN_IF(blocks_n == 0, "DynamicQuantMatMulFp8 requires non-zero B scale N dimension."); + ORT_RETURN_IF(N % block_size_n_ != 0, + "DynamicQuantMatMulFp8 requires N to be divisible by block_size_n."); + const size_t expected_blocks_n = N / block_size_n_; + ORT_RETURN_IF(blocks_n != expected_blocks_n, + "DynamicQuantMatMulFp8 requires B scale N dimension to be N / block_size_n."); + + size_t a_scale_prefix = 1; + if (a_scale_rank > 2) { + for (size_t dim = 0; dim < a_scale_rank - 2; ++dim) { + a_scale_prefix = SafeMul( + a_scale_prefix, static_cast(a_scale->Shape()[dim])); + } + } + ORT_RETURN_IF(a_scale_prefix != 1 && a_scale_prefix != num_gemms, + "DynamicQuantMatMulFp8 requires A scale batch dimensions to match the number of gemms."); + const size_t a_scale_batch_stride = SafeMul(blocks_m, blocks_k); + const size_t a_zp_count = SafeMul(a_scale_prefix, a_scale_batch_stride); + const size_t b_zp_count = SafeMul(blocks_k, blocks_n); + + ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*a_zero_point, a_zp_count, "A zero point")); + ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*b_zero_point, b_zp_count, "B zero point")); + + AllocatorPtr temp_allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&temp_allocator)); + + const float* a_scales = nullptr; + IAllocatorUniquePtr a_scale_float; + const size_t a_scale_elems = static_cast(a_scale->Shape().Size()); + if (a_scale->IsDataType()) { + a_scales = a_scale->Data(); + } else { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + a_scale_float = IAllocator::MakeUniquePtr(allocator, a_scale_elems, true); + if (a_scale->IsDataType()) { + for (size_t i = 0; i < a_scale_elems; ++i) { + a_scale_float.get()[i] = static_cast(a_scale->Data()[i]); + } + } else if (a_scale->IsDataType()) { + for (size_t i = 0; i < a_scale_elems; ++i) { + a_scale_float.get()[i] = static_cast(a_scale->Data()[i]); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires A scale input to be float, float16, or bfloat16."); + } + a_scales = a_scale_float.get(); + } + + const float* b_scales = nullptr; + IAllocatorUniquePtr b_scale_float; + const size_t b_scale_elems = static_cast(b_scale->Shape().Size()); + if (b_scale->IsDataType()) { + b_scales = b_scale->Data(); + } else { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + b_scale_float = IAllocator::MakeUniquePtr(allocator, b_scale_elems, true); + if (b_scale->IsDataType()) { + for (size_t i = 0; i < b_scale_elems; ++i) { + b_scale_float.get()[i] = static_cast(b_scale->Data()[i]); + } + } else if (b_scale->IsDataType()) { + for (size_t i = 0; i < b_scale_elems; ++i) { + b_scale_float.get()[i] = static_cast(b_scale->Data()[i]); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires B scale input to be float, float16, or bfloat16."); + } + b_scales = b_scale_float.get(); + } + + // MLAS FP8 GEMM accumulates and stores float output. Use scratch for lower-precision Y, + // then convert once after all batched GEMMs complete. + IAllocatorUniquePtr y_float_buffer; + float* y_float_data = nullptr; + if (y->IsDataType()) { + y_float_data = y->MutableData(); + } else if (y->IsDataType() || y->IsDataType()) { + y_float_buffer = IAllocator::MakeUniquePtr(temp_allocator, y_size, true); + y_float_data = y_float_buffer.get(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16."); + } + + MLAS_FP8_GEMM_SHAPE_PARAMS gemm_shape; + gemm_shape.M = M; + gemm_shape.N = N; + gemm_shape.K = K; + + std::vector gemm_data(num_gemms); + std::vector> a_fp8_buffers(num_gemms); + ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(a_scales, a_scale_elems, "A scale")); + ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(b_scales, b_scale_elems, "B scale")); + + const size_t a_fp8_size = SafeMul(M, K); + for (size_t gemm_idx = 0; gemm_idx < num_gemms; ++gemm_idx) { + auto& params = gemm_data[gemm_idx]; + const size_t a_offset = helper.LeftOffsets()[gemm_idx]; + const size_t scale_batch_index = (a_scale_prefix == 1) ? 0 : gemm_idx; + const size_t a_scale_batch_offset = SafeMul(scale_batch_index, a_scale_batch_stride); + const float* a_scales_batch = a_scales + a_scale_batch_offset; + a_fp8_buffers[gemm_idx] = IAllocator::MakeUniquePtr(temp_allocator, a_fp8_size, true); + if (a->IsDataType()) { + QuantizeBlockwiseFp8A(a->Data() + a_offset, M, K, block_size_m_, block_size_k_, blocks_k, a_scales_batch, + nullptr, a_type, + a_fp8_buffers[gemm_idx].get()); + } else if (a->IsDataType()) { + QuantizeBlockwiseFp8A(a->Data() + a_offset, M, K, block_size_m_, block_size_k_, blocks_k, + a_scales_batch, + nullptr, a_type, + a_fp8_buffers[gemm_idx].get()); + } else if (a->IsDataType()) { + QuantizeBlockwiseFp8A(a->Data() + a_offset, M, K, block_size_m_, block_size_k_, blocks_k, + a_scales_batch, + nullptr, a_type, + a_fp8_buffers[gemm_idx].get()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16."); + } + params.A = a_fp8_buffers[gemm_idx].get(); + params.lda = K; + params.B = b_fp8 + helper.RightOffsets()[gemm_idx]; + params.ldb = N; + params.C = y_float_data + helper.OutputOffsets()[gemm_idx]; + params.ldc = N; + params.ScaleA = a_scales_batch; + params.ScaleB = b_scales; + params.ScaleY = y_scale_data; + params.ZeroPointA = nullptr; + params.ZeroPointB = nullptr; + params.ZeroPointY = nullptr; + params.Fp8Type = a_type; + params.BlockSizeM = block_size_m_; + params.BlockSizeK = block_size_k_; + params.BlockSizeN = block_size_n_; + params.BlocksM = blocks_m; + params.BlocksK = blocks_k; + params.BlocksN = blocks_n; + params.ScaleAStrideK = 1; + params.ScaleAStrideM = blocks_k; + params.ScaleBStrideN = 1; + params.ScaleBStrideK = blocks_n; + params.ZeroPointAStrideK = 1; + params.ZeroPointAStrideM = blocks_k; + params.ZeroPointBStrideN = 1; + params.ZeroPointBStrideK = blocks_n; + } + + MlasFp8GemmBatch(gemm_shape, gemm_data.data(), num_gemms, context->GetOperatorThreadPool()); + + if (y_float_buffer != nullptr) { + if (y->IsDataType()) { + auto* y_data = y->MutableData(); + for (size_t i = 0; i < y_size; ++i) { + y_data[i] = static_cast(y_float_data[i]); + } + } else { + auto* y_data = y->MutableData(); + for (size_t i = 0; i < y_size; ++i) { + y_data[i] = BFloat16(y_float_data[i]); + } + } + } + + return Status::OK(); +} + +#endif // !defined(DISABLE_FLOAT8_TYPES) + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h new file mode 100644 index 0000000000000..62360e80fd180 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/prepacked_weights.h" +#include "core/graph/onnx_protobuf.h" +#include "core/mlas/inc/mlas.h" + +namespace onnxruntime { +namespace contrib { + +class DynamicQuantMatMulFp8 final : public OpKernel { + public: + DynamicQuantMatMulFp8(const OpKernelInfo& info); + + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status Compute(OpKernelContext* context) const override; + + enum InputTensors : int { + IN_A = 0, + IN_A_SCALE = 1, + IN_A_ZERO_POINT = 2, + IN_B = 3, + IN_B_SCALE = 4, + IN_B_ZERO_POINT = 5, + IN_Y_SCALE = 6, + IN_Y_ZERO_POINT = 7 + }; + + enum OutputTensors : int { OUT_Y = 0 }; + + static Status GetFp8Type(const Tensor& tensor, mlas_fp8_mode& out_type); + static Status GetFp8Type(ONNX_NAMESPACE::TensorProto_DataType elem_type, mlas_fp8_mode& out_type); + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + + private: + static constexpr int GetBIdx() { return IN_B; } + IAllocatorUniquePtr quantized_b_; + size_t quantized_b_size_{0}; + TensorShape b_shape_; + mlas_fp8_mode b_type_{static_cast(0)}; + bool has_b_type_{false}; + size_t block_size_m_{128}; + size_t block_size_k_{128}; + size_t block_size_n_{128}; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 59f97c222ceb2..6894b4e7eed40 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -25,6 +25,9 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MulInteger); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QGemm); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicQuantMatMulFp8); +#endif class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAdd); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConcat); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearWhere); @@ -139,6 +142,9 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); +#if !defined(DISABLE_FLOAT8_TYPES) + fn(GetOpSchema()); +#endif fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 5a3cd86b04492..d055aa9c9333a 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -946,6 +946,216 @@ ONNX_MS_OPERATOR_SET_SCHEMA( updateOutputShape(ctx, 0, {first_input_shape.dim(transA ? 1 : 0), second_input_shape.dim(transB ? 0 : 1)}); } })); + +#if !defined(DISABLE_FLOAT8_TYPES) +ONNX_MS_OPERATOR_SET_SCHEMA( + DynamicQuantMatMulFp8, 1, + OpSchema() + .SetDoc("Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from " + "float16/bfloat16/float) and runtime casting of activations to fp8 using tile-wise scales. " + "All zero-point inputs, when provided, must encode 0.0.") + .Input(0, "A", "Input tensor A.", "TA") + .Input(1, "a_scale", + "Scale of quantized input 'A'. Must be a tile-wise tensor with shape " + "(ceil(M / block_size_m), K / block_size_k), or the same shape with output batch dimensions prefixed.", + "TS") + .Input(2, "a_zero_point", + "Zero point tensor for input 'A'. Must have the same shape as a_scale and all values must encode 0.0.", + "TZ") + .Input(3, "B", + "Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only " + "supported when B is a constant initializer that can be quantized during prepack.", + "TB") + .Input(4, "b_scale", + "Scale of input 'B'. Must be a tile-wise tensor with shape " + "(K / block_size_k, N / block_size_n).", + "TS") + .Input(5, "b_zero_point", + "Zero point tensor for input 'B'. Must have the same shape as b_scale and all values must encode 0.0.", + "TZ") + .Input(6, "y_scale", "Scale of output 'Y'. Must be a scalar when provided.", "TS", + OpSchema::Optional) + .Input(7, "y_zero_point", + "Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided.", "TZ", + OpSchema::Optional) + .Output(0, "Y", "Output tensor of shape (..., M, N).", "TY") + .Attr("block_size_m", "Block size along M for A tile-wise scales.", AttributeProto::INT, + static_cast(128)) + .Attr("block_size_k", "Block size along K for A and B tile-wise scales.", AttributeProto::INT, + static_cast(128)) + .Attr("block_size_n", "Block size along N for B tile-wise scales.", AttributeProto::INT, + static_cast(128)) + .TypeConstraint("TA", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, + "Constrain input A type to float16, bfloat16, or float.") + .TypeConstraint("TB", + {"tensor(float16)", "tensor(bfloat16)", "tensor(float)", + "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", + "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"}, + "Constrain input B type to fp8, or to float16, bfloat16, or float for constant initializers.") + .TypeConstraint("TZ", {"tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"}, + "Constrain zero point types to fp8. Only zero-valued zero points are supported.") + // Scale tensors are upcast to float by the CPU kernel before calling MLAS. + .TypeConstraint("TS", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain scale types to float, float16, or bfloat16.") + .TypeConstraint("TY", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, + "Constrain output type to float16, bfloat16, or float.") + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + const int64_t block_size_m = getAttribute(ctx, "block_size_m", static_cast(128)); + const int64_t block_size_k = getAttribute(ctx, "block_size_k", static_cast(128)); + const int64_t block_size_n = getAttribute(ctx, "block_size_n", static_cast(128)); + if (block_size_m <= 0 || block_size_k <= 0 || block_size_n <= 0) { + fail_type_inference("block_size_m, block_size_k, and block_size_n must be greater than zero."); + } + const auto ceil_div = [](int64_t value, int64_t divisor) { + return value == 0 ? int64_t{0} : ((value - 1) / divisor) + 1; + }; + + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 3)) { + ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 3); + } + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { + auto& a_shape = getInputShape(ctx, 0); + auto& a_scale_shape = getInputShape(ctx, 1); + const int a_rank = a_shape.dim_size(); + if (a_rank < 2) { + fail_type_inference("A must be at least 2D."); + } + const int a_scale_rank = a_scale_shape.dim_size(); + if (a_scale_rank < 2) { + fail_type_inference("A scale must have rank 2 or the same rank as Y."); + } + if (a_scale_rank != 2) { + if (hasInputShape(ctx, 3)) { + const auto& y_shape = ctx.getOutputType(0)->tensor_type().shape(); + const int y_rank = y_shape.dim_size(); + if (a_scale_rank != y_rank) { + fail_type_inference("A scale must have rank 2 or the same rank as Y."); + } + for (int i = 0; i < y_rank - 2; ++i) { + if (y_shape.dim(i).has_dim_value() && a_scale_shape.dim(i).has_dim_value() && + y_shape.dim(i).dim_value() != a_scale_shape.dim(i).dim_value()) { + fail_type_inference("A scale batch dimensions must match Y."); + } + } + } else if (a_scale_rank != a_rank) { + fail_type_inference("A scale must have rank 2 or the same rank as Y."); + } + } + if (a_shape.dim(a_rank - 2).has_dim_value() && a_scale_shape.dim(a_scale_rank - 2).has_dim_value() && + a_shape.dim(a_rank - 1).has_dim_value() && a_scale_shape.dim(a_scale_rank - 1).has_dim_value()) { + const auto m = a_shape.dim(a_rank - 2).dim_value(); + const auto k = a_shape.dim(a_rank - 1).dim_value(); + const auto m_blocks = ceil_div(m, block_size_m); + if (a_scale_shape.dim(a_scale_rank - 2).dim_value() != m_blocks) { + fail_type_inference("A scale second-to-last dimension must be ceil(M / block_size_m)."); + } + if ((k % block_size_k) != 0 || + a_scale_shape.dim(a_scale_rank - 1).dim_value() != (k / block_size_k)) { + fail_type_inference("A scale last dimension must be K / block_size_k."); + } + } + } + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { + auto& a_shape = getInputShape(ctx, 0); + auto& a_zp_shape = getInputShape(ctx, 2); + const int a_rank = a_shape.dim_size(); + if (a_rank < 2) { + fail_type_inference("A must be at least 2D."); + } + const int a_zp_rank = a_zp_shape.dim_size(); + if (a_zp_rank < 2) { + fail_type_inference("A zero point must have rank 2 or the same rank as Y."); + } + if (a_zp_rank != 2) { + if (hasInputShape(ctx, 3)) { + const auto& y_shape = ctx.getOutputType(0)->tensor_type().shape(); + const int y_rank = y_shape.dim_size(); + if (a_zp_rank != y_rank) { + fail_type_inference("A zero point must have rank 2 or the same rank as Y."); + } + for (int i = 0; i < y_rank - 2; ++i) { + if (y_shape.dim(i).has_dim_value() && a_zp_shape.dim(i).has_dim_value() && + y_shape.dim(i).dim_value() != a_zp_shape.dim(i).dim_value()) { + fail_type_inference("A zero point batch dimensions must match Y."); + } + } + } else if (a_zp_rank != a_rank) { + fail_type_inference("A zero point must have rank 2 or the same rank as Y."); + } + } + if (a_shape.dim(a_rank - 2).has_dim_value() && a_zp_shape.dim(a_zp_rank - 2).has_dim_value() && + a_shape.dim(a_rank - 1).has_dim_value() && a_zp_shape.dim(a_zp_rank - 1).has_dim_value()) { + const auto m = a_shape.dim(a_rank - 2).dim_value(); + const auto k = a_shape.dim(a_rank - 1).dim_value(); + const auto m_blocks = ceil_div(m, block_size_m); + if (a_zp_shape.dim(a_zp_rank - 2).dim_value() != m_blocks) { + fail_type_inference("A zero point second-to-last dimension must be ceil(M / block_size_m)."); + } + if ((k % block_size_k) != 0 || + a_zp_shape.dim(a_zp_rank - 1).dim_value() != (k / block_size_k)) { + fail_type_inference("A zero point last dimension must be K / block_size_k."); + } + } + } + if (hasInputShape(ctx, 6)) { + auto shape = ctx.getInputType(6)->tensor_type().shape(); + if (shape.dim_size() != 0) { + fail_type_inference("Y scale input must be a scalar."); + } + } + if (hasInputShape(ctx, 7)) { + auto shape = ctx.getInputType(7)->tensor_type().shape(); + if (shape.dim_size() != 0) { + fail_type_inference("Y zero point input must be a scalar."); + } + } + if (hasInputShape(ctx, 3) && hasInputShape(ctx, 4)) { + auto& b_shape = getInputShape(ctx, 3); + auto& b_scale_shape = getInputShape(ctx, 4); + if (b_shape.dim_size() != 2) { + fail_type_inference("B must be 2D."); + } + if (b_scale_shape.dim_size() != 2) { + fail_type_inference("B scale must be 2D."); + } + if (b_shape.dim(1).has_dim_value() && b_scale_shape.dim(1).has_dim_value()) { + const auto n = b_shape.dim(1).dim_value(); + if ((n % block_size_n) != 0 || b_scale_shape.dim(1).dim_value() != (n / block_size_n)) { + fail_type_inference("B scale last dimension must be N / block_size_n."); + } + } + if (b_shape.dim(0).has_dim_value() && b_scale_shape.dim(0).has_dim_value()) { + const auto k = b_shape.dim(0).dim_value(); + if ((k % block_size_k) != 0 || b_scale_shape.dim(0).dim_value() != (k / block_size_k)) { + fail_type_inference("B scale first dimension must be K / block_size_k."); + } + } + } + if (hasInputShape(ctx, 3) && hasInputShape(ctx, 5)) { + auto& b_shape = getInputShape(ctx, 3); + auto& b_zp_shape = getInputShape(ctx, 5); + if (b_shape.dim_size() != 2) { + fail_type_inference("B must be 2D."); + } + if (b_zp_shape.dim_size() != 2) { + fail_type_inference("B zero point must be 2D."); + } + if (b_shape.dim(1).has_dim_value() && b_zp_shape.dim(1).has_dim_value()) { + const auto n = b_shape.dim(1).dim_value(); + if ((n % block_size_n) != 0 || b_zp_shape.dim(1).dim_value() != (n / block_size_n)) { + fail_type_inference("B zero point last dimension must be N / block_size_n."); + } + } + if (b_shape.dim(0).has_dim_value() && b_zp_shape.dim(0).has_dim_value()) { + const auto k = b_shape.dim(0).dim_value(); + if ((k % block_size_k) != 0 || b_zp_shape.dim(0).dim_value() != (k / block_size_k)) { + fail_type_inference("B zero point first dimension must be K / block_size_k."); + } + } + } + })); + +#endif ONNX_MS_OPERATOR_SET_SCHEMA( QAttention, 1, OpSchema() diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 04e99d206bd06..806f76266f633 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -726,6 +726,81 @@ bool MLASCALL MlasIsDynamicQGemmAvailable(const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); +/** + * @brief Parameters that define the shape of an FP8 GEMM operation. + */ +struct MLAS_FP8_GEMM_SHAPE_PARAMS { + size_t M = 0; /**< Row size of matrix A */ + size_t N = 0; /**< Column size of matrix B */ + size_t K = 0; /**< Column size of matrix A and row size of matrix B */ +}; + +// FP8 mode is aligned with Arm KleidiAI format/overflow modes. +// Defined here to keep MLAS FP8 APIs platform-agnostic. +#ifndef MLAS_FP8_MODE_DEFINED +#define MLAS_FP8_MODE_DEFINED +enum mlas_fp8_mode : uint8_t { + MLAS_FP8_MODE_E4M3_INF = 0, ///< E4M3 with NaN/Inf on overflow. + MLAS_FP8_MODE_E4M3_SAT = 1, ///< E4M3 with saturation on overflow. + MLAS_FP8_MODE_E5M2_INF = 2, ///< E5M2 with NaN/Inf on overflow. + MLAS_FP8_MODE_E5M2_SAT = 3, ///< E5M2 with saturation on overflow. + MLAS_FP8_MODE_END = 4, ///< End marker. Do not use. +}; +#endif // MLAS_FP8_MODE_DEFINED + +/** + * @brief Parameters that define the data buffers and layout for an FP8 GEMM. + */ +struct MLAS_FP8_GEMM_DATA_PARAMS { + const void* A = nullptr; + size_t lda = 0; + const void* B = nullptr; + size_t ldb = 0; + void* C = nullptr; + size_t ldc = 0; + const float* ScaleA = nullptr; // Tile scales for A: [BlocksM, BlocksK]. + const float* ScaleB = nullptr; // Tile scales for B: [BlocksK, BlocksN]. + const float* ScaleY = nullptr; // Scalar scale for Y. + const void* ZeroPointA = nullptr; // Tile zero-points for A: [BlocksM, BlocksK], fp8. + const void* ZeroPointB = nullptr; // Tile zero-points for B: [BlocksK, BlocksN], fp8. + const float* ZeroPointY = nullptr; // Scalar zero-point for Y, in dequantized float units. + mlas_fp8_mode Fp8Type = static_cast(0); + size_t BlockSizeM = 128; // Block size along M for A quantization. + size_t BlockSizeK = 128; // Block size along K for A/B quantization. + size_t BlockSizeN = 128; // Block size along N for B quantization. + size_t BlocksM = 0; // Number of blocks along M (ceil(M / BlockSizeM)). + size_t BlocksK = 0; // Number of blocks along K (ceil(K / BlockSizeK)). + size_t BlocksN = 0; // Number of blocks along N (ceil(N / BlockSizeN)). + size_t ScaleAStrideK = 0; // ScaleA stride between K blocks (elements). + size_t ScaleAStrideM = 0; // ScaleA stride between M blocks (elements). + size_t ScaleBStrideN = 0; // ScaleB stride between N blocks (elements). + size_t ScaleBStrideK = 0; // ScaleB stride between K blocks (elements). + size_t ZeroPointAStrideK = 0; // ZeroPointA stride between K blocks (elements). + size_t ZeroPointAStrideM = 0; // ZeroPointA stride between M blocks (elements). + size_t ZeroPointBStrideN = 0; // ZeroPointB stride between N blocks (elements). + size_t ZeroPointBStrideK = 0; // ZeroPointB stride between K blocks (elements). +}; + +#if !defined(DISABLE_FLOAT8_TYPES) +void +MLASCALL +MlasFp8GemmBatch( + const MLAS_FP8_GEMM_SHAPE_PARAMS& Shape, + const MLAS_FP8_GEMM_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool +); + +inline void +MlasFp8Gemm( + const MLAS_FP8_GEMM_SHAPE_PARAMS& Shape, + const MLAS_FP8_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool +) { + MlasFp8GemmBatch(Shape, DataParams, 1, ThreadPool); +} +#endif // !defined(DISABLE_FLOAT8_TYPES) + // // Symmetric QGEMM has limited buffer overrun. // Currently only supported in ARM64 diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index a1aa241b89299..7b6b8b5262edb 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -165,7 +165,7 @@ MLASCALL MlasDynamicQGemmBatch( const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, - const size_t BatchN, + const size_t BatchSize, MLAS_THREADPOOL* ThreadPool ); @@ -200,37 +200,3 @@ MlasConv( MLAS_THREADPOOL* ThreadPool ); } - -/*++ - -Routine Description: - - This routine determines if a wraparound will occur when multiplying two size_t variables - Uses __builtin_mul_overflow if available on the current system and if not falls back - to a default implementation to check this wraparound. - -Arguments: - - a - Supplies the first number to be muliplied. - - b - Supplies the second number to be muliplied. - - out - pointer to a size_t which acts as the return value in success cases. - -Return Value: - - Returns false if the operation was successful - Returns true if wraparound of size_t was detected - ---*/ -inline bool mul_overflow_size_t_builtin(size_t a, size_t b, size_t* out) { -#if defined(__has_builtin) -# if __has_builtin(__builtin_mul_overflow) - return __builtin_mul_overflow(a, b, out); -# endif -#endif - // Fallback to manual check if builtin not available - if (b != 0 && a > SIZE_MAX / b) return true; - if (out) *out = a * b; - return false; -} diff --git a/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp index f88af056aa156..99816649ac7d0 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp @@ -279,7 +279,7 @@ Return Value: LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr); size_t lhs_resize = 0; - if (mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize)) + if (MlasMultiplyOverflowsSizeT(LhsPackedStride, BatchSize, &lhs_resize)) { // size_t wraparound detected for LhsPackedStride, fallback to MLAS return false; @@ -304,7 +304,7 @@ Return Value: // Multithread pack lhs and rhs RhsPackedStride = ArmKleidiAI::MlasSBGemmPackBSize(TransA, TransB, N, K); size_t rhs_resize = 0; - if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize)) + if (MlasMultiplyOverflowsSizeT(RhsPackedStride, BatchSize, &rhs_resize)) { // size_t wraparound detected for RhsPackedStride, fallback to MLAS return false; @@ -354,7 +354,7 @@ Return Value: // Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop. // Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively. size_t max_tile_elems = 0; - if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) { + if (MlasMultiplyOverflowsSizeT(m_step, n_step, &max_tile_elems)) { // size_t wraparound detected for tile size, fallback to MLAS return false; } diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp index 69b83e1e5b49c..ef599cf1346a8 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -542,7 +542,7 @@ Return Value: LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr); size_t lhs_resize = 0; - if(mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize)) + if(MlasMultiplyOverflowsSizeT(LhsPackedStride, BatchSize, &lhs_resize)) { // size_t wraparound detected for LhsPackedStride, fallback to MLAS return false; @@ -568,7 +568,7 @@ Return Value: // Multithread pack lhs and rhs RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K); size_t rhs_resize = 0; - if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize)) + if (MlasMultiplyOverflowsSizeT(RhsPackedStride, BatchSize, &rhs_resize)) { // size_t wraparound detected for RhsPackedStride, fallback to MLAS return false; @@ -616,7 +616,7 @@ Return Value: // Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop. // Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively. size_t max_tile_elems = 0; - if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) { + if (MlasMultiplyOverflowsSizeT(m_step, n_step, &max_tile_elems)) { // size_t wraparound detected for tile size, fallback to MLAS return false; } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 1fa4c90913b24..8618bf38a05b2 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -119,6 +119,33 @@ Module Name: #define MLAS_UNREFERENCED_PARAMETER(parameter) ((void)(parameter)) +// +// Reports whether multiplying two size_t values would overflow. +// + +MLAS_FORCEINLINE +bool +MlasMultiplyOverflowsSizeT( + size_t a, + size_t b, + size_t* out + ) +{ +#if defined(__has_builtin) +#if __has_builtin(__builtin_mul_overflow) + size_t result; + return __builtin_mul_overflow(a, b, out != nullptr ? out : &result); +#endif +#endif + if (b != 0 && a > std::numeric_limits::max() / b) { + return true; + } + if (out != nullptr) { + *out = a * b; + } + return false; +} + #ifdef MLAS_NO_EXCEPTION MLAS_FORCEINLINE void diff --git a/onnxruntime/core/mlas/lib/qgemm_fp8.cpp b/onnxruntime/core/mlas/lib/qgemm_fp8.cpp new file mode 100644 index 0000000000000..f434ecab2d87d --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_fp8.cpp @@ -0,0 +1,219 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "mlasi.h" + +#if !defined(DISABLE_FLOAT8_TYPES) + +#include "core/common/common.h" +#include "core/common/float8.h" + +namespace { + +inline float Fp8ByteToFloat(uint8_t value, mlas_fp8_mode mode) { + using onnxruntime::Float8E4M3FN; + using onnxruntime::Float8E4M3FNUZ; + using onnxruntime::Float8E5M2; + using onnxruntime::Float8E5M2FNUZ; + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: + return Float8E4M3FN(value, Float8E4M3FN::FromBits()).ToFloat(); + case MLAS_FP8_MODE_E4M3_SAT: + return Float8E4M3FNUZ(value, Float8E4M3FNUZ::FromBits()).ToFloat(); + case MLAS_FP8_MODE_E5M2_INF: + return Float8E5M2(value, Float8E5M2::FromBits()).ToFloat(); + case MLAS_FP8_MODE_E5M2_SAT: + return Float8E5M2FNUZ(value, Float8E5M2FNUZ::FromBits()).ToFloat(); + default: + MLAS_THROW_EX(std::invalid_argument, "Unsupported FP8 GEMM mode."); + break; + } +} + +inline bool IsValidFp8Mode(mlas_fp8_mode mode) { + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: + case MLAS_FP8_MODE_E4M3_SAT: + case MLAS_FP8_MODE_E5M2_INF: + case MLAS_FP8_MODE_E5M2_SAT: + return true; + default: + return false; + } +} + +// Computes the parallel work item count without wrapping before ptrdiff_t narrowing. +inline ptrdiff_t CheckedWorkItems(size_t batch_count, size_t m) { + size_t work_items = 0; + ORT_ENFORCE(!MlasMultiplyOverflowsSizeT(batch_count, m, &work_items), "FP8 GEMM work item count overflow."); + ORT_ENFORCE(work_items <= static_cast(std::numeric_limits::max()), + "FP8 GEMM work item count exceeds ptrdiff_t range."); + return static_cast(work_items); +} + +// Validates the largest row-major strided matrix offset used by the GEMM loops. +inline void CheckStridedMatrixOffset(size_t rows, size_t cols, size_t row_stride) { + if (rows == 0 || cols == 0) { + return; + } + + size_t offset = 0; + ORT_ENFORCE(!MlasMultiplyOverflowsSizeT(rows - 1, row_stride, &offset), + "FP8 GEMM matrix offset overflow."); + ORT_ENFORCE(cols - 1 <= std::numeric_limits::max() - offset, + "FP8 GEMM matrix offset overflow."); +} + +// Validates the largest block-scale or zero-point offset used by the GEMM loops. +inline void CheckBlockedMatrixOffset(size_t blocks0, size_t stride0, size_t blocks1, size_t stride1) { + if (blocks0 == 0 || blocks1 == 0) { + return; + } + + size_t offset0 = 0; + ORT_ENFORCE(!MlasMultiplyOverflowsSizeT(blocks0 - 1, stride0, &offset0), + "FP8 GEMM block offset overflow."); + size_t offset1 = 0; + ORT_ENFORCE(!MlasMultiplyOverflowsSizeT(blocks1 - 1, stride1, &offset1), + "FP8 GEMM block offset overflow."); + ORT_ENFORCE(offset1 <= std::numeric_limits::max() - offset0, + "FP8 GEMM block offset overflow."); +} + +inline void CheckBlockCount(size_t actual, size_t supplied, const char* dimension_name) { + ORT_ENFORCE(actual == supplied, "FP8 GEMM ", dimension_name, " block count must match shape and block size."); +} + +// Validate caller-provided buffers, strides, and block metadata before parallel workers dereference them. +inline void CheckFp8GemmBatchParams( + const MLAS_FP8_GEMM_SHAPE_PARAMS& shape, + const MLAS_FP8_GEMM_DATA_PARAMS& params) { + ORT_ENFORCE(IsValidFp8Mode(params.Fp8Type), "FP8 GEMM mode must be valid."); + ORT_ENFORCE(params.BlockSizeM != 0 && params.BlockSizeK != 0 && params.BlockSizeN != 0, + "FP8 GEMM block sizes must be non-zero."); + + const bool writes_output = shape.M != 0 && shape.N != 0; + const bool reads_reduction_data = writes_output && shape.K != 0; + + // Empty-output GEMMs do not dereference C, and empty reductions do not read A/B. + if (reads_reduction_data) { + ORT_ENFORCE(params.A != nullptr, "FP8 GEMM A buffer must not be null."); + ORT_ENFORCE(params.B != nullptr, "FP8 GEMM B buffer must not be null."); + ORT_ENFORCE(params.lda >= shape.K, "FP8 GEMM lda must be greater than or equal to K."); + ORT_ENFORCE(params.ldb >= shape.N, "FP8 GEMM ldb must be greater than or equal to N."); + + CheckStridedMatrixOffset(shape.M, shape.K, params.lda); + CheckStridedMatrixOffset(shape.K, shape.N, params.ldb); + } + + if (writes_output) { + ORT_ENFORCE(params.C != nullptr, "FP8 GEMM C buffer must not be null."); + ORT_ENFORCE(params.ldc >= shape.N, "FP8 GEMM ldc must be greater than or equal to N."); + + CheckStridedMatrixOffset(shape.M, shape.N, params.ldc); + } + + const size_t blocks_m = shape.M == 0 ? 0 : ((shape.M - 1) / params.BlockSizeM) + 1; + const size_t blocks_k = shape.K == 0 ? 0 : ((shape.K - 1) / params.BlockSizeK) + 1; + const size_t blocks_n = shape.N == 0 ? 0 : ((shape.N - 1) / params.BlockSizeN) + 1; + + if (reads_reduction_data && params.ScaleA != nullptr) { + CheckBlockCount(blocks_m, params.BlocksM, "M"); + CheckBlockCount(blocks_k, params.BlocksK, "K"); + CheckBlockedMatrixOffset(blocks_m, params.ScaleAStrideM, blocks_k, params.ScaleAStrideK); + } + if (reads_reduction_data && params.ScaleB != nullptr) { + CheckBlockCount(blocks_k, params.BlocksK, "K"); + CheckBlockCount(blocks_n, params.BlocksN, "N"); + CheckBlockedMatrixOffset(blocks_k, params.ScaleBStrideK, blocks_n, params.ScaleBStrideN); + } + if (reads_reduction_data && params.ZeroPointA != nullptr) { + CheckBlockCount(blocks_m, params.BlocksM, "M"); + CheckBlockCount(blocks_k, params.BlocksK, "K"); + CheckBlockedMatrixOffset(blocks_m, params.ZeroPointAStrideM, blocks_k, params.ZeroPointAStrideK); + } + if (reads_reduction_data && params.ZeroPointB != nullptr) { + CheckBlockCount(blocks_k, params.BlocksK, "K"); + CheckBlockCount(blocks_n, params.BlocksN, "N"); + CheckBlockedMatrixOffset(blocks_k, params.ZeroPointBStrideK, blocks_n, params.ZeroPointBStrideN); + } +} + +} // namespace + +void +MLASCALL +MlasFp8GemmBatch( + const MLAS_FP8_GEMM_SHAPE_PARAMS& Shape, + const MLAS_FP8_GEMM_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool + ) +{ + const size_t M = Shape.M; + const size_t N = Shape.N; + const size_t K = Shape.K; + + if (BatchN == 0 || M == 0 || N == 0) { + return; + } + + const ptrdiff_t WorkItems = CheckedWorkItems(BatchN, M); + + ORT_ENFORCE(DataParams != nullptr, "FP8 GEMM data parameters must not be null."); + + for (size_t batch = 0; batch < BatchN; ++batch) { + CheckFp8GemmBatchParams(Shape, DataParams[batch]); + } + + MlasTrySimpleParallel(ThreadPool, WorkItems, [&](ptrdiff_t tid) { + const size_t batch = static_cast(tid) / M; + const size_t m = static_cast(tid) % M; + const auto& params = DataParams[batch]; + + const auto* a_fp8 = static_cast(params.A); + const auto* b_fp8 = static_cast(params.B); + auto* c = static_cast(params.C); + const auto* scale_a = params.ScaleA; + const auto* scale_b = params.ScaleB; + const auto* zp_a = static_cast(params.ZeroPointA); + const auto* zp_b = static_cast(params.ZeroPointB); + + const size_t block_m = m / params.BlockSizeM; + for (size_t n = 0; n < N; ++n) { + const size_t block_n = n / params.BlockSizeN; + float acc = 0.0f; + for (size_t k = 0; k < K; ++k) { + const size_t block_k = k / params.BlockSizeK; + + const size_t a_scale_idx = block_m * params.ScaleAStrideM + block_k * params.ScaleAStrideK; + const size_t b_scale_idx = block_k * params.ScaleBStrideK + block_n * params.ScaleBStrideN; + const float scale_a_val = scale_a ? scale_a[a_scale_idx] : 1.0f; + const float scale_b_val = scale_b ? scale_b[b_scale_idx] : 1.0f; + + const size_t a_zp_idx = block_m * params.ZeroPointAStrideM + block_k * params.ZeroPointAStrideK; + const size_t b_zp_idx = block_k * params.ZeroPointBStrideK + block_n * params.ZeroPointBStrideN; + const float zp_a_val = zp_a ? Fp8ByteToFloat(zp_a[a_zp_idx], params.Fp8Type) : 0.0f; + const float zp_b_val = zp_b ? Fp8ByteToFloat(zp_b[b_zp_idx], params.Fp8Type) : 0.0f; + + const float a_val = Fp8ByteToFloat(a_fp8[m * params.lda + k], params.Fp8Type); + const float b_val = Fp8ByteToFloat(b_fp8[k * params.ldb + n], params.Fp8Type); + + const float a_deq = (a_val - zp_a_val) * scale_a_val; + const float b_deq = (b_val - zp_b_val) * scale_b_val; + acc += a_deq * b_deq; + } + + if (params.ScaleY != nullptr) { + acc *= params.ScaleY[0]; + } + if (params.ZeroPointY != nullptr) { + acc += params.ZeroPointY[0]; + } + c[m * params.ldc + n] = acc; + } + }); +} + +#endif // !defined(DISABLE_FLOAT8_TYPES) diff --git a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc new file mode 100644 index 0000000000000..af6eb8e006476 --- /dev/null +++ b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc @@ -0,0 +1,997 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "gtest/gtest.h" +#include "core/session/inference_session.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/test_environment.h" +#include "test/unittest_util/conversion.h" +#include "default_providers.h" + +#if !defined(DISABLE_FLOAT8_TYPES) +#include "core/common/float8.h" + +namespace onnxruntime { +namespace test { + +class DynamicQuantMatMulFp8SessionTester : public OpTester { + public: + using BaseTester::ExecuteModel; + using BaseTester::FillFeedsAndOutputNames; + using BaseTester::SetTestFunctionCalled; + using OpTester::BuildModel; + using OpTester::OpTester; +}; + +float QuantizeDequantizeE4M3(float value, float scale) { + return Float8E4M3FN(value / scale, true).ToFloat() * scale; +} + +std::vector ComputeExpectedIdentityAWithQuantizedB(gsl::span b_data, + gsl::span b_scale, + int64_t k, + int64_t n, + int64_t block_size_k, + int64_t block_size_n) { + const int64_t blocks_n = n / block_size_n; + std::vector expected(b_data.size()); + for (int64_t row = 0; row < k; ++row) { + for (int64_t col = 0; col < n; ++col) { + const int64_t scale_idx = (row / block_size_k) * blocks_n + (col / block_size_n); + const size_t data_idx = static_cast(row * n + col); + expected[data_idx] = QuantizeDequantizeE4M3(b_data[data_idx], b_scale[scale_idx]); + } + } + return expected; +} + +void RunDynamicQuantMatMulFp8WithSharedPrepack(DynamicQuantMatMulFp8SessionTester& test, + OrtValue& shared_b, + PrepackedWeightsContainer& prepacked_weights_container, + size_t& shared_prepack_count) { + test.SetTestFunctionCalled(); + + auto& model = test.BuildModel(); + Status status = model.MainGraph().Resolve(); + ASSERT_TRUE(status.IsOK()) << status; + + std::unordered_map feeds; + std::vector output_names; + test.FillFeedsAndOutputNames(feeds, output_names); + + SessionOptions so; + status = so.AddInitializer("B", &shared_b); + ASSERT_TRUE(status.IsOK()) << status; + + InferenceSession session{so, GetEnvironment()}; + status = session.AddPrePackedWeightsContainer(&prepacked_weights_container); + ASSERT_TRUE(status.IsOK()) << status; + + status = session.RegisterExecutionProvider(DefaultCpuExecutionProvider()); + ASSERT_TRUE(status.IsOK()) << status; + + test.ExecuteModel(model, + session, + OpTester::ExpectResult::kExpectSuccess, + "", + nullptr, + feeds, + output_names, + kCpuExecutionProvider); + shared_prepack_count = session.GetSessionState().GetUsedSharedPrePackedWeightCounter(); +} + +TEST(DynamicQuantMatMulFp8, WithConstantBInputs) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + const float expected_value = 0.25f * -0.5f * static_cast(K); + std::vector y_data(static_cast(M * N), expected_value); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.5f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, WithOmittedOutputQuantizationInputs) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + const float expected_value = 0.25f * -0.5f * static_cast(K); + std::vector y_data(static_cast(M * N), expected_value); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.5f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, WithOnlyYScale) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + constexpr float y_scale_value = 0.5f; + const float expected_value = 0.25f * -0.5f * static_cast(K) * y_scale_value; + std::vector y_data(static_cast(M * N), expected_value); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_scale{y_scale_value}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.5f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, RejectsNonZeroYZeroPoint) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + std::vector y_data(static_cast(M * N), 0.0f); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_zp{Float8E4M3FN(1.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddOptionalInputEdge(); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 supports symmetric quantization only; Y zero point values must be zero."); +} + +TEST(DynamicQuantMatMulFp8, WithConstantBInputsBf16Scales) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + const float expected_value = 0.25f * -0.5f * static_cast(K); + std::vector y_data(static_cast(M * N), expected_value); + + std::vector a_scale = MakeBFloat16({1.0f}); + std::vector b_scale = MakeBFloat16({1.0f}); + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_scale = MakeBFloat16({1.0f}); + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.5f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, Float16Output) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + const float expected_value = 0.25f * -0.5f * static_cast(K); + std::vector y_data(static_cast(M * N), expected_value); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, FloatsToMLFloat16s(a_data)); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, FloatsToMLFloat16s(y_data)); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, BFloat16Output) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + const float expected_value = 0.25f * -0.5f * static_cast(K); + std::vector y_data(static_cast(M * N), expected_value); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, FloatsToBFloat16s(a_data)); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, FloatsToBFloat16s(y_data)); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, RejectsNonConstantB) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + std::vector y_data(static_cast(M * N), 0.0f); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure); +} + +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInput) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + const float expected_value = 0.25f * -0.5f * static_cast(K); + std::vector y_data(static_cast(M * N), expected_value); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.5f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, RejectsRuntimeFp8BZeroPointTypeMismatch) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector y_data(static_cast(M * N), 0.0f); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{Float8E5M2(0.0f)}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure); +} + +TEST(DynamicQuantMatMulFp8, RejectsNonZeroAZeroPoint) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector y_data(static_cast(M * N), 0.0f); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Float8E4M3FN(1.0f)}; + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 supports symmetric quantization only; A zero point values must be zero."); +} + +TEST(DynamicQuantMatMulFp8, RejectsNonZeroBZeroPoint) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + std::vector y_data(static_cast(M * N), 0.0f); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{Float8E4M3FN(1.0f)}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 supports symmetric quantization only; B zero point values must be zero."); +} + +TEST(DynamicQuantMatMulFp8, NonDefaultBlockSizesWithPartialM) { + constexpr int64_t M = 9; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + + std::vector a_data(static_cast(M * K), 0.0f); + for (int64_t m = 0; m < M; ++m) { + for (int64_t k = 0; k < K; ++k) { + a_data[static_cast(m * K + k)] = (m >= 4 && m <= 7) ? 0.04f : static_cast(1 << k); + } + } + for (int64_t k = 0; k < K; ++k) { + a_data[static_cast(3 * K + k)] = 64.0f; + } + + std::vector b_data(static_cast(K * N), 0.0f); + for (int64_t k = 0; k < K; ++k) { + for (int64_t n = 0; n < N; ++n) { + b_data[static_cast(k * N + n)] = (k == n) ? 1.0f : 0.0f; + } + } + + const std::vector a_scale{1.0f, 1.0f, + 0.01f, 0.01f, + 1.0f, 1.0f}; + const std::vector b_scale{1.0f, 1.0f, + 1.0f, 1.0f}; + std::vector a_zp(a_scale.size(), Float8E4M3FN(0.0f)); + std::vector b_zp(b_scale.size(), Float8E4M3FN(0.0f)); + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + std::vector y_data(static_cast(M * N), 0.0f); + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + sum += a_data[static_cast(m * K + k)] * b_data[static_cast(k * N + n)]; + } + y_data[static_cast(m * N + n)] = sum; + } + } + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("block_size_m", 4); + test.AddAttribute("block_size_k", 2); + test.AddAttribute("block_size_n", 2); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {3, 2}, a_scale); + test.AddInput("A_zero_point", {3, 2}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {2, 2}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {2, 2}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.01f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsRestoresPackedBMetadata) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + + std::vector a_data{ + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 0.5f, 1.0f, -0.5f, -1.0f, + 4.0f, 3.0f, 2.0f, 1.0f}; + std::vector b_data{ + 2.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 2.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 2.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 2.0f}; + + std::vector y_data(a_data.size()); + for (size_t i = 0; i < a_data.size(); ++i) { + y_data[i] = 2.0f * a_data[i]; + } + + std::vector a_scale{1.0f, 1.0f, + 1.0f, 1.0f}; + std::vector b_scale{1.0f, 1.0f, + 1.0f, 1.0f}; + std::vector a_zp(a_scale.size(), Float8E5M2(0.0f)); + std::vector b_zp(b_scale.size(), Float8E5M2(0.0f)); + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E5M2(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("block_size_m", 2); + test.AddAttribute("block_size_k", 2); + test.AddAttribute("block_size_n", 2); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {2, 2}, a_scale); + test.AddInput("A_zero_point", {2, 2}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {2, 2}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {2, 2}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.01f); + + OrtValue b; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({K, N}), b_data.data(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + SessionOptions so; + ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); + + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t prepack_count_session_1 = 0; + size_t shared_prepack_count = 0; + test.Config(so).ConfigEps(cpu_ep()).RunWithConfig(&prepack_count_session_1, &shared_prepack_count); + ASSERT_EQ(shared_prepack_count, static_cast(0)); + ASSERT_EQ(test.GetNumPrePackedWeightsShared(), prepack_count_session_1); + ASSERT_GT(prepack_count_session_1, static_cast(0)); + + size_t prepack_count_session_2 = 0; + test.Config(so).ConfigEps(cpu_ep()).RunWithConfig(&prepack_count_session_2, &shared_prepack_count); + ASSERT_EQ(prepack_count_session_2, prepack_count_session_1); + ASSERT_EQ(shared_prepack_count, prepack_count_session_2); +} + +TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsWithDifferentBScaleKeepCorrectSemantics) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + constexpr int64_t BlockSize = 2; + + const std::vector a_data{ + 1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}; + std::vector b_data{ + 0.31f, -0.37f, 0.83f, -1.70f, + 1.20f, -2.30f, 3.40f, -4.50f, + 5.50f, -6.25f, 7.75f, -8.50f, + 9.00f, -10.50f, 11.25f, -12.75f}; + + std::vector a_scale{1.0f, 1.0f, + 1.0f, 1.0f}; + std::vector a_zp(a_scale.size(), Float8E4M3FN(0.0f)); + std::vector b_zp(a_scale.size(), Float8E4M3FN(0.0f)); + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + std::vector b_scale_1{1.0f, 1.0f, + 1.0f, 1.0f}; + std::vector b_scale_2{0.10f, 0.25f, + 0.50f, 2.00f}; + std::vector y_data_1 = ComputeExpectedIdentityAWithQuantizedB(b_data, b_scale_1, K, N, + BlockSize, BlockSize); + std::vector y_data_2 = ComputeExpectedIdentityAWithQuantizedB(b_data, b_scale_2, K, N, + BlockSize, BlockSize); + + OrtValue b; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({K, N}), b_data.data(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + PrepackedWeightsContainer prepacked_weights_container; + + DynamicQuantMatMulFp8SessionTester test_1("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test_1.AddAttribute("block_size_m", BlockSize); + test_1.AddAttribute("block_size_k", BlockSize); + test_1.AddAttribute("block_size_n", BlockSize); + test_1.AddInput("A", {M, K}, a_data); + test_1.AddInput("A_scale", {2, 2}, a_scale); + test_1.AddInput("A_zero_point", {2, 2}, a_zp); + test_1.AddInput("B", {K, N}, b_data, true /*initializer*/); + test_1.AddInput("B_scale", {2, 2}, b_scale_1, true /*initializer*/); + test_1.AddInput("B_zero_point", {2, 2}, b_zp, true /*initializer*/); + test_1.AddInput("Y_scale", {}, y_scale); + test_1.AddInput("Y_zero_point", {}, y_zp); + test_1.AddOutput("Y", {M, N}, y_data_1); + test_1.SetOutputAbsErr("Y", 1e-5f); + + size_t shared_prepack_count = 0; + RunDynamicQuantMatMulFp8WithSharedPrepack(test_1, b, prepacked_weights_container, shared_prepack_count); + ASSERT_EQ(shared_prepack_count, static_cast(0)); + ASSERT_EQ(prepacked_weights_container.GetNumberOfElements(), static_cast(1)); + + DynamicQuantMatMulFp8SessionTester test_2("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test_2.AddAttribute("block_size_m", BlockSize); + test_2.AddAttribute("block_size_k", BlockSize); + test_2.AddAttribute("block_size_n", BlockSize); + test_2.AddInput("A", {M, K}, a_data); + test_2.AddInput("A_scale", {2, 2}, a_scale); + test_2.AddInput("A_zero_point", {2, 2}, a_zp); + test_2.AddInput("B", {K, N}, b_data, true /*initializer*/); + test_2.AddInput("B_scale", {2, 2}, b_scale_2, true /*initializer*/); + test_2.AddInput("B_zero_point", {2, 2}, b_zp, true /*initializer*/); + test_2.AddInput("Y_scale", {}, y_scale); + test_2.AddInput("Y_zero_point", {}, y_zp); + test_2.AddOutput("Y", {M, N}, y_data_2); + test_2.SetOutputAbsErr("Y", 1e-5f); + + RunDynamicQuantMatMulFp8WithSharedPrepack(test_2, b, prepacked_weights_container, shared_prepack_count); + ASSERT_EQ(shared_prepack_count, static_cast(0)); + ASSERT_EQ(prepacked_weights_container.GetNumberOfElements(), static_cast(2)); +} + +TEST(DynamicQuantMatMulFp8, RejectsMismatchedAScaleBatchPrefix) { + constexpr int64_t Batch = 2; + constexpr int64_t Seq = 3; + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(Batch * Seq * M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(0.5f)); + std::vector y_data(static_cast(Batch * Seq * M * N), 0.0f); + + std::vector a_scale(static_cast(Seq * Batch), 1.0f); + std::vector a_zp(a_scale.size(), Float8E4M3FN(0.0f)); + std::vector b_scale{1.0f}; + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + const std::vector a_dim_params{"batch", "seq", "128", "128"}; + const std::vector a_scale_dim_params{"seq", "batch", "1", "1"}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {Batch, Seq, M, K}, a_data, false, &a_dim_params); + test.AddInput("A_scale", {Seq, Batch, 1, 1}, a_scale, false, &a_scale_dim_params); + test.AddInput("A_zero_point", {Seq, Batch, 1, 1}, a_zp, false, &a_scale_dim_params); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {1, 1}, b_scale); + test.AddInput("B_zero_point", {1, 1}, b_zp); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {Batch, Seq, M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires A scale batch dimensions to match Y."); +} + +TEST(DynamicQuantMatMulFp8, RejectsMalformedAScaleShapeBeforeReadingScaleData) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(0.5f)); + std::vector y_data(static_cast(M * N), 0.0f); + + std::vector a_scale = FloatsToMLFloat16s({1.0f}); + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_scale = FloatsToMLFloat16s({1.0f}); + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_scale = FloatsToMLFloat16s({1.0f}); + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("block_size_m", 4); + test.AddAttribute("block_size_k", 4); + test.AddAttribute("block_size_n", 4); + test.AddShapeToTensorData(false); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {1}, a_scale); + test.AddInput("A_zero_point", {1}, a_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {1, 1}, b_scale); + test.AddInput("B_zero_point", {1, 1}, b_zp); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires A scale to have rank >= 2."); +} + +TEST(DynamicQuantMatMulFp8, RejectsMalformedBScaleShapeBeforeReadingScaleData) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(0.5f)); + std::vector y_data(static_cast(M * N), 0.0f); + + std::vector a_scale = FloatsToMLFloat16s({1.0f}); + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_scale = FloatsToMLFloat16s({1.0f}); + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_scale = FloatsToMLFloat16s({1.0f}); + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("block_size_m", 4); + test.AddAttribute("block_size_k", 4); + test.AddAttribute("block_size_n", 4); + test.AddShapeToTensorData(false); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {1, 1}, a_scale); + test.AddInput("A_zero_point", {1, 1}, a_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {1}, b_scale); + test.AddInput("B_zero_point", {1}, b_zp); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor."); +} + +TEST(DynamicQuantMatMulFp8, ZeroMInput) { + constexpr int64_t M = 0; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.0f); + std::vector b_data(static_cast(K * N), 0.0f); + std::vector y_data{}; + + std::vector a_scale{}; + std::vector b_scale{1.0f}; + std::vector a_zp{}; + std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, ZeroKInput) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 0; + + std::vector a_data{}; + std::vector b_data{}; + std::vector y_data(static_cast(M * N), 0.0f); + + std::vector a_scale{}; + std::vector b_scale{}; + std::vector a_zp{}; + std::vector b_zp{}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, ZeroKInputRejectsInvalidYScaleShape) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 0; + + std::vector a_data{}; + std::vector b_data{}; + std::vector y_data(static_cast(M * N), 0.0f); + + std::vector a_scale{}; + std::vector b_scale{}; + std::vector a_zp{}; + std::vector b_zp{}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddShapeToTensorData(false); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {1}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires Y scale input to be a scalar."); +} + +TEST(DynamicQuantMatMulFp8, ZeroKInputRejectsInvalidYScaleValue) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 0; + + std::vector a_data{}; + std::vector b_data{}; + std::vector y_data(static_cast(M * N), 0.0f); + + std::vector a_scale{}; + std::vector b_scale{}; + std::vector a_zp{}; + std::vector b_zp{}; + std::vector y_scale{0.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Y scale values to be finite and positive."); +} + +TEST(DynamicQuantMatMulFp8, ZeroKInputRejectsInvalidYScaleType) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 0; + + std::vector a_data{}; + std::vector b_data{}; + std::vector y_data(static_cast(M * N), 0.0f); + + std::vector a_scale{}; + std::vector b_scale{}; + std::vector a_zp{}; + std::vector b_zp{}; + std::vector y_scale{1}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Type Error"); +} + +TEST(DynamicQuantMatMulFp8, ZeroNInput) { + constexpr int64_t M = 128; + constexpr int64_t N = 0; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.0f); + std::vector b_data{}; + std::vector y_data{}; + + std::vector a_scale{1.0f}; + std::vector b_scale{}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, ZeroNInputRejectsInvalidYScaleValue) { + constexpr int64_t M = 128; + constexpr int64_t N = 0; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.0f); + std::vector b_data{}; + std::vector y_data{}; + + std::vector a_scale{1.0f}; + std::vector b_scale{}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{}; + std::vector y_scale{0.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Y scale values to be finite and positive."); +} + +TEST(DynamicQuantMatMulFp8, ZeroNInputRejectsNonFp8B) { + constexpr int64_t M = 128; + constexpr int64_t N = 0; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.0f); + std::vector b_data{}; + std::vector y_data{}; + + std::vector a_scale{1.0f}; + std::vector b_scale{}; + std::vector a_zp{Float8E4M3FN(0.0f)}; + std::vector b_zp{}; + std::vector y_scale{1.0f}; + std::vector y_zp{Float8E4M3FN(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires runtime B input to be FP8."); +} + +} // namespace test +} // namespace onnxruntime +#endif // !defined(DISABLE_FLOAT8_TYPES) diff --git a/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp b/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp new file mode 100644 index 0000000000000..78c27c8d487b3 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp @@ -0,0 +1,193 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "mlas.h" +#include "test_util.h" +#include "core/common/float8.h" +#include "core/mlas/inc/mlas.h" + +#include +#include +#include + +#if !defined(DISABLE_FLOAT8_TYPES) + +namespace { + +uint8_t EncodeFp8(float value, mlas_fp8_mode mode) { + using onnxruntime::Float8E4M3FN; + using onnxruntime::Float8E4M3FNUZ; + using onnxruntime::Float8E5M2; + using onnxruntime::Float8E5M2FNUZ; + + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: + return Float8E4M3FN(value).val; + case MLAS_FP8_MODE_E4M3_SAT: + return Float8E4M3FNUZ(value).val; + case MLAS_FP8_MODE_E5M2_INF: + return Float8E5M2(value).val; + case MLAS_FP8_MODE_E5M2_SAT: + return Float8E5M2FNUZ(value).val; + default: + ORT_THROW("Unsupported FP8 GEMM test mode."); + } +} + +void RunFp8GemmBatchThreaded(mlas_fp8_mode mode) { + constexpr size_t BatchN = 2; + constexpr size_t M = 3; + constexpr size_t N = 2; + constexpr size_t K = 4; + constexpr size_t BlockSizeM = 2; + constexpr size_t BlockSizeK = 2; + constexpr size_t BlockSizeN = 1; + constexpr size_t BlocksM = 2; + constexpr size_t BlocksK = 2; + constexpr size_t BlocksN = 2; + constexpr size_t ScaleElements = BlocksM * BlocksK; + constexpr size_t BScaleElements = BlocksK * BlocksN; + + const std::array a_values{ + 1.0f, 2.0f, -1.0f, 0.5f, + -2.0f, 1.5f, 0.0f, 4.0f, + 0.25f, -0.5f, 3.0f, -4.0f, + -1.0f, 0.5f, 2.0f, -2.0f, + 4.0f, -0.25f, -1.5f, 1.0f, + 0.0f, 3.0f, -4.0f, 0.5f}; + const std::array b_values{ + 1.0f, -1.0f, + 0.5f, 2.0f, + -2.0f, 0.25f, + 1.5f, -0.5f, + -0.5f, 1.0f, + 2.0f, -2.0f, + 0.25f, 1.5f, + -1.0f, 0.5f}; + const std::array scale_a{ + 1.0f, 0.5f, + 2.0f, 1.5f, + 0.25f, 1.0f, + 0.5f, 2.0f}; + const std::array scale_b{ + 1.0f, 2.0f, + 0.25f, 1.25f, + 0.5f, 1.0f, + 2.0f, 0.25f}; + const std::array y_scale{0.5f, 2.0f}; + + std::vector a_fp8(a_values.size()); + std::vector b_fp8(b_values.size()); + std::vector zp_a(BatchN * ScaleElements, EncodeFp8(0.0f, mode)); + std::vector zp_b(BatchN * BScaleElements, EncodeFp8(0.0f, mode)); + for (size_t i = 0; i < a_values.size(); ++i) { + a_fp8[i] = EncodeFp8(a_values[i], mode); + } + for (size_t i = 0; i < b_values.size(); ++i) { + b_fp8[i] = EncodeFp8(b_values[i], mode); + } + + std::array output{}; + std::array expected{}; + std::array params{}; + + for (size_t batch = 0; batch < BatchN; ++batch) { + params[batch].A = a_fp8.data() + batch * M * K; + params[batch].lda = K; + params[batch].B = b_fp8.data() + batch * K * N; + params[batch].ldb = N; + params[batch].C = output.data() + batch * M * N; + params[batch].ldc = N; + params[batch].ScaleA = scale_a.data() + batch * ScaleElements; + params[batch].ScaleB = scale_b.data() + batch * BScaleElements; + params[batch].ScaleY = y_scale.data() + batch; + params[batch].ZeroPointA = zp_a.data() + batch * ScaleElements; + params[batch].ZeroPointB = zp_b.data() + batch * BScaleElements; + params[batch].ZeroPointY = nullptr; + params[batch].Fp8Type = mode; + params[batch].BlockSizeM = BlockSizeM; + params[batch].BlockSizeK = BlockSizeK; + params[batch].BlockSizeN = BlockSizeN; + params[batch].BlocksM = BlocksM; + params[batch].BlocksK = BlocksK; + params[batch].BlocksN = BlocksN; + params[batch].ScaleAStrideK = 1; + params[batch].ScaleAStrideM = BlocksK; + params[batch].ScaleBStrideN = 1; + params[batch].ScaleBStrideK = BlocksN; + params[batch].ZeroPointAStrideK = 1; + params[batch].ZeroPointAStrideM = BlocksK; + params[batch].ZeroPointBStrideN = 1; + params[batch].ZeroPointBStrideK = BlocksN; + + for (size_t m = 0; m < M; ++m) { + const size_t block_m = m / BlockSizeM; + for (size_t n = 0; n < N; ++n) { + const size_t block_n = n / BlockSizeN; + float acc = 0.0f; + for (size_t k = 0; k < K; ++k) { + const size_t block_k = k / BlockSizeK; + const size_t a_scale_idx = batch * ScaleElements + block_m * BlocksK + block_k; + const size_t b_scale_idx = batch * BScaleElements + block_k * BlocksN + block_n; + const float a_deq = a_values[batch * M * K + m * K + k] * scale_a[a_scale_idx]; + const float b_deq = b_values[batch * K * N + k * N + n] * scale_b[b_scale_idx]; + acc += a_deq * b_deq; + } + expected[batch * M * N + m * N + n] = acc * y_scale[batch]; + } + } + } + + MLAS_FP8_GEMM_SHAPE_PARAMS shape{M, N, K}; + MLAS_THREADPOOL* threadpool = GetMlasThreadPool(); + if (threadpool == nullptr) { + GTEST_SKIP() << "MlasFp8GemmBatch threaded test requires an MLAS thread pool."; + } + + MlasFp8GemmBatch(shape, params.data(), BatchN, threadpool); + + // Inputs are exactly representable test values, so the scalar fallback should match closely. + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_NEAR(output[i], expected[i], 1e-5f); + } +} + +} // namespace + +TEST(Fp8Gemm, BatchedModesSymmetricThreaded) { + RunFp8GemmBatchThreaded(MLAS_FP8_MODE_E4M3_INF); + RunFp8GemmBatchThreaded(MLAS_FP8_MODE_E4M3_SAT); + RunFp8GemmBatchThreaded(MLAS_FP8_MODE_E5M2_INF); + RunFp8GemmBatchThreaded(MLAS_FP8_MODE_E5M2_SAT); +} + +TEST(Fp8Gemm, EmptyDimensionsSkipUnusedBufferValidation) { + MLAS_FP8_GEMM_DATA_PARAMS params{}; + params.Fp8Type = MLAS_FP8_MODE_E4M3_INF; + params.BlockSizeM = 2; + params.BlockSizeK = 2; + params.BlockSizeN = 2; + + MLAS_FP8_GEMM_SHAPE_PARAMS empty_output_shape{3, 0, 4}; + MlasFp8GemmBatch(empty_output_shape, ¶ms, 1, nullptr); + + std::array output; + output.fill(-1.0f); + + MLAS_FP8_GEMM_SHAPE_PARAMS empty_reduction_shape{3, 2, 0}; + params.C = output.data(); + params.ldc = 2; + MlasFp8GemmBatch(empty_reduction_shape, ¶ms, 1, nullptr); + + for (float value : output) { + EXPECT_EQ(value, 0.0f); + } +} + +TEST(Fp8Gemm, ZeroColumnReturnsBeforeWorkItemOverflow) { + MLAS_FP8_GEMM_SHAPE_PARAMS shape{std::numeric_limits::max(), 0, 4}; + EXPECT_NO_THROW(MlasFp8GemmBatch(shape, nullptr, 2, nullptr)); +} + +#endif // !defined(DISABLE_FLOAT8_TYPES) From b53b1c5ab79eda4e18b4286b89a0e175feee2dff Mon Sep 17 00:00:00 2001 From: melkap01 Date: Fri, 8 May 2026 18:45:14 +0100 Subject: [PATCH 02/11] wording for tile replaced with block Signed-off-by: melkap01 --- .../core/graph/contrib_ops/quantization_defs.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index d055aa9c9333a..b1b641d399c53 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -952,11 +952,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA( DynamicQuantMatMulFp8, 1, OpSchema() .SetDoc("Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from " - "float16/bfloat16/float) and runtime casting of activations to fp8 using tile-wise scales. " + "float16/bfloat16/float) and runtime casting of activations to fp8 using block-wise scales. " "All zero-point inputs, when provided, must encode 0.0.") .Input(0, "A", "Input tensor A.", "TA") .Input(1, "a_scale", - "Scale of quantized input 'A'. Must be a tile-wise tensor with shape " + "Scale of quantized input 'A'. Must be a block-wise tensor with shape " "(ceil(M / block_size_m), K / block_size_k), or the same shape with output batch dimensions prefixed.", "TS") .Input(2, "a_zero_point", @@ -967,7 +967,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "supported when B is a constant initializer that can be quantized during prepack.", "TB") .Input(4, "b_scale", - "Scale of input 'B'. Must be a tile-wise tensor with shape " + "Scale of input 'B'. Must be a block-wise tensor with shape " "(K / block_size_k, N / block_size_n).", "TS") .Input(5, "b_zero_point", @@ -979,11 +979,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided.", "TZ", OpSchema::Optional) .Output(0, "Y", "Output tensor of shape (..., M, N).", "TY") - .Attr("block_size_m", "Block size along M for A tile-wise scales.", AttributeProto::INT, + .Attr("block_size_m", "Block size along M for A block-wise scales.", AttributeProto::INT, static_cast(128)) - .Attr("block_size_k", "Block size along K for A and B tile-wise scales.", AttributeProto::INT, + .Attr("block_size_k", "Block size along K for A and B block-wise scales.", AttributeProto::INT, static_cast(128)) - .Attr("block_size_n", "Block size along N for B tile-wise scales.", AttributeProto::INT, + .Attr("block_size_n", "Block size along N for B block-wise scales.", AttributeProto::INT, static_cast(128)) .TypeConstraint("TA", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input A type to float16, bfloat16, or float.") From c5fe8e105c7b7c5c0eb6af5ab844c3c6ca195659 Mon Sep 17 00:00:00 2001 From: melkap01 Date: Mon, 11 May 2026 15:55:31 +0100 Subject: [PATCH 03/11] cleaning the zp checks from MlasFp8GemmBatch after symmetric quantisation enforced Signed-off-by: melkap01 --- .../quantization/dynamic_quant_matmul_fp8.cc | 7 ------ onnxruntime/core/mlas/inc/mlas.h | 7 ------ onnxruntime/core/mlas/lib/qgemm_fp8.cpp | 24 ++----------------- .../test/mlas/unittest/test_qgemm_fp8.cpp | 9 ------- 4 files changed, 2 insertions(+), 45 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc index 1f5ee0378ad16..7aa25c70393af 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -825,9 +825,6 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { params.ScaleA = a_scales_batch; params.ScaleB = b_scales; params.ScaleY = y_scale_data; - params.ZeroPointA = nullptr; - params.ZeroPointB = nullptr; - params.ZeroPointY = nullptr; params.Fp8Type = a_type; params.BlockSizeM = block_size_m_; params.BlockSizeK = block_size_k_; @@ -839,10 +836,6 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { params.ScaleAStrideM = blocks_k; params.ScaleBStrideN = 1; params.ScaleBStrideK = blocks_n; - params.ZeroPointAStrideK = 1; - params.ZeroPointAStrideM = blocks_k; - params.ZeroPointBStrideN = 1; - params.ZeroPointBStrideK = blocks_n; } MlasFp8GemmBatch(gemm_shape, gemm_data.data(), num_gemms, context->GetOperatorThreadPool()); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 806f76266f633..f49ed543334bd 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -761,9 +761,6 @@ struct MLAS_FP8_GEMM_DATA_PARAMS { const float* ScaleA = nullptr; // Tile scales for A: [BlocksM, BlocksK]. const float* ScaleB = nullptr; // Tile scales for B: [BlocksK, BlocksN]. const float* ScaleY = nullptr; // Scalar scale for Y. - const void* ZeroPointA = nullptr; // Tile zero-points for A: [BlocksM, BlocksK], fp8. - const void* ZeroPointB = nullptr; // Tile zero-points for B: [BlocksK, BlocksN], fp8. - const float* ZeroPointY = nullptr; // Scalar zero-point for Y, in dequantized float units. mlas_fp8_mode Fp8Type = static_cast(0); size_t BlockSizeM = 128; // Block size along M for A quantization. size_t BlockSizeK = 128; // Block size along K for A/B quantization. @@ -775,10 +772,6 @@ struct MLAS_FP8_GEMM_DATA_PARAMS { size_t ScaleAStrideM = 0; // ScaleA stride between M blocks (elements). size_t ScaleBStrideN = 0; // ScaleB stride between N blocks (elements). size_t ScaleBStrideK = 0; // ScaleB stride between K blocks (elements). - size_t ZeroPointAStrideK = 0; // ZeroPointA stride between K blocks (elements). - size_t ZeroPointAStrideM = 0; // ZeroPointA stride between M blocks (elements). - size_t ZeroPointBStrideN = 0; // ZeroPointB stride between N blocks (elements). - size_t ZeroPointBStrideK = 0; // ZeroPointB stride between K blocks (elements). }; #if !defined(DISABLE_FLOAT8_TYPES) diff --git a/onnxruntime/core/mlas/lib/qgemm_fp8.cpp b/onnxruntime/core/mlas/lib/qgemm_fp8.cpp index f434ecab2d87d..a849c7060bed8 100644 --- a/onnxruntime/core/mlas/lib/qgemm_fp8.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_fp8.cpp @@ -128,16 +128,6 @@ inline void CheckFp8GemmBatchParams( CheckBlockCount(blocks_n, params.BlocksN, "N"); CheckBlockedMatrixOffset(blocks_k, params.ScaleBStrideK, blocks_n, params.ScaleBStrideN); } - if (reads_reduction_data && params.ZeroPointA != nullptr) { - CheckBlockCount(blocks_m, params.BlocksM, "M"); - CheckBlockCount(blocks_k, params.BlocksK, "K"); - CheckBlockedMatrixOffset(blocks_m, params.ZeroPointAStrideM, blocks_k, params.ZeroPointAStrideK); - } - if (reads_reduction_data && params.ZeroPointB != nullptr) { - CheckBlockCount(blocks_k, params.BlocksK, "K"); - CheckBlockCount(blocks_n, params.BlocksN, "N"); - CheckBlockedMatrixOffset(blocks_k, params.ZeroPointBStrideK, blocks_n, params.ZeroPointBStrideN); - } } } // namespace @@ -177,8 +167,6 @@ MlasFp8GemmBatch( auto* c = static_cast(params.C); const auto* scale_a = params.ScaleA; const auto* scale_b = params.ScaleB; - const auto* zp_a = static_cast(params.ZeroPointA); - const auto* zp_b = static_cast(params.ZeroPointB); const size_t block_m = m / params.BlockSizeM; for (size_t n = 0; n < N; ++n) { @@ -192,25 +180,17 @@ MlasFp8GemmBatch( const float scale_a_val = scale_a ? scale_a[a_scale_idx] : 1.0f; const float scale_b_val = scale_b ? scale_b[b_scale_idx] : 1.0f; - const size_t a_zp_idx = block_m * params.ZeroPointAStrideM + block_k * params.ZeroPointAStrideK; - const size_t b_zp_idx = block_k * params.ZeroPointBStrideK + block_n * params.ZeroPointBStrideN; - const float zp_a_val = zp_a ? Fp8ByteToFloat(zp_a[a_zp_idx], params.Fp8Type) : 0.0f; - const float zp_b_val = zp_b ? Fp8ByteToFloat(zp_b[b_zp_idx], params.Fp8Type) : 0.0f; - const float a_val = Fp8ByteToFloat(a_fp8[m * params.lda + k], params.Fp8Type); const float b_val = Fp8ByteToFloat(b_fp8[k * params.ldb + n], params.Fp8Type); - const float a_deq = (a_val - zp_a_val) * scale_a_val; - const float b_deq = (b_val - zp_b_val) * scale_b_val; + const float a_deq = a_val * scale_a_val; + const float b_deq = b_val * scale_b_val; acc += a_deq * b_deq; } if (params.ScaleY != nullptr) { acc *= params.ScaleY[0]; } - if (params.ZeroPointY != nullptr) { - acc += params.ZeroPointY[0]; - } c[m * params.ldc + n] = acc; } }); diff --git a/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp b/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp index 78c27c8d487b3..b9d9b3a226cbe 100644 --- a/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp +++ b/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp @@ -79,8 +79,6 @@ void RunFp8GemmBatchThreaded(mlas_fp8_mode mode) { std::vector a_fp8(a_values.size()); std::vector b_fp8(b_values.size()); - std::vector zp_a(BatchN * ScaleElements, EncodeFp8(0.0f, mode)); - std::vector zp_b(BatchN * BScaleElements, EncodeFp8(0.0f, mode)); for (size_t i = 0; i < a_values.size(); ++i) { a_fp8[i] = EncodeFp8(a_values[i], mode); } @@ -102,9 +100,6 @@ void RunFp8GemmBatchThreaded(mlas_fp8_mode mode) { params[batch].ScaleA = scale_a.data() + batch * ScaleElements; params[batch].ScaleB = scale_b.data() + batch * BScaleElements; params[batch].ScaleY = y_scale.data() + batch; - params[batch].ZeroPointA = zp_a.data() + batch * ScaleElements; - params[batch].ZeroPointB = zp_b.data() + batch * BScaleElements; - params[batch].ZeroPointY = nullptr; params[batch].Fp8Type = mode; params[batch].BlockSizeM = BlockSizeM; params[batch].BlockSizeK = BlockSizeK; @@ -116,10 +111,6 @@ void RunFp8GemmBatchThreaded(mlas_fp8_mode mode) { params[batch].ScaleAStrideM = BlocksK; params[batch].ScaleBStrideN = 1; params[batch].ScaleBStrideK = BlocksN; - params[batch].ZeroPointAStrideK = 1; - params[batch].ZeroPointAStrideM = BlocksK; - params[batch].ZeroPointBStrideN = 1; - params[batch].ZeroPointBStrideK = BlocksN; for (size_t m = 0; m < M; ++m) { const size_t block_m = m / BlockSizeM; From a258ad8d98c47d89468721059199d1535b50662c Mon Sep 17 00:00:00 2001 From: melkap01 Date: Tue, 12 May 2026 14:51:58 +0100 Subject: [PATCH 04/11] documentation updated for failing build, copilot comment addressed Signed-off-by: melkap01 --- docs/ContribOperators.md | 65 +++++++++++++++++- docs/OperatorKernels.md | 15 ++-- .../quantization/dynamic_quant_matmul_fp8.cc | 41 +---------- onnxruntime/core/common/fp8_common.h | 68 +++++++++++++++++++ .../graph/contrib_ops/quantization_defs.cc | 16 ++--- onnxruntime/core/mlas/lib/qgemm_fp8.cpp | 41 ++--------- .../test/mlas/unittest/test_qgemm_fp8.cpp | 3 +- 7 files changed, 156 insertions(+), 93 deletions(-) create mode 100644 onnxruntime/core/common/fp8_common.h diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9aa44a1600ae6..e2771b2e4d7f0 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -26,6 +26,7 @@ Do not modify directly.* * com.microsoft.DequantizeBFP * com.microsoft.DequantizeLinear * com.microsoft.DequantizeWithOrder + * com.microsoft.DynamicQuantMatMulFp8 * com.microsoft.DynamicQuantizeLSTM * com.microsoft.DynamicQuantizeMatMul * com.microsoft.DynamicTimeWarping @@ -1493,6 +1494,69 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.DynamicQuantMatMulFp8** + + Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from float16/bfloat16/float) and runtime casting of activations to fp8 using block-wise scales. All zero-point inputs, when provided, must encode 0.0. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
block_size_k : int (default is 128)
+
Block size along K for A and B block-wise scales.
+
block_size_m : int (default is 128)
+
Block size along M for A block-wise scales.
+
block_size_n : int (default is 128)
+
Block size along N for B block-wise scales.
+
+ +#### Inputs (6 - 8) + +
+
A : TA
+
Input tensor A.
+
A_scale : TS
+
Scale of quantized input 'A'. Must be a block-wise tensor with shape (ceil(M / block_size_m), K / block_size_k), or the same shape with output batch dimensions prefixed.
+
A_zero_point : TZ
+
Zero point tensor for input 'A'. Must have the same shape as A_scale and all values must encode 0.0.
+
B : TB
+
Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only supported when B is a constant initializer that can be quantized during prepack.
+
B_scale : TS
+
Scale of input 'B'. Must be a block-wise tensor with shape (K / block_size_k, N / block_size_n).
+
B_zero_point : TZ
+
Zero point tensor for input 'B'. Must have the same shape as B_scale and all values must encode 0.0.
+
Y_scale (optional) : TS
+
Scale of output 'Y'. Must be a scalar when provided.
+
Y_zero_point (optional) : TZ
+
Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided.
+
+ +#### Outputs + +
+
Y : TY
+
Output tensor of shape (..., M, N).
+
+ +#### Type Constraints + +
+
TA : tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain input A type to float16, bfloat16, or float.
+
TB : tensor(float16), tensor(bfloat16), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
+
Constrain input B type to fp8, or to float16, bfloat16, or float for constant initializers.
+
TZ : tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
+
Constrain zero point types to fp8. Only zero-valued zero points are supported.
+
TS : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain scale types to float, float16, or bfloat16.
+
TY : tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain output type to float16, bfloat16, or float.
+
+ + ### **com.microsoft.DynamicQuantizeLSTM** #### Version @@ -6675,4 +6739,3 @@ No versioning maintained for experimental ops.
Constrain input and output types to float32 tensors.
- diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 7596ab7592b25..ab63fe3be38b6 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -577,6 +577,7 @@ The **OpSet Version** column uses the following notation: |CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)| |DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**T** = tensor(float)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float)| +|DynamicQuantMatMulFp8|*in* A:**TA**
*in* A_scale:**TS**
*in* A_zero_point:**TZ**
*in* B:**TB**
*in* B_scale:**TS**
*in* B_zero_point:**TZ**
*in* Y_scale:**TS**
*in* Y_zero_point:**TZ**
*out* Y:**TY**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**TS** = tensor(bfloat16), tensor(float), tensor(float16)
**TY** = tensor(bfloat16), tensor(float), tensor(float16)
**TZ** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)| |DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)| @@ -733,8 +734,7 @@ The **OpSet Version** column uses the following notation: |DynamicSlice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(double), tensor(float), tensor(float16)| |Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| -|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| -|||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| |||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[7, 10]|**T** = tensor(bool), tensor(int32), tensor(int64)| |Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| @@ -902,13 +902,11 @@ The **OpSet Version** column uses the following notation: |||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| |ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||[18, 19]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| |ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| |||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||[18, 19]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| |||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| @@ -937,8 +935,7 @@ The **OpSet Version** column uses the following notation: |||[16, 21]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(int64)| |||[10, 15]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |RotaryEmbedding|*in* X:**T**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**M**
*out* Y:**T**|23+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| -|Round|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 21]|**T** = tensor(double), tensor(float), tensor(float16)| +|Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**

or

*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|25+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[23, 24]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc index 7aa25c70393af..e4734e8b7cd00 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -5,6 +5,7 @@ #include "dynamic_quant_matmul_fp8.h" #include "core/common/common.h" +#include "core/common/fp8_common.h" #include "core/framework/op_kernel.h" #include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas.h" @@ -44,44 +45,6 @@ size_t CeilDiv(size_t value, size_t divisor) { return value == 0 ? 0 : ((value - 1) / divisor) + 1; } -inline float Fp8ByteToFloat(uint8_t value, mlas_fp8_mode mode) { - switch (mode) { - case MLAS_FP8_MODE_E4M3_INF: - return Float8E4M3FN(value, Float8E4M3FN::FromBits()).ToFloat(); - case MLAS_FP8_MODE_E4M3_SAT: - return Float8E4M3FNUZ(value, Float8E4M3FNUZ::FromBits()).ToFloat(); - case MLAS_FP8_MODE_E5M2_INF: - return Float8E5M2(value, Float8E5M2::FromBits()).ToFloat(); - case MLAS_FP8_MODE_E5M2_SAT: - return Float8E5M2FNUZ(value, Float8E5M2FNUZ::FromBits()).ToFloat(); - default: - ORT_THROW("Unsupported FP8 mode."); - } -} - -inline uint8_t FloatToFp8Byte(float value, mlas_fp8_mode mode) { - switch (mode) { - case MLAS_FP8_MODE_E4M3_INF: { - const Float8E4M3FN fp8(value, true); - return fp8.val; - } - case MLAS_FP8_MODE_E4M3_SAT: { - const Float8E4M3FNUZ fp8(value, true); - return fp8.val; - } - case MLAS_FP8_MODE_E5M2_INF: { - const Float8E5M2 fp8(value, true); - return fp8.val; - } - case MLAS_FP8_MODE_E5M2_SAT: { - const Float8E5M2FNUZ fp8(value, true); - return fp8.val; - } - default: - ORT_THROW("Unsupported fp8 mode for DynamicQuantMatMulFp8."); - } -} - bool IsFp8DataType(ONNX_NAMESPACE::TensorProto_DataType elem_type) { return elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ || @@ -207,6 +170,7 @@ void QuantizeBlockwiseFp8A(const SrcT* src, const uint8_t* zero_points, mlas_fp8_mode mode, uint8_t* dst) { + ORT_ENFORCE(onnxruntime::IsValidFp8Mode(mode), "DynamicQuantMatMulFp8 FP8 mode must be valid."); // Block sizes come from op attributes; scale shapes only provide the number of blocks. for (size_t m = 0; m < M; ++m) { const size_t block_m = m / block_size_m; @@ -233,6 +197,7 @@ void QuantizeBlockwiseFp8(const SrcT* src, const uint8_t* zero_points, mlas_fp8_mode mode, uint8_t* dst) { + ORT_ENFORCE(onnxruntime::IsValidFp8Mode(mode), "DynamicQuantMatMulFp8 FP8 mode must be valid."); // Block sizes come from op attributes; scale shapes only provide the number of blocks. const size_t blocks_n = N / block_size_n; for (size_t k = 0; k < K; ++k) { diff --git a/onnxruntime/core/common/fp8_common.h b/onnxruntime/core/common/fp8_common.h new file mode 100644 index 0000000000000..e4cb73df9cd80 --- /dev/null +++ b/onnxruntime/core/common/fp8_common.h @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#pragma once + +#if !defined(DISABLE_FLOAT8_TYPES) + +#include "core/common/float8.h" +#include "core/mlas/inc/mlas.h" + +#include + +namespace onnxruntime { + +inline float Fp8ByteToFloat(uint8_t value, mlas_fp8_mode mode) { + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: + return Float8E4M3FN(value, Float8E4M3FN::FromBits()).ToFloat(); + case MLAS_FP8_MODE_E4M3_SAT: + return Float8E4M3FNUZ(value, Float8E4M3FNUZ::FromBits()).ToFloat(); + case MLAS_FP8_MODE_E5M2_INF: + return Float8E5M2(value, Float8E5M2::FromBits()).ToFloat(); + case MLAS_FP8_MODE_E5M2_SAT: + return Float8E5M2FNUZ(value, Float8E5M2FNUZ::FromBits()).ToFloat(); + default: + return 0.0f; + } +} + +inline uint8_t FloatToFp8Byte(float value, mlas_fp8_mode mode) { + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: { + const Float8E4M3FN fp8(value, true); + return fp8.val; + } + case MLAS_FP8_MODE_E4M3_SAT: { + const Float8E4M3FNUZ fp8(value, true); + return fp8.val; + } + case MLAS_FP8_MODE_E5M2_INF: { + const Float8E5M2 fp8(value, true); + return fp8.val; + } + case MLAS_FP8_MODE_E5M2_SAT: { + const Float8E5M2FNUZ fp8(value, true); + return fp8.val; + } + default: + return 0; + } +} + +inline bool IsValidFp8Mode(mlas_fp8_mode mode) { + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: + case MLAS_FP8_MODE_E4M3_SAT: + case MLAS_FP8_MODE_E5M2_INF: + case MLAS_FP8_MODE_E5M2_SAT: + return true; + default: + return false; + } +} + +} // namespace onnxruntime + +#endif // !defined(DISABLE_FLOAT8_TYPES) diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index b1b641d399c53..d91bf9bebabcf 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -955,27 +955,27 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "float16/bfloat16/float) and runtime casting of activations to fp8 using block-wise scales. " "All zero-point inputs, when provided, must encode 0.0.") .Input(0, "A", "Input tensor A.", "TA") - .Input(1, "a_scale", + .Input(1, "A_scale", "Scale of quantized input 'A'. Must be a block-wise tensor with shape " "(ceil(M / block_size_m), K / block_size_k), or the same shape with output batch dimensions prefixed.", "TS") - .Input(2, "a_zero_point", - "Zero point tensor for input 'A'. Must have the same shape as a_scale and all values must encode 0.0.", + .Input(2, "A_zero_point", + "Zero point tensor for input 'A'. Must have the same shape as A_scale and all values must encode 0.0.", "TZ") .Input(3, "B", "Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only " "supported when B is a constant initializer that can be quantized during prepack.", "TB") - .Input(4, "b_scale", + .Input(4, "B_scale", "Scale of input 'B'. Must be a block-wise tensor with shape " "(K / block_size_k, N / block_size_n).", "TS") - .Input(5, "b_zero_point", - "Zero point tensor for input 'B'. Must have the same shape as b_scale and all values must encode 0.0.", + .Input(5, "B_zero_point", + "Zero point tensor for input 'B'. Must have the same shape as B_scale and all values must encode 0.0.", "TZ") - .Input(6, "y_scale", "Scale of output 'Y'. Must be a scalar when provided.", "TS", + .Input(6, "Y_scale", "Scale of output 'Y'. Must be a scalar when provided.", "TS", OpSchema::Optional) - .Input(7, "y_zero_point", + .Input(7, "Y_zero_point", "Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided.", "TZ", OpSchema::Optional) .Output(0, "Y", "Output tensor of shape (..., M, N).", "TY") diff --git a/onnxruntime/core/mlas/lib/qgemm_fp8.cpp b/onnxruntime/core/mlas/lib/qgemm_fp8.cpp index a849c7060bed8..00f5538547348 100644 --- a/onnxruntime/core/mlas/lib/qgemm_fp8.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_fp8.cpp @@ -7,42 +7,10 @@ #if !defined(DISABLE_FLOAT8_TYPES) #include "core/common/common.h" -#include "core/common/float8.h" +#include "core/common/fp8_common.h" namespace { -inline float Fp8ByteToFloat(uint8_t value, mlas_fp8_mode mode) { - using onnxruntime::Float8E4M3FN; - using onnxruntime::Float8E4M3FNUZ; - using onnxruntime::Float8E5M2; - using onnxruntime::Float8E5M2FNUZ; - switch (mode) { - case MLAS_FP8_MODE_E4M3_INF: - return Float8E4M3FN(value, Float8E4M3FN::FromBits()).ToFloat(); - case MLAS_FP8_MODE_E4M3_SAT: - return Float8E4M3FNUZ(value, Float8E4M3FNUZ::FromBits()).ToFloat(); - case MLAS_FP8_MODE_E5M2_INF: - return Float8E5M2(value, Float8E5M2::FromBits()).ToFloat(); - case MLAS_FP8_MODE_E5M2_SAT: - return Float8E5M2FNUZ(value, Float8E5M2FNUZ::FromBits()).ToFloat(); - default: - MLAS_THROW_EX(std::invalid_argument, "Unsupported FP8 GEMM mode."); - break; - } -} - -inline bool IsValidFp8Mode(mlas_fp8_mode mode) { - switch (mode) { - case MLAS_FP8_MODE_E4M3_INF: - case MLAS_FP8_MODE_E4M3_SAT: - case MLAS_FP8_MODE_E5M2_INF: - case MLAS_FP8_MODE_E5M2_SAT: - return true; - default: - return false; - } -} - // Computes the parallel work item count without wrapping before ptrdiff_t narrowing. inline ptrdiff_t CheckedWorkItems(size_t batch_count, size_t m) { size_t work_items = 0; @@ -89,7 +57,7 @@ inline void CheckBlockCount(size_t actual, size_t supplied, const char* dimensio inline void CheckFp8GemmBatchParams( const MLAS_FP8_GEMM_SHAPE_PARAMS& shape, const MLAS_FP8_GEMM_DATA_PARAMS& params) { - ORT_ENFORCE(IsValidFp8Mode(params.Fp8Type), "FP8 GEMM mode must be valid."); + ORT_ENFORCE(onnxruntime::IsValidFp8Mode(params.Fp8Type), "FP8 GEMM mode must be valid."); ORT_ENFORCE(params.BlockSizeM != 0 && params.BlockSizeK != 0 && params.BlockSizeN != 0, "FP8 GEMM block sizes must be non-zero."); @@ -161,6 +129,7 @@ MlasFp8GemmBatch( const size_t batch = static_cast(tid) / M; const size_t m = static_cast(tid) % M; const auto& params = DataParams[batch]; + ORT_ENFORCE(onnxruntime::IsValidFp8Mode(params.Fp8Type), "FP8 GEMM mode must be valid."); const auto* a_fp8 = static_cast(params.A); const auto* b_fp8 = static_cast(params.B); @@ -180,8 +149,8 @@ MlasFp8GemmBatch( const float scale_a_val = scale_a ? scale_a[a_scale_idx] : 1.0f; const float scale_b_val = scale_b ? scale_b[b_scale_idx] : 1.0f; - const float a_val = Fp8ByteToFloat(a_fp8[m * params.lda + k], params.Fp8Type); - const float b_val = Fp8ByteToFloat(b_fp8[k * params.ldb + n], params.Fp8Type); + const float a_val = onnxruntime::Fp8ByteToFloat(a_fp8[m * params.lda + k], params.Fp8Type); + const float b_val = onnxruntime::Fp8ByteToFloat(b_fp8[k * params.ldb + n], params.Fp8Type); const float a_deq = a_val * scale_a_val; const float b_deq = b_val * scale_b_val; diff --git a/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp b/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp index b9d9b3a226cbe..4c2e82f1c9e2a 100644 --- a/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp +++ b/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp @@ -178,7 +178,8 @@ TEST(Fp8Gemm, EmptyDimensionsSkipUnusedBufferValidation) { TEST(Fp8Gemm, ZeroColumnReturnsBeforeWorkItemOverflow) { MLAS_FP8_GEMM_SHAPE_PARAMS shape{std::numeric_limits::max(), 0, 4}; - EXPECT_NO_THROW(MlasFp8GemmBatch(shape, nullptr, 2, nullptr)); + MlasFp8GemmBatch(shape, nullptr, 2, nullptr); + SUCCEED(); } #endif // !defined(DISABLE_FLOAT8_TYPES) From 7fb80e7b327fc8936cbf17c72e142fd0870dfdd5 Mon Sep 17 00:00:00 2001 From: melkap01 Date: Tue, 12 May 2026 23:44:00 +0100 Subject: [PATCH 05/11] Reusable A buffer implemented before gemm, tests covering all fp8 types implemented Signed-off-by: melkap01 --- docs/ContribOperators.md | 7 +- .../quantization/dynamic_quant_matmul_fp8.cc | 57 ++++---- .../dynamic_quant_matmul_fp8_test.cc | 134 +++++++++++------- 3 files changed, 115 insertions(+), 83 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e2771b2e4d7f0..49f918a93c5e8 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1505,11 +1505,11 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
-
block_size_k : int (default is 128)
+
block_size_k : int
Block size along K for A and B block-wise scales.
-
block_size_m : int (default is 128)
+
block_size_m : int
Block size along M for A block-wise scales.
-
block_size_n : int (default is 128)
+
block_size_n : int
Block size along N for B block-wise scales.
@@ -6738,4 +6738,3 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
- diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc index e4734e8b7cd00..fd44b38af03e3 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -750,61 +750,60 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { gemm_shape.N = N; gemm_shape.K = K; - std::vector gemm_data(num_gemms); - std::vector> a_fp8_buffers(num_gemms); ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(a_scales, a_scale_elems, "A scale")); ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(b_scales, b_scale_elems, "B scale")); const size_t a_fp8_size = SafeMul(M, K); + // Reuse one quantized A buffer to keep scratch bounded when broadcasting creates many GEMMs. + auto a_fp8_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_fp8_size, true); + MLAS_FP8_GEMM_DATA_PARAMS gemm_data; for (size_t gemm_idx = 0; gemm_idx < num_gemms; ++gemm_idx) { - auto& params = gemm_data[gemm_idx]; const size_t a_offset = helper.LeftOffsets()[gemm_idx]; const size_t scale_batch_index = (a_scale_prefix == 1) ? 0 : gemm_idx; const size_t a_scale_batch_offset = SafeMul(scale_batch_index, a_scale_batch_stride); const float* a_scales_batch = a_scales + a_scale_batch_offset; - a_fp8_buffers[gemm_idx] = IAllocator::MakeUniquePtr(temp_allocator, a_fp8_size, true); if (a->IsDataType()) { QuantizeBlockwiseFp8A(a->Data() + a_offset, M, K, block_size_m_, block_size_k_, blocks_k, a_scales_batch, nullptr, a_type, - a_fp8_buffers[gemm_idx].get()); + a_fp8_buffer.get()); } else if (a->IsDataType()) { QuantizeBlockwiseFp8A(a->Data() + a_offset, M, K, block_size_m_, block_size_k_, blocks_k, a_scales_batch, nullptr, a_type, - a_fp8_buffers[gemm_idx].get()); + a_fp8_buffer.get()); } else if (a->IsDataType()) { QuantizeBlockwiseFp8A(a->Data() + a_offset, M, K, block_size_m_, block_size_k_, blocks_k, a_scales_batch, nullptr, a_type, - a_fp8_buffers[gemm_idx].get()); + a_fp8_buffer.get()); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16."); } - params.A = a_fp8_buffers[gemm_idx].get(); - params.lda = K; - params.B = b_fp8 + helper.RightOffsets()[gemm_idx]; - params.ldb = N; - params.C = y_float_data + helper.OutputOffsets()[gemm_idx]; - params.ldc = N; - params.ScaleA = a_scales_batch; - params.ScaleB = b_scales; - params.ScaleY = y_scale_data; - params.Fp8Type = a_type; - params.BlockSizeM = block_size_m_; - params.BlockSizeK = block_size_k_; - params.BlockSizeN = block_size_n_; - params.BlocksM = blocks_m; - params.BlocksK = blocks_k; - params.BlocksN = blocks_n; - params.ScaleAStrideK = 1; - params.ScaleAStrideM = blocks_k; - params.ScaleBStrideN = 1; - params.ScaleBStrideK = blocks_n; + gemm_data.A = a_fp8_buffer.get(); + gemm_data.lda = K; + gemm_data.B = b_fp8 + helper.RightOffsets()[gemm_idx]; + gemm_data.ldb = N; + gemm_data.C = y_float_data + helper.OutputOffsets()[gemm_idx]; + gemm_data.ldc = N; + gemm_data.ScaleA = a_scales_batch; + gemm_data.ScaleB = b_scales; + gemm_data.ScaleY = y_scale_data; + gemm_data.Fp8Type = a_type; + gemm_data.BlockSizeM = block_size_m_; + gemm_data.BlockSizeK = block_size_k_; + gemm_data.BlockSizeN = block_size_n_; + gemm_data.BlocksM = blocks_m; + gemm_data.BlocksK = blocks_k; + gemm_data.BlocksN = blocks_n; + gemm_data.ScaleAStrideK = 1; + gemm_data.ScaleAStrideM = blocks_k; + gemm_data.ScaleBStrideN = 1; + gemm_data.ScaleBStrideK = blocks_n; + + MlasFp8GemmBatch(gemm_shape, &gemm_data, 1, context->GetOperatorThreadPool()); } - MlasFp8GemmBatch(gemm_shape, gemm_data.data(), num_gemms, context->GetOperatorThreadPool()); - if (y_float_buffer != nullptr) { if (y->IsDataType()) { auto* y_data = y->MutableData(); diff --git a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc index af6eb8e006476..b8f459f0d1d4b 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc @@ -46,6 +46,70 @@ std::vector ComputeExpectedIdentityAWithQuantizedB(gsl::span return expected; } +template +void RunRuntimeFp8BInput() { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Fp8T(-0.5f)); + const float expected_value = 0.25f * -0.5f * static_cast(K); + std::vector y_data(static_cast(M * N), expected_value); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Fp8T(0.0f)}; + std::vector b_zp{Fp8T(0.0f)}; + std::vector y_scale{1.0f}; + std::vector y_zp{Fp8T(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.5f); + test.Run(); +} + +template +void RunConstantBInputs() { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + const float expected_value = 0.25f * -0.5f * static_cast(K); + std::vector y_data(static_cast(M * N), expected_value); + + std::vector a_scale{1.0f}; + std::vector b_scale{1.0f}; + std::vector a_zp{Fp8T(0.0f)}; + std::vector b_zp{Fp8T(0.0f)}; + std::vector y_scale{1.0f}; + std::vector y_zp{Fp8T(0.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("A_scale", {M / 128, K / 128}, a_scale); + test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); + test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("Y_scale", {}, y_scale); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.5f); + test.Run(); +} + void RunDynamicQuantMatMulFp8WithSharedPrepack(DynamicQuantMatMulFp8SessionTester& test, OrtValue& shared_b, PrepackedWeightsContainer& prepacked_weights_container, @@ -83,34 +147,19 @@ void RunDynamicQuantMatMulFp8WithSharedPrepack(DynamicQuantMatMulFp8SessionTeste } TEST(DynamicQuantMatMulFp8, WithConstantBInputs) { - constexpr int64_t M = 128; - constexpr int64_t N = 128; - constexpr int64_t K = 128; + RunConstantBInputs(); +} - std::vector a_data(static_cast(M * K), 0.25f); - std::vector b_data(static_cast(K * N), -0.5f); - const float expected_value = 0.25f * -0.5f * static_cast(K); - std::vector y_data(static_cast(M * N), expected_value); +TEST(DynamicQuantMatMulFp8, WithConstantBInputsE4M3FNUZ) { + RunConstantBInputs(); +} - std::vector a_scale{1.0f}; - std::vector b_scale{1.0f}; - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_zp{Float8E4M3FN(0.0f)}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; +TEST(DynamicQuantMatMulFp8, WithConstantBInputsE5M2) { + RunConstantBInputs(); +} - OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); - test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); - test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); - test.AddOutput("Y", {M, N}, y_data); - test.SetOutputAbsErr("Y", 0.5f); - test.Run(); +TEST(DynamicQuantMatMulFp8, WithConstantBInputsE5M2FNUZ) { + RunConstantBInputs(); } TEST(DynamicQuantMatMulFp8, WithOmittedOutputQuantizationInputs) { @@ -320,34 +369,19 @@ TEST(DynamicQuantMatMulFp8, RejectsNonConstantB) { } TEST(DynamicQuantMatMulFp8, RuntimeFp8BInput) { - constexpr int64_t M = 128; - constexpr int64_t N = 128; - constexpr int64_t K = 128; + RunRuntimeFp8BInput(); +} - std::vector a_data(static_cast(M * K), 0.25f); - std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); - const float expected_value = 0.25f * -0.5f * static_cast(K); - std::vector y_data(static_cast(M * N), expected_value); +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE4M3FNUZ) { + RunRuntimeFp8BInput(); +} - std::vector a_scale{1.0f}; - std::vector b_scale{1.0f}; - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_zp{Float8E4M3FN(0.0f)}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE5M2) { + RunRuntimeFp8BInput(); +} - OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); - test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); - test.AddInput("B", {K, N}, b_data); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); - test.AddOutput("Y", {M, N}, y_data); - test.SetOutputAbsErr("Y", 0.5f); - test.Run(); +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE5M2FNUZ) { + RunRuntimeFp8BInput(); } TEST(DynamicQuantMatMulFp8, RejectsRuntimeFp8BZeroPointTypeMismatch) { From 1966da42c30a428044c17e5e0fea33e15af3f212 Mon Sep 17 00:00:00 2001 From: melkap01 Date: Wed, 13 May 2026 13:35:56 +0100 Subject: [PATCH 06/11] redundant lines removed Signed-off-by: melkap01 --- onnxruntime/core/mlas/lib/qgemm_fp8.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qgemm_fp8.cpp b/onnxruntime/core/mlas/lib/qgemm_fp8.cpp index 00f5538547348..fa0115db89d5b 100644 --- a/onnxruntime/core/mlas/lib/qgemm_fp8.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_fp8.cpp @@ -49,10 +49,6 @@ inline void CheckBlockedMatrixOffset(size_t blocks0, size_t stride0, size_t bloc "FP8 GEMM block offset overflow."); } -inline void CheckBlockCount(size_t actual, size_t supplied, const char* dimension_name) { - ORT_ENFORCE(actual == supplied, "FP8 GEMM ", dimension_name, " block count must match shape and block size."); -} - // Validate caller-provided buffers, strides, and block metadata before parallel workers dereference them. inline void CheckFp8GemmBatchParams( const MLAS_FP8_GEMM_SHAPE_PARAMS& shape, @@ -87,13 +83,13 @@ inline void CheckFp8GemmBatchParams( const size_t blocks_n = shape.N == 0 ? 0 : ((shape.N - 1) / params.BlockSizeN) + 1; if (reads_reduction_data && params.ScaleA != nullptr) { - CheckBlockCount(blocks_m, params.BlocksM, "M"); - CheckBlockCount(blocks_k, params.BlocksK, "K"); + ORT_ENFORCE(blocks_m == params.BlocksM, "FP8 GEMM M block count must match shape and block size."); + ORT_ENFORCE(blocks_k == params.BlocksK, "FP8 GEMM K block count must match shape and block size."); CheckBlockedMatrixOffset(blocks_m, params.ScaleAStrideM, blocks_k, params.ScaleAStrideK); } if (reads_reduction_data && params.ScaleB != nullptr) { - CheckBlockCount(blocks_k, params.BlocksK, "K"); - CheckBlockCount(blocks_n, params.BlocksN, "N"); + ORT_ENFORCE(blocks_k == params.BlocksK, "FP8 GEMM K block count must match shape and block size."); + ORT_ENFORCE(blocks_n == params.BlocksN, "FP8 GEMM N block count must match shape and block size."); CheckBlockedMatrixOffset(blocks_k, params.ScaleBStrideK, blocks_n, params.ScaleBStrideN); } } From 98ea9fff22916dfdb3ffdeca250c0e0c42a26b12 Mon Sep 17 00:00:00 2001 From: melkap01 Date: Wed, 13 May 2026 17:51:36 +0100 Subject: [PATCH 07/11] Optimize DynamicQuantMatMulFp8 A quantization Signed-off-by: melkap01 --- docs/ContribOperators.md | 2 +- .../quantization/dynamic_quant_matmul_fp8.cc | 167 ++++++++++-------- .../graph/contrib_ops/quantization_defs.cc | 46 ++--- .../dynamic_quant_matmul_fp8_test.cc | 2 +- 4 files changed, 116 insertions(+), 101 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 95b25a879f9c5..82b2927e526d0 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1519,7 +1519,7 @@ This version of the operator has been available since version 1 of the 'com.micr
A : TA
Input tensor A.
A_scale : TS
-
Scale of quantized input 'A'. Must be a block-wise tensor with shape (ceil(M / block_size_m), K / block_size_k), or the same shape with output batch dimensions prefixed.
+
Scale of quantized input 'A'. Must be a block-wise tensor with shape (ceil(M / block_size_m), K / block_size_k), or the same shape with A batch dimensions prefixed.
A_zero_point : TZ
Zero point tensor for input 'A'. Must have the same shape as A_scale and all values must encode 0.0.
B : TB
diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc index fd44b38af03e3..d5ece3bed83d9 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -12,12 +12,15 @@ #include "core/common/float16.h" #include "core/common/float8.h" #include "core/common/safeint.h" +#include "core/platform/threadpool.h" #include "core/providers/cpu/math/matmul_helper.h" #include #include #include #include +#include +#include namespace onnxruntime { namespace contrib { @@ -160,28 +163,28 @@ Status ValidateZeroPointValuesAreZero(const Tensor& zero_point, size_t expected_ } template -void QuantizeBlockwiseFp8A(const SrcT* src, - size_t M, - size_t K, - size_t block_size_m, - size_t block_size_k, - size_t blocks_k, - const float* scales, - const uint8_t* zero_points, - mlas_fp8_mode mode, - uint8_t* dst) { - ORT_ENFORCE(onnxruntime::IsValidFp8Mode(mode), "DynamicQuantMatMulFp8 FP8 mode must be valid."); - // Block sizes come from op attributes; scale shapes only provide the number of blocks. - for (size_t m = 0; m < M; ++m) { - const size_t block_m = m / block_size_m; +void QuantizeBlockwiseFp8ABlock(const SrcT* src, + size_t M, + size_t K, + size_t block_size_m, + size_t block_size_k, + size_t blocks_k, + size_t block_m, + size_t block_k, + const float* scales, + mlas_fp8_mode mode, + uint8_t* dst) { + const size_t m_begin = block_m * block_size_m; + const size_t m_end = std::min(M, m_begin + block_size_m); + const size_t k_begin = block_k * block_size_k; + const size_t k_end = std::min(K, k_begin + block_size_k); + const size_t idx = block_m * blocks_k + block_k; + const float scale = scales[idx]; + for (size_t m = m_begin; m < m_end; ++m) { const size_t row_offset = m * K; - for (size_t k = 0; k < K; ++k) { - const size_t block_k = k / block_size_k; - const size_t idx = block_m * blocks_k + block_k; - const float scale = scales[idx]; - const float zp = zero_points ? Fp8ByteToFloat(zero_points[idx], mode) : 0.0f; + for (size_t k = k_begin; k < k_end; ++k) { const float value = static_cast(src[row_offset + k]); - const float quantized = (value / scale) + zp; + const float quantized = value / scale; dst[row_offset + k] = FloatToFp8Byte(quantized, mode); } } @@ -194,10 +197,7 @@ void QuantizeBlockwiseFp8(const SrcT* src, size_t block_size_k, size_t block_size_n, const float* scales, - const uint8_t* zero_points, - mlas_fp8_mode mode, uint8_t* dst) { - ORT_ENFORCE(onnxruntime::IsValidFp8Mode(mode), "DynamicQuantMatMulFp8 FP8 mode must be valid."); // Block sizes come from op attributes; scale shapes only provide the number of blocks. const size_t blocks_n = N / block_size_n; for (size_t k = 0; k < K; ++k) { @@ -207,9 +207,8 @@ void QuantizeBlockwiseFp8(const SrcT* src, const size_t block_n = n / block_size_n; const size_t scale_idx = block_k * blocks_n + block_n; const float scale = scales[scale_idx]; - const float zp = zero_points ? Fp8ByteToFloat(zero_points[scale_idx], mode) : 0.0f; const float value = static_cast(src[row_offset + n]); - const float quantized = (value / scale) + zp; + const float quantized = value / scale; const Fp8T fp8_value(quantized, true); dst[row_offset + n] = fp8_value.val; } @@ -224,25 +223,20 @@ Status QuantizeToFp8ByMode(mlas_fp8_mode fp8_mode, size_t block_size_k, size_t block_size_n, const float* scales, - const uint8_t* zero_points, uint8_t* dst) { // Dispatch quantization using the requested FP8 mode and runtime block sizes. switch (fp8_mode) { case MLAS_FP8_MODE_E4M3_INF: - QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, - zero_points, fp8_mode, dst); + QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, dst); return Status::OK(); case MLAS_FP8_MODE_E4M3_SAT: - QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, - zero_points, fp8_mode, dst); + QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, dst); return Status::OK(); case MLAS_FP8_MODE_E5M2_INF: - QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, zero_points, - fp8_mode, dst); + QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, dst); return Status::OK(); case MLAS_FP8_MODE_E5M2_SAT: - QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, - zero_points, fp8_mode, dst); + QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, dst); return Status::OK(); default: return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -416,13 +410,13 @@ Status DynamicQuantMatMulFp8::PrePack(const Tensor& tensor, int input_idx, Alloc auto* quantized_b_bytes = static_cast(quantized_b_.get()); if (tensor.IsDataType()) { ORT_RETURN_IF_ERROR(QuantizeToFp8ByMode(b_type, tensor.Data(), K, N, block_size_k_, block_size_n_, - b_scales, nullptr, quantized_b_bytes)); + b_scales, quantized_b_bytes)); } else if (tensor.IsDataType()) { ORT_RETURN_IF_ERROR(QuantizeToFp8ByMode(b_type, tensor.Data(), K, N, block_size_k_, block_size_n_, - b_scales, nullptr, quantized_b_bytes)); + b_scales, quantized_b_bytes)); } else if (tensor.IsDataType()) { ORT_RETURN_IF_ERROR(QuantizeToFp8ByMode(b_type, tensor.Data(), K, N, block_size_k_, block_size_n_, - b_scales, nullptr, quantized_b_bytes)); + b_scales, quantized_b_bytes)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported B type for DynamicQuantMatMulFp8 prepack."); @@ -625,12 +619,12 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { "DynamicQuantMatMulFp8 requires A scale and zero point to have the same shape."); } if (a_scale_rank != 2) { - const size_t y_rank = y->Shape().NumDimensions(); - ORT_RETURN_IF(a_scale_rank != y_rank, - "DynamicQuantMatMulFp8 requires A scale rank to be 2 or match Y rank."); - for (size_t dim = 0; dim < y_rank - 2; ++dim) { - ORT_RETURN_IF(a_scale->Shape()[dim] != y->Shape()[dim], - "DynamicQuantMatMulFp8 requires A scale batch dimensions to match Y."); + const size_t a_rank = a->Shape().NumDimensions(); + ORT_RETURN_IF(a_scale_rank != a_rank, + "DynamicQuantMatMulFp8 requires A scale rank to be 2 or match A rank."); + for (size_t dim = 0; dim < a_rank - 2; ++dim) { + ORT_RETURN_IF(a_scale->Shape()[dim] != a->Shape()[dim], + "DynamicQuantMatMulFp8 requires A scale batch dimensions to match A."); } } // Scale tensor block counts must match the explicit block-size attributes before reading scale data. @@ -671,8 +665,6 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { a_scale_prefix, static_cast(a_scale->Shape()[dim])); } } - ORT_RETURN_IF(a_scale_prefix != 1 && a_scale_prefix != num_gemms, - "DynamicQuantMatMulFp8 requires A scale batch dimensions to match the number of gemms."); const size_t a_scale_batch_stride = SafeMul(blocks_m, blocks_k); const size_t a_zp_count = SafeMul(a_scale_prefix, a_scale_batch_stride); const size_t b_zp_count = SafeMul(blocks_k, blocks_n); @@ -754,33 +746,68 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(b_scales, b_scale_elems, "B scale")); const size_t a_fp8_size = SafeMul(M, K); - // Reuse one quantized A buffer to keep scratch bounded when broadcasting creates many GEMMs. - auto a_fp8_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_fp8_size, true); - MLAS_FP8_GEMM_DATA_PARAMS gemm_data; + const size_t a_num_elements = static_cast(a->Shape().Size()); + ORT_RETURN_IF(a_num_elements % a_fp8_size != 0, + "DynamicQuantMatMulFp8 requires A to contain complete MxK matrices."); + const size_t a_batch_count = a_num_elements / a_fp8_size; + ORT_RETURN_IF(a_scale_prefix != 1 && a_scale_prefix != a_batch_count, + "DynamicQuantMatMulFp8 requires A scale batch dimensions to match A."); + + // Quantize the physical A tensor once. Broadcasted output GEMMs then reuse the same FP8 A slice. + auto a_fp8_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_num_elements, true); + const size_t a_quant_work_items = SafeMul(a_batch_count, a_scale_batch_stride); + ORT_RETURN_IF(a_quant_work_items > static_cast(std::numeric_limits::max()), + "DynamicQuantMatMulFp8 A quantization work item count exceeds ptrdiff_t range."); + const auto a_quant_work_items_i = static_cast(a_quant_work_items); + const size_t a_quant_block_elems = SafeMul(block_size_m_, block_size_k_); + const TensorOpCost a_quant_unit_cost{ + static_cast(SafeMul(a_quant_block_elems, sizeof(float))), + static_cast(SafeMul(a_quant_block_elems, sizeof(uint8_t))), + static_cast(a_quant_block_elems) * 2.0}; + const auto quantize_a_batches = [&](const auto* a_data) { + concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), a_quant_work_items_i, + a_quant_unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t tid = begin; tid < end; ++tid) { + const size_t work_idx = static_cast(tid); + const size_t a_batch_idx = work_idx / a_scale_batch_stride; + const size_t scale_block_idx = work_idx % a_scale_batch_stride; + const size_t block_m = scale_block_idx / blocks_k; + const size_t block_k = scale_block_idx % blocks_k; + const size_t a_batch_offset = a_batch_idx * a_fp8_size; + const size_t scale_batch_index = (a_scale_prefix == 1) ? 0 : a_batch_idx; + const size_t a_scale_batch_offset = scale_batch_index * a_scale_batch_stride; + QuantizeBlockwiseFp8ABlock(a_data + a_batch_offset, + M, K, block_size_m_, block_size_k_, blocks_k, + block_m, block_k, + a_scales + a_scale_batch_offset, a_type, + a_fp8_buffer.get() + a_batch_offset); + } + }); + }; + if (a->IsDataType()) { + quantize_a_batches(a->Data()); + } else if (a->IsDataType()) { + quantize_a_batches(a->Data()); + } else if (a->IsDataType()) { + quantize_a_batches(a->Data()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16."); + } + + std::vector gemm_data_vec(num_gemms); for (size_t gemm_idx = 0; gemm_idx < num_gemms; ++gemm_idx) { const size_t a_offset = helper.LeftOffsets()[gemm_idx]; - const size_t scale_batch_index = (a_scale_prefix == 1) ? 0 : gemm_idx; + ORT_RETURN_IF(a_offset >= a_num_elements || (a_offset % a_fp8_size) != 0, + "DynamicQuantMatMulFp8 requires A offsets to reference complete MxK matrices."); + const size_t scale_batch_index = (a_scale_prefix == 1) ? 0 : a_offset / a_fp8_size; + ORT_RETURN_IF(scale_batch_index >= a_scale_prefix, + "DynamicQuantMatMulFp8 requires A scale batch dimensions to match A."); const size_t a_scale_batch_offset = SafeMul(scale_batch_index, a_scale_batch_stride); const float* a_scales_batch = a_scales + a_scale_batch_offset; - if (a->IsDataType()) { - QuantizeBlockwiseFp8A(a->Data() + a_offset, M, K, block_size_m_, block_size_k_, blocks_k, a_scales_batch, - nullptr, a_type, - a_fp8_buffer.get()); - } else if (a->IsDataType()) { - QuantizeBlockwiseFp8A(a->Data() + a_offset, M, K, block_size_m_, block_size_k_, blocks_k, - a_scales_batch, - nullptr, a_type, - a_fp8_buffer.get()); - } else if (a->IsDataType()) { - QuantizeBlockwiseFp8A(a->Data() + a_offset, M, K, block_size_m_, block_size_k_, blocks_k, - a_scales_batch, - nullptr, a_type, - a_fp8_buffer.get()); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16."); - } - gemm_data.A = a_fp8_buffer.get(); + auto& gemm_data = gemm_data_vec[gemm_idx]; + gemm_data.A = a_fp8_buffer.get() + a_offset; gemm_data.lda = K; gemm_data.B = b_fp8 + helper.RightOffsets()[gemm_idx]; gemm_data.ldb = N; @@ -800,10 +827,10 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { gemm_data.ScaleAStrideM = blocks_k; gemm_data.ScaleBStrideN = 1; gemm_data.ScaleBStrideK = blocks_n; - - MlasFp8GemmBatch(gemm_shape, &gemm_data, 1, context->GetOperatorThreadPool()); } + MlasFp8GemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, context->GetOperatorThreadPool()); + if (y_float_buffer != nullptr) { if (y->IsDataType()) { auto* y_data = y->MutableData(); diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index d91bf9bebabcf..1c459ef1e9fbf 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -957,7 +957,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(0, "A", "Input tensor A.", "TA") .Input(1, "A_scale", "Scale of quantized input 'A'. Must be a block-wise tensor with shape " - "(ceil(M / block_size_m), K / block_size_k), or the same shape with output batch dimensions prefixed.", + "(ceil(M / block_size_m), K / block_size_k), or the same shape with A batch dimensions prefixed.", "TS") .Input(2, "A_zero_point", "Zero point tensor for input 'A'. Must have the same shape as A_scale and all values must encode 0.0.", @@ -1022,23 +1022,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA( } const int a_scale_rank = a_scale_shape.dim_size(); if (a_scale_rank < 2) { - fail_type_inference("A scale must have rank 2 or the same rank as Y."); + fail_type_inference("A scale must have rank 2 or the same rank as A."); } if (a_scale_rank != 2) { - if (hasInputShape(ctx, 3)) { - const auto& y_shape = ctx.getOutputType(0)->tensor_type().shape(); - const int y_rank = y_shape.dim_size(); - if (a_scale_rank != y_rank) { - fail_type_inference("A scale must have rank 2 or the same rank as Y."); - } - for (int i = 0; i < y_rank - 2; ++i) { - if (y_shape.dim(i).has_dim_value() && a_scale_shape.dim(i).has_dim_value() && - y_shape.dim(i).dim_value() != a_scale_shape.dim(i).dim_value()) { - fail_type_inference("A scale batch dimensions must match Y."); - } + if (a_scale_rank != a_rank) { + fail_type_inference("A scale must have rank 2 or the same rank as A."); + } + for (int i = 0; i < a_rank - 2; ++i) { + if (a_shape.dim(i).has_dim_value() && a_scale_shape.dim(i).has_dim_value() && + a_shape.dim(i).dim_value() != a_scale_shape.dim(i).dim_value()) { + fail_type_inference("A scale batch dimensions must match A."); } - } else if (a_scale_rank != a_rank) { - fail_type_inference("A scale must have rank 2 or the same rank as Y."); } } if (a_shape.dim(a_rank - 2).has_dim_value() && a_scale_shape.dim(a_scale_rank - 2).has_dim_value() && @@ -1064,23 +1058,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA( } const int a_zp_rank = a_zp_shape.dim_size(); if (a_zp_rank < 2) { - fail_type_inference("A zero point must have rank 2 or the same rank as Y."); + fail_type_inference("A zero point must have rank 2 or the same rank as A."); } if (a_zp_rank != 2) { - if (hasInputShape(ctx, 3)) { - const auto& y_shape = ctx.getOutputType(0)->tensor_type().shape(); - const int y_rank = y_shape.dim_size(); - if (a_zp_rank != y_rank) { - fail_type_inference("A zero point must have rank 2 or the same rank as Y."); - } - for (int i = 0; i < y_rank - 2; ++i) { - if (y_shape.dim(i).has_dim_value() && a_zp_shape.dim(i).has_dim_value() && - y_shape.dim(i).dim_value() != a_zp_shape.dim(i).dim_value()) { - fail_type_inference("A zero point batch dimensions must match Y."); - } + if (a_zp_rank != a_rank) { + fail_type_inference("A zero point must have rank 2 or the same rank as A."); + } + for (int i = 0; i < a_rank - 2; ++i) { + if (a_shape.dim(i).has_dim_value() && a_zp_shape.dim(i).has_dim_value() && + a_shape.dim(i).dim_value() != a_zp_shape.dim(i).dim_value()) { + fail_type_inference("A zero point batch dimensions must match A."); } - } else if (a_zp_rank != a_rank) { - fail_type_inference("A zero point must have rank 2 or the same rank as Y."); } } if (a_shape.dim(a_rank - 2).has_dim_value() && a_zp_shape.dim(a_zp_rank - 2).has_dim_value() && diff --git a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc index b8f459f0d1d4b..08f5f1cfe5f91 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc @@ -717,7 +717,7 @@ TEST(DynamicQuantMatMulFp8, RejectsMismatchedAScaleBatchPrefix) { test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {Batch, Seq, M, N}, y_data); test.Run(OpTester::ExpectResult::kExpectFailure, - "DynamicQuantMatMulFp8 requires A scale batch dimensions to match Y."); + "DynamicQuantMatMulFp8 requires A scale batch dimensions to match A."); } TEST(DynamicQuantMatMulFp8, RejectsMalformedAScaleShapeBeforeReadingScaleData) { From 4655e91426bffa9c3dd697d7825f1f82ac6d5eb5 Mon Sep 17 00:00:00 2001 From: melkap01 Date: Mon, 18 May 2026 22:46:20 +0100 Subject: [PATCH 08/11] documentation difference patched Signed-off-by: melkap01 --- docs/OperatorKernels.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9ee4aa7a43d65..8c5ca73276b5d 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -736,7 +736,8 @@ The **OpSet Version** column uses the following notation: |DynamicSlice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(double), tensor(float), tensor(float16)| |Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| -|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| |||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[7, 10]|**T** = tensor(bool), tensor(int32), tensor(int64)| |Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| @@ -909,11 +910,13 @@ The **OpSet Version** column uses the following notation: |||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| |ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[18, 19]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| |||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[18, 19]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| |||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| @@ -942,7 +945,8 @@ The **OpSet Version** column uses the following notation: |||[16, 21]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(int64)| |||[10, 15]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |RotaryEmbedding|*in* X:**T**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**M**
*out* Y:**T**|23+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| -|Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| +|Round|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[11, 21]|**T** = tensor(double), tensor(float), tensor(float16)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**

or

*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|25+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[23, 24]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| From 93592f399772758f8d43a2994bd5f0b29fcc65d8 Mon Sep 17 00:00:00 2001 From: melkap01 Date: Wed, 20 May 2026 12:21:14 +0100 Subject: [PATCH 09/11] answering copilot comment regarding N==0 case Signed-off-by: melkap01 --- .../quantization/dynamic_quant_matmul_fp8.cc | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc index d5ece3bed83d9..3daf07e4a2f1a 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -328,7 +328,7 @@ Status DynamicQuantMatMulFp8::PrePack(const Tensor& tensor, int input_idx, Alloc const size_t K = static_cast(b_shape_[0]); const size_t N = static_cast(b_shape_[1]); - if (K == 0 || N == 0) { + if (K == 0) { return Status::OK(); } @@ -357,6 +357,10 @@ Status DynamicQuantMatMulFp8::PrePack(const Tensor& tensor, int input_idx, Alloc b_type_ = b_type; has_b_type_ = true; + if (N == 0) { + return Status::OK(); + } + ORT_RETURN_IF_NOT(b_scale.IsDataType() || b_scale.IsDataType() || b_scale.IsDataType(), "DynamicQuantMatMulFp8 requires B scale input to be float, float16, or bfloat16."); ORT_RETURN_IF_NOT(b_scale.Shape().NumDimensions() == 2, @@ -555,18 +559,6 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { static_cast(b_zero_point->GetElementType()); const bool b_is_fp8 = IsFp8DataType(b_elem_type); - // Select the FP8 B buffer: prefer pre-quantized B from PrePack, otherwise accept FP8-typed B input. - const uint8_t* b_fp8 = nullptr; - if (quantized_b_) { - b_fp8 = static_cast(quantized_b_.get()); - } else if (b_is_fp8) { - b_fp8 = static_cast(b->DataRaw()); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "DynamicQuantMatMulFp8 requires runtime B input to be FP8. Non-FP8 B is only supported " - "when B is a constant initializer that can be quantized during prepack."); - } - mlas_fp8_mode a_type{}; ORT_RETURN_IF(a_zero_point == nullptr, "DynamicQuantMatMulFp8 requires FP8 zero point for A."); @@ -592,6 +584,18 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { return Status::OK(); } + // Select the FP8 B buffer: prefer pre-quantized B from PrePack, otherwise accept FP8-typed B input. + const uint8_t* b_fp8 = nullptr; + if (quantized_b_) { + b_fp8 = static_cast(quantized_b_.get()); + } else if (b_is_fp8) { + b_fp8 = static_cast(b->DataRaw()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires runtime B input to be FP8. Non-FP8 B is only supported " + "when B is a constant initializer that can be quantized during prepack."); + } + const size_t num_gemms = helper.OutputOffsets().size(); const size_t a_scale_rank = static_cast(a_scale->Shape().NumDimensions()); const size_t a_zp_rank = static_cast(a_zero_point->Shape().NumDimensions()); From f5424a9b27a660c44ec34488f8e1401d003b7a57 Mon Sep 17 00:00:00 2001 From: melkap01 Date: Wed, 20 May 2026 23:20:11 +0100 Subject: [PATCH 10/11] LHS,RHS block layouts changed, scales adjusted Signed-off-by: melkap01 --- docs/ContribOperators.md | 24 +- docs/OperatorKernels.md | 2 +- .../quantization/dynamic_quant_matmul_fp8.cc | 533 ++++++------ .../quantization/dynamic_quant_matmul_fp8.h | 16 +- .../graph/contrib_ops/quantization_defs.cc | 168 ++-- .../dynamic_quant_matmul_fp8_test.cc | 772 +++++++----------- 6 files changed, 598 insertions(+), 917 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 3d17a1dfd7298..3453e32239571 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1496,7 +1496,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.DynamicQuantMatMulFp8** - Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from float16/bfloat16/float) and runtime casting of activations to fp8 using block-wise scales. All zero-point inputs, when provided, must encode 0.0. + Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from float16/bfloat16/float) and dynamic runtime quantization of activations to fp8 using internally computed block-wise scales. All zero-point inputs, when provided, must encode 0.0. #### Version @@ -1505,28 +1505,26 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
-
block_size_k : int
+
block_size_k : int (default is 128)
Block size along K for A and B block-wise scales.
-
block_size_m : int
-
Block size along M for A block-wise scales.
-
block_size_n : int
+
block_size_m : int (default is 1)
+
Block size along M for A block-wise scales. Must be 1.
+
block_size_n : int (default is 128)
Block size along N for B block-wise scales.
+
fp8_type : int (default is 17)
+
FP8 TensorProto data type used when non-FP8 constant B is dynamically quantized during prepack. Defaults to FLOAT8E4M3FN.
-#### Inputs (6 - 8) +#### Inputs (2 - 6)
A : TA
Input tensor A.
-
A_scale : TS
-
Scale of quantized input 'A'. Must be a block-wise tensor with shape (ceil(M / block_size_m), K / block_size_k), or the same shape with A batch dimensions prefixed.
-
A_zero_point : TZ
-
Zero point tensor for input 'A'. Must have the same shape as A_scale and all values must encode 0.0.
B : TB
Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only supported when B is a constant initializer that can be quantized during prepack.
-
B_scale : TS
-
Scale of input 'B'. Must be a block-wise tensor with shape (K / block_size_k, N / block_size_n).
-
B_zero_point : TZ
+
B_scale (optional) : TS
+
Scale of FP8 input 'B'. Must be a block-wise tensor with shape (N / block_size_n, K / block_size_k). Required when B is already FP8. Ignored for non-FP8 constant B, where scales are computed during prepack.
+
B_zero_point (optional) : TZ
Zero point tensor for input 'B'. Must have the same shape as B_scale and all values must encode 0.0.
Y_scale (optional) : TS
Scale of output 'Y'. Must be a scalar when provided.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 929242146e68b..969083aba3aa0 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -577,7 +577,7 @@ The **OpSet Version** column uses the following notation: |CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)| |DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**T** = tensor(float)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float)| -|DynamicQuantMatMulFp8|*in* A:**TA**
*in* A_scale:**TS**
*in* A_zero_point:**TZ**
*in* B:**TB**
*in* B_scale:**TS**
*in* B_zero_point:**TZ**
*in* Y_scale:**TS**
*in* Y_zero_point:**TZ**
*out* Y:**TY**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**TS** = tensor(bfloat16), tensor(float), tensor(float16)
**TY** = tensor(bfloat16), tensor(float), tensor(float16)
**TZ** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)| +|DynamicQuantMatMulFp8|*in* A:**TA**
*in* B:**TB**
*in* B_scale:**TS**
*in* B_zero_point:**TZ**
*in* Y_scale:**TS**
*in* Y_zero_point:**TZ**
*out* Y:**TY**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**TS** = tensor(bfloat16), tensor(float), tensor(float16)
**TY** = tensor(bfloat16), tensor(float), tensor(float16)
**TZ** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)| |DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)| diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc index 3daf07e4a2f1a..bc0204863eb59 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -30,8 +30,9 @@ namespace contrib { namespace { constexpr int64_t kDefaultBlockSize = 128; +constexpr int64_t kDefaultBlockSizeM = 1; constexpr int64_t kPackedBMetadataVersion = 1; -constexpr size_t kPackedBMetadataElementCount = 6; +constexpr size_t kPackedBMetadataElementCount = 7; constexpr size_t kPackedBMetadataSize = kPackedBMetadataElementCount * sizeof(int64_t); enum PackedBMetadataIndex : size_t { @@ -39,15 +40,11 @@ enum PackedBMetadataIndex : size_t { kPackedBMetadataRowsIndex, kPackedBMetadataColsIndex, kPackedBMetadataSizeIndex, + kPackedBMetadataScaleCountIndex, kPackedBMetadataFp8ModeIndex, kPackedBMetadataHasFp8ModeIndex, }; -size_t CeilDiv(size_t value, size_t divisor) { - ORT_ENFORCE(divisor != 0, "CeilDiv divisor must be non-zero."); - return value == 0 ? 0 : ((value - 1) / divisor) + 1; -} - bool IsFp8DataType(ONNX_NAMESPACE::TensorProto_DataType elem_type) { return elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ || @@ -63,8 +60,10 @@ bool IsValidFp8Mode(int64_t mode) { Status RestorePackedBMetadata(const void* metadata_buffer, size_t metadata_size, size_t quantized_b_buffer_size, + size_t b_scale_buffer_size, TensorShape& b_shape, size_t& quantized_b_size, + size_t& b_scale_count, mlas_fp8_mode& b_type, bool& has_b_type) { ORT_RETURN_IF(metadata_buffer == nullptr, @@ -79,6 +78,8 @@ Status RestorePackedBMetadata(const void* metadata_buffer, "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B shape."); ORT_RETURN_IF(metadata[kPackedBMetadataSizeIndex] <= 0, "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B buffer size."); + ORT_RETURN_IF(metadata[kPackedBMetadataScaleCountIndex] <= 0, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B scale count."); ORT_RETURN_IF(metadata[kPackedBMetadataHasFp8ModeIndex] != 1 || !IsValidFp8Mode(metadata[kPackedBMetadataFp8ModeIndex]), "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid FP8 type."); @@ -90,9 +91,14 @@ Status RestorePackedBMetadata(const void* metadata_buffer, ORT_RETURN_IF(restored_quantized_b_size != expected_quantized_b_size || restored_quantized_b_size != quantized_b_buffer_size, "DynamicQuantMatMulFp8 shared prepacked B metadata does not match the B buffer size."); + const size_t restored_b_scale_count = static_cast(metadata[kPackedBMetadataScaleCountIndex]); + ORT_RETURN_IF(restored_b_scale_count > std::numeric_limits::max() / sizeof(float) || + restored_b_scale_count * sizeof(float) != b_scale_buffer_size, + "DynamicQuantMatMulFp8 shared prepacked B metadata does not match the B scale buffer size."); b_shape = TensorShape({metadata[kPackedBMetadataRowsIndex], metadata[kPackedBMetadataColsIndex]}); quantized_b_size = restored_quantized_b_size; + b_scale_count = restored_b_scale_count; b_type = static_cast(metadata[kPackedBMetadataFp8ModeIndex]); has_b_type = true; return Status::OK(); @@ -107,6 +113,22 @@ Status ValidatePositiveFiniteScales(const float* scales, size_t count, const cha return Status::OK(); } +Status GetFp8MaxAbs(mlas_fp8_mode mode, float& max_abs) { + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: + case MLAS_FP8_MODE_E4M3_SAT: + max_abs = 448.0f; + return Status::OK(); + case MLAS_FP8_MODE_E5M2_INF: + case MLAS_FP8_MODE_E5M2_SAT: + max_abs = 57344.0f; + return Status::OK(); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported fp8 mode for DynamicQuantMatMulFp8."); + } +} + Status ValidateZeroPointValuesAreZero(const Tensor& zero_point, size_t expected_count, const char* zero_point_name) { const size_t actual_count = static_cast(zero_point.Shape().Size()); @@ -163,49 +185,53 @@ Status ValidateZeroPointValuesAreZero(const Tensor& zero_point, size_t expected_ } template -void QuantizeBlockwiseFp8ABlock(const SrcT* src, - size_t M, - size_t K, - size_t block_size_m, - size_t block_size_k, - size_t blocks_k, - size_t block_m, - size_t block_k, - const float* scales, - mlas_fp8_mode mode, - uint8_t* dst) { - const size_t m_begin = block_m * block_size_m; - const size_t m_end = std::min(M, m_begin + block_size_m); +void QuantizeBlockwiseFp8ABlockDynamic(const SrcT* src, + size_t K, + size_t block_size_k, + size_t blocks_k, + size_t row, + size_t block_k, + float fp8_max_abs, + mlas_fp8_mode mode, + uint8_t* dst, + float* scales) { const size_t k_begin = block_k * block_size_k; const size_t k_end = std::min(K, k_begin + block_size_k); - const size_t idx = block_m * blocks_k + block_k; - const float scale = scales[idx]; - for (size_t m = m_begin; m < m_end; ++m) { - const size_t row_offset = m * K; - for (size_t k = k_begin; k < k_end; ++k) { - const float value = static_cast(src[row_offset + k]); - const float quantized = value / scale; - dst[row_offset + k] = FloatToFp8Byte(quantized, mode); - } + const size_t idx = row * blocks_k + block_k; + + // Match the KleidiAI LHS packer: one dynamic quantization scale per A row and K block. + float max_abs = 0.0f; + const size_t row_offset = row * K; + for (size_t k = k_begin; k < k_end; ++k) { + max_abs = std::max(max_abs, std::fabs(static_cast(src[row_offset + k]))); + } + + const float scale = max_abs == 0.0f ? 1.0f : max_abs / fp8_max_abs; + scales[idx] = scale; + + for (size_t k = k_begin; k < k_end; ++k) { + const float value = static_cast(src[row_offset + k]); + const float quantized = value / scale; + dst[row_offset + k] = FloatToFp8Byte(quantized, mode); } } template -void QuantizeBlockwiseFp8(const SrcT* src, - size_t K, - size_t N, - size_t block_size_k, - size_t block_size_n, - const float* scales, - uint8_t* dst) { +void QuantizeBlockwiseFp8WithScales(const SrcT* src, + size_t K, + size_t N, + size_t block_size_k, + size_t block_size_n, + const float* scales, + uint8_t* dst) { // Block sizes come from op attributes; scale shapes only provide the number of blocks. - const size_t blocks_n = N / block_size_n; + const size_t blocks_k = K / block_size_k; for (size_t k = 0; k < K; ++k) { const size_t block_k = k / block_size_k; const size_t row_offset = k * N; for (size_t n = 0; n < N; ++n) { const size_t block_n = n / block_size_n; - const size_t scale_idx = block_k * blocks_n + block_n; + const size_t scale_idx = block_n * blocks_k + block_k; const float scale = scales[scale_idx]; const float value = static_cast(src[row_offset + n]); const float quantized = value / scale; @@ -216,27 +242,57 @@ void QuantizeBlockwiseFp8(const SrcT* src, } template -Status QuantizeToFp8ByMode(mlas_fp8_mode fp8_mode, - const SrcT* src, - size_t K, - size_t N, - size_t block_size_k, - size_t block_size_n, - const float* scales, - uint8_t* dst) { +void ComputeBlockwiseScalesFromInput(const SrcT* src, + size_t K, + size_t N, + size_t block_size_k, + size_t block_size_n, + float fp8_max_abs, + float* scales) { + // Reference-style dynamic quantization: derive one positive scale from each source block. + const size_t blocks_k = K / block_size_k; + const size_t blocks_n = N / block_size_n; + for (size_t block_k = 0; block_k < blocks_k; ++block_k) { + const size_t k_begin = block_k * block_size_k; + const size_t k_end = k_begin + block_size_k; + for (size_t block_n = 0; block_n < blocks_n; ++block_n) { + const size_t n_begin = block_n * block_size_n; + const size_t n_end = n_begin + block_size_n; + float max_abs = 0.0f; + for (size_t k = k_begin; k < k_end; ++k) { + const size_t row_offset = k * N; + for (size_t n = n_begin; n < n_end; ++n) { + max_abs = std::max(max_abs, std::fabs(static_cast(src[row_offset + n]))); + } + } + // Match KleidiAI RHS scale layout: one scale per N block and K block. + scales[block_n * blocks_k + block_k] = max_abs == 0.0f ? 1.0f : max_abs / fp8_max_abs; + } + } +} + +template +Status QuantizeToFp8ByModeWithScales(mlas_fp8_mode fp8_mode, + const SrcT* src, + size_t K, + size_t N, + size_t block_size_k, + size_t block_size_n, + const float* scales, + uint8_t* dst) { // Dispatch quantization using the requested FP8 mode and runtime block sizes. switch (fp8_mode) { case MLAS_FP8_MODE_E4M3_INF: - QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, dst); + QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst); return Status::OK(); case MLAS_FP8_MODE_E4M3_SAT: - QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, dst); + QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst); return Status::OK(); case MLAS_FP8_MODE_E5M2_INF: - QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, dst); + QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst); return Status::OK(); case MLAS_FP8_MODE_E5M2_SAT: - QuantizeBlockwiseFp8(src, K, N, block_size_k, block_size_n, scales, dst); + QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst); return Status::OK(); default: return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -263,18 +319,20 @@ ONNX_OPERATOR_KERNEL_EX( DynamicQuantMatMulFp8); DynamicQuantMatMulFp8::DynamicQuantMatMulFp8(const OpKernelInfo& info) : OpKernel(info) { - const int64_t block_size_m = info.GetAttrOrDefault("block_size_m", kDefaultBlockSize); + const int64_t block_size_m = info.GetAttrOrDefault("block_size_m", kDefaultBlockSizeM); const int64_t block_size_k = info.GetAttrOrDefault("block_size_k", kDefaultBlockSize); const int64_t block_size_n = info.GetAttrOrDefault("block_size_n", kDefaultBlockSize); - ORT_ENFORCE(block_size_m > 0, - "DynamicQuantMatMulFp8 requires block_size_m to be greater than zero."); + const int64_t fp8_type = + info.GetAttrOrDefault("fp8_type", ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN); + ORT_ENFORCE(block_size_m == 1, + "DynamicQuantMatMulFp8 requires block_size_m to be 1 to match the row-wise A scale layout."); ORT_ENFORCE(block_size_k > 0, "DynamicQuantMatMulFp8 requires block_size_k to be greater than zero."); ORT_ENFORCE(block_size_n > 0, "DynamicQuantMatMulFp8 requires block_size_n to be greater than zero."); - block_size_m_ = static_cast(block_size_m); block_size_k_ = static_cast(block_size_k); block_size_n_ = static_cast(block_size_n); + ORT_THROW_IF_ERROR(GetFp8Type(static_cast(fp8_type), fp8_type_)); } Status DynamicQuantMatMulFp8::GetFp8Type(const Tensor& tensor, mlas_fp8_mode& out_type) { @@ -308,18 +366,6 @@ Status DynamicQuantMatMulFp8::PrePack(const Tensor& tensor, int input_idx, Alloc if (input_idx != GetBIdx()) { return Status::OK(); } - // Only prepack if B scale and zero point are constant initializers. - const OrtValue* b_scale_ort = nullptr; - const OrtValue* b_zp_ort = nullptr; - const bool has_b_scale = Info().TryGetConstantInput(IN_B_SCALE, &b_scale_ort); - const bool has_b_zp = Info().TryGetConstantInput(IN_B_ZERO_POINT, &b_zp_ort); - if (!has_b_scale || !has_b_zp) { - const auto b_elem_type = static_cast(tensor.GetElementType()); - ORT_RETURN_IF(!IsFp8DataType(b_elem_type), - "DynamicQuantMatMulFp8 requires B scale and B zero point to be constant initializers when B " - "is not FP8."); - return Status::OK(); - } b_shape_ = tensor.Shape(); if (b_shape_.NumDimensions() != 2) { @@ -328,123 +374,84 @@ Status DynamicQuantMatMulFp8::PrePack(const Tensor& tensor, int input_idx, Alloc const size_t K = static_cast(b_shape_[0]); const size_t N = static_cast(b_shape_[1]); - if (K == 0) { - return Status::OK(); - } - - const auto& b_scale = b_scale_ort->Get(); - const auto& b_zp = b_zp_ort->Get(); - const auto b_elem_type = static_cast(tensor.GetElementType()); - const auto b_zp_elem_type = static_cast(b_zp.GetElementType()); const bool b_is_fp8 = IsFp8DataType(b_elem_type); - const bool zp_is_fp8 = IsFp8DataType(b_zp_elem_type); - mlas_fp8_mode b_type{}; if (b_is_fp8) { - ORT_RETURN_IF_ERROR(GetFp8Type(tensor, b_type)); - if (zp_is_fp8) { - mlas_fp8_mode b_zp_type{}; - ORT_RETURN_IF_ERROR(GetFp8Type(b_zp, b_zp_type)); - ORT_RETURN_IF(b_type != b_zp_type, - "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); - } - } else if (zp_is_fp8) { - ORT_RETURN_IF_ERROR(GetFp8Type(b_zp, b_type)); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "DynamicQuantMatMulFp8 requires fp8 zero points when B is not fp8."); + ORT_RETURN_IF_ERROR(GetFp8Type(tensor, b_type_)); + has_b_type_ = true; + return Status::OK(); } - b_type_ = b_type; - has_b_type_ = true; + b_type_ = fp8_type_; + has_b_type_ = true; + if (K == 0) { + return Status::OK(); + } if (N == 0) { return Status::OK(); } - ORT_RETURN_IF_NOT(b_scale.IsDataType() || b_scale.IsDataType() || b_scale.IsDataType(), - "DynamicQuantMatMulFp8 requires B scale input to be float, float16, or bfloat16."); - ORT_RETURN_IF_NOT(b_scale.Shape().NumDimensions() == 2, - "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor."); - ORT_RETURN_IF_NOT(b_zp.Shape().NumDimensions() == 2, - "DynamicQuantMatMulFp8 requires B zero point to be a 2D tensor."); - ORT_RETURN_IF_NOT(b_zp.Shape()[0] == b_scale.Shape()[0] && - b_zp.Shape()[1] == b_scale.Shape()[1], - "DynamicQuantMatMulFp8 requires B scale and zero point to have the same shape."); - const size_t blocks_k = static_cast(b_scale.Shape()[0]); - const size_t blocks_n = static_cast(b_scale.Shape()[1]); - ORT_RETURN_IF_NOT(blocks_k != 0 && blocks_n != 0, - "DynamicQuantMatMulFp8 requires non-zero B scale dimensions."); ORT_RETURN_IF_NOT(K % block_size_k_ == 0, "DynamicQuantMatMulFp8 requires K to be divisible by block_size_k."); ORT_RETURN_IF_NOT(N % block_size_n_ == 0, "DynamicQuantMatMulFp8 requires N to be divisible by block_size_n."); - const size_t expected_blocks_k = K / block_size_k_; - const size_t expected_blocks_n = N / block_size_n_; - ORT_RETURN_IF_NOT(blocks_k == expected_blocks_k, - "DynamicQuantMatMulFp8 requires B scale first dimension to be K / block_size_k."); - ORT_RETURN_IF_NOT(blocks_n == expected_blocks_n, - "DynamicQuantMatMulFp8 requires B scale last dimension to be N / block_size_n."); - - const size_t b_scale_elems = static_cast(b_scale.Shape().Size()); - ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(b_zp, b_scale_elems, "B zero point")); - - const float* b_scales = nullptr; - IAllocatorUniquePtr b_scale_float; - if (b_scale.IsDataType()) { - b_scales = b_scale.Data(); - } else if (b_scale.IsDataType()) { - b_scale_float = IAllocator::MakeUniquePtr(alloc, b_scale_elems, true); - for (size_t i = 0; i < b_scale_elems; ++i) { - b_scale_float.get()[i] = static_cast(b_scale.Data()[i]); - } - b_scales = b_scale_float.get(); + const size_t blocks_k = K / block_size_k_; + const size_t blocks_n = N / block_size_n_; + b_scale_count_ = SafeMul(blocks_k, blocks_n); + b_scales_ = IAllocator::MakeUniquePtr(alloc, b_scale_count_, true); + float fp8_max_abs = 0.0f; + ORT_RETURN_IF_ERROR(GetFp8MaxAbs(b_type_, fp8_max_abs)); + + const size_t quantized_b_size = SafeMul(K, N); + quantized_b_ = IAllocator::MakeUniquePtr(alloc, quantized_b_size, true); + quantized_b_size_ = quantized_b_size; + auto* quantized_b_bytes = static_cast(quantized_b_.get()); + if (tensor.IsDataType()) { + ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_, + fp8_max_abs, b_scales_.get()); + ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_, + b_scales_.get(), quantized_b_bytes)); + } else if (tensor.IsDataType()) { + ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_, + fp8_max_abs, b_scales_.get()); + ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_, + b_scales_.get(), quantized_b_bytes)); + } else if (tensor.IsDataType()) { + ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_, + fp8_max_abs, b_scales_.get()); + ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_, + b_scales_.get(), quantized_b_bytes)); } else { - b_scale_float = IAllocator::MakeUniquePtr(alloc, b_scale_elems, true); - for (size_t i = 0; i < b_scale_elems; ++i) { - b_scale_float.get()[i] = static_cast(b_scale.Data()[i]); - } - b_scales = b_scale_float.get(); - } - ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(b_scales, b_scale_elems, "B scale")); - // If B is not already FP8, quantize it once during prepack and reuse the cached FP8 buffer. - if (!b_is_fp8) { - const size_t quantized_b_size = SafeMul(K, N); - quantized_b_ = IAllocator::MakeUniquePtr(alloc, quantized_b_size, true); - quantized_b_size_ = quantized_b_size; - auto* quantized_b_bytes = static_cast(quantized_b_.get()); - if (tensor.IsDataType()) { - ORT_RETURN_IF_ERROR(QuantizeToFp8ByMode(b_type, tensor.Data(), K, N, block_size_k_, block_size_n_, - b_scales, quantized_b_bytes)); - } else if (tensor.IsDataType()) { - ORT_RETURN_IF_ERROR(QuantizeToFp8ByMode(b_type, tensor.Data(), K, N, block_size_k_, block_size_n_, - b_scales, quantized_b_bytes)); - } else if (tensor.IsDataType()) { - ORT_RETURN_IF_ERROR(QuantizeToFp8ByMode(b_type, tensor.Data(), K, N, block_size_k_, block_size_n_, - b_scales, quantized_b_bytes)); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Unsupported B type for DynamicQuantMatMulFp8 prepack."); - } - if (prepacked_weights != nullptr) { - const std::array metadata_values = { - kPackedBMetadataVersion, - b_shape_[0], - b_shape_[1], - static_cast(quantized_b_size_), - static_cast(b_type_), - 1, - }; - auto metadata = IAllocator::MakeUniquePtr(alloc, kPackedBMetadataSize, true); - std::memcpy(metadata.get(), metadata_values.data(), kPackedBMetadataSize); - prepacked_weights->buffers_.push_back(std::move(quantized_b_)); - prepacked_weights->buffer_sizes_.push_back(quantized_b_size_); - prepacked_weights->buffers_.push_back(std::move(metadata)); - prepacked_weights->buffer_sizes_.push_back(kPackedBMetadataSize); - } - is_packed = true; - return Status::OK(); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported B type for DynamicQuantMatMulFp8 prepack."); } + if (prepacked_weights != nullptr) { + const std::array metadata_values = { + kPackedBMetadataVersion, + b_shape_[0], + b_shape_[1], + static_cast(quantized_b_size_), + static_cast(b_scale_count_), + static_cast(b_type_), + 1, + }; + auto metadata = IAllocator::MakeUniquePtr(alloc, kPackedBMetadataSize, true); + std::memcpy(metadata.get(), metadata_values.data(), kPackedBMetadataSize); + auto b_scales_deleter = std::move(b_scales_.get_deleter()); + IAllocatorUniquePtr b_scales_buffer( + b_scales_.release(), + [b_scales_deleter = std::move(b_scales_deleter)](void* p) mutable { + b_scales_deleter(static_cast(p)); + }); + prepacked_weights->buffers_.push_back(std::move(quantized_b_)); + prepacked_weights->buffer_sizes_.push_back(quantized_b_size_); + prepacked_weights->buffers_.push_back(std::move(b_scales_buffer)); + prepacked_weights->buffer_sizes_.push_back(b_scale_count_ * sizeof(float)); + prepacked_weights->buffers_.push_back(std::move(metadata)); + prepacked_weights->buffer_sizes_.push_back(kPackedBMetadataSize); + } + is_packed = true; return Status::OK(); } @@ -457,20 +464,30 @@ Status DynamicQuantMatMulFp8::UseSharedPrePackedBuffers(std::vector( + static_cast(prepacked_buffers[1].release()), + [b_scales_deleter = std::move(b_scales_deleter)](float* p) mutable { + b_scales_deleter(static_cast(p)); + }); used_shared_buffers = true; return Status::OK(); } @@ -478,8 +495,6 @@ Status DynamicQuantMatMulFp8::UseSharedPrePackedBuffers(std::vectorInput(IN_A); const Tensor* b = quantized_b_ ? nullptr : context->Input(IN_B); - const Tensor* a_scale = context->Input(IN_A_SCALE); - const Tensor* a_zero_point = context->Input(IN_A_ZERO_POINT); const Tensor* b_scale = context->Input(IN_B_SCALE); const Tensor* b_zero_point = context->Input(IN_B_ZERO_POINT); const Tensor* y_scale = context->Input(IN_Y_SCALE); @@ -555,30 +570,24 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { const auto b_elem_type = b ? static_cast(b->GetElementType()) : static_cast(0); - const auto b_zp_elem_type = - static_cast(b_zero_point->GetElementType()); const bool b_is_fp8 = IsFp8DataType(b_elem_type); - mlas_fp8_mode a_type{}; - ORT_RETURN_IF(a_zero_point == nullptr, - "DynamicQuantMatMulFp8 requires FP8 zero point for A."); - ORT_RETURN_IF_ERROR(GetFp8Type(*a_zero_point, a_type)); mlas_fp8_mode b_type{}; if (has_b_type_) { b_type = b_type_; } else if (b_is_fp8) { ORT_RETURN_IF_ERROR(GetFp8Type(b_elem_type, b_type)); - if (IsFp8DataType(b_zp_elem_type)) { + if (b_zero_point != nullptr) { + const auto b_zp_elem_type = + static_cast(b_zero_point->GetElementType()); mlas_fp8_mode b_zp_type{}; ORT_RETURN_IF_ERROR(GetFp8Type(b_zp_elem_type, b_zp_type)); ORT_RETURN_IF(b_type != b_zp_type, "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); } } else { - ORT_RETURN_IF_ERROR(GetFp8Type(b_zp_elem_type, b_type)); + b_type = fp8_type_; } - ORT_RETURN_IF(a_type != b_type, - "DynamicQuantMatMulFp8 requires A/B FP8 types to match."); if (y_size == 0) { return Status::OK(); @@ -597,120 +606,53 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { } const size_t num_gemms = helper.OutputOffsets().size(); - const size_t a_scale_rank = static_cast(a_scale->Shape().NumDimensions()); - const size_t a_zp_rank = static_cast(a_zero_point->Shape().NumDimensions()); - // The scale tensor layout carries the number of tiles because it physically stores one - // scale per tile. - // A scale/zero-point may be [blocks_m, blocks_k] or [prefix..., blocks_m, blocks_k]. - // The scale tensor shape provides the number of quantization blocks. It does not define - // the block size. Block sizes come from the block_size_m/block_size_k/block_size_n attributes - // so future models can choose different tile sizes without changing how scale tensors are - // interpreted. The validation below binds both pieces of information together by requiring: - // blocks_m == ceil(M / block_size_m) - // blocks_k == K / block_size_k - // blocks_n == N / block_size_n - // This prevents silently treating a malformed scale shape as a different runtime block size. - ORT_RETURN_IF(a_scale_rank < 2, - "DynamicQuantMatMulFp8 requires A scale to have rank >= 2."); - ORT_RETURN_IF(a_zp_rank < 2, - "DynamicQuantMatMulFp8 requires A zero point to have rank >= 2."); - ORT_RETURN_IF(a_scale_rank != a_zp_rank, - "DynamicQuantMatMulFp8 requires A scale and zero point to have the same rank."); - const size_t blocks_m = static_cast(a_scale->Shape()[a_scale_rank - 2]); - const size_t blocks_k = static_cast(a_scale->Shape()[a_scale_rank - 1]); - for (size_t dim = 0; dim < a_scale_rank; ++dim) { - ORT_RETURN_IF(a_scale->Shape()[dim] != a_zero_point->Shape()[dim], - "DynamicQuantMatMulFp8 requires A scale and zero point to have the same shape."); - } - if (a_scale_rank != 2) { - const size_t a_rank = a->Shape().NumDimensions(); - ORT_RETURN_IF(a_scale_rank != a_rank, - "DynamicQuantMatMulFp8 requires A scale rank to be 2 or match A rank."); - for (size_t dim = 0; dim < a_rank - 2; ++dim) { - ORT_RETURN_IF(a_scale->Shape()[dim] != a->Shape()[dim], - "DynamicQuantMatMulFp8 requires A scale batch dimensions to match A."); - } - } - // Scale tensor block counts must match the explicit block-size attributes before reading scale data. - ORT_RETURN_IF(blocks_m == 0, "DynamicQuantMatMulFp8 requires non-zero A scale M dimension."); - ORT_RETURN_IF(blocks_k == 0, "DynamicQuantMatMulFp8 requires non-zero A scale K dimension."); ORT_RETURN_IF(K % block_size_k_ != 0, "DynamicQuantMatMulFp8 requires K to be divisible by block_size_k."); - const size_t expected_blocks_m = CeilDiv(M, block_size_m_); const size_t expected_blocks_k = K / block_size_k_; - // If the scale tensor says it has a different number of M blocks than ceil(M / block_size_m), - // return an error instead of running with wrong scale indexing. - ORT_RETURN_IF(blocks_m != expected_blocks_m, - "DynamicQuantMatMulFp8 requires A scale M dimension to be ceil(M / block_size_m)."); - ORT_RETURN_IF(blocks_k != expected_blocks_k, - "DynamicQuantMatMulFp8 requires A scale K dimension to be K / block_size_k."); - - ORT_RETURN_IF(b_scale->Shape().NumDimensions() != 2, + const size_t blocks_m = M; + const size_t blocks_k = expected_blocks_k; + + const size_t blocks_n = N / block_size_n_; + ORT_RETURN_IF(!b_scales_ && b_scale == nullptr, + "DynamicQuantMatMulFp8 requires B scale when B is already FP8."); + ORT_RETURN_IF(b_scale != nullptr && b_scale->Shape().NumDimensions() != 2, "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor."); - ORT_RETURN_IF(b_zero_point->Shape().NumDimensions() != 2, - "DynamicQuantMatMulFp8 requires B zero point to be a 2D tensor."); - const size_t blocks_n = static_cast(b_scale->Shape()[1]); - ORT_RETURN_IF(b_zero_point->Shape()[0] != b_scale->Shape()[0] || - b_zero_point->Shape()[1] != b_scale->Shape()[1], - "DynamicQuantMatMulFp8 requires B scale and zero point to have the same shape."); - ORT_RETURN_IF(static_cast(b_scale->Shape()[0]) != blocks_k, - "DynamicQuantMatMulFp8 requires B scale K dimension to match A scale K dimension."); ORT_RETURN_IF(blocks_n == 0, "DynamicQuantMatMulFp8 requires non-zero B scale N dimension."); ORT_RETURN_IF(N % block_size_n_ != 0, "DynamicQuantMatMulFp8 requires N to be divisible by block_size_n."); - const size_t expected_blocks_n = N / block_size_n_; - ORT_RETURN_IF(blocks_n != expected_blocks_n, + ORT_RETURN_IF(b_scale != nullptr && static_cast(b_scale->Shape()[0]) != blocks_n, "DynamicQuantMatMulFp8 requires B scale N dimension to be N / block_size_n."); + ORT_RETURN_IF(b_scale != nullptr && static_cast(b_scale->Shape()[1]) != blocks_k, + "DynamicQuantMatMulFp8 requires B scale K dimension to be K / block_size_k."); - size_t a_scale_prefix = 1; - if (a_scale_rank > 2) { - for (size_t dim = 0; dim < a_scale_rank - 2; ++dim) { - a_scale_prefix = SafeMul( - a_scale_prefix, static_cast(a_scale->Shape()[dim])); - } - } const size_t a_scale_batch_stride = SafeMul(blocks_m, blocks_k); - const size_t a_zp_count = SafeMul(a_scale_prefix, a_scale_batch_stride); const size_t b_zp_count = SafeMul(blocks_k, blocks_n); - ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*a_zero_point, a_zp_count, "A zero point")); - ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*b_zero_point, b_zp_count, "B zero point")); + if (b_zero_point != nullptr) { + ORT_RETURN_IF(b_zero_point->Shape().NumDimensions() != 2, + "DynamicQuantMatMulFp8 requires B zero point to be a 2D tensor."); + ORT_RETURN_IF(b_zero_point->Shape()[0] != static_cast(blocks_n) || + b_zero_point->Shape()[1] != static_cast(blocks_k), + "DynamicQuantMatMulFp8 requires B zero point to have shape [N / block_size_n, K / block_size_k]."); + ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*b_zero_point, b_zp_count, "B zero point")); + } AllocatorPtr temp_allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&temp_allocator)); - const float* a_scales = nullptr; - IAllocatorUniquePtr a_scale_float; - const size_t a_scale_elems = static_cast(a_scale->Shape().Size()); - if (a_scale->IsDataType()) { - a_scales = a_scale->Data(); - } else { - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - a_scale_float = IAllocator::MakeUniquePtr(allocator, a_scale_elems, true); - if (a_scale->IsDataType()) { - for (size_t i = 0; i < a_scale_elems; ++i) { - a_scale_float.get()[i] = static_cast(a_scale->Data()[i]); - } - } else if (a_scale->IsDataType()) { - for (size_t i = 0; i < a_scale_elems; ++i) { - a_scale_float.get()[i] = static_cast(a_scale->Data()[i]); - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "DynamicQuantMatMulFp8 requires A scale input to be float, float16, or bfloat16."); - } - a_scales = a_scale_float.get(); - } - const float* b_scales = nullptr; IAllocatorUniquePtr b_scale_float; - const size_t b_scale_elems = static_cast(b_scale->Shape().Size()); - if (b_scale->IsDataType()) { + size_t b_scale_elems = 0; + if (b_scales_) { + b_scales = b_scales_.get(); + b_scale_elems = b_scale_count_; + } else if (b_scale->IsDataType()) { b_scales = b_scale->Data(); + b_scale_elems = static_cast(b_scale->Shape().Size()); } else { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + b_scale_elems = static_cast(b_scale->Shape().Size()); b_scale_float = IAllocator::MakeUniquePtr(allocator, b_scale_elems, true); if (b_scale->IsDataType()) { for (size_t i = 0; i < b_scale_elems; ++i) { @@ -746,7 +688,6 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { gemm_shape.N = N; gemm_shape.K = K; - ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(a_scales, a_scale_elems, "A scale")); ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(b_scales, b_scale_elems, "B scale")); const size_t a_fp8_size = SafeMul(M, K); @@ -754,20 +695,22 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { ORT_RETURN_IF(a_num_elements % a_fp8_size != 0, "DynamicQuantMatMulFp8 requires A to contain complete MxK matrices."); const size_t a_batch_count = a_num_elements / a_fp8_size; - ORT_RETURN_IF(a_scale_prefix != 1 && a_scale_prefix != a_batch_count, - "DynamicQuantMatMulFp8 requires A scale batch dimensions to match A."); // Quantize the physical A tensor once. Broadcasted output GEMMs then reuse the same FP8 A slice. auto a_fp8_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_num_elements, true); + const size_t a_scale_count = SafeMul(a_batch_count, a_scale_batch_stride); + auto a_scale_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_scale_count, true); const size_t a_quant_work_items = SafeMul(a_batch_count, a_scale_batch_stride); ORT_RETURN_IF(a_quant_work_items > static_cast(std::numeric_limits::max()), "DynamicQuantMatMulFp8 A quantization work item count exceeds ptrdiff_t range."); const auto a_quant_work_items_i = static_cast(a_quant_work_items); - const size_t a_quant_block_elems = SafeMul(block_size_m_, block_size_k_); + const size_t a_quant_block_elems = block_size_k_; const TensorOpCost a_quant_unit_cost{ static_cast(SafeMul(a_quant_block_elems, sizeof(float))), static_cast(SafeMul(a_quant_block_elems, sizeof(uint8_t))), static_cast(a_quant_block_elems) * 2.0}; + float fp8_max_abs = 0.0f; + ORT_RETURN_IF_ERROR(GetFp8MaxAbs(b_type, fp8_max_abs)); const auto quantize_a_batches = [&](const auto* a_data) { concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), a_quant_work_items_i, a_quant_unit_cost, @@ -776,16 +719,16 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { const size_t work_idx = static_cast(tid); const size_t a_batch_idx = work_idx / a_scale_batch_stride; const size_t scale_block_idx = work_idx % a_scale_batch_stride; - const size_t block_m = scale_block_idx / blocks_k; + const size_t row = scale_block_idx / blocks_k; const size_t block_k = scale_block_idx % blocks_k; const size_t a_batch_offset = a_batch_idx * a_fp8_size; - const size_t scale_batch_index = (a_scale_prefix == 1) ? 0 : a_batch_idx; - const size_t a_scale_batch_offset = scale_batch_index * a_scale_batch_stride; - QuantizeBlockwiseFp8ABlock(a_data + a_batch_offset, - M, K, block_size_m_, block_size_k_, blocks_k, - block_m, block_k, - a_scales + a_scale_batch_offset, a_type, - a_fp8_buffer.get() + a_batch_offset); + const size_t a_scale_batch_offset = a_batch_idx * a_scale_batch_stride; + QuantizeBlockwiseFp8ABlockDynamic( + a_data + a_batch_offset, + K, block_size_k_, blocks_k, + row, block_k, fp8_max_abs, b_type, + a_fp8_buffer.get() + a_batch_offset, + a_scale_buffer.get() + a_scale_batch_offset); } }); }; @@ -805,11 +748,11 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { const size_t a_offset = helper.LeftOffsets()[gemm_idx]; ORT_RETURN_IF(a_offset >= a_num_elements || (a_offset % a_fp8_size) != 0, "DynamicQuantMatMulFp8 requires A offsets to reference complete MxK matrices."); - const size_t scale_batch_index = (a_scale_prefix == 1) ? 0 : a_offset / a_fp8_size; - ORT_RETURN_IF(scale_batch_index >= a_scale_prefix, - "DynamicQuantMatMulFp8 requires A scale batch dimensions to match A."); + const size_t scale_batch_index = a_offset / a_fp8_size; + ORT_RETURN_IF(scale_batch_index >= a_batch_count, + "DynamicQuantMatMulFp8 requires A offsets to reference complete MxK matrices."); const size_t a_scale_batch_offset = SafeMul(scale_batch_index, a_scale_batch_stride); - const float* a_scales_batch = a_scales + a_scale_batch_offset; + const float* a_scales_batch = a_scale_buffer.get() + a_scale_batch_offset; auto& gemm_data = gemm_data_vec[gemm_idx]; gemm_data.A = a_fp8_buffer.get() + a_offset; gemm_data.lda = K; @@ -820,8 +763,8 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { gemm_data.ScaleA = a_scales_batch; gemm_data.ScaleB = b_scales; gemm_data.ScaleY = y_scale_data; - gemm_data.Fp8Type = a_type; - gemm_data.BlockSizeM = block_size_m_; + gemm_data.Fp8Type = b_type; + gemm_data.BlockSizeM = 1; gemm_data.BlockSizeK = block_size_k_; gemm_data.BlockSizeN = block_size_n_; gemm_data.BlocksM = blocks_m; @@ -829,8 +772,8 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { gemm_data.BlocksN = blocks_n; gemm_data.ScaleAStrideK = 1; gemm_data.ScaleAStrideM = blocks_k; - gemm_data.ScaleBStrideN = 1; - gemm_data.ScaleBStrideK = blocks_n; + gemm_data.ScaleBStrideN = blocks_k; + gemm_data.ScaleBStrideK = 1; } MlasFp8GemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, context->GetOperatorThreadPool()); diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h index 62360e80fd180..cef88def0ec77 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h @@ -24,13 +24,11 @@ class DynamicQuantMatMulFp8 final : public OpKernel { enum InputTensors : int { IN_A = 0, - IN_A_SCALE = 1, - IN_A_ZERO_POINT = 2, - IN_B = 3, - IN_B_SCALE = 4, - IN_B_ZERO_POINT = 5, - IN_Y_SCALE = 6, - IN_Y_ZERO_POINT = 7 + IN_B = 1, + IN_B_SCALE = 2, + IN_B_ZERO_POINT = 3, + IN_Y_SCALE = 4, + IN_Y_ZERO_POINT = 5 }; enum OutputTensors : int { OUT_Y = 0 }; @@ -47,10 +45,12 @@ class DynamicQuantMatMulFp8 final : public OpKernel { static constexpr int GetBIdx() { return IN_B; } IAllocatorUniquePtr quantized_b_; size_t quantized_b_size_{0}; + IAllocatorUniquePtr b_scales_; + size_t b_scale_count_{0}; TensorShape b_shape_; mlas_fp8_mode b_type_{static_cast(0)}; bool has_b_type_{false}; - size_t block_size_m_{128}; + mlas_fp8_mode fp8_type_{MLAS_FP8_MODE_E4M3_INF}; size_t block_size_k_{128}; size_t block_size_n_{128}; }; diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 1c459ef1e9fbf..8601b53cc15a1 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -952,39 +952,38 @@ ONNX_MS_OPERATOR_SET_SCHEMA( DynamicQuantMatMulFp8, 1, OpSchema() .SetDoc("Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from " - "float16/bfloat16/float) and runtime casting of activations to fp8 using block-wise scales. " - "All zero-point inputs, when provided, must encode 0.0.") + "float16/bfloat16/float) and dynamic runtime quantization of activations to fp8 using " + "internally computed block-wise scales. All zero-point inputs, when provided, must encode 0.0.") .Input(0, "A", "Input tensor A.", "TA") - .Input(1, "A_scale", - "Scale of quantized input 'A'. Must be a block-wise tensor with shape " - "(ceil(M / block_size_m), K / block_size_k), or the same shape with A batch dimensions prefixed.", - "TS") - .Input(2, "A_zero_point", - "Zero point tensor for input 'A'. Must have the same shape as A_scale and all values must encode 0.0.", - "TZ") - .Input(3, "B", + .Input(1, "B", "Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only " "supported when B is a constant initializer that can be quantized during prepack.", "TB") - .Input(4, "B_scale", - "Scale of input 'B'. Must be a block-wise tensor with shape " - "(K / block_size_k, N / block_size_n).", - "TS") - .Input(5, "B_zero_point", + .Input(2, "B_scale", + "Scale of FP8 input 'B'. Must be a block-wise tensor with shape " + "(N / block_size_n, K / block_size_k). Required when B is already FP8. Ignored for non-FP8 " + "constant B, where scales are computed during prepack.", + "TS", OpSchema::Optional) + .Input(3, "B_zero_point", "Zero point tensor for input 'B'. Must have the same shape as B_scale and all values must encode 0.0.", - "TZ") - .Input(6, "Y_scale", "Scale of output 'Y'. Must be a scalar when provided.", "TS", + "TZ", OpSchema::Optional) + .Input(4, "Y_scale", "Scale of output 'Y'. Must be a scalar when provided.", "TS", OpSchema::Optional) - .Input(7, "Y_zero_point", + .Input(5, "Y_zero_point", "Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided.", "TZ", OpSchema::Optional) .Output(0, "Y", "Output tensor of shape (..., M, N).", "TY") - .Attr("block_size_m", "Block size along M for A block-wise scales.", AttributeProto::INT, - static_cast(128)) + .Attr("block_size_m", "Block size along M for A block-wise scales. Must be 1.", + AttributeProto::INT, static_cast(1)) .Attr("block_size_k", "Block size along K for A and B block-wise scales.", AttributeProto::INT, static_cast(128)) .Attr("block_size_n", "Block size along N for B block-wise scales.", AttributeProto::INT, static_cast(128)) + .Attr("fp8_type", + "FP8 TensorProto data type used when non-FP8 constant B is dynamically quantized during prepack. " + "Defaults to FLOAT8E4M3FN.", + AttributeProto::INT, + static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN)) .TypeConstraint("TA", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input A type to float16, bfloat16, or float.") .TypeConstraint("TB", @@ -1000,144 +999,71 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("TY", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain output type to float16, bfloat16, or float.") .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { - const int64_t block_size_m = getAttribute(ctx, "block_size_m", static_cast(128)); + const int64_t block_size_m = getAttribute(ctx, "block_size_m", static_cast(1)); const int64_t block_size_k = getAttribute(ctx, "block_size_k", static_cast(128)); const int64_t block_size_n = getAttribute(ctx, "block_size_n", static_cast(128)); - if (block_size_m <= 0 || block_size_k <= 0 || block_size_n <= 0) { - fail_type_inference("block_size_m, block_size_k, and block_size_n must be greater than zero."); + if (block_size_m != 1) { + fail_type_inference("block_size_m must be 1."); } - const auto ceil_div = [](int64_t value, int64_t divisor) { - return value == 0 ? int64_t{0} : ((value - 1) / divisor) + 1; - }; - - if (hasInputShape(ctx, 0) && hasInputShape(ctx, 3)) { - ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 3); + if (block_size_k <= 0 || block_size_n <= 0) { + fail_type_inference("block_size_k and block_size_n must be greater than zero."); } if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { - auto& a_shape = getInputShape(ctx, 0); - auto& a_scale_shape = getInputShape(ctx, 1); - const int a_rank = a_shape.dim_size(); - if (a_rank < 2) { - fail_type_inference("A must be at least 2D."); - } - const int a_scale_rank = a_scale_shape.dim_size(); - if (a_scale_rank < 2) { - fail_type_inference("A scale must have rank 2 or the same rank as A."); - } - if (a_scale_rank != 2) { - if (a_scale_rank != a_rank) { - fail_type_inference("A scale must have rank 2 or the same rank as A."); - } - for (int i = 0; i < a_rank - 2; ++i) { - if (a_shape.dim(i).has_dim_value() && a_scale_shape.dim(i).has_dim_value() && - a_shape.dim(i).dim_value() != a_scale_shape.dim(i).dim_value()) { - fail_type_inference("A scale batch dimensions must match A."); - } - } - } - if (a_shape.dim(a_rank - 2).has_dim_value() && a_scale_shape.dim(a_scale_rank - 2).has_dim_value() && - a_shape.dim(a_rank - 1).has_dim_value() && a_scale_shape.dim(a_scale_rank - 1).has_dim_value()) { - const auto m = a_shape.dim(a_rank - 2).dim_value(); - const auto k = a_shape.dim(a_rank - 1).dim_value(); - const auto m_blocks = ceil_div(m, block_size_m); - if (a_scale_shape.dim(a_scale_rank - 2).dim_value() != m_blocks) { - fail_type_inference("A scale second-to-last dimension must be ceil(M / block_size_m)."); - } - if ((k % block_size_k) != 0 || - a_scale_shape.dim(a_scale_rank - 1).dim_value() != (k / block_size_k)) { - fail_type_inference("A scale last dimension must be K / block_size_k."); - } - } - } - if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { - auto& a_shape = getInputShape(ctx, 0); - auto& a_zp_shape = getInputShape(ctx, 2); - const int a_rank = a_shape.dim_size(); - if (a_rank < 2) { - fail_type_inference("A must be at least 2D."); - } - const int a_zp_rank = a_zp_shape.dim_size(); - if (a_zp_rank < 2) { - fail_type_inference("A zero point must have rank 2 or the same rank as A."); - } - if (a_zp_rank != 2) { - if (a_zp_rank != a_rank) { - fail_type_inference("A zero point must have rank 2 or the same rank as A."); - } - for (int i = 0; i < a_rank - 2; ++i) { - if (a_shape.dim(i).has_dim_value() && a_zp_shape.dim(i).has_dim_value() && - a_shape.dim(i).dim_value() != a_zp_shape.dim(i).dim_value()) { - fail_type_inference("A zero point batch dimensions must match A."); - } - } - } - if (a_shape.dim(a_rank - 2).has_dim_value() && a_zp_shape.dim(a_zp_rank - 2).has_dim_value() && - a_shape.dim(a_rank - 1).has_dim_value() && a_zp_shape.dim(a_zp_rank - 1).has_dim_value()) { - const auto m = a_shape.dim(a_rank - 2).dim_value(); - const auto k = a_shape.dim(a_rank - 1).dim_value(); - const auto m_blocks = ceil_div(m, block_size_m); - if (a_zp_shape.dim(a_zp_rank - 2).dim_value() != m_blocks) { - fail_type_inference("A zero point second-to-last dimension must be ceil(M / block_size_m)."); - } - if ((k % block_size_k) != 0 || - a_zp_shape.dim(a_zp_rank - 1).dim_value() != (k / block_size_k)) { - fail_type_inference("A zero point last dimension must be K / block_size_k."); - } - } + ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1); } - if (hasInputShape(ctx, 6)) { - auto shape = ctx.getInputType(6)->tensor_type().shape(); + if (hasInputShape(ctx, 4)) { + auto shape = ctx.getInputType(4)->tensor_type().shape(); if (shape.dim_size() != 0) { fail_type_inference("Y scale input must be a scalar."); } } - if (hasInputShape(ctx, 7)) { - auto shape = ctx.getInputType(7)->tensor_type().shape(); + if (hasInputShape(ctx, 5)) { + auto shape = ctx.getInputType(5)->tensor_type().shape(); if (shape.dim_size() != 0) { fail_type_inference("Y zero point input must be a scalar."); } } - if (hasInputShape(ctx, 3) && hasInputShape(ctx, 4)) { - auto& b_shape = getInputShape(ctx, 3); - auto& b_scale_shape = getInputShape(ctx, 4); + if (hasInputShape(ctx, 1) && hasInputShape(ctx, 2)) { + auto& b_shape = getInputShape(ctx, 1); + auto& b_scale_shape = getInputShape(ctx, 2); if (b_shape.dim_size() != 2) { fail_type_inference("B must be 2D."); } if (b_scale_shape.dim_size() != 2) { fail_type_inference("B scale must be 2D."); } - if (b_shape.dim(1).has_dim_value() && b_scale_shape.dim(1).has_dim_value()) { + if (b_shape.dim(1).has_dim_value() && b_scale_shape.dim(0).has_dim_value()) { const auto n = b_shape.dim(1).dim_value(); - if ((n % block_size_n) != 0 || b_scale_shape.dim(1).dim_value() != (n / block_size_n)) { - fail_type_inference("B scale last dimension must be N / block_size_n."); + if ((n % block_size_n) != 0 || b_scale_shape.dim(0).dim_value() != (n / block_size_n)) { + fail_type_inference("B scale first dimension must be N / block_size_n."); } } - if (b_shape.dim(0).has_dim_value() && b_scale_shape.dim(0).has_dim_value()) { + if (b_shape.dim(0).has_dim_value() && b_scale_shape.dim(1).has_dim_value()) { const auto k = b_shape.dim(0).dim_value(); - if ((k % block_size_k) != 0 || b_scale_shape.dim(0).dim_value() != (k / block_size_k)) { - fail_type_inference("B scale first dimension must be K / block_size_k."); + if ((k % block_size_k) != 0 || b_scale_shape.dim(1).dim_value() != (k / block_size_k)) { + fail_type_inference("B scale last dimension must be K / block_size_k."); } } } - if (hasInputShape(ctx, 3) && hasInputShape(ctx, 5)) { - auto& b_shape = getInputShape(ctx, 3); - auto& b_zp_shape = getInputShape(ctx, 5); + if (hasInputShape(ctx, 1) && hasInputShape(ctx, 3)) { + auto& b_shape = getInputShape(ctx, 1); + auto& b_zp_shape = getInputShape(ctx, 3); if (b_shape.dim_size() != 2) { fail_type_inference("B must be 2D."); } if (b_zp_shape.dim_size() != 2) { fail_type_inference("B zero point must be 2D."); } - if (b_shape.dim(1).has_dim_value() && b_zp_shape.dim(1).has_dim_value()) { + if (b_shape.dim(1).has_dim_value() && b_zp_shape.dim(0).has_dim_value()) { const auto n = b_shape.dim(1).dim_value(); - if ((n % block_size_n) != 0 || b_zp_shape.dim(1).dim_value() != (n / block_size_n)) { - fail_type_inference("B zero point last dimension must be N / block_size_n."); + if ((n % block_size_n) != 0 || b_zp_shape.dim(0).dim_value() != (n / block_size_n)) { + fail_type_inference("B zero point first dimension must be N / block_size_n."); } } - if (b_shape.dim(0).has_dim_value() && b_zp_shape.dim(0).has_dim_value()) { + if (b_shape.dim(0).has_dim_value() && b_zp_shape.dim(1).has_dim_value()) { const auto k = b_shape.dim(0).dim_value(); - if ((k % block_size_k) != 0 || b_zp_shape.dim(0).dim_value() != (k / block_size_k)) { - fail_type_inference("B zero point first dimension must be K / block_size_k."); + if ((k % block_size_k) != 0 || b_zp_shape.dim(1).dim_value() != (k / block_size_k)) { + fail_type_inference("B zero point last dimension must be K / block_size_k."); } } } diff --git a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc index 08f5f1cfe5f91..6026ea13c239d 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc @@ -2,6 +2,11 @@ // SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates // SPDX-License-Identifier: MIT +#include +#include +#include +#include + #include "gtest/gtest.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" @@ -24,87 +29,198 @@ class DynamicQuantMatMulFp8SessionTester : public OpTester { using OpTester::OpTester; }; -float QuantizeDequantizeE4M3(float value, float scale) { - return Float8E4M3FN(value / scale, true).ToFloat() * scale; +template +struct Fp8TensorProtoType; + +template <> +struct Fp8TensorProtoType { + static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN; +}; + +template <> +struct Fp8TensorProtoType { + static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ; +}; + +template <> +struct Fp8TensorProtoType { + static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; +}; + +template <> +struct Fp8TensorProtoType { + static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ; +}; + +template +float Fp8MaxAbs(); + +template <> +float Fp8MaxAbs() { + return 448.0f; +} + +template <> +float Fp8MaxAbs() { + return 448.0f; +} + +template <> +float Fp8MaxAbs() { + return 57344.0f; +} + +template <> +float Fp8MaxAbs() { + return 57344.0f; } -std::vector ComputeExpectedIdentityAWithQuantizedB(gsl::span b_data, - gsl::span b_scale, - int64_t k, - int64_t n, - int64_t block_size_k, - int64_t block_size_n) { +template +float QuantizeDequantize(float value, float scale) { + return Fp8T(value / scale, true).ToFloat() * scale; +} + +template +std::vector ComputeRowwiseAQuantized(const std::vector& a_data, + int64_t m, + int64_t k, + int64_t block_size_k) { + const int64_t blocks_k = k / block_size_k; + std::vector result(a_data.size()); + for (int64_t row = 0; row < m; ++row) { + for (int64_t block_k = 0; block_k < blocks_k; ++block_k) { + const int64_t k_begin = block_k * block_size_k; + const int64_t k_end = k_begin + block_size_k; + float max_abs = 0.0f; + for (int64_t kk = k_begin; kk < k_end; ++kk) { + max_abs = std::max(max_abs, std::fabs(a_data[static_cast(row * k + kk)])); + } + const float scale = max_abs == 0.0f ? 1.0f : max_abs / Fp8MaxAbs(); + for (int64_t kk = k_begin; kk < k_end; ++kk) { + const size_t index = static_cast(row * k + kk); + result[index] = QuantizeDequantize(a_data[index], scale); + } + } + } + return result; +} + +template +std::vector ComputeConstantBQuantized(const std::vector& b_data, + int64_t k, + int64_t n, + int64_t block_size_k, + int64_t block_size_n) { + const int64_t blocks_k = k / block_size_k; const int64_t blocks_n = n / block_size_n; - std::vector expected(b_data.size()); - for (int64_t row = 0; row < k; ++row) { + std::vector result(b_data.size()); + for (int64_t block_n = 0; block_n < blocks_n; ++block_n) { + const int64_t n_begin = block_n * block_size_n; + const int64_t n_end = n_begin + block_size_n; + for (int64_t block_k = 0; block_k < blocks_k; ++block_k) { + const int64_t k_begin = block_k * block_size_k; + const int64_t k_end = k_begin + block_size_k; + float max_abs = 0.0f; + for (int64_t kk = k_begin; kk < k_end; ++kk) { + for (int64_t nn = n_begin; nn < n_end; ++nn) { + max_abs = std::max(max_abs, std::fabs(b_data[static_cast(kk * n + nn)])); + } + } + const float scale = max_abs == 0.0f ? 1.0f : max_abs / Fp8MaxAbs(); + for (int64_t kk = k_begin; kk < k_end; ++kk) { + for (int64_t nn = n_begin; nn < n_end; ++nn) { + const size_t index = static_cast(kk * n + nn); + result[index] = QuantizeDequantize(b_data[index], scale); + } + } + } + } + return result; +} + +template +std::vector ComputeRuntimeBQuantized(const std::vector& b_data, + const std::vector& b_scale, + int64_t k, + int64_t n, + int64_t block_size_k, + int64_t block_size_n) { + const int64_t blocks_k = k / block_size_k; + std::vector result(b_data.size()); + for (int64_t kk = 0; kk < k; ++kk) { + const int64_t block_k = kk / block_size_k; + for (int64_t nn = 0; nn < n; ++nn) { + const int64_t block_n = nn / block_size_n; + const size_t index = static_cast(kk * n + nn); + result[index] = b_data[index].ToFloat() * b_scale[static_cast(block_n * blocks_k + block_k)]; + } + } + return result; +} + +std::vector ComputeMatMul(const std::vector& a_data, + const std::vector& b_data, + int64_t m, + int64_t n, + int64_t k, + float y_scale = 1.0f) { + std::vector y_data(static_cast(m * n), 0.0f); + for (int64_t row = 0; row < m; ++row) { for (int64_t col = 0; col < n; ++col) { - const int64_t scale_idx = (row / block_size_k) * blocks_n + (col / block_size_n); - const size_t data_idx = static_cast(row * n + col); - expected[data_idx] = QuantizeDequantizeE4M3(b_data[data_idx], b_scale[scale_idx]); + float sum = 0.0f; + for (int64_t kk = 0; kk < k; ++kk) { + sum += a_data[static_cast(row * k + kk)] * b_data[static_cast(kk * n + col)]; + } + y_data[static_cast(row * n + col)] = sum * y_scale; } } - return expected; + return y_data; } template -void RunRuntimeFp8BInput() { +void AddZeroPoint(OpTester& test, const char* name, const std::vector& shape, size_t count, bool initializer) { + test.AddInput(name, shape, std::vector(count, Fp8T(0.0f)), initializer); +} + +template +void RunConstantBInputs() { constexpr int64_t M = 128; constexpr int64_t N = 128; constexpr int64_t K = 128; std::vector a_data(static_cast(M * K), 0.25f); - std::vector b_data(static_cast(K * N), Fp8T(-0.5f)); - const float expected_value = 0.25f * -0.5f * static_cast(K); - std::vector y_data(static_cast(M * N), expected_value); - - std::vector a_scale{1.0f}; - std::vector b_scale{1.0f}; - std::vector a_zp{Fp8T(0.0f)}; - std::vector b_zp{Fp8T(0.0f)}; - std::vector y_scale{1.0f}; - std::vector y_zp{Fp8T(0.0f)}; + std::vector b_data(static_cast(K * N), -0.5f); + const auto a_quantized = ComputeRowwiseAQuantized(a_data, M, K, 128); + const auto b_quantized = ComputeConstantBQuantized(b_data, K, N, 128, 128); + const auto y_data = ComputeMatMul(a_quantized, b_quantized, M, N, K); OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("fp8_type", Fp8TensorProtoType::value); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); - test.AddInput("B", {K, N}, b_data); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); test.AddOutput("Y", {M, N}, y_data); test.SetOutputAbsErr("Y", 0.5f); test.Run(); } template -void RunConstantBInputs() { +void RunRuntimeFp8BInput() { constexpr int64_t M = 128; constexpr int64_t N = 128; constexpr int64_t K = 128; std::vector a_data(static_cast(M * K), 0.25f); - std::vector b_data(static_cast(K * N), -0.5f); - const float expected_value = 0.25f * -0.5f * static_cast(K); - std::vector y_data(static_cast(M * N), expected_value); - - std::vector a_scale{1.0f}; + std::vector b_data(static_cast(K * N), Fp8T(-0.5f)); std::vector b_scale{1.0f}; - std::vector a_zp{Fp8T(0.0f)}; - std::vector b_zp{Fp8T(0.0f)}; - std::vector y_scale{1.0f}; - std::vector y_zp{Fp8T(0.0f)}; + const auto a_quantized = ComputeRowwiseAQuantized(a_data, M, K, 128); + const auto b_quantized = ComputeRuntimeBQuantized(b_data, b_scale, K, N, 128, 128); + const auto y_data = ComputeMatMul(a_quantized, b_quantized, M, N, K); OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); - test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); test.AddOutput("Y", {M, N}, y_data); test.SetOutputAbsErr("Y", 0.5f); test.Run(); @@ -162,28 +278,38 @@ TEST(DynamicQuantMatMulFp8, WithConstantBInputsE5M2FNUZ) { RunConstantBInputs(); } +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInput) { + RunRuntimeFp8BInput(); +} + +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE4M3FNUZ) { + RunRuntimeFp8BInput(); +} + +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE5M2) { + RunRuntimeFp8BInput(); +} + +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE5M2FNUZ) { + RunRuntimeFp8BInput(); +} + TEST(DynamicQuantMatMulFp8, WithOmittedOutputQuantizationInputs) { constexpr int64_t M = 128; constexpr int64_t N = 128; constexpr int64_t K = 128; std::vector a_data(static_cast(M * K), 0.25f); - std::vector b_data(static_cast(K * N), -0.5f); - const float expected_value = 0.25f * -0.5f * static_cast(K); - std::vector y_data(static_cast(M * N), expected_value); - - std::vector a_scale{1.0f}; + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); std::vector b_scale{1.0f}; - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_zp{Float8E4M3FN(0.0f)}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeRuntimeBQuantized(b_data, b_scale, K, N, 128, 128), M, N, K); OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); - test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); test.AddOutput("Y", {M, N}, y_data); test.SetOutputAbsErr("Y", 0.5f); test.Run(); @@ -193,26 +319,20 @@ TEST(DynamicQuantMatMulFp8, WithOnlyYScale) { constexpr int64_t M = 128; constexpr int64_t N = 128; constexpr int64_t K = 128; + constexpr float YScale = 0.5f; std::vector a_data(static_cast(M * K), 0.25f); - std::vector b_data(static_cast(K * N), -0.5f); - constexpr float y_scale_value = 0.5f; - const float expected_value = 0.25f * -0.5f * static_cast(K) * y_scale_value; - std::vector y_data(static_cast(M * N), expected_value); - - std::vector a_scale{1.0f}; + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); std::vector b_scale{1.0f}; - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_zp{Float8E4M3FN(0.0f)}; - std::vector y_scale{y_scale_value}; + std::vector y_scale{YScale}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeRuntimeBQuantized(b_data, b_scale, K, N, 128, 128), M, N, K, YScale); OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); - test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); test.AddInput("Y_scale", {}, y_scale); test.AddOutput("Y", {M, N}, y_data); test.SetOutputAbsErr("Y", 0.5f); @@ -225,22 +345,16 @@ TEST(DynamicQuantMatMulFp8, RejectsNonZeroYZeroPoint) { constexpr int64_t K = 128; std::vector a_data(static_cast(M * K), 0.25f); - std::vector b_data(static_cast(K * N), -0.5f); - std::vector y_data(static_cast(M * N), 0.0f); - - std::vector a_scale{1.0f}; + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); std::vector b_scale{1.0f}; - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_zp{Float8E4M3FN(0.0f)}; + std::vector y_data(static_cast(M * N), 0.0f); std::vector y_zp{Float8E4M3FN(1.0f)}; OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); - test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); test.AddOptionalInputEdge(); test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, y_data); @@ -248,32 +362,25 @@ TEST(DynamicQuantMatMulFp8, RejectsNonZeroYZeroPoint) { "DynamicQuantMatMulFp8 supports symmetric quantization only; Y zero point values must be zero."); } -TEST(DynamicQuantMatMulFp8, WithConstantBInputsBf16Scales) { +TEST(DynamicQuantMatMulFp8, WithRuntimeBInputsBf16Scales) { constexpr int64_t M = 128; constexpr int64_t N = 128; constexpr int64_t K = 128; std::vector a_data(static_cast(M * K), 0.25f); - std::vector b_data(static_cast(K * N), -0.5f); - const float expected_value = 0.25f * -0.5f * static_cast(K); - std::vector y_data(static_cast(M * N), expected_value); - - std::vector a_scale = MakeBFloat16({1.0f}); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale_float{1.0f}; std::vector b_scale = MakeBFloat16({1.0f}); - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_zp{Float8E4M3FN(0.0f)}; std::vector y_scale = MakeBFloat16({1.0f}); - std::vector y_zp{Float8E4M3FN(0.0f)}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeRuntimeBQuantized(b_data, b_scale_float, K, N, 128, 128), M, N, K); OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); - test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, y_data); test.SetOutputAbsErr("Y", 0.5f); test.Run(); @@ -286,25 +393,12 @@ TEST(DynamicQuantMatMulFp8, Float16Output) { std::vector a_data(static_cast(M * K), 0.25f); std::vector b_data(static_cast(K * N), -0.5f); - const float expected_value = 0.25f * -0.5f * static_cast(K); - std::vector y_data(static_cast(M * N), expected_value); - - std::vector a_scale{1.0f}; - std::vector b_scale{1.0f}; - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_zp{Float8E4M3FN(0.0f)}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeConstantBQuantized(b_data, K, N, 128, 128), M, N, K); OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, FloatsToMLFloat16s(a_data)); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, FloatsToMLFloat16s(y_data)); test.Run(); } @@ -316,25 +410,12 @@ TEST(DynamicQuantMatMulFp8, BFloat16Output) { std::vector a_data(static_cast(M * K), 0.25f); std::vector b_data(static_cast(K * N), -0.5f); - const float expected_value = 0.25f * -0.5f * static_cast(K); - std::vector y_data(static_cast(M * N), expected_value); - - std::vector a_scale{1.0f}; - std::vector b_scale{1.0f}; - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_zp{Float8E4M3FN(0.0f)}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeConstantBQuantized(b_data, K, N, 128, 128), M, N, K); OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, FloatsToBFloat16s(a_data)); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, FloatsToBFloat16s(y_data)); test.Run(); } @@ -348,40 +429,12 @@ TEST(DynamicQuantMatMulFp8, RejectsNonConstantB) { std::vector b_data(static_cast(K * N), -0.5f); std::vector y_data(static_cast(M * N), 0.0f); - std::vector a_scale{1.0f}; - std::vector b_scale{1.0f}; - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_zp{Float8E4M3FN(0.0f)}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; - OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); test.AddInput("B", {K, N}, b_data); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, y_data); - test.Run(OpTester::ExpectResult::kExpectFailure); -} - -TEST(DynamicQuantMatMulFp8, RuntimeFp8BInput) { - RunRuntimeFp8BInput(); -} - -TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE4M3FNUZ) { - RunRuntimeFp8BInput(); -} - -TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE5M2) { - RunRuntimeFp8BInput(); -} - -TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE5M2FNUZ) { - RunRuntimeFp8BInput(); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires runtime B input to be FP8."); } TEST(DynamicQuantMatMulFp8, RejectsRuntimeFp8BZeroPointTypeMismatch) { @@ -391,56 +444,18 @@ TEST(DynamicQuantMatMulFp8, RejectsRuntimeFp8BZeroPointTypeMismatch) { std::vector a_data(static_cast(M * K), 0.25f); std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); - std::vector y_data(static_cast(M * N), 0.0f); - - std::vector a_scale{1.0f}; std::vector b_scale{1.0f}; - std::vector a_zp{Float8E4M3FN(0.0f)}; std::vector b_zp{Float8E5M2(0.0f)}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; - - OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); - test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); - test.AddInput("B", {K, N}, b_data); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); - test.AddOutput("Y", {M, N}, y_data); - test.Run(OpTester::ExpectResult::kExpectFailure); -} - -TEST(DynamicQuantMatMulFp8, RejectsNonZeroAZeroPoint) { - constexpr int64_t M = 128; - constexpr int64_t N = 128; - constexpr int64_t K = 128; - - std::vector a_data(static_cast(M * K), 0.25f); - std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); std::vector y_data(static_cast(M * N), 0.0f); - std::vector a_scale{1.0f}; - std::vector b_scale{1.0f}; - std::vector a_zp{Float8E4M3FN(1.0f)}; - std::vector b_zp{Float8E4M3FN(0.0f)}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; - OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); test.AddInput("B", {K, N}, b_data); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddInput("B_zero_point", {N / 128, K / 128}, b_zp); test.AddOutput("Y", {M, N}, y_data); test.Run(OpTester::ExpectResult::kExpectFailure, - "DynamicQuantMatMulFp8 supports symmetric quantization only; A zero point values must be zero."); + "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); } TEST(DynamicQuantMatMulFp8, RejectsNonZeroBZeroPoint) { @@ -449,25 +464,16 @@ TEST(DynamicQuantMatMulFp8, RejectsNonZeroBZeroPoint) { constexpr int64_t K = 128; std::vector a_data(static_cast(M * K), 0.25f); - std::vector b_data(static_cast(K * N), -0.5f); - std::vector y_data(static_cast(M * N), 0.0f); - - std::vector a_scale{1.0f}; + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); std::vector b_scale{1.0f}; - std::vector a_zp{Float8E4M3FN(0.0f)}; std::vector b_zp{Float8E4M3FN(1.0f)}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; + std::vector y_data(static_cast(M * N), 0.0f); OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); - test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddInput("B_zero_point", {N / 128, K / 128}, b_zp); test.AddOutput("Y", {M, N}, y_data); test.Run(OpTester::ExpectResult::kExpectFailure, "DynamicQuantMatMulFp8 supports symmetric quantization only; B zero point values must be zero."); @@ -477,66 +483,66 @@ TEST(DynamicQuantMatMulFp8, NonDefaultBlockSizesWithPartialM) { constexpr int64_t M = 9; constexpr int64_t N = 4; constexpr int64_t K = 4; + constexpr int64_t BlockK = 2; + constexpr int64_t BlockN = 2; std::vector a_data(static_cast(M * K), 0.0f); for (int64_t m = 0; m < M; ++m) { for (int64_t k = 0; k < K; ++k) { - a_data[static_cast(m * K + k)] = (m >= 4 && m <= 7) ? 0.04f : static_cast(1 << k); + a_data[static_cast(m * K + k)] = static_cast((m + 1) * (k + 1)) / 16.0f; } } - for (int64_t k = 0; k < K; ++k) { - a_data[static_cast(3 * K + k)] = 64.0f; - } - std::vector b_data(static_cast(K * N), 0.0f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(0.0f)); for (int64_t k = 0; k < K; ++k) { for (int64_t n = 0; n < N; ++n) { - b_data[static_cast(k * N + n)] = (k == n) ? 1.0f : 0.0f; + b_data[static_cast(k * N + n)] = Float8E4M3FN(k == n ? 1.0f : 0.0f); } } + std::vector b_scale{1.0f, 2.0f, + 3.0f, 4.0f}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, BlockK), + ComputeRuntimeBQuantized(b_data, b_scale, K, N, BlockK, BlockN), M, N, K); - const std::vector a_scale{1.0f, 1.0f, - 0.01f, 0.01f, - 1.0f, 1.0f}; - const std::vector b_scale{1.0f, 1.0f, - 1.0f, 1.0f}; - std::vector a_zp(a_scale.size(), Float8E4M3FN(0.0f)); - std::vector b_zp(b_scale.size(), Float8E4M3FN(0.0f)); - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("block_size_k", BlockK); + test.AddAttribute("block_size_n", BlockN); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / BlockN, K / BlockK}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / BlockN, K / BlockK}, b_scale.size(), false); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.01f); + test.Run(); +} +TEST(DynamicQuantMatMulFp8, RejectsUnsupportedBlockSizeM) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; std::vector y_data(static_cast(M * N), 0.0f); - for (int64_t m = 0; m < M; ++m) { - for (int64_t n = 0; n < N; ++n) { - float sum = 0.0f; - for (int64_t k = 0; k < K; ++k) { - sum += a_data[static_cast(m * K + k)] * b_data[static_cast(k * N + n)]; - } - y_data[static_cast(m * N + n)] = sum; - } - } OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); - test.AddAttribute("block_size_m", 4); - test.AddAttribute("block_size_k", 2); - test.AddAttribute("block_size_n", 2); + test.AddAttribute("block_size_m", 2); + test.AddAttribute("block_size_k", 4); + test.AddAttribute("block_size_n", 4); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {3, 2}, a_scale); - test.AddInput("A_zero_point", {3, 2}, a_zp); - test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {2, 2}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {2, 2}, b_zp, true /*initializer*/); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {1, 1}, b_scale); + AddZeroPoint(test, "B_zero_point", {1, 1}, b_scale.size(), false); test.AddOutput("Y", {M, N}, y_data); - test.SetOutputAbsErr("Y", 0.01f); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectFailure, "block_size_m must be 1"); } TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsRestoresPackedBMetadata) { constexpr int64_t M = 4; constexpr int64_t N = 4; constexpr int64_t K = 4; + constexpr int64_t BlockSize = 2; std::vector a_data{ 1.0f, 2.0f, 3.0f, 4.0f, @@ -548,33 +554,16 @@ TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsRestoresPackedBMetadata) { 0.0f, 2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f}; - - std::vector y_data(a_data.size()); - for (size_t i = 0; i < a_data.size(); ++i) { - y_data[i] = 2.0f * a_data[i]; - } - - std::vector a_scale{1.0f, 1.0f, - 1.0f, 1.0f}; - std::vector b_scale{1.0f, 1.0f, - 1.0f, 1.0f}; - std::vector a_zp(a_scale.size(), Float8E5M2(0.0f)); - std::vector b_zp(b_scale.size(), Float8E5M2(0.0f)); - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E5M2(0.0f)}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, BlockSize), + ComputeConstantBQuantized(b_data, K, N, BlockSize, BlockSize), + M, N, K); OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); - test.AddAttribute("block_size_m", 2); - test.AddAttribute("block_size_k", 2); - test.AddAttribute("block_size_n", 2); + test.AddAttribute("fp8_type", Fp8TensorProtoType::value); + test.AddAttribute("block_size_k", BlockSize); + test.AddAttribute("block_size_n", BlockSize); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {2, 2}, a_scale); - test.AddInput("A_zero_point", {2, 2}, a_zp); test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {2, 2}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {2, 2}, b_zp, true /*initializer*/); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, y_data); test.SetOutputAbsErr("Y", 0.01f); @@ -606,7 +595,7 @@ TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsRestoresPackedBMetadata) { ASSERT_EQ(shared_prepack_count, prepack_count_session_2); } -TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsWithDifferentBScaleKeepCorrectSemantics) { +TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsWithComputedBScalesReuseCorrectly) { constexpr int64_t M = 4; constexpr int64_t N = 4; constexpr int64_t K = 4; @@ -622,22 +611,9 @@ TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsWithDifferentBScaleKeepCorrect 1.20f, -2.30f, 3.40f, -4.50f, 5.50f, -6.25f, 7.75f, -8.50f, 9.00f, -10.50f, 11.25f, -12.75f}; - - std::vector a_scale{1.0f, 1.0f, - 1.0f, 1.0f}; - std::vector a_zp(a_scale.size(), Float8E4M3FN(0.0f)); - std::vector b_zp(a_scale.size(), Float8E4M3FN(0.0f)); - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; - - std::vector b_scale_1{1.0f, 1.0f, - 1.0f, 1.0f}; - std::vector b_scale_2{0.10f, 0.25f, - 0.50f, 2.00f}; - std::vector y_data_1 = ComputeExpectedIdentityAWithQuantizedB(b_data, b_scale_1, K, N, - BlockSize, BlockSize); - std::vector y_data_2 = ComputeExpectedIdentityAWithQuantizedB(b_data, b_scale_2, K, N, - BlockSize, BlockSize); + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, BlockSize), + ComputeConstantBQuantized(b_data, K, N, BlockSize, BlockSize), + M, N, K); OrtValue b; Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({K, N}), b_data.data(), @@ -646,18 +622,11 @@ TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsWithDifferentBScaleKeepCorrect PrepackedWeightsContainer prepacked_weights_container; DynamicQuantMatMulFp8SessionTester test_1("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); - test_1.AddAttribute("block_size_m", BlockSize); test_1.AddAttribute("block_size_k", BlockSize); test_1.AddAttribute("block_size_n", BlockSize); test_1.AddInput("A", {M, K}, a_data); - test_1.AddInput("A_scale", {2, 2}, a_scale); - test_1.AddInput("A_zero_point", {2, 2}, a_zp); test_1.AddInput("B", {K, N}, b_data, true /*initializer*/); - test_1.AddInput("B_scale", {2, 2}, b_scale_1, true /*initializer*/); - test_1.AddInput("B_zero_point", {2, 2}, b_zp, true /*initializer*/); - test_1.AddInput("Y_scale", {}, y_scale); - test_1.AddInput("Y_zero_point", {}, y_zp); - test_1.AddOutput("Y", {M, N}, y_data_1); + test_1.AddOutput("Y", {M, N}, y_data); test_1.SetOutputAbsErr("Y", 1e-5f); size_t shared_prepack_count = 0; @@ -666,95 +635,42 @@ TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsWithDifferentBScaleKeepCorrect ASSERT_EQ(prepacked_weights_container.GetNumberOfElements(), static_cast(1)); DynamicQuantMatMulFp8SessionTester test_2("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); - test_2.AddAttribute("block_size_m", BlockSize); test_2.AddAttribute("block_size_k", BlockSize); test_2.AddAttribute("block_size_n", BlockSize); test_2.AddInput("A", {M, K}, a_data); - test_2.AddInput("A_scale", {2, 2}, a_scale); - test_2.AddInput("A_zero_point", {2, 2}, a_zp); test_2.AddInput("B", {K, N}, b_data, true /*initializer*/); - test_2.AddInput("B_scale", {2, 2}, b_scale_2, true /*initializer*/); - test_2.AddInput("B_zero_point", {2, 2}, b_zp, true /*initializer*/); - test_2.AddInput("Y_scale", {}, y_scale); - test_2.AddInput("Y_zero_point", {}, y_zp); - test_2.AddOutput("Y", {M, N}, y_data_2); + test_2.AddOutput("Y", {M, N}, y_data); test_2.SetOutputAbsErr("Y", 1e-5f); RunDynamicQuantMatMulFp8WithSharedPrepack(test_2, b, prepacked_weights_container, shared_prepack_count); - ASSERT_EQ(shared_prepack_count, static_cast(0)); - ASSERT_EQ(prepacked_weights_container.GetNumberOfElements(), static_cast(2)); -} - -TEST(DynamicQuantMatMulFp8, RejectsMismatchedAScaleBatchPrefix) { - constexpr int64_t Batch = 2; - constexpr int64_t Seq = 3; - constexpr int64_t M = 128; - constexpr int64_t N = 128; - constexpr int64_t K = 128; - - std::vector a_data(static_cast(Batch * Seq * M * K), 0.25f); - std::vector b_data(static_cast(K * N), Float8E4M3FN(0.5f)); - std::vector y_data(static_cast(Batch * Seq * M * N), 0.0f); - - std::vector a_scale(static_cast(Seq * Batch), 1.0f); - std::vector a_zp(a_scale.size(), Float8E4M3FN(0.0f)); - std::vector b_scale{1.0f}; - std::vector b_zp{Float8E4M3FN(0.0f)}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; - - const std::vector a_dim_params{"batch", "seq", "128", "128"}; - const std::vector a_scale_dim_params{"seq", "batch", "1", "1"}; - - OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); - test.AddInput("A", {Batch, Seq, M, K}, a_data, false, &a_dim_params); - test.AddInput("A_scale", {Seq, Batch, 1, 1}, a_scale, false, &a_scale_dim_params); - test.AddInput("A_zero_point", {Seq, Batch, 1, 1}, a_zp, false, &a_scale_dim_params); - test.AddInput("B", {K, N}, b_data); - test.AddInput("B_scale", {1, 1}, b_scale); - test.AddInput("B_zero_point", {1, 1}, b_zp); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); - test.AddOutput("Y", {Batch, Seq, M, N}, y_data); - test.Run(OpTester::ExpectResult::kExpectFailure, - "DynamicQuantMatMulFp8 requires A scale batch dimensions to match A."); + ASSERT_GT(shared_prepack_count, static_cast(0)); + ASSERT_EQ(prepacked_weights_container.GetNumberOfElements(), static_cast(1)); } -TEST(DynamicQuantMatMulFp8, RejectsMalformedAScaleShapeBeforeReadingScaleData) { +TEST(DynamicQuantMatMulFp8, RejectsMalformedBScaleShapeBeforeReadingScaleData) { constexpr int64_t M = 4; constexpr int64_t N = 4; constexpr int64_t K = 4; std::vector a_data(static_cast(M * K), 0.25f); std::vector b_data(static_cast(K * N), Float8E4M3FN(0.5f)); - std::vector y_data(static_cast(M * N), 0.0f); - - std::vector a_scale = FloatsToMLFloat16s({1.0f}); - std::vector a_zp{Float8E4M3FN(0.0f)}; std::vector b_scale = FloatsToMLFloat16s({1.0f}); - std::vector b_zp{Float8E4M3FN(0.0f)}; - std::vector y_scale = FloatsToMLFloat16s({1.0f}); - std::vector y_zp{Float8E4M3FN(0.0f)}; + std::vector y_data(static_cast(M * N), 0.0f); OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); - test.AddAttribute("block_size_m", 4); test.AddAttribute("block_size_k", 4); test.AddAttribute("block_size_n", 4); test.AddShapeToTensorData(false); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {1}, a_scale); - test.AddInput("A_zero_point", {1}, a_zp); test.AddInput("B", {K, N}, b_data); - test.AddInput("B_scale", {1, 1}, b_scale); - test.AddInput("B_zero_point", {1, 1}, b_zp); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); + test.AddInput("B_scale", {1}, b_scale); + AddZeroPoint(test, "B_zero_point", {1, 1}, 1, false); test.AddOutput("Y", {M, N}, y_data); test.Run(OpTester::ExpectResult::kExpectFailure, - "DynamicQuantMatMulFp8 requires A scale to have rank >= 2."); + "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor."); } -TEST(DynamicQuantMatMulFp8, RejectsMalformedBScaleShapeBeforeReadingScaleData) { +TEST(DynamicQuantMatMulFp8, RejectsRuntimeFp8BWithoutBScale) { constexpr int64_t M = 4; constexpr int64_t N = 4; constexpr int64_t K = 4; @@ -763,29 +679,14 @@ TEST(DynamicQuantMatMulFp8, RejectsMalformedBScaleShapeBeforeReadingScaleData) { std::vector b_data(static_cast(K * N), Float8E4M3FN(0.5f)); std::vector y_data(static_cast(M * N), 0.0f); - std::vector a_scale = FloatsToMLFloat16s({1.0f}); - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_scale = FloatsToMLFloat16s({1.0f}); - std::vector b_zp{Float8E4M3FN(0.0f)}; - std::vector y_scale = FloatsToMLFloat16s({1.0f}); - std::vector y_zp{Float8E4M3FN(0.0f)}; - OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); - test.AddAttribute("block_size_m", 4); test.AddAttribute("block_size_k", 4); test.AddAttribute("block_size_n", 4); - test.AddShapeToTensorData(false); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {1, 1}, a_scale); - test.AddInput("A_zero_point", {1, 1}, a_zp); test.AddInput("B", {K, N}, b_data); - test.AddInput("B_scale", {1}, b_scale); - test.AddInput("B_zero_point", {1}, b_zp); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, y_data); test.Run(OpTester::ExpectResult::kExpectFailure, - "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor."); + "DynamicQuantMatMulFp8 requires B scale when B is already FP8."); } TEST(DynamicQuantMatMulFp8, ZeroMInput) { @@ -793,26 +694,13 @@ TEST(DynamicQuantMatMulFp8, ZeroMInput) { constexpr int64_t N = 128; constexpr int64_t K = 128; - std::vector a_data(static_cast(M * K), 0.0f); + std::vector a_data{}; std::vector b_data(static_cast(K * N), 0.0f); std::vector y_data{}; - std::vector a_scale{}; - std::vector b_scale{1.0f}; - std::vector a_zp{}; - std::vector b_zp{Float8E4M3FN(0.0f)}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; - OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, y_data); test.Run(); } @@ -826,22 +714,9 @@ TEST(DynamicQuantMatMulFp8, ZeroKInput) { std::vector b_data{}; std::vector y_data(static_cast(M * N), 0.0f); - std::vector a_scale{}; - std::vector b_scale{}; - std::vector a_zp{}; - std::vector b_zp{}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; - OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, y_data); test.Run(); } @@ -854,24 +729,15 @@ TEST(DynamicQuantMatMulFp8, ZeroKInputRejectsInvalidYScaleShape) { std::vector a_data{}; std::vector b_data{}; std::vector y_data(static_cast(M * N), 0.0f); - - std::vector a_scale{}; - std::vector b_scale{}; - std::vector a_zp{}; - std::vector b_zp{}; std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddShapeToTensorData(false); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); test.AddInput("Y_scale", {1}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, y_data); test.Run(OpTester::ExpectResult::kExpectFailure, "DynamicQuantMatMulFp8 requires Y scale input to be a scalar."); @@ -885,23 +751,14 @@ TEST(DynamicQuantMatMulFp8, ZeroKInputRejectsInvalidYScaleValue) { std::vector a_data{}; std::vector b_data{}; std::vector y_data(static_cast(M * N), 0.0f); - - std::vector a_scale{}; - std::vector b_scale{}; - std::vector a_zp{}; - std::vector b_zp{}; std::vector y_scale{0.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, y_data); test.Run(OpTester::ExpectResult::kExpectFailure, "Y scale values to be finite and positive."); @@ -915,26 +772,16 @@ TEST(DynamicQuantMatMulFp8, ZeroKInputRejectsInvalidYScaleType) { std::vector a_data{}; std::vector b_data{}; std::vector y_data(static_cast(M * N), 0.0f); - - std::vector a_scale{}; - std::vector b_scale{}; - std::vector a_zp{}; - std::vector b_zp{}; std::vector y_scale{1}; - std::vector y_zp{Float8E4M3FN(0.0f)}; OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); test.AddInput("B", {K, N}, b_data, true /*initializer*/); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, y_data); - test.Run(OpTester::ExpectResult::kExpectFailure, - "Type Error"); + test.Run(OpTester::ExpectResult::kExpectFailure, "Type Error"); } TEST(DynamicQuantMatMulFp8, ZeroNInput) { @@ -945,23 +792,12 @@ TEST(DynamicQuantMatMulFp8, ZeroNInput) { std::vector a_data(static_cast(M * K), 0.0f); std::vector b_data{}; std::vector y_data{}; - - std::vector a_scale{1.0f}; std::vector b_scale{}; - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_zp{}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); test.AddInput("B", {K, N}, b_data); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); test.AddOutput("Y", {M, N}, y_data); test.Run(); } @@ -974,29 +810,21 @@ TEST(DynamicQuantMatMulFp8, ZeroNInputRejectsInvalidYScaleValue) { std::vector a_data(static_cast(M * K), 0.0f); std::vector b_data{}; std::vector y_data{}; - - std::vector a_scale{1.0f}; std::vector b_scale{}; - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_zp{}; std::vector y_scale{0.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); test.AddInput("B", {K, N}, b_data); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddOptionalInputEdge(); test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); test.AddOutput("Y", {M, N}, y_data); test.Run(OpTester::ExpectResult::kExpectFailure, "Y scale values to be finite and positive."); } -TEST(DynamicQuantMatMulFp8, ZeroNInputRejectsNonFp8B) { +TEST(DynamicQuantMatMulFp8, ZeroNInputWithConstantNonFp8B) { constexpr int64_t M = 128; constexpr int64_t N = 0; constexpr int64_t K = 128; @@ -1005,25 +833,11 @@ TEST(DynamicQuantMatMulFp8, ZeroNInputRejectsNonFp8B) { std::vector b_data{}; std::vector y_data{}; - std::vector a_scale{1.0f}; - std::vector b_scale{}; - std::vector a_zp{Float8E4M3FN(0.0f)}; - std::vector b_zp{}; - std::vector y_scale{1.0f}; - std::vector y_zp{Float8E4M3FN(0.0f)}; - OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); test.AddInput("A", {M, K}, a_data); - test.AddInput("A_scale", {M / 128, K / 128}, a_scale); - test.AddInput("A_zero_point", {M / 128, K / 128}, a_zp); - test.AddInput("B", {K, N}, b_data); - test.AddInput("B_scale", {K / 128, N / 128}, b_scale, true /*initializer*/); - test.AddInput("B_zero_point", {K / 128, N / 128}, b_zp, true /*initializer*/); - test.AddInput("Y_scale", {}, y_scale); - test.AddInput("Y_zero_point", {}, y_zp); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); test.AddOutput("Y", {M, N}, y_data); - test.Run(OpTester::ExpectResult::kExpectFailure, - "DynamicQuantMatMulFp8 requires runtime B input to be FP8."); + test.Run(); } } // namespace test From 3696f3673e1f581cd1be5d914081d00744f4fe5d Mon Sep 17 00:00:00 2001 From: melkap01 Date: Thu, 21 May 2026 13:06:16 +0100 Subject: [PATCH 11/11] review comments addressed, docs patched Signed-off-by: melkap01 --- docs/ContribOperators.md | 8 +++---- .../quantization/dynamic_quant_matmul_fp8.cc | 17 ++++++++------- .../graph/contrib_ops/quantization_defs.cc | 12 +++++------ onnxruntime/core/mlas/inc/mlas.h | 6 ++++-- .../dynamic_quant_matmul_fp8_test.cc | 21 +++++++++++++++++++ 5 files changed, 44 insertions(+), 20 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 3453e32239571..454051fe896fa 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1505,13 +1505,13 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
-
block_size_k : int (default is 128)
+
block_size_k : int
Block size along K for A and B block-wise scales.
-
block_size_m : int (default is 1)
+
block_size_m : int
Block size along M for A block-wise scales. Must be 1.
-
block_size_n : int (default is 128)
+
block_size_n : int
Block size along N for B block-wise scales.
-
fp8_type : int (default is 17)
+
fp8_type : int
FP8 TensorProto data type used when non-FP8 constant B is dynamically quantized during prepack. Defaults to FLOAT8E4M3FN.
diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc index bc0204863eb59..350de4ac9cfa3 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -577,18 +577,19 @@ Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { b_type = b_type_; } else if (b_is_fp8) { ORT_RETURN_IF_ERROR(GetFp8Type(b_elem_type, b_type)); - if (b_zero_point != nullptr) { - const auto b_zp_elem_type = - static_cast(b_zero_point->GetElementType()); - mlas_fp8_mode b_zp_type{}; - ORT_RETURN_IF_ERROR(GetFp8Type(b_zp_elem_type, b_zp_type)); - ORT_RETURN_IF(b_type != b_zp_type, - "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); - } } else { b_type = fp8_type_; } + if (b_zero_point != nullptr) { + const auto b_zp_elem_type = + static_cast(b_zero_point->GetElementType()); + mlas_fp8_mode b_zp_type{}; + ORT_RETURN_IF_ERROR(GetFp8Type(b_zp_elem_type, b_zp_type)); + ORT_RETURN_IF(b_type != b_zp_type, + "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); + } + if (y_size == 0) { return Status::OK(); } diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 8601b53cc15a1..86800e4d1ec09 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -1008,6 +1008,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( if (block_size_k <= 0 || block_size_n <= 0) { fail_type_inference("block_size_k and block_size_n must be greater than zero."); } + if (hasInputShape(ctx, 1)) { + auto& b_shape = getInputShape(ctx, 1); + if (b_shape.dim_size() != 2) { + fail_type_inference("B must be 2D."); + } + } if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1); } @@ -1026,9 +1032,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( if (hasInputShape(ctx, 1) && hasInputShape(ctx, 2)) { auto& b_shape = getInputShape(ctx, 1); auto& b_scale_shape = getInputShape(ctx, 2); - if (b_shape.dim_size() != 2) { - fail_type_inference("B must be 2D."); - } if (b_scale_shape.dim_size() != 2) { fail_type_inference("B scale must be 2D."); } @@ -1048,9 +1051,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( if (hasInputShape(ctx, 1) && hasInputShape(ctx, 3)) { auto& b_shape = getInputShape(ctx, 1); auto& b_zp_shape = getInputShape(ctx, 3); - if (b_shape.dim_size() != 2) { - fail_type_inference("B must be 2D."); - } if (b_zp_shape.dim_size() != 2) { fail_type_inference("B zero point must be 2D."); } diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index d47773ecac86a..866b4430a9ce7 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -758,8 +758,10 @@ struct MLAS_FP8_GEMM_DATA_PARAMS { size_t ldb = 0; void* C = nullptr; size_t ldc = 0; - const float* ScaleA = nullptr; // Tile scales for A: [BlocksM, BlocksK]. - const float* ScaleB = nullptr; // Tile scales for B: [BlocksK, BlocksN]. + // Block-wise scales for A, indexed as block_m * ScaleAStrideM + block_k * ScaleAStrideK. + const float* ScaleA = nullptr; + // Block-wise scales for B, indexed as block_k * ScaleBStrideK + block_n * ScaleBStrideN. + const float* ScaleB = nullptr; const float* ScaleY = nullptr; // Scalar scale for Y. mlas_fp8_mode Fp8Type = static_cast(0); size_t BlockSizeM = 128; // Block size along M for A quantization. diff --git a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc index 6026ea13c239d..763660460d855 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc @@ -458,6 +458,27 @@ TEST(DynamicQuantMatMulFp8, RejectsRuntimeFp8BZeroPointTypeMismatch) { "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); } +TEST(DynamicQuantMatMulFp8, RejectsConstantFp8BZeroPointTypeMismatch) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + std::vector b_zp{Float8E5M2(0.0f)}; + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddInput("B_zero_point", {N / 128, K / 128}, b_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); +} + TEST(DynamicQuantMatMulFp8, RejectsNonZeroBZeroPoint) { constexpr int64_t M = 128; constexpr int64_t N = 128;