Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2108,7 +2108,8 @@ This version of the operator has been available since version 1 of the 'com.micr
3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used
to dequantize the output.
4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type.
5. For uint8 data, the `gather_axis` must be 0.
5. For uint8 data, the `gather_axis` must be 0. The supported `bits` values for uint8 data are 2, 4, and 8;
for `bits` < 8 the values are packed along the last dimension (low-order bits first).

#### Version

Expand All @@ -2118,7 +2119,7 @@ This version of the operator has been available since version 1 of the 'com.micr

<dl>
<dt><tt>bits</tt> : int</dt>
<dd>Number of bits used for weight quantization. Must be either 4 or 8. </dd>
<dd>Number of bits used for weight quantization. Must be 2, 4 or 8. </dd>
<dt><tt>block_size</tt> : int</dt>
<dd>(Optional) block size used for weight quantization. It needs to be a power of 2 and not smaller than 16.</dd>
<dt><tt>gather_axis</tt> : int</dt>
Expand Down
48 changes: 40 additions & 8 deletions onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ int32_t Get4BitElement<uint8_t>(const uint8_t* data_ptr, int64_t data_idx) {
return data_val;
}

// Extracts a 2-bit element from uint8_t storage. Four elements are packed per byte,
// with element index 0 in the lowest 2 bits and index 3 in the highest 2 bits.
int32_t Get2BitElementUint8(const uint8_t* data_ptr, int64_t data_idx) {
const uint8_t data_val_u8 = data_ptr[data_idx >> 2];
const int shift = static_cast<int>((data_idx & 3) * 2);
return static_cast<int32_t>((data_val_u8 >> shift) & 0x03);
}

} // namespace

