diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index d3c125dbc..ad28e8870 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -544,6 +544,19 @@ static void init_row_q4x4x2(block_q4_0 * x, int64_t k) { } } + +static inline void unpack_q4_1_quants(uint8_t * y, const block_q4_1 * x, int b) { + for (int i = 0; i < QK4_1 / 2; i++) { + y[b * QK4_1 / 2 + i + 000] = x->qs[i] & 0x0F; + y[b * QK4_1 / 2 + i + 128] = x->qs[i] >> 4; + } +} +static inline void pack_q4_1_quants(block_q4_1 * y, const uint8_t * x, int b) { + for (int i = 0; i < QK4_1 / 2; i++) { + y->qs[i] = (x[b * QK4_1 / 2 + i + 000] & 0x0F) | (x[b * QK4_1 / 2 + i + 128] << 4); + } +} + // repack q4_0 data into q4x4x2 tensor static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) { int64_t nrows = ggml_nrows(t); @@ -605,6 +618,237 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) ggml_aligned_free(buf_rp, row_size_rp); } + +static void repack_row_q4x4x2_q4_1(uint8_t * y, const block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_1x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers + + const int dblk_size = 8 * 2; // 8x __fp16 for d + const int mblk_size = 8 * 2; // 8x __fp16 for m + const int qblk_size = qk / 2; // int4 + const int qrow_size = k / 2; // int4 (not padded to blocks) + const int drow_size = (k + qk - 1) / qk * (qk / 32 * 2); // padded drow_size + + uint8_t * y_q = y + 0; // quants first + uint8_t * y_d = y + qrow_size; // then scales + uint8_t * y_m = y + qrow_size + drow_size; // then mins + + // Repack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_1x4x2]; // unpacked quants + unpack_q4_1_quants(qs, &x[i * 8 + 0], 0); + unpack_q4_1_quants(qs, &x[i * 8 + 1], 1); + unpack_q4_1_quants(qs, &x[i * 8 + 2], 2); + unpack_q4_1_quants(qs, &x[i * 8 + 3], 3); + unpack_q4_1_quants(qs, &x[i * 8 + 4], 4); + unpack_q4_1_quants(qs, &x[i * 8 + 5], 5); + unpack_q4_1_quants(qs, &x[i * 8 + 6], 6); + unpack_q4_1_quants(qs, &x[i * 8 + 7], 7); + + bool partial = (nloe && i == nb-1); + + uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; + } + } + + // Repack the scales and minimums + for (int i = 0; i < nb; i++) { + // Repack the scales + ggml_half * d = (ggml_half *) (y_d + i * dblk_size); + d[0] = x[i * 8 + 0].d; + d[1] = x[i * 8 + 1].d; + d[2] = x[i * 8 + 2].d; + d[3] = x[i * 8 + 3].d; + d[4] = x[i * 8 + 4].d; + d[5] = x[i * 8 + 5].d; + d[6] = x[i * 8 + 6].d; + d[7] = x[i * 8 + 7].d; + + // Repack the minimums + ggml_half * m = (ggml_half *) (y_m + i * mblk_size); + m[0] = x[i * 8 + 0].m; + m[1] = x[i * 8 + 1].m; + m[2] = x[i * 8 + 2].m; + m[3] = x[i * 8 + 3].m; + m[4] = x[i * 8 + 4].m; + m[5] = x[i * 8 + 5].m; + m[6] = x[i * 8 + 6].m; + m[7] = x[i * 8 + 7].m; + } +} + +static void unpack_row_q4x4x2_q4_1(block_q4_1 * y, const uint8_t * x, int64_t k) { + static const int qk = QK_Q4_1x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers + + const int dblk_size = 8 * 2; // 8x __fp16 for d + const int mblk_size = 8 * 2; // 8x __fp16 for m + const int qblk_size = qk / 2; // int4 + const int qrow_size = k / 2; // int4 (not padded to blocks) + const int drow_size = (k + qk - 1) / qk * (qk / 32 * 2); // padded drow_size + + const uint8_t * x_q = x + 0; // quants first + const uint8_t * x_d = x + qrow_size; // then scales + const uint8_t * x_m = x + qrow_size + drow_size; // then mins + + // Unpack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_1x4x2]; // unpacked quants + + bool partial = (nloe && i == nb-1); + + const uint8_t * q = x_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + if (partial) { + qs[j*2+0] = q[j] & 0x0F; + qs[j*2+1] = q[j] >> 4; + } else { + qs[j+000] = q[j] & 0x0F; + qs[j+128] = q[j] >> 4; + } + } + + pack_q4_1_quants(&y[i * 8 + 0], qs, 0); + pack_q4_1_quants(&y[i * 8 + 1], qs, 1); + pack_q4_1_quants(&y[i * 8 + 2], qs, 2); + pack_q4_1_quants(&y[i * 8 + 3], qs, 3); + pack_q4_1_quants(&y[i * 8 + 4], qs, 4); + pack_q4_1_quants(&y[i * 8 + 5], qs, 5); + pack_q4_1_quants(&y[i * 8 + 6], qs, 6); + pack_q4_1_quants(&y[i * 8 + 7], qs, 7); + } + + // Unpack the scales and minimums + for (int i = 0; i < nb; i++) { + // Unpack the scales + const ggml_half * d = (const ggml_half *) (x_d + i * dblk_size); + y[i * 8 + 0].d = d[0]; + y[i * 8 + 1].d = d[1]; + y[i * 8 + 2].d = d[2]; + y[i * 8 + 3].d = d[3]; + y[i * 8 + 4].d = d[4]; + y[i * 8 + 5].d = d[5]; + y[i * 8 + 6].d = d[6]; + y[i * 8 + 7].d = d[7]; + + // Unpack the minimums + const ggml_half * m = (const ggml_half *) (x_m + i * mblk_size); + y[i * 8 + 0].m = m[0]; + y[i * 8 + 1].m = m[1]; + y[i * 8 + 2].m = m[2]; + y[i * 8 + 3].m = m[3]; + y[i * 8 + 4].m = m[4]; + y[i * 8 + 5].m = m[5]; + y[i * 8 + 6].m = m[6]; + y[i * 8 + 7].m = m[7]; + } +} + +static inline void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) { + const int nb = (k + QK_Q4_1x4x2 - 1) / QK_Q4_1x4x2; + memset(x, 0, nb * QK_Q4_1x4x2 / 2 + nb * 8 * 2 * 2); +} + +static void repack_q4_1_q4x4x2(ggml_tensor * t, const void * data, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_1x4x2)); // extra elements for the pad + size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4_1-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + memcpy(buf_pd, src, row_size); + repack_row_q4x4x2_q4_1((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + memcpy(buf_pd, src, n_rem_bytes); + repack_row_q4x4x2_q4_1((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +static void repack_q4x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_1x4x2)); // extra elements for the pad + size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_1 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + memset(buf_pd, 0, row_size_pd); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + memcpy(buf_pd, src, row_size); + unpack_row_q4x4x2_q4_1((block_q4_1 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + memcpy(buf_pd, src, n_rem_bytes); + unpack_row_q4x4x2_q4_1((block_q4_1 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + + // repack q4x4x2 tensor into q4_0 data static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) { int64_t nrows = ggml_nrows(t); @@ -1365,6 +1609,12 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, repack_q4_0_q4x4x2(tensor, data, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4_1_q4x4x2(tensor, data, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1407,6 +1657,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, repack_q4x4x2_q4_0(data, tensor, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4x4x2_q4_1(data, tensor, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -2327,6 +2583,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -2377,6 +2634,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -3558,6 +3816,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { // Basic sanity checks to make sure definitions match static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_Q4_1 == (unsigned int) GGML_TYPE_Q4_1, + "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0, "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 9e8c9966e..375dd5b9f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -62,6 +62,8 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { case HTP_TYPE_Q4_0: case HTP_TYPE_IQ4_NL: return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + case HTP_TYPE_Q4_1: + return (size_t) nb * (QK_Q4_1x4x2 / 2 + HMX_X4X2_DBLK_SIZE + HMX_X4X2_DBLK_SIZE); // 160 * nb case HTP_TYPE_Q8_0: return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb case HTP_TYPE_MXFP4: @@ -177,6 +179,53 @@ static int hmx_compute_chunks(size_t vtcm_total, // --- x4x2 format dequantizers --- + +static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx( + const uint8_t *packed_32, bool upper_nibbles, + const __fp16 *scale, const __fp16 *min, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_scales = hvx_vec_splat_f16(*scale); + HVX_Vector v_mins = hvx_vec_splat_f16(*min); + // q4x4x2 stores two int4 values per byte. Keep only the selected nibble. + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + // Shuffle before LUT + v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + HVX_Vector v_scaled = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); + return Q6_Vhf_vadd_VhfVhf(v_scaled, v_mins); +} + +static inline void dequantize_x4x2_q4_1_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_4, const __fp16 *mins_4, const HVX_Vector vlut_cvt, + HVX_Vector out[4]) { + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + // Shuffle before LUT + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + HVX_Vector v_scales = Q6_Vh_vdeal_Vh(Q6_V_vshuff_Vb(*(const HVX_UVector *) scales_4)); + HVX_Vector v_mins = Q6_Vh_vdeal_Vh(Q6_V_vshuff_Vb(*(const HVX_UVector *) mins_4)); + + HVX_VectorPair vp0 = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_VectorPair vp1 = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 1); + HVX_VectorPair vp2 = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 2); + HVX_VectorPair vp3 = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 3); + + out[0] = Q6_Vhf_vadd_VhfVhf(Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp0), v_scales)), v_mins); + out[1] = Q6_Vhf_vadd_VhfVhf(Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp0), v_scales)), v_mins); + out[2] = Q6_Vhf_vadd_VhfVhf(Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp1), v_scales)), v_mins); + out[3] = Q6_Vhf_vadd_VhfVhf(Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp1), v_scales)), v_mins); +} + + // Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. // In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles // of the same 32 packed bytes. @@ -336,7 +385,7 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( int start_tile, int end_tile) { const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS; - const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); + const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL); const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block; const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : @@ -363,6 +412,8 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales + unsigned min_off = qrow_size + (k_block / 16) + blk_idx * HMX_X4X2_DBLK_SIZE + + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive mins __fp16 *tile_bases[4]; for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } @@ -375,14 +426,22 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { HVX_Vector v0[2]; const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); + if (weight_type == HTP_TYPE_Q4_1) { + dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), (const __fp16 *)(r0 + min_off), vlut_cvt, v0); + } else { + dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); + } Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); r0 = vtcm_src + row_offset; row_offset += row_stride; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); + if (weight_type == HTP_TYPE_Q4_1) { + dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), (const __fp16 *)(r0 + min_off), vlut_cvt, v0); + } else { + dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); + } Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); @@ -452,6 +511,7 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( bool upper = (sub_blk >= 4); unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + unsigned min_off = qrow_size + (k_block / 16) + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); HVX_Vector v_off = v_scat_base; // reset to column 0 unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; @@ -460,12 +520,22 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx( - r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector v1 = (row1 < n_cols) - ? dequantize_x4x2_q4_0_group_hvx( - r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) - : Q6_V_vzero(); + HVX_Vector v0, v1; + if (weight_type == HTP_TYPE_Q4_1) { + v0 = dequantize_x4x2_q4_1_group_hvx( + r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), (const __fp16 *)(r0 + min_off), vlut_cvt); + v1 = (row1 < n_cols) + ? dequantize_x4x2_q4_1_group_hvx( + r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), (const __fp16 *)(r1 + min_off), vlut_cvt) + : Q6_V_vzero(); + } else { + v0 = dequantize_x4x2_q4_0_group_hvx( + r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + v1 = (row1 < n_cols) + ? dequantize_x4x2_q4_0_group_hvx( + r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) + : Q6_V_vzero(); + } Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); @@ -1018,6 +1088,13 @@ static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct h dma_queue_push(ctx->dma[0], dma_make_ptr(dst, src + quant_off), dst_stride, src_stride, quant_width, n_rows); // 2D DMA: scales sub-range dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width, src + scale_off), dst_stride, src_stride, scale_width, n_rows); + + if (weight_type == HTP_TYPE_Q4_1) { + const size_t full_drow = k / 16; + const size_t min_off = full_qrow + full_drow + blk_start * scale_blk_size; + const size_t min_width = nb_sub * scale_blk_size; + dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width + scale_width, src + min_off), dst_stride, src_stride, min_width, n_rows); + } } TIMER_STOP(fetch); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 6203e3848..2241f3aac 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -20,7 +20,9 @@ enum htp_data_type { HTP_TYPE_F32 = 0, HTP_TYPE_F16 = 1, HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q4_1 = 3, HTP_TYPE_Q8_0 = 8, + HTP_TYPE_Q8_1 = 9, HTP_TYPE_IQ4_NL = 20, HTP_TYPE_I32 = 26, HTP_TYPE_I64 = 27, @@ -28,7 +30,9 @@ enum htp_data_type { // types used internally for repack, dyn.quant, etc HTP_TYPE_Q4_0x4x2 = 200, + HTP_TYPE_Q4_1x4x2 = 202, HTP_TYPE_Q8_0x4x2, + HTP_TYPE_Q8_1x4x2, HTP_TYPE_MXFP4x4x2, HTP_TYPE_INVALID @@ -36,7 +40,9 @@ enum htp_data_type { // Constats for internal types #define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) +#define QK_Q4_1x4x2 256 // 4x Q4_1 blocks packed with next 4x Q4_1 blocks (size in bytes 160) #define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks +#define QK_Q8_1x4x2 256 // 4x Q8_1 blocks concat with next 4x Q8_1 blocks #define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 2461ae617..0f419bf94 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -999,6 +999,330 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } + +static void vec_dot_q4_1x4x2_q8_1x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + const uint32_t qk = QK_Q4_1x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; + const uint32_t x_mblk_size = 8 * 4 * 2; + const uint32_t x_qblk_size = qk / 2; + const uint32_t x_qrow_size = n / 2; + const uint32_t x_drow_size = n / 32 * 2; + + const uint32_t y_dblk_size = 8 * 4 * 2; + const uint32_t y_sblk_size = 8 * 4 * 2; + const uint32_t y_qblk_size = qk; + const uint32_t y_qrow_size = n; + const uint32_t y_drow_size = n / 32 * 2; + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); + const uint8_t * restrict r0_x_m = ((const uint8_t *) vx0 + x_qrow_size + x_drow_size); + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); + const uint8_t * restrict y_s = ((const uint8_t *) vy0 + y_qrow_size + y_drow_size); + + HVX_Vector r0_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_vq = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_vq = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_vq, vy_vq)); + + HVX_Vector vy_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_vs = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_s + i * y_sblk_size)); + + HVX_Vector r0_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r0_vm = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_mblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vd, vy_vd))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vm, vy_vs))); + + HVX_Vector r0_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd), r0_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_vq = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_vq = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_vq, vy_vq)); + + HVX_Vector vy_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_vs = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_s + i * y_sblk_size)); + + HVX_Vector r0_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r0_vm = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_mblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vd, vy_vd))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vm, vy_vs))); + + HVX_Vector r0_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd), r0_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + HVX_Vector s0_vec = hvx_vec_reduce_sum_f32(r0_sum); + hvx_vec_store_u(s0, 4, s0_vec); +} + +static void vec_dot_q4_1x4x2_q8_1x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const uint32_t qk = QK_Q4_1x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; + const uint32_t x_mblk_size = 8 * 4 * 2; + const uint32_t x_qblk_size = qk / 2; + const uint32_t x_qrow_size = n / 2; + const uint32_t x_drow_size = n / 32 * 2; + + const uint32_t y_dblk_size = 8 * 4 * 2; + const uint32_t y_sblk_size = 8 * 4 * 2; + const uint32_t y_qblk_size = qk; + const uint32_t y_qrow_size = n; + const uint32_t y_drow_size = n / 32 * 2; + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); + const uint8_t * restrict r0_x_m = ((const uint8_t *) vx0 + x_qrow_size + x_drow_size); + + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1 + 0); + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1 + x_qrow_size); + const uint8_t * restrict r1_x_m = ((const uint8_t *) vx1 + x_qrow_size + x_drow_size); + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); + const uint8_t * restrict y_s = ((const uint8_t *) vy0 + y_qrow_size + y_drow_size); + + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_vq = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_vq = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_vq = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_vq, vy_vq)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_vq, vy_vq)); + + HVX_Vector vy_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_vs = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_s + i * y_sblk_size)); + + HVX_Vector r0_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r0_vm = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_mblk_size)); + + HVX_Vector r1_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r1_vm = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_m + i * x_mblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vd, vy_vd))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vm, vy_vs))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_vd, vy_vd))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_vm, vy_vs))); + + HVX_Vector r0_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd), r0_ms); + HVX_Vector r1_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd), r1_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_vq = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_vq = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_vq = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_vq, vy_vq)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_vq, vy_vq)); + + HVX_Vector vy_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_vs = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_s + i * y_sblk_size)); + + HVX_Vector r0_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r0_vm = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_mblk_size)); + + HVX_Vector r1_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r1_vm = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_m + i * x_mblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vd, vy_vd))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vm, vy_vs))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_vd, vy_vd))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_vm, vy_vs))); + + HVX_Vector r0_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd), r0_ms); + HVX_Vector r1_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd), r1_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector s0_vec = hvx_vec_reduce_sum_f32(r0_sum); + HVX_Vector s1_vec = hvx_vec_reduce_sum_f32(r1_sum); + hvx_vec_store_u(&s0[0], 4, s0_vec); + hvx_vec_store_u(&s0[1], 4, s1_vec); +} + +static void vec_dot_q4_1x4x2_q8_1x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const uint32_t qk = QK_Q4_1x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; + const uint32_t x_mblk_size = 8 * 4 * 2; + const uint32_t x_qblk_size = qk / 2; + const uint32_t x_qrow_size = n / 2; + const uint32_t x_drow_size = n / 32 * 2; + + const uint32_t y_dblk_size = 8 * 4 * 2; + const uint32_t y_sblk_size = 8 * 4 * 2; + const uint32_t y_qblk_size = qk; + const uint32_t y_qrow_size = n; + const uint32_t y_drow_size = n / 32 * 2; + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); + const uint8_t * restrict r0_x_m = ((const uint8_t *) vx0 + x_qrow_size + x_drow_size); + + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1 + 0); + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1 + x_qrow_size); + const uint8_t * restrict r1_x_m = ((const uint8_t *) vx1 + x_qrow_size + x_drow_size); + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0 + 0); + const uint8_t * restrict y0_d = ((const uint8_t *) vy0 + y_qrow_size); + const uint8_t * restrict y0_s = ((const uint8_t *) vy0 + y_qrow_size + y_drow_size); + + const uint8_t * restrict y1_q = ((const uint8_t *) vy1 + 0); + const uint8_t * restrict y1_d = ((const uint8_t *) vy1 + y_qrow_size); + const uint8_t * restrict y1_s = ((const uint8_t *) vy1 + y_qrow_size + y_drow_size); + + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy0_vq = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_vq = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + + HVX_Vector_x8 r0_vq = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_vq = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_vq, vy0_vq)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_vq, vy1_vq)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_vq, vy0_vq)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_vq, vy1_vq)); + + HVX_Vector vy0_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy0_vs = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_s + i * y_sblk_size)); + + HVX_Vector vy1_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector vy1_vs = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_s + i * y_sblk_size)); + + HVX_Vector r0_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r0_vm = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_mblk_size)); + + HVX_Vector r1_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r1_vm = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_m + i * x_mblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vd, vy0_vd))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vm, vy0_vs))); + + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vd, vy1_vd))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vm, vy1_vs))); + + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_vd, vy0_vd))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_vm, vy0_vs))); + + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_vd, vy1_vd))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_vm, vy1_vs))); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd), r0_c0_ms); + HVX_Vector r0_c1_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd), r0_c1_ms); + HVX_Vector r1_c0_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd), r1_c0_ms); + HVX_Vector r1_c1_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd), r1_c1_ms); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy0_vq = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_vq = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + + HVX_Vector_x8 r0_vq = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_vq = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_vq, vy0_vq)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_vq, vy1_vq)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_vq, vy0_vq)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_vq, vy1_vq)); + + HVX_Vector vy0_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy0_vs = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_s + i * y_sblk_size)); + + HVX_Vector vy1_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector vy1_vs = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_s + i * y_sblk_size)); + + HVX_Vector r0_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r0_vm = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_mblk_size)); + + HVX_Vector r1_vd = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r1_vm = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_m + i * x_mblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vd, vy0_vd))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vm, vy0_vs))); + + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vd, vy1_vd))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_vm, vy1_vs))); + + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_vd, vy0_vd))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_vm, vy0_vs))); + + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_vd, vy1_vd))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_vm, vy1_vs))); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd), r0_c0_ms); + HVX_Vector r0_c1_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd), r0_c1_ms); + HVX_Vector r1_c0_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd), r1_c0_ms); + HVX_Vector r1_c1_fa = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd), r1_c1_ms); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + HVX_Vector s0_c0_vec = hvx_vec_reduce_sum_f32(r0_c0_sum); + HVX_Vector s1_c0_vec = hvx_vec_reduce_sum_f32(r1_c0_sum); + HVX_Vector s0_c1_vec = hvx_vec_reduce_sum_f32(r0_c1_sum); + HVX_Vector s1_c1_vec = hvx_vec_reduce_sum_f32(r1_c1_sum); + + hvx_vec_store_u(&s0[0], 4, s0_c0_vec); + hvx_vec_store_u(&s0[1], 4, s1_c0_vec); + hvx_vec_store_u(&s1[0], 4, s0_c1_vec); + hvx_vec_store_u(&s1[1], 4, s1_c1_vec); +} + + // ======== IQ4_NL x Q8_0 vec_dot kernels ======== // Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue). // Scale format is identical to Q4_0 (fp16 scales). @@ -2574,6 +2898,160 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric } // Overrides input x + +static inline void quantize_block_f32_q8_1x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d, uint8_t * restrict y_s) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_q % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + // Use reduce max fp32 to find max(abs(e)) first + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + // Load and convert into QF32 + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + // Convert to QF32 + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + // Combine and convert to fp16 + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + hvx_vec_store_u(y_d + 0, 2, vd01_hf); + HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64); + hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf); + + hvx_vec_store_u(y_d + 4, 2, vd23_hf); + rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64); + hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf); + + // Divide input by the scale + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + // Convert to int8 + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + *(HVX_Vector *) y_q = vx_i8; + + // Sum calculation for Q8_1 + HVX_Vector ones = Q6_Vb_vsplat_R(1); + HVX_Vector sum4 = Q6_Vw_vrmpy_VbVb(vx_i8, ones); + + HVX_Vector sum_a = Q6_Vw_vadd_VwVw(sum4, Q6_V_vror_VR(sum4, 4)); + HVX_Vector sum_b = Q6_Vw_vadd_VwVw(sum_a, Q6_V_vror_VR(sum_a, 8)); + HVX_Vector sum_c = Q6_Vw_vadd_VwVw(sum_b, Q6_V_vror_VR(sum_b, 16)); + + int32_t s0 = Q6_R_vextract_VR(sum_c, 0); + int32_t s1 = Q6_R_vextract_VR(sum_c, 32); + int32_t s2 = Q6_R_vextract_VR(sum_c, 64); + int32_t s3 = Q6_R_vextract_VR(sum_c, 96); + + // Reconstruct d + float d0 = (float)*(__fp16*)(y_d + 0); + float d1 = (float)*(__fp16*)(y_d + 2); + float d2 = (float)*(__fp16*)(y_d + 4); + float d3 = (float)*(__fp16*)(y_d + 6); + + __fp16 hs0 = (__fp16)(d0 * s0); + __fp16 hs1 = (__fp16)(d1 * s1); + __fp16 hs2 = (__fp16)(d2 * s2); + __fp16 hs3 = (__fp16)(d3 * s3); + + hvx_vec_store_u(y_s + 0, 2, Q6_Vh_vsplat_R(*(int*)&hs0)); + hvx_vec_store_u(y_s + 2, 2, Q6_Vh_vsplat_R(*(int*)&hs1)); + hvx_vec_store_u(y_s + 4, 2, Q6_Vh_vsplat_R(*(int*)&hs2)); + hvx_vec_store_u(y_s + 6, 2, Q6_Vh_vsplat_R(*(int*)&hs3)); +} + +static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { + const uint32_t qk = QK_Q8_1x4x2; + const uint32_t nb = k / qk; + + const uint32_t dblk_size = 8 * 2; // fp16 + const uint32_t qblk_size = qk; // int8 + const uint32_t sblk_size = 8 * 2; // fp16 + + const uint32_t qrow_size = k; + const uint32_t drow_size = k / 32 * 2; + const uint32_t srow_size = k / 32 * 2; + + uint8_t * restrict y_q = y + 0; + uint8_t * restrict y_d = y + qrow_size; + uint8_t * restrict y_s = y + qrow_size + drow_size; + + uint8_t t_d[256]; + uint8_t t_s[256]; + + for (uint32_t i = 0; i < nb; i++) { + quantize_block_f32_q8_1x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2, t_s + (i*2 + 0) * sblk_size/2); + quantize_block_f32_q8_1x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2, t_s + (i*2 + 1) * sblk_size/2); + } + + hvx_copy_f16_ua(y_d, t_d, nb * 8); + hvx_copy_f16_ua(y_s, t_s, nb * 8); +} + +static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = (struct htp_matmul_context *)data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = (uint8_t *)octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + + uint32_t nb01 = src->nb[1]; + uint32_t ne0 = src->ne[0]; + uint32_t ne1 = src->ne[1]; + + uint32_t r_start = ith * nrows_per_thread; + uint32_t r_end = r_start + nrows_per_thread; + if (r_end > ne1) r_end = ne1; + if (r_start >= r_end) return; + + size_t src1_row_size = ne0 + ne0 / 32 * 2 + ne0 / 32 * 2; + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_1x4x2 * sizeof(float)); + size_t src_row_size_padded = hex_round_up(ne0 * sizeof(float), QK_Q8_1x4x2 * sizeof(float)); + + uint8_t * tmp_data = (uint8_t *)spad->data + ith * spad->size_per_thread; + + for (uint32_t r = r_start; r < r_end; r++) { + uint8_t * src_data = (uint8_t *)src->data + r * nb01; + uint8_t * dst_data = dst + r * src1_row_size_padded; + + memcpy(tmp_data, src_data, ne0 * sizeof(float)); + if (src_row_size_padded > ne0 * sizeof(float)) { + memset(tmp_data + ne0 * sizeof(float), 0, src_row_size_padded - ne0 * sizeof(float)); + } + + quantize_row_f32_q8_1x4x2((float *) tmp_data, dst_data, ne0); + } +} + static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { assert(k % 32 == 0); const uint32_t qk = QK_Q8_0x4x2; @@ -2752,6 +3230,12 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; return 0; + case HTP_TYPE_Q4_1: + mmctx->type = "q4_1x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8_1x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8_1x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8_1x4x2_2x2; + return 0; case HTP_TYPE_Q8_0: mmctx->type = "q8x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; @@ -2894,8 +3378,13 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = ne10 + ne10 / 32 * 2 + ne10 / 32 * 2; // q8_1x4x2 size + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); } @@ -3100,8 +3589,13 @@ int op_matmul_id(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = ne10 + ne10 / 32 * 2 + ne10 / 32 * 2; + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);