diff --git a/CHANGELOG.md b/CHANGELOG.md index ea7379f8b4712bcb4c3b38df58390e6c09ea2750..d03bb04187abd1054acd22332fb82f27d5c0e5b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- New Advanced SIMD micro-kernels: + - Optimized version of kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 kernel for block depth of 4 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon`) + ## v1.10.0 - Convert SME and SME2 imatmul micro-kernels to use pure assembly, and add MSVC support. Affects: diff --git a/CMakeLists.txt b/CMakeLists.txt index a7d6185cf9b038648da0897e041b72b53cc8bda1..1673084c61bfddbf2dc6f480553db9f832881f16 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -158,6 +158,7 @@ set(KLEIDIAI_FILES_NEON_ASM kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.c ) set(KLEIDIAI_FILES_NEON diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 038a4c0189cc8e9dc4920d10c9de9de3ad42401c..ae9628a3650894573f7b269283cb6aaf25fa3d73 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -35,6 +35,7 @@ NEON_KERNELS = [ "pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon", "pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon", "pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon", + "pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.c new file mode 100644 index 0000000000000000000000000000000000000000..4b6403de28f3483ca4b09f8ba773eb908c632cee --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.c @@ -0,0 +1,353 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// nrx4 => this function can take in generic nr values but the input is expected to have a block depth of 4. +// Block depth is calculated as kr / sr. The values of these parameters are defined in the matmul ukernel. + +#if !defined(__aarch64__) && !defined(_M_ARM64) +#error This file must be compiled for AArch64. +#else // Architectural features check. + +#include "kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h" + +#include +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); +static const size_t kai_nr_multiple_of = 4; +static const size_t kai_bl_multiple_of = 32; + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return (bl / 2) + num_bytes_multiplier_rhs; +} + +inline static size_t kai_get_rhs_packed_offset_end_of_all_blocks( + size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + + return (nr * num_bytes_per_block * num_blocks_per_row); +} + +size_t kai_get_n_step_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + size_t n_idx, // + size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == kai_dt_bf16); + + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + + return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((n_idx % nr) == 0); + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == kai_dt_bf16); + + return (n_idx / nr) * + kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == kai_dt_bf16); + + const size_t num_rows = kai_roundup(n, nr) / nr; + + return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt); +} + +void kai_run_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(params->scale_dt == kai_dt_bf16); + + // Note: The input matrix (rhs) is expected with: + // "k" columns and "n" rows (NxK) + const enum kai_datatype scale_dt = params->scale_dt; + const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); + const size_t rhs_packed_offset_end_of_all_blocks = + kai_get_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); + const size_t num_qblocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block_k = bl / 2; + const size_t dst_num_rows = kai_roundup(n, nr); + const size_t block_length_in_bytes = kr / sr; + KAI_ASSERT(block_length_in_bytes == 4); + + uint8_t* dst_row = (uint8_t*)rhs_packed; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; dst_row_idx += nr) { + float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); + + // Initialize the RHS reduction sums to zero + memset(sums, 0, nr * kai_num_bytes_sum_rhs); + + // Iterate over the quantized blocks + for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) { + // Store the scales after packing all K values in the block + uint8_t* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr; + const uint8_t* scale_ptr = (const uint8_t*)scale + dst_qblock_idx * num_bytes_multiplier_rhs; + + for (size_t i = 0; i < nr; ++i) { + const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1); + const void* src_scales_ptr = scale_ptr + src_row_idx * scale_stride; + void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs; + + memcpy( + dst_scales_ptr, // + src_scales_ptr, // + num_bytes_multiplier_rhs); // + } + + size_t k0_idx_i = dst_qblock_idx * bl; + const uint8x8_t top_mask = vdup_n_u8(0xF0); + const uint8x8_t bottom_mask = vdup_n_u8(0x0F); + const uint32x2_t zero_point_conversion_mask = vdup_n_u32(0x88888888); + + for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { + for (size_t nr_idx = 0; nr_idx < nr; nr_idx += 4) { + // Clamp the indices to avoid out-of-bound reads + const size_t n0_idx = KAI_MIN(dst_row_idx + nr_idx, n - 1); + const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1); + const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1); + const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1); + + const float d0 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 0]); + const float d1 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 1]); + const float d2 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 2]); + const float d3 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 3]); + + // Initialize partial sum taking new zero-point (8) into account + int32_t partial_sum0 = -(32 * 8); + int32_t partial_sum1 = -(32 * 8); + int32_t partial_sum2 = -(32 * 8); + int32_t partial_sum3 = -(32 * 8); + + const uint8_t* src_block_base = rhs + ((k0_idx_i / 2) + dst_byte_idx); + + const uint8x8_t vld0_0 = vld1_u8(src_block_base + n0_idx * rhs_stride); + const uint8x8_t vld0_1 = vld1_u8(src_block_base + n0_idx * rhs_stride + 8); + const uint8x8_t vld1_0 = vld1_u8(src_block_base + n1_idx * rhs_stride); + const uint8x8_t vld1_1 = vld1_u8(src_block_base + n1_idx * rhs_stride + 8); + const uint8x8_t vld2_0 = vld1_u8(src_block_base + n2_idx * rhs_stride); + const uint8x8_t vld2_1 = vld1_u8(src_block_base + n2_idx * rhs_stride + 8); + const uint8x8_t vld3_0 = vld1_u8(src_block_base + n3_idx * rhs_stride); + const uint8x8_t vld3_1 = vld1_u8(src_block_base + n3_idx * rhs_stride + 8); + + // Reorder blocks to give correct packing + const uint8x8_t vld0_0_lower = vand_u8(vld0_0, bottom_mask); + const uint8x8_t vld0_1_lower = vshl_n_u8(vld0_1, 4); + const uint8x8_t vld0_0_upper = vshr_n_u8(vld0_0, 4); + const uint8x8_t vld0_1_upper = vand_u8(vld0_1, top_mask); + const uint8x8_t vstr0_04 = + vorr_u8(vzip1_u8(vld0_0_lower, vld0_0_upper), vzip1_u8(vld0_1_lower, vld0_1_upper)); + const uint8x8_t vstr0_46 = + vorr_u8(vzip2_u8(vld0_0_lower, vld0_0_upper), vzip2_u8(vld0_1_lower, vld0_1_upper)); + + const uint8x8_t vld1_0_lower = vand_u8(vld1_0, bottom_mask); + const uint8x8_t vld1_1_lower = vshl_n_u8(vld1_1, 4); + const uint8x8_t vld1_0_upper = vshr_n_u8(vld1_0, 4); + const uint8x8_t vld1_1_upper = vand_u8(vld1_1, top_mask); + const uint8x8_t vstr0_04_1 = + vorr_u8(vzip1_u8(vld1_0_lower, vld1_0_upper), vzip1_u8(vld1_1_lower, vld1_1_upper)); + const uint8x8_t vstr0_46_1 = + vorr_u8(vzip2_u8(vld1_0_lower, vld1_0_upper), vzip2_u8(vld1_1_lower, vld1_1_upper)); + + const uint8x8_t vld2_0_lower = vand_u8(vld2_0, bottom_mask); + const uint8x8_t vld2_1_lower = vshl_n_u8(vld2_1, 4); + const uint8x8_t vld2_0_upper = vshr_n_u8(vld2_0, 4); + const uint8x8_t vld2_1_upper = vand_u8(vld2_1, top_mask); + const uint8x8_t vstr0_15 = + vorr_u8(vzip1_u8(vld2_0_lower, vld2_0_upper), vzip1_u8(vld2_1_lower, vld2_1_upper)); + const uint8x8_t vstr0_57 = + vorr_u8(vzip2_u8(vld2_0_lower, vld2_0_upper), vzip2_u8(vld2_1_lower, vld2_1_upper)); + + const uint8x8_t vld3_0_lower = vand_u8(vld3_0, bottom_mask); + const uint8x8_t vld3_1_lower = vshl_n_u8(vld3_1, 4); + const uint8x8_t vld3_0_upper = vshr_n_u8(vld3_0, 4); + const uint8x8_t vld3_1_upper = vand_u8(vld3_1, top_mask); + const uint8x8_t vstr0_15_1 = + vorr_u8(vzip1_u8(vld3_0_lower, vld3_0_upper), vzip1_u8(vld3_1_lower, vld3_1_upper)); + const uint8x8_t vstr0_57_1 = + vorr_u8(vzip2_u8(vld3_0_lower, vld3_0_upper), vzip2_u8(vld3_1_lower, vld3_1_upper)); + + const uint32x2_t vstr0_0 = + vzip1_u32(vreinterpret_u32_u8(vstr0_04), vreinterpret_u32_u8(vstr0_04_1)); + const uint32x2_t vstr0_4 = + vzip1_u32(vreinterpret_u32_u8(vstr0_46), vreinterpret_u32_u8(vstr0_46_1)); + const uint32x2_t vstr0_2 = + vzip2_u32(vreinterpret_u32_u8(vstr0_04), vreinterpret_u32_u8(vstr0_04_1)); + const uint32x2_t vstr0_6 = + vzip2_u32(vreinterpret_u32_u8(vstr0_46), vreinterpret_u32_u8(vstr0_46_1)); + const uint32x2_t vstr0_1 = + vzip1_u32(vreinterpret_u32_u8(vstr0_15), vreinterpret_u32_u8(vstr0_15_1)); + const uint32x2_t vstr0_5 = + vzip1_u32(vreinterpret_u32_u8(vstr0_57), vreinterpret_u32_u8(vstr0_57_1)); + const uint32x2_t vstr0_3 = + vzip2_u32(vreinterpret_u32_u8(vstr0_15), vreinterpret_u32_u8(vstr0_15_1)); + const uint32x2_t vstr0_7 = + vzip2_u32(vreinterpret_u32_u8(vstr0_57), vreinterpret_u32_u8(vstr0_57_1)); + + // Convert to signed int4 and store repacked values + vst1_u32((uint32_t*)dst_row + 0, veor_u32(vstr0_0, zero_point_conversion_mask)); + vst1_u32((uint32_t*)dst_row + 2, veor_u32(vstr0_1, zero_point_conversion_mask)); + + vst1_u32( + (uint32_t*)(dst_row + nr * block_length_in_bytes) + 0, + veor_u32(vstr0_2, zero_point_conversion_mask)); + vst1_u32( + (uint32_t*)(dst_row + nr * block_length_in_bytes) + 2, + veor_u32(vstr0_3, zero_point_conversion_mask)); + + vst1_u32( + (uint32_t*)(dst_row + (2 * nr * block_length_in_bytes)) + 0, + veor_u32(vstr0_4, zero_point_conversion_mask)); + vst1_u32( + (uint32_t*)(dst_row + (2 * nr * block_length_in_bytes)) + 2, + veor_u32(vstr0_5, zero_point_conversion_mask)); + + vst1_u32( + (uint32_t*)(dst_row + (3 * nr * block_length_in_bytes)) + 0, + veor_u32(vstr0_6, zero_point_conversion_mask)); + vst1_u32( + (uint32_t*)(dst_row + (3 * nr * block_length_in_bytes)) + 2, + veor_u32(vstr0_7, zero_point_conversion_mask)); + + // Calculate and store row sums + partial_sum0 += (int32_t)vaddlvq_u16(vaddl_u8( + vadd_u8(vld0_0_lower, vand_u8(vld0_1, bottom_mask)), + vadd_u8(vld0_0_upper, vshr_n_u8(vld0_1, 4)))); + partial_sum1 += (int32_t)vaddlvq_u16(vaddl_u8( + vadd_u8(vld1_0_lower, vand_u8(vld1_1, bottom_mask)), + vadd_u8(vld1_0_upper, vshr_n_u8(vld1_1, 4)))); + partial_sum2 += (int32_t)vaddlvq_u16(vaddl_u8( + vadd_u8(vld2_0_lower, vand_u8(vld2_1, bottom_mask)), + vadd_u8(vld2_0_upper, vshr_n_u8(vld2_1, 4)))); + partial_sum3 += (int32_t)vaddlvq_u16(vaddl_u8( + vadd_u8(vld3_0_lower, vand_u8(vld3_1, bottom_mask)), + vadd_u8(vld3_0_upper, vshr_n_u8(vld3_1, 4)))); + + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + sums[nr_idx + 0] += (float)partial_sum0 * d0; + sums[nr_idx + 1] += (float)partial_sum1 * d1; + sums[nr_idx + 2] += (float)partial_sum2 * d2; + sums[nr_idx + 3] += (float)partial_sum3 * d3; + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + + dst_row += (4 * block_length_in_bytes); + } + // Skip to end of qblock + dst_row += 3 * nr * block_length_in_bytes; + } + + // Move the pointer after scales + dst_row += num_bytes_multiplier_rhs * nr; + } + + // Move the pointer after the row sum + dst_row += kai_num_bytes_sum_rhs * nr; + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + + // Move the pointer after the row sum + dst_row += kai_num_bytes_bias * nr; + } +} +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..4f7ba1f53c99a65fa451f3e9fc2f507c8dee8f1d --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h @@ -0,0 +1,146 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Get the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// +/// @return the n step value +size_t kai_get_n_step_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(size_t nr); + +/// Gets the offset in bytes for the RHS matrix (not packed), which holds +/// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. +/// +/// Two int4 values are stored in one byte. +/// The lower order part of the byte (low) holds the first nibble (K-index + 0). +/// The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + size_t n_idx, // + size_t rhs_stride); // + +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the stride in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Gets the offset in bytes for the packed RHS matrix. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Gets the size in bytes for the quantized and packed RHS matrix. +/// +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple +/// of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Runs the RHS packing micro-kernel. +/// +/// The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. +/// Two int4 values are stored in one byte. The lower order part of the byte (low) holds +/// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of rows in the RHS matrix (not packed). +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. It must be a multiple of 4. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] bl The block length, which defines the number of +/// K values stored in a single block. It must be a multiple of 32. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). +/// @param[in] rhs_stride The number of bytes per row in bytes of the RHS matrix +/// @param[in] bias The biases. +/// @param[in] scale The per-block quantization scales. +/// The scale data type must be provided with the params object. +/// Supported scale data types are FP32, FP16 and BF16. +/// @param[in] scale_stride The number of bytes per row in bytes of the scale matrix +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params); // + +#ifdef __cplusplus +} +#endif diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp index bd7d45c908aaecf6a56e9c16b17f14083f7866fa..5cd255d5bdf91da650d834c19f355acca341a939 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -34,6 +34,7 @@ #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h" #include "test/common/bfloat16.hpp" #include "test/common/buffer.hpp" @@ -264,7 +265,7 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_nxk) { // Runs the GEMM micro-kernel. Buffer imp_dst(imp_dst_size); if (kr / sr == 8) { - // Test that vectorized packing kernel gives same output as scalar + // Test that vectorized packing kernel for nrx8 gives same output as scalar const auto imp_packed_rhs_size_neon = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt); ASSERT_EQ(imp_packed_rhs_size_neon, imp_packed_rhs_size); @@ -290,6 +291,37 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_nxk) { imp_packed_rhs_neon.data() + rhs_matmul_offset, reinterpret_cast(imp_dst.data() + dst_offset), N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::FP32); + const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); + } else if (kr / sr == 4) { + // Test that vectorized packing kernel for nrx4 gives same output as scalar + const auto imp_packed_rhs_size_neon = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt); + ASSERT_EQ(imp_packed_rhs_size_neon, imp_packed_rhs_size); + + Buffer imp_packed_rhs_neon(imp_packed_rhs_size_neon); + + auto rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + rhs_start_row, K, nr, kr, sr, bl, scale_dt); + ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset); + + auto rhs_offset_neon = + kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(rhs_start_row, ref_rhs_qsu4_stride); + + kai_run_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + 1, rect.width() /* n */, K, nr, kr, sr, bl, + reinterpret_cast(ref_rhs_qsu4_padded.data() + rhs_offset_neon), ref_rhs_qsu4_stride, + reinterpret_cast(ref_biases.data() + bias_offset), + reinterpret_cast(ref_rhs_scales.data() + scale_offset), ref_rhs_scales_stride, + imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, ¶ms); + + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, + imp_packed_rhs_neon.data() + rhs_matmul_offset, reinterpret_cast(imp_dst.data() + dst_offset), + N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); DataFormat dst_format = DataFormat(DataType::FP32); const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler);