diff --git a/CHANGELOG.md b/CHANGELOG.md index 93d4e239fe8eacbbf056c87566c595256e581307..b2aaf70238765d8865c1ebd995d56c5804526347 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,11 +20,15 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme - kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme - kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme -- Convert SME and SME2 matmul micro-kernels to use pure assembly, and add MSVC support. Affects: +- Convert SME and SME2 matmul micro-kernels to pure assembly, and add MSVC support. Affects: + - kai_lhs_pack_x16p2vlx2_x16_sme + - kai_lhs_pack_x8p2vlx4_x8_sme + - kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot + - kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa - kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot - kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa - - kai_lhs_pack_x8p2vlx4_x8_sme - kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme + - kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme - New Advanced SIMD micro-kernels: - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_DotProd. - Optimized version of kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 kernel for block depth of 8 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon`) diff --git a/CMakeLists.txt b/CMakeLists.txt index af90d7597d7b190724c6eff15d04702894de0a84..e0e966e10eaedf5e967befb80514ae5ce510da63 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -229,6 +229,10 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME_ASM + kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme_asm.S + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c @@ -250,15 +254,17 @@ set(KLEIDIAI_FILES_SME_ASM set(KLEIDIAI_FILES_SME ${KLEIDIAI_FILES_SME_ASM} kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c ) set(KLEIDIAI_FILES_SME2_ASM + kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S + kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c + kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot_asm.S kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c @@ -277,8 +283,6 @@ set(KLEIDIAI_FILES_SME2_ASM set(KLEIDIAI_FILES_SME2 ${KLEIDIAI_FILES_SME2_ASM} - kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c - kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index c3caa717f6578fe8b600426900bbab77364b033c..76eafed73f0a240d4f91ec917a5754219890acae 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -142,11 +142,9 @@ I8MM_KERNELS_ASM = [ SME_KERNELS = [ "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_lhs_pack_f32p2vlx1_f32_sme", - "pack/kai_lhs_pack_x16p2vlx2_x16_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", - "pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme", "pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme", "pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme", ] @@ -156,17 +154,17 @@ SME_KERNELS_ASM = [ "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", "pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme", "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme", + "pack/kai_lhs_pack_x16p2vlx2_x16_sme", "pack/kai_lhs_pack_x8p2vlx4_x8_sme", "pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", "pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme", "pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme", "pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", + "pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme", ] # buildifier: keep sorted SME2_KERNELS = [ - "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", - "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla", "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla", "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa", @@ -182,6 +180,8 @@ SME2_KERNELS_ASM = [ "imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", "imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa", "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", + "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", + "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot", "matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot", diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c index 9a7aaf11cf1cc18e83248a555fb6a6fc5dff657e..aa7d3a7d4905a6ffe90f906bce68004055f30c8a 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c @@ -4,10 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. @@ -19,6 +16,20 @@ #include "kai/kai_common.h" +typedef struct { + uint16_t maxval; + uint16_t minval; + const void* A_ptr; + const void* B_ptr; + size_t N; + size_t K; + void* output_ptr; + uint64_t flags; +} KernelArgs; + +void kai_kernel_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(KernelArgs* args_ptr); +uint16_t kai_f16_from_float_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(float value); + static const size_t kai_m_step = 1; static const size_t kai_nr = 2; static const size_t kai_n_step = 16; @@ -46,14 +57,14 @@ size_t kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(void) { } size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(size_t m_idx, size_t k) { - KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot() == 0); + KAI_ASSUME(m_idx == 0); return m_idx * k; } static size_t kai_get_rhs_packed_stride_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(size_t k) { return kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot() * - ((kai_roundup(k, kai_kr) * sizeof(uint16_t) + sizeof(uint16_t))); + (kai_roundup(k, kai_kr) * sizeof(uint16_t) + sizeof(uint16_t)); } size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(size_t n_idx, size_t k) { @@ -65,7 +76,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot() == 0); + KAI_ASSUME(m_idx == 0); KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot() == 0); return (m_idx * dst_stride) + (n_idx * sizeof(uint16_t)); @@ -80,837 +91,26 @@ void kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot( size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) { KAI_UNUSED(dst_stride_row); KAI_UNUSED(dst_stride_col); + KAI_UNUSED(lhs_stride); - KAI_ASSERT(m == 1); - typedef struct { - float16_t maxval; - float16_t minval; - } KernelArgs; + KAI_ASSUME(m == 1); - KernelArgs ka; - ka.maxval = (float16_t)clamp_max; - ka.minval = (float16_t)clamp_min; + uint64_t flags = 2; - size_t N = n; - size_t K = k; + KernelArgs args; - const void* A_ptr = lhs; - const void* B_ptr = rhs_packed; - void* output_ptr = dst; + args.minval = kai_f16_from_float_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(clamp_min); + args.maxval = kai_f16_from_float_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(clamp_max); - uint64_t flags = 2; + args.A_ptr = lhs; + args.B_ptr = rhs_packed; + args.N = n; + args.K = k; + args.output_ptr = dst; + args.flags = flags; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x8, #0x0\n" - "mov x16, %x[B_ptr]\n" - "cntw x15, ALL, MUL #4\n" - "mov x14, %x[output_ptr]\n" - "add x13, %x[N], x15\n" - "ptrue p1.b\n" - "sub x13, x13, #0x1\n" - ".inst 0x25207811 // ptrue pn9.b\n" - "udiv x13, x13, x15\n" - "mov x22, #0x1\n" - "add x21, x13, #0x3\n" - "and x21, x21, #0xfffffffffffffffc\n" - "mul x21, x21, x15\n" - "mul x21, x21, %x[K]\n" - "lsl x21, x21, #0x1\n" - "1:" // RHS size check loop - "cmp x21, #0x200000\n" - "blt 2f\n" - "tbnz x21, #0, 3f\n" - "lsr x21, x21, #0x1\n" - "lsl x22, x22, #0x1\n" - "b 1b\n" - "2:" // RHS do prefetch - "lsl x20, x21, #0x26\n" - "sub x22, x22, #0x1\n" - "lsl x22, x22, #0x16\n" - "orr x21, x21, x20\n" - "orr x21, x21, x22\n" - ".inst 0xf8b54a1a // rprfm pldonce, x21, [x16]\n" - "3:" // RHS prefetch exit - "add x12, %x[K], #0x1\n" - "cntw x20, ALL, MUL #2\n" - "bic x12, x12, #0x1\n" - "lsl x12, x12, #0x1\n" - "add x12, x12, #0x2\n" - "mul x12, x12, x20\n" - "4:" // Column loop - "cmp x13, #0x4\n" - "bge 22f\n" - "cmp x13, #0x2\n" - "bgt 16f\n" - "beq 10f\n" - "cntw x20, ALL, MUL #2\n" - "add x22, x16, x12\n" - "ld1h { z8.s }, p1/Z, [x16]\n" - "cmp %x[N], x20\n" - "ld1h { z9.s }, p1/Z, [x16, #1, MUL VL]\n" - "mov x11, %x[K]\n" - "csel x22, x22, x16, GT\n" - "mov x21, %x[N]\n" - "ld1h { z10.s }, p1/Z, [x22]\n" - "fcvt z8.s, p1/m, z8.h\n" - "mov x10, %x[A_ptr]\n" - "lsl x20, %x[K], #0x1\n" - "ld1h { z11.s }, p1/Z, [x22, #1, MUL VL]\n" - "fcvt z9.s, p1/m, z9.h\n" - ".inst 0x257547f0 // whilelt p8.h, XZR, x21, VLx2\n" - "cmp x11, #0x8\n" - ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" - "inch x16, ALL, MUL #2\n" - "fcvt z10.s, p1/m, z10.h\n" - "inch x22, ALL, MUL #2\n" - "fcvt z11.s, p1/m, z11.h\n" - ".inst 0xc0040d00 // mova za.d[x8, #0], { z8.d-z11.d }\n" - "ble 6f\n" - "5:" // Width 1: Multiply loop: Main loop head - "whilelt p0.h, XZR, x11\n" - ".inst 0xa0402615 // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - "ld1rqh { z4.h }, p0/Z, [x10]\n" - "sub x11, x11, #0x8\n" - "add x10, x10, #0x10\n" - ".inst 0xa04026d7 // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - "cmp x11, #0x8\n" - ".inst 0xa0402609 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026cb // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1549288 // fdot za.s[x8, 0], { z20.h-z23.h }, z4.h[0]\n" - ".inst 0xa040260d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026cf // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0402615 // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026d7 // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1549508 // fdot za.s[x8, 0], { z8.h-z11.h }, z4.h[1]\n" - ".inst 0xc1549988 // fdot za.s[x8, 0], { z12.h-z15.h }, z4.h[2]\n" - ".inst 0xc1549e88 // fdot za.s[x8, 0], { z20.h-z23.h }, z4.h[3]\n" - "bgt 5b\n" - "6:" // Width 1: Multiply loop: Single iteration only - "whilelt p0.h, XZR, x11\n" - ".inst 0xa040260d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x2\n" - "ld1rqh { z3.h }, p0/Z, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026cf // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1539188 // fdot za.s[x8, 0], { z12.h-z15.h }, z3.h[0]\n" - "ble 7f\n" - ".inst 0xa0402605 // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x2\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026c7 // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1539488 // fdot za.s[x8, 0], { z4.h-z7.h }, z3.h[1]\n" - "ble 7f\n" - ".inst 0xa040260d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x2\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026cf // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1539988 // fdot za.s[x8, 0], { z12.h-z15.h }, z3.h[2]\n" - "ble 7f\n" - ".inst 0xa0402611 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x16]\n" - ".inst 0xa04026d3 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x22]\n" - ".inst 0xc1539e08 // fdot za.s[x8, 0], { z16.h-z19.h }, z3.h[3]\n" - "7:" // Width 1: Multiply loop: multiply skip - "tbz %x[flags], #1, 8f\n" - ".inst 0xc0060c04 // mova { z4.d-z7.d }, za.d[x8, #0]\n" - "add x21, %x[args_ptr], %[offset_min]\n" - "add x20, %x[args_ptr], %[offset_max]\n" - "ld1rh { z19.h }, p1/Z, [x21]\n" - "ld1rh { z22.h }, p1/Z, [x20]\n" - ".inst 0xc120e094 // fcvt z20.h, { z4.s-z5.s }\n" - ".inst 0xc120e0d5 // fcvt z21.h, { z6.s-z7.s }\n" - ".inst 0xc176c274 // fclamp { z20.h-z21.h }, z19.h, z22.h\n" - ".inst 0xa06021d4 // st1h { z20.h-z21.h }, p8, [x14]\n" - "b 9f\n" - "8:" // Width 1: No activation - ".inst 0xc0060c04 // mova { z4.d-z7.d }, za.d[x8, #0]\n" - ".inst 0xc120e084 // fcvt z4.h, { z4.s-z5.s }\n" - ".inst 0xc120e0c5 // fcvt z5.h, { z6.s-z7.s }\n" - ".inst 0xa06021c4 // st1h { z4.h-z5.h }, p8, [x14]\n" - "9:" // Width 1: Output done - "b 28f\n" - "10:" // Width 2 - "add x24, x16, x12, LSL #1\n" - "cntw x20, ALL, MUL #6\n" - "ld1h { z24.s }, p1/Z, [x16]\n" - "add x23, x24, x12\n" - "cmp %x[N], x20\n" - "ld1h { z25.s }, p1/Z, [x16, #1, MUL VL]\n" - "add x22, x16, x12\n" - "csel x23, x23, x16, GT\n" - "ld1h { z0.s }, p1/Z, [x24]\n" - "ld1h { z26.s }, p1/Z, [x22]\n" - "fcvt z24.s, p1/m, z24.h\n" - "mov x11, %x[K]\n" - "sub x21, %x[N], x15\n" - "ld1h { z27.s }, p1/Z, [x22, #1, MUL VL]\n" - "fcvt z25.s, p1/m, z25.h\n" - "mov x10, %x[A_ptr]\n" - "lsl x20, %x[K], #0x1\n" - "ld1h { z1.s }, p1/Z, [x24, #1, MUL VL]\n" - "fcvt z0.s, p1/m, z0.h\n" - ".inst 0x257547f0 // whilelt p8.h, XZR, x21, VLx2\n" - "cmp x11, #0x8\n" - "ld1h { z2.s }, p1/Z, [x23]\n" - "fcvt z26.s, p1/m, z26.h\n" - ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" - "inch x16, ALL, MUL #2\n" - "ld1h { z3.s }, p1/Z, [x23, #1, MUL VL]\n" - "fcvt z27.s, p1/m, z27.h\n" - "inch x22, ALL, MUL #2\n" - "inch x24, ALL, MUL #2\n" - "fcvt z1.s, p1/m, z1.h\n" - "inch x23, ALL, MUL #2\n" - "fcvt z2.s, p1/m, z2.h\n" - "fcvt z3.s, p1/m, z3.h\n" - ".inst 0xc0040f00 // mova za.d[x8, #0], { z24.d-z27.d }\n" - ".inst 0xc0040c01 // mova za.d[x8, #1], { z0.d-z3.d }\n" - "ble 12f\n" - "11:" // Width 2: Multiply loop: Main loop head - "whilelt p0.h, XZR, x11\n" - ".inst 0xa0402615 // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - "ld1rqh { z4.h }, p0/Z, [x10]\n" - "sub x11, x11, #0x8\n" - "add x10, x10, #0x10\n" - ".inst 0xa04026d7 // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - "cmp x11, #0x8\n" - ".inst 0xa0402709 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04026eb // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1549288 // fdot za.s[x8, 0], { z20.h-z23.h }, z4.h[0]\n" - ".inst 0xa0402615 // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026d7 // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1549109 // fdot za.s[x8, 1], { z8.h-z11.h }, z4.h[0]\n" - ".inst 0xa040270d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04026ef // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1549688 // fdot za.s[x8, 0], { z20.h-z23.h }, z4.h[1]\n" - ".inst 0xa0402611 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026d3 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1549589 // fdot za.s[x8, 1], { z12.h-z15.h }, z4.h[1]\n" - ".inst 0xa0402719 // ldnt1h { z24.h-z25.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04026fb // ldnt1h { z26.h-z27.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1549a08 // fdot za.s[x8, 0], { z16.h-z19.h }, z4.h[2]\n" - ".inst 0xa0402615 // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026d7 // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xc1549b09 // fdot za.s[x8, 1], { z24.h-z27.h }, z4.h[2]\n" - ".inst 0xa0402709 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04026eb // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1549e88 // fdot za.s[x8, 0], { z20.h-z23.h }, z4.h[3]\n" - ".inst 0xc1549d09 // fdot za.s[x8, 1], { z8.h-z11.h }, z4.h[3]\n" - "bgt 11b\n" - "12:" // Width 2: Multiply loop: Single iteration only - "whilelt p0.h, XZR, x11\n" - ".inst 0xa0402615 // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x2\n" - "ld1rqh { z3.h }, p0/Z, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026d7 // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa040270d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04026ef // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1539288 // fdot za.s[x8, 0], { z20.h-z23.h }, z3.h[0]\n" - ".inst 0xc1539189 // fdot za.s[x8, 1], { z12.h-z15.h }, z3.h[0]\n" - "ble 13f\n" - ".inst 0xa0402611 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x2\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026d3 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0402709 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04026eb // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1539608 // fdot za.s[x8, 0], { z16.h-z19.h }, z3.h[1]\n" - ".inst 0xc1539509 // fdot za.s[x8, 1], { z8.h-z11.h }, z3.h[1]\n" - "ble 13f\n" - ".inst 0xa0402619 // ldnt1h { z24.h-z25.h }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x2\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026db // ldnt1h { z26.h-z27.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0402711 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa04026f3 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1539b08 // fdot za.s[x8, 0], { z24.h-z27.h }, z3.h[2]\n" - ".inst 0xc1539a09 // fdot za.s[x8, 1], { z16.h-z19.h }, z3.h[2]\n" - "ble 13f\n" - ".inst 0xa0402609 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x16]\n" - ".inst 0xa04026cb // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x22]\n" - ".inst 0xa0402705 // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x24]\n" - ".inst 0xa04026e7 // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x23]\n" - ".inst 0xc1539d08 // fdot za.s[x8, 0], { z8.h-z11.h }, z3.h[3]\n" - ".inst 0xc1539c89 // fdot za.s[x8, 1], { z4.h-z7.h }, z3.h[3]\n" - "13:" // Width 2: Multiply loop: multiply skip - "tbz %x[flags], #1, 14f\n" - ".inst 0xc0060c08 // mova { z8.d-z11.d }, za.d[x8, #0]\n" - "add x21, %x[args_ptr], %[offset_min]\n" - "add x20, %x[args_ptr], %[offset_max]\n" - ".inst 0xc0060c2c // mova { z12.d-z15.d }, za.d[x8, #1]\n" - "ld1rh { z6.h }, p1/Z, [x21]\n" - "ld1rh { z22.h }, p1/Z, [x20]\n" - ".inst 0xc120e112 // fcvt z18.h, { z8.s-z9.s }\n" - ".inst 0xc120e153 // fcvt z19.h, { z10.s-z11.s }\n" - ".inst 0xc120e190 // fcvt z16.h, { z12.s-z13.s }\n" - ".inst 0xc120e1d1 // fcvt z17.h, { z14.s-z15.s }\n" - ".inst 0xc176c0d2 // fclamp { z18.h-z19.h }, z6.h, z22.h\n" - ".inst 0xc176c0d0 // fclamp { z16.h-z17.h }, z6.h, z22.h\n" - ".inst 0xa06025d2 // st1h { z18.h-z19.h }, pn9.b, [x14]\n" - ".inst 0xa06121d0 // st1h { z16.h-z17.h }, p8, [x14, #0x2, MUL VL]\n" - "b 15f\n" - "14:" // Width 2: No activation - ".inst 0xc0060c1c // mova { z28.d-z31.d }, za.d[x8, #0]\n" - ".inst 0xc0060c34 // mova { z20.d-z23.d }, za.d[x8, #1]\n" - ".inst 0xc120e39a // fcvt z26.h, { z28.s-z29.s }\n" - ".inst 0xc120e3db // fcvt z27.h, { z30.s-z31.s }\n" - ".inst 0xa06025da // st1h { z26.h-z27.h }, pn9.b, [x14]\n" - ".inst 0xc120e291 // fcvt z17.h, { z20.s-z21.s }\n" - ".inst 0xc120e2d9 // fcvt z25.h, { z22.s-z23.s }\n" - ".inst 0xa16121d1 // st1h { z17.h, z25.h }, p8, [x14, #0x2, MUL VL]\n" - "15:" // Width 2: Output done - "b 28f\n" - "16:" // Width 3 - "add x26, x16, x12, LSL #2\n" - "cntw x20, ALL, MUL #10\n" - "ld1h { z28.s }, p1/Z, [x16]\n" - "add x25, x16, x12, LSL #1\n" - "add x24, x26, x12\n" - "ld1h { z29.s }, p1/Z, [x16, #1, MUL VL]\n" - "cmp %x[N], x20\n" - "add x23, x16, x12\n" - "ld1h { z4.s }, p1/Z, [x25]\n" - "add x22, x25, x12\n" - "csel x24, x24, x16, GT\n" - "ld1h { z30.s }, p1/Z, [x23]\n" - "fcvt z28.s, p1/m, z28.h\n" - "ld1h { z31.s }, p1/Z, [x23, #1, MUL VL]\n" - "fcvt z29.s, p1/m, z29.h\n" - "mov x20, #0x2\n" - "mov x11, %x[K]\n" - "ld1h { z5.s }, p1/Z, [x25, #1, MUL VL]\n" - "fcvt z4.s, p1/m, z4.h\n" - "msub x21, x15, x20, %x[N]\n" - "mov x10, %x[A_ptr]\n" - "ld1h { z6.s }, p1/Z, [x22]\n" - "fcvt z30.s, p1/m, z30.h\n" - "lsl x20, %x[K], #0x1\n" - ".inst 0x257547f0 // whilelt p8.h, XZR, x21, VLx2\n" - "ld1h { z7.s }, p1/Z, [x22, #1, MUL VL]\n" - "fcvt z31.s, p1/m, z31.h\n" - "cmp x11, #0x8\n" - ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" - "ld1h { z8.s }, p1/Z, [x26]\n" - "fcvt z5.s, p1/m, z5.h\n" - "inch x16, ALL, MUL #2\n" - "inch x23, ALL, MUL #2\n" - "ld1h { z9.s }, p1/Z, [x26, #1, MUL VL]\n" - "fcvt z6.s, p1/m, z6.h\n" - "inch x25, ALL, MUL #2\n" - "inch x22, ALL, MUL #2\n" - "ld1h { z10.s }, p1/Z, [x24]\n" - "fcvt z7.s, p1/m, z7.h\n" - "inch x26, ALL, MUL #2\n" - "ld1h { z11.s }, p1/Z, [x24, #1, MUL VL]\n" - "fcvt z8.s, p1/m, z8.h\n" - "inch x24, ALL, MUL #2\n" - ".inst 0xc0040f80 // mova za.d[x8, #0], { z28.d-z31.d }\n" - "fcvt z9.s, p1/m, z9.h\n" - "fcvt z10.s, p1/m, z10.h\n" - "fcvt z11.s, p1/m, z11.h\n" - ".inst 0xc0040c81 // mova za.d[x8, #1], { z4.d-z7.d }\n" - ".inst 0xc0040d02 // mova za.d[x8, #2], { z8.d-z11.d }\n" - "ble 18f\n" - "17:" // Width 3: Multiply loop: Main loop head - "whilelt p0.h, XZR, x11\n" - ".inst 0xa040260d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - "ld1rqh { z4.h }, p0/Z, [x10]\n" - "sub x11, x11, #0x8\n" - "add x10, x10, #0x10\n" - ".inst 0xa04026ef // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - "cmp x11, #0x8\n" - ".inst 0xa0402731 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04026d3 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0402741 // ldnt1h { z0.h-z1.h }, pn9.b/Z, [x26]\n" - ".inst 0xc1549188 // fdot za.s[x8, 0], { z12.h-z15.h }, z4.h[0]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0402703 // ldnt1h { z2.h-z3.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1549209 // fdot za.s[x8, 1], { z16.h-z19.h }, z4.h[0]\n" - ".inst 0xa0402611 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026f3 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc154900a // fdot za.s[x8, 2], { z0.h-z3.h }, z4.h[0]\n" - ".inst 0xa040272d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04026cf // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0402755 // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x26]\n" - ".inst 0xc1549608 // fdot za.s[x8, 0], { z16.h-z19.h }, z4.h[1]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0402717 // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1549589 // fdot za.s[x8, 1], { z12.h-z15.h }, z4.h[1]\n" - ".inst 0xa040260d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026ef // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc154968a // fdot za.s[x8, 2], { z20.h-z23.h }, z4.h[1]\n" - ".inst 0xa0402729 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04026cb // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0402751 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x26]\n" - ".inst 0xc1549988 // fdot za.s[x8, 0], { z12.h-z15.h }, z4.h[2]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0402713 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1549909 // fdot za.s[x8, 1], { z8.h-z11.h }, z4.h[2]\n" - ".inst 0xa0402619 // ldnt1h { z24.h-z25.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026fb // ldnt1h { z26.h-z27.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xc1549a0a // fdot za.s[x8, 2], { z16.h-z19.h }, z4.h[2]\n" - ".inst 0xa0402731 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04026d3 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa040274d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x26]\n" - ".inst 0xc1549f08 // fdot za.s[x8, 0], { z24.h-z27.h }, z4.h[3]\n" - "addvl x26, x26, #2\n" - ".inst 0xa040270f // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1549e09 // fdot za.s[x8, 1], { z16.h-z19.h }, z4.h[3]\n" - ".inst 0xc1549d8a // fdot za.s[x8, 2], { z12.h-z15.h }, z4.h[3]\n" - "bgt 17b\n" - "18:" // Width 3: Multiply loop: Single iteration only - "whilelt p0.h, XZR, x11\n" - ".inst 0xa0402605 // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x2\n" - "ld1rqh { z3.h }, p0/Z, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026e7 // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa040272d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04026cf // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0402759 // ldnt1h { z24.h-z25.h }, pn9.b/Z, [x26]\n" - ".inst 0xc1539088 // fdot za.s[x8, 0], { z4.h-z7.h }, z3.h[0]\n" - "addvl x26, x26, #2\n" - ".inst 0xa040271b // ldnt1h { z26.h-z27.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1539189 // fdot za.s[x8, 1], { z12.h-z15.h }, z3.h[0]\n" - ".inst 0xc153930a // fdot za.s[x8, 2], { z24.h-z27.h }, z3.h[0]\n" - "ble 19f\n" - ".inst 0xa0402619 // ldnt1h { z24.h-z25.h }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x2\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026fb // ldnt1h { z26.h-z27.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0402729 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04026cb // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0402751 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x26]\n" - ".inst 0xc1539708 // fdot za.s[x8, 0], { z24.h-z27.h }, z3.h[1]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0402713 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1539509 // fdot za.s[x8, 1], { z8.h-z11.h }, z3.h[1]\n" - ".inst 0xc153960a // fdot za.s[x8, 2], { z16.h-z19.h }, z3.h[1]\n" - "ble 19f\n" - ".inst 0xa040260d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x2\n" - "addvl x16, x16, #2\n" - ".inst 0xa04026ef // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0402729 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa04026cb // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x22]\n" - "addvl x22, x22, #2\n" - ".inst 0xa0402745 // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x26]\n" - ".inst 0xc1539988 // fdot za.s[x8, 0], { z12.h-z15.h }, z3.h[2]\n" - "addvl x26, x26, #2\n" - ".inst 0xa0402707 // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xc1539909 // fdot za.s[x8, 1], { z8.h-z11.h }, z3.h[2]\n" - ".inst 0xc153988a // fdot za.s[x8, 2], { z4.h-z7.h }, z3.h[2]\n" - "ble 19f\n" - ".inst 0xa0402619 // ldnt1h { z24.h-z25.h }, pn9.b/Z, [x16]\n" - ".inst 0xa04026fb // ldnt1h { z26.h-z27.h }, pn9.b/Z, [x23]\n" - ".inst 0xa040273d // ldnt1h { z28.h-z29.h }, pn9.b/Z, [x25]\n" - ".inst 0xa04026df // ldnt1h { z30.h-z31.h }, pn9.b/Z, [x22]\n" - ".inst 0xa040274d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x26]\n" - ".inst 0xc1539f08 // fdot za.s[x8, 0], { z24.h-z27.h }, z3.h[3]\n" - ".inst 0xa040270f // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x24]\n" - ".inst 0xc1539f89 // fdot za.s[x8, 1], { z28.h-z31.h }, z3.h[3]\n" - ".inst 0xc1539d8a // fdot za.s[x8, 2], { z12.h-z15.h }, z3.h[3]\n" - "19:" // Width 3: Multiply loop: multiply skip - "tbz %x[flags], #1, 20f\n" - ".inst 0xc0060c18 // mova { z24.d-z27.d }, za.d[x8, #0]\n" - "add x21, %x[args_ptr], %[offset_min]\n" - "add x20, %x[args_ptr], %[offset_max]\n" - ".inst 0xc0060c3c // mova { z28.d-z31.d }, za.d[x8, #1]\n" - "ld1rh { z19.h }, p1/Z, [x21]\n" - ".inst 0xc0060c40 // mova { z0.d-z3.d }, za.d[x8, #2]\n" - "ld1rh { z18.h }, p1/Z, [x20]\n" - ".inst 0xc120e314 // fcvt z20.h, { z24.s-z25.s }\n" - ".inst 0xc120e355 // fcvt z21.h, { z26.s-z27.s }\n" - ".inst 0xc120e38e // fcvt z14.h, { z28.s-z29.s }\n" - ".inst 0xc120e3cf // fcvt z15.h, { z30.s-z31.s }\n" - ".inst 0xc172c274 // fclamp { z20.h-z21.h }, z19.h, z18.h\n" - ".inst 0xc120e010 // fcvt z16.h, { z0.s-z1.s }\n" - ".inst 0xc120e051 // fcvt z17.h, { z2.s-z3.s }\n" - ".inst 0xc172c26e // fclamp { z14.h-z15.h }, z19.h, z18.h\n" - ".inst 0xc172c270 // fclamp { z16.h-z17.h }, z19.h, z18.h\n" - ".inst 0xa06025d4 // st1h { z20.h-z21.h }, pn9.b, [x14]\n" - ".inst 0xa06125ce // st1h { z14.h-z15.h }, pn9.b, [x14, #0x2, MUL VL]\n" - ".inst 0xa06221d0 // st1h { z16.h-z17.h }, p8, [x14, #0x4, MUL VL]\n" - "b 21f\n" - "20:" // Width 3: No activation - ".inst 0xc0060c04 // mova { z4.d-z7.d }, za.d[x8, #0]\n" - ".inst 0xc0060c28 // mova { z8.d-z11.d }, za.d[x8, #1]\n" - ".inst 0xc0060c4c // mova { z12.d-z15.d }, za.d[x8, #2]\n" - ".inst 0xc120e091 // fcvt z17.h, { z4.s-z5.s }\n" - ".inst 0xc120e0d9 // fcvt z25.h, { z6.s-z7.s }\n" - ".inst 0xa16025d1 // st1h { z17.h, z25.h }, pn9.b, [x14]\n" - ".inst 0xc120e112 // fcvt z18.h, { z8.s-z9.s }\n" - ".inst 0xc120e153 // fcvt z19.h, { z10.s-z11.s }\n" - ".inst 0xa06125d2 // st1h { z18.h-z19.h }, pn9.b, [x14, #0x2, MUL VL]\n" - ".inst 0xc120e191 // fcvt z17.h, { z12.s-z13.s }\n" - ".inst 0xc120e1d9 // fcvt z25.h, { z14.s-z15.s }\n" - ".inst 0xa16221d1 // st1h { z17.h, z25.h }, p8, [x14, #0x4, MUL VL]\n" - "21:" // Width 3: Output done - "b 28f\n" - "22:" // Width 4 - "add x9, x16, x12, LSL #2\n" - "cntw x20, ALL, MUL #14\n" - "ld1h { z12.s }, p1/Z, [x16]\n" - "add x28, x9, x12, LSL #1\n" - "add x27, x16, x12, LSL #1\n" - "ld1h { z13.s }, p1/Z, [x16, #1, MUL VL]\n" - "add x26, x28, x12\n" - "cmp %x[N], x20\n" - "ld1h { z8.s }, p1/Z, [x27]\n" - "add x25, x16, x12\n" - "add x24, x27, x12\n" - "ld1h { z9.s }, p1/Z, [x27, #1, MUL VL]\n" - "fcvt z12.s, p1/m, z12.h\n" - "add x23, x9, x12\n" - "csel x26, x26, x16, GT\n" - "ld1h { z14.s }, p1/Z, [x25]\n" - "fcvt z13.s, p1/m, z13.h\n" - "ld1h { z15.s }, p1/Z, [x25, #1, MUL VL]\n" - "fcvt z8.s, p1/m, z8.h\n" - "mov x20, #0x3\n" - "mov x11, %x[K]\n" - "ld1h { z10.s }, p1/Z, [x24]\n" - "fcvt z9.s, p1/m, z9.h\n" - "msub x21, x15, x20, %x[N]\n" - "mov x10, %x[A_ptr]\n" - "ld1h { z11.s }, p1/Z, [x24, #1, MUL VL]\n" - "fcvt z14.s, p1/m, z14.h\n" - "lsl x20, %x[K], #0x1\n" - ".inst 0x257547f0 // whilelt p8.h, XZR, x21, VLx2\n" - "ld1h { z4.s }, p1/Z, [x9]\n" - "fcvt z15.s, p1/m, z15.h\n" - "cmp x11, #0x8\n" - ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" - "ld1h { z5.s }, p1/Z, [x9, #1, MUL VL]\n" - "fcvt z10.s, p1/m, z10.h\n" - "add x22, x16, x12, LSL #3\n" - "inch x16, ALL, MUL #2\n" - "ld1h { z6.s }, p1/Z, [x23]\n" - "fcvt z11.s, p1/m, z11.h\n" - "inch x25, ALL, MUL #2\n" - "inch x27, ALL, MUL #2\n" - "ld1h { z7.s }, p1/Z, [x23, #1, MUL VL]\n" - "fcvt z4.s, p1/m, z4.h\n" - "inch x24, ALL, MUL #2\n" - "inch x9, ALL, MUL #2\n" - "ld1h { z0.s }, p1/Z, [x28]\n" - "fcvt z5.s, p1/m, z5.h\n" - "inch x23, ALL, MUL #2\n" - ".inst 0xc0040d80 // mova za.d[x8, #0], { z12.d-z15.d }\n" - "ld1h { z1.s }, p1/Z, [x28, #1, MUL VL]\n" - "fcvt z6.s, p1/m, z6.h\n" - "inch x28, ALL, MUL #2\n" - "ld1h { z2.s }, p1/Z, [x26]\n" - "fcvt z7.s, p1/m, z7.h\n" - ".inst 0xc0040d01 // mova za.d[x8, #1], { z8.d-z11.d }\n" - "ld1h { z3.s }, p1/Z, [x26, #1, MUL VL]\n" - "fcvt z0.s, p1/m, z0.h\n" - "inch x26, ALL, MUL #2\n" - "fcvt z1.s, p1/m, z1.h\n" - "fcvt z2.s, p1/m, z2.h\n" - "fcvt z3.s, p1/m, z3.h\n" - ".inst 0xc0040c82 // mova za.d[x8, #2], { z4.d-z7.d }\n" - ".inst 0xc0040c03 // mova za.d[x8, #3], { z0.d-z3.d }\n" - "ble 24f\n" - "23:" // Width 4: Multiply loop: Main loop head - "whilelt p0.h, XZR, x11\n" - ".inst 0xa040260d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - "ld1rqh { z0.h }, p0/Z, [x10]\n" - "sub x11, x11, #0x8\n" - "add x10, x10, #0x10\n" - ".inst 0xa040272f // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - "cmp x11, #0x8\n" - ".inst 0xa0402765 // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0402707 // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0402529 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x9]\n" - ".inst 0xc1509188 // fdot za.s[x8, 0], { z12.h-z15.h }, z0.h[0]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04026eb // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa040278d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x28]\n" - ".inst 0xc1509089 // fdot za.s[x8, 1], { z4.h-z7.h }, z0.h[0]\n" - "addvl x28, x28, #2\n" - ".inst 0xa040274f // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc150910a // fdot za.s[x8, 2], { z8.h-z11.h }, z0.h[0]\n" - ".inst 0xa0402609 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa040272b // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xc150918b // fdot za.s[x8, 3], { z12.h-z15.h }, z0.h[0]\n" - ".inst 0xa0402765 // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0402707 // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa040252d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x9]\n" - ".inst 0xc1509508 // fdot za.s[x8, 0], { z8.h-z11.h }, z0.h[1]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04026ef // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0402789 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x28]\n" - ".inst 0xc1509489 // fdot za.s[x8, 1], { z4.h-z7.h }, z0.h[1]\n" - "addvl x28, x28, #2\n" - ".inst 0xa040274b // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc150958a // fdot za.s[x8, 2], { z12.h-z15.h }, z0.h[1]\n" - ".inst 0xa040260d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa040272f // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xc150950b // fdot za.s[x8, 3], { z8.h-z11.h }, z0.h[1]\n" - ".inst 0xa0402765 // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0402707 // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0402529 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x9]\n" - ".inst 0xc1509988 // fdot za.s[x8, 0], { z12.h-z15.h }, z0.h[2]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04026eb // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa040278d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x28]\n" - ".inst 0xc1509889 // fdot za.s[x8, 1], { z4.h-z7.h }, z0.h[2]\n" - "addvl x28, x28, #2\n" - ".inst 0xa040274f // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc150990a // fdot za.s[x8, 2], { z8.h-z11.h }, z0.h[2]\n" - ".inst 0xa040261d // ldnt1h { z28.h-z29.h }, pn9.b/Z, [x16]\n" - "addvl x16, x16, #2\n" - ".inst 0xa040273f // ldnt1h { z30.h-z31.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xc150998b // fdot za.s[x8, 3], { z12.h-z15.h }, z0.h[2]\n" - ".inst 0xa0402769 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa040270b // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0402535 // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x9]\n" - ".inst 0xc1509f88 // fdot za.s[x8, 0], { z28.h-z31.h }, z0.h[3]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04026f7 // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0402791 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x28]\n" - ".inst 0xc1509d09 // fdot za.s[x8, 1], { z8.h-z11.h }, z0.h[3]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0402753 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc1509e8a // fdot za.s[x8, 2], { z20.h-z23.h }, z0.h[3]\n" - ".inst 0xc1509e0b // fdot za.s[x8, 3], { z16.h-z19.h }, z0.h[3]\n" - "bgt 23b\n" - "24:" // Width 4: Multiply loop: Single iteration only - "whilelt p0.h, XZR, x11\n" - ".inst 0xa0402615 // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x2\n" - "ld1rqh { z3.h }, p0/Z, [x10]\n" - "addvl x16, x16, #2\n" - ".inst 0xa0402737 // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa040276d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa040270f // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0402531 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x9]\n" - ".inst 0xc1539288 // fdot za.s[x8, 0], { z20.h-z23.h }, z3.h[0]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04026f3 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa040279d // ldnt1h { z28.h-z29.h }, pn9.b/Z, [x28]\n" - ".inst 0xc1539189 // fdot za.s[x8, 1], { z12.h-z15.h }, z3.h[0]\n" - "addvl x28, x28, #2\n" - ".inst 0xa040275f // ldnt1h { z30.h-z31.h }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc153920a // fdot za.s[x8, 2], { z16.h-z19.h }, z3.h[0]\n" - ".inst 0xc153938b // fdot za.s[x8, 3], { z28.h-z31.h }, z3.h[0]\n" - "ble 25f\n" - ".inst 0xa0402609 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x2\n" - "addvl x16, x16, #2\n" - ".inst 0xa040272b // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa0402765 // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa0402707 // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa040252d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x9]\n" - ".inst 0xc1539508 // fdot za.s[x8, 0], { z8.h-z11.h }, z3.h[1]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04026ef // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0402795 // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x28]\n" - ".inst 0xc1539489 // fdot za.s[x8, 1], { z4.h-z7.h }, z3.h[1]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0402757 // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc153958a // fdot za.s[x8, 2], { z12.h-z15.h }, z3.h[1]\n" - ".inst 0xc153968b // fdot za.s[x8, 3], { z20.h-z23.h }, z3.h[1]\n" - "ble 25f\n" - ".inst 0xa040260d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x16]\n" - "subs x11, x11, #0x2\n" - "addvl x16, x16, #2\n" - ".inst 0xa040272f // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x25]\n" - "addvl x25, x25, #2\n" - ".inst 0xa0402769 // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0xa040270b // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x24]\n" - "addvl x24, x24, #2\n" - ".inst 0xa0402535 // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x9]\n" - ".inst 0xc1539988 // fdot za.s[x8, 0], { z12.h-z15.h }, z3.h[2]\n" - "addvl x9, x9, #2\n" - ".inst 0xa04026f7 // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x23]\n" - "addvl x23, x23, #2\n" - ".inst 0xa0402791 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x28]\n" - ".inst 0xc1539909 // fdot za.s[x8, 1], { z8.h-z11.h }, z3.h[2]\n" - "addvl x28, x28, #2\n" - ".inst 0xa0402753 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x26]\n" - "addvl x26, x26, #2\n" - ".inst 0xc1539a8a // fdot za.s[x8, 2], { z20.h-z23.h }, z3.h[2]\n" - ".inst 0xc1539a0b // fdot za.s[x8, 3], { z16.h-z19.h }, z3.h[2]\n" - "ble 25f\n" - ".inst 0xa0402605 // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x16]\n" - ".inst 0xa0402727 // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x25]\n" - ".inst 0xa040276d // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x27]\n" - ".inst 0xa040270f // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x24]\n" - ".inst 0xa0402531 // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x9]\n" - ".inst 0xc1539c88 // fdot za.s[x8, 0], { z4.h-z7.h }, z3.h[3]\n" - ".inst 0xa04026f3 // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x23]\n" - ".inst 0xa0402795 // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x28]\n" - ".inst 0xc1539d89 // fdot za.s[x8, 1], { z12.h-z15.h }, z3.h[3]\n" - ".inst 0xa0402757 // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x26]\n" - ".inst 0xc1539e0a // fdot za.s[x8, 2], { z16.h-z19.h }, z3.h[3]\n" - ".inst 0xc1539e8b // fdot za.s[x8, 3], { z20.h-z23.h }, z3.h[3]\n" - "25:" // Width 4: Multiply loop: multiply skip - "tbz %x[flags], #1, 26f\n" - ".inst 0xc0060c1c // mova { z28.d-z31.d }, za.d[x8, #0]\n" - "add x21, %x[args_ptr], %[offset_min]\n" - "add x20, %x[args_ptr], %[offset_max]\n" - ".inst 0xc0060c2c // mova { z12.d-z15.d }, za.d[x8, #1]\n" - "ld1rh { z19.h }, p1/Z, [x21]\n" - ".inst 0xc0060c40 // mova { z0.d-z3.d }, za.d[x8, #2]\n" - "ld1rh { z18.h }, p1/Z, [x20]\n" - ".inst 0xc0060c64 // mova { z4.d-z7.d }, za.d[x8, #3]\n" - ".inst 0xc120e38a // fcvt z10.h, { z28.s-z29.s }\n" - ".inst 0xc120e3cb // fcvt z11.h, { z30.s-z31.s }\n" - ".inst 0xc120e18c // fcvt z12.h, { z12.s-z13.s }\n" - ".inst 0xc120e1cd // fcvt z13.h, { z14.s-z15.s }\n" - ".inst 0xc172c26a // fclamp { z10.h-z11.h }, z19.h, z18.h\n" - ".inst 0xc120e00e // fcvt z14.h, { z0.s-z1.s }\n" - ".inst 0xc120e04f // fcvt z15.h, { z2.s-z3.s }\n" - ".inst 0xc172c26c // fclamp { z12.h-z13.h }, z19.h, z18.h\n" - ".inst 0xc120e090 // fcvt z16.h, { z4.s-z5.s }\n" - ".inst 0xc120e0d1 // fcvt z17.h, { z6.s-z7.s }\n" - ".inst 0xc172c26e // fclamp { z14.h-z15.h }, z19.h, z18.h\n" - ".inst 0xc172c270 // fclamp { z16.h-z17.h }, z19.h, z18.h\n" - ".inst 0xa06025ca // st1h { z10.h-z11.h }, pn9.b, [x14]\n" - ".inst 0xa06125cc // st1h { z12.h-z13.h }, pn9.b, [x14, #0x2, MUL VL]\n" - ".inst 0xa06225ce // st1h { z14.h-z15.h }, pn9.b, [x14, #0x4, MUL VL]\n" - ".inst 0xa06321d0 // st1h { z16.h-z17.h }, p8, [x14, #0x6, MUL VL]\n" - "addvl x14, x14, #8\n" - "b 27f\n" - "26:" // Width 4: No activation - ".inst 0xc0060c0c // mova { z12.d-z15.d }, za.d[x8, #0]\n" - ".inst 0xc0060c30 // mova { z16.d-z19.d }, za.d[x8, #1]\n" - ".inst 0xc0060c5c // mova { z28.d-z31.d }, za.d[x8, #2]\n" - ".inst 0xc0060c68 // mova { z8.d-z11.d }, za.d[x8, #3]\n" - ".inst 0xc120e187 // fcvt z7.h, { z12.s-z13.s }\n" - ".inst 0xc120e1cf // fcvt z15.h, { z14.s-z15.s }\n" - ".inst 0xa16025c7 // st1h { z7.h, z15.h }, pn9.b, [x14]\n" - ".inst 0xc120e207 // fcvt z7.h, { z16.s-z17.s }\n" - ".inst 0xc120e24f // fcvt z15.h, { z18.s-z19.s }\n" - ".inst 0xa16125c7 // st1h { z7.h, z15.h }, pn9.b, [x14, #0x2, MUL VL]\n" - ".inst 0xc120e38e // fcvt z14.h, { z28.s-z29.s }\n" - ".inst 0xc120e3cf // fcvt z15.h, { z30.s-z31.s }\n" - ".inst 0xa06225ce // st1h { z14.h-z15.h }, pn9.b, [x14, #0x4, MUL VL]\n" - ".inst 0xc120e112 // fcvt z18.h, { z8.s-z9.s }\n" - ".inst 0xc120e15a // fcvt z26.h, { z10.s-z11.s }\n" - ".inst 0xa16321d2 // st1h { z18.h, z26.h }, p8, [x14, #0x6, MUL VL]\n" - "addvl x14, x14, #8\n" - "27:" // Width 4: Output done - "subs x13, x13, #0x4\n" - "mov x16, x22\n" - "sub %x[N], %x[N], x15, LSL #2\n" - "bgt 4b\n" - "28:" // Exit - ".inst 0xd503467f // SMSTOP\n" - : [N] "+&r"(N) - : [A_ptr] "r"(A_ptr), [B_ptr] "r"(B_ptr), [K] "r"(K), [args_ptr] "r"(&ka), [flags] "r"(flags), - [offset_max] "I"(offsetof(KernelArgs, maxval)), [offset_min] "I"(offsetof(KernelArgs, minval)), - [output_ptr] "r"(output_ptr) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", - "x27", "x28", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", - "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", - "z6", "z7", "z8", "z9"); + kai_kernel_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h index fb36385ea901318d5554b090f712754a4b4d61ef..b104b575c1d10dd1e55cc892a2d0b33139f1c262 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,8 @@ #include +#include "kai/kai_common.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -55,15 +57,15 @@ size_t kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(void); /// Gets the offset in bytes to the data element in the LHS matrix buffer. /// -/// @param[in] m_idx Row index. -/// @param[in] k Number of columns in unpacked LHS. +/// @param[in] m_idx Row index. This must be 0. +/// @param[in] k Columns of unpacked LHS. /// /// @return The offset in bytes to the data element. size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(size_t m_idx, size_t k); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// -/// @param[in] n_idx Column index in the unpacked RHS matrix. +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of n_step /// @param[in] k Number of rows in the unpacked RHS matrix. /// /// @return The offset in bytes to the data element. @@ -71,8 +73,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot /// Gets the offset in bytes to the data element in the destination matrix buffer. /// -/// @param[in] m_idx Row index. -/// @param[in] n_idx Column index. +/// @param[in] m_idx Row index. Must be 0 +/// @param[in] n_idx Column index. Must be multiple of n_step /// @param[in] dst_stride Row stride in bytes. /// /// @return The offset in bytes to the data element. @@ -100,7 +102,6 @@ size_t kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(size_t m /// @param[in] n Number of output columns to be computed. /// @param[in] k Common dimension of the LHS and RHS operand. /// @param[in] lhs LHS matrix buffer. -/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. Unused parameter. /// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. /// @param[in] dst_stride_row Row stride in bytes of the output matrix. Currently, an unused parameter. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..86641550386921b0e1042d320a43a0266594fd1b --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot_asm.S @@ -0,0 +1,882 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot) + + KAI_ASM_GLOBAL(kai_f16_from_float_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot) + +KAI_ASM_FUNCTION_TYPE(kai_f16_from_float_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot) +KAI_ASM_FUNCTION_LABEL(kai_f16_from_float_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot) + fcvt h0, s0 + fmov w0, h0 + ret + KAI_ASM_FUNCTION_END(kai_f16_from_float_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot) + + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x8, #0x0 + ldr x5, [x0, #0x18] + cntw x6, ALL, MUL #4 + ptrue p1.b + ldr x7, [x0, #0x20] + KAI_ASM_INST(0x25207811) // ptrue pn9.b + mov x22, #0x1 + ldr x21, [x0, #0x10] + add x17, x5, x6 + ldr x20, [x0, #0x28] + sub x17, x17, #0x1 + ldr x16, [x0, #0x8] + udiv x17, x17, x6 + ldr x15, [x0, #0x30] + mov x14, x21 + add x21, x17, #0x3 + mov x13, x20 + and x21, x21, #0xfffffffffffffffc + mul x21, x21, x6 + mul x21, x21, x7 + lsl x21, x21, #0x1 +KAI_ASM_LABEL(label_1) // RHS size check loop + cmp x21, #0x200, LSL #12 + blt label_2 + tbnz x21, #0, label_3 + lsr x21, x21, #0x1 + lsl x22, x22, #0x1 + b label_1 +KAI_ASM_LABEL(label_2) // RHS do prefetch + lsl x20, x21, #0x26 + sub x22, x22, #0x1 + lsl x22, x22, #0x16 + orr x21, x21, x20 + orr x21, x21, x22 + KAI_ASM_INST(0xf8b549da) // rprfm pldonce, x21, [x14] +KAI_ASM_LABEL(label_3) // RHS prefetch exit + add x12, x7, #0x1 + cntw x20, ALL, MUL #2 + bic x12, x12, #0x1 + lsl x12, x12, #0x1 + add x12, x12, #0x2 + mul x12, x12, x20 +KAI_ASM_LABEL(label_4) // Column loop + cmp x17, #0x4 + bge label_22 + cmp x17, #0x2 + bgt label_16 + beq label_10 + cntw x20, ALL, MUL #2 + add x22, x14, x12 + ld1h { z8.s }, p1/Z, [x14] + cmp x5, x20 + ld1h { z9.s }, p1/Z, [x14, #1, MUL VL] + mov x11, x7 + csel x22, x22, x14, GT + mov x21, x5 + ld1h { z10.s }, p1/Z, [x22] + fcvt z8.s, p1/m, z8.h + mov x10, x16 + lsl x20, x7, #0x1 + ld1h { z11.s }, p1/Z, [x22, #1, MUL VL] + fcvt z9.s, p1/m, z9.h + KAI_ASM_INST(0x257547f0) // whilelt p8.h, XZR, x21, VLx2 + cmp x11, #0x8 + KAI_ASM_INST(0xf8b44958) // rprfm pldmany, x20, [x10] + inch x14, ALL, MUL #2 + fcvt z10.s, p1/m, z10.h + inch x22, ALL, MUL #2 + fcvt z11.s, p1/m, z11.h + KAI_ASM_INST(0xc0040d00) // mova za.d[x8, #0], { z8.d-z11.d } + ble label_6 +KAI_ASM_LABEL(label_5) // Width 1: Multiply loop: Main loop head + whilelt p0.h, XZR, x11 + KAI_ASM_INST(0xa04025d5) // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + ld1rqh { z4.h }, p0/Z, [x10] + sub x11, x11, #0x8 + add x10, x10, #0x10 + KAI_ASM_INST(0xa04026d7) // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + cmp x11, #0x8 + KAI_ASM_INST(0xa04025c9) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026cb) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1549288) // fdot za.s[x8, 0], { z20.h-z23.h }, z4.h[0] + KAI_ASM_INST(0xa04025cd) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026cf) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa04025d5) // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026d7) // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1549508) // fdot za.s[x8, 0], { z8.h-z11.h }, z4.h[1] + KAI_ASM_INST(0xc1549988) // fdot za.s[x8, 0], { z12.h-z15.h }, z4.h[2] + KAI_ASM_INST(0xc1549e88) // fdot za.s[x8, 0], { z20.h-z23.h }, z4.h[3] + bgt label_5 +KAI_ASM_LABEL(label_6) // Width 1: Multiply loop: Single iteration only + whilelt p0.h, XZR, x11 + KAI_ASM_INST(0xa04025cd) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x14] + subs x11, x11, #0x2 + ld1rqh { z3.h }, p0/Z, [x10] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026cf) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1539188) // fdot za.s[x8, 0], { z12.h-z15.h }, z3.h[0] + ble label_7 + KAI_ASM_INST(0xa04025c5) // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x14] + subs x11, x11, #0x2 + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026c7) // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1539488) // fdot za.s[x8, 0], { z4.h-z7.h }, z3.h[1] + ble label_7 + KAI_ASM_INST(0xa04025cd) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x14] + subs x11, x11, #0x2 + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026cf) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1539988) // fdot za.s[x8, 0], { z12.h-z15.h }, z3.h[2] + ble label_7 + KAI_ASM_INST(0xa04025d1) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x14] + KAI_ASM_INST(0xa04026d3) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x22] + KAI_ASM_INST(0xc1539e08) // fdot za.s[x8, 0], { z16.h-z19.h }, z3.h[3] +KAI_ASM_LABEL(label_7) // Width 1: Multiply loop: multiply skip + tbz x15, #1, label_8 + KAI_ASM_INST(0xc0060c04) // mova { z4.d-z7.d }, za.d[x8, #0] + add x21, x0, #0x2 + add x20, x0, #0x0 + KAI_ASM_INST(0x84c0a6b3) // ld1rh { z19.h }, p1/Z, [x21] + KAI_ASM_INST(0x84c0a696) // ld1rh { z22.h }, p1/Z, [x20] + KAI_ASM_INST(0xc120e094) // fcvt z20.h, { z4.s-z5.s } + KAI_ASM_INST(0xc120e0d5) // fcvt z21.h, { z6.s-z7.s } + KAI_ASM_INST(0xc176c274) // fclamp { z20.h-z21.h }, z19.h, z22.h + KAI_ASM_INST(0xa06021b4) // st1h { z20.h-z21.h }, p8, [x13] + b label_9 +KAI_ASM_LABEL(label_8) // Width 1: No activation + KAI_ASM_INST(0xc0060c04) // mova { z4.d-z7.d }, za.d[x8, #0] + KAI_ASM_INST(0xc120e084) // fcvt z4.h, { z4.s-z5.s } + KAI_ASM_INST(0xc120e0c5) // fcvt z5.h, { z6.s-z7.s } + KAI_ASM_INST(0xa06021a4) // st1h { z4.h-z5.h }, p8, [x13] +KAI_ASM_LABEL(label_9) // Width 1: Output done + b label_28 +KAI_ASM_LABEL(label_10) // Width 2 + add x24, x14, x12, LSL #1 + cntw x20, ALL, MUL #6 + ld1h { z24.s }, p1/Z, [x14] + add x23, x24, x12 + cmp x5, x20 + ld1h { z25.s }, p1/Z, [x14, #1, MUL VL] + add x22, x14, x12 + csel x23, x23, x14, GT + ld1h { z0.s }, p1/Z, [x24] + ld1h { z26.s }, p1/Z, [x22] + fcvt z24.s, p1/m, z24.h + mov x11, x7 + sub x21, x5, x6 + ld1h { z27.s }, p1/Z, [x22, #1, MUL VL] + fcvt z25.s, p1/m, z25.h + mov x10, x16 + lsl x20, x7, #0x1 + ld1h { z1.s }, p1/Z, [x24, #1, MUL VL] + fcvt z0.s, p1/m, z0.h + KAI_ASM_INST(0x257547f0) // whilelt p8.h, XZR, x21, VLx2 + cmp x11, #0x8 + ld1h { z2.s }, p1/Z, [x23] + fcvt z26.s, p1/m, z26.h + KAI_ASM_INST(0xf8b44958) // rprfm pldmany, x20, [x10] + inch x14, ALL, MUL #2 + ld1h { z3.s }, p1/Z, [x23, #1, MUL VL] + fcvt z27.s, p1/m, z27.h + inch x22, ALL, MUL #2 + inch x24, ALL, MUL #2 + fcvt z1.s, p1/m, z1.h + inch x23, ALL, MUL #2 + fcvt z2.s, p1/m, z2.h + fcvt z3.s, p1/m, z3.h + KAI_ASM_INST(0xc0040f00) // mova za.d[x8, #0], { z24.d-z27.d } + KAI_ASM_INST(0xc0040c01) // mova za.d[x8, #1], { z0.d-z3.d } + ble label_12 +KAI_ASM_LABEL(label_11) // Width 2: Multiply loop: Main loop head + whilelt p0.h, XZR, x11 + KAI_ASM_INST(0xa04025d5) // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + ld1rqh { z4.h }, p0/Z, [x10] + sub x11, x11, #0x8 + add x10, x10, #0x10 + KAI_ASM_INST(0xa04026d7) // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + cmp x11, #0x8 + KAI_ASM_INST(0xa0402709) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04026eb) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1549288) // fdot za.s[x8, 0], { z20.h-z23.h }, z4.h[0] + KAI_ASM_INST(0xa04025d5) // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026d7) // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1549109) // fdot za.s[x8, 1], { z8.h-z11.h }, z4.h[0] + KAI_ASM_INST(0xa040270d) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04026ef) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1549688) // fdot za.s[x8, 0], { z20.h-z23.h }, z4.h[1] + KAI_ASM_INST(0xa04025d1) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026d3) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1549589) // fdot za.s[x8, 1], { z12.h-z15.h }, z4.h[1] + KAI_ASM_INST(0xa0402719) // ldnt1h { z24.h-z25.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04026fb) // ldnt1h { z26.h-z27.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1549a08) // fdot za.s[x8, 0], { z16.h-z19.h }, z4.h[2] + KAI_ASM_INST(0xa04025d5) // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026d7) // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xc1549b09) // fdot za.s[x8, 1], { z24.h-z27.h }, z4.h[2] + KAI_ASM_INST(0xa0402709) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04026eb) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1549e88) // fdot za.s[x8, 0], { z20.h-z23.h }, z4.h[3] + KAI_ASM_INST(0xc1549d09) // fdot za.s[x8, 1], { z8.h-z11.h }, z4.h[3] + bgt label_11 +KAI_ASM_LABEL(label_12) // Width 2: Multiply loop: Single iteration only + whilelt p0.h, XZR, x11 + KAI_ASM_INST(0xa04025d5) // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x14] + subs x11, x11, #0x2 + ld1rqh { z3.h }, p0/Z, [x10] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026d7) // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa040270d) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04026ef) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1539288) // fdot za.s[x8, 0], { z20.h-z23.h }, z3.h[0] + KAI_ASM_INST(0xc1539189) // fdot za.s[x8, 1], { z12.h-z15.h }, z3.h[0] + ble label_13 + KAI_ASM_INST(0xa04025d1) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x14] + subs x11, x11, #0x2 + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026d3) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0402709) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04026eb) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1539608) // fdot za.s[x8, 0], { z16.h-z19.h }, z3.h[1] + KAI_ASM_INST(0xc1539509) // fdot za.s[x8, 1], { z8.h-z11.h }, z3.h[1] + ble label_13 + KAI_ASM_INST(0xa04025d9) // ldnt1h { z24.h-z25.h }, pn9.b/Z, [x14] + subs x11, x11, #0x2 + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026db) // ldnt1h { z26.h-z27.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0402711) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa04026f3) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1539b08) // fdot za.s[x8, 0], { z24.h-z27.h }, z3.h[2] + KAI_ASM_INST(0xc1539a09) // fdot za.s[x8, 1], { z16.h-z19.h }, z3.h[2] + ble label_13 + KAI_ASM_INST(0xa04025c9) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x14] + KAI_ASM_INST(0xa04026cb) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x22] + KAI_ASM_INST(0xa0402705) // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x24] + KAI_ASM_INST(0xa04026e7) // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x23] + KAI_ASM_INST(0xc1539d08) // fdot za.s[x8, 0], { z8.h-z11.h }, z3.h[3] + KAI_ASM_INST(0xc1539c89) // fdot za.s[x8, 1], { z4.h-z7.h }, z3.h[3] +KAI_ASM_LABEL(label_13) // Width 2: Multiply loop: multiply skip + tbz x15, #1, label_14 + KAI_ASM_INST(0xc0060c08) // mova { z8.d-z11.d }, za.d[x8, #0] + add x21, x0, #0x2 + add x20, x0, #0x0 + KAI_ASM_INST(0xc0060c2c) // mova { z12.d-z15.d }, za.d[x8, #1] + KAI_ASM_INST(0x84c0a6a6) // ld1rh { z6.h }, p1/Z, [x21] + KAI_ASM_INST(0x84c0a696) // ld1rh { z22.h }, p1/Z, [x20] + KAI_ASM_INST(0xc120e112) // fcvt z18.h, { z8.s-z9.s } + KAI_ASM_INST(0xc120e153) // fcvt z19.h, { z10.s-z11.s } + KAI_ASM_INST(0xc120e190) // fcvt z16.h, { z12.s-z13.s } + KAI_ASM_INST(0xc120e1d1) // fcvt z17.h, { z14.s-z15.s } + KAI_ASM_INST(0xc176c0d2) // fclamp { z18.h-z19.h }, z6.h, z22.h + KAI_ASM_INST(0xc176c0d0) // fclamp { z16.h-z17.h }, z6.h, z22.h + KAI_ASM_INST(0xa06025b2) // st1h { z18.h-z19.h }, pn9.b, [x13] + KAI_ASM_INST(0xa06121b0) // st1h { z16.h-z17.h }, p8, [x13, #0x2, MUL VL] + b label_15 +KAI_ASM_LABEL(label_14) // Width 2: No activation + KAI_ASM_INST(0xc0060c1c) // mova { z28.d-z31.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c34) // mova { z20.d-z23.d }, za.d[x8, #1] + KAI_ASM_INST(0xc120e39a) // fcvt z26.h, { z28.s-z29.s } + KAI_ASM_INST(0xc120e3db) // fcvt z27.h, { z30.s-z31.s } + KAI_ASM_INST(0xa06025ba) // st1h { z26.h-z27.h }, pn9.b, [x13] + KAI_ASM_INST(0xc120e291) // fcvt z17.h, { z20.s-z21.s } + KAI_ASM_INST(0xc120e2d9) // fcvt z25.h, { z22.s-z23.s } + KAI_ASM_INST(0xa16121b1) // st1h { z17.h, z25.h }, p8, [x13, #0x2, MUL VL] +KAI_ASM_LABEL(label_15) // Width 2: Output done + b label_28 +KAI_ASM_LABEL(label_16) // Width 3 + add x26, x14, x12, LSL #2 + cntw x20, ALL, MUL #10 + ld1h { z28.s }, p1/Z, [x14] + add x25, x14, x12, LSL #1 + add x24, x26, x12 + ld1h { z29.s }, p1/Z, [x14, #1, MUL VL] + cmp x5, x20 + add x23, x14, x12 + ld1h { z4.s }, p1/Z, [x25] + add x22, x25, x12 + csel x24, x24, x14, GT + ld1h { z30.s }, p1/Z, [x23] + fcvt z28.s, p1/m, z28.h + ld1h { z31.s }, p1/Z, [x23, #1, MUL VL] + fcvt z29.s, p1/m, z29.h + mov x20, #0x2 + mov x11, x7 + ld1h { z5.s }, p1/Z, [x25, #1, MUL VL] + fcvt z4.s, p1/m, z4.h + msub x21, x6, x20, x5 + mov x10, x16 + ld1h { z6.s }, p1/Z, [x22] + fcvt z30.s, p1/m, z30.h + lsl x20, x7, #0x1 + KAI_ASM_INST(0x257547f0) // whilelt p8.h, XZR, x21, VLx2 + ld1h { z7.s }, p1/Z, [x22, #1, MUL VL] + fcvt z31.s, p1/m, z31.h + cmp x11, #0x8 + KAI_ASM_INST(0xf8b44958) // rprfm pldmany, x20, [x10] + ld1h { z8.s }, p1/Z, [x26] + fcvt z5.s, p1/m, z5.h + inch x14, ALL, MUL #2 + inch x23, ALL, MUL #2 + ld1h { z9.s }, p1/Z, [x26, #1, MUL VL] + fcvt z6.s, p1/m, z6.h + inch x25, ALL, MUL #2 + inch x22, ALL, MUL #2 + ld1h { z10.s }, p1/Z, [x24] + fcvt z7.s, p1/m, z7.h + inch x26, ALL, MUL #2 + ld1h { z11.s }, p1/Z, [x24, #1, MUL VL] + fcvt z8.s, p1/m, z8.h + inch x24, ALL, MUL #2 + KAI_ASM_INST(0xc0040f80) // mova za.d[x8, #0], { z28.d-z31.d } + fcvt z9.s, p1/m, z9.h + fcvt z10.s, p1/m, z10.h + fcvt z11.s, p1/m, z11.h + KAI_ASM_INST(0xc0040c81) // mova za.d[x8, #1], { z4.d-z7.d } + KAI_ASM_INST(0xc0040d02) // mova za.d[x8, #2], { z8.d-z11.d } + ble label_18 +KAI_ASM_LABEL(label_17) // Width 3: Multiply loop: Main loop head + whilelt p0.h, XZR, x11 + KAI_ASM_INST(0xa04025cd) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + ld1rqh { z4.h }, p0/Z, [x10] + sub x11, x11, #0x8 + add x10, x10, #0x10 + KAI_ASM_INST(0xa04026ef) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + cmp x11, #0x8 + KAI_ASM_INST(0xa0402731) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04026d3) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0402741) // ldnt1h { z0.h-z1.h }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1549188) // fdot za.s[x8, 0], { z12.h-z15.h }, z4.h[0] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0402703) // ldnt1h { z2.h-z3.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1549209) // fdot za.s[x8, 1], { z16.h-z19.h }, z4.h[0] + KAI_ASM_INST(0xa04025d1) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026f3) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc154900a) // fdot za.s[x8, 2], { z0.h-z3.h }, z4.h[0] + KAI_ASM_INST(0xa040272d) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04026cf) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0402755) // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1549608) // fdot za.s[x8, 0], { z16.h-z19.h }, z4.h[1] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0402717) // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1549589) // fdot za.s[x8, 1], { z12.h-z15.h }, z4.h[1] + KAI_ASM_INST(0xa04025cd) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026ef) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc154968a) // fdot za.s[x8, 2], { z20.h-z23.h }, z4.h[1] + KAI_ASM_INST(0xa0402729) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04026cb) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0402751) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1549988) // fdot za.s[x8, 0], { z12.h-z15.h }, z4.h[2] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0402713) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1549909) // fdot za.s[x8, 1], { z8.h-z11.h }, z4.h[2] + KAI_ASM_INST(0xa04025d9) // ldnt1h { z24.h-z25.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026fb) // ldnt1h { z26.h-z27.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xc1549a0a) // fdot za.s[x8, 2], { z16.h-z19.h }, z4.h[2] + KAI_ASM_INST(0xa0402731) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04026d3) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa040274d) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1549f08) // fdot za.s[x8, 0], { z24.h-z27.h }, z4.h[3] + addvl x26, x26, #2 + KAI_ASM_INST(0xa040270f) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1549e09) // fdot za.s[x8, 1], { z16.h-z19.h }, z4.h[3] + KAI_ASM_INST(0xc1549d8a) // fdot za.s[x8, 2], { z12.h-z15.h }, z4.h[3] + bgt label_17 +KAI_ASM_LABEL(label_18) // Width 3: Multiply loop: Single iteration only + whilelt p0.h, XZR, x11 + KAI_ASM_INST(0xa04025c5) // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x14] + subs x11, x11, #0x2 + ld1rqh { z3.h }, p0/Z, [x10] + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026e7) // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa040272d) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04026cf) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0402759) // ldnt1h { z24.h-z25.h }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1539088) // fdot za.s[x8, 0], { z4.h-z7.h }, z3.h[0] + addvl x26, x26, #2 + KAI_ASM_INST(0xa040271b) // ldnt1h { z26.h-z27.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1539189) // fdot za.s[x8, 1], { z12.h-z15.h }, z3.h[0] + KAI_ASM_INST(0xc153930a) // fdot za.s[x8, 2], { z24.h-z27.h }, z3.h[0] + ble label_19 + KAI_ASM_INST(0xa04025d9) // ldnt1h { z24.h-z25.h }, pn9.b/Z, [x14] + subs x11, x11, #0x2 + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026fb) // ldnt1h { z26.h-z27.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0402729) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04026cb) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0402751) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1539708) // fdot za.s[x8, 0], { z24.h-z27.h }, z3.h[1] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0402713) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1539509) // fdot za.s[x8, 1], { z8.h-z11.h }, z3.h[1] + KAI_ASM_INST(0xc153960a) // fdot za.s[x8, 2], { z16.h-z19.h }, z3.h[1] + ble label_19 + KAI_ASM_INST(0xa04025cd) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x14] + subs x11, x11, #0x2 + addvl x14, x14, #2 + KAI_ASM_INST(0xa04026ef) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0402729) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa04026cb) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x22] + addvl x22, x22, #2 + KAI_ASM_INST(0xa0402745) // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1539988) // fdot za.s[x8, 0], { z12.h-z15.h }, z3.h[2] + addvl x26, x26, #2 + KAI_ASM_INST(0xa0402707) // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xc1539909) // fdot za.s[x8, 1], { z8.h-z11.h }, z3.h[2] + KAI_ASM_INST(0xc153988a) // fdot za.s[x8, 2], { z4.h-z7.h }, z3.h[2] + ble label_19 + KAI_ASM_INST(0xa04025d9) // ldnt1h { z24.h-z25.h }, pn9.b/Z, [x14] + KAI_ASM_INST(0xa04026fb) // ldnt1h { z26.h-z27.h }, pn9.b/Z, [x23] + KAI_ASM_INST(0xa040273d) // ldnt1h { z28.h-z29.h }, pn9.b/Z, [x25] + KAI_ASM_INST(0xa04026df) // ldnt1h { z30.h-z31.h }, pn9.b/Z, [x22] + KAI_ASM_INST(0xa040274d) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1539f08) // fdot za.s[x8, 0], { z24.h-z27.h }, z3.h[3] + KAI_ASM_INST(0xa040270f) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x24] + KAI_ASM_INST(0xc1539f89) // fdot za.s[x8, 1], { z28.h-z31.h }, z3.h[3] + KAI_ASM_INST(0xc1539d8a) // fdot za.s[x8, 2], { z12.h-z15.h }, z3.h[3] +KAI_ASM_LABEL(label_19) // Width 3: Multiply loop: multiply skip + tbz x15, #1, label_20 + KAI_ASM_INST(0xc0060c18) // mova { z24.d-z27.d }, za.d[x8, #0] + add x21, x0, #0x2 + add x20, x0, #0x0 + KAI_ASM_INST(0xc0060c3c) // mova { z28.d-z31.d }, za.d[x8, #1] + KAI_ASM_INST(0x84c0a6b3) // ld1rh { z19.h }, p1/Z, [x21] + KAI_ASM_INST(0xc0060c40) // mova { z0.d-z3.d }, za.d[x8, #2] + KAI_ASM_INST(0x84c0a692) // ld1rh { z18.h }, p1/Z, [x20] + KAI_ASM_INST(0xc120e314) // fcvt z20.h, { z24.s-z25.s } + KAI_ASM_INST(0xc120e355) // fcvt z21.h, { z26.s-z27.s } + KAI_ASM_INST(0xc120e38e) // fcvt z14.h, { z28.s-z29.s } + KAI_ASM_INST(0xc120e3cf) // fcvt z15.h, { z30.s-z31.s } + KAI_ASM_INST(0xc172c274) // fclamp { z20.h-z21.h }, z19.h, z18.h + KAI_ASM_INST(0xc120e010) // fcvt z16.h, { z0.s-z1.s } + KAI_ASM_INST(0xc120e051) // fcvt z17.h, { z2.s-z3.s } + KAI_ASM_INST(0xc172c26e) // fclamp { z14.h-z15.h }, z19.h, z18.h + KAI_ASM_INST(0xc172c270) // fclamp { z16.h-z17.h }, z19.h, z18.h + KAI_ASM_INST(0xa06025b4) // st1h { z20.h-z21.h }, pn9.b, [x13] + KAI_ASM_INST(0xa06125ae) // st1h { z14.h-z15.h }, pn9.b, [x13, #0x2, MUL VL] + KAI_ASM_INST(0xa06221b0) // st1h { z16.h-z17.h }, p8, [x13, #0x4, MUL VL] + b label_21 +KAI_ASM_LABEL(label_20) // Width 3: No activation + KAI_ASM_INST(0xc0060c04) // mova { z4.d-z7.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c28) // mova { z8.d-z11.d }, za.d[x8, #1] + KAI_ASM_INST(0xc0060c4c) // mova { z12.d-z15.d }, za.d[x8, #2] + KAI_ASM_INST(0xc120e091) // fcvt z17.h, { z4.s-z5.s } + KAI_ASM_INST(0xc120e0d9) // fcvt z25.h, { z6.s-z7.s } + KAI_ASM_INST(0xa16025b1) // st1h { z17.h, z25.h }, pn9.b, [x13] + KAI_ASM_INST(0xc120e112) // fcvt z18.h, { z8.s-z9.s } + KAI_ASM_INST(0xc120e153) // fcvt z19.h, { z10.s-z11.s } + KAI_ASM_INST(0xa06125b2) // st1h { z18.h-z19.h }, pn9.b, [x13, #0x2, MUL VL] + KAI_ASM_INST(0xc120e191) // fcvt z17.h, { z12.s-z13.s } + KAI_ASM_INST(0xc120e1d9) // fcvt z25.h, { z14.s-z15.s } + KAI_ASM_INST(0xa16221b1) // st1h { z17.h, z25.h }, p8, [x13, #0x4, MUL VL] +KAI_ASM_LABEL(label_21) // Width 3: Output done + b label_28 +KAI_ASM_LABEL(label_22) // Width 4 + add x9, x14, x12, LSL #2 + cntw x20, ALL, MUL #14 + ld1h { z12.s }, p1/Z, [x14] + add x28, x9, x12, LSL #1 + add x27, x14, x12, LSL #1 + ld1h { z13.s }, p1/Z, [x14, #1, MUL VL] + add x26, x28, x12 + cmp x5, x20 + ld1h { z8.s }, p1/Z, [x27] + add x25, x14, x12 + add x24, x27, x12 + ld1h { z9.s }, p1/Z, [x27, #1, MUL VL] + fcvt z12.s, p1/m, z12.h + add x23, x9, x12 + csel x26, x26, x14, GT + ld1h { z14.s }, p1/Z, [x25] + fcvt z13.s, p1/m, z13.h + ld1h { z15.s }, p1/Z, [x25, #1, MUL VL] + fcvt z8.s, p1/m, z8.h + mov x20, #0x3 + mov x11, x7 + ld1h { z10.s }, p1/Z, [x24] + fcvt z9.s, p1/m, z9.h + msub x21, x6, x20, x5 + mov x10, x16 + ld1h { z11.s }, p1/Z, [x24, #1, MUL VL] + fcvt z14.s, p1/m, z14.h + lsl x20, x7, #0x1 + KAI_ASM_INST(0x257547f0) // whilelt p8.h, XZR, x21, VLx2 + ld1h { z4.s }, p1/Z, [x9] + fcvt z15.s, p1/m, z15.h + cmp x11, #0x8 + KAI_ASM_INST(0xf8b44958) // rprfm pldmany, x20, [x10] + ld1h { z5.s }, p1/Z, [x9, #1, MUL VL] + fcvt z10.s, p1/m, z10.h + add x22, x14, x12, LSL #3 + inch x14, ALL, MUL #2 + ld1h { z6.s }, p1/Z, [x23] + fcvt z11.s, p1/m, z11.h + inch x25, ALL, MUL #2 + inch x27, ALL, MUL #2 + ld1h { z7.s }, p1/Z, [x23, #1, MUL VL] + fcvt z4.s, p1/m, z4.h + inch x24, ALL, MUL #2 + inch x9, ALL, MUL #2 + ld1h { z0.s }, p1/Z, [x28] + fcvt z5.s, p1/m, z5.h + inch x23, ALL, MUL #2 + KAI_ASM_INST(0xc0040d80) // mova za.d[x8, #0], { z12.d-z15.d } + ld1h { z1.s }, p1/Z, [x28, #1, MUL VL] + fcvt z6.s, p1/m, z6.h + inch x28, ALL, MUL #2 + ld1h { z2.s }, p1/Z, [x26] + fcvt z7.s, p1/m, z7.h + KAI_ASM_INST(0xc0040d01) // mova za.d[x8, #1], { z8.d-z11.d } + ld1h { z3.s }, p1/Z, [x26, #1, MUL VL] + fcvt z0.s, p1/m, z0.h + inch x26, ALL, MUL #2 + fcvt z1.s, p1/m, z1.h + fcvt z2.s, p1/m, z2.h + fcvt z3.s, p1/m, z3.h + KAI_ASM_INST(0xc0040c82) // mova za.d[x8, #2], { z4.d-z7.d } + KAI_ASM_INST(0xc0040c03) // mova za.d[x8, #3], { z0.d-z3.d } + ble label_24 +KAI_ASM_LABEL(label_23) // Width 4: Multiply loop: Main loop head + whilelt p0.h, XZR, x11 + KAI_ASM_INST(0xa04025cd) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + ld1rqh { z0.h }, p0/Z, [x10] + sub x11, x11, #0x8 + add x10, x10, #0x10 + KAI_ASM_INST(0xa040272f) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + cmp x11, #0x8 + KAI_ASM_INST(0xa0402765) // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0402707) // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0402529) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc1509188) // fdot za.s[x8, 0], { z12.h-z15.h }, z0.h[0] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04026eb) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa040278d) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc1509089) // fdot za.s[x8, 1], { z4.h-z7.h }, z0.h[0] + addvl x28, x28, #2 + KAI_ASM_INST(0xa040274f) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc150910a) // fdot za.s[x8, 2], { z8.h-z11.h }, z0.h[0] + KAI_ASM_INST(0xa04025c9) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa040272b) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xc150918b) // fdot za.s[x8, 3], { z12.h-z15.h }, z0.h[0] + KAI_ASM_INST(0xa0402765) // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0402707) // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa040252d) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc1509508) // fdot za.s[x8, 0], { z8.h-z11.h }, z0.h[1] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04026ef) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0402789) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc1509489) // fdot za.s[x8, 1], { z4.h-z7.h }, z0.h[1] + addvl x28, x28, #2 + KAI_ASM_INST(0xa040274b) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc150958a) // fdot za.s[x8, 2], { z12.h-z15.h }, z0.h[1] + KAI_ASM_INST(0xa04025cd) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa040272f) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xc150950b) // fdot za.s[x8, 3], { z8.h-z11.h }, z0.h[1] + KAI_ASM_INST(0xa0402765) // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0402707) // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0402529) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc1509988) // fdot za.s[x8, 0], { z12.h-z15.h }, z0.h[2] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04026eb) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa040278d) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc1509889) // fdot za.s[x8, 1], { z4.h-z7.h }, z0.h[2] + addvl x28, x28, #2 + KAI_ASM_INST(0xa040274f) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc150990a) // fdot za.s[x8, 2], { z8.h-z11.h }, z0.h[2] + KAI_ASM_INST(0xa04025dd) // ldnt1h { z28.h-z29.h }, pn9.b/Z, [x14] + addvl x14, x14, #2 + KAI_ASM_INST(0xa040273f) // ldnt1h { z30.h-z31.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xc150998b) // fdot za.s[x8, 3], { z12.h-z15.h }, z0.h[2] + KAI_ASM_INST(0xa0402769) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa040270b) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0402535) // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc1509f88) // fdot za.s[x8, 0], { z28.h-z31.h }, z0.h[3] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04026f7) // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0402791) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc1509d09) // fdot za.s[x8, 1], { z8.h-z11.h }, z0.h[3] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0402753) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc1509e8a) // fdot za.s[x8, 2], { z20.h-z23.h }, z0.h[3] + KAI_ASM_INST(0xc1509e0b) // fdot za.s[x8, 3], { z16.h-z19.h }, z0.h[3] + bgt label_23 +KAI_ASM_LABEL(label_24) // Width 4: Multiply loop: Single iteration only + whilelt p0.h, XZR, x11 + KAI_ASM_INST(0xa04025d5) // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x14] + subs x11, x11, #0x2 + ld1rqh { z3.h }, p0/Z, [x10] + addvl x14, x14, #2 + KAI_ASM_INST(0xa0402737) // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa040276d) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa040270f) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0402531) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc1539288) // fdot za.s[x8, 0], { z20.h-z23.h }, z3.h[0] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04026f3) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa040279d) // ldnt1h { z28.h-z29.h }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc1539189) // fdot za.s[x8, 1], { z12.h-z15.h }, z3.h[0] + addvl x28, x28, #2 + KAI_ASM_INST(0xa040275f) // ldnt1h { z30.h-z31.h }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc153920a) // fdot za.s[x8, 2], { z16.h-z19.h }, z3.h[0] + KAI_ASM_INST(0xc153938b) // fdot za.s[x8, 3], { z28.h-z31.h }, z3.h[0] + ble label_25 + KAI_ASM_INST(0xa04025c9) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x14] + subs x11, x11, #0x2 + addvl x14, x14, #2 + KAI_ASM_INST(0xa040272b) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa0402765) // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa0402707) // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa040252d) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc1539508) // fdot za.s[x8, 0], { z8.h-z11.h }, z3.h[1] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04026ef) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0402795) // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc1539489) // fdot za.s[x8, 1], { z4.h-z7.h }, z3.h[1] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0402757) // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc153958a) // fdot za.s[x8, 2], { z12.h-z15.h }, z3.h[1] + KAI_ASM_INST(0xc153968b) // fdot za.s[x8, 3], { z20.h-z23.h }, z3.h[1] + ble label_25 + KAI_ASM_INST(0xa04025cd) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x14] + subs x11, x11, #0x2 + addvl x14, x14, #2 + KAI_ASM_INST(0xa040272f) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x25] + addvl x25, x25, #2 + KAI_ASM_INST(0xa0402769) // ldnt1h { z8.h-z9.h }, pn9.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0xa040270b) // ldnt1h { z10.h-z11.h }, pn9.b/Z, [x24] + addvl x24, x24, #2 + KAI_ASM_INST(0xa0402535) // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc1539988) // fdot za.s[x8, 0], { z12.h-z15.h }, z3.h[2] + addvl x9, x9, #2 + KAI_ASM_INST(0xa04026f7) // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x23] + addvl x23, x23, #2 + KAI_ASM_INST(0xa0402791) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc1539909) // fdot za.s[x8, 1], { z8.h-z11.h }, z3.h[2] + addvl x28, x28, #2 + KAI_ASM_INST(0xa0402753) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x26] + addvl x26, x26, #2 + KAI_ASM_INST(0xc1539a8a) // fdot za.s[x8, 2], { z20.h-z23.h }, z3.h[2] + KAI_ASM_INST(0xc1539a0b) // fdot za.s[x8, 3], { z16.h-z19.h }, z3.h[2] + ble label_25 + KAI_ASM_INST(0xa04025c5) // ldnt1h { z4.h-z5.h }, pn9.b/Z, [x14] + KAI_ASM_INST(0xa0402727) // ldnt1h { z6.h-z7.h }, pn9.b/Z, [x25] + KAI_ASM_INST(0xa040276d) // ldnt1h { z12.h-z13.h }, pn9.b/Z, [x27] + KAI_ASM_INST(0xa040270f) // ldnt1h { z14.h-z15.h }, pn9.b/Z, [x24] + KAI_ASM_INST(0xa0402531) // ldnt1h { z16.h-z17.h }, pn9.b/Z, [x9] + KAI_ASM_INST(0xc1539c88) // fdot za.s[x8, 0], { z4.h-z7.h }, z3.h[3] + KAI_ASM_INST(0xa04026f3) // ldnt1h { z18.h-z19.h }, pn9.b/Z, [x23] + KAI_ASM_INST(0xa0402795) // ldnt1h { z20.h-z21.h }, pn9.b/Z, [x28] + KAI_ASM_INST(0xc1539d89) // fdot za.s[x8, 1], { z12.h-z15.h }, z3.h[3] + KAI_ASM_INST(0xa0402757) // ldnt1h { z22.h-z23.h }, pn9.b/Z, [x26] + KAI_ASM_INST(0xc1539e0a) // fdot za.s[x8, 2], { z16.h-z19.h }, z3.h[3] + KAI_ASM_INST(0xc1539e8b) // fdot za.s[x8, 3], { z20.h-z23.h }, z3.h[3] +KAI_ASM_LABEL(label_25) // Width 4: Multiply loop: multiply skip + tbz x15, #1, label_26 + KAI_ASM_INST(0xc0060c1c) // mova { z28.d-z31.d }, za.d[x8, #0] + add x21, x0, #0x2 + add x20, x0, #0x0 + KAI_ASM_INST(0xc0060c2c) // mova { z12.d-z15.d }, za.d[x8, #1] + KAI_ASM_INST(0x84c0a6b3) // ld1rh { z19.h }, p1/Z, [x21] + KAI_ASM_INST(0xc0060c40) // mova { z0.d-z3.d }, za.d[x8, #2] + KAI_ASM_INST(0x84c0a692) // ld1rh { z18.h }, p1/Z, [x20] + KAI_ASM_INST(0xc0060c64) // mova { z4.d-z7.d }, za.d[x8, #3] + KAI_ASM_INST(0xc120e38a) // fcvt z10.h, { z28.s-z29.s } + KAI_ASM_INST(0xc120e3cb) // fcvt z11.h, { z30.s-z31.s } + KAI_ASM_INST(0xc120e18c) // fcvt z12.h, { z12.s-z13.s } + KAI_ASM_INST(0xc120e1cd) // fcvt z13.h, { z14.s-z15.s } + KAI_ASM_INST(0xc172c26a) // fclamp { z10.h-z11.h }, z19.h, z18.h + KAI_ASM_INST(0xc120e00e) // fcvt z14.h, { z0.s-z1.s } + KAI_ASM_INST(0xc120e04f) // fcvt z15.h, { z2.s-z3.s } + KAI_ASM_INST(0xc172c26c) // fclamp { z12.h-z13.h }, z19.h, z18.h + KAI_ASM_INST(0xc120e090) // fcvt z16.h, { z4.s-z5.s } + KAI_ASM_INST(0xc120e0d1) // fcvt z17.h, { z6.s-z7.s } + KAI_ASM_INST(0xc172c26e) // fclamp { z14.h-z15.h }, z19.h, z18.h + KAI_ASM_INST(0xc172c270) // fclamp { z16.h-z17.h }, z19.h, z18.h + KAI_ASM_INST(0xa06025aa) // st1h { z10.h-z11.h }, pn9.b, [x13] + KAI_ASM_INST(0xa06125ac) // st1h { z12.h-z13.h }, pn9.b, [x13, #0x2, MUL VL] + KAI_ASM_INST(0xa06225ae) // st1h { z14.h-z15.h }, pn9.b, [x13, #0x4, MUL VL] + KAI_ASM_INST(0xa06321b0) // st1h { z16.h-z17.h }, p8, [x13, #0x6, MUL VL] + addvl x13, x13, #8 + b label_27 +KAI_ASM_LABEL(label_26) // Width 4: No activation + KAI_ASM_INST(0xc0060c0c) // mova { z12.d-z15.d }, za.d[x8, #0] + KAI_ASM_INST(0xc0060c30) // mova { z16.d-z19.d }, za.d[x8, #1] + KAI_ASM_INST(0xc0060c5c) // mova { z28.d-z31.d }, za.d[x8, #2] + KAI_ASM_INST(0xc0060c68) // mova { z8.d-z11.d }, za.d[x8, #3] + KAI_ASM_INST(0xc120e187) // fcvt z7.h, { z12.s-z13.s } + KAI_ASM_INST(0xc120e1cf) // fcvt z15.h, { z14.s-z15.s } + KAI_ASM_INST(0xa16025a7) // st1h { z7.h, z15.h }, pn9.b, [x13] + KAI_ASM_INST(0xc120e207) // fcvt z7.h, { z16.s-z17.s } + KAI_ASM_INST(0xc120e24f) // fcvt z15.h, { z18.s-z19.s } + KAI_ASM_INST(0xa16125a7) // st1h { z7.h, z15.h }, pn9.b, [x13, #0x2, MUL VL] + KAI_ASM_INST(0xc120e38e) // fcvt z14.h, { z28.s-z29.s } + KAI_ASM_INST(0xc120e3cf) // fcvt z15.h, { z30.s-z31.s } + KAI_ASM_INST(0xa06225ae) // st1h { z14.h-z15.h }, pn9.b, [x13, #0x4, MUL VL] + KAI_ASM_INST(0xc120e112) // fcvt z18.h, { z8.s-z9.s } + KAI_ASM_INST(0xc120e15a) // fcvt z26.h, { z10.s-z11.s } + KAI_ASM_INST(0xa16321b2) // st1h { z18.h, z26.h }, p8, [x13, #0x6, MUL VL] + addvl x13, x13, #8 +KAI_ASM_LABEL(label_27) // Width 4: Output done + subs x17, x17, #0x4 + mov x14, x22 + sub x5, x5, x6, LSL #2 + bgt label_4 +KAI_ASM_LABEL(label_28) // Exit + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c index a30936cf5dc2f98752b07617d5de2467ea588bb5..a1ccaa278b89f54a0b63a98790b6aa797ad40e5a 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c @@ -4,40 +4,58 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" -#include #include #include #include "kai/kai_common.h" +typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + uint16_t min; + uint16_t max; + void* accumulator_buffer; + uint64_t flags; +} KernelArgs; + static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 2; static const size_t kai_sr = 1; +void kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(KernelArgs* args); +uint16_t kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(float value); + +// Returns a constant value specific to this kernel that's relative to vector length +static size_t kai_get_kernel_vec_length_constant(void) { + const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u16() / kai_kr; + return kernel_vec_length_constant; +} + size_t kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u16() / kai_kr; + return kai_mr * kai_get_kernel_vec_length_constant(); } size_t kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u16() / kai_kr; + return kai_nr * kai_get_kernel_vec_length_constant(); } size_t kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u16() / kai_kr; + return kai_mr * kai_get_kernel_vec_length_constant(); } size_t kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u16() / kai_kr; + return kai_nr * kai_get_kernel_vec_length_constant(); } size_t kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { @@ -65,11 +83,11 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sm } size_t kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_stride) { + size_t m_idx, size_t n_idx, size_t dst_stride_row) { KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); - return m_idx * dst_stride + n_idx * sizeof(uint16_t); + return m_idx * dst_stride_row + n_idx * sizeof(uint16_t); } size_t kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(size_t m, size_t n) { @@ -80,184 +98,21 @@ void kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) { KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); - - typedef struct { - const void* A; - const void* B; - - void* C; - uint64_t ldcb; - uint64_t M, N, K; - float16_t min; - float16_t max; - - void* accumulator_buffer; - uint64_t flags; - } KernelArgs; - KernelArgs args; args.A = lhs_packed; args.B = rhs_packed; - args.C = dst; args.ldcb = dst_stride_row; args.M = m; args.N = n; args.K = k; - args.min = (float16_t)clamp_min; - args.max = (float16_t)clamp_max; - + args.min = kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(clamp_min); + args.max = kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(clamp_max); args.accumulator_buffer = NULL; args.flags = 0; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "ldr w13, [%x[args], %[offsetof_M]]\n" - "mov x11, #0x0\n" - "mov x10, #0x0\n" - "ptrue p1.b\n" - ".inst 0x25207810 // ptrue pn8.b\n" - "ldr w9, [%x[args], %[offsetof_N]]\n" - "ldr x28, [%x[args], %[offsetof_A]]\n" - "1:" // M loop - "ldr x27, [%x[args], %[offsetof_B]]\n" - "2:" // N loop - "fmov z24.h, #0.0\n" - "ld1h { z5.h }, p1/Z, [x27]\n" - "fmov z27.h, #1.0\n" - "mov x26, x28\n" - ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" - "inch x27, ALL, MUL #2\n" - "zip1 z30.h, z5.h, z24.h\n" - "zip2 z20.h, z5.h, z24.h\n" - ".inst 0x81be2760 // fmopa za0.s, p1/M, p1/M, z27.h, z30.h\n" - ".inst 0x81b42761 // fmopa za1.s, p1/M, p1/M, z27.h, z20.h\n" - ".inst 0x81be2762 // fmopa za2.s, p1/M, p1/M, z27.h, z30.h\n" - ".inst 0x81b42763 // fmopa za3.s, p1/M, p1/M, z27.h, z20.h\n" - "ldr x20, [%x[args], %[offsetof_K]]\n" - "add x20, x20, #0x1\n" - "lsr x20, x20, #0x1\n" - "lsr x21, x20, #0x2\n" - "and x20, x20, #0x3\n" - "cbz x21, 6f\n" - "subs x21, x21, #0x1\n" - ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" - ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" - ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" - ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" - ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" - "addvl x26, x26, #8\n" - ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - "ble 5f\n" - "4:" // K loop - ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" - "subs x21, x21, #0x1\n" - ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" - ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" - ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" - ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" - ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" - ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" - ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" - ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" - ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" - ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" - ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" - ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" - ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" - ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" - ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" - ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" - ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" - ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" - ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" - ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" - ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" - ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" - "addvl x26, x26, #8\n" - ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" - "addvl x27, x27, #8\n" - "bgt 4b\n" - "5:" // K loop tail - ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" - ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" - ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" - ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" - ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" - ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" - ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" - ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" - ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" - ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" - ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" - ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" - ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" - ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" - ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" - ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" - "6:" // K oddments - "cbz x20, 8f\n" - "7:" // K oddments: Loop - ".inst 0xa1402345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26]\n" - "subs x20, x20, #0x1\n" - "addvl x26, x26, #2\n" - ".inst 0xa040236e // ld1h { z14.h-z15.h }, pn8.b/Z, [x27]\n" - "addvl x27, x27, #2\n" - ".inst 0x81ae24a0 // fmopa za0.s, p1/M, p1/M, z5.h, z14.h\n" - ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" - ".inst 0x81ae25a2 // fmopa za2.s, p1/M, p1/M, z13.h, z14.h\n" - ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" - "bgt 7b\n" - "8:" // K oddments: End - "ldr x25, [%x[args], %[offsetof_C]]\n" - "sub x24, x13, x11\n" - "cntw x23, ALL, MUL #2\n" - "ld1rh { z17.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" - "ldr x22, [%x[args], %[offsetof_ldcb]]\n" - "whilelt p0.h, x10, x9\n" - "cmp x24, x23\n" - "ld1rh { z16.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" - "mov x12, #0x0\n" - "mov x21, #0x0\n" - "add x25, x25, x10, LSL #1\n" // C += n - "mov x20, #0x2\n" - "madd x25, x11, x22, x25\n" // C += m * ldc - "csel x24, x24, x23, LT\n" - "10:" // Store to output array: Accumulator loop - ".inst 0xc006000e // mova { z14.b-z15.b }, za0h.b[x12, 0:1]\n" - "add x12, x12, #0x4\n" - "cmp x12, x23, LSL #1\n" - "add x21, x21, #0x1\n" - ".inst 0xc120e1cc // fcvt z12.h, { z14.s-z15.s }\n" - "csel x12, x12, x20, LT\n" - "cmp x21, x24\n" - ".inst 0x6470262c // fclamp z12.h, z17.h, z16.h\n" - "st1h { z12.h }, p0, [x25]\n" - "add x25, x25, x22\n" - "blt 10b\n" - "incw x10, ALL, MUL #2\n" - "cmp x10, x9\n" - "blt 2b\n" - "incw x11, ALL, MUL #2\n" - "mov x10, #0x0\n" - "cmp x11, x13\n" - "mov x28, x26\n" - "blt 1b\n" - ".inst 0xd503467f // SMSTOP\n" - : - : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), - [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), - [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), - [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), [offsetof_M] "I"(offsetof(KernelArgs, M)), - [offsetof_N] "I"(offsetof(KernelArgs, N)), [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", - "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", - "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h index 20be08a0d35799a8b1255403f5d6e2f339b296ec..bc61fe28117f7fdc685b0e2359b09c8a204fa838 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -61,7 +61,7 @@ size_t kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void); /// Gets the offset in bytes to the data element in the packed LHS matrix buffer. /// -/// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step` +/// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. /// @param[in] k Number of columns in the unpacked LHS matrix. /// /// @return The offset in bytes to the data element. @@ -69,21 +69,21 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sm /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// -/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of `n_step` -/// @param[in] k Number of rows in the unpacked RHS matrix. +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of `n_step`. +/// @param[in] k Number of columns in the unpacked LHS matrix. /// /// @return The offset in bytes to the data element. size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// -/// @param[in] m_idx Row index. Must be a multiple of `m_step` -/// @param[in] n_idx Column index. Must be a multiple of `n_step` -/// @param[in] stride Row stride in bytes. +/// @param[in] m_idx Row index. Must be a multiple of `m_step`. +/// @param[in] n_idx Column index. Must be a multiple of `n_step`. +/// @param[in] dst_stride_row Row stride in bytes. /// /// @return The offset in bytes to the data element. size_t kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_stride); + size_t m_idx, size_t n_idx, size_t dst_stride_row); /// Gets the size in bytes of the destination matrix buffer. /// @@ -104,7 +104,7 @@ size_t kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(s /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. -/// @param[in] k Common dimension of the LHS and RHS operands. +/// @param[in] k Number of columns in the unpacked LHS matrix. /// @param[in] lhs_packed Packed LHS matrix buffer. /// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..cd9e9ff4bb8dcc2de9b68a35622d8f364202a9aa --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_asm.S @@ -0,0 +1,202 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + KAI_ASM_GLOBAL(kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + fcvt h0, s0 + fmov w0, h0 + ret + KAI_ASM_FUNCTION_END(kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x14, #0x0 + ldr x13, [x0, #0x30] + ptrue p1.b + KAI_ASM_INST(0x25207810) // ptrue pn8.b + ldr w11, [x0, #0x20] + mov x10, #0x0 + ldr w9, [x0, #0x28] + add x13, x13, #0x1 + ldr x28, [x0, #0x0] + lsr x13, x13, #0x1 +KAI_ASM_LABEL(label_1) // M loop + ldr x27, [x0, #0x8] +KAI_ASM_LABEL(label_2) // N loop + fmov z23.h, #0.0 + ld1h { z18.h }, p1/Z, [x27] + fmov z2.h, #1.0 + mov x26, x28 + KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } + inch x27, ALL, MUL #2 + zip1 z14.h, z18.h, z23.h + zip2 z3.h, z18.h, z23.h + KAI_ASM_INST(0x81ae2440) // fmopa za0.s, p1/M, p1/M, z2.h, z14.h + KAI_ASM_INST(0x81a32441) // fmopa za1.s, p1/M, p1/M, z2.h, z3.h + KAI_ASM_INST(0x81ae2442) // fmopa za2.s, p1/M, p1/M, z2.h, z14.h + KAI_ASM_INST(0x81a32443) // fmopa za3.s, p1/M, p1/M, z2.h, z3.h + lsr x21, x13, #0x2 + and x20, x13, #0x3 + cbz x21, label_6 + subs x21, x21, #0x1 + KAI_ASM_INST(0xa040a350) // ld1h { z16.h-z19.h }, pn8.b/Z, [x26] + KAI_ASM_INST(0xa041a35c) // ld1h { z28.h-z31.h }, pn8.b/Z, [x26, #0x4, MUL VL] + addvl x26, x26, #8 + KAI_ASM_INST(0xa040a360) // ld1h { z0.h-z3.h }, pn8.b/Z, [x27] + KAI_ASM_INST(0xa041a368) // ld1h { z8.h-z11.h }, pn8.b/Z, [x27, #0x4, MUL VL] + addvl x27, x27, #8 + ble label_5 +KAI_ASM_LABEL(label_4) // K loop + KAI_ASM_INST(0x81a02600) // fmopa za0.s, p1/M, p1/M, z16.h, z0.h + subs x21, x21, #0x1 + KAI_ASM_INST(0x81a12601) // fmopa za1.s, p1/M, p1/M, z16.h, z1.h + KAI_ASM_INST(0x81a02622) // fmopa za2.s, p1/M, p1/M, z17.h, z0.h + KAI_ASM_INST(0x81a12623) // fmopa za3.s, p1/M, p1/M, z17.h, z1.h + KAI_ASM_INST(0x81a22640) // fmopa za0.s, p1/M, p1/M, z18.h, z2.h + KAI_ASM_INST(0x81a32641) // fmopa za1.s, p1/M, p1/M, z18.h, z3.h + KAI_ASM_INST(0x81a22662) // fmopa za2.s, p1/M, p1/M, z19.h, z2.h + KAI_ASM_INST(0x81a32663) // fmopa za3.s, p1/M, p1/M, z19.h, z3.h + KAI_ASM_INST(0xa040a350) // ld1h { z16.h-z19.h }, pn8.b/Z, [x26] + KAI_ASM_INST(0x81a82780) // fmopa za0.s, p1/M, p1/M, z28.h, z8.h + KAI_ASM_INST(0xa040a360) // ld1h { z0.h-z3.h }, pn8.b/Z, [x27] + KAI_ASM_INST(0x81a92781) // fmopa za1.s, p1/M, p1/M, z28.h, z9.h + KAI_ASM_INST(0x81a827a2) // fmopa za2.s, p1/M, p1/M, z29.h, z8.h + KAI_ASM_INST(0x81a927a3) // fmopa za3.s, p1/M, p1/M, z29.h, z9.h + KAI_ASM_INST(0x81aa27c0) // fmopa za0.s, p1/M, p1/M, z30.h, z10.h + KAI_ASM_INST(0x81ab27c1) // fmopa za1.s, p1/M, p1/M, z30.h, z11.h + KAI_ASM_INST(0x81aa27e2) // fmopa za2.s, p1/M, p1/M, z31.h, z10.h + KAI_ASM_INST(0x81ab27e3) // fmopa za3.s, p1/M, p1/M, z31.h, z11.h + KAI_ASM_INST(0xa041a35c) // ld1h { z28.h-z31.h }, pn8.b/Z, [x26, #0x4, MUL VL] + addvl x26, x26, #8 + KAI_ASM_INST(0xa041a368) // ld1h { z8.h-z11.h }, pn8.b/Z, [x27, #0x4, MUL VL] + addvl x27, x27, #8 + bgt label_4 +KAI_ASM_LABEL(label_5) // K loop tail + KAI_ASM_INST(0x81a02600) // fmopa za0.s, p1/M, p1/M, z16.h, z0.h + KAI_ASM_INST(0x81a12601) // fmopa za1.s, p1/M, p1/M, z16.h, z1.h + KAI_ASM_INST(0x81a02622) // fmopa za2.s, p1/M, p1/M, z17.h, z0.h + KAI_ASM_INST(0x81a12623) // fmopa za3.s, p1/M, p1/M, z17.h, z1.h + KAI_ASM_INST(0x81a22640) // fmopa za0.s, p1/M, p1/M, z18.h, z2.h + KAI_ASM_INST(0x81a32641) // fmopa za1.s, p1/M, p1/M, z18.h, z3.h + KAI_ASM_INST(0x81a22662) // fmopa za2.s, p1/M, p1/M, z19.h, z2.h + KAI_ASM_INST(0x81a32663) // fmopa za3.s, p1/M, p1/M, z19.h, z3.h + KAI_ASM_INST(0x81a82780) // fmopa za0.s, p1/M, p1/M, z28.h, z8.h + KAI_ASM_INST(0x81a92781) // fmopa za1.s, p1/M, p1/M, z28.h, z9.h + KAI_ASM_INST(0x81a827a2) // fmopa za2.s, p1/M, p1/M, z29.h, z8.h + KAI_ASM_INST(0x81a927a3) // fmopa za3.s, p1/M, p1/M, z29.h, z9.h + KAI_ASM_INST(0x81aa27c0) // fmopa za0.s, p1/M, p1/M, z30.h, z10.h + KAI_ASM_INST(0x81ab27c1) // fmopa za1.s, p1/M, p1/M, z30.h, z11.h + KAI_ASM_INST(0x81aa27e2) // fmopa za2.s, p1/M, p1/M, z31.h, z10.h + KAI_ASM_INST(0x81ab27e3) // fmopa za3.s, p1/M, p1/M, z31.h, z11.h +KAI_ASM_LABEL(label_6) // K oddments + cbz x20, label_8 +KAI_ASM_LABEL(label_7) // K oddments: Loop + KAI_ASM_INST(0xa1402345) // ld1h { z5.h, z13.h }, pn8.b/Z, [x26] + subs x20, x20, #0x1 + addvl x26, x26, #2 + KAI_ASM_INST(0xa040236e) // ld1h { z14.h-z15.h }, pn8.b/Z, [x27] + addvl x27, x27, #2 + KAI_ASM_INST(0x81ae24a0) // fmopa za0.s, p1/M, p1/M, z5.h, z14.h + KAI_ASM_INST(0x81af24a1) // fmopa za1.s, p1/M, p1/M, z5.h, z15.h + KAI_ASM_INST(0x81ae25a2) // fmopa za2.s, p1/M, p1/M, z13.h, z14.h + KAI_ASM_INST(0x81af25a3) // fmopa za3.s, p1/M, p1/M, z13.h, z15.h + bgt label_7 +KAI_ASM_LABEL(label_8) // K oddments: End + ldr x25, [x0, #0x10] + sub x24, x11, x14 + cntw x23, ALL, MUL #2 + KAI_ASM_INST(0x84dca411) // ld1rh { z17.h }, p1/Z, [x0, #56] + ldr x22, [x0, #0x18] + whilelt p0.h, x10, x9 + cmp x24, x23 + KAI_ASM_INST(0x84dda410) // ld1rh { z16.h }, p1/Z, [x0, #58] + mov x12, #0x0 + mov x21, #0x0 + add x25, x25, x10, LSL #1 // C += n + mov x20, #0x2 + madd x25, x14, x22, x25 // C += m * ldc + csel x24, x24, x23, LT +KAI_ASM_LABEL(label_10) // Store to output array: Accumulator loop + KAI_ASM_INST(0xc006000e) // mova { z14.b-z15.b }, za0h.b[x12, 0:1] + add x12, x12, #0x4 + cmp x12, x23, LSL #1 + add x21, x21, #0x1 + KAI_ASM_INST(0xc120e1c4) // fcvt z4.h, { z14.s-z15.s } + csel x12, x12, x20, LT + cmp x21, x24 + KAI_ASM_INST(0x64702624) // fclamp z4.h, z17.h, z16.h + st1h { z4.h }, p0, [x25] + add x25, x25, x22 + blt label_10 + incw x10, ALL, MUL #2 + cmp x10, x9 + blt label_2 + incw x14, ALL, MUL #2 + mov x10, #0x0 + cmp x14, x11 + mov x28, x26 + blt label_1 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c index d865117f190b0dc79aa14b12884634a3b83d20d7..fb6a143597434df35964f815b75899a718c79ddd 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c @@ -4,10 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. @@ -18,12 +15,34 @@ #include "kai/kai_common.h" -static const size_t kai_mr = 2; -static const size_t kai_kr = 2; -static const size_t kai_sr = 1; +enum { + MR = 2, + KR = 2, + MAX_M_STEP = (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR), + SR = 1, +}; + +typedef struct { + size_t m; + size_t k; + size_t mr; + size_t kr; + size_t sr; + size_t m_idx_start; + const void* lhs; + size_t lhs_stride; + void* lhs_packed; + size_t height; + size_t width; + const void* const* in; + size_t row_offset; + void* out; +} KernelArgs; + +void kai_kernel_lhs_pack_x16p2vlx2_x16_sme(const KernelArgs* args_ptr); static size_t kai_get_mr_lhs_pack_x16p2vlx2_x16_sme(void) { - return kai_mr * kai_get_sme_vector_length_u16() / kai_kr; + return MR * kai_get_sme_vector_length_u16() / KR; } size_t kai_get_m_step_lhs_pack_x16p2vlx2_x16_sme(size_t mr) { @@ -42,319 +61,73 @@ size_t kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme(size_t m_idx, size_t lhs_st size_t kai_get_lhs_packed_offset_lhs_pack_x16p2vlx2_x16_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { KAI_ASSUME(m_idx % kai_get_m_step_lhs_pack_x16p2vlx2_x16_sme(mr) == 0); KAI_ASSUME(mr == kai_get_mr_lhs_pack_x16p2vlx2_x16_sme()); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(kr == KR); + KAI_ASSUME(sr == SR); KAI_UNUSED(mr); KAI_UNUSED(kr); KAI_UNUSED(sr); - return m_idx * kai_roundup(k, kr) * sizeof(uint16_t); + return m_idx * kai_roundup(k, KR) * sizeof(uint16_t); } size_t kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { KAI_ASSUME(mr == kai_get_mr_lhs_pack_x16p2vlx2_x16_sme()); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(kr == KR); + KAI_ASSUME(sr == SR); KAI_UNUSED(mr); KAI_UNUSED(kr); KAI_UNUSED(sr); - return kai_roundup(m, kai_get_mr_lhs_pack_x16p2vlx2_x16_sme()) * kai_roundup(k, kai_kr) * sizeof(uint16_t); + return kai_roundup(m, kai_get_mr_lhs_pack_x16p2vlx2_x16_sme()) * kai_roundup(k, KR) * sizeof(uint16_t); } void kai_run_lhs_pack_x16p2vlx2_x16_sme( size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { KAI_ASSUME(mr == kai_get_mr_lhs_pack_x16p2vlx2_x16_sme()); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(kr == KR); + KAI_ASSUME(sr == SR); KAI_ASSUME(lhs != NULL); KAI_ASSUME(lhs_packed != NULL); - KAI_ASSUME(m_idx_start == 0); + const size_t m_step = kai_get_mr_lhs_pack_x16p2vlx2_x16_sme(); const size_t block_height = mr; const size_t width = k; const size_t row_offset = 0; - const void* in[block_height]; + KAI_ASSERT(m_step <= MAX_M_STEP); + const void* in[MAX_M_STEP]; uint8_t* lhs_packed_ptr = lhs_packed; const uint8_t* lhs_ptr = lhs; for (size_t block_y = 0; block_y < m; block_y += block_height) { const size_t height = KAI_MIN(m - block_y, block_height); - void* out = (void*)((char*)lhs_packed_ptr + block_y * kai_roundup(k, kai_kr) * sizeof(uint16_t)); + void* out = lhs_packed_ptr + block_y * kai_roundup(k, KR) * sizeof(uint16_t); for (size_t y = 0; y < height; y++) { in[y] = lhs_ptr + (block_y + y) * lhs_stride; } - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x22, %x[width]\n" - "mov x21, %x[width]\n" - "cnth x20\n" - "inch x22\n" - "sub x7, x20, #0x1\n" - "sub x22, x22, #0x1\n" - "ands x7, x21, x7\n" - "cntw x8\n" - "udiv x22, x22, x20\n" // n_passes = ceildiv(width, VL) - "csel x7, x7, x20, NE\n" - "sub x13, x22, #0x1\n" - "add x7, x7, #0x1\n" - "sub x17, x8, #0x2\n" - "lsl x21, %x[height], #0x1\n" // height * 2 - "lsl x20, x8, #0x1\n" - "mov x16, #0x0\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - "cntw x9, ALL, MUL #2\n" - "cntw x28, ALL, MUL #3\n" - "ldr x27, [x11, #0x0]\n" - "lsr x13, x13, #0x1\n" // n_loops = (n_passes - 1) / 2 - "and x26, x22, #0x1\n" // odd_tail = bool(n_passes & 0x1) - "ldr x25, [x10, #0x0]\n" - "lsr x7, x7, #0x1\n" - "ptrue p12.s\n" - "ldr x24, [x11, #0x8]\n" - "whilelt p11.h, XZR, x21\n" - "whilelt p10.h, x20, x21\n" - "ldr x21, [x10, #0x8]\n" - "mov x23, %x[row_offset]\n" - "mov x22, %x[out]\n" - "whilelt p9.h, x16, %x[width]\n" - "whilelt p8.h, x16, %x[width]\n" - "add x11, x11, #0x10\n" - "add x10, x10, #0x10\n" - "mov x12, #0x0\n" - "cbz x17, 2f\n" - "1:" // K loop: Charge: Loop - ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" - ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" - ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" - ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" - ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" - "add x12, x12, #0x4\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "cmp x12, x17, LSL #1\n" - "blt 1b\n" - "2:" // K loop: Charge: End - ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" - ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" - ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" - ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" - "ldr x27, [x11, #0x0]\n" - "inch x16\n" - ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "inch x23\n" - "cbz x13, 8f\n" - "mov x20, x13\n" - "3:" // K loop: Main loop - "whilelt p8.h, x16, %x[width]\n" - "mov x15, #0x0\n" - "mov x14, #0x0\n" - "cbz x17, 5f\n" - "4:" // K loop: Main loop: First: Loop - ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" - ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" - ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" - ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" - ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - "add x10, x10, #0x10\n" - "add x15, x15, #0x4\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x14, x14, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x14, x17\n" - "blt 4b\n" - "5:" // K loop: Main loop: First: Tail - ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" - ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" - ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" - ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" - "ldr x27, [x11, #0x0]\n" - "mov x13, #0x0\n" - ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" - "ldr x25, [x10, #0x0]\n" - "mov x12, #0x0\n" - ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - "whilelt p9.h, x16, %x[width]\n" - "inch x16\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "inch x23\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "whilelt p8.h, x16, %x[width]\n" - "cbz x17, 7f\n" - "6:" // K loop: Main loop: Second: Loop - ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" - ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" - ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" - ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" - ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "add x10, x10, #0x10\n" - "add x13, x13, #0x4\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x12, x12, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x12, x17\n" - "blt 6b\n" - "7:" // K loop: Main loop: Second: Tail - ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" - ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" - ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" - ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "whilelt p9.h, x16, %x[width]\n" - "subs x20, x20, #0x1\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "inch x16\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "inch x23\n" - "bgt 3b\n" - "8:" // K loop: Tails - "cbnz x26, 11f\n" - "mov x11, %x[in]\n" - "whilelt p8.h, x16, %x[width]\n" - "mov x13, #0x0\n" - "mov x12, #0x0\n" - "9:" // K loop: Tails: Even: First - ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25396161 // psel p1.h, p8.h/Z, p11.h[w13, #1]\n" - ".inst 0x25396140 // psel p0.h, p8.h/Z, p10.h[w13, #1]\n" - ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "ldr x21, [x11, #0x0]\n" - "cmp x12, x8\n" - "ldr x20, [x11, x8, LSL #0x3]\n" - "add x11, x11, #0x8\n" - ".inst 0xe05726a1 // ld1h { za0h.h[x13, #1] }, p1/Z, [x21, x23, LSL #1]\n" - ".inst 0xe0572289 // ld1h { za1h.h[x13, #1] }, p0/Z, [x20, x23, LSL #1]\n" - "add x13, x13, #0x2\n" - "blt 9b\n" - "whilelt p9.h, x16, %x[width]\n" - "whilelt p8.h, x16, %x[width]\n" - "mov x20, #0x0\n" - "mov x12, #0x0\n" - "10:" // K loop: Tails: Even: Second - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - "add x20, x20, #0x2\n" - ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 10b\n" - "whilelt p8.h, x16, %x[width]\n" - "b 13f\n" - "11:" // K loop: Tails: Odd - "mov x12, #0x0\n" - "12:" // K loop: Tails: Odd: Loop - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 12b\n" - "13:" // K loop: End - "mov %x[out], x22\n" - ".inst 0xd503467f // SMSTOP\n" - : [out] "+&r"(out) - : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", - "x25", "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", - "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", - "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + KernelArgs args; + args.m = m; + args.k = k; + args.mr = MR; + args.kr = KR; + args.sr = SR; + args.m_idx_start = m_idx_start; + args.lhs = lhs; + args.lhs_stride = lhs_stride; + args.lhs_packed = lhs_packed; + args.height = height; + args.width = width; + args.in = in; + args.row_offset = row_offset; + args.out = out; + + kai_kernel_lhs_pack_x16p2vlx2_x16_sme(&args); } } diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme_asm.S b/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..3aef27de9063c1ab8e4b24b8568b484df11a90fb --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme_asm.S @@ -0,0 +1,325 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(lhs_pack_x16p2vlx2_x16_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_lhs_pack_x16p2vlx2_x16_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_lhs_pack_x16p2vlx2_x16_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_lhs_pack_x16p2vlx2_x16_sme) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x4, #0x0 + ldr x5, [x0, #0x50] + cnth x24 + cntw x6 + ldr x23, [x0, #0x48] + sub x7, x24, #0x1 + sub x8, x6, #0x2 + ldr x17, [x0, #0x58] + lsl x12, x6, #0x1 + cntw x16, ALL, MUL #2 + mov x22, x5 + mov x20, x5 + ldr x21, [x0, #0x60] + inch x22 + ands x7, x20, x7 + ldr x11, [x0, #0x68] + sub x22, x22, #0x1 + csel x7, x7, x24, NE + udiv x22, x22, x24 // n_passes = ceildiv(width, VL) + add x7, x7, #0x1 + sub x10, x22, #0x1 + lsl x20, x23, #0x1 // height * 2 + mov x9, x17 + add x28, x17, x6, LSL #3 + cntw x27, ALL, MUL #3 + lsr x10, x10, #0x1 // n_loops = (n_passes - 1) / 2 + ldr x26, [x9, #0x0] + and x25, x22, #0x1 // odd_tail = bool(n_passes & 0x1) + lsr x7, x7, #0x1 + ldr x24, [x28, #0x0] + ptrue p12.s + whilelt p11.h, XZR, x20 + ldr x23, [x9, #0x8] + whilelt p10.h, x12, x20 + mov x22, x21 + ldr x21, [x28, #0x8] + whilelt p9.h, x4, x5 + whilelt p8.h, x4, x5 + add x9, x9, #0x10 + add x28, x28, #0x10 + mov x12, #0x0 + cbz x8, label_2 +KAI_ASM_LABEL(label_1) // K loop: Charge: Loop + KAI_ASM_INST(0x25286163) // psel p3.h, p8.h/Z, p11.h[w12] + KAI_ASM_INST(0x25286142) // psel p2.h, p8.h/Z, p10.h[w12] + KAI_ASM_INST(0x25686161) // psel p1.h, p8.h/Z, p11.h[w12, #2] + KAI_ASM_INST(0x25686140) // psel p0.h, p8.h/Z, p10.h[w12, #2] + KAI_ASM_INST(0xe0560f40) // ld1h { za0h.h[x12] }, p3/Z, [x26, x22, LSL #1] + ldr x26, [x9, #0x0] + KAI_ASM_INST(0xe0560b08) // ld1h { za1h.h[x12] }, p2/Z, [x24, x22, LSL #1] + ldr x24, [x28, #0x0] + KAI_ASM_INST(0xe05606e2) // ld1h { za0h.h[x12, #2] }, p1/Z, [x23, x22, LSL #1] + ldr x23, [x9, #0x8] + add x9, x9, #0x10 + KAI_ASM_INST(0xe05602aa) // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x22, LSL #1] + add x12, x12, #0x4 + ldr x21, [x28, #0x8] + add x28, x28, #0x10 + cmp x12, x8, LSL #1 + blt label_1 +KAI_ASM_LABEL(label_2) // K loop: Charge: End + KAI_ASM_INST(0x25286163) // psel p3.h, p8.h/Z, p11.h[w12] + KAI_ASM_INST(0x25286142) // psel p2.h, p8.h/Z, p10.h[w12] + KAI_ASM_INST(0x25686161) // psel p1.h, p8.h/Z, p11.h[w12, #2] + KAI_ASM_INST(0x25686140) // psel p0.h, p8.h/Z, p10.h[w12, #2] + mov x9, x17 + add x28, x17, x6, LSL #3 + KAI_ASM_INST(0xe0560f40) // ld1h { za0h.h[x12] }, p3/Z, [x26, x22, LSL #1] + ldr x26, [x9, #0x0] + inch x4 + KAI_ASM_INST(0xe0560b08) // ld1h { za1h.h[x12] }, p2/Z, [x24, x22, LSL #1] + ldr x24, [x28, #0x0] + KAI_ASM_INST(0xe05606e2) // ld1h { za0h.h[x12, #2] }, p1/Z, [x23, x22, LSL #1] + ldr x23, [x9, #0x8] + add x9, x9, #0x10 + KAI_ASM_INST(0xe05602aa) // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x22, LSL #1] + ldr x21, [x28, #0x8] + add x28, x28, #0x10 + inch x22 + cbz x10, label_8 + mov x20, x10 +KAI_ASM_LABEL(label_3) // K loop: Main loop + whilelt p8.h, x4, x5 + mov x15, #0x0 + mov x14, #0x0 + cbz x8, label_5 +KAI_ASM_LABEL(label_4) // K loop: Main loop: First: Loop + KAI_ASM_INST(0x253b6160) // psel p0.h, p8.h/Z, p11.h[w15, #1] + KAI_ASM_INST(0x253b6142) // psel p2.h, p8.h/Z, p10.h[w15, #1] + KAI_ASM_INST(0x257b6161) // psel p1.h, p8.h/Z, p11.h[w15, #3] + KAI_ASM_INST(0x257b6143) // psel p3.h, p8.h/Z, p10.h[w15, #3] + KAI_ASM_INST(0xe0566341) // ld1h { za0h.h[x15, #1] }, p0/Z, [x26, x22, LSL #1] + KAI_ASM_INST(0x252a7120) // psel p0.h, p12.h/Z, p9.h[w14] + ldr x26, [x9, #0x0] + KAI_ASM_INST(0xe0566b09) // ld1h { za1h.h[x15, #1] }, p2/Z, [x24, x22, LSL #1] + KAI_ASM_INST(0x252a7122) // psel p2.h, p12.h/Z, p9.h[w14] + ldr x24, [x28, #0x0] + KAI_ASM_INST(0xe05666e3) // ld1h { za0h.h[x15, #3] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x253a7121) // psel p1.h, p12.h/Z, p9.h[w14, #1] + ldr x23, [x9, #0x8] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0566eab) // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x28, #0x8] + KAI_ASM_INST(0xe0bfc160) // st1w { za0v.s[x14] }, p0/Z, [x11, XZR, LSL #2] + KAI_ASM_INST(0x253a7120) // psel p0.h, p12.h/Z, p9.h[w14, #1] + KAI_ASM_INST(0xe0a6c964) // st1w { za1v.s[x14] }, p2/Z, [x11, x6, LSL #2] + add x28, x28, #0x10 + add x15, x15, #0x4 + KAI_ASM_INST(0xe0b0c561) // st1w { za0v.s[x14, #1] }, p1/Z, [x11, x16, LSL #2] + KAI_ASM_INST(0xe0bbc165) // st1w { za1v.s[x14, #1] }, p0/Z, [x11, x27, LSL #2] + add x14, x14, #0x2 + addvl x11, x11, #4 + cmp x14, x8 + blt label_4 +KAI_ASM_LABEL(label_5) // K loop: Main loop: First: Tail + KAI_ASM_INST(0x253b6160) // psel p0.h, p8.h/Z, p11.h[w15, #1] + KAI_ASM_INST(0x253b6142) // psel p2.h, p8.h/Z, p10.h[w15, #1] + KAI_ASM_INST(0x257b6161) // psel p1.h, p8.h/Z, p11.h[w15, #3] + KAI_ASM_INST(0x257b6143) // psel p3.h, p8.h/Z, p10.h[w15, #3] + mov x9, x17 + add x28, x17, x6, LSL #3 + KAI_ASM_INST(0xe0566341) // ld1h { za0h.h[x15, #1] }, p0/Z, [x26, x22, LSL #1] + KAI_ASM_INST(0x252a7120) // psel p0.h, p12.h/Z, p9.h[w14] + ldr x26, [x9, #0x0] + mov x13, #0x0 + KAI_ASM_INST(0xe0566b09) // ld1h { za1h.h[x15, #1] }, p2/Z, [x24, x22, LSL #1] + KAI_ASM_INST(0x252a7122) // psel p2.h, p12.h/Z, p9.h[w14] + ldr x24, [x28, #0x0] + mov x12, #0x0 + KAI_ASM_INST(0xe05666e3) // ld1h { za0h.h[x15, #3] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x253a7121) // psel p1.h, p12.h/Z, p9.h[w14, #1] + ldr x23, [x9, #0x8] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0566eab) // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x28, #0x8] + KAI_ASM_INST(0xe0bfc160) // st1w { za0v.s[x14] }, p0/Z, [x11, XZR, LSL #2] + KAI_ASM_INST(0x253a7120) // psel p0.h, p12.h/Z, p9.h[w14, #1] + KAI_ASM_INST(0xe0a6c964) // st1w { za1v.s[x14] }, p2/Z, [x11, x6, LSL #2] + whilelt p9.h, x4, x5 + inch x4 + KAI_ASM_INST(0xe0b0c561) // st1w { za0v.s[x14, #1] }, p1/Z, [x11, x16, LSL #2] + add x28, x28, #0x10 + inch x22 + KAI_ASM_INST(0xe0bbc165) // st1w { za1v.s[x14, #1] }, p0/Z, [x11, x27, LSL #2] + addvl x11, x11, #4 + whilelt p8.h, x4, x5 + cbz x8, label_7 +KAI_ASM_LABEL(label_6) // K loop: Main loop: Second: Loop + KAI_ASM_INST(0x25296160) // psel p0.h, p8.h/Z, p11.h[w13] + KAI_ASM_INST(0x25296142) // psel p2.h, p8.h/Z, p10.h[w13] + KAI_ASM_INST(0x25696161) // psel p1.h, p8.h/Z, p11.h[w13, #2] + KAI_ASM_INST(0x25696143) // psel p3.h, p8.h/Z, p10.h[w13, #2] + KAI_ASM_INST(0xe0562340) // ld1h { za0h.h[x13] }, p0/Z, [x26, x22, LSL #1] + KAI_ASM_INST(0x25287120) // psel p0.h, p12.h/Z, p9.h[w12] + ldr x26, [x9, #0x0] + KAI_ASM_INST(0xe0562b08) // ld1h { za1h.h[x13] }, p2/Z, [x24, x22, LSL #1] + KAI_ASM_INST(0x25287122) // psel p2.h, p12.h/Z, p9.h[w12] + ldr x24, [x28, #0x0] + KAI_ASM_INST(0xe05626e2) // ld1h { za0h.h[x13, #2] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x25387121) // psel p1.h, p12.h/Z, p9.h[w12, #1] + ldr x23, [x9, #0x8] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0562eaa) // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x28, #0x8] + KAI_ASM_INST(0xe0bf8168) // st1w { za2v.s[x12] }, p0/Z, [x11, XZR, LSL #2] + KAI_ASM_INST(0x25387120) // psel p0.h, p12.h/Z, p9.h[w12, #1] + KAI_ASM_INST(0xe0a6896c) // st1w { za3v.s[x12] }, p2/Z, [x11, x6, LSL #2] + add x28, x28, #0x10 + add x13, x13, #0x4 + KAI_ASM_INST(0xe0b08569) // st1w { za2v.s[x12, #1] }, p1/Z, [x11, x16, LSL #2] + KAI_ASM_INST(0xe0bb816d) // st1w { za3v.s[x12, #1] }, p0/Z, [x11, x27, LSL #2] + add x12, x12, #0x2 + addvl x11, x11, #4 + cmp x12, x8 + blt label_6 +KAI_ASM_LABEL(label_7) // K loop: Main loop: Second: Tail + KAI_ASM_INST(0x25296160) // psel p0.h, p8.h/Z, p11.h[w13] + KAI_ASM_INST(0x25296142) // psel p2.h, p8.h/Z, p10.h[w13] + KAI_ASM_INST(0x25696161) // psel p1.h, p8.h/Z, p11.h[w13, #2] + KAI_ASM_INST(0x25696143) // psel p3.h, p8.h/Z, p10.h[w13, #2] + mov x9, x17 + add x28, x17, x6, LSL #3 + KAI_ASM_INST(0xe0562340) // ld1h { za0h.h[x13] }, p0/Z, [x26, x22, LSL #1] + KAI_ASM_INST(0x25287120) // psel p0.h, p12.h/Z, p9.h[w12] + ldr x26, [x9, #0x0] + KAI_ASM_INST(0xe0562b08) // ld1h { za1h.h[x13] }, p2/Z, [x24, x22, LSL #1] + KAI_ASM_INST(0x25287122) // psel p2.h, p12.h/Z, p9.h[w12] + ldr x24, [x28, #0x0] + KAI_ASM_INST(0xe05626e2) // ld1h { za0h.h[x13, #2] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x25387121) // psel p1.h, p12.h/Z, p9.h[w12, #1] + ldr x23, [x9, #0x8] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0562eaa) // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x28, #0x8] + KAI_ASM_INST(0xe0bf8168) // st1w { za2v.s[x12] }, p0/Z, [x11, XZR, LSL #2] + KAI_ASM_INST(0x25387120) // psel p0.h, p12.h/Z, p9.h[w12, #1] + KAI_ASM_INST(0xe0a6896c) // st1w { za3v.s[x12] }, p2/Z, [x11, x6, LSL #2] + whilelt p9.h, x4, x5 + subs x20, x20, #0x1 + KAI_ASM_INST(0xe0b08569) // st1w { za2v.s[x12, #1] }, p1/Z, [x11, x16, LSL #2] + add x28, x28, #0x10 + inch x4 + KAI_ASM_INST(0xe0bb816d) // st1w { za3v.s[x12, #1] }, p0/Z, [x11, x27, LSL #2] + addvl x11, x11, #4 + inch x22 + bgt label_3 +KAI_ASM_LABEL(label_8) // K loop: Tails + cbnz x25, label_11 + mov x9, x17 + whilelt p8.h, x4, x5 + mov x13, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_9) // K loop: Tails: Even: First + KAI_ASM_INST(0x25307123) // psel p3.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25396161) // psel p1.h, p8.h/Z, p11.h[w13, #1] + KAI_ASM_INST(0x25396140) // psel p0.h, p8.h/Z, p10.h[w13, #1] + KAI_ASM_INST(0xe0bf8d60) // st1w { za0v.s[x12] }, p3/Z, [x11, XZR, LSL #2] + KAI_ASM_INST(0xe0a68964) // st1w { za1v.s[x12] }, p2/Z, [x11, x6, LSL #2] + add x12, x12, #0x1 + addvl x11, x11, #2 + ldr x21, [x9, #0x0] + cmp x12, x6 + ldr x20, [x9, x6, LSL #0x3] + add x9, x9, #0x8 + KAI_ASM_INST(0xe05626a1) // ld1h { za0h.h[x13, #1] }, p1/Z, [x21, x22, LSL #1] + KAI_ASM_INST(0xe0562289) // ld1h { za1h.h[x13, #1] }, p0/Z, [x20, x22, LSL #1] + add x13, x13, #0x2 + blt label_9 + whilelt p9.h, x4, x5 + whilelt p8.h, x4, x5 + mov x20, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_10) // K loop: Tails: Even: Second + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + add x20, x20, #0x2 + KAI_ASM_INST(0xe0bf8568) // st1w { za2v.s[x12] }, p1/Z, [x11, XZR, LSL #2] + KAI_ASM_INST(0xe0a6816c) // st1w { za3v.s[x12] }, p0/Z, [x11, x6, LSL #2] + add x12, x12, #0x1 + addvl x11, x11, #2 + cmp x12, x7 + blt label_10 + whilelt p8.h, x4, x5 + b label_13 +KAI_ASM_LABEL(label_11) // K loop: Tails: Odd + mov x12, #0x0 +KAI_ASM_LABEL(label_12) // K loop: Tails: Odd: Loop + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8560) // st1w { za0v.s[x12] }, p1/Z, [x11, XZR, LSL #2] + KAI_ASM_INST(0xe0a68164) // st1w { za1v.s[x12] }, p0/Z, [x11, x6, LSL #2] + add x12, x12, #0x1 + addvl x11, x11, #2 + cmp x12, x7 + blt label_12 +KAI_ASM_LABEL(label_13) // K loop: End + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_lhs_pack_x16p2vlx2_x16_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.c index 698edaa3f061b3f65831b8c260c0576acea2d9a6..4904a36c2d3d54bba4189476f5298bbe98cef3c1 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.c @@ -1,13 +1,12 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. - #include "kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h" #include @@ -16,14 +15,31 @@ #include "kai/kai_common.h" -static const size_t kai_nr = 2; -static const size_t kai_kr = 2; -static const size_t kai_num_bytes_input = 2; -static const size_t kai_num_bytes_output = 2; -static const size_t kai_num_bytes_bias = 2; +enum { + NR = 2, + KR = 2, + MAX_N_STEP = NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR), +}; + +typedef struct { + const void* bias_ptr; + size_t width; + size_t height; + size_t in_stride; + size_t out_stride; + const void* in; + void* out; + const void* pad_row; +} KernelArgs; + +static const size_t kai_num_bytes_input = sizeof(uint16_t); +static const size_t kai_num_bytes_output = sizeof(uint16_t); +static const size_t kai_num_bytes_bias = sizeof(uint16_t); + +void kai_kernel_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(const KernelArgs* args_ptr); size_t kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(void) { - return kai_nr * kai_get_sme_vector_length_u16() / kai_kr; + return NR * kai_get_sme_vector_length_u16() / KR; } size_t kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx) { @@ -38,7 +54,7 @@ size_t kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx) { size_t kai_get_rhs_packed_stride_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t k) { return kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme() * - (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output); + (kai_num_bytes_bias + kai_roundup(k, KR) * kai_num_bytes_output); } size_t kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx, size_t k) { @@ -54,11 +70,11 @@ size_t kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n, siz } void kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride_row, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { KAI_ASSUME(num_groups == 1); KAI_ASSUME(nr == kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme()); - KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(kr == KR); KAI_ASSUME(sr == 1); KAI_ASSUME(rhs != NULL); KAI_ASSUME(bias != NULL); @@ -67,122 +83,20 @@ void kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme( KAI_ASSUME(extra_bytes == 0); KAI_ASSUME(params == NULL); - size_t height = k; - const size_t width = n; - const void* in = rhs; - void* out = rhs_packed; - const size_t in_stride = rhs_stride; - const uint16_t* pad_row = rhs; - - size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(height); - - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "mov x21, %x[out]\n" - "mov x20, %x[width]\n" - "ptrue p1.b\n" - "1:" // Bias: Full loop - "whilelt p0.h, XZR, x20\n" - "dech x20\n" - "cmp x20, #0x0\n" - "ld1h { z16.h }, p0/Z, [%x[bias]]\n" - "incb %x[bias]\n" - "st1h { z16.h }, p1, [x21]\n" - "add x21, x21, %x[out_stride]\n" - "bgt 1b\n" - "cmp %x[height], #0x8\n" - "incb %x[out]\n" - "blt 5f\n" - "2:" // Main row loop: Head - "mov x9, %x[in]\n" - "mov x28, %x[out]\n" - "add x27, x9, %x[in_stride]\n" - "sub %x[height], %x[height], #0x8\n" - "add x26, x27, %x[in_stride]\n" - "mov x25, %x[width]\n" - "add x24, x26, %x[in_stride]\n" - "add x23, x24, %x[in_stride]\n" - "add x22, x23, %x[in_stride]\n" - "add x21, x22, %x[in_stride]\n" - "add x20, x21, %x[in_stride]\n" - "add %x[in], x20, %x[in_stride]\n" - "3:" // Main row loop: Column loop - "whilelt p0.h, XZR, x25\n" - "decw x25, ALL, MUL #2\n" - "ld1h { z20.h }, p0/Z, [x9]\n" - "cmp x25, #0x0\n" - "addvl x9, x9, #1\n" - "ld1h { z17.h }, p0/Z, [x27]\n" - "addvl x27, x27, #1\n" - "ld1h { z19.h }, p0/Z, [x26]\n" - "addvl x26, x26, #1\n" - "ld1h { z16.h }, p0/Z, [x24]\n" - "addvl x24, x24, #1\n" - "ld1h { z18.h }, p0/Z, [x23]\n" - "addvl x23, x23, #1\n" - "zip1 z24.h, z20.h, z17.h\n" - "zip2 z23.h, z20.h, z17.h\n" - "ld1h { z17.h }, p0/Z, [x22]\n" - "addvl x22, x22, #1\n" - "ld1h { z22.h }, p0/Z, [x21]\n" - "addvl x21, x21, #1\n" - "zip1 z21.h, z19.h, z16.h\n" - "zip2 z20.h, z19.h, z16.h\n" - "ld1h { z16.h }, p0/Z, [x20]\n" - "addvl x20, x20, #1\n" - "zip1 z19.h, z18.h, z17.h\n" - "zip2 z18.h, z18.h, z17.h\n" - "st1h { z24.h }, p1, [x28]\n" - "st1h { z23.h }, p1, [x28, #1, MUL VL]\n" - "zip1 z17.h, z22.h, z16.h\n" - "zip2 z16.h, z22.h, z16.h\n" - "st1h { z21.h }, p1, [x28, #2, MUL VL]\n" - "st1h { z20.h }, p1, [x28, #3, MUL VL]\n" - "st1h { z19.h }, p1, [x28, #4, MUL VL]\n" - "st1h { z18.h }, p1, [x28, #5, MUL VL]\n" - "st1h { z17.h }, p1, [x28, #6, MUL VL]\n" - "st1h { z16.h }, p1, [x28, #7, MUL VL]\n" - "add x28, x28, %x[out_stride]\n" - "bgt 3b\n" - "cmp %x[height], #0x8\n" - "addvl %x[out], %x[out], #8\n" - "bge 2b\n" - "cbz %x[height], 9f\n" - "5:" // Main loop skip - "6:" // Tail row loop: Head - "mov x9, %x[in]\n" - "cmp %x[height], #0x1\n" - "add x27, x9, %x[in_stride]\n" - "mov x28, %x[out]\n" - "add %x[in], x27, %x[in_stride]\n" - "csel x27, x27, %x[pad_row], GT\n" - "sub %x[height], %x[height], #0x2\n" - "mov x20, %x[width]\n" - "7:" // Tail row loop: Column loop - "whilelt p0.h, XZR, x20\n" - "decw x20, ALL, MUL #2\n" - "ld1h { z18.h }, p0/Z, [x9]\n" - "cmp x20, #0x0\n" - "addvl x9, x9, #1\n" - "ld1h { z16.h }, p0/Z, [x27]\n" - "addvl x27, x27, #1\n" - "zip1 z17.h, z18.h, z16.h\n" - "zip2 z16.h, z18.h, z16.h\n" - "st1h { z17.h }, p1, [x28]\n" - "st1h { z16.h }, p1, [x28, #1, MUL VL]\n" - "add x28, x28, %x[out_stride]\n" - "bgt 7b\n" - "cmp %x[height], #0x1\n" - "addvl %x[out], %x[out], #2\n" - "bge 6b\n" - "9:" // Done - ".inst 0xd503467f // SMSTOP\n" - : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) - : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", "z10", "z11", - "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", - "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + KAI_ASSERT(kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme() <= MAX_N_STEP); + static const uint16_t pad_row[MAX_N_STEP] = {0}; + + KernelArgs args; + args.bias_ptr = bias; + args.height = k; + args.width = n; + args.in = rhs; + args.out = rhs_packed; + args.in_stride = rhs_stride_row; + args.out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(args.height); + args.pad_row = pad_row; + + kai_kernel_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h index 8eacaed500a4d36d7ad1d53e8469d07cf0f1f194..833d17972f38d1eff09d04c81a598ccd5fd8deff 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -33,11 +33,11 @@ size_t kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx); /// @return The offset in bytes to the data element. size_t kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx); -/// Get the row stride in bytes to the packed RHS matrix +/// Gets row stride in bytes of the packed RHS matrix. /// -/// @param[in] k Number of rows in unpacked RHS +/// @param[in] k Number of columns of the unpacked RHS matrix. /// -/// @return Row stride in bytes +/// @return Row stride in bytes. size_t kai_get_rhs_packed_stride_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t k); /// Gets the offset in bytes to the data element in the packed RHS buffer. @@ -67,19 +67,19 @@ size_t kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n, siz /// /// @param[in] num_groups Number of groups. It must be 1. /// @param[in] n Number of columns of the output matrix. -/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] k Number of rows. /// @param[in] nr Block size in N dimension. It must be `get_n_step` /// @param[in] kr Block size in K dimension. It must be 2. /// @param[in] sr Number of kr splits. It must be 1. -/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs_stride_row Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. /// @param[in] scale Scale data buffer. It must be NULL. /// @param[out] rhs_packed Packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. -/// @param[in] params Extra packing parameters. It must be NULL. +/// @param[in] params Packing parameters. It must be NULL. void kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride_row, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); #ifdef __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..07fc40ab0e1cd20272f2a25d1bdbdc0f92f7c345 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S @@ -0,0 +1,178 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(rhs_pack_kxn_x16p2vlx2b_x16_x16_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x16, [x0, #0x8] + ptrue p1.b + ldr x15, [x0, #0x30] + ldr x23, [x0, #0x0] + ldr x22, [x0, #0x10] + mov x21, x16 + ldr x14, [x0, #0x18] + mov x20, x15 + ldr x13, [x0, #0x20] + ldr x12, [x0, #0x28] + ldr x11, [x0, #0x38] +KAI_ASM_LABEL(label_1) // Bias: Full loop + whilelt p0.h, XZR, x21 + dech x21 + cmp x21, #0x0 + ld1h { z16.h }, p0/Z, [x23] + incb x23 + st1h { z16.h }, p1, [x20] + add x20, x20, x13 + bgt label_1 + mov x10, x22 + incb x15 + cmp x10, #0x8 + blt label_5 +KAI_ASM_LABEL(label_2) // Main row loop: Head + mov x9, x12 + mov x28, x15 + add x27, x9, x14 + sub x10, x10, #0x8 + add x26, x27, x14 + mov x25, x16 + add x24, x26, x14 + add x23, x24, x14 + add x22, x23, x14 + add x21, x22, x14 + add x20, x21, x14 + add x12, x20, x14 +KAI_ASM_LABEL(label_3) // Main row loop: Column loop + whilelt p0.h, XZR, x25 + decw x25, ALL, MUL #2 + ld1h { z20.h }, p0/Z, [x9] + cmp x25, #0x0 + addvl x9, x9, #1 + ld1h { z17.h }, p0/Z, [x27] + addvl x27, x27, #1 + ld1h { z19.h }, p0/Z, [x26] + addvl x26, x26, #1 + ld1h { z16.h }, p0/Z, [x24] + addvl x24, x24, #1 + ld1h { z18.h }, p0/Z, [x23] + addvl x23, x23, #1 + zip1 z24.h, z20.h, z17.h + zip2 z23.h, z20.h, z17.h + ld1h { z17.h }, p0/Z, [x22] + addvl x22, x22, #1 + ld1h { z22.h }, p0/Z, [x21] + addvl x21, x21, #1 + zip1 z21.h, z19.h, z16.h + zip2 z20.h, z19.h, z16.h + ld1h { z16.h }, p0/Z, [x20] + addvl x20, x20, #1 + zip1 z19.h, z18.h, z17.h + zip2 z18.h, z18.h, z17.h + st1h { z24.h }, p1, [x28] + st1h { z23.h }, p1, [x28, #1, MUL VL] + zip1 z17.h, z22.h, z16.h + zip2 z16.h, z22.h, z16.h + st1h { z21.h }, p1, [x28, #2, MUL VL] + st1h { z20.h }, p1, [x28, #3, MUL VL] + st1h { z19.h }, p1, [x28, #4, MUL VL] + st1h { z18.h }, p1, [x28, #5, MUL VL] + st1h { z17.h }, p1, [x28, #6, MUL VL] + st1h { z16.h }, p1, [x28, #7, MUL VL] + add x28, x28, x13 + bgt label_3 + cmp x10, #0x8 + addvl x15, x15, #8 + bge label_2 + cbz x10, label_9 +KAI_ASM_LABEL(label_5) // Main loop skip +KAI_ASM_LABEL(label_6) // Tail row loop: Head + mov x9, x12 + cntw x22, ALL, MUL #4 + add x27, x9, x14 + cmp x10, #0x1 + add x12, x27, x14 + mov x28, x15 + csel x12, x12, x27, GT + csel x27, x27, x11, GT + csel x21, x22, XZR, GT + sub x10, x10, #0x2 + mov x20, x16 +KAI_ASM_LABEL(label_7) // Tail row loop: Column loop + whilelt p0.h, XZR, x20 + decw x20, ALL, MUL #2 + ld1h { z18.h }, p0/Z, [x9] + cmp x20, #0x0 + add x9, x9, x22 + ld1h { z16.h }, p0/Z, [x27] + add x27, x27, x21 + zip1 z17.h, z18.h, z16.h + zip2 z16.h, z18.h, z16.h + st1h { z17.h }, p1, [x28] + st1h { z16.h }, p1, [x28, #1, MUL VL] + add x28, x28, x13 + bgt label_7 + cmp x10, #0x1 + addvl x15, x15, #2 + bge label_6 +KAI_ASM_LABEL(label_9) // Done + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme) + + KAI_ASM_END