diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b5d8567c650d7e3dc6bce5cb210fc5e30fd60ae..17b6f967ebd0013a60f536a75316ad2208663842 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Matrix multiplication (1xN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F32 output, optimized for FEAT_DotProd. - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_I8MM. - Matrix multiplication (1xN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_DotProd. + - Matrix multiplication (MxN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with F16 output, optimized for FEAT_I8MM and FEAT_DotProd. + - Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with F16 output, optimized for FEAT_DotProd. ## v1.6.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index 8e0c713024da474aadb962f705def01a39e46ffa..5ff5d8057e6642ede9c66fc0879c0be168a40baf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,11 +99,18 @@ set(KLEIDIAI_FILES_NEON_FP16 kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f16_neon.c + kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.c ) set(KLEIDIAI_FILES_NEON_FP16_DOTPROD_ASM kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod.c ) set(KLEIDIAI_FILES_NEON_FP16_DOTPROD @@ -112,6 +119,8 @@ set(KLEIDIAI_FILES_NEON_FP16_DOTPROD set(KLEIDIAI_FILES_NEON_FP16_I8MM_ASM kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm.c ) set(KLEIDIAI_FILES_NEON_FP16_I8MM @@ -368,6 +377,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp test/tests/matmul_clamp_f16_qsi8d32p_qai4c32p_test.cpp + test/tests/matmul_clamp_f16_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp test/tests/matmul_test.cpp ) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 79bfac8b50910ed212b0d13f1a9b3378e82b0394..0532f4347bf8689e953703b65882565a1b9a33c6 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -49,6 +49,7 @@ NEON_KERNELS_ASM = [ # buildifier: keep sorted FP16_KERNELS = [ "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla", + "pack/kai_lhs_quant_pack_qai8dxp_f16_neon", "pack/kai_lhs_quant_pack_qsi8d32pscalef32_f16_neon", "pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", ] @@ -72,11 +73,15 @@ FP16_BF16_KERNELS = [ # buildifier: keep sorted FP16_DOTPROD_KERNELS_ASM = [ + "matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod", + "matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod", + "matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod", "matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod", ] # buildifier: keep sorted FP16_I8MM_KERNELS_ASM = [ + "matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm", "matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", ] diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.c new file mode 100644 index 0000000000000000000000000000000000000000..dfac28289111d99e740ba62431cdfc4e59c77f05 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.c @@ -0,0 +1,165 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && \ + !defined(_M_ARM64) +#error "Dotprod extension and fp16 vector arithmetic required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + void* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 8; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t rhs_packed_stride = kai_nr * (k_internal / kai_num_bytes_recip_qvalue_rhs); + + rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + if (m == 0) { + return; + } + const size_t k_internal = kai_get_k_roundedup(k); + size_t num_blocks = k_internal / kai_bl; + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + + kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.h new file mode 100644 index 0000000000000000000000000000000000000000..8755135825abede60fc3679cf9b03b9c5453823c --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.h @@ -0,0 +1,139 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_f16_neon to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 4-bit with per-channel quantization (qsi4cx) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-channel quantization (qsi4cx) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod +/// +/// @param[in] m The number of output rows written. It must be 1. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..cdff336b250b694e95473c42db23cfd06cfcf626 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod_asm.S @@ -0,0 +1,146 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + mov x13, #0x20 + movi v27.16b, #0xf0 + mov x21, #0x8 + ldr x12, [x0, #0x38] + ldr x20, [x0, #0x28] + ldr x11, [x0, #0x8] + ldr x10, [x0, #0x10] + ldr x9, [x0, #0x30] + ldr x28, [x0, #0x0] + ldr x27, [x0, #0x20] + madd x13, x12, x13, x21 + ldr x26, [x0, #0x18] + mov x25, x20 +KAI_ASM_LABEL(label_1) // Row loop + mov x24, x10 + mov x23, x9 + add x22, x28, x27 +KAI_ASM_LABEL(label_2) // Column loop + mov x21, x11 + movi v26.4s, #0x0 + mov x20, x12 +KAI_ASM_LABEL(label_3) // Sub block loop + ldr q25, [x24, #0x0] + ldr q24, [x21, #0x0] + subs x20, x20, #0x1 + ldr q23, [x24, #0x10] + ldr q22, [x24, #0x20] + ldr q21, [x24, #0x30] + ldr q20, [x21, #0x10] + add x24, x24, #0x40 + add x21, x21, #0x20 + shl v19.16b, v25.16b, #0x4 + and v25.16b, v25.16b, v27.16b + shl v18.16b, v23.16b, #0x4 + shl v17.16b, v22.16b, #0x4 + shl v16.16b, v21.16b, #0x4 + and v23.16b, v23.16b, v27.16b + KAI_ASM_INST(0x4f98e27a) // sdot v26.4s, v19.16b, v24.4b[0] + and v22.16b, v22.16b, v27.16b + and v21.16b, v21.16b, v27.16b + KAI_ASM_INST(0x4fb8e25a) // sdot v26.4s, v18.16b, v24.4b[1] + KAI_ASM_INST(0x4f98ea3a) // sdot v26.4s, v17.16b, v24.4b[2] + KAI_ASM_INST(0x4fb8ea1a) // sdot v26.4s, v16.16b, v24.4b[3] + KAI_ASM_INST(0x4f94e33a) // sdot v26.4s, v25.16b, v20.4b[0] + KAI_ASM_INST(0x4fb4e2fa) // sdot v26.4s, v23.16b, v20.4b[1] + KAI_ASM_INST(0x4f94eada) // sdot v26.4s, v22.16b, v20.4b[2] + KAI_ASM_INST(0x4fb4eaba) // sdot v26.4s, v21.16b, v20.4b[3] + bgt label_3 + ldr q22, [x24, #0x0] + ld1r { v21.4s }, [x21] + add x21, x21, #0x4 + add x20, x26, #0x4 + ld1r { v20.4s }, [x21] + ldr q16, [x24, #0x10] + cmp x23, #0x4 + ldr q19, [x24, #0x20] + ld1r { v18.4s }, [x26] + add x24, x24, #0x30 + ld1r { v17.4s }, [x20] + mla v26.4s, v22.4s, v21.s[0] + fmul v16.4s, v16.4s, v20.4s + scvtf v26.4s, v26.4s + fmul v16.4s, v26.4s, v16.4s + fadd v16.4s, v16.4s, v19.4s + fmax v16.4s, v16.4s, v18.4s + fmin v16.4s, v16.4s, v17.4s + fcvtn v16.4h, v16.4s + blt label_4 + str d16, [x28, #0x0] + b label_7 +KAI_ASM_LABEL(label_4) // Partial output + mov x20, x28 + tbz x23, #1, label_5 + st1 { v16.s }[0], [x20], #0x4 + tbz x23, #0, label_6 + st1 { v16.h }[2], [x20] + b label_6 +KAI_ASM_LABEL(label_5) // Output block 0: partial_1_0 + st1 { v16.h }[0], [x20] +KAI_ASM_LABEL(label_6) // Output block 0: Done +KAI_ASM_LABEL(label_7) // Stores done + subs x23, x23, #0x4 + add x28, x28, #0x8 + bgt label_2 + subs x25, x25, #0x1 + add x11, x11, x13 + mov x28, x22 + bgt label_1 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.c new file mode 100644 index 0000000000000000000000000000000000000000..6e5316c60546e1aaa5022de991b5f3c930e0612d --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.c @@ -0,0 +1,165 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && \ + !defined(_M_ARM64) +#error "Dotprod extension and fp16 vector arithmetic required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + uint16_t* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t rhs_packed_stride = kai_nr * (k_internal / kai_num_bytes_recip_qvalue_rhs); + + rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + if (m == 0) { + return; + } + const size_t k_internal = kai_get_k_roundedup(k); + size_t num_blocks = k_internal / kai_bl; + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + + kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.h new file mode 100644 index 0000000000000000000000000000000000000000..c3812a14049339196b9e0a5fa041fe0052b28f41 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.h @@ -0,0 +1,139 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_f16_neon to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 4-bit with per-channel quantization (qsi4cx) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-channel quantization (qsi4cx) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod +/// +/// @param[in] m The number of output rows written. It must be 1. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..54a67797cf634338cd91e58247aeed7e8e3f25b5 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod_asm.S @@ -0,0 +1,149 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + mov x13, #0x20 + movi v30.16b, #0xf0 + mov x21, #0x8 + ldr x12, [x0, #0x38] + ldr x20, [x0, #0x28] + ldr x11, [x0, #0x8] + ldr x10, [x0, #0x10] + ldr x9, [x0, #0x30] + ldr x28, [x0, #0x0] + ldr x27, [x0, #0x20] + madd x13, x12, x13, x21 + ldr x26, [x0, #0x18] + mov x25, x20 +KAI_ASM_LABEL(label_1) // Row loop + mov x24, x10 + mov x23, x9 + add x22, x28, x27 +KAI_ASM_LABEL(label_2) // Column loop + mov x21, x11 + movi v29.4s, #0x0 + movi v28.4s, #0x0 + mov x20, x12 +KAI_ASM_LABEL(label_3) // Sub block loop + ldr q27, [x24, #0x0] + ldr q26, [x24, #0x10] + subs x20, x20, #0x1 + ld1r { v25.2d }, [x21], #0x8 + ldr q24, [x24, #0x20] + ldr q23, [x24, #0x30] + add x24, x24, #0x40 + ld1r { v22.2d }, [x21], #0x8 + ld1r { v21.2d }, [x21], #0x8 + shl v20.16b, v27.16b, #0x4 + shl v19.16b, v26.16b, #0x4 + ld1r { v18.2d }, [x21], #0x8 + shl v17.16b, v24.16b, #0x4 + and v27.16b, v27.16b, v30.16b + shl v16.16b, v23.16b, #0x4 + and v26.16b, v26.16b, v30.16b + KAI_ASM_INST(0x4e99969d) // sdot v29.4s, v20.16b, v25.16b + KAI_ASM_INST(0x4e99967c) // sdot v28.4s, v19.16b, v25.16b + and v24.16b, v24.16b, v30.16b + and v23.16b, v23.16b, v30.16b + KAI_ASM_INST(0x4e96963d) // sdot v29.4s, v17.16b, v22.16b + KAI_ASM_INST(0x4e96961c) // sdot v28.4s, v16.16b, v22.16b + KAI_ASM_INST(0x4e95977d) // sdot v29.4s, v27.16b, v21.16b + KAI_ASM_INST(0x4e95975c) // sdot v28.4s, v26.16b, v21.16b + KAI_ASM_INST(0x4e92971d) // sdot v29.4s, v24.16b, v18.16b + KAI_ASM_INST(0x4e9296fc) // sdot v28.4s, v23.16b, v18.16b + bgt label_3 + ldr q22, [x24, #0x0] + ld1r { v21.4s }, [x21] + addp v29.4s, v29.4s, v28.4s + add x21, x21, #0x4 + ld1r { v20.4s }, [x21] + ldr q16, [x24, #0x10] + add x20, x26, #0x4 + cmp x23, #0x4 + ldr q19, [x24, #0x20] + ld1r { v18.4s }, [x26] + add x24, x24, #0x30 + ld1r { v17.4s }, [x20] + mla v29.4s, v22.4s, v21.s[0] + fmul v16.4s, v16.4s, v20.4s + scvtf v29.4s, v29.4s + fmul v16.4s, v29.4s, v16.4s + fadd v16.4s, v16.4s, v19.4s + fmax v16.4s, v16.4s, v18.4s + fmin v16.4s, v16.4s, v17.4s + fcvtn v16.4h, v16.4s + blt label_4 + str d16, [x28, #0x0] + b label_7 +KAI_ASM_LABEL(label_4) // Partial output + mov x20, x28 + tbz x23, #1, label_5 + st1 { v16.s }[0], [x20], #0x4 + tbz x23, #0, label_6 + st1 { v16.h }[2], [x20] + b label_6 +KAI_ASM_LABEL(label_5) // Output block 0: partial_1_0 + st1 { v16.h }[0], [x20] +KAI_ASM_LABEL(label_6) // Output block 0: Done +KAI_ASM_LABEL(label_7) // Stores done + subs x23, x23, #0x4 + add x28, x28, #0x8 + bgt label_2 + subs x25, x25, #0x1 + add x11, x11, x13 + mov x28, x22 + bgt label_1 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod.c new file mode 100644 index 0000000000000000000000000000000000000000..a0296f937088feea7636f9fa3835ead1c4bf5590 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod.c @@ -0,0 +1,165 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && \ + !defined(_M_ARM64) +#error "Dotprod extension and fp16 vector arithmetic required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + void* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 16; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 8; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t rhs_packed_stride = kai_nr * (k_internal / kai_num_bytes_recip_qvalue_rhs); + + rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + if (m == 0) { + return; + } + const size_t k_internal = kai_get_k_roundedup(k); + size_t num_blocks = k_internal / kai_bl; + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + + kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod.h new file mode 100644 index 0000000000000000000000000000000000000000..d6ade6530172311e910e70cec42539dfc8d31358 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod.h @@ -0,0 +1,139 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_f16_neon to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 4-bit with per-channel quantization (qsi4cx) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-channel quantization (qsi4cx) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..2ed3769931f29f2b60a9e397cc185a054109a1f3 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod_asm.S @@ -0,0 +1,723 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x6, #0x80 + mov x21, #0x20 + ldr x20, [x0, #0x28] + ldr x7, [x0, #0x38] + ldr x8, [x0, #0x8] + ldr x17, [x0, #0x10] + ldr x16, [x0, #0x30] + ldr x15, [x0, #0x0] + mov x14, x20 + ldr x13, [x0, #0x20] + madd x6, x7, x6, x21 + ldr x12, [x0, #0x18] + cmp x14, #0x10 + blt label_14 +KAI_ASM_LABEL(label_1) // Row loop + mov x11, x17 + mov x10, x16 + add x9, x15, x13, LSL #4 +KAI_ASM_LABEL(label_2) // Column loop + mov x27, x8 + movi v31.4s, #0x0 + movi v30.4s, #0x0 + mov x23, x7 + movi v29.4s, #0x0 + movi v28.4s, #0x0 + movi v27.4s, #0x0 + movi v26.4s, #0x0 + add x22, x27, x6 + add x21, x22, x6 + add x20, x21, x6 + movi v25.4s, #0x0 + movi v24.4s, #0x0 + movi v23.4s, #0x0 + movi v22.4s, #0x0 + movi v21.4s, #0x0 + movi v20.4s, #0x0 + movi v19.4s, #0x0 + movi v18.4s, #0x0 + movi v17.4s, #0x0 + movi v16.4s, #0x0 +KAI_ASM_LABEL(label_3) // Sub block loop + ldr q11, [x11, #0x0] + ldr q8, [x27, #0x0] + movi v6.16b, #0xf0 + subs x23, x23, #0x1 + ldr q1, [x22, #0x0] + ldr q15, [x21, #0x0] + ldr q3, [x20, #0x0] + ldr q13, [x11, #0x10] + ldr q10, [x27, #0x10] + ldr q7, [x22, #0x10] + shl v9.16b, v11.16b, #0x4 + and v11.16b, v11.16b, v6.16b + ldr q4, [x21, #0x10] + ldr q0, [x20, #0x10] + ldr q5, [x11, #0x20] + ldr q2, [x27, #0x20] + shl v12.16b, v13.16b, #0x4 + and v13.16b, v13.16b, v6.16b + ldr q14, [x22, #0x20] + KAI_ASM_INST(0x4f88e13f) // sdot v31.4s, v9.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e13e) // sdot v30.4s, v9.16b, v8.4b[1] + KAI_ASM_INST(0x4f88e93d) // sdot v29.4s, v9.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8e93c) // sdot v28.4s, v9.16b, v8.4b[3] + ldr q8, [x21, #0x20] + KAI_ASM_INST(0x4f81e13b) // sdot v27.4s, v9.16b, v1.4b[0] + KAI_ASM_INST(0x4fa1e13a) // sdot v26.4s, v9.16b, v1.4b[1] + KAI_ASM_INST(0x4f81e939) // sdot v25.4s, v9.16b, v1.4b[2] + KAI_ASM_INST(0x4fa1e938) // sdot v24.4s, v9.16b, v1.4b[3] + ldr q1, [x20, #0x20] + KAI_ASM_INST(0x4f8fe137) // sdot v23.4s, v9.16b, v15.4b[0] + KAI_ASM_INST(0x4fafe136) // sdot v22.4s, v9.16b, v15.4b[1] + KAI_ASM_INST(0x4f8fe935) // sdot v21.4s, v9.16b, v15.4b[2] + KAI_ASM_INST(0x4fafe934) // sdot v20.4s, v9.16b, v15.4b[3] + ldr q15, [x11, #0x30] + add x11, x11, #0x40 + KAI_ASM_INST(0x4f83e133) // sdot v19.4s, v9.16b, v3.4b[0] + KAI_ASM_INST(0x4fa3e132) // sdot v18.4s, v9.16b, v3.4b[1] + KAI_ASM_INST(0x4f83e931) // sdot v17.4s, v9.16b, v3.4b[2] + KAI_ASM_INST(0x4fa3e930) // sdot v16.4s, v9.16b, v3.4b[3] + ldr q3, [x27, #0x30] + ldr q9, [x22, #0x30] + KAI_ASM_INST(0x4f8ae19f) // sdot v31.4s, v12.16b, v10.4b[0] + KAI_ASM_INST(0x4faae19e) // sdot v30.4s, v12.16b, v10.4b[1] + KAI_ASM_INST(0x4f8ae99d) // sdot v29.4s, v12.16b, v10.4b[2] + KAI_ASM_INST(0x4faae99c) // sdot v28.4s, v12.16b, v10.4b[3] + ldr q10, [x21, #0x30] + KAI_ASM_INST(0x4f87e19b) // sdot v27.4s, v12.16b, v7.4b[0] + KAI_ASM_INST(0x4fa7e19a) // sdot v26.4s, v12.16b, v7.4b[1] + KAI_ASM_INST(0x4f87e999) // sdot v25.4s, v12.16b, v7.4b[2] + KAI_ASM_INST(0x4fa7e998) // sdot v24.4s, v12.16b, v7.4b[3] + ldr q7, [x20, #0x30] + KAI_ASM_INST(0x4f84e197) // sdot v23.4s, v12.16b, v4.4b[0] + KAI_ASM_INST(0x4fa4e196) // sdot v22.4s, v12.16b, v4.4b[1] + KAI_ASM_INST(0x4f84e995) // sdot v21.4s, v12.16b, v4.4b[2] + KAI_ASM_INST(0x4fa4e994) // sdot v20.4s, v12.16b, v4.4b[3] + ldr q4, [x27, #0x40] + KAI_ASM_INST(0x4f80e193) // sdot v19.4s, v12.16b, v0.4b[0] + KAI_ASM_INST(0x4fa0e192) // sdot v18.4s, v12.16b, v0.4b[1] + KAI_ASM_INST(0x4f80e991) // sdot v17.4s, v12.16b, v0.4b[2] + KAI_ASM_INST(0x4fa0e990) // sdot v16.4s, v12.16b, v0.4b[3] + ldr q0, [x22, #0x40] + shl v12.16b, v5.16b, #0x4 + and v5.16b, v5.16b, v6.16b + KAI_ASM_INST(0x4f82e19f) // sdot v31.4s, v12.16b, v2.4b[0] + KAI_ASM_INST(0x4fa2e19e) // sdot v30.4s, v12.16b, v2.4b[1] + KAI_ASM_INST(0x4f82e99d) // sdot v29.4s, v12.16b, v2.4b[2] + KAI_ASM_INST(0x4fa2e99c) // sdot v28.4s, v12.16b, v2.4b[3] + ldr q2, [x21, #0x40] + KAI_ASM_INST(0x4f8ee19b) // sdot v27.4s, v12.16b, v14.4b[0] + KAI_ASM_INST(0x4faee19a) // sdot v26.4s, v12.16b, v14.4b[1] + KAI_ASM_INST(0x4f8ee999) // sdot v25.4s, v12.16b, v14.4b[2] + KAI_ASM_INST(0x4faee998) // sdot v24.4s, v12.16b, v14.4b[3] + ldr q14, [x20, #0x40] + KAI_ASM_INST(0x4f88e197) // sdot v23.4s, v12.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e196) // sdot v22.4s, v12.16b, v8.4b[1] + KAI_ASM_INST(0x4f88e995) // sdot v21.4s, v12.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8e994) // sdot v20.4s, v12.16b, v8.4b[3] + ldr q8, [x27, #0x50] + KAI_ASM_INST(0x4f81e193) // sdot v19.4s, v12.16b, v1.4b[0] + KAI_ASM_INST(0x4fa1e192) // sdot v18.4s, v12.16b, v1.4b[1] + KAI_ASM_INST(0x4f81e991) // sdot v17.4s, v12.16b, v1.4b[2] + KAI_ASM_INST(0x4fa1e990) // sdot v16.4s, v12.16b, v1.4b[3] + ldr q12, [x22, #0x50] + shl v1.16b, v15.16b, #0x4 + and v15.16b, v15.16b, v6.16b + ldr q6, [x21, #0x50] + KAI_ASM_INST(0x4f83e03f) // sdot v31.4s, v1.16b, v3.4b[0] + KAI_ASM_INST(0x4fa3e03e) // sdot v30.4s, v1.16b, v3.4b[1] + KAI_ASM_INST(0x4f83e83d) // sdot v29.4s, v1.16b, v3.4b[2] + KAI_ASM_INST(0x4fa3e83c) // sdot v28.4s, v1.16b, v3.4b[3] + ldr q3, [x20, #0x50] + KAI_ASM_INST(0x4f89e03b) // sdot v27.4s, v1.16b, v9.4b[0] + KAI_ASM_INST(0x4fa9e03a) // sdot v26.4s, v1.16b, v9.4b[1] + KAI_ASM_INST(0x4f89e839) // sdot v25.4s, v1.16b, v9.4b[2] + KAI_ASM_INST(0x4fa9e838) // sdot v24.4s, v1.16b, v9.4b[3] + ldr q9, [x27, #0x60] + KAI_ASM_INST(0x4f8ae037) // sdot v23.4s, v1.16b, v10.4b[0] + KAI_ASM_INST(0x4faae036) // sdot v22.4s, v1.16b, v10.4b[1] + KAI_ASM_INST(0x4f8ae835) // sdot v21.4s, v1.16b, v10.4b[2] + KAI_ASM_INST(0x4faae834) // sdot v20.4s, v1.16b, v10.4b[3] + ldr q10, [x22, #0x60] + KAI_ASM_INST(0x4f87e033) // sdot v19.4s, v1.16b, v7.4b[0] + KAI_ASM_INST(0x4fa7e032) // sdot v18.4s, v1.16b, v7.4b[1] + KAI_ASM_INST(0x4f87e831) // sdot v17.4s, v1.16b, v7.4b[2] + KAI_ASM_INST(0x4fa7e830) // sdot v16.4s, v1.16b, v7.4b[3] + ldr q7, [x21, #0x60] + ldr q1, [x20, #0x60] + KAI_ASM_INST(0x4f84e17f) // sdot v31.4s, v11.16b, v4.4b[0] + KAI_ASM_INST(0x4fa4e17e) // sdot v30.4s, v11.16b, v4.4b[1] + KAI_ASM_INST(0x4f84e97d) // sdot v29.4s, v11.16b, v4.4b[2] + KAI_ASM_INST(0x4fa4e97c) // sdot v28.4s, v11.16b, v4.4b[3] + ldr q4, [x27, #0x70] + add x27, x27, #0x80 + KAI_ASM_INST(0x4f80e17b) // sdot v27.4s, v11.16b, v0.4b[0] + KAI_ASM_INST(0x4fa0e17a) // sdot v26.4s, v11.16b, v0.4b[1] + KAI_ASM_INST(0x4f80e979) // sdot v25.4s, v11.16b, v0.4b[2] + KAI_ASM_INST(0x4fa0e978) // sdot v24.4s, v11.16b, v0.4b[3] + ldr q0, [x22, #0x70] + add x22, x22, #0x80 + KAI_ASM_INST(0x4f82e177) // sdot v23.4s, v11.16b, v2.4b[0] + KAI_ASM_INST(0x4fa2e176) // sdot v22.4s, v11.16b, v2.4b[1] + KAI_ASM_INST(0x4f82e975) // sdot v21.4s, v11.16b, v2.4b[2] + KAI_ASM_INST(0x4fa2e974) // sdot v20.4s, v11.16b, v2.4b[3] + ldr q2, [x21, #0x70] + add x21, x21, #0x80 + KAI_ASM_INST(0x4f8ee173) // sdot v19.4s, v11.16b, v14.4b[0] + KAI_ASM_INST(0x4faee172) // sdot v18.4s, v11.16b, v14.4b[1] + KAI_ASM_INST(0x4f8ee971) // sdot v17.4s, v11.16b, v14.4b[2] + KAI_ASM_INST(0x4faee970) // sdot v16.4s, v11.16b, v14.4b[3] + ldr q11, [x20, #0x70] + add x20, x20, #0x80 + KAI_ASM_INST(0x4f88e1bf) // sdot v31.4s, v13.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e1be) // sdot v30.4s, v13.16b, v8.4b[1] + KAI_ASM_INST(0x4f88e9bd) // sdot v29.4s, v13.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8e9bc) // sdot v28.4s, v13.16b, v8.4b[3] + KAI_ASM_INST(0x4f8ce1bb) // sdot v27.4s, v13.16b, v12.4b[0] + KAI_ASM_INST(0x4face1ba) // sdot v26.4s, v13.16b, v12.4b[1] + KAI_ASM_INST(0x4f8ce9b9) // sdot v25.4s, v13.16b, v12.4b[2] + KAI_ASM_INST(0x4face9b8) // sdot v24.4s, v13.16b, v12.4b[3] + KAI_ASM_INST(0x4f86e1b7) // sdot v23.4s, v13.16b, v6.4b[0] + KAI_ASM_INST(0x4fa6e1b6) // sdot v22.4s, v13.16b, v6.4b[1] + KAI_ASM_INST(0x4f86e9b5) // sdot v21.4s, v13.16b, v6.4b[2] + KAI_ASM_INST(0x4fa6e9b4) // sdot v20.4s, v13.16b, v6.4b[3] + KAI_ASM_INST(0x4f83e1b3) // sdot v19.4s, v13.16b, v3.4b[0] + KAI_ASM_INST(0x4fa3e1b2) // sdot v18.4s, v13.16b, v3.4b[1] + KAI_ASM_INST(0x4f83e9b1) // sdot v17.4s, v13.16b, v3.4b[2] + KAI_ASM_INST(0x4fa3e9b0) // sdot v16.4s, v13.16b, v3.4b[3] + KAI_ASM_INST(0x4f89e0bf) // sdot v31.4s, v5.16b, v9.4b[0] + KAI_ASM_INST(0x4fa9e0be) // sdot v30.4s, v5.16b, v9.4b[1] + KAI_ASM_INST(0x4f89e8bd) // sdot v29.4s, v5.16b, v9.4b[2] + KAI_ASM_INST(0x4fa9e8bc) // sdot v28.4s, v5.16b, v9.4b[3] + KAI_ASM_INST(0x4f8ae0bb) // sdot v27.4s, v5.16b, v10.4b[0] + KAI_ASM_INST(0x4faae0ba) // sdot v26.4s, v5.16b, v10.4b[1] + KAI_ASM_INST(0x4f8ae8b9) // sdot v25.4s, v5.16b, v10.4b[2] + KAI_ASM_INST(0x4faae8b8) // sdot v24.4s, v5.16b, v10.4b[3] + KAI_ASM_INST(0x4f87e0b7) // sdot v23.4s, v5.16b, v7.4b[0] + KAI_ASM_INST(0x4fa7e0b6) // sdot v22.4s, v5.16b, v7.4b[1] + KAI_ASM_INST(0x4f87e8b5) // sdot v21.4s, v5.16b, v7.4b[2] + KAI_ASM_INST(0x4fa7e8b4) // sdot v20.4s, v5.16b, v7.4b[3] + KAI_ASM_INST(0x4f81e0b3) // sdot v19.4s, v5.16b, v1.4b[0] + KAI_ASM_INST(0x4fa1e0b2) // sdot v18.4s, v5.16b, v1.4b[1] + KAI_ASM_INST(0x4f81e8b1) // sdot v17.4s, v5.16b, v1.4b[2] + KAI_ASM_INST(0x4fa1e8b0) // sdot v16.4s, v5.16b, v1.4b[3] + KAI_ASM_INST(0x4f84e1ff) // sdot v31.4s, v15.16b, v4.4b[0] + KAI_ASM_INST(0x4fa4e1fe) // sdot v30.4s, v15.16b, v4.4b[1] + KAI_ASM_INST(0x4f84e9fd) // sdot v29.4s, v15.16b, v4.4b[2] + KAI_ASM_INST(0x4fa4e9fc) // sdot v28.4s, v15.16b, v4.4b[3] + KAI_ASM_INST(0x4f80e1fb) // sdot v27.4s, v15.16b, v0.4b[0] + KAI_ASM_INST(0x4fa0e1fa) // sdot v26.4s, v15.16b, v0.4b[1] + KAI_ASM_INST(0x4f80e9f9) // sdot v25.4s, v15.16b, v0.4b[2] + KAI_ASM_INST(0x4fa0e9f8) // sdot v24.4s, v15.16b, v0.4b[3] + KAI_ASM_INST(0x4f82e1f7) // sdot v23.4s, v15.16b, v2.4b[0] + KAI_ASM_INST(0x4fa2e1f6) // sdot v22.4s, v15.16b, v2.4b[1] + KAI_ASM_INST(0x4f82e9f5) // sdot v21.4s, v15.16b, v2.4b[2] + KAI_ASM_INST(0x4fa2e9f4) // sdot v20.4s, v15.16b, v2.4b[3] + KAI_ASM_INST(0x4f8be1f3) // sdot v19.4s, v15.16b, v11.4b[0] + KAI_ASM_INST(0x4fabe1f2) // sdot v18.4s, v15.16b, v11.4b[1] + KAI_ASM_INST(0x4f8be9f1) // sdot v17.4s, v15.16b, v11.4b[2] + KAI_ASM_INST(0x4fabe9f0) // sdot v16.4s, v15.16b, v11.4b[3] + bgt label_3 + ldr q5, [x11, #0x0] + ld1 { v1.4s }, [x27] + add x27, x27, #0x10 + ldr q4, [x11, #0x10] + ldr q0, [x27, #0x0] + add x11, x11, #0x20 + mla v31.4s, v5.4s, v1.s[0] + mla v30.4s, v5.4s, v1.s[1] + mla v29.4s, v5.4s, v1.s[2] + mla v28.4s, v5.4s, v1.s[3] + fmul v3.4s, v4.4s, v0.s[0] + fmul v2.4s, v4.4s, v0.s[1] + fmul v1.4s, v4.4s, v0.s[2] + scvtf v31.4s, v31.4s + fmul v0.4s, v4.4s, v0.s[3] + scvtf v30.4s, v30.4s + scvtf v29.4s, v29.4s + scvtf v28.4s, v28.4s + fmul v31.4s, v31.4s, v3.4s + fmul v30.4s, v30.4s, v2.4s + fmul v29.4s, v29.4s, v1.4s + fmul v28.4s, v28.4s, v0.4s + ld1 { v1.4s }, [x22] + add x22, x22, #0x10 + ldr q0, [x22, #0x0] + mla v27.4s, v5.4s, v1.s[0] + mla v26.4s, v5.4s, v1.s[1] + mla v25.4s, v5.4s, v1.s[2] + mla v24.4s, v5.4s, v1.s[3] + fmul v3.4s, v4.4s, v0.s[0] + fmul v2.4s, v4.4s, v0.s[1] + fmul v1.4s, v4.4s, v0.s[2] + scvtf v27.4s, v27.4s + fmul v0.4s, v4.4s, v0.s[3] + scvtf v26.4s, v26.4s + scvtf v25.4s, v25.4s + scvtf v24.4s, v24.4s + fmul v27.4s, v27.4s, v3.4s + fmul v26.4s, v26.4s, v2.4s + fmul v25.4s, v25.4s, v1.4s + fmul v24.4s, v24.4s, v0.4s + ld1 { v1.4s }, [x21] + add x21, x21, #0x10 + ldr q0, [x21, #0x0] + mla v23.4s, v5.4s, v1.s[0] + mla v22.4s, v5.4s, v1.s[1] + mla v21.4s, v5.4s, v1.s[2] + mla v20.4s, v5.4s, v1.s[3] + fmul v3.4s, v4.4s, v0.s[0] + fmul v2.4s, v4.4s, v0.s[1] + fmul v1.4s, v4.4s, v0.s[2] + scvtf v23.4s, v23.4s + fmul v0.4s, v4.4s, v0.s[3] + scvtf v22.4s, v22.4s + scvtf v21.4s, v21.4s + scvtf v20.4s, v20.4s + fmul v23.4s, v23.4s, v3.4s + fmul v22.4s, v22.4s, v2.4s + fmul v21.4s, v21.4s, v1.4s + fmul v20.4s, v20.4s, v0.4s + ld1 { v1.4s }, [x20] + add x20, x20, #0x10 + ldr q0, [x20, #0x0] + mla v19.4s, v5.4s, v1.s[0] + mla v18.4s, v5.4s, v1.s[1] + mla v17.4s, v5.4s, v1.s[2] + mla v16.4s, v5.4s, v1.s[3] + fmul v3.4s, v4.4s, v0.s[0] + fmul v2.4s, v4.4s, v0.s[1] + fmul v1.4s, v4.4s, v0.s[2] + scvtf v19.4s, v19.4s + fmul v0.4s, v4.4s, v0.s[3] + scvtf v18.4s, v18.4s + scvtf v17.4s, v17.4s + scvtf v16.4s, v16.4s + fmul v19.4s, v19.4s, v3.4s + fmul v18.4s, v18.4s, v2.4s + fmul v17.4s, v17.4s, v1.4s + fmul v16.4s, v16.4s, v0.4s + ldr q2, [x11, #0x0] + ld1r { v1.4s }, [x12] + add x20, x12, #0x4 + cmp x10, #0x4 + ld1r { v0.4s }, [x20] + add x11, x11, #0x10 + fadd v31.4s, v31.4s, v2.4s + fadd v30.4s, v30.4s, v2.4s + fadd v29.4s, v29.4s, v2.4s + fadd v28.4s, v28.4s, v2.4s + fadd v27.4s, v27.4s, v2.4s + fadd v26.4s, v26.4s, v2.4s + fadd v25.4s, v25.4s, v2.4s + fadd v24.4s, v24.4s, v2.4s + fadd v23.4s, v23.4s, v2.4s + fadd v22.4s, v22.4s, v2.4s + fadd v21.4s, v21.4s, v2.4s + fadd v20.4s, v20.4s, v2.4s + fadd v19.4s, v19.4s, v2.4s + fadd v18.4s, v18.4s, v2.4s + fadd v17.4s, v17.4s, v2.4s + fadd v16.4s, v16.4s, v2.4s + fmax v31.4s, v31.4s, v1.4s + fmax v30.4s, v30.4s, v1.4s + fmax v29.4s, v29.4s, v1.4s + fmax v28.4s, v28.4s, v1.4s + fmax v27.4s, v27.4s, v1.4s + fmax v26.4s, v26.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + fmax v24.4s, v24.4s, v1.4s + fmax v23.4s, v23.4s, v1.4s + fmax v22.4s, v22.4s, v1.4s + fmax v21.4s, v21.4s, v1.4s + fmax v20.4s, v20.4s, v1.4s + fmax v19.4s, v19.4s, v1.4s + fmax v18.4s, v18.4s, v1.4s + fmax v17.4s, v17.4s, v1.4s + fmax v16.4s, v16.4s, v1.4s + fmin v31.4s, v31.4s, v0.4s + fmin v30.4s, v30.4s, v0.4s + fmin v29.4s, v29.4s, v0.4s + fmin v28.4s, v28.4s, v0.4s + fmin v27.4s, v27.4s, v0.4s + fmin v26.4s, v26.4s, v0.4s + fmin v25.4s, v25.4s, v0.4s + fmin v24.4s, v24.4s, v0.4s + fmin v23.4s, v23.4s, v0.4s + fmin v22.4s, v22.4s, v0.4s + fmin v21.4s, v21.4s, v0.4s + fmin v20.4s, v20.4s, v0.4s + fmin v19.4s, v19.4s, v0.4s + fmin v18.4s, v18.4s, v0.4s + fmin v17.4s, v17.4s, v0.4s + fmin v16.4s, v16.4s, v0.4s + fcvtn v31.4h, v31.4s + fcvtn v30.4h, v30.4s + fcvtn v29.4h, v29.4s + fcvtn v28.4h, v28.4s + fcvtn v27.4h, v27.4s + fcvtn v26.4h, v26.4s + fcvtn v25.4h, v25.4s + fcvtn v24.4h, v24.4s + fcvtn v23.4h, v23.4s + fcvtn v22.4h, v22.4s + fcvtn v21.4h, v21.4s + fcvtn v20.4h, v20.4s + fcvtn v19.4h, v19.4s + fcvtn v18.4h, v18.4s + fcvtn v17.4h, v17.4s + fcvtn v16.4h, v16.4s + blt label_8 + mov x20, x15 + str d31, [x20, #0x0] + add x20, x20, x13 + str d30, [x20, #0x0] + add x20, x20, x13 + str d29, [x20, #0x0] + add x20, x20, x13 + str d28, [x20, #0x0] + add x20, x20, x13 + str d27, [x20, #0x0] + add x20, x20, x13 + str d26, [x20, #0x0] + add x20, x20, x13 + str d25, [x20, #0x0] + add x20, x20, x13 + str d24, [x20, #0x0] + add x20, x20, x13 + str d23, [x20, #0x0] + add x20, x20, x13 + str d22, [x20, #0x0] + add x20, x20, x13 + str d21, [x20, #0x0] + add x20, x20, x13 + str d20, [x20, #0x0] + add x20, x20, x13 + str d19, [x20, #0x0] + add x20, x20, x13 + str d18, [x20, #0x0] + add x20, x20, x13 + str d17, [x20, #0x0] + add x20, x20, x13 + str d16, [x20, #0x0] + b label_13 +KAI_ASM_LABEL(label_8) // Partial output + mov x28, x15 + add x26, x28, x13, LSL #2 + add x25, x26, x13, LSL #1 + add x24, x26, x13 + add x23, x25, x13 + add x22, x28, x13, LSL #1 + add x21, x28, x13 + add x20, x22, x13 + add x27, x23, x13 + tbz x10, #1, label_9 + st1 { v24.s }[0], [x23], #0x4 + st1 { v25.s }[0], [x25], #0x4 + st1 { v26.s }[0], [x24], #0x4 + st1 { v27.s }[0], [x26], #0x4 + st1 { v28.s }[0], [x20], #0x4 + st1 { v29.s }[0], [x22], #0x4 + st1 { v30.s }[0], [x21], #0x4 + st1 { v31.s }[0], [x28], #0x4 + tbz x10, #0, label_10 + st1 { v24.h }[2], [x23] + st1 { v25.h }[2], [x25] + st1 { v26.h }[2], [x24] + st1 { v27.h }[2], [x26] + st1 { v28.h }[2], [x20] + st1 { v29.h }[2], [x22] + st1 { v30.h }[2], [x21] + st1 { v31.h }[2], [x28] + b label_10 +KAI_ASM_LABEL(label_9) // Output block 0: partial_1_0 + st1 { v24.h }[0], [x23] + st1 { v25.h }[0], [x25] + st1 { v26.h }[0], [x24] + st1 { v27.h }[0], [x26] + st1 { v28.h }[0], [x20] + st1 { v29.h }[0], [x22] + st1 { v30.h }[0], [x21] + st1 { v31.h }[0], [x28] +KAI_ASM_LABEL(label_10) // Output block 0: Done + add x26, x27, x13, LSL #2 + add x25, x27, x13, LSL #1 + add x24, x26, x13, LSL #1 + add x23, x27, x13 + add x22, x25, x13 + add x21, x26, x13 + add x20, x24, x13 + tbz x10, #1, label_11 + st1 { v16.s }[0], [x20], #0x4 + st1 { v17.s }[0], [x24], #0x4 + st1 { v18.s }[0], [x21], #0x4 + st1 { v19.s }[0], [x26], #0x4 + st1 { v20.s }[0], [x22], #0x4 + st1 { v21.s }[0], [x25], #0x4 + st1 { v22.s }[0], [x23], #0x4 + st1 { v23.s }[0], [x27], #0x4 + tbz x10, #0, label_12 + st1 { v16.h }[2], [x20] + st1 { v17.h }[2], [x24] + st1 { v18.h }[2], [x21] + st1 { v19.h }[2], [x26] + st1 { v20.h }[2], [x22] + st1 { v21.h }[2], [x25] + st1 { v22.h }[2], [x23] + st1 { v23.h }[2], [x27] + b label_12 +KAI_ASM_LABEL(label_11) // Output block 1: partial_1_0 + st1 { v16.h }[0], [x20] + st1 { v17.h }[0], [x24] + st1 { v18.h }[0], [x21] + st1 { v19.h }[0], [x26] + st1 { v20.h }[0], [x22] + st1 { v21.h }[0], [x25] + st1 { v22.h }[0], [x23] + st1 { v23.h }[0], [x27] +KAI_ASM_LABEL(label_12) // Output block 1: Done +KAI_ASM_LABEL(label_13) // Output stage exit + subs x10, x10, #0x4 + add x15, x15, #0x8 + bgt label_2 + mov x20, #0x4 + sub x14, x14, #0x10 + cmp x14, #0x10 + mov x15, x9 + madd x8, x20, x6, x8 + bge label_1 +KAI_ASM_LABEL(label_14) // Row loop skip + cbz x14, label_23 +KAI_ASM_LABEL(label_15) // Row tail: Row loop + mov x26, x17 + mov x25, x16 + add x24, x15, x13, LSL #2 +KAI_ASM_LABEL(label_16) // Row tail: Column loop + mov x27, x8 + movi v31.4s, #0x0 + movi v30.4s, #0x0 + mov x20, x7 + movi v29.4s, #0x0 + movi v28.4s, #0x0 +KAI_ASM_LABEL(label_17) // Row tail: Sub block loop + ldr q4, [x26, #0x0] + ldr q3, [x27, #0x0] + movi v2.16b, #0xf0 + subs x20, x20, #0x1 + ldr q1, [x26, #0x10] + ldr q0, [x27, #0x10] + ldr q27, [x26, #0x20] + ldr q26, [x27, #0x20] + ldr q25, [x26, #0x30] + ldr q24, [x27, #0x30] + shl v23.16b, v4.16b, #0x4 + and v4.16b, v4.16b, v2.16b + ldr q22, [x27, #0x40] + ldr q21, [x27, #0x50] + shl v20.16b, v1.16b, #0x4 + and v1.16b, v1.16b, v2.16b + ldr q19, [x27, #0x60] + ldr q18, [x27, #0x70] + shl v17.16b, v27.16b, #0x4 + and v27.16b, v27.16b, v2.16b + KAI_ASM_INST(0x4f83e2ff) // sdot v31.4s, v23.16b, v3.4b[0] + KAI_ASM_INST(0x4fa3e2fe) // sdot v30.4s, v23.16b, v3.4b[1] + shl v16.16b, v25.16b, #0x4 + add x26, x26, #0x40 + KAI_ASM_INST(0x4f83eafd) // sdot v29.4s, v23.16b, v3.4b[2] + KAI_ASM_INST(0x4fa3eafc) // sdot v28.4s, v23.16b, v3.4b[3] + and v25.16b, v25.16b, v2.16b + add x27, x27, #0x80 + KAI_ASM_INST(0x4f80e29f) // sdot v31.4s, v20.16b, v0.4b[0] + KAI_ASM_INST(0x4fa0e29e) // sdot v30.4s, v20.16b, v0.4b[1] + KAI_ASM_INST(0x4f80ea9d) // sdot v29.4s, v20.16b, v0.4b[2] + KAI_ASM_INST(0x4fa0ea9c) // sdot v28.4s, v20.16b, v0.4b[3] + KAI_ASM_INST(0x4f9ae23f) // sdot v31.4s, v17.16b, v26.4b[0] + KAI_ASM_INST(0x4fbae23e) // sdot v30.4s, v17.16b, v26.4b[1] + KAI_ASM_INST(0x4f9aea3d) // sdot v29.4s, v17.16b, v26.4b[2] + KAI_ASM_INST(0x4fbaea3c) // sdot v28.4s, v17.16b, v26.4b[3] + KAI_ASM_INST(0x4f98e21f) // sdot v31.4s, v16.16b, v24.4b[0] + KAI_ASM_INST(0x4fb8e21e) // sdot v30.4s, v16.16b, v24.4b[1] + KAI_ASM_INST(0x4f98ea1d) // sdot v29.4s, v16.16b, v24.4b[2] + KAI_ASM_INST(0x4fb8ea1c) // sdot v28.4s, v16.16b, v24.4b[3] + KAI_ASM_INST(0x4f96e09f) // sdot v31.4s, v4.16b, v22.4b[0] + KAI_ASM_INST(0x4fb6e09e) // sdot v30.4s, v4.16b, v22.4b[1] + KAI_ASM_INST(0x4f96e89d) // sdot v29.4s, v4.16b, v22.4b[2] + KAI_ASM_INST(0x4fb6e89c) // sdot v28.4s, v4.16b, v22.4b[3] + KAI_ASM_INST(0x4f95e03f) // sdot v31.4s, v1.16b, v21.4b[0] + KAI_ASM_INST(0x4fb5e03e) // sdot v30.4s, v1.16b, v21.4b[1] + KAI_ASM_INST(0x4f95e83d) // sdot v29.4s, v1.16b, v21.4b[2] + KAI_ASM_INST(0x4fb5e83c) // sdot v28.4s, v1.16b, v21.4b[3] + KAI_ASM_INST(0x4f93e37f) // sdot v31.4s, v27.16b, v19.4b[0] + KAI_ASM_INST(0x4fb3e37e) // sdot v30.4s, v27.16b, v19.4b[1] + KAI_ASM_INST(0x4f93eb7d) // sdot v29.4s, v27.16b, v19.4b[2] + KAI_ASM_INST(0x4fb3eb7c) // sdot v28.4s, v27.16b, v19.4b[3] + KAI_ASM_INST(0x4f92e33f) // sdot v31.4s, v25.16b, v18.4b[0] + KAI_ASM_INST(0x4fb2e33e) // sdot v30.4s, v25.16b, v18.4b[1] + KAI_ASM_INST(0x4f92eb3d) // sdot v29.4s, v25.16b, v18.4b[2] + KAI_ASM_INST(0x4fb2eb3c) // sdot v28.4s, v25.16b, v18.4b[3] + bgt label_17 + ldr q18, [x26, #0x0] + ld1 { v17.4s }, [x27] + add x27, x27, #0x10 + ldr q20, [x26, #0x10] + ldr q16, [x27, #0x0] + add x26, x26, #0x20 + mla v31.4s, v18.4s, v17.s[0] + mla v30.4s, v18.4s, v17.s[1] + mla v29.4s, v18.4s, v17.s[2] + mla v28.4s, v18.4s, v17.s[3] + fmul v19.4s, v20.4s, v16.s[0] + fmul v18.4s, v20.4s, v16.s[1] + fmul v17.4s, v20.4s, v16.s[2] + scvtf v31.4s, v31.4s + fmul v16.4s, v20.4s, v16.s[3] + scvtf v30.4s, v30.4s + scvtf v29.4s, v29.4s + scvtf v28.4s, v28.4s + fmul v31.4s, v31.4s, v19.4s + fmul v30.4s, v30.4s, v18.4s + fmul v29.4s, v29.4s, v17.4s + fmul v28.4s, v28.4s, v16.4s + ldr q18, [x26, #0x0] + ld1r { v17.4s }, [x12] + add x20, x12, #0x4 + cmp x25, #0x4 + ld1r { v16.4s }, [x20] + add x26, x26, #0x10 + fadd v31.4s, v31.4s, v18.4s + fadd v30.4s, v30.4s, v18.4s + fadd v29.4s, v29.4s, v18.4s + fadd v28.4s, v28.4s, v18.4s + fmax v31.4s, v31.4s, v17.4s + fmax v30.4s, v30.4s, v17.4s + fmax v29.4s, v29.4s, v17.4s + fmax v28.4s, v28.4s, v17.4s + fmin v31.4s, v31.4s, v16.4s + fmin v30.4s, v30.4s, v16.4s + fmin v29.4s, v29.4s, v16.4s + fmin v28.4s, v28.4s, v16.4s + fcvtn v19.4h, v31.4s + fcvtn v18.4h, v30.4s + fcvtn v17.4h, v29.4s + fcvtn v16.4h, v28.4s + blt label_19 + mov x20, x15 + cmp x14, #0x1 + str d19, [x20, #0x0] + add x20, x20, x13 + ble label_22 + cmp x14, #0x2 + str d18, [x20, #0x0] + add x20, x20, x13 + ble label_22 + cmp x14, #0x3 + str d17, [x20, #0x0] + add x20, x20, x13 + ble label_22 + str d16, [x20, #0x0] + b label_22 +KAI_ASM_LABEL(label_19) // Row tail: Partial output + mov x23, x15 + cmp x14, #0x1 + add x22, x23, x13 + csel x22, x22, x23, GT + cmp x14, #0x2 + add x21, x23, x13, LSL #1 + csel x21, x21, x22, GT + cmp x14, #0x3 + add x20, x21, x13 + csel x20, x20, x21, GT + tbz x25, #1, label_20 + st1 { v16.s }[0], [x20], #0x4 + st1 { v17.s }[0], [x21], #0x4 + st1 { v18.s }[0], [x22], #0x4 + st1 { v19.s }[0], [x23], #0x4 + tbz x25, #0, label_21 + st1 { v16.h }[2], [x20] + st1 { v17.h }[2], [x21] + st1 { v18.h }[2], [x22] + st1 { v19.h }[2], [x23] + b label_21 +KAI_ASM_LABEL(label_20) // Row tail: Output block 0: partial_1_0 + st1 { v16.h }[0], [x20] + st1 { v17.h }[0], [x21] + st1 { v18.h }[0], [x22] + st1 { v19.h }[0], [x23] +KAI_ASM_LABEL(label_21) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_22) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x15, x15, #0x8 + bgt label_16 + subs x14, x14, #0x4 + add x8, x8, x6 + mov x15, x24 + bgt label_15 +KAI_ASM_LABEL(label_23) // Row tail: Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm.c new file mode 100644 index 0000000000000000000000000000000000000000..6b63e6f3955a752f90c3c915a026095bf44fc11e --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm.c @@ -0,0 +1,165 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) && !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && \ + !defined(_M_ARM64) +#error "I8mm extension and fp16 vector arithmetic required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + void* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 16; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t rhs_packed_stride = kai_nr * (k_internal / kai_num_bytes_recip_qvalue_rhs); + + rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + if (m == 0) { + return; + } + const size_t k_internal = kai_get_k_roundedup(k); + size_t num_blocks = k_internal / kai_bl; + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + + kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm.h new file mode 100644 index 0000000000000000000000000000000000000000..dc8eb0e8fa514ba2375c22d8b711466e18a4f144 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm.h @@ -0,0 +1,139 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_f16_neon to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 4-bit with per-channel quantization (qsi4cx) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-channel quantization (qsi4cx) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..67c7e1113c4404ba126d02611b39271e7ff357ef --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm_asm.S @@ -0,0 +1,663 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x6, #0x80 + mov x21, #0x20 + ldr x20, [x0, #0x28] + ldr x7, [x0, #0x38] + ldr x8, [x0, #0x8] + ldr x17, [x0, #0x10] + ldr x16, [x0, #0x30] + ldr x15, [x0, #0x0] + mov x14, x20 + ldr x13, [x0, #0x20] + madd x6, x7, x6, x21 + ldr x12, [x0, #0x18] + cmp x14, #0x10 + blt label_14 +KAI_ASM_LABEL(label_1) // Row loop + mov x11, x17 + mov x10, x16 + add x9, x15, x13, LSL #4 +KAI_ASM_LABEL(label_2) // Column loop + mov x27, x8 + movi v31.4s, #0x0 + movi v30.4s, #0x0 + mov x23, x7 + movi v29.4s, #0x0 + movi v28.4s, #0x0 + movi v27.4s, #0x0 + movi v26.4s, #0x0 + add x22, x27, x6 + add x21, x22, x6 + add x20, x21, x6 + movi v25.4s, #0x0 + movi v24.4s, #0x0 + movi v23.4s, #0x0 + movi v22.4s, #0x0 + movi v21.4s, #0x0 + movi v20.4s, #0x0 + movi v19.4s, #0x0 + movi v18.4s, #0x0 + movi v17.4s, #0x0 + movi v16.4s, #0x0 +KAI_ASM_LABEL(label_3) // Sub block loop + ldr q14, [x11, #0x0] + ldr q12, [x11, #0x10] + movi v1.16b, #0xf0 + subs x23, x23, #0x1 + ldr q11, [x27, #0x0] + ldr q10, [x27, #0x10] + ldr q9, [x22, #0x0] + ldr q8, [x22, #0x10] + ldr q7, [x21, #0x0] + ldr q6, [x21, #0x10] + shl v4.16b, v14.16b, #0x4 + shl v5.16b, v12.16b, #0x4 + ldr q3, [x20, #0x0] + ldr q2, [x20, #0x10] + and v14.16b, v14.16b, v1.16b + and v12.16b, v12.16b, v1.16b + ldr q13, [x11, #0x20] + ldr q15, [x11, #0x30] + add x11, x11, #0x40 + ldr q0, [x27, #0x20] + KAI_ASM_INST(0x4e84a57f) // smmla v31.4s, v11.16b, v4.16b + KAI_ASM_INST(0x4e85a57e) // smmla v30.4s, v11.16b, v5.16b + ldr q11, [x27, #0x30] + KAI_ASM_INST(0x4e84a55d) // smmla v29.4s, v10.16b, v4.16b + KAI_ASM_INST(0x4e85a55c) // smmla v28.4s, v10.16b, v5.16b + ldr q10, [x22, #0x20] + KAI_ASM_INST(0x4e84a53b) // smmla v27.4s, v9.16b, v4.16b + KAI_ASM_INST(0x4e85a53a) // smmla v26.4s, v9.16b, v5.16b + ldr q9, [x22, #0x30] + KAI_ASM_INST(0x4e84a519) // smmla v25.4s, v8.16b, v4.16b + KAI_ASM_INST(0x4e85a518) // smmla v24.4s, v8.16b, v5.16b + ldr q8, [x21, #0x20] + KAI_ASM_INST(0x4e84a4f7) // smmla v23.4s, v7.16b, v4.16b + KAI_ASM_INST(0x4e85a4f6) // smmla v22.4s, v7.16b, v5.16b + ldr q7, [x21, #0x30] + KAI_ASM_INST(0x4e84a4d5) // smmla v21.4s, v6.16b, v4.16b + KAI_ASM_INST(0x4e85a4d4) // smmla v20.4s, v6.16b, v5.16b + ldr q6, [x20, #0x20] + KAI_ASM_INST(0x4e84a473) // smmla v19.4s, v3.16b, v4.16b + KAI_ASM_INST(0x4e85a472) // smmla v18.4s, v3.16b, v5.16b + ldr q3, [x20, #0x30] + KAI_ASM_INST(0x4e84a451) // smmla v17.4s, v2.16b, v4.16b + ldr q4, [x27, #0x40] + KAI_ASM_INST(0x4e85a450) // smmla v16.4s, v2.16b, v5.16b + ldr q2, [x27, #0x50] + shl v5.16b, v13.16b, #0x4 + and v13.16b, v13.16b, v1.16b + KAI_ASM_INST(0x4e85a41f) // smmla v31.4s, v0.16b, v5.16b + KAI_ASM_INST(0x4e85a57d) // smmla v29.4s, v11.16b, v5.16b + KAI_ASM_INST(0x4e85a55b) // smmla v27.4s, v10.16b, v5.16b + KAI_ASM_INST(0x4e85a539) // smmla v25.4s, v9.16b, v5.16b + KAI_ASM_INST(0x4e85a517) // smmla v23.4s, v8.16b, v5.16b + KAI_ASM_INST(0x4e85a4f5) // smmla v21.4s, v7.16b, v5.16b + KAI_ASM_INST(0x4e85a4d3) // smmla v19.4s, v6.16b, v5.16b + KAI_ASM_INST(0x4e85a471) // smmla v17.4s, v3.16b, v5.16b + shl v5.16b, v15.16b, #0x4 + KAI_ASM_INST(0x4e8ea49f) // smmla v31.4s, v4.16b, v14.16b + KAI_ASM_INST(0x4e8ea45d) // smmla v29.4s, v2.16b, v14.16b + and v15.16b, v15.16b, v1.16b + ldr q1, [x22, #0x40] + KAI_ASM_INST(0x4e85a41e) // smmla v30.4s, v0.16b, v5.16b + ldr q0, [x22, #0x50] + KAI_ASM_INST(0x4e85a57c) // smmla v28.4s, v11.16b, v5.16b + ldr q11, [x21, #0x40] + KAI_ASM_INST(0x4e85a55a) // smmla v26.4s, v10.16b, v5.16b + ldr q10, [x21, #0x50] + KAI_ASM_INST(0x4e85a538) // smmla v24.4s, v9.16b, v5.16b + ldr q9, [x20, #0x40] + KAI_ASM_INST(0x4e85a516) // smmla v22.4s, v8.16b, v5.16b + ldr q8, [x20, #0x50] + KAI_ASM_INST(0x4e85a4f4) // smmla v20.4s, v7.16b, v5.16b + ldr q7, [x27, #0x60] + KAI_ASM_INST(0x4e85a4d2) // smmla v18.4s, v6.16b, v5.16b + ldr q6, [x27, #0x70] + KAI_ASM_INST(0x4e85a470) // smmla v16.4s, v3.16b, v5.16b + ldr q5, [x22, #0x60] + KAI_ASM_INST(0x4e8ca49e) // smmla v30.4s, v4.16b, v12.16b + ldr q4, [x22, #0x70] + ldr q3, [x21, #0x60] + KAI_ASM_INST(0x4e8ca45c) // smmla v28.4s, v2.16b, v12.16b + KAI_ASM_INST(0x4e8ea43b) // smmla v27.4s, v1.16b, v14.16b + ldr q2, [x21, #0x70] + KAI_ASM_INST(0x4e8ca43a) // smmla v26.4s, v1.16b, v12.16b + ldr q1, [x20, #0x60] + KAI_ASM_INST(0x4e8ea419) // smmla v25.4s, v0.16b, v14.16b + KAI_ASM_INST(0x4e8ca418) // smmla v24.4s, v0.16b, v12.16b + ldr q0, [x20, #0x70] + KAI_ASM_INST(0x4e8ea577) // smmla v23.4s, v11.16b, v14.16b + add x27, x27, #0x80 + KAI_ASM_INST(0x4e8ca576) // smmla v22.4s, v11.16b, v12.16b + KAI_ASM_INST(0x4e8ea555) // smmla v21.4s, v10.16b, v14.16b + add x22, x22, #0x80 + add x21, x21, #0x80 + KAI_ASM_INST(0x4e8ca554) // smmla v20.4s, v10.16b, v12.16b + KAI_ASM_INST(0x4e8ea533) // smmla v19.4s, v9.16b, v14.16b + add x20, x20, #0x80 + KAI_ASM_INST(0x4e8ca532) // smmla v18.4s, v9.16b, v12.16b + KAI_ASM_INST(0x4e8ea511) // smmla v17.4s, v8.16b, v14.16b + KAI_ASM_INST(0x4e8ca510) // smmla v16.4s, v8.16b, v12.16b + KAI_ASM_INST(0x4e8da4ff) // smmla v31.4s, v7.16b, v13.16b + KAI_ASM_INST(0x4e8fa4fe) // smmla v30.4s, v7.16b, v15.16b + KAI_ASM_INST(0x4e8da4dd) // smmla v29.4s, v6.16b, v13.16b + KAI_ASM_INST(0x4e8fa4dc) // smmla v28.4s, v6.16b, v15.16b + KAI_ASM_INST(0x4e8da4bb) // smmla v27.4s, v5.16b, v13.16b + KAI_ASM_INST(0x4e8fa4ba) // smmla v26.4s, v5.16b, v15.16b + KAI_ASM_INST(0x4e8da499) // smmla v25.4s, v4.16b, v13.16b + KAI_ASM_INST(0x4e8fa498) // smmla v24.4s, v4.16b, v15.16b + KAI_ASM_INST(0x4e8da477) // smmla v23.4s, v3.16b, v13.16b + KAI_ASM_INST(0x4e8fa476) // smmla v22.4s, v3.16b, v15.16b + KAI_ASM_INST(0x4e8da455) // smmla v21.4s, v2.16b, v13.16b + KAI_ASM_INST(0x4e8fa454) // smmla v20.4s, v2.16b, v15.16b + KAI_ASM_INST(0x4e8da433) // smmla v19.4s, v1.16b, v13.16b + KAI_ASM_INST(0x4e8fa432) // smmla v18.4s, v1.16b, v15.16b + KAI_ASM_INST(0x4e8da411) // smmla v17.4s, v0.16b, v13.16b + KAI_ASM_INST(0x4e8fa410) // smmla v16.4s, v0.16b, v15.16b + bgt label_3 + ldr q7, [x11, #0x0] + ld1 { v4.4s }, [x27] + uzp1 v3.2d, v31.2d, v30.2d + uzp2 v2.2d, v31.2d, v30.2d + ldr q6, [x11, #0x10] + uzp1 v1.2d, v29.2d, v28.2d + uzp2 v0.2d, v29.2d, v28.2d + add x27, x27, #0x10 + ldr q28, [x27, #0x0] + add x11, x11, #0x20 + mla v3.4s, v7.4s, v4.s[0] + mla v2.4s, v7.4s, v4.s[1] + mla v1.4s, v7.4s, v4.s[2] + mla v0.4s, v7.4s, v4.s[3] + fmul v31.4s, v6.4s, v28.s[0] + fmul v30.4s, v6.4s, v28.s[1] + fmul v29.4s, v6.4s, v28.s[2] + fmul v28.4s, v6.4s, v28.s[3] + scvtf v3.4s, v3.4s + scvtf v2.4s, v2.4s + scvtf v1.4s, v1.4s + scvtf v0.4s, v0.4s + fmul v31.4s, v3.4s, v31.4s + fmul v30.4s, v2.4s, v30.4s + fmul v29.4s, v1.4s, v29.4s + fmul v28.4s, v0.4s, v28.4s + ld1 { v5.4s }, [x22] + uzp1 v4.2d, v27.2d, v26.2d + uzp2 v3.2d, v27.2d, v26.2d + add x22, x22, #0x10 + ldr q2, [x22, #0x0] + uzp1 v1.2d, v25.2d, v24.2d + uzp2 v0.2d, v25.2d, v24.2d + mla v4.4s, v7.4s, v5.s[0] + mla v3.4s, v7.4s, v5.s[1] + mla v1.4s, v7.4s, v5.s[2] + mla v0.4s, v7.4s, v5.s[3] + fmul v27.4s, v6.4s, v2.s[0] + fmul v26.4s, v6.4s, v2.s[1] + fmul v25.4s, v6.4s, v2.s[2] + scvtf v4.4s, v4.4s + fmul v24.4s, v6.4s, v2.s[3] + scvtf v3.4s, v3.4s + scvtf v1.4s, v1.4s + scvtf v0.4s, v0.4s + fmul v27.4s, v4.4s, v27.4s + fmul v26.4s, v3.4s, v26.4s + fmul v25.4s, v1.4s, v25.4s + fmul v24.4s, v0.4s, v24.4s + ld1 { v5.4s }, [x21] + uzp1 v4.2d, v23.2d, v22.2d + uzp2 v3.2d, v23.2d, v22.2d + add x21, x21, #0x10 + ldr q2, [x21, #0x0] + uzp1 v1.2d, v21.2d, v20.2d + uzp2 v0.2d, v21.2d, v20.2d + mla v4.4s, v7.4s, v5.s[0] + mla v3.4s, v7.4s, v5.s[1] + mla v1.4s, v7.4s, v5.s[2] + mla v0.4s, v7.4s, v5.s[3] + fmul v23.4s, v6.4s, v2.s[0] + fmul v22.4s, v6.4s, v2.s[1] + fmul v21.4s, v6.4s, v2.s[2] + scvtf v4.4s, v4.4s + fmul v20.4s, v6.4s, v2.s[3] + scvtf v3.4s, v3.4s + scvtf v1.4s, v1.4s + scvtf v0.4s, v0.4s + fmul v23.4s, v4.4s, v23.4s + fmul v22.4s, v3.4s, v22.4s + fmul v21.4s, v1.4s, v21.4s + fmul v20.4s, v0.4s, v20.4s + ld1 { v5.4s }, [x20] + uzp1 v4.2d, v19.2d, v18.2d + uzp2 v3.2d, v19.2d, v18.2d + add x20, x20, #0x10 + ldr q2, [x20, #0x0] + uzp1 v1.2d, v17.2d, v16.2d + uzp2 v0.2d, v17.2d, v16.2d + mla v4.4s, v7.4s, v5.s[0] + mla v3.4s, v7.4s, v5.s[1] + mla v1.4s, v7.4s, v5.s[2] + mla v0.4s, v7.4s, v5.s[3] + fmul v19.4s, v6.4s, v2.s[0] + fmul v18.4s, v6.4s, v2.s[1] + fmul v17.4s, v6.4s, v2.s[2] + scvtf v4.4s, v4.4s + fmul v16.4s, v6.4s, v2.s[3] + scvtf v3.4s, v3.4s + scvtf v1.4s, v1.4s + scvtf v0.4s, v0.4s + fmul v19.4s, v4.4s, v19.4s + fmul v18.4s, v3.4s, v18.4s + fmul v17.4s, v1.4s, v17.4s + fmul v16.4s, v0.4s, v16.4s + ldr q2, [x11, #0x0] + ld1r { v1.4s }, [x12] + add x20, x12, #0x4 + cmp x10, #0x4 + ld1r { v0.4s }, [x20] + add x11, x11, #0x10 + fadd v31.4s, v31.4s, v2.4s + fadd v30.4s, v30.4s, v2.4s + fadd v29.4s, v29.4s, v2.4s + fadd v28.4s, v28.4s, v2.4s + fadd v27.4s, v27.4s, v2.4s + fadd v26.4s, v26.4s, v2.4s + fadd v25.4s, v25.4s, v2.4s + fadd v24.4s, v24.4s, v2.4s + fadd v23.4s, v23.4s, v2.4s + fadd v22.4s, v22.4s, v2.4s + fadd v21.4s, v21.4s, v2.4s + fadd v20.4s, v20.4s, v2.4s + fadd v19.4s, v19.4s, v2.4s + fadd v18.4s, v18.4s, v2.4s + fadd v17.4s, v17.4s, v2.4s + fadd v16.4s, v16.4s, v2.4s + fmax v31.4s, v31.4s, v1.4s + fmax v30.4s, v30.4s, v1.4s + fmax v29.4s, v29.4s, v1.4s + fmax v28.4s, v28.4s, v1.4s + fmax v27.4s, v27.4s, v1.4s + fmax v26.4s, v26.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + fmax v24.4s, v24.4s, v1.4s + fmax v23.4s, v23.4s, v1.4s + fmax v22.4s, v22.4s, v1.4s + fmax v21.4s, v21.4s, v1.4s + fmax v20.4s, v20.4s, v1.4s + fmax v19.4s, v19.4s, v1.4s + fmax v18.4s, v18.4s, v1.4s + fmax v17.4s, v17.4s, v1.4s + fmax v16.4s, v16.4s, v1.4s + fmin v31.4s, v31.4s, v0.4s + fmin v30.4s, v30.4s, v0.4s + fmin v29.4s, v29.4s, v0.4s + fmin v28.4s, v28.4s, v0.4s + fmin v27.4s, v27.4s, v0.4s + fmin v26.4s, v26.4s, v0.4s + fmin v25.4s, v25.4s, v0.4s + fmin v24.4s, v24.4s, v0.4s + fmin v23.4s, v23.4s, v0.4s + fmin v22.4s, v22.4s, v0.4s + fmin v21.4s, v21.4s, v0.4s + fmin v20.4s, v20.4s, v0.4s + fmin v19.4s, v19.4s, v0.4s + fmin v18.4s, v18.4s, v0.4s + fmin v17.4s, v17.4s, v0.4s + fmin v16.4s, v16.4s, v0.4s + fcvtn v31.4h, v31.4s + fcvtn v30.4h, v30.4s + fcvtn v29.4h, v29.4s + fcvtn v28.4h, v28.4s + fcvtn v27.4h, v27.4s + fcvtn v26.4h, v26.4s + fcvtn v25.4h, v25.4s + fcvtn v24.4h, v24.4s + fcvtn v23.4h, v23.4s + fcvtn v22.4h, v22.4s + fcvtn v21.4h, v21.4s + fcvtn v20.4h, v20.4s + fcvtn v19.4h, v19.4s + fcvtn v18.4h, v18.4s + fcvtn v17.4h, v17.4s + fcvtn v16.4h, v16.4s + blt label_8 + mov x20, x15 + str d31, [x20, #0x0] + add x20, x20, x13 + str d30, [x20, #0x0] + add x20, x20, x13 + str d29, [x20, #0x0] + add x20, x20, x13 + str d28, [x20, #0x0] + add x20, x20, x13 + str d27, [x20, #0x0] + add x20, x20, x13 + str d26, [x20, #0x0] + add x20, x20, x13 + str d25, [x20, #0x0] + add x20, x20, x13 + str d24, [x20, #0x0] + add x20, x20, x13 + str d23, [x20, #0x0] + add x20, x20, x13 + str d22, [x20, #0x0] + add x20, x20, x13 + str d21, [x20, #0x0] + add x20, x20, x13 + str d20, [x20, #0x0] + add x20, x20, x13 + str d19, [x20, #0x0] + add x20, x20, x13 + str d18, [x20, #0x0] + add x20, x20, x13 + str d17, [x20, #0x0] + add x20, x20, x13 + str d16, [x20, #0x0] + b label_13 +KAI_ASM_LABEL(label_8) // Partial output + mov x28, x15 + add x26, x28, x13, LSL #2 + add x25, x26, x13, LSL #1 + add x24, x26, x13 + add x23, x25, x13 + add x22, x28, x13, LSL #1 + add x21, x28, x13 + add x20, x22, x13 + add x27, x23, x13 + tbz x10, #1, label_9 + st1 { v24.s }[0], [x23], #0x4 + st1 { v25.s }[0], [x25], #0x4 + st1 { v26.s }[0], [x24], #0x4 + st1 { v27.s }[0], [x26], #0x4 + st1 { v28.s }[0], [x20], #0x4 + st1 { v29.s }[0], [x22], #0x4 + st1 { v30.s }[0], [x21], #0x4 + st1 { v31.s }[0], [x28], #0x4 + tbz x10, #0, label_10 + st1 { v24.h }[2], [x23] + st1 { v25.h }[2], [x25] + st1 { v26.h }[2], [x24] + st1 { v27.h }[2], [x26] + st1 { v28.h }[2], [x20] + st1 { v29.h }[2], [x22] + st1 { v30.h }[2], [x21] + st1 { v31.h }[2], [x28] + b label_10 +KAI_ASM_LABEL(label_9) // Output block 0: partial_1_0 + st1 { v24.h }[0], [x23] + st1 { v25.h }[0], [x25] + st1 { v26.h }[0], [x24] + st1 { v27.h }[0], [x26] + st1 { v28.h }[0], [x20] + st1 { v29.h }[0], [x22] + st1 { v30.h }[0], [x21] + st1 { v31.h }[0], [x28] +KAI_ASM_LABEL(label_10) // Output block 0: Done + add x26, x27, x13, LSL #2 + add x25, x27, x13, LSL #1 + add x24, x26, x13, LSL #1 + add x23, x27, x13 + add x22, x25, x13 + add x21, x26, x13 + add x20, x24, x13 + tbz x10, #1, label_11 + st1 { v16.s }[0], [x20], #0x4 + st1 { v17.s }[0], [x24], #0x4 + st1 { v18.s }[0], [x21], #0x4 + st1 { v19.s }[0], [x26], #0x4 + st1 { v20.s }[0], [x22], #0x4 + st1 { v21.s }[0], [x25], #0x4 + st1 { v22.s }[0], [x23], #0x4 + st1 { v23.s }[0], [x27], #0x4 + tbz x10, #0, label_12 + st1 { v16.h }[2], [x20] + st1 { v17.h }[2], [x24] + st1 { v18.h }[2], [x21] + st1 { v19.h }[2], [x26] + st1 { v20.h }[2], [x22] + st1 { v21.h }[2], [x25] + st1 { v22.h }[2], [x23] + st1 { v23.h }[2], [x27] + b label_12 +KAI_ASM_LABEL(label_11) // Output block 1: partial_1_0 + st1 { v16.h }[0], [x20] + st1 { v17.h }[0], [x24] + st1 { v18.h }[0], [x21] + st1 { v19.h }[0], [x26] + st1 { v20.h }[0], [x22] + st1 { v21.h }[0], [x25] + st1 { v22.h }[0], [x23] + st1 { v23.h }[0], [x27] +KAI_ASM_LABEL(label_12) // Output block 1: Done +KAI_ASM_LABEL(label_13) // Output stage exit + subs x10, x10, #0x4 + add x15, x15, #0x8 + bgt label_2 + mov x20, #0x4 + sub x14, x14, #0x10 + cmp x14, #0x10 + mov x15, x9 + madd x8, x20, x6, x8 + bge label_1 +KAI_ASM_LABEL(label_14) // Row loop skip + cbz x14, label_23 +KAI_ASM_LABEL(label_15) // Row tail: Row loop + mov x26, x17 + mov x25, x16 + add x24, x15, x13, LSL #2 +KAI_ASM_LABEL(label_16) // Row tail: Column loop + mov x27, x8 + movi v31.4s, #0x0 + movi v30.4s, #0x0 + mov x20, x7 + movi v29.4s, #0x0 + movi v28.4s, #0x0 +KAI_ASM_LABEL(label_17) // Row tail: Sub block loop + ldr q4, [x26, #0x0] + ldr q3, [x26, #0x10] + movi v2.16b, #0xf0 + subs x20, x20, #0x1 + ldr q1, [x27, #0x0] + ldr q0, [x27, #0x10] + ldr q27, [x26, #0x20] + ldr q26, [x26, #0x30] + add x26, x26, #0x40 + ldr q25, [x27, #0x20] + ldr q24, [x27, #0x30] + shl v23.16b, v4.16b, #0x4 + shl v22.16b, v3.16b, #0x4 + ldr q21, [x27, #0x40] + ldr q20, [x27, #0x50] + and v4.16b, v4.16b, v2.16b + and v3.16b, v3.16b, v2.16b + ldr q19, [x27, #0x60] + ldr q18, [x27, #0x70] + shl v17.16b, v27.16b, #0x4 + shl v16.16b, v26.16b, #0x4 + KAI_ASM_INST(0x4e97a43f) // smmla v31.4s, v1.16b, v23.16b + KAI_ASM_INST(0x4e96a43e) // smmla v30.4s, v1.16b, v22.16b + and v27.16b, v27.16b, v2.16b + add x27, x27, #0x80 + KAI_ASM_INST(0x4e97a41d) // smmla v29.4s, v0.16b, v23.16b + KAI_ASM_INST(0x4e96a41c) // smmla v28.4s, v0.16b, v22.16b + and v26.16b, v26.16b, v2.16b + KAI_ASM_INST(0x4e91a73f) // smmla v31.4s, v25.16b, v17.16b + KAI_ASM_INST(0x4e90a73e) // smmla v30.4s, v25.16b, v16.16b + KAI_ASM_INST(0x4e91a71d) // smmla v29.4s, v24.16b, v17.16b + KAI_ASM_INST(0x4e90a71c) // smmla v28.4s, v24.16b, v16.16b + KAI_ASM_INST(0x4e84a6bf) // smmla v31.4s, v21.16b, v4.16b + KAI_ASM_INST(0x4e83a6be) // smmla v30.4s, v21.16b, v3.16b + KAI_ASM_INST(0x4e84a69d) // smmla v29.4s, v20.16b, v4.16b + KAI_ASM_INST(0x4e83a69c) // smmla v28.4s, v20.16b, v3.16b + KAI_ASM_INST(0x4e9ba67f) // smmla v31.4s, v19.16b, v27.16b + KAI_ASM_INST(0x4e9aa67e) // smmla v30.4s, v19.16b, v26.16b + KAI_ASM_INST(0x4e9ba65d) // smmla v29.4s, v18.16b, v27.16b + KAI_ASM_INST(0x4e9aa65c) // smmla v28.4s, v18.16b, v26.16b + bgt label_17 + ldr q18, [x26, #0x0] + ld1 { v17.4s }, [x27] + uzp1 v24.2d, v31.2d, v30.2d + uzp2 v23.2d, v31.2d, v30.2d + ldr q22, [x26, #0x10] + uzp1 v21.2d, v29.2d, v28.2d + uzp2 v20.2d, v29.2d, v28.2d + add x27, x27, #0x10 + ldr q16, [x27, #0x0] + add x26, x26, #0x20 + mla v24.4s, v18.4s, v17.s[0] + mla v23.4s, v18.4s, v17.s[1] + mla v21.4s, v18.4s, v17.s[2] + mla v20.4s, v18.4s, v17.s[3] + fmul v19.4s, v22.4s, v16.s[0] + fmul v18.4s, v22.4s, v16.s[1] + fmul v17.4s, v22.4s, v16.s[2] + fmul v16.4s, v22.4s, v16.s[3] + scvtf v24.4s, v24.4s + scvtf v23.4s, v23.4s + scvtf v21.4s, v21.4s + scvtf v20.4s, v20.4s + fmul v31.4s, v24.4s, v19.4s + fmul v30.4s, v23.4s, v18.4s + fmul v29.4s, v21.4s, v17.4s + fmul v28.4s, v20.4s, v16.4s + ldr q18, [x26, #0x0] + ld1r { v17.4s }, [x12] + add x20, x12, #0x4 + cmp x25, #0x4 + ld1r { v16.4s }, [x20] + add x26, x26, #0x10 + fadd v31.4s, v31.4s, v18.4s + fadd v30.4s, v30.4s, v18.4s + fadd v29.4s, v29.4s, v18.4s + fadd v28.4s, v28.4s, v18.4s + fmax v31.4s, v31.4s, v17.4s + fmax v30.4s, v30.4s, v17.4s + fmax v29.4s, v29.4s, v17.4s + fmax v28.4s, v28.4s, v17.4s + fmin v31.4s, v31.4s, v16.4s + fmin v30.4s, v30.4s, v16.4s + fmin v29.4s, v29.4s, v16.4s + fmin v28.4s, v28.4s, v16.4s + fcvtn v19.4h, v31.4s + fcvtn v18.4h, v30.4s + fcvtn v17.4h, v29.4s + fcvtn v16.4h, v28.4s + blt label_19 + mov x20, x15 + cmp x14, #0x1 + str d19, [x20, #0x0] + add x20, x20, x13 + ble label_22 + cmp x14, #0x2 + str d18, [x20, #0x0] + add x20, x20, x13 + ble label_22 + cmp x14, #0x3 + str d17, [x20, #0x0] + add x20, x20, x13 + ble label_22 + str d16, [x20, #0x0] + b label_22 +KAI_ASM_LABEL(label_19) // Row tail: Partial output + mov x23, x15 + cmp x14, #0x1 + add x22, x23, x13 + csel x22, x22, x23, GT + cmp x14, #0x2 + add x21, x23, x13, LSL #1 + csel x21, x21, x22, GT + cmp x14, #0x3 + add x20, x21, x13 + csel x20, x20, x21, GT + tbz x25, #1, label_20 + st1 { v16.s }[0], [x20], #0x4 + st1 { v17.s }[0], [x21], #0x4 + st1 { v18.s }[0], [x22], #0x4 + st1 { v19.s }[0], [x23], #0x4 + tbz x25, #0, label_21 + st1 { v16.h }[2], [x20] + st1 { v17.h }[2], [x21] + st1 { v18.h }[2], [x22] + st1 { v19.h }[2], [x23] + b label_21 +KAI_ASM_LABEL(label_20) // Row tail: Output block 0: partial_1_0 + st1 { v16.h }[0], [x20] + st1 { v17.h }[0], [x21] + st1 { v18.h }[0], [x22] + st1 { v19.h }[0], [x23] +KAI_ASM_LABEL(label_21) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_22) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x15, x15, #0x8 + bgt label_16 + subs x14, x14, #0x4 + add x8, x8, x6 + mov x15, x24 + bgt label_15 +KAI_ASM_LABEL(label_23) // Row tail: Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp_qsi4cxp_interface.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp_qsi4cxp_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..b58487868d624d6fd81d5dbb202c1f6a64d234a8 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp_qsi4cxp_interface.h @@ -0,0 +1,52 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f16_qai8dxp_qsi4cxp + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_f16_qai8dxp_qsi4cxp_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_f16_qai8dxp_qsi4cxp_ukernel { + kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_m_step_func_t get_m_step; + kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_n_step_func_t get_n_step; + kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_mr_func_t get_mr; + kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_nr_func_t get_nr; + kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_kr_func_t get_kr; + kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_sr_func_t get_sr; + kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f16_qai8dxp_qsi4cxp_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f16_qai8dxp_qsi4cxp_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.c new file mode 100644 index 0000000000000000000000000000000000000000..14f4768edeedddcdf28b201fce6fc505190b028f --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.c @@ -0,0 +1,237 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || \ + !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_FP16. +#else // Architectural features check. + +#include "kai_lhs_quant_pack_qai8dxp_f16_neon.h" + +#include +#include +#include +#include +#include + +#include "kai/kai_common.h" +#define FLT16_MAX 65504.0 +#define FLT16_MIN (-65504.0F) + +static const size_t kai_num_bytes_per_multiplier = sizeof(float); +static const size_t kai_num_bytes_per_offset = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k) { + // Round up k to be a multiple of 32. + size_t kai_k_multiple_of = 32; + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) { + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); +} + +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f16_neon(size_t mr) { + KAI_UNUSED(mr); + return 1; +} + +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f16_neon(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f16_neon( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + // It always points to the beginning of the row + return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr); +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(m, mr) / mr; + + return num_rows * kai_lhs_packed_stride(k, mr, kr, sr); +} + +void kai_run_lhs_quant_pack_qai8dxp_f16_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* restrict lhs, + size_t lhs_stride, void* restrict lhs_packed) { + KAI_ASSERT((kr % sr) == 0); + KAI_ASSUME((kr / sr == 8) || (kr / sr == 4)); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + + float16_t const* src_ptr = (float16_t const*)lhs; + + const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr); + const size_t k_internal = kai_k_roundedup(k); + const int32_t k_block_len = (int32_t)(kr / sr); + + const int32_t num_blocks_k = (int32_t)(k / k_block_len); + const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len); + + for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { + // Find min/max for each channel + int32_t k_idx = 0; + float16_t absmax = (float16_t)(-FLT16_MAX); + + float16x8_t vmax0 = vdupq_n_f16(absmax); + float16x8_t vmin0 = vdupq_n_f16(-absmax); + + for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { + const float16x8_t src0_0 = vld1q_f16(src_ptr + (size_t)k_idx); + vmax0 = vmaxq_f16(vmax0, src0_0); + vmin0 = vminq_f16(vmin0, src0_0); + } + // Get the max/min + float16_t max0 = vmaxvq_f16(vmax0); + float16_t min0 = vminvq_f16(vmin0); + + for (; k_idx < (int32_t)k; ++k_idx) { + const float16_t src0 = *(src_ptr + (size_t)k_idx); + max0 = vmaxh_f16(src0, max0); + min0 = vminh_f16(src0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = fminf(0.0F, min0); + const float rmax0 = fmaxf(0.0F, max0); + const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + + zero_point0 = fmaxf(zero_point0, qmin); + zero_point0 = fminf(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); + + const size_t dst_x = ((row_idx + m_idx_start) % mr); + + uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); + + // Quantize the channels + int32_t block_idx = 0; + + if (k_block_len == 8) { + for (; block_idx < num_blocks_k; ++block_idx) { + // Clamp at the last valid k-index + const int32_t k_idx_start = block_idx * k_block_len; + + const float16x8_t src_0 = vld1q_f16(src_ptr + k_idx_start); + + // Scale the values + float32x4_t v0_f32 = vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src_0)), scale0); + float32x4_t v1_f32 = vmulq_n_f32(vcvt_high_f32_f16(src_0), scale0); + int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); + int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); + + int16x4_t v0_s16 = vqmovn_s32(v0_s32); + int16x4_t v1_s16 = vqmovn_s32(v1_s32); + int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); + + // Add zero points + int16_t nzp_s16 = (int16_t)nudged_zero_point0; + int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16); + v_s16 = vaddq_s16(v_s16, vnzp_s16); + v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); + v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); + + int8x8_t v0_s8 = vqmovn_s16(v_s16); + vst1_s8((int8_t*)(dst_ptr), v0_s8); + dst_ptr += 8 * sizeof(int8_t); + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + } else { + for (; block_idx < num_blocks_k; ++block_idx) { + for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { + const int32_t k_idx_start = (block_idx * k_block_len) + k_block_idx; + + const float src0 = (float)(*(src_ptr + k_idx_start)); + + // Scale the values + int32_t v0_s32 = (int32_t)(roundf(src0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = KAI_MAX(v0_s32, INT8_MIN); + v0_s32 = KAI_MIN(v0_s32, INT8_MAX); + + *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + } + + for (; block_idx < num_blocks_k_internal; ++block_idx) { + // left over k + for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { + // Clamp at the last valid k-index + const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); + + const float src0 = (float)(*(src_ptr + k_idx_start)); + + // Scale the values + int32_t v0_s32 = (int32_t)(roundf(src0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = KAI_MAX(v0_s32, INT8_MIN); + v0_s32 = KAI_MIN(v0_s32, INT8_MAX); + + *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + + dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); + + dst_ptr += dst_x * kai_num_bytes_per_offset; + + // LHS offset at the beginning of the row + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + + // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier + KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); + + dst_ptr += mr * kai_num_bytes_per_offset; + + // Store the scale quantization params + *((float*)(dst_ptr)) = recip_scale0; + + src_ptr += (lhs_stride / sizeof(float16_t)); + + // Move to the next row if we have interleaved all Mr rows + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); + } + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..7785c52175e7266a69729c604f2d7ec3e4c0aec4 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.h @@ -0,0 +1,79 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @param[in] mr The number of M rows to interleave on the same output row. +/// +/// @return the m step value +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f16_neon(size_t mr); + +/// Gets the offset in bytes for the LHS matrix (not packed) +/// +/// This function should be called before passing the pointer to the LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) +/// +/// @return the offset in bytes to the LHS matrix +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f16_neon(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f16_neon( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes for the quantized and packed LHS matrix +/// +/// @param[in] m Total number of rows in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed LHS matrix size in bytes +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +/// Run the micro-kernel to quantize and pack the LHS matrix. +/// +/// @param[in] m The number of output rows written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] m_idx_start The starting M index. +/// @param[in] lhs LHS of the vector-by-matrix. +/// @param[in] lhs_stride Stride in bytes between two rows of LHS. +/// @param[out] lhs_packed The quantized and packed LHS matrix. +void kai_run_lhs_quant_pack_qai8dxp_f16_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed); + +#ifdef __cplusplus +} +#endif diff --git a/test/tests/matmul_clamp_f16_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f16_qai8dxp_qsi4cxp_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b8bc29ad65583f8bee665d131974a367dadf9ba --- /dev/null +++ b/test/tests/matmul_clamp_f16_qai8dxp_qsi4cxp_test.cpp @@ -0,0 +1,238 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp_qsi4cxp_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" +#include "test/common/compare.hpp" +#include "test/common/cpu_info.hpp" +#include "test/common/data_format.hpp" +#include "test/common/float16.hpp" +#include "test/common/int4.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/memory.hpp" +#include "test/common/round.hpp" +#include "test/common/test_suite.hpp" +#include "test/reference/cast.hpp" +#include "test/reference/clamp.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/pad.hpp" +#include "test/reference/quantize.hpp" + +namespace kai::test { + +static auto cpu_has_dotprod_and_fp16 = []() { return cpu_has_dotprod() && cpu_has_fp16(); }; +static auto cpu_has_i8mm_and_fp16 = []() { return cpu_has_i8mm() && cpu_has_fp16(); }; + +static const std::array, 4> + variants_kai_matmul_clamp_f16_qai8dxp_qsi4cxp = {{ + {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod), + "kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod", cpu_has_dotprod_and_fp16}, + {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod), + "kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod", cpu_has_dotprod_and_fp16}, + {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod), + "kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod", cpu_has_dotprod_and_fp16}, + {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm), + "kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm", cpu_has_i8mm_and_fp16}, + }}; + +class MatMulTest_f16_qai8dxp_qsi4cxp : public ::testing::TestWithParam {}; + +TEST_P(MatMulTest_f16_qai8dxp_qsi4cxp, EndToEnd) { + const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_f16_qai8dxp_qsi4cxp.at(variant_index); + + if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { + GTEST_SKIP() << "CPU features are not supported by current CPU"; + } + + const std::uint32_t seed = 0; + + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; + + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + + if (mr == 1 && M > 1) { + GTEST_SKIP() << "Kernel does not support M != 1"; + } + + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; + } + + // Generates input data. + const auto ref_lhs_f16 = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + std::vector ref_biases; + + if (has_bias) { + ref_biases = fill_random(N, seed + 2); + } + // For reference implementation, Casting FP16 input to FP32 type and FP32 output back to FP16 because the matmul + // implementation works with FP32 accumulation and casts the result to FP16 + const auto ref_lhs = cast(ref_lhs_f16.data(), ref_lhs_f16.size() * 8 / size_in_bits); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit symmetric quantization. + // * Quantizes the RHS matrix using 8-bit asymmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block_dynamic(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4, ref_rhs_scales] = + quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, K); + + const auto ref_dst_no_clamp = + matmul_nt_t_quantized( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, + ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, K, has_bias ? ref_biases.data() : nullptr, nullptr, + nullptr, 1); + + // Clamps the reference output. + const auto clamp_ratio = 0.8F; + const auto [clamp_min, clamp_max] = find_clamp_range(ref_dst_no_clamp.data(), M * N, clamp_ratio); + const auto ref_dst_float = clamp(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max); + + // Cast the reference output to F16 + auto ref_dst = cast(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits); + + // Runs the LHS packing micro-kernel. + const auto lhs_start_row = rect.start_row(); + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f16_neon(M, K, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + + auto lhs_stride = K * sizeof(uint16_t); + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f16_neon(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f16_neon(lhs_start_row, K, mr, kr, sr); + auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); + + ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); + + kai_run_lhs_quant_pack_qai8dxp_f16_neon( + rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_f16.data() + lhs_offset, lhs_stride, + imp_packed_lhs.data() + lhs_packed_offset); + + const auto ref_rhs_qsi4_padded = pad_row( + ref_rhs_qsi4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); + + const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(N, K, nr, kr, sr); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const auto rhs_start_row = rect.start_col(); + auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr); + auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); + ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); + + // Runs the RHS packing micro-kernel. + kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{}; + params.lhs_zero_point = 1; + params.rhs_zero_point = 0; + + kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( + 1, N, K, nr, kr, sr, ref_rhs_qsi4_padded.data(), + has_bias ? reinterpret_cast(ref_biases.data()) : nullptr, + reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); + + const auto dst_stride_row = N * sizeof(uint16_t); + const auto dst_stride_col = sizeof(uint16_t); + const auto dst_offset = + ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); + const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; + ASSERT_EQ(dst_offset, ref_dst_offset); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, + imp_packed_rhs.data() + rhs_matmul_offset, imp_dst.data() + dst_offset, dst_stride_row, dst_stride_col, + clamp_min, clamp_max); + + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + DefaultMismatchHandler handler(0, 0.02, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::FP16); + const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); +} +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTest_f16_qai8dxp_qsi4cxp, + testing::Combine( + testing::Range(0, variants_kai_matmul_clamp_f16_qai8dxp_qsi4cxp.size()), + testing::Values( + MatMulShape{1, 2, 32}, // + MatMulShape{1, 3, 32}, // + MatMulShape{1, 4, 32}, // + MatMulShape{1, 5, 31}, // + MatMulShape{3, 3, 32}, // + MatMulShape{4, 4, 32}, // + MatMulShape{5, 5, 31}, // + MatMulShape{16, 32, 64}, // + MatMulShape{16, 32, 36}, // + MatMulShape{15, 35, 65}, // + MatMulShape{8, 32, 64}, // + MatMulShape{15, 31, 45}, // + MatMulShape{1, 35, 65}, // + MatMulShape{1, 128, 32}, // + MatMulShape{64, 128, 32}, // + MatMulShape{77, 99, 64}), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. + MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. + MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle + MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. + MatrixPortion(0.75, 0, 1, 1), // Partial rows + MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle + ), + testing::Bool()), + [](const auto& info) { + const auto variant_idx = std::get<0>(info.param); + const std::string name{variants_kai_matmul_clamp_f16_qai8dxp_qsi4cxp.at(variant_idx).name}; + const auto shape = std::get(info.param); + const auto portion = std::get<2>(info.param); + const auto has_bias = std::get<3>(info.param); + + std::stringstream sstream; + sstream << name << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // + << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // + << "__PortionHeight_" << static_cast(portion.height() * 1000) // + << "__PortionWidth_" << static_cast(portion.width() * 1000) // + << (has_bias ? "__Bias" : ""); + return sstream.str(); + }); + +} // namespace kai::test