diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fb19843a7bb97c24f2b66820fba4af7749ced86..ef8489ee2120b94043afdd52b2d1bb2d7e6e6279 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme - kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme - New Advanced SIMD micro-kernels: + - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F32 output, optimized for FEAT_DotProd. - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_DotProd. - Optimized version of kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 kernel for block depth of 8 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon`) - Added Convolution example using SME Indirect Matmul Kernels diff --git a/CMakeLists.txt b/CMakeLists.txt index c52c2f8bc5a632d5efa4a4aa4fa9ddc8efdcedee..b90e58b01fa65f547c11eaafeb08f90aadece4f4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,8 +185,10 @@ set(KLEIDIAI_FILES_NEON_DOTPROD_ASM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod_asm.S ) set(KLEIDIAI_FILES_NEON_DOTPROD diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 0531d57c3e556ce926537e1f25185be0b69d66bc..a7218a025fd5d7cf48071ebc73b1c21bb6223f4e 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -116,6 +116,7 @@ DOTPROD_KERNELS_ASM = [ "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod", "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod", + "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod", ] # buildifier: keep sorted diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.c new file mode 100644 index 0000000000000000000000000000000000000000..6b42d442ccfef5b9a583e58ecf5c21685f4a9ab3 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.c @@ -0,0 +1,176 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(_M_ARM64) +#error "Dotprod extension required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.h" + +#include + +#include "kai/kai_common.h" + +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 8; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_sum_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_offset_rhs = 4; + +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_bl = 32; + +inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) { + return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs; +} + +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = + (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; + return num_bytes_per_block_rhs; +} + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) { + return kai_mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl); +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); + + size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod( + size_t m_idx, size_t k, size_t bl) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k, bl); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); + if (m == 0) { + return; + } + const size_t num_subblocks = bl / kai_bl; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.h new file mode 100644 index 0000000000000000000000000000000000000000..d69eca6d518c3cf4f6f76e6de47659355d2ed70d --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.h @@ -0,0 +1,150 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon to dynamically quantize and pack the LHS matrix in a single +/// step. +/// -# @ref kai_run_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon to pack the RHS NxK matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-block (multiple of 32) quantization (qsi8d32) +/// values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod( + size_t m_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) +/// values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Symmetric Signed 8-bit with per-block (multiple of 32) quantization (qsi8d32) and packed. +/// RHS matrix: Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..ea8af633508e419d4967446421291ee4f6390cf4 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod_asm.S @@ -0,0 +1,520 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x21, #0x20 + movi v13.16b, #0xf0 + mov x6, #0x80 + ldr x20, [x0, #0x28] + ldr x7, [x0, #0x40] + ldr x8, [x0, #0x38] + ldr x17, [x0, #0x8] + ldr x16, [x0, #0x10] + ldr x15, [x0, #0x30] + mov x14, x20 + madd x6, x7, x6, x21 + ldr x13, [x0, #0x0] + ldr x12, [x0, #0x20] + ldr x11, [x0, #0x18] + cmp x14, #0x8 + mul x6, x8, x6 + blt label_11 +KAI_ASM_LABEL(label_1) // Row loop + mov x10, x16 + mov x9, x15 + add x28, x13, x12, LSL #3 +KAI_ASM_LABEL(label_2) // Column loop + mov x23, x17 + movi v24.16b, #0x0 + movi v11.16b, #0x0 + mov x22, x8 + movi v23.16b, #0x0 + movi v14.16b, #0x0 + movi v12.16b, #0x0 + movi v31.16b, #0x0 + movi v7.16b, #0x0 + movi v25.16b, #0x0 + add x21, x23, x6 +KAI_ASM_LABEL(label_3) // Block loop + movi v6.4s, #0x0 + movi v2.4s, #0x0 + mov x20, x7 + movi v22.4s, #0x0 + movi v3.4s, #0x0 + movi v9.4s, #0x0 + movi v20.4s, #0x0 + movi v5.4s, #0x0 + movi v0.4s, #0x0 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q30, [x10, #0x0] + ldr q15, [x23, #0x0] + subs x20, x20, #0x1 + ldr q18, [x21, #0x0] + ldr q4, [x10, #0x10] + ldr q29, [x23, #0x10] + ldr q10, [x21, #0x10] + ldr q26, [x10, #0x20] + ldr q28, [x23, #0x20] + shl v19.16b, v30.16b, #0x4 + and v30.16b, v30.16b, v13.16b + ldr q17, [x21, #0x20] + ldr q16, [x10, #0x30] + shl v27.16b, v4.16b, #0x4 + and v4.16b, v4.16b, v13.16b + ldr q8, [x23, #0x30] + ldr q21, [x21, #0x30] + add x10, x10, #0x40 + ldr q1, [x23, #0x40] + KAI_ASM_INST(0x4f8fe266) // sdot v6.4s, v19.16b, v15.4b[0] + KAI_ASM_INST(0x4fafe262) // sdot v2.4s, v19.16b, v15.4b[1] + KAI_ASM_INST(0x4f8fea76) // sdot v22.4s, v19.16b, v15.4b[2] + KAI_ASM_INST(0x4fafea63) // sdot v3.4s, v19.16b, v15.4b[3] + ldr q15, [x21, #0x40] + KAI_ASM_INST(0x4f92e269) // sdot v9.4s, v19.16b, v18.4b[0] + KAI_ASM_INST(0x4fb2e274) // sdot v20.4s, v19.16b, v18.4b[1] + KAI_ASM_INST(0x4f92ea65) // sdot v5.4s, v19.16b, v18.4b[2] + KAI_ASM_INST(0x4fb2ea60) // sdot v0.4s, v19.16b, v18.4b[3] + ldr q18, [x23, #0x50] + ldr q19, [x21, #0x50] + KAI_ASM_INST(0x4f9de366) // sdot v6.4s, v27.16b, v29.4b[0] + KAI_ASM_INST(0x4fbde362) // sdot v2.4s, v27.16b, v29.4b[1] + KAI_ASM_INST(0x4f9deb76) // sdot v22.4s, v27.16b, v29.4b[2] + KAI_ASM_INST(0x4fbdeb63) // sdot v3.4s, v27.16b, v29.4b[3] + ldr q29, [x23, #0x60] + KAI_ASM_INST(0x4f8ae369) // sdot v9.4s, v27.16b, v10.4b[0] + KAI_ASM_INST(0x4faae374) // sdot v20.4s, v27.16b, v10.4b[1] + KAI_ASM_INST(0x4f8aeb65) // sdot v5.4s, v27.16b, v10.4b[2] + KAI_ASM_INST(0x4faaeb60) // sdot v0.4s, v27.16b, v10.4b[3] + ldr q10, [x21, #0x60] + shl v27.16b, v26.16b, #0x4 + and v26.16b, v26.16b, v13.16b + KAI_ASM_INST(0x4f9ce366) // sdot v6.4s, v27.16b, v28.4b[0] + KAI_ASM_INST(0x4fbce362) // sdot v2.4s, v27.16b, v28.4b[1] + KAI_ASM_INST(0x4f9ceb76) // sdot v22.4s, v27.16b, v28.4b[2] + KAI_ASM_INST(0x4fbceb63) // sdot v3.4s, v27.16b, v28.4b[3] + ldr q28, [x23, #0x70] + add x23, x23, #0x80 + KAI_ASM_INST(0x4f91e369) // sdot v9.4s, v27.16b, v17.4b[0] + KAI_ASM_INST(0x4fb1e374) // sdot v20.4s, v27.16b, v17.4b[1] + KAI_ASM_INST(0x4f91eb65) // sdot v5.4s, v27.16b, v17.4b[2] + KAI_ASM_INST(0x4fb1eb60) // sdot v0.4s, v27.16b, v17.4b[3] + ldr q27, [x21, #0x70] + shl v17.16b, v16.16b, #0x4 + and v16.16b, v16.16b, v13.16b + add x21, x21, #0x80 + KAI_ASM_INST(0x4f88e226) // sdot v6.4s, v17.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e222) // sdot v2.4s, v17.16b, v8.4b[1] + KAI_ASM_INST(0x4f88ea36) // sdot v22.4s, v17.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8ea23) // sdot v3.4s, v17.16b, v8.4b[3] + KAI_ASM_INST(0x4f95e229) // sdot v9.4s, v17.16b, v21.4b[0] + KAI_ASM_INST(0x4fb5e234) // sdot v20.4s, v17.16b, v21.4b[1] + KAI_ASM_INST(0x4f95ea25) // sdot v5.4s, v17.16b, v21.4b[2] + KAI_ASM_INST(0x4fb5ea20) // sdot v0.4s, v17.16b, v21.4b[3] + KAI_ASM_INST(0x4f81e3c6) // sdot v6.4s, v30.16b, v1.4b[0] + KAI_ASM_INST(0x4fa1e3c2) // sdot v2.4s, v30.16b, v1.4b[1] + KAI_ASM_INST(0x4f81ebd6) // sdot v22.4s, v30.16b, v1.4b[2] + KAI_ASM_INST(0x4fa1ebc3) // sdot v3.4s, v30.16b, v1.4b[3] + KAI_ASM_INST(0x4f8fe3c9) // sdot v9.4s, v30.16b, v15.4b[0] + KAI_ASM_INST(0x4fafe3d4) // sdot v20.4s, v30.16b, v15.4b[1] + KAI_ASM_INST(0x4f8febc5) // sdot v5.4s, v30.16b, v15.4b[2] + KAI_ASM_INST(0x4fafebc0) // sdot v0.4s, v30.16b, v15.4b[3] + KAI_ASM_INST(0x4f92e086) // sdot v6.4s, v4.16b, v18.4b[0] + KAI_ASM_INST(0x4fb2e082) // sdot v2.4s, v4.16b, v18.4b[1] + KAI_ASM_INST(0x4f92e896) // sdot v22.4s, v4.16b, v18.4b[2] + KAI_ASM_INST(0x4fb2e883) // sdot v3.4s, v4.16b, v18.4b[3] + KAI_ASM_INST(0x4f93e089) // sdot v9.4s, v4.16b, v19.4b[0] + KAI_ASM_INST(0x4fb3e094) // sdot v20.4s, v4.16b, v19.4b[1] + KAI_ASM_INST(0x4f93e885) // sdot v5.4s, v4.16b, v19.4b[2] + KAI_ASM_INST(0x4fb3e880) // sdot v0.4s, v4.16b, v19.4b[3] + KAI_ASM_INST(0x4f9de346) // sdot v6.4s, v26.16b, v29.4b[0] + KAI_ASM_INST(0x4fbde342) // sdot v2.4s, v26.16b, v29.4b[1] + KAI_ASM_INST(0x4f9deb56) // sdot v22.4s, v26.16b, v29.4b[2] + KAI_ASM_INST(0x4fbdeb43) // sdot v3.4s, v26.16b, v29.4b[3] + KAI_ASM_INST(0x4f8ae349) // sdot v9.4s, v26.16b, v10.4b[0] + KAI_ASM_INST(0x4faae354) // sdot v20.4s, v26.16b, v10.4b[1] + KAI_ASM_INST(0x4f8aeb45) // sdot v5.4s, v26.16b, v10.4b[2] + KAI_ASM_INST(0x4faaeb40) // sdot v0.4s, v26.16b, v10.4b[3] + KAI_ASM_INST(0x4f9ce206) // sdot v6.4s, v16.16b, v28.4b[0] + KAI_ASM_INST(0x4fbce202) // sdot v2.4s, v16.16b, v28.4b[1] + KAI_ASM_INST(0x4f9cea16) // sdot v22.4s, v16.16b, v28.4b[2] + KAI_ASM_INST(0x4fbcea03) // sdot v3.4s, v16.16b, v28.4b[3] + KAI_ASM_INST(0x4f9be209) // sdot v9.4s, v16.16b, v27.4b[0] + KAI_ASM_INST(0x4fbbe214) // sdot v20.4s, v16.16b, v27.4b[1] + KAI_ASM_INST(0x4f9bea05) // sdot v5.4s, v16.16b, v27.4b[2] + KAI_ASM_INST(0x4fbbea00) // sdot v0.4s, v16.16b, v27.4b[3] + bgt label_4 + ldr q1, [x10, #0x0] + ld1 { v17.4s }, [x23] + add x23, x23, #0x10 + scvtf v6.4s, v6.4s + ldr q29, [x10, #0x10] + ldr q16, [x23, #0x0] + scvtf v2.4s, v2.4s + scvtf v22.4s, v22.4s + scvtf v3.4s, v3.4s + add x10, x10, #0x20 + add x23, x23, #0x10 + fmla v24.4s, v1.4s, v17.s[0] + fmla v11.4s, v1.4s, v17.s[1] + fmla v23.4s, v1.4s, v17.s[2] + fmla v14.4s, v1.4s, v17.s[3] + fmul v8.4s, v29.4s, v16.s[0] + fmul v18.4s, v29.4s, v16.s[1] + fmul v17.4s, v29.4s, v16.s[2] + fmul v16.4s, v29.4s, v16.s[3] + fmla v24.4s, v6.4s, v8.4s + fmla v11.4s, v2.4s, v18.4s + fmla v23.4s, v22.4s, v17.4s + fmla v14.4s, v3.4s, v16.4s + ld1 { v17.4s }, [x21] + add x21, x21, #0x10 + scvtf v9.4s, v9.4s + scvtf v20.4s, v20.4s + ldr q16, [x21, #0x0] + scvtf v5.4s, v5.4s + scvtf v0.4s, v0.4s + add x21, x21, #0x10 + fmla v12.4s, v1.4s, v17.s[0] + fmla v31.4s, v1.4s, v17.s[1] + fmla v7.4s, v1.4s, v17.s[2] + fmla v25.4s, v1.4s, v17.s[3] + fmul v19.4s, v29.4s, v16.s[0] + fmul v18.4s, v29.4s, v16.s[1] + fmul v17.4s, v29.4s, v16.s[2] + fmul v16.4s, v29.4s, v16.s[3] + fmla v12.4s, v9.4s, v19.4s + fmla v31.4s, v20.4s, v18.4s + fmla v7.4s, v5.4s, v17.4s + fmla v25.4s, v0.4s, v16.4s + subs x22, x22, #0x1 + bgt label_3 + ldr q18, [x10, #0x0] + ld1r { v17.4s }, [x11] + add x20, x11, #0x4 + cmp x9, #0x4 + ld1r { v16.4s }, [x20] + add x10, x10, #0x10 + fadd v24.4s, v24.4s, v18.4s + fadd v11.4s, v11.4s, v18.4s + fadd v23.4s, v23.4s, v18.4s + fadd v14.4s, v14.4s, v18.4s + fadd v12.4s, v12.4s, v18.4s + fadd v31.4s, v31.4s, v18.4s + fadd v7.4s, v7.4s, v18.4s + fadd v25.4s, v25.4s, v18.4s + fmax v24.4s, v24.4s, v17.4s + fmax v11.4s, v11.4s, v17.4s + fmax v23.4s, v23.4s, v17.4s + fmax v14.4s, v14.4s, v17.4s + fmax v12.4s, v12.4s, v17.4s + fmax v31.4s, v31.4s, v17.4s + fmax v7.4s, v7.4s, v17.4s + fmax v25.4s, v25.4s, v17.4s + fmin v24.4s, v24.4s, v16.4s + fmin v11.4s, v11.4s, v16.4s + fmin v23.4s, v23.4s, v16.4s + fmin v14.4s, v14.4s, v16.4s + fmin v12.4s, v12.4s, v16.4s + fmin v31.4s, v31.4s, v16.4s + fmin v7.4s, v7.4s, v16.4s + fmin v25.4s, v25.4s, v16.4s + blt label_7 + mov x20, x13 + str q24, [x20, #0x0] + add x20, x20, x12 + str q11, [x20, #0x0] + add x20, x20, x12 + str q23, [x20, #0x0] + add x20, x20, x12 + str q14, [x20, #0x0] + add x20, x20, x12 + str q12, [x20, #0x0] + add x20, x20, x12 + str q31, [x20, #0x0] + add x20, x20, x12 + str q7, [x20, #0x0] + add x20, x20, x12 + str q25, [x20, #0x0] + b label_10 +KAI_ASM_LABEL(label_7) // Partial output + mov x27, x13 + add x26, x27, x12, LSL #2 + add x25, x26, x12, LSL #1 + add x24, x26, x12 + add x23, x25, x12 + add x22, x27, x12, LSL #1 + add x21, x27, x12 + add x20, x22, x12 + tbz x9, #1, label_8 + st1 { v25.d }[0], [x23], #0x8 + st1 { v7.d }[0], [x25], #0x8 + st1 { v31.d }[0], [x24], #0x8 + st1 { v12.d }[0], [x26], #0x8 + st1 { v14.d }[0], [x20], #0x8 + st1 { v23.d }[0], [x22], #0x8 + st1 { v11.d }[0], [x21], #0x8 + st1 { v24.d }[0], [x27], #0x8 + tbz x9, #0, label_9 + st1 { v25.s }[2], [x23] + st1 { v7.s }[2], [x25] + st1 { v31.s }[2], [x24] + st1 { v12.s }[2], [x26] + st1 { v14.s }[2], [x20] + st1 { v23.s }[2], [x22] + st1 { v11.s }[2], [x21] + st1 { v24.s }[2], [x27] + b label_9 +KAI_ASM_LABEL(label_8) // Output block 0: partial_1_0 + st1 { v25.s }[0], [x23] + st1 { v7.s }[0], [x25] + st1 { v31.s }[0], [x24] + st1 { v12.s }[0], [x26] + st1 { v14.s }[0], [x20] + st1 { v23.s }[0], [x22] + st1 { v11.s }[0], [x21] + st1 { v24.s }[0], [x27] +KAI_ASM_LABEL(label_9) // Output block 0: Done +KAI_ASM_LABEL(label_10) // Output stage exit + subs x9, x9, #0x4 + add x13, x13, #0x10 + bgt label_2 + mov x20, #0x2 + sub x14, x14, #0x8 + cmp x14, #0x8 + mov x13, x28 + madd x17, x20, x6, x17 + bge label_1 +KAI_ASM_LABEL(label_11) // Row loop skip + cbz x14, label_21 +KAI_ASM_LABEL(label_12) // Row tail: Row loop + mov x26, x16 + mov x25, x15 + add x24, x13, x12, LSL #2 +KAI_ASM_LABEL(label_13) // Row tail: Column loop + movi v24.16b, #0x0 + movi v11.16b, #0x0 + mov x23, x17 + mov x21, x8 + movi v23.16b, #0x0 + movi v14.16b, #0x0 +KAI_ASM_LABEL(label_14) // Row tail: Block loop + movi v6.4s, #0x0 + movi v2.4s, #0x0 + mov x20, x7 + movi v22.4s, #0x0 + movi v3.4s, #0x0 +KAI_ASM_LABEL(label_15) // Row tail: Sub block loop + ldr q31, [x26, #0x0] + ldr q30, [x23, #0x0] + subs x20, x20, #0x1 + ldr q29, [x26, #0x10] + ldr q28, [x23, #0x10] + ldr q5, [x26, #0x20] + ldr q27, [x23, #0x20] + ldr q26, [x26, #0x30] + ldr q25, [x23, #0x30] + shl v7.16b, v31.16b, #0x4 + and v31.16b, v31.16b, v13.16b + ldr q10, [x23, #0x40] + ldr q8, [x23, #0x50] + shl v21.16b, v29.16b, #0x4 + and v29.16b, v29.16b, v13.16b + ldr q20, [x23, #0x60] + ldr q18, [x23, #0x70] + shl v17.16b, v5.16b, #0x4 + and v5.16b, v5.16b, v13.16b + KAI_ASM_INST(0x4f9ee0e6) // sdot v6.4s, v7.16b, v30.4b[0] + KAI_ASM_INST(0x4fbee0e2) // sdot v2.4s, v7.16b, v30.4b[1] + shl v16.16b, v26.16b, #0x4 + add x26, x26, #0x40 + KAI_ASM_INST(0x4f9ee8f6) // sdot v22.4s, v7.16b, v30.4b[2] + KAI_ASM_INST(0x4fbee8e3) // sdot v3.4s, v7.16b, v30.4b[3] + and v26.16b, v26.16b, v13.16b + add x23, x23, #0x80 + KAI_ASM_INST(0x4f9ce2a6) // sdot v6.4s, v21.16b, v28.4b[0] + KAI_ASM_INST(0x4fbce2a2) // sdot v2.4s, v21.16b, v28.4b[1] + KAI_ASM_INST(0x4f9ceab6) // sdot v22.4s, v21.16b, v28.4b[2] + KAI_ASM_INST(0x4fbceaa3) // sdot v3.4s, v21.16b, v28.4b[3] + KAI_ASM_INST(0x4f9be226) // sdot v6.4s, v17.16b, v27.4b[0] + KAI_ASM_INST(0x4fbbe222) // sdot v2.4s, v17.16b, v27.4b[1] + KAI_ASM_INST(0x4f9bea36) // sdot v22.4s, v17.16b, v27.4b[2] + KAI_ASM_INST(0x4fbbea23) // sdot v3.4s, v17.16b, v27.4b[3] + KAI_ASM_INST(0x4f99e206) // sdot v6.4s, v16.16b, v25.4b[0] + KAI_ASM_INST(0x4fb9e202) // sdot v2.4s, v16.16b, v25.4b[1] + KAI_ASM_INST(0x4f99ea16) // sdot v22.4s, v16.16b, v25.4b[2] + KAI_ASM_INST(0x4fb9ea03) // sdot v3.4s, v16.16b, v25.4b[3] + KAI_ASM_INST(0x4f8ae3e6) // sdot v6.4s, v31.16b, v10.4b[0] + KAI_ASM_INST(0x4faae3e2) // sdot v2.4s, v31.16b, v10.4b[1] + KAI_ASM_INST(0x4f8aebf6) // sdot v22.4s, v31.16b, v10.4b[2] + KAI_ASM_INST(0x4faaebe3) // sdot v3.4s, v31.16b, v10.4b[3] + KAI_ASM_INST(0x4f88e3a6) // sdot v6.4s, v29.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e3a2) // sdot v2.4s, v29.16b, v8.4b[1] + KAI_ASM_INST(0x4f88ebb6) // sdot v22.4s, v29.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8eba3) // sdot v3.4s, v29.16b, v8.4b[3] + KAI_ASM_INST(0x4f94e0a6) // sdot v6.4s, v5.16b, v20.4b[0] + KAI_ASM_INST(0x4fb4e0a2) // sdot v2.4s, v5.16b, v20.4b[1] + KAI_ASM_INST(0x4f94e8b6) // sdot v22.4s, v5.16b, v20.4b[2] + KAI_ASM_INST(0x4fb4e8a3) // sdot v3.4s, v5.16b, v20.4b[3] + KAI_ASM_INST(0x4f92e346) // sdot v6.4s, v26.16b, v18.4b[0] + KAI_ASM_INST(0x4fb2e342) // sdot v2.4s, v26.16b, v18.4b[1] + KAI_ASM_INST(0x4f92eb56) // sdot v22.4s, v26.16b, v18.4b[2] + KAI_ASM_INST(0x4fb2eb43) // sdot v3.4s, v26.16b, v18.4b[3] + bgt label_15 + ldr q18, [x26, #0x0] + ld1 { v17.4s }, [x23] + add x23, x23, #0x10 + scvtf v6.4s, v6.4s + ldr q20, [x26, #0x10] + ldr q16, [x23, #0x0] + scvtf v2.4s, v2.4s + scvtf v22.4s, v22.4s + scvtf v3.4s, v3.4s + add x26, x26, #0x20 + add x23, x23, #0x10 + fmla v24.4s, v18.4s, v17.s[0] + fmla v11.4s, v18.4s, v17.s[1] + fmla v23.4s, v18.4s, v17.s[2] + fmla v14.4s, v18.4s, v17.s[3] + fmul v17.4s, v20.4s, v16.s[0] + fmul v18.4s, v20.4s, v16.s[1] + fmul v8.4s, v20.4s, v16.s[2] + fmul v16.4s, v20.4s, v16.s[3] + fmla v24.4s, v6.4s, v17.4s + fmla v11.4s, v2.4s, v18.4s + fmla v23.4s, v22.4s, v8.4s + fmla v14.4s, v3.4s, v16.4s + subs x21, x21, #0x1 + bgt label_14 + ldr q18, [x26, #0x0] + ld1r { v17.4s }, [x11] + add x20, x11, #0x4 + cmp x25, #0x4 + ld1r { v16.4s }, [x20] + add x26, x26, #0x10 + fadd v24.4s, v24.4s, v18.4s + fadd v11.4s, v11.4s, v18.4s + fadd v23.4s, v23.4s, v18.4s + fadd v14.4s, v14.4s, v18.4s + fmax v24.4s, v24.4s, v17.4s + fmax v11.4s, v11.4s, v17.4s + fmax v23.4s, v23.4s, v17.4s + fmax v14.4s, v14.4s, v17.4s + fmin v24.4s, v24.4s, v16.4s + fmin v11.4s, v11.4s, v16.4s + fmin v23.4s, v23.4s, v16.4s + fmin v14.4s, v14.4s, v16.4s + blt label_17 + mov x20, x13 + cmp x14, #0x1 + str q24, [x20, #0x0] + add x20, x20, x12 + ble label_20 + cmp x14, #0x2 + str q11, [x20, #0x0] + add x20, x20, x12 + ble label_20 + cmp x14, #0x3 + str q23, [x20, #0x0] + add x20, x20, x12 + ble label_20 + str q14, [x20, #0x0] + b label_20 +KAI_ASM_LABEL(label_17) // Row tail: Partial output + mov x23, x13 + cmp x14, #0x1 + add x22, x23, x12 + csel x22, x22, x23, GT + cmp x14, #0x2 + add x21, x23, x12, LSL #1 + csel x21, x21, x22, GT + cmp x14, #0x3 + add x20, x21, x12 + csel x20, x20, x21, GT + tbz x25, #1, label_18 + st1 { v14.d }[0], [x20], #0x8 + st1 { v23.d }[0], [x21], #0x8 + st1 { v11.d }[0], [x22], #0x8 + st1 { v24.d }[0], [x23], #0x8 + tbz x25, #0, label_19 + st1 { v14.s }[2], [x20] + st1 { v23.s }[2], [x21] + st1 { v11.s }[2], [x22] + st1 { v24.s }[2], [x23] + b label_19 +KAI_ASM_LABEL(label_18) // Row tail: Output block 0: partial_1_0 + st1 { v14.s }[0], [x20] + st1 { v23.s }[0], [x21] + st1 { v11.s }[0], [x22] + st1 { v24.s }[0], [x23] +KAI_ASM_LABEL(label_19) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_20) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x13, x13, #0x10 + bgt label_13 + subs x14, x14, #0x4 + add x17, x17, x6 + mov x13, x24 + bgt label_12 +KAI_ASM_LABEL(label_21) // Row tail: Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c index 1cfcaf71c727465834ee47a236cf2dbffecb3144..ea87dd8648239ced3e6f9a44386aa9bc95b2bf3a 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c @@ -75,7 +75,7 @@ void kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( KAI_ASSUME((bl % kr) == 0); KAI_ASSUME((k % bl) == 0); KAI_ASSUME((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT((kr / sr) % 8 == 0); + KAI_ASSERT(((kr / sr) == 8) || ((kr / sr) == 4)); if (m == 0) { return; @@ -119,12 +119,10 @@ void kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( int32_t qsum = 0; // Quantize the blocks - for (k_idx = 0; k_idx < (int32_t)bl; k_idx += k_block_len) { - size_t k_block_idx = 0; - for (; k_block_idx <= (size_t)k_block_len - 8; k_block_idx += 8) { - // Clamp at the last valid k-index - const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k - 1); - + for (k_idx = 0; k_idx <= (int32_t)bl - k_block_len; k_idx += k_block_len) { + // Clamp at the last valid k-index + const size_t k_idx_start = KAI_MIN((size_t)k_idx, k - 1); + if (k_block_len == 8) { const float32x4_t vsrc_0 = vld1q_f32(src_ptr + k_idx_start); const float32x4_t vsrc_1 = vld1q_f32(src_ptr + k_idx_start + 4); @@ -145,6 +143,29 @@ void kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( int8x8_t v0_s8 = vqmovn_s16(v_s16); vst1_s8(dst_ptr, v0_s8); dst_ptr += 8 * sizeof(int8_t); + } else if (k_block_len == 4) { + const float32x2_t vsrc_0 = vld1_f32(src_ptr + k_idx_start); + const float32x2_t vsrc_1 = vld1_f32(src_ptr + k_idx_start + 2); + + // Scale the values + float32x2_t v0_f32 = vmul_n_f32(vsrc_0, scale0); + float32x2_t v1_f32 = vmul_n_f32(vsrc_1, scale0); + + int32x2_t v0_s32 = vcvtn_s32_f32(v0_f32); + int32x2_t v1_s32 = vcvtn_s32_f32(v1_f32); + int16x4_t v_s16 = vqmovn_s32(vcombine_s32(v0_s32, v1_s32)); + + v_s16 = vmax_s16(v_s16, vdup_n_s16(INT8_MIN)); + v_s16 = vmin_s16(v_s16, vdup_n_s16(INT8_MAX)); + + // Update the sum + qsum += vaddv_s16(v_s16); + + dst_ptr[0] = vqmovnh_s16(vget_lane_s16(v_s16, 0)); + dst_ptr[1] = vqmovnh_s16(vget_lane_s16(v_s16, 1)); + dst_ptr[2] = vqmovnh_s16(vget_lane_s16(v_s16, 2)); + dst_ptr[3] = vqmovnh_s16(vget_lane_s16(v_s16, 3)); + dst_ptr += 4 * sizeof(int8_t); } dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); } diff --git a/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp b/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp index 699ced2ad8d79632aec179f14cf2e5515cfce48d..a36ff48b65e5cd470edb6d1a22fa3d130e52b639 100644 --- a/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp @@ -15,6 +15,7 @@ #include #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p_qai4c32p_interface.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.h" @@ -36,15 +37,90 @@ namespace kai::test { -static const std::array, 2> +static const std::array, 3> variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p = { {{UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod), "kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod", cpu_has_dotprod}, {UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm), - "kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", cpu_has_i8mm}}}; + "kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", cpu_has_i8mm}, + {UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod), + "kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod", cpu_has_dotprod}}}; class MatMulTest_f32_qsi8d32p_qai4c32p : public ::testing::TestWithParam {}; +TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, LhsPackedWithSameBlockdepth) { + // Verify LHS quant and pack int8 kernel behaves same for int4 and int8 matmul kernels, + // when the block-depth is same for different values of kr, sr. + + const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_index); + + const std::uint32_t seed = 0; + + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; + const size_t bl = 32; + + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + + // Generates input data. + const auto ref_lhs = fill_random(M * K, seed + 0); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit symmetric quantization. + const auto [ref_lhs_qvalues, ref_lhs_scales] = + quantize_symmetric_per_block_dynamic(ref_lhs.data(), M, K, bl); + + // Runs the LHS packing micro-kernel. + const auto lhs_start_row = rect.start_row(); + const auto imp_packed_lhs_size = + kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon(M, K, bl, mr, kr, sr); + Buffer imp_packed_lhs(imp_packed_lhs_size, 0); + + auto lhs_stride = K * sizeof(float); + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, lhs_stride); + auto lhs_packed_offset = + kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, K, bl, mr, kr, sr); + + kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( + rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data() + lhs_offset), + lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); + + const size_t kr_qsi8 = kr / sr; + const size_t sr_qsi8 = 1; + const auto imp_packed_lhs_qsi8_size = + kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon(M, K, bl, mr, kr_qsi8, sr_qsi8); + Buffer imp_packed_lhs_qsi8(imp_packed_lhs_qsi8_size, 0); + + auto lhs_qsi8_packed_offset = + kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, K, bl, mr, kr_qsi8, sr_qsi8); + + ASSERT_EQ(lhs_qsi8_packed_offset, lhs_packed_offset); + + kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( + rect.height() /* m */, K, bl, mr, kr_qsi8, sr_qsi8, 0, + reinterpret_cast(ref_lhs.data() + lhs_offset), lhs_stride, + imp_packed_lhs_qsi8.data() + lhs_qsi8_packed_offset); + + auto* imp_packed_lhs_ptr = reinterpret_cast(imp_packed_lhs.data()); + auto* imp_packed_lhs_qsi8_ptr = reinterpret_cast(imp_packed_lhs_qsi8.data()); + for (size_t i = 0; i < imp_packed_lhs_qsi8_size; i++) { + ASSERT_EQ(imp_packed_lhs_ptr[i], imp_packed_lhs_qsi8_ptr[i]); + } +} + TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, EndToEnd) { const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_index); @@ -112,7 +188,7 @@ TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, EndToEnd) { const auto lhs_start_row = rect.start_row(); const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon(M, K, bl, mr, kr, sr); - Buffer imp_packed_lhs(imp_packed_lhs_size); + Buffer imp_packed_lhs(imp_packed_lhs_size, 0); auto lhs_stride = K * sizeof(float); auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, lhs_stride);