From 6b09fcdf77ff1e8dea671740000388d0adb2e5a8 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 11 Feb 2025 13:34:49 +0000 Subject: [PATCH 1/5] Rename and Optimize QSI4CXP RHS pack micro-kernel * Rename kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon to kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon to represent signed and unsigned input support * Optimize the kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon micro-kernel using Advanced SIMD Resolves: #COMPMID-7829 Signed-off-by: Anitha Raj --- CHANGELOG.md | 2 + CMakeLists.txt | 2 +- kai/ukernels/matmul/BUILD.bazel | 2 +- ...8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h | 12 +- ..._qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h | 12 +- ..._rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c | 250 ++++++++++++++++++ ...rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.h} | 18 +- ...rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c | 165 ------------ .../matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp | 68 +++-- 9 files changed, 305 insertions(+), 226 deletions(-) create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c rename kai/ukernels/matmul/pack/{kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h => kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.h} (86%) delete mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c diff --git a/CHANGELOG.md b/CHANGELOG.md index 006b7bbd..5740f417 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - New 4x8 block size variant of matrix multiplication of QAI8DXP LHS and QSI4C32P RHS with F32 output. - Optimizations for FEAT_DotProd. - Added demonstration of integration using CMake in F16 Arm® Neon™ matrix multiplication example. +- Rename kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon to kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon to represent signed and unsigned input support +- Optimize the kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon micro-kernel using Advanced SIMD. - Fixes: - Fix the RHS packing micro-kernel kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon to handle null bias. diff --git a/CMakeLists.txt b/CMakeLists.txt index 0df28b83..486748c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -122,7 +122,7 @@ set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla_asm.S kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c - kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c ) set(KLEIDIAI_FILES_NEON_DOTPROD_ASM diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index ad4769a3..fa781400 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -36,7 +36,7 @@ NEON_KERNELS = [ "pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", "pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", - "pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon", + "pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon", "pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon", ] diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h index e4523092..20337d7f 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -15,7 +15,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon to pack the RHS matrix +/// -# kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon to pack the RHS matrix /// -------------------------------------------------- @@ -40,19 +40,19 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mo size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void); /// Gets the nr value, which must be used to pack the RHS matrix with -/// the @ref kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon micro-kernel +/// the @ref kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon micro-kernel /// /// @return the nr value size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void); /// Gets the kr value, which must be used to pack the RHS matrix with -/// the @ref kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon micro-kernel +/// the @ref kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon micro-kernel /// /// @return the kr value size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void); /// Gets the sr value, which must be used to pack the RHS matrix with -/// the @ref kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon micro-kernel +/// the @ref kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon micro-kernel /// /// @return the sr value size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void); @@ -108,7 +108,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_ /// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs /// both the dynamic quantization to 8-bit and activation packing in a single step. /// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon +/// kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h index 15e13801..07676b7a 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -15,7 +15,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon to pack the RHS matrix +/// -# kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon to pack the RHS matrix /// -------------------------------------------------- @@ -40,19 +40,19 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(v size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void); /// Gets the nr value, which must be used to pack the RHS matrix with -/// the @ref kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon micro-kernel +/// the @ref kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon micro-kernel /// /// @return the nr value size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void); /// Gets the kr value, which must be used to pack the RHS matrix with -/// the @ref kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon micro-kernel +/// the @ref kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon micro-kernel /// /// @return the kr value size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void); /// Gets the sr value, which must be used to pack the RHS matrix with -/// the @ref kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon micro-kernel +/// the @ref kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon micro-kernel /// /// @return the sr value size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void); @@ -108,7 +108,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot /// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs /// both the dynamic quantization to 8-bit and activation packing in a single step. /// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon +/// kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c new file mode 100644 index 00000000..9f2af08c --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c @@ -0,0 +1,250 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. + +#include "kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k) { + // Round up k to be a multiple of 32. + size_t kai_k_multiple_of = 32; + return kai_roundup(k, kai_k_multiple_of); +} + +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon(size_t n_idx, size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon(size_t k, size_t nr, size_t kr, size_t sr) { + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + const size_t k_internal = kai_k_roundedup(k); + + // multiple of 2 because 2 elements in a byte + KAI_ASSERT((k_internal % 2) == 0); + + return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { + KAI_ASSERT((n_idx % nr) == 0); + + return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon(k, nr, kr, sr); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon( + size_t n, size_t k, size_t nr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(n, nr) / nr; + + return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon(k, nr, kr, sr); +} + +void kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->lhs_zero_point == 1); + KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); + + KAI_ASSERT(kr == 4); + KAI_ASSERT(nr % 4 == 0); + + const uint8_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (k_internal / 2); + const size_t block_length_in_bytes = + kr * sizeof(uint8_t) / 2; // Dividing by 2, as we have 2 int4 elements in 1 byte. + const size_t dst_nr_block_size = nr * block_length_in_bytes; + const size_t rhs_stride = kai_roundup(k, 2) / 2; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + int8_t* dst_row = (int8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + dst_num_bytes_per_row); + float* const scaling_factors = (float*)((uint8_t*)sums + (nr * kai_num_bytes_sum_rhs)); + // Update destination row pointer + float* const biases = (float*)((uint8_t*)scaling_factors + (nr * kai_num_bytes_multiplier_rhs)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + size_t row_idx = dst_row_idx * nr; + size_t rows_left = n - row_idx; + // Saving scales. + if (rows_left >= nr) { + memcpy(scaling_factors, &scale[row_idx], nr * kai_num_bytes_multiplier_rhs); + } else { + // Fill remaining values + memcpy(scaling_factors, &scale[row_idx], rows_left * kai_num_bytes_multiplier_rhs); + // Set leftover to 0 + memset(&scaling_factors[rows_left], 0, (nr - rows_left) * kai_num_bytes_multiplier_rhs); + } + if (bias == NULL) { + // Set bias to 0 + memset(biases, 0, nr * kai_num_bytes_bias); + } else { + if (rows_left >= nr) { + memcpy(biases, &bias[row_idx], nr * kai_num_bytes_bias); + } else { + // Fill remaining values + memcpy(biases, &bias[row_idx], rows_left * kai_num_bytes_bias); + // Set leftover to 0 + memset(&biases[rows_left], 0, (nr - rows_left) * kai_num_bytes_bias); + } + } + size_t nr_idx = 0; + if ((n % nr == 0) && (k % 16 == 0)) { + // 4 rows at a time + for (; nr_idx <= nr - 4; nr_idx += 4) { + // Each iteration processes 16 elements across k + for (size_t k_idx = 0; k_idx <= k - 16; k_idx += 16) { + const size_t n0_idx = dst_row_idx * nr + nr_idx; + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + const size_t src_addr_byte = (k_idx / 2) + n0_valid_idx * rhs_stride; + + // Load 16 kr values for each nr_idx + uint8x8_t vec0_u8 = vld1_u8(rhs + src_addr_byte); + uint8x8_t vec1_u8 = vld1_u8(rhs + src_addr_byte + 1 * rhs_stride); + uint8x8_t vec2_u8 = vld1_u8(rhs + src_addr_byte + 2 * rhs_stride); + uint8x8_t vec3_u8 = vld1_u8(rhs + src_addr_byte + 3 * rhs_stride); + + uint16x4_t vec0_u16 = vreinterpret_u16_u8(vec0_u8); + uint16x4_t vec1_u16 = vreinterpret_u16_u8(vec1_u8); + uint16x4_t vec2_u16 = vreinterpret_u16_u8(vec2_u8); + uint16x4_t vec3_u16 = vreinterpret_u16_u8(vec3_u8); + + uint16x4_t vec01_lo_u16 = vzip1_u16(vec0_u16, vec1_u16); + uint16x4_t vec01_hi_u16 = vzip2_u16(vec0_u16, vec1_u16); + uint16x4_t vec23_lo_u16 = vzip1_u16(vec2_u16, vec3_u16); + uint16x4_t vec23_hi_u16 = vzip2_u16(vec2_u16, vec3_u16); + + uint32x2_t vec01_lo_u32 = vreinterpret_u32_u16(vec01_lo_u16); + uint32x2_t vec01_hi_u32 = vreinterpret_u32_u16(vec01_hi_u16); + uint32x2_t vec23_lo_u32 = vreinterpret_u32_u16(vec23_lo_u16); + uint32x2_t vec23_hi_u32 = vreinterpret_u32_u16(vec23_hi_u16); + + uint32x2_t vin0_u32 = vzip1_u32(vec01_lo_u32, vec23_lo_u32); + uint32x2_t vin1_u32 = vzip2_u32(vec01_lo_u32, vec23_lo_u32); + uint32x2_t vin2_u32 = vzip1_u32(vec01_hi_u32, vec23_hi_u32); + uint32x2_t vin3_u32 = vzip2_u32(vec01_hi_u32, vec23_hi_u32); + + uint8x8_t vin0_u8 = vreinterpret_u8_u32(vin0_u32); + uint8x8_t vin1_u8 = vreinterpret_u8_u32(vin1_u32); + uint8x8_t vin2_u8 = vreinterpret_u8_u32(vin2_u32); + uint8x8_t vin3_u8 = vreinterpret_u8_u32(vin3_u32); + + int16x8_t vin0_s16 = vreinterpretq_s16_u16(vmovl_u8(vin0_u8)); + int16x8_t vin1_s16 = vreinterpretq_s16_u16(vmovl_u8(vin1_u8)); + int16x8_t vin2_s16 = vreinterpretq_s16_u16(vmovl_u8(vin2_u8)); + int16x8_t vin3_s16 = vreinterpretq_s16_u16(vmovl_u8(vin3_u8)); + + int16x8x4_t src_arr_s16 = {{vin0_s16, vin1_s16, vin2_s16, vin3_s16}}; + + for (int i = 0; i < 4; i++) { + int16x8_t vsrc_s16 = src_arr_s16.val[i]; + int16x8_t vlo_s16 = vandq_s16(vsrc_s16, vdupq_n_s16(0x0F)); + int16x8_t vhi_s16 = vshrq_n_s16(vsrc_s16, 4); + int32_t rhs_zero_point = params->rhs_zero_point; + + int32x4_t vzp_s32 = vdupq_n_s32(rhs_zero_point); + int32x4_t vlo_s32_1 = vsubq_s32(vmovl_s16(vget_low_s16(vlo_s16)), vzp_s32); + int32x4_t vlo_s32_2 = vsubq_s32(vmovl_s16(vget_high_s16(vlo_s16)), vzp_s32); + int32x4_t vhi_s32_1 = vsubq_s32(vmovl_s16(vget_low_s16(vhi_s16)), vzp_s32); + int32x4_t vhi_s32_2 = vsubq_s32(vmovl_s16(vget_high_s16(vhi_s16)), vzp_s32); + + int32x4_t vsum_s32_1 = vaddq_s32(vlo_s32_1, vhi_s32_1); + int32x4_t vsum_s32_2 = vaddq_s32(vlo_s32_2, vhi_s32_2); + + sums[nr_idx] += vaddv_s32(vget_low_s32(vsum_s32_1)); + sums[nr_idx + 1] += vaddv_s32(vget_high_s32(vsum_s32_1)); + sums[nr_idx + 2] += vaddv_s32(vget_low_s32(vsum_s32_2)); + sums[nr_idx + 3] += vaddv_s32(vget_high_s32(vsum_s32_2)); + + int16x8_t vlo_1_s16 = vcombine_s16(vmovn_s32(vlo_s32_1), vmovn_s32(vlo_s32_2)); + int16x8_t vhi_1_s16 = vcombine_s16(vmovn_s32(vhi_s32_1), vmovn_s32(vhi_s32_2)); + + int16x8_t vdst_s16 = + vorrq_s16(vandq_s16(vlo_1_s16, vdupq_n_s16(0x0F)), vshlq_n_s16(vhi_1_s16, 4)); + uint8x8_t vdst_u8 = vreinterpret_u8_s8(vmovn_s16(vdst_s16)); + + size_t offset = ((k_idx + i * 4) / kr) * (nr * kr / 2) + nr_idx * 2; + uint8_t* dst_row_offset = dst_row + offset; + vst1_u8(dst_row_offset, vdst_u8); + } + } + } + dst_row += dst_num_bytes_per_row; + } + // Leftover + // Iterate over rows in the nr row block + for (; nr_idx < nr; ++nr_idx) { + const uint8_t* const src_row = rhs + ((row_idx + nr_idx) * rhs_stride); + // Go to the first kr block for this row in the nr block + int8_t* dst_kr_block = dst_row + (nr_idx * kr / 2); + + int32_t sum = 0; + + // Iterate over k src columns in blocks of kr columns + for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) { + // Iterate over columns in the kr block + // Kr checked to be multiple of 2 (because 2 values per byte) + for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) { + // We pad dst with 0s if the rounded k or n values have been exceeded + if (row_idx + nr_idx >= n || col_idx + kr_block_idx >= k) { + dst_kr_block[kr_block_idx / 2] = 0; + continue; + } + + // Load the 2 u4 values from source + const uint8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2]; + + // extract i8 values from the 2 u4 values + const int32_t first_value = (dst_byte & 0xF) - rhs_zero_point; + const int32_t second_value = col_idx + kr_block_idx + 1 >= k ? 0 : (dst_byte >> 4) - rhs_zero_point; + + // Add the i4 value to the row sum + sum += first_value + second_value; + + // Truncate i8 to i4 and write to dst + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + dst_kr_block[kr_block_idx / 2] = (second_value << 4) | (first_value & 0xF); + } + + // Go to the next kr block for this row in the nr rows + dst_kr_block += dst_nr_block_size; + } + + // save sum + sums[nr_idx] = sum; + } + } +} +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.h similarity index 86% rename from kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h rename to kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.h index 6e94913a..187ea3b6 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.h @@ -14,8 +14,8 @@ extern "C" { #endif -#ifndef kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon_params -#define kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon_params kai_rhs_pack_qs4cxs1s0_param +#ifndef kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon_params +#define kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon_params kai_rhs_pack_qs4cxs1s0_param #endif /// Get the n step value. @@ -26,7 +26,7 @@ extern "C" { /// @param[in] nr The number of columns written by the matmul micro-kernel /// /// @return the n step value -size_t kai_get_n_step_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t nr); +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon(size_t nr); /// Gets the offset in bytes for the RHS matrix (not packed) /// @@ -38,7 +38,7 @@ size_t kai_get_n_step_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t nr); /// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) /// /// @return the offset in bytes to the RHS matrix (not packed) -size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t n_idx, size_t rhs_stride); +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon(size_t n_idx, size_t rhs_stride); /// Get the row stride in bytes to the packed RHS matrix /// @@ -48,7 +48,7 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t n_idx, /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the stride in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t k, size_t nr, size_t kr, size_t sr); +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon(size_t k, size_t nr, size_t kr, size_t sr); /// Gets the offset in bytes for the packed RHS matrix, /// which contains the packed 4-bit quantized symmetric per-channel (qsi4cx) values. @@ -60,7 +60,7 @@ size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the offset in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr); /// @brief Gets the size in bytes for the packed RHS matrix @@ -72,7 +72,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the packed RHS matrix size in bytes -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon( size_t n, size_t k, size_t nr, size_t kr, size_t sr); /// Run the micro-kernel to pack the RHS matrix. @@ -94,10 +94,10 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( /// @param[out] rhs_packed The packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. /// @param[in] params Parameters for the micro-kernel. -void kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( +void kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, const float* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon_params* params); + const struct kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon_params* params); #ifdef __cplusplus } diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c deleted file mode 100644 index d0c66276..00000000 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c +++ /dev/null @@ -1,165 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#if !defined(__aarch64__) -#error This file must be compiled for AArch64. -#else // Architectural features check. - -#include "kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h" - -#include -#include - -#include "kai/kai_common.h" - -static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); -static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); -static const size_t kai_num_bytes_bias = sizeof(float); - -inline static size_t kai_k_roundedup(size_t k) { - // Round up k to be a multiple of 32. - size_t kai_k_multiple_of = 32; - return kai_roundup(k, kai_k_multiple_of); -} - -size_t kai_get_n_step_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t nr) { - return nr; -} - -size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t n_idx, size_t rhs_stride) { - return n_idx * rhs_stride; -} - -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t k, size_t nr, size_t kr, size_t sr) { - KAI_UNUSED(kr); - KAI_UNUSED(sr); - - const size_t k_internal = kai_k_roundedup(k); - - // multiple of 2 because 2 elements in a byte - KAI_ASSERT((k_internal % 2) == 0); - - return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); -} - -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( - size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { - KAI_ASSERT((n_idx % nr) == 0); - - return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr); -} - -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( - size_t n, size_t k, size_t nr, size_t kr, size_t sr) { - const size_t num_rows = kai_roundup(n, nr) / nr; - - return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr); -} - -void kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, - const float* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon_params* params) { - const size_t k_internal = kai_k_roundedup(k); - - KAI_ASSERT((k_internal % kr) == 0); - KAI_ASSERT(num_groups == 1); - KAI_ASSERT(extra_bytes == 0); - KAI_ASSERT((kr % sr) == 0); - KAI_ASSERT(rhs != NULL); - KAI_ASSERT(scale != NULL); - KAI_ASSERT(rhs_packed != NULL); - KAI_ASSERT(params != NULL); - KAI_ASSERT(params->lhs_zero_point == 1); - KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); - - // Note: The input matrix (rhs) is expected with: - // "k" columns and "n" rows (NxK) - - const int32_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_stride = kai_roundup(k, 2) / 2; - const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr); - const size_t dst_nr_block_size = nr * kr * sizeof(uint8_t) / 2; - - // Iterate over n src rows in blocks of nr rows - for (size_t row_idx = 0; row_idx < n; row_idx += nr) { - int8_t* const dst_row = (int8_t*)rhs_packed + ((row_idx / nr) * rhs_packed_stride); - - int32_t* const sums = (int32_t*)(dst_row + (nr * (k_internal / 2))); - float* const scaling_factors = (float*)((uint8_t*)sums + (nr * kai_num_bytes_sum_rhs)); - // Update destination row pointer - float* const biases = (float*)((uint8_t*)scaling_factors + (nr * kai_num_bytes_multiplier_rhs)); - - // initialize sums to 0 - memset(sums, 0, nr * kai_num_bytes_sum_rhs); - - // Copy the scaling factors and bias - size_t rows_left = n - row_idx; - // Saving scales. - if (rows_left >= nr) { - memcpy(scaling_factors, &scale[row_idx], nr * kai_num_bytes_multiplier_rhs); - } else { - // Fill remaining values - memcpy(scaling_factors, &scale[row_idx], rows_left * kai_num_bytes_multiplier_rhs); - // Set leftover to 0 - memset(&scaling_factors[rows_left], 0, (nr - rows_left) * kai_num_bytes_multiplier_rhs); - } - if (bias == NULL) { - // Set bias to 0 - memset(biases, 0, nr * kai_num_bytes_bias); - } else { - if (rows_left >= nr) { - memcpy(biases, &bias[row_idx], nr * kai_num_bytes_bias); - } else { - // Fill remaining values - memcpy(biases, &bias[row_idx], rows_left * kai_num_bytes_bias); - // Set leftover to 0 - memset(&biases[rows_left], 0, (nr - rows_left) * kai_num_bytes_bias); - } - } - // Iterate over rows in the nr row block - for (size_t nr_block_idx = 0; nr_block_idx < nr; ++nr_block_idx) { - const uint8_t* const src_row = rhs + ((row_idx + nr_block_idx) * rhs_stride); - // Go to the first kr block for this row in the nr block - int8_t* dst_kr_block = dst_row + (nr_block_idx * kr / 2); - - int32_t sum = 0; - - // Iterate over k src columns in blocks of kr columns - for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) { - // Iterate over columns in the kr block - // Kr checked to be multiple of 2 (because 2 values per byte) - for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) { - // We pad dst with 0s if the rounded k or n values have been exceeded - if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) { - dst_kr_block[kr_block_idx / 2] = 0; - continue; - } - - // Load the 2 u4 values from source - const uint8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2]; - - // extract i8 values from the 2 u4 values - const int32_t first_value = (dst_byte & 0xF) - rhs_zero_point; - const int32_t second_value = col_idx + kr_block_idx + 1 >= k ? 0 : (dst_byte >> 4) - rhs_zero_point; - - // Add the i4 value to the row sum - sum += first_value + second_value; - - // Truncate i8 to i4 and write to dst - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - dst_kr_block[kr_block_idx / 2] = (second_value << 4) | (first_value & 0xF); - } - - // Go to the next kr block for this row in the nr rows - dst_kr_block += dst_nr_block_size; - } - - // save sum - sums[nr_block_idx] = sum; - } - } -} -#endif // Architectural features check. diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp index c0018dd6..0e0e5bcd 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -32,7 +32,7 @@ #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.h" #include "test/common/cpu_info.hpp" #include "test/common/int4.hpp" #include "test/common/matrix_portion.hpp" @@ -56,7 +56,7 @@ enum class RhsPackType { NxK, KxN }; using ukernel_rhs_pack_function = std::function; using ukernel_get_rhs_packed_size = std::function; using ukernel_get_rhs_packed_offset = std::function; -using ukernel_get_rhs_offset = std::function; +using ukernel_get_rhs_offset = std::function; template struct UkernelVariantCustom : public UkernelVariant { @@ -65,7 +65,6 @@ struct UkernelVariantCustom : public UkernelVariant { ukernel_get_rhs_packed_offset get_rhs_packed_offset; ukernel_get_rhs_offset get_rhs_offset; RhsPackType rhs_pack_type; - bool signed_integer_support; UkernelVariantCustom() = delete; @@ -73,14 +72,13 @@ struct UkernelVariantCustom : public UkernelVariant { T interface, std::string_view name, const std::function& fn_is_supported, ukernel_rhs_pack_function run_rhs_pack, ukernel_get_rhs_packed_size get_rhs_packed_size, ukernel_get_rhs_packed_offset get_rhs_packed_offset, ukernel_get_rhs_offset get_rhs_offset, - const RhsPackType pack_type, const bool signed_integer_support) : + const RhsPackType pack_type) : UkernelVariant(interface, name, fn_is_supported), run_rhs_pack(std::move(run_rhs_pack)), get_rhs_packed_size(std::move(get_rhs_packed_size)), get_rhs_packed_offset(std::move(get_rhs_packed_offset)), get_rhs_offset(std::move(get_rhs_offset)), - rhs_pack_type(pack_type), - signed_integer_support(signed_integer_support) { + rhs_pack_type(pack_type) { } }; @@ -88,118 +86,118 @@ static const std::array Date: Tue, 11 Feb 2025 16:43:22 +0000 Subject: [PATCH 2/5] Fix warnings Signed-off-by: Anitha Raj --- .../pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c index 9f2af08c..b515202e 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c @@ -193,11 +193,11 @@ void kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon( int16x8_t vdst_s16 = vorrq_s16(vandq_s16(vlo_1_s16, vdupq_n_s16(0x0F)), vshlq_n_s16(vhi_1_s16, 4)); - uint8x8_t vdst_u8 = vreinterpret_u8_s8(vmovn_s16(vdst_s16)); + int8x8_t vdst_s8 = vmovn_s16(vdst_s16); size_t offset = ((k_idx + i * 4) / kr) * (nr * kr / 2) + nr_idx * 2; - uint8_t* dst_row_offset = dst_row + offset; - vst1_u8(dst_row_offset, vdst_u8); + int8_t* dst_row_offset = dst_row + offset; + vst1_s8(dst_row_offset, vdst_s8); } } } -- GitLab From 6f71d0cc60ca0fcebf7ece303f35b9f78dee573a Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 11 Feb 2025 16:57:50 +0000 Subject: [PATCH 3/5] Fix Clang tidy warnings Signed-off-by: Anitha Raj --- .../matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c index b515202e..5807dfa3 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c @@ -168,7 +168,7 @@ void kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon( int16x8x4_t src_arr_s16 = {{vin0_s16, vin1_s16, vin2_s16, vin3_s16}}; - for (int i = 0; i < 4; i++) { + for (size_t i = 0; i < 4; i++) { int16x8_t vsrc_s16 = src_arr_s16.val[i]; int16x8_t vlo_s16 = vandq_s16(vsrc_s16, vdupq_n_s16(0x0F)); int16x8_t vhi_s16 = vshrq_n_s16(vsrc_s16, 4); -- GitLab From 69f1b988954d4fcc03fafcd86252210022c925f4 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 11 Feb 2025 17:19:05 +0000 Subject: [PATCH 4/5] Clamp the src index to avoid out of bound reads Signed-off-by: Anitha Raj --- .../pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c index 5807dfa3..0939636b 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon.c @@ -206,7 +206,11 @@ void kai_run_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon( // Leftover // Iterate over rows in the nr row block for (; nr_idx < nr; ++nr_idx) { - const uint8_t* const src_row = rhs + ((row_idx + nr_idx) * rhs_stride); + // Clamp the index to avoid out-of-bound reads + const size_t n_idx = row_idx + nr_idx; + const size_t n_valid_idx = KAI_MIN(n_idx, n - 1); + const uint8_t* const src_row = rhs + (n_valid_idx * rhs_stride); + // Go to the first kr block for this row in the nr block int8_t* dst_kr_block = dst_row + (nr_idx * kr / 2); -- GitLab From 9db59ac576e1bdec413e803cdf4d7d7ea239a73c Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 12 Feb 2025 13:18:35 +0000 Subject: [PATCH 5/5] Update Changelog Signed-off-by: Anitha Raj --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5740f417..21068140 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - New 4x8 block size variant of matrix multiplication of QAI8DXP LHS and QSI4C32P RHS with F32 output. - Optimizations for FEAT_DotProd. - Added demonstration of integration using CMake in F16 Arm® Neon™ matrix multiplication example. -- Rename kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon to kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon to represent signed and unsigned input support +- Rename kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon to kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon to represent signed and unsigned input support - Optimize the kai_rhs_pack_nxk_qsi4cxps1s0_qs4cxs1s0_neon micro-kernel using Advanced SIMD. - Fixes: - Fix the RHS packing micro-kernel kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon to handle null bias. -- GitLab