diff --git a/CHANGELOG.md b/CHANGELOG.md index 99cac4f3f5b09e9ab7fb5ebc524ab6e3b07ef06d..5f634825d0f0fd431d27a72d801ff8b4fefdd580 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- Improve performance of lhs_quant_pack_qsi8d32p_f32 using Advanced SIMD reimplemented as quant_pack_qsi8d32p4x8sb_f32_neon + ## v1.12.0 - New Advanced SIMD micro-kernels: diff --git a/CMakeLists.txt b/CMakeLists.txt index 1146a801f5a708afb596185e6282347a261d00c3..dd6191801414daf108f2c93feaa7d6b7ca3567ff 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -159,6 +159,7 @@ set(KLEIDIAI_FILES_NEON_ASM kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla_asm.S kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c + kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 498dbd76561300eea8d541678786711d4607d371..6d96d77d366621ab4f48ba27abda1e79cace149c 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -31,6 +31,7 @@ SCALAR_KERNELS = [ # buildifier: keep sorted NEON_KERNELS = [ "pack/kai_lhs_quant_pack_qai8dxp_bf16_neon", + "pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon", "pack/kai_lhs_quant_pack_qsi8d32p_f32_neon", "pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon", "pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon", diff --git a/kai/ukernels/matmul/pack/README.md b/kai/ukernels/matmul/pack/README.md index 90b14eff16399cf18ee3e26f294f556bf9111acd..69159f1a191858352b9f2a8f595c81dd0701800a 100644 --- a/kai/ukernels/matmul/pack/README.md +++ b/kai/ukernels/matmul/pack/README.md @@ -61,6 +61,15 @@ Output LHS packed matrix containing quantized (q) symmertric (s) signed int8 (i8) values, with block-wise quantization (d32p) parameters, i.e. the quantized elements are stored in blocks and each block has a scale factor. +#### kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon() + +This routine follows the same format as kai_run_lhs_quant_pack_qsi8d32p_f32() above. + +However, it differs in the following way: + +1. Functionality is implemented using vectorized Advanced SIMD to improve performance +1. The packing routine targets a specific shape with mr 4, kr 16, sr 2 & bl 32 + #### kai_run_lhs_quant_pack_qai8dxp_f32() Quantize and pack LHS matrix with per-dimension(row) quantization parameters. diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c new file mode 100644 index 0000000000000000000000000000000000000000..a81dcc571257a94a8313d135db7d73fcc4fcf2c4 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c @@ -0,0 +1,357 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); + +inline static size_t kai_num_bytes_per_block(size_t bl) { + return bl * sizeof(int8_t) + kai_num_bytes_multiplier; +} + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((k % bl) == 0); + return k / bl; +} + +inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { + KAI_UNUSED(kr); + return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl); +} + +size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t mr) { + KAI_UNUSED(mr); + return 1; +} + +size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( + size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + + KAI_UNUSED(sr); + KAI_UNUSED(kr); + + return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl); +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + + KAI_UNUSED(sr); + KAI_UNUSED(kr); + + const size_t num_rows = kai_roundup(m, mr) / mr; + + return num_rows * kai_lhs_packed_stride(k, mr, kr, bl); +} + +void kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, + size_t lhs_stride, void* lhs_packed) { + if (m == 0) { + return; + } + + KAI_ASSUME(bl == 32); + KAI_ASSUME(mr == 4); + KAI_ASSUME(kr == 16); + KAI_ASSUME(sr == 2); + + const size_t local_bl = 32; + const size_t local_mr = 4; + const size_t local_kr = 16; + const size_t local_sr = 2; + const size_t num_rows = m; + const size_t k_block_len = local_kr / local_sr; + const size_t lhs_packed_stride = kai_lhs_packed_stride(k, local_mr, local_kr, local_bl); + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, local_bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(local_bl); + + size_t row_idx = 0; + + const size_t write_mem_increment = 2 * k_block_len * sizeof(int8_t); + const size_t read_mem_increment = num_blocks_per_row * local_bl * sizeof(int8_t); + + if (num_rows >= 4) { + for (; row_idx + 4 <= num_rows; row_idx += 4) { + const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); + + for (size_t b = 0; b < num_blocks_per_row; ++b) { + const size_t dst_x = ((row_idx + m_idx_start) % local_mr); + int8_t* dst_ptr = (int8_t*)lhs_packed + (b * local_mr) * num_bytes_per_block; + + float32x4_t v_f32_abs_values; + float32x4_t v_f32_maxvals; + float32x4_t v_currentmax; + float abs_max = 0.0F; + + v_currentmax = vdupq_n_f32(0); + + for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { + v_f32_maxvals = vld1q_f32(src_ptr + idx_v); + v_f32_abs_values = vabsq_f32(v_f32_maxvals); + v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); + } + abs_max = vmaxvq_f32(v_currentmax); + + // Calculate scale and reciprocals + const float scale1 = abs_max / ((1 << 7) - 1); + const float rep_scale1 = scale1 ? 1.0F / scale1 : 0.0F; + *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale1); + + v_currentmax = vdupq_n_f32(0); + + for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { + v_f32_maxvals = vld1q_f32(src_ptr + idx_v + read_mem_increment); + v_f32_abs_values = vabsq_f32(v_f32_maxvals); + v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); + } + abs_max = vmaxvq_f32(v_currentmax); + + const float scale2 = abs_max / ((1 << 7) - 1); + const float rep_scale2 = scale2 ? 1.0F / scale2 : 0.0F; + *((uint16_t*)(dst_ptr + ((1 + m_idx_start) % local_mr) * kai_num_bytes_multiplier)) = + kai_cast_f16_f32(scale2); + + v_currentmax = vdupq_n_f32(0); + + for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { + v_f32_maxvals = vld1q_f32(src_ptr + idx_v + 2 * read_mem_increment); + v_f32_abs_values = vabsq_f32(v_f32_maxvals); + v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); + } + abs_max = vmaxvq_f32(v_currentmax); + + const float scale3 = abs_max / ((1 << 7) - 1); + const float rep_scale3 = scale3 ? 1.0F / scale3 : 0.0F; + *((uint16_t*)(dst_ptr + ((2 + m_idx_start) % local_mr) * kai_num_bytes_multiplier)) = + kai_cast_f16_f32(scale3); + + v_currentmax = vdupq_n_f32(0); + + for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { + v_f32_maxvals = vld1q_f32(src_ptr + idx_v + 3 * read_mem_increment); + v_f32_abs_values = vabsq_f32(v_f32_maxvals); + v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); + } + abs_max = vmaxvq_f32(v_currentmax); + + const float scale4 = abs_max / ((1 << 7) - 1); + const float rep_scale4 = scale4 ? 1.0F / scale4 : 0.0F; + *((uint16_t*)(dst_ptr + ((3 + m_idx_start) % local_mr) * kai_num_bytes_multiplier)) = + kai_cast_f16_f32(scale4); + + dst_ptr += local_mr * kai_num_bytes_multiplier; + + dst_ptr += dst_x * k_block_len * sizeof(int8_t); + + // Quantize and pack the blocks + for (size_t k_idx = 0; k_idx < local_bl; k_idx += k_block_len * 2) { + // Row 1 blocks + const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx); + const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, rep_scale1); + const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1); + + const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4); + const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, rep_scale1); + const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2); + + const int16x8_t v_full_i16_block1 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2)); + + const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8); + const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, rep_scale1); + const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3); + + const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12); + const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, rep_scale1); + const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4); + + const int16x8_t v_full_i16_block2 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4)); + + // Row 2 blocks + const float32x4_t v_f32_block5 = vld1q_f32(src_ptr + k_idx + read_mem_increment); + const float32x4_t v_f32_sblock5 = vmulq_n_f32(v_f32_block5, rep_scale2); + const int32x4_t v_i32_block5 = vcvtnq_s32_f32(v_f32_sblock5); + + const float32x4_t v_f32_block6 = vld1q_f32(src_ptr + k_idx + 4 + read_mem_increment); + const float32x4_t v_f32_sblock6 = vmulq_n_f32(v_f32_block6, rep_scale2); + const int32x4_t v_i32_block6 = vcvtnq_s32_f32(v_f32_sblock6); + + const int16x8_t v_full_i16_block3 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block5), vreinterpretq_s16_s32(v_i32_block6)); + + const float32x4_t v_f32_block7 = vld1q_f32(src_ptr + k_idx + 8 + read_mem_increment); + const float32x4_t v_f32_sblock7 = vmulq_n_f32(v_f32_block7, rep_scale2); + const int32x4_t v_i32_block7 = vcvtnq_s32_f32(v_f32_sblock7); + + const float32x4_t v_f32_block8 = vld1q_f32(src_ptr + k_idx + 12 + read_mem_increment); + const float32x4_t v_f32_sblock8 = vmulq_n_f32(v_f32_block8, rep_scale2); + const int32x4_t v_i32_block8 = vcvtnq_s32_f32(v_f32_sblock8); + + const int16x8_t v_full_i16_block4 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block7), vreinterpretq_s16_s32(v_i32_block8)); + + // Row 3 blocks + const float32x4_t v_f32_block9 = vld1q_f32(src_ptr + k_idx + 2 * read_mem_increment); + const float32x4_t v_f32_sblock9 = vmulq_n_f32(v_f32_block9, rep_scale3); + const int32x4_t v_i32_block9 = vcvtnq_s32_f32(v_f32_sblock9); + + const float32x4_t v_f32_blockA = vld1q_f32(src_ptr + k_idx + 4 + 2 * read_mem_increment); + const float32x4_t v_f32_sblockA = vmulq_n_f32(v_f32_blockA, rep_scale3); + const int32x4_t v_i32_blockA = vcvtnq_s32_f32(v_f32_sblockA); + + const int16x8_t v_full_i16_block5 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block9), vreinterpretq_s16_s32(v_i32_blockA)); + + const float32x4_t v_f32_blockB = vld1q_f32(src_ptr + k_idx + 8 + 2 * read_mem_increment); + const float32x4_t v_f32_sblockB = vmulq_n_f32(v_f32_blockB, rep_scale3); + const int32x4_t v_i32_blockB = vcvtnq_s32_f32(v_f32_sblockB); + + const float32x4_t v_f32_blockC = vld1q_f32(src_ptr + k_idx + 12 + 2 * read_mem_increment); + const float32x4_t v_f32_sblockC = vmulq_n_f32(v_f32_blockC, rep_scale3); + const int32x4_t v_i32_blockC = vcvtnq_s32_f32(v_f32_sblockC); + + const int16x8_t v_full_i16_block6 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockB), vreinterpretq_s16_s32(v_i32_blockC)); + + // Row 4 blocks + const float32x4_t v_f32_blockD = vld1q_f32(src_ptr + k_idx + 3 * read_mem_increment); + const float32x4_t v_f32_sblockD = vmulq_n_f32(v_f32_blockD, rep_scale4); + const int32x4_t v_i32_blockD = vcvtnq_s32_f32(v_f32_sblockD); + + const float32x4_t v_f32_blockE = vld1q_f32(src_ptr + k_idx + 4 + 3 * read_mem_increment); + const float32x4_t v_f32_sblockE = vmulq_n_f32(v_f32_blockE, rep_scale4); + const int32x4_t v_i32_blockE = vcvtnq_s32_f32(v_f32_sblockE); + + const int16x8_t v_full_i16_block7 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockD), vreinterpretq_s16_s32(v_i32_blockE)); + + const float32x4_t v_f32_blockF = vld1q_f32(src_ptr + k_idx + 8 + 3 * read_mem_increment); + const float32x4_t v_f32_sblockF = vmulq_n_f32(v_f32_blockF, rep_scale4); + const int32x4_t v_i32_blockF = vcvtnq_s32_f32(v_f32_sblockF); + + const float32x4_t v_f32_block0 = vld1q_f32(src_ptr + k_idx + 12 + 3 * read_mem_increment); + const float32x4_t v_f32_sblock0 = vmulq_n_f32(v_f32_block0, rep_scale4); + const int32x4_t v_i32_block0 = vcvtnq_s32_f32(v_f32_sblock0); + + const int16x8_t v_full_i16_block8 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockF), vreinterpretq_s16_s32(v_i32_block0)); + + const int8x16_t v_i8_block1_3 = + vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block3)); + vst1q_s8(dst_ptr, v_i8_block1_3); + dst_ptr += write_mem_increment; + + const int8x16_t v_i8_block5_7 = + vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block5), vreinterpretq_s8_s16(v_full_i16_block7)); + vst1q_s8(dst_ptr, v_i8_block5_7); + dst_ptr += write_mem_increment; + + const int8x16_t v_i8_block2_4 = + vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block2), vreinterpretq_s8_s16(v_full_i16_block4)); + vst1q_s8(dst_ptr, v_i8_block2_4); + dst_ptr += write_mem_increment; + + const int8x16_t v_i8_block6_8 = + vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block6), vreinterpretq_s8_s16(v_full_i16_block8)); + vst1q_s8(dst_ptr, v_i8_block6_8); + dst_ptr += write_mem_increment; + } + src_ptr += local_bl; + } + lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); + } + } + if (num_rows % 4 != 0) { + for (; row_idx < num_rows; ++row_idx) { + const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); + + for (size_t b = 0; b < num_blocks_per_row; ++b) { + float abs_max = 0.0F; + + const size_t dst_x = ((row_idx + m_idx_start) % local_mr); + int8_t* dst_ptr = (int8_t*)lhs_packed + (b * local_mr) * num_bytes_per_block; + + float32x4_t v_f32_abs_values; + float32x4_t v_f32_maxvals; + float32x4_t v_currentmax = vdupq_n_f32(0); + + for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { + v_f32_maxvals = vld1q_f32(src_ptr + idx_v); + v_f32_abs_values = vabsq_f32(v_f32_maxvals); + v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); + } + abs_max = vmaxvq_f32(v_currentmax); + + // Calculate scale and reciprocal + const float scale = abs_max / ((1 << 7) - 1); + const float rep_scale = scale ? 1.0F / scale : 0.0F; + + *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale); + dst_ptr += local_mr * kai_num_bytes_multiplier; + + dst_ptr += dst_x * k_block_len * sizeof(int8_t); + + // Quantize and pack the block + for (size_t k_idx = 0; k_idx < local_bl; k_idx += k_block_len * 2) { + const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx); + const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, rep_scale); + const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1); + + const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4); + const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, rep_scale); + const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2); + + const int16x8_t v_full_i16_block1 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2)); + + const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8); + const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, rep_scale); + const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3); + + const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12); + const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, rep_scale); + const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4); + + const int16x8_t v_full_i16_block2 = + vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4)); + + const int8x16_t v_full_i8_block = + vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block2)); + + vst1_s8(dst_ptr, vget_low_s8(v_full_i8_block)); + dst_ptr += 8 * sizeof(int8_t); + dst_ptr += (local_mr - 1) * k_block_len * sizeof(int8_t); + + vst1_s8(dst_ptr, vget_high_s8(v_full_i8_block)); + dst_ptr += 8 * sizeof(int8_t); + dst_ptr += (local_mr - 1) * k_block_len * sizeof(int8_t); + } + src_ptr += local_bl; + } + // Move to the next row if we have interleaved all Mr rows + if ((((row_idx + 1) + m_idx_start) % local_mr) == 0) { + lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); + } + } + } +} diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..86873a7effbd51278d99270252e4edd8a397ddec --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h @@ -0,0 +1,84 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @param[in] mr The number of M rows to interleave on the same output row. +/// +/// @return the m step value +size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t mr); + +/// Gets the offset in bytes for the LHS matrix (not packed) +/// +/// This function should be called before passing the pointer to the LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) +/// +/// @return the offset in bytes to the LHS matrix +size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized symmetric per-block (qsi8d32) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] bl The block length. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( + size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes for the quantized and packed LHS matrix +/// +/// @param[in] m Total number of rows in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed LHS matrix size in bytes +size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + +/// Run the micro-kernel to quantize and pack the LHS matrix. +/// +/// @param[in] m The number of output rows written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be +/// a multiple of 32. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] m_idx_start The starting M index. +/// @param[in] lhs LHS matrix. +/// @param[in] lhs_stride Stride in bytes between two rows of LHS. +/// @param[out] lhs_packed The quantized and packed LHS matrix. +void kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, + size_t lhs_stride, void* lhs_packed); + +#ifdef __cplusplus +} +#endif diff --git a/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp index ac69e11926b465eaed659be032d7a9ce2d862b1d..bfc2a1c379139731d7030218b54f2891457db93f 100644 --- a/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp @@ -23,6 +23,7 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" @@ -75,7 +76,7 @@ static const std::array< clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p_f32, rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0), UKERNEL_MATMUL_PACK_VARIANT( - clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p_f32, + clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p4x8sb_f32_neon, rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0), UKERNEL_MATMUL_PACK_VARIANT( clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32,