From 819e992ffa80ad82595ca1129d699317c3fb02c2 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Thu, 8 May 2025 17:44:49 +0200 Subject: [PATCH 1/9] imatmul kernels use pure assembly This change move the kernel inline assembly blocks out to separate assembly files Signed-off-by: Emil Ohlsson --- CMakeLists.txt | 9 + kai/ukernels/matmul/BUILD.bazel | 42 +- ...16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c | 217 ++--------- ...16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h | 10 +- ...16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S | 194 ++++++++++ ...2_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c | 276 ++----------- ...2_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h | 10 +- ...2p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S | 252 ++++++++++++ ...8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c | 366 ++---------------- ...8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h | 13 +- ...lx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S | 336 ++++++++++++++++ .../kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c | 287 +------------- .../kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h | 2 +- ..._lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S | 312 +++++++++++++++ .../kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c | 274 +------------ .../kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h | 2 +- ..._lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S | 299 ++++++++++++++ .../kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c | 287 +------------- .../kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h | 10 +- ...ai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S | 313 +++++++++++++++ ...ack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c | 252 ++---------- ...ack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h | 8 +- ...kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S | 240 ++++++++++++ ..._imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c | 179 ++------- ..._imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h | 6 +- ...tmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S | 175 +++++++++ ..._imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c | 158 ++------ ..._imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h | 6 +- ...tmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S | 161 ++++++++ 29 files changed, 2595 insertions(+), 2101 deletions(-) create mode 100644 kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S create mode 100644 kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S create mode 100644 kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S create mode 100644 kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S create mode 100644 kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S create mode 100644 kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S create mode 100644 kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S create mode 100644 kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S create mode 100644 kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index ea7b9926..1627d694 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -223,14 +223,20 @@ set(KLEIDIAI_FILES_NEON_I8MM set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c + kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c + kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c + kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c + kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c + kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c + kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -241,8 +247,11 @@ set(KLEIDIAI_FILES_SME set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 2da2800c..bf89cb54 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -138,16 +138,10 @@ I8MM_KERNELS_ASM = [ # buildifier: keep sorted SME_KERNELS = [ - "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", - "pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme", - "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme", "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_lhs_pack_f32p2vlx1_f32_sme", "pack/kai_lhs_pack_x16p2vlx2_x16_sme", "pack/kai_lhs_pack_x8p2vlx4_x8_sme", - "pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", - "pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme", - "pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", @@ -157,11 +151,18 @@ SME_KERNELS = [ "pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme", ] +# buildifier: keep sorted +SME_KERNELS_ASM = [ + "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", + "pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme", + "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme", + "pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", + "pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme", + "pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme", +] + # buildifier: keep sorted SME2_KERNELS = [ - "imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", - "imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa", - "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla", @@ -176,6 +177,13 @@ SME2_KERNELS = [ "matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", ] +# buildifier: keep sorted +SME2_KERNELS_ASM = [ + "imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", + "imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa", + "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", +] + kai_c_library( name = "interface", textual_hdrs = glob(["**/*_interface.h"]), @@ -272,6 +280,13 @@ kai_c_library( textual_hdrs = [ukernel + ".h" for ukernel in SME_KERNELS], ) +kai_c_library( + name = "sme_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in SME_KERNELS_ASM] + [ukernel + ".c" for ukernel in SME_KERNELS_ASM], + cpu_uarch = kai_cpu_sme(), + textual_hdrs = [ukernel + ".h" for ukernel in SME_KERNELS_ASM], +) + kai_c_library( name = "sme2_impl", srcs = [ukernel + ".c" for ukernel in SME2_KERNELS], @@ -279,6 +294,13 @@ kai_c_library( textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS], ) +kai_c_library( + name = "sme2_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in SME2_KERNELS_ASM] + [ukernel + ".c" for ukernel in SME2_KERNELS_ASM], + cpu_uarch = kai_cpu_sme2(), + textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS_ASM], +) + kai_c_library( name = "matmul", visibility = ["//visibility:public"], @@ -297,6 +319,8 @@ kai_c_library( ":neon_impl_asm", ":scalar_impl", ":sme2_impl", + ":sme2_impl_asm", ":sme_impl", + ":sme_impl_asm", ], ) diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c index 7e77125a..cc2e7e41 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c @@ -4,13 +4,6 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) -#error This file must be compiled for AArch64, FEAT_SVE2. -#else // Architectural features check. - #include "kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" #include @@ -18,30 +11,50 @@ #include "kai/kai_common.h" +typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + float16_t min; + float16_t max; + void* accumulator_buffer; + uint64_t flags; +} KernelArgs; + static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 2; +void kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(KernelArgs* args); + +// Returns a constant value specific to this kernel that's relative to vector length +static size_t kai_get_kernel_vec_length_constant(void) { + const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u16() / kai_kr; + return kernel_vec_length_constant; +} + size_t kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u32(); + return kai_mr * kai_get_kernel_vec_length_constant(); } size_t kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_kernel_vec_length_constant(); } size_t kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - return m_idx * indirect_k * sizeof(uint16_t); + return m_idx * k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(uint16_t); } static size_t kai_get_rhs_packed_stride_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( size_t k_chunk_count, size_t k_chunk_length) { - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); return kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() * - (sizeof(uint16_t) + indirect_k * sizeof(uint16_t)); + (sizeof(uint16_t) + k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(uint16_t)); } size_t kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( @@ -54,11 +67,11 @@ size_t kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_s } size_t kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_row_stride) { + size_t m_idx, size_t n_idx, size_t dst_stride_row) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); - return m_idx * dst_row_stride + n_idx * sizeof(uint16_t); + return m_idx * dst_stride_row + n_idx * sizeof(uint16_t); } size_t kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(size_t m, size_t n) { @@ -67,186 +80,20 @@ size_t kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( void kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, float clamp_min, float clamp_max) { - typedef struct { - const void* A; - const void* B; - void* C; - uint64_t ldcb; - uint64_t M; - uint64_t N; - uint64_t K; - float16_t min; - float16_t max; - void* accumulator_buffer; - uint64_t flags; - } KernelArgs; - + void* dst, size_t dst_stride_row, float clamp_min, float clamp_max) { KernelArgs args; args.A = lhs_packed; args.B = rhs_packed; - - size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - args.C = dst; - args.ldcb = dst_row_stride; + args.ldcb = dst_stride_row; args.M = m; args.N = n; - args.K = indirect_k; + args.K = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); args.min = (float16_t)clamp_min; args.max = (float16_t)clamp_max; - args.accumulator_buffer = NULL; args.flags = 0; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "ldr w13, [%x[args], %[offsetof_M]]\n" - "mov x11, #0x0\n" - "mov x10, #0x0\n" - "ptrue p1.b\n" - ".inst 0x25207810 // ptrue pn8.b\n" - "ldr w9, [%x[args], %[offsetof_N]]\n" - "ldr x28, [%x[args], %[offsetof_A]]\n" - "1:" // M loop - "ldr x27, [%x[args], %[offsetof_B]]\n" - "2:" // N loop - "fmov z24.h, #0.0\n" - "ld1h { z5.h }, p1/Z, [x27]\n" - "fmov z27.h, #1.0\n" - "mov x26, x28\n" - ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" - "inch x27, ALL, MUL #2\n" - "zip1 z30.h, z5.h, z24.h\n" - "zip2 z20.h, z5.h, z24.h\n" - ".inst 0x81be2760 // fmopa za0.s, p1/M, p1/M, z27.h, z30.h\n" - ".inst 0x81b42761 // fmopa za1.s, p1/M, p1/M, z27.h, z20.h\n" - ".inst 0x81be2762 // fmopa za2.s, p1/M, p1/M, z27.h, z30.h\n" - ".inst 0x81b42763 // fmopa za3.s, p1/M, p1/M, z27.h, z20.h\n" - "ldr x20, [%x[args], %[offsetof_K]]\n" - "add x20, x20, #0x1\n" - "lsr x20, x20, #0x1\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 6f\n" - "subs x21, x21, #0x1\n" - ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" - ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" - ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" - ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" - ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" - "addvl x26, x26, #8\n" - ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - "ble 5f\n" - "4:" // K loop - ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" - "subs x21, x21, #0x1\n" - ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" - ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" - ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" - ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" - ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" - ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" - ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" - ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" - ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" - ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" - ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" - ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" - ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" - ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" - ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" - ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" - ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" - ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" - ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" - ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" - "addvl x26, x26, #8\n" - ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - "bgt 4b\n" - "5:" // K loop tail - ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" - ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" - ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" - ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" - ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" - ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" - ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" - ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" - ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" - ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" - ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" - ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" - ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" - ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" - ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" - ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" - "6:" // K oddments - "cbz x20, 8f\n" - "7:" // K oddments: Loop - ".inst 0xa1402345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26]\n" - "subs x20, x20, #0x1\n" - "addvl x26, x26, #2\n" - ".inst 0xa040236e // ld1h { z14.h-z15.h }, pn8.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0x81ae24a0 // fmopa za0.s, p1/M, p1/M, z5.h, z14.h\n" - ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" - ".inst 0x81ae25a2 // fmopa za2.s, p1/M, p1/M, z13.h, z14.h\n" - ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" - "bgt 7b\n" - "8:" // K oddments: End - "ldr x25, [%x[args], %[offsetof_C]]\n" - "sub x24, x13, x11\n" - "cntw x23, ALL, MUL #2\n" - "ld1rh { z17.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" - "ldr x22, [%x[args], %[offsetof_ldcb]]\n" - "whilelt p0.h, x10, x9\n" - "cmp x24, x23\n" - "ld1rh { z16.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" - "mov x12, #0x0\n" - "mov x21, #0x0\n" - "add x25, x25, x10, LSL #1\n" // C += n - "mov x20, #0x2\n" - "madd x25, x11, x22, x25\n" // C += m * ldc - "csel x24, x24, x23, LT\n" - "10:" // Store to output array: Accumulator loop - ".inst 0xc006000e // mova { z14.b-z15.b }, za0h.b[x12, 0:1]\n" - "add x12, x12, #0x4\n" - "cmp x12, x23, LSL #1\n" - "add x21, x21, #0x1\n" - ".inst 0xc120e1cc // fcvt z12.h, { z14.s-z15.s }\n" - "csel x12, x12, x20, LT\n" - "cmp x21, x24\n" - ".inst 0x6470262c // fclamp z12.h, z17.h, z16.h\n" - "st1h { z12.h }, p0, [x25]\n" - "add x25, x25, x22\n" - "blt 10b\n" - "incw x10, ALL, MUL #2\n" - "cmp x10, x9\n" - "blt 2b\n" - "incw x11, ALL, MUL #2\n" - "mov x10, #0x0\n" - "cmp x11, x13\n" - "mov x28, x26\n" - "blt 1b\n" - ".inst 0xd503467f // SMSTOP\n" - : - : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), - [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), - [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), - [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), [offsetof_M] "I"(offsetof(KernelArgs, M)), - [offsetof_N] "I"(offsetof(KernelArgs, N)), [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", - "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", - "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(&args); } - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h index 79c52a42..68144382 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h @@ -55,11 +55,11 @@ size_t kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_s /// /// @param[in] m_idx Row index. Must be a multiple of `m_step`. /// @param[in] n_idx Column index. Must be a multiple of `n_step`. -/// @param[in] dst_row_stride. Distance between start of two rows in the output buffer. +/// @param[in] dst_stride_row Row stride in bytes. /// /// @return The offset in bytes to the data element. size_t kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_row_stride); + size_t m_idx, size_t n_idx, size_t dst_stride_row); /// Gets the size in bytes of the destination matrix buffer. /// @@ -81,16 +81,16 @@ size_t kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length of a LHS column split +/// @param[in] k_chunk_length Length of a LHS column split. /// @param[in] lhs_packed Packed LHS matrix buffer. /// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. -/// @param[in] dst_row_stride Row stride in bytes of the output matrix. +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. void kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + void* dst, size_t dst_stride_row, float clamp_min, float clamp_max); #ifdef __cplusplus } // extern "C" diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S new file mode 100644 index 00000000..24d69bf5 --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S @@ -0,0 +1,194 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x13, #0x0 + ptrue p1.b + KAI_ASM_INST(0x25207810) // ptrue pn8.b + ldr w11, [x0, #0x20] + ldr w10, [x0, #0x28] + mov x9, #0x0 + ldr x28, [x0, #0x0] +KAI_ASM_LABEL(label_1) // M loop + ldr x27, [x0, #0x8] +KAI_ASM_LABEL(label_2) // N loop + fmov z24.h, #0.0 + ld1h { z5.h }, p1/Z, [x27] + fmov z27.h, #1.0 + mov x26, x28 + KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } + inch x27, ALL, MUL #2 + zip1 z30.h, z5.h, z24.h + zip2 z20.h, z5.h, z24.h + KAI_ASM_INST(0x81be2760) // fmopa za0.s, p1/M, p1/M, z27.h, z30.h + KAI_ASM_INST(0x81b42761) // fmopa za1.s, p1/M, p1/M, z27.h, z20.h + KAI_ASM_INST(0x81be2762) // fmopa za2.s, p1/M, p1/M, z27.h, z30.h + KAI_ASM_INST(0x81b42763) // fmopa za3.s, p1/M, p1/M, z27.h, z20.h + ldr x20, [x0, #0x30] + add x20, x20, #0x1 + lsr x20, x20, #0x1 + lsr x21, x20, #0x2 + and x20, x20, #0x3 + cbz x21, label_6 + subs x21, x21, #0x1 + KAI_ASM_INST(0xa0402352) // ld1h { z18.h-z19.h }, pn8.b/Z, [x26] + KAI_ASM_INST(0xa0402370) // ld1h { z16.h-z17.h }, pn8.b/Z, [x27] + KAI_ASM_INST(0xa1412342) // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL] + KAI_ASM_INST(0xa041237e) // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL] + KAI_ASM_INST(0xa042235c) // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL] + KAI_ASM_INST(0xa1422366) // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL] + KAI_ASM_INST(0xa1432345) // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL] + addvl x26, x26, #8 + KAI_ASM_INST(0xa1432367) // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL] + addvl x27, x27, #8 + ble label_5 +KAI_ASM_LABEL(label_4) // K loop + KAI_ASM_INST(0x81b02640) // fmopa za0.s, p1/M, p1/M, z18.h, z16.h + subs x21, x21, #0x1 + KAI_ASM_INST(0x81b12641) // fmopa za1.s, p1/M, p1/M, z18.h, z17.h + KAI_ASM_INST(0x81b02662) // fmopa za2.s, p1/M, p1/M, z19.h, z16.h + KAI_ASM_INST(0x81b12663) // fmopa za3.s, p1/M, p1/M, z19.h, z17.h + KAI_ASM_INST(0xa0402352) // ld1h { z18.h-z19.h }, pn8.b/Z, [x26] + KAI_ASM_INST(0x81be2440) // fmopa za0.s, p1/M, p1/M, z2.h, z30.h + KAI_ASM_INST(0xa0402370) // ld1h { z16.h-z17.h }, pn8.b/Z, [x27] + KAI_ASM_INST(0x81bf2441) // fmopa za1.s, p1/M, p1/M, z2.h, z31.h + KAI_ASM_INST(0x81be2542) // fmopa za2.s, p1/M, p1/M, z10.h, z30.h + KAI_ASM_INST(0x81bf2543) // fmopa za3.s, p1/M, p1/M, z10.h, z31.h + KAI_ASM_INST(0xa1412342) // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL] + KAI_ASM_INST(0x81a62780) // fmopa za0.s, p1/M, p1/M, z28.h, z6.h + KAI_ASM_INST(0xa041237e) // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL] + KAI_ASM_INST(0x81ae2781) // fmopa za1.s, p1/M, p1/M, z28.h, z14.h + KAI_ASM_INST(0x81a627a2) // fmopa za2.s, p1/M, p1/M, z29.h, z6.h + KAI_ASM_INST(0x81ae27a3) // fmopa za3.s, p1/M, p1/M, z29.h, z14.h + KAI_ASM_INST(0xa042235c) // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL] + KAI_ASM_INST(0xa1422366) // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL] + KAI_ASM_INST(0x81a724a0) // fmopa za0.s, p1/M, p1/M, z5.h, z7.h + KAI_ASM_INST(0x81af24a1) // fmopa za1.s, p1/M, p1/M, z5.h, z15.h + KAI_ASM_INST(0x81a725a2) // fmopa za2.s, p1/M, p1/M, z13.h, z7.h + KAI_ASM_INST(0x81af25a3) // fmopa za3.s, p1/M, p1/M, z13.h, z15.h + KAI_ASM_INST(0xa1432345) // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL] + addvl x26, x26, #8 + KAI_ASM_INST(0xa1432367) // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL] + addvl x27, x27, #8 + bgt label_4 +KAI_ASM_LABEL(label_5) // K loop tail + KAI_ASM_INST(0x81b02640) // fmopa za0.s, p1/M, p1/M, z18.h, z16.h + KAI_ASM_INST(0x81b12641) // fmopa za1.s, p1/M, p1/M, z18.h, z17.h + KAI_ASM_INST(0x81b02662) // fmopa za2.s, p1/M, p1/M, z19.h, z16.h + KAI_ASM_INST(0x81b12663) // fmopa za3.s, p1/M, p1/M, z19.h, z17.h + KAI_ASM_INST(0x81be2440) // fmopa za0.s, p1/M, p1/M, z2.h, z30.h + KAI_ASM_INST(0x81bf2441) // fmopa za1.s, p1/M, p1/M, z2.h, z31.h + KAI_ASM_INST(0x81be2542) // fmopa za2.s, p1/M, p1/M, z10.h, z30.h + KAI_ASM_INST(0x81bf2543) // fmopa za3.s, p1/M, p1/M, z10.h, z31.h + KAI_ASM_INST(0x81a62780) // fmopa za0.s, p1/M, p1/M, z28.h, z6.h + KAI_ASM_INST(0x81ae2781) // fmopa za1.s, p1/M, p1/M, z28.h, z14.h + KAI_ASM_INST(0x81a627a2) // fmopa za2.s, p1/M, p1/M, z29.h, z6.h + KAI_ASM_INST(0x81ae27a3) // fmopa za3.s, p1/M, p1/M, z29.h, z14.h + KAI_ASM_INST(0x81a724a0) // fmopa za0.s, p1/M, p1/M, z5.h, z7.h + KAI_ASM_INST(0x81af24a1) // fmopa za1.s, p1/M, p1/M, z5.h, z15.h + KAI_ASM_INST(0x81a725a2) // fmopa za2.s, p1/M, p1/M, z13.h, z7.h + KAI_ASM_INST(0x81af25a3) // fmopa za3.s, p1/M, p1/M, z13.h, z15.h +KAI_ASM_LABEL(label_6) // K oddments + cbz x20, label_8 +KAI_ASM_LABEL(label_7) // K oddments: Loop + KAI_ASM_INST(0xa1402345) // ld1h { z5.h, z13.h }, pn8.b/Z, [x26] + subs x20, x20, #0x1 + addvl x26, x26, #2 + KAI_ASM_INST(0xa040236e) // ld1h { z14.h-z15.h }, pn8.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0x81ae24a0) // fmopa za0.s, p1/M, p1/M, z5.h, z14.h + KAI_ASM_INST(0x81af24a1) // fmopa za1.s, p1/M, p1/M, z5.h, z15.h + KAI_ASM_INST(0x81ae25a2) // fmopa za2.s, p1/M, p1/M, z13.h, z14.h + KAI_ASM_INST(0x81af25a3) // fmopa za3.s, p1/M, p1/M, z13.h, z15.h + bgt label_7 +KAI_ASM_LABEL(label_8) // K oddments: End + ldr x25, [x0, #0x10] + sub x24, x11, x13 + cntw x23, ALL, MUL #2 + ld1rh { z17.h }, p1/Z, [x0, #56] + ldr x22, [x0, #0x18] + whilelt p0.h, x9, x10 + cmp x24, x23 + ld1rh { z16.h }, p1/Z, [x0, #58] + mov x12, #0x0 + mov x21, #0x0 + add x25, x25, x9, LSL #1 // C += n + mov x20, #0x2 + madd x25, x13, x22, x25 // C += m * ldc + csel x24, x24, x23, LT +KAI_ASM_LABEL(label_10) // Store to output array: Accumulator loop + KAI_ASM_INST(0xc006000e) // mova { z14.b-z15.b }, za0h.b[x12, 0:1] + add x12, x12, #0x4 + cmp x12, x23, LSL #1 + add x21, x21, #0x1 + KAI_ASM_INST(0xc120e1cc) // fcvt z12.h, { z14.s-z15.s } + csel x12, x12, x20, LT + cmp x21, x24 + KAI_ASM_INST(0x6470262c) // fclamp z12.h, z17.h, z16.h + st1h { z12.h }, p0, [x25] + add x25, x25, x22 + blt label_10 + incw x9, ALL, MUL #2 + cmp x9, x10 + blt label_2 + incw x13, ALL, MUL #2 + mov x9, #0x0 + cmp x13, x11 + mov x28, x26 + blt label_1 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c index a927e2b7..c3071ab7 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c @@ -4,13 +4,6 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) -#error This file must be compiled for AArch64, FEAT_SVE2. -#else // Architectural features check. - #include "kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" #include @@ -18,30 +11,50 @@ #include "kai/kai_common.h" +typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + float min; + float max; + void* accumulator_buffer; + uint64_t flags; +} KernelArgs; + static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 1; +void kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(KernelArgs* args); + +// Returns a constant value specific to this kernel that's relative to vector length +static size_t kai_get_kernel_vec_length_constant(void) { + const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u32() / kai_kr; + return kernel_vec_length_constant; +} + size_t kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u32(); + return kai_mr * kai_get_kernel_vec_length_constant(); } size_t kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_kernel_vec_length_constant(); } size_t kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() == 0); - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - return m_idx * indirect_k * sizeof(float); + return m_idx * k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(float); } static size_t kai_get_rhs_packed_stride_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( size_t k_chunk_count, size_t k_chunk_length) { - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); return kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() * - (sizeof(float) + indirect_k * sizeof(float)); + (sizeof(float) + k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(float)); } size_t kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( @@ -54,11 +67,11 @@ size_t kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_ } size_t kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_row_stride) { + size_t m_idx, size_t n_idx, size_t dst_stride_row) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() == 0); KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() == 0); - return m_idx * dst_row_stride + n_idx * sizeof(float); + return m_idx * dst_stride_row + n_idx * sizeof(float); } size_t kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(size_t m, size_t n) { @@ -67,245 +80,20 @@ size_t kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa void kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, float clamp_min, float clamp_max) { - typedef struct { - const void* A; - const void* B; - void* C; - uint64_t ldcb; - uint64_t M; - uint64_t N; - uint64_t K; - float min; - float max; - void* accumulator_buffer; - uint64_t flags; - } KernelArgs; - + void* dst, size_t dst_stride_row, float clamp_min, float clamp_max) { KernelArgs args; args.A = lhs_packed; args.B = rhs_packed; - - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - args.C = dst; - args.ldcb = dst_row_stride; + args.ldcb = dst_stride_row; args.M = m; args.N = n; - args.K = indirect_k; + args.K = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); args.min = clamp_min; args.max = clamp_max; - args.accumulator_buffer = NULL; args.flags = 0; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "ldr w14, [%x[args], %[offsetof_M]]\n" - "mov x13, #0x0\n" - "mov x11, #0x0\n" - "ptrue p0.b\n" - ".inst 0x25207811 // ptrue pn9.b\n" - "ldr w10, [%x[args], %[offsetof_N]]\n" - "ldr x9, [%x[args], %[offsetof_A]]\n" - "1:" // M loop - "ldr x28, [%x[args], %[offsetof_B]]\n" - "2:" // N loop - ".inst 0x25aa4570 // whilelt pn8.s, x11, x10, VLx2\n" - "fmov z13.s, #1.0\n" - ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" - "mov x27, x9\n" - ".inst 0xa040438e // ld1w { z14.s-z15.s }, p8/Z, [x28]\n" // Load bias - "addvl x28, x28, #2\n" - ".inst 0x808e01a0 // fmopa za0.s, p0/M, p0/M, z13.s, z14.s\n" - ".inst 0x808f01a1 // fmopa za1.s, p0/M, p0/M, z13.s, z15.s\n" - ".inst 0x808e01a2 // fmopa za2.s, p0/M, p0/M, z13.s, z14.s\n" - ".inst 0x808f01a3 // fmopa za3.s, p0/M, p0/M, z13.s, z15.s\n" - "ldr x20, [%x[args], %[offsetof_K]]\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 6f\n" - "subs x21, x21, #0x1\n" - ".inst 0xa1404772 // ld1w { z18.s, z26.s }, pn9.b/Z, [x27]\n" - ".inst 0xa0404794 // ld1w { z20.s-z21.s }, pn9.b/Z, [x28]\n" - ".inst 0xa1414764 // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0xa041478a // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL]\n" - ".inst 0xa1424773 // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa0424798 // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL]\n" - ".inst 0xa043476e // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - ".inst 0xa1434796 // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL]\n" - "addvl x28, x28, #8\n" - "ble 5f\n" - "4:" // K loop - ".inst 0x80940240 // fmopa za0.s, p0/M, p0/M, z18.s, z20.s\n" - "subs x21, x21, #0x1\n" - ".inst 0x80950241 // fmopa za1.s, p0/M, p0/M, z18.s, z21.s\n" - ".inst 0x80940342 // fmopa za2.s, p0/M, p0/M, z26.s, z20.s\n" - ".inst 0x80950343 // fmopa za3.s, p0/M, p0/M, z26.s, z21.s\n" - ".inst 0xa1404772 // ld1w { z18.s, z26.s }, pn9.b/Z, [x27]\n" - ".inst 0x808a0080 // fmopa za0.s, p0/M, p0/M, z4.s, z10.s\n" - ".inst 0xa0404794 // ld1w { z20.s-z21.s }, pn9.b/Z, [x28]\n" - ".inst 0x808b0081 // fmopa za1.s, p0/M, p0/M, z4.s, z11.s\n" - ".inst 0x808a0182 // fmopa za2.s, p0/M, p0/M, z12.s, z10.s\n" - ".inst 0x808b0183 // fmopa za3.s, p0/M, p0/M, z12.s, z11.s\n" - ".inst 0xa1414764 // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0x80980260 // fmopa za0.s, p0/M, p0/M, z19.s, z24.s\n" - ".inst 0xa041478a // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL]\n" - ".inst 0x80990261 // fmopa za1.s, p0/M, p0/M, z19.s, z25.s\n" - ".inst 0x80980362 // fmopa za2.s, p0/M, p0/M, z27.s, z24.s\n" - ".inst 0x80990363 // fmopa za3.s, p0/M, p0/M, z27.s, z25.s\n" - ".inst 0xa1424773 // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa0424798 // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL]\n" - ".inst 0x809601c0 // fmopa za0.s, p0/M, p0/M, z14.s, z22.s\n" - ".inst 0x809e01c1 // fmopa za1.s, p0/M, p0/M, z14.s, z30.s\n" - ".inst 0x809601e2 // fmopa za2.s, p0/M, p0/M, z15.s, z22.s\n" - ".inst 0x809e01e3 // fmopa za3.s, p0/M, p0/M, z15.s, z30.s\n" - ".inst 0xa043476e // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - ".inst 0xa1434796 // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL]\n" - "addvl x28, x28, #8\n" - "bgt 4b\n" - "5:" // K loop tail - ".inst 0x80940240 // fmopa za0.s, p0/M, p0/M, z18.s, z20.s\n" - ".inst 0x80950241 // fmopa za1.s, p0/M, p0/M, z18.s, z21.s\n" - ".inst 0x80940342 // fmopa za2.s, p0/M, p0/M, z26.s, z20.s\n" - ".inst 0x80950343 // fmopa za3.s, p0/M, p0/M, z26.s, z21.s\n" - ".inst 0x808a0080 // fmopa za0.s, p0/M, p0/M, z4.s, z10.s\n" - ".inst 0x808b0081 // fmopa za1.s, p0/M, p0/M, z4.s, z11.s\n" - ".inst 0x808a0182 // fmopa za2.s, p0/M, p0/M, z12.s, z10.s\n" - ".inst 0x808b0183 // fmopa za3.s, p0/M, p0/M, z12.s, z11.s\n" - ".inst 0x80980260 // fmopa za0.s, p0/M, p0/M, z19.s, z24.s\n" - ".inst 0x80990261 // fmopa za1.s, p0/M, p0/M, z19.s, z25.s\n" - ".inst 0x80980362 // fmopa za2.s, p0/M, p0/M, z27.s, z24.s\n" - ".inst 0x80990363 // fmopa za3.s, p0/M, p0/M, z27.s, z25.s\n" - ".inst 0x809601c0 // fmopa za0.s, p0/M, p0/M, z14.s, z22.s\n" - ".inst 0x809e01c1 // fmopa za1.s, p0/M, p0/M, z14.s, z30.s\n" - ".inst 0x809601e2 // fmopa za2.s, p0/M, p0/M, z15.s, z22.s\n" - ".inst 0x809e01e3 // fmopa za3.s, p0/M, p0/M, z15.s, z30.s\n" - "6:" // K oddments - "cbz x20, 8f\n" - "7:" // K oddments: Loop - ".inst 0xa040477c // ld1w { z28.s-z29.s }, pn9.b/Z, [x27]\n" - "subs x20, x20, #0x1\n" - "addvl x27, x27, #2\n" - ".inst 0xa1404787 // ld1w { z7.s, z15.s }, pn9.b/Z, [x28]\n" - "addvl x28, x28, #2\n" - ".inst 0x80870380 // fmopa za0.s, p0/M, p0/M, z28.s, z7.s\n" - ".inst 0x808f0381 // fmopa za1.s, p0/M, p0/M, z28.s, z15.s\n" - ".inst 0x808703a2 // fmopa za2.s, p0/M, p0/M, z29.s, z7.s\n" - ".inst 0x808f03a3 // fmopa za3.s, p0/M, p0/M, z29.s, z15.s\n" - "bgt 7b\n" - "8:" // K oddments: End - "ldr x26, [%x[args], %[offsetof_C]]\n" - "sub x25, x14, x13\n" - "cntw x24\n" - "ld1rw { z19.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" - "ldr x23, [%x[args], %[offsetof_ldcb]]\n" - "cmp x25, x24\n" - "ld1rw { z26.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" - "mov x12, #0x0\n" - "csel x22, x25, x24, LT\n" - "add x26, x26, x11, LSL #2\n" // C += n - "lsr x21, x22, #0x2\n" - "madd x26, x13, x23, x26\n" // C += m * ldc - "and x20, x22, #0x3\n" - "cbz x21, 11f\n" - "10:" // Store to output array: Accumulator row 0 loop - ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" - ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" - ".inst 0xc1baca64 // fclamp { z4.s-z7.s }, z19.s, z26.s\n" - ".inst 0xc1baca6c // fclamp { z12.s-z15.s }, z19.s, z26.s\n" - "add x12, x12, #0x4\n" - "cmp x12, x21, LSL #2\n" - ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" - "add x26, x26, x23\n" - ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" - "add x26, x26, x23\n" - ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" - "add x26, x26, x23\n" - ".inst 0xa1604347 // st1w { z7.s, z15.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "blt 10b\n" - "11:" // Store to output array: Accumulator row 0 oddments - "cbz x20, 12f\n" - ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" - ".inst 0xc0860428 // mova { z8.s-z11.s }, za1h.s[x12]\n" - "subs x20, x20, #0x1\n" - ".inst 0xc1baca60 // fclamp { z0.s-z3.s }, z19.s, z26.s\n" - ".inst 0xc1baca68 // fclamp { z8.s-z11.s }, z19.s, z26.s\n" - ".inst 0xa1604340 // st1w { z0.s, z8.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "beq 12f\n" - "subs x20, x20, #0x1\n" - ".inst 0xa1604341 // st1w { z1.s, z9.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "beq 12f\n" - ".inst 0xa1604342 // st1w { z2.s, z10.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "12:" // Store to output array: Accumulator row 0 oddments: End - "subs x25, x25, x22\n" - "beq 16f\n" - "cmp x25, x24\n" - "mov x12, #0x0\n" - "csel x20, x25, x24, LT\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 14f\n" - "13:" // Store to output array: Accumulator row 1 loop - ".inst 0xc0860454 // mova { z20.s-z23.s }, za2h.s[x12]\n" - ".inst 0xc086047c // mova { z28.s-z31.s }, za3h.s[x12]\n" - ".inst 0xc1baca74 // fclamp { z20.s-z23.s }, z19.s, z26.s\n" - ".inst 0xc1baca7c // fclamp { z28.s-z31.s }, z19.s, z26.s\n" - "add x12, x12, #0x4\n" - "cmp x12, x21, LSL #2\n" - ".inst 0xa1604354 // st1w { z20.s, z28.s }, p8, [x26]\n" - "add x26, x26, x23\n" - ".inst 0xa1604355 // st1w { z21.s, z29.s }, p8, [x26]\n" - "add x26, x26, x23\n" - ".inst 0xa1604356 // st1w { z22.s, z30.s }, p8, [x26]\n" - "add x26, x26, x23\n" - ".inst 0xa1604357 // st1w { z23.s, z31.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "blt 13b\n" - "14:" // Store to output array: Accumulator row 1 oddments - "cbz x20, 15f\n" - ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n" - ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" - "subs x20, x20, #0x1\n" - ".inst 0xc1baca64 // fclamp { z4.s-z7.s }, z19.s, z26.s\n" - ".inst 0xc1baca6c // fclamp { z12.s-z15.s }, z19.s, z26.s\n" - ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "beq 15f\n" - "subs x20, x20, #0x1\n" - ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" - "add x26, x26, x23\n" - "beq 15f\n" - ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" - "15:" // Store to output array: Accumulator row 1 oddments: End - "16:" // Store to output array: End - "incw x11, ALL, MUL #2\n" - "cmp x11, x10\n" - "blt 2b\n" - "incw x13, ALL, MUL #2\n" - "mov x11, #0x0\n" - "cmp x13, x14\n" - "mov x9, x27\n" - "blt 1b\n" - ".inst 0xd503467f // SMSTOP\n" - : - : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), - [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), - [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), - [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), [offsetof_M] "I"(offsetof(KernelArgs, M)), - [offsetof_N] "I"(offsetof(KernelArgs, N)), [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", - "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", - "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", - "z9"); + kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(&args); } - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h index c7ac5fa1..655ce4d1 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h @@ -55,11 +55,11 @@ size_t kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_ /// /// @param[in] m_idx Row index. Must be a multiple of `m_step`. /// @param[in] n_idx Column index. Must be a multiple of `n_step`. -/// @param[in] dst_row_stride. Distance between start of two rows in the output buffer. +/// @param[in] dst_stride_row Row stride in bytes. /// /// @return The offset in bytes to the data element. size_t kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_row_stride); + size_t m_idx, size_t n_idx, size_t dst_stride_row); /// Gets the size in bytes of the destination matrix buffer. /// @@ -81,16 +81,16 @@ size_t kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length of a LHS column split +/// @param[in] k_chunk_length Length of a LHS column split. /// @param[in] lhs_packed Packed LHS matrix buffer. /// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. -/// @param[in] dst_row_stride Row stride in bytes of the output matrix. +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. void kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + void* dst, size_t dst_stride_row, float clamp_min, float clamp_max); #ifdef __cplusplus } // extern "C" diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S new file mode 100644 index 00000000..bb60a77d --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S @@ -0,0 +1,252 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x14, #0x0 + ptrue p0.b + KAI_ASM_INST(0x25207811) // ptrue pn9.b + ldr w13, [x0, #0x20] + ldr w11, [x0, #0x28] + mov x10, #0x0 + ldr x9, [x0, #0x0] +KAI_ASM_LABEL(label_1) // M loop + ldr x28, [x0, #0x8] +KAI_ASM_LABEL(label_2) // N loop + KAI_ASM_INST(0x25ab4550) // whilelt pn8.s, x10, x11, VLx2 + fmov z13.s, #1.0 + KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } + mov x27, x9 + KAI_ASM_INST(0xa040438e) // ld1w { z14.s-z15.s }, p8/Z, [x28] // Load bias + addvl x28, x28, #2 + KAI_ASM_INST(0x808e01a0) // fmopa za0.s, p0/M, p0/M, z13.s, z14.s + KAI_ASM_INST(0x808f01a1) // fmopa za1.s, p0/M, p0/M, z13.s, z15.s + KAI_ASM_INST(0x808e01a2) // fmopa za2.s, p0/M, p0/M, z13.s, z14.s + KAI_ASM_INST(0x808f01a3) // fmopa za3.s, p0/M, p0/M, z13.s, z15.s + ldr x20, [x0, #0x30] + lsr x21, x20, #0x2 + and x20, x20, #0x3 + cbz x21, label_6 + subs x21, x21, #0x1 + KAI_ASM_INST(0xa1404772) // ld1w { z18.s, z26.s }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa0404794) // ld1w { z20.s-z21.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xa1414764) // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL] + KAI_ASM_INST(0xa041478a) // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL] + KAI_ASM_INST(0xa1424773) // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL] + KAI_ASM_INST(0xa0424798) // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL] + KAI_ASM_INST(0xa043476e) // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL] + addvl x27, x27, #8 + KAI_ASM_INST(0xa1434796) // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL] + addvl x28, x28, #8 + ble label_5 +KAI_ASM_LABEL(label_4) // K loop + KAI_ASM_INST(0x80940240) // fmopa za0.s, p0/M, p0/M, z18.s, z20.s + subs x21, x21, #0x1 + KAI_ASM_INST(0x80950241) // fmopa za1.s, p0/M, p0/M, z18.s, z21.s + KAI_ASM_INST(0x80940342) // fmopa za2.s, p0/M, p0/M, z26.s, z20.s + KAI_ASM_INST(0x80950343) // fmopa za3.s, p0/M, p0/M, z26.s, z21.s + KAI_ASM_INST(0xa1404772) // ld1w { z18.s, z26.s }, pn9.b/Z, [x27] + KAI_ASM_INST(0x808a0080) // fmopa za0.s, p0/M, p0/M, z4.s, z10.s + KAI_ASM_INST(0xa0404794) // ld1w { z20.s-z21.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0x808b0081) // fmopa za1.s, p0/M, p0/M, z4.s, z11.s + KAI_ASM_INST(0x808a0182) // fmopa za2.s, p0/M, p0/M, z12.s, z10.s + KAI_ASM_INST(0x808b0183) // fmopa za3.s, p0/M, p0/M, z12.s, z11.s + KAI_ASM_INST(0xa1414764) // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL] + KAI_ASM_INST(0x80980260) // fmopa za0.s, p0/M, p0/M, z19.s, z24.s + KAI_ASM_INST(0xa041478a) // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL] + KAI_ASM_INST(0x80990261) // fmopa za1.s, p0/M, p0/M, z19.s, z25.s + KAI_ASM_INST(0x80980362) // fmopa za2.s, p0/M, p0/M, z27.s, z24.s + KAI_ASM_INST(0x80990363) // fmopa za3.s, p0/M, p0/M, z27.s, z25.s + KAI_ASM_INST(0xa1424773) // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL] + KAI_ASM_INST(0xa0424798) // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL] + KAI_ASM_INST(0x809601c0) // fmopa za0.s, p0/M, p0/M, z14.s, z22.s + KAI_ASM_INST(0x809e01c1) // fmopa za1.s, p0/M, p0/M, z14.s, z30.s + KAI_ASM_INST(0x809601e2) // fmopa za2.s, p0/M, p0/M, z15.s, z22.s + KAI_ASM_INST(0x809e01e3) // fmopa za3.s, p0/M, p0/M, z15.s, z30.s + KAI_ASM_INST(0xa043476e) // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL] + addvl x27, x27, #8 + KAI_ASM_INST(0xa1434796) // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL] + addvl x28, x28, #8 + bgt label_4 +KAI_ASM_LABEL(label_5) // K loop tail + KAI_ASM_INST(0x80940240) // fmopa za0.s, p0/M, p0/M, z18.s, z20.s + KAI_ASM_INST(0x80950241) // fmopa za1.s, p0/M, p0/M, z18.s, z21.s + KAI_ASM_INST(0x80940342) // fmopa za2.s, p0/M, p0/M, z26.s, z20.s + KAI_ASM_INST(0x80950343) // fmopa za3.s, p0/M, p0/M, z26.s, z21.s + KAI_ASM_INST(0x808a0080) // fmopa za0.s, p0/M, p0/M, z4.s, z10.s + KAI_ASM_INST(0x808b0081) // fmopa za1.s, p0/M, p0/M, z4.s, z11.s + KAI_ASM_INST(0x808a0182) // fmopa za2.s, p0/M, p0/M, z12.s, z10.s + KAI_ASM_INST(0x808b0183) // fmopa za3.s, p0/M, p0/M, z12.s, z11.s + KAI_ASM_INST(0x80980260) // fmopa za0.s, p0/M, p0/M, z19.s, z24.s + KAI_ASM_INST(0x80990261) // fmopa za1.s, p0/M, p0/M, z19.s, z25.s + KAI_ASM_INST(0x80980362) // fmopa za2.s, p0/M, p0/M, z27.s, z24.s + KAI_ASM_INST(0x80990363) // fmopa za3.s, p0/M, p0/M, z27.s, z25.s + KAI_ASM_INST(0x809601c0) // fmopa za0.s, p0/M, p0/M, z14.s, z22.s + KAI_ASM_INST(0x809e01c1) // fmopa za1.s, p0/M, p0/M, z14.s, z30.s + KAI_ASM_INST(0x809601e2) // fmopa za2.s, p0/M, p0/M, z15.s, z22.s + KAI_ASM_INST(0x809e01e3) // fmopa za3.s, p0/M, p0/M, z15.s, z30.s +KAI_ASM_LABEL(label_6) // K oddments + cbz x20, label_8 +KAI_ASM_LABEL(label_7) // K oddments: Loop + KAI_ASM_INST(0xa040477c) // ld1w { z28.s-z29.s }, pn9.b/Z, [x27] + subs x20, x20, #0x1 + addvl x27, x27, #2 + KAI_ASM_INST(0xa1404787) // ld1w { z7.s, z15.s }, pn9.b/Z, [x28] + addvl x28, x28, #2 + KAI_ASM_INST(0x80870380) // fmopa za0.s, p0/M, p0/M, z28.s, z7.s + KAI_ASM_INST(0x808f0381) // fmopa za1.s, p0/M, p0/M, z28.s, z15.s + KAI_ASM_INST(0x808703a2) // fmopa za2.s, p0/M, p0/M, z29.s, z7.s + KAI_ASM_INST(0x808f03a3) // fmopa za3.s, p0/M, p0/M, z29.s, z15.s + bgt label_7 +KAI_ASM_LABEL(label_8) // K oddments: End + ldr x26, [x0, #0x10] + sub x25, x13, x14 + cntw x24 + ld1rw { z19.s }, p0/Z, [x0, #56] + ldr x23, [x0, #0x18] + cmp x25, x24 + ld1rw { z26.s }, p0/Z, [x0, #60] + mov x12, #0x0 + csel x22, x25, x24, LT + add x26, x26, x10, LSL #2 // C += n + lsr x21, x22, #0x2 + madd x26, x14, x23, x26 // C += m * ldc + and x20, x22, #0x3 + cbz x21, label_11 +KAI_ASM_LABEL(label_10) // Store to output array: Accumulator row 0 loop + KAI_ASM_INST(0xc0860404) // mova { z4.s-z7.s }, za0h.s[x12] + KAI_ASM_INST(0xc086042c) // mova { z12.s-z15.s }, za1h.s[x12] + KAI_ASM_INST(0xc1baca64) // fclamp { z4.s-z7.s }, z19.s, z26.s + KAI_ASM_INST(0xc1baca6c) // fclamp { z12.s-z15.s }, z19.s, z26.s + add x12, x12, #0x4 + cmp x12, x21, LSL #2 + KAI_ASM_INST(0xa1604344) // st1w { z4.s, z12.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604345) // st1w { z5.s, z13.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604346) // st1w { z6.s, z14.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604347) // st1w { z7.s, z15.s }, p8, [x26] + add x26, x26, x23 + blt label_10 +KAI_ASM_LABEL(label_11) // Store to output array: Accumulator row 0 oddments + cbz x20, label_12 + KAI_ASM_INST(0xc0860400) // mova { z0.s-z3.s }, za0h.s[x12] + KAI_ASM_INST(0xc0860428) // mova { z8.s-z11.s }, za1h.s[x12] + subs x20, x20, #0x1 + KAI_ASM_INST(0xc1baca60) // fclamp { z0.s-z3.s }, z19.s, z26.s + KAI_ASM_INST(0xc1baca68) // fclamp { z8.s-z11.s }, z19.s, z26.s + KAI_ASM_INST(0xa1604340) // st1w { z0.s, z8.s }, p8, [x26] + add x26, x26, x23 + beq label_12 + subs x20, x20, #0x1 + KAI_ASM_INST(0xa1604341) // st1w { z1.s, z9.s }, p8, [x26] + add x26, x26, x23 + beq label_12 + KAI_ASM_INST(0xa1604342) // st1w { z2.s, z10.s }, p8, [x26] + add x26, x26, x23 +KAI_ASM_LABEL(label_12) // Store to output array: Accumulator row 0 oddments: End + subs x25, x25, x22 + beq label_16 + cmp x25, x24 + mov x12, #0x0 + csel x20, x25, x24, LT + lsr x21, x20, #0x2 + and x20, x20, #0x3 + cbz x21, label_14 +KAI_ASM_LABEL(label_13) // Store to output array: Accumulator row 1 loop + KAI_ASM_INST(0xc0860454) // mova { z20.s-z23.s }, za2h.s[x12] + KAI_ASM_INST(0xc086047c) // mova { z28.s-z31.s }, za3h.s[x12] + KAI_ASM_INST(0xc1baca74) // fclamp { z20.s-z23.s }, z19.s, z26.s + KAI_ASM_INST(0xc1baca7c) // fclamp { z28.s-z31.s }, z19.s, z26.s + add x12, x12, #0x4 + cmp x12, x21, LSL #2 + KAI_ASM_INST(0xa1604354) // st1w { z20.s, z28.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604355) // st1w { z21.s, z29.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604356) // st1w { z22.s, z30.s }, p8, [x26] + add x26, x26, x23 + KAI_ASM_INST(0xa1604357) // st1w { z23.s, z31.s }, p8, [x26] + add x26, x26, x23 + blt label_13 +KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments + cbz x20, label_15 + KAI_ASM_INST(0xc0860444) // mova { z4.s-z7.s }, za2h.s[x12] + KAI_ASM_INST(0xc086046c) // mova { z12.s-z15.s }, za3h.s[x12] + subs x20, x20, #0x1 + KAI_ASM_INST(0xc1baca64) // fclamp { z4.s-z7.s }, z19.s, z26.s + KAI_ASM_INST(0xc1baca6c) // fclamp { z12.s-z15.s }, z19.s, z26.s + KAI_ASM_INST(0xa1604344) // st1w { z4.s, z12.s }, p8, [x26] + add x26, x26, x23 + beq label_15 + subs x20, x20, #0x1 + KAI_ASM_INST(0xa1604345) // st1w { z5.s, z13.s }, p8, [x26] + add x26, x26, x23 + beq label_15 + KAI_ASM_INST(0xa1604346) // st1w { z6.s, z14.s }, p8, [x26] +KAI_ASM_LABEL(label_15) // Store to output array: Accumulator row 1 oddments: End +KAI_ASM_LABEL(label_16) // Store to output array: End + incw x10, ALL, MUL #2 + cmp x10, x11 + blt label_2 + incw x14, ALL, MUL #2 + mov x10, #0x0 + cmp x14, x13 + mov x9, x27 + blt label_1 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c index b2eeb5ef..c8e36d6a 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c @@ -4,13 +4,6 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) -#error This file must be compiled for AArch64, FEAT_SVE2. -#else // Architectural features check. - #include "kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" #include @@ -18,30 +11,52 @@ #include "kai/kai_common.h" +typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + int32_t min; + int32_t max; + int32_t result_zero_point; + const int n_0; + void* accumulator_buffer; + uint64_t flags; +} KernelArgs; + static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 4; +void kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(KernelArgs* args); + +// Returns a constant value specific to this kernel that's relative to vector length +static size_t kai_get_kernel_vec_length_constant(void) { + const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u8() / kai_kr; + return kernel_vec_length_constant; +} + size_t kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u32(); + return kai_mr * kai_get_kernel_vec_length_constant(); } size_t kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_kernel_vec_length_constant(); } size_t kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - return m_idx * indirect_k * sizeof(int8_t); + return m_idx * k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(int8_t); } static size_t kai_get_rhs_packed_stride_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( size_t k_chunk_count, size_t k_chunk_length) { - const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); return kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() * - (sizeof(int32_t) + indirect_k * sizeof(int8_t) + sizeof(float)); + (sizeof(int32_t) + k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(int8_t) + sizeof(float)); } size_t kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( @@ -54,11 +69,11 @@ size_t kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2v } size_t kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_row_stride) { + size_t m_idx, size_t n_idx, size_t dst_stride_row) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); - return m_idx * dst_row_stride + n_idx * sizeof(int8_t); + return m_idx * dst_stride_row + n_idx * sizeof(int8_t); } size_t kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t m, size_t n) { @@ -67,334 +82,21 @@ size_t kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme void kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params) { - typedef struct { - const void* A; - const void* B; - void* C; - uint64_t ldcb; - uint64_t M; - uint64_t N; - uint64_t K; - int32_t min; - int32_t max; - int32_t result_zero_point; - const int n_0; - void* accumulator_buffer; - uint64_t flags; - } KernelArgs; - + void* dst, size_t dst_stride_row, const struct kai_matmul_requantize32_params* params) { KernelArgs args; args.A = lhs_packed; args.B = rhs_packed; - - size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - args.C = dst; - args.ldcb = dst_row_stride; + args.ldcb = dst_stride_row; args.M = m; args.N = n; - args.K = indirect_k; + args.K = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); args.min = params->min_value; args.max = params->max_value; args.result_zero_point = params->output_zero_point; - args.accumulator_buffer = NULL; args.flags = 0; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "ldr w14, [%x[args], %[offsetof_M]]\n" - "mov x13, #0x0\n" - "mov x11, #0x0\n" - "ptrue p1.b\n" - ".inst 0x25207811 // ptrue pn9.b\n" - "ldr w10, [%x[args], %[offsetof_N]]\n" - "ldr x9, [%x[args], %[offsetof_A]]\n" - "1:" // M loop - "ldr x28, [%x[args], %[offsetof_B]]\n" - "2:" // N loop - ".inst 0x25aa4570 // whilelt pn8.s, x11, x10, VLx2\n" - ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" - "mov x27, x9\n" - ".inst 0xa040438e // ld1w { z14.s-z15.s }, p8/Z, [x28]\n" // Load bias - "addvl x28, x28, #2\n" - ".inst 0xc09025c0 // addha za0.s, p1/M, p1/M, z14.s\n" - ".inst 0xc09025e1 // addha za1.s, p1/M, p1/M, z15.s\n" - ".inst 0xc09025c2 // addha za2.s, p1/M, p1/M, z14.s\n" - ".inst 0xc09025e3 // addha za3.s, p1/M, p1/M, z15.s\n" - "ldr x20, [%x[args], %[offsetof_K]]\n" - "add x20, x20, #0x3\n" - "lsr x20, x20, #0x2\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 6f\n" - "subs x21, x21, #0x1\n" - ".inst 0xa0400762 // ld1b { z2.b-z3.b }, pn9.b/Z, [x27]\n" - ".inst 0xa1400780 // ld1b { z0.b, z8.b }, pn9.b/Z, [x28]\n" - ".inst 0xa0410772 // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0xa0410794 // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL]\n" - ".inst 0xa042077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa0420796 // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL]\n" - ".inst 0xa0430778 // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - ".inst 0xa0430784 // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL]\n" - "addvl x28, x28, #8\n" - "ble 5f\n" - "4:" // K loop - ".inst 0xa0802440 // smopa za0.s, p1/M, p1/M, z2.b, z0.b\n" - "subs x21, x21, #0x1\n" - ".inst 0xa0882441 // smopa za1.s, p1/M, p1/M, z2.b, z8.b\n" - ".inst 0xa0802462 // smopa za2.s, p1/M, p1/M, z3.b, z0.b\n" - ".inst 0xa0882463 // smopa za3.s, p1/M, p1/M, z3.b, z8.b\n" - ".inst 0xa0400762 // ld1b { z2.b-z3.b }, pn9.b/Z, [x27]\n" - ".inst 0xa0942640 // smopa za0.s, p1/M, p1/M, z18.b, z20.b\n" - ".inst 0xa1400780 // ld1b { z0.b, z8.b }, pn9.b/Z, [x28]\n" - ".inst 0xa0952641 // smopa za1.s, p1/M, p1/M, z18.b, z21.b\n" - ".inst 0xa0942662 // smopa za2.s, p1/M, p1/M, z19.b, z20.b\n" - ".inst 0xa0952663 // smopa za3.s, p1/M, p1/M, z19.b, z21.b\n" - ".inst 0xa0410772 // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0xa0962740 // smopa za0.s, p1/M, p1/M, z26.b, z22.b\n" - ".inst 0xa0410794 // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL]\n" - ".inst 0xa0972741 // smopa za1.s, p1/M, p1/M, z26.b, z23.b\n" - ".inst 0xa0962762 // smopa za2.s, p1/M, p1/M, z27.b, z22.b\n" - ".inst 0xa0972763 // smopa za3.s, p1/M, p1/M, z27.b, z23.b\n" - ".inst 0xa042077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa0420796 // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL]\n" - ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" - ".inst 0xa0852701 // smopa za1.s, p1/M, p1/M, z24.b, z5.b\n" - ".inst 0xa0842722 // smopa za2.s, p1/M, p1/M, z25.b, z4.b\n" - ".inst 0xa0852723 // smopa za3.s, p1/M, p1/M, z25.b, z5.b\n" - ".inst 0xa0430778 // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - ".inst 0xa0430784 // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL]\n" - "addvl x28, x28, #8\n" - "bgt 4b\n" - "5:" // K loop tail - ".inst 0xa0802440 // smopa za0.s, p1/M, p1/M, z2.b, z0.b\n" - ".inst 0xa0882441 // smopa za1.s, p1/M, p1/M, z2.b, z8.b\n" - ".inst 0xa0802462 // smopa za2.s, p1/M, p1/M, z3.b, z0.b\n" - ".inst 0xa0882463 // smopa za3.s, p1/M, p1/M, z3.b, z8.b\n" - ".inst 0xa0942640 // smopa za0.s, p1/M, p1/M, z18.b, z20.b\n" - ".inst 0xa0952641 // smopa za1.s, p1/M, p1/M, z18.b, z21.b\n" - ".inst 0xa0942662 // smopa za2.s, p1/M, p1/M, z19.b, z20.b\n" - ".inst 0xa0952663 // smopa za3.s, p1/M, p1/M, z19.b, z21.b\n" - ".inst 0xa0962740 // smopa za0.s, p1/M, p1/M, z26.b, z22.b\n" - ".inst 0xa0972741 // smopa za1.s, p1/M, p1/M, z26.b, z23.b\n" - ".inst 0xa0962762 // smopa za2.s, p1/M, p1/M, z27.b, z22.b\n" - ".inst 0xa0972763 // smopa za3.s, p1/M, p1/M, z27.b, z23.b\n" - ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" - ".inst 0xa0852701 // smopa za1.s, p1/M, p1/M, z24.b, z5.b\n" - ".inst 0xa0842722 // smopa za2.s, p1/M, p1/M, z25.b, z4.b\n" - ".inst 0xa0852723 // smopa za3.s, p1/M, p1/M, z25.b, z5.b\n" - "6:" // K oddments - "cbz x20, 8f\n" - "7:" // K oddments: Loop - ".inst 0xa0400770 // ld1b { z16.b-z17.b }, pn9.b/Z, [x27]\n" - "subs x20, x20, #0x1\n" - "addvl x27, x27, #2\n" - ".inst 0xa0400788 // ld1b { z8.b-z9.b }, pn9.b/Z, [x28]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0882600 // smopa za0.s, p1/M, p1/M, z16.b, z8.b\n" - ".inst 0xa0892601 // smopa za1.s, p1/M, p1/M, z16.b, z9.b\n" - ".inst 0xa0882622 // smopa za2.s, p1/M, p1/M, z17.b, z8.b\n" - ".inst 0xa0892623 // smopa za3.s, p1/M, p1/M, z17.b, z9.b\n" - "bgt 7b\n" - "8:" // K oddments: End - "ldr x26, [%x[args], %[offsetof_C]]\n" - "sub x25, x14, x13\n" - "cntw x24\n" - "ld1rw { z27.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" - "ldr x23, [%x[args], %[offsetof_ldcb]]\n" - "whilelt p0.h, x11, x10\n" - "cmp x25, x24\n" - "ld1rw { z1.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" - "csel x22, x25, x24, LT\n" - "ld1rw { z0.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_result_zero_point]]\n" - "mov x12, #0x0\n" - "add x26, x26, x11\n" // C += n - "lsr x21, x22, #0x2\n" - "ld1w { z22.s }, p1/Z, [x28]\n" - "madd x26, x13, x23, x26\n" // C += m * ldc - "ld1w { z26.s }, p1/Z, [x28, #1, MUL VL]\n" - "and x20, x22, #0x3\n" - "addvl x28, x28, #2\n" - "cbz x21, 11f\n" - "10:" // Store to output array: Accumulator row 0 loop - ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" - ".inst 0xc086043c // mova { z28.s-z31.s }, za1h.s[x12]\n" - ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" - "fmul z16.s, z16.s, z22.s\n" - "fmul z17.s, z17.s, z22.s\n" - "add x12, x12, #0x4\n" - "fmul z18.s, z18.s, z22.s\n" - "fmul z19.s, z19.s, z22.s\n" - "cmp x12, x21, LSL #2\n" - "fmul z28.s, z28.s, z26.s\n" - "fmul z29.s, z29.s, z26.s\n" - "fmul z30.s, z30.s, z26.s\n" - "fmul z31.s, z31.s, z26.s\n" - ".inst 0xc1b8e210 // frintn { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc131e210 // fcvtzs { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc1b8e39c // frintn { z28.s-z31.s }, { z28.s-z31.s }\n" - ".inst 0xc1a0ab10 // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s\n" - ".inst 0xc131e39c // fcvtzs { z28.s-z31.s }, { z28.s-z31.s }\n" - ".inst 0xc1a0ab1c // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s\n" - ".inst 0xc1a1cf70 // sclamp { z16.s-z19.s }, z27.s, z1.s\n" - ".inst 0xc1a1cf7c // sclamp { z28.s-z31.s }, z27.s, z1.s\n" - "uzp1 z5.h, z16.h, z28.h\n" - "uzp1 z20.h, z17.h, z29.h\n" - "uzp1 z17.h, z18.h, z30.h\n" - "uzp1 z16.h, z19.h, z31.h\n" - "st1b { z5.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z20.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z17.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "blt 10b\n" - "11:" // Store to output array: Accumulator row 0 oddments - "cbz x20, 12f\n" - ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" - ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" - ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" - "fmul z4.s, z4.s, z22.s\n" - "fmul z5.s, z5.s, z22.s\n" - "subs x20, x20, #0x1\n" - "fmul z6.s, z6.s, z22.s\n" - "fmul z7.s, z7.s, z22.s\n" - "fmul z12.s, z12.s, z26.s\n" - "fmul z13.s, z13.s, z26.s\n" - "fmul z14.s, z14.s, z26.s\n" - "fmul z15.s, z15.s, z26.s\n" - ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc1b8e18c // frintn { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" - ".inst 0xc131e18c // fcvtzs { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc1a0ab0c // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s\n" - ".inst 0xc1a1cf64 // sclamp { z4.s-z7.s }, z27.s, z1.s\n" - ".inst 0xc1a1cf6c // sclamp { z12.s-z15.s }, z27.s, z1.s\n" - "uzp1 z16.h, z4.h, z12.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "beq 12f\n" - "subs x20, x20, #0x1\n" - "uzp1 z16.h, z5.h, z13.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "beq 12f\n" - "uzp1 z16.h, z6.h, z14.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "12:" // Store to output array: Accumulator row 0 oddments: End - "subs x25, x25, x22\n" - "beq 16f\n" - "cmp x25, x24\n" - "mov x12, #0x0\n" - "csel x20, x25, x24, LT\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 14f\n" - "13:" // Store to output array: Accumulator row 1 loop - ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" - ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" - ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" - "fmul z8.s, z8.s, z22.s\n" - "fmul z9.s, z9.s, z22.s\n" - "add x12, x12, #0x4\n" - "fmul z10.s, z10.s, z22.s\n" - "fmul z11.s, z11.s, z22.s\n" - "cmp x12, x21, LSL #2\n" - "fmul z16.s, z16.s, z26.s\n" - "fmul z17.s, z17.s, z26.s\n" - "fmul z18.s, z18.s, z26.s\n" - "fmul z19.s, z19.s, z26.s\n" - ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc1b8e210 // frintn { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc1a0ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s\n" - ".inst 0xc131e210 // fcvtzs { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc1a0ab10 // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s\n" - ".inst 0xc1a1cf68 // sclamp { z8.s-z11.s }, z27.s, z1.s\n" - ".inst 0xc1a1cf70 // sclamp { z16.s-z19.s }, z27.s, z1.s\n" - "uzp1 z21.h, z8.h, z16.h\n" - "uzp1 z20.h, z9.h, z17.h\n" - "uzp1 z17.h, z10.h, z18.h\n" - "uzp1 z16.h, z11.h, z19.h\n" - "st1b { z21.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z20.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z17.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "blt 13b\n" - "14:" // Store to output array: Accumulator row 1 oddments - "cbz x20, 15f\n" - ".inst 0xc086044c // mova { z12.s-z15.s }, za2h.s[x12]\n" - ".inst 0xc0860464 // mova { z4.s-z7.s }, za3h.s[x12]\n" - ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" - "fmul z12.s, z12.s, z22.s\n" - "fmul z13.s, z13.s, z22.s\n" - "subs x20, x20, #0x1\n" - "fmul z14.s, z14.s, z22.s\n" - "fmul z15.s, z15.s, z22.s\n" - "fmul z4.s, z4.s, z26.s\n" - "fmul z5.s, z5.s, z26.s\n" - "fmul z6.s, z6.s, z26.s\n" - "fmul z7.s, z7.s, z26.s\n" - ".inst 0xc1b8e18c // frintn { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc131e18c // fcvtzs { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc1a0ab0c // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s\n" - ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" - ".inst 0xc1a1cf6c // sclamp { z12.s-z15.s }, z27.s, z1.s\n" - ".inst 0xc1a1cf64 // sclamp { z4.s-z7.s }, z27.s, z1.s\n" - "uzp1 z16.h, z12.h, z4.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "beq 15f\n" - "subs x20, x20, #0x1\n" - "uzp1 z16.h, z13.h, z5.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "beq 15f\n" - "uzp1 z16.h, z14.h, z6.h\n" - "st1b { z16.h }, p0, [x26]\n" - "15:" // Store to output array: Accumulator row 1 oddments: End - "16:" // Store to output array: End - "incw x11, ALL, MUL #2\n" - "cmp x11, x10\n" - "blt 2b\n" - "incw x13, ALL, MUL #2\n" - "mov x11, #0x0\n" - "cmp x13, x14\n" - "mov x9, x27\n" - "blt 1b\n" - ".inst 0xd503467f // SMSTOP\n" - : - : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), - [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), - [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), - [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), - [offsetof_KernelArgs_result_zero_point] "I"(offsetof(KernelArgs, result_zero_point)), - [offsetof_M] "I"(offsetof(KernelArgs, M)), [offsetof_N] "I"(offsetof(KernelArgs, N)), - [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", - "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", - "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", - "z9"); + kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(&args); } - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h index 2f52001b..31ed3e5c 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h @@ -37,7 +37,6 @@ size_t kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_ /// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. /// @param[in] k_chunk_count Number of LHS column splits. /// @param[in] k_chunk_length Length of a LHS column split. -/// @param[in] k Number of columns in the unpacked LHS matrix. /// /// @return The offset in bytes to the data element. size_t kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( @@ -57,11 +56,11 @@ size_t kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2v /// /// @param[in] m_idx Row index. Must be a multiple of `m_step`. /// @param[in] n_idx Column index. Must be a multiple of `n_step`. -/// @param[in] dst_row_stride. Distance between start of two rows in the output buffer. +/// @param[in] dst_stride_row Row stride in bytes. /// /// @return The offset in bytes to the data element. size_t kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_row_stride); + size_t m_idx, size_t n_idx, size_t dst_stride_row); /// Gets the size in bytes of the destination matrix buffer. /// @@ -83,17 +82,15 @@ size_t kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length of a LHS column split +/// @param[in] k_chunk_length Length of a LHS column split. /// @param[in] lhs_packed Packed LHS matrix buffer. /// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. -/// @param[in] dst_row_stride Row stride in bytes of the output matrix. - +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. /// @param[in] params Requantization and clamp parameters. - void kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params); + void* dst, size_t dst_stride_row, const struct kai_matmul_requantize32_params* params); #ifdef __cplusplus } // extern "C" diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S new file mode 100644 index 00000000..c73aaad2 --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S @@ -0,0 +1,336 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x14, #0x0 + ptrue p1.b + KAI_ASM_INST(0x25207811) // ptrue pn9.b + ldr w13, [x0, #0x20] + ldr w11, [x0, #0x28] + mov x10, #0x0 + ldr x9, [x0, #0x0] +KAI_ASM_LABEL(label_1) // M loop + ldr x28, [x0, #0x8] +KAI_ASM_LABEL(label_2) // N loop + KAI_ASM_INST(0x25ab4550) // whilelt pn8.s, x10, x11, VLx2 + KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } + mov x27, x9 + KAI_ASM_INST(0xa040438e) // ld1w { z14.s-z15.s }, p8/Z, [x28] // Load bias + addvl x28, x28, #2 + KAI_ASM_INST(0xc09025c0) // addha za0.s, p1/M, p1/M, z14.s + KAI_ASM_INST(0xc09025e1) // addha za1.s, p1/M, p1/M, z15.s + KAI_ASM_INST(0xc09025c2) // addha za2.s, p1/M, p1/M, z14.s + KAI_ASM_INST(0xc09025e3) // addha za3.s, p1/M, p1/M, z15.s + ldr x20, [x0, #0x30] + add x20, x20, #0x3 + lsr x20, x20, #0x2 + lsr x21, x20, #0x2 + and x20, x20, #0x3 + cbz x21, label_6 + subs x21, x21, #0x1 + KAI_ASM_INST(0xa0400762) // ld1b { z2.b-z3.b }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa1400780) // ld1b { z0.b, z8.b }, pn9.b/Z, [x28] + KAI_ASM_INST(0xa0410772) // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL] + KAI_ASM_INST(0xa0410794) // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL] + KAI_ASM_INST(0xa042077a) // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL] + KAI_ASM_INST(0xa0420796) // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL] + KAI_ASM_INST(0xa0430778) // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL] + addvl x27, x27, #8 + KAI_ASM_INST(0xa0430784) // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL] + addvl x28, x28, #8 + ble label_5 +KAI_ASM_LABEL(label_4) // K loop + KAI_ASM_INST(0xa0802440) // smopa za0.s, p1/M, p1/M, z2.b, z0.b + subs x21, x21, #0x1 + KAI_ASM_INST(0xa0882441) // smopa za1.s, p1/M, p1/M, z2.b, z8.b + KAI_ASM_INST(0xa0802462) // smopa za2.s, p1/M, p1/M, z3.b, z0.b + KAI_ASM_INST(0xa0882463) // smopa za3.s, p1/M, p1/M, z3.b, z8.b + KAI_ASM_INST(0xa0400762) // ld1b { z2.b-z3.b }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa0942640) // smopa za0.s, p1/M, p1/M, z18.b, z20.b + KAI_ASM_INST(0xa1400780) // ld1b { z0.b, z8.b }, pn9.b/Z, [x28] + KAI_ASM_INST(0xa0952641) // smopa za1.s, p1/M, p1/M, z18.b, z21.b + KAI_ASM_INST(0xa0942662) // smopa za2.s, p1/M, p1/M, z19.b, z20.b + KAI_ASM_INST(0xa0952663) // smopa za3.s, p1/M, p1/M, z19.b, z21.b + KAI_ASM_INST(0xa0410772) // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL] + KAI_ASM_INST(0xa0962740) // smopa za0.s, p1/M, p1/M, z26.b, z22.b + KAI_ASM_INST(0xa0410794) // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL] + KAI_ASM_INST(0xa0972741) // smopa za1.s, p1/M, p1/M, z26.b, z23.b + KAI_ASM_INST(0xa0962762) // smopa za2.s, p1/M, p1/M, z27.b, z22.b + KAI_ASM_INST(0xa0972763) // smopa za3.s, p1/M, p1/M, z27.b, z23.b + KAI_ASM_INST(0xa042077a) // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL] + KAI_ASM_INST(0xa0420796) // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL] + KAI_ASM_INST(0xa0842700) // smopa za0.s, p1/M, p1/M, z24.b, z4.b + KAI_ASM_INST(0xa0852701) // smopa za1.s, p1/M, p1/M, z24.b, z5.b + KAI_ASM_INST(0xa0842722) // smopa za2.s, p1/M, p1/M, z25.b, z4.b + KAI_ASM_INST(0xa0852723) // smopa za3.s, p1/M, p1/M, z25.b, z5.b + KAI_ASM_INST(0xa0430778) // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL] + addvl x27, x27, #8 + KAI_ASM_INST(0xa0430784) // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL] + addvl x28, x28, #8 + bgt label_4 +KAI_ASM_LABEL(label_5) // K loop tail + KAI_ASM_INST(0xa0802440) // smopa za0.s, p1/M, p1/M, z2.b, z0.b + KAI_ASM_INST(0xa0882441) // smopa za1.s, p1/M, p1/M, z2.b, z8.b + KAI_ASM_INST(0xa0802462) // smopa za2.s, p1/M, p1/M, z3.b, z0.b + KAI_ASM_INST(0xa0882463) // smopa za3.s, p1/M, p1/M, z3.b, z8.b + KAI_ASM_INST(0xa0942640) // smopa za0.s, p1/M, p1/M, z18.b, z20.b + KAI_ASM_INST(0xa0952641) // smopa za1.s, p1/M, p1/M, z18.b, z21.b + KAI_ASM_INST(0xa0942662) // smopa za2.s, p1/M, p1/M, z19.b, z20.b + KAI_ASM_INST(0xa0952663) // smopa za3.s, p1/M, p1/M, z19.b, z21.b + KAI_ASM_INST(0xa0962740) // smopa za0.s, p1/M, p1/M, z26.b, z22.b + KAI_ASM_INST(0xa0972741) // smopa za1.s, p1/M, p1/M, z26.b, z23.b + KAI_ASM_INST(0xa0962762) // smopa za2.s, p1/M, p1/M, z27.b, z22.b + KAI_ASM_INST(0xa0972763) // smopa za3.s, p1/M, p1/M, z27.b, z23.b + KAI_ASM_INST(0xa0842700) // smopa za0.s, p1/M, p1/M, z24.b, z4.b + KAI_ASM_INST(0xa0852701) // smopa za1.s, p1/M, p1/M, z24.b, z5.b + KAI_ASM_INST(0xa0842722) // smopa za2.s, p1/M, p1/M, z25.b, z4.b + KAI_ASM_INST(0xa0852723) // smopa za3.s, p1/M, p1/M, z25.b, z5.b +KAI_ASM_LABEL(label_6) // K oddments + cbz x20, label_8 +KAI_ASM_LABEL(label_7) // K oddments: Loop + KAI_ASM_INST(0xa0400770) // ld1b { z16.b-z17.b }, pn9.b/Z, [x27] + subs x20, x20, #0x1 + addvl x27, x27, #2 + KAI_ASM_INST(0xa0400788) // ld1b { z8.b-z9.b }, pn9.b/Z, [x28] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0882600) // smopa za0.s, p1/M, p1/M, z16.b, z8.b + KAI_ASM_INST(0xa0892601) // smopa za1.s, p1/M, p1/M, z16.b, z9.b + KAI_ASM_INST(0xa0882622) // smopa za2.s, p1/M, p1/M, z17.b, z8.b + KAI_ASM_INST(0xa0892623) // smopa za3.s, p1/M, p1/M, z17.b, z9.b + bgt label_7 +KAI_ASM_LABEL(label_8) // K oddments: End + ldr x26, [x0, #0x10] + sub x25, x13, x14 + cntw x24 + ld1rw { z27.s }, p1/Z, [x0, #56] + ldr x23, [x0, #0x18] + whilelt p0.h, x10, x11 + cmp x25, x24 + ld1rw { z1.s }, p1/Z, [x0, #60] + csel x22, x25, x24, LT + ld1rw { z0.s }, p1/Z, [x0, #64] + mov x12, #0x0 + add x26, x26, x10 // C += n + lsr x21, x22, #0x2 + ld1w { z22.s }, p1/Z, [x28] + madd x26, x14, x23, x26 // C += m * ldc + ld1w { z26.s }, p1/Z, [x28, #1, MUL VL] + and x20, x22, #0x3 + addvl x28, x28, #2 + cbz x21, label_11 +KAI_ASM_LABEL(label_10) // Store to output array: Accumulator row 0 loop + KAI_ASM_INST(0xc0860410) // mova { z16.s-z19.s }, za0h.s[x12] + KAI_ASM_INST(0xc086043c) // mova { z28.s-z31.s }, za1h.s[x12] + KAI_ASM_INST(0xc132e210) // scvtf { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc132e39c) // scvtf { z28.s-z31.s }, { z28.s-z31.s } + fmul z16.s, z16.s, z22.s + fmul z17.s, z17.s, z22.s + add x12, x12, #0x4 + fmul z18.s, z18.s, z22.s + fmul z19.s, z19.s, z22.s + cmp x12, x21, LSL #2 + fmul z28.s, z28.s, z26.s + fmul z29.s, z29.s, z26.s + fmul z30.s, z30.s, z26.s + fmul z31.s, z31.s, z26.s + KAI_ASM_INST(0xc1b8e210) // frintn { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc131e210) // fcvtzs { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc1b8e39c) // frintn { z28.s-z31.s }, { z28.s-z31.s } + KAI_ASM_INST(0xc1a0ab10) // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s + KAI_ASM_INST(0xc131e39c) // fcvtzs { z28.s-z31.s }, { z28.s-z31.s } + KAI_ASM_INST(0xc1a0ab1c) // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s + KAI_ASM_INST(0xc1a1cf70) // sclamp { z16.s-z19.s }, z27.s, z1.s + KAI_ASM_INST(0xc1a1cf7c) // sclamp { z28.s-z31.s }, z27.s, z1.s + uzp1 z5.h, z16.h, z28.h + uzp1 z20.h, z17.h, z29.h + uzp1 z17.h, z18.h, z30.h + uzp1 z16.h, z19.h, z31.h + st1b { z5.h }, p0, [x26] + add x26, x26, x23 + st1b { z20.h }, p0, [x26] + add x26, x26, x23 + st1b { z17.h }, p0, [x26] + add x26, x26, x23 + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + blt label_10 +KAI_ASM_LABEL(label_11) // Store to output array: Accumulator row 0 oddments + cbz x20, label_12 + KAI_ASM_INST(0xc0860404) // mova { z4.s-z7.s }, za0h.s[x12] + KAI_ASM_INST(0xc086042c) // mova { z12.s-z15.s }, za1h.s[x12] + KAI_ASM_INST(0xc132e084) // scvtf { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc132e18c) // scvtf { z12.s-z15.s }, { z12.s-z15.s } + fmul z4.s, z4.s, z22.s + fmul z5.s, z5.s, z22.s + subs x20, x20, #0x1 + fmul z6.s, z6.s, z22.s + fmul z7.s, z7.s, z22.s + fmul z12.s, z12.s, z26.s + fmul z13.s, z13.s, z26.s + fmul z14.s, z14.s, z26.s + fmul z15.s, z15.s, z26.s + KAI_ASM_INST(0xc1b8e084) // frintn { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc131e084) // fcvtzs { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc1b8e18c) // frintn { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc1a0ab04) // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s + KAI_ASM_INST(0xc131e18c) // fcvtzs { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc1a0ab0c) // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s + KAI_ASM_INST(0xc1a1cf64) // sclamp { z4.s-z7.s }, z27.s, z1.s + KAI_ASM_INST(0xc1a1cf6c) // sclamp { z12.s-z15.s }, z27.s, z1.s + uzp1 z16.h, z4.h, z12.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + beq label_12 + subs x20, x20, #0x1 + uzp1 z16.h, z5.h, z13.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + beq label_12 + uzp1 z16.h, z6.h, z14.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 +KAI_ASM_LABEL(label_12) // Store to output array: Accumulator row 0 oddments: End + subs x25, x25, x22 + beq label_16 + cmp x25, x24 + mov x12, #0x0 + csel x20, x25, x24, LT + lsr x21, x20, #0x2 + and x20, x20, #0x3 + cbz x21, label_14 +KAI_ASM_LABEL(label_13) // Store to output array: Accumulator row 1 loop + KAI_ASM_INST(0xc0860448) // mova { z8.s-z11.s }, za2h.s[x12] + KAI_ASM_INST(0xc0860470) // mova { z16.s-z19.s }, za3h.s[x12] + KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc132e210) // scvtf { z16.s-z19.s }, { z16.s-z19.s } + fmul z8.s, z8.s, z22.s + fmul z9.s, z9.s, z22.s + add x12, x12, #0x4 + fmul z10.s, z10.s, z22.s + fmul z11.s, z11.s, z22.s + cmp x12, x21, LSL #2 + fmul z16.s, z16.s, z26.s + fmul z17.s, z17.s, z26.s + fmul z18.s, z18.s, z26.s + fmul z19.s, z19.s, z26.s + KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc1b8e210) // frintn { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s + KAI_ASM_INST(0xc131e210) // fcvtzs { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc1a0ab10) // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s + KAI_ASM_INST(0xc1a1cf68) // sclamp { z8.s-z11.s }, z27.s, z1.s + KAI_ASM_INST(0xc1a1cf70) // sclamp { z16.s-z19.s }, z27.s, z1.s + uzp1 z21.h, z8.h, z16.h + uzp1 z20.h, z9.h, z17.h + uzp1 z17.h, z10.h, z18.h + uzp1 z16.h, z11.h, z19.h + st1b { z21.h }, p0, [x26] + add x26, x26, x23 + st1b { z20.h }, p0, [x26] + add x26, x26, x23 + st1b { z17.h }, p0, [x26] + add x26, x26, x23 + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + blt label_13 +KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments + cbz x20, label_15 + KAI_ASM_INST(0xc086044c) // mova { z12.s-z15.s }, za2h.s[x12] + KAI_ASM_INST(0xc0860464) // mova { z4.s-z7.s }, za3h.s[x12] + KAI_ASM_INST(0xc132e18c) // scvtf { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc132e084) // scvtf { z4.s-z7.s }, { z4.s-z7.s } + fmul z12.s, z12.s, z22.s + fmul z13.s, z13.s, z22.s + subs x20, x20, #0x1 + fmul z14.s, z14.s, z22.s + fmul z15.s, z15.s, z22.s + fmul z4.s, z4.s, z26.s + fmul z5.s, z5.s, z26.s + fmul z6.s, z6.s, z26.s + fmul z7.s, z7.s, z26.s + KAI_ASM_INST(0xc1b8e18c) // frintn { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc131e18c) // fcvtzs { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc1b8e084) // frintn { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc1a0ab0c) // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s + KAI_ASM_INST(0xc131e084) // fcvtzs { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc1a0ab04) // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s + KAI_ASM_INST(0xc1a1cf6c) // sclamp { z12.s-z15.s }, z27.s, z1.s + KAI_ASM_INST(0xc1a1cf64) // sclamp { z4.s-z7.s }, z27.s, z1.s + uzp1 z16.h, z12.h, z4.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + beq label_15 + subs x20, x20, #0x1 + uzp1 z16.h, z13.h, z5.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + beq label_15 + uzp1 z16.h, z14.h, z6.h + st1b { z16.h }, p0, [x26] +KAI_ASM_LABEL(label_15) // Store to output array: Accumulator row 1 oddments: End +KAI_ASM_LABEL(label_16) // Store to output array: End + incw x10, ALL, MUL #2 + cmp x10, x11 + blt label_2 + incw x14, ALL, MUL #2 + mov x10, #0x0 + cmp x14, x13 + mov x9, x27 + blt label_1 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c index f996bd48..4058a06e 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c @@ -3,14 +3,6 @@ // // SPDX-License-Identifier: Apache-2.0 // - -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) -#error This file must be compiled for AArch64, FEAT_SVE2. -#else // Architectural features check. - #include "kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" #include @@ -18,9 +10,14 @@ #include "kai/kai_common.h" -#define MR 2 -#define KR 2 -#define MAX_M_STEP (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR) +enum { + MR = 2, + KR = 2, + MAX_M_STEP = MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR, +}; + +void kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t height, size_t width, void* in, size_t row_offset, void* out); static size_t kai_get_mr_lhs_imatmul_pack_x16p2vlx2_x16p_sme(void) { return MR * kai_get_sme_vector_length_u16() / KR; @@ -69,273 +66,9 @@ void kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme( } } - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x22, %x[width]\n" - "mov x21, %x[width]\n" - "cnth x20\n" - "inch x22\n" - "sub x7, x20, #0x1\n" - "sub x22, x22, #0x1\n" - "ands x7, x21, x7\n" - "cntw x8\n" - "udiv x22, x22, x20\n" // n_passes = ceildiv(width, VL) - "csel x7, x7, x20, NE\n" - "sub x13, x22, #0x1\n" - "add x7, x7, #0x1\n" - "sub x17, x8, #0x2\n" - "lsl x21, %x[height], #0x1\n" // height * 2 - "lsl x20, x8, #0x1\n" - "mov x16, #0x0\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - "cntw x9, ALL, MUL #2\n" - "cntw x28, ALL, MUL #3\n" - "ldr x27, [x11, #0x0]\n" - "lsr x13, x13, #0x1\n" // n_loops = (n_passes - 1) / 2 - "and x26, x22, #0x1\n" // odd_tail = bool(n_passes & 0x1) - "ldr x25, [x10, #0x0]\n" - "lsr x7, x7, #0x1\n" - "ptrue p12.s\n" - "ldr x24, [x11, #0x8]\n" - "whilelt p11.h, XZR, x21\n" - "whilelt p10.h, x20, x21\n" - "ldr x21, [x10, #0x8]\n" - "mov x23, %x[row_offset]\n" - "mov x22, %x[out]\n" - "whilelt p9.h, x16, %x[width]\n" - "whilelt p8.h, x16, %x[width]\n" - "add x11, x11, #0x10\n" - "add x10, x10, #0x10\n" - "mov x12, #0x0\n" - "cbz x17, 2f\n" - "1:" // K loop: Charge: Loop - ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" - ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" - ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" - ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" - ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" - "add x12, x12, #0x4\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "cmp x12, x17, LSL #1\n" - "blt 1b\n" - "2:" // K loop: Charge: End - ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" - ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" - ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" - ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" - "ldr x27, [x11, #0x0]\n" - "inch x16\n" - ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "inch x23\n" - "cbz x13, 8f\n" - "mov x20, x13\n" - "3:" // K loop: Main loop - "whilelt p8.h, x16, %x[width]\n" - "mov x15, #0x0\n" - "mov x14, #0x0\n" - "cbz x17, 5f\n" - "4:" // K loop: Main loop: First: Loop - ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" - ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" - ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" - ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" - ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - "add x10, x10, #0x10\n" - "add x15, x15, #0x4\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x14, x14, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x14, x17\n" - "blt 4b\n" - "5:" // K loop: Main loop: First: Tail - ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" - ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" - ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" - ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" - "ldr x27, [x11, #0x0]\n" - "mov x13, #0x0\n" - ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" - "ldr x25, [x10, #0x0]\n" - "mov x12, #0x0\n" - ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - "whilelt p9.h, x16, %x[width]\n" - "inch x16\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "inch x23\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "whilelt p8.h, x16, %x[width]\n" - "cbz x17, 7f\n" - "6:" // K loop: Main loop: Second: Loop - ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" - ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" - ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" - ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" - ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "add x10, x10, #0x10\n" - "add x13, x13, #0x4\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x12, x12, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x12, x17\n" - "blt 6b\n" - "7:" // K loop: Main loop: Second: Tail - ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" - ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" - ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" - ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "whilelt p9.h, x16, %x[width]\n" - "subs x20, x20, #0x1\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "inch x16\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "inch x23\n" - "bgt 3b\n" - "8:" // K loop: Tails - "cbnz x26, 11f\n" - "mov x11, %x[in]\n" - "whilelt p8.h, x16, %x[width]\n" - "mov x13, #0x0\n" - "mov x12, #0x0\n" - "9:" // K loop: Tails: Even: First - ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25396161 // psel p1.h, p8.h/Z, p11.h[w13, #1]\n" - ".inst 0x25396140 // psel p0.h, p8.h/Z, p10.h[w13, #1]\n" - ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "ldr x21, [x11, #0x0]\n" - "cmp x12, x8\n" - "ldr x20, [x11, x8, LSL #0x3]\n" - "add x11, x11, #0x8\n" - ".inst 0xe05726a1 // ld1h { za0h.h[x13, #1] }, p1/Z, [x21, x23, LSL #1]\n" - ".inst 0xe0572289 // ld1h { za1h.h[x13, #1] }, p0/Z, [x20, x23, LSL #1]\n" - "add x13, x13, #0x2\n" - "blt 9b\n" - "whilelt p9.h, x16, %x[width]\n" - "whilelt p8.h, x16, %x[width]\n" - "mov x20, #0x0\n" - "mov x12, #0x0\n" - "10:" // K loop: Tails: Even: Second - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - "add x20, x20, #0x2\n" - ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 10b\n" - "whilelt p8.h, x16, %x[width]\n" - "b 13f\n" - "11:" // K loop: Tails: Odd - "mov x12, #0x0\n" - "12:" // K loop: Tails: Odd: Loop - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 12b\n" - "13:" // K loop: End - "mov %x[out], x22\n" - ".inst 0xd503467f // SMSTOP\n" - : [out] "+&r"(out) - : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", - "p7", "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", - "x24", "x25", "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", - "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", - "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + height, width, in, row_offset, out); // NOLINT(bugprone-multi-level-implicit-pointer-conversion) out_base += m_step * kai_roundup(k_chunk_length, KR) * sizeof(uint16_t); } } } - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h index a0343938..2990bad7 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h @@ -16,7 +16,7 @@ extern "C" { /// /// The starting row index must be divisible by `m_step`. /// -/// @return Step size for row index +/// @return The m step value. size_t kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme(void); /// Gets the offset in bytes to the data element in the packed LHS buffer. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S new file mode 100644 index 00000000..d82b08d7 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S @@ -0,0 +1,312 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(lhs_imatmul_pack_x16p2vlx2_x16p_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x8, #0x0 + cnth x22 + mov x21, x1 + inch x21 + mov x20, x1 + sub x17, x22, #0x1 + sub x21, x21, #0x1 + ands x17, x20, x17 + cntw x16 + udiv x21, x21, x22 // n_passes = ceildiv(width, VL) + csel x17, x17, x22, NE + sub x13, x21, #0x1 + add x17, x17, #0x1 + sub x11, x16, #0x2 + lsl x22, x0, #0x1 // height * 2 + lsl x20, x16, #0x1 + mov x10, x2 + add x9, x2, x16, LSL #3 + cntw x28, ALL, MUL #2 + ldr x27, [x10, #0x0] + cntw x26, ALL, MUL #3 + lsr x13, x13, #0x1 // n_loops = (n_passes - 1) / 2 + ldr x25, [x9, #0x0] + and x24, x21, #0x1 // odd_tail = bool(n_passes & 0x1) + lsr x17, x17, #0x1 + ldr x23, [x10, #0x8] + ptrue p12.s + whilelt p11.h, XZR, x22 + ldr x21, [x9, #0x8] + whilelt p10.h, x20, x22 + mov x22, x3 + whilelt p9.h, x8, x1 + whilelt p8.h, x8, x1 + add x10, x10, #0x10 + add x9, x9, #0x10 + mov x12, #0x0 + cbz x11, label_2 +KAI_ASM_LABEL(label_1) // K loop: Charge: Loop + KAI_ASM_INST(0x25286163) // psel p3.h, p8.h/Z, p11.h[w12] + KAI_ASM_INST(0x25286142) // psel p2.h, p8.h/Z, p10.h[w12] + KAI_ASM_INST(0x25686161) // psel p1.h, p8.h/Z, p11.h[w12, #2] + KAI_ASM_INST(0x25686140) // psel p0.h, p8.h/Z, p10.h[w12, #2] + KAI_ASM_INST(0xe0560f60) // ld1h { za0h.h[x12] }, p3/Z, [x27, x22, LSL #1] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0560b28) // ld1h { za1h.h[x12] }, p2/Z, [x25, x22, LSL #1] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe05606e2) // ld1h { za0h.h[x12, #2] }, p1/Z, [x23, x22, LSL #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe05602aa) // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x22, LSL #1] + add x12, x12, #0x4 + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + cmp x12, x11, LSL #1 + blt label_1 +KAI_ASM_LABEL(label_2) // K loop: Charge: End + KAI_ASM_INST(0x25286163) // psel p3.h, p8.h/Z, p11.h[w12] + KAI_ASM_INST(0x25286142) // psel p2.h, p8.h/Z, p10.h[w12] + KAI_ASM_INST(0x25686161) // psel p1.h, p8.h/Z, p11.h[w12, #2] + KAI_ASM_INST(0x25686140) // psel p0.h, p8.h/Z, p10.h[w12, #2] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0560f60) // ld1h { za0h.h[x12] }, p3/Z, [x27, x22, LSL #1] + ldr x27, [x10, #0x0] + inch x8 + KAI_ASM_INST(0xe0560b28) // ld1h { za1h.h[x12] }, p2/Z, [x25, x22, LSL #1] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe05606e2) // ld1h { za0h.h[x12, #2] }, p1/Z, [x23, x22, LSL #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe05602aa) // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + inch x22 + cbz x13, label_8 + mov x20, x13 +KAI_ASM_LABEL(label_3) // K loop: Main loop + whilelt p8.h, x8, x1 + mov x15, #0x0 + mov x14, #0x0 + cbz x11, label_5 +KAI_ASM_LABEL(label_4) // K loop: Main loop: First: Loop + KAI_ASM_INST(0x253b6160) // psel p0.h, p8.h/Z, p11.h[w15, #1] + KAI_ASM_INST(0x253b6142) // psel p2.h, p8.h/Z, p10.h[w15, #1] + KAI_ASM_INST(0x257b6161) // psel p1.h, p8.h/Z, p11.h[w15, #3] + KAI_ASM_INST(0x257b6143) // psel p3.h, p8.h/Z, p10.h[w15, #3] + KAI_ASM_INST(0xe0566361) // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x22, LSL #1] + KAI_ASM_INST(0x252a7120) // psel p0.h, p12.h/Z, p9.h[w14] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0566b29) // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x22, LSL #1] + KAI_ASM_INST(0x252a7122) // psel p2.h, p12.h/Z, p9.h[w14] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe05666e3) // ld1h { za0h.h[x15, #3] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x253a7121) // psel p1.h, p12.h/Z, p9.h[w14, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0566eab) // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bfc080) // st1w { za0v.s[x14] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x253a7120) // psel p0.h, p12.h/Z, p9.h[w14, #1] + KAI_ASM_INST(0xe0b0c884) // st1w { za1v.s[x14] }, p2/Z, [x4, x16, LSL #2] + add x9, x9, #0x10 + add x15, x15, #0x4 + KAI_ASM_INST(0xe0bcc481) // st1w { za0v.s[x14, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0bac085) // st1w { za1v.s[x14, #1] }, p0/Z, [x4, x26, LSL #2] + add x14, x14, #0x2 + addvl x4, x4, #4 + cmp x14, x11 + blt label_4 +KAI_ASM_LABEL(label_5) // K loop: Main loop: First: Tail + KAI_ASM_INST(0x253b6160) // psel p0.h, p8.h/Z, p11.h[w15, #1] + KAI_ASM_INST(0x253b6142) // psel p2.h, p8.h/Z, p10.h[w15, #1] + KAI_ASM_INST(0x257b6161) // psel p1.h, p8.h/Z, p11.h[w15, #3] + KAI_ASM_INST(0x257b6143) // psel p3.h, p8.h/Z, p10.h[w15, #3] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0566361) // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x22, LSL #1] + KAI_ASM_INST(0x252a7120) // psel p0.h, p12.h/Z, p9.h[w14] + ldr x27, [x10, #0x0] + mov x13, #0x0 + KAI_ASM_INST(0xe0566b29) // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x22, LSL #1] + KAI_ASM_INST(0x252a7122) // psel p2.h, p12.h/Z, p9.h[w14] + ldr x25, [x9, #0x0] + mov x12, #0x0 + KAI_ASM_INST(0xe05666e3) // ld1h { za0h.h[x15, #3] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x253a7121) // psel p1.h, p12.h/Z, p9.h[w14, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0566eab) // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bfc080) // st1w { za0v.s[x14] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x253a7120) // psel p0.h, p12.h/Z, p9.h[w14, #1] + KAI_ASM_INST(0xe0b0c884) // st1w { za1v.s[x14] }, p2/Z, [x4, x16, LSL #2] + whilelt p9.h, x8, x1 + inch x8 + KAI_ASM_INST(0xe0bcc481) // st1w { za0v.s[x14, #1] }, p1/Z, [x4, x28, LSL #2] + add x9, x9, #0x10 + inch x22 + KAI_ASM_INST(0xe0bac085) // st1w { za1v.s[x14, #1] }, p0/Z, [x4, x26, LSL #2] + addvl x4, x4, #4 + whilelt p8.h, x8, x1 + cbz x11, label_7 +KAI_ASM_LABEL(label_6) // K loop: Main loop: Second: Loop + KAI_ASM_INST(0x25296160) // psel p0.h, p8.h/Z, p11.h[w13] + KAI_ASM_INST(0x25296142) // psel p2.h, p8.h/Z, p10.h[w13] + KAI_ASM_INST(0x25696161) // psel p1.h, p8.h/Z, p11.h[w13, #2] + KAI_ASM_INST(0x25696143) // psel p3.h, p8.h/Z, p10.h[w13, #2] + KAI_ASM_INST(0xe0562360) // ld1h { za0h.h[x13] }, p0/Z, [x27, x22, LSL #1] + KAI_ASM_INST(0x25287120) // psel p0.h, p12.h/Z, p9.h[w12] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0562b28) // ld1h { za1h.h[x13] }, p2/Z, [x25, x22, LSL #1] + KAI_ASM_INST(0x25287122) // psel p2.h, p12.h/Z, p9.h[w12] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe05626e2) // ld1h { za0h.h[x13, #2] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x25387121) // psel p1.h, p12.h/Z, p9.h[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0562eaa) // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bf8088) // st1w { za2v.s[x12] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x25387120) // psel p0.h, p12.h/Z, p9.h[w12, #1] + KAI_ASM_INST(0xe0b0888c) // st1w { za3v.s[x12] }, p2/Z, [x4, x16, LSL #2] + add x9, x9, #0x10 + add x13, x13, #0x4 + KAI_ASM_INST(0xe0bc8489) // st1w { za2v.s[x12, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0ba808d) // st1w { za3v.s[x12, #1] }, p0/Z, [x4, x26, LSL #2] + add x12, x12, #0x2 + addvl x4, x4, #4 + cmp x12, x11 + blt label_6 +KAI_ASM_LABEL(label_7) // K loop: Main loop: Second: Tail + KAI_ASM_INST(0x25296160) // psel p0.h, p8.h/Z, p11.h[w13] + KAI_ASM_INST(0x25296142) // psel p2.h, p8.h/Z, p10.h[w13] + KAI_ASM_INST(0x25696161) // psel p1.h, p8.h/Z, p11.h[w13, #2] + KAI_ASM_INST(0x25696143) // psel p3.h, p8.h/Z, p10.h[w13, #2] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0562360) // ld1h { za0h.h[x13] }, p0/Z, [x27, x22, LSL #1] + KAI_ASM_INST(0x25287120) // psel p0.h, p12.h/Z, p9.h[w12] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0562b28) // ld1h { za1h.h[x13] }, p2/Z, [x25, x22, LSL #1] + KAI_ASM_INST(0x25287122) // psel p2.h, p12.h/Z, p9.h[w12] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe05626e2) // ld1h { za0h.h[x13, #2] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x25387121) // psel p1.h, p12.h/Z, p9.h[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0562eaa) // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bf8088) // st1w { za2v.s[x12] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x25387120) // psel p0.h, p12.h/Z, p9.h[w12, #1] + KAI_ASM_INST(0xe0b0888c) // st1w { za3v.s[x12] }, p2/Z, [x4, x16, LSL #2] + whilelt p9.h, x8, x1 + subs x20, x20, #0x1 + KAI_ASM_INST(0xe0bc8489) // st1w { za2v.s[x12, #1] }, p1/Z, [x4, x28, LSL #2] + add x9, x9, #0x10 + inch x8 + KAI_ASM_INST(0xe0ba808d) // st1w { za3v.s[x12, #1] }, p0/Z, [x4, x26, LSL #2] + addvl x4, x4, #4 + inch x22 + bgt label_3 +KAI_ASM_LABEL(label_8) // K loop: Tails + cbnz x24, label_11 + mov x10, x2 + whilelt p8.h, x8, x1 + mov x13, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_9) // K loop: Tails: Even: First + KAI_ASM_INST(0x25307123) // psel p3.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25396161) // psel p1.h, p8.h/Z, p11.h[w13, #1] + KAI_ASM_INST(0x25396140) // psel p0.h, p8.h/Z, p10.h[w13, #1] + KAI_ASM_INST(0xe0bf8c80) // st1w { za0v.s[x12] }, p3/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0b08884) // st1w { za1v.s[x12] }, p2/Z, [x4, x16, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + ldr x21, [x10, #0x0] + cmp x12, x16 + ldr x20, [x10, x16, LSL #0x3] + add x10, x10, #0x8 + KAI_ASM_INST(0xe05626a1) // ld1h { za0h.h[x13, #1] }, p1/Z, [x21, x22, LSL #1] + KAI_ASM_INST(0xe0562289) // ld1h { za1h.h[x13, #1] }, p0/Z, [x20, x22, LSL #1] + add x13, x13, #0x2 + blt label_9 + whilelt p9.h, x8, x1 + whilelt p8.h, x8, x1 + mov x20, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_10) // K loop: Tails: Even: Second + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + add x20, x20, #0x2 + KAI_ASM_INST(0xe0bf8488) // st1w { za2v.s[x12] }, p1/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0b0808c) // st1w { za3v.s[x12] }, p0/Z, [x4, x16, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + cmp x12, x17 + blt label_10 + whilelt p8.h, x8, x1 + b label_13 +KAI_ASM_LABEL(label_11) // K loop: Tails: Odd + mov x12, #0x0 +KAI_ASM_LABEL(label_12) // K loop: Tails: Odd: Loop + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8480) // st1w { za0v.s[x12] }, p1/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0b08084) // st1w { za1v.s[x12] }, p0/Z, [x4, x16, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + cmp x12, x17 + blt label_12 +KAI_ASM_LABEL(label_13) // K loop: End + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c index bef728fe..4d2d272b 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c @@ -3,14 +3,6 @@ // // SPDX-License-Identifier: Apache-2.0 // - -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) -#error This file must be compiled for AArch64, FEAT_SVE2. -#else // Architectural features check. - #include "kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" #include @@ -18,9 +10,14 @@ #include "kai/kai_common.h" -#define MR 2 -#define KR 1 -#define MAX_M_STEP (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(float)) / KR) +enum { + MR = 2, + KR = 1, + MAX_M_STEP = MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(float)) / KR, +}; + +void kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + size_t height, size_t width, void* in, size_t row_offset, void* out); static size_t kai_get_mr_lhs_imatmul_pack_x32p2vlx1_x32p_sme(void) { return MR * kai_get_sme_vector_length_u32() / KR; @@ -69,260 +66,9 @@ void kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme( } } - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x21, %x[width]\n" - "mov x20, %x[width]\n" - "incw x21\n" - "cntw x17\n" - "sub x21, x21, #0x1\n" - "sub x16, x17, #0x1\n" - "udiv x21, x21, x17\n" // n_passes = ceildiv(width, VL) - "ands x16, x20, x16\n" - "sub x20, x21, #0x1\n" - "sub x15, x17, #0x2\n" - "mov x14, #0x0\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - "cntw x9, ALL, MUL #2\n" - "ldr x28, [x11, #0x0]\n" - "cntw x27, ALL, MUL #3\n" - "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 - "ldr x26, [x10, #0x0]\n" - "and x25, x21, #0x1\n" // odd_tail = bool(n_passes & 0x1) - "csel x16, x16, x17, NE\n" - "ldr x24, [x11, #0x8]\n" - "ptrue p12.s\n" - "whilelt p11.s, XZR, %x[height]\n" - "ldr x21, [x10, #0x8]\n" - "whilelt p10.s, x17, %x[height]\n" - "mov x23, %x[row_offset]\n" - "mov x22, %x[out]\n" - "whilelt p9.s, x14, %x[width]\n" - "whilelt p8.s, x14, %x[width]\n" - "add x11, x11, #0x10\n" - "add x10, x10, #0x10\n" - "mov x12, #0x0\n" - "cbz x15, 2f\n" - "1:" // K loop: Charge: Loop - ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" - ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" - "add x12, x12, #0x2\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "cmp x12, x15\n" - "blt 1b\n" - "2:" // K loop: Charge: End - ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" - "ldr x28, [x11, #0x0]\n" - "incw x14\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "incw x23\n" - "cbz x20, 8f\n" - "mov x20, x20\n" - "3:" // K loop: Main loop - "whilelt p8.s, x14, %x[width]\n" - "mov x13, #0x0\n" - "cbz x15, 5f\n" - "4:" // K loop: Main loop: First: Loop - ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" - ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" - ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" - ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" - ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" - ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "add x13, x13, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x13, x15\n" - "blt 4b\n" - "5:" // K loop: Main loop: First: Tail - ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" - ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" - ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" - ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" - "ldr x28, [x11, #0x0]\n" - "mov x12, #0x0\n" - ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" - ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" - "whilelt p9.s, x14, %x[width]\n" - "incw x14\n" - ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "incw x23\n" - ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "addvl x22, x22, #4\n" - "whilelt p8.s, x14, %x[width]\n" - "cbz x15, 7f\n" - "6:" // K loop: Main loop: Second: Loop - ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" - ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" - ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "add x12, x12, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x12, x15\n" - "blt 6b\n" - "7:" // K loop: Main loop: Second: Tail - ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" - ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" - "whilelt p9.s, x14, %x[width]\n" - "subs x20, x20, #0x1\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "incw x14\n" - ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "addvl x22, x22, #4\n" - "incw x23\n" - "bgt 3b\n" - "8:" // K loop: Tails - "cbnz x25, 11f\n" - "mov x11, %x[in]\n" - "whilelt p8.s, x14, %x[width]\n" - "mov x12, #0x0\n" - "9:" // K loop: Tails: Even: First - ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25306161 // psel p1.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306140 // psel p0.s, p8.s/Z, p10.s[w12]\n" - ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0b18ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" - "addvl x22, x22, #2\n" - "ldr x21, [x11, #0x0]\n" - "ldr x20, [x11, x17, LSL #0x3]\n" - "add x11, x11, #0x8\n" - ".inst 0xe09706a8 // ld1w { za2h.s[x12] }, p1/Z, [x21, x23, LSL #2]\n" - ".inst 0xe097028c // ld1w { za3h.s[x12] }, p0/Z, [x20, x23, LSL #2]\n" - "add x12, x12, #0x1\n" - "cmp x12, x17\n" - "blt 9b\n" - "whilelt p9.s, x14, %x[width]\n" - "whilelt p8.s, x14, %x[width]\n" - "mov x12, #0x0\n" - "10:" // K loop: Tails: Even: Second - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0b182cc // st1w { za3v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x16\n" - "blt 10b\n" - "whilelt p8.s, x14, %x[width]\n" - "b 13f\n" - "11:" // K loop: Tails: Odd - "mov x12, #0x0\n" - "12:" // K loop: Tails: Odd: Loop - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0b182c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x16\n" - "blt 12b\n" - "13:" // K loop: End - "mov %x[out], x22\n" - ".inst 0xd503467f // SMSTOP\n" - : [out] "+&r"(out) - : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", - "p7", "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", - "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", - "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", - "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + height, width, in, row_offset, out); // NOLINT(bugprone-multi-level-implicit-pointer-conversion) out_base += m_step * kai_roundup(k_chunk_length, KR) * sizeof(float); } } } - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h index 5f6c68a9..416c3130 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h @@ -16,7 +16,7 @@ extern "C" { /// /// The starting row index must be divisible by `m_step`. /// -/// @return Step size for row index +/// @return The m step value. size_t kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme(void); /// Gets the offset in bytes to the data element in the packed LHS buffer. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S new file mode 100644 index 00000000..673ae8a1 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S @@ -0,0 +1,299 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(lhs_imatmul_pack_x32p2vlx1_x32p_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x16, #0x0 + mov x21, x1 + cntw x15 + incw x21 + mov x20, x1 + sub x21, x21, #0x1 + sub x14, x15, #0x1 + udiv x21, x21, x15 // n_passes = ceildiv(width, VL) + ands x14, x20, x14 + sub x20, x21, #0x1 + sub x11, x15, #0x2 + mov x10, x2 + add x9, x2, x15, LSL #3 + cntw x28, ALL, MUL #2 + cntw x27, ALL, MUL #3 + ldr x26, [x10, #0x0] + lsr x20, x20, #0x1 // n_loops = (n_passes - 1) / 2 + and x25, x21, #0x1 // odd_tail = bool(n_passes & 0x1) + ldr x24, [x9, #0x0] + csel x14, x14, x15, NE + ptrue p12.s + ldr x23, [x10, #0x8] + whilelt p11.s, XZR, x0 + whilelt p10.s, x15, x0 + ldr x21, [x9, #0x8] + mov x22, x3 + whilelt p9.s, x16, x1 + whilelt p8.s, x16, x1 + add x10, x10, #0x10 + add x9, x9, #0x10 + mov x12, #0x0 + cbz x11, label_2 +KAI_ASM_LABEL(label_1) // K loop: Charge: Loop + KAI_ASM_INST(0x25306163) // psel p3.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706140) // psel p0.s, p8.s/Z, p10.s[w12, #1] + KAI_ASM_INST(0xe0960f40) // ld1w { za0h.s[x12] }, p3/Z, [x26, x22, LSL #2] + ldr x26, [x10, #0x0] + KAI_ASM_INST(0xe0960b04) // ld1w { za1h.s[x12] }, p2/Z, [x24, x22, LSL #2] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe09602a5) // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x22, LSL #2] + add x12, x12, #0x2 + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + cmp x12, x11 + blt label_1 +KAI_ASM_LABEL(label_2) // K loop: Charge: End + KAI_ASM_INST(0x25306163) // psel p3.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706140) // psel p0.s, p8.s/Z, p10.s[w12, #1] + mov x10, x2 + add x9, x2, x15, LSL #3 + KAI_ASM_INST(0xe0960f40) // ld1w { za0h.s[x12] }, p3/Z, [x26, x22, LSL #2] + ldr x26, [x10, #0x0] + incw x16 + KAI_ASM_INST(0xe0960b04) // ld1w { za1h.s[x12] }, p2/Z, [x24, x22, LSL #2] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe09602a5) // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + incw x22 + cbz x20, label_8 + mov x20, x20 +KAI_ASM_LABEL(label_3) // K loop: Main loop + whilelt p8.s, x16, x1 + mov x13, #0x0 + cbz x11, label_5 +KAI_ASM_LABEL(label_4) // K loop: Main loop: First: Loop + KAI_ASM_INST(0x25316160) // psel p0.s, p8.s/Z, p11.s[w13] + KAI_ASM_INST(0x25316142) // psel p2.s, p8.s/Z, p10.s[w13] + KAI_ASM_INST(0x25716161) // psel p1.s, p8.s/Z, p11.s[w13, #1] + KAI_ASM_INST(0x25716143) // psel p3.s, p8.s/Z, p10.s[w13, #1] + KAI_ASM_INST(0xe0962348) // ld1w { za2h.s[x13] }, p0/Z, [x26, x22, LSL #2] + KAI_ASM_INST(0x25317120) // psel p0.s, p12.s/Z, p9.s[w13] + ldr x26, [x10, #0x0] + KAI_ASM_INST(0xe0962b0c) // ld1w { za3h.s[x13] }, p2/Z, [x24, x22, LSL #2] + KAI_ASM_INST(0x25317122) // psel p2.s, p12.s/Z, p9.s[w13] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe09626e9) // ld1w { za2h.s[x13, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25717121) // psel p1.s, p12.s/Z, p9.s[w13, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0962ead) // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bfa080) // st1w { za0v.s[x13] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x25717120) // psel p0.s, p12.s/Z, p9.s[w13, #1] + KAI_ASM_INST(0xe0afa884) // st1w { za1v.s[x13] }, p2/Z, [x4, x15, LSL #2] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0bca481) // st1w { za0v.s[x13, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0bba085) // st1w { za1v.s[x13, #1] }, p0/Z, [x4, x27, LSL #2] + add x13, x13, #0x2 + addvl x4, x4, #4 + cmp x13, x11 + blt label_4 +KAI_ASM_LABEL(label_5) // K loop: Main loop: First: Tail + KAI_ASM_INST(0x25316160) // psel p0.s, p8.s/Z, p11.s[w13] + KAI_ASM_INST(0x25316142) // psel p2.s, p8.s/Z, p10.s[w13] + KAI_ASM_INST(0x25716161) // psel p1.s, p8.s/Z, p11.s[w13, #1] + KAI_ASM_INST(0x25716143) // psel p3.s, p8.s/Z, p10.s[w13, #1] + mov x10, x2 + add x9, x2, x15, LSL #3 + KAI_ASM_INST(0xe0962348) // ld1w { za2h.s[x13] }, p0/Z, [x26, x22, LSL #2] + KAI_ASM_INST(0x25317120) // psel p0.s, p12.s/Z, p9.s[w13] + ldr x26, [x10, #0x0] + mov x12, #0x0 + KAI_ASM_INST(0xe0962b0c) // ld1w { za3h.s[x13] }, p2/Z, [x24, x22, LSL #2] + KAI_ASM_INST(0x25317122) // psel p2.s, p12.s/Z, p9.s[w13] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe09626e9) // ld1w { za2h.s[x13, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25717121) // psel p1.s, p12.s/Z, p9.s[w13, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0962ead) // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bfa080) // st1w { za0v.s[x13] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x25717120) // psel p0.s, p12.s/Z, p9.s[w13, #1] + KAI_ASM_INST(0xe0afa884) // st1w { za1v.s[x13] }, p2/Z, [x4, x15, LSL #2] + whilelt p9.s, x16, x1 + incw x16 + KAI_ASM_INST(0xe0bca481) // st1w { za0v.s[x13, #1] }, p1/Z, [x4, x28, LSL #2] + add x9, x9, #0x10 + incw x22 + KAI_ASM_INST(0xe0bba085) // st1w { za1v.s[x13, #1] }, p0/Z, [x4, x27, LSL #2] + addvl x4, x4, #4 + whilelt p8.s, x16, x1 + cbz x11, label_7 +KAI_ASM_LABEL(label_6) // K loop: Main loop: Second: Loop + KAI_ASM_INST(0x25306160) // psel p0.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706143) // psel p3.s, p8.s/Z, p10.s[w12, #1] + KAI_ASM_INST(0xe0960340) // ld1w { za0h.s[x12] }, p0/Z, [x26, x22, LSL #2] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + ldr x26, [x10, #0x0] + KAI_ASM_INST(0xe0960b04) // ld1w { za1h.s[x12] }, p2/Z, [x24, x22, LSL #2] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25707121) // psel p1.s, p12.s/Z, p9.s[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0960ea5) // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bf8088) // st1w { za2v.s[x12] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x25707120) // psel p0.s, p12.s/Z, p9.s[w12, #1] + KAI_ASM_INST(0xe0af888c) // st1w { za3v.s[x12] }, p2/Z, [x4, x15, LSL #2] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0bc8489) // st1w { za2v.s[x12, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0bb808d) // st1w { za3v.s[x12, #1] }, p0/Z, [x4, x27, LSL #2] + add x12, x12, #0x2 + addvl x4, x4, #4 + cmp x12, x11 + blt label_6 +KAI_ASM_LABEL(label_7) // K loop: Main loop: Second: Tail + KAI_ASM_INST(0x25306160) // psel p0.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706143) // psel p3.s, p8.s/Z, p10.s[w12, #1] + mov x10, x2 + add x9, x2, x15, LSL #3 + KAI_ASM_INST(0xe0960340) // ld1w { za0h.s[x12] }, p0/Z, [x26, x22, LSL #2] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + ldr x26, [x10, #0x0] + KAI_ASM_INST(0xe0960b04) // ld1w { za1h.s[x12] }, p2/Z, [x24, x22, LSL #2] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25707121) // psel p1.s, p12.s/Z, p9.s[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0960ea5) // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bf8088) // st1w { za2v.s[x12] }, p0/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0x25707120) // psel p0.s, p12.s/Z, p9.s[w12, #1] + KAI_ASM_INST(0xe0af888c) // st1w { za3v.s[x12] }, p2/Z, [x4, x15, LSL #2] + whilelt p9.s, x16, x1 + subs x20, x20, #0x1 + KAI_ASM_INST(0xe0bc8489) // st1w { za2v.s[x12, #1] }, p1/Z, [x4, x28, LSL #2] + add x9, x9, #0x10 + incw x16 + KAI_ASM_INST(0xe0bb808d) // st1w { za3v.s[x12, #1] }, p0/Z, [x4, x27, LSL #2] + addvl x4, x4, #4 + incw x22 + bgt label_3 +KAI_ASM_LABEL(label_8) // K loop: Tails + cbnz x25, label_11 + mov x10, x2 + whilelt p8.s, x16, x1 + mov x12, #0x0 +KAI_ASM_LABEL(label_9) // K loop: Tails: Even: First + KAI_ASM_INST(0x25307123) // psel p3.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306161) // psel p1.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306140) // psel p0.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0xe0bf8c80) // st1w { za0v.s[x12] }, p3/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0af8884) // st1w { za1v.s[x12] }, p2/Z, [x4, x15, LSL #2] + addvl x4, x4, #2 + ldr x21, [x10, #0x0] + ldr x20, [x10, x15, LSL #0x3] + add x10, x10, #0x8 + KAI_ASM_INST(0xe09606a8) // ld1w { za2h.s[x12] }, p1/Z, [x21, x22, LSL #2] + KAI_ASM_INST(0xe096028c) // ld1w { za3h.s[x12] }, p0/Z, [x20, x22, LSL #2] + add x12, x12, #0x1 + cmp x12, x15 + blt label_9 + whilelt p9.s, x16, x1 + whilelt p8.s, x16, x1 + mov x12, #0x0 +KAI_ASM_LABEL(label_10) // K loop: Tails: Even: Second + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8488) // st1w { za2v.s[x12] }, p1/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0af808c) // st1w { za3v.s[x12] }, p0/Z, [x4, x15, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + cmp x12, x14 + blt label_10 + whilelt p8.s, x16, x1 + b label_13 +KAI_ASM_LABEL(label_11) // K loop: Tails: Odd + mov x12, #0x0 +KAI_ASM_LABEL(label_12) // K loop: Tails: Odd: Loop + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8480) // st1w { za0v.s[x12] }, p1/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0af8084) // st1w { za1v.s[x12] }, p0/Z, [x4, x15, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + cmp x12, x14 + blt label_12 +KAI_ASM_LABEL(label_13) // K loop: End + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c index 25a48afc..b5fbd436 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c @@ -3,14 +3,6 @@ // // SPDX-License-Identifier: Apache-2.0 // - -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) -#error This file must be compiled for AArch64, FEAT_SVE2. -#else // Architectural features check. - #include "kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h" #include @@ -18,9 +10,13 @@ #include "kai/kai_common.h" -#define MR 2 -#define KR 4 -#define MAX_M_STEP (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(int8_t)) / KR) +enum { + MR = 2, + KR = 4, + MAX_M_STEP = MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(int8_t)) / KR, +}; + +void kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme(size_t height, size_t width, void* in, size_t row_offset, void* out); static size_t kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void) { return MR * kai_get_sme_vector_length_u8() / KR; @@ -69,274 +65,9 @@ void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( } } - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x23, %x[width]\n" - "mov x21, %x[width]\n" - "cntb x20\n" - "incb x23\n" - "sub x7, x20, #0x1\n" - "cntw x8\n" - "sub x23, x23, #0x1\n" - "ands x7, x21, x7\n" - "udiv x23, x23, x20\n" // n_passes = ceildiv(width, VL) - "csel x7, x7, x20, NE\n" - "lsl x22, %x[height], #0x1\n" // height * 2 - "lsl x21, x8, #0x1\n" - "sub x20, x23, #0x1\n" - "add x7, x7, #0x3\n" - "sub x17, x8, #0x2\n" - "whilelt p9.b, XZR, x22\n" - "whilelt p8.b, x21, x22\n" - "mov x16, #0x0\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - "cntw x9, ALL, MUL #2\n" - "cntw x28, ALL, MUL #3\n" - "ldr x27, [x11, #0x0]\n" - "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 - "and x26, x23, #0x1\n" // odd_tail = bool(n_passes & 0x1) - "ldr x25, [x10, #0x0]\n" - "lsr x7, x7, #0x2\n" - "ptrue p11.s\n" - "ldr x24, [x11, #0x8]\n" - "zip1 p10.b, p9.b, p8.b\n" - "mov x23, %x[row_offset]\n" - "ldr x21, [x10, #0x8]\n" - "mov x22, %x[out]\n" - "whilelt p9.b, x16, %x[width]\n" - "whilelt p8.b, x16, %x[width]\n" - "add x11, x11, #0x10\n" - "add x10, x10, #0x10\n" - "mov x12, #0x0\n" - "cbz x17, 2f\n" - "1:" // K loop: Charge: Loop - ".inst 0x25246143 // psel p3.b, p8.b/Z, p10.b[w12]\n" - ".inst 0x252c6142 // psel p2.b, p8.b/Z, p10.b[w12, #1]\n" - ".inst 0x25646141 // psel p1.b, p8.b/Z, p10.b[w12, #4]\n" - ".inst 0x256c6140 // psel p0.b, p8.b/Z, p10.b[w12, #5]\n" - ".inst 0xe0170f60 // ld1b { za0h.b[x12] }, p3/Z, [x27, x23]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0170b21 // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x23]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0170704 // ld1b { za0h.b[x12, #4] }, p1/Z, [x24, x23]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01702a5 // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x23]\n" - "add x12, x12, #0x8\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "cmp x12, x17, LSL #2\n" - "blt 1b\n" - "2:" // K loop: Charge: End - ".inst 0x25246143 // psel p3.b, p8.b/Z, p10.b[w12]\n" - ".inst 0x252c6142 // psel p2.b, p8.b/Z, p10.b[w12, #1]\n" - ".inst 0x25646141 // psel p1.b, p8.b/Z, p10.b[w12, #4]\n" - ".inst 0x256c6140 // psel p0.b, p8.b/Z, p10.b[w12, #5]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0170f60 // ld1b { za0h.b[x12] }, p3/Z, [x27, x23]\n" - "ldr x27, [x11, #0x0]\n" - "incb x16\n" - ".inst 0xe0170b21 // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x23]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0170704 // ld1b { za0h.b[x12, #4] }, p1/Z, [x24, x23]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01702a5 // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "incb x23\n" - "cbz x20, 8f\n" - "mov x20, x20\n" - "3:" // K loop: Main loop - "whilelt p8.b, x16, %x[width]\n" - "mov x15, #0x0\n" - "mov x14, #0x0\n" - "cbz x17, 5f\n" - "4:" // K loop: Main loop: First: Loop - ".inst 0x25376143 // psel p3.b, p8.b/Z, p10.b[w15, #2]\n" - ".inst 0x253f6142 // psel p2.b, p8.b/Z, p10.b[w15, #3]\n" - ".inst 0x25776141 // psel p1.b, p8.b/Z, p10.b[w15, #6]\n" - ".inst 0x257f6140 // psel p0.b, p8.b/Z, p10.b[w15, #7]\n" - ".inst 0xe0176f62 // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x23]\n" - ".inst 0x25266d23 // psel p3.b, p11.b/Z, p9.b[w14]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0176b23 // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x23]\n" - ".inst 0x25266d22 // psel p2.b, p11.b/Z, p9.b[w14]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0176706 // ld1b { za0h.b[x15, #6] }, p1/Z, [x24, x23]\n" - ".inst 0x252e6d21 // psel p1.b, p11.b/Z, p9.b[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01762a7 // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0x252e6d20 // psel p0.b, p11.b/Z, p9.b[w14, #1]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0bfcec0 // st1w { za0v.s[x14] }, p3/Z, [x22, XZR, LSL #2]\n" - "add x15, x15, #0x8\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x14, x14, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x14, x17\n" - "blt 4b\n" - "5:" // K loop: Main loop: First: Tail - ".inst 0x25376143 // psel p3.b, p8.b/Z, p10.b[w15, #2]\n" - ".inst 0x253f6142 // psel p2.b, p8.b/Z, p10.b[w15, #3]\n" - ".inst 0x25776141 // psel p1.b, p8.b/Z, p10.b[w15, #6]\n" - ".inst 0x257f6140 // psel p0.b, p8.b/Z, p10.b[w15, #7]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0176f62 // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x23]\n" - ".inst 0x25266d23 // psel p3.b, p11.b/Z, p9.b[w14]\n" - "ldr x27, [x11, #0x0]\n" - "mov x13, #0x0\n" - ".inst 0xe0176b23 // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x23]\n" - ".inst 0x25266d22 // psel p2.b, p11.b/Z, p9.b[w14]\n" - "ldr x25, [x10, #0x0]\n" - "mov x12, #0x0\n" - ".inst 0xe0176706 // ld1b { za0h.b[x15, #6] }, p1/Z, [x24, x23]\n" - ".inst 0x252e6d21 // psel p1.b, p11.b/Z, p9.b[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01762a7 // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0x252e6d20 // psel p0.b, p11.b/Z, p9.b[w14, #1]\n" - "whilelt p9.b, x16, %x[width]\n" - ".inst 0xe0bfcec0 // st1w { za0v.s[x14] }, p3/Z, [x22, XZR, LSL #2]\n" - "incb x16\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - "incb x23\n" - "whilelt p8.b, x16, %x[width]\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "cbz x17, 7f\n" - "6:" // K loop: Main loop: Second: Loop - ".inst 0x25256143 // psel p3.b, p8.b/Z, p10.b[w13]\n" - ".inst 0x252d6142 // psel p2.b, p8.b/Z, p10.b[w13, #1]\n" - ".inst 0x25656141 // psel p1.b, p8.b/Z, p10.b[w13, #4]\n" - ".inst 0x256d6140 // psel p0.b, p8.b/Z, p10.b[w13, #5]\n" - ".inst 0xe0172f60 // ld1b { za0h.b[x13] }, p3/Z, [x27, x23]\n" - ".inst 0x25246d23 // psel p3.b, p11.b/Z, p9.b[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0172b21 // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x23]\n" - ".inst 0x25246d22 // psel p2.b, p11.b/Z, p9.b[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0172704 // ld1b { za0h.b[x13, #4] }, p1/Z, [x24, x23]\n" - ".inst 0x252c6d21 // psel p1.b, p11.b/Z, p9.b[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01722a5 // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0x252c6d20 // psel p0.b, p11.b/Z, p9.b[w12, #1]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0bf8ec8 // st1w { za2v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - "add x13, x13, #0x8\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x12, x12, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x12, x17\n" - "blt 6b\n" - "7:" // K loop: Main loop: Second: Tail - ".inst 0x25256143 // psel p3.b, p8.b/Z, p10.b[w13]\n" - ".inst 0x252d6142 // psel p2.b, p8.b/Z, p10.b[w13, #1]\n" - ".inst 0x25656141 // psel p1.b, p8.b/Z, p10.b[w13, #4]\n" - ".inst 0x256d6140 // psel p0.b, p8.b/Z, p10.b[w13, #5]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0172f60 // ld1b { za0h.b[x13] }, p3/Z, [x27, x23]\n" - ".inst 0x25246d23 // psel p3.b, p11.b/Z, p9.b[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0172b21 // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x23]\n" - ".inst 0x25246d22 // psel p2.b, p11.b/Z, p9.b[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0172704 // ld1b { za0h.b[x13, #4] }, p1/Z, [x24, x23]\n" - ".inst 0x252c6d21 // psel p1.b, p11.b/Z, p9.b[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01722a5 // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0x252c6d20 // psel p0.b, p11.b/Z, p9.b[w12, #1]\n" - "whilelt p9.b, x16, %x[width]\n" - ".inst 0xe0bf8ec8 // st1w { za2v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - "subs x20, x20, #0x1\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "incb x16\n" - "incb x23\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "bgt 3b\n" - "8:" // K loop: Tails - "cbnz x26, 11f\n" - "mov x11, %x[in]\n" - "whilelt p8.b, x16, %x[width]\n" - "mov x13, #0x0\n" - "mov x12, #0x0\n" - "9:" // K loop: Tails: Even: First - ".inst 0x25306d23 // psel p3.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25306d22 // psel p2.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25356141 // psel p1.b, p8.b/Z, p10.b[w13, #2]\n" - ".inst 0x253d6140 // psel p0.b, p8.b/Z, p10.b[w13, #3]\n" - ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "ldr x21, [x11, #0x0]\n" - "cmp x12, x8\n" - "ldr x20, [x11, x8, LSL #0x3]\n" - "add x11, x11, #0x8\n" - ".inst 0xe01726a2 // ld1b { za0h.b[x13, #2] }, p1/Z, [x21, x23]\n" - ".inst 0xe0172283 // ld1b { za0h.b[x13, #3] }, p0/Z, [x20, x23]\n" - "add x13, x13, #0x4\n" - "blt 9b\n" - "whilelt p9.b, x16, %x[width]\n" - "whilelt p8.b, x16, %x[width]\n" - "mov x20, #0x0\n" - "mov x12, #0x0\n" - "10:" // K loop: Tails: Even: Second - ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25306d20 // psel p0.s, p11.s/Z, p9.s[w12]\n" - "add x20, x20, #0x4\n" - ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 10b\n" - "whilelt p8.b, x16, %x[width]\n" - "b 13f\n" - "11:" // K loop: Tails: Odd - "mov x12, #0x0\n" - "12:" // K loop: Tails: Odd: Loop - ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25306d20 // psel p0.s, p11.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 12b\n" - "13:" // K loop: End - "mov %x[out], x22\n" - ".inst 0xd503467f // SMSTOP\n" - : [out] "+&r"(out) - : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", - "p7", "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", - "x24", "x25", "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", - "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", - "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme( + height, width, in, row_offset, out); // NOLINT(bugprone-multi-level-implicit-pointer-conversion) out_base += m_step * kai_roundup(k_chunk_length, KR) * sizeof(int8_t); } } } - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h index 7136d837..1adc97bc 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h @@ -16,14 +16,14 @@ extern "C" { /// /// The starting row index must be divisible by `m_step`. /// -/// @return Step size for row index +/// @return The m step value. size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void); /// Gets the offset in bytes to the data element in the packed LHS buffer. /// /// @param[in] m_idx Row index in the unpacked LHS matrix. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. +/// @param[in] k_chunk_length Length of a LHS column split. /// /// @return The offset in bytes to the data element. size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme( @@ -33,7 +33,7 @@ size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme( /// /// @param[in] m Number of rows in the unpacked LHS matrix. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. +/// @param[in] k_chunk_length Length of a LHS column split. /// /// @return The size in bytes of the packed LHS buffer. size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length); @@ -42,9 +42,9 @@ size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme(size_t m, size_ /// /// @param[in] m Number of rows of the unpacked LHS matrix. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. +/// @param[in] k_chunk_length Length of a LHS column split. /// @param[in] lhs_ptrs Pointer to an array of input pointers consisting of -/// t `m * k_chunk_count` pointers. +/// `m * k_chunk_count` pointers. /// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs /// array, excluding zero pointers. /// @param[in] pad_ptr Pointer to chunk used for padding. @ref lhs_ptr_offset is diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S new file mode 100644 index 00000000..18ea1f78 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S @@ -0,0 +1,313 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(lhs_imatmul_pack_x8p2vlx4_x8p_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x8, #0x0 + cntb x21 + mov x23, x1 + incb x23 + mov x20, x1 + sub x17, x21, #0x1 + cntw x16 + sub x23, x23, #0x1 + ands x17, x20, x17 + udiv x23, x23, x21 // n_passes = ceildiv(width, VL) + csel x17, x17, x21, NE + lsl x22, x0, #0x1 // height * 2 + lsl x21, x16, #0x1 + sub x20, x23, #0x1 + add x17, x17, #0x3 + sub x11, x16, #0x2 + whilelt p9.b, XZR, x22 + whilelt p8.b, x21, x22 + mov x10, x2 + add x9, x2, x16, LSL #3 + cntw x28, ALL, MUL #2 + ldr x27, [x10, #0x0] + cntw x26, ALL, MUL #3 + lsr x20, x20, #0x1 // n_loops = (n_passes - 1) / 2 + ldr x25, [x9, #0x0] + and x24, x23, #0x1 // odd_tail = bool(n_passes & 0x1) + lsr x17, x17, #0x2 + ldr x23, [x10, #0x8] + ptrue p11.s + zip1 p10.b, p9.b, p8.b + ldr x21, [x9, #0x8] + mov x22, x3 + whilelt p9.b, x8, x1 + whilelt p8.b, x8, x1 + add x10, x10, #0x10 + add x9, x9, #0x10 + mov x12, #0x0 + cbz x11, label_2 +KAI_ASM_LABEL(label_1) // K loop: Charge: Loop + KAI_ASM_INST(0x25246143) // psel p3.b, p8.b/Z, p10.b[w12] + KAI_ASM_INST(0x252c6142) // psel p2.b, p8.b/Z, p10.b[w12, #1] + KAI_ASM_INST(0x25646141) // psel p1.b, p8.b/Z, p10.b[w12, #4] + KAI_ASM_INST(0x256c6140) // psel p0.b, p8.b/Z, p10.b[w12, #5] + KAI_ASM_INST(0xe0160f60) // ld1b { za0h.b[x12] }, p3/Z, [x27, x22] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0160b21) // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x22] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe01606e4) // ld1b { za0h.b[x12, #4] }, p1/Z, [x23, x22] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe01602a5) // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x22] + add x12, x12, #0x8 + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + cmp x12, x11, LSL #2 + blt label_1 +KAI_ASM_LABEL(label_2) // K loop: Charge: End + KAI_ASM_INST(0x25246143) // psel p3.b, p8.b/Z, p10.b[w12] + KAI_ASM_INST(0x252c6142) // psel p2.b, p8.b/Z, p10.b[w12, #1] + KAI_ASM_INST(0x25646141) // psel p1.b, p8.b/Z, p10.b[w12, #4] + KAI_ASM_INST(0x256c6140) // psel p0.b, p8.b/Z, p10.b[w12, #5] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0160f60) // ld1b { za0h.b[x12] }, p3/Z, [x27, x22] + ldr x27, [x10, #0x0] + incb x8 + KAI_ASM_INST(0xe0160b21) // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x22] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe01606e4) // ld1b { za0h.b[x12, #4] }, p1/Z, [x23, x22] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe01602a5) // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x22] + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + incb x22 + cbz x20, label_8 + mov x20, x20 +KAI_ASM_LABEL(label_3) // K loop: Main loop + whilelt p8.b, x8, x1 + mov x15, #0x0 + mov x14, #0x0 + cbz x11, label_5 +KAI_ASM_LABEL(label_4) // K loop: Main loop: First: Loop + KAI_ASM_INST(0x25376143) // psel p3.b, p8.b/Z, p10.b[w15, #2] + KAI_ASM_INST(0x253f6142) // psel p2.b, p8.b/Z, p10.b[w15, #3] + KAI_ASM_INST(0x25776141) // psel p1.b, p8.b/Z, p10.b[w15, #6] + KAI_ASM_INST(0x257f6140) // psel p0.b, p8.b/Z, p10.b[w15, #7] + KAI_ASM_INST(0xe0166f62) // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x22] + KAI_ASM_INST(0x25266d23) // psel p3.b, p11.b/Z, p9.b[w14] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0166b23) // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x22] + KAI_ASM_INST(0x25266d22) // psel p2.b, p11.b/Z, p9.b[w14] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe01666e6) // ld1b { za0h.b[x15, #6] }, p1/Z, [x23, x22] + KAI_ASM_INST(0x252e6d21) // psel p1.b, p11.b/Z, p9.b[w14, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe01662a7) // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x22] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0x252e6d20) // psel p0.b, p11.b/Z, p9.b[w14, #1] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0bfcc80) // st1w { za0v.s[x14] }, p3/Z, [x4, XZR, LSL #2] + add x15, x15, #0x8 + KAI_ASM_INST(0xe0b0c884) // st1w { za1v.s[x14] }, p2/Z, [x4, x16, LSL #2] + KAI_ASM_INST(0xe0bcc481) // st1w { za0v.s[x14, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0bac085) // st1w { za1v.s[x14, #1] }, p0/Z, [x4, x26, LSL #2] + add x14, x14, #0x2 + addvl x4, x4, #4 + cmp x14, x11 + blt label_4 +KAI_ASM_LABEL(label_5) // K loop: Main loop: First: Tail + KAI_ASM_INST(0x25376143) // psel p3.b, p8.b/Z, p10.b[w15, #2] + KAI_ASM_INST(0x253f6142) // psel p2.b, p8.b/Z, p10.b[w15, #3] + KAI_ASM_INST(0x25776141) // psel p1.b, p8.b/Z, p10.b[w15, #6] + KAI_ASM_INST(0x257f6140) // psel p0.b, p8.b/Z, p10.b[w15, #7] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0166f62) // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x22] + KAI_ASM_INST(0x25266d23) // psel p3.b, p11.b/Z, p9.b[w14] + ldr x27, [x10, #0x0] + mov x13, #0x0 + KAI_ASM_INST(0xe0166b23) // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x22] + KAI_ASM_INST(0x25266d22) // psel p2.b, p11.b/Z, p9.b[w14] + ldr x25, [x9, #0x0] + mov x12, #0x0 + KAI_ASM_INST(0xe01666e6) // ld1b { za0h.b[x15, #6] }, p1/Z, [x23, x22] + KAI_ASM_INST(0x252e6d21) // psel p1.b, p11.b/Z, p9.b[w14, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe01662a7) // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x22] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0x252e6d20) // psel p0.b, p11.b/Z, p9.b[w14, #1] + whilelt p9.b, x8, x1 + KAI_ASM_INST(0xe0bfcc80) // st1w { za0v.s[x14] }, p3/Z, [x4, XZR, LSL #2] + incb x8 + add x9, x9, #0x10 + KAI_ASM_INST(0xe0b0c884) // st1w { za1v.s[x14] }, p2/Z, [x4, x16, LSL #2] + incb x22 + whilelt p8.b, x8, x1 + KAI_ASM_INST(0xe0bcc481) // st1w { za0v.s[x14, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0bac085) // st1w { za1v.s[x14, #1] }, p0/Z, [x4, x26, LSL #2] + addvl x4, x4, #4 + cbz x11, label_7 +KAI_ASM_LABEL(label_6) // K loop: Main loop: Second: Loop + KAI_ASM_INST(0x25256143) // psel p3.b, p8.b/Z, p10.b[w13] + KAI_ASM_INST(0x252d6142) // psel p2.b, p8.b/Z, p10.b[w13, #1] + KAI_ASM_INST(0x25656141) // psel p1.b, p8.b/Z, p10.b[w13, #4] + KAI_ASM_INST(0x256d6140) // psel p0.b, p8.b/Z, p10.b[w13, #5] + KAI_ASM_INST(0xe0162f60) // ld1b { za0h.b[x13] }, p3/Z, [x27, x22] + KAI_ASM_INST(0x25246d23) // psel p3.b, p11.b/Z, p9.b[w12] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0162b21) // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x22] + KAI_ASM_INST(0x25246d22) // psel p2.b, p11.b/Z, p9.b[w12] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe01626e4) // ld1b { za0h.b[x13, #4] }, p1/Z, [x23, x22] + KAI_ASM_INST(0x252c6d21) // psel p1.b, p11.b/Z, p9.b[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe01622a5) // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x22] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0x252c6d20) // psel p0.b, p11.b/Z, p9.b[w12, #1] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0bf8c88) // st1w { za2v.s[x12] }, p3/Z, [x4, XZR, LSL #2] + add x13, x13, #0x8 + KAI_ASM_INST(0xe0b0888c) // st1w { za3v.s[x12] }, p2/Z, [x4, x16, LSL #2] + KAI_ASM_INST(0xe0bc8489) // st1w { za2v.s[x12, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0ba808d) // st1w { za3v.s[x12, #1] }, p0/Z, [x4, x26, LSL #2] + add x12, x12, #0x2 + addvl x4, x4, #4 + cmp x12, x11 + blt label_6 +KAI_ASM_LABEL(label_7) // K loop: Main loop: Second: Tail + KAI_ASM_INST(0x25256143) // psel p3.b, p8.b/Z, p10.b[w13] + KAI_ASM_INST(0x252d6142) // psel p2.b, p8.b/Z, p10.b[w13, #1] + KAI_ASM_INST(0x25656141) // psel p1.b, p8.b/Z, p10.b[w13, #4] + KAI_ASM_INST(0x256d6140) // psel p0.b, p8.b/Z, p10.b[w13, #5] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0162f60) // ld1b { za0h.b[x13] }, p3/Z, [x27, x22] + KAI_ASM_INST(0x25246d23) // psel p3.b, p11.b/Z, p9.b[w12] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0162b21) // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x22] + KAI_ASM_INST(0x25246d22) // psel p2.b, p11.b/Z, p9.b[w12] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe01626e4) // ld1b { za0h.b[x13, #4] }, p1/Z, [x23, x22] + KAI_ASM_INST(0x252c6d21) // psel p1.b, p11.b/Z, p9.b[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe01622a5) // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x22] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0x252c6d20) // psel p0.b, p11.b/Z, p9.b[w12, #1] + whilelt p9.b, x8, x1 + KAI_ASM_INST(0xe0bf8c88) // st1w { za2v.s[x12] }, p3/Z, [x4, XZR, LSL #2] + subs x20, x20, #0x1 + add x9, x9, #0x10 + KAI_ASM_INST(0xe0b0888c) // st1w { za3v.s[x12] }, p2/Z, [x4, x16, LSL #2] + incb x8 + incb x22 + KAI_ASM_INST(0xe0bc8489) // st1w { za2v.s[x12, #1] }, p1/Z, [x4, x28, LSL #2] + KAI_ASM_INST(0xe0ba808d) // st1w { za3v.s[x12, #1] }, p0/Z, [x4, x26, LSL #2] + addvl x4, x4, #4 + bgt label_3 +KAI_ASM_LABEL(label_8) // K loop: Tails + cbnz x24, label_11 + mov x10, x2 + whilelt p8.b, x8, x1 + mov x13, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_9) // K loop: Tails: Even: First + KAI_ASM_INST(0x25306d23) // psel p3.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306d22) // psel p2.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25356141) // psel p1.b, p8.b/Z, p10.b[w13, #2] + KAI_ASM_INST(0x253d6140) // psel p0.b, p8.b/Z, p10.b[w13, #3] + KAI_ASM_INST(0xe0bf8c80) // st1w { za0v.s[x12] }, p3/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0b08884) // st1w { za1v.s[x12] }, p2/Z, [x4, x16, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + ldr x21, [x10, #0x0] + cmp x12, x16 + ldr x20, [x10, x16, LSL #0x3] + add x10, x10, #0x8 + KAI_ASM_INST(0xe01626a2) // ld1b { za0h.b[x13, #2] }, p1/Z, [x21, x22] + KAI_ASM_INST(0xe0162283) // ld1b { za0h.b[x13, #3] }, p0/Z, [x20, x22] + add x13, x13, #0x4 + blt label_9 + whilelt p9.b, x8, x1 + whilelt p8.b, x8, x1 + mov x20, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_10) // K loop: Tails: Even: Second + KAI_ASM_INST(0x25306d21) // psel p1.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306d20) // psel p0.s, p11.s/Z, p9.s[w12] + add x20, x20, #0x4 + KAI_ASM_INST(0xe0bf8488) // st1w { za2v.s[x12] }, p1/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0b0808c) // st1w { za3v.s[x12] }, p0/Z, [x4, x16, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + cmp x12, x17 + blt label_10 + whilelt p8.b, x8, x1 + b label_13 +KAI_ASM_LABEL(label_11) // K loop: Tails: Odd + mov x12, #0x0 +KAI_ASM_LABEL(label_12) // K loop: Tails: Odd: Loop + KAI_ASM_INST(0x25306d21) // psel p1.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306d20) // psel p0.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8480) // st1w { za0v.s[x12] }, p1/Z, [x4, XZR, LSL #2] + KAI_ASM_INST(0xe0b08084) // st1w { za1v.s[x12] }, p0/Z, [x4, x16, LSL #2] + add x12, x12, #0x1 + addvl x4, x4, #2 + cmp x12, x17 + blt label_12 +KAI_ASM_LABEL(label_13) // K loop: End + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c index 8cd1201a..40922c37 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -3,13 +3,6 @@ // // SPDX-License-Identifier: Apache-2.0 // - -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) -#error This file must be compiled for AArch64, FEAT_SVE2. -#else // Architectural features check. #include "kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" #include @@ -18,14 +11,33 @@ #include "kai/kai_common.h" +enum { + NR = 2, + KR = 4, + MAX_N_STEP = NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint8_t)) / KR), +}; + +typedef struct { + const void* bias_ptr; + const void* scale_ptr; + int32_t input_zero_point; + float scale_multiplier; + size_t width; + size_t height; + size_t k_chunk_count; + size_t in_stride; + size_t out_stride; + const void* in; + void* out; + const void* pad_row; +} KernelArgs; + static const size_t kai_num_bytes_input = sizeof(uint8_t); static const size_t kai_num_bytes_output = sizeof(uint8_t); static const size_t kai_num_bytes_bias = sizeof(int32_t); -static const size_t kai_num_bytes_scale = sizeof(float32_t); +static const size_t kai_num_bytes_scale = sizeof(float); -#define NR 2 -#define KR 4 -#define MAX_N_STEP (NR * KAI_SME_VEC_LENGTH_MAX_BYTES / KR) +void kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(const KernelArgs* args_ptr); size_t kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void) { return NR * kai_get_sme_vector_length_u8() / KR; @@ -63,13 +75,13 @@ size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( size_t n, size_t k_chunk_count, size_t k_chunk_length) { - const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme()); + const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme()); return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( - n_rounded_up, k_chunk_count, k_chunk_length); + n_nr_blocks, k_chunk_count, k_chunk_length); } void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( - size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias, const void* scale, void* rhs_packed, const struct kai_rhs_pack_qsi8cx_params* params) { KAI_ASSUME(rhs != NULL); KAI_ASSUME(bias != NULL); @@ -77,201 +89,23 @@ void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( KAI_ASSUME(rhs_packed != NULL); KAI_ASSUME(params != NULL); - size_t height = k_chunk_length; - const size_t width = n; - const void* in = rhs; - void* out = rhs_packed; - const size_t in_stride = rhs_row_stride; - KAI_ASSERT(kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() <= MAX_N_STEP); - uint8_t pad_row[MAX_N_STEP]; - if (height % KR) { - memset(pad_row, 0, MAX_N_STEP * sizeof(uint8_t)); - } - - size_t out_stride = + static const uint8_t pad_row[MAX_N_STEP] = {0}; + + KernelArgs args; + args.bias_ptr = bias; + args.scale_ptr = scale; + args.height = k_chunk_length; + args.width = n; + args.in = rhs; + args.out = rhs_packed; + args.k_chunk_count = k_chunk_count; + args.in_stride = rhs_stride_row; + args.out_stride = kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k_chunk_count, k_chunk_length); - const int32_t input_zero_point = params->lhs_zero_point; - const float scale_multiplier = params->scale_multiplier; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x12, %x[out]\n" - "mov x11, %x[k_chunk_count]\n" - "ptrue p2.b\n" - "incb %x[out], ALL, MUL #2\n" - "1:" // Chunk Loop - "mov x10, %x[height]\n" - "cmp x10, #0x8\n" - "blt 5f\n" - "2:" // Main row loop: Head - "mov x9, %x[in]\n" - "mov x28, %x[out]\n" - "add x27, x9, %x[in_stride]\n" - "sub x10, x10, #0x8\n" - "add x26, x27, %x[in_stride]\n" - "mov x24, %x[width]\n" - "add x25, x26, %x[in_stride]\n" - "add x23, x25, %x[in_stride]\n" - "add x22, x23, %x[in_stride]\n" - "add x21, x22, %x[in_stride]\n" - "add x20, x21, %x[in_stride]\n" - "add %x[in], x20, %x[in_stride]\n" - "3:" // Main row loop: Column loop - "whilelt p0.b, XZR, x24\n" - "decw x24, ALL, MUL #2\n" - "ld1b { z18.b }, p0/Z, [x9]\n" - "cmp x24, #0x0\n" - "incd x9, ALL, MUL #4\n" - "ld1b { z22.b }, p0/Z, [x27]\n" - "incd x27, ALL, MUL #4\n" - "ld1b { z17.b }, p0/Z, [x26]\n" - "incd x26, ALL, MUL #4\n" - "ld1b { z16.b }, p0/Z, [x25]\n" - "incd x25, ALL, MUL #4\n" - "ld1b { z20.b }, p0/Z, [x23]\n" - "incd x23, ALL, MUL #4\n" - "ld1b { z19.b }, p0/Z, [x22]\n" - "zip1 z21.b, z18.b, z17.b\n" - "incd x22, ALL, MUL #4\n" - "ld1b { z18.b }, p0/Z, [x21]\n" - "zip1 z17.b, z22.b, z16.b\n" - "incd x21, ALL, MUL #4\n" - "ld1b { z16.b }, p0/Z, [x20]\n" - "incd x20, ALL, MUL #4\n" - "zip1 z20.b, z20.b, z18.b\n" - "zip1 z16.b, z19.b, z16.b\n" - "zip1 z19.b, z21.b, z17.b\n" - "zip2 z18.b, z21.b, z17.b\n" - "zip1 z17.b, z20.b, z16.b\n" - "zip2 z16.b, z20.b, z16.b\n" - "st1b { z19.b }, p2, [x28]\n" - "st1b { z18.b }, p2, [x28, #1, MUL VL]\n" - "st1b { z17.b }, p2, [x28, #2, MUL VL]\n" - "st1b { z16.b }, p2, [x28, #3, MUL VL]\n" - "add x28, x28, %x[out_stride]\n" - "bgt 3b\n" - "cmp x10, #0x8\n" - "addvl %x[out], %x[out], #4\n" - "bge 2b\n" - "cbz x10, 9f\n" - "5:" // Main loop skip - "6:" // Tail row loop: Head - "mov x9, %x[in]\n" - "cmp x10, #0x3\n" - "add x27, x9, %x[in_stride]\n" - "cntw x24, ALL, MUL #2\n" - "add x26, x27, %x[in_stride]\n" - "csel x23, x24, XZR, GT\n" - "add x25, x26, %x[in_stride]\n" - "csel x22, x24, XZR, GE\n" - "add %x[in], x25, %x[in_stride]\n" - "mov x28, %x[out]\n" - "csel %x[in], %x[in], x25, GT\n" - "csel x25, x25, %x[pad_row], GT\n" - "csel %x[in], %x[in], x26, GE\n" - "csel x26, x26, %x[pad_row], GE\n" - "cmp x10, #0x1\n" - "sub x10, x10, #0x4\n" - "csel %x[in], %x[in], x27, GT\n" - "csel x27, x27, %x[pad_row], GT\n" - "csel x21, x24, XZR, GT\n" - "mov x20, %x[width]\n" - "7:" // Tail row loop: Column loop - "whilelt p0.b, XZR, x20\n" - "decw x20, ALL, MUL #2\n" - "ld1b { z18.b }, p0/Z, [x9]\n" - "cmp x20, #0x0\n" - "add x9, x9, x24\n" - "ld1b { z19.b }, p0/Z, [x27]\n" - "add x27, x27, x21\n" - "ld1b { z17.b }, p0/Z, [x26]\n" - "add x26, x26, x22\n" - "ld1b { z16.b }, p0/Z, [x25]\n" - "add x25, x25, x23\n" - "zip1 z18.b, z18.b, z17.b\n" - "zip1 z16.b, z19.b, z16.b\n" - "zip1 z17.b, z18.b, z16.b\n" - "zip2 z16.b, z18.b, z16.b\n" - "st1b { z17.b }, p2, [x28]\n" - "st1b { z16.b }, p2, [x28, #1, MUL VL]\n" - "add x28, x28, %x[out_stride]\n" - "bgt 7b\n" - "cmp x10, #0x1\n" - "addvl %x[out], %x[out], #2\n" - "bge 6b\n" - "9:" // Done - "sub x11, x11, #0x1\n" - "cbnz x11, 1b\n" - "mov x22, %x[out]\n" - "mov x21, %x[width]\n" - "dup z18.s, %w[scale_multiplier]\n" - "cbz %x[scale], 11f\n" - "10:" // Scale: Full loop - "mov x20, x21\n" - "decw x21, ALL, MUL #2\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "ld1w { z17.s }, p1/Z, [%x[scale]]\n" - "cmp x21, #0x0\n" - "ld1w { z16.s }, p0/Z, [%x[scale], #1, MUL VL]\n" - "incb %x[scale], ALL, MUL #2\n" - "fmul z17.s, z17.s, z18.s\n" - "fmul z16.s, z16.s, z18.s\n" - "st1w { z17.s }, p2, [x22]\n" - "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" - "add x22, x22, %x[out_stride]\n" - "bgt 10b\n" - "11:" // Scale: Done - "cbz %x[width], 14f\n" - "cbz %x[height], 14f\n" - "dup z21.s, %w[input_zero_point]\n" - "add x25, %x[height], #0x3\n" - "cntw x24, ALL, MUL #2\n" - "mov z20.b, #0x1\n" - "lsr x25, x25, #0x2\n" - "mov x23, %x[width]\n" - "mul x25, %x[k_chunk_count], x25\n" - "addvl x22, x12, #2\n" - "neg z21.s, p2/M, z21.s\n" - "12:" // Bias: N loop - "mov x21, x22\n" - "mov x20, x25\n" - "mov z19.s, #0x0\n" - "mov z18.s, #0x0\n" - "13:" // Bias: K loop - "ld1b { z17.b }, p2/Z, [x21]\n" - "subs x20, x20, #0x1\n" - "ld1b { z16.b }, p2/Z, [x21, #1, MUL VL]\n" - "addvl x21, x21, #2\n" - "sdot z19.s, z17.b, z20.b\n" - "sdot z18.s, z16.b, z20.b\n" - "bgt 13b\n" - "mov x20, x23\n" - "add x22, x22, %x[out_stride]\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "ld1w { z17.s }, p1/Z, [%x[bias]]\n" - "subs x23, x23, x24\n" - "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" - "addvl %x[bias], %x[bias], #2\n" - "mla z17.s, p2/M, z19.s, z21.s\n" - "mla z16.s, p2/M, z18.s, z21.s\n" - "st1w { z17.s }, p2, [x12]\n" - "st1w { z16.s }, p2, [x12, #1, MUL VL]\n" - "add x12, x12, %x[out_stride]\n" - "bgt 12b\n" - "14:" // Bias: Done - ".inst 0xd503467f // SMSTOP\n" - : [bias] "+&r"(bias), [in] "+&r"(in), [out] "+&r"(out), [scale] "+&r"(scale) - : [height] "r"(height), [in_stride] "r"(in_stride), [input_zero_point] "r"(input_zero_point), - [k_chunk_count] "r"(k_chunk_count), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), - [scale_multiplier] "r"(scale_multiplier), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", - "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", - "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); -} + args.input_zero_point = params->lhs_zero_point; + args.scale_multiplier = params->scale_multiplier; + args.pad_row = pad_row; -#endif // Architectural features check. + kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(&args); +} diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h index 77b1dbde..e8325a10 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h @@ -18,7 +18,7 @@ extern "C" { /// /// The starting column index must be divisible by `n_step`. /// -/// @return Step size for column index. +/// @return The n step value. size_t kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. @@ -75,14 +75,14 @@ size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32 /// @param[in] n Number of columns of the output matrix. /// @param[in] k_chunk_count Number of chunks. /// @param[in] k_chunk_length Number of rows in each chunk. -/// @param[in] rhs_row_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs_stride_row Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. /// @param[in] scale Scale data buffer. /// @param[out] rhs_packed Packed RHS matrix. -/// @param[in] params Extra packing parameters. +/// @param[in] params Extra quantization packing parameters. void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( - size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias, const void* scale, void* rhs_packed, const struct kai_rhs_pack_qsi8cx_params* params); #ifdef __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S new file mode 100644 index 00000000..28933124 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S @@ -0,0 +1,240 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x2, [x0, #0x28] + ptrue p2.b + ldr x3, [x0, #0x48] + ldr x4, [x0, #0x0] + ldr x5, [x0, #0x8] + mov x6, x2 + ldr w7, [x0, #0x10] + mov x8, x3 + incb x3, ALL, MUL #2 + ldr w17, [x0, #0x14] + ldr x16, [x0, #0x18] + ldr x15, [x0, #0x20] + ldr x14, [x0, #0x30] + ldr x13, [x0, #0x38] + ldr x12, [x0, #0x40] + ldr x11, [x0, #0x50] +KAI_ASM_LABEL(label_1) // Chunk Loop + mov x10, x15 + cmp x10, #0x8 + blt label_5 +KAI_ASM_LABEL(label_2) // Main row loop: Head + mov x9, x12 + mov x28, x3 + add x27, x9, x14 + sub x10, x10, #0x8 + add x26, x27, x14 + mov x24, x16 + add x25, x26, x14 + add x23, x25, x14 + add x22, x23, x14 + add x21, x22, x14 + add x20, x21, x14 + add x12, x20, x14 +KAI_ASM_LABEL(label_3) // Main row loop: Column loop + whilelt p0.b, XZR, x24 + decw x24, ALL, MUL #2 + ld1b { z18.b }, p0/Z, [x9] + cmp x24, #0x0 + incd x9, ALL, MUL #4 + ld1b { z22.b }, p0/Z, [x27] + incd x27, ALL, MUL #4 + ld1b { z17.b }, p0/Z, [x26] + incd x26, ALL, MUL #4 + ld1b { z16.b }, p0/Z, [x25] + incd x25, ALL, MUL #4 + ld1b { z20.b }, p0/Z, [x23] + incd x23, ALL, MUL #4 + ld1b { z19.b }, p0/Z, [x22] + zip1 z21.b, z18.b, z17.b + incd x22, ALL, MUL #4 + ld1b { z18.b }, p0/Z, [x21] + zip1 z17.b, z22.b, z16.b + incd x21, ALL, MUL #4 + ld1b { z16.b }, p0/Z, [x20] + incd x20, ALL, MUL #4 + zip1 z20.b, z20.b, z18.b + zip1 z16.b, z19.b, z16.b + zip1 z19.b, z21.b, z17.b + zip2 z18.b, z21.b, z17.b + zip1 z17.b, z20.b, z16.b + zip2 z16.b, z20.b, z16.b + st1b { z19.b }, p2, [x28] + st1b { z18.b }, p2, [x28, #1, MUL VL] + st1b { z17.b }, p2, [x28, #2, MUL VL] + st1b { z16.b }, p2, [x28, #3, MUL VL] + add x28, x28, x13 + bgt label_3 + cmp x10, #0x8 + addvl x3, x3, #4 + bge label_2 + cbz x10, label_9 +KAI_ASM_LABEL(label_5) // Main loop skip +KAI_ASM_LABEL(label_6) // Tail row loop: Head + mov x9, x12 + cmp x10, #0x3 + add x27, x9, x14 + cntw x24, ALL, MUL #2 + add x26, x27, x14 + csel x23, x24, XZR, GT + add x25, x26, x14 + csel x22, x24, XZR, GE + add x12, x25, x14 + mov x28, x3 + csel x12, x12, x25, GT + csel x25, x25, x11, GT + csel x12, x12, x26, GE + csel x26, x26, x11, GE + cmp x10, #0x1 + sub x10, x10, #0x4 + csel x12, x12, x27, GT + csel x27, x27, x11, GT + csel x21, x24, XZR, GT + mov x20, x16 +KAI_ASM_LABEL(label_7) // Tail row loop: Column loop + whilelt p0.b, XZR, x20 + decw x20, ALL, MUL #2 + ld1b { z18.b }, p0/Z, [x9] + cmp x20, #0x0 + add x9, x9, x24 + ld1b { z19.b }, p0/Z, [x27] + add x27, x27, x21 + ld1b { z17.b }, p0/Z, [x26] + add x26, x26, x22 + ld1b { z16.b }, p0/Z, [x25] + add x25, x25, x23 + zip1 z18.b, z18.b, z17.b + zip1 z16.b, z19.b, z16.b + zip1 z17.b, z18.b, z16.b + zip2 z16.b, z18.b, z16.b + st1b { z17.b }, p2, [x28] + st1b { z16.b }, p2, [x28, #1, MUL VL] + add x28, x28, x13 + bgt label_7 + cmp x10, #0x1 + addvl x3, x3, #2 + bge label_6 +KAI_ASM_LABEL(label_9) // Done + sub x6, x6, #0x1 + cbnz x6, label_1 + mov x22, x3 + mov x21, x16 + dup z18.s, w17 + cbz x5, label_11 +KAI_ASM_LABEL(label_10) // Scale: Full loop + mov x20, x21 + decw x21, ALL, MUL #2 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + ld1w { z17.s }, p1/Z, [x5] + cmp x21, #0x0 + ld1w { z16.s }, p0/Z, [x5, #1, MUL VL] + incb x5, ALL, MUL #2 + fmul z17.s, z17.s, z18.s + fmul z16.s, z16.s, z18.s + st1w { z17.s }, p2, [x22] + st1w { z16.s }, p2, [x22, #1, MUL VL] + add x22, x22, x13 + bgt label_10 +KAI_ASM_LABEL(label_11) // Scale: Done + cbz x16, label_14 + cbz x15, label_14 + dup z21.s, w7 + add x25, x15, #0x3 + cntw x24, ALL, MUL #2 + mov z20.b, #0x1 + lsr x25, x25, #0x2 + mov x23, x16 + mul x25, x2, x25 + addvl x22, x8, #2 + neg z21.s, p2/M, z21.s +KAI_ASM_LABEL(label_12) // Bias: N loop + mov x21, x22 + mov x20, x25 + mov z19.s, #0x0 + mov z18.s, #0x0 +KAI_ASM_LABEL(label_13) // Bias: K loop + ld1b { z17.b }, p2/Z, [x21] + subs x20, x20, #0x1 + ld1b { z16.b }, p2/Z, [x21, #1, MUL VL] + addvl x21, x21, #2 + sdot z19.s, z17.b, z20.b + sdot z18.s, z16.b, z20.b + bgt label_13 + mov x20, x23 + add x22, x22, x13 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + ld1w { z17.s }, p1/Z, [x4] + subs x23, x23, x24 + ld1w { z16.s }, p0/Z, [x4, #1, MUL VL] + addvl x4, x4, #2 + mla z17.s, p2/M, z19.s, z21.s + mla z16.s, p2/M, z18.s, z21.s + st1w { z17.s }, p2, [x8] + st1w { z16.s }, p2, [x8, #1, MUL VL] + add x8, x8, x13 + bgt label_12 +KAI_ASM_LABEL(label_14) // Bias: Done + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c index a9c0bb73..228f3c0b 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c @@ -3,13 +3,6 @@ // // SPDX-License-Identifier: Apache-2.0 // - -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) -#error This file must be compiled for AArch64, FEAT_SVE2. -#else // Architectural features check. #include "kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h" #include @@ -18,13 +11,29 @@ #include "kai/kai_common.h" -#define NR 2 -#define KR 2 +enum { + NR = 2, + KR = 2, + MAX_N_STEP = NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR), +}; + +typedef struct { + const void* bias_ptr; + size_t width; + size_t height; + size_t k_chunk_count; + size_t in_stride; + size_t out_stride; + const void* in; + void* out; + const void* pad_row; +} KernelArgs; + static const size_t kai_num_bytes_input = sizeof(uint16_t); static const size_t kai_num_bytes_output = sizeof(uint16_t); static const size_t kai_num_bytes_bias = sizeof(uint16_t); -#define MAX_N_STEP (NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR)) +void kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(const KernelArgs* args_ptr); size_t kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(void) { return NR * kai_get_sme_vector_length_u16() / KR; @@ -57,148 +66,32 @@ size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( size_t n, size_t k_chunk_count, size_t k_chunk_length) { - const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme()); + const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme()); return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( - n_rounded_up, k_chunk_count, k_chunk_length); + n_nr_blocks, k_chunk_count, k_chunk_length); } void kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( - size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias, void* rhs_packed) { KAI_ASSUME(rhs != NULL); KAI_ASSUME(bias != NULL); KAI_ASSUME(rhs_packed != NULL); - size_t height = k_chunk_length; - const size_t width = n; - const void* in = rhs; - void* out = rhs_packed; - const size_t in_stride = rhs_row_stride; - KAI_ASSERT(kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() <= MAX_N_STEP); - uint16_t pad_row[MAX_N_STEP]; - if (height % KR) { - memset(pad_row, 0, MAX_N_STEP * sizeof(uint16_t)); - } - - size_t out_stride = + static const uint16_t pad_row[MAX_N_STEP] = {0}; + + KernelArgs args; + args.bias_ptr = bias; + args.height = k_chunk_length; + args.width = n; + args.in = rhs; + args.out = rhs_packed; + args.k_chunk_count = k_chunk_count; + args.in_stride = rhs_stride_row; + args.out_stride = kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(k_chunk_count, k_chunk_length); - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x21, %x[out]\n" - "mov x20, %x[width]\n" - "ptrue p1.b\n" - "1:" // Bias: Full loop - "whilelt p0.h, XZR, x20\n" - "dech x20\n" - "cmp x20, #0x0\n" - "ld1h { z16.h }, p0/Z, [%x[bias]]\n" - "incb %x[bias]\n" - "st1h { z16.h }, p1, [x21]\n" - "add x21, x21, %x[out_stride]\n" - "bgt 1b\n" - "incb %x[out]\n" - "mov x11, %x[k_chunk_count]\n" - "2:" // Chunk Loop - "mov x10, %x[height]\n" - "cmp x10, #0x8\n" - "blt 6f\n" - "3:" // Main row loop: Head - "mov x9, %x[in]\n" - "mov x28, %x[out]\n" - "add x27, x9, %x[in_stride]\n" - "sub x10, x10, #0x8\n" - "add x26, x27, %x[in_stride]\n" - "mov x25, %x[width]\n" - "add x24, x26, %x[in_stride]\n" - "add x23, x24, %x[in_stride]\n" - "add x22, x23, %x[in_stride]\n" - "add x21, x22, %x[in_stride]\n" - "add x20, x21, %x[in_stride]\n" - "add %x[in], x20, %x[in_stride]\n" - "4:" // Main row loop: Column loop - "whilelt p0.h, XZR, x25\n" - "decw x25, ALL, MUL #2\n" - "ld1h { z20.h }, p0/Z, [x9]\n" - "cmp x25, #0x0\n" - "addvl x9, x9, #1\n" - "ld1h { z17.h }, p0/Z, [x27]\n" - "addvl x27, x27, #1\n" - "ld1h { z19.h }, p0/Z, [x26]\n" - "addvl x26, x26, #1\n" - "ld1h { z16.h }, p0/Z, [x24]\n" - "addvl x24, x24, #1\n" - "ld1h { z18.h }, p0/Z, [x23]\n" - "addvl x23, x23, #1\n" - "zip1 z24.h, z20.h, z17.h\n" - "zip2 z23.h, z20.h, z17.h\n" - "ld1h { z17.h }, p0/Z, [x22]\n" - "addvl x22, x22, #1\n" - "ld1h { z22.h }, p0/Z, [x21]\n" - "addvl x21, x21, #1\n" - "zip1 z21.h, z19.h, z16.h\n" - "zip2 z20.h, z19.h, z16.h\n" - "ld1h { z16.h }, p0/Z, [x20]\n" - "addvl x20, x20, #1\n" - "zip1 z19.h, z18.h, z17.h\n" - "zip2 z18.h, z18.h, z17.h\n" - "st1h { z24.h }, p1, [x28]\n" - "st1h { z23.h }, p1, [x28, #1, MUL VL]\n" - "zip1 z17.h, z22.h, z16.h\n" - "zip2 z16.h, z22.h, z16.h\n" - "st1h { z21.h }, p1, [x28, #2, MUL VL]\n" - "st1h { z20.h }, p1, [x28, #3, MUL VL]\n" - "st1h { z19.h }, p1, [x28, #4, MUL VL]\n" - "st1h { z18.h }, p1, [x28, #5, MUL VL]\n" - "st1h { z17.h }, p1, [x28, #6, MUL VL]\n" - "st1h { z16.h }, p1, [x28, #7, MUL VL]\n" - "add x28, x28, %x[out_stride]\n" - "bgt 4b\n" - "cmp x10, #0x8\n" - "addvl %x[out], %x[out], #8\n" - "bge 3b\n" - "cbz x10, 10f\n" - "6:" // Main loop skip - "7:" // Tail row loop: Head - "mov x9, %x[in]\n" - "cntw x22, ALL, MUL #4\n" - "add x27, x9, %x[in_stride]\n" - "cmp x10, #0x1\n" - "add %x[in], x27, %x[in_stride]\n" - "mov x28, %x[out]\n" - "csel %x[in], %x[in], x27, GT\n" - "csel x27, x27, %x[pad_row], GT\n" - "csel x21, x22, XZR, GT\n" - "sub x10, x10, #0x2\n" - "mov x20, %x[width]\n" - "8:" // Tail row loop: Column loop - "whilelt p0.h, XZR, x20\n" - "decw x20, ALL, MUL #2\n" - "ld1h { z18.h }, p0/Z, [x9]\n" - "cmp x20, #0x0\n" - "add x9, x9, x22\n" - "ld1h { z16.h }, p0/Z, [x27]\n" - "add x27, x27, x21\n" - "zip1 z17.h, z18.h, z16.h\n" - "zip2 z16.h, z18.h, z16.h\n" - "st1h { z17.h }, p1, [x28]\n" - "st1h { z16.h }, p1, [x28, #1, MUL VL]\n" - "add x28, x28, %x[out_stride]\n" - "bgt 8b\n" - "cmp x10, #0x1\n" - "addvl %x[out], %x[out], #2\n" - "bge 7b\n" - "10:" // Done - "sub x11, x11, #0x1\n" - "cbnz x11, 2b\n" - ".inst 0xd503467f // SMSTOP\n" - : [bias] "+&r"(bias), [in] "+&r"(in), [out] "+&r"(out) - : [height] "r"(height), [in_stride] "r"(in_stride), [k_chunk_count] "r"(k_chunk_count), - [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", - "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", - "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); -} + args.pad_row = pad_row; -#endif // Architectural features check. + kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(&args); +} diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h index ebf1aec2..9dd33d72 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h @@ -16,7 +16,7 @@ extern "C" { /// /// The starting column index must be divisible by `n_step`. /// -/// @return Step size for column index. +/// @return The n step value. size_t kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. @@ -65,12 +65,12 @@ size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( /// @param[in] n Number of columns of the output matrix. /// @param[in] k_chunk_count Number of chunks. /// @param[in] k_chunk_length Number of rows in each chunk. -/// @param[in] rhs_row_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs_stride_row Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. /// @param[out] rhs_packed Packed RHS matrix. void kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( - size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias, void* rhs_packed); #ifdef __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S new file mode 100644 index 00000000..d6230333 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S @@ -0,0 +1,175 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x8, [x0, #0x8] + ptrue p1.b + ldr x17, [x0, #0x38] + ldr x23, [x0, #0x0] + ldr x16, [x0, #0x10] + mov x22, x8 + ldr x21, [x0, #0x18] + mov x20, x17 + ldr x15, [x0, #0x20] + ldr x14, [x0, #0x28] + ldr x13, [x0, #0x30] + ldr x12, [x0, #0x40] +KAI_ASM_LABEL(label_1) // Bias: Full loop + whilelt p0.h, XZR, x22 + dech x22 + cmp x22, #0x0 + ld1h { z16.h }, p0/Z, [x23] + incb x23 + st1h { z16.h }, p1, [x20] + add x20, x20, x14 + bgt label_1 + incb x17 + mov x11, x21 +KAI_ASM_LABEL(label_2) // Chunk Loop + mov x10, x16 + cmp x10, #0x8 + blt label_6 +KAI_ASM_LABEL(label_3) // Main row loop: Head + mov x9, x13 + mov x28, x17 + add x27, x9, x15 + sub x10, x10, #0x8 + add x26, x27, x15 + mov x25, x8 + add x24, x26, x15 + add x23, x24, x15 + add x22, x23, x15 + add x21, x22, x15 + add x20, x21, x15 + add x13, x20, x15 +KAI_ASM_LABEL(label_4) // Main row loop: Column loop + whilelt p0.h, XZR, x25 + decw x25, ALL, MUL #2 + ld1h { z20.h }, p0/Z, [x9] + cmp x25, #0x0 + addvl x9, x9, #1 + ld1h { z17.h }, p0/Z, [x27] + addvl x27, x27, #1 + ld1h { z19.h }, p0/Z, [x26] + addvl x26, x26, #1 + ld1h { z16.h }, p0/Z, [x24] + addvl x24, x24, #1 + ld1h { z18.h }, p0/Z, [x23] + addvl x23, x23, #1 + zip1 z24.h, z20.h, z17.h + zip2 z23.h, z20.h, z17.h + ld1h { z17.h }, p0/Z, [x22] + addvl x22, x22, #1 + ld1h { z22.h }, p0/Z, [x21] + addvl x21, x21, #1 + zip1 z21.h, z19.h, z16.h + zip2 z20.h, z19.h, z16.h + ld1h { z16.h }, p0/Z, [x20] + addvl x20, x20, #1 + zip1 z19.h, z18.h, z17.h + zip2 z18.h, z18.h, z17.h + st1h { z24.h }, p1, [x28] + st1h { z23.h }, p1, [x28, #1, MUL VL] + zip1 z17.h, z22.h, z16.h + zip2 z16.h, z22.h, z16.h + st1h { z21.h }, p1, [x28, #2, MUL VL] + st1h { z20.h }, p1, [x28, #3, MUL VL] + st1h { z19.h }, p1, [x28, #4, MUL VL] + st1h { z18.h }, p1, [x28, #5, MUL VL] + st1h { z17.h }, p1, [x28, #6, MUL VL] + st1h { z16.h }, p1, [x28, #7, MUL VL] + add x28, x28, x14 + bgt label_4 + cmp x10, #0x8 + addvl x17, x17, #8 + bge label_3 + cbz x10, label_10 +KAI_ASM_LABEL(label_6) // Main loop skip +KAI_ASM_LABEL(label_7) // Tail row loop: Head + mov x9, x13 + cntw x22, ALL, MUL #4 + add x27, x9, x15 + cmp x10, #0x1 + add x13, x27, x15 + mov x28, x17 + csel x13, x13, x27, GT + csel x27, x27, x12, GT + csel x21, x22, XZR, GT + sub x10, x10, #0x2 + mov x20, x8 +KAI_ASM_LABEL(label_8) // Tail row loop: Column loop + whilelt p0.h, XZR, x20 + decw x20, ALL, MUL #2 + ld1h { z18.h }, p0/Z, [x9] + cmp x20, #0x0 + add x9, x9, x22 + ld1h { z16.h }, p0/Z, [x27] + add x27, x27, x21 + zip1 z17.h, z18.h, z16.h + zip2 z16.h, z18.h, z16.h + st1h { z17.h }, p1, [x28] + st1h { z16.h }, p1, [x28, #1, MUL VL] + add x28, x28, x14 + bgt label_8 + cmp x10, #0x1 + addvl x17, x17, #2 + bge label_7 +KAI_ASM_LABEL(label_10) // Done + sub x11, x11, #0x1 + cbnz x11, label_2 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c index 46c626f1..d56d329b 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c @@ -3,13 +3,6 @@ // // SPDX-License-Identifier: Apache-2.0 // - -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) -#error This file must be compiled for AArch64, FEAT_SVE2. -#else // Architectural features check. #include "kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" #include @@ -17,11 +10,27 @@ #include "kai/kai_common.h" -#define NR 2 -#define KR 1 +enum { + NR = 2, + KR = 1, +}; + +typedef struct { + const void* bias_ptr; + size_t width; + size_t height; + size_t k_chunk_count; + size_t in_stride; + size_t out_stride; + const void* in; + void* out; +} KernelArgs; + static const size_t kai_num_bytes_input = sizeof(uint32_t); static const size_t kai_num_bytes_output = sizeof(uint32_t); -static const size_t kai_num_bytes_bias = sizeof(uint32_t); +static const size_t kai_num_bytes_bias = sizeof(float); + +void kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(const KernelArgs* args_ptr); size_t kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(void) { return NR * kai_get_sme_vector_length_u32() / KR; @@ -54,129 +63,28 @@ size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( size_t n, size_t k_chunk_count, size_t k_chunk_length) { - const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme()); + const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme()); return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( - n_rounded_up, k_chunk_count, k_chunk_length); + n_nr_blocks, k_chunk_count, k_chunk_length); } void kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( - size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias, void* rhs_packed) { KAI_ASSUME(rhs != NULL); KAI_ASSUME(bias != NULL); KAI_ASSUME(rhs_packed != NULL); - size_t height = k_chunk_length; - const size_t width = n; - const void* in = rhs; - void* out = rhs_packed; - const size_t in_stride = rhs_row_stride; - - size_t out_stride = + KernelArgs args; + args.bias_ptr = bias; + args.height = k_chunk_length; + args.width = n; + args.in = rhs; + args.out = rhs_packed; + args.k_chunk_count = k_chunk_count; + args.in_stride = rhs_stride_row; + args.out_stride = kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(k_chunk_count, k_chunk_length); - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x22, %x[out]\n" - "mov x21, %x[width]\n" - "ptrue p2.b\n" - "1:" // Bias: Full loop - "mov x20, x21\n" - "decw x21, ALL, MUL #2\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "cmp x21, #0x0\n" - "ld1w { z17.s }, p1/Z, [%x[bias]]\n" - "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" - "incb %x[bias], ALL, MUL #2\n" - "st1w { z17.s }, p2, [x22]\n" - "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" - "add x22, x22, %x[out_stride]\n" - "bgt 1b\n" - "incb %x[out], ALL, MUL #2\n" - "mov x28, %x[k_chunk_count]\n" - "2:" // Chunk Loop - "mov x27, %x[height]\n" - "cmp x27, #0x4\n" - "blt 6f\n" - "3:" // Main row loop: Head - "mov x26, %x[in]\n" - "mov x25, %x[out]\n" - "add x24, x26, %x[in_stride]\n" - "sub x27, x27, #0x4\n" - "add x23, x24, %x[in_stride]\n" - "mov x22, %x[width]\n" - "add x21, x23, %x[in_stride]\n" - "add %x[in], x21, %x[in_stride]\n" - "4:" // Main row loop: Column loop - "mov x20, x22\n" - "decw x22, ALL, MUL #2\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "cmp x22, #0x0\n" - "ld1w { z23.s }, p1/Z, [x26]\n" - "ld1w { z22.s }, p0/Z, [x26, #1, MUL VL]\n" - "addvl x26, x26, #2\n" - "ld1w { z21.s }, p1/Z, [x24]\n" - "ld1w { z20.s }, p0/Z, [x24, #1, MUL VL]\n" - "addvl x24, x24, #2\n" - "ld1w { z19.s }, p1/Z, [x23]\n" - "ld1w { z18.s }, p0/Z, [x23, #1, MUL VL]\n" - "addvl x23, x23, #2\n" - "ld1w { z17.s }, p1/Z, [x21]\n" - "ld1w { z16.s }, p0/Z, [x21, #1, MUL VL]\n" - "addvl x21, x21, #2\n" - "st1w { z23.s }, p2, [x25]\n" - "st1w { z22.s }, p2, [x25, #1, MUL VL]\n" - "st1w { z21.s }, p2, [x25, #2, MUL VL]\n" - "st1w { z20.s }, p2, [x25, #3, MUL VL]\n" - "st1w { z19.s }, p2, [x25, #4, MUL VL]\n" - "st1w { z18.s }, p2, [x25, #5, MUL VL]\n" - "st1w { z17.s }, p2, [x25, #6, MUL VL]\n" - "st1w { z16.s }, p2, [x25, #7, MUL VL]\n" - "add x25, x25, %x[out_stride]\n" - "bgt 4b\n" - "cmp x27, #0x4\n" - "addvl %x[out], %x[out], #8\n" - "bge 3b\n" - "cbz x27, 10f\n" - "6:" // Main loop skip - "7:" // Tail row loop: Head - "mov x26, %x[in]\n" - "cntw x22, ALL, MUL #8\n" - "add %x[in], x26, %x[in_stride]\n" - "mov x25, %x[out]\n" - "sub x27, x27, #0x1\n" - "mov x21, %x[width]\n" - "8:" // Tail row loop: Column loop - "mov x20, x21\n" - "decw x21, ALL, MUL #2\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "cmp x21, #0x0\n" - "ld1w { z17.s }, p1/Z, [x26]\n" - "ld1w { z16.s }, p0/Z, [x26, #1, MUL VL]\n" - "add x26, x26, x22\n" - "st1w { z17.s }, p2, [x25]\n" - "st1w { z16.s }, p2, [x25, #1, MUL VL]\n" - "add x25, x25, %x[out_stride]\n" - "bgt 8b\n" - "cmp x27, #0x1\n" - "addvl %x[out], %x[out], #2\n" - "bge 7b\n" - "10:" // Done - "sub x28, x28, #0x1\n" - "cbnz x28, 2b\n" - ".inst 0xd503467f // SMSTOP\n" - : [bias] "+&r"(bias), [in] "+&r"(in), [out] "+&r"(out) - : [height] "r"(height), [in_stride] "r"(in_stride), [k_chunk_count] "r"(k_chunk_count), - [out_stride] "r"(out_stride), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z10", "z11", "z12", - "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", - "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); -} -#endif // Architectural features check. + kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(&args); +} diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h index ea16c9df..4af70ce1 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h @@ -16,7 +16,7 @@ extern "C" { /// /// The starting column index must be divisible by `n_step`. /// -/// @return Step size for column index. +/// @return The n step value. size_t kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. @@ -65,12 +65,12 @@ size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( /// @param[in] n Number of columns of the output matrix. /// @param[in] k_chunk_count Number of chunks. /// @param[in] k_chunk_length Number of rows in each chunk. -/// @param[in] rhs_row_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs_stride_row Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. /// @param[out] rhs_packed Packed RHS matrix. void kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( - size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias, void* rhs_packed); #ifdef __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S new file mode 100644 index 00000000..9fa37cfb --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S @@ -0,0 +1,161 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x14, [x0, #0x8] + ptrue p2.b + ldr x13, [x0, #0x38] + ldr x24, [x0, #0x0] + ldr x12, [x0, #0x10] + mov x23, x14 + ldr x22, [x0, #0x18] + mov x21, x13 + ldr x11, [x0, #0x20] + ldr x10, [x0, #0x28] + ldr x9, [x0, #0x30] +KAI_ASM_LABEL(label_1) // Bias: Full loop + mov x20, x23 + decw x23, ALL, MUL #2 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + cmp x23, #0x0 + ld1w { z17.s }, p1/Z, [x24] + ld1w { z16.s }, p0/Z, [x24, #1, MUL VL] + incb x24, ALL, MUL #2 + st1w { z17.s }, p2, [x21] + st1w { z16.s }, p2, [x21, #1, MUL VL] + add x21, x21, x10 + bgt label_1 + incb x13, ALL, MUL #2 + mov x28, x22 +KAI_ASM_LABEL(label_2) // Chunk Loop + mov x27, x12 + cmp x27, #0x4 + blt label_6 +KAI_ASM_LABEL(label_3) // Main row loop: Head + mov x26, x9 + mov x25, x13 + add x24, x26, x11 + sub x27, x27, #0x4 + add x23, x24, x11 + mov x22, x14 + add x21, x23, x11 + add x9, x21, x11 +KAI_ASM_LABEL(label_4) // Main row loop: Column loop + mov x20, x22 + decw x22, ALL, MUL #2 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + cmp x22, #0x0 + ld1w { z23.s }, p1/Z, [x26] + ld1w { z22.s }, p0/Z, [x26, #1, MUL VL] + addvl x26, x26, #2 + ld1w { z21.s }, p1/Z, [x24] + ld1w { z20.s }, p0/Z, [x24, #1, MUL VL] + addvl x24, x24, #2 + ld1w { z19.s }, p1/Z, [x23] + ld1w { z18.s }, p0/Z, [x23, #1, MUL VL] + addvl x23, x23, #2 + ld1w { z17.s }, p1/Z, [x21] + ld1w { z16.s }, p0/Z, [x21, #1, MUL VL] + addvl x21, x21, #2 + st1w { z23.s }, p2, [x25] + st1w { z22.s }, p2, [x25, #1, MUL VL] + st1w { z21.s }, p2, [x25, #2, MUL VL] + st1w { z20.s }, p2, [x25, #3, MUL VL] + st1w { z19.s }, p2, [x25, #4, MUL VL] + st1w { z18.s }, p2, [x25, #5, MUL VL] + st1w { z17.s }, p2, [x25, #6, MUL VL] + st1w { z16.s }, p2, [x25, #7, MUL VL] + add x25, x25, x10 + bgt label_4 + cmp x27, #0x4 + addvl x13, x13, #8 + bge label_3 + cbz x27, label_10 +KAI_ASM_LABEL(label_6) // Main loop skip +KAI_ASM_LABEL(label_7) // Tail row loop: Head + mov x26, x9 + cntw x22, ALL, MUL #8 + add x9, x26, x11 + mov x25, x13 + sub x27, x27, #0x1 + mov x21, x14 +KAI_ASM_LABEL(label_8) // Tail row loop: Column loop + mov x20, x21 + decw x21, ALL, MUL #2 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + cmp x21, #0x0 + ld1w { z17.s }, p1/Z, [x26] + ld1w { z16.s }, p0/Z, [x26, #1, MUL VL] + add x26, x26, x22 + st1w { z17.s }, p2, [x25] + st1w { z16.s }, p2, [x25, #1, MUL VL] + add x25, x25, x10 + bgt label_8 + cmp x27, #0x1 + addvl x13, x13, #2 + bge label_7 +KAI_ASM_LABEL(label_10) // Done + sub x28, x28, #0x1 + cbnz x28, label_2 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) + + KAI_ASM_END -- GitLab From 9a24e25d6462a6dca97293ed420de5d6e656303c Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Thu, 15 May 2025 20:20:03 +0200 Subject: [PATCH 2/9] Add changelog for pure assembly imatmul Signed-off-by: Emil Ohlsson --- CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9f740d3..24727706 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,17 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- Convert SME and SME2 imatmul micro-kernels to use pure assembly. Affects: + - kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa + - kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa + - kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa + - kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme + - kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme + - kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme + - kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme + - kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme + - kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme + ## v1.8.0 - New Advanced SIMD micro-kernels: -- GitLab From bd175e809cc3598b7b0c016ce2ea35bfe849b45e Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Mon, 26 May 2025 08:00:19 +0200 Subject: [PATCH 3/9] Split CMakeLists.txt and align dst_stride_row Signed-off-by: Emil Ohlsson --- CMakeLists.txt | 22 ++++++++++++++----- ...ai_imatmul_clamp_f16_f16p_f16p_interface.h | 2 +- ...ai_imatmul_clamp_f32_f32p_f32p_interface.h | 2 +- ...atmul_clamp_qai8_qai8p_qsi8cxp_interface.h | 2 +- test/tests/imatmul_test.cpp | 2 +- .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 2 +- 6 files changed, 21 insertions(+), 11 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1627d694..1c7f1353 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -221,22 +221,26 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c ) -set(KLEIDIAI_FILES_SME +set(KLEIDIAI_FILES_SME_ASM kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S - kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S +) + +set(KLEIDIAI_FILES_SME + ${KLEIDIAI_FILES_SME_ASM} + kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -245,13 +249,17 @@ set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c ) -set(KLEIDIAI_FILES_SME2 +set(KLEIDIAI_FILES_SME2_ASM kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S +) + +set(KLEIDIAI_FILES_SME2 + ${KLEIDIAI_FILES_SME2_ASM} kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c @@ -302,8 +310,10 @@ if(NOT MSVC) else() target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM_ASM}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME_ASM}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2_ASM}) - set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h index bbc2b318..b2eef490 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h @@ -27,7 +27,7 @@ typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_dst_size_func_t)(size_t m, /// Micro-kernel core function ("run" method) typedef void (*kai_imatmul_clamp_f16_f16p_f16p_run_imatmul_func_t)( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + void* dst, size_t dst_stride_row, float clamp_min, float clamp_max); /// Micro-kernel interface struct kai_imatmul_clamp_f16_f16p_f16p_ukernel { diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h index 6e629274..58440f4b 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h @@ -27,7 +27,7 @@ typedef size_t (*kai_imatmul_clamp_f32_f32p_f32p_get_dst_size_func_t)(size_t m, /// Micro-kernel core function ("run" method) typedef void (*kai_imatmul_clamp_f32_f32p_f32p_run_imatmul_func_t)( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + void* dst, size_t dst_stride_row, float clamp_min, float clamp_max); /// Micro-kernel interface struct kai_imatmul_clamp_f32_f32p_f32p_ukernel { diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h index 84ca66b1..a2527662 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h @@ -28,7 +28,7 @@ typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_size_func_t)(size_ /// Micro-kernel core function ("run" method) typedef void (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_imatmul_func_t)( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, - void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params); + void* dst, size_t dst_stride_row, const struct kai_matmul_requantize32_params* params); /// Micro-kernel interface struct kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel { diff --git a/test/tests/imatmul_test.cpp b/test/tests/imatmul_test.cpp index dd12fd4f..eb770811 100644 --- a/test/tests/imatmul_test.cpp +++ b/test/tests/imatmul_test.cpp @@ -76,7 +76,7 @@ struct MatMulIndirectKernel { std::function get_kr; std::function get_lhs_packed_offset; std::function get_rhs_packed_offset; - std::function get_dst_offset; + std::function get_dst_offset; std::function get_dst_size; std::function get_n_step; std::function get_packed_lhs_offset; std::function get_packed_rhs_offset; - std::function get_dst_offset; + std::function get_dst_offset; std::function get_dst_size; std::function Date: Mon, 26 May 2025 08:13:16 +0200 Subject: [PATCH 4/9] Revert accidental MSVC enabling Signed-off-by: Emil Ohlsson --- CMakeLists.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1c7f1353..127a03d4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -310,10 +310,8 @@ if(NOT MSVC) else() target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM_ASM}) - target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME_ASM}) - target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2_ASM}) - set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) -- GitLab From 88a1e8e385c0441c483bd1d075718c2b8f5b7dd1 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Mon, 19 May 2025 11:50:44 +0200 Subject: [PATCH 5/9] Add MSVC support for imatmul kernels This change enables MSVC for imatmul kernels. In doing so there is an addition of function converting `float` to `fp16`, but represented as a `uint16_t`. Signed-off-by: Emil Ohlsson --- CHANGELOG.md | 2 +- CMakeLists.txt | 4 +- ...16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c | 9 +- ...16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S | 166 +++++----- ...2p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S | 162 +++++----- ...lx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S | 286 +++++++++--------- .../kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c | 2 +- ..._lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S | 12 +- .../kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c | 2 +- ..._lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S | 12 +- .../kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c | 3 +- ...ai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S | 12 +- ...kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S | 12 +- ...tmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S | 12 +- ...tmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S | 12 +- 15 files changed, 383 insertions(+), 325 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a1e3921e..152cf1ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release -- Convert SME and SME2 imatmul micro-kernels to use pure assembly. Affects: +- Convert SME and SME2 imatmul micro-kernels to use pure assembly, and add MSVC support. Affects: - kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa - kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa - kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa diff --git a/CMakeLists.txt b/CMakeLists.txt index 91bc5943..bcb7049c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -314,8 +314,10 @@ if(NOT MSVC) else() target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM_ASM}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME_ASM}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2_ASM}) - set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c index cc2e7e41..904268ff 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c @@ -19,8 +19,8 @@ typedef struct { uint64_t M; uint64_t N; uint64_t K; - float16_t min; - float16_t max; + uint16_t min; + uint16_t max; void* accumulator_buffer; uint64_t flags; } KernelArgs; @@ -30,6 +30,7 @@ static const size_t kai_nr = 2; static const size_t kai_kr = 2; void kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(KernelArgs* args); +uint16_t kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(float value); // Returns a constant value specific to this kernel that's relative to vector length static size_t kai_get_kernel_vec_length_constant(void) { @@ -90,8 +91,8 @@ void kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( args.M = m; args.N = n; args.K = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); - args.min = (float16_t)clamp_min; - args.max = (float16_t)clamp_max; + args.min = kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(clamp_min); + args.max = kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(clamp_max); args.accumulator_buffer = NULL; args.flags = 0; diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S index 24d69bf5..259e07f9 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S @@ -39,101 +39,105 @@ KAI_ASM_ALIGN KAI_ASM_GLOBAL(kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + KAI_ASM_GLOBAL(kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + fcvt h0, s0 + fmov w0, h0 + ret + KAI_ASM_FUNCTION_END(kai_f16_from_float_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) KAI_ASM_FUNCTION_TYPE(kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) KAI_ASM_FUNCTION_LABEL(kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA - mov x13, #0x0 + mov x14, #0x0 + ldr x13, [x0, #0x30] ptrue p1.b KAI_ASM_INST(0x25207810) // ptrue pn8.b ldr w11, [x0, #0x20] - ldr w10, [x0, #0x28] - mov x9, #0x0 + mov x10, #0x0 + ldr w9, [x0, #0x28] + add x13, x13, #0x1 ldr x28, [x0, #0x0] + lsr x13, x13, #0x1 KAI_ASM_LABEL(label_1) // M loop ldr x27, [x0, #0x8] KAI_ASM_LABEL(label_2) // N loop - fmov z24.h, #0.0 - ld1h { z5.h }, p1/Z, [x27] - fmov z27.h, #1.0 + fmov z23.h, #0.0 + ld1h { z18.h }, p1/Z, [x27] + fmov z2.h, #1.0 mov x26, x28 KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } inch x27, ALL, MUL #2 - zip1 z30.h, z5.h, z24.h - zip2 z20.h, z5.h, z24.h - KAI_ASM_INST(0x81be2760) // fmopa za0.s, p1/M, p1/M, z27.h, z30.h - KAI_ASM_INST(0x81b42761) // fmopa za1.s, p1/M, p1/M, z27.h, z20.h - KAI_ASM_INST(0x81be2762) // fmopa za2.s, p1/M, p1/M, z27.h, z30.h - KAI_ASM_INST(0x81b42763) // fmopa za3.s, p1/M, p1/M, z27.h, z20.h - ldr x20, [x0, #0x30] - add x20, x20, #0x1 - lsr x20, x20, #0x1 - lsr x21, x20, #0x2 - and x20, x20, #0x3 + zip1 z14.h, z18.h, z23.h + zip2 z3.h, z18.h, z23.h + KAI_ASM_INST(0x81ae2440) // fmopa za0.s, p1/M, p1/M, z2.h, z14.h + KAI_ASM_INST(0x81a32441) // fmopa za1.s, p1/M, p1/M, z2.h, z3.h + KAI_ASM_INST(0x81ae2442) // fmopa za2.s, p1/M, p1/M, z2.h, z14.h + KAI_ASM_INST(0x81a32443) // fmopa za3.s, p1/M, p1/M, z2.h, z3.h + lsr x21, x13, #0x2 + and x20, x13, #0x3 cbz x21, label_6 subs x21, x21, #0x1 - KAI_ASM_INST(0xa0402352) // ld1h { z18.h-z19.h }, pn8.b/Z, [x26] - KAI_ASM_INST(0xa0402370) // ld1h { z16.h-z17.h }, pn8.b/Z, [x27] - KAI_ASM_INST(0xa1412342) // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL] - KAI_ASM_INST(0xa041237e) // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL] - KAI_ASM_INST(0xa042235c) // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL] - KAI_ASM_INST(0xa1422366) // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL] - KAI_ASM_INST(0xa1432345) // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL] + KAI_ASM_INST(0xa040a350) // ld1h { z16.h-z19.h }, pn8.b/Z, [x26] + KAI_ASM_INST(0xa041a35c) // ld1h { z28.h-z31.h }, pn8.b/Z, [x26, #0x4, MUL VL] addvl x26, x26, #8 - KAI_ASM_INST(0xa1432367) // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL] + KAI_ASM_INST(0xa040a360) // ld1h { z0.h-z3.h }, pn8.b/Z, [x27] + KAI_ASM_INST(0xa041a368) // ld1h { z8.h-z11.h }, pn8.b/Z, [x27, #0x4, MUL VL] addvl x27, x27, #8 ble label_5 KAI_ASM_LABEL(label_4) // K loop - KAI_ASM_INST(0x81b02640) // fmopa za0.s, p1/M, p1/M, z18.h, z16.h + KAI_ASM_INST(0x81a02600) // fmopa za0.s, p1/M, p1/M, z16.h, z0.h subs x21, x21, #0x1 - KAI_ASM_INST(0x81b12641) // fmopa za1.s, p1/M, p1/M, z18.h, z17.h - KAI_ASM_INST(0x81b02662) // fmopa za2.s, p1/M, p1/M, z19.h, z16.h - KAI_ASM_INST(0x81b12663) // fmopa za3.s, p1/M, p1/M, z19.h, z17.h - KAI_ASM_INST(0xa0402352) // ld1h { z18.h-z19.h }, pn8.b/Z, [x26] - KAI_ASM_INST(0x81be2440) // fmopa za0.s, p1/M, p1/M, z2.h, z30.h - KAI_ASM_INST(0xa0402370) // ld1h { z16.h-z17.h }, pn8.b/Z, [x27] - KAI_ASM_INST(0x81bf2441) // fmopa za1.s, p1/M, p1/M, z2.h, z31.h - KAI_ASM_INST(0x81be2542) // fmopa za2.s, p1/M, p1/M, z10.h, z30.h - KAI_ASM_INST(0x81bf2543) // fmopa za3.s, p1/M, p1/M, z10.h, z31.h - KAI_ASM_INST(0xa1412342) // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL] - KAI_ASM_INST(0x81a62780) // fmopa za0.s, p1/M, p1/M, z28.h, z6.h - KAI_ASM_INST(0xa041237e) // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL] - KAI_ASM_INST(0x81ae2781) // fmopa za1.s, p1/M, p1/M, z28.h, z14.h - KAI_ASM_INST(0x81a627a2) // fmopa za2.s, p1/M, p1/M, z29.h, z6.h - KAI_ASM_INST(0x81ae27a3) // fmopa za3.s, p1/M, p1/M, z29.h, z14.h - KAI_ASM_INST(0xa042235c) // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL] - KAI_ASM_INST(0xa1422366) // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL] - KAI_ASM_INST(0x81a724a0) // fmopa za0.s, p1/M, p1/M, z5.h, z7.h - KAI_ASM_INST(0x81af24a1) // fmopa za1.s, p1/M, p1/M, z5.h, z15.h - KAI_ASM_INST(0x81a725a2) // fmopa za2.s, p1/M, p1/M, z13.h, z7.h - KAI_ASM_INST(0x81af25a3) // fmopa za3.s, p1/M, p1/M, z13.h, z15.h - KAI_ASM_INST(0xa1432345) // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL] + KAI_ASM_INST(0x81a12601) // fmopa za1.s, p1/M, p1/M, z16.h, z1.h + KAI_ASM_INST(0x81a02622) // fmopa za2.s, p1/M, p1/M, z17.h, z0.h + KAI_ASM_INST(0x81a12623) // fmopa za3.s, p1/M, p1/M, z17.h, z1.h + KAI_ASM_INST(0x81a22640) // fmopa za0.s, p1/M, p1/M, z18.h, z2.h + KAI_ASM_INST(0x81a32641) // fmopa za1.s, p1/M, p1/M, z18.h, z3.h + KAI_ASM_INST(0x81a22662) // fmopa za2.s, p1/M, p1/M, z19.h, z2.h + KAI_ASM_INST(0x81a32663) // fmopa za3.s, p1/M, p1/M, z19.h, z3.h + KAI_ASM_INST(0xa040a350) // ld1h { z16.h-z19.h }, pn8.b/Z, [x26] + KAI_ASM_INST(0x81a82780) // fmopa za0.s, p1/M, p1/M, z28.h, z8.h + KAI_ASM_INST(0xa040a360) // ld1h { z0.h-z3.h }, pn8.b/Z, [x27] + KAI_ASM_INST(0x81a92781) // fmopa za1.s, p1/M, p1/M, z28.h, z9.h + KAI_ASM_INST(0x81a827a2) // fmopa za2.s, p1/M, p1/M, z29.h, z8.h + KAI_ASM_INST(0x81a927a3) // fmopa za3.s, p1/M, p1/M, z29.h, z9.h + KAI_ASM_INST(0x81aa27c0) // fmopa za0.s, p1/M, p1/M, z30.h, z10.h + KAI_ASM_INST(0x81ab27c1) // fmopa za1.s, p1/M, p1/M, z30.h, z11.h + KAI_ASM_INST(0x81aa27e2) // fmopa za2.s, p1/M, p1/M, z31.h, z10.h + KAI_ASM_INST(0x81ab27e3) // fmopa za3.s, p1/M, p1/M, z31.h, z11.h + KAI_ASM_INST(0xa041a35c) // ld1h { z28.h-z31.h }, pn8.b/Z, [x26, #0x4, MUL VL] addvl x26, x26, #8 - KAI_ASM_INST(0xa1432367) // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL] + KAI_ASM_INST(0xa041a368) // ld1h { z8.h-z11.h }, pn8.b/Z, [x27, #0x4, MUL VL] addvl x27, x27, #8 bgt label_4 KAI_ASM_LABEL(label_5) // K loop tail - KAI_ASM_INST(0x81b02640) // fmopa za0.s, p1/M, p1/M, z18.h, z16.h - KAI_ASM_INST(0x81b12641) // fmopa za1.s, p1/M, p1/M, z18.h, z17.h - KAI_ASM_INST(0x81b02662) // fmopa za2.s, p1/M, p1/M, z19.h, z16.h - KAI_ASM_INST(0x81b12663) // fmopa za3.s, p1/M, p1/M, z19.h, z17.h - KAI_ASM_INST(0x81be2440) // fmopa za0.s, p1/M, p1/M, z2.h, z30.h - KAI_ASM_INST(0x81bf2441) // fmopa za1.s, p1/M, p1/M, z2.h, z31.h - KAI_ASM_INST(0x81be2542) // fmopa za2.s, p1/M, p1/M, z10.h, z30.h - KAI_ASM_INST(0x81bf2543) // fmopa za3.s, p1/M, p1/M, z10.h, z31.h - KAI_ASM_INST(0x81a62780) // fmopa za0.s, p1/M, p1/M, z28.h, z6.h - KAI_ASM_INST(0x81ae2781) // fmopa za1.s, p1/M, p1/M, z28.h, z14.h - KAI_ASM_INST(0x81a627a2) // fmopa za2.s, p1/M, p1/M, z29.h, z6.h - KAI_ASM_INST(0x81ae27a3) // fmopa za3.s, p1/M, p1/M, z29.h, z14.h - KAI_ASM_INST(0x81a724a0) // fmopa za0.s, p1/M, p1/M, z5.h, z7.h - KAI_ASM_INST(0x81af24a1) // fmopa za1.s, p1/M, p1/M, z5.h, z15.h - KAI_ASM_INST(0x81a725a2) // fmopa za2.s, p1/M, p1/M, z13.h, z7.h - KAI_ASM_INST(0x81af25a3) // fmopa za3.s, p1/M, p1/M, z13.h, z15.h + KAI_ASM_INST(0x81a02600) // fmopa za0.s, p1/M, p1/M, z16.h, z0.h + KAI_ASM_INST(0x81a12601) // fmopa za1.s, p1/M, p1/M, z16.h, z1.h + KAI_ASM_INST(0x81a02622) // fmopa za2.s, p1/M, p1/M, z17.h, z0.h + KAI_ASM_INST(0x81a12623) // fmopa za3.s, p1/M, p1/M, z17.h, z1.h + KAI_ASM_INST(0x81a22640) // fmopa za0.s, p1/M, p1/M, z18.h, z2.h + KAI_ASM_INST(0x81a32641) // fmopa za1.s, p1/M, p1/M, z18.h, z3.h + KAI_ASM_INST(0x81a22662) // fmopa za2.s, p1/M, p1/M, z19.h, z2.h + KAI_ASM_INST(0x81a32663) // fmopa za3.s, p1/M, p1/M, z19.h, z3.h + KAI_ASM_INST(0x81a82780) // fmopa za0.s, p1/M, p1/M, z28.h, z8.h + KAI_ASM_INST(0x81a92781) // fmopa za1.s, p1/M, p1/M, z28.h, z9.h + KAI_ASM_INST(0x81a827a2) // fmopa za2.s, p1/M, p1/M, z29.h, z8.h + KAI_ASM_INST(0x81a927a3) // fmopa za3.s, p1/M, p1/M, z29.h, z9.h + KAI_ASM_INST(0x81aa27c0) // fmopa za0.s, p1/M, p1/M, z30.h, z10.h + KAI_ASM_INST(0x81ab27c1) // fmopa za1.s, p1/M, p1/M, z30.h, z11.h + KAI_ASM_INST(0x81aa27e2) // fmopa za2.s, p1/M, p1/M, z31.h, z10.h + KAI_ASM_INST(0x81ab27e3) // fmopa za3.s, p1/M, p1/M, z31.h, z11.h KAI_ASM_LABEL(label_6) // K oddments cbz x20, label_8 KAI_ASM_LABEL(label_7) // K oddments: Loop @@ -149,37 +153,37 @@ KAI_ASM_LABEL(label_7) // K oddments: Loop bgt label_7 KAI_ASM_LABEL(label_8) // K oddments: End ldr x25, [x0, #0x10] - sub x24, x11, x13 + sub x24, x11, x14 cntw x23, ALL, MUL #2 ld1rh { z17.h }, p1/Z, [x0, #56] ldr x22, [x0, #0x18] - whilelt p0.h, x9, x10 + whilelt p0.h, x10, x9 cmp x24, x23 ld1rh { z16.h }, p1/Z, [x0, #58] mov x12, #0x0 mov x21, #0x0 - add x25, x25, x9, LSL #1 // C += n + add x25, x25, x10, LSL #1 // C += n mov x20, #0x2 - madd x25, x13, x22, x25 // C += m * ldc + madd x25, x14, x22, x25 // C += m * ldc csel x24, x24, x23, LT KAI_ASM_LABEL(label_10) // Store to output array: Accumulator loop KAI_ASM_INST(0xc006000e) // mova { z14.b-z15.b }, za0h.b[x12, 0:1] add x12, x12, #0x4 cmp x12, x23, LSL #1 add x21, x21, #0x1 - KAI_ASM_INST(0xc120e1cc) // fcvt z12.h, { z14.s-z15.s } + KAI_ASM_INST(0xc120e1c4) // fcvt z4.h, { z14.s-z15.s } csel x12, x12, x20, LT cmp x21, x24 - KAI_ASM_INST(0x6470262c) // fclamp z12.h, z17.h, z16.h - st1h { z12.h }, p0, [x25] + KAI_ASM_INST(0x64702624) // fclamp z4.h, z17.h, z16.h + st1h { z4.h }, p0, [x25] add x25, x25, x22 blt label_10 - incw x9, ALL, MUL #2 - cmp x9, x10 + incw x10, ALL, MUL #2 + cmp x10, x9 blt label_2 - incw x13, ALL, MUL #2 - mov x9, #0x0 - cmp x13, x11 + incw x14, ALL, MUL #2 + mov x10, #0x0 + cmp x14, x11 mov x28, x26 blt label_1 KAI_ASM_INST(0xd503467f) // SMSTOP @@ -187,7 +191,11 @@ KAI_ASM_LABEL(label_10) // Store to output array: Accumulator loop ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S index bb60a77d..b355fbe9 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S @@ -42,94 +42,90 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) KAI_ASM_FUNCTION_LABEL(kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA - mov x14, #0x0 + mov x15, #0x0 ptrue p0.b KAI_ASM_INST(0x25207811) // ptrue pn9.b + ldr x14, [x0, #0x30] ldr w13, [x0, #0x20] - ldr w11, [x0, #0x28] - mov x10, #0x0 + mov x11, #0x0 + ldr w10, [x0, #0x28] ldr x9, [x0, #0x0] KAI_ASM_LABEL(label_1) // M loop ldr x28, [x0, #0x8] KAI_ASM_LABEL(label_2) // N loop - KAI_ASM_INST(0x25ab4550) // whilelt pn8.s, x10, x11, VLx2 - fmov z13.s, #1.0 + fmov z22.s, #1.0 + KAI_ASM_INST(0xa040478c) // ld1w { z12.s-z13.s }, pn9.b/Z, [x28] KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } mov x27, x9 - KAI_ASM_INST(0xa040438e) // ld1w { z14.s-z15.s }, p8/Z, [x28] // Load bias + KAI_ASM_INST(0x25aa4570) // whilelt pn8.s, x11, x10, VLx2 addvl x28, x28, #2 - KAI_ASM_INST(0x808e01a0) // fmopa za0.s, p0/M, p0/M, z13.s, z14.s - KAI_ASM_INST(0x808f01a1) // fmopa za1.s, p0/M, p0/M, z13.s, z15.s - KAI_ASM_INST(0x808e01a2) // fmopa za2.s, p0/M, p0/M, z13.s, z14.s - KAI_ASM_INST(0x808f01a3) // fmopa za3.s, p0/M, p0/M, z13.s, z15.s - ldr x20, [x0, #0x30] - lsr x21, x20, #0x2 - and x20, x20, #0x3 + KAI_ASM_INST(0x808c02c0) // fmopa za0.s, p0/M, p0/M, z22.s, z12.s + KAI_ASM_INST(0x808d02c1) // fmopa za1.s, p0/M, p0/M, z22.s, z13.s + KAI_ASM_INST(0x808c02c2) // fmopa za2.s, p0/M, p0/M, z22.s, z12.s + KAI_ASM_INST(0x808d02c3) // fmopa za3.s, p0/M, p0/M, z22.s, z13.s + lsr x21, x14, #0x2 + and x20, x14, #0x3 cbz x21, label_6 subs x21, x21, #0x1 - KAI_ASM_INST(0xa1404772) // ld1w { z18.s, z26.s }, pn9.b/Z, [x27] - KAI_ASM_INST(0xa0404794) // ld1w { z20.s-z21.s }, pn9.b/Z, [x28] - KAI_ASM_INST(0xa1414764) // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL] - KAI_ASM_INST(0xa041478a) // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL] - KAI_ASM_INST(0xa1424773) // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL] - KAI_ASM_INST(0xa0424798) // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL] - KAI_ASM_INST(0xa043476e) // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL] + KAI_ASM_INST(0xa040c764) // ld1w { z4.s-z7.s }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa041c768) // ld1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x4, MUL VL] addvl x27, x27, #8 - KAI_ASM_INST(0xa1434796) // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL] + KAI_ASM_INST(0xa040c794) // ld1w { z20.s-z23.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0xa041c78c) // ld1w { z12.s-z15.s }, pn9.b/Z, [x28, #0x4, MUL VL] addvl x28, x28, #8 ble label_5 KAI_ASM_LABEL(label_4) // K loop - KAI_ASM_INST(0x80940240) // fmopa za0.s, p0/M, p0/M, z18.s, z20.s + KAI_ASM_INST(0x80940080) // fmopa za0.s, p0/M, p0/M, z4.s, z20.s subs x21, x21, #0x1 - KAI_ASM_INST(0x80950241) // fmopa za1.s, p0/M, p0/M, z18.s, z21.s - KAI_ASM_INST(0x80940342) // fmopa za2.s, p0/M, p0/M, z26.s, z20.s - KAI_ASM_INST(0x80950343) // fmopa za3.s, p0/M, p0/M, z26.s, z21.s - KAI_ASM_INST(0xa1404772) // ld1w { z18.s, z26.s }, pn9.b/Z, [x27] - KAI_ASM_INST(0x808a0080) // fmopa za0.s, p0/M, p0/M, z4.s, z10.s - KAI_ASM_INST(0xa0404794) // ld1w { z20.s-z21.s }, pn9.b/Z, [x28] - KAI_ASM_INST(0x808b0081) // fmopa za1.s, p0/M, p0/M, z4.s, z11.s - KAI_ASM_INST(0x808a0182) // fmopa za2.s, p0/M, p0/M, z12.s, z10.s - KAI_ASM_INST(0x808b0183) // fmopa za3.s, p0/M, p0/M, z12.s, z11.s - KAI_ASM_INST(0xa1414764) // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL] - KAI_ASM_INST(0x80980260) // fmopa za0.s, p0/M, p0/M, z19.s, z24.s - KAI_ASM_INST(0xa041478a) // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL] - KAI_ASM_INST(0x80990261) // fmopa za1.s, p0/M, p0/M, z19.s, z25.s - KAI_ASM_INST(0x80980362) // fmopa za2.s, p0/M, p0/M, z27.s, z24.s - KAI_ASM_INST(0x80990363) // fmopa za3.s, p0/M, p0/M, z27.s, z25.s - KAI_ASM_INST(0xa1424773) // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL] - KAI_ASM_INST(0xa0424798) // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL] - KAI_ASM_INST(0x809601c0) // fmopa za0.s, p0/M, p0/M, z14.s, z22.s - KAI_ASM_INST(0x809e01c1) // fmopa za1.s, p0/M, p0/M, z14.s, z30.s - KAI_ASM_INST(0x809601e2) // fmopa za2.s, p0/M, p0/M, z15.s, z22.s - KAI_ASM_INST(0x809e01e3) // fmopa za3.s, p0/M, p0/M, z15.s, z30.s - KAI_ASM_INST(0xa043476e) // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL] + KAI_ASM_INST(0x80950081) // fmopa za1.s, p0/M, p0/M, z4.s, z21.s + KAI_ASM_INST(0x809400a2) // fmopa za2.s, p0/M, p0/M, z5.s, z20.s + KAI_ASM_INST(0x809500a3) // fmopa za3.s, p0/M, p0/M, z5.s, z21.s + KAI_ASM_INST(0x809600c0) // fmopa za0.s, p0/M, p0/M, z6.s, z22.s + KAI_ASM_INST(0x809700c1) // fmopa za1.s, p0/M, p0/M, z6.s, z23.s + KAI_ASM_INST(0x809600e2) // fmopa za2.s, p0/M, p0/M, z7.s, z22.s + KAI_ASM_INST(0x809700e3) // fmopa za3.s, p0/M, p0/M, z7.s, z23.s + KAI_ASM_INST(0xa040c764) // ld1w { z4.s-z7.s }, pn9.b/Z, [x27] + KAI_ASM_INST(0x808c0100) // fmopa za0.s, p0/M, p0/M, z8.s, z12.s + KAI_ASM_INST(0xa040c794) // ld1w { z20.s-z23.s }, pn9.b/Z, [x28] + KAI_ASM_INST(0x808d0101) // fmopa za1.s, p0/M, p0/M, z8.s, z13.s + KAI_ASM_INST(0x808c0122) // fmopa za2.s, p0/M, p0/M, z9.s, z12.s + KAI_ASM_INST(0x808d0123) // fmopa za3.s, p0/M, p0/M, z9.s, z13.s + KAI_ASM_INST(0x808e0140) // fmopa za0.s, p0/M, p0/M, z10.s, z14.s + KAI_ASM_INST(0x808f0141) // fmopa za1.s, p0/M, p0/M, z10.s, z15.s + KAI_ASM_INST(0x808e0162) // fmopa za2.s, p0/M, p0/M, z11.s, z14.s + KAI_ASM_INST(0x808f0163) // fmopa za3.s, p0/M, p0/M, z11.s, z15.s + KAI_ASM_INST(0xa041c768) // ld1w { z8.s-z11.s }, pn9.b/Z, [x27, #0x4, MUL VL] addvl x27, x27, #8 - KAI_ASM_INST(0xa1434796) // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL] + KAI_ASM_INST(0xa041c78c) // ld1w { z12.s-z15.s }, pn9.b/Z, [x28, #0x4, MUL VL] addvl x28, x28, #8 bgt label_4 KAI_ASM_LABEL(label_5) // K loop tail - KAI_ASM_INST(0x80940240) // fmopa za0.s, p0/M, p0/M, z18.s, z20.s - KAI_ASM_INST(0x80950241) // fmopa za1.s, p0/M, p0/M, z18.s, z21.s - KAI_ASM_INST(0x80940342) // fmopa za2.s, p0/M, p0/M, z26.s, z20.s - KAI_ASM_INST(0x80950343) // fmopa za3.s, p0/M, p0/M, z26.s, z21.s - KAI_ASM_INST(0x808a0080) // fmopa za0.s, p0/M, p0/M, z4.s, z10.s - KAI_ASM_INST(0x808b0081) // fmopa za1.s, p0/M, p0/M, z4.s, z11.s - KAI_ASM_INST(0x808a0182) // fmopa za2.s, p0/M, p0/M, z12.s, z10.s - KAI_ASM_INST(0x808b0183) // fmopa za3.s, p0/M, p0/M, z12.s, z11.s - KAI_ASM_INST(0x80980260) // fmopa za0.s, p0/M, p0/M, z19.s, z24.s - KAI_ASM_INST(0x80990261) // fmopa za1.s, p0/M, p0/M, z19.s, z25.s - KAI_ASM_INST(0x80980362) // fmopa za2.s, p0/M, p0/M, z27.s, z24.s - KAI_ASM_INST(0x80990363) // fmopa za3.s, p0/M, p0/M, z27.s, z25.s - KAI_ASM_INST(0x809601c0) // fmopa za0.s, p0/M, p0/M, z14.s, z22.s - KAI_ASM_INST(0x809e01c1) // fmopa za1.s, p0/M, p0/M, z14.s, z30.s - KAI_ASM_INST(0x809601e2) // fmopa za2.s, p0/M, p0/M, z15.s, z22.s - KAI_ASM_INST(0x809e01e3) // fmopa za3.s, p0/M, p0/M, z15.s, z30.s + KAI_ASM_INST(0x80940080) // fmopa za0.s, p0/M, p0/M, z4.s, z20.s + KAI_ASM_INST(0x80950081) // fmopa za1.s, p0/M, p0/M, z4.s, z21.s + KAI_ASM_INST(0x809400a2) // fmopa za2.s, p0/M, p0/M, z5.s, z20.s + KAI_ASM_INST(0x809500a3) // fmopa za3.s, p0/M, p0/M, z5.s, z21.s + KAI_ASM_INST(0x809600c0) // fmopa za0.s, p0/M, p0/M, z6.s, z22.s + KAI_ASM_INST(0x809700c1) // fmopa za1.s, p0/M, p0/M, z6.s, z23.s + KAI_ASM_INST(0x809600e2) // fmopa za2.s, p0/M, p0/M, z7.s, z22.s + KAI_ASM_INST(0x809700e3) // fmopa za3.s, p0/M, p0/M, z7.s, z23.s + KAI_ASM_INST(0x808c0100) // fmopa za0.s, p0/M, p0/M, z8.s, z12.s + KAI_ASM_INST(0x808d0101) // fmopa za1.s, p0/M, p0/M, z8.s, z13.s + KAI_ASM_INST(0x808c0122) // fmopa za2.s, p0/M, p0/M, z9.s, z12.s + KAI_ASM_INST(0x808d0123) // fmopa za3.s, p0/M, p0/M, z9.s, z13.s + KAI_ASM_INST(0x808e0140) // fmopa za0.s, p0/M, p0/M, z10.s, z14.s + KAI_ASM_INST(0x808f0141) // fmopa za1.s, p0/M, p0/M, z10.s, z15.s + KAI_ASM_INST(0x808e0162) // fmopa za2.s, p0/M, p0/M, z11.s, z14.s + KAI_ASM_INST(0x808f0163) // fmopa za3.s, p0/M, p0/M, z11.s, z15.s KAI_ASM_LABEL(label_6) // K oddments cbz x20, label_8 KAI_ASM_LABEL(label_7) // K oddments: Loop @@ -145,24 +141,24 @@ KAI_ASM_LABEL(label_7) // K oddments: Loop bgt label_7 KAI_ASM_LABEL(label_8) // K oddments: End ldr x26, [x0, #0x10] - sub x25, x13, x14 + sub x25, x13, x15 cntw x24 - ld1rw { z19.s }, p0/Z, [x0, #56] + ld1rw { z26.s }, p0/Z, [x0, #56] ldr x23, [x0, #0x18] cmp x25, x24 - ld1rw { z26.s }, p0/Z, [x0, #60] + ld1rw { z24.s }, p0/Z, [x0, #60] mov x12, #0x0 csel x22, x25, x24, LT - add x26, x26, x10, LSL #2 // C += n + add x26, x26, x11, LSL #2 // C += n lsr x21, x22, #0x2 - madd x26, x14, x23, x26 // C += m * ldc + madd x26, x15, x23, x26 // C += m * ldc and x20, x22, #0x3 cbz x21, label_11 KAI_ASM_LABEL(label_10) // Store to output array: Accumulator row 0 loop KAI_ASM_INST(0xc0860404) // mova { z4.s-z7.s }, za0h.s[x12] KAI_ASM_INST(0xc086042c) // mova { z12.s-z15.s }, za1h.s[x12] - KAI_ASM_INST(0xc1baca64) // fclamp { z4.s-z7.s }, z19.s, z26.s - KAI_ASM_INST(0xc1baca6c) // fclamp { z12.s-z15.s }, z19.s, z26.s + KAI_ASM_INST(0xc1b8cb44) // fclamp { z4.s-z7.s }, z26.s, z24.s + KAI_ASM_INST(0xc1b8cb4c) // fclamp { z12.s-z15.s }, z26.s, z24.s add x12, x12, #0x4 cmp x12, x21, LSL #2 KAI_ASM_INST(0xa1604344) // st1w { z4.s, z12.s }, p8, [x26] @@ -179,8 +175,8 @@ KAI_ASM_LABEL(label_11) // Store to output array: Accumulator row 0 oddments KAI_ASM_INST(0xc0860400) // mova { z0.s-z3.s }, za0h.s[x12] KAI_ASM_INST(0xc0860428) // mova { z8.s-z11.s }, za1h.s[x12] subs x20, x20, #0x1 - KAI_ASM_INST(0xc1baca60) // fclamp { z0.s-z3.s }, z19.s, z26.s - KAI_ASM_INST(0xc1baca68) // fclamp { z8.s-z11.s }, z19.s, z26.s + KAI_ASM_INST(0xc1b8cb40) // fclamp { z0.s-z3.s }, z26.s, z24.s + KAI_ASM_INST(0xc1b8cb48) // fclamp { z8.s-z11.s }, z26.s, z24.s KAI_ASM_INST(0xa1604340) // st1w { z0.s, z8.s }, p8, [x26] add x26, x26, x23 beq label_12 @@ -202,8 +198,8 @@ KAI_ASM_LABEL(label_12) // Store to output array: Accumulator row 0 oddments: E KAI_ASM_LABEL(label_13) // Store to output array: Accumulator row 1 loop KAI_ASM_INST(0xc0860454) // mova { z20.s-z23.s }, za2h.s[x12] KAI_ASM_INST(0xc086047c) // mova { z28.s-z31.s }, za3h.s[x12] - KAI_ASM_INST(0xc1baca74) // fclamp { z20.s-z23.s }, z19.s, z26.s - KAI_ASM_INST(0xc1baca7c) // fclamp { z28.s-z31.s }, z19.s, z26.s + KAI_ASM_INST(0xc1b8cb54) // fclamp { z20.s-z23.s }, z26.s, z24.s + KAI_ASM_INST(0xc1b8cb5c) // fclamp { z28.s-z31.s }, z26.s, z24.s add x12, x12, #0x4 cmp x12, x21, LSL #2 KAI_ASM_INST(0xa1604354) // st1w { z20.s, z28.s }, p8, [x26] @@ -220,8 +216,8 @@ KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments KAI_ASM_INST(0xc0860444) // mova { z4.s-z7.s }, za2h.s[x12] KAI_ASM_INST(0xc086046c) // mova { z12.s-z15.s }, za3h.s[x12] subs x20, x20, #0x1 - KAI_ASM_INST(0xc1baca64) // fclamp { z4.s-z7.s }, z19.s, z26.s - KAI_ASM_INST(0xc1baca6c) // fclamp { z12.s-z15.s }, z19.s, z26.s + KAI_ASM_INST(0xc1b8cb44) // fclamp { z4.s-z7.s }, z26.s, z24.s + KAI_ASM_INST(0xc1b8cb4c) // fclamp { z12.s-z15.s }, z26.s, z24.s KAI_ASM_INST(0xa1604344) // st1w { z4.s, z12.s }, p8, [x26] add x26, x26, x23 beq label_15 @@ -232,12 +228,12 @@ KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments KAI_ASM_INST(0xa1604346) // st1w { z6.s, z14.s }, p8, [x26] KAI_ASM_LABEL(label_15) // Store to output array: Accumulator row 1 oddments: End KAI_ASM_LABEL(label_16) // Store to output array: End - incw x10, ALL, MUL #2 - cmp x10, x11 + incw x11, ALL, MUL #2 + cmp x11, x10 blt label_2 - incw x14, ALL, MUL #2 - mov x10, #0x0 - cmp x14, x13 + incw x15, ALL, MUL #2 + mov x11, #0x0 + cmp x15, x13 mov x9, x27 blt label_1 KAI_ASM_INST(0xd503467f) // SMSTOP @@ -245,7 +241,11 @@ KAI_ASM_LABEL(label_16) // Store to output array: End ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa) diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S index c73aaad2..c08d2dea 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S @@ -42,158 +42,152 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) KAI_ASM_FUNCTION_LABEL(kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA - mov x14, #0x0 + mov x15, #0x0 + ldr x14, [x0, #0x30] ptrue p1.b - KAI_ASM_INST(0x25207811) // ptrue pn9.b + KAI_ASM_INST(0x25207810) // ptrue pn8.b ldr w13, [x0, #0x20] - ldr w11, [x0, #0x28] - mov x10, #0x0 + mov x11, #0x0 + ldr w10, [x0, #0x28] + add x14, x14, #0x3 ldr x9, [x0, #0x0] + lsr x14, x14, #0x2 KAI_ASM_LABEL(label_1) // M loop ldr x28, [x0, #0x8] KAI_ASM_LABEL(label_2) // N loop - KAI_ASM_INST(0x25ab4550) // whilelt pn8.s, x10, x11, VLx2 + KAI_ASM_INST(0xa0404382) // ld1w { z2.s-z3.s }, pn8.b/Z, [x28] KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } mov x27, x9 - KAI_ASM_INST(0xa040438e) // ld1w { z14.s-z15.s }, p8/Z, [x28] // Load bias addvl x28, x28, #2 - KAI_ASM_INST(0xc09025c0) // addha za0.s, p1/M, p1/M, z14.s - KAI_ASM_INST(0xc09025e1) // addha za1.s, p1/M, p1/M, z15.s - KAI_ASM_INST(0xc09025c2) // addha za2.s, p1/M, p1/M, z14.s - KAI_ASM_INST(0xc09025e3) // addha za3.s, p1/M, p1/M, z15.s - ldr x20, [x0, #0x30] - add x20, x20, #0x3 - lsr x20, x20, #0x2 - lsr x21, x20, #0x2 - and x20, x20, #0x3 + KAI_ASM_INST(0xc0902440) // addha za0.s, p1/M, p1/M, z2.s + KAI_ASM_INST(0xc0902461) // addha za1.s, p1/M, p1/M, z3.s + KAI_ASM_INST(0xc0902442) // addha za2.s, p1/M, p1/M, z2.s + KAI_ASM_INST(0xc0902463) // addha za3.s, p1/M, p1/M, z3.s + lsr x21, x14, #0x2 + and x20, x14, #0x3 cbz x21, label_6 subs x21, x21, #0x1 - KAI_ASM_INST(0xa0400762) // ld1b { z2.b-z3.b }, pn9.b/Z, [x27] - KAI_ASM_INST(0xa1400780) // ld1b { z0.b, z8.b }, pn9.b/Z, [x28] - KAI_ASM_INST(0xa0410772) // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL] - KAI_ASM_INST(0xa0410794) // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL] - KAI_ASM_INST(0xa042077a) // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL] - KAI_ASM_INST(0xa0420796) // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL] - KAI_ASM_INST(0xa0430778) // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL] + KAI_ASM_INST(0xa1408362) // ld1b { z2.b, z6.b, z10.b, z14.b }, pn8.b/Z, [x27] + KAI_ASM_INST(0xa1418360) // ld1b { z0.b, z4.b, z8.b, z12.b }, pn8.b/Z, [x27, #0x4, MUL VL] addvl x27, x27, #8 - KAI_ASM_INST(0xa0430784) // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL] + KAI_ASM_INST(0xa0408390) // ld1b { z16.b-z19.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xa041839c) // ld1b { z28.b-z31.b }, pn8.b/Z, [x28, #0x4, MUL VL] addvl x28, x28, #8 ble label_5 KAI_ASM_LABEL(label_4) // K loop - KAI_ASM_INST(0xa0802440) // smopa za0.s, p1/M, p1/M, z2.b, z0.b + KAI_ASM_INST(0xa0902440) // smopa za0.s, p1/M, p1/M, z2.b, z16.b subs x21, x21, #0x1 - KAI_ASM_INST(0xa0882441) // smopa za1.s, p1/M, p1/M, z2.b, z8.b - KAI_ASM_INST(0xa0802462) // smopa za2.s, p1/M, p1/M, z3.b, z0.b - KAI_ASM_INST(0xa0882463) // smopa za3.s, p1/M, p1/M, z3.b, z8.b - KAI_ASM_INST(0xa0400762) // ld1b { z2.b-z3.b }, pn9.b/Z, [x27] - KAI_ASM_INST(0xa0942640) // smopa za0.s, p1/M, p1/M, z18.b, z20.b - KAI_ASM_INST(0xa1400780) // ld1b { z0.b, z8.b }, pn9.b/Z, [x28] - KAI_ASM_INST(0xa0952641) // smopa za1.s, p1/M, p1/M, z18.b, z21.b - KAI_ASM_INST(0xa0942662) // smopa za2.s, p1/M, p1/M, z19.b, z20.b - KAI_ASM_INST(0xa0952663) // smopa za3.s, p1/M, p1/M, z19.b, z21.b - KAI_ASM_INST(0xa0410772) // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL] - KAI_ASM_INST(0xa0962740) // smopa za0.s, p1/M, p1/M, z26.b, z22.b - KAI_ASM_INST(0xa0410794) // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL] - KAI_ASM_INST(0xa0972741) // smopa za1.s, p1/M, p1/M, z26.b, z23.b - KAI_ASM_INST(0xa0962762) // smopa za2.s, p1/M, p1/M, z27.b, z22.b - KAI_ASM_INST(0xa0972763) // smopa za3.s, p1/M, p1/M, z27.b, z23.b - KAI_ASM_INST(0xa042077a) // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL] - KAI_ASM_INST(0xa0420796) // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL] - KAI_ASM_INST(0xa0842700) // smopa za0.s, p1/M, p1/M, z24.b, z4.b - KAI_ASM_INST(0xa0852701) // smopa za1.s, p1/M, p1/M, z24.b, z5.b - KAI_ASM_INST(0xa0842722) // smopa za2.s, p1/M, p1/M, z25.b, z4.b - KAI_ASM_INST(0xa0852723) // smopa za3.s, p1/M, p1/M, z25.b, z5.b - KAI_ASM_INST(0xa0430778) // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL] + KAI_ASM_INST(0xa0912441) // smopa za1.s, p1/M, p1/M, z2.b, z17.b + KAI_ASM_INST(0xa09024c2) // smopa za2.s, p1/M, p1/M, z6.b, z16.b + KAI_ASM_INST(0xa09124c3) // smopa za3.s, p1/M, p1/M, z6.b, z17.b + KAI_ASM_INST(0xa0922540) // smopa za0.s, p1/M, p1/M, z10.b, z18.b + KAI_ASM_INST(0xa0932541) // smopa za1.s, p1/M, p1/M, z10.b, z19.b + KAI_ASM_INST(0xa09225c2) // smopa za2.s, p1/M, p1/M, z14.b, z18.b + KAI_ASM_INST(0xa09325c3) // smopa za3.s, p1/M, p1/M, z14.b, z19.b + KAI_ASM_INST(0xa1408362) // ld1b { z2.b, z6.b, z10.b, z14.b }, pn8.b/Z, [x27] + KAI_ASM_INST(0xa09c2400) // smopa za0.s, p1/M, p1/M, z0.b, z28.b + KAI_ASM_INST(0xa0408390) // ld1b { z16.b-z19.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xa09d2401) // smopa za1.s, p1/M, p1/M, z0.b, z29.b + KAI_ASM_INST(0xa09c2482) // smopa za2.s, p1/M, p1/M, z4.b, z28.b + KAI_ASM_INST(0xa09d2483) // smopa za3.s, p1/M, p1/M, z4.b, z29.b + KAI_ASM_INST(0xa09e2500) // smopa za0.s, p1/M, p1/M, z8.b, z30.b + KAI_ASM_INST(0xa09f2501) // smopa za1.s, p1/M, p1/M, z8.b, z31.b + KAI_ASM_INST(0xa09e2582) // smopa za2.s, p1/M, p1/M, z12.b, z30.b + KAI_ASM_INST(0xa09f2583) // smopa za3.s, p1/M, p1/M, z12.b, z31.b + KAI_ASM_INST(0xa1418360) // ld1b { z0.b, z4.b, z8.b, z12.b }, pn8.b/Z, [x27, #0x4, MUL VL] addvl x27, x27, #8 - KAI_ASM_INST(0xa0430784) // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL] + KAI_ASM_INST(0xa041839c) // ld1b { z28.b-z31.b }, pn8.b/Z, [x28, #0x4, MUL VL] addvl x28, x28, #8 bgt label_4 KAI_ASM_LABEL(label_5) // K loop tail - KAI_ASM_INST(0xa0802440) // smopa za0.s, p1/M, p1/M, z2.b, z0.b - KAI_ASM_INST(0xa0882441) // smopa za1.s, p1/M, p1/M, z2.b, z8.b - KAI_ASM_INST(0xa0802462) // smopa za2.s, p1/M, p1/M, z3.b, z0.b - KAI_ASM_INST(0xa0882463) // smopa za3.s, p1/M, p1/M, z3.b, z8.b - KAI_ASM_INST(0xa0942640) // smopa za0.s, p1/M, p1/M, z18.b, z20.b - KAI_ASM_INST(0xa0952641) // smopa za1.s, p1/M, p1/M, z18.b, z21.b - KAI_ASM_INST(0xa0942662) // smopa za2.s, p1/M, p1/M, z19.b, z20.b - KAI_ASM_INST(0xa0952663) // smopa za3.s, p1/M, p1/M, z19.b, z21.b - KAI_ASM_INST(0xa0962740) // smopa za0.s, p1/M, p1/M, z26.b, z22.b - KAI_ASM_INST(0xa0972741) // smopa za1.s, p1/M, p1/M, z26.b, z23.b - KAI_ASM_INST(0xa0962762) // smopa za2.s, p1/M, p1/M, z27.b, z22.b - KAI_ASM_INST(0xa0972763) // smopa za3.s, p1/M, p1/M, z27.b, z23.b - KAI_ASM_INST(0xa0842700) // smopa za0.s, p1/M, p1/M, z24.b, z4.b - KAI_ASM_INST(0xa0852701) // smopa za1.s, p1/M, p1/M, z24.b, z5.b - KAI_ASM_INST(0xa0842722) // smopa za2.s, p1/M, p1/M, z25.b, z4.b - KAI_ASM_INST(0xa0852723) // smopa za3.s, p1/M, p1/M, z25.b, z5.b + KAI_ASM_INST(0xa0902440) // smopa za0.s, p1/M, p1/M, z2.b, z16.b + KAI_ASM_INST(0xa0912441) // smopa za1.s, p1/M, p1/M, z2.b, z17.b + KAI_ASM_INST(0xa09024c2) // smopa za2.s, p1/M, p1/M, z6.b, z16.b + KAI_ASM_INST(0xa09124c3) // smopa za3.s, p1/M, p1/M, z6.b, z17.b + KAI_ASM_INST(0xa0922540) // smopa za0.s, p1/M, p1/M, z10.b, z18.b + KAI_ASM_INST(0xa0932541) // smopa za1.s, p1/M, p1/M, z10.b, z19.b + KAI_ASM_INST(0xa09225c2) // smopa za2.s, p1/M, p1/M, z14.b, z18.b + KAI_ASM_INST(0xa09325c3) // smopa za3.s, p1/M, p1/M, z14.b, z19.b + KAI_ASM_INST(0xa09c2400) // smopa za0.s, p1/M, p1/M, z0.b, z28.b + KAI_ASM_INST(0xa09d2401) // smopa za1.s, p1/M, p1/M, z0.b, z29.b + KAI_ASM_INST(0xa09c2482) // smopa za2.s, p1/M, p1/M, z4.b, z28.b + KAI_ASM_INST(0xa09d2483) // smopa za3.s, p1/M, p1/M, z4.b, z29.b + KAI_ASM_INST(0xa09e2500) // smopa za0.s, p1/M, p1/M, z8.b, z30.b + KAI_ASM_INST(0xa09f2501) // smopa za1.s, p1/M, p1/M, z8.b, z31.b + KAI_ASM_INST(0xa09e2582) // smopa za2.s, p1/M, p1/M, z12.b, z30.b + KAI_ASM_INST(0xa09f2583) // smopa za3.s, p1/M, p1/M, z12.b, z31.b KAI_ASM_LABEL(label_6) // K oddments cbz x20, label_8 KAI_ASM_LABEL(label_7) // K oddments: Loop - KAI_ASM_INST(0xa0400770) // ld1b { z16.b-z17.b }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa0400370) // ld1b { z16.b-z17.b }, pn8.b/Z, [x27] subs x20, x20, #0x1 addvl x27, x27, #2 - KAI_ASM_INST(0xa0400788) // ld1b { z8.b-z9.b }, pn9.b/Z, [x28] + KAI_ASM_INST(0xa1400385) // ld1b { z5.b, z13.b }, pn8.b/Z, [x28] addvl x28, x28, #2 - KAI_ASM_INST(0xa0882600) // smopa za0.s, p1/M, p1/M, z16.b, z8.b - KAI_ASM_INST(0xa0892601) // smopa za1.s, p1/M, p1/M, z16.b, z9.b - KAI_ASM_INST(0xa0882622) // smopa za2.s, p1/M, p1/M, z17.b, z8.b - KAI_ASM_INST(0xa0892623) // smopa za3.s, p1/M, p1/M, z17.b, z9.b + KAI_ASM_INST(0xa0852600) // smopa za0.s, p1/M, p1/M, z16.b, z5.b + KAI_ASM_INST(0xa08d2601) // smopa za1.s, p1/M, p1/M, z16.b, z13.b + KAI_ASM_INST(0xa0852622) // smopa za2.s, p1/M, p1/M, z17.b, z5.b + KAI_ASM_INST(0xa08d2623) // smopa za3.s, p1/M, p1/M, z17.b, z13.b bgt label_7 KAI_ASM_LABEL(label_8) // K oddments: End ldr x26, [x0, #0x10] - sub x25, x13, x14 + sub x25, x13, x15 cntw x24 - ld1rw { z27.s }, p1/Z, [x0, #56] + ld1rw { z26.s }, p1/Z, [x0, #56] ldr x23, [x0, #0x18] - whilelt p0.h, x10, x11 + whilelt p0.h, x11, x10 cmp x25, x24 - ld1rw { z1.s }, p1/Z, [x0, #60] + ld1rw { z23.s }, p1/Z, [x0, #60] csel x22, x25, x24, LT ld1rw { z0.s }, p1/Z, [x0, #64] mov x12, #0x0 - add x26, x26, x10 // C += n + add x26, x26, x11 // C += n lsr x21, x22, #0x2 - ld1w { z22.s }, p1/Z, [x28] - madd x26, x14, x23, x26 // C += m * ldc - ld1w { z26.s }, p1/Z, [x28, #1, MUL VL] - and x20, x22, #0x3 + KAI_ASM_INST(0xa0404382) // ld1w { z2.s-z3.s }, pn8.b/Z, [x28] + madd x26, x15, x23, x26 // C += m * ldc addvl x28, x28, #2 + and x20, x22, #0x3 cbz x21, label_11 KAI_ASM_LABEL(label_10) // Store to output array: Accumulator row 0 loop - KAI_ASM_INST(0xc0860410) // mova { z16.s-z19.s }, za0h.s[x12] - KAI_ASM_INST(0xc086043c) // mova { z28.s-z31.s }, za1h.s[x12] + KAI_ASM_INST(0xc0860408) // mova { z8.s-z11.s }, za0h.s[x12] + KAI_ASM_INST(0xc0860430) // mova { z16.s-z19.s }, za1h.s[x12] + KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc132e210) // scvtf { z16.s-z19.s }, { z16.s-z19.s } - KAI_ASM_INST(0xc132e39c) // scvtf { z28.s-z31.s }, { z28.s-z31.s } - fmul z16.s, z16.s, z22.s - fmul z17.s, z17.s, z22.s + fmul z8.s, z8.s, z2.s + fmul z9.s, z9.s, z2.s add x12, x12, #0x4 - fmul z18.s, z18.s, z22.s - fmul z19.s, z19.s, z22.s + fmul z10.s, z10.s, z2.s + fmul z11.s, z11.s, z2.s cmp x12, x21, LSL #2 - fmul z28.s, z28.s, z26.s - fmul z29.s, z29.s, z26.s - fmul z30.s, z30.s, z26.s - fmul z31.s, z31.s, z26.s + fmul z16.s, z16.s, z3.s + fmul z17.s, z17.s, z3.s + fmul z18.s, z18.s, z3.s + fmul z19.s, z19.s, z3.s + KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc1b8e210) // frintn { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s KAI_ASM_INST(0xc131e210) // fcvtzs { z16.s-z19.s }, { z16.s-z19.s } - KAI_ASM_INST(0xc1b8e39c) // frintn { z28.s-z31.s }, { z28.s-z31.s } KAI_ASM_INST(0xc1a0ab10) // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s - KAI_ASM_INST(0xc131e39c) // fcvtzs { z28.s-z31.s }, { z28.s-z31.s } - KAI_ASM_INST(0xc1a0ab1c) // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s - KAI_ASM_INST(0xc1a1cf70) // sclamp { z16.s-z19.s }, z27.s, z1.s - KAI_ASM_INST(0xc1a1cf7c) // sclamp { z28.s-z31.s }, z27.s, z1.s - uzp1 z5.h, z16.h, z28.h - uzp1 z20.h, z17.h, z29.h - uzp1 z17.h, z18.h, z30.h - uzp1 z16.h, z19.h, z31.h + KAI_ASM_INST(0xc1b7cf48) // sclamp { z8.s-z11.s }, z26.s, z23.s + KAI_ASM_INST(0xc1b7cf50) // sclamp { z16.s-z19.s }, z26.s, z23.s + uzp1 z5.h, z8.h, z16.h + uzp1 z14.h, z9.h, z17.h + uzp1 z17.h, z10.h, z18.h + uzp1 z16.h, z11.h, z19.h st1b { z5.h }, p0, [x26] add x26, x26, x23 - st1b { z20.h }, p0, [x26] + st1b { z14.h }, p0, [x26] add x26, x26, x23 st1b { z17.h }, p0, [x26] add x26, x26, x23 @@ -202,37 +196,37 @@ KAI_ASM_LABEL(label_10) // Store to output array: Accumulator row 0 loop blt label_10 KAI_ASM_LABEL(label_11) // Store to output array: Accumulator row 0 oddments cbz x20, label_12 - KAI_ASM_INST(0xc0860404) // mova { z4.s-z7.s }, za0h.s[x12] + KAI_ASM_INST(0xc0860408) // mova { z8.s-z11.s }, za0h.s[x12] KAI_ASM_INST(0xc086042c) // mova { z12.s-z15.s }, za1h.s[x12] - KAI_ASM_INST(0xc132e084) // scvtf { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc132e18c) // scvtf { z12.s-z15.s }, { z12.s-z15.s } - fmul z4.s, z4.s, z22.s - fmul z5.s, z5.s, z22.s + fmul z8.s, z8.s, z2.s + fmul z9.s, z9.s, z2.s subs x20, x20, #0x1 - fmul z6.s, z6.s, z22.s - fmul z7.s, z7.s, z22.s - fmul z12.s, z12.s, z26.s - fmul z13.s, z13.s, z26.s - fmul z14.s, z14.s, z26.s - fmul z15.s, z15.s, z26.s - KAI_ASM_INST(0xc1b8e084) // frintn { z4.s-z7.s }, { z4.s-z7.s } - KAI_ASM_INST(0xc131e084) // fcvtzs { z4.s-z7.s }, { z4.s-z7.s } + fmul z10.s, z10.s, z2.s + fmul z11.s, z11.s, z2.s + fmul z12.s, z12.s, z3.s + fmul z13.s, z13.s, z3.s + fmul z14.s, z14.s, z3.s + fmul z15.s, z15.s, z3.s + KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc1b8e18c) // frintn { z12.s-z15.s }, { z12.s-z15.s } - KAI_ASM_INST(0xc1a0ab04) // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s + KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s KAI_ASM_INST(0xc131e18c) // fcvtzs { z12.s-z15.s }, { z12.s-z15.s } KAI_ASM_INST(0xc1a0ab0c) // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s - KAI_ASM_INST(0xc1a1cf64) // sclamp { z4.s-z7.s }, z27.s, z1.s - KAI_ASM_INST(0xc1a1cf6c) // sclamp { z12.s-z15.s }, z27.s, z1.s - uzp1 z16.h, z4.h, z12.h + KAI_ASM_INST(0xc1b7cf48) // sclamp { z8.s-z11.s }, z26.s, z23.s + KAI_ASM_INST(0xc1b7cf4c) // sclamp { z12.s-z15.s }, z26.s, z23.s + uzp1 z16.h, z8.h, z12.h st1b { z16.h }, p0, [x26] add x26, x26, x23 beq label_12 subs x20, x20, #0x1 - uzp1 z16.h, z5.h, z13.h + uzp1 z16.h, z9.h, z13.h st1b { z16.h }, p0, [x26] add x26, x26, x23 beq label_12 - uzp1 z16.h, z6.h, z14.h + uzp1 z16.h, z10.h, z14.h st1b { z16.h }, p0, [x26] add x26, x26, x23 KAI_ASM_LABEL(label_12) // Store to output array: Accumulator row 0 oddments: End @@ -249,24 +243,24 @@ KAI_ASM_LABEL(label_13) // Store to output array: Accumulator row 1 loop KAI_ASM_INST(0xc0860470) // mova { z16.s-z19.s }, za3h.s[x12] KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc132e210) // scvtf { z16.s-z19.s }, { z16.s-z19.s } - fmul z8.s, z8.s, z22.s - fmul z9.s, z9.s, z22.s + fmul z8.s, z8.s, z2.s + fmul z9.s, z9.s, z2.s add x12, x12, #0x4 - fmul z10.s, z10.s, z22.s - fmul z11.s, z11.s, z22.s + fmul z10.s, z10.s, z2.s + fmul z11.s, z11.s, z2.s cmp x12, x21, LSL #2 - fmul z16.s, z16.s, z26.s - fmul z17.s, z17.s, z26.s - fmul z18.s, z18.s, z26.s - fmul z19.s, z19.s, z26.s + fmul z16.s, z16.s, z3.s + fmul z17.s, z17.s, z3.s + fmul z18.s, z18.s, z3.s + fmul z19.s, z19.s, z3.s KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc1b8e210) // frintn { z16.s-z19.s }, { z16.s-z19.s } KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s KAI_ASM_INST(0xc131e210) // fcvtzs { z16.s-z19.s }, { z16.s-z19.s } KAI_ASM_INST(0xc1a0ab10) // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s - KAI_ASM_INST(0xc1a1cf68) // sclamp { z8.s-z11.s }, z27.s, z1.s - KAI_ASM_INST(0xc1a1cf70) // sclamp { z16.s-z19.s }, z27.s, z1.s + KAI_ASM_INST(0xc1b7cf48) // sclamp { z8.s-z11.s }, z26.s, z23.s + KAI_ASM_INST(0xc1b7cf50) // sclamp { z16.s-z19.s }, z26.s, z23.s uzp1 z21.h, z8.h, z16.h uzp1 z20.h, z9.h, z17.h uzp1 z17.h, z10.h, z18.h @@ -286,23 +280,23 @@ KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments KAI_ASM_INST(0xc0860464) // mova { z4.s-z7.s }, za3h.s[x12] KAI_ASM_INST(0xc132e18c) // scvtf { z12.s-z15.s }, { z12.s-z15.s } KAI_ASM_INST(0xc132e084) // scvtf { z4.s-z7.s }, { z4.s-z7.s } - fmul z12.s, z12.s, z22.s - fmul z13.s, z13.s, z22.s + fmul z12.s, z12.s, z2.s + fmul z13.s, z13.s, z2.s subs x20, x20, #0x1 - fmul z14.s, z14.s, z22.s - fmul z15.s, z15.s, z22.s - fmul z4.s, z4.s, z26.s - fmul z5.s, z5.s, z26.s - fmul z6.s, z6.s, z26.s - fmul z7.s, z7.s, z26.s + fmul z14.s, z14.s, z2.s + fmul z15.s, z15.s, z2.s + fmul z4.s, z4.s, z3.s + fmul z5.s, z5.s, z3.s + fmul z6.s, z6.s, z3.s + fmul z7.s, z7.s, z3.s KAI_ASM_INST(0xc1b8e18c) // frintn { z12.s-z15.s }, { z12.s-z15.s } KAI_ASM_INST(0xc131e18c) // fcvtzs { z12.s-z15.s }, { z12.s-z15.s } KAI_ASM_INST(0xc1b8e084) // frintn { z4.s-z7.s }, { z4.s-z7.s } KAI_ASM_INST(0xc1a0ab0c) // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s KAI_ASM_INST(0xc131e084) // fcvtzs { z4.s-z7.s }, { z4.s-z7.s } KAI_ASM_INST(0xc1a0ab04) // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s - KAI_ASM_INST(0xc1a1cf6c) // sclamp { z12.s-z15.s }, z27.s, z1.s - KAI_ASM_INST(0xc1a1cf64) // sclamp { z4.s-z7.s }, z27.s, z1.s + KAI_ASM_INST(0xc1b7cf4c) // sclamp { z12.s-z15.s }, z26.s, z23.s + KAI_ASM_INST(0xc1b7cf44) // sclamp { z4.s-z7.s }, z26.s, z23.s uzp1 z16.h, z12.h, z4.h st1b { z16.h }, p0, [x26] add x26, x26, x23 @@ -316,12 +310,12 @@ KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments st1b { z16.h }, p0, [x26] KAI_ASM_LABEL(label_15) // Store to output array: Accumulator row 1 oddments: End KAI_ASM_LABEL(label_16) // Store to output array: End - incw x10, ALL, MUL #2 - cmp x10, x11 + incw x11, ALL, MUL #2 + cmp x11, x10 blt label_2 - incw x14, ALL, MUL #2 - mov x10, #0x0 - cmp x14, x13 + incw x15, ALL, MUL #2 + mov x11, #0x0 + cmp x15, x13 mov x9, x27 blt label_1 KAI_ASM_INST(0xd503467f) // SMSTOP @@ -329,7 +323,11 @@ KAI_ASM_LABEL(label_16) // Store to output array: End ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c index 4058a06e..52b1b542 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c @@ -17,7 +17,7 @@ enum { }; void kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme( - size_t height, size_t width, void* in, size_t row_offset, void* out); + size_t height, size_t width, const void* in, size_t row_offset, void* out); static size_t kai_get_mr_lhs_imatmul_pack_x16p2vlx2_x16p_sme(void) { return MR * kai_get_sme_vector_length_u16() / KR; diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S index d82b08d7..cdd955de 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme) KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA mov x8, #0x0 cnth x22 @@ -305,7 +309,11 @@ KAI_ASM_LABEL(label_13) // K loop: End ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_lhs_imatmul_pack_x16p2vlx2_x16p_sme) diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c index 4d2d272b..cd6fb60e 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c @@ -17,7 +17,7 @@ enum { }; void kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme( - size_t height, size_t width, void* in, size_t row_offset, void* out); + size_t height, size_t width, const void* in, size_t row_offset, void* out); static size_t kai_get_mr_lhs_imatmul_pack_x32p2vlx1_x32p_sme(void) { return MR * kai_get_sme_vector_length_u32() / KR; diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S index 673ae8a1..f2a144e6 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme) KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA mov x16, #0x0 mov x21, x1 @@ -292,7 +296,11 @@ KAI_ASM_LABEL(label_13) // K loop: End ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_lhs_imatmul_pack_x32p2vlx1_x32p_sme) diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c index b5fbd436..c0c87037 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c @@ -16,7 +16,8 @@ enum { MAX_M_STEP = MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(int8_t)) / KR, }; -void kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme(size_t height, size_t width, void* in, size_t row_offset, void* out); +void kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme( + size_t height, size_t width, const void* in, size_t row_offset, void* out); static size_t kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void) { return MR * kai_get_sme_vector_length_u8() / KR; diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S index 18ea1f78..5040ab58 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme) KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA mov x8, #0x0 cntb x21 @@ -306,7 +310,11 @@ KAI_ASM_LABEL(label_13) // K loop: End ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_lhs_imatmul_pack_x8p2vlx4_x8p_sme) diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S index 28933124..6713f68d 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA ldr x2, [x0, #0x28] ptrue p2.b @@ -233,7 +237,11 @@ KAI_ASM_LABEL(label_14) // Bias: Done ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S index d6230333..2ff504ee 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA ldr x8, [x0, #0x8] ptrue p1.b @@ -168,7 +172,11 @@ KAI_ASM_LABEL(label_10) // Done ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme) diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S index 9fa37cfb..6189678a 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA ldr x14, [x0, #0x8] ptrue p2.b @@ -154,7 +158,11 @@ KAI_ASM_LABEL(label_10) // Done ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme) -- GitLab From 57d2994320461cb292ddcd605b40869dec126913 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Tue, 3 Jun 2025 13:39:06 +0200 Subject: [PATCH 6/9] Omit unused `n_0` Signed-off-by: Emil Ohlsson --- ...tmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c | 1 - 1 file changed, 1 deletion(-) diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c index c8e36d6a..4b253b26 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c @@ -22,7 +22,6 @@ typedef struct { int32_t min; int32_t max; int32_t result_zero_point; - const int n_0; void* accumulator_buffer; uint64_t flags; } KernelArgs; -- GitLab From 64645fe9932957056f1105d9a497c6c54429c674 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Tue, 3 Jun 2025 19:56:52 +0200 Subject: [PATCH 7/9] Re-apply arch feature check Signed-off-by: Emil Ohlsson --- ...matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c | 5 +++++ ...atmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c | 5 +++++ ...clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c | 5 +++++ .../matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c | 6 ++++++ .../matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c | 6 ++++++ .../matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c | 6 ++++++ ..._rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c | 6 ++++++ .../pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c | 6 ++++++ .../pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c | 6 ++++++ 9 files changed, 51 insertions(+) diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c index 904268ff..51c869ec 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c @@ -4,6 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. #include "kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" #include @@ -98,3 +101,5 @@ void kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( kai_kernel_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(&args); } + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c index c3071ab7..8e64712c 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c @@ -4,6 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. #include "kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" #include @@ -97,3 +100,5 @@ void kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(&args); } + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c index 4b253b26..0db463d4 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c @@ -4,6 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. #include "kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" #include @@ -99,3 +102,5 @@ void kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(&args); } + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c index 52b1b542..b3bbdbee 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c @@ -3,6 +3,10 @@ // // SPDX-License-Identifier: Apache-2.0 // + +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. #include "kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" #include @@ -72,3 +76,5 @@ void kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme( } } } + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c index cd6fb60e..4f092599 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c @@ -3,6 +3,10 @@ // // SPDX-License-Identifier: Apache-2.0 // + +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. #include "kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" #include @@ -72,3 +76,5 @@ void kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme( } } } + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c index c0c87037..53799234 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c @@ -3,6 +3,10 @@ // // SPDX-License-Identifier: Apache-2.0 // + +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. #include "kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h" #include @@ -72,3 +76,5 @@ void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( } } } + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c index 40922c37..c8b814c6 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -3,6 +3,10 @@ // // SPDX-License-Identifier: Apache-2.0 // + +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. #include "kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" #include @@ -109,3 +113,5 @@ void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(&args); } + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c index 228f3c0b..4cc50d1d 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c @@ -3,6 +3,10 @@ // // SPDX-License-Identifier: Apache-2.0 // + +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. #include "kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h" #include @@ -95,3 +99,5 @@ void kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(&args); } + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c index d56d329b..2a7c8100 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c @@ -3,6 +3,10 @@ // // SPDX-License-Identifier: Apache-2.0 // + +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. #include "kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" #include @@ -88,3 +92,5 @@ void kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(&args); } + +#endif // Architectural features check. -- GitLab From 2c4115a58f4e43a709a2ce2137ee1c365f6158cc Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Wed, 4 Jun 2025 10:10:05 +0200 Subject: [PATCH 8/9] Update conv2d example for pure asm Signed-off-by: Emil Ohlsson --- .../CMakeLists.txt | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt b/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt index a9499b84..0034c3dd 100644 --- a/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt +++ b/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt @@ -8,6 +8,12 @@ cmake_minimum_required(VERSION 3.16) project(conv2d_imatmul_clamp_f16_f16_f16p_sme2) +if(MSVC) + enable_language(ASM_MARMASM) +else() + enable_language(ASM) +endif() + set(CMAKE_CXX_STANDARD 17) set(KAI_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../) set(KAI_BUILD ${KAI_PATH}/build) @@ -15,14 +21,19 @@ set(KAI_BUILD ${KAI_PATH}/build) include_directories(${KAI_PATH}) set(KAI_SOURCES + ${KAI_PATH}/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S ${KAI_PATH}/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c + ${KAI_PATH}/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S ${KAI_PATH}/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c - ${KAI_PATH}/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c) + ${KAI_PATH}/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S + ${KAI_PATH}/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c +) # Files requires to build the executable -add_executable( - conv2d_imatmul_clamp_f16_f16_f16p_sme2 conv2d_imatmul_clamp_f16_f16_f16p.cpp - ${KAI_SOURCES}) +add_executable(conv2d_imatmul_clamp_f16_f16_f16p_sme2 + conv2d_imatmul_clamp_f16_f16_f16p.cpp + ${KAI_SOURCES} +) target_compile_options(conv2d_imatmul_clamp_f16_f16_f16p_sme2 PRIVATE "-march=armv8.2-a+sve+sve2;-fno-tree-vectorize" -- GitLab From 37458271645be720cb6b06fd737c9fe6b970afd8 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Thu, 5 Jun 2025 13:06:02 +0200 Subject: [PATCH 9/9] Fix build and unsupported instructions Signed-off-by: Emil Ohlsson --- CMakeLists.txt | 2 ++ .../CMakeLists.txt | 6 +----- ...p_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S | 4 ++-- ..._f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S | 4 ++-- ...qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S | 6 +++--- .../kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h | 2 +- test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 12 ++++++------ 7 files changed, 17 insertions(+), 19 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f46bc7f7..4a36b210 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -324,6 +324,8 @@ else() set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set(KLEIDIAI_FILES_ASM + ${KLEIDIAI_FILES_SME_ASM} + ${KLEIDIAI_FILES_SME2_ASM} ${KLEIDIAI_FILES_NEON_DOTPROD_ASM} ${KLEIDIAI_FILES_NEON_I8MM_ASM}) list(FILTER KLEIDIAI_FILES_ASM INCLUDE REGEX "^.*\.S$") diff --git a/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt b/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt index 0034c3dd..afc8a4c5 100644 --- a/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt +++ b/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt @@ -8,11 +8,7 @@ cmake_minimum_required(VERSION 3.16) project(conv2d_imatmul_clamp_f16_f16_f16p_sme2) -if(MSVC) - enable_language(ASM_MARMASM) -else() - enable_language(ASM) -endif() +enable_language(ASM) set(CMAKE_CXX_STANDARD 17) set(KAI_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../) diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S index 259e07f9..d6bae7c7 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S @@ -155,11 +155,11 @@ KAI_ASM_LABEL(label_8) // K oddments: End ldr x25, [x0, #0x10] sub x24, x11, x14 cntw x23, ALL, MUL #2 - ld1rh { z17.h }, p1/Z, [x0, #56] + KAI_ASM_INST(0x84dca411) // ld1rh { z17.h }, p1/Z, [x0, #56] ldr x22, [x0, #0x18] whilelt p0.h, x10, x9 cmp x24, x23 - ld1rh { z16.h }, p1/Z, [x0, #58] + KAI_ASM_INST(0x84dda410) // ld1rh { z16.h }, p1/Z, [x0, #58] mov x12, #0x0 mov x21, #0x0 add x25, x25, x10, LSL #1 // C += n diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S index b355fbe9..71bcc59d 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa_asm.S @@ -143,10 +143,10 @@ KAI_ASM_LABEL(label_8) // K oddments: End ldr x26, [x0, #0x10] sub x25, x13, x15 cntw x24 - ld1rw { z26.s }, p0/Z, [x0, #56] + KAI_ASM_INST(0x854ec01a) // ld1rw { z26.s }, p0/Z, [x0, #56] ldr x23, [x0, #0x18] cmp x25, x24 - ld1rw { z24.s }, p0/Z, [x0, #60] + KAI_ASM_INST(0x854fc018) // ld1rw { z24.s }, p0/Z, [x0, #60] mov x12, #0x0 csel x22, x25, x24, LT add x26, x26, x11, LSL #2 // C += n diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S index c08d2dea..750c7c6b 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S @@ -143,13 +143,13 @@ KAI_ASM_LABEL(label_8) // K oddments: End ldr x26, [x0, #0x10] sub x25, x13, x15 cntw x24 - ld1rw { z26.s }, p1/Z, [x0, #56] + KAI_ASM_INST(0x854ec41a) // ld1rw { z26.s }, p1/Z, [x0, #56] ldr x23, [x0, #0x18] whilelt p0.h, x11, x10 cmp x25, x24 - ld1rw { z23.s }, p1/Z, [x0, #60] + KAI_ASM_INST(0x854fc417) // ld1rw { z23.s }, p1/Z, [x0, #60] csel x22, x25, x24, LT - ld1rw { z0.s }, p1/Z, [x0, #64] + KAI_ASM_INST(0x8550c400) // ld1rw { z0.s }, p1/Z, [x0, #64] mov x12, #0x0 add x26, x26, x11 // C += n lsr x21, x22, #0x2 diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h index a2527662..c5154668 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h @@ -22,7 +22,7 @@ typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_lhs_packed_offset_func typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_rhs_packed_offset_func_t)( size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_offset_func_t)( - size_t m_idx, size_t n_idx, size_t dst_stride); + size_t m_idx, size_t n_idx, size_t dst_stride_row); typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_size_func_t)(size_t m, size_t n); /// Micro-kernel core function ("run" method) diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 27daa65f..05044804 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -124,8 +124,8 @@ struct MatMulKernel { struct MatMulIndirectKernel { std::function get_m_step; std::function get_n_step; - std::function get_packed_lhs_offset; - std::function get_packed_rhs_offset; + std::function get_lhs_packed_offset; + std::function get_rhs_packed_offset; std::function get_dst_offset; std::function get_dst_size; std::function& get_indirect_gemm_variants() { variants[0].rhs_pack.pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; variants[0].matmul.get_m_step = ukernel.get_m_step; variants[0].matmul.get_n_step = ukernel.get_n_step; - variants[0].matmul.get_packed_lhs_offset = ukernel.get_lhs_packed_offset; - variants[0].matmul.get_packed_rhs_offset = ukernel.get_rhs_packed_offset; + variants[0].matmul.get_lhs_packed_offset = ukernel.get_lhs_packed_offset; + variants[0].matmul.get_rhs_packed_offset = ukernel.get_rhs_packed_offset; variants[0].matmul.get_dst_offset = ukernel.get_dst_offset; variants[0].matmul.get_dst_size = ukernel.get_dst_size; variants[0].matmul.imatmul = ukernel.run_imatmul; @@ -845,8 +845,8 @@ static Buffer matmul( const Buffer& packed_rhs, const MatMulShape& shape, const KChunk& k_chunk) { // Calculate portion offsets. size_t dst_offset = variant.get_dst_offset(portion.start_row(), portion.start_col(), shape.n); - size_t lhs_offset = variant.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); - size_t rhs_offset = variant.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length); + size_t lhs_offset = variant.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length); + size_t rhs_offset = variant.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length); // Allocate output buffer const size_t dst_size = variant.get_dst_size(shape.m, shape.n); -- GitLab