template <typename T1, typename Tind>
Expand All @@ -53,7 +61,12 @@ class GatherBlockQuantized : public OpKernel {

constexpr int64_t default_bits = 4;
info.GetAttrOrDefault("bits", &bits_, default_bits);
ORT_ENFORCE(bits_ == 4 || bits_ == 8, "GatherBlockQuantized only support bits==4 or 8");
if constexpr (std::is_same_v<T1, uint8_t>) {
ORT_ENFORCE(bits_ == 2 || bits_ == 4 || bits_ == 8,
"GatherBlockQuantized with uint8 data only supports bits==2, 4, or 8");
} else {
ORT_ENFORCE(bits_ == 4, "GatherBlockQuantized with int4/uint4 data only supports bits==4");
}

ORT_ENFORCE(block_size_ >= 16 && ((block_size_ - 1) & block_size_) == 0,
"'block_size' must be 2's power and not less than 16.");
Expand Down Expand Up @@ -221,10 +234,12 @@ Status GatherBlockQuantized<T1, Tind>::CopyDataAndDequantize(const T1* data_ptr,
int32_t data_val;
if constexpr (!std::is_same_v<T1, uint8_t>) {
data_val = Get4BitElement(data_ptr, data_idx);
} else { // unit8_t
if (bits_ == 4) {
} else { // uint8_t
if (bits_ == 2) {
data_val = Get2BitElementUint8(data_ptr, data_idx);
} else if (bits_ == 4) {
data_val = Get4BitElement(data_ptr, data_idx);
} else { // buts_ == 8
} else { // bits_ == 8
data_val = static_cast<int32_t>(data_ptr[data_idx]);
}
}
Expand All @@ -238,9 +253,25 @@ Status GatherBlockQuantized<T1, Tind>::CopyDataAndDequantize(const T1* data_ptr,

if constexpr (std::is_same_v<T1, uint8_t>) {
if (zero_points_ptr) {
if (bits_ == 4) {
uint8_t packed = zero_points_ptr[scale_idx >> 1];
if (scale_idx & 1) {
// For uint8 we enforce quantize_axis == last dim, which makes quantize_N == 1
// and scale_full_block == scale_qaxis_dim. Zero points are packed only along
// the quantize axis, so the packed byte must be addressed using the scale row
// index and the within-row quantize-axis index, not the flat scale_idx; the
// latter crosses row boundaries when scale_qaxis_dim is not a multiple of the
// packing factor.
const int64_t scale_qaxis_dim = scale_full_block;
const int64_t scale_row = scale_idx / scale_qaxis_dim;
const int64_t q_in_row = scale_idx % scale_qaxis_dim;
if (bits_ == 2) {
const int64_t packed_zp_qaxis_dim = (scale_qaxis_dim + 3) / 4;
const int64_t byte_idx = scale_row * packed_zp_qaxis_dim + (q_in_row >> 2);
const int shift = static_cast<int>((q_in_row & 3) * 2);
zp_val = static_cast<int32_t>((zero_points_ptr[byte_idx] >> shift) & 0x03);
} else if (bits_ == 4) {
const int64_t packed_zp_qaxis_dim = (scale_qaxis_dim + 1) / 2;
const int64_t byte_idx = scale_row * packed_zp_qaxis_dim + (q_in_row >> 1);
uint8_t packed = zero_points_ptr[byte_idx];
if (q_in_row & 1) {
zp_val = static_cast<int32_t>((packed >> 4) & 0x0F);
} else {
zp_val = static_cast<int32_t>(packed & 0x0F);
Expand All @@ -249,7 +280,8 @@ Status GatherBlockQuantized<T1, Tind>::CopyDataAndDequantize(const T1* data_ptr,
zp_val = static_cast<int32_t>(zero_points_ptr[scale_idx]);
}
} else {
const int32_t default_zero_point = bits_ == 4 ? 8 : 128;
// Default zero point is 2^(bits-1): 2 for 2-bit, 8 for 4-bit, 128 for 8-bit.
const int32_t default_zero_point = 1 << (static_cast<int>(bits_) - 1);
zp_val = default_zero_point;
}
} else {
Expand Down
104 changes: 95 additions & 9 deletions onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias);

bool is_4bit = bits_ == 4;
const bool is_2bit = bits_ == 2;
const bool is_4bit = bits_ == 4;
const std::string unpack = (is_signed_) ? "unpack4xI8" : "unpack4xU8";

shader.MainFunctionBody()
Expand Down Expand Up @@ -57,7 +58,18 @@
shader.MainFunctionBody()
<< " let data_offset = " << x_shape.IndicesToOffset("data_indices") << ";\n";

if (is_4bit) {
if (is_2bit) {
// 2-bit values are packed 4 per byte (LSB first). x is the original uint8 tensor with
// Flatten=4 (4 bytes per u32); the input_shape uniform here is the *dequantized* shape,
// so data_offset is the dequantized 2-bit-element index.
shader.MainFunctionBody()
<< " let byte_idx_2b = data_offset / 4;\n"
<< " let bit_shift_2b = (data_offset % 4) * 2;\n"
<< " let packed_word_2b = " << x.GetByOffset("byte_idx_2b / 4") << ";\n"
<< " let byte_in_word_2b = byte_idx_2b % 4;\n"
<< " let unpacked_bytes_2b = " << unpack << "(u32(packed_word_2b));\n"
<< " var quantized_data = (unpacked_bytes_2b[byte_in_word_2b] >> bit_shift_2b) & 0x3;\n";
} else if (is_4bit) {
shader.MainFunctionBody()
<< " let data_index = data_offset % 8;\n"
<< " let packed_4bit_quantized_data = " << x.GetByOffset("data_offset / 8") << ";\n"
Expand All @@ -83,15 +95,41 @@
<< " var scale = " << scales.GetByIndices("scale_indices") << ";\n";

if (!has_zeropoint_) {
const std::string default_zero_point = is_uint8_ ? is_4bit ? "input_element_t(8)" : "input_element_t(128)" : "input_element_t(0)";
std::string default_zero_point;

Check warning on line 98 in onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc:98: Add #include <string> for string [build/include_what_you_use] [4]

Check warning on line 98 in onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc:98: Add #include <string> for string [build/include_what_you_use] [4]
if (is_uint8_) {
if (is_2bit) {
default_zero_point = "input_element_t(2)";
} else if (is_4bit) {
default_zero_point = "input_element_t(8)";
} else {
default_zero_point = "input_element_t(128)";
}
} else {
default_zero_point = "input_element_t(0)";
}
shader.MainFunctionBody()
<< " let zero_point = " << default_zero_point << ";\n";
} else {
const auto& zero_point = shader.AddInput("zero_point", ShaderUsage::None);
shader.MainFunctionBody()
<< " let zero_point_indices = scale_indices;\n"
<< " let zero_point_offset = " << scales.IndicesToOffset("zero_point_indices") << ";\n";
if (is_4bit) {
if (is_2bit) {
// 2-bit zero points are packed 4-per-byte along the quantize axis only. The scales
// tensor's flat offset cannot be used directly because dividing it by 4 crosses row
// boundaries when scale_qaxis_dim is not a multiple of 4 (e.g. scales {2,3,1} has
// packed zp shape {2,3,1} with one usable 2-bit value per byte per row). Derive the
// packed byte index from the scale row index plus the within-row quantize-axis index.
shader.MainFunctionBody()
<< " let q_idx_2b = " << scales.IndicesGet("scale_indices", "uniforms.quantize_axis") << ";\n"
<< " let scale_row_2b = zero_point_offset / uniforms.scale_qaxis_dim;\n"
<< " let zp_byte_offset_2b = scale_row_2b * uniforms.zp_packed_qaxis_dim + q_idx_2b / 4u;\n"
<< " let zp_bit_shift_2b = (q_idx_2b % 4u) * 2u;\n"
<< " let packed_zp_word_2b = " << zero_point.GetByOffset("zp_byte_offset_2b / 4") << ";\n"
<< " let zp_byte_in_word_2b = zp_byte_offset_2b % 4;\n"
<< " let zp_unpacked_2b = " << unpack << "(u32(packed_zp_word_2b));\n"
<< " var zero_point = (zp_unpacked_2b[zp_byte_in_word_2b] >> zp_bit_shift_2b) & 0x3;\n";
} else if (is_4bit) {
shader.MainFunctionBody()
<< " let zero_point_index = zero_point_offset % 8;\n"
<< " let packed_4bit_zero_points = " << zero_point.GetByOffset("zero_point_offset / 8") << ";\n"
Expand Down Expand Up @@ -142,6 +180,16 @@
int64_t x_dtype = x->GetElementType();
bool is_signed = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4;
bool is_int8 = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
bool is_uint8 = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;

// Only uint8 storage supports the full bits set {2, 4, 8}. The packed int4/uint4 types
// can only carry bits==4, matching the CPU kernel's constraint.
if (is_uint8) {
ORT_RETURN_IF_NOT(bits_ == 2 || bits_ == 4 || bits_ == 8,
"'bits' must be 2, 4 or 8 for uint8 input.");
} else {
ORT_RETURN_IF_NOT(bits_ == 4, "'bits' must be 4 for non-uint8 input.");
}

std::optional<Tensor> data_representation_4bit;
std::optional<Tensor> zero_points_representation_4bit;
Expand Down Expand Up @@ -174,7 +222,19 @@
zero_points = zero_points_representation_4bit.has_value() ? &zero_points_representation_4bit.value() : zero_points;
}

const auto& x_shape = x->Shape();
const auto& x_shape_intrinsic = x->Shape();
// For bits == 2 with uint8 storage we don't construct a packed-type reinterpret (no UInt2x4 type
// exists). Instead, build a logical "dequantized" shape (last dim x4) and feed that to the shader
// as the input_shape uniform. The buffer remains the original uint8 storage with Flatten=4, and
// the shader does explicit byte+bit-position extraction.
TensorShape x_shape;
if (bits_ == 2 && is_uint8) {
TensorShapeVector v = x_shape_intrinsic.AsShapeVector();
v.back() *= 4;
x_shape = TensorShape(std::move(v));

Check warning on line 234 in onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc:234: Add #include <utility> for move [build/include_what_you_use] [4]

Check warning on line 234 in onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc:234: Add #include <utility> for move [build/include_what_you_use] [4]
} else {
x_shape = x_shape_intrinsic;
}

size_t indices_rank = indices->Shape().NumDimensions();
const auto scales_shape = scales->Shape();
Expand All @@ -195,6 +255,17 @@
int64_t output_size = output_shape.Size();
auto* output_tensor = context.Output(0, output_shape);

// For the 2-bit zero-point path we need to address the packed byte using the scale row index
// and the within-row quantize-axis index (not the flat scales offset, which crosses row
// boundaries when scale_qaxis_dim isn't a multiple of the packing factor). To keep the shader
// simple we require quantize_axis to be the last dim for uint8 2-bit, matching the CPU kernel.
if (bits_ == 2 && is_uint8) {
ORT_RETURN_IF_NOT(quantize_axis == x_rank - 1,
"For uint8 2-bit data, quantize_axis must be the last dimension.");
}
const uint32_t scale_qaxis_dim = static_cast<uint32_t>(scales_shape[quantize_axis]);
const uint32_t zp_packed_qaxis_dim = (scale_qaxis_dim + 3) / 4;

GatherBlockQuantizedProgram program{is_signed, is_int8, indices_rank, gather_axis, bits_, zero_points != nullptr, x_shape, output_shape};

program
Expand All @@ -208,12 +279,27 @@
.AddUniformVariables({{static_cast<uint32_t>(quantize_axis)}})
.AddUniformVariables({{static_cast<uint32_t>(gather_axis)}})
.AddUniformVariables({{static_cast<uint32_t>(block_size_)}})
.CacheHint(std::to_string(gather_axis), std::to_string(quantize_axis), std::to_string(block_size_));
.AddUniformVariables({{scale_qaxis_dim}})
.AddUniformVariables({{zp_packed_qaxis_dim}})
.CacheHint(std::to_string(bits_), std::to_string(gather_axis), std::to_string(quantize_axis), std::to_string(block_size_));

if (zero_points != nullptr) {
ORT_RETURN_IF_NOT(scales_shape == zero_points->Shape(),
"scales and zero_points must have the same shape.");
auto zero_points_shape = zero_points->Shape();
if (bits_ == 2 && is_uint8) {
// 2-bit zero points are packed 4 per byte along the quantize axis.
const auto& zp_shape = zero_points->Shape();
ORT_RETURN_IF_NOT(zp_shape.NumDimensions() == scales_shape.NumDimensions(),
"scales and zero_points must have the same rank.");
for (size_t i = 0; i < scales_shape.NumDimensions(); ++i) {
int64_t expected = (i == static_cast<size_t>(quantize_axis))
? (scales_shape[i] + 3) / 4
: scales_shape[i];
ORT_RETURN_IF_NOT(zp_shape[i] == expected,
"zero_points shape does not match expected packed shape for 2-bit data.");
}
} else {
ORT_RETURN_IF_NOT(scales_shape == zero_points->Shape(),
"scales and zero_points must have the same shape.");
}
program.AddInputs({{zero_points, ProgramTensorMetadataDependency::None, ProgramInput::Flatten, (bits_ == 4) ? 8 : 4}});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ class GatherBlockQuantizedProgram final : public Program<GatherBlockQuantizedPro
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
{"quantize_axis", ProgramUniformVariableDataType::Uint32},
{"gather_axis", ProgramUniformVariableDataType::Uint32},
{"block_size", ProgramUniformVariableDataType::Uint32});
{"block_size", ProgramUniformVariableDataType::Uint32},
{"scale_qaxis_dim", ProgramUniformVariableDataType::Uint32},
{"zp_packed_qaxis_dim", ProgramUniformVariableDataType::Uint32});

private:
bool is_signed_;
Expand All @@ -52,11 +54,10 @@ class GatherBlockQuantized final : public WebGpuKernel {
quantize_axis_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("quantize_axis", 1));
bits_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("bits", 4));

ORT_ENFORCE(bits_ == 4 || bits_ == 8, "'bits' must be 4 or 8.");
ORT_ENFORCE(bits_ == 2 || bits_ == 4 || bits_ == 8, "'bits' must be 2, 4 or 8.");
Comment thread
HectorSVC marked this conversation as resolved.
ORT_ENFORCE(block_size_ >= 16 && ((block_size_ - 1) & block_size_) == 0,
"'block_size' must be 2's power and not less than 16.");
}

Status ComputeInternal(ComputeContext& context) const override;

private:
Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3692,7 +3692,8 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used
to dequantize the output.
4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type.
5. For uint8 data, the `gather_axis` must be 0.
5. For uint8 data, the `gather_axis` must be 0. The supported `bits` values for uint8 data are 2, 4, and 8;
for `bits` < 8 the values are packed along the last dimension (low-order bits first).
)DOC";

ONNX_CONTRIB_OPERATOR_SCHEMA(GatherBlockQuantized)
Expand All @@ -3712,7 +3713,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
AttributeProto::INT,
static_cast<int64_t>(128))
.Attr("bits",
"Number of bits used for weight quantization. Must be either 4 or 8. ",
"Number of bits used for weight quantization. Must be 2, 4 or 8. ",
AttributeProto::INT,
static_cast<int64_t>(4))
.Input(0, "data", "Tensor of rank r >= 1. Block-wise quantized.", "T1")
Expand Down Expand Up @@ -3799,9 +3800,9 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
if (!zp_shape.dim(i).has_dim_value() ||
zp_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value()) {
if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 &&
bits == 4 &&
components > 1 &&
i == quantize_axis &&
zp_shape.dim(i).dim_value() == (scales_shape.dim(i).dim_value() + 1) / 2) {
zp_shape.dim(i).dim_value() == (scales_shape.dim(i).dim_value() + components - 1) / components) {
continue;
}
fail_shape_inference("zero points shape and scales shape do not match");
Expand Down
Loading
Loading