diff --git a/CHANGELOG.md b/CHANGELOG.md index 070054f05496b7fa6b78fb1fbc39fdd66737cf88..437929ea78936b6b471349408749f0f5728ce3bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- New Advanced SIMD micro-kernels: + - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F32 output, optimized for FEAT_I8MM. + ## v1.4.0 - New Advanced SIMD micro-kernels: diff --git a/CMakeLists.txt b/CMakeLists.txt index c4ff1c8aac333d29b75e3b9648b7bcedb90a3682..28a80a6c9477344a120bd241a34d380880bc9157 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -123,6 +123,8 @@ set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon.c + kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c ) set(KLEIDIAI_FILES_NEON_DOTPROD_ASM @@ -166,6 +168,8 @@ set(KLEIDIAI_FILES_NEON_I8MM_ASM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.c ) set(KLEIDIAI_FILES_NEON_I8MM @@ -331,6 +335,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp + test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp ) endif() diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 74214896393d65149c6d09a5b303ad91bdba0cdc..a201a2b1715fa2af4893b43b74b98812145db6c3 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -30,9 +30,11 @@ SCALAR_KERNELS = [ NEON_KERNELS = [ "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla", "pack/kai_lhs_quant_pack_qsi8d32p_f32_neon", + "pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon", "pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon", "pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", "pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon", + "pack/kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", "pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", @@ -110,6 +112,7 @@ I8MM_KERNELS = [ "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", + "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm", ] @@ -119,6 +122,7 @@ I8MM_KERNELS_ASM = [ "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", + "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", ] # buildifier: keep sorted diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.c new file mode 100644 index 0000000000000000000000000000000000000000..8c04303f117a5d1407a63cd0ae1c67c480b2dba8 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.c @@ -0,0 +1,177 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) && !defined(_M_ARM64) +#error "I8mm extension required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_offset_rhs = sizeof(float); + +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_bl = 32; + +inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) { + return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs; +} + +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = + (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; + return num_bytes_per_block_rhs; +} + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) { + return kai_mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl); +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); + + size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm( + size_t m_idx, size_t k, size_t bl) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k, bl); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); + + if (m == 0) { + return; + } + const size_t num_subblocks = bl / kai_bl; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h new file mode 100644 index 0000000000000000000000000000000000000000..83355d247a7c04d472d04a3c76c371defddb0825 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h @@ -0,0 +1,150 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qsi8d32p_f32 to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qai4c32p_qau4c32s1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qai4c32p_qau4c32s1s0 to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-block (multiple of 32) quantization (qsi8d32) +/// values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm( + size_t m_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) +/// values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Symmetric Signed 8-bit with per-block (multiple of 32) quantization (qsi8d32) and packed. +/// RHS matrix: Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..f0ed88486464de4236801a931f79b4d3b917baf7 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm_asm.S @@ -0,0 +1,487 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x21, #0x10 + movi v12.16b, #0xf0 + mov x6, #0x80 + ldr x20, [x0, #0x28] + ldr x7, [x0, #0x40] + ldr x8, [x0, #0x38] + ldr x17, [x0, #0x8] + ldr x16, [x0, #0x10] + ldr x15, [x0, #0x30] + mov x14, x20 + madd x6, x7, x6, x21 + ldr x13, [x0, #0x0] + ldr x12, [x0, #0x20] + ldr x11, [x0, #0x18] + cmp x14, #0x8 + mul x6, x8, x6 + blt label_11 +KAI_ASM_LABEL(label_1) // Row loop + mov x10, x16 + mov x9, x15 + add x28, x13, x12, LSL #3 +KAI_ASM_LABEL(label_2) // Column loop + mov x23, x17 + movi v13.16b, #0x0 + movi v18.16b, #0x0 + mov x22, x8 + movi v29.16b, #0x0 + movi v14.16b, #0x0 + movi v15.16b, #0x0 + movi v11.16b, #0x0 + movi v4.16b, #0x0 + movi v6.16b, #0x0 + add x21, x23, x6 +KAI_ASM_LABEL(label_3) // Block loop + movi v25.4s, #0x0 + movi v7.4s, #0x0 + mov x20, x7 + add x23, x23, #0x10 + movi v10.4s, #0x0 + movi v17.4s, #0x0 + add x21, x21, #0x10 + movi v3.4s, #0x0 + movi v8.4s, #0x0 + movi v9.4s, #0x0 + movi v27.4s, #0x0 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q20, [x10, #0x0] + ldr q23, [x10, #0x10] + subs x20, x20, #0x1 + ldr q31, [x23, #0x0] + ldr q22, [x23, #0x10] + ldr q19, [x21, #0x0] + ldr q28, [x21, #0x10] + ldr q5, [x10, #0x20] + ldr q16, [x10, #0x30] + shl v26.16b, v20.16b, #0x4 + shl v0.16b, v23.16b, #0x4 + ldr q1, [x23, #0x20] + ldr q21, [x23, #0x30] + and v20.16b, v20.16b, v12.16b + and v23.16b, v23.16b, v12.16b + ldr q30, [x21, #0x20] + ldr q24, [x21, #0x30] + add x10, x10, #0x40 + ldr q2, [x23, #0x40] + KAI_ASM_INST(0x4e9aa7f9) // smmla v25.4s, v31.16b, v26.16b + KAI_ASM_INST(0x4e80a7e7) // smmla v7.4s, v31.16b, v0.16b + ldr q31, [x23, #0x50] + KAI_ASM_INST(0x4e9aa6ca) // smmla v10.4s, v22.16b, v26.16b + KAI_ASM_INST(0x4e80a6d1) // smmla v17.4s, v22.16b, v0.16b + ldr q22, [x21, #0x40] + KAI_ASM_INST(0x4e9aa663) // smmla v3.4s, v19.16b, v26.16b + KAI_ASM_INST(0x4e80a668) // smmla v8.4s, v19.16b, v0.16b + ldr q19, [x21, #0x50] + KAI_ASM_INST(0x4e9aa789) // smmla v9.4s, v28.16b, v26.16b + ldr q26, [x23, #0x60] + KAI_ASM_INST(0x4e80a79b) // smmla v27.4s, v28.16b, v0.16b + ldr q0, [x23, #0x70] + shl v28.16b, v5.16b, #0x4 + and v5.16b, v5.16b, v12.16b + add x23, x23, #0x80 + KAI_ASM_INST(0x4e9ca439) // smmla v25.4s, v1.16b, v28.16b + KAI_ASM_INST(0x4e9ca6aa) // smmla v10.4s, v21.16b, v28.16b + KAI_ASM_INST(0x4e9ca7c3) // smmla v3.4s, v30.16b, v28.16b + KAI_ASM_INST(0x4e9ca709) // smmla v9.4s, v24.16b, v28.16b + ldr q28, [x21, #0x60] + KAI_ASM_INST(0x4e94a459) // smmla v25.4s, v2.16b, v20.16b + KAI_ASM_INST(0x4e94a7ea) // smmla v10.4s, v31.16b, v20.16b + KAI_ASM_INST(0x4e94a6c3) // smmla v3.4s, v22.16b, v20.16b + KAI_ASM_INST(0x4e94a669) // smmla v9.4s, v19.16b, v20.16b + ldr q20, [x21, #0x70] + add x21, x21, #0x80 + KAI_ASM_INST(0x4e85a759) // smmla v25.4s, v26.16b, v5.16b + KAI_ASM_INST(0x4e85a40a) // smmla v10.4s, v0.16b, v5.16b + KAI_ASM_INST(0x4e85a783) // smmla v3.4s, v28.16b, v5.16b + KAI_ASM_INST(0x4e85a689) // smmla v9.4s, v20.16b, v5.16b + shl v5.16b, v16.16b, #0x4 + and v16.16b, v16.16b, v12.16b + KAI_ASM_INST(0x4e85a427) // smmla v7.4s, v1.16b, v5.16b + KAI_ASM_INST(0x4e85a6b1) // smmla v17.4s, v21.16b, v5.16b + KAI_ASM_INST(0x4e85a7c8) // smmla v8.4s, v30.16b, v5.16b + KAI_ASM_INST(0x4e85a71b) // smmla v27.4s, v24.16b, v5.16b + KAI_ASM_INST(0x4e97a447) // smmla v7.4s, v2.16b, v23.16b + KAI_ASM_INST(0x4e97a7f1) // smmla v17.4s, v31.16b, v23.16b + KAI_ASM_INST(0x4e97a6c8) // smmla v8.4s, v22.16b, v23.16b + KAI_ASM_INST(0x4e97a67b) // smmla v27.4s, v19.16b, v23.16b + KAI_ASM_INST(0x4e90a747) // smmla v7.4s, v26.16b, v16.16b + KAI_ASM_INST(0x4e90a411) // smmla v17.4s, v0.16b, v16.16b + KAI_ASM_INST(0x4e90a788) // smmla v8.4s, v28.16b, v16.16b + KAI_ASM_INST(0x4e90a69b) // smmla v27.4s, v20.16b, v16.16b + bgt label_4 + ldr q30, [x10, #0x0] + ld1 { v31.4s }, [x23] + add x23, x23, #0x10 + uzp1 v23.2d, v25.2d, v7.2d + ldr q26, [x10, #0x10] + ldr q16, [x23, #0x0] + uzp2 v24.2d, v25.2d, v7.2d + uzp1 v21.2d, v10.2d, v17.2d + uzp2 v20.2d, v10.2d, v17.2d + add x10, x10, #0x20 + add x23, x23, #0x10 + fmla v13.4s, v30.4s, v31.s[0] + fmla v18.4s, v30.4s, v31.s[1] + fmla v29.4s, v30.4s, v31.s[2] + scvtf v23.4s, v23.4s + fmla v14.4s, v30.4s, v31.s[3] + fmul v19.4s, v26.4s, v16.s[0] + fmul v1.4s, v26.4s, v16.s[1] + scvtf v24.4s, v24.4s + fmul v22.4s, v26.4s, v16.s[2] + scvtf v21.4s, v21.4s + fmul v16.4s, v26.4s, v16.s[3] + scvtf v20.4s, v20.4s + fmla v13.4s, v23.4s, v19.4s + fmla v18.4s, v24.4s, v1.4s + fmla v29.4s, v21.4s, v22.4s + fmla v14.4s, v20.4s, v16.4s + ld1 { v7.4s }, [x21] + add x21, x21, #0x10 + uzp1 v23.2d, v3.2d, v8.2d + uzp2 v22.2d, v3.2d, v8.2d + ldr q16, [x21, #0x0] + uzp1 v21.2d, v9.2d, v27.2d + uzp2 v20.2d, v9.2d, v27.2d + add x21, x21, #0x10 + fmla v15.4s, v30.4s, v7.s[0] + fmla v11.4s, v30.4s, v7.s[1] + fmla v4.4s, v30.4s, v7.s[2] + fmla v6.4s, v30.4s, v7.s[3] + scvtf v23.4s, v23.4s + fmul v19.4s, v26.4s, v16.s[0] + fmul v8.4s, v26.4s, v16.s[1] + scvtf v22.4s, v22.4s + fmul v30.4s, v26.4s, v16.s[2] + scvtf v21.4s, v21.4s + fmul v16.4s, v26.4s, v16.s[3] + scvtf v20.4s, v20.4s + fmla v15.4s, v23.4s, v19.4s + fmla v11.4s, v22.4s, v8.4s + fmla v4.4s, v21.4s, v30.4s + fmla v6.4s, v20.4s, v16.4s + subs x22, x22, #0x1 + bgt label_3 + ldr q21, [x10, #0x0] + ld1r { v7.4s }, [x11] + add x20, x11, #0x4 + cmp x9, #0x4 + ld1r { v16.4s }, [x20] + add x10, x10, #0x10 + fadd v13.4s, v13.4s, v21.4s + fadd v18.4s, v18.4s, v21.4s + fadd v29.4s, v29.4s, v21.4s + fadd v14.4s, v14.4s, v21.4s + fadd v15.4s, v15.4s, v21.4s + fadd v11.4s, v11.4s, v21.4s + fadd v4.4s, v4.4s, v21.4s + fadd v6.4s, v6.4s, v21.4s + fmax v13.4s, v13.4s, v7.4s + fmax v18.4s, v18.4s, v7.4s + fmax v29.4s, v29.4s, v7.4s + fmax v14.4s, v14.4s, v7.4s + fmax v15.4s, v15.4s, v7.4s + fmax v11.4s, v11.4s, v7.4s + fmax v4.4s, v4.4s, v7.4s + fmax v6.4s, v6.4s, v7.4s + fmin v13.4s, v13.4s, v16.4s + fmin v18.4s, v18.4s, v16.4s + fmin v29.4s, v29.4s, v16.4s + fmin v14.4s, v14.4s, v16.4s + fmin v15.4s, v15.4s, v16.4s + fmin v11.4s, v11.4s, v16.4s + fmin v4.4s, v4.4s, v16.4s + fmin v6.4s, v6.4s, v16.4s + blt label_7 + mov x20, x13 + str q13, [x20, #0x0] + add x20, x20, x12 + str q18, [x20, #0x0] + add x20, x20, x12 + str q29, [x20, #0x0] + add x20, x20, x12 + str q14, [x20, #0x0] + add x20, x20, x12 + str q15, [x20, #0x0] + add x20, x20, x12 + str q11, [x20, #0x0] + add x20, x20, x12 + str q4, [x20, #0x0] + add x20, x20, x12 + str q6, [x20, #0x0] + b label_10 +KAI_ASM_LABEL(label_7) // Partial output + mov x27, x13 + add x26, x27, x12, LSL #2 + add x25, x26, x12, LSL #1 + add x24, x26, x12 + add x23, x25, x12 + add x22, x27, x12, LSL #1 + add x21, x27, x12 + add x20, x22, x12 + tbz x9, #1, label_8 + st1 { v6.d }[0], [x23], #0x8 + st1 { v4.d }[0], [x25], #0x8 + st1 { v11.d }[0], [x24], #0x8 + st1 { v15.d }[0], [x26], #0x8 + st1 { v14.d }[0], [x20], #0x8 + st1 { v29.d }[0], [x22], #0x8 + st1 { v18.d }[0], [x21], #0x8 + st1 { v13.d }[0], [x27], #0x8 + tbz x9, #0, label_9 + st1 { v6.s }[2], [x23] + st1 { v4.s }[2], [x25] + st1 { v11.s }[2], [x24] + st1 { v15.s }[2], [x26] + st1 { v14.s }[2], [x20] + st1 { v29.s }[2], [x22] + st1 { v18.s }[2], [x21] + st1 { v13.s }[2], [x27] + b label_9 +KAI_ASM_LABEL(label_8) // Output block 0: partial_1_0 + st1 { v6.s }[0], [x23] + st1 { v4.s }[0], [x25] + st1 { v11.s }[0], [x24] + st1 { v15.s }[0], [x26] + st1 { v14.s }[0], [x20] + st1 { v29.s }[0], [x22] + st1 { v18.s }[0], [x21] + st1 { v13.s }[0], [x27] +KAI_ASM_LABEL(label_9) // Output block 0: Done +KAI_ASM_LABEL(label_10) // Output stage exit + subs x9, x9, #0x4 + add x13, x13, #0x10 + bgt label_2 + mov x20, #0x2 + sub x14, x14, #0x8 + cmp x14, #0x8 + mov x13, x28 + madd x17, x20, x6, x17 + bge label_1 +KAI_ASM_LABEL(label_11) // Row loop skip + cbz x14, label_21 +KAI_ASM_LABEL(label_12) // Row tail: Row loop + mov x26, x16 + mov x25, x15 + add x24, x13, x12, LSL #2 +KAI_ASM_LABEL(label_13) // Row tail: Column loop + movi v13.16b, #0x0 + movi v18.16b, #0x0 + mov x23, x17 + mov x21, x8 + movi v29.16b, #0x0 + movi v14.16b, #0x0 +KAI_ASM_LABEL(label_14) // Row tail: Block loop + movi v25.4s, #0x0 + movi v7.4s, #0x0 + mov x20, x7 + add x23, x23, #0x10 + movi v10.4s, #0x0 + movi v17.4s, #0x0 +KAI_ASM_LABEL(label_15) // Row tail: Sub block loop + ldr q0, [x26, #0x0] + ldr q31, [x26, #0x10] + subs x20, x20, #0x1 + ldr q30, [x23, #0x0] + ldr q28, [x23, #0x10] + ldr q27, [x26, #0x20] + ldr q26, [x26, #0x30] + add x26, x26, #0x40 + ldr q6, [x23, #0x20] + ldr q24, [x23, #0x30] + shl v23.16b, v0.16b, #0x4 + shl v22.16b, v31.16b, #0x4 + ldr q21, [x23, #0x40] + ldr q20, [x23, #0x50] + and v0.16b, v0.16b, v12.16b + and v31.16b, v31.16b, v12.16b + ldr q19, [x23, #0x60] + ldr q2, [x23, #0x70] + shl v11.16b, v27.16b, #0x4 + shl v16.16b, v26.16b, #0x4 + KAI_ASM_INST(0x4e97a7d9) // smmla v25.4s, v30.16b, v23.16b + KAI_ASM_INST(0x4e96a7c7) // smmla v7.4s, v30.16b, v22.16b + and v27.16b, v27.16b, v12.16b + add x23, x23, #0x80 + KAI_ASM_INST(0x4e97a78a) // smmla v10.4s, v28.16b, v23.16b + KAI_ASM_INST(0x4e96a791) // smmla v17.4s, v28.16b, v22.16b + and v26.16b, v26.16b, v12.16b + KAI_ASM_INST(0x4e8ba4d9) // smmla v25.4s, v6.16b, v11.16b + KAI_ASM_INST(0x4e90a4c7) // smmla v7.4s, v6.16b, v16.16b + KAI_ASM_INST(0x4e8ba70a) // smmla v10.4s, v24.16b, v11.16b + KAI_ASM_INST(0x4e90a711) // smmla v17.4s, v24.16b, v16.16b + KAI_ASM_INST(0x4e80a6b9) // smmla v25.4s, v21.16b, v0.16b + KAI_ASM_INST(0x4e9fa6a7) // smmla v7.4s, v21.16b, v31.16b + KAI_ASM_INST(0x4e80a68a) // smmla v10.4s, v20.16b, v0.16b + KAI_ASM_INST(0x4e9fa691) // smmla v17.4s, v20.16b, v31.16b + KAI_ASM_INST(0x4e9ba679) // smmla v25.4s, v19.16b, v27.16b + KAI_ASM_INST(0x4e9aa667) // smmla v7.4s, v19.16b, v26.16b + KAI_ASM_INST(0x4e9ba44a) // smmla v10.4s, v2.16b, v27.16b + KAI_ASM_INST(0x4e9aa451) // smmla v17.4s, v2.16b, v26.16b + bgt label_15 + ldr q6, [x26, #0x0] + ld1 { v3.4s }, [x23] + add x23, x23, #0x10 + uzp1 v24.2d, v25.2d, v7.2d + ldr q23, [x26, #0x10] + ldr q16, [x23, #0x0] + uzp2 v22.2d, v25.2d, v7.2d + uzp1 v21.2d, v10.2d, v17.2d + uzp2 v20.2d, v10.2d, v17.2d + add x26, x26, #0x20 + add x23, x23, #0x10 + fmla v13.4s, v6.4s, v3.s[0] + fmla v18.4s, v6.4s, v3.s[1] + fmla v29.4s, v6.4s, v3.s[2] + scvtf v24.4s, v24.4s + fmla v14.4s, v6.4s, v3.s[3] + fmul v19.4s, v23.4s, v16.s[0] + fmul v31.4s, v23.4s, v16.s[1] + scvtf v22.4s, v22.4s + fmul v17.4s, v23.4s, v16.s[2] + scvtf v21.4s, v21.4s + fmul v16.4s, v23.4s, v16.s[3] + scvtf v20.4s, v20.4s + fmla v13.4s, v24.4s, v19.4s + fmla v18.4s, v22.4s, v31.4s + fmla v29.4s, v21.4s, v17.4s + fmla v14.4s, v20.4s, v16.4s + subs x21, x21, #0x1 + bgt label_14 + ldr q20, [x26, #0x0] + ld1r { v17.4s }, [x11] + add x20, x11, #0x4 + cmp x25, #0x4 + ld1r { v16.4s }, [x20] + add x26, x26, #0x10 + fadd v13.4s, v13.4s, v20.4s + fadd v18.4s, v18.4s, v20.4s + fadd v29.4s, v29.4s, v20.4s + fadd v14.4s, v14.4s, v20.4s + fmax v13.4s, v13.4s, v17.4s + fmax v18.4s, v18.4s, v17.4s + fmax v29.4s, v29.4s, v17.4s + fmax v14.4s, v14.4s, v17.4s + fmin v13.4s, v13.4s, v16.4s + fmin v18.4s, v18.4s, v16.4s + fmin v29.4s, v29.4s, v16.4s + fmin v14.4s, v14.4s, v16.4s + blt label_17 + mov x20, x13 + cmp x14, #0x1 + str q13, [x20, #0x0] + add x20, x20, x12 + ble label_20 + cmp x14, #0x2 + str q18, [x20, #0x0] + add x20, x20, x12 + ble label_20 + cmp x14, #0x3 + str q29, [x20, #0x0] + add x20, x20, x12 + ble label_20 + str q14, [x20, #0x0] + b label_20 +KAI_ASM_LABEL(label_17) // Row tail: Partial output + mov x23, x13 + cmp x14, #0x1 + add x22, x23, x12 + csel x22, x22, x23, GT + cmp x14, #0x2 + add x21, x23, x12, LSL #1 + csel x21, x21, x22, GT + cmp x14, #0x3 + add x20, x21, x12 + csel x20, x20, x21, GT + tbz x25, #1, label_18 + st1 { v14.d }[0], [x20], #0x8 + st1 { v29.d }[0], [x21], #0x8 + st1 { v18.d }[0], [x22], #0x8 + st1 { v13.d }[0], [x23], #0x8 + tbz x25, #0, label_19 + st1 { v14.s }[2], [x20] + st1 { v29.s }[2], [x21] + st1 { v18.s }[2], [x22] + st1 { v13.s }[2], [x23] + b label_19 +KAI_ASM_LABEL(label_18) // Row tail: Output block 0: partial_1_0 + st1 { v14.s }[0], [x20] + st1 { v29.s }[0], [x21] + st1 { v18.s }[0], [x22] + st1 { v13.s }[0], [x23] +KAI_ASM_LABEL(label_19) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_20) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x13, x13, #0x10 + bgt label_13 + subs x14, x14, #0x4 + add x17, x17, x6 + mov x13, x24 + bgt label_12 +KAI_ASM_LABEL(label_21) // Row tail: Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p_qai4c32p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p_qai4c32p_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..b6ba7545f10810f21dc12d20c35a01f36cbba945 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p_qai4c32p_interface.h @@ -0,0 +1,54 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_qsi8d32p_qai4c32p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_lhs_packed_offset_func_t)( + size_t m_idx, size_t k, size_t bl); +typedef size_t (*kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_rhs_packed_offset_func_t)( + size_t n_idx, size_t k, size_t bl); +typedef size_t (*kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_f32_qsi8d32p_qai4c32p_run_matmul_func_t)( + size_t m, size_t n, size_t k, size_t bl, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_f32_qsi8d32p_qai4c32p_ukernel { + kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_mr_func_t get_mr; + kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_nr_func_t get_nr; + kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_kr_func_t get_kr; + kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_sr_func_t get_sr; + kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_qsi8d32p_qai4c32p_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_qsi8d32p_qai4c32p_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c new file mode 100644 index 0000000000000000000000000000000000000000..49f4a1beb2acdaa32026c4f9e2d809c154c90940 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c @@ -0,0 +1,145 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. + +#include "kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.h" + +#include +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_sum = sizeof(float); +static const size_t kai_num_bytes_multiplier = sizeof(float); +static const size_t kai_bl_multiple_of = 32; + +inline static size_t kai_get_num_bytes_per_block(size_t bl) { + return bl * sizeof(int8_t) + kai_num_bytes_multiplier + kai_num_bytes_sum; +} + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((k % bl) == 0); + return k / bl; +} + +inline static size_t kai_get_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { + KAI_UNUSED(kr); + + return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block(bl); +} + +size_t kai_get_m_step_lhs_quant_pack_qsi8d32pscalef32_f32_neon(size_t mr) { + KAI_UNUSED(mr); + return 1; +} + +size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon( + size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((m_idx % mr) == 0); + + KAI_UNUSED(sr); + return (m_idx / mr) * kai_get_lhs_packed_stride(k, mr, kr, bl); +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + + KAI_UNUSED(sr); + + const size_t num_rows = kai_roundup(m, mr) / mr; + + return (num_rows * kai_get_lhs_packed_stride(k, mr, kr, bl)); +} +void kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, + size_t lhs_stride, void* lhs_packed) { + KAI_ASSERT((kr % sr) == 0); + KAI_ASSUME((bl % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl_multiple_of) == 0); + + if (m == 0) { + return; + } + const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, mr, kr, bl); + const size_t num_rows = m; + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl); + const size_t mr_block_size = mr * num_bytes_per_block; + + const int32_t k_block_len = (int32_t)(kr / sr); + + for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { + const float* row_src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); + const size_t dst_idx = ((row_idx + m_idx_start) % mr); + for (size_t blk_idx = 0; blk_idx < num_blocks_per_row; ++blk_idx) { + const float* src_ptr = row_src_ptr + blk_idx * bl; + int8_t* dst_ptr = (int8_t*)lhs_packed + dst_idx * k_block_len * sizeof(int8_t) + blk_idx * mr_block_size; + int8_t* param_ptr = (int8_t*)lhs_packed + blk_idx * mr_block_size + bl * mr + dst_idx * kai_num_bytes_sum; + // Find absmax for each block + float absmax = -FLT_MAX; + int32_t k_idx = 0; + float32x4_t vabsmax = vdupq_n_f32(-FLT_MAX); + for (; k_idx < ((int32_t)bl); k_idx += 8) { + const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx); + const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx); + // Calculate the max + vabsmax = vmaxq_f32(vabsq_f32(src0_0), vmaxq_f32(vabsmax, vabsq_f32(src0_1))); + } + // Get the absmax + absmax = vmaxvq_f32(vabsmax); + // Maximum/minimum int8 values + const float qmax = (float)INT8_MAX; + const float scale0 = absmax == 0.0F ? 0.0F : qmax / absmax; + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; + int32_t qsum = 0; + // Quantize the blocks + for (k_idx = 0; k_idx < (int32_t)bl; k_idx += k_block_len) { + for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; ++k_block_idx) { + // Clamp at the last valid k-index + const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k - 1); + + const float src0_0 = *(src_ptr + k_idx_start); + + // Scale the values + int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); + + v0_s32 = KAI_MAX(v0_s32, INT8_MIN); + v0_s32 = KAI_MIN(v0_s32, INT8_MAX); + qsum += v0_s32; + + *(dst_ptr) = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + *((float*)(param_ptr)) = ((float)qsum) * recip_scale0; + param_ptr += mr * kai_num_bytes_sum; + *((float*)(param_ptr)) = recip_scale0; + } + // Move to the next row if we have interleaved all Mr rows + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); + } + } +} +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..ab2ef40bd14d5c357669a093bf6585a15d547b8c --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.h @@ -0,0 +1,83 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @param[in] mr The number of M rows to interleave on the same output row. +/// +/// @return the m step value +size_t kai_get_m_step_lhs_quant_pack_qsi8d32pscalef32_f32_neon(size_t mr); + +/// Gets the offset in bytes for the LHS matrix (not packed) +/// +/// This function should be called before passing the pointer to the LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) +/// +/// @return the offset in bytes to the LHS matrix +size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized symmetric per-block (qsi8d32) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of mr. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] bl The block length. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon( + size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes for the quantized and packed LHS matrix +/// +/// @param[in] m Total number of rows in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed LHS matrix size in bytes +size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + +/// Run the micro-kernel to quantize and pack the LHS matrix. +/// +/// @param[in] m The number of output rows written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be +/// a multiple of 32. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] m_idx_start The starting M index. +/// @param[in] lhs LHS matrix. +/// @param[in] lhs_stride Stride in bytes between two rows of LHS. +/// @param[out] lhs_packed The quantized and packed LHS matrix. +void kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( + size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, + size_t lhs_stride, void* lhs_packed); + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon.c new file mode 100644 index 0000000000000000000000000000000000000000..1ff7f725062ea1ec387c392809b72d1183659b67 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon.c @@ -0,0 +1,164 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. +#include "kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon.h" + +#include +#include + +#include "kai/kai_common.h" +static const size_t kai_num_bytes_offset_rhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); +static const size_t kai_bl_multiple_of = 32; + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} +inline static size_t kai_num_bytes_per_block(size_t bl) { + return (bl / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; +} +inline static size_t kai_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kr) == 0); + KAI_ASSUME((bl % kai_bl_multiple_of) == 0); + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); + return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias); +} +size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon(size_t n_idx, size_t rhs_stride) { + return n_idx * rhs_stride; +} +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % nr) == 0); + KAI_UNUSED(kr); + return (n_idx / nr) * kai_rhs_packed_stride(k, nr, kr, bl); +} +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon( + size_t n, size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_UNUSED(kr); + const size_t num_rows = kai_roundup(n, nr) / nr; + return num_rows * kai_rhs_packed_stride(k, nr, kr, bl); +} +void kai_run_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, + const float* zero, const float* bias, const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon_params* params) { + // Temporary asserts + KAI_ASSUME(num_groups == 1); + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % 32) == 0); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(sr == 2); + KAI_ASSUME(kr >= 1 && kr <= 16); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(zero != NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(params != NULL); + KAI_ASSUME(params->rhs_zero_point == 8); + KAI_ASSUME(params->lhs_zero_point == 1); + // Note: The input matrix (rhs) is expected with: + // "k" columns and "n" rows (NxK) + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t block_num_per_row = k_internal / bl; + const size_t rhs_stride = k; + const size_t k_interleaved_v = 16U; + const size_t dst_packed_stride = kai_rhs_packed_stride(k, nr, kr, bl); + const size_t dst_packed_block_size = kai_num_bytes_per_block(bl) * nr; + const size_t dst_block_data_size = (bl / 2) * nr; + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_bias_offset = block_num_per_row * dst_packed_block_size; + const size_t k_block_length_in_bytes = kr / sr; + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * dst_packed_stride; + float* dst_row_bias = (float*)(dst_row + dst_bias_offset); + for (size_t block_idx = 0; block_idx < block_num_per_row; block_idx++) { + uint8_t* block_dst_row = dst_row + block_idx * dst_packed_block_size; + for (size_t block_byte_idx = 0; block_byte_idx < dst_block_data_size; ++block_byte_idx) { + const size_t dst_byte_idx = block_byte_idx; + const size_t k_block_idx = dst_byte_idx / k_block_length_in_bytes; + const size_t k_block_byte_idx = dst_byte_idx % k_block_length_in_bytes; + const size_t super_k_block_idx = k_block_idx / nr; + const size_t nr_idx = k_block_idx % nr; + const size_t k_adjustment = + ((k_block_byte_idx + super_k_block_idx * k_block_length_in_bytes) / k_interleaved_v) * + k_interleaved_v; + const size_t k0_idx = k_block_byte_idx + super_k_block_idx * k_block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + const size_t src_addr_byte0 = (k0_idx + n0_valid_idx * rhs_stride + block_idx * bl) / 2; + const size_t src_addr_byte1 = (k1_idx + n0_valid_idx * rhs_stride + block_idx * bl) / 2; + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + const size_t shift_right_x0 = (k0_idx % 2) * 4; + const size_t shift_right_x1 = (k1_idx % 2) * 4; + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + *block_dst_row = dst_qs0 ^ 0x88; + block_dst_row += sizeof(uint8_t); + } + // Adjust the zero point + for (size_t i = 0; i < nr; ++i) { + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + const float* block_zero = zero + block_num_per_row * src_row_idx; + *((float*)(block_dst_row)) = block_zero[block_idx]; + block_dst_row += sizeof(float); + } + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + const float* block_scale = scale + block_num_per_row * src_row_idx; + *((float*)(block_dst_row)) = block_scale[block_idx] * 0.0625F; + block_dst_row += sizeof(float); + } + } + // Set the bias + if (bias == NULL) { + memset(dst_row_bias, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + dst_row_bias[i] = bias[src_row_idx]; + } + } + } +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..b17e691b56b51ceb7b6065374c5274534fdfb8ae --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon.h @@ -0,0 +1,111 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif +struct kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon_params { + int8_t lhs_zero_point; + uint8_t rhs_zero_point; +}; +/// Gets the offset in bytes for the RHS matrix (not packed), which holds +/// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. +/// +/// Two int4 K values are stored in one byte. These values are stored in blocks, where each block +/// has it own scale factor. The scale factor is expected to be a f16 value and stored at the beginning of each block. +/// The first byte in the block holds the K-index + 0 and K-index + 16 values. +/// The K-index + 0 value is stored in the lower order part of the byte (low nibble) while +/// the K-index + 16 value is stored in the higher order part (high nibble). +/// For example, if the block length is 32, the values are stored in the following order: +/// |byte(s16, s0),byte(s17, s1),byte(s18, s2),...,byte(s31, s15),float16(scale)| +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon( + size_t n_idx, // + size_t rhs_stride); // +/// Gets the offset in bytes for the packed RHS matrix. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The common dimension between the LHS and RHS matrix (K) +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t bl // +); +/// Gets the size in bytes for the quantized and packed RHS matrix. +/// +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple +/// of 32. +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t bl // +); +/// Runs the RHS packing micro-kernel. +/// +/// The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. +/// Two int4 values are stored in one byte. The lower order part of the byte (low) holds +/// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of columns of the output matrix (N). +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] bl The block length, which defines the number of +/// K values stored in a single block. It must be a multiple of 32. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n/// k/// (sizeof(uint8_t) / 2). +/// @param[in] zero Asymmetric quant zero point. +/// @param[in] bias The biases. +/// @param[in] scale The scale for each output channel. +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + const float* zero, + const float* bias, // + const float* scale, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon_params* params); +#ifdef __cplusplus +} +#endif diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index 82a09fd7614db1b9f25a927414f7a49ccc177c56..b76e57b2419396555730f6299285585d92448c4f 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -15,6 +15,7 @@ #include #include "test/common/bfloat16.hpp" +#include "test/common/float16.hpp" #include "test/common/int4.hpp" #include "test/common/memory.hpp" #include "test/common/numeric_limits.hpp" @@ -28,8 +29,8 @@ namespace { template std::tuple get_scale_zero_point_from_range(FloatData min_value, FloatData max_value) { - constexpr FloatData q_min = std::numeric_limits::min(); - constexpr FloatData q_max = std::numeric_limits::max(); + const FloatData q_min = numeric_lowest; + const FloatData q_max = numeric_highest; if (min_value > 0) { min_value = 0; @@ -289,5 +290,7 @@ quantize_asymmetric_per_block_dynamic( template std::tuple, std::vector, std::vector> quantize_asymmetric_per_block_dynamic( const void* src, size_t height, size_t width, size_t quant_width); - +template std::tuple, std::vector, std::vector> +quantize_asymmetric_per_block_dynamic( + const void* src, size_t height, size_t width, size_t quant_width); } // namespace kai::test diff --git a/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp b/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..702ffc445256df2102e5bb349f7bff8055db7248 --- /dev/null +++ b/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp @@ -0,0 +1,320 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p_qai4c32p_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon.h" +#include "test/common/cpu_info.hpp" +#include "test/common/float16.hpp" +#include "test/common/int4.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/memory.hpp" +#include "test/common/round.hpp" +#include "test/common/test_suite.hpp" +#include "test/reference/cast.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/quantize.hpp" + +namespace kai::test { + +static const std::array, 1> + variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p = { + {{UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm), + "kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", cpu_has_i8mm}}}; + +class MatMulTest_f32_qsi8d32p_qai4c32p : public ::testing::TestWithParam {}; + +TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, EndToEnd_NullBias) { + const auto& [variant_index, matmul_shape, portion] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_index); + + if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { + GTEST_SKIP() << "Kernel not supported"; + } + + const std::uint64_t seed = 0; + + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; + const size_t bl = 32; + + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + + if (mr == 1 && M > 1) { + GTEST_SKIP() << "Kernel does not support M != 1"; + } + + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP() << "Test Portion size is 0!"; + } + + // Generates input data. + const auto ref_lhs = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit symmetric quantization. + // * Quantizes the RHS matrix using 8-bit asymmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales] = + quantize_symmetric_per_block_dynamic(ref_lhs.data(), M, K, bl); + const auto [ref_rhs_qai4, ref_rhs_scales, ref_rhs_zero_points] = + quantize_asymmetric_per_block_dynamic(ref_rhs.data(), N, K, bl); + + const auto ref_dst = matmul_clamp_nt_t( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), nullptr, bl, ref_rhs_qai4.data(), ref_rhs_scales.data(), + ref_rhs_zero_points.data(), bl, nullptr, std::numeric_limits::lowest(), + std::numeric_limits::max()); + + // Runs the LHS packing micro-kernel. + const auto lhs_start_row = rect.start_row(); + const auto imp_packed_lhs_size = + kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon(M, K, bl, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + + auto lhs_stride = K * sizeof(float); + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, lhs_stride); + auto lhs_packed_offset = + kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, K, bl, mr, kr, sr); + auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K, bl); + + ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); + + kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( + rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data() + lhs_offset), + lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); + + // Runs the RHS packing micro-kernel. + const auto ref_rhs_qau4 = cast_qsu4_qsi4(ref_rhs_qai4.data(), N * K); + + const auto imp_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon(N, K, nr, kr, bl); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const auto rhs_start_row = rect.start_col(); + auto rhs_packed_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon(rhs_start_row, K, nr, kr, bl); + auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl); + ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); + + const kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon_params params{.lhs_zero_point = 1, .rhs_zero_point = 8}; + kai_run_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon( + 1, N, K, nr, kr, sr, bl, ref_rhs_qau4.data(), reinterpret_cast(ref_rhs_zero_points.data()), + nullptr, reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); + + const auto dst_stride_row = N * sizeof(float); + const auto dst_stride_col = sizeof(float); + const auto dst_offset = + ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); + const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; + ASSERT_EQ(dst_offset, ref_dst_offset); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, + imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast(imp_dst.data() + dst_offset), + dst_stride_row, dst_stride_col, std::numeric_limits::lowest(), std::numeric_limits::max()); + + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + for (size_t y = 0; y < rect.height(); ++y) { + for (size_t x = 0; x < rect.width(); ++x) { + const auto imp_value = + read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto ref_value = + read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; + + if (rel_error > 0.0001F) { + ASSERT_EQ(imp_value, ref_value); + } + } + } +} + +TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, EndToEnd) { + const auto& [variant_index, matmul_shape, portion] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_index); + + if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { + GTEST_SKIP() << "Kernel not supported"; + } + + const std::uint64_t seed = 0; + + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; + const size_t bl = 32; + + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + + if (mr == 1 && M > 1) { + GTEST_SKIP() << "Kernel does not support M != 1"; + } + + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP() << "Test Portion size is 0!"; + } + + // Generates input data. + const auto ref_lhs = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + const auto ref_biases = fill_random(N, seed + 2); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit symmetric quantization. + // * Quantizes the RHS matrix using 8-bit asymmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales] = + quantize_symmetric_per_block_dynamic(ref_lhs.data(), M, K, bl); + const auto [ref_rhs_qai4, ref_rhs_scales, ref_rhs_zero_points] = + quantize_asymmetric_per_block_dynamic(ref_rhs.data(), N, K, bl); + + const auto ref_dst = matmul_clamp_nt_t( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), nullptr, bl, ref_rhs_qai4.data(), ref_rhs_scales.data(), + ref_rhs_zero_points.data(), bl, ref_biases.data(), std::numeric_limits::lowest(), + std::numeric_limits::max()); + + // Runs the LHS packing micro-kernel. + const auto lhs_start_row = rect.start_row(); + const auto imp_packed_lhs_size = + kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon(M, K, bl, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + + auto lhs_stride = K * sizeof(float); + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, lhs_stride); + auto lhs_packed_offset = + kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, K, bl, mr, kr, sr); + auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K, bl); + + ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); + + kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( + rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data() + lhs_offset), + lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); + + // Runs the RHS packing micro-kernel. + const auto ref_rhs_qau4 = cast_qsu4_qsi4(ref_rhs_qai4.data(), N * K); + + const auto imp_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon(N, K, nr, kr, bl); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const auto rhs_start_row = rect.start_col(); + auto rhs_packed_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon(rhs_start_row, K, nr, kr, bl); + auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl); + ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); + + const kai_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon_params params{.lhs_zero_point = 1, .rhs_zero_point = 8}; + kai_run_rhs_pack_nxk_qai4c32pscalef32_qau4c32s1s0_neon( + 1, N, K, nr, kr, sr, bl, ref_rhs_qau4.data(), reinterpret_cast(ref_rhs_zero_points.data()), + reinterpret_cast(ref_biases.data()), reinterpret_cast(ref_rhs_scales.data()), + imp_packed_rhs.data(), 0, ¶ms); + + const auto dst_stride_row = N * sizeof(float); + const auto dst_stride_col = sizeof(float); + const auto dst_offset = + ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); + const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; + ASSERT_EQ(dst_offset, ref_dst_offset); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, + imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast(imp_dst.data() + dst_offset), + dst_stride_row, dst_stride_col, std::numeric_limits::lowest(), std::numeric_limits::max()); + + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + for (size_t y = 0; y < rect.height(); ++y) { + for (size_t x = 0; x < rect.width(); ++x) { + const auto imp_value = + read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto ref_value = + read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; + + if (rel_error > 0.0001F) { + ASSERT_EQ(imp_value, ref_value); + } + } + } +} +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTest_f32_qsi8d32p_qai4c32p, + testing::Combine( + testing::Range(0, variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.size()), + testing::Values( + MatMulShape{1, 2, 32}, // + MatMulShape{32, 64, 64}, // + MatMulShape{16, 32, 64}, // + MatMulShape{8, 32, 64}, // + MatMulShape{15, 32, 32}, // + MatMulShape{77, 99, 64}), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. + MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. + MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle + )), + [](const auto& info) { + const auto variant_idx = std::get<0>(info.param); + const std::string name{variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_idx).name}; + const auto shape = std::get(info.param); + const auto portion = std::get<2>(info.param); + + std::stringstream sstream; + sstream << name << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // + << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // + << "__PortionHeight_" << static_cast(portion.height() * 1000) // + << "__PortionWidth_" << static_cast(portion.width() * 1000); + return sstream.str(); + }); + +} // namespace kai::test