From b754463d36363d3af4dad073726bdb499b1a3f3f Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Thu, 22 May 2025 15:32:13 +0200 Subject: [PATCH 1/4] Split INT8 SME kernels into seperate assembly file. Move the assembly blocks of the following kernels into their own files: - lhs_pack_x8p2vlx4_x8_sme - rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme - matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot - matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa Signed-off-by: Jens Elofsson --- CMakeLists.txt | 28 +- kai/ukernels/matmul/BUILD.bazel | 32 +- ...qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c | 900 +---------------- ..._qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S | 907 ++++++++++++++++++ ...8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c | 360 +------ ...8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h | 26 +- ...lx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S | 336 +++++++ .../pack/kai_lhs_pack_x8p2vlx4_x8_sme.c | 364 ++----- .../pack/kai_lhs_pack_x8p2vlx4_x8_sme.h | 8 +- .../pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S | 318 ++++++ ...ack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c | 260 ++--- ...ack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h | 28 +- ...kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S | 234 +++++ 13 files changed, 2068 insertions(+), 1733 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S create mode 100644 kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index ea7b9926..f0b9f0d4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -221,25 +221,39 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c ) +set(KLEIDIAI_FILES_SME_ASM + kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S +) + set(KLEIDIAI_FILES_SME + ${KLEIDIAI_FILES_SME_ASM} kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c ) +set(KLEIDIAI_FILES_SME2_ASM + kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c + kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S + kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S +) + set(KLEIDIAI_FILES_SME2 + ${KLEIDIAI_FILES_SME2_ASM} kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c @@ -253,8 +267,6 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c - kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ) @@ -293,14 +305,20 @@ if(NOT MSVC) else() target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM_ASM}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME_ASM}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2_ASM}) set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SME_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SME2_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set(KLEIDIAI_FILES_ASM ${KLEIDIAI_FILES_NEON_DOTPROD_ASM} - ${KLEIDIAI_FILES_NEON_I8MM_ASM}) + ${KLEIDIAI_FILES_NEON_I8MM_ASM} + ${KLEIDIAI_FILES_SME_ASM} + ${KLEIDIAI_FILES_SME2_ASM}) list(FILTER KLEIDIAI_FILES_ASM INCLUDE REGEX "^.*\.S$") set_source_files_properties(${KLEIDIAI_FILES_ASM} PROPERTIES LANGUAGE ASM_MARMASM) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 2da2800c..40ed5216 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -136,6 +136,12 @@ I8MM_KERNELS_ASM = [ "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", ] +# buildifier: keep sorted +SME_KERNELS_ASM = [ + "pack/kai_lhs_pack_x8p2vlx4_x8_sme", + "pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", +] + # buildifier: keep sorted SME_KERNELS = [ "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", @@ -144,19 +150,23 @@ SME_KERNELS = [ "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_lhs_pack_f32p2vlx1_f32_sme", "pack/kai_lhs_pack_x16p2vlx2_x16_sme", - "pack/kai_lhs_pack_x8p2vlx4_x8_sme", "pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", "pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme", "pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", - "pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", "pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme", "pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme", "pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme", ] +# buildifier: keep sorted +SME2_KERNELS_ASM = [ + "matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot", + "matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", +] + # buildifier: keep sorted SME2_KERNELS = [ "imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", @@ -172,8 +182,6 @@ SME2_KERNELS = [ "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot", "matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", - "matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot", - "matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", ] kai_c_library( @@ -265,6 +273,13 @@ kai_c_library( textual_hdrs = [ukernel + ".h" for ukernel in I8MM_KERNELS_ASM], ) +kai_c_library( + name = "sme_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in SME_KERNELS_ASM] + [ukernel + ".c" for ukernel in SME_KERNELS_ASM], + cpu_uarch = kai_cpu_sme(), + textual_hdrs = [ukernel + ".h" for ukernel in SME_KERNELS_ASM], +) + kai_c_library( name = "sme_impl", srcs = [ukernel + ".c" for ukernel in SME_KERNELS], @@ -272,6 +287,13 @@ kai_c_library( textual_hdrs = [ukernel + ".h" for ukernel in SME_KERNELS], ) +kai_c_library( + name = "sme2_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in SME2_KERNELS_ASM] + [ukernel + ".c" for ukernel in SME2_KERNELS_ASM], + cpu_uarch = kai_cpu_sme2(), + textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS_ASM], +) + kai_c_library( name = "sme2_impl", srcs = [ukernel + ".c" for ukernel in SME2_KERNELS], @@ -297,6 +319,8 @@ kai_c_library( ":neon_impl_asm", ":scalar_impl", ":sme2_impl", + ":sme2_impl_asm", ":sme_impl", + ":sme_impl_asm", ], ) diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c index e638e69e..6a5f9536 100644 --- a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c @@ -7,7 +7,7 @@ // Do not flag up inline assembly blocks #pragma GCC diagnostic ignored "-Woverlength-strings" -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. @@ -18,6 +18,20 @@ #include "kai/kai_common.h" +typedef struct { + int32_t c_offset; + int32_t maxval; + int32_t minval; + const void* A_ptr; + const void* B_ptr; + size_t N; + size_t K; + void* output_ptr; + uint64_t flags; +} KernelArgs; + +void kai_kernel_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(KernelArgs* args_ptr); + static const size_t kai_m_step = 1; static const size_t kai_nr = 2; static const size_t kai_n_step = 16; @@ -79,881 +93,23 @@ void kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( size_t dst_stride_col, const struct kai_matmul_requantize32_params* params) { KAI_UNUSED(dst_stride_row); KAI_UNUSED(dst_stride_col); - KAI_ASSUME(m == 1); - typedef struct { - int32_t c_offset; - int32_t maxval; - int32_t minval; - } KernelArgs; - - KernelArgs k_args; - k_args.maxval = params->max_value; - k_args.minval = params->min_value; - k_args.c_offset = params->output_zero_point; - - size_t N = n; - size_t K = k; + KAI_ASSUME(m == 1); - const void* A_ptr = lhs; - const void* B_ptr = rhs_packed; - void* output_ptr = dst; + uint64_t flags = 2; - uint64_t flags = 0; + KernelArgs args; + args.c_offset = params->output_zero_point; + args.maxval = params->max_value; + args.minval = params->min_value; + args.A_ptr = lhs; + args.B_ptr = rhs_packed; + args.N = n; + args.K = k; + args.output_ptr = dst; + args.flags = flags; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x8, #0x0\n" - "mov x16, %x[B_ptr]\n" - "cntw x15, ALL, MUL #4\n" - "mov x14, %x[output_ptr]\n" - "add x13, %x[N], x15\n" - "ptrue p2.b\n" - "sub x13, x13, #0x1\n" - ".inst 0x25207810 // ptrue pn8.b\n" - "udiv x13, x13, x15\n" - "mov x22, #0x1\n" - "add x21, x13, #0x3\n" - "and x21, x21, #0xfffffffffffffffc\n" - "mul x21, x21, x15\n" - "mul x21, x21, %x[K]\n" - "1:" // RHS size check loop - "cmp x21, #0x200000\n" - "blt 2f\n" - "tbnz x21, #0, 3f\n" - "lsr x21, x21, #0x1\n" - "lsl x22, x22, #0x1\n" - "b 1b\n" - "2:" // RHS do prefetch - "lsl x20, x21, #0x26\n" - "sub x22, x22, #0x1\n" - "lsl x22, x22, #0x16\n" - "orr x21, x21, x20\n" - "orr x21, x21, x22\n" - ".inst 0xf8b54a1a // rprfm pldonce, x21, [x16]\n" - "3:" // RHS prefetch exit - "add x12, %x[K], #0x3\n" - "cntw x20, ALL, MUL #2\n" - "mov z25.s, #0x0\n" - "mov z27.b, #0x1\n" - "bic x12, x12, #0x3\n" - "bic %x[flags], %x[flags], #0x80000000\n" - "add x12, x12, #0x8\n" - "mul x12, x12, x20\n" - "4:" // Column loop - "cmp x13, #0x4\n" - "bge 25f\n" - "cmp x13, #0x2\n" - "bgt 18f\n" - "beq 11f\n" - "cntw x20, ALL, MUL #2\n" - "add x23, x16, x12\n" - ".inst 0xa0404210 // ld1w { z16.s-z17.s }, pn8.b/Z, [x16]\n" - "cmp %x[N], x20\n" - "mov x11, %x[K]\n" - "csel x23, x23, x16, GT\n" - "mov x21, %x[N]\n" - ".inst 0xa04042f2 // ld1w { z18.s-z19.s }, pn8.b/Z, [x23]\n" - "mov x10, %x[A_ptr]\n" - "mov x20, %x[K]\n" - "whilelt p1.b, XZR, x21\n" - "cmp x11, #0x10\n" - ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" - "addvl x16, x16, #2\n" - "addvl x23, x23, #2\n" - ".inst 0xc0040e00 // mova za.d[x8, #0], { z16.d-z19.d }\n" - "ble 7f\n" - "5:" // Width 1: Multiply loop: Main loop head - "whilelt p0.b, XZR, x11\n" - ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - "ld1rqb { z13.b }, p0/Z, [x10]\n" - "add x10, x10, #0x10\n" - ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002f7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc15d93a0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[0]\n" - ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc15d96a0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[1]\n" - ".inst 0xc15d9a20 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[2]\n" - ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" - "tbnz %x[flags], #31, 6f\n" - "sdot z25.s, z13.b, z27.b\n" - "6:" // Width 1: Multiply loop: unique 1: skip row sum - "sub x11, x11, #0x10\n" - "cmp x11, #0x10\n" - "bgt 5b\n" - "7:" // Width 1: Multiply loop: Single iteration only - "whilelt p0.b, XZR, x11\n" - ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" - "subs x11, x11, #0x4\n" - "ld1rqb { z13.b }, p0/Z, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002e7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc15d90a0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[0]\n" - "ble 8f\n" - ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" - "subs x11, x11, #0x4\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc15d97a0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[1]\n" - "ble 8f\n" - ".inst 0xa0400209 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x16]\n" - "subs x11, x11, #0x4\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002eb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc15d9920 // sdot za.s[x8, 0], { z8.b-z11.b }, z13.b[2]\n" - "ble 8f\n" - ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" - "8:" // Width 1: Multiply loop: multiply skip - "tbnz %x[flags], #31, 9f\n" - "9:" // Width 1: Multiply loop: unique 2: skip row sum - ".inst 0xc0060c08 // mova { z8.d-z11.d }, za.d[x8, #0]\n" - ".inst 0xa040421e // ld1w { z30.s-z31.s }, pn8.b/Z, [x16]\n" - "add x22, %x[k_args], %[c_offset]\n" - "add x21, %x[k_args], %[minval]\n" - ".inst 0xa04042f8 // ld1w { z24.s-z25.s }, pn8.b/Z, [x23]\n" - "add x20, %x[k_args], %[maxval]\n" - "ld1rw { z2.s }, p2/Z, [x22]\n" - "ld1rw { z13.s }, p2/Z, [x21]\n" - ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" - "ld1rw { z20.s }, p2/Z, [x20]\n" - "fmul z8.s, z8.s, z30.s\n" - "fmul z9.s, z9.s, z31.s\n" - "fmul z10.s, z10.s, z24.s\n" - "fmul z11.s, z11.s, z25.s\n" - ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc1a2ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z2.s\n" - ".inst 0xc1b4cda8 // sclamp { z8.s-z11.s }, z13.s, z20.s\n" - "uzp1 z8.h, z8.h, z9.h\n" - "uzp1 z0.h, z10.h, z11.h\n" - "uzp1 z8.b, z8.b, z0.b\n" - "st1b { z8.b }, p1, [x14]\n" - "b 32f\n" - "11:" // Width 2 - "add x24, x16, x12, LSL #1\n" - "cntw x20, ALL, MUL #6\n" - ".inst 0xa0404214 // ld1w { z20.s-z21.s }, pn8.b/Z, [x16]\n" - "add x22, x24, x12\n" - "cmp %x[N], x20\n" - ".inst 0xa040430c // ld1w { z12.s-z13.s }, pn8.b/Z, [x24]\n" - "add x23, x16, x12\n" - "csel x22, x22, x16, GT\n" - ".inst 0xa04042f6 // ld1w { z22.s-z23.s }, pn8.b/Z, [x23]\n" - "mov x11, %x[K]\n" - "sub x21, %x[N], x15\n" - ".inst 0xa04042ce // ld1w { z14.s-z15.s }, pn8.b/Z, [x22]\n" - "mov x10, %x[A_ptr]\n" - "mov x20, %x[K]\n" - "whilelt p1.b, XZR, x21\n" - "cmp x11, #0x10\n" - ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xc0040e80 // mova za.d[x8, #0], { z20.d-z23.d }\n" - "addvl x23, x23, #2\n" - "addvl x24, x24, #2\n" - ".inst 0xc0040d81 // mova za.d[x8, #1], { z12.d-z15.d }\n" - "addvl x22, x22, #2\n" - "ble 14f\n" - "12:" // Width 2: Multiply loop: Main loop head - "whilelt p0.b, XZR, x11\n" - ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - "ld1rqb { z13.b }, p0/Z, [x10]\n" - "add x10, x10, #0x10\n" - ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0400305 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc15d9220 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[0]\n" - ".inst 0xa0400209 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002eb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc15d90a1 // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[0]\n" - ".inst 0xa0400311 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04002d3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc15d9520 // sdot za.s[x8, 0], { z8.b-z11.b }, z13.b[1]\n" - ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002e3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc15d9621 // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[1]\n" - ".inst 0xa0400305 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc15d9820 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[2]\n" - ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc15d98a1 // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[2]\n" - ".inst 0xa0400309 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04002cb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" - ".inst 0xc15d9d21 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[3]\n" - "tbnz %x[flags], #31, 13f\n" - "sdot z25.s, z13.b, z27.b\n" - "13:" // Width 2: Multiply loop: unique 3: skip row sum - "sub x11, x11, #0x10\n" - "cmp x11, #0x10\n" - "bgt 12b\n" - "14:" // Width 2: Multiply loop: Single iteration only - "whilelt p0.b, XZR, x11\n" - ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" - "subs x11, x11, #0x4\n" - "ld1rqb { z13.b }, p0/Z, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0400301 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04002c3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc15d9220 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[0]\n" - ".inst 0xc15d9021 // sdot za.s[x8, 1], { z0.b-z3.b }, z13.b[0]\n" - "ble 15f\n" - ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" - "subs x11, x11, #0x4\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002f7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa040031d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc15d96a0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[1]\n" - ".inst 0xc15d97a1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[1]\n" - "ble 15f\n" - ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" - "subs x11, x11, #0x4\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002e3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0400319 // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04002db // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc15d9820 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[2]\n" - ".inst 0xc15d9b21 // sdot za.s[x8, 1], { z24.b-z27.b }, z13.b[2]\n" - "ble 15f\n" - ".inst 0xa0400219 // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002fb // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa040031d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x24]\n" - ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" - ".inst 0xc15d9f20 // sdot za.s[x8, 0], { z24.b-z27.b }, z13.b[3]\n" - ".inst 0xc15d9fa1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[3]\n" - "15:" // Width 2: Multiply loop: multiply skip - "tbnz %x[flags], #31, 16f\n" - "16:" // Width 2: Multiply loop: unique 4: skip row sum - ".inst 0xc0060c00 // mova { z0.d-z3.d }, za.d[x8, #0]\n" - ".inst 0xa0404208 // ld1w { z8.s-z9.s }, pn8.b/Z, [x16]\n" - "add x22, %x[k_args], %[c_offset]\n" - "add x21, %x[k_args], %[minval]\n" - ".inst 0xa04042fe // ld1w { z30.s-z31.s }, pn8.b/Z, [x23]\n" - "add x20, %x[k_args], %[maxval]\n" - ".inst 0xc0060c24 // mova { z4.d-z7.d }, za.d[x8, #1]\n" - "add x16, x16, x12, LSL #1\n" - "ld1rw { z14.s }, p2/Z, [x22]\n" - "add x23, x23, x12, LSL #1\n" - "ld1rw { z11.s }, p2/Z, [x21]\n" - ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" - "ld1rw { z10.s }, p2/Z, [x20]\n" - "fmul z0.s, z0.s, z8.s\n" - "fmul z1.s, z1.s, z9.s\n" - ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" - "fmul z2.s, z2.s, z30.s\n" - "fmul z3.s, z3.s, z31.s\n" - ".inst 0xc1b8e000 // frintn { z0.s-z3.s }, { z0.s-z3.s }\n" - ".inst 0xc131e000 // fcvtzs { z0.s-z3.s }, { z0.s-z3.s }\n" - ".inst 0xc1aeab00 // add { z0.s-z3.s }, { z0.s-z3.s }, z14.s\n" - ".inst 0xc1aacd60 // sclamp { z0.s-z3.s }, z11.s, z10.s\n" - "uzp1 z0.h, z0.h, z1.h\n" - "uzp1 z16.h, z2.h, z3.h\n" - "uzp1 z0.b, z0.b, z16.b\n" - "st1b { z0.b }, p2, [x14]\n" - ".inst 0xa1404217 // ld1w { z23.s, z31.s }, pn8.b/Z, [x16]\n" - ".inst 0xa14042f6 // ld1w { z22.s, z30.s }, pn8.b/Z, [x23]\n" - "fmul z4.s, z4.s, z23.s\n" - "fmul z5.s, z5.s, z31.s\n" - "fmul z6.s, z6.s, z22.s\n" - "fmul z7.s, z7.s, z30.s\n" - ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc1aeab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z14.s\n" - ".inst 0xc1aacd64 // sclamp { z4.s-z7.s }, z11.s, z10.s\n" - "uzp1 z4.h, z4.h, z5.h\n" - "uzp1 z2.h, z6.h, z7.h\n" - "uzp1 z4.b, z4.b, z2.b\n" - "st1b { z4.b }, p1, [x14, #1, MUL VL]\n" - "b 32f\n" - "18:" // Width 3 - "add x26, x16, x12, LSL #2\n" - "cntw x20, ALL, MUL #10\n" - ".inst 0xa0404210 // ld1w { z16.s-z17.s }, pn8.b/Z, [x16]\n" - "add x25, x16, x12, LSL #1\n" - "add x24, x26, x12\n" - ".inst 0xa040435c // ld1w { z28.s-z29.s }, pn8.b/Z, [x26]\n" - "cmp %x[N], x20\n" - "add x23, x16, x12\n" - ".inst 0xa040432c // ld1w { z12.s-z13.s }, pn8.b/Z, [x25]\n" - "add x22, x25, x12\n" - "csel x24, x24, x16, GT\n" - ".inst 0xa04042f2 // ld1w { z18.s-z19.s }, pn8.b/Z, [x23]\n" - "mov x20, #0x2\n" - ".inst 0xa04042ce // ld1w { z14.s-z15.s }, pn8.b/Z, [x22]\n" - "mov x11, %x[K]\n" - ".inst 0xa040431e // ld1w { z30.s-z31.s }, pn8.b/Z, [x24]\n" - "msub x21, x15, x20, %x[N]\n" - "mov x10, %x[A_ptr]\n" - "mov x20, %x[K]\n" - "whilelt p1.b, XZR, x21\n" - ".inst 0xc0040e00 // mova za.d[x8, #0], { z16.d-z19.d }\n" - "cmp x11, #0x10\n" - ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" - ".inst 0xc0040d81 // mova za.d[x8, #1], { z12.d-z15.d }\n" - "addvl x16, x16, #2\n" - "addvl x23, x23, #2\n" - ".inst 0xc0040f82 // mova za.d[x8, #2], { z28.d-z31.d }\n" - "addvl x25, x25, #2\n" - "addvl x22, x22, #2\n" - "addvl x26, x26, #2\n" - "addvl x24, x24, #2\n" - "ble 21f\n" - "19:" // Width 3: Multiply loop: Main loop head - "whilelt p0.b, XZR, x11\n" - ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - "ld1rqb { z13.b }, p0/Z, [x10]\n" - "add x10, x10, #0x10\n" - ".inst 0xa04002e3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0400329 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04002cb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400351 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x26]\n" - ".inst 0xc15d9020 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[0]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc15d9121 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[0]\n" - ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc15d9222 // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[0]\n" - ".inst 0xa0400321 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04002c3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400345 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x26]\n" - ".inst 0xc15d97a0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[1]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0400307 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc15d9421 // sdot za.s[x8, 1], { z0.b-z3.b }, z13.b[1]\n" - ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002f7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc15d94a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[1]\n" - ".inst 0xa040033d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400345 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x26]\n" - ".inst 0xc15d9aa0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[2]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0400307 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc15d9ba1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[2]\n" - ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc15d98a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[2]\n" - ".inst 0xa0400335 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04002d7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400351 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x26]\n" - ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc15d9ea1 // sdot za.s[x8, 1], { z20.b-z23.b }, z13.b[3]\n" - ".inst 0xc15d9e22 // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[3]\n" - "tbnz %x[flags], #31, 20f\n" - "sdot z25.s, z13.b, z27.b\n" - "20:" // Width 3: Multiply loop: unique 5: skip row sum - "sub x11, x11, #0x10\n" - "cmp x11, #0x10\n" - "bgt 19b\n" - "21:" // Width 3: Multiply loop: Single iteration only - "whilelt p0.b, XZR, x11\n" - ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" - "subs x11, x11, #0x4\n" - "ld1rqb { z13.b }, p0/Z, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0400325 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa040035d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x26]\n" - ".inst 0xc15d9220 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[0]\n" - "addvl x26, x26, #2\n" - ".inst 0xa040031f // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc15d90a1 // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[0]\n" - ".inst 0xc15d93a2 // sdot za.s[x8, 2], { z28.b-z31.b }, z13.b[0]\n" - "ble 22f\n" - ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" - "subs x11, x11, #0x4\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0400329 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04002cb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400345 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x26]\n" - ".inst 0xc15d9620 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[1]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0400307 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc15d9521 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[1]\n" - ".inst 0xc15d94a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[1]\n" - "ble 22f\n" - ".inst 0xa0400209 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x16]\n" - "subs x11, x11, #0x4\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002eb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa040033d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400359 // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x26]\n" - ".inst 0xc15d9920 // sdot za.s[x8, 0], { z8.b-z11.b }, z13.b[2]\n" - "addvl x26, x26, #2\n" - ".inst 0xa040031b // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc15d9ba1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[2]\n" - ".inst 0xc15d9b22 // sdot za.s[x8, 2], { z24.b-z27.b }, z13.b[2]\n" - "ble 22f\n" - ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0400339 // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x25]\n" - ".inst 0xa04002db // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x22]\n" - ".inst 0xa0400355 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x26]\n" - ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" - ".inst 0xa0400317 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x24]\n" - ".inst 0xc15d9f21 // sdot za.s[x8, 1], { z24.b-z27.b }, z13.b[3]\n" - ".inst 0xc15d9ea2 // sdot za.s[x8, 2], { z20.b-z23.b }, z13.b[3]\n" - "22:" // Width 3: Multiply loop: multiply skip - "tbnz %x[flags], #31, 23f\n" - "23:" // Width 3: Multiply loop: unique 6: skip row sum - ".inst 0xc0060c18 // mova { z24.d-z27.d }, za.d[x8, #0]\n" - ".inst 0xa0404202 // ld1w { z2.s-z3.s }, pn8.b/Z, [x16]\n" - "add x22, %x[k_args], %[c_offset]\n" - "add x21, %x[k_args], %[minval]\n" - ".inst 0xa04042e6 // ld1w { z6.s-z7.s }, pn8.b/Z, [x23]\n" - "add x20, %x[k_args], %[maxval]\n" - ".inst 0xc0060c28 // mova { z8.d-z11.d }, za.d[x8, #1]\n" - "add x16, x16, x12, LSL #1\n" - "ld1rw { z0.s }, p2/Z, [x22]\n" - "add x23, x23, x12, LSL #1\n" - ".inst 0xc0060c5c // mova { z28.d-z31.d }, za.d[x8, #2]\n" - "ld1rw { z19.s }, p2/Z, [x21]\n" - ".inst 0xc132e318 // scvtf { z24.s-z27.s }, { z24.s-z27.s }\n" - "ld1rw { z18.s }, p2/Z, [x20]\n" - "fmul z24.s, z24.s, z2.s\n" - "fmul z25.s, z25.s, z3.s\n" - ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" - "fmul z26.s, z26.s, z6.s\n" - "fmul z27.s, z27.s, z7.s\n" - ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" - ".inst 0xc1b8e318 // frintn { z24.s-z27.s }, { z24.s-z27.s }\n" - ".inst 0xc131e318 // fcvtzs { z24.s-z27.s }, { z24.s-z27.s }\n" - ".inst 0xc1a0ab18 // add { z24.s-z27.s }, { z24.s-z27.s }, z0.s\n" - ".inst 0xc1b2ce78 // sclamp { z24.s-z27.s }, z19.s, z18.s\n" - "uzp1 z24.h, z24.h, z25.h\n" - "uzp1 z16.h, z26.h, z27.h\n" - "uzp1 z24.b, z24.b, z16.b\n" - "st1b { z24.b }, p2, [x14]\n" - ".inst 0xa1404207 // ld1w { z7.s, z15.s }, pn8.b/Z, [x16]\n" - "add x16, x16, x12, LSL #1\n" - ".inst 0xa14042f1 // ld1w { z17.s, z25.s }, pn8.b/Z, [x23]\n" - "add x23, x23, x12, LSL #1\n" - "fmul z8.s, z8.s, z7.s\n" - "fmul z9.s, z9.s, z15.s\n" - "fmul z10.s, z10.s, z17.s\n" - "fmul z11.s, z11.s, z25.s\n" - ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc1a0ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s\n" - ".inst 0xc1b2ce68 // sclamp { z8.s-z11.s }, z19.s, z18.s\n" - "uzp1 z8.h, z8.h, z9.h\n" - "uzp1 z16.h, z10.h, z11.h\n" - "uzp1 z8.b, z8.b, z16.b\n" - "st1b { z8.b }, p2, [x14, #1, MUL VL]\n" - ".inst 0xa1404207 // ld1w { z7.s, z15.s }, pn8.b/Z, [x16]\n" - ".inst 0xa14042f1 // ld1w { z17.s, z25.s }, pn8.b/Z, [x23]\n" - "fmul z28.s, z28.s, z7.s\n" - "fmul z29.s, z29.s, z15.s\n" - "fmul z30.s, z30.s, z17.s\n" - "fmul z31.s, z31.s, z25.s\n" - ".inst 0xc1b8e39c // frintn { z28.s-z31.s }, { z28.s-z31.s }\n" - ".inst 0xc131e39c // fcvtzs { z28.s-z31.s }, { z28.s-z31.s }\n" - ".inst 0xc1a0ab1c // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s\n" - ".inst 0xc1b2ce7c // sclamp { z28.s-z31.s }, z19.s, z18.s\n" - "uzp1 z28.h, z28.h, z29.h\n" - "uzp1 z16.h, z30.h, z31.h\n" - "uzp1 z28.b, z28.b, z16.b\n" - "st1b { z28.b }, p1, [x14, #2, MUL VL]\n" - "b 32f\n" - "25:" // Width 4 - "add x9, x16, x12, LSL #2\n" - "cntw x20, ALL, MUL #14\n" - ".inst 0xa040420c // ld1w { z12.s-z13.s }, pn8.b/Z, [x16]\n" - "add x28, x9, x12, LSL #1\n" - "add x27, x16, x12, LSL #1\n" - ".inst 0xa0404124 // ld1w { z4.s-z5.s }, pn8.b/Z, [x9]\n" - "add x26, x28, x12\n" - "cmp %x[N], x20\n" - ".inst 0xa0404368 // ld1w { z8.s-z9.s }, pn8.b/Z, [x27]\n" - "add x25, x16, x12\n" - "add x24, x27, x12\n" - ".inst 0xa0404380 // ld1w { z0.s-z1.s }, pn8.b/Z, [x28]\n" - "add x22, x9, x12\n" - "csel x26, x26, x16, GT\n" - ".inst 0xa040432e // ld1w { z14.s-z15.s }, pn8.b/Z, [x25]\n" - "mov x20, #0x3\n" - ".inst 0xa040430a // ld1w { z10.s-z11.s }, pn8.b/Z, [x24]\n" - "mov x11, %x[K]\n" - ".inst 0xa04042c6 // ld1w { z6.s-z7.s }, pn8.b/Z, [x22]\n" - "msub x21, x15, x20, %x[N]\n" - "mov x10, %x[A_ptr]\n" - ".inst 0xa0404342 // ld1w { z2.s-z3.s }, pn8.b/Z, [x26]\n" - "mov x20, %x[K]\n" - "whilelt p1.b, XZR, x21\n" - ".inst 0xc0040d80 // mova za.d[x8, #0], { z12.d-z15.d }\n" - "cmp x11, #0x10\n" - ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" - ".inst 0xc0040d01 // mova za.d[x8, #1], { z8.d-z11.d }\n" - "add x23, x16, x12, LSL #3\n" - "addvl x16, x16, #2\n" - ".inst 0xc0040c82 // mova za.d[x8, #2], { z4.d-z7.d }\n" - "addvl x25, x25, #2\n" - "addvl x27, x27, #2\n" - ".inst 0xc0040c03 // mova za.d[x8, #3], { z0.d-z3.d }\n" - "addvl x24, x24, #2\n" - "addvl x9, x9, #2\n" - "addvl x22, x22, #2\n" - "addvl x28, x28, #2\n" - "addvl x26, x26, #2\n" - "ble 28f\n" - "26:" // Width 4: Multiply loop: Main loop head - "whilelt p0.b, XZR, x11\n" - ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - "ld1rqb { z13.b }, p0/Z, [x10]\n" - "add x10, x10, #0x10\n" - ".inst 0xa0400323 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa0400369 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa040030b // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0400125 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x9]\n" - ".inst 0xc15d9020 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[0]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400395 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x28]\n" - ".inst 0xc15d9121 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[0]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0400357 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc15d90a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[0]\n" - ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa0400323 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xc15d92a3 // sdot za.s[x8, 3], { z20.b-z23.b }, z13.b[0]\n" - ".inst 0xa0400365 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0400307 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0400131 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x9]\n" - ".inst 0xc15d9420 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[1]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04002d3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400381 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x28]\n" - ".inst 0xc15d94a1 // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[1]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0400343 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc15d9622 // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[1]\n" - ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa0400337 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xc15d9423 // sdot za.s[x8, 3], { z0.b-z3.b }, z13.b[1]\n" - ".inst 0xa0400371 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0400125 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x9]\n" - ".inst 0xc15d9aa0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[2]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400381 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x28]\n" - ".inst 0xc15d9a21 // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[2]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0400343 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc15d98a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[2]\n" - ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa0400327 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xc15d9823 // sdot za.s[x8, 3], { z0.b-z3.b }, z13.b[2]\n" - ".inst 0xa040037d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa040031f // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0400135 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x9]\n" - ".inst 0xc15d9ca0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[3]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04002d7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400385 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x28]\n" - ".inst 0xc15d9fa1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[3]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0400347 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc15d9ea2 // sdot za.s[x8, 2], { z20.b-z23.b }, z13.b[3]\n" - ".inst 0xc15d9ca3 // sdot za.s[x8, 3], { z4.b-z7.b }, z13.b[3]\n" - "tbnz %x[flags], #31, 27f\n" - "sdot z25.s, z13.b, z27.b\n" - "27:" // Width 4: Multiply loop: unique 7: skip row sum - "sub x11, x11, #0x10\n" - "cmp x11, #0x10\n" - "bgt 26b\n" - "28:" // Width 4: Multiply loop: Single iteration only - "whilelt p0.b, XZR, x11\n" - ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" - "subs x11, x11, #0x4\n" - "ld1rqb { z13.b }, p0/Z, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xa0400327 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa0400371 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa040013d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x9]\n" - ".inst 0xc15d90a0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[0]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400389 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x28]\n" - ".inst 0xc15d9221 // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[0]\n" - "addvl x28, x28, #2\n" - ".inst 0xa040034b // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc15d93a2 // sdot za.s[x8, 2], { z28.b-z31.b }, z13.b[0]\n" - ".inst 0xc15d9123 // sdot za.s[x8, 3], { z8.b-z11.b }, z13.b[0]\n" - "ble 29f\n" - ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" - "subs x11, x11, #0x4\n" - "addvl x16, x16, #2\n" - ".inst 0xa0400327 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa0400371 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0400121 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x9]\n" - ".inst 0xc15d94a0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[1]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04002c3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400385 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x28]\n" - ".inst 0xc15d9621 // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[1]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0400347 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc15d9422 // sdot za.s[x8, 2], { z0.b-z3.b }, z13.b[1]\n" - ".inst 0xc15d94a3 // sdot za.s[x8, 3], { z4.b-z7.b }, z13.b[1]\n" - "ble 29f\n" - ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" - "subs x11, x11, #0x4\n" - "addvl x16, x16, #2\n" - ".inst 0xa0400337 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa0400369 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa040030b // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0400125 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x9]\n" - ".inst 0xc15d9aa0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[2]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0400391 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x28]\n" - ".inst 0xc15d9921 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[2]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0400353 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc15d98a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[2]\n" - ".inst 0xc15d9a23 // sdot za.s[x8, 3], { z16.b-z19.b }, z13.b[2]\n" - "ble 29f\n" - ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa0400327 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa0400361 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x27]\n" - ".inst 0xa0400303 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x24]\n" - ".inst 0xa0400131 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x9]\n" - ".inst 0xc15d9ca0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[3]\n" - ".inst 0xa04002d3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x22]\n" - ".inst 0xa0400385 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x28]\n" - ".inst 0xc15d9c21 // sdot za.s[x8, 1], { z0.b-z3.b }, z13.b[3]\n" - ".inst 0xa0400347 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x26]\n" - ".inst 0xc15d9e22 // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[3]\n" - ".inst 0xc15d9ca3 // sdot za.s[x8, 3], { z4.b-z7.b }, z13.b[3]\n" - "29:" // Width 4: Multiply loop: multiply skip - "tbnz %x[flags], #31, 30f\n" - "sdot z25.s, z13.b, z27.b\n" - "30:" // Width 4: Multiply loop: unique 8: skip row sum - ".inst 0xc0060c04 // mova { z4.d-z7.d }, za.d[x8, #0]\n" - ".inst 0xa0404202 // ld1w { z2.s-z3.s }, pn8.b/Z, [x16]\n" - "add x22, %x[k_args], %[c_offset]\n" - "add x21, %x[k_args], %[minval]\n" - ".inst 0xa040432c // ld1w { z12.s-z13.s }, pn8.b/Z, [x25]\n" - "add x20, %x[k_args], %[maxval]\n" - ".inst 0xc0060c3c // mova { z28.d-z31.d }, za.d[x8, #1]\n" - "add x16, x16, x12, LSL #1\n" - "ld1rw { z0.s }, p2/Z, [x22]\n" - "add x25, x25, x12, LSL #1\n" - ".inst 0xc0060c54 // mova { z20.d-z23.d }, za.d[x8, #2]\n" - "ld1rw { z1.s }, p2/Z, [x21]\n" - ".inst 0xc0060c68 // mova { z8.d-z11.d }, za.d[x8, #3]\n" - ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" - "ld1rw { z17.s }, p2/Z, [x20]\n" - "fmul z4.s, z4.s, z2.s\n" - "fmul z5.s, z5.s, z3.s\n" - ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" - "fmul z6.s, z6.s, z12.s\n" - "fmul z7.s, z7.s, z13.s\n" - ".inst 0xc132e294 // scvtf { z20.s-z23.s }, { z20.s-z23.s }\n" - ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" - ".inst 0xc1b1cc24 // sclamp { z4.s-z7.s }, z1.s, z17.s\n" - "uzp1 z4.h, z4.h, z5.h\n" - "uzp1 z16.h, z6.h, z7.h\n" - "uzp1 z4.b, z4.b, z16.b\n" - "st1b { z4.b }, p2, [x14]\n" - ".inst 0xa1404212 // ld1w { z18.s, z26.s }, pn8.b/Z, [x16]\n" - "add x16, x16, x12, LSL #1\n" - ".inst 0xa0404324 // ld1w { z4.s-z5.s }, pn8.b/Z, [x25]\n" - "add x25, x25, x12, LSL #1\n" - "fmul z28.s, z28.s, z18.s\n" - "fmul z29.s, z29.s, z26.s\n" - "fmul z30.s, z30.s, z4.s\n" - "fmul z31.s, z31.s, z5.s\n" - ".inst 0xc1b8e39c // frintn { z28.s-z31.s }, { z28.s-z31.s }\n" - ".inst 0xc131e39c // fcvtzs { z28.s-z31.s }, { z28.s-z31.s }\n" - ".inst 0xc1a0ab1c // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s\n" - ".inst 0xc1b1cc3c // sclamp { z28.s-z31.s }, z1.s, z17.s\n" - "uzp1 z28.h, z28.h, z29.h\n" - "uzp1 z16.h, z30.h, z31.h\n" - "uzp1 z28.b, z28.b, z16.b\n" - "st1b { z28.b }, p2, [x14, #1, MUL VL]\n" - ".inst 0xa1404207 // ld1w { z7.s, z15.s }, pn8.b/Z, [x16]\n" - "add x16, x16, x12, LSL #1\n" - ".inst 0xa1404324 // ld1w { z4.s, z12.s }, pn8.b/Z, [x25]\n" - "add x25, x25, x12, LSL #1\n" - "fmul z20.s, z20.s, z7.s\n" - "fmul z21.s, z21.s, z15.s\n" - "fmul z22.s, z22.s, z4.s\n" - "fmul z23.s, z23.s, z12.s\n" - ".inst 0xc1b8e294 // frintn { z20.s-z23.s }, { z20.s-z23.s }\n" - ".inst 0xc131e294 // fcvtzs { z20.s-z23.s }, { z20.s-z23.s }\n" - ".inst 0xc1a0ab14 // add { z20.s-z23.s }, { z20.s-z23.s }, z0.s\n" - ".inst 0xc1b1cc34 // sclamp { z20.s-z23.s }, z1.s, z17.s\n" - "uzp1 z20.h, z20.h, z21.h\n" - "uzp1 z16.h, z22.h, z23.h\n" - "uzp1 z20.b, z20.b, z16.b\n" - "st1b { z20.b }, p2, [x14, #2, MUL VL]\n" - ".inst 0xa1404206 // ld1w { z6.s, z14.s }, pn8.b/Z, [x16]\n" - ".inst 0xa1404327 // ld1w { z7.s, z15.s }, pn8.b/Z, [x25]\n" - "fmul z8.s, z8.s, z6.s\n" - "fmul z9.s, z9.s, z14.s\n" - "fmul z10.s, z10.s, z7.s\n" - "fmul z11.s, z11.s, z15.s\n" - ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc1a0ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s\n" - ".inst 0xc1b1cc28 // sclamp { z8.s-z11.s }, z1.s, z17.s\n" - "uzp1 z8.h, z8.h, z9.h\n" - "uzp1 z16.h, z10.h, z11.h\n" - "uzp1 z8.b, z8.b, z16.b\n" - "st1b { z8.b }, p1, [x14, #3, MUL VL]\n" - "addvl x14, x14, #4\n" - "subs x13, x13, #0x4\n" - "mov x16, x23\n" - "sub %x[N], %x[N], x15, LSL #2\n" - "bgt 4b\n" - "32:" // Exit - ".inst 0xd503467f // SMSTOP\n" - : [N] "+&r"(N), [flags] "+&r"(flags) - : [A_ptr] "r"(A_ptr), [B_ptr] "r"(B_ptr), [K] "r"(K), [c_offset] "I"(offsetof(KernelArgs, c_offset)), - [k_args] "r"(&k_args), [maxval] "I"(offsetof(KernelArgs, maxval)), [minval] "I"(offsetof(KernelArgs, minval)), - [output_ptr] "r"(output_ptr) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", - "x27", "x28", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", - "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", - "z6", "z7", "z8", "z9"); + kai_kernel_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S new file mode 100644 index 00000000..2cfe4ce6 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S @@ -0,0 +1,907 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x8, #0x0 + ldr x5, [x0, #0x20] + cntw x6, ALL, MUL #4 + ptrue p2.b + ldr x21, [x0, #0x18] + KAI_ASM_INST(0x25207810) // ptrue pn8.b + mov x22, #0x1 + ldr x7, [x0, #0x28] + add x17, x5, x6 + ldr x20, [x0, #0x30] + sub x17, x17, #0x1 + ldr x16, [x0, #0x10] + mov x15, x21 + udiv x17, x17, x6 + ldr x14, [x0, #0x38] + add x21, x17, #0x3 + mov x13, x20 + and x21, x21, #0xfffffffffffffffc + mul x21, x21, x6 + mul x21, x21, x7 +KAI_ASM_LABEL(label_1) // RHS size check loop + cmp x21, #0x200, LSL #12 + blt label_2 + tbnz x21, #0, label_3 + lsr x21, x21, #0x1 + lsl x22, x22, #0x1 + b label_1 +KAI_ASM_LABEL(label_2) // RHS do prefetch + lsl x20, x21, #0x26 + sub x22, x22, #0x1 + lsl x22, x22, #0x16 + orr x21, x21, x20 + orr x21, x21, x22 + KAI_ASM_INST(0xf8b549fa) // rprfm pldonce, x21, [x15] +KAI_ASM_LABEL(label_3) // RHS prefetch exit + add x12, x7, #0x3 + cntw x20, ALL, MUL #2 + mov z25.s, #0x0 + mov z27.b, #0x1 + bic x12, x12, #0x3 + bic x14, x14, #0x80000000 + add x12, x12, #0x8 + mul x12, x12, x20 +KAI_ASM_LABEL(label_4) // Column loop + cmp x17, #0x4 + bge label_25 + cmp x17, #0x2 + bgt label_18 + beq label_11 + cntw x20, ALL, MUL #2 + add x23, x15, x12 + KAI_ASM_INST(0xa04041f0) // ld1w { z16.s-z17.s }, pn8.b/Z, [x15] + cmp x5, x20 + mov x11, x7 + csel x23, x23, x15, GT + mov x21, x5 + KAI_ASM_INST(0xa04042f2) // ld1w { z18.s-z19.s }, pn8.b/Z, [x23] + mov x10, x16 + mov x20, x7 + whilelt p1.b, XZR, x21 + cmp x11, #0x10 + KAI_ASM_INST(0xf8b44958) // rprfm pldmany, x20, [x10] + addvl x15, x15, #2 + addvl x23, x23, #2 + KAI_ASM_INST(0xc0040e00) // mova za.d[x8, #0], { z16.d-z19.d } + ble label_7 +KAI_ASM_LABEL(label_5) // Width 1: Multiply loop: Main loop head + whilelt p0.b, XZR, x11 + KAI_ASM_INST(0xa04001fd) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + ld1rqb { z13.b }, p0/Z, [x10] + add x10, x10, #0x10 + KAI_ASM_INST(0xa04002ff) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa04001f5) // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002f7) // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc15d93a0) // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[0] + KAI_ASM_INST(0xa04001f1) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002f3) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa04001fd) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002ff) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc15d96a0) // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[1] + KAI_ASM_INST(0xc15d9a20) // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[2] + KAI_ASM_INST(0xc15d9fa0) // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3] + tbnz x14, #31, label_6 + sdot z25.s, z13.b, z27.b +KAI_ASM_LABEL(label_6) // Width 1: Multiply loop: unique 1: skip row sum + sub x11, x11, #0x10 + cmp x11, #0x10 + bgt label_5 +KAI_ASM_LABEL(label_7) // Width 1: Multiply loop: Single iteration only + whilelt p0.b, XZR, x11 + KAI_ASM_INST(0xa04001e5) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x15] + subs x11, x11, #0x4 + ld1rqb { z13.b }, p0/Z, [x10] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002e7) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc15d90a0) // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[0] + ble label_8 + KAI_ASM_INST(0xa04001fd) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x15] + subs x11, x11, #0x4 + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002ff) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc15d97a0) // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[1] + ble label_8 + KAI_ASM_INST(0xa04001e9) // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x15] + subs x11, x11, #0x4 + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002eb) // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc15d9920) // sdot za.s[x8, 0], { z8.b-z11.b }, z13.b[2] + ble label_8 + KAI_ASM_INST(0xa04001fd) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002ff) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc15d9fa0) // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3] +KAI_ASM_LABEL(label_8) // Width 1: Multiply loop: multiply skip + tbnz x14, #31, label_9 +KAI_ASM_LABEL(label_9) // Width 1: Multiply loop: unique 2: skip row sum + KAI_ASM_INST(0xc0060c08) // mova { z8.d-z11.d }, za.d[x8, #0] + KAI_ASM_INST(0xa04041fe) // ld1w { z30.s-z31.s }, pn8.b/Z, [x15] + add x22, x0, #0x0 + add x21, x0, #0x8 + KAI_ASM_INST(0xa04042f8) // ld1w { z24.s-z25.s }, pn8.b/Z, [x23] + add x20, x0, #0x4 + ld1rw { z2.s }, p2/Z, [x22] + ld1rw { z13.s }, p2/Z, [x21] + KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } + ld1rw { z20.s }, p2/Z, [x20] + fmul z8.s, z8.s, z30.s + fmul z9.s, z9.s, z31.s + fmul z10.s, z10.s, z24.s + fmul z11.s, z11.s, z25.s + KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc1a2ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z2.s + KAI_ASM_INST(0xc1b4cda8) // sclamp { z8.s-z11.s }, z13.s, z20.s + uzp1 z8.h, z8.h, z9.h + uzp1 z0.h, z10.h, z11.h + uzp1 z8.b, z8.b, z0.b + st1b { z8.b }, p1, [x13] + b label_32 +KAI_ASM_LABEL(label_11) // Width 2 + add x24, x15, x12, LSL #1 + cntw x20, ALL, MUL #6 + KAI_ASM_INST(0xa04041f4) // ld1w { z20.s-z21.s }, pn8.b/Z, [x15] + add x22, x24, x12 + cmp x5, x20 + KAI_ASM_INST(0xa040430c) // ld1w { z12.s-z13.s }, pn8.b/Z, [x24] + add x23, x15, x12 + csel x22, x22, x15, GT + KAI_ASM_INST(0xa04042f6) // ld1w { z22.s-z23.s }, pn8.b/Z, [x23] + mov x11, x7 + sub x21, x5, x6 + KAI_ASM_INST(0xa04042ce) // ld1w { z14.s-z15.s }, pn8.b/Z, [x22] + mov x10, x16 + mov x20, x7 + whilelt p1.b, XZR, x21 + cmp x11, #0x10 + KAI_ASM_INST(0xf8b44958) // rprfm pldmany, x20, [x10] + addvl x15, x15, #2 + KAI_ASM_INST(0xc0040e80) // mova za.d[x8, #0], { z20.d-z23.d } + addvl x23, x23, #2 + addvl x24, x24, #2 + KAI_ASM_INST(0xc0040d81) // mova za.d[x8, #1], { z12.d-z15.d } + addvl x22, x22, #2 + ble label_14 +KAI_ASM_LABEL(label_12) // Width 2: Multiply loop: Main loop head + whilelt p0.b, XZR, x11 + KAI_ASM_INST(0xa04001f1) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + ld1rqb { z13.b }, p0/Z, [x10] + add x10, x10, #0x10 + KAI_ASM_INST(0xa04002f3) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0400305) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04002c7) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc15d9220) // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[0] + KAI_ASM_INST(0xa04001e9) // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002eb) // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc15d90a1) // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[0] + KAI_ASM_INST(0xa0400311) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04002d3) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc15d9520) // sdot za.s[x8, 0], { z8.b-z11.b }, z13.b[1] + KAI_ASM_INST(0xa04001e1) // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002e3) // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc15d9621) // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[1] + KAI_ASM_INST(0xa0400305) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04002c7) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc15d9820) // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[2] + KAI_ASM_INST(0xa04001fd) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002ff) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc15d98a1) // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[2] + KAI_ASM_INST(0xa0400309) // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04002cb) // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc15d9fa0) // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3] + KAI_ASM_INST(0xc15d9d21) // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[3] + tbnz x14, #31, label_13 + sdot z25.s, z13.b, z27.b +KAI_ASM_LABEL(label_13) // Width 2: Multiply loop: unique 3: skip row sum + sub x11, x11, #0x10 + cmp x11, #0x10 + bgt label_12 +KAI_ASM_LABEL(label_14) // Width 2: Multiply loop: Single iteration only + whilelt p0.b, XZR, x11 + KAI_ASM_INST(0xa04001f1) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x15] + subs x11, x11, #0x4 + ld1rqb { z13.b }, p0/Z, [x10] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002f3) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0400301) // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04002c3) // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc15d9220) // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[0] + KAI_ASM_INST(0xc15d9021) // sdot za.s[x8, 1], { z0.b-z3.b }, z13.b[0] + ble label_15 + KAI_ASM_INST(0xa04001f5) // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x15] + subs x11, x11, #0x4 + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002f7) // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa040031d) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04002df) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc15d96a0) // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[1] + KAI_ASM_INST(0xc15d97a1) // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[1] + ble label_15 + KAI_ASM_INST(0xa04001e1) // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x15] + subs x11, x11, #0x4 + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002e3) // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0400319) // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04002db) // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc15d9820) // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[2] + KAI_ASM_INST(0xc15d9b21) // sdot za.s[x8, 1], { z24.b-z27.b }, z13.b[2] + ble label_15 + KAI_ASM_INST(0xa04001f9) // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002fb) // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa040031d) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x24] + KAI_ASM_INST(0xa04002df) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22] + KAI_ASM_INST(0xc15d9f20) // sdot za.s[x8, 0], { z24.b-z27.b }, z13.b[3] + KAI_ASM_INST(0xc15d9fa1) // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[3] +KAI_ASM_LABEL(label_15) // Width 2: Multiply loop: multiply skip + tbnz x14, #31, label_16 +KAI_ASM_LABEL(label_16) // Width 2: Multiply loop: unique 4: skip row sum + KAI_ASM_INST(0xc0060c00) // mova { z0.d-z3.d }, za.d[x8, #0] + KAI_ASM_INST(0xa04041e8) // ld1w { z8.s-z9.s }, pn8.b/Z, [x15] + add x22, x0, #0x0 + add x21, x0, #0x8 + KAI_ASM_INST(0xa04042fe) // ld1w { z30.s-z31.s }, pn8.b/Z, [x23] + add x20, x0, #0x4 + KAI_ASM_INST(0xc0060c24) // mova { z4.d-z7.d }, za.d[x8, #1] + add x15, x15, x12, LSL #1 + ld1rw { z14.s }, p2/Z, [x22] + add x23, x23, x12, LSL #1 + ld1rw { z11.s }, p2/Z, [x21] + KAI_ASM_INST(0xc132e000) // scvtf { z0.s-z3.s }, { z0.s-z3.s } + ld1rw { z10.s }, p2/Z, [x20] + fmul z0.s, z0.s, z8.s + fmul z1.s, z1.s, z9.s + KAI_ASM_INST(0xc132e084) // scvtf { z4.s-z7.s }, { z4.s-z7.s } + fmul z2.s, z2.s, z30.s + fmul z3.s, z3.s, z31.s + KAI_ASM_INST(0xc1b8e000) // frintn { z0.s-z3.s }, { z0.s-z3.s } + KAI_ASM_INST(0xc131e000) // fcvtzs { z0.s-z3.s }, { z0.s-z3.s } + KAI_ASM_INST(0xc1aeab00) // add { z0.s-z3.s }, { z0.s-z3.s }, z14.s + KAI_ASM_INST(0xc1aacd60) // sclamp { z0.s-z3.s }, z11.s, z10.s + uzp1 z0.h, z0.h, z1.h + uzp1 z16.h, z2.h, z3.h + uzp1 z0.b, z0.b, z16.b + st1b { z0.b }, p2, [x13] + KAI_ASM_INST(0xa14041f7) // ld1w { z23.s, z31.s }, pn8.b/Z, [x15] + KAI_ASM_INST(0xa14042f6) // ld1w { z22.s, z30.s }, pn8.b/Z, [x23] + fmul z4.s, z4.s, z23.s + fmul z5.s, z5.s, z31.s + fmul z6.s, z6.s, z22.s + fmul z7.s, z7.s, z30.s + KAI_ASM_INST(0xc1b8e084) // frintn { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc131e084) // fcvtzs { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc1aeab04) // add { z4.s-z7.s }, { z4.s-z7.s }, z14.s + KAI_ASM_INST(0xc1aacd64) // sclamp { z4.s-z7.s }, z11.s, z10.s + uzp1 z4.h, z4.h, z5.h + uzp1 z2.h, z6.h, z7.h + uzp1 z4.b, z4.b, z2.b + st1b { z4.b }, p1, [x13, #1, MUL VL] + b label_32 +KAI_ASM_LABEL(label_18) // Width 3 + add x26, x15, x12, LSL #2 + cntw x20, ALL, MUL #10 + KAI_ASM_INST(0xa04041f0) // ld1w { z16.s-z17.s }, pn8.b/Z, [x15] + add x25, x15, x12, LSL #1 + add x24, x26, x12 + KAI_ASM_INST(0xa040435c) // ld1w { z28.s-z29.s }, pn8.b/Z, [x26] + cmp x5, x20 + add x23, x15, x12 + KAI_ASM_INST(0xa040432c) // ld1w { z12.s-z13.s }, pn8.b/Z, [x25] + add x22, x25, x12 + csel x24, x24, x15, GT + KAI_ASM_INST(0xa04042f2) // ld1w { z18.s-z19.s }, pn8.b/Z, [x23] + mov x20, #0x2 + KAI_ASM_INST(0xa04042ce) // ld1w { z14.s-z15.s }, pn8.b/Z, [x22] + mov x11, x7 + KAI_ASM_INST(0xa040431e) // ld1w { z30.s-z31.s }, pn8.b/Z, [x24] + msub x21, x6, x20, x5 + mov x10, x16 + mov x20, x7 + whilelt p1.b, XZR, x21 + KAI_ASM_INST(0xc0040e00) // mova za.d[x8, #0], { z16.d-z19.d } + cmp x11, #0x10 + KAI_ASM_INST(0xf8b44958) // rprfm pldmany, x20, [x10] + KAI_ASM_INST(0xc0040d81) // mova za.d[x8, #1], { z12.d-z15.d } + addvl x15, x15, #2 + addvl x23, x23, #2 + KAI_ASM_INST(0xc0040f82) // mova za.d[x8, #2], { z28.d-z31.d } + addvl x25, x25, #2 + addvl x22, x22, #2 + addvl x26, x26, #2 + addvl x24, x24, #2 + ble label_21 +KAI_ASM_LABEL(label_19) // Width 3: Multiply loop: Main loop head + whilelt p0.b, XZR, x11 + KAI_ASM_INST(0xa04001e1) // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + ld1rqb { z13.b }, p0/Z, [x10] + add x10, x10, #0x10 + KAI_ASM_INST(0xa04002e3) // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0400329) // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04002cb) // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400351) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x26] + KAI_ASM_INST(0xc15d9020) // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[0] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0400313) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc15d9121) // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[0] + KAI_ASM_INST(0xa04001fd) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002ff) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc15d9222) // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[0] + KAI_ASM_INST(0xa0400321) // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04002c3) // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400345) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x26] + KAI_ASM_INST(0xc15d97a0) // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[1] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0400307) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc15d9421) // sdot za.s[x8, 1], { z0.b-z3.b }, z13.b[1] + KAI_ASM_INST(0xa04001f5) // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002f7) // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc15d94a2) // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[1] + KAI_ASM_INST(0xa040033d) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04002df) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400345) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x26] + KAI_ASM_INST(0xc15d9aa0) // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[2] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0400307) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc15d9ba1) // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[2] + KAI_ASM_INST(0xa04001fd) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002ff) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc15d98a2) // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[2] + KAI_ASM_INST(0xa0400335) // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04002d7) // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400351) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x26] + KAI_ASM_INST(0xc15d9fa0) // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0400313) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc15d9ea1) // sdot za.s[x8, 1], { z20.b-z23.b }, z13.b[3] + KAI_ASM_INST(0xc15d9e22) // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[3] + tbnz x14, #31, label_20 + sdot z25.s, z13.b, z27.b +KAI_ASM_LABEL(label_20) // Width 3: Multiply loop: unique 5: skip row sum + sub x11, x11, #0x10 + cmp x11, #0x10 + bgt label_19 +KAI_ASM_LABEL(label_21) // Width 3: Multiply loop: Single iteration only + whilelt p0.b, XZR, x11 + KAI_ASM_INST(0xa04001f1) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x15] + subs x11, x11, #0x4 + ld1rqb { z13.b }, p0/Z, [x10] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002f3) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0400325) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04002c7) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa040035d) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x26] + KAI_ASM_INST(0xc15d9220) // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[0] + addvl x26, x26, #2 + KAI_ASM_INST(0xa040031f) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc15d90a1) // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[0] + KAI_ASM_INST(0xc15d93a2) // sdot za.s[x8, 2], { z28.b-z31.b }, z13.b[0] + ble label_22 + KAI_ASM_INST(0xa04001f1) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x15] + subs x11, x11, #0x4 + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002f3) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0400329) // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04002cb) // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400345) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x26] + KAI_ASM_INST(0xc15d9620) // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[1] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0400307) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc15d9521) // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[1] + KAI_ASM_INST(0xc15d94a2) // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[1] + ble label_22 + KAI_ASM_INST(0xa04001e9) // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x15] + subs x11, x11, #0x4 + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002eb) // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa040033d) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04002df) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400359) // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x26] + KAI_ASM_INST(0xc15d9920) // sdot za.s[x8, 0], { z8.b-z11.b }, z13.b[2] + addvl x26, x26, #2 + KAI_ASM_INST(0xa040031b) // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc15d9ba1) // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[2] + KAI_ASM_INST(0xc15d9b22) // sdot za.s[x8, 2], { z24.b-z27.b }, z13.b[2] + ble label_22 + KAI_ASM_INST(0xa04001fd) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa04002ff) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0400339) // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x25] + KAI_ASM_INST(0xa04002db) // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x22] + KAI_ASM_INST(0xa0400355) // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x26] + KAI_ASM_INST(0xc15d9fa0) // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3] + KAI_ASM_INST(0xa0400317) // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x24] + KAI_ASM_INST(0xc15d9f21) // sdot za.s[x8, 1], { z24.b-z27.b }, z13.b[3] + KAI_ASM_INST(0xc15d9ea2) // sdot za.s[x8, 2], { z20.b-z23.b }, z13.b[3] +KAI_ASM_LABEL(label_22) // Width 3: Multiply loop: multiply skip + tbnz x14, #31, label_23 +KAI_ASM_LABEL(label_23) // Width 3: Multiply loop: unique 6: skip row sum + KAI_ASM_INST(0xc0060c18) // mova { z24.d-z27.d }, za.d[x8, #0] + KAI_ASM_INST(0xa04041e2) // ld1w { z2.s-z3.s }, pn8.b/Z, [x15] + add x22, x0, #0x0 + add x21, x0, #0x8 + KAI_ASM_INST(0xa04042e6) // ld1w { z6.s-z7.s }, pn8.b/Z, [x23] + add x20, x0, #0x4 + KAI_ASM_INST(0xc0060c28) // mova { z8.d-z11.d }, za.d[x8, #1] + add x15, x15, x12, LSL #1 + ld1rw { z0.s }, p2/Z, [x22] + add x23, x23, x12, LSL #1 + KAI_ASM_INST(0xc0060c5c) // mova { z28.d-z31.d }, za.d[x8, #2] + ld1rw { z19.s }, p2/Z, [x21] + KAI_ASM_INST(0xc132e318) // scvtf { z24.s-z27.s }, { z24.s-z27.s } + ld1rw { z18.s }, p2/Z, [x20] + fmul z24.s, z24.s, z2.s + fmul z25.s, z25.s, z3.s + KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } + fmul z26.s, z26.s, z6.s + fmul z27.s, z27.s, z7.s + KAI_ASM_INST(0xc132e39c) // scvtf { z28.s-z31.s }, { z28.s-z31.s } + KAI_ASM_INST(0xc1b8e318) // frintn { z24.s-z27.s }, { z24.s-z27.s } + KAI_ASM_INST(0xc131e318) // fcvtzs { z24.s-z27.s }, { z24.s-z27.s } + KAI_ASM_INST(0xc1a0ab18) // add { z24.s-z27.s }, { z24.s-z27.s }, z0.s + KAI_ASM_INST(0xc1b2ce78) // sclamp { z24.s-z27.s }, z19.s, z18.s + uzp1 z24.h, z24.h, z25.h + uzp1 z16.h, z26.h, z27.h + uzp1 z24.b, z24.b, z16.b + st1b { z24.b }, p2, [x13] + KAI_ASM_INST(0xa14041e7) // ld1w { z7.s, z15.s }, pn8.b/Z, [x15] + add x15, x15, x12, LSL #1 + KAI_ASM_INST(0xa14042f1) // ld1w { z17.s, z25.s }, pn8.b/Z, [x23] + add x23, x23, x12, LSL #1 + fmul z8.s, z8.s, z7.s + fmul z9.s, z9.s, z15.s + fmul z10.s, z10.s, z17.s + fmul z11.s, z11.s, z25.s + KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s + KAI_ASM_INST(0xc1b2ce68) // sclamp { z8.s-z11.s }, z19.s, z18.s + uzp1 z8.h, z8.h, z9.h + uzp1 z16.h, z10.h, z11.h + uzp1 z8.b, z8.b, z16.b + st1b { z8.b }, p2, [x13, #1, MUL VL] + KAI_ASM_INST(0xa14041e7) // ld1w { z7.s, z15.s }, pn8.b/Z, [x15] + KAI_ASM_INST(0xa14042f1) // ld1w { z17.s, z25.s }, pn8.b/Z, [x23] + fmul z28.s, z28.s, z7.s + fmul z29.s, z29.s, z15.s + fmul z30.s, z30.s, z17.s + fmul z31.s, z31.s, z25.s + KAI_ASM_INST(0xc1b8e39c) // frintn { z28.s-z31.s }, { z28.s-z31.s } + KAI_ASM_INST(0xc131e39c) // fcvtzs { z28.s-z31.s }, { z28.s-z31.s } + KAI_ASM_INST(0xc1a0ab1c) // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s + KAI_ASM_INST(0xc1b2ce7c) // sclamp { z28.s-z31.s }, z19.s, z18.s + uzp1 z28.h, z28.h, z29.h + uzp1 z16.h, z30.h, z31.h + uzp1 z28.b, z28.b, z16.b + st1b { z28.b }, p1, [x13, #2, MUL VL] + b label_32 +KAI_ASM_LABEL(label_25) // Width 4 + add x9, x15, x12, LSL #2 + cntw x20, ALL, MUL #14 + KAI_ASM_INST(0xa04041ec) // ld1w { z12.s-z13.s }, pn8.b/Z, [x15] + add x28, x9, x12, LSL #1 + add x27, x15, x12, LSL #1 + KAI_ASM_INST(0xa0404124) // ld1w { z4.s-z5.s }, pn8.b/Z, [x9] + add x26, x28, x12 + cmp x5, x20 + KAI_ASM_INST(0xa0404368) // ld1w { z8.s-z9.s }, pn8.b/Z, [x27] + add x25, x15, x12 + add x24, x27, x12 + KAI_ASM_INST(0xa0404380) // ld1w { z0.s-z1.s }, pn8.b/Z, [x28] + add x22, x9, x12 + csel x26, x26, x15, GT + KAI_ASM_INST(0xa040432e) // ld1w { z14.s-z15.s }, pn8.b/Z, [x25] + mov x20, #0x3 + KAI_ASM_INST(0xa040430a) // ld1w { z10.s-z11.s }, pn8.b/Z, [x24] + mov x11, x7 + KAI_ASM_INST(0xa04042c6) // ld1w { z6.s-z7.s }, pn8.b/Z, [x22] + msub x21, x6, x20, x5 + mov x10, x16 + KAI_ASM_INST(0xa0404342) // ld1w { z2.s-z3.s }, pn8.b/Z, [x26] + mov x20, x7 + whilelt p1.b, XZR, x21 + KAI_ASM_INST(0xc0040d80) // mova za.d[x8, #0], { z12.d-z15.d } + cmp x11, #0x10 + KAI_ASM_INST(0xf8b44958) // rprfm pldmany, x20, [x10] + KAI_ASM_INST(0xc0040d01) // mova za.d[x8, #1], { z8.d-z11.d } + add x23, x15, x12, LSL #3 + addvl x15, x15, #2 + KAI_ASM_INST(0xc0040c82) // mova za.d[x8, #2], { z4.d-z7.d } + addvl x25, x25, #2 + addvl x27, x27, #2 + KAI_ASM_INST(0xc0040c03) // mova za.d[x8, #3], { z0.d-z3.d } + addvl x24, x24, #2 + addvl x9, x9, #2 + addvl x22, x22, #2 + addvl x28, x28, #2 + addvl x26, x26, #2 + ble label_28 +KAI_ASM_LABEL(label_26) // Width 4: Multiply loop: Main loop head + whilelt p0.b, XZR, x11 + KAI_ASM_INST(0xa04001e1) // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + ld1rqb { z13.b }, p0/Z, [x10] + add x10, x10, #0x10 + KAI_ASM_INST(0xa0400323) // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa0400369) // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa040030b) // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0400125) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x9] + KAI_ASM_INST(0xc15d9020) // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[0] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04002c7) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400395) // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xc15d9121) // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[0] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0400357) // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc15d90a2) // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[0] + KAI_ASM_INST(0xa04001e1) // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa0400323) // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xc15d92a3) // sdot za.s[x8, 3], { z20.b-z23.b }, z13.b[0] + KAI_ASM_INST(0xa0400365) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0400307) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0400131) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x9] + KAI_ASM_INST(0xc15d9420) // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[1] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04002d3) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400381) // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xc15d94a1) // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[1] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0400343) // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc15d9622) // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[1] + KAI_ASM_INST(0xa04001f5) // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa0400337) // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xc15d9423) // sdot za.s[x8, 3], { z0.b-z3.b }, z13.b[1] + KAI_ASM_INST(0xa0400371) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0400313) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0400125) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x9] + KAI_ASM_INST(0xc15d9aa0) // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[2] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04002c7) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400381) // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xc15d9a21) // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[2] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0400343) // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc15d98a2) // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[2] + KAI_ASM_INST(0xa04001e5) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa0400327) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xc15d9823) // sdot za.s[x8, 3], { z0.b-z3.b }, z13.b[2] + KAI_ASM_INST(0xa040037d) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa040031f) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0400135) // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x9] + KAI_ASM_INST(0xc15d9ca0) // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[3] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04002d7) // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400385) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xc15d9fa1) // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[3] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0400347) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc15d9ea2) // sdot za.s[x8, 2], { z20.b-z23.b }, z13.b[3] + KAI_ASM_INST(0xc15d9ca3) // sdot za.s[x8, 3], { z4.b-z7.b }, z13.b[3] + tbnz x14, #31, label_27 + sdot z25.s, z13.b, z27.b +KAI_ASM_LABEL(label_27) // Width 4: Multiply loop: unique 7: skip row sum + sub x11, x11, #0x10 + cmp x11, #0x10 + bgt label_26 +KAI_ASM_LABEL(label_28) // Width 4: Multiply loop: Single iteration only + whilelt p0.b, XZR, x11 + KAI_ASM_INST(0xa04001e5) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x15] + subs x11, x11, #0x4 + ld1rqb { z13.b }, p0/Z, [x10] + addvl x15, x15, #2 + KAI_ASM_INST(0xa0400327) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa0400371) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0400313) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa040013d) // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x9] + KAI_ASM_INST(0xc15d90a0) // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[0] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04002df) // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400389) // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xc15d9221) // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[0] + addvl x28, x28, #2 + KAI_ASM_INST(0xa040034b) // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc15d93a2) // sdot za.s[x8, 2], { z28.b-z31.b }, z13.b[0] + KAI_ASM_INST(0xc15d9123) // sdot za.s[x8, 3], { z8.b-z11.b }, z13.b[0] + ble label_29 + KAI_ASM_INST(0xa04001e5) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x15] + subs x11, x11, #0x4 + addvl x15, x15, #2 + KAI_ASM_INST(0xa0400327) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa0400371) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0400313) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0400121) // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x9] + KAI_ASM_INST(0xc15d94a0) // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[1] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04002c3) // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400385) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xc15d9621) // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[1] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0400347) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc15d9422) // sdot za.s[x8, 2], { z0.b-z3.b }, z13.b[1] + KAI_ASM_INST(0xc15d94a3) // sdot za.s[x8, 3], { z4.b-z7.b }, z13.b[1] + ble label_29 + KAI_ASM_INST(0xa04001f5) // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x15] + subs x11, x11, #0x4 + addvl x15, x15, #2 + KAI_ASM_INST(0xa0400337) // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa0400369) // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa040030b) // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0400125) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x9] + KAI_ASM_INST(0xc15d9aa0) // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[2] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04002c7) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0400391) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xc15d9921) // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[2] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0400353) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc15d98a2) // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[2] + KAI_ASM_INST(0xc15d9a23) // sdot za.s[x8, 3], { z16.b-z19.b }, z13.b[2] + ble label_29 + KAI_ASM_INST(0xa04001e5) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x15] + addvl x15, x15, #2 + KAI_ASM_INST(0xa0400327) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa0400361) // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x27] + KAI_ASM_INST(0xa0400303) // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x24] + KAI_ASM_INST(0xa0400131) // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x9] + KAI_ASM_INST(0xc15d9ca0) // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[3] + KAI_ASM_INST(0xa04002d3) // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x22] + KAI_ASM_INST(0xa0400385) // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xc15d9c21) // sdot za.s[x8, 1], { z0.b-z3.b }, z13.b[3] + KAI_ASM_INST(0xa0400347) // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x26] + KAI_ASM_INST(0xc15d9e22) // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[3] + KAI_ASM_INST(0xc15d9ca3) // sdot za.s[x8, 3], { z4.b-z7.b }, z13.b[3] +KAI_ASM_LABEL(label_29) // Width 4: Multiply loop: multiply skip + tbnz x14, #31, label_30 + sdot z25.s, z13.b, z27.b +KAI_ASM_LABEL(label_30) // Width 4: Multiply loop: unique 8: skip row sum + KAI_ASM_INST(0xc0060c04) // mova { z4.d-z7.d }, za.d[x8, #0] + KAI_ASM_INST(0xa04041e2) // ld1w { z2.s-z3.s }, pn8.b/Z, [x15] + add x22, x0, #0x0 + add x21, x0, #0x8 + KAI_ASM_INST(0xa040432c) // ld1w { z12.s-z13.s }, pn8.b/Z, [x25] + add x20, x0, #0x4 + KAI_ASM_INST(0xc0060c3c) // mova { z28.d-z31.d }, za.d[x8, #1] + add x15, x15, x12, LSL #1 + ld1rw { z0.s }, p2/Z, [x22] + add x25, x25, x12, LSL #1 + KAI_ASM_INST(0xc0060c54) // mova { z20.d-z23.d }, za.d[x8, #2] + ld1rw { z1.s }, p2/Z, [x21] + KAI_ASM_INST(0xc0060c68) // mova { z8.d-z11.d }, za.d[x8, #3] + KAI_ASM_INST(0xc132e084) // scvtf { z4.s-z7.s }, { z4.s-z7.s } + ld1rw { z17.s }, p2/Z, [x20] + fmul z4.s, z4.s, z2.s + fmul z5.s, z5.s, z3.s + KAI_ASM_INST(0xc132e39c) // scvtf { z28.s-z31.s }, { z28.s-z31.s } + fmul z6.s, z6.s, z12.s + fmul z7.s, z7.s, z13.s + KAI_ASM_INST(0xc132e294) // scvtf { z20.s-z23.s }, { z20.s-z23.s } + KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc1b8e084) // frintn { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc131e084) // fcvtzs { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc1a0ab04) // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s + KAI_ASM_INST(0xc1b1cc24) // sclamp { z4.s-z7.s }, z1.s, z17.s + uzp1 z4.h, z4.h, z5.h + uzp1 z16.h, z6.h, z7.h + uzp1 z4.b, z4.b, z16.b + st1b { z4.b }, p2, [x13] + KAI_ASM_INST(0xa14041f2) // ld1w { z18.s, z26.s }, pn8.b/Z, [x15] + add x15, x15, x12, LSL #1 + KAI_ASM_INST(0xa0404324) // ld1w { z4.s-z5.s }, pn8.b/Z, [x25] + add x25, x25, x12, LSL #1 + fmul z28.s, z28.s, z18.s + fmul z29.s, z29.s, z26.s + fmul z30.s, z30.s, z4.s + fmul z31.s, z31.s, z5.s + KAI_ASM_INST(0xc1b8e39c) // frintn { z28.s-z31.s }, { z28.s-z31.s } + KAI_ASM_INST(0xc131e39c) // fcvtzs { z28.s-z31.s }, { z28.s-z31.s } + KAI_ASM_INST(0xc1a0ab1c) // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s + KAI_ASM_INST(0xc1b1cc3c) // sclamp { z28.s-z31.s }, z1.s, z17.s + uzp1 z28.h, z28.h, z29.h + uzp1 z16.h, z30.h, z31.h + uzp1 z28.b, z28.b, z16.b + st1b { z28.b }, p2, [x13, #1, MUL VL] + KAI_ASM_INST(0xa14041e7) // ld1w { z7.s, z15.s }, pn8.b/Z, [x15] + add x15, x15, x12, LSL #1 + KAI_ASM_INST(0xa1404324) // ld1w { z4.s, z12.s }, pn8.b/Z, [x25] + add x25, x25, x12, LSL #1 + fmul z20.s, z20.s, z7.s + fmul z21.s, z21.s, z15.s + fmul z22.s, z22.s, z4.s + fmul z23.s, z23.s, z12.s + KAI_ASM_INST(0xc1b8e294) // frintn { z20.s-z23.s }, { z20.s-z23.s } + KAI_ASM_INST(0xc131e294) // fcvtzs { z20.s-z23.s }, { z20.s-z23.s } + KAI_ASM_INST(0xc1a0ab14) // add { z20.s-z23.s }, { z20.s-z23.s }, z0.s + KAI_ASM_INST(0xc1b1cc34) // sclamp { z20.s-z23.s }, z1.s, z17.s + uzp1 z20.h, z20.h, z21.h + uzp1 z16.h, z22.h, z23.h + uzp1 z20.b, z20.b, z16.b + st1b { z20.b }, p2, [x13, #2, MUL VL] + KAI_ASM_INST(0xa14041e6) // ld1w { z6.s, z14.s }, pn8.b/Z, [x15] + KAI_ASM_INST(0xa1404327) // ld1w { z7.s, z15.s }, pn8.b/Z, [x25] + fmul z8.s, z8.s, z6.s + fmul z9.s, z9.s, z14.s + fmul z10.s, z10.s, z7.s + fmul z11.s, z11.s, z15.s + KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s + KAI_ASM_INST(0xc1b1cc28) // sclamp { z8.s-z11.s }, z1.s, z17.s + uzp1 z8.h, z8.h, z9.h + uzp1 z16.h, z10.h, z11.h + uzp1 z8.b, z8.b, z16.b + st1b { z8.b }, p1, [x13, #3, MUL VL] + addvl x13, x13, #4 + subs x17, x17, #0x4 + mov x15, x23 + sub x5, x5, x6, LSL #2 + bgt label_4 +KAI_ASM_LABEL(label_32) // Exit + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c index 5f6bd2b2..bc686223 100644 --- a/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c @@ -4,13 +4,9 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" #include @@ -18,25 +14,49 @@ #include "kai/kai_common.h" +typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + int32_t min; + int32_t max; + int32_t result_zero_point; + const int n_0; + void* accumulator_buffer; + uint64_t flags; +} KernelArgs; + static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 4; static const size_t kai_sr = 1; +void kai_kernel_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(KernelArgs* args); + +// Returns a constant value specific to this kernel that's relative to vector length +static size_t kai_get_kernel_vec_length_constant(void) { + const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u8() / kai_kr; + return kernel_vec_length_constant; +} + size_t kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u32(); + return kai_mr * kai_get_kernel_vec_length_constant(); } size_t kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_kernel_vec_length_constant(); } size_t kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u32(); + return kai_mr * kai_get_kernel_vec_length_constant(); } size_t kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_kernel_vec_length_constant(); } size_t kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { @@ -52,17 +72,23 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vl return m_idx * kai_roundup(k, kai_kr) * sizeof(int8_t); } +static size_t kai_get_rhs_packed_stride_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t k) { + return kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() * + (sizeof(int32_t) + kai_roundup(k, kai_kr) * sizeof(int8_t) + sizeof(float)); +} + size_t kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); - return n_idx * (sizeof(int32_t) + kai_roundup(k, kai_kr) * sizeof(int8_t) + sizeof(float)); + const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(); + return block_idx * kai_get_rhs_packed_stride_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(k); } size_t kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_stride) { + size_t m_idx, size_t n_idx, size_t dst_stride_row) { KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); - return m_idx * dst_stride + n_idx * sizeof(int8_t); + return m_idx * dst_stride_row + n_idx * sizeof(int8_t); } size_t kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t m, size_t n) { @@ -73,27 +99,10 @@ void kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, const struct kai_matmul_requantize32_params* params) { KAI_ASSUME(dst_stride_col == sizeof(int8_t)); - - typedef struct { - const void* A; - const void* B; - - void* C; - uint64_t ldcb; - uint64_t M, N, K; - int32_t min; - int32_t max; - int32_t result_zero_point; - - void* accumulator_buffer; - uint64_t flags; - } KernelArgs; - KernelArgs args; args.A = lhs_packed; args.B = rhs_packed; - args.C = dst; args.ldcb = dst_stride_row; args.M = m; @@ -102,301 +111,10 @@ void kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( args.min = params->min_value; args.max = params->max_value; args.result_zero_point = params->output_zero_point; - args.accumulator_buffer = NULL; args.flags = 0; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "ldr w14, [%x[args], %[offsetof_M]]\n" - "mov x13, #0x0\n" - "mov x11, #0x0\n" - "ptrue p1.b\n" - ".inst 0x25207811 // ptrue pn9.b\n" - "ldr w10, [%x[args], %[offsetof_N]]\n" - "ldr x9, [%x[args], %[offsetof_A]]\n" - "1:" // M loop - "ldr x28, [%x[args], %[offsetof_B]]\n" - "2:" // N loop - ".inst 0x25aa4570 // whilelt pn8.s, x11, x10, VLx2\n" - ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" - "mov x27, x9\n" - ".inst 0xa040438e // ld1w { z14.s-z15.s }, p8/Z, [x28]\n" // Load bias - "addvl x28, x28, #2\n" - ".inst 0xc09025c0 // addha za0.s, p1/M, p1/M, z14.s\n" - ".inst 0xc09025e1 // addha za1.s, p1/M, p1/M, z15.s\n" - ".inst 0xc09025c2 // addha za2.s, p1/M, p1/M, z14.s\n" - ".inst 0xc09025e3 // addha za3.s, p1/M, p1/M, z15.s\n" - "ldr x20, [%x[args], %[offsetof_K]]\n" - "add x20, x20, #0x3\n" - "lsr x20, x20, #0x2\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 6f\n" - "subs x21, x21, #0x1\n" - ".inst 0xa0400762 // ld1b { z2.b-z3.b }, pn9.b/Z, [x27]\n" - ".inst 0xa1400780 // ld1b { z0.b, z8.b }, pn9.b/Z, [x28]\n" - ".inst 0xa0410772 // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0xa0410794 // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL]\n" - ".inst 0xa042077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa0420796 // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL]\n" - ".inst 0xa0430778 // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - ".inst 0xa0430784 // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL]\n" - "addvl x28, x28, #8\n" - "ble 5f\n" - "4:" // K loop - ".inst 0xa0802440 // smopa za0.s, p1/M, p1/M, z2.b, z0.b\n" - "subs x21, x21, #0x1\n" - ".inst 0xa0882441 // smopa za1.s, p1/M, p1/M, z2.b, z8.b\n" - ".inst 0xa0802462 // smopa za2.s, p1/M, p1/M, z3.b, z0.b\n" - ".inst 0xa0882463 // smopa za3.s, p1/M, p1/M, z3.b, z8.b\n" - ".inst 0xa0400762 // ld1b { z2.b-z3.b }, pn9.b/Z, [x27]\n" - ".inst 0xa0942640 // smopa za0.s, p1/M, p1/M, z18.b, z20.b\n" - ".inst 0xa1400780 // ld1b { z0.b, z8.b }, pn9.b/Z, [x28]\n" - ".inst 0xa0952641 // smopa za1.s, p1/M, p1/M, z18.b, z21.b\n" - ".inst 0xa0942662 // smopa za2.s, p1/M, p1/M, z19.b, z20.b\n" - ".inst 0xa0952663 // smopa za3.s, p1/M, p1/M, z19.b, z21.b\n" - ".inst 0xa0410772 // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0xa0962740 // smopa za0.s, p1/M, p1/M, z26.b, z22.b\n" - ".inst 0xa0410794 // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL]\n" - ".inst 0xa0972741 // smopa za1.s, p1/M, p1/M, z26.b, z23.b\n" - ".inst 0xa0962762 // smopa za2.s, p1/M, p1/M, z27.b, z22.b\n" - ".inst 0xa0972763 // smopa za3.s, p1/M, p1/M, z27.b, z23.b\n" - ".inst 0xa042077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa0420796 // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL]\n" - ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" - ".inst 0xa0852701 // smopa za1.s, p1/M, p1/M, z24.b, z5.b\n" - ".inst 0xa0842722 // smopa za2.s, p1/M, p1/M, z25.b, z4.b\n" - ".inst 0xa0852723 // smopa za3.s, p1/M, p1/M, z25.b, z5.b\n" - ".inst 0xa0430778 // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - ".inst 0xa0430784 // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL]\n" - "addvl x28, x28, #8\n" - "bgt 4b\n" - "5:" // K loop tail - ".inst 0xa0802440 // smopa za0.s, p1/M, p1/M, z2.b, z0.b\n" - ".inst 0xa0882441 // smopa za1.s, p1/M, p1/M, z2.b, z8.b\n" - ".inst 0xa0802462 // smopa za2.s, p1/M, p1/M, z3.b, z0.b\n" - ".inst 0xa0882463 // smopa za3.s, p1/M, p1/M, z3.b, z8.b\n" - ".inst 0xa0942640 // smopa za0.s, p1/M, p1/M, z18.b, z20.b\n" - ".inst 0xa0952641 // smopa za1.s, p1/M, p1/M, z18.b, z21.b\n" - ".inst 0xa0942662 // smopa za2.s, p1/M, p1/M, z19.b, z20.b\n" - ".inst 0xa0952663 // smopa za3.s, p1/M, p1/M, z19.b, z21.b\n" - ".inst 0xa0962740 // smopa za0.s, p1/M, p1/M, z26.b, z22.b\n" - ".inst 0xa0972741 // smopa za1.s, p1/M, p1/M, z26.b, z23.b\n" - ".inst 0xa0962762 // smopa za2.s, p1/M, p1/M, z27.b, z22.b\n" - ".inst 0xa0972763 // smopa za3.s, p1/M, p1/M, z27.b, z23.b\n" - ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" - ".inst 0xa0852701 // smopa za1.s, p1/M, p1/M, z24.b, z5.b\n" - ".inst 0xa0842722 // smopa za2.s, p1/M, p1/M, z25.b, z4.b\n" - ".inst 0xa0852723 // smopa za3.s, p1/M, p1/M, z25.b, z5.b\n" - "6:" // K oddments - "cbz x20, 8f\n" - "7:" // K oddments: Loop - ".inst 0xa0400770 // ld1b { z16.b-z17.b }, pn9.b/Z, [x27]\n" - "subs x20, x20, #0x1\n" - "addvl x27, x27, #2\n" - ".inst 0xa0400788 // ld1b { z8.b-z9.b }, pn9.b/Z, [x28]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0882600 // smopa za0.s, p1/M, p1/M, z16.b, z8.b\n" - ".inst 0xa0892601 // smopa za1.s, p1/M, p1/M, z16.b, z9.b\n" - ".inst 0xa0882622 // smopa za2.s, p1/M, p1/M, z17.b, z8.b\n" - ".inst 0xa0892623 // smopa za3.s, p1/M, p1/M, z17.b, z9.b\n" - "bgt 7b\n" - "8:" // K oddments: End - "ldr x26, [%x[args], %[offsetof_C]]\n" - "sub x25, x14, x13\n" - "cntw x24\n" - "ld1rw { z27.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" - "ldr x23, [%x[args], %[offsetof_ldcb]]\n" - "whilelt p0.h, x11, x10\n" - "cmp x25, x24\n" - "ld1rw { z1.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" - "csel x22, x25, x24, LT\n" - "ld1rw { z0.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_result_zero_point]]\n" - "mov x12, #0x0\n" - "add x26, x26, x11\n" // C += n - "lsr x21, x22, #0x2\n" - "ld1w { z22.s }, p1/Z, [x28]\n" - "madd x26, x13, x23, x26\n" // C += m * ldc - "ld1w { z26.s }, p1/Z, [x28, #1, MUL VL]\n" - "and x20, x22, #0x3\n" - "addvl x28, x28, #2\n" - "cbz x21, 11f\n" - "10:" // Store to output array: Accumulator row 0 loop - ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" - ".inst 0xc086043c // mova { z28.s-z31.s }, za1h.s[x12]\n" - ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" - "fmul z16.s, z16.s, z22.s\n" - "fmul z17.s, z17.s, z22.s\n" - "add x12, x12, #0x4\n" - "fmul z18.s, z18.s, z22.s\n" - "fmul z19.s, z19.s, z22.s\n" - "cmp x12, x21, LSL #2\n" - "fmul z28.s, z28.s, z26.s\n" - "fmul z29.s, z29.s, z26.s\n" - "fmul z30.s, z30.s, z26.s\n" - "fmul z31.s, z31.s, z26.s\n" - ".inst 0xc1b8e210 // frintn { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc131e210 // fcvtzs { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc1b8e39c // frintn { z28.s-z31.s }, { z28.s-z31.s }\n" - ".inst 0xc1a0ab10 // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s\n" - ".inst 0xc131e39c // fcvtzs { z28.s-z31.s }, { z28.s-z31.s }\n" - ".inst 0xc1a0ab1c // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s\n" - ".inst 0xc1a1cf70 // sclamp { z16.s-z19.s }, z27.s, z1.s\n" - ".inst 0xc1a1cf7c // sclamp { z28.s-z31.s }, z27.s, z1.s\n" - "uzp1 z5.h, z16.h, z28.h\n" - "uzp1 z20.h, z17.h, z29.h\n" - "uzp1 z17.h, z18.h, z30.h\n" - "uzp1 z16.h, z19.h, z31.h\n" - "st1b { z5.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z20.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z17.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "blt 10b\n" - "11:" // Store to output array: Accumulator row 0 oddments - "cbz x20, 12f\n" - ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" - ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" - ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" - "fmul z4.s, z4.s, z22.s\n" - "fmul z5.s, z5.s, z22.s\n" - "subs x20, x20, #0x1\n" - "fmul z6.s, z6.s, z22.s\n" - "fmul z7.s, z7.s, z22.s\n" - "fmul z12.s, z12.s, z26.s\n" - "fmul z13.s, z13.s, z26.s\n" - "fmul z14.s, z14.s, z26.s\n" - "fmul z15.s, z15.s, z26.s\n" - ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc1b8e18c // frintn { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" - ".inst 0xc131e18c // fcvtzs { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc1a0ab0c // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s\n" - ".inst 0xc1a1cf64 // sclamp { z4.s-z7.s }, z27.s, z1.s\n" - ".inst 0xc1a1cf6c // sclamp { z12.s-z15.s }, z27.s, z1.s\n" - "uzp1 z16.h, z4.h, z12.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "beq 12f\n" - "subs x20, x20, #0x1\n" - "uzp1 z16.h, z5.h, z13.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "beq 12f\n" - "uzp1 z16.h, z6.h, z14.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "12:" // Store to output array: Accumulator row 0 oddments: End - "subs x25, x25, x22\n" - "beq 16f\n" - "cmp x25, x24\n" - "mov x12, #0x0\n" - "csel x20, x25, x24, LT\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 14f\n" - "13:" // Store to output array: Accumulator row 1 loop - ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" - ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" - ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" - "fmul z8.s, z8.s, z22.s\n" - "fmul z9.s, z9.s, z22.s\n" - "add x12, x12, #0x4\n" - "fmul z10.s, z10.s, z22.s\n" - "fmul z11.s, z11.s, z22.s\n" - "cmp x12, x21, LSL #2\n" - "fmul z16.s, z16.s, z26.s\n" - "fmul z17.s, z17.s, z26.s\n" - "fmul z18.s, z18.s, z26.s\n" - "fmul z19.s, z19.s, z26.s\n" - ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" - ".inst 0xc1b8e210 // frintn { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc1a0ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s\n" - ".inst 0xc131e210 // fcvtzs { z16.s-z19.s }, { z16.s-z19.s }\n" - ".inst 0xc1a0ab10 // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s\n" - ".inst 0xc1a1cf68 // sclamp { z8.s-z11.s }, z27.s, z1.s\n" - ".inst 0xc1a1cf70 // sclamp { z16.s-z19.s }, z27.s, z1.s\n" - "uzp1 z21.h, z8.h, z16.h\n" - "uzp1 z20.h, z9.h, z17.h\n" - "uzp1 z17.h, z10.h, z18.h\n" - "uzp1 z16.h, z11.h, z19.h\n" - "st1b { z21.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z20.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z17.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "blt 13b\n" - "14:" // Store to output array: Accumulator row 1 oddments - "cbz x20, 15f\n" - ".inst 0xc086044c // mova { z12.s-z15.s }, za2h.s[x12]\n" - ".inst 0xc0860464 // mova { z4.s-z7.s }, za3h.s[x12]\n" - ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" - "fmul z12.s, z12.s, z22.s\n" - "fmul z13.s, z13.s, z22.s\n" - "subs x20, x20, #0x1\n" - "fmul z14.s, z14.s, z22.s\n" - "fmul z15.s, z15.s, z22.s\n" - "fmul z4.s, z4.s, z26.s\n" - "fmul z5.s, z5.s, z26.s\n" - "fmul z6.s, z6.s, z26.s\n" - "fmul z7.s, z7.s, z26.s\n" - ".inst 0xc1b8e18c // frintn { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc131e18c // fcvtzs { z12.s-z15.s }, { z12.s-z15.s }\n" - ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc1a0ab0c // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s\n" - ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" - ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" - ".inst 0xc1a1cf6c // sclamp { z12.s-z15.s }, z27.s, z1.s\n" - ".inst 0xc1a1cf64 // sclamp { z4.s-z7.s }, z27.s, z1.s\n" - "uzp1 z16.h, z12.h, z4.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "beq 15f\n" - "subs x20, x20, #0x1\n" - "uzp1 z16.h, z13.h, z5.h\n" - "st1b { z16.h }, p0, [x26]\n" - "add x26, x26, x23\n" - "beq 15f\n" - "uzp1 z16.h, z14.h, z6.h\n" - "st1b { z16.h }, p0, [x26]\n" - "15:" // Store to output array: Accumulator row 1 oddments: End - "16:" // Store to output array: End - "incw x11, ALL, MUL #2\n" - "cmp x11, x10\n" - "blt 2b\n" - "incw x13, ALL, MUL #2\n" - "mov x11, #0x0\n" - "cmp x13, x14\n" - "mov x9, x27\n" - "blt 1b\n" - ".inst 0xd503467f // SMSTOP\n" - : - : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), - [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), - [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), - [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), - [offsetof_KernelArgs_result_zero_point] "I"(offsetof(KernelArgs, result_zero_point)), - [offsetof_M] "I"(offsetof(KernelArgs, M)), [offsetof_N] "I"(offsetof(KernelArgs, N)), - [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) - : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", - "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", - "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", - "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + kai_kernel_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h index e8e2868b..f3df3789 100644 --- a/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -63,7 +63,7 @@ size_t kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( /// Gets the offset in bytes to the data element in the packed LHS matrix buffer. /// -/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. /// @param[in] k Number of columns in the unpacked LHS matrix. /// /// @return The offset in bytes to the data element. @@ -71,21 +71,21 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vl /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// -/// @param[in] n_idx Column index in the unpacked RHS matrix. -/// @param[in] k Number of rows in the unpacked RHS matrix. +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of `n_step`. +/// @param[in] k Number of columns in the unpacked LHS matrix. /// /// @return The offset in bytes to the data element. size_t kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// -/// @param[in] m_idx Row index. -/// @param[in] n_idx Column index. -/// @param[in] dst_stride Row stride in bytes. +/// @param[in] m_idx Row index. Must be a multiple of `m_step`. +/// @param[in] n_idx Column index. Must be a multiple of `n_step`. +/// @param[in] dst_stride_row Row stride in bytes. /// /// @return The offset in bytes to the data element. size_t kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_stride); + size_t m_idx, size_t n_idx, size_t dst_stride_row); /// Gets the size in bytes of the destination matrix buffer. /// @@ -106,13 +106,13 @@ size_t kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2 /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. -/// @param[in] k Common dimension of the LHS and RHS operands. -/// @param[in] packed_lhs Packed LHS matrix buffer. -/// @param[in] packed_rhs Packed RHS matrix buffer. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] lhs_packed Packed LHS matrix buffer. +/// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. /// @param[in] dst_stride_row Row stride in bytes of the output matrix. -/// @param[in] dst_stride_col Column stride in bytes of the output matrix. -/// @param[in] params Requantization and clamp parmaters. +/// @param[in] dst_stride_col Column stride in bytes of the output matrix. Must be 1 +/// @param[in] params Requantization and clamp parameters. void kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, const struct kai_matmul_requantize32_params* params); diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S new file mode 100644 index 00000000..5784542e --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S @@ -0,0 +1,336 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x14, #0x0 + ptrue p1.b + KAI_ASM_INST(0x25207811) // ptrue pn9.b + ldr w13, [x0, #0x20] + ldr w11, [x0, #0x28] + mov x10, #0x0 + ldr x9, [x0, #0x0] +KAI_ASM_LABEL(label_1) // M loop + ldr x28, [x0, #0x8] +KAI_ASM_LABEL(label_2) // N loop + KAI_ASM_INST(0x25ab4550) // whilelt pn8.s, x10, x11, VLx2 + KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } + mov x27, x9 + KAI_ASM_INST(0xa040438e) // ld1w { z14.s-z15.s }, p8/Z, [x28] // Load bias + addvl x28, x28, #2 + KAI_ASM_INST(0xc09025c0) // addha za0.s, p1/M, p1/M, z14.s + KAI_ASM_INST(0xc09025e1) // addha za1.s, p1/M, p1/M, z15.s + KAI_ASM_INST(0xc09025c2) // addha za2.s, p1/M, p1/M, z14.s + KAI_ASM_INST(0xc09025e3) // addha za3.s, p1/M, p1/M, z15.s + ldr x20, [x0, #0x30] + add x20, x20, #0x3 + lsr x20, x20, #0x2 + lsr x21, x20, #0x2 + and x20, x20, #0x3 + cbz x21, label_6 + subs x21, x21, #0x1 + KAI_ASM_INST(0xa0400762) // ld1b { z2.b-z3.b }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa1400780) // ld1b { z0.b, z8.b }, pn9.b/Z, [x28] + KAI_ASM_INST(0xa0410772) // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL] + KAI_ASM_INST(0xa0410794) // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL] + KAI_ASM_INST(0xa042077a) // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL] + KAI_ASM_INST(0xa0420796) // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL] + KAI_ASM_INST(0xa0430778) // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL] + addvl x27, x27, #8 + KAI_ASM_INST(0xa0430784) // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL] + addvl x28, x28, #8 + ble label_5 +KAI_ASM_LABEL(label_4) // K loop + KAI_ASM_INST(0xa0802440) // smopa za0.s, p1/M, p1/M, z2.b, z0.b + subs x21, x21, #0x1 + KAI_ASM_INST(0xa0882441) // smopa za1.s, p1/M, p1/M, z2.b, z8.b + KAI_ASM_INST(0xa0802462) // smopa za2.s, p1/M, p1/M, z3.b, z0.b + KAI_ASM_INST(0xa0882463) // smopa za3.s, p1/M, p1/M, z3.b, z8.b + KAI_ASM_INST(0xa0400762) // ld1b { z2.b-z3.b }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa0942640) // smopa za0.s, p1/M, p1/M, z18.b, z20.b + KAI_ASM_INST(0xa1400780) // ld1b { z0.b, z8.b }, pn9.b/Z, [x28] + KAI_ASM_INST(0xa0952641) // smopa za1.s, p1/M, p1/M, z18.b, z21.b + KAI_ASM_INST(0xa0942662) // smopa za2.s, p1/M, p1/M, z19.b, z20.b + KAI_ASM_INST(0xa0952663) // smopa za3.s, p1/M, p1/M, z19.b, z21.b + KAI_ASM_INST(0xa0410772) // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL] + KAI_ASM_INST(0xa0962740) // smopa za0.s, p1/M, p1/M, z26.b, z22.b + KAI_ASM_INST(0xa0410794) // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL] + KAI_ASM_INST(0xa0972741) // smopa za1.s, p1/M, p1/M, z26.b, z23.b + KAI_ASM_INST(0xa0962762) // smopa za2.s, p1/M, p1/M, z27.b, z22.b + KAI_ASM_INST(0xa0972763) // smopa za3.s, p1/M, p1/M, z27.b, z23.b + KAI_ASM_INST(0xa042077a) // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL] + KAI_ASM_INST(0xa0420796) // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL] + KAI_ASM_INST(0xa0842700) // smopa za0.s, p1/M, p1/M, z24.b, z4.b + KAI_ASM_INST(0xa0852701) // smopa za1.s, p1/M, p1/M, z24.b, z5.b + KAI_ASM_INST(0xa0842722) // smopa za2.s, p1/M, p1/M, z25.b, z4.b + KAI_ASM_INST(0xa0852723) // smopa za3.s, p1/M, p1/M, z25.b, z5.b + KAI_ASM_INST(0xa0430778) // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL] + addvl x27, x27, #8 + KAI_ASM_INST(0xa0430784) // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL] + addvl x28, x28, #8 + bgt label_4 +KAI_ASM_LABEL(label_5) // K loop tail + KAI_ASM_INST(0xa0802440) // smopa za0.s, p1/M, p1/M, z2.b, z0.b + KAI_ASM_INST(0xa0882441) // smopa za1.s, p1/M, p1/M, z2.b, z8.b + KAI_ASM_INST(0xa0802462) // smopa za2.s, p1/M, p1/M, z3.b, z0.b + KAI_ASM_INST(0xa0882463) // smopa za3.s, p1/M, p1/M, z3.b, z8.b + KAI_ASM_INST(0xa0942640) // smopa za0.s, p1/M, p1/M, z18.b, z20.b + KAI_ASM_INST(0xa0952641) // smopa za1.s, p1/M, p1/M, z18.b, z21.b + KAI_ASM_INST(0xa0942662) // smopa za2.s, p1/M, p1/M, z19.b, z20.b + KAI_ASM_INST(0xa0952663) // smopa za3.s, p1/M, p1/M, z19.b, z21.b + KAI_ASM_INST(0xa0962740) // smopa za0.s, p1/M, p1/M, z26.b, z22.b + KAI_ASM_INST(0xa0972741) // smopa za1.s, p1/M, p1/M, z26.b, z23.b + KAI_ASM_INST(0xa0962762) // smopa za2.s, p1/M, p1/M, z27.b, z22.b + KAI_ASM_INST(0xa0972763) // smopa za3.s, p1/M, p1/M, z27.b, z23.b + KAI_ASM_INST(0xa0842700) // smopa za0.s, p1/M, p1/M, z24.b, z4.b + KAI_ASM_INST(0xa0852701) // smopa za1.s, p1/M, p1/M, z24.b, z5.b + KAI_ASM_INST(0xa0842722) // smopa za2.s, p1/M, p1/M, z25.b, z4.b + KAI_ASM_INST(0xa0852723) // smopa za3.s, p1/M, p1/M, z25.b, z5.b +KAI_ASM_LABEL(label_6) // K oddments + cbz x20, label_8 +KAI_ASM_LABEL(label_7) // K oddments: Loop + KAI_ASM_INST(0xa0400770) // ld1b { z16.b-z17.b }, pn9.b/Z, [x27] + subs x20, x20, #0x1 + addvl x27, x27, #2 + KAI_ASM_INST(0xa0400788) // ld1b { z8.b-z9.b }, pn9.b/Z, [x28] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0882600) // smopa za0.s, p1/M, p1/M, z16.b, z8.b + KAI_ASM_INST(0xa0892601) // smopa za1.s, p1/M, p1/M, z16.b, z9.b + KAI_ASM_INST(0xa0882622) // smopa za2.s, p1/M, p1/M, z17.b, z8.b + KAI_ASM_INST(0xa0892623) // smopa za3.s, p1/M, p1/M, z17.b, z9.b + bgt label_7 +KAI_ASM_LABEL(label_8) // K oddments: End + ldr x26, [x0, #0x10] + sub x25, x13, x14 + cntw x24 + ld1rw { z27.s }, p1/Z, [x0, #56] + ldr x23, [x0, #0x18] + whilelt p0.h, x10, x11 + cmp x25, x24 + ld1rw { z1.s }, p1/Z, [x0, #60] + csel x22, x25, x24, LT + ld1rw { z0.s }, p1/Z, [x0, #64] + mov x12, #0x0 + add x26, x26, x10 // C += n + lsr x21, x22, #0x2 + ld1w { z22.s }, p1/Z, [x28] + madd x26, x14, x23, x26 // C += m * ldc + ld1w { z26.s }, p1/Z, [x28, #1, MUL VL] + and x20, x22, #0x3 + addvl x28, x28, #2 + cbz x21, label_11 +KAI_ASM_LABEL(label_10) // Store to output array: Accumulator row 0 loop + KAI_ASM_INST(0xc0860410) // mova { z16.s-z19.s }, za0h.s[x12] + KAI_ASM_INST(0xc086043c) // mova { z28.s-z31.s }, za1h.s[x12] + KAI_ASM_INST(0xc132e210) // scvtf { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc132e39c) // scvtf { z28.s-z31.s }, { z28.s-z31.s } + fmul z16.s, z16.s, z22.s + fmul z17.s, z17.s, z22.s + add x12, x12, #0x4 + fmul z18.s, z18.s, z22.s + fmul z19.s, z19.s, z22.s + cmp x12, x21, LSL #2 + fmul z28.s, z28.s, z26.s + fmul z29.s, z29.s, z26.s + fmul z30.s, z30.s, z26.s + fmul z31.s, z31.s, z26.s + KAI_ASM_INST(0xc1b8e210) // frintn { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc131e210) // fcvtzs { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc1b8e39c) // frintn { z28.s-z31.s }, { z28.s-z31.s } + KAI_ASM_INST(0xc1a0ab10) // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s + KAI_ASM_INST(0xc131e39c) // fcvtzs { z28.s-z31.s }, { z28.s-z31.s } + KAI_ASM_INST(0xc1a0ab1c) // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s + KAI_ASM_INST(0xc1a1cf70) // sclamp { z16.s-z19.s }, z27.s, z1.s + KAI_ASM_INST(0xc1a1cf7c) // sclamp { z28.s-z31.s }, z27.s, z1.s + uzp1 z5.h, z16.h, z28.h + uzp1 z20.h, z17.h, z29.h + uzp1 z17.h, z18.h, z30.h + uzp1 z16.h, z19.h, z31.h + st1b { z5.h }, p0, [x26] + add x26, x26, x23 + st1b { z20.h }, p0, [x26] + add x26, x26, x23 + st1b { z17.h }, p0, [x26] + add x26, x26, x23 + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + blt label_10 +KAI_ASM_LABEL(label_11) // Store to output array: Accumulator row 0 oddments + cbz x20, label_12 + KAI_ASM_INST(0xc0860404) // mova { z4.s-z7.s }, za0h.s[x12] + KAI_ASM_INST(0xc086042c) // mova { z12.s-z15.s }, za1h.s[x12] + KAI_ASM_INST(0xc132e084) // scvtf { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc132e18c) // scvtf { z12.s-z15.s }, { z12.s-z15.s } + fmul z4.s, z4.s, z22.s + fmul z5.s, z5.s, z22.s + subs x20, x20, #0x1 + fmul z6.s, z6.s, z22.s + fmul z7.s, z7.s, z22.s + fmul z12.s, z12.s, z26.s + fmul z13.s, z13.s, z26.s + fmul z14.s, z14.s, z26.s + fmul z15.s, z15.s, z26.s + KAI_ASM_INST(0xc1b8e084) // frintn { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc131e084) // fcvtzs { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc1b8e18c) // frintn { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc1a0ab04) // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s + KAI_ASM_INST(0xc131e18c) // fcvtzs { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc1a0ab0c) // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s + KAI_ASM_INST(0xc1a1cf64) // sclamp { z4.s-z7.s }, z27.s, z1.s + KAI_ASM_INST(0xc1a1cf6c) // sclamp { z12.s-z15.s }, z27.s, z1.s + uzp1 z16.h, z4.h, z12.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + beq label_12 + subs x20, x20, #0x1 + uzp1 z16.h, z5.h, z13.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + beq label_12 + uzp1 z16.h, z6.h, z14.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 +KAI_ASM_LABEL(label_12) // Store to output array: Accumulator row 0 oddments: End + subs x25, x25, x22 + beq label_16 + cmp x25, x24 + mov x12, #0x0 + csel x20, x25, x24, LT + lsr x21, x20, #0x2 + and x20, x20, #0x3 + cbz x21, label_14 +KAI_ASM_LABEL(label_13) // Store to output array: Accumulator row 1 loop + KAI_ASM_INST(0xc0860448) // mova { z8.s-z11.s }, za2h.s[x12] + KAI_ASM_INST(0xc0860470) // mova { z16.s-z19.s }, za3h.s[x12] + KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc132e210) // scvtf { z16.s-z19.s }, { z16.s-z19.s } + fmul z8.s, z8.s, z22.s + fmul z9.s, z9.s, z22.s + add x12, x12, #0x4 + fmul z10.s, z10.s, z22.s + fmul z11.s, z11.s, z22.s + cmp x12, x21, LSL #2 + fmul z16.s, z16.s, z26.s + fmul z17.s, z17.s, z26.s + fmul z18.s, z18.s, z26.s + fmul z19.s, z19.s, z26.s + KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc1b8e210) // frintn { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s + KAI_ASM_INST(0xc131e210) // fcvtzs { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc1a0ab10) // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s + KAI_ASM_INST(0xc1a1cf68) // sclamp { z8.s-z11.s }, z27.s, z1.s + KAI_ASM_INST(0xc1a1cf70) // sclamp { z16.s-z19.s }, z27.s, z1.s + uzp1 z21.h, z8.h, z16.h + uzp1 z20.h, z9.h, z17.h + uzp1 z17.h, z10.h, z18.h + uzp1 z16.h, z11.h, z19.h + st1b { z21.h }, p0, [x26] + add x26, x26, x23 + st1b { z20.h }, p0, [x26] + add x26, x26, x23 + st1b { z17.h }, p0, [x26] + add x26, x26, x23 + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + blt label_13 +KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments + cbz x20, label_15 + KAI_ASM_INST(0xc086044c) // mova { z12.s-z15.s }, za2h.s[x12] + KAI_ASM_INST(0xc0860464) // mova { z4.s-z7.s }, za3h.s[x12] + KAI_ASM_INST(0xc132e18c) // scvtf { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc132e084) // scvtf { z4.s-z7.s }, { z4.s-z7.s } + fmul z12.s, z12.s, z22.s + fmul z13.s, z13.s, z22.s + subs x20, x20, #0x1 + fmul z14.s, z14.s, z22.s + fmul z15.s, z15.s, z22.s + fmul z4.s, z4.s, z26.s + fmul z5.s, z5.s, z26.s + fmul z6.s, z6.s, z26.s + fmul z7.s, z7.s, z26.s + KAI_ASM_INST(0xc1b8e18c) // frintn { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc131e18c) // fcvtzs { z12.s-z15.s }, { z12.s-z15.s } + KAI_ASM_INST(0xc1b8e084) // frintn { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc1a0ab0c) // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s + KAI_ASM_INST(0xc131e084) // fcvtzs { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc1a0ab04) // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s + KAI_ASM_INST(0xc1a1cf6c) // sclamp { z12.s-z15.s }, z27.s, z1.s + KAI_ASM_INST(0xc1a1cf64) // sclamp { z4.s-z7.s }, z27.s, z1.s + uzp1 z16.h, z12.h, z4.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + beq label_15 + subs x20, x20, #0x1 + uzp1 z16.h, z13.h, z5.h + st1b { z16.h }, p0, [x26] + add x26, x26, x23 + beq label_15 + uzp1 z16.h, z14.h, z6.h + st1b { z16.h }, p0, [x26] +KAI_ASM_LABEL(label_15) // Store to output array: Accumulator row 1 oddments: End +KAI_ASM_LABEL(label_16) // Store to output array: End + incw x10, ALL, MUL #2 + cmp x10, x11 + blt label_2 + incw x14, ALL, MUL #2 + mov x10, #0x0 + cmp x14, x13 + mov x9, x27 + blt label_1 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c index 511d3ada..584e9761 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c @@ -4,10 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. @@ -18,345 +15,120 @@ #include "kai/kai_common.h" -static const size_t kai_mr = 2; -static const size_t kai_kr = 4; -static const size_t kai_sr = 1; - -static inline size_t kai_get_m_step(void) { - return (kai_mr * kai_get_sme_vector_length_u8()) / kai_kr; +typedef struct { + size_t m; + size_t k; + size_t mr; + size_t kr; + size_t sr; + size_t m_idx_start; + const void* lhs; + size_t lhs_stride; + void* lhs_packed; + size_t height; + size_t width; + const void* const* in; + size_t row_offset; + void* out; +} KernelArgs; + +void kai_kernel_lhs_pack_x8p2vlx4_x8_sme(const KernelArgs* args_ptr); + +enum { + MR = 2, + KR = 4, + MAX_M_STEP = (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(int8_t)) / KR), + SR = 1, +}; + +static size_t kai_get_mr_lhs_pack_x8p2vlx4_x8_sme(void) { + return MR * kai_get_sme_vector_length_u8() / KR; } size_t kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme(size_t mr) { - KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u8() / kai_kr); + KAI_ASSUME(mr == kai_get_mr_lhs_pack_x8p2vlx4_x8_sme()); KAI_UNUSED(mr); - return (kai_mr * kai_get_sme_vector_length_u8()) / kai_kr; + return kai_get_mr_lhs_pack_x8p2vlx4_x8_sme(); } size_t kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme(size_t m_idx, size_t lhs_stride) { - KAI_ASSUME(m_idx % (kai_get_m_step()) == 0); + KAI_ASSUME(m_idx % kai_get_mr_lhs_pack_x8p2vlx4_x8_sme() == 0); return m_idx * lhs_stride; } size_t kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { - const size_t scaled_mr = kai_get_m_step(); - KAI_ASSUME(m_idx % scaled_mr == 0); - KAI_ASSUME(mr == scaled_mr); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(m_idx % kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme(mr) == 0); + KAI_ASSUME(mr == kai_get_mr_lhs_pack_x8p2vlx4_x8_sme()); + KAI_ASSUME(kr == KR); + KAI_ASSUME(sr == SR); KAI_UNUSED(mr); KAI_UNUSED(kr); KAI_UNUSED(sr); - return m_idx * kai_roundup(k, kai_kr) * sizeof(int8_t); + return m_idx * kai_roundup(k, kr) * sizeof(int8_t); } size_t kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { - KAI_ASSUME(mr == kai_get_m_step()); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(mr == kai_get_mr_lhs_pack_x8p2vlx4_x8_sme()); + KAI_ASSUME(kr == KR); + KAI_ASSUME(sr == SR); KAI_UNUSED(mr); KAI_UNUSED(kr); KAI_UNUSED(sr); - return (kai_roundup(m, kai_get_m_step()) * kai_roundup(k, kai_kr) * sizeof(int8_t)); + return kai_roundup(m, kai_get_mr_lhs_pack_x8p2vlx4_x8_sme()) * kai_roundup(k, KR) * sizeof(int8_t); } void kai_run_lhs_pack_x8p2vlx4_x8_sme( size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { - KAI_ASSUME(mr == kai_get_m_step()); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(mr == kai_get_mr_lhs_pack_x8p2vlx4_x8_sme()); + KAI_ASSUME(kr == KR); + KAI_ASSUME(sr == SR); KAI_ASSUME(lhs != NULL); KAI_ASSUME(lhs_packed != NULL); KAI_ASSUME(m_idx_start == 0); - const size_t block_height = kai_get_m_step(); + const size_t m_step = kai_get_mr_lhs_pack_x8p2vlx4_x8_sme(); + const size_t block_height = mr; const size_t width = k; const size_t row_offset = 0; - const void* in[block_height]; - const uint8_t* lhs_ptr = lhs; - uint8_t* lhs_packed_ptr = lhs_packed; + KAI_ASSERT(m_step <= MAX_M_STEP); + const void* in[MAX_M_STEP]; + uint8_t* lhs_packed_ptr = lhs_packed; + const uint8_t* lhs_ptr = lhs; for (size_t block_y = 0; block_y < m; block_y += block_height) { const size_t height = KAI_MIN(m - block_y, block_height); - void* out = (void*)((char*)lhs_packed_ptr + block_y * kai_roundup(k, kai_kr) * sizeof(int8_t)); + void* out = lhs_packed_ptr + block_y * kai_roundup(k, KR) * sizeof(int8_t); for (size_t y = 0; y < height; y++) { in[y] = lhs_ptr + (block_y + y) * lhs_stride; } - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x23, %x[width]\n" - "mov x21, %x[width]\n" - "cntb x20\n" - "incb x23\n" - "sub x7, x20, #0x1\n" - "cntw x8\n" - "sub x23, x23, #0x1\n" - "ands x7, x21, x7\n" - "udiv x23, x23, x20\n" // n_passes = ceildiv(width, VL) - "csel x7, x7, x20, NE\n" - "lsl x22, %x[height], #0x1\n" // height * 2 - "lsl x21, x8, #0x1\n" - "sub x20, x23, #0x1\n" - "add x7, x7, #0x3\n" - "sub x17, x8, #0x2\n" - "whilelt p9.b, XZR, x22\n" - "whilelt p8.b, x21, x22\n" - "mov x16, #0x0\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - "cntw x9, ALL, MUL #2\n" - "cntw x28, ALL, MUL #3\n" - "ldr x27, [x11, #0x0]\n" - "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 - "and x26, x23, #0x1\n" // odd_tail = bool(n_passes & 0x1) - "ldr x25, [x10, #0x0]\n" - "lsr x7, x7, #0x2\n" - "ptrue p11.s\n" - "ldr x24, [x11, #0x8]\n" - "zip1 p10.b, p9.b, p8.b\n" - "mov x23, %x[row_offset]\n" - "ldr x21, [x10, #0x8]\n" - "mov x22, %x[out]\n" - "whilelt p9.b, x16, %x[width]\n" - "whilelt p8.b, x16, %x[width]\n" - "add x11, x11, #0x10\n" - "add x10, x10, #0x10\n" - "mov x12, #0x0\n" - "cbz x17, 2f\n" - "1:" // K loop: Charge: Loop - ".inst 0x25246143 // psel p3.b, p8.b/Z, p10.b[w12]\n" - ".inst 0x252c6142 // psel p2.b, p8.b/Z, p10.b[w12, #1]\n" - ".inst 0x25646141 // psel p1.b, p8.b/Z, p10.b[w12, #4]\n" - ".inst 0x256c6140 // psel p0.b, p8.b/Z, p10.b[w12, #5]\n" - ".inst 0xe0170f60 // ld1b { za0h.b[x12] }, p3/Z, [x27, x23]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0170b21 // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x23]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0170704 // ld1b { za0h.b[x12, #4] }, p1/Z, [x24, x23]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01702a5 // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x23]\n" - "add x12, x12, #0x8\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "cmp x12, x17, LSL #2\n" - "blt 1b\n" - "2:" // K loop: Charge: End - ".inst 0x25246143 // psel p3.b, p8.b/Z, p10.b[w12]\n" - ".inst 0x252c6142 // psel p2.b, p8.b/Z, p10.b[w12, #1]\n" - ".inst 0x25646141 // psel p1.b, p8.b/Z, p10.b[w12, #4]\n" - ".inst 0x256c6140 // psel p0.b, p8.b/Z, p10.b[w12, #5]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0170f60 // ld1b { za0h.b[x12] }, p3/Z, [x27, x23]\n" - "ldr x27, [x11, #0x0]\n" - "incb x16\n" - ".inst 0xe0170b21 // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x23]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0170704 // ld1b { za0h.b[x12, #4] }, p1/Z, [x24, x23]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01702a5 // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "incb x23\n" - "cbz x20, 8f\n" - "mov x20, x20\n" - "3:" // K loop: Main loop - "whilelt p8.b, x16, %x[width]\n" - "mov x15, #0x0\n" - "mov x14, #0x0\n" - "cbz x17, 5f\n" - "4:" // K loop: Main loop: First: Loop - ".inst 0x25376143 // psel p3.b, p8.b/Z, p10.b[w15, #2]\n" - ".inst 0x253f6142 // psel p2.b, p8.b/Z, p10.b[w15, #3]\n" - ".inst 0x25776141 // psel p1.b, p8.b/Z, p10.b[w15, #6]\n" - ".inst 0x257f6140 // psel p0.b, p8.b/Z, p10.b[w15, #7]\n" - ".inst 0xe0176f62 // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x23]\n" - ".inst 0x25266d23 // psel p3.b, p11.b/Z, p9.b[w14]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0176b23 // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x23]\n" - ".inst 0x25266d22 // psel p2.b, p11.b/Z, p9.b[w14]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0176706 // ld1b { za0h.b[x15, #6] }, p1/Z, [x24, x23]\n" - ".inst 0x252e6d21 // psel p1.b, p11.b/Z, p9.b[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01762a7 // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0x252e6d20 // psel p0.b, p11.b/Z, p9.b[w14, #1]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0bfcec0 // st1w { za0v.s[x14] }, p3/Z, [x22, XZR, LSL #2]\n" - "add x15, x15, #0x8\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x14, x14, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x14, x17\n" - "blt 4b\n" - "5:" // K loop: Main loop: First: Tail - ".inst 0x25376143 // psel p3.b, p8.b/Z, p10.b[w15, #2]\n" - ".inst 0x253f6142 // psel p2.b, p8.b/Z, p10.b[w15, #3]\n" - ".inst 0x25776141 // psel p1.b, p8.b/Z, p10.b[w15, #6]\n" - ".inst 0x257f6140 // psel p0.b, p8.b/Z, p10.b[w15, #7]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0176f62 // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x23]\n" - ".inst 0x25266d23 // psel p3.b, p11.b/Z, p9.b[w14]\n" - "ldr x27, [x11, #0x0]\n" - "mov x13, #0x0\n" - ".inst 0xe0176b23 // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x23]\n" - ".inst 0x25266d22 // psel p2.b, p11.b/Z, p9.b[w14]\n" - "ldr x25, [x10, #0x0]\n" - "mov x12, #0x0\n" - ".inst 0xe0176706 // ld1b { za0h.b[x15, #6] }, p1/Z, [x24, x23]\n" - ".inst 0x252e6d21 // psel p1.b, p11.b/Z, p9.b[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01762a7 // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0x252e6d20 // psel p0.b, p11.b/Z, p9.b[w14, #1]\n" - "whilelt p9.b, x16, %x[width]\n" - ".inst 0xe0bfcec0 // st1w { za0v.s[x14] }, p3/Z, [x22, XZR, LSL #2]\n" - "incb x16\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - "incb x23\n" - "whilelt p8.b, x16, %x[width]\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "cbz x17, 7f\n" - "6:" // K loop: Main loop: Second: Loop - ".inst 0x25256143 // psel p3.b, p8.b/Z, p10.b[w13]\n" - ".inst 0x252d6142 // psel p2.b, p8.b/Z, p10.b[w13, #1]\n" - ".inst 0x25656141 // psel p1.b, p8.b/Z, p10.b[w13, #4]\n" - ".inst 0x256d6140 // psel p0.b, p8.b/Z, p10.b[w13, #5]\n" - ".inst 0xe0172f60 // ld1b { za0h.b[x13] }, p3/Z, [x27, x23]\n" - ".inst 0x25246d23 // psel p3.b, p11.b/Z, p9.b[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0172b21 // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x23]\n" - ".inst 0x25246d22 // psel p2.b, p11.b/Z, p9.b[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0172704 // ld1b { za0h.b[x13, #4] }, p1/Z, [x24, x23]\n" - ".inst 0x252c6d21 // psel p1.b, p11.b/Z, p9.b[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01722a5 // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0x252c6d20 // psel p0.b, p11.b/Z, p9.b[w12, #1]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0bf8ec8 // st1w { za2v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - "add x13, x13, #0x8\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x12, x12, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x12, x17\n" - "blt 6b\n" - "7:" // K loop: Main loop: Second: Tail - ".inst 0x25256143 // psel p3.b, p8.b/Z, p10.b[w13]\n" - ".inst 0x252d6142 // psel p2.b, p8.b/Z, p10.b[w13, #1]\n" - ".inst 0x25656141 // psel p1.b, p8.b/Z, p10.b[w13, #4]\n" - ".inst 0x256d6140 // psel p0.b, p8.b/Z, p10.b[w13, #5]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0172f60 // ld1b { za0h.b[x13] }, p3/Z, [x27, x23]\n" - ".inst 0x25246d23 // psel p3.b, p11.b/Z, p9.b[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0172b21 // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x23]\n" - ".inst 0x25246d22 // psel p2.b, p11.b/Z, p9.b[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0172704 // ld1b { za0h.b[x13, #4] }, p1/Z, [x24, x23]\n" - ".inst 0x252c6d21 // psel p1.b, p11.b/Z, p9.b[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe01722a5 // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x23]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0x252c6d20 // psel p0.b, p11.b/Z, p9.b[w12, #1]\n" - "whilelt p9.b, x16, %x[width]\n" - ".inst 0xe0bf8ec8 // st1w { za2v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - "subs x20, x20, #0x1\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "incb x16\n" - "incb x23\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "bgt 3b\n" - "8:" // K loop: Tails - "cbnz x26, 11f\n" - "mov x11, %x[in]\n" - "whilelt p8.b, x16, %x[width]\n" - "mov x13, #0x0\n" - "mov x12, #0x0\n" - "9:" // K loop: Tails: Even: First - ".inst 0x25306d23 // psel p3.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25306d22 // psel p2.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25356141 // psel p1.b, p8.b/Z, p10.b[w13, #2]\n" - ".inst 0x253d6140 // psel p0.b, p8.b/Z, p10.b[w13, #3]\n" - ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "ldr x21, [x11, #0x0]\n" - "cmp x12, x8\n" - "ldr x20, [x11, x8, LSL #0x3]\n" - "add x11, x11, #0x8\n" - ".inst 0xe01726a2 // ld1b { za0h.b[x13, #2] }, p1/Z, [x21, x23]\n" - ".inst 0xe0172283 // ld1b { za0h.b[x13, #3] }, p0/Z, [x20, x23]\n" - "add x13, x13, #0x4\n" - "blt 9b\n" - "whilelt p9.b, x16, %x[width]\n" - "whilelt p8.b, x16, %x[width]\n" - "mov x20, #0x0\n" - "mov x12, #0x0\n" - "10:" // K loop: Tails: Even: Second - ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25306d20 // psel p0.s, p11.s/Z, p9.s[w12]\n" - "add x20, x20, #0x4\n" - ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 10b\n" - "whilelt p8.b, x16, %x[width]\n" - "b 13f\n" - "11:" // K loop: Tails: Odd - "mov x12, #0x0\n" - "12:" // K loop: Tails: Odd: Loop - ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25306d20 // psel p0.s, p11.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 12b\n" - "13:" // K loop: End - "mov %x[out], x22\n" - ".inst 0xd503467f // SMSTOP\n" - : [out] "+&r"(out) - : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", - "p14", "p15", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", - "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", - "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", - "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + KernelArgs args; + args.m = m; + args.k = k; + args.mr = mr; + args.kr = kr; + args.sr = sr; + args.m_idx_start = m_idx_start; + args.lhs = lhs; + args.lhs_stride = lhs_stride; + args.lhs_packed = lhs_packed; + args.height = height; + args.width = width; + args.in = in; + args.row_offset = row_offset; + args.out = out; + + kai_kernel_lhs_pack_x8p2vlx4_x8_sme(&args); } } diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h index 3b95a0a9..a8157bfa 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -34,7 +34,7 @@ size_t kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme(size_t m_idx, size_t lhs_stri /// @param[in] m_idx Row index in the unpacked LHS matrix. /// @param[in] k Number of columns in the unpacked LHS matrix. /// @param[in] mr Number of rows to be interleaved. -/// @param[in] kr Unused. Must be 1. +/// @param[in] kr Unused. Must be 4. /// @param[in] sr Unused. Must be 1. /// /// @return The offset in bytes to the data element. @@ -45,7 +45,7 @@ size_t kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme(size_t m_idx, size_t k /// @param[in] m Number of rows in the unpacked LHS matrix. /// @param[in] k Number of columns in the unpacked LHS matrix. /// @param[in] mr Number of rows to be interleaved. -/// @param[in] kr Unused. Must be 1. +/// @param[in] kr Unused. Must be 4. /// @param[in] sr Unused. Must be 1. /// /// @return The size in bytes of the packed LHS buffer. @@ -61,7 +61,7 @@ size_t kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme(size_t m, size_t k, size /// /// @param[in] m Number of rows of the unpacked LHS matrix. /// @param[in] k Common dimension between the LHS and RHS matrix. -/// @param[in] mr Block size in M dimension. It must be 2 * kai_get_sme_vector_length_u8(). +/// @param[in] mr Block size in M dimension. It must be kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme(). /// @param[in] kr Block size in K dimension. It must be 4. /// @param[in] sr Number of kr splits. It must be 1. /// @param[in] m_idx_start Unused. Must be 0. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S new file mode 100644 index 00000000..f5635ab5 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S @@ -0,0 +1,318 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(lhs_pack_x8p2vlx4_x8_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_lhs_pack_x8p2vlx4_x8_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_pack_x8p2vlx4_x8_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_pack_x8p2vlx4_x8_sme) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x5, #0x0 + ldr x6, [x0, #0x50] + cntb x25 + cntw x7 + ldr x21, [x0, #0x48] + sub x8, x25, #0x1 + lsl x24, x7, #0x1 + ldr x17, [x0, #0x58] + sub x16, x7, #0x2 + cntw x11, ALL, MUL #2 + mov x23, x6 + mov x20, x6 + ldr x22, [x0, #0x60] + incb x23 + ands x8, x20, x8 + ldr x10, [x0, #0x68] + sub x23, x23, #0x1 + csel x8, x8, x25, NE + udiv x23, x23, x25 // n_passes = ceildiv(width, VL) + lsl x21, x21, #0x1 // height * 2 + sub x20, x23, #0x1 + add x8, x8, #0x3 + whilelt p9.b, XZR, x21 + whilelt p8.b, x24, x21 + mov x9, x17 + add x28, x17, x7, LSL #3 + cntw x27, ALL, MUL #3 + lsr x20, x20, #0x1 // n_loops = (n_passes - 1) / 2 + ldr x26, [x9, #0x0] + and x25, x23, #0x1 // odd_tail = bool(n_passes & 0x1) + lsr x8, x8, #0x2 + ldr x24, [x28, #0x0] + ptrue p11.s + zip1 p10.b, p9.b, p8.b + ldr x23, [x9, #0x8] + mov x22, x22 + whilelt p9.b, x5, x6 + ldr x21, [x28, #0x8] + whilelt p8.b, x5, x6 + add x9, x9, #0x10 + add x28, x28, #0x10 + mov x12, #0x0 + cbz x16, label_2 +KAI_ASM_LABEL(label_1) // K loop: Charge: Loop + KAI_ASM_INST(0x25246143) // psel p3.b, p8.b/Z, p10.b[w12] + KAI_ASM_INST(0x252c6142) // psel p2.b, p8.b/Z, p10.b[w12, #1] + KAI_ASM_INST(0x25646141) // psel p1.b, p8.b/Z, p10.b[w12, #4] + KAI_ASM_INST(0x256c6140) // psel p0.b, p8.b/Z, p10.b[w12, #5] + KAI_ASM_INST(0xe0160f40) // ld1b { za0h.b[x12] }, p3/Z, [x26, x22] + ldr x26, [x9, #0x0] + KAI_ASM_INST(0xe0160b01) // ld1b { za0h.b[x12, #1] }, p2/Z, [x24, x22] + ldr x24, [x28, #0x0] + KAI_ASM_INST(0xe01606e4) // ld1b { za0h.b[x12, #4] }, p1/Z, [x23, x22] + ldr x23, [x9, #0x8] + add x9, x9, #0x10 + KAI_ASM_INST(0xe01602a5) // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x22] + add x12, x12, #0x8 + ldr x21, [x28, #0x8] + add x28, x28, #0x10 + cmp x12, x16, LSL #2 + blt label_1 +KAI_ASM_LABEL(label_2) // K loop: Charge: End + KAI_ASM_INST(0x25246143) // psel p3.b, p8.b/Z, p10.b[w12] + KAI_ASM_INST(0x252c6142) // psel p2.b, p8.b/Z, p10.b[w12, #1] + KAI_ASM_INST(0x25646141) // psel p1.b, p8.b/Z, p10.b[w12, #4] + KAI_ASM_INST(0x256c6140) // psel p0.b, p8.b/Z, p10.b[w12, #5] + mov x9, x17 + add x28, x17, x7, LSL #3 + KAI_ASM_INST(0xe0160f40) // ld1b { za0h.b[x12] }, p3/Z, [x26, x22] + ldr x26, [x9, #0x0] + incb x5 + KAI_ASM_INST(0xe0160b01) // ld1b { za0h.b[x12, #1] }, p2/Z, [x24, x22] + ldr x24, [x28, #0x0] + KAI_ASM_INST(0xe01606e4) // ld1b { za0h.b[x12, #4] }, p1/Z, [x23, x22] + ldr x23, [x9, #0x8] + add x9, x9, #0x10 + KAI_ASM_INST(0xe01602a5) // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x22] + ldr x21, [x28, #0x8] + add x28, x28, #0x10 + incb x22 + cbz x20, label_8 + mov x20, x20 +KAI_ASM_LABEL(label_3) // K loop: Main loop + whilelt p8.b, x5, x6 + mov x15, #0x0 + mov x14, #0x0 + cbz x16, label_5 +KAI_ASM_LABEL(label_4) // K loop: Main loop: First: Loop + KAI_ASM_INST(0x25376143) // psel p3.b, p8.b/Z, p10.b[w15, #2] + KAI_ASM_INST(0x253f6142) // psel p2.b, p8.b/Z, p10.b[w15, #3] + KAI_ASM_INST(0x25776141) // psel p1.b, p8.b/Z, p10.b[w15, #6] + KAI_ASM_INST(0x257f6140) // psel p0.b, p8.b/Z, p10.b[w15, #7] + KAI_ASM_INST(0xe0166f42) // ld1b { za0h.b[x15, #2] }, p3/Z, [x26, x22] + KAI_ASM_INST(0x25266d23) // psel p3.b, p11.b/Z, p9.b[w14] + ldr x26, [x9, #0x0] + KAI_ASM_INST(0xe0166b03) // ld1b { za0h.b[x15, #3] }, p2/Z, [x24, x22] + KAI_ASM_INST(0x25266d22) // psel p2.b, p11.b/Z, p9.b[w14] + ldr x24, [x28, #0x0] + KAI_ASM_INST(0xe01666e6) // ld1b { za0h.b[x15, #6] }, p1/Z, [x23, x22] + KAI_ASM_INST(0x252e6d21) // psel p1.b, p11.b/Z, p9.b[w14, #1] + ldr x23, [x9, #0x8] + add x9, x9, #0x10 + KAI_ASM_INST(0xe01662a7) // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x22] + ldr x21, [x28, #0x8] + KAI_ASM_INST(0x252e6d20) // psel p0.b, p11.b/Z, p9.b[w14, #1] + add x28, x28, #0x10 + KAI_ASM_INST(0xe0bfcd40) // st1w { za0v.s[x14] }, p3/Z, [x10, XZR, LSL #2] + add x15, x15, #0x8 + KAI_ASM_INST(0xe0a7c944) // st1w { za1v.s[x14] }, p2/Z, [x10, x7, LSL #2] + KAI_ASM_INST(0xe0abc541) // st1w { za0v.s[x14, #1] }, p1/Z, [x10, x11, LSL #2] + KAI_ASM_INST(0xe0bbc145) // st1w { za1v.s[x14, #1] }, p0/Z, [x10, x27, LSL #2] + add x14, x14, #0x2 + addvl x10, x10, #4 + cmp x14, x16 + blt label_4 +KAI_ASM_LABEL(label_5) // K loop: Main loop: First: Tail + KAI_ASM_INST(0x25376143) // psel p3.b, p8.b/Z, p10.b[w15, #2] + KAI_ASM_INST(0x253f6142) // psel p2.b, p8.b/Z, p10.b[w15, #3] + KAI_ASM_INST(0x25776141) // psel p1.b, p8.b/Z, p10.b[w15, #6] + KAI_ASM_INST(0x257f6140) // psel p0.b, p8.b/Z, p10.b[w15, #7] + mov x9, x17 + add x28, x17, x7, LSL #3 + KAI_ASM_INST(0xe0166f42) // ld1b { za0h.b[x15, #2] }, p3/Z, [x26, x22] + KAI_ASM_INST(0x25266d23) // psel p3.b, p11.b/Z, p9.b[w14] + ldr x26, [x9, #0x0] + mov x13, #0x0 + KAI_ASM_INST(0xe0166b03) // ld1b { za0h.b[x15, #3] }, p2/Z, [x24, x22] + KAI_ASM_INST(0x25266d22) // psel p2.b, p11.b/Z, p9.b[w14] + ldr x24, [x28, #0x0] + mov x12, #0x0 + KAI_ASM_INST(0xe01666e6) // ld1b { za0h.b[x15, #6] }, p1/Z, [x23, x22] + KAI_ASM_INST(0x252e6d21) // psel p1.b, p11.b/Z, p9.b[w14, #1] + ldr x23, [x9, #0x8] + add x9, x9, #0x10 + KAI_ASM_INST(0xe01662a7) // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x22] + ldr x21, [x28, #0x8] + KAI_ASM_INST(0x252e6d20) // psel p0.b, p11.b/Z, p9.b[w14, #1] + whilelt p9.b, x5, x6 + KAI_ASM_INST(0xe0bfcd40) // st1w { za0v.s[x14] }, p3/Z, [x10, XZR, LSL #2] + incb x5 + add x28, x28, #0x10 + KAI_ASM_INST(0xe0a7c944) // st1w { za1v.s[x14] }, p2/Z, [x10, x7, LSL #2] + incb x22 + whilelt p8.b, x5, x6 + KAI_ASM_INST(0xe0abc541) // st1w { za0v.s[x14, #1] }, p1/Z, [x10, x11, LSL #2] + KAI_ASM_INST(0xe0bbc145) // st1w { za1v.s[x14, #1] }, p0/Z, [x10, x27, LSL #2] + addvl x10, x10, #4 + cbz x16, label_7 +KAI_ASM_LABEL(label_6) // K loop: Main loop: Second: Loop + KAI_ASM_INST(0x25256143) // psel p3.b, p8.b/Z, p10.b[w13] + KAI_ASM_INST(0x252d6142) // psel p2.b, p8.b/Z, p10.b[w13, #1] + KAI_ASM_INST(0x25656141) // psel p1.b, p8.b/Z, p10.b[w13, #4] + KAI_ASM_INST(0x256d6140) // psel p0.b, p8.b/Z, p10.b[w13, #5] + KAI_ASM_INST(0xe0162f40) // ld1b { za0h.b[x13] }, p3/Z, [x26, x22] + KAI_ASM_INST(0x25246d23) // psel p3.b, p11.b/Z, p9.b[w12] + ldr x26, [x9, #0x0] + KAI_ASM_INST(0xe0162b01) // ld1b { za0h.b[x13, #1] }, p2/Z, [x24, x22] + KAI_ASM_INST(0x25246d22) // psel p2.b, p11.b/Z, p9.b[w12] + ldr x24, [x28, #0x0] + KAI_ASM_INST(0xe01626e4) // ld1b { za0h.b[x13, #4] }, p1/Z, [x23, x22] + KAI_ASM_INST(0x252c6d21) // psel p1.b, p11.b/Z, p9.b[w12, #1] + ldr x23, [x9, #0x8] + add x9, x9, #0x10 + KAI_ASM_INST(0xe01622a5) // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x22] + ldr x21, [x28, #0x8] + KAI_ASM_INST(0x252c6d20) // psel p0.b, p11.b/Z, p9.b[w12, #1] + add x28, x28, #0x10 + KAI_ASM_INST(0xe0bf8d48) // st1w { za2v.s[x12] }, p3/Z, [x10, XZR, LSL #2] + add x13, x13, #0x8 + KAI_ASM_INST(0xe0a7894c) // st1w { za3v.s[x12] }, p2/Z, [x10, x7, LSL #2] + KAI_ASM_INST(0xe0ab8549) // st1w { za2v.s[x12, #1] }, p1/Z, [x10, x11, LSL #2] + KAI_ASM_INST(0xe0bb814d) // st1w { za3v.s[x12, #1] }, p0/Z, [x10, x27, LSL #2] + add x12, x12, #0x2 + addvl x10, x10, #4 + cmp x12, x16 + blt label_6 +KAI_ASM_LABEL(label_7) // K loop: Main loop: Second: Tail + KAI_ASM_INST(0x25256143) // psel p3.b, p8.b/Z, p10.b[w13] + KAI_ASM_INST(0x252d6142) // psel p2.b, p8.b/Z, p10.b[w13, #1] + KAI_ASM_INST(0x25656141) // psel p1.b, p8.b/Z, p10.b[w13, #4] + KAI_ASM_INST(0x256d6140) // psel p0.b, p8.b/Z, p10.b[w13, #5] + mov x9, x17 + add x28, x17, x7, LSL #3 + KAI_ASM_INST(0xe0162f40) // ld1b { za0h.b[x13] }, p3/Z, [x26, x22] + KAI_ASM_INST(0x25246d23) // psel p3.b, p11.b/Z, p9.b[w12] + ldr x26, [x9, #0x0] + KAI_ASM_INST(0xe0162b01) // ld1b { za0h.b[x13, #1] }, p2/Z, [x24, x22] + KAI_ASM_INST(0x25246d22) // psel p2.b, p11.b/Z, p9.b[w12] + ldr x24, [x28, #0x0] + KAI_ASM_INST(0xe01626e4) // ld1b { za0h.b[x13, #4] }, p1/Z, [x23, x22] + KAI_ASM_INST(0x252c6d21) // psel p1.b, p11.b/Z, p9.b[w12, #1] + ldr x23, [x9, #0x8] + add x9, x9, #0x10 + KAI_ASM_INST(0xe01622a5) // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x22] + ldr x21, [x28, #0x8] + KAI_ASM_INST(0x252c6d20) // psel p0.b, p11.b/Z, p9.b[w12, #1] + whilelt p9.b, x5, x6 + KAI_ASM_INST(0xe0bf8d48) // st1w { za2v.s[x12] }, p3/Z, [x10, XZR, LSL #2] + subs x20, x20, #0x1 + add x28, x28, #0x10 + KAI_ASM_INST(0xe0a7894c) // st1w { za3v.s[x12] }, p2/Z, [x10, x7, LSL #2] + incb x5 + incb x22 + KAI_ASM_INST(0xe0ab8549) // st1w { za2v.s[x12, #1] }, p1/Z, [x10, x11, LSL #2] + KAI_ASM_INST(0xe0bb814d) // st1w { za3v.s[x12, #1] }, p0/Z, [x10, x27, LSL #2] + addvl x10, x10, #4 + bgt label_3 +KAI_ASM_LABEL(label_8) // K loop: Tails + cbnz x25, label_11 + mov x9, x17 + whilelt p8.b, x5, x6 + mov x13, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_9) // K loop: Tails: Even: First + KAI_ASM_INST(0x25306d23) // psel p3.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306d22) // psel p2.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25356141) // psel p1.b, p8.b/Z, p10.b[w13, #2] + KAI_ASM_INST(0x253d6140) // psel p0.b, p8.b/Z, p10.b[w13, #3] + KAI_ASM_INST(0xe0bf8d40) // st1w { za0v.s[x12] }, p3/Z, [x10, XZR, LSL #2] + KAI_ASM_INST(0xe0a78944) // st1w { za1v.s[x12] }, p2/Z, [x10, x7, LSL #2] + add x12, x12, #0x1 + addvl x10, x10, #2 + ldr x21, [x9, #0x0] + cmp x12, x7 + ldr x20, [x9, x7, LSL #0x3] + add x9, x9, #0x8 + KAI_ASM_INST(0xe01626a2) // ld1b { za0h.b[x13, #2] }, p1/Z, [x21, x22] + KAI_ASM_INST(0xe0162283) // ld1b { za0h.b[x13, #3] }, p0/Z, [x20, x22] + add x13, x13, #0x4 + blt label_9 + whilelt p9.b, x5, x6 + whilelt p8.b, x5, x6 + mov x20, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_10) // K loop: Tails: Even: Second + KAI_ASM_INST(0x25306d21) // psel p1.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306d20) // psel p0.s, p11.s/Z, p9.s[w12] + add x20, x20, #0x4 + KAI_ASM_INST(0xe0bf8548) // st1w { za2v.s[x12] }, p1/Z, [x10, XZR, LSL #2] + KAI_ASM_INST(0xe0a7814c) // st1w { za3v.s[x12] }, p0/Z, [x10, x7, LSL #2] + add x12, x12, #0x1 + addvl x10, x10, #2 + cmp x12, x8 + blt label_10 + whilelt p8.b, x5, x6 + b label_13 +KAI_ASM_LABEL(label_11) // K loop: Tails: Odd + mov x12, #0x0 +KAI_ASM_LABEL(label_12) // K loop: Tails: Odd: Loop + KAI_ASM_INST(0x25306d21) // psel p1.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306d20) // psel p0.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8540) // st1w { za0v.s[x12] }, p1/Z, [x10, XZR, LSL #2] + KAI_ASM_INST(0xe0a78144) // st1w { za1v.s[x12] }, p0/Z, [x10, x7, LSL #2] + add x12, x12, #0x1 + addvl x10, x10, #2 + cmp x12, x8 + blt label_12 +KAI_ASM_LABEL(label_13) // K loop: End + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_lhs_pack_x8p2vlx4_x8_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c index 3241c578..ff30f7c9 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -1,13 +1,12 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" #include @@ -16,19 +15,39 @@ #include "kai/kai_common.h" -static const size_t kai_nr = 2; -static const size_t kai_kr = 4; -static const size_t kai_num_bytes_input = 1; -static const size_t kai_num_bytes_output = 1; -static const size_t kai_num_bytes_bias = 4; -static const size_t kai_num_bytes_scale = 4; +enum { + NR = 2, + KR = 4, + MAX_N_STEP = NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint8_t)) / KR), +}; + +typedef struct { + const void* bias_ptr; + const void* scale_ptr; + int32_t input_zero_point; + float scale_multiplier; + size_t width; + size_t height; + size_t in_stride; + size_t out_stride; + const void* in; + void* out; + const void* pad_row; +} KernelArgs; + +static const size_t kai_num_bytes_input = sizeof(uint8_t); +static const size_t kai_num_bytes_output = sizeof(uint8_t); +static const size_t kai_num_bytes_bias = sizeof(int32_t); +static const size_t kai_num_bytes_scale = sizeof(float); + +void kai_kernel_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(const KernelArgs* args_ptr); size_t kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void) { - return kai_nr * kai_get_sme_vector_length_u8() / kai_kr; + return NR * kai_get_sme_vector_length_u8() / KR; } size_t kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { - KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u8() / kai_kr) == 0); + KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0); return n_idx * kai_num_bytes_input; } @@ -43,220 +62,53 @@ size_t kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t k) { return kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() * - (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output + kai_num_bytes_scale); + (kai_num_bytes_bias + kai_roundup(k, KR) * kai_num_bytes_output + kai_num_bytes_scale); } size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0); - return (n_idx / kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme()) * - kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k); + const size_t block_idx = n_idx / kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(); + return block_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k); } size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n, size_t k) { - return kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( - kai_roundup(n, kai_nr * kai_get_sme_vector_length_u8() / kai_kr), k); + const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme()); + return kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(n_nr_blocks, k); } void kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride_row, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qsi8cx_params* params) { KAI_ASSUME(num_groups == 1); - KAI_ASSUME(nr == kai_nr * kai_get_sme_vector_length_u8() / kai_kr); - KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(nr == kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme()); + KAI_ASSUME(kr == KR); KAI_ASSUME(sr == 1); KAI_ASSUME(rhs != NULL); KAI_ASSUME(bias != NULL); + KAI_ASSUME(scale != NULL); KAI_ASSUME(rhs_packed != NULL); KAI_ASSUME(extra_bytes == 0); KAI_ASSUME(params != NULL); - size_t height = k; - const size_t width = n; - const void* in = rhs; - void* out = rhs_packed; - const size_t in_stride = rhs_stride; - uint8_t pad_row[nr]; - - if (height % 4) { - memset(pad_row, 0, nr * sizeof(uint8_t)); - } - - size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(height); - const int32_t lhs_zero_point = params->lhs_zero_point; - const float scale_multiplier = params->scale_multiplier; - - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "cmp %x[height], #0x8\n" - "mov x11, %x[out]\n" - "ptrue p2.b\n" - "mov x10, %x[height]\n" - "incb %x[out], ALL, MUL #2\n" - "blt 4f\n" - "1:" // Main row loop: Head - "mov x9, %x[in]\n" - "mov x28, %x[out]\n" - "add x27, x9, %x[in_stride]\n" - "sub %x[height], %x[height], #0x8\n" - "add x26, x27, %x[in_stride]\n" - "mov x24, %x[width]\n" - "add x25, x26, %x[in_stride]\n" - "add x23, x25, %x[in_stride]\n" - "add x22, x23, %x[in_stride]\n" - "add x21, x22, %x[in_stride]\n" - "add x20, x21, %x[in_stride]\n" - "add %x[in], x20, %x[in_stride]\n" - "2:" // Main row loop: Column loop - "whilelt p0.b, XZR, x24\n" - "decw x24, ALL, MUL #2\n" - "ld1b { z18.b }, p0/Z, [x9]\n" - "cmp x24, #0x0\n" - "incd x9, ALL, MUL #4\n" - "ld1b { z22.b }, p0/Z, [x27]\n" - "incd x27, ALL, MUL #4\n" - "ld1b { z17.b }, p0/Z, [x26]\n" - "incd x26, ALL, MUL #4\n" - "ld1b { z16.b }, p0/Z, [x25]\n" - "incd x25, ALL, MUL #4\n" - "ld1b { z20.b }, p0/Z, [x23]\n" - "incd x23, ALL, MUL #4\n" - "ld1b { z19.b }, p0/Z, [x22]\n" - "zip1 z21.b, z18.b, z17.b\n" - "incd x22, ALL, MUL #4\n" - "ld1b { z18.b }, p0/Z, [x21]\n" - "zip1 z17.b, z22.b, z16.b\n" - "incd x21, ALL, MUL #4\n" - "ld1b { z16.b }, p0/Z, [x20]\n" - "incd x20, ALL, MUL #4\n" - "zip1 z20.b, z20.b, z18.b\n" - "zip1 z16.b, z19.b, z16.b\n" - "zip1 z19.b, z21.b, z17.b\n" - "zip2 z18.b, z21.b, z17.b\n" - "zip1 z17.b, z20.b, z16.b\n" - "zip2 z16.b, z20.b, z16.b\n" - "st1b { z19.b }, p2, [x28]\n" - "st1b { z18.b }, p2, [x28, #1, MUL VL]\n" - "st1b { z17.b }, p2, [x28, #2, MUL VL]\n" - "st1b { z16.b }, p2, [x28, #3, MUL VL]\n" - "add x28, x28, %x[out_stride]\n" - "bgt 2b\n" - "cmp %x[height], #0x8\n" - "addvl %x[out], %x[out], #4\n" - "bge 1b\n" - "cbz %x[height], 8f\n" - "4:" // Main loop skip - "5:" // Tail row loop: Head - "mov x9, %x[in]\n" - "cntw x24, ALL, MUL #2\n" - "add x27, x9, %x[in_stride]\n" - "cmp %x[height], #0x3\n" - "add x26, x27, %x[in_stride]\n" - "csel x23, x24, XZR, GT\n" - "add x25, x26, %x[in_stride]\n" - "csel x26, x26, %x[pad_row], GE\n" - "add %x[in], x25, %x[in_stride]\n" - "csel x25, x25, %x[pad_row], GT\n" - "csel x22, x24, XZR, GE\n" - "cmp %x[height], #0x1\n" - "mov x28, %x[out]\n" - "csel x27, x27, %x[pad_row], GT\n" - "csel x21, x24, XZR, GT\n" - "sub %x[height], %x[height], #0x4\n" - "mov x20, %x[width]\n" - "6:" // Tail row loop: Column loop - "whilelt p0.b, XZR, x20\n" - "decw x20, ALL, MUL #2\n" - "ld1b { z18.b }, p0/Z, [x9]\n" - "cmp x20, #0x0\n" - "add x9, x9, x24\n" - "ld1b { z19.b }, p0/Z, [x27]\n" - "add x27, x27, x21\n" - "ld1b { z17.b }, p0/Z, [x26]\n" - "add x26, x26, x22\n" - "ld1b { z16.b }, p0/Z, [x25]\n" - "add x25, x25, x23\n" - "zip1 z18.b, z18.b, z17.b\n" - "zip1 z16.b, z19.b, z16.b\n" - "zip1 z17.b, z18.b, z16.b\n" - "zip2 z16.b, z18.b, z16.b\n" - "st1b { z17.b }, p2, [x28]\n" - "st1b { z16.b }, p2, [x28, #1, MUL VL]\n" - "add x28, x28, %x[out_stride]\n" - "bgt 6b\n" - "cmp %x[height], #0x1\n" - "addvl %x[out], %x[out], #2\n" - "bge 5b\n" - "8:" // Done - "mov x22, %x[out]\n" - "mov x21, %x[width]\n" - "dup z18.s, %w[scale_multiplier]\n" - "cbz %x[scale], 10f\n" - "9:" // Scale: Full loop - "mov x20, x21\n" - "decw x21, ALL, MUL #2\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "ld1w { z17.s }, p1/Z, [%x[scale]]\n" - "cmp x21, #0x0\n" - "ld1w { z16.s }, p0/Z, [%x[scale], #1, MUL VL]\n" - "incb %x[scale], ALL, MUL #2\n" - "fmul z17.s, z17.s, z18.s\n" - "fmul z16.s, z16.s, z18.s\n" - "st1w { z17.s }, p2, [x22]\n" - "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" - "add x22, x22, %x[out_stride]\n" - "bgt 9b\n" - "10:" // Scale: Done - "cbz %x[width], 13f\n" - "cbz x10, 13f\n" - "dup z21.s, %w[lhs_zero_point]\n" - "add x25, x10, #0x3\n" - "cntw x24, ALL, MUL #2\n" - "mov z20.b, #0x1\n" - "lsr x25, x25, #0x2\n" - "mov x23, %x[width]\n" - "addvl x22, x11, #2\n" - "neg z21.s, p2/M, z21.s\n" - "11:" // Bias: N loop - "mov x21, x22\n" - "mov x20, x25\n" - "mov z19.s, #0x0\n" - "mov z18.s, #0x0\n" - "12:" // Bias: K loop - "ld1b { z17.b }, p2/Z, [x21]\n" - "subs x20, x20, #0x1\n" - "ld1b { z16.b }, p2/Z, [x21, #1, MUL VL]\n" - "addvl x21, x21, #2\n" - "sdot z19.s, z17.b, z20.b\n" - "sdot z18.s, z16.b, z20.b\n" - "bgt 12b\n" - "mov x20, x23\n" - "add x22, x22, %x[out_stride]\n" - "whilelt p1.s, XZR, x20\n" - "decw x20\n" - "whilelt p0.s, XZR, x20\n" - "ld1w { z17.s }, p1/Z, [%x[bias]]\n" - "subs x23, x23, x24\n" - "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" - "addvl %x[bias], %x[bias], #2\n" - "mla z17.s, p2/M, z19.s, z21.s\n" - "mla z16.s, p2/M, z18.s, z21.s\n" - "st1w { z17.s }, p2, [x11]\n" - "st1w { z16.s }, p2, [x11, #1, MUL VL]\n" - "add x11, x11, %x[out_stride]\n" - "bgt 11b\n" - "13:" // Bias: Done - ".inst 0xd503467f // SMSTOP\n" - : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out), [scale] "+&r"(scale) - : [in_stride] "r"(in_stride), [lhs_zero_point] "r"(lhs_zero_point), [out_stride] "r"(out_stride), - [pad_row] "r"(pad_row), [scale_multiplier] "r"(scale_multiplier), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", - "p15", "x9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", - "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", - "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + KAI_ASSERT(kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() <= MAX_N_STEP); + static const uint8_t pad_row[MAX_N_STEP] = {0}; + + KernelArgs args; + args.bias_ptr = bias; + args.scale_ptr = scale; + args.height = k; + args.width = n; + args.in = rhs; + args.out = rhs_packed; + args.in_stride = rhs_stride_row; + args.out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(args.height); + args.input_zero_point = params->lhs_zero_point; + args.scale_multiplier = params->scale_multiplier; + args.pad_row = pad_row; + + kai_kernel_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h index 6107a14d..30baec30 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -16,14 +16,14 @@ extern "C" { /// Gets n step value. /// -/// The starting row index must be divisible by `n_step`. +/// The starting column index must be divisible by `n_step`. /// /// @return The n step value. size_t kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. /// -/// @param[in] n_idx Column index. +/// @param[in] n_idx Column index. Must be divisible by `n_step` /// /// @return The offset in bytes to the data element. size_t kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); @@ -51,7 +51,7 @@ size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(s /// Gets the offset in bytes to the data element in the packed RHS buffer. /// -/// @param[in] n_idx Column index. +/// @param[in] n_idx Column index. Must be divisible by `n_step` /// @param[in] k Number of rows. /// /// @return The offset in bytes to the data element. @@ -70,26 +70,26 @@ size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(siz /// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset /// calculated using the following functions: /// -/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(. -/// * Bias: @ref kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(. -/// * Scale: @ref kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(. -/// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(. +/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * Bias: @ref kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * Scale: @ref kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. /// /// @param[in] num_groups Number of groups. It must be 1. /// @param[in] n Number of columns of the output matrix. -/// @param[in] k Common dimension between the LHS and RHS matrix. -/// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u8(). +/// @param[in] k Number of rows. +/// @param[in] nr Block size in N dimension. It must be `get_n_step` /// @param[in] kr Block size in K dimension. It must be 4. /// @param[in] sr Number of kr splits. It must be 1. -/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs_stride_row Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. -/// @param[in] scale Scale data buffer. It must be NULL. +/// @param[in] scale Scale data buffer. /// @param[out] rhs_packed Packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. -/// @param[in] params Extra packing parameters. It must be NULL. +/// @param[in] params Extra quantization packing parameters. void kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride_row, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qsi8cx_params* params); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S new file mode 100644 index 00000000..9b120c9b --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S @@ -0,0 +1,234 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x4, [x0, #0x20] + ptrue p2.b + ldr x5, [x0, #0x40] + ldr x6, [x0, #0x0] + mov x7, x4 + ldr x8, [x0, #0x8] + cmp x7, #0x8 + ldr w17, [x0, #0x10] + mov x16, x5 + ldr w15, [x0, #0x14] + incb x5, ALL, MUL #2 + ldr x14, [x0, #0x18] + ldr x13, [x0, #0x28] + ldr x12, [x0, #0x30] + ldr x11, [x0, #0x38] + ldr x10, [x0, #0x48] + blt label_4 +KAI_ASM_LABEL(label_1) // Main row loop: Head + mov x9, x11 + mov x28, x5 + add x27, x9, x13 + sub x7, x7, #0x8 + add x26, x27, x13 + mov x24, x14 + add x25, x26, x13 + add x23, x25, x13 + add x22, x23, x13 + add x21, x22, x13 + add x20, x21, x13 + add x11, x20, x13 +KAI_ASM_LABEL(label_2) // Main row loop: Column loop + whilelt p0.b, XZR, x24 + decw x24, ALL, MUL #2 + ld1b { z18.b }, p0/Z, [x9] + cmp x24, #0x0 + incd x9, ALL, MUL #4 + ld1b { z22.b }, p0/Z, [x27] + incd x27, ALL, MUL #4 + ld1b { z17.b }, p0/Z, [x26] + incd x26, ALL, MUL #4 + ld1b { z16.b }, p0/Z, [x25] + incd x25, ALL, MUL #4 + ld1b { z20.b }, p0/Z, [x23] + incd x23, ALL, MUL #4 + ld1b { z19.b }, p0/Z, [x22] + zip1 z21.b, z18.b, z17.b + incd x22, ALL, MUL #4 + ld1b { z18.b }, p0/Z, [x21] + zip1 z17.b, z22.b, z16.b + incd x21, ALL, MUL #4 + ld1b { z16.b }, p0/Z, [x20] + incd x20, ALL, MUL #4 + zip1 z20.b, z20.b, z18.b + zip1 z16.b, z19.b, z16.b + zip1 z19.b, z21.b, z17.b + zip2 z18.b, z21.b, z17.b + zip1 z17.b, z20.b, z16.b + zip2 z16.b, z20.b, z16.b + st1b { z19.b }, p2, [x28] + st1b { z18.b }, p2, [x28, #1, MUL VL] + st1b { z17.b }, p2, [x28, #2, MUL VL] + st1b { z16.b }, p2, [x28, #3, MUL VL] + add x28, x28, x12 + bgt label_2 + cmp x7, #0x8 + addvl x5, x5, #4 + bge label_1 + cbz x7, label_8 +KAI_ASM_LABEL(label_4) // Main loop skip +KAI_ASM_LABEL(label_5) // Tail row loop: Head + mov x9, x11 + cmp x7, #0x3 + add x27, x9, x13 + cntw x24, ALL, MUL #2 + add x26, x27, x13 + csel x23, x24, XZR, GT + add x25, x26, x13 + csel x22, x24, XZR, GE + add x11, x25, x13 + mov x28, x5 + csel x11, x11, x25, GT + csel x25, x25, x10, GT + csel x11, x11, x26, GE + csel x26, x26, x10, GE + cmp x7, #0x1 + sub x7, x7, #0x4 + csel x11, x11, x27, GT + csel x27, x27, x10, GT + csel x21, x24, XZR, GT + mov x20, x14 +KAI_ASM_LABEL(label_6) // Tail row loop: Column loop + whilelt p0.b, XZR, x20 + decw x20, ALL, MUL #2 + ld1b { z18.b }, p0/Z, [x9] + cmp x20, #0x0 + add x9, x9, x24 + ld1b { z19.b }, p0/Z, [x27] + add x27, x27, x21 + ld1b { z17.b }, p0/Z, [x26] + add x26, x26, x22 + ld1b { z16.b }, p0/Z, [x25] + add x25, x25, x23 + zip1 z18.b, z18.b, z17.b + zip1 z16.b, z19.b, z16.b + zip1 z17.b, z18.b, z16.b + zip2 z16.b, z18.b, z16.b + st1b { z17.b }, p2, [x28] + st1b { z16.b }, p2, [x28, #1, MUL VL] + add x28, x28, x12 + bgt label_6 + cmp x7, #0x1 + addvl x5, x5, #2 + bge label_5 +KAI_ASM_LABEL(label_8) // Done + mov x22, x5 + mov x21, x14 + dup z18.s, w15 + cbz x8, label_10 +KAI_ASM_LABEL(label_9) // Scale: Full loop + mov x20, x21 + decw x21, ALL, MUL #2 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + ld1w { z17.s }, p1/Z, [x8] + cmp x21, #0x0 + ld1w { z16.s }, p0/Z, [x8, #1, MUL VL] + incb x8, ALL, MUL #2 + fmul z17.s, z17.s, z18.s + fmul z16.s, z16.s, z18.s + st1w { z17.s }, p2, [x22] + st1w { z16.s }, p2, [x22, #1, MUL VL] + add x22, x22, x12 + bgt label_9 +KAI_ASM_LABEL(label_10) // Scale: Done + cbz x14, label_13 + cbz x4, label_13 + dup z21.s, w17 + add x25, x4, #0x3 + cntw x24, ALL, MUL #2 + mov z20.b, #0x1 + lsr x25, x25, #0x2 + mov x23, x14 + addvl x22, x16, #2 + neg z21.s, p2/M, z21.s +KAI_ASM_LABEL(label_11) // Bias: N loop + mov x21, x22 + mov x20, x25 + mov z19.s, #0x0 + mov z18.s, #0x0 +KAI_ASM_LABEL(label_12) // Bias: K loop + ld1b { z17.b }, p2/Z, [x21] + subs x20, x20, #0x1 + ld1b { z16.b }, p2/Z, [x21, #1, MUL VL] + addvl x21, x21, #2 + sdot z19.s, z17.b, z20.b + sdot z18.s, z16.b, z20.b + bgt label_12 + mov x20, x23 + add x22, x22, x12 + whilelt p1.s, XZR, x20 + decw x20 + whilelt p0.s, XZR, x20 + ld1w { z17.s }, p1/Z, [x6] + subs x23, x23, x24 + ld1w { z16.s }, p0/Z, [x6, #1, MUL VL] + addvl x6, x6, #2 + mla z17.s, p2/M, z19.s, z21.s + mla z16.s, p2/M, z18.s, z21.s + st1w { z17.s }, p2, [x16] + st1w { z16.s }, p2, [x16, #1, MUL VL] + add x16, x16, x12 + bgt label_11 +KAI_ASM_LABEL(label_13) // Bias: Done + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) + + KAI_ASM_END -- GitLab From 2d41802c3887977d8573d4e0c7f92bc270e4c177 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 4 Jun 2025 13:17:05 +0200 Subject: [PATCH 2/4] Save/restore the correct registers in the assembly files. Signed-off-by: Jens Elofsson --- ..._qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S | 12 +- ...lx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S | 286 +++++++++--------- .../pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S | 12 +- ...kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S | 12 +- 4 files changed, 172 insertions(+), 150 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S index 2cfe4ce6..9acbe00e 100644 --- a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot) KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA mov x8, #0x0 ldr x5, [x0, #0x20] @@ -900,7 +904,11 @@ KAI_ASM_LABEL(label_32) // Exit ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot) diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S index 5784542e..8d3e0e14 100644 --- a/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S @@ -42,158 +42,152 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA - mov x14, #0x0 + mov x15, #0x0 + ldr x14, [x0, #0x30] ptrue p1.b - KAI_ASM_INST(0x25207811) // ptrue pn9.b + KAI_ASM_INST(0x25207810) // ptrue pn8.b ldr w13, [x0, #0x20] - ldr w11, [x0, #0x28] - mov x10, #0x0 + mov x11, #0x0 + ldr w10, [x0, #0x28] + add x14, x14, #0x3 ldr x9, [x0, #0x0] + lsr x14, x14, #0x2 KAI_ASM_LABEL(label_1) // M loop ldr x28, [x0, #0x8] KAI_ASM_LABEL(label_2) // N loop - KAI_ASM_INST(0x25ab4550) // whilelt pn8.s, x10, x11, VLx2 + KAI_ASM_INST(0xa0404382) // ld1w { z2.s-z3.s }, pn8.b/Z, [x28] KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } mov x27, x9 - KAI_ASM_INST(0xa040438e) // ld1w { z14.s-z15.s }, p8/Z, [x28] // Load bias addvl x28, x28, #2 - KAI_ASM_INST(0xc09025c0) // addha za0.s, p1/M, p1/M, z14.s - KAI_ASM_INST(0xc09025e1) // addha za1.s, p1/M, p1/M, z15.s - KAI_ASM_INST(0xc09025c2) // addha za2.s, p1/M, p1/M, z14.s - KAI_ASM_INST(0xc09025e3) // addha za3.s, p1/M, p1/M, z15.s - ldr x20, [x0, #0x30] - add x20, x20, #0x3 - lsr x20, x20, #0x2 - lsr x21, x20, #0x2 - and x20, x20, #0x3 + KAI_ASM_INST(0xc0902440) // addha za0.s, p1/M, p1/M, z2.s + KAI_ASM_INST(0xc0902461) // addha za1.s, p1/M, p1/M, z3.s + KAI_ASM_INST(0xc0902442) // addha za2.s, p1/M, p1/M, z2.s + KAI_ASM_INST(0xc0902463) // addha za3.s, p1/M, p1/M, z3.s + lsr x21, x14, #0x2 + and x20, x14, #0x3 cbz x21, label_6 subs x21, x21, #0x1 - KAI_ASM_INST(0xa0400762) // ld1b { z2.b-z3.b }, pn9.b/Z, [x27] - KAI_ASM_INST(0xa1400780) // ld1b { z0.b, z8.b }, pn9.b/Z, [x28] - KAI_ASM_INST(0xa0410772) // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL] - KAI_ASM_INST(0xa0410794) // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL] - KAI_ASM_INST(0xa042077a) // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL] - KAI_ASM_INST(0xa0420796) // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL] - KAI_ASM_INST(0xa0430778) // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL] + KAI_ASM_INST(0xa1408362) // ld1b { z2.b, z6.b, z10.b, z14.b }, pn8.b/Z, [x27] + KAI_ASM_INST(0xa1418360) // ld1b { z0.b, z4.b, z8.b, z12.b }, pn8.b/Z, [x27, #0x4, MUL VL] addvl x27, x27, #8 - KAI_ASM_INST(0xa0430784) // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL] + KAI_ASM_INST(0xa0408390) // ld1b { z16.b-z19.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xa041839c) // ld1b { z28.b-z31.b }, pn8.b/Z, [x28, #0x4, MUL VL] addvl x28, x28, #8 ble label_5 KAI_ASM_LABEL(label_4) // K loop - KAI_ASM_INST(0xa0802440) // smopa za0.s, p1/M, p1/M, z2.b, z0.b + KAI_ASM_INST(0xa0902440) // smopa za0.s, p1/M, p1/M, z2.b, z16.b subs x21, x21, #0x1 - KAI_ASM_INST(0xa0882441) // smopa za1.s, p1/M, p1/M, z2.b, z8.b - KAI_ASM_INST(0xa0802462) // smopa za2.s, p1/M, p1/M, z3.b, z0.b - KAI_ASM_INST(0xa0882463) // smopa za3.s, p1/M, p1/M, z3.b, z8.b - KAI_ASM_INST(0xa0400762) // ld1b { z2.b-z3.b }, pn9.b/Z, [x27] - KAI_ASM_INST(0xa0942640) // smopa za0.s, p1/M, p1/M, z18.b, z20.b - KAI_ASM_INST(0xa1400780) // ld1b { z0.b, z8.b }, pn9.b/Z, [x28] - KAI_ASM_INST(0xa0952641) // smopa za1.s, p1/M, p1/M, z18.b, z21.b - KAI_ASM_INST(0xa0942662) // smopa za2.s, p1/M, p1/M, z19.b, z20.b - KAI_ASM_INST(0xa0952663) // smopa za3.s, p1/M, p1/M, z19.b, z21.b - KAI_ASM_INST(0xa0410772) // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL] - KAI_ASM_INST(0xa0962740) // smopa za0.s, p1/M, p1/M, z26.b, z22.b - KAI_ASM_INST(0xa0410794) // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL] - KAI_ASM_INST(0xa0972741) // smopa za1.s, p1/M, p1/M, z26.b, z23.b - KAI_ASM_INST(0xa0962762) // smopa za2.s, p1/M, p1/M, z27.b, z22.b - KAI_ASM_INST(0xa0972763) // smopa za3.s, p1/M, p1/M, z27.b, z23.b - KAI_ASM_INST(0xa042077a) // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL] - KAI_ASM_INST(0xa0420796) // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL] - KAI_ASM_INST(0xa0842700) // smopa za0.s, p1/M, p1/M, z24.b, z4.b - KAI_ASM_INST(0xa0852701) // smopa za1.s, p1/M, p1/M, z24.b, z5.b - KAI_ASM_INST(0xa0842722) // smopa za2.s, p1/M, p1/M, z25.b, z4.b - KAI_ASM_INST(0xa0852723) // smopa za3.s, p1/M, p1/M, z25.b, z5.b - KAI_ASM_INST(0xa0430778) // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL] + KAI_ASM_INST(0xa0912441) // smopa za1.s, p1/M, p1/M, z2.b, z17.b + KAI_ASM_INST(0xa09024c2) // smopa za2.s, p1/M, p1/M, z6.b, z16.b + KAI_ASM_INST(0xa09124c3) // smopa za3.s, p1/M, p1/M, z6.b, z17.b + KAI_ASM_INST(0xa0922540) // smopa za0.s, p1/M, p1/M, z10.b, z18.b + KAI_ASM_INST(0xa0932541) // smopa za1.s, p1/M, p1/M, z10.b, z19.b + KAI_ASM_INST(0xa09225c2) // smopa za2.s, p1/M, p1/M, z14.b, z18.b + KAI_ASM_INST(0xa09325c3) // smopa za3.s, p1/M, p1/M, z14.b, z19.b + KAI_ASM_INST(0xa1408362) // ld1b { z2.b, z6.b, z10.b, z14.b }, pn8.b/Z, [x27] + KAI_ASM_INST(0xa09c2400) // smopa za0.s, p1/M, p1/M, z0.b, z28.b + KAI_ASM_INST(0xa0408390) // ld1b { z16.b-z19.b }, pn8.b/Z, [x28] + KAI_ASM_INST(0xa09d2401) // smopa za1.s, p1/M, p1/M, z0.b, z29.b + KAI_ASM_INST(0xa09c2482) // smopa za2.s, p1/M, p1/M, z4.b, z28.b + KAI_ASM_INST(0xa09d2483) // smopa za3.s, p1/M, p1/M, z4.b, z29.b + KAI_ASM_INST(0xa09e2500) // smopa za0.s, p1/M, p1/M, z8.b, z30.b + KAI_ASM_INST(0xa09f2501) // smopa za1.s, p1/M, p1/M, z8.b, z31.b + KAI_ASM_INST(0xa09e2582) // smopa za2.s, p1/M, p1/M, z12.b, z30.b + KAI_ASM_INST(0xa09f2583) // smopa za3.s, p1/M, p1/M, z12.b, z31.b + KAI_ASM_INST(0xa1418360) // ld1b { z0.b, z4.b, z8.b, z12.b }, pn8.b/Z, [x27, #0x4, MUL VL] addvl x27, x27, #8 - KAI_ASM_INST(0xa0430784) // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL] + KAI_ASM_INST(0xa041839c) // ld1b { z28.b-z31.b }, pn8.b/Z, [x28, #0x4, MUL VL] addvl x28, x28, #8 bgt label_4 KAI_ASM_LABEL(label_5) // K loop tail - KAI_ASM_INST(0xa0802440) // smopa za0.s, p1/M, p1/M, z2.b, z0.b - KAI_ASM_INST(0xa0882441) // smopa za1.s, p1/M, p1/M, z2.b, z8.b - KAI_ASM_INST(0xa0802462) // smopa za2.s, p1/M, p1/M, z3.b, z0.b - KAI_ASM_INST(0xa0882463) // smopa za3.s, p1/M, p1/M, z3.b, z8.b - KAI_ASM_INST(0xa0942640) // smopa za0.s, p1/M, p1/M, z18.b, z20.b - KAI_ASM_INST(0xa0952641) // smopa za1.s, p1/M, p1/M, z18.b, z21.b - KAI_ASM_INST(0xa0942662) // smopa za2.s, p1/M, p1/M, z19.b, z20.b - KAI_ASM_INST(0xa0952663) // smopa za3.s, p1/M, p1/M, z19.b, z21.b - KAI_ASM_INST(0xa0962740) // smopa za0.s, p1/M, p1/M, z26.b, z22.b - KAI_ASM_INST(0xa0972741) // smopa za1.s, p1/M, p1/M, z26.b, z23.b - KAI_ASM_INST(0xa0962762) // smopa za2.s, p1/M, p1/M, z27.b, z22.b - KAI_ASM_INST(0xa0972763) // smopa za3.s, p1/M, p1/M, z27.b, z23.b - KAI_ASM_INST(0xa0842700) // smopa za0.s, p1/M, p1/M, z24.b, z4.b - KAI_ASM_INST(0xa0852701) // smopa za1.s, p1/M, p1/M, z24.b, z5.b - KAI_ASM_INST(0xa0842722) // smopa za2.s, p1/M, p1/M, z25.b, z4.b - KAI_ASM_INST(0xa0852723) // smopa za3.s, p1/M, p1/M, z25.b, z5.b + KAI_ASM_INST(0xa0902440) // smopa za0.s, p1/M, p1/M, z2.b, z16.b + KAI_ASM_INST(0xa0912441) // smopa za1.s, p1/M, p1/M, z2.b, z17.b + KAI_ASM_INST(0xa09024c2) // smopa za2.s, p1/M, p1/M, z6.b, z16.b + KAI_ASM_INST(0xa09124c3) // smopa za3.s, p1/M, p1/M, z6.b, z17.b + KAI_ASM_INST(0xa0922540) // smopa za0.s, p1/M, p1/M, z10.b, z18.b + KAI_ASM_INST(0xa0932541) // smopa za1.s, p1/M, p1/M, z10.b, z19.b + KAI_ASM_INST(0xa09225c2) // smopa za2.s, p1/M, p1/M, z14.b, z18.b + KAI_ASM_INST(0xa09325c3) // smopa za3.s, p1/M, p1/M, z14.b, z19.b + KAI_ASM_INST(0xa09c2400) // smopa za0.s, p1/M, p1/M, z0.b, z28.b + KAI_ASM_INST(0xa09d2401) // smopa za1.s, p1/M, p1/M, z0.b, z29.b + KAI_ASM_INST(0xa09c2482) // smopa za2.s, p1/M, p1/M, z4.b, z28.b + KAI_ASM_INST(0xa09d2483) // smopa za3.s, p1/M, p1/M, z4.b, z29.b + KAI_ASM_INST(0xa09e2500) // smopa za0.s, p1/M, p1/M, z8.b, z30.b + KAI_ASM_INST(0xa09f2501) // smopa za1.s, p1/M, p1/M, z8.b, z31.b + KAI_ASM_INST(0xa09e2582) // smopa za2.s, p1/M, p1/M, z12.b, z30.b + KAI_ASM_INST(0xa09f2583) // smopa za3.s, p1/M, p1/M, z12.b, z31.b KAI_ASM_LABEL(label_6) // K oddments cbz x20, label_8 KAI_ASM_LABEL(label_7) // K oddments: Loop - KAI_ASM_INST(0xa0400770) // ld1b { z16.b-z17.b }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa0400370) // ld1b { z16.b-z17.b }, pn8.b/Z, [x27] subs x20, x20, #0x1 addvl x27, x27, #2 - KAI_ASM_INST(0xa0400788) // ld1b { z8.b-z9.b }, pn9.b/Z, [x28] + KAI_ASM_INST(0xa1400385) // ld1b { z5.b, z13.b }, pn8.b/Z, [x28] addvl x28, x28, #2 - KAI_ASM_INST(0xa0882600) // smopa za0.s, p1/M, p1/M, z16.b, z8.b - KAI_ASM_INST(0xa0892601) // smopa za1.s, p1/M, p1/M, z16.b, z9.b - KAI_ASM_INST(0xa0882622) // smopa za2.s, p1/M, p1/M, z17.b, z8.b - KAI_ASM_INST(0xa0892623) // smopa za3.s, p1/M, p1/M, z17.b, z9.b + KAI_ASM_INST(0xa0852600) // smopa za0.s, p1/M, p1/M, z16.b, z5.b + KAI_ASM_INST(0xa08d2601) // smopa za1.s, p1/M, p1/M, z16.b, z13.b + KAI_ASM_INST(0xa0852622) // smopa za2.s, p1/M, p1/M, z17.b, z5.b + KAI_ASM_INST(0xa08d2623) // smopa za3.s, p1/M, p1/M, z17.b, z13.b bgt label_7 KAI_ASM_LABEL(label_8) // K oddments: End ldr x26, [x0, #0x10] - sub x25, x13, x14 + sub x25, x13, x15 cntw x24 - ld1rw { z27.s }, p1/Z, [x0, #56] + ld1rw { z26.s }, p1/Z, [x0, #56] ldr x23, [x0, #0x18] - whilelt p0.h, x10, x11 + whilelt p0.h, x11, x10 cmp x25, x24 - ld1rw { z1.s }, p1/Z, [x0, #60] + ld1rw { z23.s }, p1/Z, [x0, #60] csel x22, x25, x24, LT ld1rw { z0.s }, p1/Z, [x0, #64] mov x12, #0x0 - add x26, x26, x10 // C += n + add x26, x26, x11 // C += n lsr x21, x22, #0x2 - ld1w { z22.s }, p1/Z, [x28] - madd x26, x14, x23, x26 // C += m * ldc - ld1w { z26.s }, p1/Z, [x28, #1, MUL VL] - and x20, x22, #0x3 + KAI_ASM_INST(0xa0404382) // ld1w { z2.s-z3.s }, pn8.b/Z, [x28] + madd x26, x15, x23, x26 // C += m * ldc addvl x28, x28, #2 + and x20, x22, #0x3 cbz x21, label_11 KAI_ASM_LABEL(label_10) // Store to output array: Accumulator row 0 loop - KAI_ASM_INST(0xc0860410) // mova { z16.s-z19.s }, za0h.s[x12] - KAI_ASM_INST(0xc086043c) // mova { z28.s-z31.s }, za1h.s[x12] + KAI_ASM_INST(0xc0860408) // mova { z8.s-z11.s }, za0h.s[x12] + KAI_ASM_INST(0xc0860430) // mova { z16.s-z19.s }, za1h.s[x12] + KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc132e210) // scvtf { z16.s-z19.s }, { z16.s-z19.s } - KAI_ASM_INST(0xc132e39c) // scvtf { z28.s-z31.s }, { z28.s-z31.s } - fmul z16.s, z16.s, z22.s - fmul z17.s, z17.s, z22.s + fmul z8.s, z8.s, z2.s + fmul z9.s, z9.s, z2.s add x12, x12, #0x4 - fmul z18.s, z18.s, z22.s - fmul z19.s, z19.s, z22.s + fmul z10.s, z10.s, z2.s + fmul z11.s, z11.s, z2.s cmp x12, x21, LSL #2 - fmul z28.s, z28.s, z26.s - fmul z29.s, z29.s, z26.s - fmul z30.s, z30.s, z26.s - fmul z31.s, z31.s, z26.s + fmul z16.s, z16.s, z3.s + fmul z17.s, z17.s, z3.s + fmul z18.s, z18.s, z3.s + fmul z19.s, z19.s, z3.s + KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc1b8e210) // frintn { z16.s-z19.s }, { z16.s-z19.s } + KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s KAI_ASM_INST(0xc131e210) // fcvtzs { z16.s-z19.s }, { z16.s-z19.s } - KAI_ASM_INST(0xc1b8e39c) // frintn { z28.s-z31.s }, { z28.s-z31.s } KAI_ASM_INST(0xc1a0ab10) // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s - KAI_ASM_INST(0xc131e39c) // fcvtzs { z28.s-z31.s }, { z28.s-z31.s } - KAI_ASM_INST(0xc1a0ab1c) // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s - KAI_ASM_INST(0xc1a1cf70) // sclamp { z16.s-z19.s }, z27.s, z1.s - KAI_ASM_INST(0xc1a1cf7c) // sclamp { z28.s-z31.s }, z27.s, z1.s - uzp1 z5.h, z16.h, z28.h - uzp1 z20.h, z17.h, z29.h - uzp1 z17.h, z18.h, z30.h - uzp1 z16.h, z19.h, z31.h + KAI_ASM_INST(0xc1b7cf48) // sclamp { z8.s-z11.s }, z26.s, z23.s + KAI_ASM_INST(0xc1b7cf50) // sclamp { z16.s-z19.s }, z26.s, z23.s + uzp1 z5.h, z8.h, z16.h + uzp1 z14.h, z9.h, z17.h + uzp1 z17.h, z10.h, z18.h + uzp1 z16.h, z11.h, z19.h st1b { z5.h }, p0, [x26] add x26, x26, x23 - st1b { z20.h }, p0, [x26] + st1b { z14.h }, p0, [x26] add x26, x26, x23 st1b { z17.h }, p0, [x26] add x26, x26, x23 @@ -202,37 +196,37 @@ KAI_ASM_LABEL(label_10) // Store to output array: Accumulator row 0 loop blt label_10 KAI_ASM_LABEL(label_11) // Store to output array: Accumulator row 0 oddments cbz x20, label_12 - KAI_ASM_INST(0xc0860404) // mova { z4.s-z7.s }, za0h.s[x12] + KAI_ASM_INST(0xc0860408) // mova { z8.s-z11.s }, za0h.s[x12] KAI_ASM_INST(0xc086042c) // mova { z12.s-z15.s }, za1h.s[x12] - KAI_ASM_INST(0xc132e084) // scvtf { z4.s-z7.s }, { z4.s-z7.s } + KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc132e18c) // scvtf { z12.s-z15.s }, { z12.s-z15.s } - fmul z4.s, z4.s, z22.s - fmul z5.s, z5.s, z22.s + fmul z8.s, z8.s, z2.s + fmul z9.s, z9.s, z2.s subs x20, x20, #0x1 - fmul z6.s, z6.s, z22.s - fmul z7.s, z7.s, z22.s - fmul z12.s, z12.s, z26.s - fmul z13.s, z13.s, z26.s - fmul z14.s, z14.s, z26.s - fmul z15.s, z15.s, z26.s - KAI_ASM_INST(0xc1b8e084) // frintn { z4.s-z7.s }, { z4.s-z7.s } - KAI_ASM_INST(0xc131e084) // fcvtzs { z4.s-z7.s }, { z4.s-z7.s } + fmul z10.s, z10.s, z2.s + fmul z11.s, z11.s, z2.s + fmul z12.s, z12.s, z3.s + fmul z13.s, z13.s, z3.s + fmul z14.s, z14.s, z3.s + fmul z15.s, z15.s, z3.s + KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } + KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc1b8e18c) // frintn { z12.s-z15.s }, { z12.s-z15.s } - KAI_ASM_INST(0xc1a0ab04) // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s + KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s KAI_ASM_INST(0xc131e18c) // fcvtzs { z12.s-z15.s }, { z12.s-z15.s } KAI_ASM_INST(0xc1a0ab0c) // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s - KAI_ASM_INST(0xc1a1cf64) // sclamp { z4.s-z7.s }, z27.s, z1.s - KAI_ASM_INST(0xc1a1cf6c) // sclamp { z12.s-z15.s }, z27.s, z1.s - uzp1 z16.h, z4.h, z12.h + KAI_ASM_INST(0xc1b7cf48) // sclamp { z8.s-z11.s }, z26.s, z23.s + KAI_ASM_INST(0xc1b7cf4c) // sclamp { z12.s-z15.s }, z26.s, z23.s + uzp1 z16.h, z8.h, z12.h st1b { z16.h }, p0, [x26] add x26, x26, x23 beq label_12 subs x20, x20, #0x1 - uzp1 z16.h, z5.h, z13.h + uzp1 z16.h, z9.h, z13.h st1b { z16.h }, p0, [x26] add x26, x26, x23 beq label_12 - uzp1 z16.h, z6.h, z14.h + uzp1 z16.h, z10.h, z14.h st1b { z16.h }, p0, [x26] add x26, x26, x23 KAI_ASM_LABEL(label_12) // Store to output array: Accumulator row 0 oddments: End @@ -249,24 +243,24 @@ KAI_ASM_LABEL(label_13) // Store to output array: Accumulator row 1 loop KAI_ASM_INST(0xc0860470) // mova { z16.s-z19.s }, za3h.s[x12] KAI_ASM_INST(0xc132e108) // scvtf { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc132e210) // scvtf { z16.s-z19.s }, { z16.s-z19.s } - fmul z8.s, z8.s, z22.s - fmul z9.s, z9.s, z22.s + fmul z8.s, z8.s, z2.s + fmul z9.s, z9.s, z2.s add x12, x12, #0x4 - fmul z10.s, z10.s, z22.s - fmul z11.s, z11.s, z22.s + fmul z10.s, z10.s, z2.s + fmul z11.s, z11.s, z2.s cmp x12, x21, LSL #2 - fmul z16.s, z16.s, z26.s - fmul z17.s, z17.s, z26.s - fmul z18.s, z18.s, z26.s - fmul z19.s, z19.s, z26.s + fmul z16.s, z16.s, z3.s + fmul z17.s, z17.s, z3.s + fmul z18.s, z18.s, z3.s + fmul z19.s, z19.s, z3.s KAI_ASM_INST(0xc1b8e108) // frintn { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc131e108) // fcvtzs { z8.s-z11.s }, { z8.s-z11.s } KAI_ASM_INST(0xc1b8e210) // frintn { z16.s-z19.s }, { z16.s-z19.s } KAI_ASM_INST(0xc1a0ab08) // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s KAI_ASM_INST(0xc131e210) // fcvtzs { z16.s-z19.s }, { z16.s-z19.s } KAI_ASM_INST(0xc1a0ab10) // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s - KAI_ASM_INST(0xc1a1cf68) // sclamp { z8.s-z11.s }, z27.s, z1.s - KAI_ASM_INST(0xc1a1cf70) // sclamp { z16.s-z19.s }, z27.s, z1.s + KAI_ASM_INST(0xc1b7cf48) // sclamp { z8.s-z11.s }, z26.s, z23.s + KAI_ASM_INST(0xc1b7cf50) // sclamp { z16.s-z19.s }, z26.s, z23.s uzp1 z21.h, z8.h, z16.h uzp1 z20.h, z9.h, z17.h uzp1 z17.h, z10.h, z18.h @@ -286,23 +280,23 @@ KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments KAI_ASM_INST(0xc0860464) // mova { z4.s-z7.s }, za3h.s[x12] KAI_ASM_INST(0xc132e18c) // scvtf { z12.s-z15.s }, { z12.s-z15.s } KAI_ASM_INST(0xc132e084) // scvtf { z4.s-z7.s }, { z4.s-z7.s } - fmul z12.s, z12.s, z22.s - fmul z13.s, z13.s, z22.s + fmul z12.s, z12.s, z2.s + fmul z13.s, z13.s, z2.s subs x20, x20, #0x1 - fmul z14.s, z14.s, z22.s - fmul z15.s, z15.s, z22.s - fmul z4.s, z4.s, z26.s - fmul z5.s, z5.s, z26.s - fmul z6.s, z6.s, z26.s - fmul z7.s, z7.s, z26.s + fmul z14.s, z14.s, z2.s + fmul z15.s, z15.s, z2.s + fmul z4.s, z4.s, z3.s + fmul z5.s, z5.s, z3.s + fmul z6.s, z6.s, z3.s + fmul z7.s, z7.s, z3.s KAI_ASM_INST(0xc1b8e18c) // frintn { z12.s-z15.s }, { z12.s-z15.s } KAI_ASM_INST(0xc131e18c) // fcvtzs { z12.s-z15.s }, { z12.s-z15.s } KAI_ASM_INST(0xc1b8e084) // frintn { z4.s-z7.s }, { z4.s-z7.s } KAI_ASM_INST(0xc1a0ab0c) // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s KAI_ASM_INST(0xc131e084) // fcvtzs { z4.s-z7.s }, { z4.s-z7.s } KAI_ASM_INST(0xc1a0ab04) // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s - KAI_ASM_INST(0xc1a1cf6c) // sclamp { z12.s-z15.s }, z27.s, z1.s - KAI_ASM_INST(0xc1a1cf64) // sclamp { z4.s-z7.s }, z27.s, z1.s + KAI_ASM_INST(0xc1b7cf4c) // sclamp { z12.s-z15.s }, z26.s, z23.s + KAI_ASM_INST(0xc1b7cf44) // sclamp { z4.s-z7.s }, z26.s, z23.s uzp1 z16.h, z12.h, z4.h st1b { z16.h }, p0, [x26] add x26, x26, x23 @@ -316,12 +310,12 @@ KAI_ASM_LABEL(label_14) // Store to output array: Accumulator row 1 oddments st1b { z16.h }, p0, [x26] KAI_ASM_LABEL(label_15) // Store to output array: Accumulator row 1 oddments: End KAI_ASM_LABEL(label_16) // Store to output array: End - incw x10, ALL, MUL #2 - cmp x10, x11 + incw x11, ALL, MUL #2 + cmp x11, x10 blt label_2 - incw x14, ALL, MUL #2 - mov x10, #0x0 - cmp x14, x13 + incw x15, ALL, MUL #2 + mov x11, #0x0 + cmp x15, x13 mov x9, x27 blt label_1 KAI_ASM_INST(0xd503467f) // SMSTOP @@ -329,7 +323,11 @@ KAI_ASM_LABEL(label_16) // Store to output array: End ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa) diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S index f5635ab5..d6e2ddb0 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_pack_x8p2vlx4_x8_sme) KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_pack_x8p2vlx4_x8_sme) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA mov x5, #0x0 ldr x6, [x0, #0x50] @@ -311,7 +315,11 @@ KAI_ASM_LABEL(label_13) // K loop: End ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_lhs_pack_x8p2vlx4_x8_sme) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S index 9b120c9b..89e88f73 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA ldr x4, [x0, #0x20] ptrue p2.b @@ -227,7 +231,11 @@ KAI_ASM_LABEL(label_13) // Bias: Done ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme) -- GitLab From 1f1ae8791d6f5f3a0923b9ce95a35c51739f04a5 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Tue, 10 Jun 2025 11:59:31 +0200 Subject: [PATCH 3/4] Address review comments - Add CHANGELOG entry - Sort lists in CMakeLists.txt - Use MR, KR, SR enums instead of function params Signed-off-by: Jens Elofsson --- CHANGELOG.md | 5 ++++ CMakeLists.txt | 16 ++++++------- .../pack/kai_lhs_pack_x8p2vlx4_x8_sme.c | 23 +++++++++---------- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e430695d..93d4e239 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,11 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme - kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme - kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme +- Convert SME and SME2 matmul micro-kernels to use pure assembly, and add MSVC support. Affects: + - kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot + - kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa + - kai_lhs_pack_x8p2vlx4_x8_sme + - kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme - New Advanced SIMD micro-kernels: - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_DotProd. - Optimized version of kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 kernel for block depth of 8 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon`) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0df243b2..2ac295b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -229,10 +229,6 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME_ASM - kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c @@ -245,6 +241,10 @@ set(KLEIDIAI_FILES_SME_ASM kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S + kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S ) set(KLEIDIAI_FILES_SME @@ -259,10 +259,6 @@ set(KLEIDIAI_FILES_SME ) set(KLEIDIAI_FILES_SME2_ASM - kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c - kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S - kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c @@ -273,6 +269,10 @@ set(KLEIDIAI_FILES_SME2_ASM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S + kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c + kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S + kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S ) set(KLEIDIAI_FILES_SME2 diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c index 584e9761..07183d4f 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c @@ -15,6 +15,13 @@ #include "kai/kai_common.h" +enum { + MR = 2, + KR = 4, + MAX_M_STEP = (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(int8_t)) / KR), + SR = 1, +}; + typedef struct { size_t m; size_t k; @@ -34,13 +41,6 @@ typedef struct { void kai_kernel_lhs_pack_x8p2vlx4_x8_sme(const KernelArgs* args_ptr); -enum { - MR = 2, - KR = 4, - MAX_M_STEP = (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(int8_t)) / KR), - SR = 1, -}; - static size_t kai_get_mr_lhs_pack_x8p2vlx4_x8_sme(void) { return MR * kai_get_sme_vector_length_u8() / KR; } @@ -68,7 +68,7 @@ size_t kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme(size_t m_idx, size_t k KAI_UNUSED(kr); KAI_UNUSED(sr); - return m_idx * kai_roundup(k, kr) * sizeof(int8_t); + return m_idx * kai_roundup(k, KR) * sizeof(int8_t); } size_t kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { @@ -91,7 +91,6 @@ void kai_run_lhs_pack_x8p2vlx4_x8_sme( KAI_ASSUME(sr == SR); KAI_ASSUME(lhs != NULL); KAI_ASSUME(lhs_packed != NULL); - KAI_ASSUME(m_idx_start == 0); const size_t m_step = kai_get_mr_lhs_pack_x8p2vlx4_x8_sme(); @@ -115,9 +114,9 @@ void kai_run_lhs_pack_x8p2vlx4_x8_sme( KernelArgs args; args.m = m; args.k = k; - args.mr = mr; - args.kr = kr; - args.sr = sr; + args.mr = MR; + args.kr = KR; + args.sr = SR; args.m_idx_start = m_idx_start; args.lhs = lhs; args.lhs_stride = lhs_stride; -- GitLab From 19cd0d553d40e3ec86b00194080687499525f8ca Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Tue, 10 Jun 2025 13:49:57 +0200 Subject: [PATCH 4/4] Address review comments - Minor formatting fixes Signed-off-by: Jens Elofsson --- CMakeLists.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2ac295b3..af90d759 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -235,14 +235,14 @@ set(KLEIDIAI_FILES_SME_ASM kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme_asm.S + kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme_asm.S - kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S ) @@ -333,8 +333,8 @@ else() set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_SME_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_SME2_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SME_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SME2_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set(KLEIDIAI_FILES_ASM ${KLEIDIAI_FILES_SME_ASM} -- GitLab