From c8bffb33a15051827ba4c523cccaaf79afb1b57c Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Tue, 8 Jul 2025 15:56:08 +0100 Subject: [PATCH 1/5] Add SME1 F32 GEMV kernel This SME1 GEMV kernel computes a 1x8VL block and is designed to work with the same RHS packing function as the SME1 GEMM. Signed-off-by: Jakub Sujak --- CHANGELOG.md | 2 + CMakeLists.txt | 2 + kai/ukernels/matmul/BUILD.bazel | 1 + ...l_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.c | 110 ++ ...l_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h | 114 ++ ...amp_f32_f32_f32p2vlx1b_1x8vl_sme_mla_asm.S | 1246 +++++++++++++++++ test/tests/matmul_clamp_f32_f32_f32p_test.cpp | 21 +- 7 files changed, 1491 insertions(+), 5 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla_asm.S diff --git a/CHANGELOG.md b/CHANGELOG.md index a9cc0096..b78e76f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - New Advanced SIMD micro-kernels: - Matrix multiplication (MxN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with BF16 output, optimized for FEAT_I8MM. - Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with BF16 output, optimized for FEAT_DotProd. +- New SME micro-kernels: + - SME1 compatible matrix multiplication (1xN) of F32 LHS and RHS with F32 output. ## v1.11.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index a74e779a..432e770e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -243,6 +243,8 @@ set(KLEIDIAI_FILES_NEON_I8MM set(KLEIDIAI_FILES_SME_ASM kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa_asm.S + kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.c + kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla_asm.S kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 5535a57e..4f83da44 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -156,6 +156,7 @@ SME_KERNELS = [ # buildifier: keep sorted SME_KERNELS_ASM = [ "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa", + "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla", "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa", "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", "pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.c b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.c new file mode 100644 index 00000000..fafa0f95 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.c @@ -0,0 +1,110 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + float maxval; + float minval; + const void* A_ptr; + const void* B_ptr; + size_t N; + size_t K; + void* output_ptr; + uint64_t flags; +} KernelArgs; + +static const size_t kai_m_step = 1; +static const size_t kai_nr = 2; +static const size_t kai_n_step = 8; +static const size_t kai_kr = 1; +static const size_t kai_sr = 1; + +void kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(KernelArgs* args_ptr); + +size_t kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(void) { + return kai_n_step * kai_get_sme_vector_length_u32() / kai_kr; +} + +size_t kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(void) { + return kai_nr * kai_get_sme_vector_length_u32() / kai_kr; +} + +size_t kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(void) { + return kai_sr; +} + +size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx == 0); + + return m_idx * k; +} + +static size_t kai_get_rhs_packed_stride_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(size_t k) { + return kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla() * + (kai_roundup(k, kai_kr) * sizeof(float) + sizeof(float)); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla() == 0); + + const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(); + return block_idx * kai_get_rhs_packed_stride_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME(m_idx == 0); + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla() == 0); + + return (m_idx * dst_stride) + (n_idx * sizeof(float)); +} + +size_t kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla( + size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, + size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) { + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(lhs_stride); + KAI_ASSUME(m == 1); + + uint64_t flags = 2; + + KernelArgs args; + + args.maxval = clamp_max; + args.minval = clamp_min; + args.A_ptr = lhs; + args.B_ptr = rhs_packed; + args.N = n; + args.K = k; + args.output_ptr = dst; + args.flags = flags; + + kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h new file mode 100644 index 00000000..6a2123eb --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h @@ -0,0 +1,114 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme or kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme to pack the RHS + +/// -------------------------------------------------- + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(void); + +/// Gets the offset in bytes to the data element in the LHS matrix buffer. +/// +/// @param[in] m_idx Row index. This must be 0. +/// @param[in] k Columns of unpacked LHS. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(size_t m_idx, size_t k); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of n_step +/// @param[in] k Number of rows in the unpacked RHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. Must be 0 +/// @param[in] n_idx Column index. Must be multiple of n_step +/// @param[in] dst_stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla. +/// +/// @param[in] m Number of output rows to be computed. This must be 1. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operand. +/// @param[in] lhs LHS matrix buffer. +/// @param[in] rhs_packed Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. Currently, an unused parameter. +/// @param[in] dst_stride_col Column stride in bytes of the output matrix. Currently, an unused parameter. +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla( + size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, + size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla_asm.S new file mode 100644 index 00000000..2f844acb --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla_asm.S @@ -0,0 +1,1246 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x14, [x0, #0x20] + cntw x13 + cntw x20, ALL, MUL #2 + ldr x12, [x0, #0x18] + ptrue p2.b + ldr x11, [x0, #0x8] + mov x10, x14 + ldr x9, [x0, #0x10] + add x28, x12, x13 + lsl x10, x10, #0x2 + ldr x27, [x0, #0x28] + sub x28, x28, #0x1 + add x10, x10, #0x4 + ldr x26, [x0, #0x30] + udiv x28, x28, x13 + mul x10, x10, x20 +KAI_ASM_LABEL(label_1) // Column loop + cmp x28, #0x8 + bge label_36 + cmp x28, #0x6 + bgt label_31 + beq label_26 + cmp x28, #0x4 + bgt label_21 + beq label_16 + cmp x28, #0x2 + bgt label_11 + beq label_6 + mov x25, x14 + whilelt p1.s, XZR, x12 + ld1w { z24.s }, p2/Z, [x9] + cmp x25, #0x4 + mov x24, x11 + addvl x9, x9, #2 + ble label_3 +KAI_ASM_LABEL(label_2) // Width 1: Multiply loop: Main loop head + whilelt p0.s, XZR, x25 + ldnt1w { z1.s }, p2/Z, [x9] + addvl x9, x9, #2 + ld1rqw { z0.s }, p0/Z, [x24] + sub x25, x25, #0x4 + add x24, x24, #0x10 + ldnt1w { z2.s }, p2/Z, [x9] + addvl x9, x9, #2 + cmp x25, #0x4 + ldnt1w { z3.s }, p2/Z, [x9] + addvl x9, x9, #2 + fmla z24.s, z1.s, z0.s[0] + ldnt1w { z4.s }, p2/Z, [x9] + addvl x9, x9, #2 + fmla z24.s, z2.s, z0.s[1] + fmla z24.s, z3.s, z0.s[2] + fmla z24.s, z4.s, z0.s[3] + bgt label_2 +KAI_ASM_LABEL(label_3) // Width 1: Multiply loop: Single iteration only + whilelt p0.s, XZR, x25 + ldnt1w { z5.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ld1rqw { z0.s }, p0/Z, [x24] + addvl x9, x9, #2 + fmla z24.s, z5.s, z0.s[0] + ble label_4 + ldnt1w { z6.s }, p2/Z, [x9] + subs x25, x25, #0x1 + addvl x9, x9, #2 + fmla z24.s, z6.s, z0.s[1] + ble label_4 + ldnt1w { z7.s }, p2/Z, [x9] + subs x25, x25, #0x1 + addvl x9, x9, #2 + fmla z24.s, z7.s, z0.s[2] + ble label_4 + ldnt1w { z8.s }, p2/Z, [x9] + fmla z24.s, z8.s, z0.s[3] +KAI_ASM_LABEL(label_4) // Width 1: Multiply loop: multiply skip + tbz x26, #1, label_5 + add x21, x0, #0x0 + add x20, x0, #0x4 + KAI_ASM_INST(0x8540cab1) // ld1rw { z17.s }, p2/Z, [x21] + KAI_ASM_INST(0x8540ca90) // ld1rw { z16.s }, p2/Z, [x20] + fmin z24.s, p2/M, z24.s, z17.s + fmax z24.s, p2/M, z24.s, z16.s +KAI_ASM_LABEL(label_5) // Width 1: No activation + st1w { z24.s }, p1, [x27] + b label_41 +KAI_ASM_LABEL(label_6) // Width 2 + mov x25, x14 + sub x20, x12, x13 + ld1w { z24.s }, p2/Z, [x9] + whilelt p1.s, XZR, x20 + cmp x25, #0x4 + ld1w { z25.s }, p2/Z, [x9, #1, MUL VL] + mov x24, x11 + addvl x9, x9, #2 + ble label_8 +KAI_ASM_LABEL(label_7) // Width 2: Multiply loop: Main loop head + whilelt p0.s, XZR, x25 + ldnt1w { z1.s }, p2/Z, [x9] + sub x25, x25, #0x4 + ld1rqw { z0.s }, p0/Z, [x24] + cmp x25, #0x4 + add x24, x24, #0x10 + ldnt1w { z2.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z3.s }, p2/Z, [x9] + fmla z24.s, z1.s, z0.s[0] + ldnt1w { z4.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z25.s, z2.s, z0.s[0] + ldnt1w { z5.s }, p2/Z, [x9] + ldnt1w { z6.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z7.s }, p2/Z, [x9] + fmla z24.s, z3.s, z0.s[1] + ldnt1w { z8.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z25.s, z4.s, z0.s[1] + fmla z24.s, z5.s, z0.s[2] + fmla z25.s, z6.s, z0.s[2] + fmla z24.s, z7.s, z0.s[3] + fmla z25.s, z8.s, z0.s[3] + bgt label_7 +KAI_ASM_LABEL(label_8) // Width 2: Multiply loop: Single iteration only + whilelt p0.s, XZR, x25 + ldnt1w { z9.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ld1rqw { z0.s }, p0/Z, [x24] + ldnt1w { z10.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z24.s, z9.s, z0.s[0] + fmla z25.s, z10.s, z0.s[0] + ble label_9 + ldnt1w { z11.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z12.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z24.s, z11.s, z0.s[1] + fmla z25.s, z12.s, z0.s[1] + ble label_9 + ldnt1w { z13.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z14.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z24.s, z13.s, z0.s[2] + fmla z25.s, z14.s, z0.s[2] + ble label_9 + ldnt1w { z15.s }, p2/Z, [x9] + ldnt1w { z16.s }, p2/Z, [x9, #1, MUL VL] + fmla z24.s, z15.s, z0.s[3] + fmla z25.s, z16.s, z0.s[3] +KAI_ASM_LABEL(label_9) // Width 2: Multiply loop: multiply skip + tbz x26, #1, label_10 + add x21, x0, #0x0 + add x20, x0, #0x4 + KAI_ASM_INST(0x8540cab1) // ld1rw { z17.s }, p2/Z, [x21] + KAI_ASM_INST(0x8540ca90) // ld1rw { z16.s }, p2/Z, [x20] + fmin z24.s, p2/M, z24.s, z17.s + fmin z25.s, p2/M, z25.s, z17.s + fmax z24.s, p2/M, z24.s, z16.s + fmax z25.s, p2/M, z25.s, z16.s +KAI_ASM_LABEL(label_10) // Width 2: No activation + st1w { z24.s }, p2, [x27] + st1w { z25.s }, p1, [x27, #1, MUL VL] + b label_41 +KAI_ASM_LABEL(label_11) // Width 3 + mov x20, #0x2 + mov x25, x14 + ld1w { z24.s }, p2/Z, [x9] + msub x21, x13, x20, x12 + add x20, x9, x10 + ld1w { z25.s }, p2/Z, [x9, #1, MUL VL] + whilelt p1.s, XZR, x21 + cmp x25, #0x4 + ld1w { z26.s }, p2/Z, [x20] + mov x24, x11 + addvl x9, x9, #2 + addvl x20, x20, #2 + ble label_13 +KAI_ASM_LABEL(label_12) // Width 3: Multiply loop: Main loop head + whilelt p0.s, XZR, x25 + ldnt1w { z1.s }, p2/Z, [x9] + sub x25, x25, #0x4 + ld1rqw { z0.s }, p0/Z, [x24] + cmp x25, #0x4 + add x24, x24, #0x10 + ldnt1w { z2.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z3.s }, p2/Z, [x20] + addvl x20, x20, #2 + fmla z24.s, z1.s, z0.s[0] + ldnt1w { z4.s }, p2/Z, [x9] + fmla z25.s, z2.s, z0.s[0] + ldnt1w { z5.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z26.s, z3.s, z0.s[0] + ldnt1w { z6.s }, p2/Z, [x20] + addvl x20, x20, #2 + ldnt1w { z7.s }, p2/Z, [x9] + fmla z24.s, z4.s, z0.s[1] + ldnt1w { z8.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z25.s, z5.s, z0.s[1] + ldnt1w { z9.s }, p2/Z, [x20] + addvl x20, x20, #2 + fmla z26.s, z6.s, z0.s[1] + ldnt1w { z10.s }, p2/Z, [x9] + ldnt1w { z11.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z24.s, z7.s, z0.s[2] + ldnt1w { z12.s }, p2/Z, [x20] + addvl x20, x20, #2 + fmla z25.s, z8.s, z0.s[2] + fmla z26.s, z9.s, z0.s[2] + fmla z24.s, z10.s, z0.s[3] + fmla z25.s, z11.s, z0.s[3] + fmla z26.s, z12.s, z0.s[3] + bgt label_12 +KAI_ASM_LABEL(label_13) // Width 3: Multiply loop: Single iteration only + whilelt p0.s, XZR, x25 + ldnt1w { z13.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ld1rqw { z0.s }, p0/Z, [x24] + ldnt1w { z14.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z15.s }, p2/Z, [x20] + addvl x20, x20, #2 + fmla z24.s, z13.s, z0.s[0] + fmla z25.s, z14.s, z0.s[0] + fmla z26.s, z15.s, z0.s[0] + ble label_14 + ldnt1w { z16.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z17.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z18.s }, p2/Z, [x20] + addvl x20, x20, #2 + fmla z24.s, z16.s, z0.s[1] + fmla z25.s, z17.s, z0.s[1] + fmla z26.s, z18.s, z0.s[1] + ble label_14 + ldnt1w { z19.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z20.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z21.s }, p2/Z, [x20] + addvl x20, x20, #2 + fmla z24.s, z19.s, z0.s[2] + fmla z25.s, z20.s, z0.s[2] + fmla z26.s, z21.s, z0.s[2] + ble label_14 + ldnt1w { z22.s }, p2/Z, [x9] + ldnt1w { z23.s }, p2/Z, [x9, #1, MUL VL] + ldnt1w { z1.s }, p2/Z, [x20] + fmla z24.s, z22.s, z0.s[3] + fmla z25.s, z23.s, z0.s[3] + fmla z26.s, z1.s, z0.s[3] +KAI_ASM_LABEL(label_14) // Width 3: Multiply loop: multiply skip + tbz x26, #1, label_15 + add x21, x0, #0x0 + add x20, x0, #0x4 + KAI_ASM_INST(0x8540cab1) // ld1rw { z17.s }, p2/Z, [x21] + KAI_ASM_INST(0x8540ca90) // ld1rw { z16.s }, p2/Z, [x20] + fmin z24.s, p2/M, z24.s, z17.s + fmin z25.s, p2/M, z25.s, z17.s + fmin z26.s, p2/M, z26.s, z17.s + fmax z24.s, p2/M, z24.s, z16.s + fmax z25.s, p2/M, z25.s, z16.s + fmax z26.s, p2/M, z26.s, z16.s +KAI_ASM_LABEL(label_15) // Width 3: No activation + st1w { z24.s }, p2, [x27] + st1w { z25.s }, p2, [x27, #1, MUL VL] + st1w { z26.s }, p1, [x27, #2, MUL VL] + b label_41 +KAI_ASM_LABEL(label_16) // Width 4 + mov x20, #0x3 + mov x25, x14 + ld1w { z24.s }, p2/Z, [x9] + msub x21, x13, x20, x12 + add x20, x9, x10 + ld1w { z25.s }, p2/Z, [x9, #1, MUL VL] + whilelt p1.s, XZR, x21 + cmp x25, #0x4 + ld1w { z26.s }, p2/Z, [x20] + mov x24, x11 + ld1w { z27.s }, p2/Z, [x20, #1, MUL VL] + addvl x9, x9, #2 + addvl x20, x20, #2 + ble label_18 +KAI_ASM_LABEL(label_17) // Width 4: Multiply loop: Main loop head + whilelt p0.s, XZR, x25 + ldnt1w { z1.s }, p2/Z, [x9] + sub x25, x25, #0x4 + ld1rqw { z0.s }, p0/Z, [x24] + cmp x25, #0x4 + add x24, x24, #0x10 + ldnt1w { z2.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z3.s }, p2/Z, [x20] + ldnt1w { z4.s }, p2/Z, [x20, #1, MUL VL] + fmla z24.s, z1.s, z0.s[0] + addvl x20, x20, #2 + fmla z25.s, z2.s, z0.s[0] + ldnt1w { z5.s }, p2/Z, [x9] + fmla z26.s, z3.s, z0.s[0] + ldnt1w { z6.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z27.s, z4.s, z0.s[0] + ldnt1w { z7.s }, p2/Z, [x20] + ldnt1w { z8.s }, p2/Z, [x20, #1, MUL VL] + fmla z24.s, z5.s, z0.s[1] + addvl x20, x20, #2 + fmla z25.s, z6.s, z0.s[1] + ldnt1w { z9.s }, p2/Z, [x9] + fmla z26.s, z7.s, z0.s[1] + ldnt1w { z10.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z27.s, z8.s, z0.s[1] + ldnt1w { z11.s }, p2/Z, [x20] + ldnt1w { z12.s }, p2/Z, [x20, #1, MUL VL] + fmla z24.s, z9.s, z0.s[2] + addvl x20, x20, #2 + fmla z25.s, z10.s, z0.s[2] + ldnt1w { z13.s }, p2/Z, [x9] + fmla z26.s, z11.s, z0.s[2] + ldnt1w { z14.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z27.s, z12.s, z0.s[2] + ldnt1w { z15.s }, p2/Z, [x20] + ldnt1w { z16.s }, p2/Z, [x20, #1, MUL VL] + fmla z24.s, z13.s, z0.s[3] + addvl x20, x20, #2 + fmla z25.s, z14.s, z0.s[3] + fmla z26.s, z15.s, z0.s[3] + fmla z27.s, z16.s, z0.s[3] + bgt label_17 +KAI_ASM_LABEL(label_18) // Width 4: Multiply loop: Single iteration only + whilelt p0.s, XZR, x25 + ldnt1w { z17.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ld1rqw { z0.s }, p0/Z, [x24] + ldnt1w { z18.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z19.s }, p2/Z, [x20] + ldnt1w { z20.s }, p2/Z, [x20, #1, MUL VL] + fmla z24.s, z17.s, z0.s[0] + addvl x20, x20, #2 + fmla z25.s, z18.s, z0.s[0] + fmla z26.s, z19.s, z0.s[0] + fmla z27.s, z20.s, z0.s[0] + ble label_19 + ldnt1w { z21.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z22.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z23.s }, p2/Z, [x20] + ldnt1w { z1.s }, p2/Z, [x20, #1, MUL VL] + fmla z24.s, z21.s, z0.s[1] + addvl x20, x20, #2 + fmla z25.s, z22.s, z0.s[1] + fmla z26.s, z23.s, z0.s[1] + fmla z27.s, z1.s, z0.s[1] + ble label_19 + ldnt1w { z2.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z3.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z4.s }, p2/Z, [x20] + ldnt1w { z5.s }, p2/Z, [x20, #1, MUL VL] + fmla z24.s, z2.s, z0.s[2] + addvl x20, x20, #2 + fmla z25.s, z3.s, z0.s[2] + fmla z26.s, z4.s, z0.s[2] + fmla z27.s, z5.s, z0.s[2] + ble label_19 + ldnt1w { z6.s }, p2/Z, [x9] + ldnt1w { z7.s }, p2/Z, [x9, #1, MUL VL] + ldnt1w { z8.s }, p2/Z, [x20] + ldnt1w { z9.s }, p2/Z, [x20, #1, MUL VL] + fmla z24.s, z6.s, z0.s[3] + fmla z25.s, z7.s, z0.s[3] + fmla z26.s, z8.s, z0.s[3] + fmla z27.s, z9.s, z0.s[3] +KAI_ASM_LABEL(label_19) // Width 4: Multiply loop: multiply skip + tbz x26, #1, label_20 + add x21, x0, #0x0 + add x20, x0, #0x4 + KAI_ASM_INST(0x8540cab1) // ld1rw { z17.s }, p2/Z, [x21] + KAI_ASM_INST(0x8540ca90) // ld1rw { z16.s }, p2/Z, [x20] + fmin z24.s, p2/M, z24.s, z17.s + fmin z25.s, p2/M, z25.s, z17.s + fmin z26.s, p2/M, z26.s, z17.s + fmin z27.s, p2/M, z27.s, z17.s + fmax z24.s, p2/M, z24.s, z16.s + fmax z25.s, p2/M, z25.s, z16.s + fmax z26.s, p2/M, z26.s, z16.s + fmax z27.s, p2/M, z27.s, z16.s +KAI_ASM_LABEL(label_20) // Width 4: No activation + st1w { z24.s }, p2, [x27] + st1w { z25.s }, p2, [x27, #1, MUL VL] + st1w { z26.s }, p2, [x27, #2, MUL VL] + st1w { z27.s }, p1, [x27, #3, MUL VL] + b label_41 +KAI_ASM_LABEL(label_21) // Width 5 + mov x20, #0x4 + mov x25, x14 + ld1w { z24.s }, p2/Z, [x9] + msub x22, x13, x20, x12 + add x21, x9, x10 + ld1w { z25.s }, p2/Z, [x9, #1, MUL VL] + add x20, x9, x10, LSL #1 + whilelt p1.s, XZR, x22 + ld1w { z26.s }, p2/Z, [x21] + cmp x25, #0x4 + mov x24, x11 + ld1w { z27.s }, p2/Z, [x21, #1, MUL VL] + ld1w { z28.s }, p2/Z, [x20] + addvl x9, x9, #2 + addvl x21, x21, #2 + addvl x20, x20, #2 + ble label_23 +KAI_ASM_LABEL(label_22) // Width 5: Multiply loop: Main loop head + whilelt p0.s, XZR, x25 + ldnt1w { z1.s }, p2/Z, [x9] + sub x25, x25, #0x4 + ld1rqw { z0.s }, p0/Z, [x24] + cmp x25, #0x4 + add x24, x24, #0x10 + ldnt1w { z2.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z3.s }, p2/Z, [x21] + ldnt1w { z4.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z1.s, z0.s[0] + addvl x21, x21, #2 + ldnt1w { z5.s }, p2/Z, [x20] + fmla z25.s, z2.s, z0.s[0] + addvl x20, x20, #2 + fmla z26.s, z3.s, z0.s[0] + ldnt1w { z6.s }, p2/Z, [x9] + fmla z27.s, z4.s, z0.s[0] + ldnt1w { z7.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z28.s, z5.s, z0.s[0] + ldnt1w { z8.s }, p2/Z, [x21] + ldnt1w { z9.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z6.s, z0.s[1] + addvl x21, x21, #2 + ldnt1w { z10.s }, p2/Z, [x20] + fmla z25.s, z7.s, z0.s[1] + addvl x20, x20, #2 + fmla z26.s, z8.s, z0.s[1] + ldnt1w { z11.s }, p2/Z, [x9] + fmla z27.s, z9.s, z0.s[1] + ldnt1w { z12.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z28.s, z10.s, z0.s[1] + ldnt1w { z13.s }, p2/Z, [x21] + ldnt1w { z14.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z11.s, z0.s[2] + addvl x21, x21, #2 + ldnt1w { z15.s }, p2/Z, [x20] + fmla z25.s, z12.s, z0.s[2] + addvl x20, x20, #2 + fmla z26.s, z13.s, z0.s[2] + ldnt1w { z16.s }, p2/Z, [x9] + fmla z27.s, z14.s, z0.s[2] + ldnt1w { z17.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z28.s, z15.s, z0.s[2] + ldnt1w { z18.s }, p2/Z, [x21] + ldnt1w { z19.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z16.s, z0.s[3] + addvl x21, x21, #2 + ldnt1w { z20.s }, p2/Z, [x20] + fmla z25.s, z17.s, z0.s[3] + addvl x20, x20, #2 + fmla z26.s, z18.s, z0.s[3] + fmla z27.s, z19.s, z0.s[3] + fmla z28.s, z20.s, z0.s[3] + bgt label_22 +KAI_ASM_LABEL(label_23) // Width 5: Multiply loop: Single iteration only + whilelt p0.s, XZR, x25 + ldnt1w { z21.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ld1rqw { z0.s }, p0/Z, [x24] + ldnt1w { z22.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z23.s }, p2/Z, [x21] + ldnt1w { z1.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z21.s, z0.s[0] + addvl x21, x21, #2 + ldnt1w { z2.s }, p2/Z, [x20] + fmla z25.s, z22.s, z0.s[0] + addvl x20, x20, #2 + fmla z26.s, z23.s, z0.s[0] + fmla z27.s, z1.s, z0.s[0] + fmla z28.s, z2.s, z0.s[0] + ble label_24 + ldnt1w { z3.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z4.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z5.s }, p2/Z, [x21] + ldnt1w { z6.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z3.s, z0.s[1] + addvl x21, x21, #2 + ldnt1w { z7.s }, p2/Z, [x20] + fmla z25.s, z4.s, z0.s[1] + addvl x20, x20, #2 + fmla z26.s, z5.s, z0.s[1] + fmla z27.s, z6.s, z0.s[1] + fmla z28.s, z7.s, z0.s[1] + ble label_24 + ldnt1w { z8.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z9.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z10.s }, p2/Z, [x21] + ldnt1w { z11.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z8.s, z0.s[2] + addvl x21, x21, #2 + ldnt1w { z12.s }, p2/Z, [x20] + fmla z25.s, z9.s, z0.s[2] + addvl x20, x20, #2 + fmla z26.s, z10.s, z0.s[2] + fmla z27.s, z11.s, z0.s[2] + fmla z28.s, z12.s, z0.s[2] + ble label_24 + ldnt1w { z13.s }, p2/Z, [x9] + ldnt1w { z14.s }, p2/Z, [x9, #1, MUL VL] + ldnt1w { z15.s }, p2/Z, [x21] + ldnt1w { z16.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z13.s, z0.s[3] + ldnt1w { z17.s }, p2/Z, [x20] + fmla z25.s, z14.s, z0.s[3] + fmla z26.s, z15.s, z0.s[3] + fmla z27.s, z16.s, z0.s[3] + fmla z28.s, z17.s, z0.s[3] +KAI_ASM_LABEL(label_24) // Width 5: Multiply loop: multiply skip + tbz x26, #1, label_25 + add x21, x0, #0x0 + add x20, x0, #0x4 + KAI_ASM_INST(0x8540cab1) // ld1rw { z17.s }, p2/Z, [x21] + KAI_ASM_INST(0x8540ca90) // ld1rw { z16.s }, p2/Z, [x20] + fmin z24.s, p2/M, z24.s, z17.s + fmin z25.s, p2/M, z25.s, z17.s + fmin z26.s, p2/M, z26.s, z17.s + fmin z27.s, p2/M, z27.s, z17.s + fmin z28.s, p2/M, z28.s, z17.s + fmax z24.s, p2/M, z24.s, z16.s + fmax z25.s, p2/M, z25.s, z16.s + fmax z26.s, p2/M, z26.s, z16.s + fmax z27.s, p2/M, z27.s, z16.s + fmax z28.s, p2/M, z28.s, z16.s +KAI_ASM_LABEL(label_25) // Width 5: No activation + st1w { z24.s }, p2, [x27] + st1w { z25.s }, p2, [x27, #1, MUL VL] + st1w { z26.s }, p2, [x27, #2, MUL VL] + st1w { z27.s }, p2, [x27, #3, MUL VL] + st1w { z28.s }, p1, [x27, #4, MUL VL] + b label_41 +KAI_ASM_LABEL(label_26) // Width 6 + mov x20, #0x5 + mov x25, x14 + ld1w { z24.s }, p2/Z, [x9] + msub x22, x13, x20, x12 + add x21, x9, x10 + ld1w { z25.s }, p2/Z, [x9, #1, MUL VL] + add x20, x9, x10, LSL #1 + whilelt p1.s, XZR, x22 + ld1w { z26.s }, p2/Z, [x21] + cmp x25, #0x4 + mov x24, x11 + ld1w { z27.s }, p2/Z, [x21, #1, MUL VL] + ld1w { z28.s }, p2/Z, [x20] + addvl x9, x9, #2 + addvl x21, x21, #2 + ld1w { z29.s }, p2/Z, [x20, #1, MUL VL] + addvl x20, x20, #2 + ble label_28 +KAI_ASM_LABEL(label_27) // Width 6: Multiply loop: Main loop head + whilelt p0.s, XZR, x25 + ldnt1w { z1.s }, p2/Z, [x9] + sub x25, x25, #0x4 + ld1rqw { z0.s }, p0/Z, [x24] + cmp x25, #0x4 + add x24, x24, #0x10 + ldnt1w { z2.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z3.s }, p2/Z, [x21] + ldnt1w { z4.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z1.s, z0.s[0] + addvl x21, x21, #2 + ldnt1w { z5.s }, p2/Z, [x20] + fmla z25.s, z2.s, z0.s[0] + ldnt1w { z6.s }, p2/Z, [x20, #1, MUL VL] + fmla z26.s, z3.s, z0.s[0] + addvl x20, x20, #2 + fmla z27.s, z4.s, z0.s[0] + ldnt1w { z7.s }, p2/Z, [x9] + fmla z28.s, z5.s, z0.s[0] + ldnt1w { z8.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z29.s, z6.s, z0.s[0] + ldnt1w { z9.s }, p2/Z, [x21] + ldnt1w { z10.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z7.s, z0.s[1] + addvl x21, x21, #2 + ldnt1w { z11.s }, p2/Z, [x20] + fmla z25.s, z8.s, z0.s[1] + ldnt1w { z12.s }, p2/Z, [x20, #1, MUL VL] + fmla z26.s, z9.s, z0.s[1] + addvl x20, x20, #2 + fmla z27.s, z10.s, z0.s[1] + ldnt1w { z13.s }, p2/Z, [x9] + fmla z28.s, z11.s, z0.s[1] + ldnt1w { z14.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z29.s, z12.s, z0.s[1] + ldnt1w { z15.s }, p2/Z, [x21] + ldnt1w { z16.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z13.s, z0.s[2] + addvl x21, x21, #2 + ldnt1w { z17.s }, p2/Z, [x20] + fmla z25.s, z14.s, z0.s[2] + ldnt1w { z18.s }, p2/Z, [x20, #1, MUL VL] + fmla z26.s, z15.s, z0.s[2] + addvl x20, x20, #2 + fmla z27.s, z16.s, z0.s[2] + ldnt1w { z19.s }, p2/Z, [x9] + fmla z28.s, z17.s, z0.s[2] + ldnt1w { z20.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z29.s, z18.s, z0.s[2] + ldnt1w { z21.s }, p2/Z, [x21] + ldnt1w { z22.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z19.s, z0.s[3] + addvl x21, x21, #2 + ldnt1w { z23.s }, p2/Z, [x20] + fmla z25.s, z20.s, z0.s[3] + ldnt1w { z1.s }, p2/Z, [x20, #1, MUL VL] + fmla z26.s, z21.s, z0.s[3] + addvl x20, x20, #2 + fmla z27.s, z22.s, z0.s[3] + fmla z28.s, z23.s, z0.s[3] + fmla z29.s, z1.s, z0.s[3] + bgt label_27 +KAI_ASM_LABEL(label_28) // Width 6: Multiply loop: Single iteration only + whilelt p0.s, XZR, x25 + ldnt1w { z2.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ld1rqw { z0.s }, p0/Z, [x24] + ldnt1w { z3.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z4.s }, p2/Z, [x21] + ldnt1w { z5.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z2.s, z0.s[0] + addvl x21, x21, #2 + ldnt1w { z6.s }, p2/Z, [x20] + fmla z25.s, z3.s, z0.s[0] + ldnt1w { z7.s }, p2/Z, [x20, #1, MUL VL] + fmla z26.s, z4.s, z0.s[0] + addvl x20, x20, #2 + fmla z27.s, z5.s, z0.s[0] + fmla z28.s, z6.s, z0.s[0] + fmla z29.s, z7.s, z0.s[0] + ble label_29 + ldnt1w { z8.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z9.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z10.s }, p2/Z, [x21] + ldnt1w { z11.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z8.s, z0.s[1] + addvl x21, x21, #2 + ldnt1w { z12.s }, p2/Z, [x20] + fmla z25.s, z9.s, z0.s[1] + ldnt1w { z13.s }, p2/Z, [x20, #1, MUL VL] + fmla z26.s, z10.s, z0.s[1] + addvl x20, x20, #2 + fmla z27.s, z11.s, z0.s[1] + fmla z28.s, z12.s, z0.s[1] + fmla z29.s, z13.s, z0.s[1] + ble label_29 + ldnt1w { z14.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z15.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z16.s }, p2/Z, [x21] + ldnt1w { z17.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z14.s, z0.s[2] + addvl x21, x21, #2 + ldnt1w { z18.s }, p2/Z, [x20] + fmla z25.s, z15.s, z0.s[2] + ldnt1w { z19.s }, p2/Z, [x20, #1, MUL VL] + fmla z26.s, z16.s, z0.s[2] + addvl x20, x20, #2 + fmla z27.s, z17.s, z0.s[2] + fmla z28.s, z18.s, z0.s[2] + fmla z29.s, z19.s, z0.s[2] + ble label_29 + ldnt1w { z20.s }, p2/Z, [x9] + ldnt1w { z21.s }, p2/Z, [x9, #1, MUL VL] + ldnt1w { z22.s }, p2/Z, [x21] + ldnt1w { z23.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z20.s, z0.s[3] + ldnt1w { z1.s }, p2/Z, [x20] + fmla z25.s, z21.s, z0.s[3] + ldnt1w { z2.s }, p2/Z, [x20, #1, MUL VL] + fmla z26.s, z22.s, z0.s[3] + fmla z27.s, z23.s, z0.s[3] + fmla z28.s, z1.s, z0.s[3] + fmla z29.s, z2.s, z0.s[3] +KAI_ASM_LABEL(label_29) // Width 6: Multiply loop: multiply skip + tbz x26, #1, label_30 + add x21, x0, #0x0 + add x20, x0, #0x4 + KAI_ASM_INST(0x8540cab1) // ld1rw { z17.s }, p2/Z, [x21] + KAI_ASM_INST(0x8540ca90) // ld1rw { z16.s }, p2/Z, [x20] + fmin z24.s, p2/M, z24.s, z17.s + fmin z25.s, p2/M, z25.s, z17.s + fmin z26.s, p2/M, z26.s, z17.s + fmin z27.s, p2/M, z27.s, z17.s + fmin z28.s, p2/M, z28.s, z17.s + fmin z29.s, p2/M, z29.s, z17.s + fmax z24.s, p2/M, z24.s, z16.s + fmax z25.s, p2/M, z25.s, z16.s + fmax z26.s, p2/M, z26.s, z16.s + fmax z27.s, p2/M, z27.s, z16.s + fmax z28.s, p2/M, z28.s, z16.s + fmax z29.s, p2/M, z29.s, z16.s +KAI_ASM_LABEL(label_30) // Width 6: No activation + st1w { z24.s }, p2, [x27] + st1w { z25.s }, p2, [x27, #1, MUL VL] + st1w { z26.s }, p2, [x27, #2, MUL VL] + st1w { z27.s }, p2, [x27, #3, MUL VL] + st1w { z28.s }, p2, [x27, #4, MUL VL] + st1w { z29.s }, p1, [x27, #5, MUL VL] + b label_41 +KAI_ASM_LABEL(label_31) // Width 7 + mov x20, #0x6 + mov x25, x14 + ld1w { z24.s }, p2/Z, [x9] + add x23, x9, x10, LSL #1 + msub x22, x13, x20, x12 + ld1w { z25.s }, p2/Z, [x9, #1, MUL VL] + add x21, x9, x10 + add x20, x23, x10 + ld1w { z28.s }, p2/Z, [x23] + whilelt p1.s, XZR, x22 + cmp x25, #0x4 + ld1w { z26.s }, p2/Z, [x21] + mov x24, x11 + ld1w { z27.s }, p2/Z, [x21, #1, MUL VL] + addvl x9, x9, #2 + ld1w { z29.s }, p2/Z, [x23, #1, MUL VL] + addvl x21, x21, #2 + addvl x23, x23, #2 + ld1w { z30.s }, p2/Z, [x20] + addvl x20, x20, #2 + ble label_33 +KAI_ASM_LABEL(label_32) // Width 7: Multiply loop: Main loop head + whilelt p0.s, XZR, x25 + ldnt1w { z1.s }, p2/Z, [x9] + sub x25, x25, #0x4 + ld1rqw { z0.s }, p0/Z, [x24] + cmp x25, #0x4 + add x24, x24, #0x10 + ldnt1w { z2.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z3.s }, p2/Z, [x21] + ldnt1w { z4.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z1.s, z0.s[0] + addvl x21, x21, #2 + ldnt1w { z5.s }, p2/Z, [x23] + fmla z25.s, z2.s, z0.s[0] + ldnt1w { z6.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z3.s, z0.s[0] + addvl x23, x23, #2 + ldnt1w { z7.s }, p2/Z, [x20] + fmla z27.s, z4.s, z0.s[0] + addvl x20, x20, #2 + fmla z28.s, z5.s, z0.s[0] + ldnt1w { z8.s }, p2/Z, [x9] + fmla z29.s, z6.s, z0.s[0] + ldnt1w { z9.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z30.s, z7.s, z0.s[0] + ldnt1w { z10.s }, p2/Z, [x21] + ldnt1w { z11.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z8.s, z0.s[1] + addvl x21, x21, #2 + ldnt1w { z12.s }, p2/Z, [x23] + fmla z25.s, z9.s, z0.s[1] + ldnt1w { z13.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z10.s, z0.s[1] + addvl x23, x23, #2 + ldnt1w { z14.s }, p2/Z, [x20] + fmla z27.s, z11.s, z0.s[1] + addvl x20, x20, #2 + fmla z28.s, z12.s, z0.s[1] + ldnt1w { z15.s }, p2/Z, [x9] + fmla z29.s, z13.s, z0.s[1] + ldnt1w { z16.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z30.s, z14.s, z0.s[1] + ldnt1w { z17.s }, p2/Z, [x21] + ldnt1w { z18.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z15.s, z0.s[2] + addvl x21, x21, #2 + ldnt1w { z19.s }, p2/Z, [x23] + fmla z25.s, z16.s, z0.s[2] + ldnt1w { z20.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z17.s, z0.s[2] + addvl x23, x23, #2 + ldnt1w { z21.s }, p2/Z, [x20] + fmla z27.s, z18.s, z0.s[2] + addvl x20, x20, #2 + fmla z28.s, z19.s, z0.s[2] + ldnt1w { z22.s }, p2/Z, [x9] + fmla z29.s, z20.s, z0.s[2] + ldnt1w { z23.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z30.s, z21.s, z0.s[2] + ldnt1w { z1.s }, p2/Z, [x21] + ldnt1w { z2.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z22.s, z0.s[3] + addvl x21, x21, #2 + ldnt1w { z3.s }, p2/Z, [x23] + fmla z25.s, z23.s, z0.s[3] + ldnt1w { z4.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z1.s, z0.s[3] + addvl x23, x23, #2 + ldnt1w { z5.s }, p2/Z, [x20] + fmla z27.s, z2.s, z0.s[3] + addvl x20, x20, #2 + fmla z28.s, z3.s, z0.s[3] + fmla z29.s, z4.s, z0.s[3] + fmla z30.s, z5.s, z0.s[3] + bgt label_32 +KAI_ASM_LABEL(label_33) // Width 7: Multiply loop: Single iteration only + whilelt p0.s, XZR, x25 + ldnt1w { z6.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ld1rqw { z0.s }, p0/Z, [x24] + ldnt1w { z7.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z8.s }, p2/Z, [x21] + ldnt1w { z9.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z6.s, z0.s[0] + addvl x21, x21, #2 + ldnt1w { z10.s }, p2/Z, [x23] + fmla z25.s, z7.s, z0.s[0] + ldnt1w { z11.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z8.s, z0.s[0] + addvl x23, x23, #2 + ldnt1w { z12.s }, p2/Z, [x20] + fmla z27.s, z9.s, z0.s[0] + addvl x20, x20, #2 + fmla z28.s, z10.s, z0.s[0] + fmla z29.s, z11.s, z0.s[0] + fmla z30.s, z12.s, z0.s[0] + ble label_34 + ldnt1w { z13.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z14.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z15.s }, p2/Z, [x21] + ldnt1w { z16.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z13.s, z0.s[1] + addvl x21, x21, #2 + ldnt1w { z17.s }, p2/Z, [x23] + fmla z25.s, z14.s, z0.s[1] + ldnt1w { z18.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z15.s, z0.s[1] + addvl x23, x23, #2 + ldnt1w { z19.s }, p2/Z, [x20] + fmla z27.s, z16.s, z0.s[1] + addvl x20, x20, #2 + fmla z28.s, z17.s, z0.s[1] + fmla z29.s, z18.s, z0.s[1] + fmla z30.s, z19.s, z0.s[1] + ble label_34 + ldnt1w { z20.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z21.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z22.s }, p2/Z, [x21] + ldnt1w { z23.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z20.s, z0.s[2] + addvl x21, x21, #2 + ldnt1w { z1.s }, p2/Z, [x23] + fmla z25.s, z21.s, z0.s[2] + ldnt1w { z2.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z22.s, z0.s[2] + addvl x23, x23, #2 + ldnt1w { z3.s }, p2/Z, [x20] + fmla z27.s, z23.s, z0.s[2] + addvl x20, x20, #2 + fmla z28.s, z1.s, z0.s[2] + fmla z29.s, z2.s, z0.s[2] + fmla z30.s, z3.s, z0.s[2] + ble label_34 + ldnt1w { z4.s }, p2/Z, [x9] + ldnt1w { z5.s }, p2/Z, [x9, #1, MUL VL] + ldnt1w { z6.s }, p2/Z, [x21] + ldnt1w { z7.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z4.s, z0.s[3] + ldnt1w { z8.s }, p2/Z, [x23] + fmla z25.s, z5.s, z0.s[3] + ldnt1w { z9.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z6.s, z0.s[3] + ldnt1w { z10.s }, p2/Z, [x20] + fmla z27.s, z7.s, z0.s[3] + fmla z28.s, z8.s, z0.s[3] + fmla z29.s, z9.s, z0.s[3] + fmla z30.s, z10.s, z0.s[3] +KAI_ASM_LABEL(label_34) // Width 7: Multiply loop: multiply skip + tbz x26, #1, label_35 + add x21, x0, #0x0 + add x20, x0, #0x4 + KAI_ASM_INST(0x8540cab1) // ld1rw { z17.s }, p2/Z, [x21] + KAI_ASM_INST(0x8540ca90) // ld1rw { z16.s }, p2/Z, [x20] + fmin z24.s, p2/M, z24.s, z17.s + fmin z25.s, p2/M, z25.s, z17.s + fmin z26.s, p2/M, z26.s, z17.s + fmin z27.s, p2/M, z27.s, z17.s + fmin z28.s, p2/M, z28.s, z17.s + fmin z29.s, p2/M, z29.s, z17.s + fmin z30.s, p2/M, z30.s, z17.s + fmax z24.s, p2/M, z24.s, z16.s + fmax z25.s, p2/M, z25.s, z16.s + fmax z26.s, p2/M, z26.s, z16.s + fmax z27.s, p2/M, z27.s, z16.s + fmax z28.s, p2/M, z28.s, z16.s + fmax z29.s, p2/M, z29.s, z16.s + fmax z30.s, p2/M, z30.s, z16.s +KAI_ASM_LABEL(label_35) // Width 7: No activation + st1w { z24.s }, p2, [x27] + st1w { z25.s }, p2, [x27, #1, MUL VL] + st1w { z26.s }, p2, [x27, #2, MUL VL] + st1w { z27.s }, p2, [x27, #3, MUL VL] + st1w { z28.s }, p2, [x27, #4, MUL VL] + st1w { z29.s }, p2, [x27, #5, MUL VL] + st1w { z30.s }, p1, [x27, #6, MUL VL] + b label_41 +KAI_ASM_LABEL(label_36) // Width 8 + mov x20, #0x7 + mov x25, x14 + ld1w { z24.s }, p2/Z, [x9] + add x23, x9, x10, LSL #1 + msub x22, x13, x20, x12 + ld1w { z25.s }, p2/Z, [x9, #1, MUL VL] + add x21, x9, x10 + add x20, x23, x10 + ld1w { z28.s }, p2/Z, [x23] + whilelt p1.s, XZR, x22 + cmp x25, #0x4 + ld1w { z26.s }, p2/Z, [x21] + mov x24, x11 + add x22, x9, x10, LSL #2 + ld1w { z27.s }, p2/Z, [x21, #1, MUL VL] + ld1w { z29.s }, p2/Z, [x23, #1, MUL VL] + addvl x9, x9, #2 + addvl x21, x21, #2 + ld1w { z30.s }, p2/Z, [x20] + addvl x23, x23, #2 + ld1w { z31.s }, p2/Z, [x20, #1, MUL VL] + addvl x20, x20, #2 + ble label_38 +KAI_ASM_LABEL(label_37) // Width 8: Multiply loop: Main loop head + whilelt p0.s, XZR, x25 + ldnt1w { z1.s }, p2/Z, [x9] + sub x25, x25, #0x4 + ld1rqw { z0.s }, p0/Z, [x24] + cmp x25, #0x4 + add x24, x24, #0x10 + ldnt1w { z2.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z3.s }, p2/Z, [x21] + ldnt1w { z4.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z1.s, z0.s[0] + addvl x21, x21, #2 + ldnt1w { z5.s }, p2/Z, [x23] + fmla z25.s, z2.s, z0.s[0] + ldnt1w { z6.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z3.s, z0.s[0] + addvl x23, x23, #2 + ldnt1w { z7.s }, p2/Z, [x20] + fmla z27.s, z4.s, z0.s[0] + ldnt1w { z8.s }, p2/Z, [x20, #1, MUL VL] + fmla z28.s, z5.s, z0.s[0] + addvl x20, x20, #2 + fmla z29.s, z6.s, z0.s[0] + ldnt1w { z9.s }, p2/Z, [x9] + fmla z30.s, z7.s, z0.s[0] + ldnt1w { z10.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z31.s, z8.s, z0.s[0] + ldnt1w { z11.s }, p2/Z, [x21] + ldnt1w { z12.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z9.s, z0.s[1] + addvl x21, x21, #2 + ldnt1w { z13.s }, p2/Z, [x23] + fmla z25.s, z10.s, z0.s[1] + ldnt1w { z14.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z11.s, z0.s[1] + addvl x23, x23, #2 + ldnt1w { z15.s }, p2/Z, [x20] + fmla z27.s, z12.s, z0.s[1] + ldnt1w { z16.s }, p2/Z, [x20, #1, MUL VL] + fmla z28.s, z13.s, z0.s[1] + addvl x20, x20, #2 + fmla z29.s, z14.s, z0.s[1] + ldnt1w { z17.s }, p2/Z, [x9] + fmla z30.s, z15.s, z0.s[1] + ldnt1w { z18.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z31.s, z16.s, z0.s[1] + ldnt1w { z19.s }, p2/Z, [x21] + ldnt1w { z20.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z17.s, z0.s[2] + addvl x21, x21, #2 + ldnt1w { z21.s }, p2/Z, [x23] + fmla z25.s, z18.s, z0.s[2] + ldnt1w { z22.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z19.s, z0.s[2] + addvl x23, x23, #2 + ldnt1w { z23.s }, p2/Z, [x20] + fmla z27.s, z20.s, z0.s[2] + ldnt1w { z1.s }, p2/Z, [x20, #1, MUL VL] + fmla z28.s, z21.s, z0.s[2] + addvl x20, x20, #2 + fmla z29.s, z22.s, z0.s[2] + ldnt1w { z2.s }, p2/Z, [x9] + fmla z30.s, z23.s, z0.s[2] + ldnt1w { z3.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + fmla z31.s, z1.s, z0.s[2] + ldnt1w { z4.s }, p2/Z, [x21] + ldnt1w { z5.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z2.s, z0.s[3] + addvl x21, x21, #2 + ldnt1w { z6.s }, p2/Z, [x23] + fmla z25.s, z3.s, z0.s[3] + ldnt1w { z7.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z4.s, z0.s[3] + addvl x23, x23, #2 + ldnt1w { z8.s }, p2/Z, [x20] + fmla z27.s, z5.s, z0.s[3] + ldnt1w { z9.s }, p2/Z, [x20, #1, MUL VL] + fmla z28.s, z6.s, z0.s[3] + addvl x20, x20, #2 + fmla z29.s, z7.s, z0.s[3] + fmla z30.s, z8.s, z0.s[3] + fmla z31.s, z9.s, z0.s[3] + bgt label_37 +KAI_ASM_LABEL(label_38) // Width 8: Multiply loop: Single iteration only + whilelt p0.s, XZR, x25 + ldnt1w { z10.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ld1rqw { z0.s }, p0/Z, [x24] + ldnt1w { z11.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z12.s }, p2/Z, [x21] + ldnt1w { z13.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z10.s, z0.s[0] + addvl x21, x21, #2 + ldnt1w { z14.s }, p2/Z, [x23] + fmla z25.s, z11.s, z0.s[0] + ldnt1w { z15.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z12.s, z0.s[0] + addvl x23, x23, #2 + ldnt1w { z16.s }, p2/Z, [x20] + fmla z27.s, z13.s, z0.s[0] + ldnt1w { z17.s }, p2/Z, [x20, #1, MUL VL] + fmla z28.s, z14.s, z0.s[0] + addvl x20, x20, #2 + fmla z29.s, z15.s, z0.s[0] + fmla z30.s, z16.s, z0.s[0] + fmla z31.s, z17.s, z0.s[0] + ble label_39 + ldnt1w { z18.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z19.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z20.s }, p2/Z, [x21] + ldnt1w { z21.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z18.s, z0.s[1] + addvl x21, x21, #2 + ldnt1w { z22.s }, p2/Z, [x23] + fmla z25.s, z19.s, z0.s[1] + ldnt1w { z23.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z20.s, z0.s[1] + addvl x23, x23, #2 + ldnt1w { z1.s }, p2/Z, [x20] + fmla z27.s, z21.s, z0.s[1] + ldnt1w { z2.s }, p2/Z, [x20, #1, MUL VL] + fmla z28.s, z22.s, z0.s[1] + addvl x20, x20, #2 + fmla z29.s, z23.s, z0.s[1] + fmla z30.s, z1.s, z0.s[1] + fmla z31.s, z2.s, z0.s[1] + ble label_39 + ldnt1w { z3.s }, p2/Z, [x9] + subs x25, x25, #0x1 + ldnt1w { z4.s }, p2/Z, [x9, #1, MUL VL] + addvl x9, x9, #2 + ldnt1w { z5.s }, p2/Z, [x21] + ldnt1w { z6.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z3.s, z0.s[2] + addvl x21, x21, #2 + ldnt1w { z7.s }, p2/Z, [x23] + fmla z25.s, z4.s, z0.s[2] + ldnt1w { z8.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z5.s, z0.s[2] + addvl x23, x23, #2 + ldnt1w { z9.s }, p2/Z, [x20] + fmla z27.s, z6.s, z0.s[2] + ldnt1w { z10.s }, p2/Z, [x20, #1, MUL VL] + fmla z28.s, z7.s, z0.s[2] + addvl x20, x20, #2 + fmla z29.s, z8.s, z0.s[2] + fmla z30.s, z9.s, z0.s[2] + fmla z31.s, z10.s, z0.s[2] + ble label_39 + ldnt1w { z11.s }, p2/Z, [x9] + ldnt1w { z12.s }, p2/Z, [x9, #1, MUL VL] + ldnt1w { z13.s }, p2/Z, [x21] + ldnt1w { z14.s }, p2/Z, [x21, #1, MUL VL] + fmla z24.s, z11.s, z0.s[3] + ldnt1w { z15.s }, p2/Z, [x23] + fmla z25.s, z12.s, z0.s[3] + ldnt1w { z16.s }, p2/Z, [x23, #1, MUL VL] + fmla z26.s, z13.s, z0.s[3] + ldnt1w { z17.s }, p2/Z, [x20] + fmla z27.s, z14.s, z0.s[3] + ldnt1w { z18.s }, p2/Z, [x20, #1, MUL VL] + fmla z28.s, z15.s, z0.s[3] + fmla z29.s, z16.s, z0.s[3] + fmla z30.s, z17.s, z0.s[3] + fmla z31.s, z18.s, z0.s[3] +KAI_ASM_LABEL(label_39) // Width 8: Multiply loop: multiply skip + tbz x26, #1, label_40 + add x21, x0, #0x0 + add x20, x0, #0x4 + KAI_ASM_INST(0x8540cab1) // ld1rw { z17.s }, p2/Z, [x21] + KAI_ASM_INST(0x8540ca90) // ld1rw { z16.s }, p2/Z, [x20] + fmin z24.s, p2/M, z24.s, z17.s + fmin z25.s, p2/M, z25.s, z17.s + fmin z26.s, p2/M, z26.s, z17.s + fmin z27.s, p2/M, z27.s, z17.s + fmin z28.s, p2/M, z28.s, z17.s + fmin z29.s, p2/M, z29.s, z17.s + fmin z30.s, p2/M, z30.s, z17.s + fmin z31.s, p2/M, z31.s, z17.s + fmax z24.s, p2/M, z24.s, z16.s + fmax z25.s, p2/M, z25.s, z16.s + fmax z26.s, p2/M, z26.s, z16.s + fmax z27.s, p2/M, z27.s, z16.s + fmax z28.s, p2/M, z28.s, z16.s + fmax z29.s, p2/M, z29.s, z16.s + fmax z30.s, p2/M, z30.s, z16.s + fmax z31.s, p2/M, z31.s, z16.s +KAI_ASM_LABEL(label_40) // Width 8: No activation + subs x28, x28, #0x8 + st1w { z24.s }, p2, [x27] + mov x9, x22 + st1w { z25.s }, p2, [x27, #1, MUL VL] + sub x12, x12, x13, LSL #3 + st1w { z26.s }, p2, [x27, #2, MUL VL] + st1w { z27.s }, p2, [x27, #3, MUL VL] + st1w { z28.s }, p2, [x27, #4, MUL VL] + st1w { z29.s }, p2, [x27, #5, MUL VL] + st1w { z30.s }, p2, [x27, #6, MUL VL] + st1w { z31.s }, p1, [x27, #7, MUL VL] + addvl x27, x27, #8 + bgt label_1 +KAI_ASM_LABEL(label_41) // Exit + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla) + + KAI_ASM_END diff --git a/test/tests/matmul_clamp_f32_f32_f32p_test.cpp b/test/tests/matmul_clamp_f32_f32_f32p_test.cpp index 3fdef040..be0d6390 100644 --- a/test/tests/matmul_clamp_f32_f32_f32p_test.cpp +++ b/test/tests/matmul_clamp_f32_f32_f32p_test.cpp @@ -19,6 +19,7 @@ #include "kai/kai_common.h" #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h" #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" @@ -36,7 +37,7 @@ namespace kai::test { namespace { -const std::array, 2> ukernel_variants = { +const std::array, 3> ukernel_variants = { {{ {kai_get_m_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla, kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla, @@ -62,7 +63,19 @@ const std::array, 2> ukern kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla}, "matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla", - cpu_has_sme2}}}; + cpu_has_sme2}, + {{kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla}, + "matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla", + cpu_has_sme}}}; } // namespace @@ -119,15 +132,13 @@ TEST_P(MatMulTest_f32_f32_f32p, EndToEnd) // NOLINT(google-readability-avoid-un 1, n, k, nr, kr, sr, rhs_stride, ref_rhs.data(), ref_bias.data(), nullptr, imp_packed_rhs->data(), 0, nullptr); break; - case 1: // matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla + default: imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(n, k); imp_packed_rhs = std::make_unique(imp_packed_rhs_size); kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( 1, n, k, nr, kr, sr, rhs_stride, ref_rhs.data(), ref_bias.data(), nullptr, imp_packed_rhs->data(), 0, nullptr); break; - default: - KAI_ERROR("Unsupported micro-kernel"); } // Run the MatMul micro-kernel. -- GitLab From dec618ce998c589f39335b48ce64c13a41ac980f Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Wed, 9 Jul 2025 12:04:47 +0100 Subject: [PATCH 2/5] Register kernel in benchmark Signed-off-by: Jakub Sujak --- benchmark/matmul/matmul_registry.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/benchmark/matmul/matmul_registry.cpp b/benchmark/matmul/matmul_registry.cpp index 75d6a48b..8816152d 100644 --- a/benchmark/matmul/matmul_registry.cpp +++ b/benchmark/matmul/matmul_registry.cpp @@ -45,6 +45,7 @@ // matmul_clamp_f32_f32_f32p #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h" #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" // matmul_clamp_f32_f32p_f32p @@ -136,6 +137,10 @@ inline constexpr MatMulStridedLhsInterface kai_matmul_clamp_f32_f32_f32p2vlx1b_1 .run_matmul = kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, }; +inline constexpr MatMulStridedLhsInterface kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla_interface{ + .run_matmul = kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, +}; + inline constexpr MatMulStridedLhsInterface kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla_interface{ .run_matmul = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, }; @@ -355,6 +360,9 @@ inline const std::array matmul_benchmarks{ "kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla", kai_benchmark_matmul, kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_interface, DataType::FP32, MatMulOp::GEMV, test::cpu_has_sme2), + RegisterBenchmark( + "kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla", kai_benchmark_matmul, + kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla_interface, DataType::FP32, MatMulOp::GEMV, test::cpu_has_sme), RegisterBenchmark( "kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla", kai_benchmark_matmul, kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla_interface, DataType::FP32, MatMulOp::GEMM, -- GitLab From 95f7a464d43e043a5a9620b204672ab70fc8bb21 Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Wed, 9 Jul 2025 12:05:15 +0100 Subject: [PATCH 3/5] Change argument name in interface file Signed-off-by: Jakub Sujak --- .../kai_matmul_clamp_f32_f32_f32p_interface.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h index c402c9bb..2c65daae 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h @@ -20,7 +20,7 @@ typedef size_t (*kai_matmul_clamp_f32_f32_f32p_get_n_step_func_t)(void); typedef size_t (*kai_matmul_clamp_f32_f32_f32p_get_nr_func_t)(void); typedef size_t (*kai_matmul_clamp_f32_f32_f32p_get_kr_func_t)(void); typedef size_t (*kai_matmul_clamp_f32_f32_f32p_get_sr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_f32_f32p_get_lhs_offset_func_t)(size_t m_idx, size_t lhs_stride); +typedef size_t (*kai_matmul_clamp_f32_f32_f32p_get_lhs_offset_func_t)(size_t m_idx, size_t k); typedef size_t (*kai_matmul_clamp_f32_f32_f32p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); typedef size_t (*kai_matmul_clamp_f32_f32_f32p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); typedef size_t (*kai_matmul_clamp_f32_f32_f32p_get_dst_size_func_t)(size_t m, size_t n); -- GitLab From 02b2fafbe39e0fd3bc1cb006bef4839a27243eba Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Wed, 9 Jul 2025 12:05:51 +0100 Subject: [PATCH 4/5] Error by default in test Signed-off-by: Jakub Sujak --- test/tests/matmul_clamp_f32_f32_f32p_test.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/tests/matmul_clamp_f32_f32_f32p_test.cpp b/test/tests/matmul_clamp_f32_f32_f32p_test.cpp index be0d6390..f54d0a38 100644 --- a/test/tests/matmul_clamp_f32_f32_f32p_test.cpp +++ b/test/tests/matmul_clamp_f32_f32_f32p_test.cpp @@ -132,13 +132,16 @@ TEST_P(MatMulTest_f32_f32_f32p, EndToEnd) // NOLINT(google-readability-avoid-un 1, n, k, nr, kr, sr, rhs_stride, ref_rhs.data(), ref_bias.data(), nullptr, imp_packed_rhs->data(), 0, nullptr); break; - default: + case 1: // matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla + case 2: // matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(n, k); imp_packed_rhs = std::make_unique(imp_packed_rhs_size); kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( 1, n, k, nr, kr, sr, rhs_stride, ref_rhs.data(), ref_bias.data(), nullptr, imp_packed_rhs->data(), 0, nullptr); break; + default: + KAI_ERROR("Unsupported micro-kernel"); } // Run the MatMul micro-kernel. -- GitLab From e862cf1d618478e3b3ea5ccffd8b7a3b6f9a106e Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Wed, 9 Jul 2025 12:12:54 +0100 Subject: [PATCH 5/5] Clarify changelog entry Signed-off-by: Jakub Sujak --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b78e76f2..2c144a9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Matrix multiplication (MxN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with BF16 output, optimized for FEAT_I8MM. - Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with BF16 output, optimized for FEAT_DotProd. - New SME micro-kernels: - - SME1 compatible matrix multiplication (1xN) of F32 LHS and RHS with F32 output. + - Matrix multiplication (1xN) of F32 LHS and RHS with F32 output, using instructions compatible with FEAT_SME. ## v1.11.0 -- GitLab