From 53431745446e83c885eb2e4d04a57421cfed0a57 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Mon, 2 Jun 2025 13:58:37 +0100 Subject: [PATCH 1/8] rename new packing function to kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon Signed-off-by: Evie Wright --- CMakeLists.txt | 1 + kai/kai_common.h | 8 + kai/ukernels/matmul/BUILD.bazel | 1 + .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h | 8 +- ...s_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c | 359 ++++++++++++++++++ ...s_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h | 146 +++++++ ...matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp | 74 ++-- 7 files changed, 563 insertions(+), 34 deletions(-) create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a36b210..8b6c23ad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -163,6 +163,7 @@ set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c ) set(KLEIDIAI_FILES_NEON_DOTPROD_ASM diff --git a/kai/kai_common.h b/kai/kai_common.h index b34d699a..2f4d2920 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -184,6 +184,14 @@ struct kai_rhs_pack_qsi8cx_params { float scale_multiplier; ///< Product of input (refers to lhs and rhs) and output quantization scales. }; +/// Parameter struct for RHS matrix packing (Quantized Symmetric Integer 4-bit with per-block quantizatio and s1s0 +/// nibble ordering) +struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params { + int8_t lhs_zero_point; + uint8_t rhs_zero_point; + enum kai_datatype scale_dt; +}; + /// Parameter struct for RHS matrix packing struct kai_rhs_pack_qs4cxs1s0_param { int8_t lhs_zero_point; ///< LHS Matrix quantization zero-point diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index af899695..2ebb4f68 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -35,6 +35,7 @@ NEON_KERNELS = [ "pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon", "pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon", "pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", + "pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", "pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon", diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h index 2822740b..baa051d1 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -13,12 +13,6 @@ extern "C" { #endif -struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params { - int8_t lhs_zero_point; - uint8_t rhs_zero_point; - enum kai_datatype scale_dt; -}; - /// Get the n step value. /// The micro-kernel can process any N values. However, the starting N index to /// be processed must be a multiple of n step. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c new file mode 100644 index 00000000..57ff79c8 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c @@ -0,0 +1,359 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_NEON) & !defined(_M_ARM64) +#error This file must be compiled for AArch64. +#else // Architectural features check. + +#include "kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); +static const size_t kai_nr_multiple_of = 4; +static const size_t kai_bl_multiple_of = 32; + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return (bl / 2) + num_bytes_multiplier_rhs; +} + +inline static size_t kai_get_rhs_packed_offset_end_of_all_blocks( + size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + + return (nr * num_bytes_per_block * num_blocks_per_row); +} + +size_t kai_get_n_step_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + size_t n_idx, // + size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == kai_dt_bf16); + + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + + return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((n_idx % nr) == 0); + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == kai_dt_bf16); + + return (n_idx / nr) * + kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == kai_dt_bf16); + + const size_t num_rows = kai_roundup(n, nr) / nr; + + return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt); +} + +void kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(params->scale_dt == kai_dt_bf16); + + // Note: The input matrix (rhs) is expected with: + // "k" columns and "n" rows (NxK) + const enum kai_datatype scale_dt = params->scale_dt; + const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); + const size_t rhs_packed_offset_end_of_all_blocks = + kai_get_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); + const size_t num_qblocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block_k = bl / 2; + const size_t dst_num_rows = kai_roundup(n, nr); + const size_t block_length_in_bytes = kr / sr; + KAI_ASSERT(block_length_in_bytes == 8); + + uint8_t* dst_row = (uint8_t*)rhs_packed; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; dst_row_idx += nr) { + float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); + + // Initialize the RHS reduction sums to zero + memset(sums, 0, nr * kai_num_bytes_sum_rhs); + + // Iterate over the quantized blocks + for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) { + // Store the scales after packing all K values in the block + uint8_t* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr; + const uint8_t* scale_ptr = (const uint8_t*)scale + dst_qblock_idx * num_bytes_multiplier_rhs; + + for (size_t i = 0; i < nr; ++i) { + const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1); + const void* src_scales_ptr = scale_ptr + src_row_idx * scale_stride; + void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs; + + memcpy( + dst_scales_ptr, // + src_scales_ptr, // + num_bytes_multiplier_rhs); // + } + + size_t k0_idx_i = dst_qblock_idx * bl; + const uint8x8_t top_mask = vdup_n_u8(0xF0); + const uint8x8_t bottom_mask = vdup_n_u8(0x0F); + const uint8x8_t zero_point_conversion_mask = vdup_n_u8(0x88); + + for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { + for (size_t nr_idx = 0; nr_idx < nr; nr_idx += 4) { + // Clamp the indices to avoid out-of-bound reads + const size_t n0_idx = KAI_MIN(dst_row_idx + nr_idx, n - 1); + const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1); + const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1); + const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1); + + const float d0 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 0]); + const float d1 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 1]); + const float d2 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 2]); + const float d3 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 3]); + + // Take zero-point (-8) into account + int32_t partial_sum0 = -(32 * 8); + int32_t partial_sum1 = -(32 * 8); + int32_t partial_sum2 = -(32 * 8); + int32_t partial_sum3 = -(32 * 8); + + const uint8_t* src_block_base = rhs + ((k0_idx_i / 2) + dst_byte_idx); + + // Load elements as uint64_ts to calculate sums more efficiently + uint64_t ld0_0 = *(const uint64_t*)(src_block_base + n0_idx * rhs_stride); + uint64_t ld0_1 = *(const uint64_t*)(src_block_base + n0_idx * rhs_stride + 8); + + uint64_t ld1_0 = *(const uint64_t*)(src_block_base + n1_idx * rhs_stride); + uint64_t ld1_1 = *(const uint64_t*)(src_block_base + n1_idx * rhs_stride + 8); + + uint64_t ld2_0 = *(const uint64_t*)(src_block_base + n2_idx * rhs_stride); + uint64_t ld2_1 = *(const uint64_t*)(src_block_base + n2_idx * rhs_stride + 8); + + uint64_t ld3_0 = *(const uint64_t*)(src_block_base + n3_idx * rhs_stride); + uint64_t ld3_1 = *(const uint64_t*)(src_block_base + n3_idx * rhs_stride + 8); + + // Copy to vector registers + const uint8x8_t vld0_0 = vcreate_u8(ld0_0); + const uint8x8_t vld0_1 = vcreate_u8(ld0_1); + + const uint8x8_t vld1_0 = vcreate_u8(ld1_0); + const uint8x8_t vld1_1 = vcreate_u8(ld1_1); + + const uint8x8_t vld2_0 = vcreate_u8(ld2_0); + const uint8x8_t vld2_1 = vcreate_u8(ld2_1); + + const uint8x8_t vld3_0 = vcreate_u8(ld3_0); + const uint8x8_t vld3_1 = vcreate_u8(ld3_1); + + // Calculate sums + for (size_t idx = 0; idx < 16; ++idx) { + const int32_t e0_0 = (int32_t)(ld0_0 & 0x0F); + const int32_t e0_1 = (int32_t)(ld0_1 & 0x0F); + partial_sum0 += e0_0 + e0_1; + ld0_0 = ld0_0 >> 4; + ld0_1 = ld0_1 >> 4; + + const int32_t e1_0 = (int32_t)(ld1_0 & 0x0F); + const int32_t e1_1 = (int32_t)(ld1_1 & 0x0F); + partial_sum1 += e1_0 + e1_1; + ld1_0 = ld1_0 >> 4; + ld1_1 = ld1_1 >> 4; + + const int32_t e2_0 = (int32_t)(ld2_0 & 0x0F); + const int32_t e2_1 = (int32_t)(ld2_1 & 0x0F); + partial_sum2 += e2_0 + e2_1; + ld2_0 = ld2_0 >> 4; + ld2_1 = ld2_1 >> 4; + + const int32_t e3_0 = (int32_t)(ld3_0 & 0x0F); + const int32_t e3_1 = (int32_t)(ld3_1 & 0x0F); + partial_sum3 += e3_0 + e3_1; + ld3_0 = ld3_0 >> 4; + ld3_1 = ld3_1 >> 4; + } + + const uint8x8_t vld0_s1s = vand_u8(vld0_0, bottom_mask); + const uint8x8_t vld0_s0s = vshr_n_u8(vld0_0, 4); + const uint8x8_t vld0_s17s = vshl_n_u8(vld0_1, 4); + const uint8x8_t vld0_s16s = vand_u8(vld0_1, top_mask); + + const uint8x8_t vld0_s16s0s_lower = + vorr_u8(vzip1_u8(vld0_s1s, vld0_s0s), vzip1_u8(vld0_s17s, vld0_s16s)); + const uint8x8_t vld0_s16s0s_upper = + vorr_u8(vzip2_u8(vld0_s1s, vld0_s0s), vzip2_u8(vld0_s17s, vld0_s16s)); + + const uint8x8_t vld1_s1s = vand_u8(vld1_0, bottom_mask); + const uint8x8_t vld1_s0s = vshr_n_u8(vld1_0, 4); + const uint8x8_t vld1_s17s = vshl_n_u8(vld1_1, 4); + const uint8x8_t vld1_s16s = vand_u8(vld1_1, top_mask); + + const uint8x8_t vld1_s16s0s_lower = + vorr_u8(vzip1_u8(vld1_s1s, vld1_s0s), vzip1_u8(vld1_s17s, vld1_s16s)); + const uint8x8_t vld1_s16s0s_upper = + vorr_u8(vzip2_u8(vld1_s1s, vld1_s0s), vzip2_u8(vld1_s17s, vld1_s16s)); + + const uint8x8_t vld2_s1s = vand_u8(vld2_0, bottom_mask); + const uint8x8_t vld2_s0s = vshr_n_u8(vld2_0, 4); + const uint8x8_t vld2_s17s = vshl_n_u8(vld2_1, 4); + const uint8x8_t vld2_s16s = vand_u8(vld2_1, top_mask); + + const uint8x8_t vld2_s16s0s_lower = + vorr_u8(vzip1_u8(vld2_s1s, vld2_s0s), vzip1_u8(vld2_s17s, vld2_s16s)); + const uint8x8_t vld2_s16s0s_upper = + vorr_u8(vzip2_u8(vld2_s1s, vld2_s0s), vzip2_u8(vld2_s17s, vld2_s16s)); + + const uint8x8_t vld3_s1s = vand_u8(vld3_0, bottom_mask); + const uint8x8_t vld3_s0s = vshr_n_u8(vld3_0, 4); + const uint8x8_t vld3_s17s = vshl_n_u8(vld3_1, 4); + const uint8x8_t vld3_s16s = vand_u8(vld3_1, top_mask); + + const uint8x8_t vld3_s16s0s_lower = + vorr_u8(vzip1_u8(vld3_s1s, vld3_s0s), vzip1_u8(vld3_s17s, vld3_s16s)); + const uint8x8_t vld3_s16s0s_upper = + vorr_u8(vzip2_u8(vld3_s1s, vld3_s0s), vzip2_u8(vld3_s17s, vld3_s16s)); + + // Convert to unsigned int4 and store repacked values + vst1_u8((uint8_t*)dst_row, veor_u8(vld0_s16s0s_lower, zero_point_conversion_mask)); + vst1_u8((uint8_t*)dst_row + 8, veor_u8(vld1_s16s0s_lower, zero_point_conversion_mask)); + vst1_u8((uint8_t*)dst_row + 16, veor_u8(vld2_s16s0s_lower, zero_point_conversion_mask)); + vst1_u8((uint8_t*)dst_row + 24, veor_u8(vld3_s16s0s_lower, zero_point_conversion_mask)); + + vst1_u8( + (uint8_t*)dst_row + (nr * block_length_in_bytes), + veor_u8(vld0_s16s0s_upper, zero_point_conversion_mask)); + vst1_u8( + (uint8_t*)dst_row + (nr * block_length_in_bytes) + 8, + veor_u8(vld1_s16s0s_upper, zero_point_conversion_mask)); + vst1_u8( + (uint8_t*)dst_row + (nr * block_length_in_bytes) + 16, + veor_u8(vld2_s16s0s_upper, zero_point_conversion_mask)); + vst1_u8( + (uint8_t*)dst_row + (nr * block_length_in_bytes) + 24, + veor_u8(vld3_s16s0s_upper, zero_point_conversion_mask)); + + // Add to row sums + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + sums[nr_idx + 0] += (float)partial_sum0 * d0; + sums[nr_idx + 1] += (float)partial_sum1 * d1; + sums[nr_idx + 2] += (float)partial_sum2 * d2; + sums[nr_idx + 3] += (float)partial_sum3 * d3; + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + + dst_row += block_length_in_bytes * 4; + } + // Skip to end of qblock + dst_row += nr * block_length_in_bytes; + } + + // Move the pointer after scales + dst_row += num_bytes_multiplier_rhs * nr; + } + + // Move the pointer after the row sum + dst_row += kai_num_bytes_sum_rhs * nr; + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + + // Move the pointer after the row sum + dst_row += kai_num_bytes_bias * nr; + } +} +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h new file mode 100644 index 00000000..4f331cfa --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h @@ -0,0 +1,146 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Get the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// +/// @return the n step value +size_t kai_get_n_step_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(size_t nr); + +/// Gets the offset in bytes for the RHS matrix (not packed), which holds +/// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. +/// +/// Two int4 values are stored in one byte. +/// The lower order part of the byte (low) holds the first nibble (K-index + 0). +/// The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + size_t n_idx, // + size_t rhs_stride); // + +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the stride in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Gets the offset in bytes for the packed RHS matrix. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Gets the size in bytes for the quantized and packed RHS matrix. +/// +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple +/// of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Runs the RHS packing micro-kernel. +/// +/// The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. +/// Two int4 values are stored in one byte. The lower order part of the byte (low) holds +/// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of rows in the RHS matrix (not packed). +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. It must be a multiple of 4. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] bl The block length, which defines the number of +/// K values stored in a single block. It must be a multiple of 32. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). +/// @param[in] rhs_stride The number of bytes per row in bytes of the RHS matrix +/// @param[in] bias The biases. +/// @param[in] scale The per-block quantization scales. +/// The scale data type must be provided with the params object. +/// Supported scale data types are FP32, FP16 and BF16. +/// @param[in] scale_stride The number of bytes per row in bytes of the scale matrix +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params); // + +#ifdef __cplusplus +} +#endif diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp index b191fd4e..bd7d45c9 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -34,8 +34,10 @@ #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h" #include "test/common/bfloat16.hpp" #include "test/common/buffer.hpp" +#include "test/common/compare.hpp" #include "test/common/cpu_info.hpp" #include "test/common/int4.hpp" #include "test/common/matmul_test_common.hpp" @@ -261,6 +263,39 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_nxk) { // Runs the GEMM micro-kernel. Buffer imp_dst(imp_dst_size); + if (kr / sr == 8) { + // Test that vectorized packing kernel gives same output as scalar + const auto imp_packed_rhs_size_neon = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt); + ASSERT_EQ(imp_packed_rhs_size_neon, imp_packed_rhs_size); + + Buffer imp_packed_rhs_neon(imp_packed_rhs_size_neon); + + auto rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + rhs_start_row, K, nr, kr, sr, bl, scale_dt); + ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset); + + auto rhs_offset_neon = + kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(rhs_start_row, ref_rhs_qsu4_stride); + + kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + 1, rect.width() /* n */, K, nr, kr, sr, bl, + reinterpret_cast(ref_rhs_qsu4_padded.data() + rhs_offset_neon), ref_rhs_qsu4_stride, + reinterpret_cast(ref_biases.data() + bias_offset), + reinterpret_cast(ref_rhs_scales.data() + scale_offset), ref_rhs_scales_stride, + imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, ¶ms); + + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, + imp_packed_rhs_neon.data() + rhs_matmul_offset, reinterpret_cast(imp_dst.data() + dst_offset), + N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); + + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::FP32); + const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); + } + ukernel_variant.interface.run_matmul( rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast(imp_dst.data() + dst_offset), @@ -268,19 +303,12 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_nxk) { // Compares the output of the micro-kernels against the output of the reference implementation for the portion // tested. - for (size_t y = 0; y < rect.height(); ++y) { - for (size_t x = 0; x < rect.width(); ++x) { - const auto imp_value = - read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); - const auto ref_value = - read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; - - if (rel_error > 0.0001F) { - ASSERT_EQ(imp_value, ref_value); - } - } - } + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::FP32); + const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); } TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_kxn) { @@ -410,20 +438,12 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_kxn) { imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast(imp_dst.data() + dst_offset), N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); - // Compares the output of the micro-kernels against the output of the reference implementation. - for (size_t y = 0; y < rect.height(); ++y) { - for (size_t x = 0; x < rect.width(); ++x) { - const auto imp_value = - read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); - const auto ref_value = - read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; - - if (rel_error > 0.0001F) { - ASSERT_EQ(imp_value, ref_value); - } - } - } + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::FP32); + const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); } INSTANTIATE_TEST_SUITE_P( -- GitLab From 02446e651211efdabe58e320af901a0812d31b28 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Mon, 2 Jun 2025 17:48:32 +0100 Subject: [PATCH 2/8] alter cmakelists to enable msvc support Signed-off-by: Evie Wright --- CMakeLists.txt | 9 ++++++++- .../kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8b6c23ad..68c707dc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -152,7 +152,12 @@ set(KLEIDIAI_FILES_NEON_FP16_BF16 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.c ) +set(KLEIDIAI_FILES_NEON_ASM + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c +) + set(KLEIDIAI_FILES_NEON + ${KLEIDIAI_FILES_NEON_ASM} kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla_asm.S kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c @@ -163,7 +168,6 @@ set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c - kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c ) set(KLEIDIAI_FILES_NEON_DOTPROD_ASM @@ -315,6 +319,7 @@ if(NOT MSVC) set_source_files_properties(${KLEIDIAI_FILES_SME} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}") set_source_files_properties(${KLEIDIAI_FILES_SME2} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}") else() + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME_ASM}) @@ -323,10 +328,12 @@ else() set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_NEON_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set(KLEIDIAI_FILES_ASM ${KLEIDIAI_FILES_SME_ASM} ${KLEIDIAI_FILES_SME2_ASM} + ${KLEIDIAI_FILES_NEON_ASM} ${KLEIDIAI_FILES_NEON_DOTPROD_ASM} ${KLEIDIAI_FILES_NEON_I8MM_ASM}) list(FILTER KLEIDIAI_FILES_ASM INCLUDE REGEX "^.*\.S$") diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c index 57ff79c8..40a21008 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c @@ -9,6 +9,7 @@ #include "kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h" +#include #include #include #include -- GitLab From 8ac2526139200d3c0ec49f82a8f245c4c23afc3e Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Tue, 3 Jun 2025 09:21:28 +0100 Subject: [PATCH 3/8] add scalar files to matching bazel category Signed-off-by: Evie Wright --- kai/ukernels/matmul/BUILD.bazel | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 2ebb4f68..c3cc2192 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -22,7 +22,9 @@ package(default_visibility = ["//visibility:private"]) SCALAR_KERNELS = [ "pack/kai_lhs_quant_pack_qai8dxp_f32", "pack/kai_lhs_quant_pack_qsi8d32p_f32", + "pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", "pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0", + "pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", "pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0", ] @@ -31,10 +33,8 @@ NEON_KERNELS = [ "pack/kai_lhs_quant_pack_qsi8d32p_f32_neon", "pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon", "pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon", - "pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", "pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon", "pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon", - "pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", "pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", -- GitLab From c4e69c462543a893e5db9ba26062f08d44db9893 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Tue, 3 Jun 2025 17:08:47 +0100 Subject: [PATCH 4/8] local rebase to avoid changelog merge conflict Signed-off-by: Evie Wright --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b682edd..38a17421 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - New Advanced SIMD micro-kernels: - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_DotProd. - Added Convolution example using SME Indirect Matmul Kernels +- New Advanced SIMD micro-kernels: + - Vectorized implementation of non-transposed per-block packing function `kai_rhs_pack_nxk_qsi4c32p_qsu4cxs1s0` for block depth of 8 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4cxs1s0_neon`) ## v1.9.0 -- GitLab From ba33a928394068a68e18c4714f88d06d0c32306e Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Wed, 4 Jun 2025 15:44:49 +0100 Subject: [PATCH 5/8] update README to expand on naming convention, standardize compilation guards Signed-off-by: Evie Wright --- kai/ukernels/matmul/pack/README.md | 10 +++++++++- .../kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/kai/ukernels/matmul/pack/README.md b/kai/ukernels/matmul/pack/README.md index 950a69ac..90b14eff 100644 --- a/kai/ukernels/matmul/pack/README.md +++ b/kai/ukernels/matmul/pack/README.md @@ -1,5 +1,5 @@ @@ -110,3 +110,11 @@ RHS packed matrix (N x K) contains quantized (q) symmetric (s) 4-bit signed int #### kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0() Same as kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0() with the input RHS matrix dimensions as K x N. + +### Vectorized packing routines with predefined block depth + +Alternative versions of certain packing functions are provided using Advanced SIMD, specialized for a predefined block depth (equal to kr / sr). + +#### kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon() + +This takes the same input and provides the same output as kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(), with faster execution time where Advanced SIMD instructions are supported. The nrx8 included within the name indicates that this routine works only where kr / sr = 8, and for any value of nr that fits within the wider constraints. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c index 40a21008..5a6985fd 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__ARM_NEON) & !defined(_M_ARM64) +#if !defined(__aarch64__) && !defined(_M_ARM64) #error This file must be compiled for AArch64. #else // Architectural features check. -- GitLab From 9ba9c29ea4b26886806f8576119c866e92e8ea5f Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Thu, 5 Jun 2025 10:51:21 +0100 Subject: [PATCH 6/8] add comment in packing file providing clarity on naming convention Signed-off-by: Evie Wright --- .../pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c index 5a6985fd..3cb7f111 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c @@ -3,6 +3,10 @@ // // SPDX-License-Identifier: Apache-2.0 // + +// nrx8 => this function can take in generic nr values but the input is expected to have a block depth of 8. +// Block depth is calculated as kr / sr. The values of these parameters are defined in the matmul ukernel. + #if !defined(__aarch64__) && !defined(_M_ARM64) #error This file must be compiled for AArch64. #else // Architectural features check. -- GitLab From 1f1de70ab4f36a293db9618ced22d2a158647d19 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 5 Jun 2025 12:53:44 +0100 Subject: [PATCH 7/8] Fix Changelog Signed-off-by: Anitha Raj --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38a17421..cacb5cd5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,9 +22,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme - New Advanced SIMD micro-kernels: - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_DotProd. + - Optimized RHS packing function for block depth of 8 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon`) - Added Convolution example using SME Indirect Matmul Kernels -- New Advanced SIMD micro-kernels: - - Vectorized implementation of non-transposed per-block packing function `kai_rhs_pack_nxk_qsi4c32p_qsu4cxs1s0` for block depth of 8 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4cxs1s0_neon`) ## v1.9.0 -- GitLab From 5b4cdcaccbb234a23233620a6cf71173d11f4e68 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 5 Jun 2025 16:57:47 +0100 Subject: [PATCH 8/8] Update Changelog to address review comment Signed-off-by: Anitha Raj --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cacb5cd5..10a4d450 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme - New Advanced SIMD micro-kernels: - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_DotProd. - - Optimized RHS packing function for block depth of 8 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon`) + - Optimized version of kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 kernel for block depth of 8 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon`) - Added Convolution example using SME Indirect Matmul Kernels ## v1.9.0 -- GitLab