diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 7ae18db235ccb..7edfdbf504732 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -24,6 +24,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sgemm.cpp ${MLAS_SRC_DIR}/halfgemm.cpp ${MLAS_SRC_DIR}/qgemm.cpp + ${MLAS_SRC_DIR}/qgemm_fp8.cpp ${MLAS_SRC_DIR}/qdwconv.cpp ${MLAS_SRC_DIR}/convolve.cpp ${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_greater_than_1.cpp diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 3608f1246450f..454051fe896fa 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -26,6 +26,7 @@ Do not modify directly.* * com.microsoft.DequantizeBFP * com.microsoft.DequantizeLinear * com.microsoft.DequantizeWithOrder + * com.microsoft.DynamicQuantMatMulFp8 * com.microsoft.DynamicQuantizeLSTM * com.microsoft.DynamicQuantizeMatMul * com.microsoft.DynamicTimeWarping @@ -1493,6 +1494,67 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.DynamicQuantMatMulFp8** + + Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from float16/bfloat16/float) and dynamic runtime quantization of activations to fp8 using internally computed block-wise scales. All zero-point inputs, when provided, must encode 0.0. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
block_size_k : int
+
Block size along K for A and B block-wise scales.
+
block_size_m : int
+
Block size along M for A block-wise scales. Must be 1.
+
block_size_n : int
+
Block size along N for B block-wise scales.
+
fp8_type : int
+
FP8 TensorProto data type used when non-FP8 constant B is dynamically quantized during prepack. Defaults to FLOAT8E4M3FN.
+
+ +#### Inputs (2 - 6) + +
+
A : TA
+
Input tensor A.
+
B : TB
+
Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only supported when B is a constant initializer that can be quantized during prepack.
+
B_scale (optional) : TS
+
Scale of FP8 input 'B'. Must be a block-wise tensor with shape (N / block_size_n, K / block_size_k). Required when B is already FP8. Ignored for non-FP8 constant B, where scales are computed during prepack.
+
B_zero_point (optional) : TZ
+
Zero point tensor for input 'B'. Must have the same shape as B_scale and all values must encode 0.0.
+
Y_scale (optional) : TS
+
Scale of output 'Y'. Must be a scalar when provided.
+
Y_zero_point (optional) : TZ
+
Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided.
+
+ +#### Outputs + +
+
Y : TY
+
Output tensor of shape (..., M, N).
+
+ +#### Type Constraints + +
+
TA : tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain input A type to float16, bfloat16, or float.
+
TB : tensor(float16), tensor(bfloat16), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
+
Constrain input B type to fp8, or to float16, bfloat16, or float for constant initializers.
+
TZ : tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
+
Constrain zero point types to fp8. Only zero-valued zero points are supported.
+
TS : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain scale types to float, float16, or bfloat16.
+
TY : tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain output type to float16, bfloat16, or float.
+
+ + ### **com.microsoft.DynamicQuantizeLSTM** #### Version @@ -6690,5 +6752,3 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
- - diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d0d8e750285d4..969083aba3aa0 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -577,6 +577,7 @@ The **OpSet Version** column uses the following notation: |CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)| |DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**T** = tensor(float)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float)| +|DynamicQuantMatMulFp8|*in* A:**TA**
*in* B:**TB**
*in* B_scale:**TS**
*in* B_zero_point:**TZ**
*in* Y_scale:**TS**
*in* Y_zero_point:**TZ**
*out* Y:**TY**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**TS** = tensor(bfloat16), tensor(float), tensor(float16)
**TY** = tensor(bfloat16), tensor(float), tensor(float16)
**TZ** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)| |DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)| diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 0749457f5a182..b581064b180a2 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -116,6 +116,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, NhwcMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, NhwcMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicQuantMatMulFp8); +#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE); @@ -284,6 +287,9 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, +#if !defined(DISABLE_FLOAT8_TYPES) + BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc new file mode 100644 index 0000000000000..350de4ac9cfa3 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc @@ -0,0 +1,802 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "dynamic_quant_matmul_fp8.h" + +#include "core/common/common.h" +#include "core/common/fp8_common.h" +#include "core/framework/op_kernel.h" +#include "core/graph/onnx_protobuf.h" +#include "core/mlas/inc/mlas.h" +#include "core/common/float16.h" +#include "core/common/float8.h" +#include "core/common/safeint.h" +#include "core/platform/threadpool.h" +#include "core/providers/cpu/math/matmul_helper.h" + +#include +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +#if !defined(DISABLE_FLOAT8_TYPES) + +namespace { + +constexpr int64_t kDefaultBlockSize = 128; +constexpr int64_t kDefaultBlockSizeM = 1; +constexpr int64_t kPackedBMetadataVersion = 1; +constexpr size_t kPackedBMetadataElementCount = 7; +constexpr size_t kPackedBMetadataSize = kPackedBMetadataElementCount * sizeof(int64_t); + +enum PackedBMetadataIndex : size_t { + kPackedBMetadataVersionIndex = 0, + kPackedBMetadataRowsIndex, + kPackedBMetadataColsIndex, + kPackedBMetadataSizeIndex, + kPackedBMetadataScaleCountIndex, + kPackedBMetadataFp8ModeIndex, + kPackedBMetadataHasFp8ModeIndex, +}; + +bool IsFp8DataType(ONNX_NAMESPACE::TensorProto_DataType elem_type) { + return elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || + elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ || + elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2 || + elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ; +} + +bool IsValidFp8Mode(int64_t mode) { + return mode >= static_cast(MLAS_FP8_MODE_E4M3_INF) && + mode < static_cast(MLAS_FP8_MODE_END); +} + +Status RestorePackedBMetadata(const void* metadata_buffer, + size_t metadata_size, + size_t quantized_b_buffer_size, + size_t b_scale_buffer_size, + TensorShape& b_shape, + size_t& quantized_b_size, + size_t& b_scale_count, + mlas_fp8_mode& b_type, + bool& has_b_type) { + ORT_RETURN_IF(metadata_buffer == nullptr, + "DynamicQuantMatMulFp8 requires shared prepacked B metadata."); + ORT_RETURN_IF(metadata_size != kPackedBMetadataSize, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an unexpected size."); + + const auto* metadata = static_cast(metadata_buffer); + ORT_RETURN_IF(metadata[kPackedBMetadataVersionIndex] != kPackedBMetadataVersion, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an unsupported version."); + ORT_RETURN_IF(metadata[kPackedBMetadataRowsIndex] <= 0 || metadata[kPackedBMetadataColsIndex] <= 0, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B shape."); + ORT_RETURN_IF(metadata[kPackedBMetadataSizeIndex] <= 0, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B buffer size."); + ORT_RETURN_IF(metadata[kPackedBMetadataScaleCountIndex] <= 0, + "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid B scale count."); + ORT_RETURN_IF(metadata[kPackedBMetadataHasFp8ModeIndex] != 1 || + !IsValidFp8Mode(metadata[kPackedBMetadataFp8ModeIndex]), + "DynamicQuantMatMulFp8 shared prepacked B metadata has an invalid FP8 type."); + + const size_t rows = static_cast(metadata[kPackedBMetadataRowsIndex]); + const size_t cols = static_cast(metadata[kPackedBMetadataColsIndex]); + const size_t expected_quantized_b_size = SafeMul(rows, cols); + const size_t restored_quantized_b_size = static_cast(metadata[kPackedBMetadataSizeIndex]); + ORT_RETURN_IF(restored_quantized_b_size != expected_quantized_b_size || + restored_quantized_b_size != quantized_b_buffer_size, + "DynamicQuantMatMulFp8 shared prepacked B metadata does not match the B buffer size."); + const size_t restored_b_scale_count = static_cast(metadata[kPackedBMetadataScaleCountIndex]); + ORT_RETURN_IF(restored_b_scale_count > std::numeric_limits::max() / sizeof(float) || + restored_b_scale_count * sizeof(float) != b_scale_buffer_size, + "DynamicQuantMatMulFp8 shared prepacked B metadata does not match the B scale buffer size."); + + b_shape = TensorShape({metadata[kPackedBMetadataRowsIndex], metadata[kPackedBMetadataColsIndex]}); + quantized_b_size = restored_quantized_b_size; + b_scale_count = restored_b_scale_count; + b_type = static_cast(metadata[kPackedBMetadataFp8ModeIndex]); + has_b_type = true; + return Status::OK(); +} + +// Reject invalid scales before quantization divides by them or MLAS dequantizes with them. +Status ValidatePositiveFiniteScales(const float* scales, size_t count, const char* scale_name) { + for (size_t i = 0; i < count; ++i) { + ORT_RETURN_IF(!std::isfinite(scales[i]) || scales[i] <= 0.0f, + "DynamicQuantMatMulFp8 requires ", scale_name, " values to be finite and positive."); + } + return Status::OK(); +} + +Status GetFp8MaxAbs(mlas_fp8_mode mode, float& max_abs) { + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: + case MLAS_FP8_MODE_E4M3_SAT: + max_abs = 448.0f; + return Status::OK(); + case MLAS_FP8_MODE_E5M2_INF: + case MLAS_FP8_MODE_E5M2_SAT: + max_abs = 57344.0f; + return Status::OK(); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported fp8 mode for DynamicQuantMatMulFp8."); + } +} + +Status ValidateZeroPointValuesAreZero(const Tensor& zero_point, size_t expected_count, + const char* zero_point_name) { + const size_t actual_count = static_cast(zero_point.Shape().Size()); + ORT_RETURN_IF(actual_count != expected_count, + "DynamicQuantMatMulFp8 requires ", zero_point_name, " to have the expected number of elements."); + + const auto reject_non_zero = [zero_point_name](float value) { + ORT_RETURN_IF(value != 0.0f, + "DynamicQuantMatMulFp8 supports symmetric quantization only; ", + zero_point_name, " values must be zero."); + return Status::OK(); + }; + + if (zero_point.IsDataType()) { + const auto* zp = static_cast(zero_point.DataRaw()); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], MLAS_FP8_MODE_E4M3_INF))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = static_cast(zero_point.DataRaw()); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], MLAS_FP8_MODE_E4M3_SAT))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = static_cast(zero_point.DataRaw()); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], MLAS_FP8_MODE_E5M2_INF))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = static_cast(zero_point.DataRaw()); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(Fp8ByteToFloat(zp[i], MLAS_FP8_MODE_E5M2_SAT))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = zero_point.Data(); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(zp[i])); + } + } else if (zero_point.IsDataType()) { + const auto* zp = zero_point.Data(); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(static_cast(zp[i]))); + } + } else if (zero_point.IsDataType()) { + const auto* zp = zero_point.Data(); + for (size_t i = 0; i < actual_count; ++i) { + ORT_RETURN_IF_ERROR(reject_non_zero(static_cast(zp[i]))); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported zero point type for DynamicQuantMatMulFp8."); + } + return Status::OK(); +} + +template +void QuantizeBlockwiseFp8ABlockDynamic(const SrcT* src, + size_t K, + size_t block_size_k, + size_t blocks_k, + size_t row, + size_t block_k, + float fp8_max_abs, + mlas_fp8_mode mode, + uint8_t* dst, + float* scales) { + const size_t k_begin = block_k * block_size_k; + const size_t k_end = std::min(K, k_begin + block_size_k); + const size_t idx = row * blocks_k + block_k; + + // Match the KleidiAI LHS packer: one dynamic quantization scale per A row and K block. + float max_abs = 0.0f; + const size_t row_offset = row * K; + for (size_t k = k_begin; k < k_end; ++k) { + max_abs = std::max(max_abs, std::fabs(static_cast(src[row_offset + k]))); + } + + const float scale = max_abs == 0.0f ? 1.0f : max_abs / fp8_max_abs; + scales[idx] = scale; + + for (size_t k = k_begin; k < k_end; ++k) { + const float value = static_cast(src[row_offset + k]); + const float quantized = value / scale; + dst[row_offset + k] = FloatToFp8Byte(quantized, mode); + } +} + +template +void QuantizeBlockwiseFp8WithScales(const SrcT* src, + size_t K, + size_t N, + size_t block_size_k, + size_t block_size_n, + const float* scales, + uint8_t* dst) { + // Block sizes come from op attributes; scale shapes only provide the number of blocks. + const size_t blocks_k = K / block_size_k; + for (size_t k = 0; k < K; ++k) { + const size_t block_k = k / block_size_k; + const size_t row_offset = k * N; + for (size_t n = 0; n < N; ++n) { + const size_t block_n = n / block_size_n; + const size_t scale_idx = block_n * blocks_k + block_k; + const float scale = scales[scale_idx]; + const float value = static_cast(src[row_offset + n]); + const float quantized = value / scale; + const Fp8T fp8_value(quantized, true); + dst[row_offset + n] = fp8_value.val; + } + } +} + +template +void ComputeBlockwiseScalesFromInput(const SrcT* src, + size_t K, + size_t N, + size_t block_size_k, + size_t block_size_n, + float fp8_max_abs, + float* scales) { + // Reference-style dynamic quantization: derive one positive scale from each source block. + const size_t blocks_k = K / block_size_k; + const size_t blocks_n = N / block_size_n; + for (size_t block_k = 0; block_k < blocks_k; ++block_k) { + const size_t k_begin = block_k * block_size_k; + const size_t k_end = k_begin + block_size_k; + for (size_t block_n = 0; block_n < blocks_n; ++block_n) { + const size_t n_begin = block_n * block_size_n; + const size_t n_end = n_begin + block_size_n; + float max_abs = 0.0f; + for (size_t k = k_begin; k < k_end; ++k) { + const size_t row_offset = k * N; + for (size_t n = n_begin; n < n_end; ++n) { + max_abs = std::max(max_abs, std::fabs(static_cast(src[row_offset + n]))); + } + } + // Match KleidiAI RHS scale layout: one scale per N block and K block. + scales[block_n * blocks_k + block_k] = max_abs == 0.0f ? 1.0f : max_abs / fp8_max_abs; + } + } +} + +template +Status QuantizeToFp8ByModeWithScales(mlas_fp8_mode fp8_mode, + const SrcT* src, + size_t K, + size_t N, + size_t block_size_k, + size_t block_size_n, + const float* scales, + uint8_t* dst) { + // Dispatch quantization using the requested FP8 mode and runtime block sizes. + switch (fp8_mode) { + case MLAS_FP8_MODE_E4M3_INF: + QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst); + return Status::OK(); + case MLAS_FP8_MODE_E4M3_SAT: + QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst); + return Status::OK(); + case MLAS_FP8_MODE_E5M2_INF: + QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst); + return Status::OK(); + case MLAS_FP8_MODE_E5M2_SAT: + QuantizeBlockwiseFp8WithScales(src, K, N, block_size_k, block_size_n, scales, dst); + return Status::OK(); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported fp8 mode for DynamicQuantMatMulFp8."); + } +} + +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + DynamicQuantMatMulFp8, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("TA", std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("TB", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("TZ", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("TS", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("TY", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + DynamicQuantMatMulFp8); + +DynamicQuantMatMulFp8::DynamicQuantMatMulFp8(const OpKernelInfo& info) : OpKernel(info) { + const int64_t block_size_m = info.GetAttrOrDefault("block_size_m", kDefaultBlockSizeM); + const int64_t block_size_k = info.GetAttrOrDefault("block_size_k", kDefaultBlockSize); + const int64_t block_size_n = info.GetAttrOrDefault("block_size_n", kDefaultBlockSize); + const int64_t fp8_type = + info.GetAttrOrDefault("fp8_type", ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN); + ORT_ENFORCE(block_size_m == 1, + "DynamicQuantMatMulFp8 requires block_size_m to be 1 to match the row-wise A scale layout."); + ORT_ENFORCE(block_size_k > 0, + "DynamicQuantMatMulFp8 requires block_size_k to be greater than zero."); + ORT_ENFORCE(block_size_n > 0, + "DynamicQuantMatMulFp8 requires block_size_n to be greater than zero."); + block_size_k_ = static_cast(block_size_k); + block_size_n_ = static_cast(block_size_n); + ORT_THROW_IF_ERROR(GetFp8Type(static_cast(fp8_type), fp8_type_)); +} + +Status DynamicQuantMatMulFp8::GetFp8Type(const Tensor& tensor, mlas_fp8_mode& out_type) { + return GetFp8Type(static_cast(tensor.GetElementType()), out_type); +} + +Status DynamicQuantMatMulFp8::GetFp8Type(ONNX_NAMESPACE::TensorProto_DataType elem_type, + mlas_fp8_mode& out_type) { + switch (elem_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + out_type = MLAS_FP8_MODE_E4M3_INF; + return Status::OK(); + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: + out_type = MLAS_FP8_MODE_E4M3_SAT; + return Status::OK(); + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + out_type = MLAS_FP8_MODE_E5M2_INF; + return Status::OK(); + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: + out_type = MLAS_FP8_MODE_E5M2_SAT; + return Status::OK(); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported fp8 type for DynamicQuantMatMulFp8."); + } +} + +Status DynamicQuantMatMulFp8::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + if (input_idx != GetBIdx()) { + return Status::OK(); + } + + b_shape_ = tensor.Shape(); + if (b_shape_.NumDimensions() != 2) { + return Status::OK(); + } + + const size_t K = static_cast(b_shape_[0]); + const size_t N = static_cast(b_shape_[1]); + const auto b_elem_type = static_cast(tensor.GetElementType()); + const bool b_is_fp8 = IsFp8DataType(b_elem_type); + if (b_is_fp8) { + ORT_RETURN_IF_ERROR(GetFp8Type(tensor, b_type_)); + has_b_type_ = true; + return Status::OK(); + } + + b_type_ = fp8_type_; + has_b_type_ = true; + if (K == 0) { + return Status::OK(); + } + if (N == 0) { + return Status::OK(); + } + + ORT_RETURN_IF_NOT(K % block_size_k_ == 0, + "DynamicQuantMatMulFp8 requires K to be divisible by block_size_k."); + ORT_RETURN_IF_NOT(N % block_size_n_ == 0, + "DynamicQuantMatMulFp8 requires N to be divisible by block_size_n."); + const size_t blocks_k = K / block_size_k_; + const size_t blocks_n = N / block_size_n_; + b_scale_count_ = SafeMul(blocks_k, blocks_n); + b_scales_ = IAllocator::MakeUniquePtr(alloc, b_scale_count_, true); + float fp8_max_abs = 0.0f; + ORT_RETURN_IF_ERROR(GetFp8MaxAbs(b_type_, fp8_max_abs)); + + const size_t quantized_b_size = SafeMul(K, N); + quantized_b_ = IAllocator::MakeUniquePtr(alloc, quantized_b_size, true); + quantized_b_size_ = quantized_b_size; + auto* quantized_b_bytes = static_cast(quantized_b_.get()); + if (tensor.IsDataType()) { + ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_, + fp8_max_abs, b_scales_.get()); + ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_, + b_scales_.get(), quantized_b_bytes)); + } else if (tensor.IsDataType()) { + ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_, + fp8_max_abs, b_scales_.get()); + ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_, + b_scales_.get(), quantized_b_bytes)); + } else if (tensor.IsDataType()) { + ComputeBlockwiseScalesFromInput(tensor.Data(), K, N, block_size_k_, block_size_n_, + fp8_max_abs, b_scales_.get()); + ORT_RETURN_IF_ERROR(QuantizeToFp8ByModeWithScales(b_type_, tensor.Data(), K, N, block_size_k_, block_size_n_, + b_scales_.get(), quantized_b_bytes)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported B type for DynamicQuantMatMulFp8 prepack."); + } + + if (prepacked_weights != nullptr) { + const std::array metadata_values = { + kPackedBMetadataVersion, + b_shape_[0], + b_shape_[1], + static_cast(quantized_b_size_), + static_cast(b_scale_count_), + static_cast(b_type_), + 1, + }; + auto metadata = IAllocator::MakeUniquePtr(alloc, kPackedBMetadataSize, true); + std::memcpy(metadata.get(), metadata_values.data(), kPackedBMetadataSize); + auto b_scales_deleter = std::move(b_scales_.get_deleter()); + IAllocatorUniquePtr b_scales_buffer( + b_scales_.release(), + [b_scales_deleter = std::move(b_scales_deleter)](void* p) mutable { + b_scales_deleter(static_cast(p)); + }); + prepacked_weights->buffers_.push_back(std::move(quantized_b_)); + prepacked_weights->buffer_sizes_.push_back(quantized_b_size_); + prepacked_weights->buffers_.push_back(std::move(b_scales_buffer)); + prepacked_weights->buffer_sizes_.push_back(b_scale_count_ * sizeof(float)); + prepacked_weights->buffers_.push_back(std::move(metadata)); + prepacked_weights->buffer_sizes_.push_back(kPackedBMetadataSize); + } + is_packed = true; + return Status::OK(); +} + +Status DynamicQuantMatMulFp8::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + if (input_idx != GetBIdx()) { + return Status::OK(); + } + + ORT_RETURN_IF(prepacked_buffers.size() != 3 || prepacked_buffer_sizes.size() != 3, + "DynamicQuantMatMulFp8 requires shared prepacked B data, scale, and metadata buffers."); + ORT_RETURN_IF(prepacked_buffers[0].get() == nullptr, + "DynamicQuantMatMulFp8 requires shared prepacked B data."); + ORT_RETURN_IF(prepacked_buffers[1].get() == nullptr, + "DynamicQuantMatMulFp8 requires shared prepacked B scales."); + + // Buffer 0 owns quantized B bytes; buffer 1 owns computed B scales; buffer 2 restores kernel state. + ORT_RETURN_IF_ERROR(RestorePackedBMetadata(prepacked_buffers[2].get(), + prepacked_buffer_sizes[2], + prepacked_buffer_sizes[0], + prepacked_buffer_sizes[1], + b_shape_, + quantized_b_size_, + b_scale_count_, + b_type_, + has_b_type_)); + quantized_b_ = std::move(prepacked_buffers[0]); + auto b_scales_deleter = prepacked_buffers[1].get_deleter(); + b_scales_ = IAllocatorUniquePtr( + static_cast(prepacked_buffers[1].release()), + [b_scales_deleter = std::move(b_scales_deleter)](float* p) mutable { + b_scales_deleter(static_cast(p)); + }); + used_shared_buffers = true; + return Status::OK(); +} + +Status DynamicQuantMatMulFp8::Compute(OpKernelContext* context) const { + const Tensor* a = context->Input(IN_A); + const Tensor* b = quantized_b_ ? nullptr : context->Input(IN_B); + const Tensor* b_scale = context->Input(IN_B_SCALE); + const Tensor* b_zero_point = context->Input(IN_B_ZERO_POINT); + const Tensor* y_scale = context->Input(IN_Y_SCALE); + const Tensor* y_zero_point = context->Input(IN_Y_ZERO_POINT); + + // Runtime B uses one 2D B scale/zero-point layout, so reject batched B before MatMul broadcasts it. + ORT_RETURN_IF(!quantized_b_ && b->Shape().NumDimensions() != 2, + "DynamicQuantMatMulFp8 requires runtime B to be a 2D tensor."); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), + quantized_b_ ? b_shape_ : b->Shape(), + nullptr, + nullptr)); + + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + Tensor* y = context->Output(OUT_Y, helper.OutputShape()); + const size_t y_size = static_cast(y->Shape().Size()); + ORT_RETURN_IF(!y->IsDataType() && !y->IsDataType() && !y->IsDataType(), + "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16."); + + if (y_zero_point != nullptr) { + // Runtime tensors must match the schema scalar contract before reading element 0. + ORT_RETURN_IF(y_zero_point->Shape().NumDimensions() != 0 || y_zero_point->Shape().Size() != 1, + "DynamicQuantMatMulFp8 requires Y zero point input to be a scalar."); + ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*y_zero_point, 1, "Y zero point")); + } + + float y_scale_storage = 0.0f; + const float* y_scale_data = nullptr; + if (y_scale != nullptr) { + // Runtime tensors must match the schema scalar contract before reading element 0. + ORT_RETURN_IF(y_scale->Shape().NumDimensions() != 0 || y_scale->Shape().Size() != 1, + "DynamicQuantMatMulFp8 requires Y scale input to be a scalar."); + if (y_scale->IsDataType()) { + y_scale_data = y_scale->Data(); + } else if (y_scale->IsDataType()) { + y_scale_storage = static_cast(y_scale->Data()[0]); + y_scale_data = &y_scale_storage; + } else if (y_scale->IsDataType()) { + y_scale_storage = static_cast(y_scale->Data()[0]); + y_scale_data = &y_scale_storage; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires Y scale input to be float, float16, or bfloat16."); + } + ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(y_scale_data, 1, "Y scale")); + } + + // Empty reduction does not need B data, so fill zeros before enforcing runtime FP8 B. + if (K == 0) { + if (y_size == 0) { + return Status::OK(); + } + if (y->IsDataType()) { + std::fill_n(y->MutableData(), y_size, 0.0f); + } else if (y->IsDataType()) { + std::fill_n(y->MutableData(), y_size, MLFloat16::FromBits(0)); + } else if (y->IsDataType()) { + std::fill_n(y->MutableData(), y_size, BFloat16::FromBits(0)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16."); + } + return Status::OK(); + } + + const bool a_is_supported = + a->IsDataType() || a->IsDataType() || a->IsDataType(); + ORT_RETURN_IF(!a_is_supported, "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16."); + + const auto b_elem_type = b ? static_cast(b->GetElementType()) + : static_cast(0); + const bool b_is_fp8 = IsFp8DataType(b_elem_type); + + mlas_fp8_mode b_type{}; + if (has_b_type_) { + b_type = b_type_; + } else if (b_is_fp8) { + ORT_RETURN_IF_ERROR(GetFp8Type(b_elem_type, b_type)); + } else { + b_type = fp8_type_; + } + + if (b_zero_point != nullptr) { + const auto b_zp_elem_type = + static_cast(b_zero_point->GetElementType()); + mlas_fp8_mode b_zp_type{}; + ORT_RETURN_IF_ERROR(GetFp8Type(b_zp_elem_type, b_zp_type)); + ORT_RETURN_IF(b_type != b_zp_type, + "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); + } + + if (y_size == 0) { + return Status::OK(); + } + + // Select the FP8 B buffer: prefer pre-quantized B from PrePack, otherwise accept FP8-typed B input. + const uint8_t* b_fp8 = nullptr; + if (quantized_b_) { + b_fp8 = static_cast(quantized_b_.get()); + } else if (b_is_fp8) { + b_fp8 = static_cast(b->DataRaw()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires runtime B input to be FP8. Non-FP8 B is only supported " + "when B is a constant initializer that can be quantized during prepack."); + } + + const size_t num_gemms = helper.OutputOffsets().size(); + ORT_RETURN_IF(K % block_size_k_ != 0, + "DynamicQuantMatMulFp8 requires K to be divisible by block_size_k."); + const size_t expected_blocks_k = K / block_size_k_; + const size_t blocks_m = M; + const size_t blocks_k = expected_blocks_k; + + const size_t blocks_n = N / block_size_n_; + ORT_RETURN_IF(!b_scales_ && b_scale == nullptr, + "DynamicQuantMatMulFp8 requires B scale when B is already FP8."); + ORT_RETURN_IF(b_scale != nullptr && b_scale->Shape().NumDimensions() != 2, + "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor."); + ORT_RETURN_IF(blocks_n == 0, "DynamicQuantMatMulFp8 requires non-zero B scale N dimension."); + ORT_RETURN_IF(N % block_size_n_ != 0, + "DynamicQuantMatMulFp8 requires N to be divisible by block_size_n."); + ORT_RETURN_IF(b_scale != nullptr && static_cast(b_scale->Shape()[0]) != blocks_n, + "DynamicQuantMatMulFp8 requires B scale N dimension to be N / block_size_n."); + ORT_RETURN_IF(b_scale != nullptr && static_cast(b_scale->Shape()[1]) != blocks_k, + "DynamicQuantMatMulFp8 requires B scale K dimension to be K / block_size_k."); + + const size_t a_scale_batch_stride = SafeMul(blocks_m, blocks_k); + const size_t b_zp_count = SafeMul(blocks_k, blocks_n); + + if (b_zero_point != nullptr) { + ORT_RETURN_IF(b_zero_point->Shape().NumDimensions() != 2, + "DynamicQuantMatMulFp8 requires B zero point to be a 2D tensor."); + ORT_RETURN_IF(b_zero_point->Shape()[0] != static_cast(blocks_n) || + b_zero_point->Shape()[1] != static_cast(blocks_k), + "DynamicQuantMatMulFp8 requires B zero point to have shape [N / block_size_n, K / block_size_k]."); + ORT_RETURN_IF_ERROR(ValidateZeroPointValuesAreZero(*b_zero_point, b_zp_count, "B zero point")); + } + + AllocatorPtr temp_allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&temp_allocator)); + + const float* b_scales = nullptr; + IAllocatorUniquePtr b_scale_float; + size_t b_scale_elems = 0; + if (b_scales_) { + b_scales = b_scales_.get(); + b_scale_elems = b_scale_count_; + } else if (b_scale->IsDataType()) { + b_scales = b_scale->Data(); + b_scale_elems = static_cast(b_scale->Shape().Size()); + } else { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + b_scale_elems = static_cast(b_scale->Shape().Size()); + b_scale_float = IAllocator::MakeUniquePtr(allocator, b_scale_elems, true); + if (b_scale->IsDataType()) { + for (size_t i = 0; i < b_scale_elems; ++i) { + b_scale_float.get()[i] = static_cast(b_scale->Data()[i]); + } + } else if (b_scale->IsDataType()) { + for (size_t i = 0; i < b_scale_elems; ++i) { + b_scale_float.get()[i] = static_cast(b_scale->Data()[i]); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires B scale input to be float, float16, or bfloat16."); + } + b_scales = b_scale_float.get(); + } + + // MLAS FP8 GEMM accumulates and stores float output. Use scratch for lower-precision Y, + // then convert once after all batched GEMMs complete. + IAllocatorUniquePtr y_float_buffer; + float* y_float_data = nullptr; + if (y->IsDataType()) { + y_float_data = y->MutableData(); + } else if (y->IsDataType() || y->IsDataType()) { + y_float_buffer = IAllocator::MakeUniquePtr(temp_allocator, y_size, true); + y_float_data = y_float_buffer.get(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires Y to be float, float16, or bfloat16."); + } + + MLAS_FP8_GEMM_SHAPE_PARAMS gemm_shape; + gemm_shape.M = M; + gemm_shape.N = N; + gemm_shape.K = K; + + ORT_RETURN_IF_ERROR(ValidatePositiveFiniteScales(b_scales, b_scale_elems, "B scale")); + + const size_t a_fp8_size = SafeMul(M, K); + const size_t a_num_elements = static_cast(a->Shape().Size()); + ORT_RETURN_IF(a_num_elements % a_fp8_size != 0, + "DynamicQuantMatMulFp8 requires A to contain complete MxK matrices."); + const size_t a_batch_count = a_num_elements / a_fp8_size; + + // Quantize the physical A tensor once. Broadcasted output GEMMs then reuse the same FP8 A slice. + auto a_fp8_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_num_elements, true); + const size_t a_scale_count = SafeMul(a_batch_count, a_scale_batch_stride); + auto a_scale_buffer = IAllocator::MakeUniquePtr(temp_allocator, a_scale_count, true); + const size_t a_quant_work_items = SafeMul(a_batch_count, a_scale_batch_stride); + ORT_RETURN_IF(a_quant_work_items > static_cast(std::numeric_limits::max()), + "DynamicQuantMatMulFp8 A quantization work item count exceeds ptrdiff_t range."); + const auto a_quant_work_items_i = static_cast(a_quant_work_items); + const size_t a_quant_block_elems = block_size_k_; + const TensorOpCost a_quant_unit_cost{ + static_cast(SafeMul(a_quant_block_elems, sizeof(float))), + static_cast(SafeMul(a_quant_block_elems, sizeof(uint8_t))), + static_cast(a_quant_block_elems) * 2.0}; + float fp8_max_abs = 0.0f; + ORT_RETURN_IF_ERROR(GetFp8MaxAbs(b_type, fp8_max_abs)); + const auto quantize_a_batches = [&](const auto* a_data) { + concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), a_quant_work_items_i, + a_quant_unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t tid = begin; tid < end; ++tid) { + const size_t work_idx = static_cast(tid); + const size_t a_batch_idx = work_idx / a_scale_batch_stride; + const size_t scale_block_idx = work_idx % a_scale_batch_stride; + const size_t row = scale_block_idx / blocks_k; + const size_t block_k = scale_block_idx % blocks_k; + const size_t a_batch_offset = a_batch_idx * a_fp8_size; + const size_t a_scale_batch_offset = a_batch_idx * a_scale_batch_stride; + QuantizeBlockwiseFp8ABlockDynamic( + a_data + a_batch_offset, + K, block_size_k_, blocks_k, + row, block_k, fp8_max_abs, b_type, + a_fp8_buffer.get() + a_batch_offset, + a_scale_buffer.get() + a_scale_batch_offset); + } + }); + }; + if (a->IsDataType()) { + quantize_a_batches(a->Data()); + } else if (a->IsDataType()) { + quantize_a_batches(a->Data()); + } else if (a->IsDataType()) { + quantize_a_batches(a->Data()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DynamicQuantMatMulFp8 requires A to be float, float16, or bfloat16."); + } + + std::vector gemm_data_vec(num_gemms); + for (size_t gemm_idx = 0; gemm_idx < num_gemms; ++gemm_idx) { + const size_t a_offset = helper.LeftOffsets()[gemm_idx]; + ORT_RETURN_IF(a_offset >= a_num_elements || (a_offset % a_fp8_size) != 0, + "DynamicQuantMatMulFp8 requires A offsets to reference complete MxK matrices."); + const size_t scale_batch_index = a_offset / a_fp8_size; + ORT_RETURN_IF(scale_batch_index >= a_batch_count, + "DynamicQuantMatMulFp8 requires A offsets to reference complete MxK matrices."); + const size_t a_scale_batch_offset = SafeMul(scale_batch_index, a_scale_batch_stride); + const float* a_scales_batch = a_scale_buffer.get() + a_scale_batch_offset; + auto& gemm_data = gemm_data_vec[gemm_idx]; + gemm_data.A = a_fp8_buffer.get() + a_offset; + gemm_data.lda = K; + gemm_data.B = b_fp8 + helper.RightOffsets()[gemm_idx]; + gemm_data.ldb = N; + gemm_data.C = y_float_data + helper.OutputOffsets()[gemm_idx]; + gemm_data.ldc = N; + gemm_data.ScaleA = a_scales_batch; + gemm_data.ScaleB = b_scales; + gemm_data.ScaleY = y_scale_data; + gemm_data.Fp8Type = b_type; + gemm_data.BlockSizeM = 1; + gemm_data.BlockSizeK = block_size_k_; + gemm_data.BlockSizeN = block_size_n_; + gemm_data.BlocksM = blocks_m; + gemm_data.BlocksK = blocks_k; + gemm_data.BlocksN = blocks_n; + gemm_data.ScaleAStrideK = 1; + gemm_data.ScaleAStrideM = blocks_k; + gemm_data.ScaleBStrideN = blocks_k; + gemm_data.ScaleBStrideK = 1; + } + + MlasFp8GemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, context->GetOperatorThreadPool()); + + if (y_float_buffer != nullptr) { + if (y->IsDataType()) { + auto* y_data = y->MutableData(); + for (size_t i = 0; i < y_size; ++i) { + y_data[i] = static_cast(y_float_data[i]); + } + } else { + auto* y_data = y->MutableData(); + for (size_t i = 0; i < y_size; ++i) { + y_data[i] = BFloat16(y_float_data[i]); + } + } + } + + return Status::OK(); +} + +#endif // !defined(DISABLE_FLOAT8_TYPES) + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h new file mode 100644 index 0000000000000..cef88def0ec77 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/prepacked_weights.h" +#include "core/graph/onnx_protobuf.h" +#include "core/mlas/inc/mlas.h" + +namespace onnxruntime { +namespace contrib { + +class DynamicQuantMatMulFp8 final : public OpKernel { + public: + DynamicQuantMatMulFp8(const OpKernelInfo& info); + + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status Compute(OpKernelContext* context) const override; + + enum InputTensors : int { + IN_A = 0, + IN_B = 1, + IN_B_SCALE = 2, + IN_B_ZERO_POINT = 3, + IN_Y_SCALE = 4, + IN_Y_ZERO_POINT = 5 + }; + + enum OutputTensors : int { OUT_Y = 0 }; + + static Status GetFp8Type(const Tensor& tensor, mlas_fp8_mode& out_type); + static Status GetFp8Type(ONNX_NAMESPACE::TensorProto_DataType elem_type, mlas_fp8_mode& out_type); + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + + private: + static constexpr int GetBIdx() { return IN_B; } + IAllocatorUniquePtr quantized_b_; + size_t quantized_b_size_{0}; + IAllocatorUniquePtr b_scales_; + size_t b_scale_count_{0}; + TensorShape b_shape_; + mlas_fp8_mode b_type_{static_cast(0)}; + bool has_b_type_{false}; + mlas_fp8_mode fp8_type_{MLAS_FP8_MODE_E4M3_INF}; + size_t block_size_k_{128}; + size_t block_size_n_{128}; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/common/fp8_common.h b/onnxruntime/core/common/fp8_common.h new file mode 100644 index 0000000000000..e4cb73df9cd80 --- /dev/null +++ b/onnxruntime/core/common/fp8_common.h @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#pragma once + +#if !defined(DISABLE_FLOAT8_TYPES) + +#include "core/common/float8.h" +#include "core/mlas/inc/mlas.h" + +#include + +namespace onnxruntime { + +inline float Fp8ByteToFloat(uint8_t value, mlas_fp8_mode mode) { + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: + return Float8E4M3FN(value, Float8E4M3FN::FromBits()).ToFloat(); + case MLAS_FP8_MODE_E4M3_SAT: + return Float8E4M3FNUZ(value, Float8E4M3FNUZ::FromBits()).ToFloat(); + case MLAS_FP8_MODE_E5M2_INF: + return Float8E5M2(value, Float8E5M2::FromBits()).ToFloat(); + case MLAS_FP8_MODE_E5M2_SAT: + return Float8E5M2FNUZ(value, Float8E5M2FNUZ::FromBits()).ToFloat(); + default: + return 0.0f; + } +} + +inline uint8_t FloatToFp8Byte(float value, mlas_fp8_mode mode) { + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: { + const Float8E4M3FN fp8(value, true); + return fp8.val; + } + case MLAS_FP8_MODE_E4M3_SAT: { + const Float8E4M3FNUZ fp8(value, true); + return fp8.val; + } + case MLAS_FP8_MODE_E5M2_INF: { + const Float8E5M2 fp8(value, true); + return fp8.val; + } + case MLAS_FP8_MODE_E5M2_SAT: { + const Float8E5M2FNUZ fp8(value, true); + return fp8.val; + } + default: + return 0; + } +} + +inline bool IsValidFp8Mode(mlas_fp8_mode mode) { + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: + case MLAS_FP8_MODE_E4M3_SAT: + case MLAS_FP8_MODE_E5M2_INF: + case MLAS_FP8_MODE_E5M2_SAT: + return true; + default: + return false; + } +} + +} // namespace onnxruntime + +#endif // !defined(DISABLE_FLOAT8_TYPES) diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 59f97c222ceb2..6894b4e7eed40 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -25,6 +25,9 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MulInteger); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QGemm); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicQuantMatMulFp8); +#endif class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAdd); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConcat); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearWhere); @@ -139,6 +142,9 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); +#if !defined(DISABLE_FLOAT8_TYPES) + fn(GetOpSchema()); +#endif fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 5a3cd86b04492..86800e4d1ec09 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -946,6 +946,130 @@ ONNX_MS_OPERATOR_SET_SCHEMA( updateOutputShape(ctx, 0, {first_input_shape.dim(transA ? 1 : 0), second_input_shape.dim(transB ? 0 : 1)}); } })); + +#if !defined(DISABLE_FLOAT8_TYPES) +ONNX_MS_OPERATOR_SET_SCHEMA( + DynamicQuantMatMulFp8, 1, + OpSchema() + .SetDoc("Symmetric quantized MatMul for fp8 weights (with optional prepack conversion from " + "float16/bfloat16/float) and dynamic runtime quantization of activations to fp8 using " + "internally computed block-wise scales. All zero-point inputs, when provided, must encode 0.0.") + .Input(0, "A", "Input tensor A.", "TA") + .Input(1, "B", + "Input tensor B. FP8 B may be provided at runtime. Float, float16, and bfloat16 B are only " + "supported when B is a constant initializer that can be quantized during prepack.", + "TB") + .Input(2, "B_scale", + "Scale of FP8 input 'B'. Must be a block-wise tensor with shape " + "(N / block_size_n, K / block_size_k). Required when B is already FP8. Ignored for non-FP8 " + "constant B, where scales are computed during prepack.", + "TS", OpSchema::Optional) + .Input(3, "B_zero_point", + "Zero point tensor for input 'B'. Must have the same shape as B_scale and all values must encode 0.0.", + "TZ", OpSchema::Optional) + .Input(4, "Y_scale", "Scale of output 'Y'. Must be a scalar when provided.", "TS", + OpSchema::Optional) + .Input(5, "Y_zero_point", + "Zero point tensor for output 'Y'. Must be a scalar encoding 0.0 when provided.", "TZ", + OpSchema::Optional) + .Output(0, "Y", "Output tensor of shape (..., M, N).", "TY") + .Attr("block_size_m", "Block size along M for A block-wise scales. Must be 1.", + AttributeProto::INT, static_cast(1)) + .Attr("block_size_k", "Block size along K for A and B block-wise scales.", AttributeProto::INT, + static_cast(128)) + .Attr("block_size_n", "Block size along N for B block-wise scales.", AttributeProto::INT, + static_cast(128)) + .Attr("fp8_type", + "FP8 TensorProto data type used when non-FP8 constant B is dynamically quantized during prepack. " + "Defaults to FLOAT8E4M3FN.", + AttributeProto::INT, + static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN)) + .TypeConstraint("TA", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, + "Constrain input A type to float16, bfloat16, or float.") + .TypeConstraint("TB", + {"tensor(float16)", "tensor(bfloat16)", "tensor(float)", + "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", + "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"}, + "Constrain input B type to fp8, or to float16, bfloat16, or float for constant initializers.") + .TypeConstraint("TZ", {"tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"}, + "Constrain zero point types to fp8. Only zero-valued zero points are supported.") + // Scale tensors are upcast to float by the CPU kernel before calling MLAS. + .TypeConstraint("TS", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain scale types to float, float16, or bfloat16.") + .TypeConstraint("TY", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, + "Constrain output type to float16, bfloat16, or float.") + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + const int64_t block_size_m = getAttribute(ctx, "block_size_m", static_cast(1)); + const int64_t block_size_k = getAttribute(ctx, "block_size_k", static_cast(128)); + const int64_t block_size_n = getAttribute(ctx, "block_size_n", static_cast(128)); + if (block_size_m != 1) { + fail_type_inference("block_size_m must be 1."); + } + if (block_size_k <= 0 || block_size_n <= 0) { + fail_type_inference("block_size_k and block_size_n must be greater than zero."); + } + if (hasInputShape(ctx, 1)) { + auto& b_shape = getInputShape(ctx, 1); + if (b_shape.dim_size() != 2) { + fail_type_inference("B must be 2D."); + } + } + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { + ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1); + } + if (hasInputShape(ctx, 4)) { + auto shape = ctx.getInputType(4)->tensor_type().shape(); + if (shape.dim_size() != 0) { + fail_type_inference("Y scale input must be a scalar."); + } + } + if (hasInputShape(ctx, 5)) { + auto shape = ctx.getInputType(5)->tensor_type().shape(); + if (shape.dim_size() != 0) { + fail_type_inference("Y zero point input must be a scalar."); + } + } + if (hasInputShape(ctx, 1) && hasInputShape(ctx, 2)) { + auto& b_shape = getInputShape(ctx, 1); + auto& b_scale_shape = getInputShape(ctx, 2); + if (b_scale_shape.dim_size() != 2) { + fail_type_inference("B scale must be 2D."); + } + if (b_shape.dim(1).has_dim_value() && b_scale_shape.dim(0).has_dim_value()) { + const auto n = b_shape.dim(1).dim_value(); + if ((n % block_size_n) != 0 || b_scale_shape.dim(0).dim_value() != (n / block_size_n)) { + fail_type_inference("B scale first dimension must be N / block_size_n."); + } + } + if (b_shape.dim(0).has_dim_value() && b_scale_shape.dim(1).has_dim_value()) { + const auto k = b_shape.dim(0).dim_value(); + if ((k % block_size_k) != 0 || b_scale_shape.dim(1).dim_value() != (k / block_size_k)) { + fail_type_inference("B scale last dimension must be K / block_size_k."); + } + } + } + if (hasInputShape(ctx, 1) && hasInputShape(ctx, 3)) { + auto& b_shape = getInputShape(ctx, 1); + auto& b_zp_shape = getInputShape(ctx, 3); + if (b_zp_shape.dim_size() != 2) { + fail_type_inference("B zero point must be 2D."); + } + if (b_shape.dim(1).has_dim_value() && b_zp_shape.dim(0).has_dim_value()) { + const auto n = b_shape.dim(1).dim_value(); + if ((n % block_size_n) != 0 || b_zp_shape.dim(0).dim_value() != (n / block_size_n)) { + fail_type_inference("B zero point first dimension must be N / block_size_n."); + } + } + if (b_shape.dim(0).has_dim_value() && b_zp_shape.dim(1).has_dim_value()) { + const auto k = b_shape.dim(0).dim_value(); + if ((k % block_size_k) != 0 || b_zp_shape.dim(1).dim_value() != (k / block_size_k)) { + fail_type_inference("B zero point last dimension must be K / block_size_k."); + } + } + } + })); + +#endif ONNX_MS_OPERATOR_SET_SCHEMA( QAttention, 1, OpSchema() diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index ddb9daa5e244b..866b4430a9ce7 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -726,6 +726,76 @@ bool MLASCALL MlasIsDynamicQGemmAvailable(const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); +/** + * @brief Parameters that define the shape of an FP8 GEMM operation. + */ +struct MLAS_FP8_GEMM_SHAPE_PARAMS { + size_t M = 0; /**< Row size of matrix A */ + size_t N = 0; /**< Column size of matrix B */ + size_t K = 0; /**< Column size of matrix A and row size of matrix B */ +}; + +// FP8 mode is aligned with Arm KleidiAI format/overflow modes. +// Defined here to keep MLAS FP8 APIs platform-agnostic. +#ifndef MLAS_FP8_MODE_DEFINED +#define MLAS_FP8_MODE_DEFINED +enum mlas_fp8_mode : uint8_t { + MLAS_FP8_MODE_E4M3_INF = 0, ///< E4M3 with NaN/Inf on overflow. + MLAS_FP8_MODE_E4M3_SAT = 1, ///< E4M3 with saturation on overflow. + MLAS_FP8_MODE_E5M2_INF = 2, ///< E5M2 with NaN/Inf on overflow. + MLAS_FP8_MODE_E5M2_SAT = 3, ///< E5M2 with saturation on overflow. + MLAS_FP8_MODE_END = 4, ///< End marker. Do not use. +}; +#endif // MLAS_FP8_MODE_DEFINED + +/** + * @brief Parameters that define the data buffers and layout for an FP8 GEMM. + */ +struct MLAS_FP8_GEMM_DATA_PARAMS { + const void* A = nullptr; + size_t lda = 0; + const void* B = nullptr; + size_t ldb = 0; + void* C = nullptr; + size_t ldc = 0; + // Block-wise scales for A, indexed as block_m * ScaleAStrideM + block_k * ScaleAStrideK. + const float* ScaleA = nullptr; + // Block-wise scales for B, indexed as block_k * ScaleBStrideK + block_n * ScaleBStrideN. + const float* ScaleB = nullptr; + const float* ScaleY = nullptr; // Scalar scale for Y. + mlas_fp8_mode Fp8Type = static_cast(0); + size_t BlockSizeM = 128; // Block size along M for A quantization. + size_t BlockSizeK = 128; // Block size along K for A/B quantization. + size_t BlockSizeN = 128; // Block size along N for B quantization. + size_t BlocksM = 0; // Number of blocks along M (ceil(M / BlockSizeM)). + size_t BlocksK = 0; // Number of blocks along K (ceil(K / BlockSizeK)). + size_t BlocksN = 0; // Number of blocks along N (ceil(N / BlockSizeN)). + size_t ScaleAStrideK = 0; // ScaleA stride between K blocks (elements). + size_t ScaleAStrideM = 0; // ScaleA stride between M blocks (elements). + size_t ScaleBStrideN = 0; // ScaleB stride between N blocks (elements). + size_t ScaleBStrideK = 0; // ScaleB stride between K blocks (elements). +}; + +#if !defined(DISABLE_FLOAT8_TYPES) +void +MLASCALL +MlasFp8GemmBatch( + const MLAS_FP8_GEMM_SHAPE_PARAMS& Shape, + const MLAS_FP8_GEMM_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool +); + +inline void +MlasFp8Gemm( + const MLAS_FP8_GEMM_SHAPE_PARAMS& Shape, + const MLAS_FP8_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool +) { + MlasFp8GemmBatch(Shape, DataParams, 1, ThreadPool); +} +#endif // !defined(DISABLE_FLOAT8_TYPES) + // // Symmetric QGEMM has limited buffer overrun. // Currently only supported in ARM64 diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index 3c9f398ece887..79561d246bac9 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -165,7 +165,7 @@ MLASCALL MlasDynamicQGemmBatch( const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, - const size_t BatchN, + const size_t BatchSize, MLAS_THREADPOOL* ThreadPool ); @@ -225,37 +225,3 @@ MlasConvSymmetricChannelsLast2DFloatPackW( MLAS_THREADPOOL* ThreadPool ); } - -/*++ - -Routine Description: - - This routine determines if a wraparound will occur when multiplying two size_t variables - Uses __builtin_mul_overflow if available on the current system and if not falls back - to a default implementation to check this wraparound. - -Arguments: - - a - Supplies the first number to be muliplied. - - b - Supplies the second number to be muliplied. - - out - pointer to a size_t which acts as the return value in success cases. - -Return Value: - - Returns false if the operation was successful - Returns true if wraparound of size_t was detected - ---*/ -inline bool mul_overflow_size_t_builtin(size_t a, size_t b, size_t* out) { -#if defined(__has_builtin) -# if __has_builtin(__builtin_mul_overflow) - return __builtin_mul_overflow(a, b, out); -# endif -#endif - // Fallback to manual check if builtin not available - if (b != 0 && a > SIZE_MAX / b) return true; - if (out) *out = a * b; - return false; -} diff --git a/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp index f88af056aa156..99816649ac7d0 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp @@ -279,7 +279,7 @@ Return Value: LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr); size_t lhs_resize = 0; - if (mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize)) + if (MlasMultiplyOverflowsSizeT(LhsPackedStride, BatchSize, &lhs_resize)) { // size_t wraparound detected for LhsPackedStride, fallback to MLAS return false; @@ -304,7 +304,7 @@ Return Value: // Multithread pack lhs and rhs RhsPackedStride = ArmKleidiAI::MlasSBGemmPackBSize(TransA, TransB, N, K); size_t rhs_resize = 0; - if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize)) + if (MlasMultiplyOverflowsSizeT(RhsPackedStride, BatchSize, &rhs_resize)) { // size_t wraparound detected for RhsPackedStride, fallback to MLAS return false; @@ -354,7 +354,7 @@ Return Value: // Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop. // Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively. size_t max_tile_elems = 0; - if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) { + if (MlasMultiplyOverflowsSizeT(m_step, n_step, &max_tile_elems)) { // size_t wraparound detected for tile size, fallback to MLAS return false; } diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp index 69b83e1e5b49c..ef599cf1346a8 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -542,7 +542,7 @@ Return Value: LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr); size_t lhs_resize = 0; - if(mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize)) + if(MlasMultiplyOverflowsSizeT(LhsPackedStride, BatchSize, &lhs_resize)) { // size_t wraparound detected for LhsPackedStride, fallback to MLAS return false; @@ -568,7 +568,7 @@ Return Value: // Multithread pack lhs and rhs RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K); size_t rhs_resize = 0; - if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize)) + if (MlasMultiplyOverflowsSizeT(RhsPackedStride, BatchSize, &rhs_resize)) { // size_t wraparound detected for RhsPackedStride, fallback to MLAS return false; @@ -616,7 +616,7 @@ Return Value: // Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop. // Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively. size_t max_tile_elems = 0; - if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) { + if (MlasMultiplyOverflowsSizeT(m_step, n_step, &max_tile_elems)) { // size_t wraparound detected for tile size, fallback to MLAS return false; } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index dbb414505ff38..b4ef99b1f17b2 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -119,6 +119,33 @@ Module Name: #define MLAS_UNREFERENCED_PARAMETER(parameter) ((void)(parameter)) +// +// Reports whether multiplying two size_t values would overflow. +// + +MLAS_FORCEINLINE +bool +MlasMultiplyOverflowsSizeT( + size_t a, + size_t b, + size_t* out + ) +{ +#if defined(__has_builtin) +#if __has_builtin(__builtin_mul_overflow) + size_t result; + return __builtin_mul_overflow(a, b, out != nullptr ? out : &result); +#endif +#endif + if (b != 0 && a > std::numeric_limits::max() / b) { + return true; + } + if (out != nullptr) { + *out = a * b; + } + return false; +} + #ifdef MLAS_NO_EXCEPTION MLAS_FORCEINLINE void diff --git a/onnxruntime/core/mlas/lib/qgemm_fp8.cpp b/onnxruntime/core/mlas/lib/qgemm_fp8.cpp new file mode 100644 index 0000000000000..fa0115db89d5b --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_fp8.cpp @@ -0,0 +1,164 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "mlasi.h" + +#if !defined(DISABLE_FLOAT8_TYPES) + +#include "core/common/common.h" +#include "core/common/fp8_common.h" + +namespace { + +// Computes the parallel work item count without wrapping before ptrdiff_t narrowing. +inline ptrdiff_t CheckedWorkItems(size_t batch_count, size_t m) { + size_t work_items = 0; + ORT_ENFORCE(!MlasMultiplyOverflowsSizeT(batch_count, m, &work_items), "FP8 GEMM work item count overflow."); + ORT_ENFORCE(work_items <= static_cast(std::numeric_limits::max()), + "FP8 GEMM work item count exceeds ptrdiff_t range."); + return static_cast(work_items); +} + +// Validates the largest row-major strided matrix offset used by the GEMM loops. +inline void CheckStridedMatrixOffset(size_t rows, size_t cols, size_t row_stride) { + if (rows == 0 || cols == 0) { + return; + } + + size_t offset = 0; + ORT_ENFORCE(!MlasMultiplyOverflowsSizeT(rows - 1, row_stride, &offset), + "FP8 GEMM matrix offset overflow."); + ORT_ENFORCE(cols - 1 <= std::numeric_limits::max() - offset, + "FP8 GEMM matrix offset overflow."); +} + +// Validates the largest block-scale or zero-point offset used by the GEMM loops. +inline void CheckBlockedMatrixOffset(size_t blocks0, size_t stride0, size_t blocks1, size_t stride1) { + if (blocks0 == 0 || blocks1 == 0) { + return; + } + + size_t offset0 = 0; + ORT_ENFORCE(!MlasMultiplyOverflowsSizeT(blocks0 - 1, stride0, &offset0), + "FP8 GEMM block offset overflow."); + size_t offset1 = 0; + ORT_ENFORCE(!MlasMultiplyOverflowsSizeT(blocks1 - 1, stride1, &offset1), + "FP8 GEMM block offset overflow."); + ORT_ENFORCE(offset1 <= std::numeric_limits::max() - offset0, + "FP8 GEMM block offset overflow."); +} + +// Validate caller-provided buffers, strides, and block metadata before parallel workers dereference them. +inline void CheckFp8GemmBatchParams( + const MLAS_FP8_GEMM_SHAPE_PARAMS& shape, + const MLAS_FP8_GEMM_DATA_PARAMS& params) { + ORT_ENFORCE(onnxruntime::IsValidFp8Mode(params.Fp8Type), "FP8 GEMM mode must be valid."); + ORT_ENFORCE(params.BlockSizeM != 0 && params.BlockSizeK != 0 && params.BlockSizeN != 0, + "FP8 GEMM block sizes must be non-zero."); + + const bool writes_output = shape.M != 0 && shape.N != 0; + const bool reads_reduction_data = writes_output && shape.K != 0; + + // Empty-output GEMMs do not dereference C, and empty reductions do not read A/B. + if (reads_reduction_data) { + ORT_ENFORCE(params.A != nullptr, "FP8 GEMM A buffer must not be null."); + ORT_ENFORCE(params.B != nullptr, "FP8 GEMM B buffer must not be null."); + ORT_ENFORCE(params.lda >= shape.K, "FP8 GEMM lda must be greater than or equal to K."); + ORT_ENFORCE(params.ldb >= shape.N, "FP8 GEMM ldb must be greater than or equal to N."); + + CheckStridedMatrixOffset(shape.M, shape.K, params.lda); + CheckStridedMatrixOffset(shape.K, shape.N, params.ldb); + } + + if (writes_output) { + ORT_ENFORCE(params.C != nullptr, "FP8 GEMM C buffer must not be null."); + ORT_ENFORCE(params.ldc >= shape.N, "FP8 GEMM ldc must be greater than or equal to N."); + + CheckStridedMatrixOffset(shape.M, shape.N, params.ldc); + } + + const size_t blocks_m = shape.M == 0 ? 0 : ((shape.M - 1) / params.BlockSizeM) + 1; + const size_t blocks_k = shape.K == 0 ? 0 : ((shape.K - 1) / params.BlockSizeK) + 1; + const size_t blocks_n = shape.N == 0 ? 0 : ((shape.N - 1) / params.BlockSizeN) + 1; + + if (reads_reduction_data && params.ScaleA != nullptr) { + ORT_ENFORCE(blocks_m == params.BlocksM, "FP8 GEMM M block count must match shape and block size."); + ORT_ENFORCE(blocks_k == params.BlocksK, "FP8 GEMM K block count must match shape and block size."); + CheckBlockedMatrixOffset(blocks_m, params.ScaleAStrideM, blocks_k, params.ScaleAStrideK); + } + if (reads_reduction_data && params.ScaleB != nullptr) { + ORT_ENFORCE(blocks_k == params.BlocksK, "FP8 GEMM K block count must match shape and block size."); + ORT_ENFORCE(blocks_n == params.BlocksN, "FP8 GEMM N block count must match shape and block size."); + CheckBlockedMatrixOffset(blocks_k, params.ScaleBStrideK, blocks_n, params.ScaleBStrideN); + } +} + +} // namespace + +void +MLASCALL +MlasFp8GemmBatch( + const MLAS_FP8_GEMM_SHAPE_PARAMS& Shape, + const MLAS_FP8_GEMM_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool + ) +{ + const size_t M = Shape.M; + const size_t N = Shape.N; + const size_t K = Shape.K; + + if (BatchN == 0 || M == 0 || N == 0) { + return; + } + + const ptrdiff_t WorkItems = CheckedWorkItems(BatchN, M); + + ORT_ENFORCE(DataParams != nullptr, "FP8 GEMM data parameters must not be null."); + + for (size_t batch = 0; batch < BatchN; ++batch) { + CheckFp8GemmBatchParams(Shape, DataParams[batch]); + } + + MlasTrySimpleParallel(ThreadPool, WorkItems, [&](ptrdiff_t tid) { + const size_t batch = static_cast(tid) / M; + const size_t m = static_cast(tid) % M; + const auto& params = DataParams[batch]; + ORT_ENFORCE(onnxruntime::IsValidFp8Mode(params.Fp8Type), "FP8 GEMM mode must be valid."); + + const auto* a_fp8 = static_cast(params.A); + const auto* b_fp8 = static_cast(params.B); + auto* c = static_cast(params.C); + const auto* scale_a = params.ScaleA; + const auto* scale_b = params.ScaleB; + + const size_t block_m = m / params.BlockSizeM; + for (size_t n = 0; n < N; ++n) { + const size_t block_n = n / params.BlockSizeN; + float acc = 0.0f; + for (size_t k = 0; k < K; ++k) { + const size_t block_k = k / params.BlockSizeK; + + const size_t a_scale_idx = block_m * params.ScaleAStrideM + block_k * params.ScaleAStrideK; + const size_t b_scale_idx = block_k * params.ScaleBStrideK + block_n * params.ScaleBStrideN; + const float scale_a_val = scale_a ? scale_a[a_scale_idx] : 1.0f; + const float scale_b_val = scale_b ? scale_b[b_scale_idx] : 1.0f; + + const float a_val = onnxruntime::Fp8ByteToFloat(a_fp8[m * params.lda + k], params.Fp8Type); + const float b_val = onnxruntime::Fp8ByteToFloat(b_fp8[k * params.ldb + n], params.Fp8Type); + + const float a_deq = a_val * scale_a_val; + const float b_deq = b_val * scale_b_val; + acc += a_deq * b_deq; + } + + if (params.ScaleY != nullptr) { + acc *= params.ScaleY[0]; + } + c[m * params.ldc + n] = acc; + } + }); +} + +#endif // !defined(DISABLE_FLOAT8_TYPES) diff --git a/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc new file mode 100644 index 0000000000000..763660460d855 --- /dev/null +++ b/onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc @@ -0,0 +1,866 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "core/session/inference_session.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/test_environment.h" +#include "test/unittest_util/conversion.h" +#include "default_providers.h" + +#if !defined(DISABLE_FLOAT8_TYPES) +#include "core/common/float8.h" + +namespace onnxruntime { +namespace test { + +class DynamicQuantMatMulFp8SessionTester : public OpTester { + public: + using BaseTester::ExecuteModel; + using BaseTester::FillFeedsAndOutputNames; + using BaseTester::SetTestFunctionCalled; + using OpTester::BuildModel; + using OpTester::OpTester; +}; + +template +struct Fp8TensorProtoType; + +template <> +struct Fp8TensorProtoType { + static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN; +}; + +template <> +struct Fp8TensorProtoType { + static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ; +}; + +template <> +struct Fp8TensorProtoType { + static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; +}; + +template <> +struct Fp8TensorProtoType { + static constexpr int64_t value = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ; +}; + +template +float Fp8MaxAbs(); + +template <> +float Fp8MaxAbs() { + return 448.0f; +} + +template <> +float Fp8MaxAbs() { + return 448.0f; +} + +template <> +float Fp8MaxAbs() { + return 57344.0f; +} + +template <> +float Fp8MaxAbs() { + return 57344.0f; +} + +template +float QuantizeDequantize(float value, float scale) { + return Fp8T(value / scale, true).ToFloat() * scale; +} + +template +std::vector ComputeRowwiseAQuantized(const std::vector& a_data, + int64_t m, + int64_t k, + int64_t block_size_k) { + const int64_t blocks_k = k / block_size_k; + std::vector result(a_data.size()); + for (int64_t row = 0; row < m; ++row) { + for (int64_t block_k = 0; block_k < blocks_k; ++block_k) { + const int64_t k_begin = block_k * block_size_k; + const int64_t k_end = k_begin + block_size_k; + float max_abs = 0.0f; + for (int64_t kk = k_begin; kk < k_end; ++kk) { + max_abs = std::max(max_abs, std::fabs(a_data[static_cast(row * k + kk)])); + } + const float scale = max_abs == 0.0f ? 1.0f : max_abs / Fp8MaxAbs(); + for (int64_t kk = k_begin; kk < k_end; ++kk) { + const size_t index = static_cast(row * k + kk); + result[index] = QuantizeDequantize(a_data[index], scale); + } + } + } + return result; +} + +template +std::vector ComputeConstantBQuantized(const std::vector& b_data, + int64_t k, + int64_t n, + int64_t block_size_k, + int64_t block_size_n) { + const int64_t blocks_k = k / block_size_k; + const int64_t blocks_n = n / block_size_n; + std::vector result(b_data.size()); + for (int64_t block_n = 0; block_n < blocks_n; ++block_n) { + const int64_t n_begin = block_n * block_size_n; + const int64_t n_end = n_begin + block_size_n; + for (int64_t block_k = 0; block_k < blocks_k; ++block_k) { + const int64_t k_begin = block_k * block_size_k; + const int64_t k_end = k_begin + block_size_k; + float max_abs = 0.0f; + for (int64_t kk = k_begin; kk < k_end; ++kk) { + for (int64_t nn = n_begin; nn < n_end; ++nn) { + max_abs = std::max(max_abs, std::fabs(b_data[static_cast(kk * n + nn)])); + } + } + const float scale = max_abs == 0.0f ? 1.0f : max_abs / Fp8MaxAbs(); + for (int64_t kk = k_begin; kk < k_end; ++kk) { + for (int64_t nn = n_begin; nn < n_end; ++nn) { + const size_t index = static_cast(kk * n + nn); + result[index] = QuantizeDequantize(b_data[index], scale); + } + } + } + } + return result; +} + +template +std::vector ComputeRuntimeBQuantized(const std::vector& b_data, + const std::vector& b_scale, + int64_t k, + int64_t n, + int64_t block_size_k, + int64_t block_size_n) { + const int64_t blocks_k = k / block_size_k; + std::vector result(b_data.size()); + for (int64_t kk = 0; kk < k; ++kk) { + const int64_t block_k = kk / block_size_k; + for (int64_t nn = 0; nn < n; ++nn) { + const int64_t block_n = nn / block_size_n; + const size_t index = static_cast(kk * n + nn); + result[index] = b_data[index].ToFloat() * b_scale[static_cast(block_n * blocks_k + block_k)]; + } + } + return result; +} + +std::vector ComputeMatMul(const std::vector& a_data, + const std::vector& b_data, + int64_t m, + int64_t n, + int64_t k, + float y_scale = 1.0f) { + std::vector y_data(static_cast(m * n), 0.0f); + for (int64_t row = 0; row < m; ++row) { + for (int64_t col = 0; col < n; ++col) { + float sum = 0.0f; + for (int64_t kk = 0; kk < k; ++kk) { + sum += a_data[static_cast(row * k + kk)] * b_data[static_cast(kk * n + col)]; + } + y_data[static_cast(row * n + col)] = sum * y_scale; + } + } + return y_data; +} + +template +void AddZeroPoint(OpTester& test, const char* name, const std::vector& shape, size_t count, bool initializer) { + test.AddInput(name, shape, std::vector(count, Fp8T(0.0f)), initializer); +} + +template +void RunConstantBInputs() { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + const auto a_quantized = ComputeRowwiseAQuantized(a_data, M, K, 128); + const auto b_quantized = ComputeConstantBQuantized(b_data, K, N, 128, 128); + const auto y_data = ComputeMatMul(a_quantized, b_quantized, M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("fp8_type", Fp8TensorProtoType::value); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.5f); + test.Run(); +} + +template +void RunRuntimeFp8BInput() { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Fp8T(-0.5f)); + std::vector b_scale{1.0f}; + const auto a_quantized = ComputeRowwiseAQuantized(a_data, M, K, 128); + const auto b_quantized = ComputeRuntimeBQuantized(b_data, b_scale, K, N, 128, 128); + const auto y_data = ComputeMatMul(a_quantized, b_quantized, M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.5f); + test.Run(); +} + +void RunDynamicQuantMatMulFp8WithSharedPrepack(DynamicQuantMatMulFp8SessionTester& test, + OrtValue& shared_b, + PrepackedWeightsContainer& prepacked_weights_container, + size_t& shared_prepack_count) { + test.SetTestFunctionCalled(); + + auto& model = test.BuildModel(); + Status status = model.MainGraph().Resolve(); + ASSERT_TRUE(status.IsOK()) << status; + + std::unordered_map feeds; + std::vector output_names; + test.FillFeedsAndOutputNames(feeds, output_names); + + SessionOptions so; + status = so.AddInitializer("B", &shared_b); + ASSERT_TRUE(status.IsOK()) << status; + + InferenceSession session{so, GetEnvironment()}; + status = session.AddPrePackedWeightsContainer(&prepacked_weights_container); + ASSERT_TRUE(status.IsOK()) << status; + + status = session.RegisterExecutionProvider(DefaultCpuExecutionProvider()); + ASSERT_TRUE(status.IsOK()) << status; + + test.ExecuteModel(model, + session, + OpTester::ExpectResult::kExpectSuccess, + "", + nullptr, + feeds, + output_names, + kCpuExecutionProvider); + shared_prepack_count = session.GetSessionState().GetUsedSharedPrePackedWeightCounter(); +} + +TEST(DynamicQuantMatMulFp8, WithConstantBInputs) { + RunConstantBInputs(); +} + +TEST(DynamicQuantMatMulFp8, WithConstantBInputsE4M3FNUZ) { + RunConstantBInputs(); +} + +TEST(DynamicQuantMatMulFp8, WithConstantBInputsE5M2) { + RunConstantBInputs(); +} + +TEST(DynamicQuantMatMulFp8, WithConstantBInputsE5M2FNUZ) { + RunConstantBInputs(); +} + +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInput) { + RunRuntimeFp8BInput(); +} + +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE4M3FNUZ) { + RunRuntimeFp8BInput(); +} + +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE5M2) { + RunRuntimeFp8BInput(); +} + +TEST(DynamicQuantMatMulFp8, RuntimeFp8BInputE5M2FNUZ) { + RunRuntimeFp8BInput(); +} + +TEST(DynamicQuantMatMulFp8, WithOmittedOutputQuantizationInputs) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeRuntimeBQuantized(b_data, b_scale, K, N, 128, 128), M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.5f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, WithOnlyYScale) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + constexpr float YScale = 0.5f; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + std::vector y_scale{YScale}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeRuntimeBQuantized(b_data, b_scale, K, N, 128, 128), M, N, K, YScale); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); + test.AddInput("Y_scale", {}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.5f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, RejectsNonZeroYZeroPoint) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + std::vector y_data(static_cast(M * N), 0.0f); + std::vector y_zp{Float8E4M3FN(1.0f)}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); + test.AddOptionalInputEdge(); + test.AddInput("Y_zero_point", {}, y_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 supports symmetric quantization only; Y zero point values must be zero."); +} + +TEST(DynamicQuantMatMulFp8, WithRuntimeBInputsBf16Scales) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale_float{1.0f}; + std::vector b_scale = MakeBFloat16({1.0f}); + std::vector y_scale = MakeBFloat16({1.0f}); + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeRuntimeBQuantized(b_data, b_scale_float, K, N, 128, 128), M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / 128, K / 128}, b_scale.size(), false); + test.AddInput("Y_scale", {}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.5f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, Float16Output) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeConstantBQuantized(b_data, K, N, 128, 128), M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, FloatsToMLFloat16s(a_data)); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, FloatsToMLFloat16s(y_data)); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, BFloat16Output) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, 128), + ComputeConstantBQuantized(b_data, K, N, 128, 128), M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, FloatsToBFloat16s(a_data)); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, FloatsToBFloat16s(y_data)); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, RejectsNonConstantB) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), -0.5f); + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires runtime B input to be FP8."); +} + +TEST(DynamicQuantMatMulFp8, RejectsRuntimeFp8BZeroPointTypeMismatch) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + std::vector b_zp{Float8E5M2(0.0f)}; + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddInput("B_zero_point", {N / 128, K / 128}, b_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); +} + +TEST(DynamicQuantMatMulFp8, RejectsConstantFp8BZeroPointTypeMismatch) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + std::vector b_zp{Float8E5M2(0.0f)}; + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddInput("B_zero_point", {N / 128, K / 128}, b_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires B and B zero point FP8 types to match."); +} + +TEST(DynamicQuantMatMulFp8, RejectsNonZeroBZeroPoint) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + std::vector b_zp{Float8E4M3FN(1.0f)}; + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddInput("B_zero_point", {N / 128, K / 128}, b_zp); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 supports symmetric quantization only; B zero point values must be zero."); +} + +TEST(DynamicQuantMatMulFp8, NonDefaultBlockSizesWithPartialM) { + constexpr int64_t M = 9; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + constexpr int64_t BlockK = 2; + constexpr int64_t BlockN = 2; + + std::vector a_data(static_cast(M * K), 0.0f); + for (int64_t m = 0; m < M; ++m) { + for (int64_t k = 0; k < K; ++k) { + a_data[static_cast(m * K + k)] = static_cast((m + 1) * (k + 1)) / 16.0f; + } + } + + std::vector b_data(static_cast(K * N), Float8E4M3FN(0.0f)); + for (int64_t k = 0; k < K; ++k) { + for (int64_t n = 0; n < N; ++n) { + b_data[static_cast(k * N + n)] = Float8E4M3FN(k == n ? 1.0f : 0.0f); + } + } + std::vector b_scale{1.0f, 2.0f, + 3.0f, 4.0f}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, BlockK), + ComputeRuntimeBQuantized(b_data, b_scale, K, N, BlockK, BlockN), M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("block_size_k", BlockK); + test.AddAttribute("block_size_n", BlockN); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / BlockN, K / BlockK}, b_scale); + AddZeroPoint(test, "B_zero_point", {N / BlockN, K / BlockK}, b_scale.size(), false); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.01f); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, RejectsUnsupportedBlockSizeM) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(-0.5f)); + std::vector b_scale{1.0f}; + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("block_size_m", 2); + test.AddAttribute("block_size_k", 4); + test.AddAttribute("block_size_n", 4); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {1, 1}, b_scale); + AddZeroPoint(test, "B_zero_point", {1, 1}, b_scale.size(), false); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, "block_size_m must be 1"); +} + +TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsRestoresPackedBMetadata) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + constexpr int64_t BlockSize = 2; + + std::vector a_data{ + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 0.5f, 1.0f, -0.5f, -1.0f, + 4.0f, 3.0f, 2.0f, 1.0f}; + std::vector b_data{ + 2.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 2.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 2.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 2.0f}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, BlockSize), + ComputeConstantBQuantized(b_data, K, N, BlockSize, BlockSize), + M, N, K); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("fp8_type", Fp8TensorProtoType::value); + test.AddAttribute("block_size_k", BlockSize); + test.AddAttribute("block_size_n", BlockSize); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, y_data); + test.SetOutputAbsErr("Y", 0.01f); + + OrtValue b; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({K, N}), b_data.data(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + SessionOptions so; + ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); + + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t prepack_count_session_1 = 0; + size_t shared_prepack_count = 0; + test.Config(so).ConfigEps(cpu_ep()).RunWithConfig(&prepack_count_session_1, &shared_prepack_count); + ASSERT_EQ(shared_prepack_count, static_cast(0)); + ASSERT_EQ(test.GetNumPrePackedWeightsShared(), prepack_count_session_1); + ASSERT_GT(prepack_count_session_1, static_cast(0)); + + size_t prepack_count_session_2 = 0; + test.Config(so).ConfigEps(cpu_ep()).RunWithConfig(&prepack_count_session_2, &shared_prepack_count); + ASSERT_EQ(prepack_count_session_2, prepack_count_session_1); + ASSERT_EQ(shared_prepack_count, prepack_count_session_2); +} + +TEST(DynamicQuantMatMulFp8, SharedPrepackedWeightsWithComputedBScalesReuseCorrectly) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + constexpr int64_t BlockSize = 2; + + const std::vector a_data{ + 1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}; + std::vector b_data{ + 0.31f, -0.37f, 0.83f, -1.70f, + 1.20f, -2.30f, 3.40f, -4.50f, + 5.50f, -6.25f, 7.75f, -8.50f, + 9.00f, -10.50f, 11.25f, -12.75f}; + const auto y_data = ComputeMatMul(ComputeRowwiseAQuantized(a_data, M, K, BlockSize), + ComputeConstantBQuantized(b_data, K, N, BlockSize, BlockSize), + M, N, K); + + OrtValue b; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({K, N}), b_data.data(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + PrepackedWeightsContainer prepacked_weights_container; + + DynamicQuantMatMulFp8SessionTester test_1("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test_1.AddAttribute("block_size_k", BlockSize); + test_1.AddAttribute("block_size_n", BlockSize); + test_1.AddInput("A", {M, K}, a_data); + test_1.AddInput("B", {K, N}, b_data, true /*initializer*/); + test_1.AddOutput("Y", {M, N}, y_data); + test_1.SetOutputAbsErr("Y", 1e-5f); + + size_t shared_prepack_count = 0; + RunDynamicQuantMatMulFp8WithSharedPrepack(test_1, b, prepacked_weights_container, shared_prepack_count); + ASSERT_EQ(shared_prepack_count, static_cast(0)); + ASSERT_EQ(prepacked_weights_container.GetNumberOfElements(), static_cast(1)); + + DynamicQuantMatMulFp8SessionTester test_2("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test_2.AddAttribute("block_size_k", BlockSize); + test_2.AddAttribute("block_size_n", BlockSize); + test_2.AddInput("A", {M, K}, a_data); + test_2.AddInput("B", {K, N}, b_data, true /*initializer*/); + test_2.AddOutput("Y", {M, N}, y_data); + test_2.SetOutputAbsErr("Y", 1e-5f); + + RunDynamicQuantMatMulFp8WithSharedPrepack(test_2, b, prepacked_weights_container, shared_prepack_count); + ASSERT_GT(shared_prepack_count, static_cast(0)); + ASSERT_EQ(prepacked_weights_container.GetNumberOfElements(), static_cast(1)); +} + +TEST(DynamicQuantMatMulFp8, RejectsMalformedBScaleShapeBeforeReadingScaleData) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(0.5f)); + std::vector b_scale = FloatsToMLFloat16s({1.0f}); + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("block_size_k", 4); + test.AddAttribute("block_size_n", 4); + test.AddShapeToTensorData(false); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {1}, b_scale); + AddZeroPoint(test, "B_zero_point", {1, 1}, 1, false); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires B scale to be a 2D tensor."); +} + +TEST(DynamicQuantMatMulFp8, RejectsRuntimeFp8BWithoutBScale) { + constexpr int64_t M = 4; + constexpr int64_t N = 4; + constexpr int64_t K = 4; + + std::vector a_data(static_cast(M * K), 0.25f); + std::vector b_data(static_cast(K * N), Float8E4M3FN(0.5f)); + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddAttribute("block_size_k", 4); + test.AddAttribute("block_size_n", 4); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires B scale when B is already FP8."); +} + +TEST(DynamicQuantMatMulFp8, ZeroMInput) { + constexpr int64_t M = 0; + constexpr int64_t N = 128; + constexpr int64_t K = 128; + + std::vector a_data{}; + std::vector b_data(static_cast(K * N), 0.0f); + std::vector y_data{}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, y_data); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, ZeroKInput) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 0; + + std::vector a_data{}; + std::vector b_data{}; + std::vector y_data(static_cast(M * N), 0.0f); + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, y_data); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, ZeroKInputRejectsInvalidYScaleShape) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 0; + + std::vector a_data{}; + std::vector b_data{}; + std::vector y_data(static_cast(M * N), 0.0f); + std::vector y_scale{1.0f}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddShapeToTensorData(false); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddInput("Y_scale", {1}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "DynamicQuantMatMulFp8 requires Y scale input to be a scalar."); +} + +TEST(DynamicQuantMatMulFp8, ZeroKInputRejectsInvalidYScaleValue) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 0; + + std::vector a_data{}; + std::vector b_data{}; + std::vector y_data(static_cast(M * N), 0.0f); + std::vector y_scale{0.0f}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddInput("Y_scale", {}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Y scale values to be finite and positive."); +} + +TEST(DynamicQuantMatMulFp8, ZeroKInputRejectsInvalidYScaleType) { + constexpr int64_t M = 128; + constexpr int64_t N = 128; + constexpr int64_t K = 0; + + std::vector a_data{}; + std::vector b_data{}; + std::vector y_data(static_cast(M * N), 0.0f); + std::vector y_scale{1}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddInput("Y_scale", {}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, "Type Error"); +} + +TEST(DynamicQuantMatMulFp8, ZeroNInput) { + constexpr int64_t M = 128; + constexpr int64_t N = 0; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.0f); + std::vector b_data{}; + std::vector y_data{}; + std::vector b_scale{}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddOutput("Y", {M, N}, y_data); + test.Run(); +} + +TEST(DynamicQuantMatMulFp8, ZeroNInputRejectsInvalidYScaleValue) { + constexpr int64_t M = 128; + constexpr int64_t N = 0; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.0f); + std::vector b_data{}; + std::vector y_data{}; + std::vector b_scale{}; + std::vector y_scale{0.0f}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("B_scale", {N / 128, K / 128}, b_scale); + test.AddOptionalInputEdge(); + test.AddInput("Y_scale", {}, y_scale); + test.AddOutput("Y", {M, N}, y_data); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Y scale values to be finite and positive."); +} + +TEST(DynamicQuantMatMulFp8, ZeroNInputWithConstantNonFp8B) { + constexpr int64_t M = 128; + constexpr int64_t N = 0; + constexpr int64_t K = 128; + + std::vector a_data(static_cast(M * K), 0.0f); + std::vector b_data{}; + std::vector y_data{}; + + OpTester test("DynamicQuantMatMulFp8", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data, true /*initializer*/); + test.AddOutput("Y", {M, N}, y_data); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime +#endif // !defined(DISABLE_FLOAT8_TYPES) diff --git a/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp b/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp new file mode 100644 index 0000000000000..4c2e82f1c9e2a --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp @@ -0,0 +1,185 @@ +// Copyright (c) 2026 Arm Limited. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "mlas.h" +#include "test_util.h" +#include "core/common/float8.h" +#include "core/mlas/inc/mlas.h" + +#include +#include +#include + +#if !defined(DISABLE_FLOAT8_TYPES) + +namespace { + +uint8_t EncodeFp8(float value, mlas_fp8_mode mode) { + using onnxruntime::Float8E4M3FN; + using onnxruntime::Float8E4M3FNUZ; + using onnxruntime::Float8E5M2; + using onnxruntime::Float8E5M2FNUZ; + + switch (mode) { + case MLAS_FP8_MODE_E4M3_INF: + return Float8E4M3FN(value).val; + case MLAS_FP8_MODE_E4M3_SAT: + return Float8E4M3FNUZ(value).val; + case MLAS_FP8_MODE_E5M2_INF: + return Float8E5M2(value).val; + case MLAS_FP8_MODE_E5M2_SAT: + return Float8E5M2FNUZ(value).val; + default: + ORT_THROW("Unsupported FP8 GEMM test mode."); + } +} + +void RunFp8GemmBatchThreaded(mlas_fp8_mode mode) { + constexpr size_t BatchN = 2; + constexpr size_t M = 3; + constexpr size_t N = 2; + constexpr size_t K = 4; + constexpr size_t BlockSizeM = 2; + constexpr size_t BlockSizeK = 2; + constexpr size_t BlockSizeN = 1; + constexpr size_t BlocksM = 2; + constexpr size_t BlocksK = 2; + constexpr size_t BlocksN = 2; + constexpr size_t ScaleElements = BlocksM * BlocksK; + constexpr size_t BScaleElements = BlocksK * BlocksN; + + const std::array a_values{ + 1.0f, 2.0f, -1.0f, 0.5f, + -2.0f, 1.5f, 0.0f, 4.0f, + 0.25f, -0.5f, 3.0f, -4.0f, + -1.0f, 0.5f, 2.0f, -2.0f, + 4.0f, -0.25f, -1.5f, 1.0f, + 0.0f, 3.0f, -4.0f, 0.5f}; + const std::array b_values{ + 1.0f, -1.0f, + 0.5f, 2.0f, + -2.0f, 0.25f, + 1.5f, -0.5f, + -0.5f, 1.0f, + 2.0f, -2.0f, + 0.25f, 1.5f, + -1.0f, 0.5f}; + const std::array scale_a{ + 1.0f, 0.5f, + 2.0f, 1.5f, + 0.25f, 1.0f, + 0.5f, 2.0f}; + const std::array scale_b{ + 1.0f, 2.0f, + 0.25f, 1.25f, + 0.5f, 1.0f, + 2.0f, 0.25f}; + const std::array y_scale{0.5f, 2.0f}; + + std::vector a_fp8(a_values.size()); + std::vector b_fp8(b_values.size()); + for (size_t i = 0; i < a_values.size(); ++i) { + a_fp8[i] = EncodeFp8(a_values[i], mode); + } + for (size_t i = 0; i < b_values.size(); ++i) { + b_fp8[i] = EncodeFp8(b_values[i], mode); + } + + std::array output{}; + std::array expected{}; + std::array params{}; + + for (size_t batch = 0; batch < BatchN; ++batch) { + params[batch].A = a_fp8.data() + batch * M * K; + params[batch].lda = K; + params[batch].B = b_fp8.data() + batch * K * N; + params[batch].ldb = N; + params[batch].C = output.data() + batch * M * N; + params[batch].ldc = N; + params[batch].ScaleA = scale_a.data() + batch * ScaleElements; + params[batch].ScaleB = scale_b.data() + batch * BScaleElements; + params[batch].ScaleY = y_scale.data() + batch; + params[batch].Fp8Type = mode; + params[batch].BlockSizeM = BlockSizeM; + params[batch].BlockSizeK = BlockSizeK; + params[batch].BlockSizeN = BlockSizeN; + params[batch].BlocksM = BlocksM; + params[batch].BlocksK = BlocksK; + params[batch].BlocksN = BlocksN; + params[batch].ScaleAStrideK = 1; + params[batch].ScaleAStrideM = BlocksK; + params[batch].ScaleBStrideN = 1; + params[batch].ScaleBStrideK = BlocksN; + + for (size_t m = 0; m < M; ++m) { + const size_t block_m = m / BlockSizeM; + for (size_t n = 0; n < N; ++n) { + const size_t block_n = n / BlockSizeN; + float acc = 0.0f; + for (size_t k = 0; k < K; ++k) { + const size_t block_k = k / BlockSizeK; + const size_t a_scale_idx = batch * ScaleElements + block_m * BlocksK + block_k; + const size_t b_scale_idx = batch * BScaleElements + block_k * BlocksN + block_n; + const float a_deq = a_values[batch * M * K + m * K + k] * scale_a[a_scale_idx]; + const float b_deq = b_values[batch * K * N + k * N + n] * scale_b[b_scale_idx]; + acc += a_deq * b_deq; + } + expected[batch * M * N + m * N + n] = acc * y_scale[batch]; + } + } + } + + MLAS_FP8_GEMM_SHAPE_PARAMS shape{M, N, K}; + MLAS_THREADPOOL* threadpool = GetMlasThreadPool(); + if (threadpool == nullptr) { + GTEST_SKIP() << "MlasFp8GemmBatch threaded test requires an MLAS thread pool."; + } + + MlasFp8GemmBatch(shape, params.data(), BatchN, threadpool); + + // Inputs are exactly representable test values, so the scalar fallback should match closely. + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_NEAR(output[i], expected[i], 1e-5f); + } +} + +} // namespace + +TEST(Fp8Gemm, BatchedModesSymmetricThreaded) { + RunFp8GemmBatchThreaded(MLAS_FP8_MODE_E4M3_INF); + RunFp8GemmBatchThreaded(MLAS_FP8_MODE_E4M3_SAT); + RunFp8GemmBatchThreaded(MLAS_FP8_MODE_E5M2_INF); + RunFp8GemmBatchThreaded(MLAS_FP8_MODE_E5M2_SAT); +} + +TEST(Fp8Gemm, EmptyDimensionsSkipUnusedBufferValidation) { + MLAS_FP8_GEMM_DATA_PARAMS params{}; + params.Fp8Type = MLAS_FP8_MODE_E4M3_INF; + params.BlockSizeM = 2; + params.BlockSizeK = 2; + params.BlockSizeN = 2; + + MLAS_FP8_GEMM_SHAPE_PARAMS empty_output_shape{3, 0, 4}; + MlasFp8GemmBatch(empty_output_shape, ¶ms, 1, nullptr); + + std::array output; + output.fill(-1.0f); + + MLAS_FP8_GEMM_SHAPE_PARAMS empty_reduction_shape{3, 2, 0}; + params.C = output.data(); + params.ldc = 2; + MlasFp8GemmBatch(empty_reduction_shape, ¶ms, 1, nullptr); + + for (float value : output) { + EXPECT_EQ(value, 0.0f); + } +} + +TEST(Fp8Gemm, ZeroColumnReturnsBeforeWorkItemOverflow) { + MLAS_FP8_GEMM_SHAPE_PARAMS shape{std::numeric_limits::max(), 0, 4}; + MlasFp8GemmBatch(shape, nullptr, 2, nullptr); + SUCCEED(); +} + +#endif // !defined(DISABLE_FLOAT8_TYPES)