From 537bff952665ed36d8d8cce4ed4ef70c793c394d Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Wed, 7 May 2025 15:57:46 +0100 Subject: [PATCH 1/5] Optimize the transposed RHS packing function for matmul_clamp_f32_qai8dxp_qsi4c32p using Neon Signed-off-by: Evie Wright --- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 163 ++++++++++++------ 1 file changed, 109 insertions(+), 54 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index 9af7841d..51f6cce1 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -153,6 +153,10 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( uint8_t* dst_row = (uint8_t*)rhs_packed; + const uint8x8_t top_mask = vdup_n_u8(0xF0); + const uint8x8_t bottom_mask = vdup_n_u8(0x0F); + const uint8x8_t zero_point_conversion_mask = vdup_n_u8(0x88); + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; dst_row_idx += nr) { float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); @@ -178,84 +182,135 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t k0_idx_i = dst_qblock_idx * bl; - for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { - for (size_t segment_idx = 0; segment_idx < 16 / block_length_in_bytes; ++segment_idx) { + if (block_length_in_bytes == 8) { + for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { for (size_t nr_idx = 0; nr_idx < nr; ++nr_idx) { const size_t n0_idx = dst_row_idx + nr_idx; - // Two int4 values are stored in one byte. - // The lower order part of the byte (low) holds the first nibble (K-index + 0). - // The higher order of the byte holds the second nibble (K-index + 16). - size_t k0_idx = k0_idx_i; - size_t k1_idx = k0_idx_i + 16; - // Clamp the index to avoid out-of-bound reads const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); float d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); - int32_t partial_sum = 0; - size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; - - for (size_t block_byte_idx = 0; block_byte_idx < block_length_in_bytes; block_byte_idx += 2) { - // Initialize the byte with the zero-point (8) - // e.g. uint8_t byte0 = 8 | 8 << 4 - uint8_t byte0 = 136; - uint8_t byte1 = 136; - uint8_t byte2 = 136; - uint8_t byte3 = 136; - - if (k0_idx < k) { - byte0 = rhs[src_addr_byte0]; - } - - if (k1_idx < k) { - byte1 = rhs[src_addr_byte0 + 8]; - } - - if (k0_idx + 1 < k) { - byte2 = byte0; - } + size_t src_block_addr = ((k0_idx_i / 2) + dst_byte_idx) + n0_valid_idx * rhs_stride; - if (k1_idx + 1 < k) { - byte3 = byte1; - } + uint64_t ld0 = *(const uint64_t*)(rhs + src_block_addr); + uint64_t ld1 = *(const uint64_t*)(rhs + src_block_addr + 8); - k0_idx += 2; - k1_idx += 2; + const uint8x8_t vld0 = vcreate_u8(ld0); + const uint8x8_t vld1 = vcreate_u8(ld1); - const uint8_t src_x0_lo = byte0 & 0x0F; - const uint8_t src_x0_hi = byte1 & 0x0F; - const uint8_t src_x1_lo = (byte2 >> 4) & 0x0F; - const uint8_t src_x1_hi = (byte3 >> 4) & 0x0F; + // Calculate row sum + for (size_t idx = 0; idx < 16; ++idx) { + const int32_t e0 = (int32_t)(ld0 & 0x0F); + const int32_t e1 = (int32_t)(ld1 & 0x0F); + partial_sum += e0 + e1; + ld0 = ld0 >> 4; + ld1 = ld1 >> 4; + } + partial_sum -= (32 * 8); - partial_sum += (int32_t)src_x0_lo; - partial_sum += (int32_t)src_x0_hi; - partial_sum += (int32_t)src_x1_lo; - partial_sum += (int32_t)src_x1_hi; - partial_sum -= 32; // 4 * zero_point (8) + const uint8x8_t s1s = vand_u8(vld0, bottom_mask); + const uint8x8_t s0s = vshr_n_u8(vld0, 4); + const uint8x8_t s17s = vshl_n_u8(vld1, 4); + const uint8x8_t s16s = vand_u8(vld1, top_mask); - const uint16_t dst_q = - ((src_x0_lo)) | ((src_x0_hi) << 4) | ((src_x1_lo) << 8) | ((src_x1_hi) << 12); + const uint8x8_t s16s0s_lower = vorr_u8(vzip1_u8(s1s, s0s), vzip1_u8(s17s, s16s)); + const uint8x8_t s16s0s_upper = vorr_u8(vzip2_u8(s1s, s0s), vzip2_u8(s17s, s16s)); - *((uint16_t*)dst_row) = dst_q ^ 0x8888; + vst1_u8((uint8_t*)dst_row, veor_s8(s16s0s_lower, zero_point_conversion_mask)); + vst1_u8( + (uint8_t*)dst_row + (block_length_in_bytes * nr), + veor_s8(s16s0s_upper, zero_point_conversion_mask)); - dst_row += 2; - src_addr_byte0 += 1; - } // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) sums[nr_idx] += (float)partial_sum * d; // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - } - k0_idx_i += block_length_in_bytes; + dst_row += block_length_in_bytes; + } + // skip to end of qblock + dst_row += nr * block_length_in_bytes; + } + } else { + for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { + for (size_t segment_idx = 0; segment_idx < 16 / block_length_in_bytes; ++segment_idx) { + for (size_t nr_idx = 0; nr_idx < nr; ++nr_idx) { + const size_t n0_idx = dst_row_idx + nr_idx; + + // Two int4 values are stored in one byte. + // The lower order part of the byte (low) holds the first nibble (K-index + 0). + // The higher order of the byte holds the second nibble (K-index + 16). + size_t k0_idx = k0_idx_i; + size_t k1_idx = k0_idx_i + 16; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + float d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); + + int32_t partial_sum = 0; + + size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + + for (size_t block_byte_idx = 0; block_byte_idx < block_length_in_bytes; + block_byte_idx += 2) { + // Initialize the byte with the zero-point (8) + // e.g. uint8_t byte0 = 8 | 8 << 4 + uint8_t byte0 = 136; + uint8_t byte1 = 136; + uint8_t byte2 = 136; + uint8_t byte3 = 136; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte0 + 8]; + } + + if (k0_idx + 1 < k) { + byte2 = byte0; + } + + if (k1_idx + 1 < k) { + byte3 = byte1; + } + + k0_idx += 2; + k1_idx += 2; + + const uint8_t src_x0_lo = byte0 & 0x0F; + const uint8_t src_x0_hi = byte1 & 0x0F; + const uint8_t src_x1_lo = (byte2 >> 4) & 0x0F; + const uint8_t src_x1_hi = (byte3 >> 4) & 0x0F; + + partial_sum += (int32_t)src_x0_lo; + partial_sum += (int32_t)src_x0_hi; + partial_sum += (int32_t)src_x1_lo; + partial_sum += (int32_t)src_x1_hi; + partial_sum -= 32; // 4 * zero_point (8) + + const uint16_t dst_q = + ((src_x0_lo)) | ((src_x0_hi) << 4) | ((src_x1_lo) << 8) | ((src_x1_hi) << 12); + + *((uint16_t*)dst_row) = dst_q ^ 0x8888; + + dst_row += 2; + src_addr_byte0 += 1; + } + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + sums[nr_idx] += (float)partial_sum * d; + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + } + k0_idx_i += block_length_in_bytes; + } + k0_idx_i += 16; } - k0_idx_i += 16; } // Move the pointer after scales dst_row += num_bytes_multiplier_rhs * nr; } - // Move the pointer after the row sum dst_row += kai_num_bytes_sum_rhs * nr; -- GitLab From c798e4bd387a7aeb09e7719a5883686f03a48034 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Wed, 7 May 2025 16:20:18 +0100 Subject: [PATCH 2/5] change eor intrinsic used to avoid implicit conversion Signed-off-by: Evie Wright --- .../matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index 51f6cce1..d8e568e9 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -218,10 +218,10 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( const uint8x8_t s16s0s_lower = vorr_u8(vzip1_u8(s1s, s0s), vzip1_u8(s17s, s16s)); const uint8x8_t s16s0s_upper = vorr_u8(vzip2_u8(s1s, s0s), vzip2_u8(s17s, s16s)); - vst1_u8((uint8_t*)dst_row, veor_s8(s16s0s_lower, zero_point_conversion_mask)); + vst1_u8((uint8_t*)dst_row, veor_u8(s16s0s_lower, zero_point_conversion_mask)); vst1_u8( (uint8_t*)dst_row + (block_length_in_bytes * nr), - veor_s8(s16s0s_upper, zero_point_conversion_mask)); + veor_u8(s16s0s_upper, zero_point_conversion_mask)); // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) sums[nr_idx] += (float)partial_sum * d; -- GitLab From bf222c1eca08398042aeb54d12850e8c17b0bd29 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Thu, 8 May 2025 14:47:34 +0100 Subject: [PATCH 3/5] add extra macros to prevent MSVC build failure Signed-off-by: Evie Wright --- .../pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index d8e568e9..bf07d98f 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -5,6 +5,10 @@ // #include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" +#if defined(__ARM_NEON) | defined(_M_ARM64) +#include +#endif // defined(__ARM_NEON) | defined(_M_ARM64) + #include #include #include @@ -182,6 +186,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t k0_idx_i = dst_qblock_idx * bl; +#if defined(__ARM_NEON) | defined(_M_ARM64) if (block_length_in_bytes == 8) { for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { for (size_t nr_idx = 0; nr_idx < nr; ++nr_idx) { @@ -233,6 +238,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( dst_row += nr * block_length_in_bytes; } } else { +#endif // defined(__ARM_NEON) | defined(_M_ARM64) for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { for (size_t segment_idx = 0; segment_idx < 16 / block_length_in_bytes; ++segment_idx) { for (size_t nr_idx = 0; nr_idx < nr; ++nr_idx) { @@ -307,8 +313,10 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( } k0_idx_i += 16; } +#if defined(__ARM_NEON) | defined(_M_ARM64) } - // Move the pointer after scales +#endif // defined(__ARM_NEON) | defined(_M_ARM64) + // Move the pointer after scales dst_row += num_bytes_multiplier_rhs * nr; } // Move the pointer after the row sum -- GitLab From 15225edfd259d11bab19fe0a5833741739620181 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Tue, 13 May 2025 15:22:05 +0100 Subject: [PATCH 4/5] adapt vectorized code to work with four blocks at a time Signed-off-by: Evie Wright --- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 163 ++++++++++++++---- 1 file changed, 131 insertions(+), 32 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index bf07d98f..a5cd3cab 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -189,52 +189,151 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( #if defined(__ARM_NEON) | defined(_M_ARM64) if (block_length_in_bytes == 8) { for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { - for (size_t nr_idx = 0; nr_idx < nr; ++nr_idx) { - const size_t n0_idx = dst_row_idx + nr_idx; + for (size_t nr_idx = 0; nr_idx < nr; nr_idx += 4) { + // Clamp the indices to avoid out-of-bound reads + const size_t n0_idx = KAI_MIN(dst_row_idx + nr_idx, n - 1); + const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1); + const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1); + const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1); - // Clamp the index to avoid out-of-bound reads - const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); - float d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); - int32_t partial_sum = 0; + const float d0 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 0]); + const float d1 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 1]); + const float d2 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 2]); + const float d3 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 3]); - size_t src_block_addr = ((k0_idx_i / 2) + dst_byte_idx) + n0_valid_idx * rhs_stride; + int32_t partial_sum0 = 0; + int32_t partial_sum1 = 0; + int32_t partial_sum2 = 0; + int32_t partial_sum3 = 0; - uint64_t ld0 = *(const uint64_t*)(rhs + src_block_addr); - uint64_t ld1 = *(const uint64_t*)(rhs + src_block_addr + 8); + const size_t src_block_base_addr = ((k0_idx_i / 2) + dst_byte_idx); - const uint8x8_t vld0 = vcreate_u8(ld0); - const uint8x8_t vld1 = vcreate_u8(ld1); + uint64_t ld0_0 = *(const uint64_t*)(rhs + src_block_base_addr + n0_idx * rhs_stride); + uint64_t ld0_1 = *(const uint64_t*)(rhs + src_block_base_addr + n0_idx * rhs_stride + 8); - // Calculate row sum - for (size_t idx = 0; idx < 16; ++idx) { - const int32_t e0 = (int32_t)(ld0 & 0x0F); - const int32_t e1 = (int32_t)(ld1 & 0x0F); - partial_sum += e0 + e1; - ld0 = ld0 >> 4; - ld1 = ld1 >> 4; - } - partial_sum -= (32 * 8); + uint64_t ld1_0 = *(const uint64_t*)(rhs + src_block_base_addr + n1_idx * rhs_stride); + uint64_t ld1_1 = *(const uint64_t*)(rhs + src_block_base_addr + n1_idx * rhs_stride + 8); + + uint64_t ld2_0 = *(const uint64_t*)(rhs + src_block_base_addr + n2_idx * rhs_stride); + uint64_t ld2_1 = *(const uint64_t*)(rhs + src_block_base_addr + n2_idx * rhs_stride + 8); + + uint64_t ld3_0 = *(const uint64_t*)(rhs + src_block_base_addr + n3_idx * rhs_stride); + uint64_t ld3_1 = *(const uint64_t*)(rhs + src_block_base_addr + n3_idx * rhs_stride + 8); + + // Copy to vector registers + const uint8x8_t vld0_0 = vcreate_u8(ld0_0); + const uint8x8_t vld0_1 = vcreate_u8(ld0_1); + + const uint8x8_t vld1_0 = vcreate_u8(ld1_0); + const uint8x8_t vld1_1 = vcreate_u8(ld1_1); - const uint8x8_t s1s = vand_u8(vld0, bottom_mask); - const uint8x8_t s0s = vshr_n_u8(vld0, 4); - const uint8x8_t s17s = vshl_n_u8(vld1, 4); - const uint8x8_t s16s = vand_u8(vld1, top_mask); + const uint8x8_t vld2_0 = vcreate_u8(ld2_0); + const uint8x8_t vld2_1 = vcreate_u8(ld2_1); - const uint8x8_t s16s0s_lower = vorr_u8(vzip1_u8(s1s, s0s), vzip1_u8(s17s, s16s)); - const uint8x8_t s16s0s_upper = vorr_u8(vzip2_u8(s1s, s0s), vzip2_u8(s17s, s16s)); + const uint8x8_t vld3_0 = vcreate_u8(ld3_0); + const uint8x8_t vld3_1 = vcreate_u8(ld3_1); - vst1_u8((uint8_t*)dst_row, veor_u8(s16s0s_lower, zero_point_conversion_mask)); + // Calculate sums + for (size_t idx = 0; idx < 16; ++idx) { + const int32_t e0_0 = (int32_t)(ld0_0 & 0x0F); + const int32_t e0_1 = (int32_t)(ld0_1 & 0x0F); + partial_sum0 += e0_0 + e0_1; + ld0_0 = ld0_0 >> 4; + ld0_1 = ld0_1 >> 4; + + const int32_t e1_0 = (int32_t)(ld1_0 & 0x0F); + const int32_t e1_1 = (int32_t)(ld1_1 & 0x0F); + partial_sum1 += e1_0 + e1_1; + ld1_0 = ld1_0 >> 4; + ld1_1 = ld1_1 >> 4; + + const int32_t e2_0 = (int32_t)(ld2_0 & 0x0F); + const int32_t e2_1 = (int32_t)(ld2_1 & 0x0F); + partial_sum2 += e2_0 + e2_1; + ld2_0 = ld2_0 >> 4; + ld2_1 = ld2_1 >> 4; + + const int32_t e3_0 = (int32_t)(ld3_0 & 0x0F); + const int32_t e3_1 = (int32_t)(ld3_1 & 0x0F); + partial_sum3 += e3_0 + e3_1; + ld3_0 = ld3_0 >> 4; + ld3_1 = ld3_1 >> 4; + } + partial_sum0 -= (32 * 8); + partial_sum1 -= (32 * 8); + partial_sum2 -= (32 * 8); + partial_sum3 -= (32 * 8); + + const uint8x8_t vld0_s1s = vand_u8(vld0_0, bottom_mask); + const uint8x8_t vld0_s0s = vshr_n_u8(vld0_0, 4); + const uint8x8_t vld0_s17s = vshl_n_u8(vld0_1, 4); + const uint8x8_t vld0_s16s = vand_u8(vld0_1, top_mask); + + const uint8x8_t vld0_s16s0s_lower = + vorr_u8(vzip1_u8(vld0_s1s, vld0_s0s), vzip1_u8(vld0_s17s, vld0_s16s)); + const uint8x8_t vld0_s16s0s_upper = + vorr_u8(vzip2_u8(vld0_s1s, vld0_s0s), vzip2_u8(vld0_s17s, vld0_s16s)); + + const uint8x8_t vld1_s1s = vand_u8(vld1_0, bottom_mask); + const uint8x8_t vld1_s0s = vshr_n_u8(vld1_0, 4); + const uint8x8_t vld1_s17s = vshl_n_u8(vld1_1, 4); + const uint8x8_t vld1_s16s = vand_u8(vld1_1, top_mask); + + const uint8x8_t vld1_s16s0s_lower = + vorr_u8(vzip1_u8(vld1_s1s, vld1_s0s), vzip1_u8(vld1_s17s, vld1_s16s)); + const uint8x8_t vld1_s16s0s_upper = + vorr_u8(vzip2_u8(vld1_s1s, vld1_s0s), vzip2_u8(vld1_s17s, vld1_s16s)); + + const uint8x8_t vld2_s1s = vand_u8(vld2_0, bottom_mask); + const uint8x8_t vld2_s0s = vshr_n_u8(vld2_0, 4); + const uint8x8_t vld2_s17s = vshl_n_u8(vld2_1, 4); + const uint8x8_t vld2_s16s = vand_u8(vld2_1, top_mask); + + const uint8x8_t vld2_s16s0s_lower = + vorr_u8(vzip1_u8(vld2_s1s, vld2_s0s), vzip1_u8(vld2_s17s, vld2_s16s)); + const uint8x8_t vld2_s16s0s_upper = + vorr_u8(vzip2_u8(vld2_s1s, vld2_s0s), vzip2_u8(vld2_s17s, vld2_s16s)); + + const uint8x8_t vld3_s1s = vand_u8(vld3_0, bottom_mask); + const uint8x8_t vld3_s0s = vshr_n_u8(vld3_0, 4); + const uint8x8_t vld3_s17s = vshl_n_u8(vld3_1, 4); + const uint8x8_t vld3_s16s = vand_u8(vld3_1, top_mask); + + const uint8x8_t vld3_s16s0s_lower = + vorr_u8(vzip1_u8(vld3_s1s, vld3_s0s), vzip1_u8(vld3_s17s, vld3_s16s)); + const uint8x8_t vld3_s16s0s_upper = + vorr_u8(vzip2_u8(vld3_s1s, vld3_s0s), vzip2_u8(vld3_s17s, vld3_s16s)); + + // Store repacked values + vst1_u8((uint8_t*)dst_row, veor_u8(vld0_s16s0s_lower, zero_point_conversion_mask)); + vst1_u8((uint8_t*)dst_row + 8, veor_u8(vld1_s16s0s_lower, zero_point_conversion_mask)); + vst1_u8((uint8_t*)dst_row + 16, veor_u8(vld2_s16s0s_lower, zero_point_conversion_mask)); + vst1_u8((uint8_t*)dst_row + 24, veor_u8(vld3_s16s0s_lower, zero_point_conversion_mask)); + + vst1_u8( + (uint8_t*)dst_row + (nr * block_length_in_bytes), + veor_u8(vld0_s16s0s_upper, zero_point_conversion_mask)); + vst1_u8( + (uint8_t*)dst_row + (nr * block_length_in_bytes) + 8, + veor_u8(vld1_s16s0s_upper, zero_point_conversion_mask)); + vst1_u8( + (uint8_t*)dst_row + (nr * block_length_in_bytes) + 16, + veor_u8(vld2_s16s0s_upper, zero_point_conversion_mask)); vst1_u8( - (uint8_t*)dst_row + (block_length_in_bytes * nr), - veor_u8(s16s0s_upper, zero_point_conversion_mask)); + (uint8_t*)dst_row + (nr * block_length_in_bytes) + 24, + veor_u8(vld3_s16s0s_upper, zero_point_conversion_mask)); + // Add to row sums // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - sums[nr_idx] += (float)partial_sum * d; + sums[nr_idx + 0] += (float)partial_sum0 * d0; + sums[nr_idx + 1] += (float)partial_sum1 * d1; + sums[nr_idx + 2] += (float)partial_sum2 * d2; + sums[nr_idx + 3] += (float)partial_sum3 * d3; // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - dst_row += block_length_in_bytes; + dst_row += block_length_in_bytes * 4; } - // skip to end of qblock + // Skip to end of qblock dst_row += nr * block_length_in_bytes; } } else { -- GitLab From d0a91d056647f87549c0e110cf37ce056432532a Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Thu, 15 May 2025 17:50:15 +0100 Subject: [PATCH 5/5] address code review comments Signed-off-by: Evie Wright --- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index a5cd3cab..35c7bf17 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -157,10 +157,6 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( uint8_t* dst_row = (uint8_t*)rhs_packed; - const uint8x8_t top_mask = vdup_n_u8(0xF0); - const uint8x8_t bottom_mask = vdup_n_u8(0x0F); - const uint8x8_t zero_point_conversion_mask = vdup_n_u8(0x88); - for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; dst_row_idx += nr) { float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); @@ -187,6 +183,10 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t k0_idx_i = dst_qblock_idx * bl; #if defined(__ARM_NEON) | defined(_M_ARM64) + const uint8x8_t top_mask = vdup_n_u8(0xF0); + const uint8x8_t bottom_mask = vdup_n_u8(0x0F); + const uint8x8_t zero_point_conversion_mask = vdup_n_u8(0x88); + if (block_length_in_bytes == 8) { for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { for (size_t nr_idx = 0; nr_idx < nr; nr_idx += 4) { @@ -201,24 +201,26 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( const float d2 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 2]); const float d3 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 3]); - int32_t partial_sum0 = 0; - int32_t partial_sum1 = 0; - int32_t partial_sum2 = 0; - int32_t partial_sum3 = 0; + // Take zero-point (-8) into account + int32_t partial_sum0 = -(32 * 8); + int32_t partial_sum1 = -(32 * 8); + int32_t partial_sum2 = -(32 * 8); + int32_t partial_sum3 = -(32 * 8); - const size_t src_block_base_addr = ((k0_idx_i / 2) + dst_byte_idx); + const uint8_t* src_block_base = rhs + ((k0_idx_i / 2) + dst_byte_idx); - uint64_t ld0_0 = *(const uint64_t*)(rhs + src_block_base_addr + n0_idx * rhs_stride); - uint64_t ld0_1 = *(const uint64_t*)(rhs + src_block_base_addr + n0_idx * rhs_stride + 8); + // Load elements as uint64_ts to calculate sums more efficiently + uint64_t ld0_0 = *(const uint64_t*)(src_block_base + n0_idx * rhs_stride); + uint64_t ld0_1 = *(const uint64_t*)(src_block_base + n0_idx * rhs_stride + 8); - uint64_t ld1_0 = *(const uint64_t*)(rhs + src_block_base_addr + n1_idx * rhs_stride); - uint64_t ld1_1 = *(const uint64_t*)(rhs + src_block_base_addr + n1_idx * rhs_stride + 8); + uint64_t ld1_0 = *(const uint64_t*)(src_block_base + n1_idx * rhs_stride); + uint64_t ld1_1 = *(const uint64_t*)(src_block_base + n1_idx * rhs_stride + 8); - uint64_t ld2_0 = *(const uint64_t*)(rhs + src_block_base_addr + n2_idx * rhs_stride); - uint64_t ld2_1 = *(const uint64_t*)(rhs + src_block_base_addr + n2_idx * rhs_stride + 8); + uint64_t ld2_0 = *(const uint64_t*)(src_block_base + n2_idx * rhs_stride); + uint64_t ld2_1 = *(const uint64_t*)(src_block_base + n2_idx * rhs_stride + 8); - uint64_t ld3_0 = *(const uint64_t*)(rhs + src_block_base_addr + n3_idx * rhs_stride); - uint64_t ld3_1 = *(const uint64_t*)(rhs + src_block_base_addr + n3_idx * rhs_stride + 8); + uint64_t ld3_0 = *(const uint64_t*)(src_block_base + n3_idx * rhs_stride); + uint64_t ld3_1 = *(const uint64_t*)(src_block_base + n3_idx * rhs_stride + 8); // Copy to vector registers const uint8x8_t vld0_0 = vcreate_u8(ld0_0); @@ -259,10 +261,6 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( ld3_0 = ld3_0 >> 4; ld3_1 = ld3_1 >> 4; } - partial_sum0 -= (32 * 8); - partial_sum1 -= (32 * 8); - partial_sum2 -= (32 * 8); - partial_sum3 -= (32 * 8); const uint8x8_t vld0_s1s = vand_u8(vld0_0, bottom_mask); const uint8x8_t vld0_s0s = vshr_n_u8(vld0_0, 4); @@ -304,7 +302,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( const uint8x8_t vld3_s16s0s_upper = vorr_u8(vzip2_u8(vld3_s1s, vld3_s0s), vzip2_u8(vld3_s17s, vld3_s16s)); - // Store repacked values + // Convert to unsigned int4 and store repacked values vst1_u8((uint8_t*)dst_row, veor_u8(vld0_s16s0s_lower, zero_point_conversion_mask)); vst1_u8((uint8_t*)dst_row + 8, veor_u8(vld1_s16s0s_lower, zero_point_conversion_mask)); vst1_u8((uint8_t*)dst_row + 16, veor_u8(vld2_s16s0s_lower, zero_point_conversion_mask)); @@ -418,6 +416,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( // Move the pointer after scales dst_row += num_bytes_multiplier_rhs * nr; } + // Move the pointer after the row sum dst_row += kai_num_bytes_sum_rhs * nr; @@ -431,6 +430,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( ((float*)dst_row)[i] = bias[src_row_idx]; } } + // Move the pointer after the row sum dst_row += kai_num_bytes_bias * nr; } -- GitLab