From d621b840da5cb1975f92331268a8f7587bf8f847 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Thu, 26 Jun 2025 10:55:01 +0100 Subject: [PATCH 1/2] Update kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon with vectorized row summation Signed-off-by: Evie Wright --- CHANGELOG.md | 1 + ...s_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c | 73 +++++-------------- 2 files changed, 20 insertions(+), 54 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d03bb041..3e38e874 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - New Advanced SIMD micro-kernels: - Optimized version of kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 kernel for block depth of 4 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon`) +- Update `kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon` with vectorized row summation ## v1.10.0 diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c index 3cb7f111..8a798c45 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c @@ -205,7 +205,7 @@ void kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( const float d2 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 2]); const float d3 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 3]); - // Take zero-point (-8) into account + // Initialize partial sum taking new zero-point (8) into account int32_t partial_sum0 = -(32 * 8); int32_t partial_sum1 = -(32 * 8); int32_t partial_sum2 = -(32 * 8); @@ -213,58 +213,14 @@ void kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( const uint8_t* src_block_base = rhs + ((k0_idx_i / 2) + dst_byte_idx); - // Load elements as uint64_ts to calculate sums more efficiently - uint64_t ld0_0 = *(const uint64_t*)(src_block_base + n0_idx * rhs_stride); - uint64_t ld0_1 = *(const uint64_t*)(src_block_base + n0_idx * rhs_stride + 8); - - uint64_t ld1_0 = *(const uint64_t*)(src_block_base + n1_idx * rhs_stride); - uint64_t ld1_1 = *(const uint64_t*)(src_block_base + n1_idx * rhs_stride + 8); - - uint64_t ld2_0 = *(const uint64_t*)(src_block_base + n2_idx * rhs_stride); - uint64_t ld2_1 = *(const uint64_t*)(src_block_base + n2_idx * rhs_stride + 8); - - uint64_t ld3_0 = *(const uint64_t*)(src_block_base + n3_idx * rhs_stride); - uint64_t ld3_1 = *(const uint64_t*)(src_block_base + n3_idx * rhs_stride + 8); - - // Copy to vector registers - const uint8x8_t vld0_0 = vcreate_u8(ld0_0); - const uint8x8_t vld0_1 = vcreate_u8(ld0_1); - - const uint8x8_t vld1_0 = vcreate_u8(ld1_0); - const uint8x8_t vld1_1 = vcreate_u8(ld1_1); - - const uint8x8_t vld2_0 = vcreate_u8(ld2_0); - const uint8x8_t vld2_1 = vcreate_u8(ld2_1); - - const uint8x8_t vld3_0 = vcreate_u8(ld3_0); - const uint8x8_t vld3_1 = vcreate_u8(ld3_1); - - // Calculate sums - for (size_t idx = 0; idx < 16; ++idx) { - const int32_t e0_0 = (int32_t)(ld0_0 & 0x0F); - const int32_t e0_1 = (int32_t)(ld0_1 & 0x0F); - partial_sum0 += e0_0 + e0_1; - ld0_0 = ld0_0 >> 4; - ld0_1 = ld0_1 >> 4; - - const int32_t e1_0 = (int32_t)(ld1_0 & 0x0F); - const int32_t e1_1 = (int32_t)(ld1_1 & 0x0F); - partial_sum1 += e1_0 + e1_1; - ld1_0 = ld1_0 >> 4; - ld1_1 = ld1_1 >> 4; - - const int32_t e2_0 = (int32_t)(ld2_0 & 0x0F); - const int32_t e2_1 = (int32_t)(ld2_1 & 0x0F); - partial_sum2 += e2_0 + e2_1; - ld2_0 = ld2_0 >> 4; - ld2_1 = ld2_1 >> 4; - - const int32_t e3_0 = (int32_t)(ld3_0 & 0x0F); - const int32_t e3_1 = (int32_t)(ld3_1 & 0x0F); - partial_sum3 += e3_0 + e3_1; - ld3_0 = ld3_0 >> 4; - ld3_1 = ld3_1 >> 4; - } + const uint8x8_t vld0_0 = vld1_u8(src_block_base + n0_idx * rhs_stride); + const uint8x8_t vld0_1 = vld1_u8(src_block_base + n0_idx * rhs_stride + 8); + const uint8x8_t vld1_0 = vld1_u8(src_block_base + n1_idx * rhs_stride); + const uint8x8_t vld1_1 = vld1_u8(src_block_base + n1_idx * rhs_stride + 8); + const uint8x8_t vld2_0 = vld1_u8(src_block_base + n2_idx * rhs_stride); + const uint8x8_t vld2_1 = vld1_u8(src_block_base + n2_idx * rhs_stride + 8); + const uint8x8_t vld3_0 = vld1_u8(src_block_base + n3_idx * rhs_stride); + const uint8x8_t vld3_1 = vld1_u8(src_block_base + n3_idx * rhs_stride + 8); const uint8x8_t vld0_s1s = vand_u8(vld0_0, bottom_mask); const uint8x8_t vld0_s0s = vshr_n_u8(vld0_0, 4); @@ -325,7 +281,16 @@ void kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( (uint8_t*)dst_row + (nr * block_length_in_bytes) + 24, veor_u8(vld3_s16s0s_upper, zero_point_conversion_mask)); - // Add to row sums + // Calculate and store row sums + partial_sum0 += (int32_t)vaddlvq_u16(vaddl_u8( + vadd_u8(vld0_s1s, vand_u8(vld0_1, bottom_mask)), vadd_u8(vld0_s0s, vshr_n_u8(vld0_1, 4)))); + partial_sum1 += (int32_t)vaddlvq_u16(vaddl_u8( + vadd_u8(vld1_s1s, vand_u8(vld1_1, bottom_mask)), vadd_u8(vld1_s0s, vshr_n_u8(vld1_1, 4)))); + partial_sum2 += (int32_t)vaddlvq_u16(vaddl_u8( + vadd_u8(vld2_s1s, vand_u8(vld2_1, bottom_mask)), vadd_u8(vld2_s0s, vshr_n_u8(vld2_1, 4)))); + partial_sum3 += (int32_t)vaddlvq_u16(vaddl_u8( + vadd_u8(vld3_s1s, vand_u8(vld3_1, bottom_mask)), vadd_u8(vld3_s0s, vshr_n_u8(vld3_1, 4)))); + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) sums[nr_idx + 0] += (float)partial_sum0 * d0; sums[nr_idx + 1] += (float)partial_sum1 * d1; -- GitLab From c675ae741e9b9186ae04dd2eceb91055b0d39bdd Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Thu, 26 Jun 2025 14:44:50 +0100 Subject: [PATCH 2/2] update changelog summary Signed-off-by: Evie Wright --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e38e874..a4b90d29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - New Advanced SIMD micro-kernels: - Optimized version of kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 kernel for block depth of 4 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon`) -- Update `kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon` with vectorized row summation +- Improve performance of `kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon` ## v1.10.0 -- GitLab