diff --git a/CHANGELOG.md b/CHANGELOG.md index d72d0653553c34faf8b3462bc33fbeb51c842acd..6b682edda8436411ba351c50667b3c640e27ded2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,16 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- Convert SME and SME2 imatmul micro-kernels to use pure assembly, and add MSVC support. Affects: + - kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa + - kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa + - kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa + - kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme + - kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme + - kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme + - kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme + - kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme + - kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme - New Advanced SIMD micro-kernels: - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_DotProd. - Added Convolution example using SME Indirect Matmul Kernels diff --git a/CMakeLists.txt b/CMakeLists.txt index adf696b81fec7e2a6eebd0300fbb34a3396f417d..4a36b210bf8bfda700a127e2d1f881a1bfd4aa27 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -223,16 +223,26 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c ) -set(KLEIDIAI_FILES_SME +set(KLEIDIAI_FILES_SME_ASM kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c + kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c + kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c + kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c + kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c + kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c + kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S +) + +set(KLEIDIAI_FILES_SME + ${KLEIDIAI_FILES_SME_ASM} + kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -242,6 +252,12 @@ set(KLEIDIAI_FILES_SME ) set(KLEIDIAI_FILES_SME2_ASM + kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S + kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S + kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c @@ -250,9 +266,6 @@ set(KLEIDIAI_FILES_SME2_ASM set(KLEIDIAI_FILES_SME2 ${KLEIDIAI_FILES_SME2_ASM} - kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c - kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c - kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c @@ -303,12 +316,16 @@ if(NOT MSVC) else() target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM_ASM}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME_ASM}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2_ASM}) - set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set(KLEIDIAI_FILES_ASM + ${KLEIDIAI_FILES_SME_ASM} + ${KLEIDIAI_FILES_SME2_ASM} ${KLEIDIAI_FILES_NEON_DOTPROD_ASM} ${KLEIDIAI_FILES_NEON_I8MM_ASM}) list(FILTER KLEIDIAI_FILES_ASM INCLUDE REGEX "^.*\.S$") diff --git a/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt b/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt index a9499b8417c3787216fb453c9e51b44e025def8e..afc8a4c551a9118cf5af50c3966eb79f2d2abefd 100644 --- a/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt +++ b/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt @@ -8,6 +8,8 @@ cmake_minimum_required(VERSION 3.16) project(conv2d_imatmul_clamp_f16_f16_f16p_sme2) +enable_language(ASM) + set(CMAKE_CXX_STANDARD 17) set(KAI_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../) set(KAI_BUILD ${KAI_PATH}/build) @@ -15,14 +17,19 @@ set(KAI_BUILD ${KAI_PATH}/build) include_directories(${KAI_PATH}) set(KAI_SOURCES + ${KAI_PATH}/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S ${KAI_PATH}/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c + ${KAI_PATH}/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S ${KAI_PATH}/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c - ${KAI_PATH}/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c) + ${KAI_PATH}/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S + ${KAI_PATH}/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c +) # Files requires to build the executable -add_executable( - conv2d_imatmul_clamp_f16_f16_f16p_sme2 conv2d_imatmul_clamp_f16_f16_f16p.cpp - ${KAI_SOURCES}) +add_executable(conv2d_imatmul_clamp_f16_f16_f16p_sme2 + conv2d_imatmul_clamp_f16_f16_f16p.cpp + ${KAI_SOURCES} +) target_compile_options(conv2d_imatmul_clamp_f16_f16_f16p_sme2 PRIVATE "-march=armv8.2-a+sve+sve2;-fno-tree-vectorize" diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index a6209876983e35513861f07609550ee8464fa5bd..af8996952a367afc9f98baa3f816e802b729310a 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -139,16 +139,10 @@ I8MM_KERNELS_ASM = [ # buildifier: keep sorted SME_KERNELS = [ - "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", - "pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme", - "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme", "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_lhs_pack_f32p2vlx1_f32_sme", "pack/kai_lhs_pack_x16p2vlx2_x16_sme", "pack/kai_lhs_pack_x8p2vlx4_x8_sme", - "pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", - "pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme", - "pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", @@ -159,16 +153,17 @@ SME_KERNELS = [ ] # buildifier: keep sorted -SME2_KERNELS_ASM = [ - "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa", - "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot", +SME_KERNELS_ASM = [ + "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", + "pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme", + "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme", + "pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", + "pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme", + "pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme", ] # buildifier: keep sorted SME2_KERNELS = [ - "imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", - "imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa", - "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla", @@ -183,6 +178,15 @@ SME2_KERNELS = [ "matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", ] +# buildifier: keep sorted +SME2_KERNELS_ASM = [ + "imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", + "imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa", + "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", + "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa", + "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot", +] + kai_c_library( name = "interface", textual_hdrs = glob(["**/*_interface.h"]), @@ -279,6 +283,13 @@ kai_c_library( textual_hdrs = [ukernel + ".h" for ukernel in SME_KERNELS], ) +kai_c_library( + name = "sme_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in SME_KERNELS_ASM] + [ukernel + ".c" for ukernel in SME_KERNELS_ASM], + cpu_uarch = kai_cpu_sme(), + textual_hdrs = [ukernel + ".h" for ukernel in SME_KERNELS_ASM], +) + kai_c_library( name = "sme2_impl", srcs = [ukernel + ".c" for ukernel in SME2_KERNELS], @@ -313,5 +324,6 @@ kai_c_library( ":sme2_impl", ":sme2_impl_asm", ":sme_impl", + ":sme_impl_asm", ], ) diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c index 7e77125a4a58f190bc36cb868b52e6e348b539f5..51c869ec7817b9edbce5f5267033b2ff8e0dd929 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c @@ -4,13 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" #include @@ -18,30 +14,51 @@ #include "kai/kai_common.h" +typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + uint16_t min; + uint16_t max; + void* accumulator_buffer; + uint64_t flags; +} KernelArgs; + static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 2; +void kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(KernelArgs* args); +uint16_t kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(float value); + +// Returns a constant value specific to this kernel that's relative to vector length +static size_t kai_get_kernel_vec_length_constant(void) { + const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u16() / kai_kr; + return kernel_vec_length_constant; +} + size_t kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u32(); + return kai_mr * kai_get_kernel_vec_length_constant(); } size_t kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_kernel_vec_length_constant(); } size_t kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - return m_idx * indirect_k * sizeof(uint16_t); + return m_idx * k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(uint16_t); } static size_t kai_get_rhs_packed_stride_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( size_t k_chunk_count, size_t k_chunk_length) { - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); return kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() * - (sizeof(uint16_t) + indirect_k * sizeof(uint16_t)); + (sizeof(uint16_t) + k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(uint16_t)); } size_t kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( @@ -54,11 +71,11 @@ size_t kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_s } size_t kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_row_stride) { + size_t m_idx, size_t n_idx, size_t dst_stride_row) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); - return m_idx * dst_row_stride + n_idx * sizeof(uint16_t); + return m_idx * dst_stride_row + n_idx * sizeof(uint16_t); } size_t kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(size_t m, size_t n) { @@ -67,186 +84,22 @@ size_t kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( void kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, float clamp_min, float clamp_max) { - typedef struct { - const void* A; - const void* B; - void* C; - uint64_t ldcb; - uint64_t M; - uint64_t N; - uint64_t K; - float16_t min; - float16_t max; - void* accumulator_buffer; - uint64_t flags; - } KernelArgs; - + void* dst, size_t dst_stride_row, float clamp_min, float clamp_max) { KernelArgs args; args.A = lhs_packed; args.B = rhs_packed; - - size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - args.C = dst; - args.ldcb = dst_row_stride; + args.ldcb = dst_stride_row; args.M = m; args.N = n; - args.K = indirect_k; - args.min = (float16_t)clamp_min; - args.max = (float16_t)clamp_max; - + args.K = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + args.min = kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(clamp_min); + args.max = kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(clamp_max); args.accumulator_buffer = NULL; args.flags = 0; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "ldr w13, [%x[args], %[offsetof_M]]\n" - "mov x11, #0x0\n" - "mov x10, #0x0\n" - "ptrue p1.b\n" - ".inst 0x25207810 // ptrue pn8.b\n" - "ldr w9, [%x[args], %[offsetof_N]]\n" - "ldr x28, [%x[args], %[offsetof_A]]\n" - "1:" // M loop - "ldr x27, [%x[args], %[offsetof_B]]\n" - "2:" // N loop - "fmov z24.h, #0.0\n" - "ld1h { z5.h }, p1/Z, [x27]\n" - "fmov z27.h, #1.0\n" - "mov x26, x28\n" - ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" - "inch x27, ALL, MUL #2\n" - "zip1 z30.h, z5.h, z24.h\n" - "zip2 z20.h, z5.h, z24.h\n" - ".inst 0x81be2760 // fmopa za0.s, p1/M, p1/M, z27.h, z30.h\n" - ".inst 0x81b42761 // fmopa za1.s, p1/M, p1/M, z27.h, z20.h\n" - ".inst 0x81be2762 // fmopa za2.s, p1/M, p1/M, z27.h, z30.h\n" - ".inst 0x81b42763 // fmopa za3.s, p1/M, p1/M, z27.h, z20.h\n" - "ldr x20, [%x[args], %[offsetof_K]]\n" - "add x20, x20, #0x1\n" - "lsr x20, x20, #0x1\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 6f\n" - "subs x21, x21, #0x1\n" - ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" - ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" - ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" - ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" - ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" - "addvl x26, x26, #8\n" - ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - "ble 5f\n" - "4:" // K loop - ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" - "subs x21, x21, #0x1\n" - ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" - ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" - ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" - ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" - ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" - ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" - ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" - ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" - ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" - ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" - ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" - ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" - ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" - ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" - ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" - ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" - ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" - ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" - ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" - ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" - "addvl x26, x26, #8\n" - ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - "bgt 4b\n" - "5:" // K loop tail - ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" - ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" - ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" - ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" - ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" - ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" - ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" - ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" - ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" - ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" - ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" - ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" - ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" - ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" - ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" - ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" - "6:" // K oddments - "cbz x20, 8f\n" - "7:" // K oddments: Loop - ".inst 0xa1402345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26]\n" - "subs x20, x20, #0x1\n" - "addvl x26, x26, #2\n" - ".inst 0xa040236e // ld1h { z14.h-z15.h }, pn8.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0x81ae24a0 // fmopa za0.s, p1/M, p1/M, z5.h, z14.h\n" - ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" - ".inst 0x81ae25a2 // fmopa za2.s, p1/M, p1/M, z13.h, z14.h\n" - ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" - "bgt 7b\n" - "8:" // K oddments: End - "ldr x25, [%x[args], %[offsetof_C]]\n" - "sub x24, x13, x11\n" - "cntw x23, ALL, MUL #2\n" - "ld1rh { z17.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" - "ldr x22, [%x[args], %[offsetof_ldcb]]\n" - "whilelt p0.h, x10, x9\n" - "cmp x24, x23\n" - "ld1rh { z16.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" - "mov x12, #0x0\n" - "mov x21, #0x0\n" - "add x25, x25, x10, LSL #1\n" // C += n - "mov x20, #0x2\n" - "madd x25, x11, x22, x25\n" // C += m * ldc - "csel x24, x24, x23, LT\n" - "10:" // Store to output array: Accumulator loop - ".inst 0xc006000e // mova { z14.b-z15.b }, za0h.b[x12, 0:1]\n" - "add x12, x12, #0x4\n" - "cmp x12, x23, LSL #1\n" - "add x21, x21, #0x1\n" - ".inst 0xc120e1cc // fcvt z12.h, { z14.s-z15.s }\n" - "csel x12, x12, x20, LT\n" - "cmp x21, x24\n" - ".inst 0x6470262c // fclamp z12.h, z17.h, z16.h\n" - "st1h { z12.h }, p0, [x25]\n" - "add x25, x25, x22\n" - "blt 10b\n" - "incw x10, ALL, MUL #2\n" - "cmp x10, x9\n" - "blt 2b\n" - "incw x11, ALL, MUL #2\n" - "mov x10, #0x0\n" - "cmp x11, x13\n" - "mov x28, x26\n" - "blt 1b\n" - ".inst 0xd503467f // SMSTOP\n" - : - : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), - [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), - [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), - [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), [offsetof_M] "I"(offsetof(KernelArgs, M)), - [offsetof_N] "I"(offsetof(KernelArgs, N)), [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", - "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", - "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h index 79c52a4205dd9b02c07bfc70a730ca96df8b77ee..6814438233714c9a5e366080195fa4394320e33c 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h @@ -55,11 +55,11 @@ size_t kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_s /// /// @param[in] m_idx Row index. Must be a multiple of `m_step`. /// @param[in] n_idx Column index. Must be a multiple of `n_step`. -/// @param[in] dst_row_stride. Distance between start of two rows in the output buffer. +/// @param[in] dst_stride_row Row stride in bytes. /// /// @return The offset in bytes to the data element. size_t kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_row_stride); + size_t m_idx, size_t n_idx, size_t dst_stride_row); /// Gets the size in bytes of the destination matrix buffer. /// @@ -81,16 +81,16 @@ size_t kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length of a LHS column split +/// @param[in] k_chunk_length Length of a LHS column split. /// @param[in] lhs_packed Packed LHS matrix buffer. /// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. -/// @param[in] dst_row_stride Row stride in bytes of the output matrix. +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. void kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + void* dst, size_t dst_stride_row, float clamp_min, float clamp_max); #ifdef __cplusplus } // extern "C" diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..d6bae7c79e5e85e680c893271b4375b2655c6aa5 --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S @@ -0,0 +1,202 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + KAI_ASM_GLOBAL(kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + fcvt h0, s0 + fmov w0, h0 + ret + KAI_ASM_FUNCTION_END(kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x14, #0x0 + ldr x13, [x0, #0x30] + ptrue p1.b + KAI_ASM_INST(0x25207810) // ptrue pn8.b + ldr w11, [x0, #0x20] + mov x10, #0x0 + ldr w9, [x0, #0x28] + add x13, x13, #0x1 + ldr x28, [x0, #0x0] + lsr x13, x13, #0x1 +KAI_ASM_LABEL(label_1) // M loop + ldr x27, [x0, #0x8] +KAI_ASM_LABEL(label_2) // N loop + fmov z23.h, #0.0 + ld1h { z18.h }, p1/Z, [x27] + fmov z2.h, #1.0 + mov x26, x28 + KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } + inch x27, ALL, MUL #2 + zip1 z14.h, z18.h, z23.h + zip2 z3.h, z18.h, z23.h + KAI_ASM_INST(0x81ae2440) // fmopa za0.s, p1/M, p1/M, z2.h, z14.h + KAI_ASM_INST(0x81a32441) // fmopa za1.s, p1/M, p1/M, z2.h, z3.h + KAI_ASM_INST(0x81ae2442) // fmopa za2.s, p1/M, p1/M, z2.h, z14.h + KAI_ASM_INST(0x81a32443) // fmopa za3.s, p1/M, p1/M, z2.h, z3.h + lsr x21, x13, #0x2 + and x20, x13, #0x3 + cbz x21, label_6 + subs x21, x21, #0x1 + KAI_ASM_INST(0xa040a350) // ld1h { z16.h-z19.h }, pn8.b/Z, [x26] + KAI_ASM_INST(0xa041a35c) // ld1h { z28.h-z31.h }, pn8.b/Z, [x26, #0x4, MUL VL] + addvl x26, x26, #8 + KAI_ASM_INST(0xa040a360) // ld1h { z0.h-z3.h }, pn8.b/Z, [x27] + KAI_ASM_INST(0xa041a368) // ld1h { z8.h-z11.h }, pn8.b/Z, [x27, #0x4, MUL VL] + addvl x27, x27, #8 + ble label_5 +KAI_ASM_LABEL(label_4) // K loop + KAI_ASM_INST(0x81a02600) // fmopa za0.s, p1/M, p1/M, z16.h, z0.h + subs x21, x21, #0x1 + KAI_ASM_INST(0x81a12601) // fmopa za1.s, p1/M, p1/M, z16.h, z1.h + KAI_ASM_INST(0x81a02622) // fmopa za2.s, p1/M, p1/M, z17.h, z0.h + KAI_ASM_INST(0x81a12623) // fmopa za3.s, p1/M, p1/M, z17.h, z1.h + KAI_ASM_INST(0x81a22640) // fmopa za0.s, p1/M, p1/M, z18.h, z2.h + KAI_ASM_INST(0x81a32641) // fmopa za1.s, p1/M, p1/M, z18.h, z3.h + KAI_ASM_INST(0x81a22662) // fmopa za2.s, p1/M, p1/M, z19.h, z2.h + KAI_ASM_INST(0x81a32663) // fmopa za3.s, p1/M, p1/M, z19.h, z3.h + KAI_ASM_INST(0xa040a350) // ld1h { z16.h-z19.h }, pn8.b/Z, [x26] + KAI_ASM_INST(0x81a82780) // fmopa za0.s, p1/M, p1/M, z28.h, z8.h + KAI_ASM_INST(0xa040a360) // ld1h { z0.h-z3.h }, pn8.b/Z, [x27] + KAI_ASM_INST(0x81a92781) // fmopa za1.s, p1/M, p1/M, z28.h, z9.h + KAI_ASM_INST(0x81a827a2) // fmopa za2.s, p1/M, p1/M, z29.h, z8.h + KAI_ASM_INST(0x81a927a3) // fmopa za3.s, p1/M, p1/M, z29.h, z9.h + KAI_ASM_INST(0x81aa27c0) // fmopa za0.s, p1/M, p1/M, z30.h, z10.h + KAI_ASM_INST(0x81ab27c1) // fmopa za1.s, p1/M, p1/M, z30.h, z11.h + KAI_ASM_INST(0x81aa27e2) // fmopa za2.s, p1/M, p1/M, z31.h, z10.h + KAI_ASM_INST(0x81ab27e3) // fmopa za3.s, p1/M, p1/M, z31.h, z11.h + KAI_ASM_INST(0xa041a35c) // ld1h { z28.h-z31.h }, pn8.b/Z, [x26, #0x4, MUL VL] + addvl x26, x26, #8 + KAI_ASM_INST(0xa041a368) // ld1h { z8.h-z11.h }, pn8.b/Z, [x27, #0x4, MUL VL] + addvl x27, x27, #8 + bgt label_4 +KAI_ASM_LABEL(label_5) // K loop tail + KAI_ASM_INST(0x81a02600) // fmopa za0.s, p1/M, p1/M, z16.h, z0.h + KAI_ASM_INST(0x81a12601) // fmopa za1.s, p1/M, p1/M, z16.h, z1.h + KAI_ASM_INST(0x81a02622) // fmopa za2.s, p1/M, p1/M, z17.h, z0.h + KAI_ASM_INST(0x81a12623) // fmopa za3.s, p1/M, p1/M, z17.h, z1.h + KAI_ASM_INST(0x81a22640) // fmopa za0.s, p1/M, p1/M, z18.h, z2.h + KAI_ASM_INST(0x81a32641) // fmopa za1.s, p1/M, p1/M, z18.h, z3.h + KAI_ASM_INST(0x81a22662) // fmopa za2.s, p1/M, p1/M, z19.h, z2.h + KAI_ASM_INST(0x81a32663) // fmopa za3.s, p1/M, p1/M, z19.h, z3.h + KAI_ASM_INST(0x81a82780) // fmopa za0.s, p1/M, p1/M, z28.h, z8.h + KAI_ASM_INST(0x81a92781) // fmopa za1.s, p1/M, p1/M, z28.h, z9.h + KAI_ASM_INST(0x81a827a2) // fmopa za2.s, p1/M, p1/M, z29.h, z8.h + KAI_ASM_INST(0x81a927a3) // fmopa za3.s, p1/M, p1/M, z29.h, z9.h + KAI_ASM_INST(0x81aa27c0) // fmopa za0.s, p1/M, p1/M, z30.h, z10.h + KAI_ASM_INST(0x81ab27c1) // fmopa za1.s, p1/M, p1/M, z30.h, z11.h + KAI_ASM_INST(0x81aa27e2) // fmopa za2.s, p1/M, p1/M, z31.h, z10.h + KAI_ASM_INST(0x81ab27e3) // fmopa za3.s, p1/M, p1/M, z31.h, z11.h +KAI_ASM_LABEL(label_6) // K oddments + cbz x20, label_8 +KAI_ASM_LABEL(label_7) // K oddments: Loop + KAI_ASM_INST(0xa1402345) // ld1h { z5.h, z13.h }, pn8.b/Z, [x26] + subs x20, x20, #0x1 + addvl x26, x26, #2 + KAI_ASM_INST(0xa040236e) // ld1h { z14.h-z15.h }, pn8.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0x81ae24a0) // fmopa za0.s, p1/M, p1/M, z5.h, z14.h + KAI_ASM_INST(0x81af24a1) // fmopa za1.s, p1/M, p1/M, z5.h, z15.h + KAI_ASM_INST(0x81ae25a2) // fmopa za2.s, p1/M, p1/M, z13.h, z14.h + KAI_ASM_INST(0x81af25a3) // fmopa za3.s, p1/M, p1/M, z13.h, z15.h + bgt label_7 +KAI_ASM_LABEL(label_8) // K oddments: End + ldr x25, [x0, #0x10] + sub x24, x11, x14 + cntw x23, ALL, MUL #2 + KAI_ASM_INST(0x84dca411) // ld1rh { z17.h }, p1/Z, [x0, #56] + ldr x22, [x0, #0x18] + whilelt p0.h, x10, x9 + cmp x24, x23 + KAI_ASM_INST(0x84dda410) // ld1rh { z16.h }, p1/Z, [x0, #58] + mov x12, #0x0 + mov x21, #0x0 + add x25, x25, x10, LSL #1 // C += n + mov x20, #0x2 + madd x25, x14, x22, x25 // C += m * ldc + csel x24, x24, x23, LT +KAI_ASM_LABEL(label_10) // Store to output array: Accumulator loop + KAI_ASM_INST(0xc006000e) // mova { z14.b-z15.b }, za0h.b[x12, 0:1] + add x12, x12, #0x4 + cmp x12, x23, LSL #1 + add x21, x21, #0x1 + KAI_ASM_INST(0xc120e1c4) // fcvt z4.h, { z14.s-z15.s } + csel x12, x12, x20, LT + cmp x21, x24 + KAI_ASM_INST(0x64702624) // fclamp z4.h, z17.h, z16.h + st1h { z4.h }, p0, [x25] + add x25, x25, x22 + blt label_10 + incw x10, ALL, MUL #2 + cmp x10, x9 + blt label_2 + incw x14, ALL, MUL #2 + mov x10, #0x0 + cmp x14, x11 + mov x28, x26 + blt label_1 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h index bbc2b318bd11a02d8b576a6af9aae3a27fb70b43..b2eef4903bcc96e53c43b818c16a12ea831ca183 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h @@ -27,7 +27,7 @@ typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_dst_size_func_t)(size_t m, /// Micro-kernel core function ("run" method) typedef void (*kai_imatmul_clamp_f16_f16p_f16p_run_imatmul_func_t)( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + void* dst, size_t dst_stride_row, float clamp_min, float clamp_max); /// Micro-kernel interface struct kai_imatmul_clamp_f16_f16p_f16p_ukernel { diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c index a927e2b7fd4cb4ec3badbdc4f6d33a022b16dec4..8e64712ce1adbe67f744dcef58d7aa847c6d3ce5 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c @@ -4,13 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" #include @@ -18,30 +14,50 @@ #include "kai/kai_common.h" +typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + float min; + float max; + void* accumulator_buffer; + uint64_t flags; +} KernelArgs; + static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 1; +void kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(KernelArgs* args); + +// Returns a constant value specific to this kernel that's relative to vector length +static size_t kai_get_kernel_vec_length_constant(void) { + const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u32() / kai_kr; + return kernel_vec_length_constant; +} + size_t kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u32(); + return kai_mr * kai_get_kernel_vec_length_constant(); } size_t kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_kernel_vec_length_constant(); } size_t kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() == 0); - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - return m_idx * indirect_k * sizeof(float); + return m_idx * k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(float); } static size_t kai_get_rhs_packed_stride_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( size_t k_chunk_count, size_t k_chunk_length) { - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); return kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() * - (sizeof(float) + indirect_k * sizeof(float)); + (sizeof(float) + k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(float)); } size_t kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( @@ -54,11 +70,11 @@ size_t kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_ } size_t kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_row_stride) { + size_t m_idx, size_t n_idx, size_t dst_stride_row) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() == 0); KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() == 0); - return m_idx * dst_row_stride + n_idx * sizeof(float); + return m_idx * dst_stride_row + n_idx * sizeof(float); } size_t kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(size_t m, size_t n) { @@ -67,245 +83,22 @@ size_t kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa void kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, float clamp_min, float clamp_max) { - typedef struct { - const void* A; - const void* B; - void* C; - uint64_t ldcb; - uint64_t M; - uint64_t N; - uint64_t K; - float min; - float max; - void* accumulator_buffer; - uint64_t flags; - } KernelArgs; - + void* dst, size_t dst_stride_row, float clamp_min, float clamp_max) { KernelArgs args; args.A = lhs_packed; args.B = rhs_packed; - - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - args.C = dst; - args.ldcb = dst_row_stride; + args.ldcb = dst_stride_row; args.M = m; args.N = n; - args.K = indirect_k; + args.K = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); args.min = clamp_min; args.max = clamp_max; - args.accumulator_buffer = NULL; args.flags = 0; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "ldr w14, [%x[args], %[offsetof_M]]\n" - "mov x13, #0x0\n" - "mov x11, #0x0\n" - "ptrue p0.b\n" - ".inst 0x25207811 // ptrue pn9.b\n" - "ldr w10, [%x[args], %[offsetof_N]]\n" - "ldr x9, [%x[args], %[offsetof_A]]\n" - "1:" // M loop - "ldr x28, [%x[args], %[offsetof_B]]\n" - "2:" // N loop - ".inst 0x25aa4570 // whilelt pn8.s, x11, x10, VLx2\n" - "fmov z13.s, #1.0\n" - ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" - "mov x27, x9\n" - ".inst 0xa040438e // ld1w { z14.s-z15.s }, p8/Z, [x28]\n" // Load bias - "addvl x28, x28, #2\n" - ".inst 0x808e01a0 // fmopa za0.s, p0/M, p0/M, z13.s, z14.s\n" - ".inst 0x808f01a1 // fmopa za1.s, p0/M, p0/M, z13.s, z15.s\n" - ".inst 0x808e01a2 // fmopa za2.s, p0/M, p0/M, z13.s, z14.s\n" - ".inst 0x808f01a3 // fmopa za3.s, p0/M, p0/M, z13.s, z15.s\n" - "ldr x20, [%x[args], %[offsetof_K]]\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 6f\n" - "subs x21, x21, #0x1\n" - ".inst 0xa1404772 // ld1w { z18.s, z26.s }, pn9.b/Z, [x27]\n" - ".inst 0xa0404794 // ld1w { z20.s-z21.s }, pn9.b/Z, [x28]\n" - ".inst 0xa1414764 // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0xa041478a // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL]\n" - ".inst 0xa1424773 // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa0424798 // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL]\n" - ".inst 0xa043476e // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - ".inst 0xa1434796 // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL]\n" - "addvl x28, x28, #8\n" - "ble 5f\n" - "4:" // K loop - ".inst 0x80940240 // fmopa za0.s, p0/M, p0/M, z18.s, z20.s\n" - "subs x21, x21, #0x1\n" - ".inst 0x80950241 // fmopa za1.s, p0/M, p0/M, z18.s, z21.s\n" - ".inst 0x80940342 // fmopa za2.s, p0/M, p0/M, z26.s, z20.s\n" - ".inst 0x80950343 // fmopa za3.s, p0/M, p0/M, z26.s, z21.s\n" - ".inst 0xa1404772 // ld1w { z18.s, z26.s }, pn9.b/Z, [x27]\n" - ".inst 0x808a0080 // fmopa za0.s, p0/M, p0/M, z4.s, z10.s\n" - ".inst 0xa0404794 // ld1w { z20.s-z21.s }, pn9.b/Z, [x28]\n" - ".inst 0x808b0081 // fmopa za1.s, p0/M, p0/M, z4.s, z11.s\n" - ".inst 0x808a0182 // fmopa za2.s, p0/M, p0/M, z12.s, z10.s\n" - ".inst 0x808b0183 // fmopa za3.s, p0/M, p0/M, z12.s, z11.s\n" - ".inst 0xa1414764 // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0x80980260 // fmopa za0.s, p0/M, p0/M, z19.s, z24.s\n" - ".inst 0xa041478a // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL]\n" - ".inst 0x80990261 // fmopa za1.s, p0/M, p0/M, z19.s, z25.s\n" - ".inst 0x80980362 // fmopa za2.s, p0/M, p0/M, z27.s, z24.s\n" - ".inst 0x80990363 // fmopa za3.s, p0/M, p0/M, z27.s, z25.s\n" - ".inst 0xa1424773 // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa0424798 // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL]\n" - ".inst 0x809601c0 // fmopa za0.s, p0/M, p0/M, z14.s, z22.s\n" - ".inst 0x809e01c1 // fmopa za1.s, p0/M, p0/M, z14.s, z30.s\n" - ".inst 0x809601e2 // fmopa za2.s, p0/M, p0/M, z15.s, z22.s\n" - ".inst 0x809e01e3 // fmopa za3.s, p0/M, p0/M, z15.s, z30.s\n" - ".inst 0xa043476e // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - ".inst 0xa1434796 // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL]\n" - "addvl x28, x28, #8\n" - "bgt 4b\n" - "5:" // K loop tail - ".inst 0x80940240 // fmopa za0.s, p0/M, p0/M, z18.s, z20.s\n" - ".inst 0x80950241 // fmopa za1.s, p0/M, p0/M, z18.s, z21.s\n" - ".inst 0x80940342 // fmopa za2.s, p0/M, p0/M, z26.s, z20.s\n" - ".inst 0x80950343 // fmopa za3.s, p0/M, p0/M, z26.s, z21.s\n" - ".inst 0x808a0080 // fmopa za0.s, p0/M, p0/M, z4.s, z10.s\n" - ".inst 0x808b0081 // fmopa za1.s, p0/M, p0/M, z4.s, z11.s\n" - ".inst 0x808a0182 // fmopa za2.s, p0/M, p0/M, z12.s, z10.s\n" - ".inst 0x808b0183 // fmopa za3.s, p0/M, p0/M, z12.s, z11.s\n" - ".inst 0x80980260 // fmopa za0.s, p0/M, p0/M, z19.s, z24.s\n" - ".inst 0x80990261 // fmopa za1.s, p0/M, p0/M, z19.s, z25.s\n" - ".inst 0x80980362 // fmopa za2.s, p0/M, p0/M, z27.s, z24.s\n" - ".inst 0x80990363 // fmopa za3.s, p0/M, p0/M, z27.s, z25.s\n" - ".inst 0x809601c0 // fmopa za0.s, p0/M, p0/M, z14.s, z22.s\n" - ".inst 0x809e01c1 // fmopa za1.s, p0/M, p0/M, z14.s, z30.s\n" - ".inst 0x809601e2 // fmopa za2.s, p0/M, p0/M, z15.s, z22.s\n" - ".inst 0x809e01e3 // fmopa za3.s, p0/M, p0/M, z15.s, z30.s\n" - "6:" // K oddments - "cbz x20, 8f\n" - "7:" // K oddments: Loop - ".inst 0xa040477c // ld1w { z28.s-z29.s }, pn9.b/Z, [x27]\n" - "subs x20, x20, #0x1\n" - "addvl x27, x27, #2\n" - ".inst 0xa1404787 // ld1w { z7.s, z15.s }, pn9.b/Z, [x28]\n" - "addvl x28, x28, #2\n" - ".inst 0x80870380 // fmopa za0.s, p0/M, p0/M, z28.s, z7.s\n" - ".inst 0x808f0381 // fmopa za1.s, p0/M, p0/M, z28.s, z15.s\n" - ".inst 0x808703a2 // fmopa za2.s, p0/M, p0/M, z29.s, z7.s\n" - ".inst 0x808f03a3 // fmopa za3.s, p0/M, p0/M, z29.s, z15.s\n" - "bgt 7b\n" - "8:" // K oddments: End - "ldr x26, [%x[args], %[offsetof_C]]\n" - "sub x25, x14, x13\n" - "cntw x24\n" - "ld1rw { z19.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" - "ldr x23, [%x[args], %[offsetof_ldcb]]\n" - "cmp x25, x24\n" - "ld1rw { z26.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" - "mov x12, #0x0\n" - "csel x22, x25, x24, LT\n" - "add x26, x26, x11, LSL #2\n" // C += n - "lsr x21, x22, #0x2\n" - "madd x26, x13, x23, x26\n" // C += m * ldc - "and x20, x22, #0x3\n" - "cbz x21, 11f\n" - "10:" // Store to output array: Accumulator row 0 loop - ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" - ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" - ".inst 0xc1baca64 // fclamp { z4.s-z7.s }, z19.s, z26.s\n" - ".inst 0xc1baca6c // fclamp { z12.s-z15.s }, z19.s, z26.s\n" - "add x12, x12, #0x4\n" - "cmp x12, x21, LSL #2\n" - ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" - "add x26, x26, x23\n" - ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" - "add x26, x26, x23\n" - ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" - "add x26, x26, x23\n" - ".inst 0xa1604347 // st1w { z7.s, z15.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "blt 10b\n" - "11:" // Store to output array: Accumulator row 0 oddments - "cbz x20, 12f\n" - ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" - ".inst 0xc0860428 // mova { z8.s-z11.s }, za1h.s[x12]\n" - "subs x20, x20, #0x1\n" - ".inst 0xc1baca60 // fclamp { z0.s-z3.s }, z19.s, z26.s\n" - ".inst 0xc1baca68 // fclamp { z8.s-z11.s }, z19.s, z26.s\n" - ".inst 0xa1604340 // st1w { z0.s, z8.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "beq 12f\n" - "subs x20, x20, #0x1\n" - ".inst 0xa1604341 // st1w { z1.s, z9.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "beq 12f\n" - ".inst 0xa1604342 // st1w { z2.s, z10.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "12:" // Store to output array: Accumulator row 0 oddments: End - "subs x25, x25, x22\n" - "beq 16f\n" - "cmp x25, x24\n" - "mov x12, #0x0\n" - "csel x20, x25, x24, LT\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 14f\n" - "13:" // Store to output array: Accumulator row 1 loop - ".inst 0xc0860454 // mova { z20.s-z23.s }, za2h.s[x12]\n" - ".inst 0xc086047c // mova { z28.s-z31.s }, za3h.s[x12]\n" - ".inst 0xc1baca74 // fclamp { z20.s-z23.s }, z19.s, z26.s\n" - ".inst 0xc1baca7c // fclamp { z28.s-z31.s }, z19.s, z26.s\n" - "add x12, x12, #0x4\n" - "cmp x12, x21, LSL #2\n" - ".inst 0xa1604354 // st1w { z20.s, z28.s }, p8, [x26]\n" - "add x26, x26, x23\n" - ".inst 0xa1604355 // st1w { z21.s, z29.s }, p8, [x26]\n" - "add x26, x26, x23\n" - ".inst 0xa1604356 // st1w { z22.s, z30.s }, p8, [x26]\n" - "add x26, x26, x23\n" - ".inst 0xa1604357 // st1w { z23.s, z31.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "blt 13b\n" - "14:" // Store to output array: Accumulator row 1 oddments - "cbz x20, 15f\n" - ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n" - ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" - "subs x20, x20, #0x1\n" - ".inst 0xc1baca64 // fclamp { z4.s-z7.s }, z19.s, z26.s\n" - ".inst 0xc1baca6c // fclamp { z12.s-z15.s }, z19.s, z26.s\n" - ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "beq 15f\n" - "subs x20, x20, #0x1\n" - ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "beq 15f\n" - ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" - "15:" // Store to output array: Accumulator row 1 oddments: End - "16:" // Store to output array: End - "incw x11, ALL, MUL #2\n" - "cmp x11, x10\n" - "blt 2b\n" - "incw x13, ALL, MUL #2\n" - "mov x11, #0x0\n" - "cmp x13, x14\n" - "mov x9, x27\n" - "blt 1b\n" - ".inst 0xd503467f // SMSTOP\n" - : - : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), - [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), - [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), - [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), [offsetof_M] "I"(offsetof(KernelArgs, M)), - [offsetof_N] "I"(offsetof(KernelArgs, N)), [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", - "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", - "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", - "z9"); + kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h index c7ac5fa1fb8cc5537e05616d1f4bb9fd674b73fd..655ce4d1a13f48150f70ebea718e0b26a01ba5e1 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h @@ -55,11 +55,11 @@ size_t kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_ /// /// @param[in] m_idx Row index. Must be a multiple of `m_step`. /// @param[in] n_idx Column index. Must be a multiple of `n_step`. -/// @param[in] dst_row_stride. Distance between start of two rows in the output buffer. +/// @param[in] dst_stride_row Row stride in bytes. /// /// @return The offset in bytes to the data element. size_t kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_row_stride); + size_t m_idx, size_t n_idx, size_t dst_stride_row); /// Gets the size in bytes of the destination matrix buffer. /// @@ -81,16 +81,16 @@ size_t kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length of a LHS column split +/// @param[in] k_chunk_length Length of a LHS column split. /// @param[in] lhs_packed Packed LHS matrix buffer. /// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. -/// @param[in] dst_row_stride Row stride in bytes of the output matrix. +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. void kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + void* dst, size_t dst_stride_row, float clamp_min, float clamp_max); #ifdef __cplusplus } // extern "C" diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..71bcc59df4d8d21f293a5b27c21c7df752f53b91 --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S @@ -0,0 +1,252 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x15, #0x0 + ptrue p0.b + KAI_ASM_INST(0x25207811) // ptrue pn9.b + ldr x14, [x0, #0x30] + ldr w13, [x0, #0x20] + mov x11, #0x0 + ldr w10, [x0, #0x28] + ldr x9, [x0, #0x0] +KAI_ASM_LABEL(label_1) // M loop + ldr x28, [x0, #0x8] +KAI_ASM_LABEL(label_2) // N loop + fmov z22.s, #1.0 + KAI_ASM_INST(0xa040478c) // ld1w { z12.s-z13.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } + mov x27, x9 + KAI_ASM_INST(0x25aa4570) // whilelt pn8.s, x11, x10, VLx2 + addvl x28, x28, #2 + KAI_ASM_INST(0x808c02c0) // fmopa za0.s, p0/M, p0/M, z22.s, z12.s + KAI_ASM_INST(0x808d02c1) // fmopa za1.s, p0/M, p0/M, z22.s, z13.s + KAI_ASM_INST(0x808c02c2) // fmopa za2.s, p0/M, p0/M, z22.s, z12.s + KAI_ASM_INST(0x808d02c3) // fmopa za3.s, p0/M, p0/M, z22.s, z13.s + lsr x21, x14, #0x2 + and x20, x14, #0x3 + cbz x21, label_6 + subs x21, x21, #0x1 + KAI_ASM_INST(0xa040c764) // ld1w { z4.s-z7.s }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa041c768) // ld1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x4, MUL VL] + addvl x27, x27, #8 + KAI_ASM_INST(0xa040c794) // ld1w { z20.s-z23.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xa041c78c) // ld1w { z12.s-z15.s }, pn9.b/Z, [x28, #0x4, MUL VL] + addvl x28, x28, #8 + ble label_5 +KAI_ASM_LABEL(label_4) // K loop + KAI_ASM_INST(0x80940080) // fmopa za0.s, p0/M, p0/M, z4.s, z20.s + subs x21, x21, #0x1 + KAI_ASM_INST(0x80950081) // fmopa za1.s, p0/M, p0/M, z4.s, z21.s + KAI_ASM_INST(0x809400a2) // fmopa za2.s, p0/M, p0/M, z5.s, z20.s + KAI_ASM_INST(0x809500a3) // fmopa za3.s, p0/M, p0/M, z5.s, z21.s + KAI_ASM_INST(0x809600c0) // fmopa za0.s, p0/M, p0/M, z6.s, z22.s + KAI_ASM_INST(0x809700c1) // fmopa za1.s, p0/M, p0/M, z6.s, z23.s + KAI_ASM_INST(0x809600e2) // fmopa za2.s, p0/M, p0/M, z7.s, z22.s + KAI_ASM_INST(0x809700e3) // fmopa za3.s, p0/M, p0/M, z7.s, z23.s + KAI_ASM_INST(0xa040c764) // ld1w { z4.s-z7.s }, pn9.b/Z, [x27] + KAI_ASM_INST(0x808c0100) // fmopa za0.s, p0/M, p0/M, z8.s, z12.s + KAI_ASM_INST(0xa040c794) // ld1w { z20.s-z23.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0x808d0101) // fmopa za1.s, p0/M, p0/M, z8.s, z13.s + KAI_ASM_INST(0x808c0122) // fmopa za2.s, p0/M, p0/M, z9.s, z12.s + KAI_ASM_INST(0x808d0123) // fmopa za3.s, p0/M, p0/M, z9.s, z13.s + KAI_ASM_INST(0x808e0140) // fmopa za0.s, p0/M, p0/M, z10.s, z14.s + KAI_ASM_INST(0x808f0141) // fmopa za1.s, p0/M, p0/M, z10.s, z15.s + KAI_ASM_INST(0x808e0162) // fmopa za2.s, p0/M, p0/M, z11.s, z14.s + KAI_ASM_INST(0x808f0163) // fmopa za3.s, p0/M, p0/M, z11.s, z15.s + KAI_ASM_INST(0xa041c768) // ld1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x4, MUL VL] + addvl x27, x27, #8 + KAI_ASM_INST(0xa041c78c) // ld1w { z12.s-z15.s }, pn9.b/Z, [x28, #0x4, MUL VL] + addvl x28, x28, #8 + bgt label_4 +KAI_ASM_LABEL(label_5) // K loop tail + KAI_ASM_INST(0x80940080) // fmopa za0.s, p0/M, p0/M, z4.s, z20.s + KAI_ASM_INST(0x80950081) // fmopa za1.s, p0/M, p0/M, z4.s, z21.s + KAI_ASM_INST(0x809400a2) // fmopa za2.s, p0/M, p0/M, z5.s, z20.s + KAI_ASM_INST(0x809500a3) // fmopa za3.s, p0/M, p0/M, z5.s, z21.s + KAI_ASM_INST(0x809600c0) // fmopa za0.s, p0/M, p0/M, z6.s, z22.s + KAI_ASM_INST(0x809700c1) // fmopa za1.s, p0/M, p0/M, z6.s, z23.s + KAI_ASM_INST(0x809600e2) // fmopa za2.s, p0/M, p0/M, z7.s, z22.s + KAI_ASM_INST(0x809700e3) // fmopa za3.s, p0/M, p0/M, z7.s, z23.s + KAI_ASM_INST(0x808c0100) // fmopa za0.s, p0/M, p0/M, z8.s, z12.s + KAI_ASM_INST(0x808d0101) // fmopa za1.s, p0/M, p0/M, z8.s, z13.s + KAI_ASM_INST(0x808c0122) // fmopa za2.s, p0/M, p0/M, z9.s, z12.s + KAI_ASM_INST(0x808d0123) // fmopa za3.s, p0/M, p0/M, z9.s, z13.s + KAI_ASM_INST(0x808e0140) // fmopa za0.s, p0/M, p0/M, z10.s, z14.s + KAI_ASM_INST(0x808f0141) // fmopa za1.s, p0/M, p0/M, z10.s, z15.s + KAI_ASM_INST(0x808e0162) // fmopa za2.s, p0/M, p0/M, z11.s, z14.s + KAI_ASM_INST(0x808f0163) // fmopa za3.s, p0/M, p0/M, z11.s, z15.s +KAI_ASM_LABEL(label_6) // K oddments + cbz x20, label_8 +KAI_ASM_LABEL(label_7) // K oddments: Loop + KAI_ASM_INST(0xa040477c) // ld1w { z28.s-z29.s }, pn9.b/Z, [x27] + subs x20, x20, #0x1 + addvl x27, x27, #2 + KAI_ASM_INST(0xa1404787) // ld1w { z7.s, z15.s }, pn9.b/Z, [x28] + addvl x28, x28, #2 + KAI_ASM_INST(0x80870380) // fmopa za0.s, p0/M, p0/M, z28.s, z7.s + KAI_ASM_INST(0x808f0381) // fmopa za1.s, p0/M, p0/M, z28.s, z15.s + KAI_ASM_INST(0x808703a2) // fmopa za2.s, p0/M, p0/M, z29.s, z7.s + KAI_ASM_INST(0x808f03a3) // fmopa za3.s, p0/M, p0/M, z29.s, z15.s + bgt label_7 +KAI_ASM_LABEL(label_8) // K oddments: End + ldr x26, [x0, #0x10] + sub x25, x13, x15 + cntw x24 + KAI_ASM_INST(0x854ec01a) // ld1rw { z26.s }, p0/Z, [x0, #56] + ldr x23, [x0, #0x18] + cmp x25, x24 + KAI_ASM_INST(0x854fc018) // ld1rw { z24.s }, p0/Z, [x0, #60] + mov x12, #0x0 + csel x22, x25, x24, LT + add x26, x26, x11, LSL #2 // C += n + lsr x21, x22, #0x2 + madd x26, x15, x23, x26 // C += m * ldc + and x20, x22, #0x3 + cbz x21, label_11 +KAI_ASM_LABEL(label_10) // Store to output array: Accumulator row 0 loop + KAI_ASM_INST(0xc0860404) // mova { z4.s-z7.s }, za0h.s[x12] + KAI_ASM_INST(0xc086042c) // mova { z12.s-z15.s }, za1h.s[x12] + KAI_ASM_INST(0xc1b8cb44) // fclamp { z4.s-z7.s }, z26.s, z24.s + KAI_ASM_INST(0xc1b8cb4c) // fclamp { z12.s-z15.s }, z26.s, z24.s + add x12, x12, #0x4 + cmp x12, x21, LSL #2 + KAI_ASM_INST(0xa1604344) // st1w { z4.s, z12.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604345) // st1w { z5.s, z13.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604346) // st1w { z6.s, z14.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604347) // st1w { z7.s, z15.s }, p8, [x26] + add x26, x26, x23 + blt label_10 +KAI_ASM_LABEL(label_11) // Store to output array: Accumulator row 0 oddments + cbz x20, label_12 + KAI_ASM_INST(0xc0860400) // mova { z0.s-z3.s }, za0h.s[x12] + KAI_ASM_INST(0xc0860428) // mova { z8.s-z11.s }, za1h.s[x12] + subs x20, x20, #0x1 + KAI_ASM_INST(0xc1b8cb40) // fclamp { z0.s-z3.s }, z26.s, z24.s + KAI_ASM_INST(0xc1b8cb48) // fclamp { z8.s-z11.s }, z26.s, z24.s + KAI_ASM_INST(0xa1604340) // st1w { z0.s, z8.s }, p8, [x26] + add x26, x26, x23 + beq label_12 + subs x20, x20, #0x1 + KAI_ASM_INST(0xa1604341) // st1w { z1.s, z9.s }, p8, [x26] + add x26, x26, x23 + beq label_12 + KAI_ASM_INST(0xa1604342) // st1w { z2.s, z10.s }, p8, [x26] + add x26, x26, x23 +KAI_ASM_LABEL(label_12) // Store to output array: Accumulator row 0 oddments: End + subs x25, x25, x22 + beq label_16 + cmp x25, x24 + mov x12, #0x0 + csel x20, x25, x24, LT + lsr x21, x20, #0x2 + and x20, x20, #0x3 + cbz x21, label_14 +KAI_ASM_LABEL(label_13) // Store to output array: Accumulator row 1 loop + KAI_ASM_INST(0xc0860454) // mova { z20.s-z23.s }, za2h.s[x12] + KAI_ASM_INST(0xc086047c) // mova { z28.s-z31.s }, za3h.s[x12] + KAI_ASM_INST(0xc1b8cb54) // fclamp { z20.s-z23.s }, z26.s, z24.s + KAI_ASM_INST(0xc1b8cb5c) // fclamp { z28.s-z31.s }, z26.s, z24.s + add x12, x12, #0x4 + cmp x12, x21, LSL #2 + KAI_ASM_INST(0xa1604354) // st1w { z20.s, z28.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604355) // st1w { z21.s, z29.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604356) // st1w { z22.s, z30.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604357) // st1w { z23.s, z31.s }, p8, [x26] + add x26, x26, x23 + blt label_13 +KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments + cbz x20, label_15 + KAI_ASM_INST(0xc0860444) // mova { z4.s-z7.s }, za2h.s[x12] + KAI_ASM_INST(0xc086046c) // mova { z12.s-z15.s }, za3h.s[x12] + subs x20, x20, #0x1 + KAI_ASM_INST(0xc1b8cb44) // fclamp { z4.s-z7.s }, z26.s, z24.s + KAI_ASM_INST(0xc1b8cb4c) // fclamp { z12.s-z15.s }, z26.s, z24.s + KAI_ASM_INST(0xa1604344) // st1w { z4.s, z12.s }, p8, [x26] + add x26, x26, x23 + beq label_15 + subs x20, x20, #0x1 + KAI_ASM_INST(0xa1604345) // st1w { z5.s, z13.s }, p8, [x26] + add x26, x26, x23 + beq label_15 + KAI_ASM_INST(0xa1604346) // st1w { z6.s, z14.s }, p8, [x26] +KAI_ASM_LABEL(label_15) // Store to output array: Accumulator row 1 oddments: End +KAI_ASM_LABEL(label_16) // Store to output array: End + incw x11, ALL, MUL #2 + cmp x11, x10 + blt label_2 + incw x15, ALL, MUL #2 + mov x11, #0x0 + cmp x15, x13 + mov x9, x27 + blt label_1 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h index 6e629274e7da42d24f484f5e1c5f70dd2ba06cda..58440f4b3b9f20418ec738b2701c0f3aad09a740 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h @@ -27,7 +27,7 @@ typedef size_t (*kai_imatmul_clamp_f32_f32p_f32p_get_dst_size_func_t)(size_t m, /// Micro-kernel core function ("run" method) typedef void (*kai_imatmul_clamp_f32_f32p_f32p_run_imatmul_func_t)( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + void* dst, size_t dst_stride_row, float clamp_min, float clamp_max); /// Micro-kernel interface struct kai_imatmul_clamp_f32_f32p_f32p_ukernel { diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c index b2eeb5efcc17aeb745cc6e71c9e5104b70607bfe..0db463d4516cdf60d5616ef718f3ccfa99dacb88 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c @@ -4,13 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" #include @@ -18,30 +14,51 @@ #include "kai/kai_common.h" +typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + int32_t min; + int32_t max; + int32_t result_zero_point; + void* accumulator_buffer; + uint64_t flags; +} KernelArgs; + static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 4; +void kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(KernelArgs* args); + +// Returns a constant value specific to this kernel that's relative to vector length +static size_t kai_get_kernel_vec_length_constant(void) { + const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u8() / kai_kr; + return kernel_vec_length_constant; +} + size_t kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u32(); + return kai_mr * kai_get_kernel_vec_length_constant(); } size_t kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_kernel_vec_length_constant(); } size_t kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - return m_idx * indirect_k * sizeof(int8_t); + return m_idx * k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(int8_t); } static size_t kai_get_rhs_packed_stride_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( size_t k_chunk_count, size_t k_chunk_length) { - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); return kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() * - (sizeof(int32_t) + indirect_k * sizeof(int8_t) + sizeof(float)); + (sizeof(int32_t) + k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(int8_t) + sizeof(float)); } size_t kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( @@ -54,11 +71,11 @@ size_t kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2v } size_t kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_row_stride) { + size_t m_idx, size_t n_idx, size_t dst_stride_row) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); - return m_idx * dst_row_stride + n_idx * sizeof(int8_t); + return m_idx * dst_stride_row + n_idx * sizeof(int8_t); } size_t kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t m, size_t n) { @@ -67,334 +84,23 @@ size_t kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme void kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params) { - typedef struct { - const void* A; - const void* B; - void* C; - uint64_t ldcb; - uint64_t M; - uint64_t N; - uint64_t K; - int32_t min; - int32_t max; - int32_t result_zero_point; - const int n_0; - void* accumulator_buffer; - uint64_t flags; - } KernelArgs; - + void* dst, size_t dst_stride_row, const struct kai_matmul_requantize32_params* params) { KernelArgs args; args.A = lhs_packed; args.B = rhs_packed; - - size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - args.C = dst; - args.ldcb = dst_row_stride; + args.ldcb = dst_stride_row; args.M = m; args.N = n; - args.K = indirect_k; + args.K = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); args.min = params->min_value; args.max = params->max_value; args.result_zero_point = params->output_zero_point; - args.accumulator_buffer = NULL; args.flags = 0; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "ldr w14, [%x[args], %[offsetof_M]]\n" - "mov x13, #0x0\n" - "mov x11, #0x0\n" - "ptrue p1.b\n" - ".inst 0x25207811 // ptrue pn9.b\n" - "ldr w10, [%x[args], %[offsetof_N]]\n" - "ldr x9, [%x[args], %[offsetof_A]]\n" - "1:" // M loop - "ldr x28, [%x[args], %[offsetof_B]]\n" - "2:" // N loop - ".inst 0x25aa4570 // whilelt pn8.s, x11, x10, VLx2\n" - ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" - "mov x27, x9\n" - ".inst 0xa040438e // ld1w { z14.s-z15.s }, p8/Z, [x28]\n" // Load bias - "addvl x28, x28, #2\n" - ".inst 0xc09025c0 // addha za0.s, p1/M, p1/M, z14.s\n" - ".inst 0xc09025e1 // addha za1.s, p1/M, p1/M, z15.s\n" - ".inst 0xc09025c2 // addha za2.s, p1/M, p1/M, z14.s\n" - ".inst 0xc09025e3 // addha za3.s, p1/M, p1/M, z15.s\n" - "ldr x20, [%x[args], %[offsetof_K]]\n" - "add x20, x20, #0x3\n" - "lsr x20, x20, #0x2\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 6f\n" - "subs x21, x21, #0x1\n" - ".inst 0xa0400762 // ld1b { z2.b-z3.b }, pn9.b/Z, [x27]\n" - ".inst 0xa1400780 // ld1b { z0.b, z8.b }, pn9.b/Z, [x28]\n" - ".inst 0xa0410772 // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0xa0410794 // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL]\n" - ".inst 0xa042077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa0420796 // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL]\n" - ".inst 0xa0430778 // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - ".inst 0xa0430784 // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL]\n" - "addvl x28, x28, #8\n" - "ble 5f\n" - "4:" // K loop - ".inst 0xa0802440 // smopa za0.s, p1/M, p1/M, z2.b, z0.b\n" - "subs x21, x21, #0x1\n" - ".inst 0xa0882441 // smopa za1.s, p1/M, p1/M, z2.b, z8.b\n" - ".inst 0xa0802462 // smopa za2.s, p1/M, p1/M, z3.b, z0.b\n" - ".inst 0xa0882463 // smopa za3.s, p1/M, p1/M, z3.b, z8.b\n" - ".inst 0xa0400762 // ld1b { z2.b-z3.b }, pn9.b/Z, [x27]\n" - ".inst 0xa0942640 // smopa za0.s, p1/M, p1/M, z18.b, z20.b\n" - ".inst 0xa1400780 // ld1b { z0.b, z8.b }, pn9.b/Z, [x28]\n" - ".inst 0xa0952641 // smopa za1.s, p1/M, p1/M, z18.b, z21.b\n" - ".inst 0xa0942662 // smopa za2.s, p1/M, p1/M, z19.b, z20.b\n" - ".inst 0xa0952663 // smopa za3.s, p1/M, p1/M, z19.b, z21.b\n" - ".inst 0xa0410772 // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0xa0962740 // smopa za0.s, p1/M, p1/M, z26.b, z22.b\n" - ".inst 0xa0410794 // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL]\n" - ".inst 0xa0972741 // smopa za1.s, p1/M, p1/M, z26.b, z23.b\n" - ".inst 0xa0962762 // smopa za2.s, p1/M, p1/M, z27.b, z22.b\n" - ".inst 0xa0972763 // smopa za3.s, p1/M, p1/M, z27.b, z23.b\n" - ".inst 0xa042077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa0420796 // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL]\n" - ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" - ".inst 0xa0852701 // smopa za1.s, p1/M, p1/M, z24.b, z5.b\n" - ".inst 0xa0842722 // smopa za2.s, p1/M, p1/M, z25.b, z4.b\n" - ".inst 0xa0852723 // smopa za3.s, p1/M, p1/M, z25.b, z5.b\n" - ".inst 0xa0430778 // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - ".inst 0xa0430784 // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL]\n" - "addvl x28, x28, #8\n" - "bgt 4b\n" - "5:" // K loop tail - ".inst 0xa0802440 // smopa za0.s, p1/M, p1/M, z2.b, z0.b\n" - ".inst 0xa0882441 // smopa za1.s, p1/M, p1/M, z2.b, z8.b\n" - ".inst 0xa0802462 // smopa za2.s, p1/M, p1/M, z3.b, z0.b\n" - ".inst 0xa0882463 // smopa za3.s, p1/M, p1/M, z3.b, z8.b\n" - ".inst 0xa0942640 // smopa za0.s, p1/M, p1/M, z18.b, z20.b\n" - ".inst 0xa0952641 // smopa za1.s, p1/M, p1/M, z18.b, z21.b\n" - ".inst 0xa0942662 // smopa za2.s, p1/M, p1/M, z19.b, z20.b\n" - ".inst 0xa0952663 // smopa za3.s, p1/M, p1/M, z19.b, z21.b\n" - ".inst 0xa0962740 // smopa za0.s, p1/M, p1/M, z26.b, z22.b\n" - ".inst 0xa0972741 // smopa za1.s, p1/M, p1/M, z26.b, z23.b\n" - ".inst 0xa0962762 // smopa za2.s, p1/M, p1/M, z27.b, z22.b\n" - ".inst 0xa0972763 // smopa za3.s, p1/M, p1/M, z27.b, z23.b\n" - ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" - ".inst 0xa0852701 // smopa za1.s, p1/M, p1/M, z24.b, z5.b\n" - ".inst 0xa0842722 // smopa za2.s, p1/M, p1/M, z25.b, z4.b\n" - ".inst 0xa0852723 // smopa za3.s, p1/M, p1/M, z25.b, z5.b\n" - "6:" // K oddments - "cbz x20, 8f\n" - "7:" // K oddments: Loop - ".inst 0xa0400770 // ld1b { z16.b-z17.b }, pn9.b/Z, [x27]\n" - "subs x20, x20, #0x1\n" - "addvl x27, x27, #2\n" - ".inst 0xa0400788 // ld1b { z8.b-z9.b }, pn9.b/Z, [x28]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0882600 // smopa za0.s, p1/M, p1/M, z16.b, z8.b\n" - ".inst 0xa0892601 // smopa za1.s, p1/M, p1/M, z16.b, z9.b\n" - ".inst 0xa0882622 // smopa za2.s, p1/M, p1/M, z17.b, z8.b\n" - ".inst 0xa0892623 // smopa za3.s, p1/M, p1/M, z17.b, z9.b\n" - "bgt 7b\n" - "8:" // K oddments: End - "ldr x26, [%x[args], %[offsetof_C]]\n" - "sub x25, x14, x13\n" - "cntw x24\n" - "ld1rw { z27.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" - "ldr x23, [%x[args], %[offsetof_ldcb]]\n" - "whilelt p0.h, x11, x10\n" - "cmp x25, x24\n" - "ld1rw { z1.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" - "csel x22, x25, x24, LT\n" - "ld1rw { z0.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_result_zero_point]]\n" - "mov x12, #0x0\n" - "add x26, x26, x11\n" // C += n - "lsr x21, x22, #0x2\n" - "ld1w { z22.s }, p1/Z, [x28]\n" - "madd x26, x13, x23, x26\n" // C += m * ldc - "ld1w { z26.s }, p1/Z, [x28, #1, MUL VL]\n" - "and x20, x22, #0x3\n" - "addvl x28, x28, #2\n" - "cbz x21, 11f\n" - "10:" // Store to output array: Accumulator row 0 loop - ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" - ".inst 0xc086043c // mova { z28.s-z31.s }, za1h.s[x12]\n" - ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" - "fmul z16.s, z16.s, z22.s\n" - "fmul z17.s, z17.s, z22.s\n" - "add x12, x12, #0x4\n" - "fmul z18.s, z18.s, z22.s\n" - "fmul z19.s, z19.s, z22.s\n" - "cmp x12, x21, LSL #2\n" - "fmul z28.s, z28.s, z26.s\n" - "fmul z29.s, z29.s, z26.s\n" - "fmul z30.s, z30.s, z26.s\n" - "fmul z31.s, z31.s, z26.s\n" - ".inst 0xc1b8e210 // frintn { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc131e210 // fcvtzs { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc1b8e39c // frintn { z28.s-z31.s }, { z28.s-z31.s }\n" - ".inst 0xc1a0ab10 // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s\n" - ".inst 0xc131e39c // fcvtzs { z28.s-z31.s }, { z28.s-z31.s }\n" - ".inst 0xc1a0ab1c // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s\n" - ".inst 0xc1a1cf70 // sclamp { z16.s-z19.s }, z27.s, z1.s\n" - ".inst 0xc1a1cf7c // sclamp { z28.s-z31.s }, z27.s, z1.s\n" - "uzp1 z5.h, z16.h, z28.h\n" - "uzp1 z20.h, z17.h, z29.h\n" - "uzp1 z17.h, z18.h, z30.h\n" - "uzp1 z16.h, z19.h, z31.h\n" - "st1b { z5.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z20.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z17.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "blt 10b\n" - "11:" // Store to output array: Accumulator row 0 oddments - "cbz x20, 12f\n" - ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" - ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" - ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" - "fmul z4.s, z4.s, z22.s\n" - "fmul z5.s, z5.s, z22.s\n" - "subs x20, x20, #0x1\n" - "fmul z6.s, z6.s, z22.s\n" - "fmul z7.s, z7.s, z22.s\n" - "fmul z12.s, z12.s, z26.s\n" - "fmul z13.s, z13.s, z26.s\n" - "fmul z14.s, z14.s, z26.s\n" - "fmul z15.s, z15.s, z26.s\n" - ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc1b8e18c // frintn { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" - ".inst 0xc131e18c // fcvtzs { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc1a0ab0c // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s\n" - ".inst 0xc1a1cf64 // sclamp { z4.s-z7.s }, z27.s, z1.s\n" - ".inst 0xc1a1cf6c // sclamp { z12.s-z15.s }, z27.s, z1.s\n" - "uzp1 z16.h, z4.h, z12.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "beq 12f\n" - "subs x20, x20, #0x1\n" - "uzp1 z16.h, z5.h, z13.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "beq 12f\n" - "uzp1 z16.h, z6.h, z14.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "12:" // Store to output array: Accumulator row 0 oddments: End - "subs x25, x25, x22\n" - "beq 16f\n" - "cmp x25, x24\n" - "mov x12, #0x0\n" - "csel x20, x25, x24, LT\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 14f\n" - "13:" // Store to output array: Accumulator row 1 loop - ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" - ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" - ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" - "fmul z8.s, z8.s, z22.s\n" - "fmul z9.s, z9.s, z22.s\n" - "add x12, x12, #0x4\n" - "fmul z10.s, z10.s, z22.s\n" - "fmul z11.s, z11.s, z22.s\n" - "cmp x12, x21, LSL #2\n" - "fmul z16.s, z16.s, z26.s\n" - "fmul z17.s, z17.s, z26.s\n" - "fmul z18.s, z18.s, z26.s\n" - "fmul z19.s, z19.s, z26.s\n" - ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc1b8e210 // frintn { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc1a0ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s\n" - ".inst 0xc131e210 // fcvtzs { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc1a0ab10 // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s\n" - ".inst 0xc1a1cf68 // sclamp { z8.s-z11.s }, z27.s, z1.s\n" - ".inst 0xc1a1cf70 // sclamp { z16.s-z19.s }, z27.s, z1.s\n" - "uzp1 z21.h, z8.h, z16.h\n" - "uzp1 z20.h, z9.h, z17.h\n" - "uzp1 z17.h, z10.h, z18.h\n" - "uzp1 z16.h, z11.h, z19.h\n" - "st1b { z21.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z20.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z17.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "blt 13b\n" - "14:" // Store to output array: Accumulator row 1 oddments - "cbz x20, 15f\n" - ".inst 0xc086044c // mova { z12.s-z15.s }, za2h.s[x12]\n" - ".inst 0xc0860464 // mova { z4.s-z7.s }, za3h.s[x12]\n" - ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" - "fmul z12.s, z12.s, z22.s\n" - "fmul z13.s, z13.s, z22.s\n" - "subs x20, x20, #0x1\n" - "fmul z14.s, z14.s, z22.s\n" - "fmul z15.s, z15.s, z22.s\n" - "fmul z4.s, z4.s, z26.s\n" - "fmul z5.s, z5.s, z26.s\n" - "fmul z6.s, z6.s, z26.s\n" - "fmul z7.s, z7.s, z26.s\n" - ".inst 0xc1b8e18c // frintn { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc131e18c // fcvtzs { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc1a0ab0c // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s\n" - ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" - ".inst 0xc1a1cf6c // sclamp { z12.s-z15.s }, z27.s, z1.s\n" - ".inst 0xc1a1cf64 // sclamp { z4.s-z7.s }, z27.s, z1.s\n" - "uzp1 z16.h, z12.h, z4.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "beq 15f\n" - "subs x20, x20, #0x1\n" - "uzp1 z16.h, z13.h, z5.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "beq 15f\n" - "uzp1 z16.h, z14.h, z6.h\n" - "st1b { z16.h }, p0, [x26]\n" - "15:" // Store to output array: Accumulator row 1 oddments: End - "16:" // Store to output array: End - "incw x11, ALL, MUL #2\n" - "cmp x11, x10\n" - "blt 2b\n" - "incw x13, ALL, MUL #2\n" - "mov x11, #0x0\n" - "cmp x13, x14\n" - "mov x9, x27\n" - "blt 1b\n" - ".inst 0xd503467f // SMSTOP\n" - : - : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), - [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), - [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), - [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), - [offsetof_KernelArgs_result_zero_point] "I"(offsetof(KernelArgs, result_zero_point)), - [offsetof_M] "I"(offsetof(KernelArgs, M)), [offsetof_N] "I"(offsetof(KernelArgs, N)), - [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", - "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", - "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", - "z9"); + kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h index 2f52001b7a2f7ff25392f3f85f37c18ba636511a..31ed3e5cc5084764e8ca0a661313180a4da9c7e6 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h @@ -37,7 +37,6 @@ size_t kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_ /// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. /// @param[in] k_chunk_count Number of LHS column splits. /// @param[in] k_chunk_length Length of a LHS column split. -/// @param[in] k Number of columns in the unpacked LHS matrix. /// /// @return The offset in bytes to the data element. size_t kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( @@ -57,11 +56,11 @@ size_t kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2v /// /// @param[in] m_idx Row index. Must be a multiple of `m_step`. /// @param[in] n_idx Column index. Must be a multiple of `n_step`. -/// @param[in] dst_row_stride. Distance between start of two rows in the output buffer. +/// @param[in] dst_stride_row Row stride in bytes. /// /// @return The offset in bytes to the data element. size_t kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_row_stride); + size_t m_idx, size_t n_idx, size_t dst_stride_row); /// Gets the size in bytes of the destination matrix buffer. /// @@ -83,17 +82,15 @@ size_t kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length of a LHS column split +/// @param[in] k_chunk_length Length of a LHS column split. /// @param[in] lhs_packed Packed LHS matrix buffer. /// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. -/// @param[in] dst_row_stride Row stride in bytes of the output matrix. - +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. /// @param[in] params Requantization and clamp parameters. - void kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params); + void* dst, size_t dst_stride_row, const struct kai_matmul_requantize32_params* params); #ifdef __cplusplus } // extern "C" diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..750c7c6b45c710c643693d894961868bcf517bcc --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S @@ -0,0 +1,334 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x15, #0x0 + ldr x14, [x0, #0x30] + ptrue p1.b + KAI_ASM_INST(0x25207810) // ptrue pn8.b + ldr w13, [x0, #0x20] + mov x11, #0x0 + ldr w10, [x0, #0x28] + add x14, x14, #0x3 + ldr x9, [x0, #0x0] + lsr x14, x14, #0x2 +KAI_ASM_LABEL(label_1) // M loop + ldr x28, [x0, #0x8] +KAI_ASM_LABEL(label_2) // N loop + KAI_ASM_INST(0xa0404382) // ld1w { z2.s-z3.s }, pn8.b/Z, [x28] + KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } + mov x27, x9 + addvl x28, x28, #2 + KAI_ASM_INST(0xc0902440) // addha za0.s, p1/M, p1/M, z2.s + KAI_ASM_INST(0xc0902461) // addha za1.s, p1/M, p1/M, z3.s + KAI_ASM_INST(0xc0902442) // addha za2.s, p1/M, p1/M, z2.s + KAI_ASM_INST(0xc0902463) // addha za3.s, p1/M, p1/M, z3.s + lsr x21, x14, #0x2 + and x20, x14, #0x3 + cbz x21, label_6 + subs x21, x21, #0x1 + KAI_ASM_INST(0xa1408362) // ld1b { z2.b, z6.b, z10.b, z14.b }, pn8.b/Z, [x27] + KAI_ASM_INST(0xa1418360) // ld1b { z0.b, z4.b, z8.b, z12.b }, pn8.b/Z, [x27, #0x4, MUL VL] + addvl x27, x27, #8 + KAI_ASM_INST(0xa0408390) // ld1b { z16.b-z19.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xa041839c) // ld1b { z28.b-z31.b }, pn8.b/Z, [x28, #0x4, MUL VL] + addvl x28, x28, #8 + ble label_5 +KAI_ASM_LABEL(label_4) // K loop + KAI_ASM_INST(0xa0902440) // smopa za0.s, p1/M, p1/M, z2.b, z16.b + subs x21, x21, #0x1 + KAI_ASM_INST(0xa0912441) // smopa za1.s, p1/M, p1/M, z2.b, z17.b + KAI_ASM_INST(0xa09024c2) // smopa za2.s, p1/M, p1/M, z6.b, z16.b + KAI_ASM_INST(0xa09124c3) // smopa za3.s, p1/M, p1/M, z6.b, z17.b + KAI_ASM_INST(0xa0922540) // smopa za0.s, p1/M, p1/M, z10.b, z18.b + KAI_ASM_INST(0xa0932541) // smopa za1.s, p1/M, p1/M, z10.b, z19.b + KAI_ASM_INST(0xa09225c2) // smopa za2.s, p1/M, p1/M, z14.b, z18.b + KAI_ASM_INST(0xa09325c3) // smopa za3.s, p1/M, p1/M, z14.b, z19.b + KAI_ASM_INST(0xa1408362) // ld1b { z2.b, z6.b, z10.b, z14.b }, pn8.b/Z, [x27] + KAI_ASM_INST(0xa09c2400) // smopa za0.s, p1/M, p1/M, z0.b, z28.b + KAI_ASM_INST(0xa0408390) // ld1b { z16.b-z19.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xa09d2401) // smopa za1.s, p1/M, p1/M, z0.b, z29.b + KAI_ASM_INST(0xa09c2482) // smopa za2.s, p1/M, p1/M, z4.b, z28.b + KAI_ASM_INST(0xa09d2483) // smopa za3.s, p1/M, p1/M, z4.b, z29.b + KAI_ASM_INST(0xa09e2500) // smopa za0.s, p1/M, p1/M, z8.b, z30.b + KAI_ASM_INST(0xa09f2501) // smopa za1.s, p1/M, p1/M, z8.b, z31.b + KAI_ASM_INST(0xa09e2582) // smopa za2.s, p1/M, p1/M, z12.b, z30.b + KAI_ASM_INST(0xa09f2583) // smopa za3.s, p1/M, p1/M, z12.b, z31.b + KAI_ASM_INST(0xa1418360) // ld1b { z0.b, z4.b, z8.b, z12.b }, pn8.b/Z, [x27, #0x4, MUL VL] + addvl x27, x27, #8 + KAI_ASM_INST(0xa041839c) // ld1b { z28.b-z31.b }, pn8.b/Z, [x28, #0x4, MUL VL] + addvl x28, x28, #8 + bgt label_4 +KAI_ASM_LABEL(label_5) // K loop tail + KAI_ASM_INST(0xa0902440) // smopa za0.s, p1/M, p1/M, z2.b, z16.b + KAI_ASM_INST(0xa0912441) // smopa za1.s, p1/M, p1/M, z2.b, z17.b + KAI_ASM_INST(0xa09024c2) // smopa za2.s, p1/M, p1/M, z6.b, z16.b + KAI_ASM_INST(0xa09124c3) // smopa za3.s, p1/M, p1/M, z6.b, z17.b + KAI_ASM_INST(0xa0922540) // smopa za0.s, p1/M, p1/M, z10.b, z18.b + KAI_ASM_INST(0xa0932541) // smopa za1.s, p1/M, p1/M, z10.b, z19.b + KAI_ASM_INST(0xa09225c2) // smopa za2.s, p1/M, p1/M, z14.b, z18.b + KAI_ASM_INST(0xa09325c3) // smopa za3.s, p1/M, p1/M, z14.b, z19.b + KAI_ASM_INST(0xa09c2400) // smopa za0.s, p1/M, p1/M, z0.b, z28.b + KAI_ASM_INST(0xa09d2401) // smopa za1.s, p1/M, p1/M, z0.b, z29.b + KAI_ASM_INST(0xa09c2482) // smopa za2.s, p1/M, p1/M, z4.b, z28.b + KAI_ASM_INST(0xa09d2483) // smopa za3.s, p1/M, p1/M, z4.b, z29.b + KAI_ASM_INST(0xa09e2500) // smopa za0.s, p1/M, p1/M, z8.b, z30.b + KAI_ASM_INST(0xa09f2501) // smopa za1.s, p1/M, p1/M, z8.b, z31.b + KAI_ASM_INST(0xa09e2582) // smopa za2.s, p1/M, p1/M, z12.b, z30.b + KAI_ASM_INST(0xa09f2583) // smopa za3.s, p1/M, p1/M, z12.b, z31.b +KAI_ASM_LABEL(label_6) // K oddments + cbz x20, label_8 +KAI_ASM_LABEL(label_7) // K oddments: Loop + KAI_ASM_INST(0xa0400370) // ld1b { z16.b-z17.b }, pn8.b/Z, [x27] + subs x20, x20, #0x1 + addvl x27, x27, #2 + KAI_ASM_INST(0xa1400385) // ld1b { z5.b, z13.b }, pn8.b/Z, [x28] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0852600) // smopa za0.s, p1/M, p1/M, z16.b, z5.b + KAI_ASM_INST(0xa08d2601) // smopa za1.s, p1/M, p1/M, z16.b, z13.b + KAI_ASM_INST(0xa0852622) // smopa za2.s, p1/M, p1/M, z17.b, z5.b + KAI_ASM_INST(0xa08d2623) // smopa za3.s, p1/M, p1/M, z17.b, z13.b + bgt label_7 +KAI_ASM_LABEL(label_8) // K oddments: End + ldr x26, [x0, #0x10] + sub x25, x13, x15 + cntw x24 + KAI_ASM_INST(0x854ec41a) // ld1rw { z26.s }, p1/Z, [x0, #56] + ldr x23, [x0, #0x18] + whilelt p0.h, x11, x10 + cmp x25, x24 + KAI_ASM_INST(0x854fc417) // ld1rw { z23.s }, p1/Z, [x0, #60] + csel x22, x25, x24, LT + KAI_ASM_INST(0x8550c400) // ld1rw { z0.s }, p1/Z, [x0, #64] + mov x12, #0x0 + add x26, x26, x11 // C += n + lsr x21, x22, #0x2 + KAI_ASM_INST(0xa0404382) // ld1w { z2.s-z3.s }, pn8.b/Z, [x28] + madd x26, x15, x23, x26 // C += m * ldc + addvl x28, x28, #2 + and x20, x22, #0x3 + cbz x21, label_11 +KAI_ASM_LABEL(label_10) // Store to output array: Accumulator row 0 loop + KAI_ASM_INST(0xc0860408) // mova { z8.s-z11.s }, za0h.s[x12] + KAI_ASM_INST(0xc0860430) // mova { z16.s-z19.s }, za1h.s[x12] + KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc132e210) // scvtf { z16.s-z19.s }, { z16.s-z19.s } + fmul z8.s, z8.s, z2.s + fmul z9.s, z9.s, z2.s + add x12, x12, #0x4 + fmul z10.s, z10.s, z2.s + fmul z11.s, z11.s, z2.s + cmp x12, x21, LSL #2 + fmul z16.s, z16.s, z3.s + fmul z17.s, z17.s, z3.s + fmul z18.s, z18.s, z3.s + fmul z19.s, z19.s, z3.s + KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc1b8e210) // frintn { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s + KAI_ASM_INST(0xc131e210) // fcvtzs { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc1a0ab10) // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s + KAI_ASM_INST(0xc1b7cf48) // sclamp { z8.s-z11.s }, z26.s, z23.s + KAI_ASM_INST(0xc1b7cf50) // sclamp { z16.s-z19.s }, z26.s, z23.s + uzp1 z5.h, z8.h, z16.h + uzp1 z14.h, z9.h, z17.h + uzp1 z17.h, z10.h, z18.h + uzp1 z16.h, z11.h, z19.h + st1b { z5.h }, p0, [x26] + add x26, x26, x23 + st1b { z14.h }, p0, [x26] + add x26, x26, x23 + st1b { z17.h }, p0, [x26] + add x26, x26, x23 + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + blt label_10 +KAI_ASM_LABEL(label_11) // Store to output array: Accumulator row 0 oddments + cbz x20, label_12 + KAI_ASM_INST(0xc0860408) // mova { z8.s-z11.s }, za0h.s[x12] + KAI_ASM_INST(0xc086042c) // mova { z12.s-z15.s }, za1h.s[x12] + KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc132e18c) // scvtf { z12.s-z15.s }, { z12.s-z15.s } + fmul z8.s, z8.s, z2.s + fmul z9.s, z9.s, z2.s + subs x20, x20, #0x1 + fmul z10.s, z10.s, z2.s + fmul z11.s, z11.s, z2.s + fmul z12.s, z12.s, z3.s + fmul z13.s, z13.s, z3.s + fmul z14.s, z14.s, z3.s + fmul z15.s, z15.s, z3.s + KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc1b8e18c) // frintn { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s + KAI_ASM_INST(0xc131e18c) // fcvtzs { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc1a0ab0c) // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s + KAI_ASM_INST(0xc1b7cf48) // sclamp { z8.s-z11.s }, z26.s, z23.s + KAI_ASM_INST(0xc1b7cf4c) // sclamp { z12.s-z15.s }, z26.s, z23.s + uzp1 z16.h, z8.h, z12.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + beq label_12 + subs x20, x20, #0x1 + uzp1 z16.h, z9.h, z13.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + beq label_12 + uzp1 z16.h, z10.h, z14.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 +KAI_ASM_LABEL(label_12) // Store to output array: Accumulator row 0 oddments: End + subs x25, x25, x22 + beq label_16 + cmp x25, x24 + mov x12, #0x0 + csel x20, x25, x24, LT + lsr x21, x20, #0x2 + and x20, x20, #0x3 + cbz x21, label_14 +KAI_ASM_LABEL(label_13) // Store to output array: Accumulator row 1 loop + KAI_ASM_INST(0xc0860448) // mova { z8.s-z11.s }, za2h.s[x12] + KAI_ASM_INST(0xc0860470) // mova { z16.s-z19.s }, za3h.s[x12] + KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc132e210) // scvtf { z16.s-z19.s }, { z16.s-z19.s } + fmul z8.s, z8.s, z2.s + fmul z9.s, z9.s, z2.s + add x12, x12, #0x4 + fmul z10.s, z10.s, z2.s + fmul z11.s, z11.s, z2.s + cmp x12, x21, LSL #2 + fmul z16.s, z16.s, z3.s + fmul z17.s, z17.s, z3.s + fmul z18.s, z18.s, z3.s + fmul z19.s, z19.s, z3.s + KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc1b8e210) // frintn { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s + KAI_ASM_INST(0xc131e210) // fcvtzs { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc1a0ab10) // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s + KAI_ASM_INST(0xc1b7cf48) // sclamp { z8.s-z11.s }, z26.s, z23.s + KAI_ASM_INST(0xc1b7cf50) // sclamp { z16.s-z19.s }, z26.s, z23.s + uzp1 z21.h, z8.h, z16.h + uzp1 z20.h, z9.h, z17.h + uzp1 z17.h, z10.h, z18.h + uzp1 z16.h, z11.h, z19.h + st1b { z21.h }, p0, [x26] + add x26, x26, x23 + st1b { z20.h }, p0, [x26] + add x26, x26, x23 + st1b { z17.h }, p0, [x26] + add x26, x26, x23 + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + blt label_13 +KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments + cbz x20, label_15 + KAI_ASM_INST(0xc086044c) // mova { z12.s-z15.s }, za2h.s[x12] + KAI_ASM_INST(0xc0860464) // mova { z4.s-z7.s }, za3h.s[x12] + KAI_ASM_INST(0xc132e18c) // scvtf { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc132e084) // scvtf { z4.s-z7.s }, { z4.s-z7.s } + fmul z12.s, z12.s, z2.s + fmul z13.s, z13.s, z2.s + subs x20, x20, #0x1 + fmul z14.s, z14.s, z2.s + fmul z15.s, z15.s, z2.s + fmul z4.s, z4.s, z3.s + fmul z5.s, z5.s, z3.s + fmul z6.s, z6.s, z3.s + fmul z7.s, z7.s, z3.s + KAI_ASM_INST(0xc1b8e18c) // frintn { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc131e18c) // fcvtzs { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc1b8e084) // frintn { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc1a0ab0c) // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s + KAI_ASM_INST(0xc131e084) // fcvtzs { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc1a0ab04) // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s + KAI_ASM_INST(0xc1b7cf4c) // sclamp { z12.s-z15.s }, z26.s, z23.s + KAI_ASM_INST(0xc1b7cf44) // sclamp { z4.s-z7.s }, z26.s, z23.s + uzp1 z16.h, z12.h, z4.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + beq label_15 + subs x20, x20, #0x1 + uzp1 z16.h, z13.h, z5.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + beq label_15 + uzp1 z16.h, z14.h, z6.h + st1b { z16.h }, p0, [x26] +KAI_ASM_LABEL(label_15) // Store to output array: Accumulator row 1 oddments: End +KAI_ASM_LABEL(label_16) // Store to output array: End + incw x11, ALL, MUL #2 + cmp x11, x10 + blt label_2 + incw x15, ALL, MUL #2 + mov x11, #0x0 + cmp x15, x13 + mov x9, x27 + blt label_1 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h index 84ca66b1bd14a73af0e5b263592267e9dac87ce8..c515466884aefe817753bc312036537d1d13282d 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h @@ -22,13 +22,13 @@ typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_lhs_packed_offset_func typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_rhs_packed_offset_func_t)( size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_offset_func_t)( - size_t m_idx, size_t n_idx, size_t dst_stride); + size_t m_idx, size_t n_idx, size_t dst_stride_row); typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_size_func_t)(size_t m, size_t n); /// Micro-kernel core function ("run" method) typedef void (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_imatmul_func_t)( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params); + void* dst, size_t dst_stride_row, const struct kai_matmul_requantize32_params* params); /// Micro-kernel interface struct kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel { diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c index f996bd4889613749a1f5a1ca25193c62b1bdd48b..b3bbdbee10041a207847eeb456a66b7060690d98 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c @@ -4,13 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" #include @@ -18,9 +14,14 @@ #include "kai/kai_common.h" -#define MR 2 -#define KR 2 -#define MAX_M_STEP (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR) +enum { + MR = 2, + KR = 2, + MAX_M_STEP = MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR, +}; + +void kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t height, size_t width, const void* in, size_t row_offset, void* out); static size_t kai_get_mr_lhs_imatmul_pack_x16p2vlx2_x16p_sme(void) { return MR * kai_get_sme_vector_length_u16() / KR; @@ -69,270 +70,8 @@ void kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme( } } - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x22, %x[width]\n" - "mov x21, %x[width]\n" - "cnth x20\n" - "inch x22\n" - "sub x7, x20, #0x1\n" - "sub x22, x22, #0x1\n" - "ands x7, x21, x7\n" - "cntw x8\n" - "udiv x22, x22, x20\n" // n_passes = ceildiv(width, VL) - "csel x7, x7, x20, NE\n" - "sub x13, x22, #0x1\n" - "add x7, x7, #0x1\n" - "sub x17, x8, #0x2\n" - "lsl x21, %x[height], #0x1\n" // height * 2 - "lsl x20, x8, #0x1\n" - "mov x16, #0x0\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - "cntw x9, ALL, MUL #2\n" - "cntw x28, ALL, MUL #3\n" - "ldr x27, [x11, #0x0]\n" - "lsr x13, x13, #0x1\n" // n_loops = (n_passes - 1) / 2 - "and x26, x22, #0x1\n" // odd_tail = bool(n_passes & 0x1) - "ldr x25, [x10, #0x0]\n" - "lsr x7, x7, #0x1\n" - "ptrue p12.s\n" - "ldr x24, [x11, #0x8]\n" - "whilelt p11.h, XZR, x21\n" - "whilelt p10.h, x20, x21\n" - "ldr x21, [x10, #0x8]\n" - "mov x23, %x[row_offset]\n" - "mov x22, %x[out]\n" - "whilelt p9.h, x16, %x[width]\n" - "whilelt p8.h, x16, %x[width]\n" - "add x11, x11, #0x10\n" - "add x10, x10, #0x10\n" - "mov x12, #0x0\n" - "cbz x17, 2f\n" - "1:" // K loop: Charge: Loop - ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" - ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" - ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" - ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" - ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" - "add x12, x12, #0x4\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "cmp x12, x17, LSL #1\n" - "blt 1b\n" - "2:" // K loop: Charge: End - ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" - ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" - ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" - ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" - "ldr x27, [x11, #0x0]\n" - "inch x16\n" - ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "inch x23\n" - "cbz x13, 8f\n" - "mov x20, x13\n" - "3:" // K loop: Main loop - "whilelt p8.h, x16, %x[width]\n" - "mov x15, #0x0\n" - "mov x14, #0x0\n" - "cbz x17, 5f\n" - "4:" // K loop: Main loop: First: Loop - ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" - ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" - ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" - ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" - ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - "add x10, x10, #0x10\n" - "add x15, x15, #0x4\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x14, x14, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x14, x17\n" - "blt 4b\n" - "5:" // K loop: Main loop: First: Tail - ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" - ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" - ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" - ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" - "ldr x27, [x11, #0x0]\n" - "mov x13, #0x0\n" - ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" - "ldr x25, [x10, #0x0]\n" - "mov x12, #0x0\n" - ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - "whilelt p9.h, x16, %x[width]\n" - "inch x16\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "inch x23\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "whilelt p8.h, x16, %x[width]\n" - "cbz x17, 7f\n" - "6:" // K loop: Main loop: Second: Loop - ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" - ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" - ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" - ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" - ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "add x10, x10, #0x10\n" - "add x13, x13, #0x4\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x12, x12, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x12, x17\n" - "blt 6b\n" - "7:" // K loop: Main loop: Second: Tail - ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" - ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" - ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" - ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "whilelt p9.h, x16, %x[width]\n" - "subs x20, x20, #0x1\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "inch x16\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "inch x23\n" - "bgt 3b\n" - "8:" // K loop: Tails - "cbnz x26, 11f\n" - "mov x11, %x[in]\n" - "whilelt p8.h, x16, %x[width]\n" - "mov x13, #0x0\n" - "mov x12, #0x0\n" - "9:" // K loop: Tails: Even: First - ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25396161 // psel p1.h, p8.h/Z, p11.h[w13, #1]\n" - ".inst 0x25396140 // psel p0.h, p8.h/Z, p10.h[w13, #1]\n" - ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "ldr x21, [x11, #0x0]\n" - "cmp x12, x8\n" - "ldr x20, [x11, x8, LSL #0x3]\n" - "add x11, x11, #0x8\n" - ".inst 0xe05726a1 // ld1h { za0h.h[x13, #1] }, p1/Z, [x21, x23, LSL #1]\n" - ".inst 0xe0572289 // ld1h { za1h.h[x13, #1] }, p0/Z, [x20, x23, LSL #1]\n" - "add x13, x13, #0x2\n" - "blt 9b\n" - "whilelt p9.h, x16, %x[width]\n" - "whilelt p8.h, x16, %x[width]\n" - "mov x20, #0x0\n" - "mov x12, #0x0\n" - "10:" // K loop: Tails: Even: Second - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - "add x20, x20, #0x2\n" - ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 10b\n" - "whilelt p8.h, x16, %x[width]\n" - "b 13f\n" - "11:" // K loop: Tails: Odd - "mov x12, #0x0\n" - "12:" // K loop: Tails: Odd: Loop - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 12b\n" - "13:" // K loop: End - "mov %x[out], x22\n" - ".inst 0xd503467f // SMSTOP\n" - : [out] "+&r"(out) - : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", - "p7", "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", - "x24", "x25", "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", - "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", - "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + height, width, in, row_offset, out); // NOLINT(bugprone-multi-level-implicit-pointer-conversion) out_base += m_step * kai_roundup(k_chunk_length, KR) * sizeof(uint16_t); } } diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h index a0343938bd99c5aa466595ee45344d9763e0721d..2990bad7e007ef20b6d9921000ddf4b5c38991ac 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h @@ -16,7 +16,7 @@ extern "C" { /// /// The starting row index must be divisible by `m_step`. /// -/// @return Step size for row index +/// @return The m step value. size_t kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme(void); /// Gets the offset in bytes to the data element in the packed LHS buffer. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..cdd955deb113b1f61f5410159942d76b21933d3c --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S @@ -0,0 +1,320 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(lhs_imatmul_pack_x16p2vlx2_x16p_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x8, #0x0 + cnth x22 + mov x21, x1 + inch x21 + mov x20, x1 + sub x17, x22, #0x1 + sub x21, x21, #0x1 + ands x17, x20, x17 + cntw x16 + udiv x21, x21, x22 // n_passes = ceildiv(width, VL) + csel x17, x17, x22, NE + sub x13, x21, #0x1 + add x17, x17, #0x1 + sub x11, x16, #0x2 + lsl x22, x0, #0x1 // height * 2 + lsl x20, x16, #0x1 + mov x10, x2 + add x9, x2, x16, LSL #3 + cntw x28, ALL, MUL #2 + ldr x27, [x10, #0x0] + cntw x26, ALL, MUL #3 + lsr x13, x13, #0x1 // n_loops = (n_passes - 1) / 2 + ldr x25, [x9, #0x0] + and x24, x21, #0x1 // odd_tail = bool(n_passes & 0x1) + lsr x17, x17, #0x1 + ldr x23, [x10, #0x8] + ptrue p12.s + whilelt p11.h, XZR, x22 + ldr x21, [x9, #0x8] + whilelt p10.h, x20, x22 + mov x22, x3 + whilelt p9.h, x8, x1 + whilelt p8.h, x8, x1 + add x10, x10, #0x10 + add x9, x9, #0x10 + mov x12, #0x0 + cbz x11, label_2 +KAI_ASM_LABEL(label_1) // K loop: Charge: Loop + KAI_ASM_INST(0x25286163) // psel p3.h, p8.h/Z, p11.h[w12] + KAI_ASM_INST(0x25286142) // psel p2.h, p8.h/Z, p10.h[w12] + KAI_ASM_INST(0x25686161) // psel p1.h, p8.h/Z, p11.h[w12, #2] + KAI_ASM_INST(0x25686140) // psel p0.h, p8.h/Z, p10.h[w12, #2] + KAI_ASM_INST(0xe0560f60) // ld1h { za0h.h[x12] }, p3/Z, [x27, x22, LSL #1] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0560b28) // ld1h { za1h.h[x12] }, p2/Z, [x25, x22, LSL #1] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe05606e2) // ld1h { za0h.h[x12, #2] }, p1/Z, [x23, x22, LSL #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe05602aa) // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x22, LSL #1] + add x12, x12, #0x4 + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + cmp x12, x11, LSL #1 + blt label_1 +KAI_ASM_LABEL(label_2) // K loop: Charge: End + KAI_ASM_INST(0x25286163) // psel p3.h, p8.h/Z, p11.h[w12] + KAI_ASM_INST(0x25286142) // psel p2.h, p8.h/Z, p10.h[w12] + KAI_ASM_INST(0x25686161) // psel p1.h, p8.h/Z, p11.h[w12, #2] + KAI_ASM_INST(0x25686140) // psel p0.h, p8.h/Z, p10.h[w12, #2] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0560f60) // ld1h { za0h.h[x12] }, p3/Z, [x27, x22, LSL #1] + ldr x27, [x10, #0x0] + inch x8 + KAI_ASM_INST(0xe0560b28) // ld1h { za1h.h[x12] }, p2/Z, [x25, x22, LSL #1] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe05606e2) // ld1h { za0h.h[x12, #2] }, p1/Z, [x23, x22, LSL #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe05602aa) // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + inch x22 + cbz x13, label_8 + mov x20, x13 +KAI_ASM_LABEL(label_3) // K loop: Main loop + whilelt p8.h, x8, x1 + mov x15, #0x0 + mov x14, #0x0 + cbz x11, label_5 +KAI_ASM_LABEL(label_4) // K loop: Main loop: First: Loop + KAI_ASM_INST(0x253b6160) // psel p0.h, p8.h/Z, p11.h[w15, #1] + KAI_ASM_INST(0x253b6142) // psel p2.h, p8.h/Z, p10.h[w15, #1] + KAI_ASM_INST(0x257b6161) // psel p1.h, p8.h/Z, p11.h[w15, #3] + KAI_ASM_INST(0x257b6143) // psel p3.h, p8.h/Z, p10.h[w15, #3] + KAI_ASM_INST(0xe0566361) // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x22, LSL #1] + KAI_ASM_INST(0x252a7120) // psel p0.h, p12.h/Z, p9.h[w14] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0566b29) // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x22, LSL #1] + KAI_ASM_INST(0x252a7122) // psel p2.h, p12.h/Z, p9.h[w14] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe05666e3) // ld1h { za0h.h[x15, #3] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x253a7121) // psel p1.h, p12.h/Z, p9.h[w14, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0566eab) // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bfc080) // st1w { za0v.s[x14] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x253a7120) // psel p0.h, p12.h/Z, p9.h[w14, #1] + KAI_ASM_INST(0xe0b0c884) // st1w { za1v.s[x14] }, p2/Z, [x4, x16, LSL #2] + add x9, x9, #0x10 + add x15, x15, #0x4 + KAI_ASM_INST(0xe0bcc481) // st1w { za0v.s[x14, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0bac085) // st1w { za1v.s[x14, #1] }, p0/Z, [x4, x26, LSL #2] + add x14, x14, #0x2 + addvl x4, x4, #4 + cmp x14, x11 + blt label_4 +KAI_ASM_LABEL(label_5) // K loop: Main loop: First: Tail + KAI_ASM_INST(0x253b6160) // psel p0.h, p8.h/Z, p11.h[w15, #1] + KAI_ASM_INST(0x253b6142) // psel p2.h, p8.h/Z, p10.h[w15, #1] + KAI_ASM_INST(0x257b6161) // psel p1.h, p8.h/Z, p11.h[w15, #3] + KAI_ASM_INST(0x257b6143) // psel p3.h, p8.h/Z, p10.h[w15, #3] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0566361) // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x22, LSL #1] + KAI_ASM_INST(0x252a7120) // psel p0.h, p12.h/Z, p9.h[w14] + ldr x27, [x10, #0x0] + mov x13, #0x0 + KAI_ASM_INST(0xe0566b29) // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x22, LSL #1] + KAI_ASM_INST(0x252a7122) // psel p2.h, p12.h/Z, p9.h[w14] + ldr x25, [x9, #0x0] + mov x12, #0x0 + KAI_ASM_INST(0xe05666e3) // ld1h { za0h.h[x15, #3] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x253a7121) // psel p1.h, p12.h/Z, p9.h[w14, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0566eab) // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bfc080) // st1w { za0v.s[x14] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x253a7120) // psel p0.h, p12.h/Z, p9.h[w14, #1] + KAI_ASM_INST(0xe0b0c884) // st1w { za1v.s[x14] }, p2/Z, [x4, x16, LSL #2] + whilelt p9.h, x8, x1 + inch x8 + KAI_ASM_INST(0xe0bcc481) // st1w { za0v.s[x14, #1] }, p1/Z, [x4, x28, LSL #2] + add x9, x9, #0x10 + inch x22 + KAI_ASM_INST(0xe0bac085) // st1w { za1v.s[x14, #1] }, p0/Z, [x4, x26, LSL #2] + addvl x4, x4, #4 + whilelt p8.h, x8, x1 + cbz x11, label_7 +KAI_ASM_LABEL(label_6) // K loop: Main loop: Second: Loop + KAI_ASM_INST(0x25296160) // psel p0.h, p8.h/Z, p11.h[w13] + KAI_ASM_INST(0x25296142) // psel p2.h, p8.h/Z, p10.h[w13] + KAI_ASM_INST(0x25696161) // psel p1.h, p8.h/Z, p11.h[w13, #2] + KAI_ASM_INST(0x25696143) // psel p3.h, p8.h/Z, p10.h[w13, #2] + KAI_ASM_INST(0xe0562360) // ld1h { za0h.h[x13] }, p0/Z, [x27, x22, LSL #1] + KAI_ASM_INST(0x25287120) // psel p0.h, p12.h/Z, p9.h[w12] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0562b28) // ld1h { za1h.h[x13] }, p2/Z, [x25, x22, LSL #1] + KAI_ASM_INST(0x25287122) // psel p2.h, p12.h/Z, p9.h[w12] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe05626e2) // ld1h { za0h.h[x13, #2] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x25387121) // psel p1.h, p12.h/Z, p9.h[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0562eaa) // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bf8088) // st1w { za2v.s[x12] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x25387120) // psel p0.h, p12.h/Z, p9.h[w12, #1] + KAI_ASM_INST(0xe0b0888c) // st1w { za3v.s[x12] }, p2/Z, [x4, x16, LSL #2] + add x9, x9, #0x10 + add x13, x13, #0x4 + KAI_ASM_INST(0xe0bc8489) // st1w { za2v.s[x12, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0ba808d) // st1w { za3v.s[x12, #1] }, p0/Z, [x4, x26, LSL #2] + add x12, x12, #0x2 + addvl x4, x4, #4 + cmp x12, x11 + blt label_6 +KAI_ASM_LABEL(label_7) // K loop: Main loop: Second: Tail + KAI_ASM_INST(0x25296160) // psel p0.h, p8.h/Z, p11.h[w13] + KAI_ASM_INST(0x25296142) // psel p2.h, p8.h/Z, p10.h[w13] + KAI_ASM_INST(0x25696161) // psel p1.h, p8.h/Z, p11.h[w13, #2] + KAI_ASM_INST(0x25696143) // psel p3.h, p8.h/Z, p10.h[w13, #2] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0562360) // ld1h { za0h.h[x13] }, p0/Z, [x27, x22, LSL #1] + KAI_ASM_INST(0x25287120) // psel p0.h, p12.h/Z, p9.h[w12] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0562b28) // ld1h { za1h.h[x13] }, p2/Z, [x25, x22, LSL #1] + KAI_ASM_INST(0x25287122) // psel p2.h, p12.h/Z, p9.h[w12] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe05626e2) // ld1h { za0h.h[x13, #2] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x25387121) // psel p1.h, p12.h/Z, p9.h[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0562eaa) // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bf8088) // st1w { za2v.s[x12] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x25387120) // psel p0.h, p12.h/Z, p9.h[w12, #1] + KAI_ASM_INST(0xe0b0888c) // st1w { za3v.s[x12] }, p2/Z, [x4, x16, LSL #2] + whilelt p9.h, x8, x1 + subs x20, x20, #0x1 + KAI_ASM_INST(0xe0bc8489) // st1w { za2v.s[x12, #1] }, p1/Z, [x4, x28, LSL #2] + add x9, x9, #0x10 + inch x8 + KAI_ASM_INST(0xe0ba808d) // st1w { za3v.s[x12, #1] }, p0/Z, [x4, x26, LSL #2] + addvl x4, x4, #4 + inch x22 + bgt label_3 +KAI_ASM_LABEL(label_8) // K loop: Tails + cbnz x24, label_11 + mov x10, x2 + whilelt p8.h, x8, x1 + mov x13, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_9) // K loop: Tails: Even: First + KAI_ASM_INST(0x25307123) // psel p3.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25396161) // psel p1.h, p8.h/Z, p11.h[w13, #1] + KAI_ASM_INST(0x25396140) // psel p0.h, p8.h/Z, p10.h[w13, #1] + KAI_ASM_INST(0xe0bf8c80) // st1w { za0v.s[x12] }, p3/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0b08884) // st1w { za1v.s[x12] }, p2/Z, [x4, x16, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + ldr x21, [x10, #0x0] + cmp x12, x16 + ldr x20, [x10, x16, LSL #0x3] + add x10, x10, #0x8 + KAI_ASM_INST(0xe05626a1) // ld1h { za0h.h[x13, #1] }, p1/Z, [x21, x22, LSL #1] + KAI_ASM_INST(0xe0562289) // ld1h { za1h.h[x13, #1] }, p0/Z, [x20, x22, LSL #1] + add x13, x13, #0x2 + blt label_9 + whilelt p9.h, x8, x1 + whilelt p8.h, x8, x1 + mov x20, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_10) // K loop: Tails: Even: Second + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + add x20, x20, #0x2 + KAI_ASM_INST(0xe0bf8488) // st1w { za2v.s[x12] }, p1/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0b0808c) // st1w { za3v.s[x12] }, p0/Z, [x4, x16, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + cmp x12, x17 + blt label_10 + whilelt p8.h, x8, x1 + b label_13 +KAI_ASM_LABEL(label_11) // K loop: Tails: Odd + mov x12, #0x0 +KAI_ASM_LABEL(label_12) // K loop: Tails: Odd: Loop + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8480) // st1w { za0v.s[x12] }, p1/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0b08084) // st1w { za1v.s[x12] }, p0/Z, [x4, x16, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + cmp x12, x17 + blt label_12 +KAI_ASM_LABEL(label_13) // K loop: End + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c index bef728fe02e8265aef20b94d99c72227dd19e4ea..4f0925993237b331bdbfb08e2db68ff01e48fff1 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c @@ -4,13 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" #include @@ -18,9 +14,14 @@ #include "kai/kai_common.h" -#define MR 2 -#define KR 1 -#define MAX_M_STEP (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(float)) / KR) +enum { + MR = 2, + KR = 1, + MAX_M_STEP = MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(float)) / KR, +}; + +void kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + size_t height, size_t width, const void* in, size_t row_offset, void* out); static size_t kai_get_mr_lhs_imatmul_pack_x32p2vlx1_x32p_sme(void) { return MR * kai_get_sme_vector_length_u32() / KR; @@ -69,257 +70,8 @@ void kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme( } } - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x21, %x[width]\n" - "mov x20, %x[width]\n" - "incw x21\n" - "cntw x17\n" - "sub x21, x21, #0x1\n" - "sub x16, x17, #0x1\n" - "udiv x21, x21, x17\n" // n_passes = ceildiv(width, VL) - "ands x16, x20, x16\n" - "sub x20, x21, #0x1\n" - "sub x15, x17, #0x2\n" - "mov x14, #0x0\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - "cntw x9, ALL, MUL #2\n" - "ldr x28, [x11, #0x0]\n" - "cntw x27, ALL, MUL #3\n" - "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 - "ldr x26, [x10, #0x0]\n" - "and x25, x21, #0x1\n" // odd_tail = bool(n_passes & 0x1) - "csel x16, x16, x17, NE\n" - "ldr x24, [x11, #0x8]\n" - "ptrue p12.s\n" - "whilelt p11.s, XZR, %x[height]\n" - "ldr x21, [x10, #0x8]\n" - "whilelt p10.s, x17, %x[height]\n" - "mov x23, %x[row_offset]\n" - "mov x22, %x[out]\n" - "whilelt p9.s, x14, %x[width]\n" - "whilelt p8.s, x14, %x[width]\n" - "add x11, x11, #0x10\n" - "add x10, x10, #0x10\n" - "mov x12, #0x0\n" - "cbz x15, 2f\n" - "1:" // K loop: Charge: Loop - ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" - ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" - "add x12, x12, #0x2\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "cmp x12, x15\n" - "blt 1b\n" - "2:" // K loop: Charge: End - ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" - "ldr x28, [x11, #0x0]\n" - "incw x14\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "incw x23\n" - "cbz x20, 8f\n" - "mov x20, x20\n" - "3:" // K loop: Main loop - "whilelt p8.s, x14, %x[width]\n" - "mov x13, #0x0\n" - "cbz x15, 5f\n" - "4:" // K loop: Main loop: First: Loop - ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" - ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" - ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" - ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" - ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" - ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "add x13, x13, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x13, x15\n" - "blt 4b\n" - "5:" // K loop: Main loop: First: Tail - ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" - ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" - ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" - ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" - "ldr x28, [x11, #0x0]\n" - "mov x12, #0x0\n" - ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" - ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" - "whilelt p9.s, x14, %x[width]\n" - "incw x14\n" - ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "incw x23\n" - ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "addvl x22, x22, #4\n" - "whilelt p8.s, x14, %x[width]\n" - "cbz x15, 7f\n" - "6:" // K loop: Main loop: Second: Loop - ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" - ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" - ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "add x12, x12, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x12, x15\n" - "blt 6b\n" - "7:" // K loop: Main loop: Second: Tail - ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" - ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" - "whilelt p9.s, x14, %x[width]\n" - "subs x20, x20, #0x1\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "incw x14\n" - ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "addvl x22, x22, #4\n" - "incw x23\n" - "bgt 3b\n" - "8:" // K loop: Tails - "cbnz x25, 11f\n" - "mov x11, %x[in]\n" - "whilelt p8.s, x14, %x[width]\n" - "mov x12, #0x0\n" - "9:" // K loop: Tails: Even: First - ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25306161 // psel p1.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306140 // psel p0.s, p8.s/Z, p10.s[w12]\n" - ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0b18ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" - "addvl x22, x22, #2\n" - "ldr x21, [x11, #0x0]\n" - "ldr x20, [x11, x17, LSL #0x3]\n" - "add x11, x11, #0x8\n" - ".inst 0xe09706a8 // ld1w { za2h.s[x12] }, p1/Z, [x21, x23, LSL #2]\n" - ".inst 0xe097028c // ld1w { za3h.s[x12] }, p0/Z, [x20, x23, LSL #2]\n" - "add x12, x12, #0x1\n" - "cmp x12, x17\n" - "blt 9b\n" - "whilelt p9.s, x14, %x[width]\n" - "whilelt p8.s, x14, %x[width]\n" - "mov x12, #0x0\n" - "10:" // K loop: Tails: Even: Second - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0b182cc // st1w { za3v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x16\n" - "blt 10b\n" - "whilelt p8.s, x14, %x[width]\n" - "b 13f\n" - "11:" // K loop: Tails: Odd - "mov x12, #0x0\n" - "12:" // K loop: Tails: Odd: Loop - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0b182c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x16\n" - "blt 12b\n" - "13:" // K loop: End - "mov %x[out], x22\n" - ".inst 0xd503467f // SMSTOP\n" - : [out] "+&r"(out) - : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", - "p7", "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", - "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", - "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", - "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + height, width, in, row_offset, out); // NOLINT(bugprone-multi-level-implicit-pointer-conversion) out_base += m_step * kai_roundup(k_chunk_length, KR) * sizeof(float); } } diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h index 5f6c68a945a969a23f39f392444ac41702e1e897..416c3130e85540a9adaeebebba0fb560ab790dba 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h @@ -16,7 +16,7 @@ extern "C" { /// /// The starting row index must be divisible by `m_step`. /// -/// @return Step size for row index +/// @return The m step value. size_t kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme(void); /// Gets the offset in bytes to the data element in the packed LHS buffer. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..f2a144e672bd996c4dd7dbffe9cd4384f7147052 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S @@ -0,0 +1,307 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(lhs_imatmul_pack_x32p2vlx1_x32p_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x16, #0x0 + mov x21, x1 + cntw x15 + incw x21 + mov x20, x1 + sub x21, x21, #0x1 + sub x14, x15, #0x1 + udiv x21, x21, x15 // n_passes = ceildiv(width, VL) + ands x14, x20, x14 + sub x20, x21, #0x1 + sub x11, x15, #0x2 + mov x10, x2 + add x9, x2, x15, LSL #3 + cntw x28, ALL, MUL #2 + cntw x27, ALL, MUL #3 + ldr x26, [x10, #0x0] + lsr x20, x20, #0x1 // n_loops = (n_passes - 1) / 2 + and x25, x21, #0x1 // odd_tail = bool(n_passes & 0x1) + ldr x24, [x9, #0x0] + csel x14, x14, x15, NE + ptrue p12.s + ldr x23, [x10, #0x8] + whilelt p11.s, XZR, x0 + whilelt p10.s, x15, x0 + ldr x21, [x9, #0x8] + mov x22, x3 + whilelt p9.s, x16, x1 + whilelt p8.s, x16, x1 + add x10, x10, #0x10 + add x9, x9, #0x10 + mov x12, #0x0 + cbz x11, label_2 +KAI_ASM_LABEL(label_1) // K loop: Charge: Loop + KAI_ASM_INST(0x25306163) // psel p3.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706140) // psel p0.s, p8.s/Z, p10.s[w12, #1] + KAI_ASM_INST(0xe0960f40) // ld1w { za0h.s[x12] }, p3/Z, [x26, x22, LSL #2] + ldr x26, [x10, #0x0] + KAI_ASM_INST(0xe0960b04) // ld1w { za1h.s[x12] }, p2/Z, [x24, x22, LSL #2] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe09602a5) // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x22, LSL #2] + add x12, x12, #0x2 + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + cmp x12, x11 + blt label_1 +KAI_ASM_LABEL(label_2) // K loop: Charge: End + KAI_ASM_INST(0x25306163) // psel p3.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706140) // psel p0.s, p8.s/Z, p10.s[w12, #1] + mov x10, x2 + add x9, x2, x15, LSL #3 + KAI_ASM_INST(0xe0960f40) // ld1w { za0h.s[x12] }, p3/Z, [x26, x22, LSL #2] + ldr x26, [x10, #0x0] + incw x16 + KAI_ASM_INST(0xe0960b04) // ld1w { za1h.s[x12] }, p2/Z, [x24, x22, LSL #2] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe09602a5) // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + incw x22 + cbz x20, label_8 + mov x20, x20 +KAI_ASM_LABEL(label_3) // K loop: Main loop + whilelt p8.s, x16, x1 + mov x13, #0x0 + cbz x11, label_5 +KAI_ASM_LABEL(label_4) // K loop: Main loop: First: Loop + KAI_ASM_INST(0x25316160) // psel p0.s, p8.s/Z, p11.s[w13] + KAI_ASM_INST(0x25316142) // psel p2.s, p8.s/Z, p10.s[w13] + KAI_ASM_INST(0x25716161) // psel p1.s, p8.s/Z, p11.s[w13, #1] + KAI_ASM_INST(0x25716143) // psel p3.s, p8.s/Z, p10.s[w13, #1] + KAI_ASM_INST(0xe0962348) // ld1w { za2h.s[x13] }, p0/Z, [x26, x22, LSL #2] + KAI_ASM_INST(0x25317120) // psel p0.s, p12.s/Z, p9.s[w13] + ldr x26, [x10, #0x0] + KAI_ASM_INST(0xe0962b0c) // ld1w { za3h.s[x13] }, p2/Z, [x24, x22, LSL #2] + KAI_ASM_INST(0x25317122) // psel p2.s, p12.s/Z, p9.s[w13] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe09626e9) // ld1w { za2h.s[x13, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25717121) // psel p1.s, p12.s/Z, p9.s[w13, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0962ead) // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bfa080) // st1w { za0v.s[x13] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x25717120) // psel p0.s, p12.s/Z, p9.s[w13, #1] + KAI_ASM_INST(0xe0afa884) // st1w { za1v.s[x13] }, p2/Z, [x4, x15, LSL #2] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0bca481) // st1w { za0v.s[x13, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0bba085) // st1w { za1v.s[x13, #1] }, p0/Z, [x4, x27, LSL #2] + add x13, x13, #0x2 + addvl x4, x4, #4 + cmp x13, x11 + blt label_4 +KAI_ASM_LABEL(label_5) // K loop: Main loop: First: Tail + KAI_ASM_INST(0x25316160) // psel p0.s, p8.s/Z, p11.s[w13] + KAI_ASM_INST(0x25316142) // psel p2.s, p8.s/Z, p10.s[w13] + KAI_ASM_INST(0x25716161) // psel p1.s, p8.s/Z, p11.s[w13, #1] + KAI_ASM_INST(0x25716143) // psel p3.s, p8.s/Z, p10.s[w13, #1] + mov x10, x2 + add x9, x2, x15, LSL #3 + KAI_ASM_INST(0xe0962348) // ld1w { za2h.s[x13] }, p0/Z, [x26, x22, LSL #2] + KAI_ASM_INST(0x25317120) // psel p0.s, p12.s/Z, p9.s[w13] + ldr x26, [x10, #0x0] + mov x12, #0x0 + KAI_ASM_INST(0xe0962b0c) // ld1w { za3h.s[x13] }, p2/Z, [x24, x22, LSL #2] + KAI_ASM_INST(0x25317122) // psel p2.s, p12.s/Z, p9.s[w13] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe09626e9) // ld1w { za2h.s[x13, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25717121) // psel p1.s, p12.s/Z, p9.s[w13, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0962ead) // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bfa080) // st1w { za0v.s[x13] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x25717120) // psel p0.s, p12.s/Z, p9.s[w13, #1] + KAI_ASM_INST(0xe0afa884) // st1w { za1v.s[x13] }, p2/Z, [x4, x15, LSL #2] + whilelt p9.s, x16, x1 + incw x16 + KAI_ASM_INST(0xe0bca481) // st1w { za0v.s[x13, #1] }, p1/Z, [x4, x28, LSL #2] + add x9, x9, #0x10 + incw x22 + KAI_ASM_INST(0xe0bba085) // st1w { za1v.s[x13, #1] }, p0/Z, [x4, x27, LSL #2] + addvl x4, x4, #4 + whilelt p8.s, x16, x1 + cbz x11, label_7 +KAI_ASM_LABEL(label_6) // K loop: Main loop: Second: Loop + KAI_ASM_INST(0x25306160) // psel p0.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706143) // psel p3.s, p8.s/Z, p10.s[w12, #1] + KAI_ASM_INST(0xe0960340) // ld1w { za0h.s[x12] }, p0/Z, [x26, x22, LSL #2] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + ldr x26, [x10, #0x0] + KAI_ASM_INST(0xe0960b04) // ld1w { za1h.s[x12] }, p2/Z, [x24, x22, LSL #2] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25707121) // psel p1.s, p12.s/Z, p9.s[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0960ea5) // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bf8088) // st1w { za2v.s[x12] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x25707120) // psel p0.s, p12.s/Z, p9.s[w12, #1] + KAI_ASM_INST(0xe0af888c) // st1w { za3v.s[x12] }, p2/Z, [x4, x15, LSL #2] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0bc8489) // st1w { za2v.s[x12, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0bb808d) // st1w { za3v.s[x12, #1] }, p0/Z, [x4, x27, LSL #2] + add x12, x12, #0x2 + addvl x4, x4, #4 + cmp x12, x11 + blt label_6 +KAI_ASM_LABEL(label_7) // K loop: Main loop: Second: Tail + KAI_ASM_INST(0x25306160) // psel p0.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706143) // psel p3.s, p8.s/Z, p10.s[w12, #1] + mov x10, x2 + add x9, x2, x15, LSL #3 + KAI_ASM_INST(0xe0960340) // ld1w { za0h.s[x12] }, p0/Z, [x26, x22, LSL #2] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + ldr x26, [x10, #0x0] + KAI_ASM_INST(0xe0960b04) // ld1w { za1h.s[x12] }, p2/Z, [x24, x22, LSL #2] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25707121) // psel p1.s, p12.s/Z, p9.s[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0960ea5) // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bf8088) // st1w { za2v.s[x12] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x25707120) // psel p0.s, p12.s/Z, p9.s[w12, #1] + KAI_ASM_INST(0xe0af888c) // st1w { za3v.s[x12] }, p2/Z, [x4, x15, LSL #2] + whilelt p9.s, x16, x1 + subs x20, x20, #0x1 + KAI_ASM_INST(0xe0bc8489) // st1w { za2v.s[x12, #1] }, p1/Z, [x4, x28, LSL #2] + add x9, x9, #0x10 + incw x16 + KAI_ASM_INST(0xe0bb808d) // st1w { za3v.s[x12, #1] }, p0/Z, [x4, x27, LSL #2] + addvl x4, x4, #4 + incw x22 + bgt label_3 +KAI_ASM_LABEL(label_8) // K loop: Tails + cbnz x25, label_11 + mov x10, x2 + whilelt p8.s, x16, x1 + mov x12, #0x0 +KAI_ASM_LABEL(label_9) // K loop: Tails: Even: First + KAI_ASM_INST(0x25307123) // psel p3.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306161) // psel p1.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306140) // psel p0.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0xe0bf8c80) // st1w { za0v.s[x12] }, p3/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0af8884) // st1w { za1v.s[x12] }, p2/Z, [x4, x15, LSL #2] + addvl x4, x4, #2 + ldr x21, [x10, #0x0] + ldr x20, [x10, x15, LSL #0x3] + add x10, x10, #0x8 + KAI_ASM_INST(0xe09606a8) // ld1w { za2h.s[x12] }, p1/Z, [x21, x22, LSL #2] + KAI_ASM_INST(0xe096028c) // ld1w { za3h.s[x12] }, p0/Z, [x20, x22, LSL #2] + add x12, x12, #0x1 + cmp x12, x15 + blt label_9 + whilelt p9.s, x16, x1 + whilelt p8.s, x16, x1 + mov x12, #0x0 +KAI_ASM_LABEL(label_10) // K loop: Tails: Even: Second + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8488) // st1w { za2v.s[x12] }, p1/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0af808c) // st1w { za3v.s[x12] }, p0/Z, [x4, x15, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + cmp x12, x14 + blt label_10 + whilelt p8.s, x16, x1 + b label_13 +KAI_ASM_LABEL(label_11) // K loop: Tails: Odd + mov x12, #0x0 +KAI_ASM_LABEL(label_12) // K loop: Tails: Odd: Loop + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8480) // st1w { za0v.s[x12] }, p1/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0af8084) // st1w { za1v.s[x12] }, p0/Z, [x4, x15, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + cmp x12, x14 + blt label_12 +KAI_ASM_LABEL(label_13) // K loop: End + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c index 25a48afc6d939aede29405f2ae2b0cf30ad56181..537992349623ff9e140e636a22950349e0d0d42e 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c @@ -4,13 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h" #include @@ -18,9 +14,14 @@ #include "kai/kai_common.h" -#define MR 2 -#define KR 4 -#define MAX_M_STEP (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(int8_t)) / KR) +enum { + MR = 2, + KR = 4, + MAX_M_STEP = MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(int8_t)) / KR, +}; + +void kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme( + size_t height, size_t width, const void* in, size_t row_offset, void* out); static size_t kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void) { return MR * kai_get_sme_vector_length_u8() / KR; @@ -69,271 +70,8 @@ void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( } } - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x23, %x[width]\n" - "mov x21, %x[width]\n" - "cntb x20\n" - "incb x23\n" - "sub x7, x20, #0x1\n" - "cntw x8\n" - "sub x23, x23, #0x1\n" - "ands x7, x21, x7\n" - "udiv x23, x23, x20\n" // n_passes = ceildiv(width, VL) - "csel x7, x7, x20, NE\n" - "lsl x22, %x[height], #0x1\n" // height * 2 - "lsl x21, x8, #0x1\n" - "sub x20, x23, #0x1\n" - "add x7, x7, #0x3\n" - "sub x17, x8, #0x2\n" - "whilelt p9.b, XZR, x22\n" - "whilelt p8.b, x21, x22\n" - "mov x16, #0x0\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - "cntw x9, ALL, MUL #2\n" - "cntw x28, ALL, MUL #3\n" - "ldr x27, [x11, #0x0]\n" - "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 - "and x26, x23, #0x1\n" // odd_tail = bool(n_passes & 0x1) - "ldr x25, [x10, #0x0]\n" - "lsr x7, x7, #0x2\n" - "ptrue p11.s\n" - "ldr x24, [x11, #0x8]\n" - "zip1 p10.b, p9.b, p8.b\n" - "mov x23, %x[row_offset]\n" - "ldr x21, [x10, #0x8]\n" - "mov x22, %x[out]\n" - "whilelt p9.b, x16, %x[width]\n" - "whilelt p8.b, x16, %x[width]\n" - "add x11, x11, #0x10\n" - "add x10, x10, #0x10\n" - "mov x12, #0x0\n" - "cbz x17, 2f\n" - "1:" // K loop: Charge: Loop - ".inst 0x25246143 // psel p3.b, p8.b/Z, p10.b[w12]\n" - ".inst 0x252c6142 // psel p2.b, p8.b/Z, p10.b[w12, #1]\n" - ".inst 0x25646141 // psel p1.b, p8.b/Z, p10.b[w12, #4]\n" - ".inst 0x256c6140 // psel p0.b, p8.b/Z, p10.b[w12, #5]\n" - ".inst 0xe0170f60 // ld1b { za0h.b[x12] }, p3/Z, [x27, x23]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0170b21 // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x23]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0170704 // ld1b { za0h.b[x12, #4] }, p1/Z, [x24, x23]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01702a5 // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x23]\n" - "add x12, x12, #0x8\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "cmp x12, x17, LSL #2\n" - "blt 1b\n" - "2:" // K loop: Charge: End - ".inst 0x25246143 // psel p3.b, p8.b/Z, p10.b[w12]\n" - ".inst 0x252c6142 // psel p2.b, p8.b/Z, p10.b[w12, #1]\n" - ".inst 0x25646141 // psel p1.b, p8.b/Z, p10.b[w12, #4]\n" - ".inst 0x256c6140 // psel p0.b, p8.b/Z, p10.b[w12, #5]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0170f60 // ld1b { za0h.b[x12] }, p3/Z, [x27, x23]\n" - "ldr x27, [x11, #0x0]\n" - "incb x16\n" - ".inst 0xe0170b21 // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x23]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0170704 // ld1b { za0h.b[x12, #4] }, p1/Z, [x24, x23]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01702a5 // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "incb x23\n" - "cbz x20, 8f\n" - "mov x20, x20\n" - "3:" // K loop: Main loop - "whilelt p8.b, x16, %x[width]\n" - "mov x15, #0x0\n" - "mov x14, #0x0\n" - "cbz x17, 5f\n" - "4:" // K loop: Main loop: First: Loop - ".inst 0x25376143 // psel p3.b, p8.b/Z, p10.b[w15, #2]\n" - ".inst 0x253f6142 // psel p2.b, p8.b/Z, p10.b[w15, #3]\n" - ".inst 0x25776141 // psel p1.b, p8.b/Z, p10.b[w15, #6]\n" - ".inst 0x257f6140 // psel p0.b, p8.b/Z, p10.b[w15, #7]\n" - ".inst 0xe0176f62 // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x23]\n" - ".inst 0x25266d23 // psel p3.b, p11.b/Z, p9.b[w14]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0176b23 // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x23]\n" - ".inst 0x25266d22 // psel p2.b, p11.b/Z, p9.b[w14]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0176706 // ld1b { za0h.b[x15, #6] }, p1/Z, [x24, x23]\n" - ".inst 0x252e6d21 // psel p1.b, p11.b/Z, p9.b[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01762a7 // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0x252e6d20 // psel p0.b, p11.b/Z, p9.b[w14, #1]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0bfcec0 // st1w { za0v.s[x14] }, p3/Z, [x22, XZR, LSL #2]\n" - "add x15, x15, #0x8\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x14, x14, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x14, x17\n" - "blt 4b\n" - "5:" // K loop: Main loop: First: Tail - ".inst 0x25376143 // psel p3.b, p8.b/Z, p10.b[w15, #2]\n" - ".inst 0x253f6142 // psel p2.b, p8.b/Z, p10.b[w15, #3]\n" - ".inst 0x25776141 // psel p1.b, p8.b/Z, p10.b[w15, #6]\n" - ".inst 0x257f6140 // psel p0.b, p8.b/Z, p10.b[w15, #7]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0176f62 // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x23]\n" - ".inst 0x25266d23 // psel p3.b, p11.b/Z, p9.b[w14]\n" - "ldr x27, [x11, #0x0]\n" - "mov x13, #0x0\n" - ".inst 0xe0176b23 // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x23]\n" - ".inst 0x25266d22 // psel p2.b, p11.b/Z, p9.b[w14]\n" - "ldr x25, [x10, #0x0]\n" - "mov x12, #0x0\n" - ".inst 0xe0176706 // ld1b { za0h.b[x15, #6] }, p1/Z, [x24, x23]\n" - ".inst 0x252e6d21 // psel p1.b, p11.b/Z, p9.b[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01762a7 // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0x252e6d20 // psel p0.b, p11.b/Z, p9.b[w14, #1]\n" - "whilelt p9.b, x16, %x[width]\n" - ".inst 0xe0bfcec0 // st1w { za0v.s[x14] }, p3/Z, [x22, XZR, LSL #2]\n" - "incb x16\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - "incb x23\n" - "whilelt p8.b, x16, %x[width]\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "cbz x17, 7f\n" - "6:" // K loop: Main loop: Second: Loop - ".inst 0x25256143 // psel p3.b, p8.b/Z, p10.b[w13]\n" - ".inst 0x252d6142 // psel p2.b, p8.b/Z, p10.b[w13, #1]\n" - ".inst 0x25656141 // psel p1.b, p8.b/Z, p10.b[w13, #4]\n" - ".inst 0x256d6140 // psel p0.b, p8.b/Z, p10.b[w13, #5]\n" - ".inst 0xe0172f60 // ld1b { za0h.b[x13] }, p3/Z, [x27, x23]\n" - ".inst 0x25246d23 // psel p3.b, p11.b/Z, p9.b[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0172b21 // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x23]\n" - ".inst 0x25246d22 // psel p2.b, p11.b/Z, p9.b[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0172704 // ld1b { za0h.b[x13, #4] }, p1/Z, [x24, x23]\n" - ".inst 0x252c6d21 // psel p1.b, p11.b/Z, p9.b[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01722a5 // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0x252c6d20 // psel p0.b, p11.b/Z, p9.b[w12, #1]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0bf8ec8 // st1w { za2v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - "add x13, x13, #0x8\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x12, x12, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x12, x17\n" - "blt 6b\n" - "7:" // K loop: Main loop: Second: Tail - ".inst 0x25256143 // psel p3.b, p8.b/Z, p10.b[w13]\n" - ".inst 0x252d6142 // psel p2.b, p8.b/Z, p10.b[w13, #1]\n" - ".inst 0x25656141 // psel p1.b, p8.b/Z, p10.b[w13, #4]\n" - ".inst 0x256d6140 // psel p0.b, p8.b/Z, p10.b[w13, #5]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0172f60 // ld1b { za0h.b[x13] }, p3/Z, [x27, x23]\n" - ".inst 0x25246d23 // psel p3.b, p11.b/Z, p9.b[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0172b21 // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x23]\n" - ".inst 0x25246d22 // psel p2.b, p11.b/Z, p9.b[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0172704 // ld1b { za0h.b[x13, #4] }, p1/Z, [x24, x23]\n" - ".inst 0x252c6d21 // psel p1.b, p11.b/Z, p9.b[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01722a5 // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0x252c6d20 // psel p0.b, p11.b/Z, p9.b[w12, #1]\n" - "whilelt p9.b, x16, %x[width]\n" - ".inst 0xe0bf8ec8 // st1w { za2v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - "subs x20, x20, #0x1\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "incb x16\n" - "incb x23\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "bgt 3b\n" - "8:" // K loop: Tails - "cbnz x26, 11f\n" - "mov x11, %x[in]\n" - "whilelt p8.b, x16, %x[width]\n" - "mov x13, #0x0\n" - "mov x12, #0x0\n" - "9:" // K loop: Tails: Even: First - ".inst 0x25306d23 // psel p3.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25306d22 // psel p2.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25356141 // psel p1.b, p8.b/Z, p10.b[w13, #2]\n" - ".inst 0x253d6140 // psel p0.b, p8.b/Z, p10.b[w13, #3]\n" - ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "ldr x21, [x11, #0x0]\n" - "cmp x12, x8\n" - "ldr x20, [x11, x8, LSL #0x3]\n" - "add x11, x11, #0x8\n" - ".inst 0xe01726a2 // ld1b { za0h.b[x13, #2] }, p1/Z, [x21, x23]\n" - ".inst 0xe0172283 // ld1b { za0h.b[x13, #3] }, p0/Z, [x20, x23]\n" - "add x13, x13, #0x4\n" - "blt 9b\n" - "whilelt p9.b, x16, %x[width]\n" - "whilelt p8.b, x16, %x[width]\n" - "mov x20, #0x0\n" - "mov x12, #0x0\n" - "10:" // K loop: Tails: Even: Second - ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25306d20 // psel p0.s, p11.s/Z, p9.s[w12]\n" - "add x20, x20, #0x4\n" - ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 10b\n" - "whilelt p8.b, x16, %x[width]\n" - "b 13f\n" - "11:" // K loop: Tails: Odd - "mov x12, #0x0\n" - "12:" // K loop: Tails: Odd: Loop - ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25306d20 // psel p0.s, p11.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 12b\n" - "13:" // K loop: End - "mov %x[out], x22\n" - ".inst 0xd503467f // SMSTOP\n" - : [out] "+&r"(out) - : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", - "p7", "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", - "x24", "x25", "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", - "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", - "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme( + height, width, in, row_offset, out); // NOLINT(bugprone-multi-level-implicit-pointer-conversion) out_base += m_step * kai_roundup(k_chunk_length, KR) * sizeof(int8_t); } } diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h index 7136d837aa68230e701292cc152c81a883aabba8..1adc97bc07c64b52a3dd6b4ed51d40fb89adf50c 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h @@ -16,14 +16,14 @@ extern "C" { /// /// The starting row index must be divisible by `m_step`. /// -/// @return Step size for row index +/// @return The m step value. size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void); /// Gets the offset in bytes to the data element in the packed LHS buffer. /// /// @param[in] m_idx Row index in the unpacked LHS matrix. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. +/// @param[in] k_chunk_length Length of a LHS column split. /// /// @return The offset in bytes to the data element. size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme( @@ -33,7 +33,7 @@ size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme( /// /// @param[in] m Number of rows in the unpacked LHS matrix. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. +/// @param[in] k_chunk_length Length of a LHS column split. /// /// @return The size in bytes of the packed LHS buffer. size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length); @@ -42,9 +42,9 @@ size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme(size_t m, size_ /// /// @param[in] m Number of rows of the unpacked LHS matrix. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. +/// @param[in] k_chunk_length Length of a LHS column split. /// @param[in] lhs_ptrs Pointer to an array of input pointers consisting of -/// t `m * k_chunk_count` pointers. +/// `m * k_chunk_count` pointers. /// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs /// array, excluding zero pointers. /// @param[in] pad_ptr Pointer to chunk used for padding. @ref lhs_ptr_offset is diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..5040ab58f152f7a7dfa4573af4f55e0a1d40fb27 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S @@ -0,0 +1,321 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(lhs_imatmul_pack_x8p2vlx4_x8p_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x8, #0x0 + cntb x21 + mov x23, x1 + incb x23 + mov x20, x1 + sub x17, x21, #0x1 + cntw x16 + sub x23, x23, #0x1 + ands x17, x20, x17 + udiv x23, x23, x21 // n_passes = ceildiv(width, VL) + csel x17, x17, x21, NE + lsl x22, x0, #0x1 // height * 2 + lsl x21, x16, #0x1 + sub x20, x23, #0x1 + add x17, x17, #0x3 + sub x11, x16, #0x2 + whilelt p9.b, XZR, x22 + whilelt p8.b, x21, x22 + mov x10, x2 + add x9, x2, x16, LSL #3 + cntw x28, ALL, MUL #2 + ldr x27, [x10, #0x0] + cntw x26, ALL, MUL #3 + lsr x20, x20, #0x1 // n_loops = (n_passes - 1) / 2 + ldr x25, [x9, #0x0] + and x24, x23, #0x1 // odd_tail = bool(n_passes & 0x1) + lsr x17, x17, #0x2 + ldr x23, [x10, #0x8] + ptrue p11.s + zip1 p10.b, p9.b, p8.b + ldr x21, [x9, #0x8] + mov x22, x3 + whilelt p9.b, x8, x1 + whilelt p8.b, x8, x1 + add x10, x10, #0x10 + add x9, x9, #0x10 + mov x12, #0x0 + cbz x11, label_2 +KAI_ASM_LABEL(label_1) // K loop: Charge: Loop + KAI_ASM_INST(0x25246143) // psel p3.b, p8.b/Z, p10.b[w12] + KAI_ASM_INST(0x252c6142) // psel p2.b, p8.b/Z, p10.b[w12, #1] + KAI_ASM_INST(0x25646141) // psel p1.b, p8.b/Z, p10.b[w12, #4] + KAI_ASM_INST(0x256c6140) // psel p0.b, p8.b/Z, p10.b[w12, #5] + KAI_ASM_INST(0xe0160f60) // ld1b { za0h.b[x12] }, p3/Z, [x27, x22] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0160b21) // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x22] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe01606e4) // ld1b { za0h.b[x12, #4] }, p1/Z, [x23, x22] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe01602a5) // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x22] + add x12, x12, #0x8 + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + cmp x12, x11, LSL #2 + blt label_1 +KAI_ASM_LABEL(label_2) // K loop: Charge: End + KAI_ASM_INST(0x25246143) // psel p3.b, p8.b/Z, p10.b[w12] + KAI_ASM_INST(0x252c6142) // psel p2.b, p8.b/Z, p10.b[w12, #1] + KAI_ASM_INST(0x25646141) // psel p1.b, p8.b/Z, p10.b[w12, #4] + KAI_ASM_INST(0x256c6140) // psel p0.b, p8.b/Z, p10.b[w12, #5] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0160f60) // ld1b { za0h.b[x12] }, p3/Z, [x27, x22] + ldr x27, [x10, #0x0] + incb x8 + KAI_ASM_INST(0xe0160b21) // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x22] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe01606e4) // ld1b { za0h.b[x12, #4] }, p1/Z, [x23, x22] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe01602a5) // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x22] + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + incb x22 + cbz x20, label_8 + mov x20, x20 +KAI_ASM_LABEL(label_3) // K loop: Main loop + whilelt p8.b, x8, x1 + mov x15, #0x0 + mov x14, #0x0 + cbz x11, label_5 +KAI_ASM_LABEL(label_4) // K loop: Main loop: First: Loop + KAI_ASM_INST(0x25376143) // psel p3.b, p8.b/Z, p10.b[w15, #2] + KAI_ASM_INST(0x253f6142) // psel p2.b, p8.b/Z, p10.b[w15, #3] + KAI_ASM_INST(0x25776141) // psel p1.b, p8.b/Z, p10.b[w15, #6] + KAI_ASM_INST(0x257f6140) // psel p0.b, p8.b/Z, p10.b[w15, #7] + KAI_ASM_INST(0xe0166f62) // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x22] + KAI_ASM_INST(0x25266d23) // psel p3.b, p11.b/Z, p9.b[w14] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0166b23) // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x22] + KAI_ASM_INST(0x25266d22) // psel p2.b, p11.b/Z, p9.b[w14] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe01666e6) // ld1b { za0h.b[x15, #6] }, p1/Z, [x23, x22] + KAI_ASM_INST(0x252e6d21) // psel p1.b, p11.b/Z, p9.b[w14, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe01662a7) // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x22] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0x252e6d20) // psel p0.b, p11.b/Z, p9.b[w14, #1] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0bfcc80) // st1w { za0v.s[x14] }, p3/Z, [x4, XZR, LSL #2] + add x15, x15, #0x8 + KAI_ASM_INST(0xe0b0c884) // st1w { za1v.s[x14] }, p2/Z, [x4, x16, LSL #2] + KAI_ASM_INST(0xe0bcc481) // st1w { za0v.s[x14, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0bac085) // st1w { za1v.s[x14, #1] }, p0/Z, [x4, x26, LSL #2] + add x14, x14, #0x2 + addvl x4, x4, #4 + cmp x14, x11 + blt label_4 +KAI_ASM_LABEL(label_5) // K loop: Main loop: First: Tail + KAI_ASM_INST(0x25376143) // psel p3.b, p8.b/Z, p10.b[w15, #2] + KAI_ASM_INST(0x253f6142) // psel p2.b, p8.b/Z, p10.b[w15, #3] + KAI_ASM_INST(0x25776141) // psel p1.b, p8.b/Z, p10.b[w15, #6] + KAI_ASM_INST(0x257f6140) // psel p0.b, p8.b/Z, p10.b[w15, #7] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0166f62) // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x22] + KAI_ASM_INST(0x25266d23) // psel p3.b, p11.b/Z, p9.b[w14] + ldr x27, [x10, #0x0] + mov x13, #0x0 + KAI_ASM_INST(0xe0166b23) // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x22] + KAI_ASM_INST(0x25266d22) // psel p2.b, p11.b/Z, p9.b[w14] + ldr x25, [x9, #0x0] + mov x12, #0x0 + KAI_ASM_INST(0xe01666e6) // ld1b { za0h.b[x15, #6] }, p1/Z, [x23, x22] + KAI_ASM_INST(0x252e6d21) // psel p1.b, p11.b/Z, p9.b[w14, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe01662a7) // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x22] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0x252e6d20) // psel p0.b, p11.b/Z, p9.b[w14, #1] + whilelt p9.b, x8, x1 + KAI_ASM_INST(0xe0bfcc80) // st1w { za0v.s[x14] }, p3/Z, [x4, XZR, LSL #2] + incb x8 + add x9, x9, #0x10 + KAI_ASM_INST(0xe0b0c884) // st1w { za1v.s[x14] }, p2/Z, [x4, x16, LSL #2] + incb x22 + whilelt p8.b, x8, x1 + KAI_ASM_INST(0xe0bcc481) // st1w { za0v.s[x14, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0bac085) // st1w { za1v.s[x14, #1] }, p0/Z, [x4, x26, LSL #2] + addvl x4, x4, #4 + cbz x11, label_7 +KAI_ASM_LABEL(label_6) // K loop: Main loop: Second: Loop + KAI_ASM_INST(0x25256143) // psel p3.b, p8.b/Z, p10.b[w13] + KAI_ASM_INST(0x252d6142) // psel p2.b, p8.b/Z, p10.b[w13, #1] + KAI_ASM_INST(0x25656141) // psel p1.b, p8.b/Z, p10.b[w13, #4] + KAI_ASM_INST(0x256d6140) // psel p0.b, p8.b/Z, p10.b[w13, #5] + KAI_ASM_INST(0xe0162f60) // ld1b { za0h.b[x13] }, p3/Z, [x27, x22] + KAI_ASM_INST(0x25246d23) // psel p3.b, p11.b/Z, p9.b[w12] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0162b21) // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x22] + KAI_ASM_INST(0x25246d22) // psel p2.b, p11.b/Z, p9.b[w12] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe01626e4) // ld1b { za0h.b[x13, #4] }, p1/Z, [x23, x22] + KAI_ASM_INST(0x252c6d21) // psel p1.b, p11.b/Z, p9.b[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe01622a5) // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x22] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0x252c6d20) // psel p0.b, p11.b/Z, p9.b[w12, #1] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0bf8c88) // st1w { za2v.s[x12] }, p3/Z, [x4, XZR, LSL #2] + add x13, x13, #0x8 + KAI_ASM_INST(0xe0b0888c) // st1w { za3v.s[x12] }, p2/Z, [x4, x16, LSL #2] + KAI_ASM_INST(0xe0bc8489) // st1w { za2v.s[x12, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0ba808d) // st1w { za3v.s[x12, #1] }, p0/Z, [x4, x26, LSL #2] + add x12, x12, #0x2 + addvl x4, x4, #4 + cmp x12, x11 + blt label_6 +KAI_ASM_LABEL(label_7) // K loop: Main loop: Second: Tail + KAI_ASM_INST(0x25256143) // psel p3.b, p8.b/Z, p10.b[w13] + KAI_ASM_INST(0x252d6142) // psel p2.b, p8.b/Z, p10.b[w13, #1] + KAI_ASM_INST(0x25656141) // psel p1.b, p8.b/Z, p10.b[w13, #4] + KAI_ASM_INST(0x256d6140) // psel p0.b, p8.b/Z, p10.b[w13, #5] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0162f60) // ld1b { za0h.b[x13] }, p3/Z, [x27, x22] + KAI_ASM_INST(0x25246d23) // psel p3.b, p11.b/Z, p9.b[w12] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0162b21) // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x22] + KAI_ASM_INST(0x25246d22) // psel p2.b, p11.b/Z, p9.b[w12] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe01626e4) // ld1b { za0h.b[x13, #4] }, p1/Z, [x23, x22] + KAI_ASM_INST(0x252c6d21) // psel p1.b, p11.b/Z, p9.b[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe01622a5) // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x22] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0x252c6d20) // psel p0.b, p11.b/Z, p9.b[w12, #1] + whilelt p9.b, x8, x1 + KAI_ASM_INST(0xe0bf8c88) // st1w { za2v.s[x12] }, p3/Z, [x4, XZR, LSL #2] + subs x20, x20, #0x1 + add x9, x9, #0x10 + KAI_ASM_INST(0xe0b0888c) // st1w { za3v.s[x12] }, p2/Z, [x4, x16, LSL #2] + incb x8 + incb x22 + KAI_ASM_INST(0xe0bc8489) // st1w { za2v.s[x12, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0ba808d) // st1w { za3v.s[x12, #1] }, p0/Z, [x4, x26, LSL #2] + addvl x4, x4, #4 + bgt label_3 +KAI_ASM_LABEL(label_8) // K loop: Tails + cbnz x24, label_11 + mov x10, x2 + whilelt p8.b, x8, x1 + mov x13, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_9) // K loop: Tails: Even: First + KAI_ASM_INST(0x25306d23) // psel p3.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306d22) // psel p2.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25356141) // psel p1.b, p8.b/Z, p10.b[w13, #2] + KAI_ASM_INST(0x253d6140) // psel p0.b, p8.b/Z, p10.b[w13, #3] + KAI_ASM_INST(0xe0bf8c80) // st1w { za0v.s[x12] }, p3/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0b08884) // st1w { za1v.s[x12] }, p2/Z, [x4, x16, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + ldr x21, [x10, #0x0] + cmp x12, x16 + ldr x20, [x10, x16, LSL #0x3] + add x10, x10, #0x8 + KAI_ASM_INST(0xe01626a2) // ld1b { za0h.b[x13, #2] }, p1/Z, [x21, x22] + KAI_ASM_INST(0xe0162283) // ld1b { za0h.b[x13, #3] }, p0/Z, [x20, x22] + add x13, x13, #0x4 + blt label_9 + whilelt p9.b, x8, x1 + whilelt p8.b, x8, x1 + mov x20, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_10) // K loop: Tails: Even: Second + KAI_ASM_INST(0x25306d21) // psel p1.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306d20) // psel p0.s, p11.s/Z, p9.s[w12] + add x20, x20, #0x4 + KAI_ASM_INST(0xe0bf8488) // st1w { za2v.s[x12] }, p1/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0b0808c) // st1w { za3v.s[x12] }, p0/Z, [x4, x16, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + cmp x12, x17 + blt label_10 + whilelt p8.b, x8, x1 + b label_13 +KAI_ASM_LABEL(label_11) // K loop: Tails: Odd + mov x12, #0x0 +KAI_ASM_LABEL(label_12) // K loop: Tails: Odd: Loop + KAI_ASM_INST(0x25306d21) // psel p1.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306d20) // psel p0.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8480) // st1w { za0v.s[x12] }, p1/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0b08084) // st1w { za1v.s[x12] }, p0/Z, [x4, x16, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + cmp x12, x17 + blt label_12 +KAI_ASM_LABEL(label_13) // K loop: End + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c index 8cd1201ad528ac7d02c59c51d602007899877be5..c8b814c63e4d9a5aebea8928be74b6c80a111c3f 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -4,10 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. #include "kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" @@ -18,14 +15,33 @@ #include "kai/kai_common.h" +enum { + NR = 2, + KR = 4, + MAX_N_STEP = NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint8_t)) / KR), +}; + +typedef struct { + const void* bias_ptr; + const void* scale_ptr; + int32_t input_zero_point; + float scale_multiplier; + size_t width; + size_t height; + size_t k_chunk_count; + size_t in_stride; + size_t out_stride; + const void* in; + void* out; + const void* pad_row; +} KernelArgs; + static const size_t kai_num_bytes_input = sizeof(uint8_t); static const size_t kai_num_bytes_output = sizeof(uint8_t); static const size_t kai_num_bytes_bias = sizeof(int32_t); -static const size_t kai_num_bytes_scale = sizeof(float32_t); +static const size_t kai_num_bytes_scale = sizeof(float); -#define NR 2 -#define KR 4 -#define MAX_N_STEP (NR * KAI_SME_VEC_LENGTH_MAX_BYTES / KR) +void kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(const KernelArgs* args_ptr); size_t kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void) { return NR * kai_get_sme_vector_length_u8() / KR; @@ -63,13 +79,13 @@ size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( size_t n, size_t k_chunk_count, size_t k_chunk_length) { - const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme()); + const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme()); return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( - n_rounded_up, k_chunk_count, k_chunk_length); + n_nr_blocks, k_chunk_count, k_chunk_length); } void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( - size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias, const void* scale, void* rhs_packed, const struct kai_rhs_pack_qsi8cx_params* params) { KAI_ASSUME(rhs != NULL); KAI_ASSUME(bias != NULL); @@ -77,201 +93,25 @@ void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( KAI_ASSUME(rhs_packed != NULL); KAI_ASSUME(params != NULL); - size_t height = k_chunk_length; - const size_t width = n; - const void* in = rhs; - void* out = rhs_packed; - const size_t in_stride = rhs_row_stride; - KAI_ASSERT(kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() <= MAX_N_STEP); - uint8_t pad_row[MAX_N_STEP]; - if (height % KR) { - memset(pad_row, 0, MAX_N_STEP * sizeof(uint8_t)); - } - - size_t out_stride = + static const uint8_t pad_row[MAX_N_STEP] = {0}; + + KernelArgs args; + args.bias_ptr = bias; + args.scale_ptr = scale; + args.height = k_chunk_length; + args.width = n; + args.in = rhs; + args.out = rhs_packed; + args.k_chunk_count = k_chunk_count; + args.in_stride = rhs_stride_row; + args.out_stride = kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k_chunk_count, k_chunk_length); - const int32_t input_zero_point = params->lhs_zero_point; - const float scale_multiplier = params->scale_multiplier; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x12, %x[out]\n" - "mov x11, %x[k_chunk_count]\n" - "ptrue p2.b\n" - "incb %x[out], ALL, MUL #2\n" - "1:" // Chunk Loop - "mov x10, %x[height]\n" - "cmp x10, #0x8\n" - "blt 5f\n" - "2:" // Main row loop: Head - "mov x9, %x[in]\n" - "mov x28, %x[out]\n" - "add x27, x9, %x[in_stride]\n" - "sub x10, x10, #0x8\n" - "add x26, x27, %x[in_stride]\n" - "mov x24, %x[width]\n" - "add x25, x26, %x[in_stride]\n" - "add x23, x25, %x[in_stride]\n" - "add x22, x23, %x[in_stride]\n" - "add x21, x22, %x[in_stride]\n" - "add x20, x21, %x[in_stride]\n" - "add %x[in], x20, %x[in_stride]\n" - "3:" // Main row loop: Column loop - "whilelt p0.b, XZR, x24\n" - "decw x24, ALL, MUL #2\n" - "ld1b { z18.b }, p0/Z, [x9]\n" - "cmp x24, #0x0\n" - "incd x9, ALL, MUL #4\n" - "ld1b { z22.b }, p0/Z, [x27]\n" - "incd x27, ALL, MUL #4\n" - "ld1b { z17.b }, p0/Z, [x26]\n" - "incd x26, ALL, MUL #4\n" - "ld1b { z16.b }, p0/Z, [x25]\n" - "incd x25, ALL, MUL #4\n" - "ld1b { z20.b }, p0/Z, [x23]\n" - "incd x23, ALL, MUL #4\n" - "ld1b { z19.b }, p0/Z, [x22]\n" - "zip1 z21.b, z18.b, z17.b\n" - "incd x22, ALL, MUL #4\n" - "ld1b { z18.b }, p0/Z, [x21]\n" - "zip1 z17.b, z22.b, z16.b\n" - "incd x21, ALL, MUL #4\n" - "ld1b { z16.b }, p0/Z, [x20]\n" - "incd x20, ALL, MUL #4\n" - "zip1 z20.b, z20.b, z18.b\n" - "zip1 z16.b, z19.b, z16.b\n" - "zip1 z19.b, z21.b, z17.b\n" - "zip2 z18.b, z21.b, z17.b\n" - "zip1 z17.b, z20.b, z16.b\n" - "zip2 z16.b, z20.b, z16.b\n" - "st1b { z19.b }, p2, [x28]\n" - "st1b { z18.b }, p2, [x28, #1, MUL VL]\n" - "st1b { z17.b }, p2, [x28, #2, MUL VL]\n" - "st1b { z16.b }, p2, [x28, #3, MUL VL]\n" - "add x28, x28, %x[out_stride]\n" - "bgt 3b\n" - "cmp x10, #0x8\n" - "addvl %x[out], %x[out], #4\n" - "bge 2b\n" - "cbz x10, 9f\n" - "5:" // Main loop skip - "6:" // Tail row loop: Head - "mov x9, %x[in]\n" - "cmp x10, #0x3\n" - "add x27, x9, %x[in_stride]\n" - "cntw x24, ALL, MUL #2\n" - "add x26, x27, %x[in_stride]\n" - "csel x23, x24, XZR, GT\n" - "add x25, x26, %x[in_stride]\n" - "csel x22, x24, XZR, GE\n" - "add %x[in], x25, %x[in_stride]\n" - "mov x28, %x[out]\n" - "csel %x[in], %x[in], x25, GT\n" - "csel x25, x25, %x[pad_row], GT\n" - "csel %x[in], %x[in], x26, GE\n" - "csel x26, x26, %x[pad_row], GE\n" - "cmp x10, #0x1\n" - "sub x10, x10, #0x4\n" - "csel %x[in], %x[in], x27, GT\n" - "csel x27, x27, %x[pad_row], GT\n" - "csel x21, x24, XZR, GT\n" - "mov x20, %x[width]\n" - "7:" // Tail row loop: Column loop - "whilelt p0.b, XZR, x20\n" - "decw x20, ALL, MUL #2\n" - "ld1b { z18.b }, p0/Z, [x9]\n" - "cmp x20, #0x0\n" - "add x9, x9, x24\n" - "ld1b { z19.b }, p0/Z, [x27]\n" - "add x27, x27, x21\n" - "ld1b { z17.b }, p0/Z, [x26]\n" - "add x26, x26, x22\n" - "ld1b { z16.b }, p0/Z, [x25]\n" - "add x25, x25, x23\n" - "zip1 z18.b, z18.b, z17.b\n" - "zip1 z16.b, z19.b, z16.b\n" - "zip1 z17.b, z18.b, z16.b\n" - "zip2 z16.b, z18.b, z16.b\n" - "st1b { z17.b }, p2, [x28]\n" - "st1b { z16.b }, p2, [x28, #1, MUL VL]\n" - "add x28, x28, %x[out_stride]\n" - "bgt 7b\n" - "cmp x10, #0x1\n" - "addvl %x[out], %x[out], #2\n" - "bge 6b\n" - "9:" // Done - "sub x11, x11, #0x1\n" - "cbnz x11, 1b\n" - "mov x22, %x[out]\n" - "mov x21, %x[width]\n" - "dup z18.s, %w[scale_multiplier]\n" - "cbz %x[scale], 11f\n" - "10:" // Scale: Full loop - "mov x20, x21\n" - "decw x21, ALL, MUL #2\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "ld1w { z17.s }, p1/Z, [%x[scale]]\n" - "cmp x21, #0x0\n" - "ld1w { z16.s }, p0/Z, [%x[scale], #1, MUL VL]\n" - "incb %x[scale], ALL, MUL #2\n" - "fmul z17.s, z17.s, z18.s\n" - "fmul z16.s, z16.s, z18.s\n" - "st1w { z17.s }, p2, [x22]\n" - "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" - "add x22, x22, %x[out_stride]\n" - "bgt 10b\n" - "11:" // Scale: Done - "cbz %x[width], 14f\n" - "cbz %x[height], 14f\n" - "dup z21.s, %w[input_zero_point]\n" - "add x25, %x[height], #0x3\n" - "cntw x24, ALL, MUL #2\n" - "mov z20.b, #0x1\n" - "lsr x25, x25, #0x2\n" - "mov x23, %x[width]\n" - "mul x25, %x[k_chunk_count], x25\n" - "addvl x22, x12, #2\n" - "neg z21.s, p2/M, z21.s\n" - "12:" // Bias: N loop - "mov x21, x22\n" - "mov x20, x25\n" - "mov z19.s, #0x0\n" - "mov z18.s, #0x0\n" - "13:" // Bias: K loop - "ld1b { z17.b }, p2/Z, [x21]\n" - "subs x20, x20, #0x1\n" - "ld1b { z16.b }, p2/Z, [x21, #1, MUL VL]\n" - "addvl x21, x21, #2\n" - "sdot z19.s, z17.b, z20.b\n" - "sdot z18.s, z16.b, z20.b\n" - "bgt 13b\n" - "mov x20, x23\n" - "add x22, x22, %x[out_stride]\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "ld1w { z17.s }, p1/Z, [%x[bias]]\n" - "subs x23, x23, x24\n" - "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" - "addvl %x[bias], %x[bias], #2\n" - "mla z17.s, p2/M, z19.s, z21.s\n" - "mla z16.s, p2/M, z18.s, z21.s\n" - "st1w { z17.s }, p2, [x12]\n" - "st1w { z16.s }, p2, [x12, #1, MUL VL]\n" - "add x12, x12, %x[out_stride]\n" - "bgt 12b\n" - "14:" // Bias: Done - ".inst 0xd503467f // SMSTOP\n" - : [bias] "+&r"(bias), [in] "+&r"(in), [out] "+&r"(out), [scale] "+&r"(scale) - : [height] "r"(height), [in_stride] "r"(in_stride), [input_zero_point] "r"(input_zero_point), - [k_chunk_count] "r"(k_chunk_count), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), - [scale_multiplier] "r"(scale_multiplier), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", - "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", - "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + args.input_zero_point = params->lhs_zero_point; + args.scale_multiplier = params->scale_multiplier; + args.pad_row = pad_row; + + kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h index 77b1dbde73f1273d9de6b1acf1009421898b2015..e8325a105c7306ea8da1c9e569b1b78bc4176492 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h @@ -18,7 +18,7 @@ extern "C" { /// /// The starting column index must be divisible by `n_step`. /// -/// @return Step size for column index. +/// @return The n step value. size_t kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. @@ -75,14 +75,14 @@ size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32 /// @param[in] n Number of columns of the output matrix. /// @param[in] k_chunk_count Number of chunks. /// @param[in] k_chunk_length Number of rows in each chunk. -/// @param[in] rhs_row_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs_stride_row Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. /// @param[in] scale Scale data buffer. /// @param[out] rhs_packed Packed RHS matrix. -/// @param[in] params Extra packing parameters. +/// @param[in] params Extra quantization packing parameters. void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( - size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias, const void* scale, void* rhs_packed, const struct kai_rhs_pack_qsi8cx_params* params); #ifdef __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..6713f68dd1c77657a84838144ed0428276de5921 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S @@ -0,0 +1,248 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x2, [x0, #0x28] + ptrue p2.b + ldr x3, [x0, #0x48] + ldr x4, [x0, #0x0] + ldr x5, [x0, #0x8] + mov x6, x2 + ldr w7, [x0, #0x10] + mov x8, x3 + incb x3, ALL, MUL #2 + ldr w17, [x0, #0x14] + ldr x16, [x0, #0x18] + ldr x15, [x0, #0x20] + ldr x14, [x0, #0x30] + ldr x13, [x0, #0x38] + ldr x12, [x0, #0x40] + ldr x11, [x0, #0x50] +KAI_ASM_LABEL(label_1) // Chunk Loop + mov x10, x15 + cmp x10, #0x8 + blt label_5 +KAI_ASM_LABEL(label_2) // Main row loop: Head + mov x9, x12 + mov x28, x3 + add x27, x9, x14 + sub x10, x10, #0x8 + add x26, x27, x14 + mov x24, x16 + add x25, x26, x14 + add x23, x25, x14 + add x22, x23, x14 + add x21, x22, x14 + add x20, x21, x14 + add x12, x20, x14 +KAI_ASM_LABEL(label_3) // Main row loop: Column loop + whilelt p0.b, XZR, x24 + decw x24, ALL, MUL #2 + ld1b { z18.b }, p0/Z, [x9] + cmp x24, #0x0 + incd x9, ALL, MUL #4 + ld1b { z22.b }, p0/Z, [x27] + incd x27, ALL, MUL #4 + ld1b { z17.b }, p0/Z, [x26] + incd x26, ALL, MUL #4 + ld1b { z16.b }, p0/Z, [x25] + incd x25, ALL, MUL #4 + ld1b { z20.b }, p0/Z, [x23] + incd x23, ALL, MUL #4 + ld1b { z19.b }, p0/Z, [x22] + zip1 z21.b, z18.b, z17.b + incd x22, ALL, MUL #4 + ld1b { z18.b }, p0/Z, [x21] + zip1 z17.b, z22.b, z16.b + incd x21, ALL, MUL #4 + ld1b { z16.b }, p0/Z, [x20] + incd x20, ALL, MUL #4 + zip1 z20.b, z20.b, z18.b + zip1 z16.b, z19.b, z16.b + zip1 z19.b, z21.b, z17.b + zip2 z18.b, z21.b, z17.b + zip1 z17.b, z20.b, z16.b + zip2 z16.b, z20.b, z16.b + st1b { z19.b }, p2, [x28] + st1b { z18.b }, p2, [x28, #1, MUL VL] + st1b { z17.b }, p2, [x28, #2, MUL VL] + st1b { z16.b }, p2, [x28, #3, MUL VL] + add x28, x28, x13 + bgt label_3 + cmp x10, #0x8 + addvl x3, x3, #4 + bge label_2 + cbz x10, label_9 +KAI_ASM_LABEL(label_5) // Main loop skip +KAI_ASM_LABEL(label_6) // Tail row loop: Head + mov x9, x12 + cmp x10, #0x3 + add x27, x9, x14 + cntw x24, ALL, MUL #2 + add x26, x27, x14 + csel x23, x24, XZR, GT + add x25, x26, x14 + csel x22, x24, XZR, GE + add x12, x25, x14 + mov x28, x3 + csel x12, x12, x25, GT + csel x25, x25, x11, GT + csel x12, x12, x26, GE + csel x26, x26, x11, GE + cmp x10, #0x1 + sub x10, x10, #0x4 + csel x12, x12, x27, GT + csel x27, x27, x11, GT + csel x21, x24, XZR, GT + mov x20, x16 +KAI_ASM_LABEL(label_7) // Tail row loop: Column loop + whilelt p0.b, XZR, x20 + decw x20, ALL, MUL #2 + ld1b { z18.b }, p0/Z, [x9] + cmp x20, #0x0 + add x9, x9, x24 + ld1b { z19.b }, p0/Z, [x27] + add x27, x27, x21 + ld1b { z17.b }, p0/Z, [x26] + add x26, x26, x22 + ld1b { z16.b }, p0/Z, [x25] + add x25, x25, x23 + zip1 z18.b, z18.b, z17.b + zip1 z16.b, z19.b, z16.b + zip1 z17.b, z18.b, z16.b + zip2 z16.b, z18.b, z16.b + st1b { z17.b }, p2, [x28] + st1b { z16.b }, p2, [x28, #1, MUL VL] + add x28, x28, x13 + bgt label_7 + cmp x10, #0x1 + addvl x3, x3, #2 + bge label_6 +KAI_ASM_LABEL(label_9) // Done + sub x6, x6, #0x1 + cbnz x6, label_1 + mov x22, x3 + mov x21, x16 + dup z18.s, w17 + cbz x5, label_11 +KAI_ASM_LABEL(label_10) // Scale: Full loop + mov x20, x21 + decw x21, ALL, MUL #2 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + ld1w { z17.s }, p1/Z, [x5] + cmp x21, #0x0 + ld1w { z16.s }, p0/Z, [x5, #1, MUL VL] + incb x5, ALL, MUL #2 + fmul z17.s, z17.s, z18.s + fmul z16.s, z16.s, z18.s + st1w { z17.s }, p2, [x22] + st1w { z16.s }, p2, [x22, #1, MUL VL] + add x22, x22, x13 + bgt label_10 +KAI_ASM_LABEL(label_11) // Scale: Done + cbz x16, label_14 + cbz x15, label_14 + dup z21.s, w7 + add x25, x15, #0x3 + cntw x24, ALL, MUL #2 + mov z20.b, #0x1 + lsr x25, x25, #0x2 + mov x23, x16 + mul x25, x2, x25 + addvl x22, x8, #2 + neg z21.s, p2/M, z21.s +KAI_ASM_LABEL(label_12) // Bias: N loop + mov x21, x22 + mov x20, x25 + mov z19.s, #0x0 + mov z18.s, #0x0 +KAI_ASM_LABEL(label_13) // Bias: K loop + ld1b { z17.b }, p2/Z, [x21] + subs x20, x20, #0x1 + ld1b { z16.b }, p2/Z, [x21, #1, MUL VL] + addvl x21, x21, #2 + sdot z19.s, z17.b, z20.b + sdot z18.s, z16.b, z20.b + bgt label_13 + mov x20, x23 + add x22, x22, x13 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + ld1w { z17.s }, p1/Z, [x4] + subs x23, x23, x24 + ld1w { z16.s }, p0/Z, [x4, #1, MUL VL] + addvl x4, x4, #2 + mla z17.s, p2/M, z19.s, z21.s + mla z16.s, p2/M, z18.s, z21.s + st1w { z17.s }, p2, [x8] + st1w { z16.s }, p2, [x8, #1, MUL VL] + add x8, x8, x13 + bgt label_12 +KAI_ASM_LABEL(label_14) // Bias: Done + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c index a9c0bb73ab711e1793a31611172e551ee368eef2..4cc50d1d39f49e5f73c6e6c3f4c58ffd3d93356e 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c @@ -4,10 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. #include "kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h" @@ -18,13 +15,29 @@ #include "kai/kai_common.h" -#define NR 2 -#define KR 2 +enum { + NR = 2, + KR = 2, + MAX_N_STEP = NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR), +}; + +typedef struct { + const void* bias_ptr; + size_t width; + size_t height; + size_t k_chunk_count; + size_t in_stride; + size_t out_stride; + const void* in; + void* out; + const void* pad_row; +} KernelArgs; + static const size_t kai_num_bytes_input = sizeof(uint16_t); static const size_t kai_num_bytes_output = sizeof(uint16_t); static const size_t kai_num_bytes_bias = sizeof(uint16_t); -#define MAX_N_STEP (NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR)) +void kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(const KernelArgs* args_ptr); size_t kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(void) { return NR * kai_get_sme_vector_length_u16() / KR; @@ -57,148 +70,34 @@ size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( size_t n, size_t k_chunk_count, size_t k_chunk_length) { - const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme()); + const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme()); return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( - n_rounded_up, k_chunk_count, k_chunk_length); + n_nr_blocks, k_chunk_count, k_chunk_length); } void kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( - size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias, void* rhs_packed) { KAI_ASSUME(rhs != NULL); KAI_ASSUME(bias != NULL); KAI_ASSUME(rhs_packed != NULL); - size_t height = k_chunk_length; - const size_t width = n; - const void* in = rhs; - void* out = rhs_packed; - const size_t in_stride = rhs_row_stride; - KAI_ASSERT(kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() <= MAX_N_STEP); - uint16_t pad_row[MAX_N_STEP]; - if (height % KR) { - memset(pad_row, 0, MAX_N_STEP * sizeof(uint16_t)); - } - - size_t out_stride = + static const uint16_t pad_row[MAX_N_STEP] = {0}; + + KernelArgs args; + args.bias_ptr = bias; + args.height = k_chunk_length; + args.width = n; + args.in = rhs; + args.out = rhs_packed; + args.k_chunk_count = k_chunk_count; + args.in_stride = rhs_stride_row; + args.out_stride = kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(k_chunk_count, k_chunk_length); - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x21, %x[out]\n" - "mov x20, %x[width]\n" - "ptrue p1.b\n" - "1:" // Bias: Full loop - "whilelt p0.h, XZR, x20\n" - "dech x20\n" - "cmp x20, #0x0\n" - "ld1h { z16.h }, p0/Z, [%x[bias]]\n" - "incb %x[bias]\n" - "st1h { z16.h }, p1, [x21]\n" - "add x21, x21, %x[out_stride]\n" - "bgt 1b\n" - "incb %x[out]\n" - "mov x11, %x[k_chunk_count]\n" - "2:" // Chunk Loop - "mov x10, %x[height]\n" - "cmp x10, #0x8\n" - "blt 6f\n" - "3:" // Main row loop: Head - "mov x9, %x[in]\n" - "mov x28, %x[out]\n" - "add x27, x9, %x[in_stride]\n" - "sub x10, x10, #0x8\n" - "add x26, x27, %x[in_stride]\n" - "mov x25, %x[width]\n" - "add x24, x26, %x[in_stride]\n" - "add x23, x24, %x[in_stride]\n" - "add x22, x23, %x[in_stride]\n" - "add x21, x22, %x[in_stride]\n" - "add x20, x21, %x[in_stride]\n" - "add %x[in], x20, %x[in_stride]\n" - "4:" // Main row loop: Column loop - "whilelt p0.h, XZR, x25\n" - "decw x25, ALL, MUL #2\n" - "ld1h { z20.h }, p0/Z, [x9]\n" - "cmp x25, #0x0\n" - "addvl x9, x9, #1\n" - "ld1h { z17.h }, p0/Z, [x27]\n" - "addvl x27, x27, #1\n" - "ld1h { z19.h }, p0/Z, [x26]\n" - "addvl x26, x26, #1\n" - "ld1h { z16.h }, p0/Z, [x24]\n" - "addvl x24, x24, #1\n" - "ld1h { z18.h }, p0/Z, [x23]\n" - "addvl x23, x23, #1\n" - "zip1 z24.h, z20.h, z17.h\n" - "zip2 z23.h, z20.h, z17.h\n" - "ld1h { z17.h }, p0/Z, [x22]\n" - "addvl x22, x22, #1\n" - "ld1h { z22.h }, p0/Z, [x21]\n" - "addvl x21, x21, #1\n" - "zip1 z21.h, z19.h, z16.h\n" - "zip2 z20.h, z19.h, z16.h\n" - "ld1h { z16.h }, p0/Z, [x20]\n" - "addvl x20, x20, #1\n" - "zip1 z19.h, z18.h, z17.h\n" - "zip2 z18.h, z18.h, z17.h\n" - "st1h { z24.h }, p1, [x28]\n" - "st1h { z23.h }, p1, [x28, #1, MUL VL]\n" - "zip1 z17.h, z22.h, z16.h\n" - "zip2 z16.h, z22.h, z16.h\n" - "st1h { z21.h }, p1, [x28, #2, MUL VL]\n" - "st1h { z20.h }, p1, [x28, #3, MUL VL]\n" - "st1h { z19.h }, p1, [x28, #4, MUL VL]\n" - "st1h { z18.h }, p1, [x28, #5, MUL VL]\n" - "st1h { z17.h }, p1, [x28, #6, MUL VL]\n" - "st1h { z16.h }, p1, [x28, #7, MUL VL]\n" - "add x28, x28, %x[out_stride]\n" - "bgt 4b\n" - "cmp x10, #0x8\n" - "addvl %x[out], %x[out], #8\n" - "bge 3b\n" - "cbz x10, 10f\n" - "6:" // Main loop skip - "7:" // Tail row loop: Head - "mov x9, %x[in]\n" - "cntw x22, ALL, MUL #4\n" - "add x27, x9, %x[in_stride]\n" - "cmp x10, #0x1\n" - "add %x[in], x27, %x[in_stride]\n" - "mov x28, %x[out]\n" - "csel %x[in], %x[in], x27, GT\n" - "csel x27, x27, %x[pad_row], GT\n" - "csel x21, x22, XZR, GT\n" - "sub x10, x10, #0x2\n" - "mov x20, %x[width]\n" - "8:" // Tail row loop: Column loop - "whilelt p0.h, XZR, x20\n" - "decw x20, ALL, MUL #2\n" - "ld1h { z18.h }, p0/Z, [x9]\n" - "cmp x20, #0x0\n" - "add x9, x9, x22\n" - "ld1h { z16.h }, p0/Z, [x27]\n" - "add x27, x27, x21\n" - "zip1 z17.h, z18.h, z16.h\n" - "zip2 z16.h, z18.h, z16.h\n" - "st1h { z17.h }, p1, [x28]\n" - "st1h { z16.h }, p1, [x28, #1, MUL VL]\n" - "add x28, x28, %x[out_stride]\n" - "bgt 8b\n" - "cmp x10, #0x1\n" - "addvl %x[out], %x[out], #2\n" - "bge 7b\n" - "10:" // Done - "sub x11, x11, #0x1\n" - "cbnz x11, 2b\n" - ".inst 0xd503467f // SMSTOP\n" - : [bias] "+&r"(bias), [in] "+&r"(in), [out] "+&r"(out) - : [height] "r"(height), [in_stride] "r"(in_stride), [k_chunk_count] "r"(k_chunk_count), - [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", - "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", - "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + args.pad_row = pad_row; + + kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h index ebf1aec23a0effb6b858ec57e823d27485637588..9dd33d72698e3f78dfe37c5a01c8af5817ef995e 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h @@ -16,7 +16,7 @@ extern "C" { /// /// The starting column index must be divisible by `n_step`. /// -/// @return Step size for column index. +/// @return The n step value. size_t kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. @@ -65,12 +65,12 @@ size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( /// @param[in] n Number of columns of the output matrix. /// @param[in] k_chunk_count Number of chunks. /// @param[in] k_chunk_length Number of rows in each chunk. -/// @param[in] rhs_row_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs_stride_row Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. /// @param[out] rhs_packed Packed RHS matrix. void kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( - size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias, void* rhs_packed); #ifdef __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..2ff504eee814e74ffff26170f11bd4a16bdf337c --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S @@ -0,0 +1,183 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x8, [x0, #0x8] + ptrue p1.b + ldr x17, [x0, #0x38] + ldr x23, [x0, #0x0] + ldr x16, [x0, #0x10] + mov x22, x8 + ldr x21, [x0, #0x18] + mov x20, x17 + ldr x15, [x0, #0x20] + ldr x14, [x0, #0x28] + ldr x13, [x0, #0x30] + ldr x12, [x0, #0x40] +KAI_ASM_LABEL(label_1) // Bias: Full loop + whilelt p0.h, XZR, x22 + dech x22 + cmp x22, #0x0 + ld1h { z16.h }, p0/Z, [x23] + incb x23 + st1h { z16.h }, p1, [x20] + add x20, x20, x14 + bgt label_1 + incb x17 + mov x11, x21 +KAI_ASM_LABEL(label_2) // Chunk Loop + mov x10, x16 + cmp x10, #0x8 + blt label_6 +KAI_ASM_LABEL(label_3) // Main row loop: Head + mov x9, x13 + mov x28, x17 + add x27, x9, x15 + sub x10, x10, #0x8 + add x26, x27, x15 + mov x25, x8 + add x24, x26, x15 + add x23, x24, x15 + add x22, x23, x15 + add x21, x22, x15 + add x20, x21, x15 + add x13, x20, x15 +KAI_ASM_LABEL(label_4) // Main row loop: Column loop + whilelt p0.h, XZR, x25 + decw x25, ALL, MUL #2 + ld1h { z20.h }, p0/Z, [x9] + cmp x25, #0x0 + addvl x9, x9, #1 + ld1h { z17.h }, p0/Z, [x27] + addvl x27, x27, #1 + ld1h { z19.h }, p0/Z, [x26] + addvl x26, x26, #1 + ld1h { z16.h }, p0/Z, [x24] + addvl x24, x24, #1 + ld1h { z18.h }, p0/Z, [x23] + addvl x23, x23, #1 + zip1 z24.h, z20.h, z17.h + zip2 z23.h, z20.h, z17.h + ld1h { z17.h }, p0/Z, [x22] + addvl x22, x22, #1 + ld1h { z22.h }, p0/Z, [x21] + addvl x21, x21, #1 + zip1 z21.h, z19.h, z16.h + zip2 z20.h, z19.h, z16.h + ld1h { z16.h }, p0/Z, [x20] + addvl x20, x20, #1 + zip1 z19.h, z18.h, z17.h + zip2 z18.h, z18.h, z17.h + st1h { z24.h }, p1, [x28] + st1h { z23.h }, p1, [x28, #1, MUL VL] + zip1 z17.h, z22.h, z16.h + zip2 z16.h, z22.h, z16.h + st1h { z21.h }, p1, [x28, #2, MUL VL] + st1h { z20.h }, p1, [x28, #3, MUL VL] + st1h { z19.h }, p1, [x28, #4, MUL VL] + st1h { z18.h }, p1, [x28, #5, MUL VL] + st1h { z17.h }, p1, [x28, #6, MUL VL] + st1h { z16.h }, p1, [x28, #7, MUL VL] + add x28, x28, x14 + bgt label_4 + cmp x10, #0x8 + addvl x17, x17, #8 + bge label_3 + cbz x10, label_10 +KAI_ASM_LABEL(label_6) // Main loop skip +KAI_ASM_LABEL(label_7) // Tail row loop: Head + mov x9, x13 + cntw x22, ALL, MUL #4 + add x27, x9, x15 + cmp x10, #0x1 + add x13, x27, x15 + mov x28, x17 + csel x13, x13, x27, GT + csel x27, x27, x12, GT + csel x21, x22, XZR, GT + sub x10, x10, #0x2 + mov x20, x8 +KAI_ASM_LABEL(label_8) // Tail row loop: Column loop + whilelt p0.h, XZR, x20 + decw x20, ALL, MUL #2 + ld1h { z18.h }, p0/Z, [x9] + cmp x20, #0x0 + add x9, x9, x22 + ld1h { z16.h }, p0/Z, [x27] + add x27, x27, x21 + zip1 z17.h, z18.h, z16.h + zip2 z16.h, z18.h, z16.h + st1h { z17.h }, p1, [x28] + st1h { z16.h }, p1, [x28, #1, MUL VL] + add x28, x28, x14 + bgt label_8 + cmp x10, #0x1 + addvl x17, x17, #2 + bge label_7 +KAI_ASM_LABEL(label_10) // Done + sub x11, x11, #0x1 + cbnz x11, label_2 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c index 46c626f1260af5fe2769bc4069f15ca726c524bd..2a7c8100abba8a09aca6f4057c5e9fbed1c4bf42 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c @@ -4,10 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. #include "kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" @@ -17,11 +14,27 @@ #include "kai/kai_common.h" -#define NR 2 -#define KR 1 +enum { + NR = 2, + KR = 1, +}; + +typedef struct { + const void* bias_ptr; + size_t width; + size_t height; + size_t k_chunk_count; + size_t in_stride; + size_t out_stride; + const void* in; + void* out; +} KernelArgs; + static const size_t kai_num_bytes_input = sizeof(uint32_t); static const size_t kai_num_bytes_output = sizeof(uint32_t); -static const size_t kai_num_bytes_bias = sizeof(uint32_t); +static const size_t kai_num_bytes_bias = sizeof(float); + +void kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(const KernelArgs* args_ptr); size_t kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(void) { return NR * kai_get_sme_vector_length_u32() / KR; @@ -54,129 +67,30 @@ size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( size_t n, size_t k_chunk_count, size_t k_chunk_length) { - const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme()); + const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme()); return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( - n_rounded_up, k_chunk_count, k_chunk_length); + n_nr_blocks, k_chunk_count, k_chunk_length); } void kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( - size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias, void* rhs_packed) { KAI_ASSUME(rhs != NULL); KAI_ASSUME(bias != NULL); KAI_ASSUME(rhs_packed != NULL); - size_t height = k_chunk_length; - const size_t width = n; - const void* in = rhs; - void* out = rhs_packed; - const size_t in_stride = rhs_row_stride; - - size_t out_stride = + KernelArgs args; + args.bias_ptr = bias; + args.height = k_chunk_length; + args.width = n; + args.in = rhs; + args.out = rhs_packed; + args.k_chunk_count = k_chunk_count; + args.in_stride = rhs_stride_row; + args.out_stride = kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(k_chunk_count, k_chunk_length); - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x22, %x[out]\n" - "mov x21, %x[width]\n" - "ptrue p2.b\n" - "1:" // Bias: Full loop - "mov x20, x21\n" - "decw x21, ALL, MUL #2\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "cmp x21, #0x0\n" - "ld1w { z17.s }, p1/Z, [%x[bias]]\n" - "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" - "incb %x[bias], ALL, MUL #2\n" - "st1w { z17.s }, p2, [x22]\n" - "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" - "add x22, x22, %x[out_stride]\n" - "bgt 1b\n" - "incb %x[out], ALL, MUL #2\n" - "mov x28, %x[k_chunk_count]\n" - "2:" // Chunk Loop - "mov x27, %x[height]\n" - "cmp x27, #0x4\n" - "blt 6f\n" - "3:" // Main row loop: Head - "mov x26, %x[in]\n" - "mov x25, %x[out]\n" - "add x24, x26, %x[in_stride]\n" - "sub x27, x27, #0x4\n" - "add x23, x24, %x[in_stride]\n" - "mov x22, %x[width]\n" - "add x21, x23, %x[in_stride]\n" - "add %x[in], x21, %x[in_stride]\n" - "4:" // Main row loop: Column loop - "mov x20, x22\n" - "decw x22, ALL, MUL #2\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "cmp x22, #0x0\n" - "ld1w { z23.s }, p1/Z, [x26]\n" - "ld1w { z22.s }, p0/Z, [x26, #1, MUL VL]\n" - "addvl x26, x26, #2\n" - "ld1w { z21.s }, p1/Z, [x24]\n" - "ld1w { z20.s }, p0/Z, [x24, #1, MUL VL]\n" - "addvl x24, x24, #2\n" - "ld1w { z19.s }, p1/Z, [x23]\n" - "ld1w { z18.s }, p0/Z, [x23, #1, MUL VL]\n" - "addvl x23, x23, #2\n" - "ld1w { z17.s }, p1/Z, [x21]\n" - "ld1w { z16.s }, p0/Z, [x21, #1, MUL VL]\n" - "addvl x21, x21, #2\n" - "st1w { z23.s }, p2, [x25]\n" - "st1w { z22.s }, p2, [x25, #1, MUL VL]\n" - "st1w { z21.s }, p2, [x25, #2, MUL VL]\n" - "st1w { z20.s }, p2, [x25, #3, MUL VL]\n" - "st1w { z19.s }, p2, [x25, #4, MUL VL]\n" - "st1w { z18.s }, p2, [x25, #5, MUL VL]\n" - "st1w { z17.s }, p2, [x25, #6, MUL VL]\n" - "st1w { z16.s }, p2, [x25, #7, MUL VL]\n" - "add x25, x25, %x[out_stride]\n" - "bgt 4b\n" - "cmp x27, #0x4\n" - "addvl %x[out], %x[out], #8\n" - "bge 3b\n" - "cbz x27, 10f\n" - "6:" // Main loop skip - "7:" // Tail row loop: Head - "mov x26, %x[in]\n" - "cntw x22, ALL, MUL #8\n" - "add %x[in], x26, %x[in_stride]\n" - "mov x25, %x[out]\n" - "sub x27, x27, #0x1\n" - "mov x21, %x[width]\n" - "8:" // Tail row loop: Column loop - "mov x20, x21\n" - "decw x21, ALL, MUL #2\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "cmp x21, #0x0\n" - "ld1w { z17.s }, p1/Z, [x26]\n" - "ld1w { z16.s }, p0/Z, [x26, #1, MUL VL]\n" - "add x26, x26, x22\n" - "st1w { z17.s }, p2, [x25]\n" - "st1w { z16.s }, p2, [x25, #1, MUL VL]\n" - "add x25, x25, %x[out_stride]\n" - "bgt 8b\n" - "cmp x27, #0x1\n" - "addvl %x[out], %x[out], #2\n" - "bge 7b\n" - "10:" // Done - "sub x28, x28, #0x1\n" - "cbnz x28, 2b\n" - ".inst 0xd503467f // SMSTOP\n" - : [bias] "+&r"(bias), [in] "+&r"(in), [out] "+&r"(out) - : [height] "r"(height), [in_stride] "r"(in_stride), [k_chunk_count] "r"(k_chunk_count), - [out_stride] "r"(out_stride), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z10", "z11", "z12", - "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", - "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + + kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h index ea16c9df140741431697d6579970596c2ed51b38..4af70ce1db727e21b3a902338659ab6387645dad 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h @@ -16,7 +16,7 @@ extern "C" { /// /// The starting column index must be divisible by `n_step`. /// -/// @return Step size for column index. +/// @return The n step value. size_t kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. @@ -65,12 +65,12 @@ size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( /// @param[in] n Number of columns of the output matrix. /// @param[in] k_chunk_count Number of chunks. /// @param[in] k_chunk_length Number of rows in each chunk. -/// @param[in] rhs_row_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs_stride_row Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. /// @param[out] rhs_packed Packed RHS matrix. void kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( - size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias, void* rhs_packed); #ifdef __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..6189678a3157e569a1741d32d180d293a5bf51d3 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S @@ -0,0 +1,169 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x14, [x0, #0x8] + ptrue p2.b + ldr x13, [x0, #0x38] + ldr x24, [x0, #0x0] + ldr x12, [x0, #0x10] + mov x23, x14 + ldr x22, [x0, #0x18] + mov x21, x13 + ldr x11, [x0, #0x20] + ldr x10, [x0, #0x28] + ldr x9, [x0, #0x30] +KAI_ASM_LABEL(label_1) // Bias: Full loop + mov x20, x23 + decw x23, ALL, MUL #2 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + cmp x23, #0x0 + ld1w { z17.s }, p1/Z, [x24] + ld1w { z16.s }, p0/Z, [x24, #1, MUL VL] + incb x24, ALL, MUL #2 + st1w { z17.s }, p2, [x21] + st1w { z16.s }, p2, [x21, #1, MUL VL] + add x21, x21, x10 + bgt label_1 + incb x13, ALL, MUL #2 + mov x28, x22 +KAI_ASM_LABEL(label_2) // Chunk Loop + mov x27, x12 + cmp x27, #0x4 + blt label_6 +KAI_ASM_LABEL(label_3) // Main row loop: Head + mov x26, x9 + mov x25, x13 + add x24, x26, x11 + sub x27, x27, #0x4 + add x23, x24, x11 + mov x22, x14 + add x21, x23, x11 + add x9, x21, x11 +KAI_ASM_LABEL(label_4) // Main row loop: Column loop + mov x20, x22 + decw x22, ALL, MUL #2 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + cmp x22, #0x0 + ld1w { z23.s }, p1/Z, [x26] + ld1w { z22.s }, p0/Z, [x26, #1, MUL VL] + addvl x26, x26, #2 + ld1w { z21.s }, p1/Z, [x24] + ld1w { z20.s }, p0/Z, [x24, #1, MUL VL] + addvl x24, x24, #2 + ld1w { z19.s }, p1/Z, [x23] + ld1w { z18.s }, p0/Z, [x23, #1, MUL VL] + addvl x23, x23, #2 + ld1w { z17.s }, p1/Z, [x21] + ld1w { z16.s }, p0/Z, [x21, #1, MUL VL] + addvl x21, x21, #2 + st1w { z23.s }, p2, [x25] + st1w { z22.s }, p2, [x25, #1, MUL VL] + st1w { z21.s }, p2, [x25, #2, MUL VL] + st1w { z20.s }, p2, [x25, #3, MUL VL] + st1w { z19.s }, p2, [x25, #4, MUL VL] + st1w { z18.s }, p2, [x25, #5, MUL VL] + st1w { z17.s }, p2, [x25, #6, MUL VL] + st1w { z16.s }, p2, [x25, #7, MUL VL] + add x25, x25, x10 + bgt label_4 + cmp x27, #0x4 + addvl x13, x13, #8 + bge label_3 + cbz x27, label_10 +KAI_ASM_LABEL(label_6) // Main loop skip +KAI_ASM_LABEL(label_7) // Tail row loop: Head + mov x26, x9 + cntw x22, ALL, MUL #8 + add x9, x26, x11 + mov x25, x13 + sub x27, x27, #0x1 + mov x21, x14 +KAI_ASM_LABEL(label_8) // Tail row loop: Column loop + mov x20, x21 + decw x21, ALL, MUL #2 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + cmp x21, #0x0 + ld1w { z17.s }, p1/Z, [x26] + ld1w { z16.s }, p0/Z, [x26, #1, MUL VL] + add x26, x26, x22 + st1w { z17.s }, p2, [x25] + st1w { z16.s }, p2, [x25, #1, MUL VL] + add x25, x25, x10 + bgt label_8 + cmp x27, #0x1 + addvl x13, x13, #2 + bge label_7 +KAI_ASM_LABEL(label_10) // Done + sub x28, x28, #0x1 + cbnz x28, label_2 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) + + KAI_ASM_END diff --git a/test/tests/imatmul_test.cpp b/test/tests/imatmul_test.cpp index 70b28e555f74aca74ecb2a3b0559f7ae0946d885..ca4b71b37918afd86ce97b270a39629a7d260326 100644 --- a/test/tests/imatmul_test.cpp +++ b/test/tests/imatmul_test.cpp @@ -77,7 +77,7 @@ struct MatMulIndirectKernel { std::function get_kr; std::function get_lhs_packed_offset; std::function get_rhs_packed_offset; - std::function get_dst_offset; + std::function get_dst_offset; std::function get_dst_size; std::function get_m_step; std::function get_n_step; - std::function get_packed_lhs_offset; - std::function get_packed_rhs_offset; - std::function get_dst_offset; + std::function get_lhs_packed_offset; + std::function get_rhs_packed_offset; + std::function get_dst_offset; std::function get_dst_size; std::function& get_indirect_gemm_variants() { variants[0].rhs_pack.pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; variants[0].matmul.get_m_step = ukernel.get_m_step; variants[0].matmul.get_n_step = ukernel.get_n_step; - variants[0].matmul.get_packed_lhs_offset = ukernel.get_lhs_packed_offset; - variants[0].matmul.get_packed_rhs_offset = ukernel.get_rhs_packed_offset; + variants[0].matmul.get_lhs_packed_offset = ukernel.get_lhs_packed_offset; + variants[0].matmul.get_rhs_packed_offset = ukernel.get_rhs_packed_offset; variants[0].matmul.get_dst_offset = ukernel.get_dst_offset; variants[0].matmul.get_dst_size = ukernel.get_dst_size; variants[0].matmul.imatmul = ukernel.run_imatmul; @@ -845,8 +845,8 @@ static Buffer matmul( const Buffer& packed_rhs, const MatMulShape& shape, const KChunk& k_chunk) { // Calculate portion offsets. size_t dst_offset = variant.get_dst_offset(portion.start_row(), portion.start_col(), shape.n); - size_t lhs_offset = variant.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); - size_t rhs_offset = variant.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length); + size_t lhs_offset = variant.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length); + size_t rhs_offset = variant.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length); // Allocate output buffer const size_t dst_size = variant.get_dst_size(shape.m, shape.n);