From 482a5f37b94e82db84037c0b7dee1d8a1f8c774c Mon Sep 17 00:00:00 2001 From: johmcl01 Date: Tue, 15 Jul 2025 15:28:36 +0100 Subject: [PATCH 1/2] Implement lhs_quant_pack_qsi8d32p_f32 using Intrinsics * Added vectorized Advanced SIMD to improve performance * Implemented targetting mr 4, kr 16, sr 2 & bl 32 * New files kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c & .h Signed-off-by: johmcl01 --- CHANGELOG.md | 2 + CMakeLists.txt | 1 + kai/ukernels/matmul/BUILD.bazel | 1 + kai/ukernels/matmul/pack/README.md | 9 ++ ...i_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c | 152 ++++++++++++++++++ ...i_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h | 84 ++++++++++ ...atmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp | 14 +- 7 files changed, 260 insertions(+), 3 deletions(-) create mode 100644 kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c create mode 100644 kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h diff --git a/CHANGELOG.md b/CHANGELOG.md index 99cac4f3..e110f1d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- Improve performance of lhs_quant_pack_qsi8d32p_f32 using Advanced SIMD reimplemented as quant_pack_qsi8d32p4x16sb_f32_neon + ## v1.12.0 - New Advanced SIMD micro-kernels: diff --git a/CMakeLists.txt b/CMakeLists.txt index 1146a801..b3e71357 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -159,6 +159,7 @@ set(KLEIDIAI_FILES_NEON_ASM kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla_asm.S kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c + kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 498dbd76..38027085 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -31,6 +31,7 @@ SCALAR_KERNELS = [ # buildifier: keep sorted NEON_KERNELS = [ "pack/kai_lhs_quant_pack_qai8dxp_bf16_neon", + "pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon", "pack/kai_lhs_quant_pack_qsi8d32p_f32_neon", "pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon", "pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon", diff --git a/kai/ukernels/matmul/pack/README.md b/kai/ukernels/matmul/pack/README.md index 90b14eff..28b8bc11 100644 --- a/kai/ukernels/matmul/pack/README.md +++ b/kai/ukernels/matmul/pack/README.md @@ -61,6 +61,15 @@ Output LHS packed matrix containing quantized (q) symmertric (s) signed int8 (i8) values, with block-wise quantization (d32p) parameters, i.e. the quantized elements are stored in blocks and each block has a scale factor. +#### kai_run_lhs_quant_pack_qsi8d32p4x16sb_f32_neon() + +This routine follows the same format as kai_run_lhs_quant_pack_qsi8d32p_f32() above. + +However, it differs in the following way: + +1. Functionality is implemented using vectorized Advanced SIMD to improve performance +1. The packing routine targets a specific shape with mr 4, kr 16, sr 2 & bl 32 + #### kai_run_lhs_quant_pack_qai8dxp_f32() Quantize and pack LHS matrix with per-dimension(row) quantization parameters. diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c new file mode 100644 index 00000000..153a2084 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c @@ -0,0 +1,152 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); + +inline static size_t kai_num_bytes_per_block(size_t bl) { + return bl * sizeof(int8_t) + kai_num_bytes_multiplier; +} + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((k % bl) == 0); + return k / bl; +} + +inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { + KAI_UNUSED(kr); + return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl); +} + +size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x16sb_f32_neon(size_t mr) { + KAI_UNUSED(mr); + return 1; +} + +size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x16sb_f32_neon(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( + size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + + KAI_UNUSED(sr); + KAI_UNUSED(kr); + + return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl); +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + + KAI_UNUSED(sr); + KAI_UNUSED(kr); + + const size_t num_rows = kai_roundup(m, mr) / mr; + + return num_rows * kai_lhs_packed_stride(k, mr, kr, bl); +} + +void kai_run_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, + size_t lhs_stride, void* lhs_packed) { + if (m == 0) { + return; + } + + const size_t num_rows = m; + const size_t k_block_len = kr / sr; + const size_t lhs_packed_stride = kai_lhs_packed_stride(k, mr, kr, bl); + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); + + // block length 8 + for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { + const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); + + for (size_t b = 0; b < num_blocks_per_row; ++b) { + float abs_max = 0.0F; + + const size_t dst_x = ((row_idx + m_idx_start) % mr); + int8_t* dst_ptr = (int8_t*)lhs_packed + (b * mr) * num_bytes_per_block; + + float32x4_t v_f32_abs_values; + float32x4_t v_f32_maxvals; + float32x4_t v_currentmax = vdupq_n_f32(0); + + KAI_ASSERT((bl % 16) == 0); + for (size_t idx_v = 0; idx_v < bl; idx_v += 4) { + v_f32_maxvals = vld1q_f32(src_ptr + idx_v); + v_f32_abs_values = vabsq_f32(v_f32_maxvals); + v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); + } + abs_max = vmaxvq_f32(v_currentmax); + + // Calculate scale and reciprocal + const float scale = abs_max / ((1 << 7) - 1); + const float rep_scale = scale ? 1.0F / scale : 0.0F; + + *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale); + dst_ptr += mr * kai_num_bytes_multiplier; + + dst_ptr += dst_x * k_block_len * sizeof(int8_t); + + // Quantize and pack the block + for (size_t k_idx = 0; k_idx < bl; k_idx += k_block_len * 2) { + const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx); + const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, rep_scale); + const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1); + + const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4); + const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, rep_scale); + const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2); + + const int16x8_t v_full_i16_block1 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2)); + + const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8); + const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, rep_scale); + const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3); + + const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12); + const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, rep_scale); + const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4); + + const int16x8_t v_full_i16_block2 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4)); + + const int8x16_t v_full_i8_block = + vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block2)); + + vst1_s8(dst_ptr, vget_low_s8(v_full_i8_block)); + dst_ptr += 8 * sizeof(int8_t); + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + + vst1_s8(dst_ptr, vget_high_s8(v_full_i8_block)); + dst_ptr += 8 * sizeof(int8_t); + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + + src_ptr += bl; + } + // Move to the next row if we have interleaved all Mr rows + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); + } + } +} diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h new file mode 100644 index 00000000..878693ab --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h @@ -0,0 +1,84 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @param[in] mr The number of M rows to interleave on the same output row. +/// +/// @return the m step value +size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x16sb_f32_neon(size_t mr); + +/// Gets the offset in bytes for the LHS matrix (not packed) +/// +/// This function should be called before passing the pointer to the LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) +/// +/// @return the offset in bytes to the LHS matrix +size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x16sb_f32_neon(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized symmetric per-block (qsi8d32) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] bl The block length. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( + size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes for the quantized and packed LHS matrix +/// +/// @param[in] m Total number of rows in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed LHS matrix size in bytes +size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + +/// Run the micro-kernel to quantize and pack the LHS matrix. +/// +/// @param[in] m The number of output rows written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be +/// a multiple of 32. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] m_idx_start The starting M index. +/// @param[in] lhs LHS matrix. +/// @param[in] lhs_stride Stride in bytes between two rows of LHS. +/// @param[out] lhs_packed The quantized and packed LHS matrix. +void kai_run_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, + size_t lhs_stride, void* lhs_packed); + +#ifdef __cplusplus +} +#endif diff --git a/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp index ac69e119..6fa4d4e2 100644 --- a/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp @@ -23,6 +23,7 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" @@ -217,9 +218,16 @@ TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, EndToEnd) { ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); - ukernel_variant.pack_interface.lhs_pack( - rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data() + lhs_offset), - lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); + if (ukernel_variant.ukernel.name.find("clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm") == + std::string_view::npos) { + ukernel_variant.pack_interface.lhs_pack( + rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data() + lhs_offset), + lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); + } else { + kai_run_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( + rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data() + lhs_offset), + lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); + } // Runs the RHS packing micro-kernel. const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); -- GitLab From f0d955bb28915ffa4f9d54efb421599efaf25437 Mon Sep 17 00:00:00 2001 From: John McLoughlin Date: Fri, 25 Jul 2025 13:47:57 +0100 Subject: [PATCH 2/2] Implement lhs_quant_pack_qsi8d32p_f32 using Intrinsics * Added vectorized Advanced SIMD to improve performance * Implemented targetting mr 4, kr 16, sr 2 & bl 32 * New files kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c & .h Signed-off-by: John McLoughlin --- CHANGELOG.md | 2 +- CMakeLists.txt | 2 +- kai/ukernels/matmul/BUILD.bazel | 2 +- kai/ukernels/matmul/pack/README.md | 2 +- ...i_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c | 152 -------- ...ai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c | 357 ++++++++++++++++++ ...i_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h} | 10 +- ...atmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp | 17 +- 8 files changed, 371 insertions(+), 173 deletions(-) delete mode 100644 kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c create mode 100644 kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c rename kai/ukernels/matmul/pack/{kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h => kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h} (90%) diff --git a/CHANGELOG.md b/CHANGELOG.md index e110f1d2..5f634825 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release -- Improve performance of lhs_quant_pack_qsi8d32p_f32 using Advanced SIMD reimplemented as quant_pack_qsi8d32p4x16sb_f32_neon +- Improve performance of lhs_quant_pack_qsi8d32p_f32 using Advanced SIMD reimplemented as quant_pack_qsi8d32p4x8sb_f32_neon ## v1.12.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index b3e71357..dd619180 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -159,7 +159,7 @@ set(KLEIDIAI_FILES_NEON_ASM kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla_asm.S kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c - kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c + kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 38027085..6d96d77d 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -31,7 +31,7 @@ SCALAR_KERNELS = [ # buildifier: keep sorted NEON_KERNELS = [ "pack/kai_lhs_quant_pack_qai8dxp_bf16_neon", - "pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon", + "pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon", "pack/kai_lhs_quant_pack_qsi8d32p_f32_neon", "pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon", "pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon", diff --git a/kai/ukernels/matmul/pack/README.md b/kai/ukernels/matmul/pack/README.md index 28b8bc11..69159f1a 100644 --- a/kai/ukernels/matmul/pack/README.md +++ b/kai/ukernels/matmul/pack/README.md @@ -61,7 +61,7 @@ Output LHS packed matrix containing quantized (q) symmertric (s) signed int8 (i8) values, with block-wise quantization (d32p) parameters, i.e. the quantized elements are stored in blocks and each block has a scale factor. -#### kai_run_lhs_quant_pack_qsi8d32p4x16sb_f32_neon() +#### kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon() This routine follows the same format as kai_run_lhs_quant_pack_qsi8d32p_f32() above. diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c deleted file mode 100644 index 153a2084..00000000 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.c +++ /dev/null @@ -1,152 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#include "kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h" - -#include -#include -#include - -#include "kai/kai_common.h" - -static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); - -inline static size_t kai_num_bytes_per_block(size_t bl) { - return bl * sizeof(int8_t) + kai_num_bytes_multiplier; -} - -inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSERT((k % bl) == 0); - return k / bl; -} - -inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { - KAI_UNUSED(kr); - return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl); -} - -size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x16sb_f32_neon(size_t mr) { - KAI_UNUSED(mr); - return 1; -} - -size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x16sb_f32_neon(size_t m_idx, size_t lhs_stride) { - return m_idx * lhs_stride; -} - -size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( - size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { - KAI_ASSUME((k % 2) == 0); - KAI_ASSUME((k % kr) == 0); - KAI_ASSUME((k % bl) == 0); - - KAI_UNUSED(sr); - KAI_UNUSED(kr); - - return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl); -} - -size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( - size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { - KAI_ASSUME((k % 2) == 0); - KAI_ASSUME((k % kr) == 0); - KAI_ASSUME((k % bl) == 0); - - KAI_UNUSED(sr); - KAI_UNUSED(kr); - - const size_t num_rows = kai_roundup(m, mr) / mr; - - return num_rows * kai_lhs_packed_stride(k, mr, kr, bl); -} - -void kai_run_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( - size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, - size_t lhs_stride, void* lhs_packed) { - if (m == 0) { - return; - } - - const size_t num_rows = m; - const size_t k_block_len = kr / sr; - const size_t lhs_packed_stride = kai_lhs_packed_stride(k, mr, kr, bl); - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); - - // block length 8 - for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { - const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); - - for (size_t b = 0; b < num_blocks_per_row; ++b) { - float abs_max = 0.0F; - - const size_t dst_x = ((row_idx + m_idx_start) % mr); - int8_t* dst_ptr = (int8_t*)lhs_packed + (b * mr) * num_bytes_per_block; - - float32x4_t v_f32_abs_values; - float32x4_t v_f32_maxvals; - float32x4_t v_currentmax = vdupq_n_f32(0); - - KAI_ASSERT((bl % 16) == 0); - for (size_t idx_v = 0; idx_v < bl; idx_v += 4) { - v_f32_maxvals = vld1q_f32(src_ptr + idx_v); - v_f32_abs_values = vabsq_f32(v_f32_maxvals); - v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); - } - abs_max = vmaxvq_f32(v_currentmax); - - // Calculate scale and reciprocal - const float scale = abs_max / ((1 << 7) - 1); - const float rep_scale = scale ? 1.0F / scale : 0.0F; - - *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale); - dst_ptr += mr * kai_num_bytes_multiplier; - - dst_ptr += dst_x * k_block_len * sizeof(int8_t); - - // Quantize and pack the block - for (size_t k_idx = 0; k_idx < bl; k_idx += k_block_len * 2) { - const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx); - const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, rep_scale); - const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1); - - const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4); - const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, rep_scale); - const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2); - - const int16x8_t v_full_i16_block1 = - vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2)); - - const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8); - const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, rep_scale); - const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3); - - const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12); - const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, rep_scale); - const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4); - - const int16x8_t v_full_i16_block2 = - vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4)); - - const int8x16_t v_full_i8_block = - vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block2)); - - vst1_s8(dst_ptr, vget_low_s8(v_full_i8_block)); - dst_ptr += 8 * sizeof(int8_t); - dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); - - vst1_s8(dst_ptr, vget_high_s8(v_full_i8_block)); - dst_ptr += 8 * sizeof(int8_t); - dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); - } - - src_ptr += bl; - } - // Move to the next row if we have interleaved all Mr rows - if ((((row_idx + 1) + m_idx_start) % mr) == 0) { - lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); - } - } -} diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c new file mode 100644 index 00000000..a81dcc57 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c @@ -0,0 +1,357 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); + +inline static size_t kai_num_bytes_per_block(size_t bl) { + return bl * sizeof(int8_t) + kai_num_bytes_multiplier; +} + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((k % bl) == 0); + return k / bl; +} + +inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { + KAI_UNUSED(kr); + return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl); +} + +size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t mr) { + KAI_UNUSED(mr); + return 1; +} + +size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( + size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + + KAI_UNUSED(sr); + KAI_UNUSED(kr); + + return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl); +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + + KAI_UNUSED(sr); + KAI_UNUSED(kr); + + const size_t num_rows = kai_roundup(m, mr) / mr; + + return num_rows * kai_lhs_packed_stride(k, mr, kr, bl); +} + +void kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, + size_t lhs_stride, void* lhs_packed) { + if (m == 0) { + return; + } + + KAI_ASSUME(bl == 32); + KAI_ASSUME(mr == 4); + KAI_ASSUME(kr == 16); + KAI_ASSUME(sr == 2); + + const size_t local_bl = 32; + const size_t local_mr = 4; + const size_t local_kr = 16; + const size_t local_sr = 2; + const size_t num_rows = m; + const size_t k_block_len = local_kr / local_sr; + const size_t lhs_packed_stride = kai_lhs_packed_stride(k, local_mr, local_kr, local_bl); + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, local_bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(local_bl); + + size_t row_idx = 0; + + const size_t write_mem_increment = 2 * k_block_len * sizeof(int8_t); + const size_t read_mem_increment = num_blocks_per_row * local_bl * sizeof(int8_t); + + if (num_rows >= 4) { + for (; row_idx + 4 <= num_rows; row_idx += 4) { + const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); + + for (size_t b = 0; b < num_blocks_per_row; ++b) { + const size_t dst_x = ((row_idx + m_idx_start) % local_mr); + int8_t* dst_ptr = (int8_t*)lhs_packed + (b * local_mr) * num_bytes_per_block; + + float32x4_t v_f32_abs_values; + float32x4_t v_f32_maxvals; + float32x4_t v_currentmax; + float abs_max = 0.0F; + + v_currentmax = vdupq_n_f32(0); + + for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { + v_f32_maxvals = vld1q_f32(src_ptr + idx_v); + v_f32_abs_values = vabsq_f32(v_f32_maxvals); + v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); + } + abs_max = vmaxvq_f32(v_currentmax); + + // Calculate scale and reciprocals + const float scale1 = abs_max / ((1 << 7) - 1); + const float rep_scale1 = scale1 ? 1.0F / scale1 : 0.0F; + *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale1); + + v_currentmax = vdupq_n_f32(0); + + for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { + v_f32_maxvals = vld1q_f32(src_ptr + idx_v + read_mem_increment); + v_f32_abs_values = vabsq_f32(v_f32_maxvals); + v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); + } + abs_max = vmaxvq_f32(v_currentmax); + + const float scale2 = abs_max / ((1 << 7) - 1); + const float rep_scale2 = scale2 ? 1.0F / scale2 : 0.0F; + *((uint16_t*)(dst_ptr + ((1 + m_idx_start) % local_mr) * kai_num_bytes_multiplier)) = + kai_cast_f16_f32(scale2); + + v_currentmax = vdupq_n_f32(0); + + for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { + v_f32_maxvals = vld1q_f32(src_ptr + idx_v + 2 * read_mem_increment); + v_f32_abs_values = vabsq_f32(v_f32_maxvals); + v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); + } + abs_max = vmaxvq_f32(v_currentmax); + + const float scale3 = abs_max / ((1 << 7) - 1); + const float rep_scale3 = scale3 ? 1.0F / scale3 : 0.0F; + *((uint16_t*)(dst_ptr + ((2 + m_idx_start) % local_mr) * kai_num_bytes_multiplier)) = + kai_cast_f16_f32(scale3); + + v_currentmax = vdupq_n_f32(0); + + for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { + v_f32_maxvals = vld1q_f32(src_ptr + idx_v + 3 * read_mem_increment); + v_f32_abs_values = vabsq_f32(v_f32_maxvals); + v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); + } + abs_max = vmaxvq_f32(v_currentmax); + + const float scale4 = abs_max / ((1 << 7) - 1); + const float rep_scale4 = scale4 ? 1.0F / scale4 : 0.0F; + *((uint16_t*)(dst_ptr + ((3 + m_idx_start) % local_mr) * kai_num_bytes_multiplier)) = + kai_cast_f16_f32(scale4); + + dst_ptr += local_mr * kai_num_bytes_multiplier; + + dst_ptr += dst_x * k_block_len * sizeof(int8_t); + + // Quantize and pack the blocks + for (size_t k_idx = 0; k_idx < local_bl; k_idx += k_block_len * 2) { + // Row 1 blocks + const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx); + const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, rep_scale1); + const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1); + + const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4); + const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, rep_scale1); + const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2); + + const int16x8_t v_full_i16_block1 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2)); + + const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8); + const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, rep_scale1); + const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3); + + const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12); + const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, rep_scale1); + const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4); + + const int16x8_t v_full_i16_block2 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4)); + + // Row 2 blocks + const float32x4_t v_f32_block5 = vld1q_f32(src_ptr + k_idx + read_mem_increment); + const float32x4_t v_f32_sblock5 = vmulq_n_f32(v_f32_block5, rep_scale2); + const int32x4_t v_i32_block5 = vcvtnq_s32_f32(v_f32_sblock5); + + const float32x4_t v_f32_block6 = vld1q_f32(src_ptr + k_idx + 4 + read_mem_increment); + const float32x4_t v_f32_sblock6 = vmulq_n_f32(v_f32_block6, rep_scale2); + const int32x4_t v_i32_block6 = vcvtnq_s32_f32(v_f32_sblock6); + + const int16x8_t v_full_i16_block3 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block5), vreinterpretq_s16_s32(v_i32_block6)); + + const float32x4_t v_f32_block7 = vld1q_f32(src_ptr + k_idx + 8 + read_mem_increment); + const float32x4_t v_f32_sblock7 = vmulq_n_f32(v_f32_block7, rep_scale2); + const int32x4_t v_i32_block7 = vcvtnq_s32_f32(v_f32_sblock7); + + const float32x4_t v_f32_block8 = vld1q_f32(src_ptr + k_idx + 12 + read_mem_increment); + const float32x4_t v_f32_sblock8 = vmulq_n_f32(v_f32_block8, rep_scale2); + const int32x4_t v_i32_block8 = vcvtnq_s32_f32(v_f32_sblock8); + + const int16x8_t v_full_i16_block4 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block7), vreinterpretq_s16_s32(v_i32_block8)); + + // Row 3 blocks + const float32x4_t v_f32_block9 = vld1q_f32(src_ptr + k_idx + 2 * read_mem_increment); + const float32x4_t v_f32_sblock9 = vmulq_n_f32(v_f32_block9, rep_scale3); + const int32x4_t v_i32_block9 = vcvtnq_s32_f32(v_f32_sblock9); + + const float32x4_t v_f32_blockA = vld1q_f32(src_ptr + k_idx + 4 + 2 * read_mem_increment); + const float32x4_t v_f32_sblockA = vmulq_n_f32(v_f32_blockA, rep_scale3); + const int32x4_t v_i32_blockA = vcvtnq_s32_f32(v_f32_sblockA); + + const int16x8_t v_full_i16_block5 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block9), vreinterpretq_s16_s32(v_i32_blockA)); + + const float32x4_t v_f32_blockB = vld1q_f32(src_ptr + k_idx + 8 + 2 * read_mem_increment); + const float32x4_t v_f32_sblockB = vmulq_n_f32(v_f32_blockB, rep_scale3); + const int32x4_t v_i32_blockB = vcvtnq_s32_f32(v_f32_sblockB); + + const float32x4_t v_f32_blockC = vld1q_f32(src_ptr + k_idx + 12 + 2 * read_mem_increment); + const float32x4_t v_f32_sblockC = vmulq_n_f32(v_f32_blockC, rep_scale3); + const int32x4_t v_i32_blockC = vcvtnq_s32_f32(v_f32_sblockC); + + const int16x8_t v_full_i16_block6 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockB), vreinterpretq_s16_s32(v_i32_blockC)); + + // Row 4 blocks + const float32x4_t v_f32_blockD = vld1q_f32(src_ptr + k_idx + 3 * read_mem_increment); + const float32x4_t v_f32_sblockD = vmulq_n_f32(v_f32_blockD, rep_scale4); + const int32x4_t v_i32_blockD = vcvtnq_s32_f32(v_f32_sblockD); + + const float32x4_t v_f32_blockE = vld1q_f32(src_ptr + k_idx + 4 + 3 * read_mem_increment); + const float32x4_t v_f32_sblockE = vmulq_n_f32(v_f32_blockE, rep_scale4); + const int32x4_t v_i32_blockE = vcvtnq_s32_f32(v_f32_sblockE); + + const int16x8_t v_full_i16_block7 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockD), vreinterpretq_s16_s32(v_i32_blockE)); + + const float32x4_t v_f32_blockF = vld1q_f32(src_ptr + k_idx + 8 + 3 * read_mem_increment); + const float32x4_t v_f32_sblockF = vmulq_n_f32(v_f32_blockF, rep_scale4); + const int32x4_t v_i32_blockF = vcvtnq_s32_f32(v_f32_sblockF); + + const float32x4_t v_f32_block0 = vld1q_f32(src_ptr + k_idx + 12 + 3 * read_mem_increment); + const float32x4_t v_f32_sblock0 = vmulq_n_f32(v_f32_block0, rep_scale4); + const int32x4_t v_i32_block0 = vcvtnq_s32_f32(v_f32_sblock0); + + const int16x8_t v_full_i16_block8 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockF), vreinterpretq_s16_s32(v_i32_block0)); + + const int8x16_t v_i8_block1_3 = + vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block3)); + vst1q_s8(dst_ptr, v_i8_block1_3); + dst_ptr += write_mem_increment; + + const int8x16_t v_i8_block5_7 = + vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block5), vreinterpretq_s8_s16(v_full_i16_block7)); + vst1q_s8(dst_ptr, v_i8_block5_7); + dst_ptr += write_mem_increment; + + const int8x16_t v_i8_block2_4 = + vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block2), vreinterpretq_s8_s16(v_full_i16_block4)); + vst1q_s8(dst_ptr, v_i8_block2_4); + dst_ptr += write_mem_increment; + + const int8x16_t v_i8_block6_8 = + vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block6), vreinterpretq_s8_s16(v_full_i16_block8)); + vst1q_s8(dst_ptr, v_i8_block6_8); + dst_ptr += write_mem_increment; + } + src_ptr += local_bl; + } + lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); + } + } + if (num_rows % 4 != 0) { + for (; row_idx < num_rows; ++row_idx) { + const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); + + for (size_t b = 0; b < num_blocks_per_row; ++b) { + float abs_max = 0.0F; + + const size_t dst_x = ((row_idx + m_idx_start) % local_mr); + int8_t* dst_ptr = (int8_t*)lhs_packed + (b * local_mr) * num_bytes_per_block; + + float32x4_t v_f32_abs_values; + float32x4_t v_f32_maxvals; + float32x4_t v_currentmax = vdupq_n_f32(0); + + for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { + v_f32_maxvals = vld1q_f32(src_ptr + idx_v); + v_f32_abs_values = vabsq_f32(v_f32_maxvals); + v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); + } + abs_max = vmaxvq_f32(v_currentmax); + + // Calculate scale and reciprocal + const float scale = abs_max / ((1 << 7) - 1); + const float rep_scale = scale ? 1.0F / scale : 0.0F; + + *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale); + dst_ptr += local_mr * kai_num_bytes_multiplier; + + dst_ptr += dst_x * k_block_len * sizeof(int8_t); + + // Quantize and pack the block + for (size_t k_idx = 0; k_idx < local_bl; k_idx += k_block_len * 2) { + const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx); + const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, rep_scale); + const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1); + + const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4); + const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, rep_scale); + const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2); + + const int16x8_t v_full_i16_block1 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2)); + + const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8); + const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, rep_scale); + const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3); + + const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12); + const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, rep_scale); + const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4); + + const int16x8_t v_full_i16_block2 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4)); + + const int8x16_t v_full_i8_block = + vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block2)); + + vst1_s8(dst_ptr, vget_low_s8(v_full_i8_block)); + dst_ptr += 8 * sizeof(int8_t); + dst_ptr += (local_mr - 1) * k_block_len * sizeof(int8_t); + + vst1_s8(dst_ptr, vget_high_s8(v_full_i8_block)); + dst_ptr += 8 * sizeof(int8_t); + dst_ptr += (local_mr - 1) * k_block_len * sizeof(int8_t); + } + src_ptr += local_bl; + } + // Move to the next row if we have interleaved all Mr rows + if ((((row_idx + 1) + m_idx_start) % local_mr) == 0) { + lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); + } + } + } +} diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h similarity index 90% rename from kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h rename to kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h index 878693ab..86873a7e 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h @@ -19,7 +19,7 @@ extern "C" { /// @param[in] mr The number of M rows to interleave on the same output row. /// /// @return the m step value -size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x16sb_f32_neon(size_t mr); +size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t mr); /// Gets the offset in bytes for the LHS matrix (not packed) /// @@ -29,7 +29,7 @@ size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x16sb_f32_neon(size_t mr); /// @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) /// /// @return the offset in bytes to the LHS matrix -size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x16sb_f32_neon(size_t m_idx, size_t lhs_stride); +size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t m_idx, size_t lhs_stride); /// Gets the offset in bytes for the packed LHS matrix, /// which contains the packed 8-bit quantized symmetric per-block (qsi8d32) values. @@ -44,7 +44,7 @@ size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x16sb_f32_neon(size_t m_idx, s /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the offset in bytes to the packed LHS matrix -size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); /// Gets the size in bytes for the quantized and packed LHS matrix @@ -58,7 +58,7 @@ size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the packed LHS matrix size in bytes -size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( +size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); /// Run the micro-kernel to quantize and pack the LHS matrix. @@ -75,7 +75,7 @@ size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( /// @param[in] lhs LHS matrix. /// @param[in] lhs_stride Stride in bytes between two rows of LHS. /// @param[out] lhs_packed The quantized and packed LHS matrix. -void kai_run_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( +void kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, size_t lhs_stride, void* lhs_packed); diff --git a/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp index 6fa4d4e2..bfc2a1c3 100644 --- a/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp @@ -23,7 +23,7 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" -#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x16sb_f32_neon.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" @@ -76,7 +76,7 @@ static const std::array< clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p_f32, rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0), UKERNEL_MATMUL_PACK_VARIANT( - clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p_f32, + clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p4x8sb_f32_neon, rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0), UKERNEL_MATMUL_PACK_VARIANT( clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32, @@ -218,16 +218,9 @@ TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, EndToEnd) { ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); - if (ukernel_variant.ukernel.name.find("clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm") == - std::string_view::npos) { - ukernel_variant.pack_interface.lhs_pack( - rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data() + lhs_offset), - lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); - } else { - kai_run_lhs_quant_pack_qsi8d32p4x16sb_f32_neon( - rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data() + lhs_offset), - lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); - } + ukernel_variant.pack_interface.lhs_pack( + rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data() + lhs_offset), + lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); // Runs the RHS packing micro-kernel. const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); -- GitLab