From a70785b69eb0c628db8ef40f9ec7300681078019 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Thu, 22 May 2025 15:24:02 +0200 Subject: [PATCH 1/6] Split FP32 SME kernels into seperate assembly file. Move the assembly blocks of the following kernels into their own files: - rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme - lhs_pack_f32p2vlx1_f32_sme - matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla - matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla - matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa Signed-off-by: Jens Elofsson --- CMakeLists.txt | 31 +- kai/ukernels/matmul/BUILD.bazel | 34 +- ...lamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c | 526 +----------- ...lamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h | 16 +- ..._f32_f32_f32p16vlx1b_1x16vl_sme2_mla_asm.S | 511 ++++++++++++ ...clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c | 778 +----------------- ...clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h | 16 +- ...p_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_asm.S | 763 +++++++++++++++++ ...f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c | 448 +--------- ...f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h | 24 +- ...f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_asm.S | 252 ++++++ .../pack/kai_lhs_pack_f32p2vlx1_f32_sme.c | 344 ++------ .../pack/kai_lhs_pack_f32p2vlx1_f32_sme.h | 4 +- .../pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S | 304 +++++++ ...hs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c | 166 +--- ...hs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h | 28 +- ...ack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S | 156 ++++ 17 files changed, 2323 insertions(+), 2078 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_asm.S create mode 100644 kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index ea7b9926..49f4b09f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -221,33 +221,46 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c ) +set(KLEIDIAI_FILES_SME_ASM + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S + kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S +) + set(KLEIDIAI_FILES_SME + ${KLEIDIAI_FILES_SME_ASM} kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c ) +set(KLEIDIAI_FILES_SME2_ASM + kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c + kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla_asm.S + kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c + kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_asm.S + kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_asm.S +) + set(KLEIDIAI_FILES_SME2 + ${KLEIDIAI_FILES_SME2_ASM} kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c - kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c - kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c @@ -293,14 +306,20 @@ if(NOT MSVC) else() target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM_ASM}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME_ASM}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2_ASM}) set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SME_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SME2_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set(KLEIDIAI_FILES_ASM ${KLEIDIAI_FILES_NEON_DOTPROD_ASM} - ${KLEIDIAI_FILES_NEON_I8MM_ASM}) + ${KLEIDIAI_FILES_NEON_I8MM_ASM} + ${KLEIDIAI_FILES_SME_ASM} + ${KLEIDIAI_FILES_SME2_ASM}) list(FILTER KLEIDIAI_FILES_ASM INCLUDE REGEX "^.*\.S$") set_source_files_properties(${KLEIDIAI_FILES_ASM} PROPERTIES LANGUAGE ASM_MARMASM) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 2da2800c..ece4c604 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -136,13 +136,18 @@ I8MM_KERNELS_ASM = [ "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", ] +# buildifier: keep sorted +SME_KERNELS_ASM = [ + "pack/kai_lhs_pack_f32p2vlx1_f32_sme", + "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", +] + # buildifier: keep sorted SME_KERNELS = [ "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", "pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme", "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme", "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", - "pack/kai_lhs_pack_f32p2vlx1_f32_sme", "pack/kai_lhs_pack_x16p2vlx2_x16_sme", "pack/kai_lhs_pack_x8p2vlx4_x8_sme", "pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", @@ -150,13 +155,19 @@ SME_KERNELS = [ "pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", - "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", "pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", "pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme", "pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme", "pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme", ] +# buildifier: keep sorted +SME2_KERNELS_ASM = [ + "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla", + "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla", + "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa", +] + # buildifier: keep sorted SME2_KERNELS = [ "imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", @@ -164,9 +175,6 @@ SME2_KERNELS = [ "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", - "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla", - "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla", - "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa", @@ -265,6 +273,13 @@ kai_c_library( textual_hdrs = [ukernel + ".h" for ukernel in I8MM_KERNELS_ASM], ) +kai_c_library( + name = "sme_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in SME_KERNELS_ASM] + [ukernel + ".c" for ukernel in SME_KERNELS_ASM], + cpu_uarch = kai_cpu_sme(), + textual_hdrs = [ukernel + ".h" for ukernel in SME_KERNELS_ASM], +) + kai_c_library( name = "sme_impl", srcs = [ukernel + ".c" for ukernel in SME_KERNELS], @@ -272,6 +287,13 @@ kai_c_library( textual_hdrs = [ukernel + ".h" for ukernel in SME_KERNELS], ) +kai_c_library( + name = "sme2_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in SME2_KERNELS_ASM] + [ukernel + ".c" for ukernel in SME2_KERNELS_ASM], + cpu_uarch = kai_cpu_sme2(), + textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS_ASM], +) + kai_c_library( name = "sme2_impl", srcs = [ukernel + ".c" for ukernel in SME2_KERNELS], @@ -297,6 +319,8 @@ kai_c_library( ":neon_impl_asm", ":scalar_impl", ":sme2_impl", + ":sme2_impl_asm", ":sme_impl", + ":sme_impl_asm", ], ) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c index 627ab688..3760cc45 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c @@ -7,7 +7,7 @@ // Do not flag up inline assembly blocks #pragma GCC diagnostic ignored "-Woverlength-strings" -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. @@ -18,21 +18,35 @@ #include "kai/kai_common.h" -static const size_t kai_mr = 1; +typedef struct { + float maxval; + float minval; + const void* A_ptr; + const void* B_ptr; + size_t N; + size_t K; + void* output_ptr; + uint64_t flags; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(KernelArgs* args_ptr); + +static const size_t kai_m_step = 1; static const size_t kai_nr = 16; +static const size_t kai_n_step = 16; static const size_t kai_kr = 1; static const size_t kai_sr = 1; size_t kai_get_m_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(void) { - return kai_mr * kai_get_sme_vector_length_u32(); + return kai_m_step; } size_t kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_n_step * kai_get_sme_vector_length_u32() / kai_kr; } size_t kai_get_nr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_sme_vector_length_u32() / kai_kr; } size_t kai_get_kr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(void) { @@ -43,20 +57,27 @@ size_t kai_get_sr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(void) { return kai_sr; } -size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(size_t m_idx, size_t lhs_stride) { - KAI_ASSUME(m_idx % kai_mr == 0); +size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx == 0); + + return m_idx * k; +} - return m_idx * lhs_stride; +static size_t kai_get_rhs_packed_stride_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(size_t k) { + return kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla() * + (kai_roundup(k, kai_kr) * sizeof(float) + sizeof(float)); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla() == 0); - return n_idx * (k * sizeof(float) + sizeof(float)); + + const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(); + return block_idx * kai_get_rhs_packed_stride_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(k); } size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla() == 0); + KAI_ASSUME(m_idx == 0); KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla() == 0); return (m_idx * dst_stride) + (n_idx * sizeof(float)); @@ -69,486 +90,27 @@ size_t kai_get_dst_size_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(size_t void kai_run_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla( size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) { - KAI_UNUSED(lhs_stride); KAI_UNUSED(dst_stride_row); KAI_UNUSED(dst_stride_col); - KAI_ASSUME(m == 1); - - typedef struct { - float maxval; - float minval; - } KernelArgs; + KAI_UNUSED(lhs_stride); - KernelArgs ka; - ka.maxval = clamp_max; - ka.minval = clamp_min; + KAI_ASSUME(m == 1); - size_t N = n; - size_t K = k; + uint64_t flags = 2; - const void* A_ptr = lhs; - const void* B_ptr = rhs_packed; - void* output_ptr = dst; + KernelArgs args; - uint64_t flags = 2; + args.maxval = clamp_max; + args.minval = clamp_min; + args.A_ptr = lhs; + args.B_ptr = rhs_packed; + args.N = n; + args.K = k; + args.output_ptr = dst; + args.flags = flags; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x9, #0x0\n" - "mov x27, %x[B_ptr]\n" - "cntw x26, ALL, MUL #4\n" - "mov x25, %x[output_ptr]\n" - "add x24, %x[N], x26\n" - "ptrue p1.b\n" - "sub x24, x24, #0x1\n" - ".inst 0x25207811 // ptrue pn9.b\n" - "udiv x24, x24, x26\n" - "mov x22, #0x1\n" - "add x21, x24, #0x3\n" - "and x21, x21, #0xfffffffffffffffc\n" - "mul x21, x21, x26\n" - "mul x21, x21, %x[K]\n" - "lsl x21, x21, #0x2\n" - "1:" // RHS size check loop - "cmp x21, #0x200000\n" - "blt 2f\n" - "tbnz x21, #0, 3f\n" - "lsr x21, x21, #0x1\n" - "lsl x22, x22, #0x1\n" - "b 1b\n" - "2:" // RHS do prefetch - "lsl x20, x21, #0x26\n" - "sub x22, x22, #0x1\n" - "lsl x22, x22, #0x16\n" - "orr x21, x21, x20\n" - "orr x21, x21, x22\n" - ".inst 0xf8b54b7a // rprfm pldonce, x21, [x27]\n" - "3:" // RHS prefetch exit - "4:" // Column loop - "cmp x24, #0x4\n" - "bge 22f\n" - "cmp x24, #0x2\n" - "bgt 16f\n" - "beq 10f\n" - ".inst 0xa040c774 // ld1w { z20.s-z23.s }, pn9.b/Z, [x27]\n" - "mov x23, %x[K]\n" - "mov x21, %x[N]\n" - "mov x22, %x[A_ptr]\n" - "lsl x20, %x[K], #0x2\n" - ".inst 0x25b567f0 // whilelt p8.s, XZR, x21, VLx4\n" - "cmp x23, #0x4\n" - ".inst 0xf8b44ad8 // rprfm pldmany, x20, [x22]\n" - ".inst 0xc0042e80 // mova za.d[x9, #0], { z20.d-z23.d }\n" - "addvl x27, x27, #16\n" - "ble 6f\n" - "5:" // Width 1: Multiply loop: Main loop head - "whilelt p0.s, XZR, x23\n" - ".inst 0xa040c77d // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #16\n" - "ld1rqw { z2.s }, p0/Z, [x22]\n" - "sub x23, x23, #0x4\n" - "add x22, x22, #0x10\n" - ".inst 0xa040c771 // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #16\n" - "cmp x23, #0x4\n" - ".inst 0xa040c779 // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #16\n" - ".inst 0xc152a380 // fmla za.s[x9, 0], { z28.s-z31.s }, z2.s[0]\n" - ".inst 0xa040c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #16\n" - ".inst 0xc152a600 // fmla za.s[x9, 0], { z16.s-z19.s }, z2.s[1]\n" - ".inst 0xc152ab00 // fmla za.s[x9, 0], { z24.s-z27.s }, z2.s[2]\n" - ".inst 0xc152ad80 // fmla za.s[x9, 0], { z12.s-z15.s }, z2.s[3]\n" - "bgt 5b\n" - "6:" // Width 1: Multiply loop: Single iteration only - "whilelt p0.s, XZR, x23\n" - ".inst 0xa040c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27]\n" - "subs x23, x23, #0x1\n" - "ld1rqw { z3.s }, p0/Z, [x22]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a180 // fmla za.s[x9, 0], { z12.s-z15.s }, z3.s[0]\n" - "ble 7f\n" - ".inst 0xa040c765 // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x27]\n" - "subs x23, x23, #0x1\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a480 // fmla za.s[x9, 0], { z4.s-z7.s }, z3.s[1]\n" - "ble 7f\n" - ".inst 0xa040c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27]\n" - "subs x23, x23, #0x1\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a980 // fmla za.s[x9, 0], { z12.s-z15.s }, z3.s[2]\n" - "ble 7f\n" - ".inst 0xa040c769 // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x27]\n" - ".inst 0xc153ad00 // fmla za.s[x9, 0], { z8.s-z11.s }, z3.s[3]\n" - "7:" // Width 1: Multiply loop: multiply skip - "tbz %x[flags], #1, 8f\n" - "add x21, %x[args_ptr], %[offset_min]\n" - "add x20, %x[args_ptr], %[offset_max]\n" - ".inst 0xc0062c00 // mova { z0.d-z3.d }, za.d[x9, #0]\n" - "ld1rw { z23.s }, p1/Z, [x21]\n" - "ld1rw { z22.s }, p1/Z, [x20]\n" - ".inst 0xc1b6cae0 // fclamp { z0.s-z3.s }, z23.s, z22.s\n" - ".inst 0xa060c320 // st1w { z0.s-z3.s }, p8, [x25]\n" - "b 9f\n" - "8:" // Width 1: No activation - ".inst 0xc0062c00 // mova { z0.d-z3.d }, za.d[x9, #0]\n" - ".inst 0xa060c320 // st1w { z0.s-z3.s }, p8, [x25]\n" - "9:" // Width 1: Output done - "b 28f\n" - "10:" // Width 2 - ".inst 0xa040c77c // ld1w { z28.s-z31.s }, pn9.b/Z, [x27]\n" - "mov x23, %x[K]\n" - "sub x21, %x[N], x26\n" - ".inst 0xa041c764 // ld1w { z4.s-z7.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - "mov x22, %x[A_ptr]\n" - "lsl x20, %x[K], #0x2\n" - ".inst 0x25b567f0 // whilelt p8.s, XZR, x21, VLx4\n" - "cmp x23, #0x4\n" - ".inst 0xf8b44ad8 // rprfm pldmany, x20, [x22]\n" - ".inst 0xc0042f80 // mova za.d[x9, #0], { z28.d-z31.d }\n" - "addvl x27, x27, #8\n" - ".inst 0xc0042c81 // mova za.d[x9, #1], { z4.d-z7.d }\n" - "ble 12f\n" - "11:" // Width 2: Multiply loop: Main loop head - "whilelt p0.s, XZR, x23\n" - ".inst 0xa040c77d // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x27]\n" - "sub x23, x23, #0x4\n" - "ld1rqw { z1.s }, p0/Z, [x22]\n" - "cmp x23, #0x4\n" - "add x22, x22, #0x10\n" - ".inst 0xa041c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xa040c771 // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x27]\n" - ".inst 0xc151a380 // fmla za.s[x9, 0], { z28.s-z31.s }, z1.s[0]\n" - ".inst 0xa041c779 // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xc151a181 // fmla za.s[x9, 1], { z12.s-z15.s }, z1.s[0]\n" - ".inst 0xa040c77d // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x27]\n" - ".inst 0xa041c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xa040c769 // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x27]\n" - ".inst 0xc151a600 // fmla za.s[x9, 0], { z16.s-z19.s }, z1.s[1]\n" - ".inst 0xa041c775 // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xc151a701 // fmla za.s[x9, 1], { z24.s-z27.s }, z1.s[1]\n" - ".inst 0xc151ab80 // fmla za.s[x9, 0], { z28.s-z31.s }, z1.s[2]\n" - ".inst 0xc151a981 // fmla za.s[x9, 1], { z12.s-z15.s }, z1.s[2]\n" - ".inst 0xc151ad00 // fmla za.s[x9, 0], { z8.s-z11.s }, z1.s[3]\n" - ".inst 0xc151ae81 // fmla za.s[x9, 1], { z20.s-z23.s }, z1.s[3]\n" - "bgt 11b\n" - "12:" // Width 2: Multiply loop: Single iteration only - "whilelt p0.s, XZR, x23\n" - ".inst 0xa040c771 // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x27]\n" - "subs x23, x23, #0x1\n" - "ld1rqw { z3.s }, p0/Z, [x22]\n" - ".inst 0xa041c77d // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a200 // fmla za.s[x9, 0], { z16.s-z19.s }, z3.s[0]\n" - ".inst 0xc153a381 // fmla za.s[x9, 1], { z28.s-z31.s }, z3.s[0]\n" - "ble 13f\n" - ".inst 0xa040c775 // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x27]\n" - "subs x23, x23, #0x1\n" - ".inst 0xa041c771 // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a680 // fmla za.s[x9, 0], { z20.s-z23.s }, z3.s[1]\n" - ".inst 0xc153a601 // fmla za.s[x9, 1], { z16.s-z19.s }, z3.s[1]\n" - "ble 13f\n" - ".inst 0xa040c765 // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x27]\n" - "subs x23, x23, #0x1\n" - ".inst 0xa041c771 // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a880 // fmla za.s[x9, 0], { z4.s-z7.s }, z3.s[2]\n" - ".inst 0xc153aa01 // fmla za.s[x9, 1], { z16.s-z19.s }, z3.s[2]\n" - "ble 13f\n" - ".inst 0xa040c77d // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x27]\n" - ".inst 0xa041c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xc153af80 // fmla za.s[x9, 0], { z28.s-z31.s }, z3.s[3]\n" - ".inst 0xc153ad81 // fmla za.s[x9, 1], { z12.s-z15.s }, z3.s[3]\n" - "13:" // Width 2: Multiply loop: multiply skip - "tbz %x[flags], #1, 14f\n" - "add x21, %x[args_ptr], %[offset_min]\n" - "add x20, %x[args_ptr], %[offset_max]\n" - ".inst 0xc0062c04 // mova { z4.d-z7.d }, za.d[x9, #0]\n" - ".inst 0xc0062c28 // mova { z8.d-z11.d }, za.d[x9, #1]\n" - "ld1rw { z17.s }, p1/Z, [x21]\n" - "ld1rw { z23.s }, p1/Z, [x20]\n" - ".inst 0xc1b7ca24 // fclamp { z4.s-z7.s }, z17.s, z23.s\n" - ".inst 0xc1b7ca28 // fclamp { z8.s-z11.s }, z17.s, z23.s\n" - ".inst 0xa060c724 // st1w { z4.s-z7.s }, pn9.b, [x25]\n" - ".inst 0xa061c328 // st1w { z8.s-z11.s }, p8, [x25, #0x4, MUL VL]\n" - "b 15f\n" - "14:" // Width 2: No activation - ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n" - ".inst 0xc0062c30 // mova { z16.d-z19.d }, za.d[x9, #1]\n" - ".inst 0xa060c728 // st1w { z8.s-z11.s }, pn9.b, [x25]\n" - ".inst 0xa061c330 // st1w { z16.s-z19.s }, p8, [x25, #0x4, MUL VL]\n" - "15:" // Width 2: Output done - "b 28f\n" - "16:" // Width 3 - "mov x20, #0x2\n" - ".inst 0xa040c768 // ld1w { z8.s-z11.s }, pn9.b/Z, [x27]\n" - "mov x23, %x[K]\n" - ".inst 0xa041c760 // ld1w { z0.s-z3.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - "msub x21, x26, x20, %x[N]\n" - "mov x22, %x[A_ptr]\n" - ".inst 0xa042c764 // ld1w { z4.s-z7.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - "lsl x20, %x[K], #0x2\n" - ".inst 0x25b567f0 // whilelt p8.s, XZR, x21, VLx4\n" - "cmp x23, #0x4\n" - ".inst 0xf8b44ad8 // rprfm pldmany, x20, [x22]\n" - ".inst 0xc0042d00 // mova za.d[x9, #0], { z8.d-z11.d }\n" - ".inst 0xc0042c01 // mova za.d[x9, #1], { z0.d-z3.d }\n" - "addvl x27, x27, #16\n" - ".inst 0xc0042c82 // mova za.d[x9, #2], { z4.d-z7.d }\n" - "ble 18f\n" - "17:" // Width 3: Multiply loop: Main loop head - "whilelt p0.s, XZR, x23\n" - ".inst 0xa040c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27]\n" - "sub x23, x23, #0x4\n" - "ld1rqw { z3.s }, p0/Z, [x22]\n" - "cmp x23, #0x4\n" - "add x22, x22, #0x10\n" - ".inst 0xa041c769 // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa042c765 // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a180 // fmla za.s[x9, 0], { z12.s-z15.s }, z3.s[0]\n" - ".inst 0xa040c771 // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x27]\n" - ".inst 0xc153a101 // fmla za.s[x9, 1], { z8.s-z11.s }, z3.s[0]\n" - ".inst 0xa041c775 // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xc153a082 // fmla za.s[x9, 2], { z4.s-z7.s }, z3.s[0]\n" - ".inst 0xa042c779 // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xa040c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27]\n" - ".inst 0xc153a600 // fmla za.s[x9, 0], { z16.s-z19.s }, z3.s[1]\n" - ".inst 0xa041c77d // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xc153a681 // fmla za.s[x9, 1], { z20.s-z23.s }, z3.s[1]\n" - ".inst 0xa042c769 // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a702 // fmla za.s[x9, 2], { z24.s-z27.s }, z3.s[1]\n" - ".inst 0xa040c765 // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x27]\n" - ".inst 0xa041c775 // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xc153a980 // fmla za.s[x9, 0], { z12.s-z15.s }, z3.s[2]\n" - ".inst 0xa042c779 // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153ab81 // fmla za.s[x9, 1], { z28.s-z31.s }, z3.s[2]\n" - ".inst 0xc153a902 // fmla za.s[x9, 2], { z8.s-z11.s }, z3.s[2]\n" - ".inst 0xc153ac80 // fmla za.s[x9, 0], { z4.s-z7.s }, z3.s[3]\n" - ".inst 0xc153ae81 // fmla za.s[x9, 1], { z20.s-z23.s }, z3.s[3]\n" - ".inst 0xc153af02 // fmla za.s[x9, 2], { z24.s-z27.s }, z3.s[3]\n" - "bgt 17b\n" - "18:" // Width 3: Multiply loop: Single iteration only - "whilelt p0.s, XZR, x23\n" - ".inst 0xa040c775 // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x27]\n" - "subs x23, x23, #0x1\n" - "ld1rqw { z3.s }, p0/Z, [x22]\n" - ".inst 0xa041c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa042c765 // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a280 // fmla za.s[x9, 0], { z20.s-z23.s }, z3.s[0]\n" - ".inst 0xc153a181 // fmla za.s[x9, 1], { z12.s-z15.s }, z3.s[0]\n" - ".inst 0xc153a082 // fmla za.s[x9, 2], { z4.s-z7.s }, z3.s[0]\n" - "ble 19f\n" - ".inst 0xa040c775 // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x27]\n" - "subs x23, x23, #0x1\n" - ".inst 0xa041c769 // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa042c771 // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a680 // fmla za.s[x9, 0], { z20.s-z23.s }, z3.s[1]\n" - ".inst 0xc153a501 // fmla za.s[x9, 1], { z8.s-z11.s }, z3.s[1]\n" - ".inst 0xc153a602 // fmla za.s[x9, 2], { z16.s-z19.s }, z3.s[1]\n" - "ble 19f\n" - ".inst 0xa040c77d // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x27]\n" - "subs x23, x23, #0x1\n" - ".inst 0xa041c779 // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa042c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153ab80 // fmla za.s[x9, 0], { z28.s-z31.s }, z3.s[2]\n" - ".inst 0xc153ab01 // fmla za.s[x9, 1], { z24.s-z27.s }, z3.s[2]\n" - ".inst 0xc153a982 // fmla za.s[x9, 2], { z12.s-z15.s }, z3.s[2]\n" - "ble 19f\n" - ".inst 0xa040c769 // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x27]\n" - ".inst 0xa041c77d // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa042c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - ".inst 0xc153ad00 // fmla za.s[x9, 0], { z8.s-z11.s }, z3.s[3]\n" - ".inst 0xc153af81 // fmla za.s[x9, 1], { z28.s-z31.s }, z3.s[3]\n" - ".inst 0xc153ad82 // fmla za.s[x9, 2], { z12.s-z15.s }, z3.s[3]\n" - "19:" // Width 3: Multiply loop: multiply skip - "tbz %x[flags], #1, 20f\n" - "add x21, %x[args_ptr], %[offset_min]\n" - "add x20, %x[args_ptr], %[offset_max]\n" - ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n" - ".inst 0xc0062c2c // mova { z12.d-z15.d }, za.d[x9, #1]\n" - "ld1rw { z21.s }, p1/Z, [x21]\n" - ".inst 0xc0062c50 // mova { z16.d-z19.d }, za.d[x9, #2]\n" - "ld1rw { z20.s }, p1/Z, [x20]\n" - ".inst 0xc1b4caa8 // fclamp { z8.s-z11.s }, z21.s, z20.s\n" - ".inst 0xc1b4caac // fclamp { z12.s-z15.s }, z21.s, z20.s\n" - ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" - ".inst 0xa060c728 // st1w { z8.s-z11.s }, pn9.b, [x25]\n" - ".inst 0xa061c72c // st1w { z12.s-z15.s }, pn9.b, [x25, #0x4, MUL VL]\n" - ".inst 0xa062c330 // st1w { z16.s-z19.s }, p8, [x25, #0x8, MUL VL]\n" - "b 21f\n" - "20:" // Width 3: No activation - ".inst 0xc0062c04 // mova { z4.d-z7.d }, za.d[x9, #0]\n" - ".inst 0xc0062c2c // mova { z12.d-z15.d }, za.d[x9, #1]\n" - ".inst 0xc0062c5c // mova { z28.d-z31.d }, za.d[x9, #2]\n" - ".inst 0xa060c724 // st1w { z4.s-z7.s }, pn9.b, [x25]\n" - ".inst 0xa061c72c // st1w { z12.s-z15.s }, pn9.b, [x25, #0x4, MUL VL]\n" - ".inst 0xa062c33c // st1w { z28.s-z31.s }, p8, [x25, #0x8, MUL VL]\n" - "21:" // Width 3: Output done - "b 28f\n" - "22:" // Width 4 - "mov x20, #0x3\n" - ".inst 0xa040c764 // ld1w { z4.s-z7.s }, pn9.b/Z, [x27]\n" - "mov x23, %x[K]\n" - ".inst 0xa041c76c // ld1w { z12.s-z15.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - "msub x21, x26, x20, %x[N]\n" - "mov x22, %x[A_ptr]\n" - ".inst 0xa042c77c // ld1w { z28.s-z31.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - "lsl x20, %x[K], #0x2\n" - ".inst 0x25b567f0 // whilelt p8.s, XZR, x21, VLx4\n" - ".inst 0xa043c770 // ld1w { z16.s-z19.s }, pn9.b/Z, [x27, #0xc, MUL VL]\n" - "cmp x23, #0x4\n" - ".inst 0xf8b44ad8 // rprfm pldmany, x20, [x22]\n" - ".inst 0xc0042c80 // mova za.d[x9, #0], { z4.d-z7.d }\n" - ".inst 0xc0042d81 // mova za.d[x9, #1], { z12.d-z15.d }\n" - "addvl x27, x27, #16\n" - ".inst 0xc0042f82 // mova za.d[x9, #2], { z28.d-z31.d }\n" - ".inst 0xc0042e03 // mova za.d[x9, #3], { z16.d-z19.d }\n" - "ble 24f\n" - "23:" // Width 4: Multiply loop: Main loop head - "whilelt p0.s, XZR, x23\n" - ".inst 0xa040c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27]\n" - "sub x23, x23, #0x4\n" - "ld1rqw { z3.s }, p0/Z, [x22]\n" - "cmp x23, #0x4\n" - "add x22, x22, #0x10\n" - ".inst 0xa041c775 // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa042c771 // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - ".inst 0xa043c769 // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x27, #0xc, MUL VL]\n" - ".inst 0xc153a180 // fmla za.s[x9, 0], { z12.s-z15.s }, z3.s[0]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a281 // fmla za.s[x9, 1], { z20.s-z23.s }, z3.s[0]\n" - ".inst 0xa040c779 // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x27]\n" - ".inst 0xc153a202 // fmla za.s[x9, 2], { z16.s-z19.s }, z3.s[0]\n" - ".inst 0xa041c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xc153a103 // fmla za.s[x9, 3], { z8.s-z11.s }, z3.s[0]\n" - ".inst 0xa042c769 // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - ".inst 0xa043c765 // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x27, #0xc, MUL VL]\n" - ".inst 0xc153a700 // fmla za.s[x9, 0], { z24.s-z27.s }, z3.s[1]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a581 // fmla za.s[x9, 1], { z12.s-z15.s }, z3.s[1]\n" - ".inst 0xa040c779 // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x27]\n" - ".inst 0xc153a502 // fmla za.s[x9, 2], { z8.s-z11.s }, z3.s[1]\n" - ".inst 0xa041c769 // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xc153a483 // fmla za.s[x9, 3], { z4.s-z7.s }, z3.s[1]\n" - ".inst 0xa042c771 // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - ".inst 0xa043c765 // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x27, #0xc, MUL VL]\n" - ".inst 0xc153ab00 // fmla za.s[x9, 0], { z24.s-z27.s }, z3.s[2]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a901 // fmla za.s[x9, 1], { z8.s-z11.s }, z3.s[2]\n" - ".inst 0xa040c769 // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x27]\n" - ".inst 0xc153aa02 // fmla za.s[x9, 2], { z16.s-z19.s }, z3.s[2]\n" - ".inst 0xa041c77d // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xc153a883 // fmla za.s[x9, 3], { z4.s-z7.s }, z3.s[2]\n" - ".inst 0xa042c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - ".inst 0xa043c775 // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x27, #0xc, MUL VL]\n" - ".inst 0xc153ad00 // fmla za.s[x9, 0], { z8.s-z11.s }, z3.s[3]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153af81 // fmla za.s[x9, 1], { z28.s-z31.s }, z3.s[3]\n" - ".inst 0xc153ad82 // fmla za.s[x9, 2], { z12.s-z15.s }, z3.s[3]\n" - ".inst 0xc153ae83 // fmla za.s[x9, 3], { z20.s-z23.s }, z3.s[3]\n" - "bgt 23b\n" - "24:" // Width 4: Multiply loop: Single iteration only - "whilelt p0.s, XZR, x23\n" - ".inst 0xa040c771 // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x27]\n" - "subs x23, x23, #0x1\n" - "ld1rqw { z3.s }, p0/Z, [x22]\n" - ".inst 0xa041c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa042c77d // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - ".inst 0xa043c775 // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x27, #0xc, MUL VL]\n" - ".inst 0xc153a200 // fmla za.s[x9, 0], { z16.s-z19.s }, z3.s[0]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a181 // fmla za.s[x9, 1], { z12.s-z15.s }, z3.s[0]\n" - ".inst 0xc153a382 // fmla za.s[x9, 2], { z28.s-z31.s }, z3.s[0]\n" - ".inst 0xc153a283 // fmla za.s[x9, 3], { z20.s-z23.s }, z3.s[0]\n" - "ble 25f\n" - ".inst 0xa040c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27]\n" - "subs x23, x23, #0x1\n" - ".inst 0xa041c765 // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa042c779 // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - ".inst 0xa043c775 // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x27, #0xc, MUL VL]\n" - ".inst 0xc153a580 // fmla za.s[x9, 0], { z12.s-z15.s }, z3.s[1]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a481 // fmla za.s[x9, 1], { z4.s-z7.s }, z3.s[1]\n" - ".inst 0xc153a702 // fmla za.s[x9, 2], { z24.s-z27.s }, z3.s[1]\n" - ".inst 0xc153a683 // fmla za.s[x9, 3], { z20.s-z23.s }, z3.s[1]\n" - "ble 25f\n" - ".inst 0xa040c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27]\n" - "subs x23, x23, #0x1\n" - ".inst 0xa041c769 // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa042c775 // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - ".inst 0xa043c771 // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x27, #0xc, MUL VL]\n" - ".inst 0xc153a980 // fmla za.s[x9, 0], { z12.s-z15.s }, z3.s[2]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153a901 // fmla za.s[x9, 1], { z8.s-z11.s }, z3.s[2]\n" - ".inst 0xc153aa82 // fmla za.s[x9, 2], { z20.s-z23.s }, z3.s[2]\n" - ".inst 0xc153aa03 // fmla za.s[x9, 3], { z16.s-z19.s }, z3.s[2]\n" - "ble 25f\n" - ".inst 0xa040c76d // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x27]\n" - ".inst 0xa041c769 // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa042c775 // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x27, #0x8, MUL VL]\n" - ".inst 0xa043c771 // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x27, #0xc, MUL VL]\n" - ".inst 0xc153ad80 // fmla za.s[x9, 0], { z12.s-z15.s }, z3.s[3]\n" - "addvl x27, x27, #16\n" - ".inst 0xc153ad01 // fmla za.s[x9, 1], { z8.s-z11.s }, z3.s[3]\n" - ".inst 0xc153ae82 // fmla za.s[x9, 2], { z20.s-z23.s }, z3.s[3]\n" - ".inst 0xc153ae03 // fmla za.s[x9, 3], { z16.s-z19.s }, z3.s[3]\n" - "25:" // Width 4: Multiply loop: multiply skip - "tbz %x[flags], #1, 26f\n" - "add x21, %x[args_ptr], %[offset_min]\n" - "add x20, %x[args_ptr], %[offset_max]\n" - ".inst 0xc0062c04 // mova { z4.d-z7.d }, za.d[x9, #0]\n" - ".inst 0xc0062c20 // mova { z0.d-z3.d }, za.d[x9, #1]\n" - "ld1rw { z21.s }, p1/Z, [x21]\n" - ".inst 0xc0062c4c // mova { z12.d-z15.d }, za.d[x9, #2]\n" - "ld1rw { z20.s }, p1/Z, [x20]\n" - ".inst 0xc0062c70 // mova { z16.d-z19.d }, za.d[x9, #3]\n" - ".inst 0xc1b4caa4 // fclamp { z4.s-z7.s }, z21.s, z20.s\n" - ".inst 0xc1b4caa0 // fclamp { z0.s-z3.s }, z21.s, z20.s\n" - ".inst 0xc1b4caac // fclamp { z12.s-z15.s }, z21.s, z20.s\n" - ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" - ".inst 0xa060c724 // st1w { z4.s-z7.s }, pn9.b, [x25]\n" - ".inst 0xa061c720 // st1w { z0.s-z3.s }, pn9.b, [x25, #0x4, MUL VL]\n" - ".inst 0xa062c72c // st1w { z12.s-z15.s }, pn9.b, [x25, #0x8, MUL VL]\n" - ".inst 0xa063c330 // st1w { z16.s-z19.s }, p8, [x25, #0xc, MUL VL]\n" - "addvl x25, x25, #16\n" - "b 27f\n" - "26:" // Width 4: No activation - ".inst 0xc0062c0c // mova { z12.d-z15.d }, za.d[x9, #0]\n" - ".inst 0xc0062c20 // mova { z0.d-z3.d }, za.d[x9, #1]\n" - ".inst 0xc0062c50 // mova { z16.d-z19.d }, za.d[x9, #2]\n" - ".inst 0xc0062c64 // mova { z4.d-z7.d }, za.d[x9, #3]\n" - ".inst 0xa060c72c // st1w { z12.s-z15.s }, pn9.b, [x25]\n" - ".inst 0xa061c720 // st1w { z0.s-z3.s }, pn9.b, [x25, #0x4, MUL VL]\n" - ".inst 0xa062c730 // st1w { z16.s-z19.s }, pn9.b, [x25, #0x8, MUL VL]\n" - ".inst 0xa063c324 // st1w { z4.s-z7.s }, p8, [x25, #0xc, MUL VL]\n" - "addvl x25, x25, #16\n" - "27:" // Width 4: Output done - "subs x24, x24, #0x4\n" - "sub %x[N], %x[N], x26, LSL #2\n" - "bgt 4b\n" - "28:" // Exit - ".inst 0xd503467f // SMSTOP\n" - : [N] "+&r"(N) - : [A_ptr] "r"(A_ptr), [B_ptr] "r"(B_ptr), [K] "r"(K), [args_ptr] "r"(&ka), [flags] "r"(flags), - [offset_max] "I"(offsetof(KernelArgs, maxval)), [offset_min] "I"(offsetof(KernelArgs, minval)), - [output_ptr] "r"(output_ptr) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x9", "z0", "z1", "z10", "z11", "z12", - "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", - "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h index 7426f09e..346f40a7 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,8 @@ #include +#include "kai/kai_common.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -55,15 +57,15 @@ size_t kai_get_sr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(void); /// Gets the offset in bytes to the data element in the LHS matrix buffer. /// -/// @param[in] m_idx Row index. -/// @param[in] lhs_stride Row stride in bytes. +/// @param[in] m_idx Row index. This must be 0. +/// @param[in] k Columns of unpacked LHS. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(size_t m_idx, size_t lhs_stride); +size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(size_t m_idx, size_t k); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// -/// @param[in] n_idx Column index in the unpacked RHS matrix. +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of n_step /// @param[in] k Number of rows in the unpacked RHS matrix. /// /// @return The offset in bytes to the data element. @@ -71,8 +73,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_ml /// Gets the offset in bytes to the data element in the destination matrix buffer. /// -/// @param[in] m_idx Row index. -/// @param[in] n_idx Column index. +/// @param[in] m_idx Row index. Must be 0 +/// @param[in] n_idx Column index. Must be multiple of n_step /// @param[in] dst_stride Row stride in bytes. /// /// @return The offset in bytes to the data element. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla_asm.S new file mode 100644 index 00000000..44f1aeee --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla_asm.S @@ -0,0 +1,511 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x8, #0x0 + ldr x11, [x0, #0x18] + cntw x10, ALL, MUL #4 + ptrue p1.b + ldr x9, [x0, #0x20] + KAI_ASM_INST(0x25207811) // ptrue pn9.b + mov x22, #0x1 + ldr x21, [x0, #0x10] + add x28, x11, x10 + ldr x20, [x0, #0x28] + sub x28, x28, #0x1 + ldr x27, [x0, #0x8] + udiv x28, x28, x10 + ldr x26, [x0, #0x30] + mov x25, x21 + add x21, x28, #0x3 + mov x24, x20 + and x21, x21, #0xfffffffffffffffc + mul x21, x21, x10 + mul x21, x21, x9 + lsl x21, x21, #0x2 +KAI_ASM_LABEL(label_1) // RHS size check loop + cmp x21, #0x200, LSL #12 + blt label_2 + tbnz x21, #0, label_3 + lsr x21, x21, #0x1 + lsl x22, x22, #0x1 + b label_1 +KAI_ASM_LABEL(label_2) // RHS do prefetch + lsl x20, x21, #0x26 + sub x22, x22, #0x1 + lsl x22, x22, #0x16 + orr x21, x21, x20 + orr x21, x21, x22 + KAI_ASM_INST(0xf8b54b3a) // rprfm pldonce, x21, [x25] +KAI_ASM_LABEL(label_3) // RHS prefetch exit +KAI_ASM_LABEL(label_4) // Column loop + cmp x28, #0x4 + bge label_22 + cmp x28, #0x2 + bgt label_16 + beq label_10 + KAI_ASM_INST(0xa040c734) // ld1w { z20.s-z23.s }, pn9.b/Z, [x25] + mov x23, x9 + mov x21, x11 + mov x22, x27 + lsl x20, x9, #0x2 + KAI_ASM_INST(0x25b567f0) // whilelt p8.s, XZR, x21, VLx4 + cmp x23, #0x4 + KAI_ASM_INST(0xf8b44ad8) // rprfm pldmany, x20, [x22] + KAI_ASM_INST(0xc0040e80) // mova za.d[x8, #0], { z20.d-z23.d } + addvl x25, x25, #16 + ble label_6 +KAI_ASM_LABEL(label_5) // Width 1: Multiply loop: Main loop head + whilelt p0.s, XZR, x23 + KAI_ASM_INST(0xa040c73d) // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x25] + addvl x25, x25, #16 + ld1rqw { z2.s }, p0/Z, [x22] + sub x23, x23, #0x4 + add x22, x22, #0x10 + KAI_ASM_INST(0xa040c731) // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x25] + addvl x25, x25, #16 + cmp x23, #0x4 + KAI_ASM_INST(0xa040c739) // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x25] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1528380) // fmla za.s[x8, 0], { z28.s-z31.s }, z2.s[0] + KAI_ASM_INST(0xa040c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1528600) // fmla za.s[x8, 0], { z16.s-z19.s }, z2.s[1] + KAI_ASM_INST(0xc1528b00) // fmla za.s[x8, 0], { z24.s-z27.s }, z2.s[2] + KAI_ASM_INST(0xc1528d80) // fmla za.s[x8, 0], { z12.s-z15.s }, z2.s[3] + bgt label_5 +KAI_ASM_LABEL(label_6) // Width 1: Multiply loop: Single iteration only + whilelt p0.s, XZR, x23 + KAI_ASM_INST(0xa040c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25] + subs x23, x23, #0x1 + ld1rqw { z3.s }, p0/Z, [x22] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538180) // fmla za.s[x8, 0], { z12.s-z15.s }, z3.s[0] + ble label_7 + KAI_ASM_INST(0xa040c725) // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x25] + subs x23, x23, #0x1 + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538480) // fmla za.s[x8, 0], { z4.s-z7.s }, z3.s[1] + ble label_7 + KAI_ASM_INST(0xa040c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25] + subs x23, x23, #0x1 + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538980) // fmla za.s[x8, 0], { z12.s-z15.s }, z3.s[2] + ble label_7 + KAI_ASM_INST(0xa040c729) // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xc1538d00) // fmla za.s[x8, 0], { z8.s-z11.s }, z3.s[3] +KAI_ASM_LABEL(label_7) // Width 1: Multiply loop: multiply skip + tbz x26, #1, label_8 + add x21, x0, #0x4 + add x20, x0, #0x0 + KAI_ASM_INST(0xc0060c00) // mova { z0.d-z3.d }, za.d[x8, #0] + ld1rw { z23.s }, p1/Z, [x21] + ld1rw { z22.s }, p1/Z, [x20] + KAI_ASM_INST(0xc1b6cae0) // fclamp { z0.s-z3.s }, z23.s, z22.s + KAI_ASM_INST(0xa060c300) // st1w { z0.s-z3.s }, p8, [x24] + b label_9 +KAI_ASM_LABEL(label_8) // Width 1: No activation + KAI_ASM_INST(0xc0060c00) // mova { z0.d-z3.d }, za.d[x8, #0] + KAI_ASM_INST(0xa060c300) // st1w { z0.s-z3.s }, p8, [x24] +KAI_ASM_LABEL(label_9) // Width 1: Output done + b label_28 +KAI_ASM_LABEL(label_10) // Width 2 + KAI_ASM_INST(0xa040c73c) // ld1w { z28.s-z31.s }, pn9.b/Z, [x25] + mov x23, x9 + sub x21, x11, x10 + KAI_ASM_INST(0xa041c724) // ld1w { z4.s-z7.s }, pn9.b/Z, [x25, #0x4, MUL VL] + mov x22, x27 + lsl x20, x9, #0x2 + KAI_ASM_INST(0x25b567f0) // whilelt p8.s, XZR, x21, VLx4 + cmp x23, #0x4 + KAI_ASM_INST(0xf8b44ad8) // rprfm pldmany, x20, [x22] + KAI_ASM_INST(0xc0040f80) // mova za.d[x8, #0], { z28.d-z31.d } + addvl x25, x25, #16 + KAI_ASM_INST(0xc0040c81) // mova za.d[x8, #1], { z4.d-z7.d } + ble label_12 +KAI_ASM_LABEL(label_11) // Width 2: Multiply loop: Main loop head + whilelt p0.s, XZR, x23 + KAI_ASM_INST(0xa040c73d) // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x25] + sub x23, x23, #0x4 + ld1rqw { z1.s }, p0/Z, [x22] + cmp x23, #0x4 + add x22, x22, #0x10 + KAI_ASM_INST(0xa041c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25, #0x4, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xa040c731) // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xc1518380) // fmla za.s[x8, 0], { z28.s-z31.s }, z1.s[0] + KAI_ASM_INST(0xa041c739) // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x25, #0x4, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1518181) // fmla za.s[x8, 1], { z12.s-z15.s }, z1.s[0] + KAI_ASM_INST(0xa040c73d) // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xa041c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25, #0x4, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xa040c729) // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xc1518600) // fmla za.s[x8, 0], { z16.s-z19.s }, z1.s[1] + KAI_ASM_INST(0xa041c735) // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x25, #0x4, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1518701) // fmla za.s[x8, 1], { z24.s-z27.s }, z1.s[1] + KAI_ASM_INST(0xc1518b80) // fmla za.s[x8, 0], { z28.s-z31.s }, z1.s[2] + KAI_ASM_INST(0xc1518981) // fmla za.s[x8, 1], { z12.s-z15.s }, z1.s[2] + KAI_ASM_INST(0xc1518d00) // fmla za.s[x8, 0], { z8.s-z11.s }, z1.s[3] + KAI_ASM_INST(0xc1518e81) // fmla za.s[x8, 1], { z20.s-z23.s }, z1.s[3] + bgt label_11 +KAI_ASM_LABEL(label_12) // Width 2: Multiply loop: Single iteration only + whilelt p0.s, XZR, x23 + KAI_ASM_INST(0xa040c731) // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x25] + subs x23, x23, #0x1 + ld1rqw { z3.s }, p0/Z, [x22] + KAI_ASM_INST(0xa041c73d) // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x25, #0x4, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538200) // fmla za.s[x8, 0], { z16.s-z19.s }, z3.s[0] + KAI_ASM_INST(0xc1538381) // fmla za.s[x8, 1], { z28.s-z31.s }, z3.s[0] + ble label_13 + KAI_ASM_INST(0xa040c735) // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x25] + subs x23, x23, #0x1 + KAI_ASM_INST(0xa041c731) // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x25, #0x4, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538680) // fmla za.s[x8, 0], { z20.s-z23.s }, z3.s[1] + KAI_ASM_INST(0xc1538601) // fmla za.s[x8, 1], { z16.s-z19.s }, z3.s[1] + ble label_13 + KAI_ASM_INST(0xa040c725) // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x25] + subs x23, x23, #0x1 + KAI_ASM_INST(0xa041c731) // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x25, #0x4, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538880) // fmla za.s[x8, 0], { z4.s-z7.s }, z3.s[2] + KAI_ASM_INST(0xc1538a01) // fmla za.s[x8, 1], { z16.s-z19.s }, z3.s[2] + ble label_13 + KAI_ASM_INST(0xa040c73d) // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xa041c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xc1538f80) // fmla za.s[x8, 0], { z28.s-z31.s }, z3.s[3] + KAI_ASM_INST(0xc1538d81) // fmla za.s[x8, 1], { z12.s-z15.s }, z3.s[3] +KAI_ASM_LABEL(label_13) // Width 2: Multiply loop: multiply skip + tbz x26, #1, label_14 + add x21, x0, #0x4 + add x20, x0, #0x0 + KAI_ASM_INST(0xc0060c04) // mova { z4.d-z7.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c28) // mova { z8.d-z11.d }, za.d[x8, #1] + ld1rw { z17.s }, p1/Z, [x21] + ld1rw { z23.s }, p1/Z, [x20] + KAI_ASM_INST(0xc1b7ca24) // fclamp { z4.s-z7.s }, z17.s, z23.s + KAI_ASM_INST(0xc1b7ca28) // fclamp { z8.s-z11.s }, z17.s, z23.s + KAI_ASM_INST(0xa060c704) // st1w { z4.s-z7.s }, pn9.b, [x24] + KAI_ASM_INST(0xa061c308) // st1w { z8.s-z11.s }, p8, [x24, #0x4, MUL VL] + b label_15 +KAI_ASM_LABEL(label_14) // Width 2: No activation + KAI_ASM_INST(0xc0060c08) // mova { z8.d-z11.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c30) // mova { z16.d-z19.d }, za.d[x8, #1] + KAI_ASM_INST(0xa060c708) // st1w { z8.s-z11.s }, pn9.b, [x24] + KAI_ASM_INST(0xa061c310) // st1w { z16.s-z19.s }, p8, [x24, #0x4, MUL VL] +KAI_ASM_LABEL(label_15) // Width 2: Output done + b label_28 +KAI_ASM_LABEL(label_16) // Width 3 + mov x20, #0x2 + KAI_ASM_INST(0xa040c728) // ld1w { z8.s-z11.s }, pn9.b/Z, [x25] + mov x23, x9 + KAI_ASM_INST(0xa041c720) // ld1w { z0.s-z3.s }, pn9.b/Z, [x25, #0x4, MUL VL] + msub x21, x10, x20, x11 + mov x22, x27 + KAI_ASM_INST(0xa042c724) // ld1w { z4.s-z7.s }, pn9.b/Z, [x25, #0x8, MUL VL] + lsl x20, x9, #0x2 + KAI_ASM_INST(0x25b567f0) // whilelt p8.s, XZR, x21, VLx4 + cmp x23, #0x4 + KAI_ASM_INST(0xf8b44ad8) // rprfm pldmany, x20, [x22] + KAI_ASM_INST(0xc0040d00) // mova za.d[x8, #0], { z8.d-z11.d } + KAI_ASM_INST(0xc0040c01) // mova za.d[x8, #1], { z0.d-z3.d } + addvl x25, x25, #16 + KAI_ASM_INST(0xc0040c82) // mova za.d[x8, #2], { z4.d-z7.d } + ble label_18 +KAI_ASM_LABEL(label_17) // Width 3: Multiply loop: Main loop head + whilelt p0.s, XZR, x23 + KAI_ASM_INST(0xa040c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25] + sub x23, x23, #0x4 + ld1rqw { z3.s }, p0/Z, [x22] + cmp x23, #0x4 + add x22, x22, #0x10 + KAI_ASM_INST(0xa041c729) // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xa042c725) // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x25, #0x8, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538180) // fmla za.s[x8, 0], { z12.s-z15.s }, z3.s[0] + KAI_ASM_INST(0xa040c731) // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xc1538101) // fmla za.s[x8, 1], { z8.s-z11.s }, z3.s[0] + KAI_ASM_INST(0xa041c735) // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xc1538082) // fmla za.s[x8, 2], { z4.s-z7.s }, z3.s[0] + KAI_ASM_INST(0xa042c739) // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x25, #0x8, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xa040c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xc1538600) // fmla za.s[x8, 0], { z16.s-z19.s }, z3.s[1] + KAI_ASM_INST(0xa041c73d) // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xc1538681) // fmla za.s[x8, 1], { z20.s-z23.s }, z3.s[1] + KAI_ASM_INST(0xa042c729) // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x25, #0x8, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538702) // fmla za.s[x8, 2], { z24.s-z27.s }, z3.s[1] + KAI_ASM_INST(0xa040c725) // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xa041c735) // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xc1538980) // fmla za.s[x8, 0], { z12.s-z15.s }, z3.s[2] + KAI_ASM_INST(0xa042c739) // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x25, #0x8, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538b81) // fmla za.s[x8, 1], { z28.s-z31.s }, z3.s[2] + KAI_ASM_INST(0xc1538902) // fmla za.s[x8, 2], { z8.s-z11.s }, z3.s[2] + KAI_ASM_INST(0xc1538c80) // fmla za.s[x8, 0], { z4.s-z7.s }, z3.s[3] + KAI_ASM_INST(0xc1538e81) // fmla za.s[x8, 1], { z20.s-z23.s }, z3.s[3] + KAI_ASM_INST(0xc1538f02) // fmla za.s[x8, 2], { z24.s-z27.s }, z3.s[3] + bgt label_17 +KAI_ASM_LABEL(label_18) // Width 3: Multiply loop: Single iteration only + whilelt p0.s, XZR, x23 + KAI_ASM_INST(0xa040c735) // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x25] + subs x23, x23, #0x1 + ld1rqw { z3.s }, p0/Z, [x22] + KAI_ASM_INST(0xa041c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xa042c725) // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x25, #0x8, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538280) // fmla za.s[x8, 0], { z20.s-z23.s }, z3.s[0] + KAI_ASM_INST(0xc1538181) // fmla za.s[x8, 1], { z12.s-z15.s }, z3.s[0] + KAI_ASM_INST(0xc1538082) // fmla za.s[x8, 2], { z4.s-z7.s }, z3.s[0] + ble label_19 + KAI_ASM_INST(0xa040c735) // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x25] + subs x23, x23, #0x1 + KAI_ASM_INST(0xa041c729) // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xa042c731) // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x25, #0x8, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538680) // fmla za.s[x8, 0], { z20.s-z23.s }, z3.s[1] + KAI_ASM_INST(0xc1538501) // fmla za.s[x8, 1], { z8.s-z11.s }, z3.s[1] + KAI_ASM_INST(0xc1538602) // fmla za.s[x8, 2], { z16.s-z19.s }, z3.s[1] + ble label_19 + KAI_ASM_INST(0xa040c73d) // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x25] + subs x23, x23, #0x1 + KAI_ASM_INST(0xa041c739) // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xa042c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25, #0x8, MUL VL] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538b80) // fmla za.s[x8, 0], { z28.s-z31.s }, z3.s[2] + KAI_ASM_INST(0xc1538b01) // fmla za.s[x8, 1], { z24.s-z27.s }, z3.s[2] + KAI_ASM_INST(0xc1538982) // fmla za.s[x8, 2], { z12.s-z15.s }, z3.s[2] + ble label_19 + KAI_ASM_INST(0xa040c729) // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xa041c73d) // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xa042c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25, #0x8, MUL VL] + KAI_ASM_INST(0xc1538d00) // fmla za.s[x8, 0], { z8.s-z11.s }, z3.s[3] + KAI_ASM_INST(0xc1538f81) // fmla za.s[x8, 1], { z28.s-z31.s }, z3.s[3] + KAI_ASM_INST(0xc1538d82) // fmla za.s[x8, 2], { z12.s-z15.s }, z3.s[3] +KAI_ASM_LABEL(label_19) // Width 3: Multiply loop: multiply skip + tbz x26, #1, label_20 + add x21, x0, #0x4 + add x20, x0, #0x0 + KAI_ASM_INST(0xc0060c08) // mova { z8.d-z11.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c2c) // mova { z12.d-z15.d }, za.d[x8, #1] + ld1rw { z21.s }, p1/Z, [x21] + KAI_ASM_INST(0xc0060c50) // mova { z16.d-z19.d }, za.d[x8, #2] + ld1rw { z20.s }, p1/Z, [x20] + KAI_ASM_INST(0xc1b4caa8) // fclamp { z8.s-z11.s }, z21.s, z20.s + KAI_ASM_INST(0xc1b4caac) // fclamp { z12.s-z15.s }, z21.s, z20.s + KAI_ASM_INST(0xc1b4cab0) // fclamp { z16.s-z19.s }, z21.s, z20.s + KAI_ASM_INST(0xa060c708) // st1w { z8.s-z11.s }, pn9.b, [x24] + KAI_ASM_INST(0xa061c70c) // st1w { z12.s-z15.s }, pn9.b, [x24, #0x4, MUL VL] + KAI_ASM_INST(0xa062c310) // st1w { z16.s-z19.s }, p8, [x24, #0x8, MUL VL] + b label_21 +KAI_ASM_LABEL(label_20) // Width 3: No activation + KAI_ASM_INST(0xc0060c04) // mova { z4.d-z7.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c2c) // mova { z12.d-z15.d }, za.d[x8, #1] + KAI_ASM_INST(0xc0060c5c) // mova { z28.d-z31.d }, za.d[x8, #2] + KAI_ASM_INST(0xa060c704) // st1w { z4.s-z7.s }, pn9.b, [x24] + KAI_ASM_INST(0xa061c70c) // st1w { z12.s-z15.s }, pn9.b, [x24, #0x4, MUL VL] + KAI_ASM_INST(0xa062c31c) // st1w { z28.s-z31.s }, p8, [x24, #0x8, MUL VL] +KAI_ASM_LABEL(label_21) // Width 3: Output done + b label_28 +KAI_ASM_LABEL(label_22) // Width 4 + mov x20, #0x3 + KAI_ASM_INST(0xa040c724) // ld1w { z4.s-z7.s }, pn9.b/Z, [x25] + mov x23, x9 + KAI_ASM_INST(0xa041c72c) // ld1w { z12.s-z15.s }, pn9.b/Z, [x25, #0x4, MUL VL] + msub x21, x10, x20, x11 + mov x22, x27 + KAI_ASM_INST(0xa042c73c) // ld1w { z28.s-z31.s }, pn9.b/Z, [x25, #0x8, MUL VL] + lsl x20, x9, #0x2 + KAI_ASM_INST(0x25b567f0) // whilelt p8.s, XZR, x21, VLx4 + KAI_ASM_INST(0xa043c730) // ld1w { z16.s-z19.s }, pn9.b/Z, [x25, #0xc, MUL VL] + cmp x23, #0x4 + KAI_ASM_INST(0xf8b44ad8) // rprfm pldmany, x20, [x22] + KAI_ASM_INST(0xc0040c80) // mova za.d[x8, #0], { z4.d-z7.d } + KAI_ASM_INST(0xc0040d81) // mova za.d[x8, #1], { z12.d-z15.d } + addvl x25, x25, #16 + KAI_ASM_INST(0xc0040f82) // mova za.d[x8, #2], { z28.d-z31.d } + KAI_ASM_INST(0xc0040e03) // mova za.d[x8, #3], { z16.d-z19.d } + ble label_24 +KAI_ASM_LABEL(label_23) // Width 4: Multiply loop: Main loop head + whilelt p0.s, XZR, x23 + KAI_ASM_INST(0xa040c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25] + sub x23, x23, #0x4 + ld1rqw { z3.s }, p0/Z, [x22] + cmp x23, #0x4 + add x22, x22, #0x10 + KAI_ASM_INST(0xa041c735) // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xa042c731) // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x25, #0x8, MUL VL] + KAI_ASM_INST(0xa043c729) // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x25, #0xc, MUL VL] + KAI_ASM_INST(0xc1538180) // fmla za.s[x8, 0], { z12.s-z15.s }, z3.s[0] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538281) // fmla za.s[x8, 1], { z20.s-z23.s }, z3.s[0] + KAI_ASM_INST(0xa040c739) // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xc1538202) // fmla za.s[x8, 2], { z16.s-z19.s }, z3.s[0] + KAI_ASM_INST(0xa041c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xc1538103) // fmla za.s[x8, 3], { z8.s-z11.s }, z3.s[0] + KAI_ASM_INST(0xa042c729) // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x25, #0x8, MUL VL] + KAI_ASM_INST(0xa043c725) // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x25, #0xc, MUL VL] + KAI_ASM_INST(0xc1538700) // fmla za.s[x8, 0], { z24.s-z27.s }, z3.s[1] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538581) // fmla za.s[x8, 1], { z12.s-z15.s }, z3.s[1] + KAI_ASM_INST(0xa040c739) // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xc1538502) // fmla za.s[x8, 2], { z8.s-z11.s }, z3.s[1] + KAI_ASM_INST(0xa041c729) // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xc1538483) // fmla za.s[x8, 3], { z4.s-z7.s }, z3.s[1] + KAI_ASM_INST(0xa042c731) // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x25, #0x8, MUL VL] + KAI_ASM_INST(0xa043c725) // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x25, #0xc, MUL VL] + KAI_ASM_INST(0xc1538b00) // fmla za.s[x8, 0], { z24.s-z27.s }, z3.s[2] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538901) // fmla za.s[x8, 1], { z8.s-z11.s }, z3.s[2] + KAI_ASM_INST(0xa040c729) // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xc1538a02) // fmla za.s[x8, 2], { z16.s-z19.s }, z3.s[2] + KAI_ASM_INST(0xa041c73d) // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xc1538883) // fmla za.s[x8, 3], { z4.s-z7.s }, z3.s[2] + KAI_ASM_INST(0xa042c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25, #0x8, MUL VL] + KAI_ASM_INST(0xa043c735) // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x25, #0xc, MUL VL] + KAI_ASM_INST(0xc1538d00) // fmla za.s[x8, 0], { z8.s-z11.s }, z3.s[3] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538f81) // fmla za.s[x8, 1], { z28.s-z31.s }, z3.s[3] + KAI_ASM_INST(0xc1538d82) // fmla za.s[x8, 2], { z12.s-z15.s }, z3.s[3] + KAI_ASM_INST(0xc1538e83) // fmla za.s[x8, 3], { z20.s-z23.s }, z3.s[3] + bgt label_23 +KAI_ASM_LABEL(label_24) // Width 4: Multiply loop: Single iteration only + whilelt p0.s, XZR, x23 + KAI_ASM_INST(0xa040c731) // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x25] + subs x23, x23, #0x1 + ld1rqw { z3.s }, p0/Z, [x22] + KAI_ASM_INST(0xa041c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xa042c73d) // ldnt1w { z28.s-z31.s }, pn9.b/Z, [x25, #0x8, MUL VL] + KAI_ASM_INST(0xa043c735) // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x25, #0xc, MUL VL] + KAI_ASM_INST(0xc1538200) // fmla za.s[x8, 0], { z16.s-z19.s }, z3.s[0] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538181) // fmla za.s[x8, 1], { z12.s-z15.s }, z3.s[0] + KAI_ASM_INST(0xc1538382) // fmla za.s[x8, 2], { z28.s-z31.s }, z3.s[0] + KAI_ASM_INST(0xc1538283) // fmla za.s[x8, 3], { z20.s-z23.s }, z3.s[0] + ble label_25 + KAI_ASM_INST(0xa040c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25] + subs x23, x23, #0x1 + KAI_ASM_INST(0xa041c725) // ldnt1w { z4.s-z7.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xa042c739) // ldnt1w { z24.s-z27.s }, pn9.b/Z, [x25, #0x8, MUL VL] + KAI_ASM_INST(0xa043c735) // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x25, #0xc, MUL VL] + KAI_ASM_INST(0xc1538580) // fmla za.s[x8, 0], { z12.s-z15.s }, z3.s[1] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538481) // fmla za.s[x8, 1], { z4.s-z7.s }, z3.s[1] + KAI_ASM_INST(0xc1538702) // fmla za.s[x8, 2], { z24.s-z27.s }, z3.s[1] + KAI_ASM_INST(0xc1538683) // fmla za.s[x8, 3], { z20.s-z23.s }, z3.s[1] + ble label_25 + KAI_ASM_INST(0xa040c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25] + subs x23, x23, #0x1 + KAI_ASM_INST(0xa041c729) // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xa042c735) // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x25, #0x8, MUL VL] + KAI_ASM_INST(0xa043c731) // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x25, #0xc, MUL VL] + KAI_ASM_INST(0xc1538980) // fmla za.s[x8, 0], { z12.s-z15.s }, z3.s[2] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538901) // fmla za.s[x8, 1], { z8.s-z11.s }, z3.s[2] + KAI_ASM_INST(0xc1538a82) // fmla za.s[x8, 2], { z20.s-z23.s }, z3.s[2] + KAI_ASM_INST(0xc1538a03) // fmla za.s[x8, 3], { z16.s-z19.s }, z3.s[2] + ble label_25 + KAI_ASM_INST(0xa040c72d) // ldnt1w { z12.s-z15.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xa041c729) // ldnt1w { z8.s-z11.s }, pn9.b/Z, [x25, #0x4, MUL VL] + KAI_ASM_INST(0xa042c735) // ldnt1w { z20.s-z23.s }, pn9.b/Z, [x25, #0x8, MUL VL] + KAI_ASM_INST(0xa043c731) // ldnt1w { z16.s-z19.s }, pn9.b/Z, [x25, #0xc, MUL VL] + KAI_ASM_INST(0xc1538d80) // fmla za.s[x8, 0], { z12.s-z15.s }, z3.s[3] + addvl x25, x25, #16 + KAI_ASM_INST(0xc1538d01) // fmla za.s[x8, 1], { z8.s-z11.s }, z3.s[3] + KAI_ASM_INST(0xc1538e82) // fmla za.s[x8, 2], { z20.s-z23.s }, z3.s[3] + KAI_ASM_INST(0xc1538e03) // fmla za.s[x8, 3], { z16.s-z19.s }, z3.s[3] +KAI_ASM_LABEL(label_25) // Width 4: Multiply loop: multiply skip + tbz x26, #1, label_26 + add x21, x0, #0x4 + add x20, x0, #0x0 + KAI_ASM_INST(0xc0060c04) // mova { z4.d-z7.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c20) // mova { z0.d-z3.d }, za.d[x8, #1] + ld1rw { z21.s }, p1/Z, [x21] + KAI_ASM_INST(0xc0060c4c) // mova { z12.d-z15.d }, za.d[x8, #2] + ld1rw { z20.s }, p1/Z, [x20] + KAI_ASM_INST(0xc0060c70) // mova { z16.d-z19.d }, za.d[x8, #3] + KAI_ASM_INST(0xc1b4caa4) // fclamp { z4.s-z7.s }, z21.s, z20.s + KAI_ASM_INST(0xc1b4caa0) // fclamp { z0.s-z3.s }, z21.s, z20.s + KAI_ASM_INST(0xc1b4caac) // fclamp { z12.s-z15.s }, z21.s, z20.s + KAI_ASM_INST(0xc1b4cab0) // fclamp { z16.s-z19.s }, z21.s, z20.s + KAI_ASM_INST(0xa060c704) // st1w { z4.s-z7.s }, pn9.b, [x24] + KAI_ASM_INST(0xa061c700) // st1w { z0.s-z3.s }, pn9.b, [x24, #0x4, MUL VL] + KAI_ASM_INST(0xa062c70c) // st1w { z12.s-z15.s }, pn9.b, [x24, #0x8, MUL VL] + KAI_ASM_INST(0xa063c310) // st1w { z16.s-z19.s }, p8, [x24, #0xc, MUL VL] + addvl x24, x24, #16 + b label_27 +KAI_ASM_LABEL(label_26) // Width 4: No activation + KAI_ASM_INST(0xc0060c0c) // mova { z12.d-z15.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c20) // mova { z0.d-z3.d }, za.d[x8, #1] + KAI_ASM_INST(0xc0060c50) // mova { z16.d-z19.d }, za.d[x8, #2] + KAI_ASM_INST(0xc0060c64) // mova { z4.d-z7.d }, za.d[x8, #3] + KAI_ASM_INST(0xa060c70c) // st1w { z12.s-z15.s }, pn9.b, [x24] + KAI_ASM_INST(0xa061c700) // st1w { z0.s-z3.s }, pn9.b, [x24, #0x4, MUL VL] + KAI_ASM_INST(0xa062c710) // st1w { z16.s-z19.s }, pn9.b, [x24, #0x8, MUL VL] + KAI_ASM_INST(0xa063c304) // st1w { z4.s-z7.s }, p8, [x24, #0xc, MUL VL] + addvl x24, x24, #16 +KAI_ASM_LABEL(label_27) // Width 4: Output done + subs x28, x28, #0x4 + sub x11, x11, x10, LSL #2 + bgt label_4 +KAI_ASM_LABEL(label_28) // Exit + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c index cf060840..747fc276 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c @@ -7,7 +7,7 @@ // Do not flag up inline assembly blocks #pragma GCC diagnostic ignored "-Woverlength-strings" -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. @@ -18,22 +18,35 @@ #include "kai/kai_common.h" +typedef struct { + float maxval; + float minval; + const void* A_ptr; + const void* B_ptr; + size_t N; + size_t K; + void* output_ptr; + uint64_t flags; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(KernelArgs* args_ptr); + static const size_t kai_m_step = 1; -static const size_t kai_n_step = 16; static const size_t kai_nr = 2; +static const size_t kai_n_step = 16; static const size_t kai_kr = 1; static const size_t kai_sr = 1; size_t kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(void) { - return kai_m_step * kai_get_sme_vector_length_u32(); + return kai_m_step; } size_t kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(void) { - return kai_n_step * kai_get_sme_vector_length_u32(); + return kai_n_step * kai_get_sme_vector_length_u32() / kai_kr; } size_t kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_sme_vector_length_u32() / kai_kr; } size_t kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(void) { @@ -44,20 +57,27 @@ size_t kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(void) { return kai_sr; } -size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(size_t m_idx, size_t lhs_stride) { - KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla() == 0); +size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx == 0); - return m_idx * lhs_stride; + return m_idx * k; +} + +static size_t kai_get_rhs_packed_stride_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(size_t k) { + return kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla() * + (kai_roundup(k, kai_kr) * sizeof(float) + sizeof(float)); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla() == 0); - return n_idx * (k * sizeof(float) + sizeof(float)); + + const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(); + return block_idx * kai_get_rhs_packed_stride_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(k); } size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla() == 0); + KAI_ASSUME(m_idx == 0); KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla() == 0); return (m_idx * dst_stride) + (n_idx * sizeof(float)); @@ -70,739 +90,27 @@ size_t kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(size_t m void kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla( size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) { - KAI_UNUSED(lhs_stride); KAI_UNUSED(dst_stride_row); KAI_UNUSED(dst_stride_col); - KAI_ASSUME(m == 1); - - typedef struct { - float maxval; - float minval; - } KernelArgs; + KAI_UNUSED(lhs_stride); - KernelArgs ka; - ka.maxval = clamp_max; - ka.minval = clamp_min; + KAI_ASSUME(m == 1); - size_t N = n; - size_t K = k; + uint64_t flags = 2; - const void* A_ptr = lhs; - const void* B_ptr = rhs_packed; - void* output_ptr = dst; + KernelArgs args; - uint64_t flags = 2; + args.maxval = clamp_max; + args.minval = clamp_min; + args.A_ptr = lhs; + args.B_ptr = rhs_packed; + args.N = n; + args.K = k; + args.output_ptr = dst; + args.flags = flags; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x8, #0x0\n" - "mov x16, %x[B_ptr]\n" - "cntw x15, ALL, MUL #4\n" - "mov x14, %x[output_ptr]\n" - "add x13, %x[N], x15\n" - "ptrue p1.b\n" - "sub x13, x13, #0x1\n" - ".inst 0x25207811 // ptrue pn9.b\n" - "udiv x13, x13, x15\n" - "mov x22, #0x1\n" - "add x21, x13, #0x3\n" - "and x21, x21, #0xfffffffffffffffc\n" - "mul x21, x21, x15\n" - "mul x21, x21, %x[K]\n" - "lsl x21, x21, #0x2\n" - "1:" // RHS size check loop - "cmp x21, #0x200000\n" - "blt 2f\n" - "tbnz x21, #0, 3f\n" - "lsr x21, x21, #0x1\n" - "lsl x22, x22, #0x1\n" - "b 1b\n" - "2:" // RHS do prefetch - "lsl x20, x21, #0x26\n" - "sub x22, x22, #0x1\n" - "lsl x22, x22, #0x16\n" - "orr x21, x21, x20\n" - "orr x21, x21, x22\n" - ".inst 0xf8b54a1a // rprfm pldonce, x21, [x16]\n" - "3:" // RHS prefetch exit - "mov x12, %x[K]\n" - "cntw x20, ALL, MUL #2\n" - "lsl x12, x12, #0x2\n" - "add x12, x12, #0x4\n" - "mul x12, x12, x20\n" - "4:" // Column loop - "cmp x13, #0x4\n" - "bge 22f\n" - "cmp x13, #0x2\n" - "bgt 16f\n" - "beq 10f\n" - "cntw x20, ALL, MUL #2\n" - "add x22, x16, x12\n" - ".inst 0xa0404614 // ld1w { z20.s-z21.s }, pn9.b/Z, [x16]\n" - "cmp %x[N], x20\n" - "mov x11, %x[K]\n" - "csel x22, x22, x16, GT\n" - "mov x21, %x[N]\n" - ".inst 0xa04046d6 // ld1w { z22.s-z23.s }, pn9.b/Z, [x22]\n" - "mov x10, %x[A_ptr]\n" - "lsl x20, %x[K], #0x2\n" - ".inst 0x25b567f0 // whilelt p8.s, XZR, x21, VLx4\n" - "cmp x11, #0x4\n" - ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" - "addvl x16, x16, #2\n" - "addvl x22, x22, #2\n" - ".inst 0xc0040e80 // mova za.d[x8, #0], { z20.d-z23.d }\n" - "ble 6f\n" - "5:" // Width 1: Multiply loop: Main loop head - "whilelt p0.s, XZR, x11\n" - ".inst 0xa0404605 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - "ld1rqw { z15.s }, p0/Z, [x10]\n" - "sub x11, x11, #0x4\n" - "add x10, x10, #0x10\n" - ".inst 0xa04046c7 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - "cmp x11, #0x4\n" - ".inst 0xa040461d // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046df // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc15f8080 // fmla za.s[x8, 0], { z4.s-z7.s }, z15.s[0]\n" - ".inst 0xa0404601 // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046c3 // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0404615 // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046d7 // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc15f8780 // fmla za.s[x8, 0], { z28.s-z31.s }, z15.s[1]\n" - ".inst 0xc15f8800 // fmla za.s[x8, 0], { z0.s-z3.s }, z15.s[2]\n" - ".inst 0xc15f8e80 // fmla za.s[x8, 0], { z20.s-z23.s }, z15.s[3]\n" - "bgt 5b\n" - "6:" // Width 1: Multiply loop: Single iteration only - "whilelt p0.s, XZR, x11\n" - ".inst 0xa0404601 // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x1\n" - "ld1rqw { z8.s }, p0/Z, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046c3 // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1588000 // fmla za.s[x8, 0], { z0.s-z3.s }, z8.s[0]\n" - "ble 7f\n" - ".inst 0xa0404611 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x1\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046d3 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1588600 // fmla za.s[x8, 0], { z16.s-z19.s }, z8.s[1]\n" - "ble 7f\n" - ".inst 0xa0404615 // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x1\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046d7 // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1588a80 // fmla za.s[x8, 0], { z20.s-z23.s }, z8.s[2]\n" - "ble 7f\n" - ".inst 0xa040460d // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x16]\n" - ".inst 0xa04046cf // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x22]\n" - ".inst 0xc1588d80 // fmla za.s[x8, 0], { z12.s-z15.s }, z8.s[3]\n" - "7:" // Width 1: Multiply loop: multiply skip - "tbz %x[flags], #1, 8f\n" - "add x21, %x[args_ptr], %[offset_min]\n" - "add x20, %x[args_ptr], %[offset_max]\n" - ".inst 0xc0060c08 // mova { z8.d-z11.d }, za.d[x8, #0]\n" - "ld1rw { z21.s }, p1/Z, [x21]\n" - "ld1rw { z29.s }, p1/Z, [x20]\n" - ".inst 0xc1bdcaa8 // fclamp { z8.s-z11.s }, z21.s, z29.s\n" - ".inst 0xa060c1c8 // st1w { z8.s-z11.s }, p8, [x14]\n" - "b 9f\n" - "8:" // Width 1: No activation - ".inst 0xc0060c08 // mova { z8.d-z11.d }, za.d[x8, #0]\n" - ".inst 0xa060c1c8 // st1w { z8.s-z11.s }, p8, [x14]\n" - "9:" // Width 1: Output done - "b 28f\n" - "10:" // Width 2 - "add x24, x16, x12, LSL #1\n" - "cntw x20, ALL, MUL #6\n" - ".inst 0xa0404604 // ld1w { z4.s-z5.s }, pn9.b/Z, [x16]\n" - "add x23, x24, x12\n" - "cmp %x[N], x20\n" - ".inst 0xa0404700 // ld1w { z0.s-z1.s }, pn9.b/Z, [x24]\n" - "add x22, x16, x12\n" - "csel x23, x23, x16, GT\n" - ".inst 0xa04046c6 // ld1w { z6.s-z7.s }, pn9.b/Z, [x22]\n" - "mov x11, %x[K]\n" - "sub x21, %x[N], x15\n" - ".inst 0xa04046e2 // ld1w { z2.s-z3.s }, pn9.b/Z, [x23]\n" - "mov x10, %x[A_ptr]\n" - "lsl x20, %x[K], #0x2\n" - ".inst 0x25b567f0 // whilelt p8.s, XZR, x21, VLx4\n" - "cmp x11, #0x4\n" - ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xc0040c80 // mova za.d[x8, #0], { z4.d-z7.d }\n" - "addvl x22, x22, #2\n" - "addvl x24, x24, #2\n" - ".inst 0xc0040c01 // mova za.d[x8, #1], { z0.d-z3.d }\n" - "addvl x23, x23, #2\n" - "ble 12f\n" - "11:" // Width 2: Multiply loop: Main loop head - "whilelt p0.s, XZR, x11\n" - ".inst 0xa0404605 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - "ld1rqw { z0.s }, p0/Z, [x10]\n" - "sub x11, x11, #0x4\n" - "add x10, x10, #0x10\n" - ".inst 0xa04046c7 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - "cmp x11, #0x4\n" - ".inst 0xa0404715 // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04046f7 // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1508080 // fmla za.s[x8, 0], { z4.s-z7.s }, z0.s[0]\n" - ".inst 0xa0404619 // ldnt1w { z24.s-z25.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046db // ldnt1w { z26.s-z27.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1508281 // fmla za.s[x8, 1], { z20.s-z23.s }, z0.s[0]\n" - ".inst 0xa0404709 // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04046eb // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1508700 // fmla za.s[x8, 0], { z24.s-z27.s }, z0.s[1]\n" - ".inst 0xa040461d // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046df // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1508501 // fmla za.s[x8, 1], { z8.s-z11.s }, z0.s[1]\n" - ".inst 0xa0404709 // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04046eb // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1508b80 // fmla za.s[x8, 0], { z28.s-z31.s }, z0.s[2]\n" - ".inst 0xa0404619 // ldnt1w { z24.s-z25.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046db // ldnt1w { z26.s-z27.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1508901 // fmla za.s[x8, 1], { z8.s-z11.s }, z0.s[2]\n" - ".inst 0xa040470d // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04046ef // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1508f00 // fmla za.s[x8, 0], { z24.s-z27.s }, z0.s[3]\n" - ".inst 0xc1508d81 // fmla za.s[x8, 1], { z12.s-z15.s }, z0.s[3]\n" - "bgt 11b\n" - "12:" // Width 2: Multiply loop: Single iteration only - "whilelt p0.s, XZR, x11\n" - ".inst 0xa0404605 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x1\n" - "ld1rqw { z8.s }, p0/Z, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046c7 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0404715 // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04046f7 // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1588080 // fmla za.s[x8, 0], { z4.s-z7.s }, z8.s[0]\n" - ".inst 0xc1588281 // fmla za.s[x8, 1], { z20.s-z23.s }, z8.s[0]\n" - "ble 13f\n" - ".inst 0xa040460d // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x1\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046cf // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa040471d // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04046ff // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1588580 // fmla za.s[x8, 0], { z12.s-z15.s }, z8.s[1]\n" - ".inst 0xc1588781 // fmla za.s[x8, 1], { z28.s-z31.s }, z8.s[1]\n" - "ble 13f\n" - ".inst 0xa040461d // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x1\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046df // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0404701 // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04046e3 // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1588b80 // fmla za.s[x8, 0], { z28.s-z31.s }, z8.s[2]\n" - ".inst 0xc1588801 // fmla za.s[x8, 1], { z0.s-z3.s }, z8.s[2]\n" - "ble 13f\n" - ".inst 0xa0404615 // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x16]\n" - ".inst 0xa04046d7 // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x22]\n" - ".inst 0xa0404701 // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x24]\n" - ".inst 0xa04046e3 // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x23]\n" - ".inst 0xc1588e80 // fmla za.s[x8, 0], { z20.s-z23.s }, z8.s[3]\n" - ".inst 0xc1588c01 // fmla za.s[x8, 1], { z0.s-z3.s }, z8.s[3]\n" - "13:" // Width 2: Multiply loop: multiply skip - "tbz %x[flags], #1, 14f\n" - "add x21, %x[args_ptr], %[offset_min]\n" - "add x20, %x[args_ptr], %[offset_max]\n" - ".inst 0xc0060c1c // mova { z28.d-z31.d }, za.d[x8, #0]\n" - ".inst 0xc0060c24 // mova { z4.d-z7.d }, za.d[x8, #1]\n" - "ld1rw { z17.s }, p1/Z, [x21]\n" - "ld1rw { z9.s }, p1/Z, [x20]\n" - ".inst 0xc1a9ca3c // fclamp { z28.s-z31.s }, z17.s, z9.s\n" - ".inst 0xc1a9ca24 // fclamp { z4.s-z7.s }, z17.s, z9.s\n" - ".inst 0xa060c5dc // st1w { z28.s-z31.s }, pn9.b, [x14]\n" - ".inst 0xa061c1c4 // st1w { z4.s-z7.s }, p8, [x14, #0x4, MUL VL]\n" - "b 15f\n" - "14:" // Width 2: No activation - ".inst 0xc0060c04 // mova { z4.d-z7.d }, za.d[x8, #0]\n" - ".inst 0xc0060c3c // mova { z28.d-z31.d }, za.d[x8, #1]\n" - ".inst 0xa060c5c4 // st1w { z4.s-z7.s }, pn9.b, [x14]\n" - ".inst 0xa061c1dc // st1w { z28.s-z31.s }, p8, [x14, #0x4, MUL VL]\n" - "15:" // Width 2: Output done - "b 28f\n" - "16:" // Width 3 - "add x26, x16, x12, LSL #2\n" - "cntw x20, ALL, MUL #10\n" - ".inst 0xa0404614 // ld1w { z20.s-z21.s }, pn9.b/Z, [x16]\n" - "add x25, x16, x12, LSL #1\n" - "add x24, x26, x12\n" - ".inst 0xa0404740 // ld1w { z0.s-z1.s }, pn9.b/Z, [x26]\n" - "cmp %x[N], x20\n" - "add x23, x16, x12\n" - ".inst 0xa0404730 // ld1w { z16.s-z17.s }, pn9.b/Z, [x25]\n" - "add x22, x25, x12\n" - "csel x24, x24, x16, GT\n" - ".inst 0xa04046f6 // ld1w { z22.s-z23.s }, pn9.b/Z, [x23]\n" - "mov x20, #0x2\n" - ".inst 0xa04046d2 // ld1w { z18.s-z19.s }, pn9.b/Z, [x22]\n" - "mov x11, %x[K]\n" - ".inst 0xa0404702 // ld1w { z2.s-z3.s }, pn9.b/Z, [x24]\n" - "msub x21, x15, x20, %x[N]\n" - "mov x10, %x[A_ptr]\n" - "lsl x20, %x[K], #0x2\n" - ".inst 0x25b567f0 // whilelt p8.s, XZR, x21, VLx4\n" - ".inst 0xc0040e80 // mova za.d[x8, #0], { z20.d-z23.d }\n" - "cmp x11, #0x4\n" - ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" - ".inst 0xc0040e01 // mova za.d[x8, #1], { z16.d-z19.d }\n" - "addvl x16, x16, #2\n" - "addvl x23, x23, #2\n" - ".inst 0xc0040c02 // mova za.d[x8, #2], { z0.d-z3.d }\n" - "addvl x25, x25, #2\n" - "addvl x22, x22, #2\n" - "addvl x26, x26, #2\n" - "addvl x24, x24, #2\n" - "ble 18f\n" - "17:" // Width 3: Multiply loop: Main loop head - "whilelt p0.s, XZR, x11\n" - ".inst 0xa040460d // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - "ld1rqw { z3.s }, p0/Z, [x10]\n" - "sub x11, x11, #0x4\n" - "add x10, x10, #0x10\n" - ".inst 0xa04046ef // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - "cmp x11, #0x4\n" - ".inst 0xa0404729 // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04046cb // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0404751 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x26]\n" - ".inst 0xc1538180 // fmla za.s[x8, 0], { z12.s-z15.s }, z3.s[0]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0404713 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1538101 // fmla za.s[x8, 1], { z8.s-z11.s }, z3.s[0]\n" - ".inst 0xa0404609 // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046eb // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1538202 // fmla za.s[x8, 2], { z16.s-z19.s }, z3.s[0]\n" - ".inst 0xa0404731 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04046d3 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0404745 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x26]\n" - ".inst 0xc1538500 // fmla za.s[x8, 0], { z8.s-z11.s }, z3.s[1]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0404707 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1538601 // fmla za.s[x8, 1], { z16.s-z19.s }, z3.s[1]\n" - ".inst 0xa0404609 // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046eb // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1538482 // fmla za.s[x8, 2], { z4.s-z7.s }, z3.s[1]\n" - ".inst 0xa0404731 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04046d3 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0404745 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x26]\n" - ".inst 0xc1538900 // fmla za.s[x8, 0], { z8.s-z11.s }, z3.s[2]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0404707 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1538a01 // fmla za.s[x8, 1], { z16.s-z19.s }, z3.s[2]\n" - ".inst 0xa0404615 // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046f7 // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1538882 // fmla za.s[x8, 2], { z4.s-z7.s }, z3.s[2]\n" - ".inst 0xa0404739 // ldnt1w { z24.s-z25.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04046db // ldnt1w { z26.s-z27.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0404751 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x26]\n" - ".inst 0xc1538e80 // fmla za.s[x8, 0], { z20.s-z23.s }, z3.s[3]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0404713 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1538f01 // fmla za.s[x8, 1], { z24.s-z27.s }, z3.s[3]\n" - ".inst 0xc1538e02 // fmla za.s[x8, 2], { z16.s-z19.s }, z3.s[3]\n" - "bgt 17b\n" - "18:" // Width 3: Multiply loop: Single iteration only - "whilelt p0.s, XZR, x11\n" - ".inst 0xa0404605 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x1\n" - "ld1rqw { z8.s }, p0/Z, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046e7 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa040473d // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04046df // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0404755 // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x26]\n" - ".inst 0xc1588080 // fmla za.s[x8, 0], { z4.s-z7.s }, z8.s[0]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0404717 // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1588381 // fmla za.s[x8, 1], { z28.s-z31.s }, z8.s[0]\n" - ".inst 0xc1588282 // fmla za.s[x8, 2], { z20.s-z23.s }, z8.s[0]\n" - "ble 19f\n" - ".inst 0xa040460d // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x1\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046ef // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0404725 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04046c7 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0404751 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x26]\n" - ".inst 0xc1588580 // fmla za.s[x8, 0], { z12.s-z15.s }, z8.s[1]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0404713 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1588481 // fmla za.s[x8, 1], { z4.s-z7.s }, z8.s[1]\n" - ".inst 0xc1588602 // fmla za.s[x8, 2], { z16.s-z19.s }, z8.s[1]\n" - "ble 19f\n" - ".inst 0xa0404601 // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x1\n" - "addvl x16, x16, #2\n" - ".inst 0xa04046e3 // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa040472d // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04046cf // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0404751 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x26]\n" - ".inst 0xc1588800 // fmla za.s[x8, 0], { z0.s-z3.s }, z8.s[2]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0404713 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1588981 // fmla za.s[x8, 1], { z12.s-z15.s }, z8.s[2]\n" - ".inst 0xc1588a02 // fmla za.s[x8, 2], { z16.s-z19.s }, z8.s[2]\n" - "ble 19f\n" - ".inst 0xa0404605 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x16]\n" - ".inst 0xa04046e7 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x23]\n" - ".inst 0xa040472d // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x25]\n" - ".inst 0xa04046cf // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x22]\n" - ".inst 0xa0404755 // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x26]\n" - ".inst 0xc1588c80 // fmla za.s[x8, 0], { z4.s-z7.s }, z8.s[3]\n" - ".inst 0xa0404717 // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x24]\n" - ".inst 0xc1588d81 // fmla za.s[x8, 1], { z12.s-z15.s }, z8.s[3]\n" - ".inst 0xc1588e82 // fmla za.s[x8, 2], { z20.s-z23.s }, z8.s[3]\n" - "19:" // Width 3: Multiply loop: multiply skip - "tbz %x[flags], #1, 20f\n" - "add x21, %x[args_ptr], %[offset_min]\n" - "add x20, %x[args_ptr], %[offset_max]\n" - ".inst 0xc0060c08 // mova { z8.d-z11.d }, za.d[x8, #0]\n" - ".inst 0xc0060c2c // mova { z12.d-z15.d }, za.d[x8, #1]\n" - "ld1rw { z21.s }, p1/Z, [x21]\n" - ".inst 0xc0060c50 // mova { z16.d-z19.d }, za.d[x8, #2]\n" - "ld1rw { z20.s }, p1/Z, [x20]\n" - ".inst 0xc1b4caa8 // fclamp { z8.s-z11.s }, z21.s, z20.s\n" - ".inst 0xc1b4caac // fclamp { z12.s-z15.s }, z21.s, z20.s\n" - ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" - ".inst 0xa060c5c8 // st1w { z8.s-z11.s }, pn9.b, [x14]\n" - ".inst 0xa061c5cc // st1w { z12.s-z15.s }, pn9.b, [x14, #0x4, MUL VL]\n" - ".inst 0xa062c1d0 // st1w { z16.s-z19.s }, p8, [x14, #0x8, MUL VL]\n" - "b 21f\n" - "20:" // Width 3: No activation - ".inst 0xc0060c04 // mova { z4.d-z7.d }, za.d[x8, #0]\n" - ".inst 0xc0060c20 // mova { z0.d-z3.d }, za.d[x8, #1]\n" - ".inst 0xc0060c50 // mova { z16.d-z19.d }, za.d[x8, #2]\n" - ".inst 0xa060c5c4 // st1w { z4.s-z7.s }, pn9.b, [x14]\n" - ".inst 0xa061c5c0 // st1w { z0.s-z3.s }, pn9.b, [x14, #0x4, MUL VL]\n" - ".inst 0xa062c1d0 // st1w { z16.s-z19.s }, p8, [x14, #0x8, MUL VL]\n" - "21:" // Width 3: Output done - "b 28f\n" - "22:" // Width 4 - "add x9, x16, x12, LSL #2\n" - "cntw x20, ALL, MUL #14\n" - ".inst 0xa040460c // ld1w { z12.s-z13.s }, pn9.b/Z, [x16]\n" - "add x28, x9, x12, LSL #1\n" - "add x27, x16, x12, LSL #1\n" - ".inst 0xa0404528 // ld1w { z8.s-z9.s }, pn9.b/Z, [x9]\n" - "add x26, x28, x12\n" - "cmp %x[N], x20\n" - ".inst 0xa0404760 // ld1w { z0.s-z1.s }, pn9.b/Z, [x27]\n" - "add x25, x16, x12\n" - "add x24, x27, x12\n" - ".inst 0xa0404790 // ld1w { z16.s-z17.s }, pn9.b/Z, [x28]\n" - "add x23, x9, x12\n" - "csel x26, x26, x16, GT\n" - ".inst 0xa040472e // ld1w { z14.s-z15.s }, pn9.b/Z, [x25]\n" - "mov x20, #0x3\n" - ".inst 0xa0404702 // ld1w { z2.s-z3.s }, pn9.b/Z, [x24]\n" - "mov x11, %x[K]\n" - ".inst 0xa04046ea // ld1w { z10.s-z11.s }, pn9.b/Z, [x23]\n" - "msub x21, x15, x20, %x[N]\n" - "mov x10, %x[A_ptr]\n" - ".inst 0xa0404752 // ld1w { z18.s-z19.s }, pn9.b/Z, [x26]\n" - "lsl x20, %x[K], #0x2\n" - ".inst 0x25b567f0 // whilelt p8.s, XZR, x21, VLx4\n" - ".inst 0xc0040d80 // mova za.d[x8, #0], { z12.d-z15.d }\n" - "cmp x11, #0x4\n" - ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" - ".inst 0xc0040c01 // mova za.d[x8, #1], { z0.d-z3.d }\n" - "add x22, x16, x12, LSL #3\n" - "addvl x16, x16, #2\n" - ".inst 0xc0040d02 // mova za.d[x8, #2], { z8.d-z11.d }\n" - "addvl x25, x25, #2\n" - "addvl x27, x27, #2\n" - ".inst 0xc0040e03 // mova za.d[x8, #3], { z16.d-z19.d }\n" - "addvl x24, x24, #2\n" - "addvl x9, x9, #2\n" - "addvl x23, x23, #2\n" - "addvl x28, x28, #2\n" - "addvl x26, x26, #2\n" - "ble 24f\n" - "23:" // Width 4: Multiply loop: Main loop head - "whilelt p0.s, XZR, x11\n" - ".inst 0xa0404609 // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - "ld1rqw { z13.s }, p0/Z, [x10]\n" - "sub x11, x11, #0x4\n" - "add x10, x10, #0x10\n" - ".inst 0xa040472b // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - "cmp x11, #0x4\n" - ".inst 0xa0404765 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0404707 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0404531 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x9]\n" - ".inst 0xc15d8100 // fmla za.s[x8, 0], { z8.s-z11.s }, z13.s[0]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04046f3 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0404781 // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x28]\n" - ".inst 0xc15d8081 // fmla za.s[x8, 1], { z4.s-z7.s }, z13.s[0]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0404743 // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc15d8202 // fmla za.s[x8, 2], { z16.s-z19.s }, z13.s[0]\n" - ".inst 0xa040461d // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa040473f // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xc15d8003 // fmla za.s[x8, 3], { z0.s-z3.s }, z13.s[0]\n" - ".inst 0xa0404761 // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0404703 // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0404529 // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x9]\n" - ".inst 0xc15d8780 // fmla za.s[x8, 0], { z28.s-z31.s }, z13.s[1]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04046eb // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0404791 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x28]\n" - ".inst 0xc15d8401 // fmla za.s[x8, 1], { z0.s-z3.s }, z13.s[1]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0404753 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc15d8502 // fmla za.s[x8, 2], { z8.s-z11.s }, z13.s[1]\n" - ".inst 0xa0404605 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa0404727 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xc15d8603 // fmla za.s[x8, 3], { z16.s-z19.s }, z13.s[1]\n" - ".inst 0xa0404761 // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0404703 // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0404529 // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x9]\n" - ".inst 0xc15d8880 // fmla za.s[x8, 0], { z4.s-z7.s }, z13.s[2]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04046eb // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0404791 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x28]\n" - ".inst 0xc15d8801 // fmla za.s[x8, 1], { z0.s-z3.s }, z13.s[2]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0404753 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc15d8902 // fmla za.s[x8, 2], { z8.s-z11.s }, z13.s[2]\n" - ".inst 0xa0404615 // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa0404737 // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xc15d8a03 // fmla za.s[x8, 3], { z16.s-z19.s }, z13.s[2]\n" - ".inst 0xa0404779 // ldnt1w { z24.s-z25.s }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa040471b // ldnt1w { z26.s-z27.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0404529 // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x9]\n" - ".inst 0xc15d8e80 // fmla za.s[x8, 0], { z20.s-z23.s }, z13.s[3]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04046eb // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0404795 // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x28]\n" - ".inst 0xc15d8f01 // fmla za.s[x8, 1], { z24.s-z27.s }, z13.s[3]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0404757 // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc15d8d02 // fmla za.s[x8, 2], { z8.s-z11.s }, z13.s[3]\n" - ".inst 0xc15d8e83 // fmla za.s[x8, 3], { z20.s-z23.s }, z13.s[3]\n" - "bgt 23b\n" - "24:" // Width 4: Multiply loop: Single iteration only - "whilelt p0.s, XZR, x11\n" - ".inst 0xa0404605 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x1\n" - "ld1rqw { z8.s }, p0/Z, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xa0404727 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa0404761 // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0404703 // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa040452d // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x9]\n" - ".inst 0xc1588080 // fmla za.s[x8, 0], { z4.s-z7.s }, z8.s[0]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04046ef // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0404791 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x28]\n" - ".inst 0xc1588001 // fmla za.s[x8, 1], { z0.s-z3.s }, z8.s[0]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0404753 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc1588182 // fmla za.s[x8, 2], { z12.s-z15.s }, z8.s[0]\n" - ".inst 0xc1588203 // fmla za.s[x8, 3], { z16.s-z19.s }, z8.s[0]\n" - "ble 25f\n" - ".inst 0xa040461d // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x1\n" - "addvl x16, x16, #2\n" - ".inst 0xa040473f // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa0404761 // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0404703 // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0404525 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x9]\n" - ".inst 0xc1588780 // fmla za.s[x8, 0], { z28.s-z31.s }, z8.s[1]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04046e7 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0404791 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x28]\n" - ".inst 0xc1588401 // fmla za.s[x8, 1], { z0.s-z3.s }, z8.s[1]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0404753 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc1588482 // fmla za.s[x8, 2], { z4.s-z7.s }, z8.s[1]\n" - ".inst 0xc1588603 // fmla za.s[x8, 3], { z16.s-z19.s }, z8.s[1]\n" - "ble 25f\n" - ".inst 0xa040461d // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x1\n" - "addvl x16, x16, #2\n" - ".inst 0xa040473f // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa040476d // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa040470f // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0404521 // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x9]\n" - ".inst 0xc1588b80 // fmla za.s[x8, 0], { z28.s-z31.s }, z8.s[2]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04046e3 // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0404791 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x28]\n" - ".inst 0xc1588981 // fmla za.s[x8, 1], { z12.s-z15.s }, z8.s[2]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0404753 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc1588802 // fmla za.s[x8, 2], { z0.s-z3.s }, z8.s[2]\n" - ".inst 0xc1588a03 // fmla za.s[x8, 3], { z16.s-z19.s }, z8.s[2]\n" - "ble 25f\n" - ".inst 0xa0404605 // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x16]\n" - ".inst 0xa0404727 // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x25]\n" - ".inst 0xa040476d // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x27]\n" - ".inst 0xa040470f // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x24]\n" - ".inst 0xa0404535 // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x9]\n" - ".inst 0xc1588c80 // fmla za.s[x8, 0], { z4.s-z7.s }, z8.s[3]\n" - ".inst 0xa04046f7 // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x23]\n" - ".inst 0xa0404791 // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x28]\n" - ".inst 0xc1588d81 // fmla za.s[x8, 1], { z12.s-z15.s }, z8.s[3]\n" - ".inst 0xa0404753 // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x26]\n" - ".inst 0xc1588e82 // fmla za.s[x8, 2], { z20.s-z23.s }, z8.s[3]\n" - ".inst 0xc1588e03 // fmla za.s[x8, 3], { z16.s-z19.s }, z8.s[3]\n" - "25:" // Width 4: Multiply loop: multiply skip - "tbz %x[flags], #1, 26f\n" - "add x21, %x[args_ptr], %[offset_min]\n" - "add x20, %x[args_ptr], %[offset_max]\n" - ".inst 0xc0060c04 // mova { z4.d-z7.d }, za.d[x8, #0]\n" - ".inst 0xc0060c20 // mova { z0.d-z3.d }, za.d[x8, #1]\n" - "ld1rw { z21.s }, p1/Z, [x21]\n" - ".inst 0xc0060c4c // mova { z12.d-z15.d }, za.d[x8, #2]\n" - "ld1rw { z20.s }, p1/Z, [x20]\n" - ".inst 0xc0060c70 // mova { z16.d-z19.d }, za.d[x8, #3]\n" - ".inst 0xc1b4caa4 // fclamp { z4.s-z7.s }, z21.s, z20.s\n" - ".inst 0xc1b4caa0 // fclamp { z0.s-z3.s }, z21.s, z20.s\n" - ".inst 0xc1b4caac // fclamp { z12.s-z15.s }, z21.s, z20.s\n" - ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" - ".inst 0xa060c5c4 // st1w { z4.s-z7.s }, pn9.b, [x14]\n" - ".inst 0xa061c5c0 // st1w { z0.s-z3.s }, pn9.b, [x14, #0x4, MUL VL]\n" - ".inst 0xa062c5cc // st1w { z12.s-z15.s }, pn9.b, [x14, #0x8, MUL VL]\n" - ".inst 0xa063c1d0 // st1w { z16.s-z19.s }, p8, [x14, #0xc, MUL VL]\n" - "addvl x14, x14, #16\n" - "b 27f\n" - "26:" // Width 4: No activation - ".inst 0xc0060c0c // mova { z12.d-z15.d }, za.d[x8, #0]\n" - ".inst 0xc0060c20 // mova { z0.d-z3.d }, za.d[x8, #1]\n" - ".inst 0xc0060c50 // mova { z16.d-z19.d }, za.d[x8, #2]\n" - ".inst 0xc0060c64 // mova { z4.d-z7.d }, za.d[x8, #3]\n" - ".inst 0xa060c5cc // st1w { z12.s-z15.s }, pn9.b, [x14]\n" - ".inst 0xa061c5c0 // st1w { z0.s-z3.s }, pn9.b, [x14, #0x4, MUL VL]\n" - ".inst 0xa062c5d0 // st1w { z16.s-z19.s }, pn9.b, [x14, #0x8, MUL VL]\n" - ".inst 0xa063c1c4 // st1w { z4.s-z7.s }, p8, [x14, #0xc, MUL VL]\n" - "addvl x14, x14, #16\n" - "27:" // Width 4: Output done - "subs x13, x13, #0x4\n" - "mov x16, x22\n" - "sub %x[N], %x[N], x15, LSL #2\n" - "bgt 4b\n" - "28:" // Exit - ".inst 0xd503467f // SMSTOP\n" - : [N] "+&r"(N) - : [A_ptr] "r"(A_ptr), [B_ptr] "r"(B_ptr), [K] "r"(K), [args_ptr] "r"(&ka), [flags] "r"(flags), - [offset_max] "I"(offsetof(KernelArgs, maxval)), [offset_min] "I"(offsetof(KernelArgs, minval)), - [output_ptr] "r"(output_ptr) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", - "x27", "x28", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", - "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", - "z6", "z7", "z8", "z9"); + kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h index 62f0d159..f7a72970 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,8 @@ #include +#include "kai/kai_common.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -55,15 +57,15 @@ size_t kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(void); /// Gets the offset in bytes to the data element in the LHS matrix buffer. /// -/// @param[in] m_idx Row index. -/// @param[in] lhs_stride Row stride in bytes. +/// @param[in] m_idx Row index. This must be 0. +/// @param[in] k Columns of unpacked LHS. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(size_t m_idx, size_t lhs_stride); +size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(size_t m_idx, size_t k); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// -/// @param[in] n_idx Column index in the unpacked RHS matrix. +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of n_step /// @param[in] k Number of rows in the unpacked RHS matrix. /// /// @return The offset in bytes to the data element. @@ -71,8 +73,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla /// Gets the offset in bytes to the data element in the destination matrix buffer. /// -/// @param[in] m_idx Row index. -/// @param[in] n_idx Column index. +/// @param[in] m_idx Row index. Must be 0 +/// @param[in] n_idx Column index. Must be multiple of n_step /// @param[in] dst_stride Row stride in bytes. /// /// @return The offset in bytes to the data element. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_asm.S new file mode 100644 index 00000000..ffdce16e --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_asm.S @@ -0,0 +1,763 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x8, #0x0 + ldr x5, [x0, #0x18] + cntw x6, ALL, MUL #4 + ptrue p1.b + ldr x7, [x0, #0x20] + KAI_ASM_INST(0x25207811) // ptrue pn9.b + mov x22, #0x1 + ldr x21, [x0, #0x10] + add x17, x5, x6 + ldr x20, [x0, #0x28] + sub x17, x17, #0x1 + ldr x16, [x0, #0x8] + udiv x17, x17, x6 + ldr x15, [x0, #0x30] + mov x14, x21 + add x21, x17, #0x3 + mov x13, x20 + and x21, x21, #0xfffffffffffffffc + mul x21, x21, x6 + mul x21, x21, x7 + lsl x21, x21, #0x2 +KAI_ASM_LABEL(label_1) // RHS size check loop + cmp x21, #0x200, LSL #12 + blt label_2 + tbnz x21, #0, label_3 + lsr x21, x21, #0x1 + lsl x22, x22, #0x1 + b label_1 +KAI_ASM_LABEL(label_2) // RHS do prefetch + lsl x20, x21, #0x26 + sub x22, x22, #0x1 + lsl x22, x22, #0x16 + orr x21, x21, x20 + orr x21, x21, x22 + KAI_ASM_INST(0xf8b549da) // rprfm pldonce, x21, [x14] +KAI_ASM_LABEL(label_3) // RHS prefetch exit + mov x12, x7 + cntw x20, ALL, MUL #2 + lsl x12, x12, #0x2 + add x12, x12, #0x4 + mul x12, x12, x20 +KAI_ASM_LABEL(label_4) // Column loop + cmp x17, #0x4 + bge label_22 + cmp x17, #0x2 + bgt label_16 + beq label_10 + cntw x20, ALL, MUL #2 + add x22, x14, x12 + KAI_ASM_INST(0xa04045d4) // ld1w { z20.s-z21.s }, pn9.b/Z, [x14] + cmp x5, x20 + mov x11, x7 + csel x22, x22, x14, GT + mov x21, x5 + KAI_ASM_INST(0xa04046d6) // ld1w { z22.s-z23.s }, pn9.b/Z, [x22] + mov x10, x16 + lsl x20, x7, #0x2 + KAI_ASM_INST(0x25b567f0) // whilelt p8.s, XZR, x21, VLx4 + cmp x11, #0x4 + KAI_ASM_INST(0xf8b44958) // rprfm pldmany, x20, [x10] + addvl x14, x14, #2 + addvl x22, x22, #2 + KAI_ASM_INST(0xc0040e80) // mova za.d[x8, #0], { z20.d-z23.d } + ble label_6 +KAI_ASM_LABEL(label_5) // Width 1: Multiply loop: Main loop head + whilelt p0.s, XZR, x11 + KAI_ASM_INST(0xa04045c5) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + ld1rqw { z15.s }, p0/Z, [x10] + sub x11, x11, #0x4 + add x10, x10, #0x10 + KAI_ASM_INST(0xa04046c7) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + cmp x11, #0x4 + KAI_ASM_INST(0xa04045dd) // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046df) // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc15f8080) // fmla za.s[x8, 0], { z4.s-z7.s }, z15.s[0] + KAI_ASM_INST(0xa04045c1) // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046c3) // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa04045d5) // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046d7) // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc15f8780) // fmla za.s[x8, 0], { z28.s-z31.s }, z15.s[1] + KAI_ASM_INST(0xc15f8800) // fmla za.s[x8, 0], { z0.s-z3.s }, z15.s[2] + KAI_ASM_INST(0xc15f8e80) // fmla za.s[x8, 0], { z20.s-z23.s }, z15.s[3] + bgt label_5 +KAI_ASM_LABEL(label_6) // Width 1: Multiply loop: Single iteration only + whilelt p0.s, XZR, x11 + KAI_ASM_INST(0xa04045c1) // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x14] + subs x11, x11, #0x1 + ld1rqw { z8.s }, p0/Z, [x10] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046c3) // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1588000) // fmla za.s[x8, 0], { z0.s-z3.s }, z8.s[0] + ble label_7 + KAI_ASM_INST(0xa04045d1) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x14] + subs x11, x11, #0x1 + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046d3) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1588600) // fmla za.s[x8, 0], { z16.s-z19.s }, z8.s[1] + ble label_7 + KAI_ASM_INST(0xa04045d5) // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x14] + subs x11, x11, #0x1 + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046d7) // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1588a80) // fmla za.s[x8, 0], { z20.s-z23.s }, z8.s[2] + ble label_7 + KAI_ASM_INST(0xa04045cd) // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x14] + KAI_ASM_INST(0xa04046cf) // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x22] + KAI_ASM_INST(0xc1588d80) // fmla za.s[x8, 0], { z12.s-z15.s }, z8.s[3] +KAI_ASM_LABEL(label_7) // Width 1: Multiply loop: multiply skip + tbz x15, #1, label_8 + add x21, x0, #0x4 + add x20, x0, #0x0 + KAI_ASM_INST(0xc0060c08) // mova { z8.d-z11.d }, za.d[x8, #0] + ld1rw { z21.s }, p1/Z, [x21] + ld1rw { z29.s }, p1/Z, [x20] + KAI_ASM_INST(0xc1bdcaa8) // fclamp { z8.s-z11.s }, z21.s, z29.s + KAI_ASM_INST(0xa060c1a8) // st1w { z8.s-z11.s }, p8, [x13] + b label_9 +KAI_ASM_LABEL(label_8) // Width 1: No activation + KAI_ASM_INST(0xc0060c08) // mova { z8.d-z11.d }, za.d[x8, #0] + KAI_ASM_INST(0xa060c1a8) // st1w { z8.s-z11.s }, p8, [x13] +KAI_ASM_LABEL(label_9) // Width 1: Output done + b label_28 +KAI_ASM_LABEL(label_10) // Width 2 + add x24, x14, x12, LSL #1 + cntw x20, ALL, MUL #6 + KAI_ASM_INST(0xa04045c4) // ld1w { z4.s-z5.s }, pn9.b/Z, [x14] + add x23, x24, x12 + cmp x5, x20 + KAI_ASM_INST(0xa0404700) // ld1w { z0.s-z1.s }, pn9.b/Z, [x24] + add x22, x14, x12 + csel x23, x23, x14, GT + KAI_ASM_INST(0xa04046c6) // ld1w { z6.s-z7.s }, pn9.b/Z, [x22] + mov x11, x7 + sub x21, x5, x6 + KAI_ASM_INST(0xa04046e2) // ld1w { z2.s-z3.s }, pn9.b/Z, [x23] + mov x10, x16 + lsl x20, x7, #0x2 + KAI_ASM_INST(0x25b567f0) // whilelt p8.s, XZR, x21, VLx4 + cmp x11, #0x4 + KAI_ASM_INST(0xf8b44958) // rprfm pldmany, x20, [x10] + addvl x14, x14, #2 + KAI_ASM_INST(0xc0040c80) // mova za.d[x8, #0], { z4.d-z7.d } + addvl x22, x22, #2 + addvl x24, x24, #2 + KAI_ASM_INST(0xc0040c01) // mova za.d[x8, #1], { z0.d-z3.d } + addvl x23, x23, #2 + ble label_12 +KAI_ASM_LABEL(label_11) // Width 2: Multiply loop: Main loop head + whilelt p0.s, XZR, x11 + KAI_ASM_INST(0xa04045c5) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + ld1rqw { z0.s }, p0/Z, [x10] + sub x11, x11, #0x4 + add x10, x10, #0x10 + KAI_ASM_INST(0xa04046c7) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + cmp x11, #0x4 + KAI_ASM_INST(0xa0404715) // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04046f7) // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1508080) // fmla za.s[x8, 0], { z4.s-z7.s }, z0.s[0] + KAI_ASM_INST(0xa04045d9) // ldnt1w { z24.s-z25.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046db) // ldnt1w { z26.s-z27.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1508281) // fmla za.s[x8, 1], { z20.s-z23.s }, z0.s[0] + KAI_ASM_INST(0xa0404709) // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04046eb) // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1508700) // fmla za.s[x8, 0], { z24.s-z27.s }, z0.s[1] + KAI_ASM_INST(0xa04045dd) // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046df) // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1508501) // fmla za.s[x8, 1], { z8.s-z11.s }, z0.s[1] + KAI_ASM_INST(0xa0404709) // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04046eb) // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1508b80) // fmla za.s[x8, 0], { z28.s-z31.s }, z0.s[2] + KAI_ASM_INST(0xa04045d9) // ldnt1w { z24.s-z25.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046db) // ldnt1w { z26.s-z27.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1508901) // fmla za.s[x8, 1], { z8.s-z11.s }, z0.s[2] + KAI_ASM_INST(0xa040470d) // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04046ef) // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1508f00) // fmla za.s[x8, 0], { z24.s-z27.s }, z0.s[3] + KAI_ASM_INST(0xc1508d81) // fmla za.s[x8, 1], { z12.s-z15.s }, z0.s[3] + bgt label_11 +KAI_ASM_LABEL(label_12) // Width 2: Multiply loop: Single iteration only + whilelt p0.s, XZR, x11 + KAI_ASM_INST(0xa04045c5) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x14] + subs x11, x11, #0x1 + ld1rqw { z8.s }, p0/Z, [x10] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046c7) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0404715) // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04046f7) // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1588080) // fmla za.s[x8, 0], { z4.s-z7.s }, z8.s[0] + KAI_ASM_INST(0xc1588281) // fmla za.s[x8, 1], { z20.s-z23.s }, z8.s[0] + ble label_13 + KAI_ASM_INST(0xa04045cd) // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x14] + subs x11, x11, #0x1 + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046cf) // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa040471d) // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04046ff) // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1588580) // fmla za.s[x8, 0], { z12.s-z15.s }, z8.s[1] + KAI_ASM_INST(0xc1588781) // fmla za.s[x8, 1], { z28.s-z31.s }, z8.s[1] + ble label_13 + KAI_ASM_INST(0xa04045dd) // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x14] + subs x11, x11, #0x1 + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046df) // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0404701) // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04046e3) // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1588b80) // fmla za.s[x8, 0], { z28.s-z31.s }, z8.s[2] + KAI_ASM_INST(0xc1588801) // fmla za.s[x8, 1], { z0.s-z3.s }, z8.s[2] + ble label_13 + KAI_ASM_INST(0xa04045d5) // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x14] + KAI_ASM_INST(0xa04046d7) // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x22] + KAI_ASM_INST(0xa0404701) // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x24] + KAI_ASM_INST(0xa04046e3) // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x23] + KAI_ASM_INST(0xc1588e80) // fmla za.s[x8, 0], { z20.s-z23.s }, z8.s[3] + KAI_ASM_INST(0xc1588c01) // fmla za.s[x8, 1], { z0.s-z3.s }, z8.s[3] +KAI_ASM_LABEL(label_13) // Width 2: Multiply loop: multiply skip + tbz x15, #1, label_14 + add x21, x0, #0x4 + add x20, x0, #0x0 + KAI_ASM_INST(0xc0060c1c) // mova { z28.d-z31.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c24) // mova { z4.d-z7.d }, za.d[x8, #1] + ld1rw { z17.s }, p1/Z, [x21] + ld1rw { z9.s }, p1/Z, [x20] + KAI_ASM_INST(0xc1a9ca3c) // fclamp { z28.s-z31.s }, z17.s, z9.s + KAI_ASM_INST(0xc1a9ca24) // fclamp { z4.s-z7.s }, z17.s, z9.s + KAI_ASM_INST(0xa060c5bc) // st1w { z28.s-z31.s }, pn9.b, [x13] + KAI_ASM_INST(0xa061c1a4) // st1w { z4.s-z7.s }, p8, [x13, #0x4, MUL VL] + b label_15 +KAI_ASM_LABEL(label_14) // Width 2: No activation + KAI_ASM_INST(0xc0060c04) // mova { z4.d-z7.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c3c) // mova { z28.d-z31.d }, za.d[x8, #1] + KAI_ASM_INST(0xa060c5a4) // st1w { z4.s-z7.s }, pn9.b, [x13] + KAI_ASM_INST(0xa061c1bc) // st1w { z28.s-z31.s }, p8, [x13, #0x4, MUL VL] +KAI_ASM_LABEL(label_15) // Width 2: Output done + b label_28 +KAI_ASM_LABEL(label_16) // Width 3 + add x26, x14, x12, LSL #2 + cntw x20, ALL, MUL #10 + KAI_ASM_INST(0xa04045d4) // ld1w { z20.s-z21.s }, pn9.b/Z, [x14] + add x25, x14, x12, LSL #1 + add x24, x26, x12 + KAI_ASM_INST(0xa0404740) // ld1w { z0.s-z1.s }, pn9.b/Z, [x26] + cmp x5, x20 + add x23, x14, x12 + KAI_ASM_INST(0xa0404730) // ld1w { z16.s-z17.s }, pn9.b/Z, [x25] + add x22, x25, x12 + csel x24, x24, x14, GT + KAI_ASM_INST(0xa04046f6) // ld1w { z22.s-z23.s }, pn9.b/Z, [x23] + mov x20, #0x2 + KAI_ASM_INST(0xa04046d2) // ld1w { z18.s-z19.s }, pn9.b/Z, [x22] + mov x11, x7 + KAI_ASM_INST(0xa0404702) // ld1w { z2.s-z3.s }, pn9.b/Z, [x24] + msub x21, x6, x20, x5 + mov x10, x16 + lsl x20, x7, #0x2 + KAI_ASM_INST(0x25b567f0) // whilelt p8.s, XZR, x21, VLx4 + KAI_ASM_INST(0xc0040e80) // mova za.d[x8, #0], { z20.d-z23.d } + cmp x11, #0x4 + KAI_ASM_INST(0xf8b44958) // rprfm pldmany, x20, [x10] + KAI_ASM_INST(0xc0040e01) // mova za.d[x8, #1], { z16.d-z19.d } + addvl x14, x14, #2 + addvl x23, x23, #2 + KAI_ASM_INST(0xc0040c02) // mova za.d[x8, #2], { z0.d-z3.d } + addvl x25, x25, #2 + addvl x22, x22, #2 + addvl x26, x26, #2 + addvl x24, x24, #2 + ble label_18 +KAI_ASM_LABEL(label_17) // Width 3: Multiply loop: Main loop head + whilelt p0.s, XZR, x11 + KAI_ASM_INST(0xa04045cd) // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + ld1rqw { z3.s }, p0/Z, [x10] + sub x11, x11, #0x4 + add x10, x10, #0x10 + KAI_ASM_INST(0xa04046ef) // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + cmp x11, #0x4 + KAI_ASM_INST(0xa0404729) // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04046cb) // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0404751) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1538180) // fmla za.s[x8, 0], { z12.s-z15.s }, z3.s[0] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0404713) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1538101) // fmla za.s[x8, 1], { z8.s-z11.s }, z3.s[0] + KAI_ASM_INST(0xa04045c9) // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046eb) // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1538202) // fmla za.s[x8, 2], { z16.s-z19.s }, z3.s[0] + KAI_ASM_INST(0xa0404731) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04046d3) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0404745) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1538500) // fmla za.s[x8, 0], { z8.s-z11.s }, z3.s[1] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0404707) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1538601) // fmla za.s[x8, 1], { z16.s-z19.s }, z3.s[1] + KAI_ASM_INST(0xa04045c9) // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046eb) // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1538482) // fmla za.s[x8, 2], { z4.s-z7.s }, z3.s[1] + KAI_ASM_INST(0xa0404731) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04046d3) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0404745) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1538900) // fmla za.s[x8, 0], { z8.s-z11.s }, z3.s[2] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0404707) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1538a01) // fmla za.s[x8, 1], { z16.s-z19.s }, z3.s[2] + KAI_ASM_INST(0xa04045d5) // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046f7) // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1538882) // fmla za.s[x8, 2], { z4.s-z7.s }, z3.s[2] + KAI_ASM_INST(0xa0404739) // ldnt1w { z24.s-z25.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04046db) // ldnt1w { z26.s-z27.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0404751) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1538e80) // fmla za.s[x8, 0], { z20.s-z23.s }, z3.s[3] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0404713) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1538f01) // fmla za.s[x8, 1], { z24.s-z27.s }, z3.s[3] + KAI_ASM_INST(0xc1538e02) // fmla za.s[x8, 2], { z16.s-z19.s }, z3.s[3] + bgt label_17 +KAI_ASM_LABEL(label_18) // Width 3: Multiply loop: Single iteration only + whilelt p0.s, XZR, x11 + KAI_ASM_INST(0xa04045c5) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x14] + subs x11, x11, #0x1 + ld1rqw { z8.s }, p0/Z, [x10] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046e7) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa040473d) // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04046df) // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0404755) // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1588080) // fmla za.s[x8, 0], { z4.s-z7.s }, z8.s[0] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0404717) // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1588381) // fmla za.s[x8, 1], { z28.s-z31.s }, z8.s[0] + KAI_ASM_INST(0xc1588282) // fmla za.s[x8, 2], { z20.s-z23.s }, z8.s[0] + ble label_19 + KAI_ASM_INST(0xa04045cd) // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x14] + subs x11, x11, #0x1 + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046ef) // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0404725) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04046c7) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0404751) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1588580) // fmla za.s[x8, 0], { z12.s-z15.s }, z8.s[1] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0404713) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1588481) // fmla za.s[x8, 1], { z4.s-z7.s }, z8.s[1] + KAI_ASM_INST(0xc1588602) // fmla za.s[x8, 2], { z16.s-z19.s }, z8.s[1] + ble label_19 + KAI_ASM_INST(0xa04045c1) // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x14] + subs x11, x11, #0x1 + addvl x14, x14, #2 + KAI_ASM_INST(0xa04046e3) // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa040472d) // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04046cf) // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0404751) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1588800) // fmla za.s[x8, 0], { z0.s-z3.s }, z8.s[2] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0404713) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1588981) // fmla za.s[x8, 1], { z12.s-z15.s }, z8.s[2] + KAI_ASM_INST(0xc1588a02) // fmla za.s[x8, 2], { z16.s-z19.s }, z8.s[2] + ble label_19 + KAI_ASM_INST(0xa04045c5) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x14] + KAI_ASM_INST(0xa04046e7) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x23] + KAI_ASM_INST(0xa040472d) // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xa04046cf) // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x22] + KAI_ASM_INST(0xa0404755) // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1588c80) // fmla za.s[x8, 0], { z4.s-z7.s }, z8.s[3] + KAI_ASM_INST(0xa0404717) // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x24] + KAI_ASM_INST(0xc1588d81) // fmla za.s[x8, 1], { z12.s-z15.s }, z8.s[3] + KAI_ASM_INST(0xc1588e82) // fmla za.s[x8, 2], { z20.s-z23.s }, z8.s[3] +KAI_ASM_LABEL(label_19) // Width 3: Multiply loop: multiply skip + tbz x15, #1, label_20 + add x21, x0, #0x4 + add x20, x0, #0x0 + KAI_ASM_INST(0xc0060c08) // mova { z8.d-z11.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c2c) // mova { z12.d-z15.d }, za.d[x8, #1] + ld1rw { z21.s }, p1/Z, [x21] + KAI_ASM_INST(0xc0060c50) // mova { z16.d-z19.d }, za.d[x8, #2] + ld1rw { z20.s }, p1/Z, [x20] + KAI_ASM_INST(0xc1b4caa8) // fclamp { z8.s-z11.s }, z21.s, z20.s + KAI_ASM_INST(0xc1b4caac) // fclamp { z12.s-z15.s }, z21.s, z20.s + KAI_ASM_INST(0xc1b4cab0) // fclamp { z16.s-z19.s }, z21.s, z20.s + KAI_ASM_INST(0xa060c5a8) // st1w { z8.s-z11.s }, pn9.b, [x13] + KAI_ASM_INST(0xa061c5ac) // st1w { z12.s-z15.s }, pn9.b, [x13, #0x4, MUL VL] + KAI_ASM_INST(0xa062c1b0) // st1w { z16.s-z19.s }, p8, [x13, #0x8, MUL VL] + b label_21 +KAI_ASM_LABEL(label_20) // Width 3: No activation + KAI_ASM_INST(0xc0060c04) // mova { z4.d-z7.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c20) // mova { z0.d-z3.d }, za.d[x8, #1] + KAI_ASM_INST(0xc0060c50) // mova { z16.d-z19.d }, za.d[x8, #2] + KAI_ASM_INST(0xa060c5a4) // st1w { z4.s-z7.s }, pn9.b, [x13] + KAI_ASM_INST(0xa061c5a0) // st1w { z0.s-z3.s }, pn9.b, [x13, #0x4, MUL VL] + KAI_ASM_INST(0xa062c1b0) // st1w { z16.s-z19.s }, p8, [x13, #0x8, MUL VL] +KAI_ASM_LABEL(label_21) // Width 3: Output done + b label_28 +KAI_ASM_LABEL(label_22) // Width 4 + add x9, x14, x12, LSL #2 + cntw x20, ALL, MUL #14 + KAI_ASM_INST(0xa04045cc) // ld1w { z12.s-z13.s }, pn9.b/Z, [x14] + add x28, x9, x12, LSL #1 + add x27, x14, x12, LSL #1 + KAI_ASM_INST(0xa0404528) // ld1w { z8.s-z9.s }, pn9.b/Z, [x9] + add x26, x28, x12 + cmp x5, x20 + KAI_ASM_INST(0xa0404760) // ld1w { z0.s-z1.s }, pn9.b/Z, [x27] + add x25, x14, x12 + add x24, x27, x12 + KAI_ASM_INST(0xa0404790) // ld1w { z16.s-z17.s }, pn9.b/Z, [x28] + add x23, x9, x12 + csel x26, x26, x14, GT + KAI_ASM_INST(0xa040472e) // ld1w { z14.s-z15.s }, pn9.b/Z, [x25] + mov x20, #0x3 + KAI_ASM_INST(0xa0404702) // ld1w { z2.s-z3.s }, pn9.b/Z, [x24] + mov x11, x7 + KAI_ASM_INST(0xa04046ea) // ld1w { z10.s-z11.s }, pn9.b/Z, [x23] + msub x21, x6, x20, x5 + mov x10, x16 + KAI_ASM_INST(0xa0404752) // ld1w { z18.s-z19.s }, pn9.b/Z, [x26] + lsl x20, x7, #0x2 + KAI_ASM_INST(0x25b567f0) // whilelt p8.s, XZR, x21, VLx4 + KAI_ASM_INST(0xc0040d80) // mova za.d[x8, #0], { z12.d-z15.d } + cmp x11, #0x4 + KAI_ASM_INST(0xf8b44958) // rprfm pldmany, x20, [x10] + KAI_ASM_INST(0xc0040c01) // mova za.d[x8, #1], { z0.d-z3.d } + add x22, x14, x12, LSL #3 + addvl x14, x14, #2 + KAI_ASM_INST(0xc0040d02) // mova za.d[x8, #2], { z8.d-z11.d } + addvl x25, x25, #2 + addvl x27, x27, #2 + KAI_ASM_INST(0xc0040e03) // mova za.d[x8, #3], { z16.d-z19.d } + addvl x24, x24, #2 + addvl x9, x9, #2 + addvl x23, x23, #2 + addvl x28, x28, #2 + addvl x26, x26, #2 + ble label_24 +KAI_ASM_LABEL(label_23) // Width 4: Multiply loop: Main loop head + whilelt p0.s, XZR, x11 + KAI_ASM_INST(0xa04045c9) // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + ld1rqw { z13.s }, p0/Z, [x10] + sub x11, x11, #0x4 + add x10, x10, #0x10 + KAI_ASM_INST(0xa040472b) // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + cmp x11, #0x4 + KAI_ASM_INST(0xa0404765) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0404707) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0404531) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc15d8100) // fmla za.s[x8, 0], { z8.s-z11.s }, z13.s[0] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04046f3) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0404781) // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc15d8081) // fmla za.s[x8, 1], { z4.s-z7.s }, z13.s[0] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0404743) // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc15d8202) // fmla za.s[x8, 2], { z16.s-z19.s }, z13.s[0] + KAI_ASM_INST(0xa04045dd) // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa040473f) // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xc15d8003) // fmla za.s[x8, 3], { z0.s-z3.s }, z13.s[0] + KAI_ASM_INST(0xa0404761) // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0404703) // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0404529) // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc15d8780) // fmla za.s[x8, 0], { z28.s-z31.s }, z13.s[1] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04046eb) // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0404791) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc15d8401) // fmla za.s[x8, 1], { z0.s-z3.s }, z13.s[1] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0404753) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc15d8502) // fmla za.s[x8, 2], { z8.s-z11.s }, z13.s[1] + KAI_ASM_INST(0xa04045c5) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa0404727) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xc15d8603) // fmla za.s[x8, 3], { z16.s-z19.s }, z13.s[1] + KAI_ASM_INST(0xa0404761) // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0404703) // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0404529) // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc15d8880) // fmla za.s[x8, 0], { z4.s-z7.s }, z13.s[2] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04046eb) // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0404791) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc15d8801) // fmla za.s[x8, 1], { z0.s-z3.s }, z13.s[2] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0404753) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc15d8902) // fmla za.s[x8, 2], { z8.s-z11.s }, z13.s[2] + KAI_ASM_INST(0xa04045d5) // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa0404737) // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xc15d8a03) // fmla za.s[x8, 3], { z16.s-z19.s }, z13.s[2] + KAI_ASM_INST(0xa0404779) // ldnt1w { z24.s-z25.s }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa040471b) // ldnt1w { z26.s-z27.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0404529) // ldnt1w { z8.s-z9.s }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc15d8e80) // fmla za.s[x8, 0], { z20.s-z23.s }, z13.s[3] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04046eb) // ldnt1w { z10.s-z11.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0404795) // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc15d8f01) // fmla za.s[x8, 1], { z24.s-z27.s }, z13.s[3] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0404757) // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc15d8d02) // fmla za.s[x8, 2], { z8.s-z11.s }, z13.s[3] + KAI_ASM_INST(0xc15d8e83) // fmla za.s[x8, 3], { z20.s-z23.s }, z13.s[3] + bgt label_23 +KAI_ASM_LABEL(label_24) // Width 4: Multiply loop: Single iteration only + whilelt p0.s, XZR, x11 + KAI_ASM_INST(0xa04045c5) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x14] + subs x11, x11, #0x1 + ld1rqw { z8.s }, p0/Z, [x10] + addvl x14, x14, #2 + KAI_ASM_INST(0xa0404727) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa0404761) // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0404703) // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa040452d) // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc1588080) // fmla za.s[x8, 0], { z4.s-z7.s }, z8.s[0] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04046ef) // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0404791) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc1588001) // fmla za.s[x8, 1], { z0.s-z3.s }, z8.s[0] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0404753) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc1588182) // fmla za.s[x8, 2], { z12.s-z15.s }, z8.s[0] + KAI_ASM_INST(0xc1588203) // fmla za.s[x8, 3], { z16.s-z19.s }, z8.s[0] + ble label_25 + KAI_ASM_INST(0xa04045dd) // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x14] + subs x11, x11, #0x1 + addvl x14, x14, #2 + KAI_ASM_INST(0xa040473f) // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa0404761) // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0404703) // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0404525) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc1588780) // fmla za.s[x8, 0], { z28.s-z31.s }, z8.s[1] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04046e7) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0404791) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc1588401) // fmla za.s[x8, 1], { z0.s-z3.s }, z8.s[1] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0404753) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc1588482) // fmla za.s[x8, 2], { z4.s-z7.s }, z8.s[1] + KAI_ASM_INST(0xc1588603) // fmla za.s[x8, 3], { z16.s-z19.s }, z8.s[1] + ble label_25 + KAI_ASM_INST(0xa04045dd) // ldnt1w { z28.s-z29.s }, pn9.b/Z, [x14] + subs x11, x11, #0x1 + addvl x14, x14, #2 + KAI_ASM_INST(0xa040473f) // ldnt1w { z30.s-z31.s }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa040476d) // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa040470f) // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0404521) // ldnt1w { z0.s-z1.s }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc1588b80) // fmla za.s[x8, 0], { z28.s-z31.s }, z8.s[2] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04046e3) // ldnt1w { z2.s-z3.s }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0404791) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc1588981) // fmla za.s[x8, 1], { z12.s-z15.s }, z8.s[2] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0404753) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc1588802) // fmla za.s[x8, 2], { z0.s-z3.s }, z8.s[2] + KAI_ASM_INST(0xc1588a03) // fmla za.s[x8, 3], { z16.s-z19.s }, z8.s[2] + ble label_25 + KAI_ASM_INST(0xa04045c5) // ldnt1w { z4.s-z5.s }, pn9.b/Z, [x14] + KAI_ASM_INST(0xa0404727) // ldnt1w { z6.s-z7.s }, pn9.b/Z, [x25] + KAI_ASM_INST(0xa040476d) // ldnt1w { z12.s-z13.s }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa040470f) // ldnt1w { z14.s-z15.s }, pn9.b/Z, [x24] + KAI_ASM_INST(0xa0404535) // ldnt1w { z20.s-z21.s }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc1588c80) // fmla za.s[x8, 0], { z4.s-z7.s }, z8.s[3] + KAI_ASM_INST(0xa04046f7) // ldnt1w { z22.s-z23.s }, pn9.b/Z, [x23] + KAI_ASM_INST(0xa0404791) // ldnt1w { z16.s-z17.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc1588d81) // fmla za.s[x8, 1], { z12.s-z15.s }, z8.s[3] + KAI_ASM_INST(0xa0404753) // ldnt1w { z18.s-z19.s }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1588e82) // fmla za.s[x8, 2], { z20.s-z23.s }, z8.s[3] + KAI_ASM_INST(0xc1588e03) // fmla za.s[x8, 3], { z16.s-z19.s }, z8.s[3] +KAI_ASM_LABEL(label_25) // Width 4: Multiply loop: multiply skip + tbz x15, #1, label_26 + add x21, x0, #0x4 + add x20, x0, #0x0 + KAI_ASM_INST(0xc0060c04) // mova { z4.d-z7.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c20) // mova { z0.d-z3.d }, za.d[x8, #1] + ld1rw { z21.s }, p1/Z, [x21] + KAI_ASM_INST(0xc0060c4c) // mova { z12.d-z15.d }, za.d[x8, #2] + ld1rw { z20.s }, p1/Z, [x20] + KAI_ASM_INST(0xc0060c70) // mova { z16.d-z19.d }, za.d[x8, #3] + KAI_ASM_INST(0xc1b4caa4) // fclamp { z4.s-z7.s }, z21.s, z20.s + KAI_ASM_INST(0xc1b4caa0) // fclamp { z0.s-z3.s }, z21.s, z20.s + KAI_ASM_INST(0xc1b4caac) // fclamp { z12.s-z15.s }, z21.s, z20.s + KAI_ASM_INST(0xc1b4cab0) // fclamp { z16.s-z19.s }, z21.s, z20.s + KAI_ASM_INST(0xa060c5a4) // st1w { z4.s-z7.s }, pn9.b, [x13] + KAI_ASM_INST(0xa061c5a0) // st1w { z0.s-z3.s }, pn9.b, [x13, #0x4, MUL VL] + KAI_ASM_INST(0xa062c5ac) // st1w { z12.s-z15.s }, pn9.b, [x13, #0x8, MUL VL] + KAI_ASM_INST(0xa063c1b0) // st1w { z16.s-z19.s }, p8, [x13, #0xc, MUL VL] + addvl x13, x13, #16 + b label_27 +KAI_ASM_LABEL(label_26) // Width 4: No activation + KAI_ASM_INST(0xc0060c0c) // mova { z12.d-z15.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c20) // mova { z0.d-z3.d }, za.d[x8, #1] + KAI_ASM_INST(0xc0060c50) // mova { z16.d-z19.d }, za.d[x8, #2] + KAI_ASM_INST(0xc0060c64) // mova { z4.d-z7.d }, za.d[x8, #3] + KAI_ASM_INST(0xa060c5ac) // st1w { z12.s-z15.s }, pn9.b, [x13] + KAI_ASM_INST(0xa061c5a0) // st1w { z0.s-z3.s }, pn9.b, [x13, #0x4, MUL VL] + KAI_ASM_INST(0xa062c5b0) // st1w { z16.s-z19.s }, pn9.b, [x13, #0x8, MUL VL] + KAI_ASM_INST(0xa063c1a4) // st1w { z4.s-z7.s }, p8, [x13, #0xc, MUL VL] + addvl x13, x13, #16 +KAI_ASM_LABEL(label_27) // Width 4: Output done + subs x17, x17, #0x4 + mov x14, x22 + sub x5, x5, x6, LSL #2 + bgt label_4 +KAI_ASM_LABEL(label_28) // Exit + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c index 418ed2c0..b9d9c72a 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c @@ -4,13 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" #include @@ -18,25 +14,47 @@ #include "kai/kai_common.h" +typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + float min; + float max; + void* accumulator_buffer; + uint64_t flags; +} KernelArgs; + static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 1; static const size_t kai_sr = 1; +void kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(KernelArgs* args); + +// Returns a constant value specific to this kernel that's relative to vector length +static size_t kai_get_kernel_vec_length_constant(void) { + const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u32() / kai_kr; + return kernel_vec_length_constant; +} + size_t kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u32(); + return kai_mr * kai_get_kernel_vec_length_constant(); } size_t kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_kernel_vec_length_constant(); } size_t kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u32(); + return kai_mr * kai_get_kernel_vec_length_constant(); } size_t kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_kernel_vec_length_constant(); } size_t kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { @@ -49,20 +67,26 @@ size_t kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { size_t kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t m_idx, size_t k) { KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0); - return m_idx * k * sizeof(float); + return m_idx * kai_roundup(k, kai_kr) * sizeof(float); +} + +static size_t kai_get_rhs_packed_stride_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t k) { + return kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() * + (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(float)); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0); - return n_idx * (k * sizeof(float) + sizeof(float)); + const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + return block_idx * kai_get_rhs_packed_stride_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(k); } size_t kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_stride) { + size_t m_idx, size_t n_idx, size_t dst_stride_row) { KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0); KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0); - return m_idx * dst_stride + n_idx * sizeof(float); + return m_idx * dst_stride_row + n_idx * sizeof(float); } size_t kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t m, size_t n) { @@ -73,26 +97,10 @@ void kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) { KAI_ASSUME(dst_stride_col == sizeof(float)); - - typedef struct { - const void* A; - const void* B; - - void* C; - uint64_t ldcb; - uint64_t M, N, K; - float min; - float max; - - void* accumulator_buffer; - uint64_t flags; - } KernelArgs; - KernelArgs args; args.A = lhs_packed; args.B = rhs_packed; - args.C = dst; args.ldcb = dst_stride_row; args.M = m; @@ -100,390 +108,10 @@ void kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( args.K = k; args.min = clamp_min; args.max = clamp_max; - args.accumulator_buffer = NULL; args.flags = 0; - __asm__ __volatile__( - "ldr x17, [%x[args], %[offsetof_flags]]\n" - ".inst 0xd503477f // SMSTART ZA\n" - "ptrue p0.b\n" - ".inst 0x25207811 // ptrue pn9.b\n" - "ldr x16, [%x[args], %[offsetof_accumulator_buffer]]\n" - "ldr x15, [%x[args], %[offsetof_accumulator_buffer]]\n" - "tbz x17, #0, 2f\n" - "mov x12, #0x0\n" - "cntw x20\n" - "1:" // Initial accumulator load from buffer: Loop - ".inst 0xa040c618 // ld1w { z24.s-z27.s }, pn9.b/Z, [x16]\n" - ".inst 0xa041c60c // ld1w { z12.s-z15.s }, pn9.b/Z, [x16, #0x4, MUL VL]\n" - ".inst 0xa042c600 // ld1w { z0.s-z3.s }, pn9.b/Z, [x16, #0x8, MUL VL]\n" - ".inst 0xa043c610 // ld1w { z16.s-z19.s }, pn9.b/Z, [x16, #0xc, MUL VL]\n" - ".inst 0xc0840700 // mova za0h.s[x12], { z24.s-z27.s }\n" - "addvl x16, x16, #16\n" - ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n" - ".inst 0xc0840402 // mova za2h.s[x12], { z0.s-z3.s }\n" - ".inst 0xc0840603 // mova za3h.s[x12], { z16.s-z19.s }\n" - "add x12, x12, #0x4\n" - "cmp x12, x20\n" - "blt 1b\n" - "2:" // Initial accumulator load from buffer: End - "ldr w14, [%x[args], %[offsetof_M]]\n" - "mov x13, #0x0\n" - "mov x11, #0x0\n" - "ldr w10, [%x[args], %[offsetof_N]]\n" - "ldr x9, [%x[args], %[offsetof_A]]\n" - "3:" // M loop - "ldr x28, [%x[args], %[offsetof_B]]\n" - "4:" // N loop - "mov x27, x9\n" - ".inst 0x25aa4570 // whilelt pn8.s, x11, x10, VLx2\n" - "tbnz x17, #0, 5f\n" - "fmov z17.s, #1.0\n" - ".inst 0xa040438a // ld1w { z10.s-z11.s }, p8/Z, [x28]\n" // Load bias - ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" - "addvl x28, x28, #2\n" - ".inst 0x808a0220 // fmopa za0.s, p0/M, p0/M, z17.s, z10.s\n" - ".inst 0x808b0221 // fmopa za1.s, p0/M, p0/M, z17.s, z11.s\n" - ".inst 0x808a0222 // fmopa za2.s, p0/M, p0/M, z17.s, z10.s\n" - ".inst 0x808b0223 // fmopa za3.s, p0/M, p0/M, z17.s, z11.s\n" - "5:" // Prepare accumulators: Test for last block - "mov x20, x11\n" - "mov x21, x13\n" - "incw x20, ALL, MUL #2\n" - "incw x21, ALL, MUL #2\n" - "cmp x20, x10\n" - "mov x20, x17\n" - "csel x21, x13, x21, LT\n" - "bfm x17, XZR, #0x0, #0x0 // bfc x17, #0x0, #0x1\n" - "cmp x21, x14\n" - "csel x17, x20, x17, LT\n" - "ldr x20, [%x[args], %[offsetof_K]]\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 9f\n" - "subs x21, x21, #0x1\n" - ".inst 0xa0404776 // ld1w { z22.s-z23.s }, pn9.b/Z, [x27]\n" - ".inst 0xa1404787 // ld1w { z7.s, z15.s }, pn9.b/Z, [x28]\n" - ".inst 0xa1414766 // ld1w { z6.s, z14.s }, pn9.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0xa0414794 // ld1w { z20.s-z21.s }, pn9.b/Z, [x28, #0x2, MUL VL]\n" - ".inst 0xa1424762 // ld1w { z2.s, z10.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa1424783 // ld1w { z3.s, z11.s }, pn9.b/Z, [x28, #0x4, MUL VL]\n" - ".inst 0xa1434761 // ld1w { z1.s, z9.s }, pn9.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - ".inst 0xa0434784 // ld1w { z4.s-z5.s }, pn9.b/Z, [x28, #0x6, MUL VL]\n" - "addvl x28, x28, #8\n" - "ble 8f\n" - "7:" // K loop - ".inst 0x808702c0 // fmopa za0.s, p0/M, p0/M, z22.s, z7.s\n" - "subs x21, x21, #0x1\n" - ".inst 0x808f02c1 // fmopa za1.s, p0/M, p0/M, z22.s, z15.s\n" - ".inst 0x808702e2 // fmopa za2.s, p0/M, p0/M, z23.s, z7.s\n" - ".inst 0x808f02e3 // fmopa za3.s, p0/M, p0/M, z23.s, z15.s\n" - ".inst 0xa0404776 // ld1w { z22.s-z23.s }, pn9.b/Z, [x27]\n" - ".inst 0x809400c0 // fmopa za0.s, p0/M, p0/M, z6.s, z20.s\n" - ".inst 0xa1404787 // ld1w { z7.s, z15.s }, pn9.b/Z, [x28]\n" - ".inst 0x809500c1 // fmopa za1.s, p0/M, p0/M, z6.s, z21.s\n" - ".inst 0x809401c2 // fmopa za2.s, p0/M, p0/M, z14.s, z20.s\n" - ".inst 0x809501c3 // fmopa za3.s, p0/M, p0/M, z14.s, z21.s\n" - ".inst 0xa1414766 // ld1w { z6.s, z14.s }, pn9.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0x80830040 // fmopa za0.s, p0/M, p0/M, z2.s, z3.s\n" - ".inst 0xa0414794 // ld1w { z20.s-z21.s }, pn9.b/Z, [x28, #0x2, MUL VL]\n" - ".inst 0x808b0041 // fmopa za1.s, p0/M, p0/M, z2.s, z11.s\n" - ".inst 0x80830142 // fmopa za2.s, p0/M, p0/M, z10.s, z3.s\n" - ".inst 0x808b0143 // fmopa za3.s, p0/M, p0/M, z10.s, z11.s\n" - ".inst 0xa1424762 // ld1w { z2.s, z10.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa1424783 // ld1w { z3.s, z11.s }, pn9.b/Z, [x28, #0x4, MUL VL]\n" - ".inst 0x80840020 // fmopa za0.s, p0/M, p0/M, z1.s, z4.s\n" - ".inst 0x80850021 // fmopa za1.s, p0/M, p0/M, z1.s, z5.s\n" - ".inst 0x80840122 // fmopa za2.s, p0/M, p0/M, z9.s, z4.s\n" - ".inst 0x80850123 // fmopa za3.s, p0/M, p0/M, z9.s, z5.s\n" - ".inst 0xa1434761 // ld1w { z1.s, z9.s }, pn9.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - ".inst 0xa0434784 // ld1w { z4.s-z5.s }, pn9.b/Z, [x28, #0x6, MUL VL]\n" - "addvl x28, x28, #8\n" - "bgt 7b\n" - "8:" // K loop tail - ".inst 0x808702c0 // fmopa za0.s, p0/M, p0/M, z22.s, z7.s\n" - ".inst 0x808f02c1 // fmopa za1.s, p0/M, p0/M, z22.s, z15.s\n" - ".inst 0x808702e2 // fmopa za2.s, p0/M, p0/M, z23.s, z7.s\n" - ".inst 0x808f02e3 // fmopa za3.s, p0/M, p0/M, z23.s, z15.s\n" - ".inst 0x809400c0 // fmopa za0.s, p0/M, p0/M, z6.s, z20.s\n" - ".inst 0x809500c1 // fmopa za1.s, p0/M, p0/M, z6.s, z21.s\n" - ".inst 0x809401c2 // fmopa za2.s, p0/M, p0/M, z14.s, z20.s\n" - ".inst 0x809501c3 // fmopa za3.s, p0/M, p0/M, z14.s, z21.s\n" - ".inst 0x80830040 // fmopa za0.s, p0/M, p0/M, z2.s, z3.s\n" - ".inst 0x808b0041 // fmopa za1.s, p0/M, p0/M, z2.s, z11.s\n" - ".inst 0x80830142 // fmopa za2.s, p0/M, p0/M, z10.s, z3.s\n" - ".inst 0x808b0143 // fmopa za3.s, p0/M, p0/M, z10.s, z11.s\n" - ".inst 0x80840020 // fmopa za0.s, p0/M, p0/M, z1.s, z4.s\n" - ".inst 0x80850021 // fmopa za1.s, p0/M, p0/M, z1.s, z5.s\n" - ".inst 0x80840122 // fmopa za2.s, p0/M, p0/M, z9.s, z4.s\n" - ".inst 0x80850123 // fmopa za3.s, p0/M, p0/M, z9.s, z5.s\n" - "9:" // K oddments - "cbz x20, 11f\n" - "10:" // K oddments: Loop - ".inst 0xa040476a // ld1w { z10.s-z11.s }, pn9.b/Z, [x27]\n" - "subs x20, x20, #0x1\n" - "addvl x27, x27, #2\n" - ".inst 0xa040478e // ld1w { z14.s-z15.s }, pn9.b/Z, [x28]\n" - "addvl x28, x28, #2\n" - ".inst 0x808e0140 // fmopa za0.s, p0/M, p0/M, z10.s, z14.s\n" - ".inst 0x808f0141 // fmopa za1.s, p0/M, p0/M, z10.s, z15.s\n" - ".inst 0x808e0162 // fmopa za2.s, p0/M, p0/M, z11.s, z14.s\n" - ".inst 0x808f0163 // fmopa za3.s, p0/M, p0/M, z11.s, z15.s\n" - "bgt 10b\n" - "11:" // K oddments: End - "tbz x17, #1, 15f\n" - "tbz x17, #0, 13f\n" - "mov x12, #0x0\n" - "cntw x20\n" - "12:" // Store to partial result buffer: Store and refill: Loop - ".inst 0xa040c600 // ld1w { z0.s-z3.s }, pn9.b/Z, [x16]\n" - ".inst 0xc0860414 // mova { z20.s-z23.s }, za0h.s[x12]\n" - ".inst 0xc086043c // mova { z28.s-z31.s }, za1h.s[x12]\n" - ".inst 0xa041c604 // ld1w { z4.s-z7.s }, pn9.b/Z, [x16, #0x4, MUL VL]\n" - ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" - ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" - ".inst 0xa042c610 // ld1w { z16.s-z19.s }, pn9.b/Z, [x16, #0x8, MUL VL]\n" - ".inst 0xa043c618 // ld1w { z24.s-z27.s }, pn9.b/Z, [x16, #0xc, MUL VL]\n" - ".inst 0xc0840400 // mova za0h.s[x12], { z0.s-z3.s }\n" - "addvl x16, x16, #16\n" - ".inst 0xc0840481 // mova za1h.s[x12], { z4.s-z7.s }\n" - ".inst 0xa060c5f4 // st1w { z20.s-z23.s }, pn9.b, [x15]\n" - ".inst 0xc0840602 // mova za2h.s[x12], { z16.s-z19.s }\n" - ".inst 0xa061c5fc // st1w { z28.s-z31.s }, pn9.b, [x15, #0x4, MUL VL]\n" - ".inst 0xc0840703 // mova za3h.s[x12], { z24.s-z27.s }\n" - "add x12, x12, #0x4\n" - ".inst 0xa062c5e8 // st1w { z8.s-z11.s }, pn9.b, [x15, #0x8, MUL VL]\n" - "cmp x12, x20\n" - ".inst 0xa063c5ec // st1w { z12.s-z15.s }, pn9.b, [x15, #0xc, MUL VL]\n" - "addvl x15, x15, #16\n" - "blt 12b\n" - "b 31f\n" - "13:" // Store to partial result buffer: Store only - "mov x12, #0x0\n" - "cntw x20\n" - "14:" // Store to partial result buffer: Store only: Loop - ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" - ".inst 0xc0860430 // mova { z16.s-z19.s }, za1h.s[x12]\n" - ".inst 0xc086045c // mova { z28.s-z31.s }, za2h.s[x12]\n" - ".inst 0xc0860474 // mova { z20.s-z23.s }, za3h.s[x12]\n" - ".inst 0xa060c5e0 // st1w { z0.s-z3.s }, pn9.b, [x15]\n" - "add x12, x12, #0x4\n" - ".inst 0xa061c5f0 // st1w { z16.s-z19.s }, pn9.b, [x15, #0x4, MUL VL]\n" - "cmp x12, x20\n" - ".inst 0xa062c5fc // st1w { z28.s-z31.s }, pn9.b, [x15, #0x8, MUL VL]\n" - ".inst 0xa063c5f4 // st1w { z20.s-z23.s }, pn9.b, [x15, #0xc, MUL VL]\n" - "addvl x15, x15, #16\n" - "blt 14b\n" - "b 31f\n" - "15:" // Store to output array - "ldr x26, [%x[args], %[offsetof_C]]\n" - "sub x25, x14, x13\n" - "ldr x24, [%x[args], %[offsetof_ldcb]]\n" - "add x26, x26, x11, LSL #2\n" // C += n - "madd x26, x13, x24, x26\n" // C += m * ldc - "tbz x17, #2, 22f\n" - "cntw x23\n" - "mov x12, #0x0\n" - "cmp x25, x23\n" - "csel x22, x25, x23, LT\n" - "lsr x21, x22, #0x2\n" - "and x20, x22, #0x3\n" - "cbz x21, 17f\n" - "16:" // Store to output array: Skip activation: Accumulator row 0 loop - ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" - ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" - ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "add x12, x12, #0x4\n" - ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "cmp x12, x21, LSL #2\n" - ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" - "add x26, x26, x24\n" - ".inst 0xa1604347 // st1w { z7.s, z15.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "blt 16b\n" - "17:" // Store to output array: Skip activation: Accumulator row 0 oddments - "cbz x20, 18f\n" - "subs x20, x20, #0x1\n" - ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" - ".inst 0xc0860428 // mova { z8.s-z11.s }, za1h.s[x12]\n" - ".inst 0xa1604340 // st1w { z0.s, z8.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "beq 18f\n" - "subs x20, x20, #0x1\n" - ".inst 0xa1604341 // st1w { z1.s, z9.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "beq 18f\n" - ".inst 0xa1604342 // st1w { z2.s, z10.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "18:" // Store to output array: Skip activation: Accumulator row 0 oddments: End - "subs x25, x25, x22\n" - "beq 22f\n" - "cmp x25, x23\n" - "mov x12, #0x0\n" - "csel x22, x25, x23, LT\n" - "lsr x21, x22, #0x2\n" - "and x20, x22, #0x3\n" - "cbz x21, 20f\n" - "19:" // Store to output array: Skip activation: Accumulator row 1 loop - ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n" - ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" - ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "add x12, x12, #0x4\n" - ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "cmp x12, x21, LSL #2\n" - ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" - "add x26, x26, x24\n" - ".inst 0xa1604347 // st1w { z7.s, z15.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "blt 19b\n" - "20:" // Store to output array: Skip activation: Accumulator row 1 oddments - "cbz x20, 21f\n" - "subs x20, x20, #0x1\n" - ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n" - ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" - ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "beq 21f\n" - "subs x20, x20, #0x1\n" - ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "beq 21f\n" - ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "21:" // Store to output array: Skip activation: Accumulator row 1 oddments: End - "subs x25, x25, x22\n" - "beq 22f\n" - "b 29f\n" - "22:" // Store to output array: Skip activation: End - "cntw x23\n" - "ld1rw { z21.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" - "mov x12, #0x0\n" - "cmp x25, x23\n" - "ld1rw { z20.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" - "csel x22, x25, x23, LT\n" - "lsr x21, x22, #0x2\n" - "and x20, x22, #0x3\n" - "cbz x21, 24f\n" - "23:" // Store to output array: Accumulator row 0 loop - ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" - ".inst 0xc0860438 // mova { z24.s-z27.s }, za1h.s[x12]\n" - ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" - ".inst 0xc1b4cab8 // fclamp { z24.s-z27.s }, z21.s, z20.s\n" - "add x12, x12, #0x4\n" - "cmp x12, x21, LSL #2\n" - ".inst 0xa1604350 // st1w { z16.s, z24.s }, p8, [x26]\n" - "add x26, x26, x24\n" - ".inst 0xa1604351 // st1w { z17.s, z25.s }, p8, [x26]\n" - "add x26, x26, x24\n" - ".inst 0xa1604352 // st1w { z18.s, z26.s }, p8, [x26]\n" - "add x26, x26, x24\n" - ".inst 0xa1604353 // st1w { z19.s, z27.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "blt 23b\n" - "24:" // Store to output array: Accumulator row 0 oddments - "cbz x20, 25f\n" - ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" - ".inst 0xc0860438 // mova { z24.s-z27.s }, za1h.s[x12]\n" - "subs x20, x20, #0x1\n" - ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" - ".inst 0xc1b4cab8 // fclamp { z24.s-z27.s }, z21.s, z20.s\n" - ".inst 0xa1604350 // st1w { z16.s, z24.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "beq 25f\n" - "subs x20, x20, #0x1\n" - ".inst 0xa1604351 // st1w { z17.s, z25.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "beq 25f\n" - ".inst 0xa1604352 // st1w { z18.s, z26.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "25:" // Store to output array: Accumulator row 0 oddments: End - "subs x25, x25, x22\n" - "beq 29f\n" - "cmp x25, x23\n" - "mov x12, #0x0\n" - "csel x20, x25, x23, LT\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 27f\n" - "26:" // Store to output array: Accumulator row 1 loop - ".inst 0xc0860440 // mova { z0.s-z3.s }, za2h.s[x12]\n" - ".inst 0xc0860468 // mova { z8.s-z11.s }, za3h.s[x12]\n" - ".inst 0xc1b4caa0 // fclamp { z0.s-z3.s }, z21.s, z20.s\n" - ".inst 0xc1b4caa8 // fclamp { z8.s-z11.s }, z21.s, z20.s\n" - "add x12, x12, #0x4\n" - "cmp x12, x21, LSL #2\n" - ".inst 0xa1604340 // st1w { z0.s, z8.s }, p8, [x26]\n" - "add x26, x26, x24\n" - ".inst 0xa1604341 // st1w { z1.s, z9.s }, p8, [x26]\n" - "add x26, x26, x24\n" - ".inst 0xa1604342 // st1w { z2.s, z10.s }, p8, [x26]\n" - "add x26, x26, x24\n" - ".inst 0xa1604343 // st1w { z3.s, z11.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "blt 26b\n" - "27:" // Store to output array: Accumulator row 1 oddments - "cbz x20, 28f\n" - ".inst 0xc0860450 // mova { z16.s-z19.s }, za2h.s[x12]\n" - ".inst 0xc0860478 // mova { z24.s-z27.s }, za3h.s[x12]\n" - "subs x20, x20, #0x1\n" - ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" - ".inst 0xc1b4cab8 // fclamp { z24.s-z27.s }, z21.s, z20.s\n" - ".inst 0xa1604350 // st1w { z16.s, z24.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "beq 28f\n" - "subs x20, x20, #0x1\n" - ".inst 0xa1604351 // st1w { z17.s, z25.s }, p8, [x26]\n" - "add x26, x26, x24\n" - "beq 28f\n" - ".inst 0xa1604352 // st1w { z18.s, z26.s }, p8, [x26]\n" - "28:" // Store to output array: Accumulator row 1 oddments: End - "29:" // Store to output array: End - "tbz x17, #0, 31f\n" - "mov x12, #0x0\n" - "cntw x20\n" - "30:" // Store to output array: Refill accumulators: Loop - ".inst 0xa040c608 // ld1w { z8.s-z11.s }, pn9.b/Z, [x16]\n" - ".inst 0xa041c600 // ld1w { z0.s-z3.s }, pn9.b/Z, [x16, #0x4, MUL VL]\n" - ".inst 0xa042c604 // ld1w { z4.s-z7.s }, pn9.b/Z, [x16, #0x8, MUL VL]\n" - ".inst 0xa043c60c // ld1w { z12.s-z15.s }, pn9.b/Z, [x16, #0xc, MUL VL]\n" - ".inst 0xc0840500 // mova za0h.s[x12], { z8.s-z11.s }\n" - "addvl x16, x16, #16\n" - ".inst 0xc0840401 // mova za1h.s[x12], { z0.s-z3.s }\n" - ".inst 0xc0840482 // mova za2h.s[x12], { z4.s-z7.s }\n" - ".inst 0xc0840583 // mova za3h.s[x12], { z12.s-z15.s }\n" - "add x12, x12, #0x4\n" - "cmp x12, x20\n" - "blt 30b\n" - "31:" // End block - "incw x11, ALL, MUL #2\n" - "cmp x11, x10\n" - "blt 4b\n" - "incw x13, ALL, MUL #2\n" - "mov x11, #0x0\n" - "cmp x13, x14\n" - "mov x9, x27\n" - "blt 3b\n" - ".inst 0xd503467f // SMSTOP\n" - : - : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), - [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), - [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), - [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), [offsetof_M] "I"(offsetof(KernelArgs, M)), - [offsetof_N] "I"(offsetof(KernelArgs, N)), - [offsetof_accumulator_buffer] "I"(offsetof(KernelArgs, accumulator_buffer)), - [offsetof_flags] "I"(offsetof(KernelArgs, flags)), [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) - : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", - "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25", - "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", - "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", - "z29", "z30", "z31"); + kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h index 1dcc3404..1dc9a19f 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -61,7 +61,7 @@ size_t kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); /// Gets the offset in bytes to the data element in the packed LHS matrix buffer. /// -/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. /// @param[in] k Number of columns in the unpacked LHS matrix. /// /// @return The offset in bytes to the data element. @@ -69,21 +69,21 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// -/// @param[in] n_idx Column index in the unpacked RHS matrix. -/// @param[in] k Number of rows in the unpacked RHS matrix. +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of `n_step`. +/// @param[in] k Number of columns in the unpacked LHS matrix. /// /// @return The offset in bytes to the data element. size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// -/// @param[in] m_idx Row index. -/// @param[in] n_idx Column index. -/// @param[in] stride Row stride in bytes. +/// @param[in] m_idx Row index. Must be a multiple of `m_step`. +/// @param[in] n_idx Column index. Must be a multiple of `n_step`. +/// @param[in] dst_stride_row Row stride in bytes. /// /// @return The offset in bytes to the data element. size_t kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_stride); + size_t m_idx, size_t n_idx, size_t dst_stride_row); /// Gets the size in bytes of the destination matrix buffer. /// @@ -104,12 +104,12 @@ size_t kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(si /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. -/// @param[in] k Common dimension of the LHS and RHS operands. -/// @param[in] packed_lhs Packed LHS matrix buffer. -/// @param[in] packed_rhs Packed RHS matrix buffer. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] lhs_packed Packed LHS matrix buffer. +/// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. /// @param[in] dst_stride_row Row stride in bytes of the output matrix. -/// @param[in] dst_stride_col Column stride in bytes of the output matrix. +/// @param[in] dst_stride_col Column stride in bytes of the output matrix. Must be 4 /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. void kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_asm.S new file mode 100644 index 00000000..23d32b15 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_asm.S @@ -0,0 +1,252 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x14, #0x0 + ptrue p0.b + KAI_ASM_INST(0x25207811) // ptrue pn9.b + ldr w13, [x0, #0x20] + ldr w11, [x0, #0x28] + mov x10, #0x0 + ldr x9, [x0, #0x0] +KAI_ASM_LABEL(label_1) // M loop + ldr x28, [x0, #0x8] +KAI_ASM_LABEL(label_2) // N loop + KAI_ASM_INST(0x25ab4550) // whilelt pn8.s, x10, x11, VLx2 + fmov z13.s, #1.0 + KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } + mov x27, x9 + KAI_ASM_INST(0xa040438e) // ld1w { z14.s-z15.s }, p8/Z, [x28] // Load bias + addvl x28, x28, #2 + KAI_ASM_INST(0x808e01a0) // fmopa za0.s, p0/M, p0/M, z13.s, z14.s + KAI_ASM_INST(0x808f01a1) // fmopa za1.s, p0/M, p0/M, z13.s, z15.s + KAI_ASM_INST(0x808e01a2) // fmopa za2.s, p0/M, p0/M, z13.s, z14.s + KAI_ASM_INST(0x808f01a3) // fmopa za3.s, p0/M, p0/M, z13.s, z15.s + ldr x20, [x0, #0x30] + lsr x21, x20, #0x2 + and x20, x20, #0x3 + cbz x21, label_6 + subs x21, x21, #0x1 + KAI_ASM_INST(0xa1404772) // ld1w { z18.s, z26.s }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa0404794) // ld1w { z20.s-z21.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xa1414764) // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL] + KAI_ASM_INST(0xa041478a) // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL] + KAI_ASM_INST(0xa1424773) // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL] + KAI_ASM_INST(0xa0424798) // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL] + KAI_ASM_INST(0xa043476e) // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL] + addvl x27, x27, #8 + KAI_ASM_INST(0xa1434796) // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL] + addvl x28, x28, #8 + ble label_5 +KAI_ASM_LABEL(label_4) // K loop + KAI_ASM_INST(0x80940240) // fmopa za0.s, p0/M, p0/M, z18.s, z20.s + subs x21, x21, #0x1 + KAI_ASM_INST(0x80950241) // fmopa za1.s, p0/M, p0/M, z18.s, z21.s + KAI_ASM_INST(0x80940342) // fmopa za2.s, p0/M, p0/M, z26.s, z20.s + KAI_ASM_INST(0x80950343) // fmopa za3.s, p0/M, p0/M, z26.s, z21.s + KAI_ASM_INST(0xa1404772) // ld1w { z18.s, z26.s }, pn9.b/Z, [x27] + KAI_ASM_INST(0x808a0080) // fmopa za0.s, p0/M, p0/M, z4.s, z10.s + KAI_ASM_INST(0xa0404794) // ld1w { z20.s-z21.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0x808b0081) // fmopa za1.s, p0/M, p0/M, z4.s, z11.s + KAI_ASM_INST(0x808a0182) // fmopa za2.s, p0/M, p0/M, z12.s, z10.s + KAI_ASM_INST(0x808b0183) // fmopa za3.s, p0/M, p0/M, z12.s, z11.s + KAI_ASM_INST(0xa1414764) // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL] + KAI_ASM_INST(0x80980260) // fmopa za0.s, p0/M, p0/M, z19.s, z24.s + KAI_ASM_INST(0xa041478a) // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL] + KAI_ASM_INST(0x80990261) // fmopa za1.s, p0/M, p0/M, z19.s, z25.s + KAI_ASM_INST(0x80980362) // fmopa za2.s, p0/M, p0/M, z27.s, z24.s + KAI_ASM_INST(0x80990363) // fmopa za3.s, p0/M, p0/M, z27.s, z25.s + KAI_ASM_INST(0xa1424773) // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL] + KAI_ASM_INST(0xa0424798) // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL] + KAI_ASM_INST(0x809601c0) // fmopa za0.s, p0/M, p0/M, z14.s, z22.s + KAI_ASM_INST(0x809e01c1) // fmopa za1.s, p0/M, p0/M, z14.s, z30.s + KAI_ASM_INST(0x809601e2) // fmopa za2.s, p0/M, p0/M, z15.s, z22.s + KAI_ASM_INST(0x809e01e3) // fmopa za3.s, p0/M, p0/M, z15.s, z30.s + KAI_ASM_INST(0xa043476e) // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL] + addvl x27, x27, #8 + KAI_ASM_INST(0xa1434796) // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL] + addvl x28, x28, #8 + bgt label_4 +KAI_ASM_LABEL(label_5) // K loop tail + KAI_ASM_INST(0x80940240) // fmopa za0.s, p0/M, p0/M, z18.s, z20.s + KAI_ASM_INST(0x80950241) // fmopa za1.s, p0/M, p0/M, z18.s, z21.s + KAI_ASM_INST(0x80940342) // fmopa za2.s, p0/M, p0/M, z26.s, z20.s + KAI_ASM_INST(0x80950343) // fmopa za3.s, p0/M, p0/M, z26.s, z21.s + KAI_ASM_INST(0x808a0080) // fmopa za0.s, p0/M, p0/M, z4.s, z10.s + KAI_ASM_INST(0x808b0081) // fmopa za1.s, p0/M, p0/M, z4.s, z11.s + KAI_ASM_INST(0x808a0182) // fmopa za2.s, p0/M, p0/M, z12.s, z10.s + KAI_ASM_INST(0x808b0183) // fmopa za3.s, p0/M, p0/M, z12.s, z11.s + KAI_ASM_INST(0x80980260) // fmopa za0.s, p0/M, p0/M, z19.s, z24.s + KAI_ASM_INST(0x80990261) // fmopa za1.s, p0/M, p0/M, z19.s, z25.s + KAI_ASM_INST(0x80980362) // fmopa za2.s, p0/M, p0/M, z27.s, z24.s + KAI_ASM_INST(0x80990363) // fmopa za3.s, p0/M, p0/M, z27.s, z25.s + KAI_ASM_INST(0x809601c0) // fmopa za0.s, p0/M, p0/M, z14.s, z22.s + KAI_ASM_INST(0x809e01c1) // fmopa za1.s, p0/M, p0/M, z14.s, z30.s + KAI_ASM_INST(0x809601e2) // fmopa za2.s, p0/M, p0/M, z15.s, z22.s + KAI_ASM_INST(0x809e01e3) // fmopa za3.s, p0/M, p0/M, z15.s, z30.s +KAI_ASM_LABEL(label_6) // K oddments + cbz x20, label_8 +KAI_ASM_LABEL(label_7) // K oddments: Loop + KAI_ASM_INST(0xa040477c) // ld1w { z28.s-z29.s }, pn9.b/Z, [x27] + subs x20, x20, #0x1 + addvl x27, x27, #2 + KAI_ASM_INST(0xa1404787) // ld1w { z7.s, z15.s }, pn9.b/Z, [x28] + addvl x28, x28, #2 + KAI_ASM_INST(0x80870380) // fmopa za0.s, p0/M, p0/M, z28.s, z7.s + KAI_ASM_INST(0x808f0381) // fmopa za1.s, p0/M, p0/M, z28.s, z15.s + KAI_ASM_INST(0x808703a2) // fmopa za2.s, p0/M, p0/M, z29.s, z7.s + KAI_ASM_INST(0x808f03a3) // fmopa za3.s, p0/M, p0/M, z29.s, z15.s + bgt label_7 +KAI_ASM_LABEL(label_8) // K oddments: End + ldr x26, [x0, #0x10] + sub x25, x13, x14 + cntw x24 + KAI_ASM_INST(0x854ec013) // ld1rw { z19.s }, p0/Z, [x0, #56] + ldr x23, [x0, #0x18] + cmp x25, x24 + KAI_ASM_INST(0x854fc01a) // ld1rw { z26.s }, p0/Z, [x0, #60] + mov x12, #0x0 + csel x22, x25, x24, LT + add x26, x26, x10, LSL #2 // C += n + lsr x21, x22, #0x2 + madd x26, x14, x23, x26 // C += m * ldc + and x20, x22, #0x3 + cbz x21, label_11 +KAI_ASM_LABEL(label_10) // Store to output array: Accumulator row 0 loop + KAI_ASM_INST(0xc0860404) // mova { z4.s-z7.s }, za0h.s[x12] + KAI_ASM_INST(0xc086042c) // mova { z12.s-z15.s }, za1h.s[x12] + KAI_ASM_INST(0xc1baca64) // fclamp { z4.s-z7.s }, z19.s, z26.s + KAI_ASM_INST(0xc1baca6c) // fclamp { z12.s-z15.s }, z19.s, z26.s + add x12, x12, #0x4 + cmp x12, x21, LSL #2 + KAI_ASM_INST(0xa1604344) // st1w { z4.s, z12.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604345) // st1w { z5.s, z13.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604346) // st1w { z6.s, z14.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604347) // st1w { z7.s, z15.s }, p8, [x26] + add x26, x26, x23 + blt label_10 +KAI_ASM_LABEL(label_11) // Store to output array: Accumulator row 0 oddments + cbz x20, label_12 + KAI_ASM_INST(0xc0860400) // mova { z0.s-z3.s }, za0h.s[x12] + KAI_ASM_INST(0xc0860428) // mova { z8.s-z11.s }, za1h.s[x12] + subs x20, x20, #0x1 + KAI_ASM_INST(0xc1baca60) // fclamp { z0.s-z3.s }, z19.s, z26.s + KAI_ASM_INST(0xc1baca68) // fclamp { z8.s-z11.s }, z19.s, z26.s + KAI_ASM_INST(0xa1604340) // st1w { z0.s, z8.s }, p8, [x26] + add x26, x26, x23 + beq label_12 + subs x20, x20, #0x1 + KAI_ASM_INST(0xa1604341) // st1w { z1.s, z9.s }, p8, [x26] + add x26, x26, x23 + beq label_12 + KAI_ASM_INST(0xa1604342) // st1w { z2.s, z10.s }, p8, [x26] + add x26, x26, x23 +KAI_ASM_LABEL(label_12) // Store to output array: Accumulator row 0 oddments: End + subs x25, x25, x22 + beq label_16 + cmp x25, x24 + mov x12, #0x0 + csel x20, x25, x24, LT + lsr x21, x20, #0x2 + and x20, x20, #0x3 + cbz x21, label_14 +KAI_ASM_LABEL(label_13) // Store to output array: Accumulator row 1 loop + KAI_ASM_INST(0xc0860454) // mova { z20.s-z23.s }, za2h.s[x12] + KAI_ASM_INST(0xc086047c) // mova { z28.s-z31.s }, za3h.s[x12] + KAI_ASM_INST(0xc1baca74) // fclamp { z20.s-z23.s }, z19.s, z26.s + KAI_ASM_INST(0xc1baca7c) // fclamp { z28.s-z31.s }, z19.s, z26.s + add x12, x12, #0x4 + cmp x12, x21, LSL #2 + KAI_ASM_INST(0xa1604354) // st1w { z20.s, z28.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604355) // st1w { z21.s, z29.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604356) // st1w { z22.s, z30.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604357) // st1w { z23.s, z31.s }, p8, [x26] + add x26, x26, x23 + blt label_13 +KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments + cbz x20, label_15 + KAI_ASM_INST(0xc0860444) // mova { z4.s-z7.s }, za2h.s[x12] + KAI_ASM_INST(0xc086046c) // mova { z12.s-z15.s }, za3h.s[x12] + subs x20, x20, #0x1 + KAI_ASM_INST(0xc1baca64) // fclamp { z4.s-z7.s }, z19.s, z26.s + KAI_ASM_INST(0xc1baca6c) // fclamp { z12.s-z15.s }, z19.s, z26.s + KAI_ASM_INST(0xa1604344) // st1w { z4.s, z12.s }, p8, [x26] + add x26, x26, x23 + beq label_15 + subs x20, x20, #0x1 + KAI_ASM_INST(0xa1604345) // st1w { z5.s, z13.s }, p8, [x26] + add x26, x26, x23 + beq label_15 + KAI_ASM_INST(0xa1604346) // st1w { z6.s, z14.s }, p8, [x26] +KAI_ASM_LABEL(label_15) // Store to output array: Accumulator row 1 oddments: End +KAI_ASM_LABEL(label_16) // Store to output array: End + incw x10, ALL, MUL #2 + cmp x10, x11 + blt label_2 + incw x14, ALL, MUL #2 + mov x10, #0x0 + cmp x14, x13 + mov x9, x27 + blt label_1 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c index e2636770..434e55e7 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c @@ -4,10 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. @@ -18,327 +15,120 @@ #include "kai/kai_common.h" -static const size_t kai_mr = 2; -static const size_t kai_kr = 1; -static const size_t kai_sr = 1; +typedef struct { + size_t m; + size_t k; + size_t mr; + size_t kr; + size_t sr; + size_t m_idx_start; + const void* lhs; + size_t lhs_stride; + void* lhs_packed; + size_t height; + size_t width; + const void* const* in; + size_t row_offset; + void* out; +} KernelArgs; + +void kai_kernel_lhs_pack_f32p2vlx1_f32_sme(const KernelArgs* args_ptr); + +enum { + MR = 2, + KR = 1, + MAX_M_STEP = (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(float)) / KR), + SR = 1, +}; + +static size_t kai_get_mr_lhs_pack_f32p2vlx1_f32_sme(void) { + return MR * kai_get_sme_vector_length_u32() / KR; +} size_t kai_get_m_step_lhs_pack_f32p2vlx1_f32_sme(size_t mr) { - KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u32()); + KAI_ASSUME(mr == kai_get_mr_lhs_pack_f32p2vlx1_f32_sme()); KAI_UNUSED(mr); - return kai_mr * kai_get_sme_vector_length_u32(); + return kai_get_mr_lhs_pack_f32p2vlx1_f32_sme(); } size_t kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t lhs_stride) { - KAI_ASSUME(m_idx % (kai_mr * kai_get_sme_vector_length_u32()) == 0); + KAI_ASSUME(m_idx % kai_get_mr_lhs_pack_f32p2vlx1_f32_sme() == 0); return m_idx * lhs_stride; } size_t kai_get_lhs_packed_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { - const size_t scaled_mr = kai_mr * kai_get_sme_vector_length_u32(); - KAI_ASSUME(m_idx % scaled_mr == 0); - KAI_ASSUME(mr == scaled_mr); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(m_idx % kai_get_m_step_lhs_pack_f32p2vlx1_f32_sme(mr) == 0); + KAI_ASSUME(mr == kai_get_mr_lhs_pack_f32p2vlx1_f32_sme()); + KAI_ASSUME(kr == KR); + KAI_ASSUME(sr == SR); KAI_UNUSED(mr); KAI_UNUSED(kr); KAI_UNUSED(sr); - return m_idx * k * sizeof(float); + return m_idx * kai_roundup(k, kr) * sizeof(float); } size_t kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { - KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u32()); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(mr == kai_get_mr_lhs_pack_f32p2vlx1_f32_sme()); + KAI_ASSUME(kr == KR); + KAI_ASSUME(sr == SR); KAI_UNUSED(mr); KAI_UNUSED(kr); KAI_UNUSED(sr); - return kai_roundup(m, kai_mr * kai_get_sme_vector_length_u32()) * k * sizeof(float); + return kai_roundup(m, kai_get_mr_lhs_pack_f32p2vlx1_f32_sme()) * kai_roundup(k, KR) * sizeof(float); } void kai_run_lhs_pack_f32p2vlx1_f32_sme( size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { - KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u32()); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(mr == kai_get_mr_lhs_pack_f32p2vlx1_f32_sme()); + KAI_ASSUME(kr == KR); + KAI_ASSUME(sr == SR); KAI_ASSUME(lhs != NULL); KAI_ASSUME(lhs_packed != NULL); KAI_ASSUME(m_idx_start == 0); - const size_t block_height = kai_mr * kai_get_sme_vector_length_u32(); + const size_t m_step = kai_get_mr_lhs_pack_f32p2vlx1_f32_sme(); + const size_t block_height = mr; const size_t width = k; const size_t row_offset = 0; - const void* in[block_height]; + KAI_ASSERT(m_step <= MAX_M_STEP); + const void* in[MAX_M_STEP]; uint8_t* lhs_packed_ptr = lhs_packed; const uint8_t* lhs_ptr = lhs; for (size_t block_y = 0; block_y < m; block_y += block_height) { const size_t height = KAI_MIN(m - block_y, block_height); - void* out = (void*)((char*)lhs_packed_ptr + block_y * k * sizeof(float)); + void* out = lhs_packed_ptr + block_y * kai_roundup(k, KR) * sizeof(float); for (size_t y = 0; y < height; y++) { in[y] = lhs_ptr + (block_y + y) * lhs_stride; } - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x21, %x[width]\n" - "mov x20, %x[width]\n" - "incw x21\n" - "cntw x17\n" - "sub x21, x21, #0x1\n" - "sub x16, x17, #0x1\n" - "udiv x21, x21, x17\n" // n_passes = ceildiv(width, VL) - "ands x16, x20, x16\n" - "sub x20, x21, #0x1\n" - "sub x15, x17, #0x2\n" - "mov x14, #0x0\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - "cntw x9, ALL, MUL #2\n" - "ldr x28, [x11, #0x0]\n" - "cntw x27, ALL, MUL #3\n" - "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 - "ldr x26, [x10, #0x0]\n" - "and x25, x21, #0x1\n" // odd_tail = bool(n_passes & 0x1) - "csel x16, x16, x17, NE\n" - "ldr x24, [x11, #0x8]\n" - "ptrue p12.s\n" - "whilelt p11.s, XZR, %x[height]\n" - "ldr x21, [x10, #0x8]\n" - "whilelt p10.s, x17, %x[height]\n" - "mov x23, %x[row_offset]\n" - "mov x22, %x[out]\n" - "whilelt p9.s, x14, %x[width]\n" - "whilelt p8.s, x14, %x[width]\n" - "add x11, x11, #0x10\n" - "add x10, x10, #0x10\n" - "mov x12, #0x0\n" - "cbz x15, 2f\n" - "1:" // K loop: Charge: Loop - ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" - ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" - "add x12, x12, #0x2\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "cmp x12, x15\n" - "blt 1b\n" - "2:" // K loop: Charge: End - ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" - "ldr x28, [x11, #0x0]\n" - "incw x14\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "incw x23\n" - "cbz x20, 8f\n" - "mov x20, x20\n" - "3:" // K loop: Main loop - "whilelt p8.s, x14, %x[width]\n" - "mov x13, #0x0\n" - "cbz x15, 5f\n" - "4:" // K loop: Main loop: First: Loop - ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" - ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" - ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" - ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" - ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" - ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "add x13, x13, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x13, x15\n" - "blt 4b\n" - "5:" // K loop: Main loop: First: Tail - ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" - ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" - ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" - ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" - "ldr x28, [x11, #0x0]\n" - "mov x12, #0x0\n" - ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" - ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" - "whilelt p9.s, x14, %x[width]\n" - "incw x14\n" - ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "incw x23\n" - ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "addvl x22, x22, #4\n" - "whilelt p8.s, x14, %x[width]\n" - "cbz x15, 7f\n" - "6:" // K loop: Main loop: Second: Loop - ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" - ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" - ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "add x12, x12, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x12, x15\n" - "blt 6b\n" - "7:" // K loop: Main loop: Second: Tail - ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" - ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" - "whilelt p9.s, x14, %x[width]\n" - "subs x20, x20, #0x1\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "incw x14\n" - ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "addvl x22, x22, #4\n" - "incw x23\n" - "bgt 3b\n" - "8:" // K loop: Tails - "cbnz x25, 11f\n" - "mov x11, %x[in]\n" - "whilelt p8.s, x14, %x[width]\n" - "mov x12, #0x0\n" - "9:" // K loop: Tails: Even: First - ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25306161 // psel p1.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306140 // psel p0.s, p8.s/Z, p10.s[w12]\n" - ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0b18ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" - "addvl x22, x22, #2\n" - "ldr x21, [x11, #0x0]\n" - "ldr x20, [x11, x17, LSL #0x3]\n" - "add x11, x11, #0x8\n" - ".inst 0xe09706a8 // ld1w { za2h.s[x12] }, p1/Z, [x21, x23, LSL #2]\n" - ".inst 0xe097028c // ld1w { za3h.s[x12] }, p0/Z, [x20, x23, LSL #2]\n" - "add x12, x12, #0x1\n" - "cmp x12, x17\n" - "blt 9b\n" - "whilelt p9.s, x14, %x[width]\n" - "whilelt p8.s, x14, %x[width]\n" - "mov x12, #0x0\n" - "10:" // K loop: Tails: Even: Second - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0b182cc // st1w { za3v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x16\n" - "blt 10b\n" - "whilelt p8.s, x14, %x[width]\n" - "b 13f\n" - "11:" // K loop: Tails: Odd - "mov x12, #0x0\n" - "12:" // K loop: Tails: Odd: Loop - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0b182c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x16\n" - "blt 12b\n" - "13:" // K loop: End - "mov %x[out], x22\n" - ".inst 0xd503467f // SMSTOP\n" - : [out] "+&r"(out) - : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", - "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", - "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", - "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", - "z26", "z27", "z28", "z29", "z30", "z31"); + KernelArgs args; + args.m = m; + args.k = k; + args.mr = mr; + args.kr = kr; + args.sr = sr; + args.m_idx_start = m_idx_start; + args.lhs = lhs; + args.lhs_stride = lhs_stride; + args.lhs_packed = lhs_packed; + args.height = height; + args.width = width; + args.in = in; + args.row_offset = row_offset; + args.out = out; + + kai_kernel_lhs_pack_f32p2vlx1_f32_sme(&args); } } diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h index 82c5db48..540f24f7 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -61,7 +61,7 @@ size_t kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(size_t m, size_t k, si /// /// @param[in] m Number of rows of the unpacked LHS matrix. /// @param[in] k Common dimension between the LHS and RHS matrix. -/// @param[in] mr Block size in M dimension. It must be 2 * kai_get_sme_vector_length_u32(). +/// @param[in] mr Block size in M dimension. It must be kai_get_m_step_lhs_pack_f32p2vlx1_f32_sme(). /// @param[in] kr Block size in K dimension. It must be 1. /// @param[in] sr Number of kr splits. It must be 1. /// @param[in] m_idx_start Unused. Must be 0. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S new file mode 100644 index 00000000..4439e11a --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S @@ -0,0 +1,304 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(lhs_pack_f32p2vlx1_f32_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_lhs_pack_f32p2vlx1_f32_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_pack_f32p2vlx1_f32_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_pack_f32p2vlx1_f32_sme) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x6, #0x0 + ldr x7, [x0, #0x50] + cntw x8 + cntw x17, ALL, MUL #2 + ldr x16, [x0, #0x58] + sub x15, x8, #0x1 + sub x14, x8, #0x2 + ldr x11, [x0, #0x48] + cntw x10, ALL, MUL #3 + ptrue p12.s + mov x21, x7 + ldr x22, [x0, #0x60] + mov x20, x7 + incw x21 + ands x15, x20, x15 + ldr x9, [x0, #0x68] + sub x21, x21, #0x1 + mov x28, x16 + udiv x21, x21, x8 // n_passes = ceildiv(width, VL) + add x27, x16, x8, LSL #3 + ldr x26, [x28, #0x0] + sub x20, x21, #0x1 + and x25, x21, #0x1 // odd_tail = bool(n_passes & 0x1) + ldr x24, [x27, #0x0] + lsr x20, x20, #0x1 // n_loops = (n_passes - 1) / 2 + csel x15, x15, x8, NE + ldr x23, [x28, #0x8] + whilelt p11.s, XZR, x11 + whilelt p10.s, x8, x11 + ldr x21, [x27, #0x8] + mov x22, x22 + whilelt p9.s, x6, x7 + whilelt p8.s, x6, x7 + add x28, x28, #0x10 + add x27, x27, #0x10 + mov x12, #0x0 + cbz x14, label_2 +KAI_ASM_LABEL(label_1) // K loop: Charge: Loop + KAI_ASM_INST(0x25306163) // psel p3.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706140) // psel p0.s, p8.s/Z, p10.s[w12, #1] + KAI_ASM_INST(0xe0960f40) // ld1w { za0h.s[x12] }, p3/Z, [x26, x22, LSL #2] + ldr x26, [x28, #0x0] + KAI_ASM_INST(0xe0960b04) // ld1w { za1h.s[x12] }, p2/Z, [x24, x22, LSL #2] + ldr x24, [x27, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + ldr x23, [x28, #0x8] + add x28, x28, #0x10 + KAI_ASM_INST(0xe09602a5) // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x22, LSL #2] + add x12, x12, #0x2 + ldr x21, [x27, #0x8] + add x27, x27, #0x10 + cmp x12, x14 + blt label_1 +KAI_ASM_LABEL(label_2) // K loop: Charge: End + KAI_ASM_INST(0x25306163) // psel p3.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706140) // psel p0.s, p8.s/Z, p10.s[w12, #1] + mov x28, x16 + add x27, x16, x8, LSL #3 + KAI_ASM_INST(0xe0960f40) // ld1w { za0h.s[x12] }, p3/Z, [x26, x22, LSL #2] + ldr x26, [x28, #0x0] + incw x6 + KAI_ASM_INST(0xe0960b04) // ld1w { za1h.s[x12] }, p2/Z, [x24, x22, LSL #2] + ldr x24, [x27, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + ldr x23, [x28, #0x8] + add x28, x28, #0x10 + KAI_ASM_INST(0xe09602a5) // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x22, LSL #2] + ldr x21, [x27, #0x8] + add x27, x27, #0x10 + incw x22 + cbz x20, label_8 + mov x20, x20 +KAI_ASM_LABEL(label_3) // K loop: Main loop + whilelt p8.s, x6, x7 + mov x13, #0x0 + cbz x14, label_5 +KAI_ASM_LABEL(label_4) // K loop: Main loop: First: Loop + KAI_ASM_INST(0x25316160) // psel p0.s, p8.s/Z, p11.s[w13] + KAI_ASM_INST(0x25316142) // psel p2.s, p8.s/Z, p10.s[w13] + KAI_ASM_INST(0x25716161) // psel p1.s, p8.s/Z, p11.s[w13, #1] + KAI_ASM_INST(0x25716143) // psel p3.s, p8.s/Z, p10.s[w13, #1] + KAI_ASM_INST(0xe0962348) // ld1w { za2h.s[x13] }, p0/Z, [x26, x22, LSL #2] + KAI_ASM_INST(0x25317120) // psel p0.s, p12.s/Z, p9.s[w13] + ldr x26, [x28, #0x0] + KAI_ASM_INST(0xe0962b0c) // ld1w { za3h.s[x13] }, p2/Z, [x24, x22, LSL #2] + KAI_ASM_INST(0x25317122) // psel p2.s, p12.s/Z, p9.s[w13] + ldr x24, [x27, #0x0] + KAI_ASM_INST(0xe09626e9) // ld1w { za2h.s[x13, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25717121) // psel p1.s, p12.s/Z, p9.s[w13, #1] + ldr x23, [x28, #0x8] + add x28, x28, #0x10 + KAI_ASM_INST(0xe0962ead) // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x27, #0x8] + KAI_ASM_INST(0xe0bfa120) // st1w { za0v.s[x13] }, p0/Z, [x9, XZR, LSL #2] + KAI_ASM_INST(0x25717120) // psel p0.s, p12.s/Z, p9.s[w13, #1] + KAI_ASM_INST(0xe0a8a924) // st1w { za1v.s[x13] }, p2/Z, [x9, x8, LSL #2] + add x27, x27, #0x10 + KAI_ASM_INST(0xe0b1a521) // st1w { za0v.s[x13, #1] }, p1/Z, [x9, x17, LSL #2] + KAI_ASM_INST(0xe0aaa125) // st1w { za1v.s[x13, #1] }, p0/Z, [x9, x10, LSL #2] + add x13, x13, #0x2 + addvl x9, x9, #4 + cmp x13, x14 + blt label_4 +KAI_ASM_LABEL(label_5) // K loop: Main loop: First: Tail + KAI_ASM_INST(0x25316160) // psel p0.s, p8.s/Z, p11.s[w13] + KAI_ASM_INST(0x25316142) // psel p2.s, p8.s/Z, p10.s[w13] + KAI_ASM_INST(0x25716161) // psel p1.s, p8.s/Z, p11.s[w13, #1] + KAI_ASM_INST(0x25716143) // psel p3.s, p8.s/Z, p10.s[w13, #1] + mov x28, x16 + add x27, x16, x8, LSL #3 + KAI_ASM_INST(0xe0962348) // ld1w { za2h.s[x13] }, p0/Z, [x26, x22, LSL #2] + KAI_ASM_INST(0x25317120) // psel p0.s, p12.s/Z, p9.s[w13] + ldr x26, [x28, #0x0] + mov x12, #0x0 + KAI_ASM_INST(0xe0962b0c) // ld1w { za3h.s[x13] }, p2/Z, [x24, x22, LSL #2] + KAI_ASM_INST(0x25317122) // psel p2.s, p12.s/Z, p9.s[w13] + ldr x24, [x27, #0x0] + KAI_ASM_INST(0xe09626e9) // ld1w { za2h.s[x13, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25717121) // psel p1.s, p12.s/Z, p9.s[w13, #1] + ldr x23, [x28, #0x8] + add x28, x28, #0x10 + KAI_ASM_INST(0xe0962ead) // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x27, #0x8] + KAI_ASM_INST(0xe0bfa120) // st1w { za0v.s[x13] }, p0/Z, [x9, XZR, LSL #2] + KAI_ASM_INST(0x25717120) // psel p0.s, p12.s/Z, p9.s[w13, #1] + KAI_ASM_INST(0xe0a8a924) // st1w { za1v.s[x13] }, p2/Z, [x9, x8, LSL #2] + whilelt p9.s, x6, x7 + incw x6 + KAI_ASM_INST(0xe0b1a521) // st1w { za0v.s[x13, #1] }, p1/Z, [x9, x17, LSL #2] + add x27, x27, #0x10 + incw x22 + KAI_ASM_INST(0xe0aaa125) // st1w { za1v.s[x13, #1] }, p0/Z, [x9, x10, LSL #2] + addvl x9, x9, #4 + whilelt p8.s, x6, x7 + cbz x14, label_7 +KAI_ASM_LABEL(label_6) // K loop: Main loop: Second: Loop + KAI_ASM_INST(0x25306160) // psel p0.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706143) // psel p3.s, p8.s/Z, p10.s[w12, #1] + KAI_ASM_INST(0xe0960340) // ld1w { za0h.s[x12] }, p0/Z, [x26, x22, LSL #2] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + ldr x26, [x28, #0x0] + KAI_ASM_INST(0xe0960b04) // ld1w { za1h.s[x12] }, p2/Z, [x24, x22, LSL #2] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + ldr x24, [x27, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25707121) // psel p1.s, p12.s/Z, p9.s[w12, #1] + ldr x23, [x28, #0x8] + add x28, x28, #0x10 + KAI_ASM_INST(0xe0960ea5) // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x27, #0x8] + KAI_ASM_INST(0xe0bf8128) // st1w { za2v.s[x12] }, p0/Z, [x9, XZR, LSL #2] + KAI_ASM_INST(0x25707120) // psel p0.s, p12.s/Z, p9.s[w12, #1] + KAI_ASM_INST(0xe0a8892c) // st1w { za3v.s[x12] }, p2/Z, [x9, x8, LSL #2] + add x27, x27, #0x10 + KAI_ASM_INST(0xe0b18529) // st1w { za2v.s[x12, #1] }, p1/Z, [x9, x17, LSL #2] + KAI_ASM_INST(0xe0aa812d) // st1w { za3v.s[x12, #1] }, p0/Z, [x9, x10, LSL #2] + add x12, x12, #0x2 + addvl x9, x9, #4 + cmp x12, x14 + blt label_6 +KAI_ASM_LABEL(label_7) // K loop: Main loop: Second: Tail + KAI_ASM_INST(0x25306160) // psel p0.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706143) // psel p3.s, p8.s/Z, p10.s[w12, #1] + mov x28, x16 + add x27, x16, x8, LSL #3 + KAI_ASM_INST(0xe0960340) // ld1w { za0h.s[x12] }, p0/Z, [x26, x22, LSL #2] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + ldr x26, [x28, #0x0] + KAI_ASM_INST(0xe0960b04) // ld1w { za1h.s[x12] }, p2/Z, [x24, x22, LSL #2] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + ldr x24, [x27, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25707121) // psel p1.s, p12.s/Z, p9.s[w12, #1] + ldr x23, [x28, #0x8] + add x28, x28, #0x10 + KAI_ASM_INST(0xe0960ea5) // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x27, #0x8] + KAI_ASM_INST(0xe0bf8128) // st1w { za2v.s[x12] }, p0/Z, [x9, XZR, LSL #2] + KAI_ASM_INST(0x25707120) // psel p0.s, p12.s/Z, p9.s[w12, #1] + KAI_ASM_INST(0xe0a8892c) // st1w { za3v.s[x12] }, p2/Z, [x9, x8, LSL #2] + whilelt p9.s, x6, x7 + subs x20, x20, #0x1 + KAI_ASM_INST(0xe0b18529) // st1w { za2v.s[x12, #1] }, p1/Z, [x9, x17, LSL #2] + add x27, x27, #0x10 + incw x6 + KAI_ASM_INST(0xe0aa812d) // st1w { za3v.s[x12, #1] }, p0/Z, [x9, x10, LSL #2] + addvl x9, x9, #4 + incw x22 + bgt label_3 +KAI_ASM_LABEL(label_8) // K loop: Tails + cbnz x25, label_11 + mov x28, x16 + whilelt p8.s, x6, x7 + mov x12, #0x0 +KAI_ASM_LABEL(label_9) // K loop: Tails: Even: First + KAI_ASM_INST(0x25307123) // psel p3.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306161) // psel p1.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306140) // psel p0.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0xe0bf8d20) // st1w { za0v.s[x12] }, p3/Z, [x9, XZR, LSL #2] + KAI_ASM_INST(0xe0a88924) // st1w { za1v.s[x12] }, p2/Z, [x9, x8, LSL #2] + addvl x9, x9, #2 + ldr x21, [x28, #0x0] + ldr x20, [x28, x8, LSL #0x3] + add x28, x28, #0x8 + KAI_ASM_INST(0xe09606a8) // ld1w { za2h.s[x12] }, p1/Z, [x21, x22, LSL #2] + KAI_ASM_INST(0xe096028c) // ld1w { za3h.s[x12] }, p0/Z, [x20, x22, LSL #2] + add x12, x12, #0x1 + cmp x12, x8 + blt label_9 + whilelt p9.s, x6, x7 + whilelt p8.s, x6, x7 + mov x12, #0x0 +KAI_ASM_LABEL(label_10) // K loop: Tails: Even: Second + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8528) // st1w { za2v.s[x12] }, p1/Z, [x9, XZR, LSL #2] + KAI_ASM_INST(0xe0a8812c) // st1w { za3v.s[x12] }, p0/Z, [x9, x8, LSL #2] + add x12, x12, #0x1 + addvl x9, x9, #2 + cmp x12, x15 + blt label_10 + whilelt p8.s, x6, x7 + b label_13 +KAI_ASM_LABEL(label_11) // K loop: Tails: Odd + mov x12, #0x0 +KAI_ASM_LABEL(label_12) // K loop: Tails: Odd: Loop + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8520) // st1w { za0v.s[x12] }, p1/Z, [x9, XZR, LSL #2] + KAI_ASM_INST(0xe0a88124) // st1w { za1v.s[x12] }, p0/Z, [x9, x8, LSL #2] + add x12, x12, #0x1 + addvl x9, x9, #2 + cmp x12, x15 + blt label_12 +KAI_ASM_LABEL(label_13) // K loop: End + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_lhs_pack_f32p2vlx1_f32_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c index 55e49474..9e3c013c 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c @@ -4,10 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" #include @@ -15,46 +14,64 @@ #include "kai/kai_common.h" -static const size_t kai_nr = 2; -static const size_t kai_kr = 1; -static const size_t kai_data_size_in_bytes = sizeof(uint32_t); -static const size_t kai_bias_size_in_bytes = sizeof(uint32_t); +enum { + NR = 2, + KR = 1, +}; + +typedef struct { + const void* bias_ptr; + size_t width; + size_t height; + size_t in_stride; + size_t out_stride; + const void* in; + void* out; +} KernelArgs; + +static const size_t kai_num_bytes_input = sizeof(uint32_t); +static const size_t kai_num_bytes_output = sizeof(uint32_t); +static const size_t kai_num_bytes_bias = sizeof(float); + +void kai_kernel_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(const KernelArgs* args_ptr); size_t kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return NR * kai_get_sme_vector_length_u32() / KR; } size_t kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx) { - KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u32()) == 0); + KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme() == 0); - return n_idx * kai_data_size_in_bytes; + return n_idx * kai_num_bytes_input; } size_t kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx) { - return n_idx * kai_bias_size_in_bytes; + return n_idx * kai_num_bytes_bias; } size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t k) { - return kai_nr * kai_get_sme_vector_length_u32() * (kai_bias_size_in_bytes + k * kai_data_size_in_bytes); + return kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme() * + (kai_num_bytes_bias + kai_roundup(k, KR) * kai_num_bytes_output); } size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t k) { - KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u32()) == 0); + KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme() == 0); - return n_idx * (kai_bias_size_in_bytes + k * kai_data_size_in_bytes); + const size_t block_idx = n_idx / kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(); + return block_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(k); } size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n, size_t k) { - return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( - kai_roundup(n, kai_nr * kai_get_sme_vector_length_u32()), k); + const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme()); + return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(n_nr_blocks, k); } void kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride_row, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { KAI_ASSUME(num_groups == 1); - KAI_ASSUME(nr == kai_nr * kai_get_sme_vector_length_u32()); - KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(nr == kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme()); + KAI_ASSUME(kr == KR); KAI_ASSUME(sr == 1); KAI_ASSUME(rhs != NULL); KAI_ASSUME(bias != NULL); @@ -63,109 +80,16 @@ void kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( KAI_ASSUME(extra_bytes == 0); KAI_ASSUME(params == NULL); - size_t height = k; - const size_t width = n; - const void* in = rhs; - void* out = rhs_packed; - const size_t in_stride = rhs_stride; - size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(height); - - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x22, %x[out]\n" - "mov x21, %x[width]\n" - "ptrue p2.b\n" - "1:" // Bias: Full loop - "mov x20, x21\n" - "decw x21, ALL, MUL #2\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "cmp x21, #0x0\n" - "ld1w { z17.s }, p1/Z, [%x[bias]]\n" - "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" - "incb %x[bias], ALL, MUL #2\n" - "st1w { z17.s }, p2, [x22]\n" - "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" - "add x22, x22, %x[out_stride]\n" - "bgt 1b\n" - "cmp %x[height], #0x4\n" - "incb %x[out], ALL, MUL #2\n" - "blt 5f\n" - "2:" // Main row loop: Head - "mov x26, %x[in]\n" - "mov x25, %x[out]\n" - "add x24, x26, %x[in_stride]\n" - "sub %x[height], %x[height], #0x4\n" - "add x23, x24, %x[in_stride]\n" - "mov x22, %x[width]\n" - "add x21, x23, %x[in_stride]\n" - "add %x[in], x21, %x[in_stride]\n" - "3:" // Main row loop: Column loop - "mov x20, x22\n" - "decw x22, ALL, MUL #2\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "cmp x22, #0x0\n" - "ld1w { z23.s }, p1/Z, [x26]\n" - "ld1w { z22.s }, p0/Z, [x26, #1, MUL VL]\n" - "addvl x26, x26, #2\n" - "ld1w { z21.s }, p1/Z, [x24]\n" - "ld1w { z20.s }, p0/Z, [x24, #1, MUL VL]\n" - "addvl x24, x24, #2\n" - "ld1w { z19.s }, p1/Z, [x23]\n" - "ld1w { z18.s }, p0/Z, [x23, #1, MUL VL]\n" - "addvl x23, x23, #2\n" - "ld1w { z17.s }, p1/Z, [x21]\n" - "ld1w { z16.s }, p0/Z, [x21, #1, MUL VL]\n" - "addvl x21, x21, #2\n" - "st1w { z23.s }, p2, [x25]\n" - "st1w { z22.s }, p2, [x25, #1, MUL VL]\n" - "st1w { z21.s }, p2, [x25, #2, MUL VL]\n" - "st1w { z20.s }, p2, [x25, #3, MUL VL]\n" - "st1w { z19.s }, p2, [x25, #4, MUL VL]\n" - "st1w { z18.s }, p2, [x25, #5, MUL VL]\n" - "st1w { z17.s }, p2, [x25, #6, MUL VL]\n" - "st1w { z16.s }, p2, [x25, #7, MUL VL]\n" - "add x25, x25, %x[out_stride]\n" - "bgt 3b\n" - "cmp %x[height], #0x4\n" - "addvl %x[out], %x[out], #8\n" - "bge 2b\n" - "cbz %x[height], 9f\n" - "5:" // Main loop skip - "6:" // Tail row loop: Head - "mov x26, %x[in]\n" - "mov x25, %x[out]\n" - "add %x[in], x26, %x[in_stride]\n" - "sub %x[height], %x[height], #0x1\n" - "mov x21, %x[width]\n" - "7:" // Tail row loop: Column loop - "mov x20, x21\n" - "decw x21, ALL, MUL #2\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "cmp x21, #0x0\n" - "ld1w { z17.s }, p1/Z, [x26]\n" - "ld1w { z16.s }, p0/Z, [x26, #1, MUL VL]\n" - "addvl x26, x26, #2\n" - "st1w { z17.s }, p2, [x25]\n" - "st1w { z16.s }, p2, [x25, #1, MUL VL]\n" - "add x25, x25, %x[out_stride]\n" - "bgt 7b\n" - "cmp %x[height], #0x1\n" - "addvl %x[out], %x[out], #2\n" - "bge 6b\n" - "9:" // Done - ".inst 0xd503467f // SMSTOP\n" - : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) - : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", - "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", - "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", - "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + KernelArgs args; + args.bias_ptr = bias; + args.height = k; + args.width = n; + args.in = rhs; + args.out = rhs_packed; + args.in_stride = rhs_stride_row; + args.out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(args.height); + + kai_kernel_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h index 602e24c2..600f465d 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -14,14 +14,14 @@ extern "C" { /// Gets n step value. /// -/// The starting row index must be divisible by `n_step`. +/// The starting column index must be divisible by `n_step`. /// /// @return The n step value. size_t kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. /// -/// @param[in] n_idx Column index. +/// @param[in] n_idx Column index. Must be divisible by `n_step` /// /// @return The offset in bytes to the data element. size_t kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx); @@ -42,16 +42,16 @@ size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_ /// Gets the offset in bytes to the data element in the packed RHS buffer. /// -/// @param[in] n_idx Row index. -/// @param[in] k Number of columns. +/// @param[in] n_idx Column index. Must be divisible by `n_step` +/// @param[in] k Number of rows. /// /// @return The offset in bytes to the data element. size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t k); /// Gets the size in bytes of the packed RHS buffer. /// -/// @param[in] n Number of rows. -/// @param[in] k Number of columns. +/// @param[in] n Number of columns. +/// @param[in] k Number of rows. /// /// @return The size in bytes of the packed RHS buffer. size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n, size_t k); @@ -62,24 +62,24 @@ size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t /// calculated using the following functions: /// /// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme. -/// * Bias: @ref kai_get_packed_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme. -/// * Output: @ref kai_get_dst_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme. +/// * Bias: @ref kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme. /// /// @param[in] num_groups Number of groups. It must be 1. /// @param[in] n Number of columns of the output matrix. -/// @param[in] k Common dimension between the LHS and RHS matrix. -/// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u32(). +/// @param[in] k Number of rows. +/// @param[in] nr Block size in N dimension. It must be `get_n_step` /// @param[in] kr Block size in K dimension. It must be 1. /// @param[in] sr Number of kr splits. It must be 1. -/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs_stride_row Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. /// @param[in] scale Scale data buffer. It must be NULL. /// @param[out] rhs_packed Packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. -/// @param[in] params Extra packing parameters. It must be NULL. +/// @param[in] params Packing parameters. It must be NULL. void kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride_row, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); #ifdef __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S new file mode 100644 index 00000000..387f50d8 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S @@ -0,0 +1,156 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x12, [x0, #0x8] + ptrue p2.b + ldr x11, [x0, #0x30] + ldr x24, [x0, #0x0] + ldr x23, [x0, #0x10] + mov x22, x12 + ldr x10, [x0, #0x18] + mov x21, x11 + ldr x9, [x0, #0x20] + ldr x28, [x0, #0x28] +KAI_ASM_LABEL(label_1) // Bias: Full loop + mov x20, x22 + decw x22, ALL, MUL #2 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + cmp x22, #0x0 + ld1w { z17.s }, p1/Z, [x24] + ld1w { z16.s }, p0/Z, [x24, #1, MUL VL] + incb x24, ALL, MUL #2 + st1w { z17.s }, p2, [x21] + st1w { z16.s }, p2, [x21, #1, MUL VL] + add x21, x21, x9 + bgt label_1 + mov x27, x23 + incb x11, ALL, MUL #2 + cmp x27, #0x4 + blt label_5 +KAI_ASM_LABEL(label_2) // Main row loop: Head + mov x26, x28 + mov x25, x11 + add x24, x26, x10 + sub x27, x27, #0x4 + add x23, x24, x10 + mov x22, x12 + add x21, x23, x10 + add x28, x21, x10 +KAI_ASM_LABEL(label_3) // Main row loop: Column loop + mov x20, x22 + decw x22, ALL, MUL #2 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + cmp x22, #0x0 + ld1w { z23.s }, p1/Z, [x26] + ld1w { z22.s }, p0/Z, [x26, #1, MUL VL] + addvl x26, x26, #2 + ld1w { z21.s }, p1/Z, [x24] + ld1w { z20.s }, p0/Z, [x24, #1, MUL VL] + addvl x24, x24, #2 + ld1w { z19.s }, p1/Z, [x23] + ld1w { z18.s }, p0/Z, [x23, #1, MUL VL] + addvl x23, x23, #2 + ld1w { z17.s }, p1/Z, [x21] + ld1w { z16.s }, p0/Z, [x21, #1, MUL VL] + addvl x21, x21, #2 + st1w { z23.s }, p2, [x25] + st1w { z22.s }, p2, [x25, #1, MUL VL] + st1w { z21.s }, p2, [x25, #2, MUL VL] + st1w { z20.s }, p2, [x25, #3, MUL VL] + st1w { z19.s }, p2, [x25, #4, MUL VL] + st1w { z18.s }, p2, [x25, #5, MUL VL] + st1w { z17.s }, p2, [x25, #6, MUL VL] + st1w { z16.s }, p2, [x25, #7, MUL VL] + add x25, x25, x9 + bgt label_3 + cmp x27, #0x4 + addvl x11, x11, #8 + bge label_2 + cbz x27, label_9 +KAI_ASM_LABEL(label_5) // Main loop skip +KAI_ASM_LABEL(label_6) // Tail row loop: Head + mov x26, x28 + cntw x22, ALL, MUL #8 + add x28, x26, x10 + mov x25, x11 + sub x27, x27, #0x1 + mov x21, x12 +KAI_ASM_LABEL(label_7) // Tail row loop: Column loop + mov x20, x21 + decw x21, ALL, MUL #2 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + cmp x21, #0x0 + ld1w { z17.s }, p1/Z, [x26] + ld1w { z16.s }, p0/Z, [x26, #1, MUL VL] + add x26, x26, x22 + st1w { z17.s }, p2, [x25] + st1w { z16.s }, p2, [x25, #1, MUL VL] + add x25, x25, x9 + bgt label_7 + cmp x27, #0x1 + addvl x11, x11, #2 + bge label_6 +KAI_ASM_LABEL(label_9) // Done + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme) + + KAI_ASM_END -- GitLab From 208590ead0679840d51689e689006f6ded64ed6a Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 4 Jun 2025 11:14:17 +0200 Subject: [PATCH 2/6] Fix assembly files not saving/restoring all neccessary registers. Signed-off-by: Jens Elofsson --- ..._f32_f32_f32p16vlx1b_1x16vl_sme2_mla_asm.S | 12 +- ...p_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_asm.S | 12 +- ...f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_asm.S | 162 +++++++++--------- .../pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S | 12 +- ...ack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S | 12 +- 5 files changed, 121 insertions(+), 89 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla_asm.S index 44f1aeee..4ee2ec11 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla) KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA mov x8, #0x0 ldr x11, [x0, #0x18] @@ -504,7 +508,11 @@ KAI_ASM_LABEL(label_28) // Exit ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_asm.S index ffdce16e..0e897be4 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla) KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA mov x8, #0x0 ldr x5, [x0, #0x18] @@ -756,7 +760,11 @@ KAI_ASM_LABEL(label_28) // Exit ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_asm.S index 23d32b15..75cefc19 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_asm.S @@ -42,94 +42,90 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa) KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA - mov x14, #0x0 + mov x15, #0x0 ptrue p0.b KAI_ASM_INST(0x25207811) // ptrue pn9.b + ldr x14, [x0, #0x30] ldr w13, [x0, #0x20] - ldr w11, [x0, #0x28] - mov x10, #0x0 + mov x11, #0x0 + ldr w10, [x0, #0x28] ldr x9, [x0, #0x0] KAI_ASM_LABEL(label_1) // M loop ldr x28, [x0, #0x8] KAI_ASM_LABEL(label_2) // N loop - KAI_ASM_INST(0x25ab4550) // whilelt pn8.s, x10, x11, VLx2 - fmov z13.s, #1.0 + fmov z22.s, #1.0 + KAI_ASM_INST(0xa040478c) // ld1w { z12.s-z13.s }, pn9.b/Z, [x28] KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } mov x27, x9 - KAI_ASM_INST(0xa040438e) // ld1w { z14.s-z15.s }, p8/Z, [x28] // Load bias + KAI_ASM_INST(0x25aa4570) // whilelt pn8.s, x11, x10, VLx2 addvl x28, x28, #2 - KAI_ASM_INST(0x808e01a0) // fmopa za0.s, p0/M, p0/M, z13.s, z14.s - KAI_ASM_INST(0x808f01a1) // fmopa za1.s, p0/M, p0/M, z13.s, z15.s - KAI_ASM_INST(0x808e01a2) // fmopa za2.s, p0/M, p0/M, z13.s, z14.s - KAI_ASM_INST(0x808f01a3) // fmopa za3.s, p0/M, p0/M, z13.s, z15.s - ldr x20, [x0, #0x30] - lsr x21, x20, #0x2 - and x20, x20, #0x3 + KAI_ASM_INST(0x808c02c0) // fmopa za0.s, p0/M, p0/M, z22.s, z12.s + KAI_ASM_INST(0x808d02c1) // fmopa za1.s, p0/M, p0/M, z22.s, z13.s + KAI_ASM_INST(0x808c02c2) // fmopa za2.s, p0/M, p0/M, z22.s, z12.s + KAI_ASM_INST(0x808d02c3) // fmopa za3.s, p0/M, p0/M, z22.s, z13.s + lsr x21, x14, #0x2 + and x20, x14, #0x3 cbz x21, label_6 subs x21, x21, #0x1 - KAI_ASM_INST(0xa1404772) // ld1w { z18.s, z26.s }, pn9.b/Z, [x27] - KAI_ASM_INST(0xa0404794) // ld1w { z20.s-z21.s }, pn9.b/Z, [x28] - KAI_ASM_INST(0xa1414764) // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL] - KAI_ASM_INST(0xa041478a) // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL] - KAI_ASM_INST(0xa1424773) // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL] - KAI_ASM_INST(0xa0424798) // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL] - KAI_ASM_INST(0xa043476e) // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL] + KAI_ASM_INST(0xa040c764) // ld1w { z4.s-z7.s }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa041c768) // ld1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x4, MUL VL] addvl x27, x27, #8 - KAI_ASM_INST(0xa1434796) // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL] + KAI_ASM_INST(0xa040c794) // ld1w { z20.s-z23.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xa041c78c) // ld1w { z12.s-z15.s }, pn9.b/Z, [x28, #0x4, MUL VL] addvl x28, x28, #8 ble label_5 KAI_ASM_LABEL(label_4) // K loop - KAI_ASM_INST(0x80940240) // fmopa za0.s, p0/M, p0/M, z18.s, z20.s + KAI_ASM_INST(0x80940080) // fmopa za0.s, p0/M, p0/M, z4.s, z20.s subs x21, x21, #0x1 - KAI_ASM_INST(0x80950241) // fmopa za1.s, p0/M, p0/M, z18.s, z21.s - KAI_ASM_INST(0x80940342) // fmopa za2.s, p0/M, p0/M, z26.s, z20.s - KAI_ASM_INST(0x80950343) // fmopa za3.s, p0/M, p0/M, z26.s, z21.s - KAI_ASM_INST(0xa1404772) // ld1w { z18.s, z26.s }, pn9.b/Z, [x27] - KAI_ASM_INST(0x808a0080) // fmopa za0.s, p0/M, p0/M, z4.s, z10.s - KAI_ASM_INST(0xa0404794) // ld1w { z20.s-z21.s }, pn9.b/Z, [x28] - KAI_ASM_INST(0x808b0081) // fmopa za1.s, p0/M, p0/M, z4.s, z11.s - KAI_ASM_INST(0x808a0182) // fmopa za2.s, p0/M, p0/M, z12.s, z10.s - KAI_ASM_INST(0x808b0183) // fmopa za3.s, p0/M, p0/M, z12.s, z11.s - KAI_ASM_INST(0xa1414764) // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL] - KAI_ASM_INST(0x80980260) // fmopa za0.s, p0/M, p0/M, z19.s, z24.s - KAI_ASM_INST(0xa041478a) // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL] - KAI_ASM_INST(0x80990261) // fmopa za1.s, p0/M, p0/M, z19.s, z25.s - KAI_ASM_INST(0x80980362) // fmopa za2.s, p0/M, p0/M, z27.s, z24.s - KAI_ASM_INST(0x80990363) // fmopa za3.s, p0/M, p0/M, z27.s, z25.s - KAI_ASM_INST(0xa1424773) // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL] - KAI_ASM_INST(0xa0424798) // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL] - KAI_ASM_INST(0x809601c0) // fmopa za0.s, p0/M, p0/M, z14.s, z22.s - KAI_ASM_INST(0x809e01c1) // fmopa za1.s, p0/M, p0/M, z14.s, z30.s - KAI_ASM_INST(0x809601e2) // fmopa za2.s, p0/M, p0/M, z15.s, z22.s - KAI_ASM_INST(0x809e01e3) // fmopa za3.s, p0/M, p0/M, z15.s, z30.s - KAI_ASM_INST(0xa043476e) // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL] + KAI_ASM_INST(0x80950081) // fmopa za1.s, p0/M, p0/M, z4.s, z21.s + KAI_ASM_INST(0x809400a2) // fmopa za2.s, p0/M, p0/M, z5.s, z20.s + KAI_ASM_INST(0x809500a3) // fmopa za3.s, p0/M, p0/M, z5.s, z21.s + KAI_ASM_INST(0x809600c0) // fmopa za0.s, p0/M, p0/M, z6.s, z22.s + KAI_ASM_INST(0x809700c1) // fmopa za1.s, p0/M, p0/M, z6.s, z23.s + KAI_ASM_INST(0x809600e2) // fmopa za2.s, p0/M, p0/M, z7.s, z22.s + KAI_ASM_INST(0x809700e3) // fmopa za3.s, p0/M, p0/M, z7.s, z23.s + KAI_ASM_INST(0xa040c764) // ld1w { z4.s-z7.s }, pn9.b/Z, [x27] + KAI_ASM_INST(0x808c0100) // fmopa za0.s, p0/M, p0/M, z8.s, z12.s + KAI_ASM_INST(0xa040c794) // ld1w { z20.s-z23.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0x808d0101) // fmopa za1.s, p0/M, p0/M, z8.s, z13.s + KAI_ASM_INST(0x808c0122) // fmopa za2.s, p0/M, p0/M, z9.s, z12.s + KAI_ASM_INST(0x808d0123) // fmopa za3.s, p0/M, p0/M, z9.s, z13.s + KAI_ASM_INST(0x808e0140) // fmopa za0.s, p0/M, p0/M, z10.s, z14.s + KAI_ASM_INST(0x808f0141) // fmopa za1.s, p0/M, p0/M, z10.s, z15.s + KAI_ASM_INST(0x808e0162) // fmopa za2.s, p0/M, p0/M, z11.s, z14.s + KAI_ASM_INST(0x808f0163) // fmopa za3.s, p0/M, p0/M, z11.s, z15.s + KAI_ASM_INST(0xa041c768) // ld1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x4, MUL VL] addvl x27, x27, #8 - KAI_ASM_INST(0xa1434796) // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL] + KAI_ASM_INST(0xa041c78c) // ld1w { z12.s-z15.s }, pn9.b/Z, [x28, #0x4, MUL VL] addvl x28, x28, #8 bgt label_4 KAI_ASM_LABEL(label_5) // K loop tail - KAI_ASM_INST(0x80940240) // fmopa za0.s, p0/M, p0/M, z18.s, z20.s - KAI_ASM_INST(0x80950241) // fmopa za1.s, p0/M, p0/M, z18.s, z21.s - KAI_ASM_INST(0x80940342) // fmopa za2.s, p0/M, p0/M, z26.s, z20.s - KAI_ASM_INST(0x80950343) // fmopa za3.s, p0/M, p0/M, z26.s, z21.s - KAI_ASM_INST(0x808a0080) // fmopa za0.s, p0/M, p0/M, z4.s, z10.s - KAI_ASM_INST(0x808b0081) // fmopa za1.s, p0/M, p0/M, z4.s, z11.s - KAI_ASM_INST(0x808a0182) // fmopa za2.s, p0/M, p0/M, z12.s, z10.s - KAI_ASM_INST(0x808b0183) // fmopa za3.s, p0/M, p0/M, z12.s, z11.s - KAI_ASM_INST(0x80980260) // fmopa za0.s, p0/M, p0/M, z19.s, z24.s - KAI_ASM_INST(0x80990261) // fmopa za1.s, p0/M, p0/M, z19.s, z25.s - KAI_ASM_INST(0x80980362) // fmopa za2.s, p0/M, p0/M, z27.s, z24.s - KAI_ASM_INST(0x80990363) // fmopa za3.s, p0/M, p0/M, z27.s, z25.s - KAI_ASM_INST(0x809601c0) // fmopa za0.s, p0/M, p0/M, z14.s, z22.s - KAI_ASM_INST(0x809e01c1) // fmopa za1.s, p0/M, p0/M, z14.s, z30.s - KAI_ASM_INST(0x809601e2) // fmopa za2.s, p0/M, p0/M, z15.s, z22.s - KAI_ASM_INST(0x809e01e3) // fmopa za3.s, p0/M, p0/M, z15.s, z30.s + KAI_ASM_INST(0x80940080) // fmopa za0.s, p0/M, p0/M, z4.s, z20.s + KAI_ASM_INST(0x80950081) // fmopa za1.s, p0/M, p0/M, z4.s, z21.s + KAI_ASM_INST(0x809400a2) // fmopa za2.s, p0/M, p0/M, z5.s, z20.s + KAI_ASM_INST(0x809500a3) // fmopa za3.s, p0/M, p0/M, z5.s, z21.s + KAI_ASM_INST(0x809600c0) // fmopa za0.s, p0/M, p0/M, z6.s, z22.s + KAI_ASM_INST(0x809700c1) // fmopa za1.s, p0/M, p0/M, z6.s, z23.s + KAI_ASM_INST(0x809600e2) // fmopa za2.s, p0/M, p0/M, z7.s, z22.s + KAI_ASM_INST(0x809700e3) // fmopa za3.s, p0/M, p0/M, z7.s, z23.s + KAI_ASM_INST(0x808c0100) // fmopa za0.s, p0/M, p0/M, z8.s, z12.s + KAI_ASM_INST(0x808d0101) // fmopa za1.s, p0/M, p0/M, z8.s, z13.s + KAI_ASM_INST(0x808c0122) // fmopa za2.s, p0/M, p0/M, z9.s, z12.s + KAI_ASM_INST(0x808d0123) // fmopa za3.s, p0/M, p0/M, z9.s, z13.s + KAI_ASM_INST(0x808e0140) // fmopa za0.s, p0/M, p0/M, z10.s, z14.s + KAI_ASM_INST(0x808f0141) // fmopa za1.s, p0/M, p0/M, z10.s, z15.s + KAI_ASM_INST(0x808e0162) // fmopa za2.s, p0/M, p0/M, z11.s, z14.s + KAI_ASM_INST(0x808f0163) // fmopa za3.s, p0/M, p0/M, z11.s, z15.s KAI_ASM_LABEL(label_6) // K oddments cbz x20, label_8 KAI_ASM_LABEL(label_7) // K oddments: Loop @@ -145,24 +141,24 @@ KAI_ASM_LABEL(label_7) // K oddments: Loop bgt label_7 KAI_ASM_LABEL(label_8) // K oddments: End ldr x26, [x0, #0x10] - sub x25, x13, x14 + sub x25, x13, x15 cntw x24 - KAI_ASM_INST(0x854ec013) // ld1rw { z19.s }, p0/Z, [x0, #56] + ld1rw { z26.s }, p0/Z, [x0, #56] ldr x23, [x0, #0x18] cmp x25, x24 - KAI_ASM_INST(0x854fc01a) // ld1rw { z26.s }, p0/Z, [x0, #60] + ld1rw { z24.s }, p0/Z, [x0, #60] mov x12, #0x0 csel x22, x25, x24, LT - add x26, x26, x10, LSL #2 // C += n + add x26, x26, x11, LSL #2 // C += n lsr x21, x22, #0x2 - madd x26, x14, x23, x26 // C += m * ldc + madd x26, x15, x23, x26 // C += m * ldc and x20, x22, #0x3 cbz x21, label_11 KAI_ASM_LABEL(label_10) // Store to output array: Accumulator row 0 loop KAI_ASM_INST(0xc0860404) // mova { z4.s-z7.s }, za0h.s[x12] KAI_ASM_INST(0xc086042c) // mova { z12.s-z15.s }, za1h.s[x12] - KAI_ASM_INST(0xc1baca64) // fclamp { z4.s-z7.s }, z19.s, z26.s - KAI_ASM_INST(0xc1baca6c) // fclamp { z12.s-z15.s }, z19.s, z26.s + KAI_ASM_INST(0xc1b8cb44) // fclamp { z4.s-z7.s }, z26.s, z24.s + KAI_ASM_INST(0xc1b8cb4c) // fclamp { z12.s-z15.s }, z26.s, z24.s add x12, x12, #0x4 cmp x12, x21, LSL #2 KAI_ASM_INST(0xa1604344) // st1w { z4.s, z12.s }, p8, [x26] @@ -179,8 +175,8 @@ KAI_ASM_LABEL(label_11) // Store to output array: Accumulator row 0 oddments KAI_ASM_INST(0xc0860400) // mova { z0.s-z3.s }, za0h.s[x12] KAI_ASM_INST(0xc0860428) // mova { z8.s-z11.s }, za1h.s[x12] subs x20, x20, #0x1 - KAI_ASM_INST(0xc1baca60) // fclamp { z0.s-z3.s }, z19.s, z26.s - KAI_ASM_INST(0xc1baca68) // fclamp { z8.s-z11.s }, z19.s, z26.s + KAI_ASM_INST(0xc1b8cb40) // fclamp { z0.s-z3.s }, z26.s, z24.s + KAI_ASM_INST(0xc1b8cb48) // fclamp { z8.s-z11.s }, z26.s, z24.s KAI_ASM_INST(0xa1604340) // st1w { z0.s, z8.s }, p8, [x26] add x26, x26, x23 beq label_12 @@ -202,8 +198,8 @@ KAI_ASM_LABEL(label_12) // Store to output array: Accumulator row 0 oddments: E KAI_ASM_LABEL(label_13) // Store to output array: Accumulator row 1 loop KAI_ASM_INST(0xc0860454) // mova { z20.s-z23.s }, za2h.s[x12] KAI_ASM_INST(0xc086047c) // mova { z28.s-z31.s }, za3h.s[x12] - KAI_ASM_INST(0xc1baca74) // fclamp { z20.s-z23.s }, z19.s, z26.s - KAI_ASM_INST(0xc1baca7c) // fclamp { z28.s-z31.s }, z19.s, z26.s + KAI_ASM_INST(0xc1b8cb54) // fclamp { z20.s-z23.s }, z26.s, z24.s + KAI_ASM_INST(0xc1b8cb5c) // fclamp { z28.s-z31.s }, z26.s, z24.s add x12, x12, #0x4 cmp x12, x21, LSL #2 KAI_ASM_INST(0xa1604354) // st1w { z20.s, z28.s }, p8, [x26] @@ -220,8 +216,8 @@ KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments KAI_ASM_INST(0xc0860444) // mova { z4.s-z7.s }, za2h.s[x12] KAI_ASM_INST(0xc086046c) // mova { z12.s-z15.s }, za3h.s[x12] subs x20, x20, #0x1 - KAI_ASM_INST(0xc1baca64) // fclamp { z4.s-z7.s }, z19.s, z26.s - KAI_ASM_INST(0xc1baca6c) // fclamp { z12.s-z15.s }, z19.s, z26.s + KAI_ASM_INST(0xc1b8cb44) // fclamp { z4.s-z7.s }, z26.s, z24.s + KAI_ASM_INST(0xc1b8cb4c) // fclamp { z12.s-z15.s }, z26.s, z24.s KAI_ASM_INST(0xa1604344) // st1w { z4.s, z12.s }, p8, [x26] add x26, x26, x23 beq label_15 @@ -232,12 +228,12 @@ KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments KAI_ASM_INST(0xa1604346) // st1w { z6.s, z14.s }, p8, [x26] KAI_ASM_LABEL(label_15) // Store to output array: Accumulator row 1 oddments: End KAI_ASM_LABEL(label_16) // Store to output array: End - incw x10, ALL, MUL #2 - cmp x10, x11 + incw x11, ALL, MUL #2 + cmp x11, x10 blt label_2 - incw x14, ALL, MUL #2 - mov x10, #0x0 - cmp x14, x13 + incw x15, ALL, MUL #2 + mov x11, #0x0 + cmp x15, x13 mov x9, x27 blt label_1 KAI_ASM_INST(0xd503467f) // SMSTOP @@ -245,7 +241,11 @@ KAI_ASM_LABEL(label_16) // Store to output array: End ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa) diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S index 4439e11a..01796c59 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_pack_f32p2vlx1_f32_sme) KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_pack_f32p2vlx1_f32_sme) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA mov x6, #0x0 ldr x7, [x0, #0x50] @@ -297,7 +301,11 @@ KAI_ASM_LABEL(label_13) // K loop: End ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_lhs_pack_f32p2vlx1_f32_sme) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S index 387f50d8..a35c53c1 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme) KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA ldr x12, [x0, #0x8] ptrue p2.b @@ -149,7 +153,11 @@ KAI_ASM_LABEL(label_9) // Done ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme) -- GitLab From dfb7e9132c869d5068eb522ba3d948f342fcfafa Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 4 Jun 2025 15:18:23 +0200 Subject: [PATCH 3/6] Add pure assembly for rhs_pack_kxn_f32p16vlx1b_f32_f32_sme Signed-off-by: Jens Elofsson --- CMakeLists.txt | 3 +- kai/ukernels/matmul/BUILD.bazel | 2 +- ...kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c | 241 ++++-------------- ...kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.h | 31 ++- ...rhs_pack_kxn_f32p16vlx1b_f32_f32_sme_asm.S | 237 +++++++++++++++++ 5 files changed, 307 insertions(+), 207 deletions(-) create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index 49f4b09f..cea55ab0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -226,6 +226,8 @@ set(KLEIDIAI_FILES_SME_ASM kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme_asm.S ) set(KLEIDIAI_FILES_SME @@ -238,7 +240,6 @@ set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index ece4c604..bef6d589 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -139,6 +139,7 @@ I8MM_KERNELS_ASM = [ # buildifier: keep sorted SME_KERNELS_ASM = [ "pack/kai_lhs_pack_f32p2vlx1_f32_sme", + "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", ] @@ -154,7 +155,6 @@ SME_KERNELS = [ "pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme", "pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", - "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", "pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme", "pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme", diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c index 1180920c..4b213483 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c @@ -4,13 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.h" #include @@ -18,40 +14,64 @@ #include "kai/kai_common.h" -static const size_t kai_nr = 16; -static const size_t kai_kr = 1; +enum { + NR = 16, + KR = 1, +}; + +typedef struct { + const void* bias_ptr; + size_t width; + size_t height; + size_t in_stride; + size_t out_stride; + const void* in; + void* out; +} KernelArgs; + +static const size_t kai_num_bytes_input = sizeof(uint32_t); +static const size_t kai_num_bytes_output = sizeof(uint32_t); +static const size_t kai_num_bytes_bias = sizeof(float); + +void kai_kernel_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(const KernelArgs* args_ptr); size_t kai_get_n_step_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return NR * kai_get_sme_vector_length_u32() / KR; } size_t kai_get_rhs_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(size_t n_idx) { - KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u32()) == 0); + KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme() == 0); - return n_idx * sizeof(uint32_t); + return n_idx * kai_num_bytes_input; } size_t kai_get_bias_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(size_t n_idx) { - return n_idx * sizeof(uint32_t); + return n_idx * kai_num_bytes_bias; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(size_t k) { + return kai_get_n_step_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme() * + (kai_num_bytes_bias + kai_roundup(k, KR) * kai_num_bytes_output); } size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(size_t n_idx, size_t k) { - KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u32()) == 0); + KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme() == 0); - return n_idx * (sizeof(uint32_t) + k * sizeof(uint32_t)); + const size_t block_idx = n_idx / kai_get_n_step_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(); + return block_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(k); } size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(size_t n, size_t k) { - return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme( - kai_roundup(n, kai_nr * kai_get_sme_vector_length_u32()), k); + const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme()); + return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(n_nr_blocks, k); } void kai_run_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride_row, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { KAI_ASSUME(num_groups == 1); - KAI_ASSUME(nr == kai_nr * kai_get_sme_vector_length_u32()); - KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(nr == kai_get_n_step_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme()); + KAI_ASSUME(kr == KR); KAI_ASSUME(sr == 1); KAI_ASSUME(rhs != NULL); KAI_ASSUME(bias != NULL); @@ -60,181 +80,16 @@ void kai_run_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme( KAI_ASSUME(extra_bytes == 0); KAI_ASSUME(params == NULL); - size_t height = k; - const size_t width = n; - const void* in = rhs; - void* out = rhs_packed; - const size_t in_stride = rhs_stride; - size_t out_stride = kai_nr * kai_get_sme_vector_length_u8() * (height + 1); - - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x24, %x[out]\n" - "mov x23, %x[width]\n" - "ptrue p7.b\n" - "1:" // Bias: Full loop - "mov x22, x23\n" - "mov x21, %x[bias]\n" - "whilelt p0.s, XZR, x22\n" - "decw x22\n" - "whilelt p1.s, XZR, x22\n" - "decw x22\n" - "ld1w { z31.s }, p0/Z, [x21]\n" - "whilelt p0.s, XZR, x22\n" - "decw x22\n" - "ld1w { z30.s }, p1/Z, [x21, #1, MUL VL]\n" - "whilelt p1.s, XZR, x22\n" - "decw x22\n" - "ld1w { z29.s }, p0/Z, [x21, #2, MUL VL]\n" - "whilelt p0.s, XZR, x22\n" - "decw x22\n" - "ld1w { z28.s }, p1/Z, [x21, #3, MUL VL]\n" - "whilelt p1.s, XZR, x22\n" - "decw x22\n" - "ld1w { z27.s }, p0/Z, [x21, #4, MUL VL]\n" - "whilelt p0.s, XZR, x22\n" - "decw x22\n" - "ld1w { z26.s }, p1/Z, [x21, #5, MUL VL]\n" - "whilelt p1.s, XZR, x22\n" - "decw x22\n" - "ld1w { z25.s }, p0/Z, [x21, #6, MUL VL]\n" - "whilelt p0.s, XZR, x22\n" - "decw x22\n" - "ld1w { z24.s }, p1/Z, [x21, #7, MUL VL]\n" - "whilelt p6.s, XZR, x22\n" - "decw x22\n" - "whilelt p5.s, XZR, x22\n" - "decw x22\n" - "whilelt p4.s, XZR, x22\n" - "decw x22\n" - "whilelt p3.s, XZR, x22\n" - "decw x22\n" - "whilelt p2.s, XZR, x22\n" - "decw x22\n" - "whilelt p1.s, XZR, x22\n" - "decw x22\n" - "addvl x21, x21, #16\n" - "mov x20, x24\n" - "decw x23, ALL, MUL #16\n" - "ld1w { z23.s }, p0/Z, [x21, #-8, MUL VL]\n" - "whilelt p0.s, XZR, x22\n" - "ld1w { z22.s }, p6/Z, [x21, #-7, MUL VL]\n" - "cmp x23, #0x0\n" - "incb %x[bias], ALL, MUL #16\n" - "ld1w { z21.s }, p5/Z, [x21, #-6, MUL VL]\n" - "add x24, x24, %x[out_stride]\n" - "ld1w { z20.s }, p4/Z, [x21, #-5, MUL VL]\n" - "ld1w { z19.s }, p3/Z, [x21, #-4, MUL VL]\n" - "ld1w { z18.s }, p2/Z, [x21, #-3, MUL VL]\n" - "ld1w { z17.s }, p1/Z, [x21, #-2, MUL VL]\n" - "ld1w { z16.s }, p0/Z, [x21, #-1, MUL VL]\n" - "st1w { z31.s }, p7, [x20]\n" - "st1w { z30.s }, p7, [x20, #1, MUL VL]\n" - "st1w { z29.s }, p7, [x20, #2, MUL VL]\n" - "st1w { z28.s }, p7, [x20, #3, MUL VL]\n" - "st1w { z27.s }, p7, [x20, #4, MUL VL]\n" - "st1w { z26.s }, p7, [x20, #5, MUL VL]\n" - "st1w { z25.s }, p7, [x20, #6, MUL VL]\n" - "st1w { z24.s }, p7, [x20, #7, MUL VL]\n" - "addvl x20, x20, #16\n" - "st1w { z23.s }, p7, [x20, #-8, MUL VL]\n" - "st1w { z22.s }, p7, [x20, #-7, MUL VL]\n" - "st1w { z21.s }, p7, [x20, #-6, MUL VL]\n" - "st1w { z20.s }, p7, [x20, #-5, MUL VL]\n" - "st1w { z19.s }, p7, [x20, #-4, MUL VL]\n" - "st1w { z18.s }, p7, [x20, #-3, MUL VL]\n" - "st1w { z17.s }, p7, [x20, #-2, MUL VL]\n" - "st1w { z16.s }, p7, [x20, #-1, MUL VL]\n" - "bgt 1b\n" - "incb %x[out], ALL, MUL #16\n" - "2:" // Main row loop: Head - "mov x24, %x[in]\n" - "mov x23, %x[out]\n" - "add %x[in], x24, %x[in_stride]\n" - "sub %x[height], %x[height], #0x1\n" - "mov x22, %x[width]\n" - "3:" // Main row loop: Column loop - "mov x21, x22\n" - "mov x20, x23\n" - "whilelt p0.s, XZR, x21\n" - "decw x21\n" - "whilelt p1.s, XZR, x21\n" - "decw x21\n" - "ld1w { z31.s }, p0/Z, [x24]\n" - "whilelt p0.s, XZR, x21\n" - "decw x21\n" - "ld1w { z30.s }, p1/Z, [x24, #1, MUL VL]\n" - "whilelt p1.s, XZR, x21\n" - "decw x21\n" - "ld1w { z29.s }, p0/Z, [x24, #2, MUL VL]\n" - "whilelt p0.s, XZR, x21\n" - "decw x21\n" - "ld1w { z28.s }, p1/Z, [x24, #3, MUL VL]\n" - "whilelt p1.s, XZR, x21\n" - "decw x21\n" - "ld1w { z27.s }, p0/Z, [x24, #4, MUL VL]\n" - "whilelt p0.s, XZR, x21\n" - "decw x21\n" - "ld1w { z26.s }, p1/Z, [x24, #5, MUL VL]\n" - "whilelt p1.s, XZR, x21\n" - "decw x21\n" - "ld1w { z25.s }, p0/Z, [x24, #6, MUL VL]\n" - "whilelt p0.s, XZR, x21\n" - "decw x21\n" - "ld1w { z24.s }, p1/Z, [x24, #7, MUL VL]\n" - "whilelt p6.s, XZR, x21\n" - "decw x21\n" - "whilelt p5.s, XZR, x21\n" - "decw x21\n" - "whilelt p4.s, XZR, x21\n" - "decw x21\n" - "whilelt p3.s, XZR, x21\n" - "decw x21\n" - "whilelt p2.s, XZR, x21\n" - "decw x21\n" - "whilelt p1.s, XZR, x21\n" - "decw x21\n" - "addvl x24, x24, #16\n" - "decw x22, ALL, MUL #16\n" - "ld1w { z23.s }, p0/Z, [x24, #-8, MUL VL]\n" - "whilelt p0.s, XZR, x21\n" - "cmp x22, #0x0\n" - "ld1w { z22.s }, p6/Z, [x24, #-7, MUL VL]\n" - "add x23, x23, %x[out_stride]\n" - "ld1w { z21.s }, p5/Z, [x24, #-6, MUL VL]\n" - "ld1w { z20.s }, p4/Z, [x24, #-5, MUL VL]\n" - "ld1w { z19.s }, p3/Z, [x24, #-4, MUL VL]\n" - "ld1w { z18.s }, p2/Z, [x24, #-3, MUL VL]\n" - "ld1w { z17.s }, p1/Z, [x24, #-2, MUL VL]\n" - "ld1w { z16.s }, p0/Z, [x24, #-1, MUL VL]\n" - "st1w { z31.s }, p7, [x20]\n" - "st1w { z30.s }, p7, [x20, #1, MUL VL]\n" - "st1w { z29.s }, p7, [x20, #2, MUL VL]\n" - "st1w { z28.s }, p7, [x20, #3, MUL VL]\n" - "st1w { z27.s }, p7, [x20, #4, MUL VL]\n" - "st1w { z26.s }, p7, [x20, #5, MUL VL]\n" - "st1w { z25.s }, p7, [x20, #6, MUL VL]\n" - "st1w { z24.s }, p7, [x20, #7, MUL VL]\n" - "addvl x20, x20, #16\n" - "st1w { z23.s }, p7, [x20, #-8, MUL VL]\n" - "st1w { z22.s }, p7, [x20, #-7, MUL VL]\n" - "st1w { z21.s }, p7, [x20, #-6, MUL VL]\n" - "st1w { z20.s }, p7, [x20, #-5, MUL VL]\n" - "st1w { z19.s }, p7, [x20, #-4, MUL VL]\n" - "st1w { z18.s }, p7, [x20, #-3, MUL VL]\n" - "st1w { z17.s }, p7, [x20, #-2, MUL VL]\n" - "st1w { z16.s }, p7, [x20, #-1, MUL VL]\n" - "bgt 3b\n" - "cmp %x[height], #0x1\n" - "addvl %x[out], %x[out], #16\n" - "bge 2b\n" - ".inst 0xd503467f // SMSTOP\n" - : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) - : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", - "p15", "x20", "x21", "x22", "x23", "x24", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", - "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", - "z26", "z27", "z28", "z29", "z30", "z31"); + KernelArgs args; + args.bias_ptr = bias; + args.height = k; + args.width = n; + args.in = rhs; + args.out = rhs_packed; + args.in_stride = rhs_stride_row; + args.out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(args.height); + + kai_kernel_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.h index f8178d00..b54f1099 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -14,14 +14,14 @@ extern "C" { /// Gets n step value. /// -/// The starting row index must be divisible by `n_step`. +/// The starting column index must be divisible by `n_step`. /// /// @return The n step value. size_t kai_get_n_step_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. /// -/// @param[in] n_idx Column index. +/// @param[in] n_idx Column index. Must be divisible by `n_step` /// /// @return The offset in bytes to the data element. size_t kai_get_rhs_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(size_t n_idx); @@ -33,18 +33,25 @@ size_t kai_get_rhs_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(size_t n_idx); /// @return The offset in bytes to the data element. size_t kai_get_bias_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(size_t n_idx); +/// Gets row stride in bytes of the packed RHS matrix. +/// +/// @param[in] k Number of columns of the unpacked RHS matrix. +/// +/// @return Row stride in bytes. +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(size_t k); + /// Gets the offset in bytes to the data element in the packed RHS buffer. /// -/// @param[in] n_idx Row index. -/// @param[in] k Number of columns. +/// @param[in] n_idx Column index. Must be divisible by `n_step` +/// @param[in] k Number of rows. /// /// @return The offset in bytes to the data element. size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(size_t n_idx, size_t k); /// Gets the size in bytes of the packed RHS buffer. /// -/// @param[in] n Number of rows. -/// @param[in] k Number of columns. +/// @param[in] n Number of columns. +/// @param[in] k Number of rows. /// /// @return The size in bytes of the packed RHS buffer. size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(size_t n, size_t k); @@ -60,19 +67,19 @@ size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme(size_t n, si /// /// @param[in] num_groups Number of groups. It must be 1. /// @param[in] n Number of columns of the output matrix. -/// @param[in] k Common dimension between the LHS and RHS matrix. -/// @param[in] nr Block size in N dimension. It must be 16 * kai_get_sme_vector_length_u32(). +/// @param[in] k Number of rows. +/// @param[in] nr Block size in N dimension. It must be `get_n_step` /// @param[in] kr Block size in K dimension. It must be 1. /// @param[in] sr Number of kr splits. It must be 1. -/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs_stride_row Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. /// @param[in] scale Scale data buffer. It must be NULL. /// @param[out] rhs_packed Packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. -/// @param[in] params Extra packing parameters. It must be NULL. +/// @param[in] params Packing parameters. It must be NULL. void kai_run_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride_row, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); #ifdef __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme_asm.S new file mode 100644 index 00000000..d1480986 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme_asm.S @@ -0,0 +1,237 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(rhs_pack_kxn_f32p16vlx1b_f32_f32_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x11, [x0, #0x8] + ptrue p7.b + ldr x10, [x0, #0x30] + ldr x25, [x0, #0x0] + ldr x9, [x0, #0x10] + mov x24, x11 + ldr x28, [x0, #0x18] + mov x23, x10 + ldr x27, [x0, #0x20] + ldr x26, [x0, #0x28] +KAI_ASM_LABEL(label_1) // Bias: Full loop + mov x22, x24 + mov x21, x25 + whilelt p0.s, XZR, x22 + decw x22 + whilelt p1.s, XZR, x22 + decw x22 + ld1w { z31.s }, p0/Z, [x21] + whilelt p0.s, XZR, x22 + decw x22 + ld1w { z30.s }, p1/Z, [x21, #1, MUL VL] + whilelt p1.s, XZR, x22 + decw x22 + ld1w { z29.s }, p0/Z, [x21, #2, MUL VL] + whilelt p0.s, XZR, x22 + decw x22 + ld1w { z28.s }, p1/Z, [x21, #3, MUL VL] + whilelt p1.s, XZR, x22 + decw x22 + ld1w { z27.s }, p0/Z, [x21, #4, MUL VL] + whilelt p0.s, XZR, x22 + decw x22 + ld1w { z26.s }, p1/Z, [x21, #5, MUL VL] + whilelt p1.s, XZR, x22 + decw x22 + ld1w { z25.s }, p0/Z, [x21, #6, MUL VL] + whilelt p0.s, XZR, x22 + decw x22 + ld1w { z24.s }, p1/Z, [x21, #7, MUL VL] + whilelt p6.s, XZR, x22 + decw x22 + whilelt p5.s, XZR, x22 + decw x22 + whilelt p4.s, XZR, x22 + decw x22 + whilelt p3.s, XZR, x22 + decw x22 + whilelt p2.s, XZR, x22 + decw x22 + whilelt p1.s, XZR, x22 + decw x22 + addvl x21, x21, #16 + mov x20, x23 + decw x24, ALL, MUL #16 + ld1w { z23.s }, p0/Z, [x21, #-8, MUL VL] + whilelt p0.s, XZR, x22 + ld1w { z22.s }, p6/Z, [x21, #-7, MUL VL] + cmp x24, #0x0 + incb x25, ALL, MUL #16 + ld1w { z21.s }, p5/Z, [x21, #-6, MUL VL] + add x23, x23, x27 + ld1w { z20.s }, p4/Z, [x21, #-5, MUL VL] + ld1w { z19.s }, p3/Z, [x21, #-4, MUL VL] + ld1w { z18.s }, p2/Z, [x21, #-3, MUL VL] + ld1w { z17.s }, p1/Z, [x21, #-2, MUL VL] + ld1w { z16.s }, p0/Z, [x21, #-1, MUL VL] + st1w { z31.s }, p7, [x20] + st1w { z30.s }, p7, [x20, #1, MUL VL] + st1w { z29.s }, p7, [x20, #2, MUL VL] + st1w { z28.s }, p7, [x20, #3, MUL VL] + st1w { z27.s }, p7, [x20, #4, MUL VL] + st1w { z26.s }, p7, [x20, #5, MUL VL] + st1w { z25.s }, p7, [x20, #6, MUL VL] + st1w { z24.s }, p7, [x20, #7, MUL VL] + addvl x20, x20, #16 + st1w { z23.s }, p7, [x20, #-8, MUL VL] + st1w { z22.s }, p7, [x20, #-7, MUL VL] + st1w { z21.s }, p7, [x20, #-6, MUL VL] + st1w { z20.s }, p7, [x20, #-5, MUL VL] + st1w { z19.s }, p7, [x20, #-4, MUL VL] + st1w { z18.s }, p7, [x20, #-3, MUL VL] + st1w { z17.s }, p7, [x20, #-2, MUL VL] + st1w { z16.s }, p7, [x20, #-1, MUL VL] + bgt label_1 + incb x10, ALL, MUL #16 + mov x25, x9 + cbz x9, label_5 +KAI_ASM_LABEL(label_2) // Main row loop: Head + mov x24, x26 + mov x23, x10 + add x26, x24, x28 + sub x25, x25, #0x1 + mov x22, x11 +KAI_ASM_LABEL(label_3) // Main row loop: Column loop + mov x21, x22 + mov x20, x23 + whilelt p0.s, XZR, x21 + decw x21 + whilelt p1.s, XZR, x21 + decw x21 + ld1w { z31.s }, p0/Z, [x24] + whilelt p0.s, XZR, x21 + decw x21 + ld1w { z30.s }, p1/Z, [x24, #1, MUL VL] + whilelt p1.s, XZR, x21 + decw x21 + ld1w { z29.s }, p0/Z, [x24, #2, MUL VL] + whilelt p0.s, XZR, x21 + decw x21 + ld1w { z28.s }, p1/Z, [x24, #3, MUL VL] + whilelt p1.s, XZR, x21 + decw x21 + ld1w { z27.s }, p0/Z, [x24, #4, MUL VL] + whilelt p0.s, XZR, x21 + decw x21 + ld1w { z26.s }, p1/Z, [x24, #5, MUL VL] + whilelt p1.s, XZR, x21 + decw x21 + ld1w { z25.s }, p0/Z, [x24, #6, MUL VL] + whilelt p0.s, XZR, x21 + decw x21 + ld1w { z24.s }, p1/Z, [x24, #7, MUL VL] + whilelt p6.s, XZR, x21 + decw x21 + whilelt p5.s, XZR, x21 + decw x21 + whilelt p4.s, XZR, x21 + decw x21 + whilelt p3.s, XZR, x21 + decw x21 + whilelt p2.s, XZR, x21 + decw x21 + whilelt p1.s, XZR, x21 + decw x21 + addvl x24, x24, #16 + decw x22, ALL, MUL #16 + ld1w { z23.s }, p0/Z, [x24, #-8, MUL VL] + whilelt p0.s, XZR, x21 + cmp x22, #0x0 + ld1w { z22.s }, p6/Z, [x24, #-7, MUL VL] + add x23, x23, x27 + ld1w { z21.s }, p5/Z, [x24, #-6, MUL VL] + ld1w { z20.s }, p4/Z, [x24, #-5, MUL VL] + ld1w { z19.s }, p3/Z, [x24, #-4, MUL VL] + ld1w { z18.s }, p2/Z, [x24, #-3, MUL VL] + ld1w { z17.s }, p1/Z, [x24, #-2, MUL VL] + ld1w { z16.s }, p0/Z, [x24, #-1, MUL VL] + st1w { z31.s }, p7, [x20] + st1w { z30.s }, p7, [x20, #1, MUL VL] + st1w { z29.s }, p7, [x20, #2, MUL VL] + st1w { z28.s }, p7, [x20, #3, MUL VL] + st1w { z27.s }, p7, [x20, #4, MUL VL] + st1w { z26.s }, p7, [x20, #5, MUL VL] + st1w { z25.s }, p7, [x20, #6, MUL VL] + st1w { z24.s }, p7, [x20, #7, MUL VL] + addvl x20, x20, #16 + st1w { z23.s }, p7, [x20, #-8, MUL VL] + st1w { z22.s }, p7, [x20, #-7, MUL VL] + st1w { z21.s }, p7, [x20, #-6, MUL VL] + st1w { z20.s }, p7, [x20, #-5, MUL VL] + st1w { z19.s }, p7, [x20, #-4, MUL VL] + st1w { z18.s }, p7, [x20, #-3, MUL VL] + st1w { z17.s }, p7, [x20, #-2, MUL VL] + st1w { z16.s }, p7, [x20, #-1, MUL VL] + bgt label_3 + cmp x25, #0x1 + addvl x10, x10, #16 + bge label_2 +KAI_ASM_LABEL(label_5) // Done + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme) + + KAI_ASM_END -- GitLab From 218c8bede6da7d2e6facffc141159c83f12744a1 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Tue, 10 Jun 2025 13:08:32 +0200 Subject: [PATCH 4/6] Address review comments - Add CHANGELOG entry - Minor formatting changes - Use enum values (MR, KR, SR) instead of function parameters Signed-off-by: Jens Elofsson --- CHANGELOG.md | 7 ++++++ CMakeLists.txt | 12 +++++----- .../pack/kai_lhs_pack_f32p2vlx1_f32_sme.c | 23 +++++++++---------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e430695d..74c2e2e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,13 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme - kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme - kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme +- Convert SME and SME2 matmul micro-kernels to use pure assembly, and add MSVC support. Affects: + - kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla + - kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla + - kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa + - kai_lhs_pack_f32p2vlx1_f32_sme + - kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme + - kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme - New Advanced SIMD micro-kernels: - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_DotProd. - Optimized version of kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 kernel for block depth of 8 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon`) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5df77395..90aac72b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -229,12 +229,6 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME_ASM - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S - kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c @@ -247,6 +241,12 @@ set(KLEIDIAI_FILES_SME_ASM kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S + kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme_asm.S + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme_asm.S ) set(KLEIDIAI_FILES_SME diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c index 434e55e7..76a368cb 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c @@ -15,6 +15,13 @@ #include "kai/kai_common.h" +enum { + MR = 2, + KR = 1, + MAX_M_STEP = (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(float)) / KR), + SR = 1, +}; + typedef struct { size_t m; size_t k; @@ -34,13 +41,6 @@ typedef struct { void kai_kernel_lhs_pack_f32p2vlx1_f32_sme(const KernelArgs* args_ptr); -enum { - MR = 2, - KR = 1, - MAX_M_STEP = (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(float)) / KR), - SR = 1, -}; - static size_t kai_get_mr_lhs_pack_f32p2vlx1_f32_sme(void) { return MR * kai_get_sme_vector_length_u32() / KR; } @@ -68,7 +68,7 @@ size_t kai_get_lhs_packed_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t KAI_UNUSED(kr); KAI_UNUSED(sr); - return m_idx * kai_roundup(k, kr) * sizeof(float); + return m_idx * kai_roundup(k, KR) * sizeof(float); } size_t kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { @@ -91,7 +91,6 @@ void kai_run_lhs_pack_f32p2vlx1_f32_sme( KAI_ASSUME(sr == SR); KAI_ASSUME(lhs != NULL); KAI_ASSUME(lhs_packed != NULL); - KAI_ASSUME(m_idx_start == 0); const size_t m_step = kai_get_mr_lhs_pack_f32p2vlx1_f32_sme(); @@ -115,9 +114,9 @@ void kai_run_lhs_pack_f32p2vlx1_f32_sme( KernelArgs args; args.m = m; args.k = k; - args.mr = mr; - args.kr = kr; - args.sr = sr; + args.mr = MR; + args.kr = KR; + args.sr = SR; args.m_idx_start = m_idx_start; args.lhs = lhs; args.lhs_stride = lhs_stride; -- GitLab From 6a32b16fa4bf7c96c3557c047069f9de8e8e6122 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Tue, 10 Jun 2025 13:53:46 +0200 Subject: [PATCH 5/6] Address review comments - Minor formatting fixes Signed-off-by: Jens Elofsson --- CMakeLists.txt | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 90aac72b..203aa929 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -235,14 +235,14 @@ set(KLEIDIAI_FILES_SME_ASM kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S + kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S - kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c @@ -260,6 +260,12 @@ set(KLEIDIAI_FILES_SME ) set(KLEIDIAI_FILES_SME2_ASM + kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S + kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S + kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla_asm.S kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c @@ -270,12 +276,6 @@ set(KLEIDIAI_FILES_SME2_ASM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S - kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c - kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S - kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c - kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S - kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c - kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S ) set(KLEIDIAI_FILES_SME2 @@ -335,8 +335,8 @@ else() set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_SME_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_SME2_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SME_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SME2_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set(KLEIDIAI_FILES_ASM ${KLEIDIAI_FILES_SME_ASM} -- GitLab From 7673b7d9c80ff316040f08deb9485dc299e675c6 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 11 Jun 2025 08:59:22 +0200 Subject: [PATCH 6/6] Address review comments - Minor formatting changes Signed-off-by: Jens Elofsson --- .../kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c | 4 ++-- .../kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c index 3760cc45..e9453870 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c @@ -29,14 +29,14 @@ typedef struct { uint64_t flags; } KernelArgs; -void kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(KernelArgs* args_ptr); - static const size_t kai_m_step = 1; static const size_t kai_nr = 16; static const size_t kai_n_step = 16; static const size_t kai_kr = 1; static const size_t kai_sr = 1; +void kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(KernelArgs* args_ptr); + size_t kai_get_m_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(void) { return kai_m_step; } diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c index 747fc276..b2fe03aa 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c @@ -29,14 +29,14 @@ typedef struct { uint64_t flags; } KernelArgs; -void kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(KernelArgs* args_ptr); - static const size_t kai_m_step = 1; static const size_t kai_nr = 2; static const size_t kai_n_step = 16; static const size_t kai_kr = 1; static const size_t kai_sr = 1; +void kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(KernelArgs* args_ptr); + size_t kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla(void) { return kai_m_step; } -- GitLab