From 851017a8394cc45f68bf72e57699b85662f03802 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Wed, 19 Jun 2024 11:24:34 +0100 Subject: [PATCH 01/12] Add FP32 matmul micro kernel Block size: 6x8 Signed-off-by: Felix Thomasmathibalan --- CMakeLists.txt | 7 + README.md | 15 + ...l_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c | 2053 +++++++++++++++++ ...l_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h | 124 + ...rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c | 253 ++ ...rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h | 80 + test/tests/matmul_test.cpp | 76 +- 7 files changed, 2603 insertions(+), 5 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 80adb858..d63efae2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,6 +85,11 @@ set(KLEIDIAI_FILES_NEON_FP16 kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c ) +set(KLEIDIAI_FILES_NEON_FP32 + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c + kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c +) + set(KLEIDIAI_FILES_NEON_DOTPROD kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c @@ -116,6 +121,7 @@ target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SCALAR}) # https://learn.microsoft.com/en-us/cpp/assembler/inline/inline-assembler?view=msvc-170 # if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") AND NOT MSVC) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_FP32}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_FP16}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM}) @@ -123,6 +129,7 @@ if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2}) set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a) + set_source_files_properties(${KLEIDIAI_FILES_NEON_FP32} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a) set_source_files_properties(${KLEIDIAI_FILES_NEON_FP16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+fp16) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm) diff --git a/README.md b/README.md index d287fa96..a139d11b 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,21 @@ Some of the data types currently supported with the KleidiAI library are the fol Since the RHS matrix often contains constant values, we recommend packing the RHS matrix only once and freeing the content of the original RHS matrix.
+ + Matrix-multiplication with RHS packed + matmul_clamp_f32_f32_f32p + + DST: f32
+ LHS: f32
+ RHS: f32p
+ + + TensorFlow Lite
+ + + The packing function for the RHS matrix is listed in the header file of the GEMM micro kernel.
+ + Dynamic quantization and LHS matrix packing kai_lhs_quant_pack_qai8dxp_f32 diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c new file mode 100644 index 00000000..fe88014e --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c @@ -0,0 +1,2053 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. + +#include +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 6; +static const size_t kai_nr = 8; +static const size_t kai_kr = 1; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void) { + return kai_mr; +} + +size_t kai_get_n_step_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void) { + return kai_nr; +} + +size_t kai_get_nr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void) { + return kai_sr; +} + +size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + + return m_idx * stride; +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx / kai_nr * (kai_nr * sizeof(float) + kai_nr * k * sizeof(float)); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m_idx, size_t n_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + KAI_ASSUME(n_idx % kai_nr == 0); + + return m_idx * stride + n_idx * sizeof(float); +} + +size_t kai_get_dst_size_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla( + size_t m, size_t n, size_t k, // + const void* lhs, size_t lhs_stride, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + typedef struct { + float maxval; + float minval; + unsigned int num_strings; + const unsigned int* string_lengths; + size_t N; + const void* B_ptr; + size_t output_offset; + size_t input_initial_col; + size_t input_offset; + void* output_ptr; + const void* bias; + } KernelArgs; + + KernelArgs ka; + + unsigned long flags = 0; + + unsigned int string_length = k; + ka.num_strings = 1; + ka.string_lengths = &string_length; + ka.N = n; + ka.B_ptr = rhs_packed; + ka.bias = NULL; + + // Direct input. + const void* input_ptr = lhs; + ka.input_offset = lhs_stride / sizeof(float); + ka.input_initial_col = 0; + + // Direct output. + ka.output_ptr = dst; + ka.output_offset = dst_stride_row / sizeof(float); + + // Clamping output. + flags |= 0x2; + ka.maxval = clamp_max; + ka.minval = clamp_min; + + printf( + "*******F32************* N: %zu , ka.output_offset: %zu, input_offset: %zu m:%zu, n:%zu \n", ka.N, + ka.output_offset, ka.input_offset, m, n); + + __asm__ __volatile__( + "1:" // Row loop + "cmp %x[m], #0x6\n" + "bge 126f\n" + "cmp %x[m], #0x4\n" + "bgt 101f\n" + "beq 76f\n" + "cmp %x[m], #0x2\n" + "bgt 51f\n" + "beq 26f\n" + "ldr x11, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x10, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x9, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "2:" // Height 1: Column loop + "cbz x10, 3f\n" + "ldr q20, [x10, #0x0]\n" + "ldr q21, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "b 10f\n" + "3:" // Height 1: no bias + "tbz %x[flags], #0, 9f\n" + "cmp x11, #0x8\n" + "bge 8f\n" + "tbz x11, #2, 5f\n" + "ld1 { v20.4s }, [x9], #0x10\n" + "tbz x11, #1, 4f\n" + "ldr d21, [x9], #0x8\n" + "mov x20, #0x18\n" + "tbz x11, #0, 7f\n" + "ld1 { v21.s }[2], [x9]\n" + "b 7f\n" + "4:" // Height 1: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x11, #0, 7f\n" + "ldr s21, [x9, #0x0]\n" + "b 7f\n" + "5:" // Height 1: Partial accumulate: partial_2_0 + "tbz x11, #1, 6f\n" + "ldr d20, [x9], #0x8\n" + "mov x20, #0x8\n" + "tbz x11, #0, 7f\n" + "ld1 { v20.s }[2], [x9]\n" + "b 7f\n" + "6:" // Height 1: Partial accumulate: partial_1_0 + "ldr s20, [x9, #0x0]\n" + "mov x20, #0x0\n" + "7:" // Height 1: Partial accumulate: Done + "sub x9, x9, x20\n" + "b 10f\n" + "8:" // Height 1: full accumulate + "ldr q20, [x9, #0x0]\n" + "ldr q21, [x9, #0x10]\n" + "b 10f\n" + "9:" // Height 1: no accumulate + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "10:" // Height 1: setup done + "mov x28, #0x0\n" + "11:" // Height 1: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 12f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "cbnz x28, 13f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "b 13f\n" + "12:" // Height 1: setup direct input + "mov x26, %x[input_ptr]\n" + "13:" // Height 1: input setup done + "cmp x27, #0x4\n" + "blt 16f\n" + "ldr q0, [x26, #0x0]\n" + "ldr q6, [x10, #0x0]\n" + "cmp x27, #0x8\n" + "ldr q7, [x10, #0x10]\n" + "ldr q8, [x10, #0x20]\n" + "ldr q9, [x10, #0x30]\n" + "ldr q10, [x10, #0x40]\n" + "ldr q11, [x10, #0x50]\n" + "ldr q12, [x10, #0x60]\n" + "ldr q13, [x10, #0x70]\n" + "blt 15f\n" + "14:" // Height 1: Multiply loop: Main loop head + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "sub x27, x27, #0x4\n" + "add x26, x26, #0x10\n" + "cmp x27, #0x8\n" + "add x10, x10, #0x80\n" + "prfm pldl1keep, [x26, #0x80]\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v20.4s, v8.4s, v0.s[1]\n" + "ldr q8, [x10, #0x20]\n" + "fmla v21.4s, v9.4s, v0.s[1]\n" + "ldr q9, [x10, #0x30]\n" + "fmla v20.4s, v10.4s, v0.s[2]\n" + "ldr q10, [x10, #0x40]\n" + "fmla v21.4s, v11.4s, v0.s[2]\n" + "ldr q11, [x10, #0x50]\n" + "fmla v20.4s, v12.4s, v0.s[3]\n" + "ldr q12, [x10, #0x60]\n" + "fmla v21.4s, v13.4s, v0.s[3]\n" + "ldr q0, [x26, #0x0]\n" + "ldr q13, [x10, #0x70]\n" + "bge 14b\n" + "15:" // Height 1: Multiply loop: Single iteration only + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "add x26, x26, #0x10\n" + "sub x27, x27, #0x4\n" + "add x10, x10, #0x80\n" + "prfm pldl1keep, [x26, #0x80]\n" + "fmla v20.4s, v8.4s, v0.s[1]\n" + "fmla v21.4s, v9.4s, v0.s[1]\n" + "fmla v20.4s, v10.4s, v0.s[2]\n" + "fmla v21.4s, v11.4s, v0.s[2]\n" + "fmla v20.4s, v12.4s, v0.s[3]\n" + "fmla v21.4s, v13.4s, v0.s[3]\n" + "16:" // Height 1: Multiply loop: Main loop skip + "cbz x27, 18f\n" + "17:" // Height 1: Multiply loop: Odd block loop + "ldr s18, [x26], #0x4\n" + "ldr q17, [x10, #0x0]\n" + "sub x27, x27, #0x1\n" + "ldr q16, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "fmla v20.4s, v17.4s, v18.s[0]\n" + "fmla v21.4s, v16.4s, v18.s[0]\n" + "cbnz x27, 17b\n" + "18:" // Height 1: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 11b\n" + "prfm pstl1keep, [x9, #0x0]\n" + "tbz %x[flags], #1, 19f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v17.4s }, [x21]\n" + "ld1r { v16.4s }, [x20]\n" + "fmin v20.4s, v20.4s, v17.4s\n" + "fmin v21.4s, v21.4s, v17.4s\n" + "fmax v20.4s, v20.4s, v16.4s\n" + "fmax v21.4s, v21.4s, v16.4s\n" + "19:" // Height 1: No activation + "cmp x11, #0x8\n" + "bge 24f\n" + "tbz x11, #2, 21f\n" + "st1 { v20.4s }, [x9], #0x10\n" + "tbz x11, #1, 20f\n" + "str d21, [x9], #0x8\n" + "tbz x11, #0, 23f\n" + "st1 { v21.s }[2], [x9]\n" + "b 23f\n" + "20:" // Height 1: Partial direct writeback: partial_1_4 + "tbz x11, #0, 23f\n" + "str s21, [x9, #0x0]\n" + "b 23f\n" + "21:" // Height 1: Partial direct writeback: partial_2_0 + "tbz x11, #1, 22f\n" + "str d20, [x9], #0x8\n" + "tbz x11, #0, 23f\n" + "st1 { v20.s }[2], [x9]\n" + "b 23f\n" + "22:" // Height 1: Partial direct writeback: partial_1_0 + "str s20, [x9, #0x0]\n" + "23:" // Height 1: Partial direct writeback: Done + "b 25f\n" + "24:" // Height 1: Full writeback + "str q20, [x9, #0x0]\n" + "str q21, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + "25:" // Height 1: Writeback done + "subs x11, x11, #0x8\n" + "bgt 2b\n" + "b 152f\n" + "26:" // Height 2 + "ldr x11, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x10, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x9, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "27:" // Height 2: Column loop + "cbz x10, 28f\n" + "ldr q20, [x10, #0x0]\n" + "ldr q21, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "mov v22.16b, v20.16b\n" + "mov v23.16b, v21.16b\n" + "b 35f\n" + "28:" // Height 2: no bias + "tbz %x[flags], #0, 34f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x11, #0x8\n" + "add x26, x9, x20, LSL #2\n" + "bge 33f\n" + "tbz x11, #2, 30f\n" + "ld1 { v20.4s }, [x9], #0x10\n" + "ld1 { v22.4s }, [x26], #0x10\n" + "tbz x11, #1, 29f\n" + "ldr d21, [x9], #0x8\n" + "ldr d23, [x26], #0x8\n" + "mov x20, #0x18\n" + "tbz x11, #0, 32f\n" + "ld1 { v21.s }[2], [x9]\n" + "ld1 { v23.s }[2], [x26]\n" + "b 32f\n" + "29:" // Height 2: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x11, #0, 32f\n" + "ldr s21, [x9, #0x0]\n" + "ldr s23, [x26, #0x0]\n" + "b 32f\n" + "30:" // Height 2: Partial accumulate: partial_2_0 + "tbz x11, #1, 31f\n" + "ldr d20, [x9], #0x8\n" + "ldr d22, [x26], #0x8\n" + "mov x20, #0x8\n" + "tbz x11, #0, 32f\n" + "ld1 { v20.s }[2], [x9]\n" + "ld1 { v22.s }[2], [x26]\n" + "b 32f\n" + "31:" // Height 2: Partial accumulate: partial_1_0 + "ldr s20, [x9, #0x0]\n" + "ldr s22, [x26, #0x0]\n" + "mov x20, #0x0\n" + "32:" // Height 2: Partial accumulate: Done + "sub x9, x9, x20\n" + "b 35f\n" + "33:" // Height 2: full accumulate + "ldr q20, [x9, #0x0]\n" + "ldr q21, [x9, #0x10]\n" + "ldr q22, [x26, #0x0]\n" + "ldr q23, [x26, #0x10]\n" + "b 35f\n" + "34:" // Height 2: no accumulate + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "35:" // Height 2: setup done + "mov x28, #0x0\n" + "36:" // Height 2: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 37f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "cbnz x28, 38f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "b 38f\n" + "37:" // Height 2: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "38:" // Height 2: input setup done + "cmp x27, #0x4\n" + "blt 41f\n" + "ldr q0, [x26, #0x0]\n" + "ldr q1, [x25, #0x0]\n" + "cmp x27, #0x8\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "ldr q8, [x10, #0x20]\n" + "ldr q9, [x10, #0x30]\n" + "ldr q10, [x10, #0x40]\n" + "ldr q11, [x10, #0x50]\n" + "ldr q12, [x10, #0x60]\n" + "ldr q13, [x10, #0x70]\n" + "blt 40f\n" + "39:" // Height 2: Multiply loop: Main loop head + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "sub x27, x27, #0x4\n" + "add x26, x26, #0x10\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "fmla v23.4s, v7.4s, v1.s[0]\n" + "add x25, x25, #0x10\n" + "cmp x27, #0x8\n" + "add x10, x10, #0x80\n" + "prfm pldl1keep, [x26, #0x80]\n" + "prfm pldl1keep, [x25, #0x80]\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v20.4s, v8.4s, v0.s[1]\n" + "fmla v22.4s, v8.4s, v1.s[1]\n" + "ldr q8, [x10, #0x20]\n" + "fmla v21.4s, v9.4s, v0.s[1]\n" + "fmla v23.4s, v9.4s, v1.s[1]\n" + "ldr q9, [x10, #0x30]\n" + "fmla v20.4s, v10.4s, v0.s[2]\n" + "fmla v22.4s, v10.4s, v1.s[2]\n" + "ldr q10, [x10, #0x40]\n" + "fmla v21.4s, v11.4s, v0.s[2]\n" + "fmla v23.4s, v11.4s, v1.s[2]\n" + "ldr q11, [x10, #0x50]\n" + "fmla v20.4s, v12.4s, v0.s[3]\n" + "fmla v22.4s, v12.4s, v1.s[3]\n" + "ldr q12, [x10, #0x60]\n" + "fmla v21.4s, v13.4s, v0.s[3]\n" + "ldr q0, [x26, #0x0]\n" + "fmla v23.4s, v13.4s, v1.s[3]\n" + "ldr q1, [x25, #0x0]\n" + "ldr q13, [x10, #0x70]\n" + "bge 39b\n" + "40:" // Height 2: Multiply loop: Single iteration only + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "add x26, x26, #0x10\n" + "add x25, x25, #0x10\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "fmla v23.4s, v7.4s, v1.s[0]\n" + "sub x27, x27, #0x4\n" + "prfm pldl1keep, [x26, #0x80]\n" + "prfm pldl1keep, [x25, #0x80]\n" + "add x10, x10, #0x80\n" + "fmla v20.4s, v8.4s, v0.s[1]\n" + "fmla v22.4s, v8.4s, v1.s[1]\n" + "fmla v21.4s, v9.4s, v0.s[1]\n" + "fmla v23.4s, v9.4s, v1.s[1]\n" + "fmla v20.4s, v10.4s, v0.s[2]\n" + "fmla v22.4s, v10.4s, v1.s[2]\n" + "fmla v21.4s, v11.4s, v0.s[2]\n" + "fmla v23.4s, v11.4s, v1.s[2]\n" + "fmla v20.4s, v12.4s, v0.s[3]\n" + "fmla v22.4s, v12.4s, v1.s[3]\n" + "fmla v21.4s, v13.4s, v0.s[3]\n" + "fmla v23.4s, v13.4s, v1.s[3]\n" + "41:" // Height 2: Multiply loop: Main loop skip + "cbz x27, 43f\n" + "42:" // Height 2: Multiply loop: Odd block loop + "ldr s19, [x26], #0x4\n" + "ldr s18, [x25], #0x4\n" + "sub x27, x27, #0x1\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "fmla v20.4s, v17.4s, v19.s[0]\n" + "fmla v22.4s, v17.4s, v18.s[0]\n" + "fmla v21.4s, v16.4s, v19.s[0]\n" + "fmla v23.4s, v16.4s, v18.s[0]\n" + "cbnz x27, 42b\n" + "43:" // Height 2: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 36b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "prfm pstl1keep, [x9, #0x0]\n" + "add x26, x9, x20, LSL #2\n" + "prfm pstl1keep, [x26, #0x0]\n" + "tbz %x[flags], #1, 44f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v17.4s }, [x21]\n" + "ld1r { v16.4s }, [x20]\n" + "fmin v20.4s, v20.4s, v17.4s\n" + "fmin v21.4s, v21.4s, v17.4s\n" + "fmin v22.4s, v22.4s, v17.4s\n" + "fmin v23.4s, v23.4s, v17.4s\n" + "fmax v20.4s, v20.4s, v16.4s\n" + "fmax v21.4s, v21.4s, v16.4s\n" + "fmax v22.4s, v22.4s, v16.4s\n" + "fmax v23.4s, v23.4s, v16.4s\n" + "44:" // Height 2: No activation + "cmp x11, #0x8\n" + "bge 49f\n" + "tbz x11, #2, 46f\n" + "st1 { v20.4s }, [x9], #0x10\n" + "st1 { v22.4s }, [x26], #0x10\n" + "tbz x11, #1, 45f\n" + "str d21, [x9], #0x8\n" + "str d23, [x26], #0x8\n" + "tbz x11, #0, 48f\n" + "st1 { v21.s }[2], [x9]\n" + "st1 { v23.s }[2], [x26]\n" + "b 48f\n" + "45:" // Height 2: Partial direct writeback: partial_1_4 + "tbz x11, #0, 48f\n" + "str s21, [x9, #0x0]\n" + "str s23, [x26, #0x0]\n" + "b 48f\n" + "46:" // Height 2: Partial direct writeback: partial_2_0 + "tbz x11, #1, 47f\n" + "str d20, [x9], #0x8\n" + "str d22, [x26], #0x8\n" + "tbz x11, #0, 48f\n" + "st1 { v20.s }[2], [x9]\n" + "st1 { v22.s }[2], [x26]\n" + "b 48f\n" + "47:" // Height 2: Partial direct writeback: partial_1_0 + "str s20, [x9, #0x0]\n" + "str s22, [x26, #0x0]\n" + "48:" // Height 2: Partial direct writeback: Done + "b 50f\n" + "49:" // Height 2: Full writeback + "str q20, [x9, #0x0]\n" + "str q21, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + "str q22, [x26, #0x0]\n" + "str q23, [x26, #0x10]\n" + "50:" // Height 2: Writeback done + "subs x11, x11, #0x8\n" + "bgt 27b\n" + "b 152f\n" + "51:" // Height 3 + "ldr x11, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x10, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x9, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "52:" // Height 3: Column loop + "cbz x10, 53f\n" + "ldr q20, [x10, #0x0]\n" + "ldr q21, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "mov v22.16b, v20.16b\n" + "mov v23.16b, v21.16b\n" + "mov v24.16b, v20.16b\n" + "mov v25.16b, v21.16b\n" + "b 60f\n" + "53:" // Height 3: no bias + "tbz %x[flags], #0, 59f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x11, #0x8\n" + "add x26, x9, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "bge 58f\n" + "tbz x11, #2, 55f\n" + "ld1 { v20.4s }, [x9], #0x10\n" + "ld1 { v22.4s }, [x26], #0x10\n" + "ld1 { v24.4s }, [x25], #0x10\n" + "tbz x11, #1, 54f\n" + "ldr d21, [x9], #0x8\n" + "ldr d23, [x26], #0x8\n" + "mov x20, #0x18\n" + "ldr d25, [x25], #0x8\n" + "tbz x11, #0, 57f\n" + "ld1 { v21.s }[2], [x9]\n" + "ld1 { v23.s }[2], [x26]\n" + "ld1 { v25.s }[2], [x25]\n" + "b 57f\n" + "54:" // Height 3: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x11, #0, 57f\n" + "ldr s21, [x9, #0x0]\n" + "ldr s23, [x26, #0x0]\n" + "ldr s25, [x25, #0x0]\n" + "b 57f\n" + "55:" // Height 3: Partial accumulate: partial_2_0 + "tbz x11, #1, 56f\n" + "ldr d20, [x9], #0x8\n" + "ldr d22, [x26], #0x8\n" + "mov x20, #0x8\n" + "ldr d24, [x25], #0x8\n" + "tbz x11, #0, 57f\n" + "ld1 { v20.s }[2], [x9]\n" + "ld1 { v22.s }[2], [x26]\n" + "ld1 { v24.s }[2], [x25]\n" + "b 57f\n" + "56:" // Height 3: Partial accumulate: partial_1_0 + "ldr s20, [x9, #0x0]\n" + "ldr s22, [x26, #0x0]\n" + "mov x20, #0x0\n" + "ldr s24, [x25, #0x0]\n" + "57:" // Height 3: Partial accumulate: Done + "sub x9, x9, x20\n" + "b 60f\n" + "58:" // Height 3: full accumulate + "ldr q20, [x9, #0x0]\n" + "ldr q21, [x9, #0x10]\n" + "ldr q22, [x26, #0x0]\n" + "ldr q23, [x26, #0x10]\n" + "ldr q24, [x25, #0x0]\n" + "ldr q25, [x25, #0x10]\n" + "b 60f\n" + "59:" // Height 3: no accumulate + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "60:" // Height 3: setup done + "mov x28, #0x0\n" + "61:" // Height 3: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 62f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "ldr x24, [x20, #0x10]\n" + "cbnz x28, 63f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "add x24, x24, x20, LSL #2\n" + "b 63f\n" + "62:" // Height 3: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "add x24, x25, x21, LSL #2\n" + "63:" // Height 3: input setup done + "cmp x27, #0x4\n" + "blt 66f\n" + "ldr q0, [x26, #0x0]\n" + "ldr q1, [x25, #0x0]\n" + "cmp x27, #0x8\n" + "ldr q2, [x24, #0x0]\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "ldr q8, [x10, #0x20]\n" + "ldr q9, [x10, #0x30]\n" + "ldr q10, [x10, #0x40]\n" + "ldr q11, [x10, #0x50]\n" + "ldr q12, [x10, #0x60]\n" + "ldr q13, [x10, #0x70]\n" + "blt 65f\n" + "64:" // Height 3: Multiply loop: Main loop head + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "sub x27, x27, #0x4\n" + "add x26, x26, #0x10\n" + "fmla v24.4s, v6.4s, v2.s[0]\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "fmla v23.4s, v7.4s, v1.s[0]\n" + "fmla v25.4s, v7.4s, v2.s[0]\n" + "cmp x27, #0x8\n" + "add x10, x10, #0x80\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v20.4s, v8.4s, v0.s[1]\n" + "fmla v22.4s, v8.4s, v1.s[1]\n" + "prfm pldl1keep, [x26, #0x80]\n" + "prfm pldl1keep, [x25, #0x80]\n" + "fmla v24.4s, v8.4s, v2.s[1]\n" + "ldr q8, [x10, #0x20]\n" + "fmla v21.4s, v9.4s, v0.s[1]\n" + "prfm pldl1keep, [x24, #0x80]\n" + "fmla v23.4s, v9.4s, v1.s[1]\n" + "fmla v25.4s, v9.4s, v2.s[1]\n" + "ldr q9, [x10, #0x30]\n" + "fmla v20.4s, v10.4s, v0.s[2]\n" + "fmla v22.4s, v10.4s, v1.s[2]\n" + "fmla v24.4s, v10.4s, v2.s[2]\n" + "ldr q10, [x10, #0x40]\n" + "fmla v21.4s, v11.4s, v0.s[2]\n" + "fmla v23.4s, v11.4s, v1.s[2]\n" + "fmla v25.4s, v11.4s, v2.s[2]\n" + "ldr q11, [x10, #0x50]\n" + "fmla v20.4s, v12.4s, v0.s[3]\n" + "fmla v22.4s, v12.4s, v1.s[3]\n" + "fmla v24.4s, v12.4s, v2.s[3]\n" + "ldr q12, [x10, #0x60]\n" + "fmla v21.4s, v13.4s, v0.s[3]\n" + "ldr q0, [x26, #0x0]\n" + "fmla v23.4s, v13.4s, v1.s[3]\n" + "ldr q1, [x25, #0x0]\n" + "fmla v25.4s, v13.4s, v2.s[3]\n" + "ldr q2, [x24, #0x0]\n" + "ldr q13, [x10, #0x70]\n" + "bge 64b\n" + "65:" // Height 3: Multiply loop: Single iteration only + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "add x26, x26, #0x10\n" + "add x25, x25, #0x10\n" + "fmla v24.4s, v6.4s, v2.s[0]\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "add x24, x24, #0x10\n" + "prfm pldl1keep, [x26, #0x80]\n" + "fmla v23.4s, v7.4s, v1.s[0]\n" + "fmla v25.4s, v7.4s, v2.s[0]\n" + "sub x27, x27, #0x4\n" + "prfm pldl1keep, [x25, #0x80]\n" + "add x10, x10, #0x80\n" + "prfm pldl1keep, [x24, #0x80]\n" + "fmla v20.4s, v8.4s, v0.s[1]\n" + "fmla v22.4s, v8.4s, v1.s[1]\n" + "fmla v24.4s, v8.4s, v2.s[1]\n" + "fmla v21.4s, v9.4s, v0.s[1]\n" + "fmla v23.4s, v9.4s, v1.s[1]\n" + "fmla v25.4s, v9.4s, v2.s[1]\n" + "fmla v20.4s, v10.4s, v0.s[2]\n" + "fmla v22.4s, v10.4s, v1.s[2]\n" + "fmla v24.4s, v10.4s, v2.s[2]\n" + "fmla v21.4s, v11.4s, v0.s[2]\n" + "fmla v23.4s, v11.4s, v1.s[2]\n" + "fmla v25.4s, v11.4s, v2.s[2]\n" + "fmla v20.4s, v12.4s, v0.s[3]\n" + "fmla v22.4s, v12.4s, v1.s[3]\n" + "fmla v24.4s, v12.4s, v2.s[3]\n" + "fmla v21.4s, v13.4s, v0.s[3]\n" + "fmla v23.4s, v13.4s, v1.s[3]\n" + "fmla v25.4s, v13.4s, v2.s[3]\n" + "66:" // Height 3: Multiply loop: Main loop skip + "cbz x27, 68f\n" + "67:" // Height 3: Multiply loop: Odd block loop + "ldr s26, [x26], #0x4\n" + "ldr s19, [x25], #0x4\n" + "sub x27, x27, #0x1\n" + "ldr s18, [x24], #0x4\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "fmla v20.4s, v17.4s, v26.s[0]\n" + "fmla v22.4s, v17.4s, v19.s[0]\n" + "fmla v24.4s, v17.4s, v18.s[0]\n" + "fmla v21.4s, v16.4s, v26.s[0]\n" + "fmla v23.4s, v16.4s, v19.s[0]\n" + "fmla v25.4s, v16.4s, v18.s[0]\n" + "cbnz x27, 67b\n" + "68:" // Height 3: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 61b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "prfm pstl1keep, [x9, #0x0]\n" + "add x26, x9, x20, LSL #2\n" + "prfm pstl1keep, [x26, #0x0]\n" + "add x25, x26, x20, LSL #2\n" + "prfm pstl1keep, [x25, #0x0]\n" + "tbz %x[flags], #1, 69f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v17.4s }, [x21]\n" + "ld1r { v16.4s }, [x20]\n" + "fmin v20.4s, v20.4s, v17.4s\n" + "fmin v21.4s, v21.4s, v17.4s\n" + "fmin v22.4s, v22.4s, v17.4s\n" + "fmin v23.4s, v23.4s, v17.4s\n" + "fmin v24.4s, v24.4s, v17.4s\n" + "fmin v25.4s, v25.4s, v17.4s\n" + "fmax v20.4s, v20.4s, v16.4s\n" + "fmax v21.4s, v21.4s, v16.4s\n" + "fmax v22.4s, v22.4s, v16.4s\n" + "fmax v23.4s, v23.4s, v16.4s\n" + "fmax v24.4s, v24.4s, v16.4s\n" + "fmax v25.4s, v25.4s, v16.4s\n" + "69:" // Height 3: No activation + "cmp x11, #0x8\n" + "bge 74f\n" + "tbz x11, #2, 71f\n" + "st1 { v20.4s }, [x9], #0x10\n" + "st1 { v22.4s }, [x26], #0x10\n" + "st1 { v24.4s }, [x25], #0x10\n" + "tbz x11, #1, 70f\n" + "str d21, [x9], #0x8\n" + "str d23, [x26], #0x8\n" + "str d25, [x25], #0x8\n" + "tbz x11, #0, 73f\n" + "st1 { v21.s }[2], [x9]\n" + "st1 { v23.s }[2], [x26]\n" + "st1 { v25.s }[2], [x25]\n" + "b 73f\n" + "70:" // Height 3: Partial direct writeback: partial_1_4 + "tbz x11, #0, 73f\n" + "str s21, [x9, #0x0]\n" + "str s23, [x26, #0x0]\n" + "str s25, [x25, #0x0]\n" + "b 73f\n" + "71:" // Height 3: Partial direct writeback: partial_2_0 + "tbz x11, #1, 72f\n" + "str d20, [x9], #0x8\n" + "str d22, [x26], #0x8\n" + "str d24, [x25], #0x8\n" + "tbz x11, #0, 73f\n" + "st1 { v20.s }[2], [x9]\n" + "st1 { v22.s }[2], [x26]\n" + "st1 { v24.s }[2], [x25]\n" + "b 73f\n" + "72:" // Height 3: Partial direct writeback: partial_1_0 + "str s20, [x9, #0x0]\n" + "str s22, [x26, #0x0]\n" + "str s24, [x25, #0x0]\n" + "73:" // Height 3: Partial direct writeback: Done + "b 75f\n" + "74:" // Height 3: Full writeback + "str q20, [x9, #0x0]\n" + "str q21, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + "str q22, [x26, #0x0]\n" + "str q23, [x26, #0x10]\n" + "str q24, [x25, #0x0]\n" + "str q25, [x25, #0x10]\n" + "75:" // Height 3: Writeback done + "subs x11, x11, #0x8\n" + "bgt 52b\n" + "b 152f\n" + "76:" // Height 4 + "ldr x11, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x10, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x9, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "77:" // Height 4: Column loop + "cbz x10, 78f\n" + "ldr q20, [x10, #0x0]\n" + "ldr q21, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "mov v22.16b, v20.16b\n" + "mov v23.16b, v21.16b\n" + "mov v24.16b, v20.16b\n" + "mov v25.16b, v21.16b\n" + "mov v26.16b, v20.16b\n" + "mov v27.16b, v21.16b\n" + "b 85f\n" + "78:" // Height 4: no bias + "tbz %x[flags], #0, 84f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x11, #0x8\n" + "add x26, x9, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "bge 83f\n" + "tbz x11, #2, 80f\n" + "ld1 { v20.4s }, [x9], #0x10\n" + "ld1 { v22.4s }, [x26], #0x10\n" + "ld1 { v24.4s }, [x25], #0x10\n" + "ld1 { v26.4s }, [x24], #0x10\n" + "tbz x11, #1, 79f\n" + "ldr d21, [x9], #0x8\n" + "ldr d23, [x26], #0x8\n" + "mov x20, #0x18\n" + "ldr d25, [x25], #0x8\n" + "ldr d27, [x24], #0x8\n" + "tbz x11, #0, 82f\n" + "ld1 { v21.s }[2], [x9]\n" + "ld1 { v23.s }[2], [x26]\n" + "ld1 { v25.s }[2], [x25]\n" + "ld1 { v27.s }[2], [x24]\n" + "b 82f\n" + "79:" // Height 4: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x11, #0, 82f\n" + "ldr s21, [x9, #0x0]\n" + "ldr s23, [x26, #0x0]\n" + "ldr s25, [x25, #0x0]\n" + "ldr s27, [x24, #0x0]\n" + "b 82f\n" + "80:" // Height 4: Partial accumulate: partial_2_0 + "tbz x11, #1, 81f\n" + "ldr d20, [x9], #0x8\n" + "ldr d22, [x26], #0x8\n" + "mov x20, #0x8\n" + "ldr d24, [x25], #0x8\n" + "ldr d26, [x24], #0x8\n" + "tbz x11, #0, 82f\n" + "ld1 { v20.s }[2], [x9]\n" + "ld1 { v22.s }[2], [x26]\n" + "ld1 { v24.s }[2], [x25]\n" + "ld1 { v26.s }[2], [x24]\n" + "b 82f\n" + "81:" // Height 4: Partial accumulate: partial_1_0 + "ldr s20, [x9, #0x0]\n" + "ldr s22, [x26, #0x0]\n" + "mov x20, #0x0\n" + "ldr s24, [x25, #0x0]\n" + "ldr s26, [x24, #0x0]\n" + "82:" // Height 4: Partial accumulate: Done + "sub x9, x9, x20\n" + "b 85f\n" + "83:" // Height 4: full accumulate + "ldr q20, [x9, #0x0]\n" + "ldr q21, [x9, #0x10]\n" + "ldr q22, [x26, #0x0]\n" + "ldr q23, [x26, #0x10]\n" + "ldr q24, [x25, #0x0]\n" + "ldr q25, [x25, #0x10]\n" + "ldr q26, [x24, #0x0]\n" + "ldr q27, [x24, #0x10]\n" + "b 85f\n" + "84:" // Height 4: no accumulate + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "85:" // Height 4: setup done + "mov x28, #0x0\n" + "86:" // Height 4: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 87f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "ldr x24, [x20, #0x10]\n" + "ldr x23, [x20, #0x18]\n" + "cbnz x28, 88f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "add x24, x24, x20, LSL #2\n" + "add x23, x23, x20, LSL #2\n" + "b 88f\n" + "87:" // Height 4: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "add x24, x25, x21, LSL #2\n" + "add x23, x24, x21, LSL #2\n" + "88:" // Height 4: input setup done + "cmp x27, #0x4\n" + "blt 91f\n" + "ldr q0, [x26, #0x0]\n" + "ldr q1, [x25, #0x0]\n" + "cmp x27, #0x8\n" + "ldr q2, [x24, #0x0]\n" + "ldr q3, [x23, #0x0]\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "ldr q8, [x10, #0x20]\n" + "ldr q9, [x10, #0x30]\n" + "ldr q10, [x10, #0x40]\n" + "ldr q11, [x10, #0x50]\n" + "ldr q12, [x10, #0x60]\n" + "ldr q13, [x10, #0x70]\n" + "blt 90f\n" + "89:" // Height 4: Multiply loop: Main loop head + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "sub x27, x27, #0x4\n" + "add x26, x26, #0x10\n" + "fmla v24.4s, v6.4s, v2.s[0]\n" + "fmla v26.4s, v6.4s, v3.s[0]\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "fmla v23.4s, v7.4s, v1.s[0]\n" + "add x23, x23, #0x10\n" + "cmp x27, #0x8\n" + "fmla v25.4s, v7.4s, v2.s[0]\n" + "fmla v27.4s, v7.4s, v3.s[0]\n" + "add x10, x10, #0x80\n" + "prfm pldl1keep, [x26, #0x80]\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v20.4s, v8.4s, v0.s[1]\n" + "fmla v22.4s, v8.4s, v1.s[1]\n" + "fmla v24.4s, v8.4s, v2.s[1]\n" + "fmla v26.4s, v8.4s, v3.s[1]\n" + "ldr q8, [x10, #0x20]\n" + "prfm pldl1keep, [x25, #0x80]\n" + "fmla v21.4s, v9.4s, v0.s[1]\n" + "fmla v23.4s, v9.4s, v1.s[1]\n" + "prfm pldl1keep, [x24, #0x80]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "fmla v25.4s, v9.4s, v2.s[1]\n" + "fmla v27.4s, v9.4s, v3.s[1]\n" + "ldr q9, [x10, #0x30]\n" + "fmla v20.4s, v10.4s, v0.s[2]\n" + "fmla v22.4s, v10.4s, v1.s[2]\n" + "fmla v24.4s, v10.4s, v2.s[2]\n" + "fmla v26.4s, v10.4s, v3.s[2]\n" + "ldr q10, [x10, #0x40]\n" + "fmla v21.4s, v11.4s, v0.s[2]\n" + "fmla v23.4s, v11.4s, v1.s[2]\n" + "fmla v25.4s, v11.4s, v2.s[2]\n" + "fmla v27.4s, v11.4s, v3.s[2]\n" + "ldr q11, [x10, #0x50]\n" + "fmla v20.4s, v12.4s, v0.s[3]\n" + "fmla v22.4s, v12.4s, v1.s[3]\n" + "fmla v24.4s, v12.4s, v2.s[3]\n" + "fmla v26.4s, v12.4s, v3.s[3]\n" + "ldr q12, [x10, #0x60]\n" + "fmla v21.4s, v13.4s, v0.s[3]\n" + "ldr q0, [x26, #0x0]\n" + "fmla v23.4s, v13.4s, v1.s[3]\n" + "ldr q1, [x25, #0x0]\n" + "fmla v25.4s, v13.4s, v2.s[3]\n" + "ldr q2, [x24, #0x0]\n" + "fmla v27.4s, v13.4s, v3.s[3]\n" + "ldr q3, [x23, #0x0]\n" + "ldr q13, [x10, #0x70]\n" + "bge 89b\n" + "90:" // Height 4: Multiply loop: Single iteration only + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "add x26, x26, #0x10\n" + "add x25, x25, #0x10\n" + "fmla v24.4s, v6.4s, v2.s[0]\n" + "fmla v26.4s, v6.4s, v3.s[0]\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "fmla v23.4s, v7.4s, v1.s[0]\n" + "sub x27, x27, #0x4\n" + "prfm pldl1keep, [x26, #0x80]\n" + "fmla v25.4s, v7.4s, v2.s[0]\n" + "fmla v27.4s, v7.4s, v3.s[0]\n" + "prfm pldl1keep, [x25, #0x80]\n" + "prfm pldl1keep, [x24, #0x80]\n" + "fmla v20.4s, v8.4s, v0.s[1]\n" + "fmla v22.4s, v8.4s, v1.s[1]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "add x10, x10, #0x80\n" + "fmla v24.4s, v8.4s, v2.s[1]\n" + "fmla v26.4s, v8.4s, v3.s[1]\n" + "fmla v21.4s, v9.4s, v0.s[1]\n" + "fmla v23.4s, v9.4s, v1.s[1]\n" + "fmla v25.4s, v9.4s, v2.s[1]\n" + "fmla v27.4s, v9.4s, v3.s[1]\n" + "fmla v20.4s, v10.4s, v0.s[2]\n" + "fmla v22.4s, v10.4s, v1.s[2]\n" + "fmla v24.4s, v10.4s, v2.s[2]\n" + "fmla v26.4s, v10.4s, v3.s[2]\n" + "fmla v21.4s, v11.4s, v0.s[2]\n" + "fmla v23.4s, v11.4s, v1.s[2]\n" + "fmla v25.4s, v11.4s, v2.s[2]\n" + "fmla v27.4s, v11.4s, v3.s[2]\n" + "fmla v20.4s, v12.4s, v0.s[3]\n" + "fmla v22.4s, v12.4s, v1.s[3]\n" + "fmla v24.4s, v12.4s, v2.s[3]\n" + "fmla v26.4s, v12.4s, v3.s[3]\n" + "fmla v21.4s, v13.4s, v0.s[3]\n" + "fmla v23.4s, v13.4s, v1.s[3]\n" + "fmla v25.4s, v13.4s, v2.s[3]\n" + "fmla v27.4s, v13.4s, v3.s[3]\n" + "91:" // Height 4: Multiply loop: Main loop skip + "cbz x27, 93f\n" + "92:" // Height 4: Multiply loop: Odd block loop + "ldr s29, [x26], #0x4\n" + "ldr s28, [x25], #0x4\n" + "sub x27, x27, #0x1\n" + "ldr s19, [x24], #0x4\n" + "ldr s18, [x23], #0x4\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "fmla v20.4s, v17.4s, v29.s[0]\n" + "fmla v22.4s, v17.4s, v28.s[0]\n" + "fmla v24.4s, v17.4s, v19.s[0]\n" + "fmla v26.4s, v17.4s, v18.s[0]\n" + "fmla v21.4s, v16.4s, v29.s[0]\n" + "fmla v23.4s, v16.4s, v28.s[0]\n" + "fmla v25.4s, v16.4s, v19.s[0]\n" + "fmla v27.4s, v16.4s, v18.s[0]\n" + "cbnz x27, 92b\n" + "93:" // Height 4: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 86b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "prfm pstl1keep, [x9, #0x0]\n" + "add x26, x9, x20, LSL #2\n" + "prfm pstl1keep, [x26, #0x0]\n" + "add x25, x26, x20, LSL #2\n" + "prfm pstl1keep, [x25, #0x0]\n" + "add x24, x25, x20, LSL #2\n" + "prfm pstl1keep, [x24, #0x0]\n" + "tbz %x[flags], #1, 94f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v17.4s }, [x21]\n" + "ld1r { v16.4s }, [x20]\n" + "fmin v20.4s, v20.4s, v17.4s\n" + "fmin v21.4s, v21.4s, v17.4s\n" + "fmin v22.4s, v22.4s, v17.4s\n" + "fmin v23.4s, v23.4s, v17.4s\n" + "fmin v24.4s, v24.4s, v17.4s\n" + "fmin v25.4s, v25.4s, v17.4s\n" + "fmin v26.4s, v26.4s, v17.4s\n" + "fmin v27.4s, v27.4s, v17.4s\n" + "fmax v20.4s, v20.4s, v16.4s\n" + "fmax v21.4s, v21.4s, v16.4s\n" + "fmax v22.4s, v22.4s, v16.4s\n" + "fmax v23.4s, v23.4s, v16.4s\n" + "fmax v24.4s, v24.4s, v16.4s\n" + "fmax v25.4s, v25.4s, v16.4s\n" + "fmax v26.4s, v26.4s, v16.4s\n" + "fmax v27.4s, v27.4s, v16.4s\n" + "94:" // Height 4: No activation + "cmp x11, #0x8\n" + "bge 99f\n" + "tbz x11, #2, 96f\n" + "st1 { v20.4s }, [x9], #0x10\n" + "st1 { v22.4s }, [x26], #0x10\n" + "st1 { v24.4s }, [x25], #0x10\n" + "st1 { v26.4s }, [x24], #0x10\n" + "tbz x11, #1, 95f\n" + "str d21, [x9], #0x8\n" + "str d23, [x26], #0x8\n" + "str d25, [x25], #0x8\n" + "str d27, [x24], #0x8\n" + "tbz x11, #0, 98f\n" + "st1 { v21.s }[2], [x9]\n" + "st1 { v23.s }[2], [x26]\n" + "st1 { v25.s }[2], [x25]\n" + "st1 { v27.s }[2], [x24]\n" + "b 98f\n" + "95:" // Height 4: Partial direct writeback: partial_1_4 + "tbz x11, #0, 98f\n" + "str s21, [x9, #0x0]\n" + "str s23, [x26, #0x0]\n" + "str s25, [x25, #0x0]\n" + "str s27, [x24, #0x0]\n" + "b 98f\n" + "96:" // Height 4: Partial direct writeback: partial_2_0 + "tbz x11, #1, 97f\n" + "str d20, [x9], #0x8\n" + "str d22, [x26], #0x8\n" + "str d24, [x25], #0x8\n" + "str d26, [x24], #0x8\n" + "tbz x11, #0, 98f\n" + "st1 { v20.s }[2], [x9]\n" + "st1 { v22.s }[2], [x26]\n" + "st1 { v24.s }[2], [x25]\n" + "st1 { v26.s }[2], [x24]\n" + "b 98f\n" + "97:" // Height 4: Partial direct writeback: partial_1_0 + "str s20, [x9, #0x0]\n" + "str s22, [x26, #0x0]\n" + "str s24, [x25, #0x0]\n" + "str s26, [x24, #0x0]\n" + "98:" // Height 4: Partial direct writeback: Done + "b 100f\n" + "99:" // Height 4: Full writeback + "str q20, [x9, #0x0]\n" + "str q21, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + "str q22, [x26, #0x0]\n" + "str q23, [x26, #0x10]\n" + "str q24, [x25, #0x0]\n" + "str q25, [x25, #0x10]\n" + "str q26, [x24, #0x0]\n" + "str q27, [x24, #0x10]\n" + "100:" // Height 4: Writeback done + "subs x11, x11, #0x8\n" + "bgt 77b\n" + "b 152f\n" + "101:" // Height 5 + "ldr x11, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x10, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x9, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "102:" // Height 5: Column loop + "cbz x10, 103f\n" + "ldr q20, [x10, #0x0]\n" + "ldr q21, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "mov v22.16b, v20.16b\n" + "mov v23.16b, v21.16b\n" + "mov v24.16b, v20.16b\n" + "mov v25.16b, v21.16b\n" + "mov v26.16b, v20.16b\n" + "mov v27.16b, v21.16b\n" + "mov v28.16b, v20.16b\n" + "mov v29.16b, v21.16b\n" + "b 110f\n" + "103:" // Height 5: no bias + "tbz %x[flags], #0, 109f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x11, #0x8\n" + "add x26, x9, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "add x23, x24, x20, LSL #2\n" + "bge 108f\n" + "tbz x11, #2, 105f\n" + "ld1 { v20.4s }, [x9], #0x10\n" + "ld1 { v22.4s }, [x26], #0x10\n" + "ld1 { v24.4s }, [x25], #0x10\n" + "ld1 { v26.4s }, [x24], #0x10\n" + "ld1 { v28.4s }, [x23], #0x10\n" + "tbz x11, #1, 104f\n" + "ldr d21, [x9], #0x8\n" + "ldr d23, [x26], #0x8\n" + "mov x20, #0x18\n" + "ldr d25, [x25], #0x8\n" + "ldr d27, [x24], #0x8\n" + "ldr d29, [x23], #0x8\n" + "tbz x11, #0, 107f\n" + "ld1 { v21.s }[2], [x9]\n" + "ld1 { v23.s }[2], [x26]\n" + "ld1 { v25.s }[2], [x25]\n" + "ld1 { v27.s }[2], [x24]\n" + "ld1 { v29.s }[2], [x23]\n" + "b 107f\n" + "104:" // Height 5: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x11, #0, 107f\n" + "ldr s21, [x9, #0x0]\n" + "ldr s23, [x26, #0x0]\n" + "ldr s25, [x25, #0x0]\n" + "ldr s27, [x24, #0x0]\n" + "ldr s29, [x23, #0x0]\n" + "b 107f\n" + "105:" // Height 5: Partial accumulate: partial_2_0 + "tbz x11, #1, 106f\n" + "ldr d20, [x9], #0x8\n" + "ldr d22, [x26], #0x8\n" + "mov x20, #0x8\n" + "ldr d24, [x25], #0x8\n" + "ldr d26, [x24], #0x8\n" + "ldr d28, [x23], #0x8\n" + "tbz x11, #0, 107f\n" + "ld1 { v20.s }[2], [x9]\n" + "ld1 { v22.s }[2], [x26]\n" + "ld1 { v24.s }[2], [x25]\n" + "ld1 { v26.s }[2], [x24]\n" + "ld1 { v28.s }[2], [x23]\n" + "b 107f\n" + "106:" // Height 5: Partial accumulate: partial_1_0 + "ldr s20, [x9, #0x0]\n" + "ldr s22, [x26, #0x0]\n" + "mov x20, #0x0\n" + "ldr s24, [x25, #0x0]\n" + "ldr s26, [x24, #0x0]\n" + "ldr s28, [x23, #0x0]\n" + "107:" // Height 5: Partial accumulate: Done + "sub x9, x9, x20\n" + "b 110f\n" + "108:" // Height 5: full accumulate + "ldr q20, [x9, #0x0]\n" + "ldr q21, [x9, #0x10]\n" + "ldr q22, [x26, #0x0]\n" + "ldr q23, [x26, #0x10]\n" + "ldr q24, [x25, #0x0]\n" + "ldr q25, [x25, #0x10]\n" + "ldr q26, [x24, #0x0]\n" + "ldr q27, [x24, #0x10]\n" + "ldr q28, [x23, #0x0]\n" + "ldr q29, [x23, #0x10]\n" + "b 110f\n" + "109:" // Height 5: no accumulate + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "110:" // Height 5: setup done + "mov x28, #0x0\n" + "111:" // Height 5: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 112f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "ldr x24, [x20, #0x10]\n" + "ldr x23, [x20, #0x18]\n" + "ldr x22, [x20, #0x20]\n" + "cbnz x28, 113f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "add x24, x24, x20, LSL #2\n" + "add x23, x23, x20, LSL #2\n" + "add x22, x22, x20, LSL #2\n" + "b 113f\n" + "112:" // Height 5: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "add x24, x25, x21, LSL #2\n" + "add x23, x24, x21, LSL #2\n" + "add x22, x23, x21, LSL #2\n" + "113:" // Height 5: input setup done + "cmp x27, #0x4\n" + "blt 116f\n" + "ldr q0, [x26, #0x0]\n" + "ldr q1, [x25, #0x0]\n" + "cmp x27, #0x8\n" + "ldr q2, [x24, #0x0]\n" + "ldr q3, [x23, #0x0]\n" + "ldr q4, [x22, #0x0]\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "ldr q8, [x10, #0x20]\n" + "ldr q9, [x10, #0x30]\n" + "ldr q10, [x10, #0x40]\n" + "ldr q11, [x10, #0x50]\n" + "ldr q12, [x10, #0x60]\n" + "ldr q13, [x10, #0x70]\n" + "blt 115f\n" + "114:" // Height 5: Multiply loop: Main loop head + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "sub x27, x27, #0x4\n" + "add x26, x26, #0x10\n" + "fmla v24.4s, v6.4s, v2.s[0]\n" + "fmla v26.4s, v6.4s, v3.s[0]\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "fmla v28.4s, v6.4s, v4.s[0]\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "add x23, x23, #0x10\n" + "add x22, x22, #0x10\n" + "fmla v23.4s, v7.4s, v1.s[0]\n" + "fmla v25.4s, v7.4s, v2.s[0]\n" + "cmp x27, #0x8\n" + "add x10, x10, #0x80\n" + "ldr q6, [x10, #0x0]\n" + "fmla v27.4s, v7.4s, v3.s[0]\n" + "fmla v29.4s, v7.4s, v4.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v20.4s, v8.4s, v0.s[1]\n" + "fmla v22.4s, v8.4s, v1.s[1]\n" + "prfm pldl1keep, [x26, #0x80]\n" + "prfm pldl1keep, [x25, #0x80]\n" + "fmla v24.4s, v8.4s, v2.s[1]\n" + "fmla v26.4s, v8.4s, v3.s[1]\n" + "prfm pldl1keep, [x24, #0x80]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "fmla v28.4s, v8.4s, v4.s[1]\n" + "ldr q8, [x10, #0x20]\n" + "fmla v21.4s, v9.4s, v0.s[1]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "fmla v23.4s, v9.4s, v1.s[1]\n" + "fmla v25.4s, v9.4s, v2.s[1]\n" + "fmla v27.4s, v9.4s, v3.s[1]\n" + "fmla v29.4s, v9.4s, v4.s[1]\n" + "ldr q9, [x10, #0x30]\n" + "fmla v20.4s, v10.4s, v0.s[2]\n" + "fmla v22.4s, v10.4s, v1.s[2]\n" + "fmla v24.4s, v10.4s, v2.s[2]\n" + "fmla v26.4s, v10.4s, v3.s[2]\n" + "fmla v28.4s, v10.4s, v4.s[2]\n" + "ldr q10, [x10, #0x40]\n" + "fmla v21.4s, v11.4s, v0.s[2]\n" + "fmla v23.4s, v11.4s, v1.s[2]\n" + "fmla v25.4s, v11.4s, v2.s[2]\n" + "fmla v27.4s, v11.4s, v3.s[2]\n" + "fmla v29.4s, v11.4s, v4.s[2]\n" + "ldr q11, [x10, #0x50]\n" + "fmla v20.4s, v12.4s, v0.s[3]\n" + "fmla v22.4s, v12.4s, v1.s[3]\n" + "fmla v24.4s, v12.4s, v2.s[3]\n" + "fmla v26.4s, v12.4s, v3.s[3]\n" + "fmla v28.4s, v12.4s, v4.s[3]\n" + "ldr q12, [x10, #0x60]\n" + "fmla v21.4s, v13.4s, v0.s[3]\n" + "ldr q0, [x26, #0x0]\n" + "fmla v23.4s, v13.4s, v1.s[3]\n" + "ldr q1, [x25, #0x0]\n" + "fmla v25.4s, v13.4s, v2.s[3]\n" + "ldr q2, [x24, #0x0]\n" + "fmla v27.4s, v13.4s, v3.s[3]\n" + "ldr q3, [x23, #0x0]\n" + "fmla v29.4s, v13.4s, v4.s[3]\n" + "ldr q4, [x22, #0x0]\n" + "ldr q13, [x10, #0x70]\n" + "bge 114b\n" + "115:" // Height 5: Multiply loop: Single iteration only + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "add x26, x26, #0x10\n" + "add x25, x25, #0x10\n" + "fmla v24.4s, v6.4s, v2.s[0]\n" + "fmla v26.4s, v6.4s, v3.s[0]\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "fmla v28.4s, v6.4s, v4.s[0]\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "add x22, x22, #0x10\n" + "sub x27, x27, #0x4\n" + "fmla v23.4s, v7.4s, v1.s[0]\n" + "fmla v25.4s, v7.4s, v2.s[0]\n" + "prfm pldl1keep, [x26, #0x80]\n" + "prfm pldl1keep, [x25, #0x80]\n" + "fmla v27.4s, v7.4s, v3.s[0]\n" + "fmla v29.4s, v7.4s, v4.s[0]\n" + "prfm pldl1keep, [x24, #0x80]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "fmla v20.4s, v8.4s, v0.s[1]\n" + "fmla v22.4s, v8.4s, v1.s[1]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "add x10, x10, #0x80\n" + "fmla v24.4s, v8.4s, v2.s[1]\n" + "fmla v26.4s, v8.4s, v3.s[1]\n" + "fmla v28.4s, v8.4s, v4.s[1]\n" + "fmla v21.4s, v9.4s, v0.s[1]\n" + "fmla v23.4s, v9.4s, v1.s[1]\n" + "fmla v25.4s, v9.4s, v2.s[1]\n" + "fmla v27.4s, v9.4s, v3.s[1]\n" + "fmla v29.4s, v9.4s, v4.s[1]\n" + "fmla v20.4s, v10.4s, v0.s[2]\n" + "fmla v22.4s, v10.4s, v1.s[2]\n" + "fmla v24.4s, v10.4s, v2.s[2]\n" + "fmla v26.4s, v10.4s, v3.s[2]\n" + "fmla v28.4s, v10.4s, v4.s[2]\n" + "fmla v21.4s, v11.4s, v0.s[2]\n" + "fmla v23.4s, v11.4s, v1.s[2]\n" + "fmla v25.4s, v11.4s, v2.s[2]\n" + "fmla v27.4s, v11.4s, v3.s[2]\n" + "fmla v29.4s, v11.4s, v4.s[2]\n" + "fmla v20.4s, v12.4s, v0.s[3]\n" + "fmla v22.4s, v12.4s, v1.s[3]\n" + "fmla v24.4s, v12.4s, v2.s[3]\n" + "fmla v26.4s, v12.4s, v3.s[3]\n" + "fmla v28.4s, v12.4s, v4.s[3]\n" + "fmla v21.4s, v13.4s, v0.s[3]\n" + "fmla v23.4s, v13.4s, v1.s[3]\n" + "fmla v25.4s, v13.4s, v2.s[3]\n" + "fmla v27.4s, v13.4s, v3.s[3]\n" + "fmla v29.4s, v13.4s, v4.s[3]\n" + "116:" // Height 5: Multiply loop: Main loop skip + "cbz x27, 118f\n" + "117:" // Height 5: Multiply loop: Odd block loop + "ldr s0, [x26], #0x4\n" + "ldr s31, [x25], #0x4\n" + "sub x27, x27, #0x1\n" + "ldr s30, [x24], #0x4\n" + "ldr s19, [x23], #0x4\n" + "ldr s18, [x22], #0x4\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "fmla v20.4s, v17.4s, v0.s[0]\n" + "fmla v22.4s, v17.4s, v31.s[0]\n" + "fmla v24.4s, v17.4s, v30.s[0]\n" + "fmla v26.4s, v17.4s, v19.s[0]\n" + "fmla v28.4s, v17.4s, v18.s[0]\n" + "fmla v21.4s, v16.4s, v0.s[0]\n" + "fmla v23.4s, v16.4s, v31.s[0]\n" + "fmla v25.4s, v16.4s, v30.s[0]\n" + "fmla v27.4s, v16.4s, v19.s[0]\n" + "fmla v29.4s, v16.4s, v18.s[0]\n" + "cbnz x27, 117b\n" + "118:" // Height 5: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 111b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "prfm pstl1keep, [x9, #0x0]\n" + "add x26, x9, x20, LSL #2\n" + "prfm pstl1keep, [x26, #0x0]\n" + "add x25, x26, x20, LSL #2\n" + "prfm pstl1keep, [x25, #0x0]\n" + "add x24, x25, x20, LSL #2\n" + "prfm pstl1keep, [x24, #0x0]\n" + "add x23, x24, x20, LSL #2\n" + "prfm pstl1keep, [x23, #0x0]\n" + "tbz %x[flags], #1, 119f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v17.4s }, [x21]\n" + "ld1r { v16.4s }, [x20]\n" + "fmin v20.4s, v20.4s, v17.4s\n" + "fmin v21.4s, v21.4s, v17.4s\n" + "fmin v22.4s, v22.4s, v17.4s\n" + "fmin v23.4s, v23.4s, v17.4s\n" + "fmin v24.4s, v24.4s, v17.4s\n" + "fmin v25.4s, v25.4s, v17.4s\n" + "fmin v26.4s, v26.4s, v17.4s\n" + "fmin v27.4s, v27.4s, v17.4s\n" + "fmin v28.4s, v28.4s, v17.4s\n" + "fmin v29.4s, v29.4s, v17.4s\n" + "fmax v20.4s, v20.4s, v16.4s\n" + "fmax v21.4s, v21.4s, v16.4s\n" + "fmax v22.4s, v22.4s, v16.4s\n" + "fmax v23.4s, v23.4s, v16.4s\n" + "fmax v24.4s, v24.4s, v16.4s\n" + "fmax v25.4s, v25.4s, v16.4s\n" + "fmax v26.4s, v26.4s, v16.4s\n" + "fmax v27.4s, v27.4s, v16.4s\n" + "fmax v28.4s, v28.4s, v16.4s\n" + "fmax v29.4s, v29.4s, v16.4s\n" + "119:" // Height 5: No activation + "cmp x11, #0x8\n" + "bge 124f\n" + "tbz x11, #2, 121f\n" + "st1 { v20.4s }, [x9], #0x10\n" + "st1 { v22.4s }, [x26], #0x10\n" + "st1 { v24.4s }, [x25], #0x10\n" + "st1 { v26.4s }, [x24], #0x10\n" + "st1 { v28.4s }, [x23], #0x10\n" + "tbz x11, #1, 120f\n" + "str d21, [x9], #0x8\n" + "str d23, [x26], #0x8\n" + "str d25, [x25], #0x8\n" + "str d27, [x24], #0x8\n" + "str d29, [x23], #0x8\n" + "tbz x11, #0, 123f\n" + "st1 { v21.s }[2], [x9]\n" + "st1 { v23.s }[2], [x26]\n" + "st1 { v25.s }[2], [x25]\n" + "st1 { v27.s }[2], [x24]\n" + "st1 { v29.s }[2], [x23]\n" + "b 123f\n" + "120:" // Height 5: Partial direct writeback: partial_1_4 + "tbz x11, #0, 123f\n" + "str s21, [x9, #0x0]\n" + "str s23, [x26, #0x0]\n" + "str s25, [x25, #0x0]\n" + "str s27, [x24, #0x0]\n" + "str s29, [x23, #0x0]\n" + "b 123f\n" + "121:" // Height 5: Partial direct writeback: partial_2_0 + "tbz x11, #1, 122f\n" + "str d20, [x9], #0x8\n" + "str d22, [x26], #0x8\n" + "str d24, [x25], #0x8\n" + "str d26, [x24], #0x8\n" + "str d28, [x23], #0x8\n" + "tbz x11, #0, 123f\n" + "st1 { v20.s }[2], [x9]\n" + "st1 { v22.s }[2], [x26]\n" + "st1 { v24.s }[2], [x25]\n" + "st1 { v26.s }[2], [x24]\n" + "st1 { v28.s }[2], [x23]\n" + "b 123f\n" + "122:" // Height 5: Partial direct writeback: partial_1_0 + "str s20, [x9, #0x0]\n" + "str s22, [x26, #0x0]\n" + "str s24, [x25, #0x0]\n" + "str s26, [x24, #0x0]\n" + "str s28, [x23, #0x0]\n" + "123:" // Height 5: Partial direct writeback: Done + "b 125f\n" + "124:" // Height 5: Full writeback + "str q20, [x9, #0x0]\n" + "str q21, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + "str q22, [x26, #0x0]\n" + "str q23, [x26, #0x10]\n" + "str q24, [x25, #0x0]\n" + "str q25, [x25, #0x10]\n" + "str q26, [x24, #0x0]\n" + "str q27, [x24, #0x10]\n" + "str q28, [x23, #0x0]\n" + "str q29, [x23, #0x10]\n" + "125:" // Height 5: Writeback done + "subs x11, x11, #0x8\n" + "bgt 102b\n" + "b 152f\n" + "126:" // Height 6 + "ldr x21, [%x[args_ptr], %[offsetof_output_offset]]\n" + "ldr x9, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "mov x20, #0x18\n" + "ldr x11, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x10, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "madd x20, x21, x20, x9\n" + "str x20, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "127:" // Height 6: Column loop + "cbz x10, 128f\n" + "ldr q20, [x10, #0x0]\n" + "ldr q21, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "mov v22.16b, v20.16b\n" + "mov v23.16b, v21.16b\n" + "mov v24.16b, v20.16b\n" + "mov v25.16b, v21.16b\n" + "mov v26.16b, v20.16b\n" + "mov v27.16b, v21.16b\n" + "mov v28.16b, v20.16b\n" + "mov v29.16b, v21.16b\n" + "mov v30.16b, v20.16b\n" + "mov v31.16b, v21.16b\n" + "b 135f\n" + "128:" // Height 6: no bias + "tbz %x[flags], #0, 134f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x11, #0x8\n" + "add x26, x9, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "add x23, x24, x20, LSL #2\n" + "add x22, x23, x20, LSL #2\n" + "bge 133f\n" + "tbz x11, #2, 130f\n" + "ld1 { v20.4s }, [x9], #0x10\n" + "ld1 { v22.4s }, [x26], #0x10\n" + "ld1 { v24.4s }, [x25], #0x10\n" + "ld1 { v26.4s }, [x24], #0x10\n" + "ld1 { v28.4s }, [x23], #0x10\n" + "ld1 { v30.4s }, [x22], #0x10\n" + "tbz x11, #1, 129f\n" + "ldr d21, [x9], #0x8\n" + "ldr d23, [x26], #0x8\n" + "mov x20, #0x18\n" + "ldr d25, [x25], #0x8\n" + "ldr d27, [x24], #0x8\n" + "ldr d29, [x23], #0x8\n" + "ldr d31, [x22], #0x8\n" + "tbz x11, #0, 132f\n" + "ld1 { v21.s }[2], [x9]\n" + "ld1 { v23.s }[2], [x26]\n" + "ld1 { v25.s }[2], [x25]\n" + "ld1 { v27.s }[2], [x24]\n" + "ld1 { v29.s }[2], [x23]\n" + "ld1 { v31.s }[2], [x22]\n" + "b 132f\n" + "129:" // Height 6: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x11, #0, 132f\n" + "ldr s21, [x9, #0x0]\n" + "ldr s23, [x26, #0x0]\n" + "ldr s25, [x25, #0x0]\n" + "ldr s27, [x24, #0x0]\n" + "ldr s29, [x23, #0x0]\n" + "ldr s31, [x22, #0x0]\n" + "b 132f\n" + "130:" // Height 6: Partial accumulate: partial_2_0 + "tbz x11, #1, 131f\n" + "ldr d20, [x9], #0x8\n" + "ldr d22, [x26], #0x8\n" + "mov x20, #0x8\n" + "ldr d24, [x25], #0x8\n" + "ldr d26, [x24], #0x8\n" + "ldr d28, [x23], #0x8\n" + "ldr d30, [x22], #0x8\n" + "tbz x11, #0, 132f\n" + "ld1 { v20.s }[2], [x9]\n" + "ld1 { v22.s }[2], [x26]\n" + "ld1 { v24.s }[2], [x25]\n" + "ld1 { v26.s }[2], [x24]\n" + "ld1 { v28.s }[2], [x23]\n" + "ld1 { v30.s }[2], [x22]\n" + "b 132f\n" + "131:" // Height 6: Partial accumulate: partial_1_0 + "ldr s20, [x9, #0x0]\n" + "ldr s22, [x26, #0x0]\n" + "mov x20, #0x0\n" + "ldr s24, [x25, #0x0]\n" + "ldr s26, [x24, #0x0]\n" + "ldr s28, [x23, #0x0]\n" + "ldr s30, [x22, #0x0]\n" + "132:" // Height 6: Partial accumulate: Done + "sub x9, x9, x20\n" + "b 135f\n" + "133:" // Height 6: full accumulate + "ldr q20, [x9, #0x0]\n" + "ldr q21, [x9, #0x10]\n" + "ldr q22, [x26, #0x0]\n" + "ldr q23, [x26, #0x10]\n" + "ldr q24, [x25, #0x0]\n" + "ldr q25, [x25, #0x10]\n" + "ldr q26, [x24, #0x0]\n" + "ldr q27, [x24, #0x10]\n" + "ldr q28, [x23, #0x0]\n" + "ldr q29, [x23, #0x10]\n" + "ldr q30, [x22, #0x0]\n" + "ldr q31, [x22, #0x10]\n" + "b 135f\n" + "134:" // Height 6: no accumulate + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "135:" // Height 6: setup done + "mov x28, #0x0\n" + "136:" // Height 6: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 137f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "ldr x24, [x20, #0x10]\n" + "ldr x23, [x20, #0x18]\n" + "ldr x22, [x20, #0x20]\n" + "ldr x21, [x20, #0x28]\n" + "cbnz x28, 138f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "add x24, x24, x20, LSL #2\n" + "add x23, x23, x20, LSL #2\n" + "add x22, x22, x20, LSL #2\n" + "add x21, x21, x20, LSL #2\n" + "b 138f\n" + "137:" // Height 6: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "add x24, x25, x21, LSL #2\n" + "add x23, x24, x21, LSL #2\n" + "add x22, x23, x21, LSL #2\n" + "add x21, x22, x21, LSL #2\n" + "138:" // Height 6: input setup done + "cmp x27, #0x4\n" + "blt 141f\n" + "ldr q0, [x26, #0x0]\n" + "ldr q1, [x25, #0x0]\n" + "cmp x27, #0x8\n" + "ldr q2, [x24, #0x0]\n" + "ldr q3, [x23, #0x0]\n" + "ldr q4, [x22, #0x0]\n" + "ldr q5, [x21, #0x0]\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "ldr q8, [x10, #0x20]\n" + "ldr q9, [x10, #0x30]\n" + "ldr q10, [x10, #0x40]\n" + "ldr q11, [x10, #0x50]\n" + "ldr q12, [x10, #0x60]\n" + "ldr q13, [x10, #0x70]\n" + "blt 140f\n" + "139:" // Height 6: Multiply loop: Main loop head + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "sub x27, x27, #0x4\n" + "add x26, x26, #0x10\n" + "fmla v24.4s, v6.4s, v2.s[0]\n" + "fmla v26.4s, v6.4s, v3.s[0]\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "fmla v28.4s, v6.4s, v4.s[0]\n" + "fmla v30.4s, v6.4s, v5.s[0]\n" + "add x23, x23, #0x10\n" + "add x22, x22, #0x10\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "fmla v23.4s, v7.4s, v1.s[0]\n" + "add x21, x21, #0x10\n" + "cmp x27, #0x8\n" + "fmla v25.4s, v7.4s, v2.s[0]\n" + "fmla v27.4s, v7.4s, v3.s[0]\n" + "add x10, x10, #0x80\n" + "prfm pldl1keep, [x26, #0x80]\n" + "ldr q6, [x10, #0x0]\n" + "fmla v29.4s, v7.4s, v4.s[0]\n" + "fmla v31.4s, v7.4s, v5.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v20.4s, v8.4s, v0.s[1]\n" + "fmla v22.4s, v8.4s, v1.s[1]\n" + "prfm pldl1keep, [x25, #0x80]\n" + "prfm pldl1keep, [x24, #0x80]\n" + "fmla v24.4s, v8.4s, v2.s[1]\n" + "fmla v26.4s, v8.4s, v3.s[1]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "fmla v28.4s, v8.4s, v4.s[1]\n" + "fmla v30.4s, v8.4s, v5.s[1]\n" + "ldr q8, [x10, #0x20]\n" + "prfm pldl1keep, [x21, #0x80]\n" + "fmla v21.4s, v9.4s, v0.s[1]\n" + "fmla v23.4s, v9.4s, v1.s[1]\n" + "fmla v25.4s, v9.4s, v2.s[1]\n" + "fmla v27.4s, v9.4s, v3.s[1]\n" + "fmla v29.4s, v9.4s, v4.s[1]\n" + "fmla v31.4s, v9.4s, v5.s[1]\n" + "ldr q9, [x10, #0x30]\n" + "fmla v20.4s, v10.4s, v0.s[2]\n" + "fmla v22.4s, v10.4s, v1.s[2]\n" + "fmla v24.4s, v10.4s, v2.s[2]\n" + "fmla v26.4s, v10.4s, v3.s[2]\n" + "fmla v28.4s, v10.4s, v4.s[2]\n" + "fmla v30.4s, v10.4s, v5.s[2]\n" + "ldr q10, [x10, #0x40]\n" + "fmla v21.4s, v11.4s, v0.s[2]\n" + "fmla v23.4s, v11.4s, v1.s[2]\n" + "fmla v25.4s, v11.4s, v2.s[2]\n" + "fmla v27.4s, v11.4s, v3.s[2]\n" + "fmla v29.4s, v11.4s, v4.s[2]\n" + "fmla v31.4s, v11.4s, v5.s[2]\n" + "ldr q11, [x10, #0x50]\n" + "fmla v20.4s, v12.4s, v0.s[3]\n" + "fmla v22.4s, v12.4s, v1.s[3]\n" + "fmla v24.4s, v12.4s, v2.s[3]\n" + "fmla v26.4s, v12.4s, v3.s[3]\n" + "fmla v28.4s, v12.4s, v4.s[3]\n" + "fmla v30.4s, v12.4s, v5.s[3]\n" + "ldr q12, [x10, #0x60]\n" + "fmla v21.4s, v13.4s, v0.s[3]\n" + "ldr q0, [x26, #0x0]\n" + "fmla v23.4s, v13.4s, v1.s[3]\n" + "ldr q1, [x25, #0x0]\n" + "fmla v25.4s, v13.4s, v2.s[3]\n" + "ldr q2, [x24, #0x0]\n" + "fmla v27.4s, v13.4s, v3.s[3]\n" + "ldr q3, [x23, #0x0]\n" + "fmla v29.4s, v13.4s, v4.s[3]\n" + "ldr q4, [x22, #0x0]\n" + "fmla v31.4s, v13.4s, v5.s[3]\n" + "ldr q5, [x21, #0x0]\n" + "ldr q13, [x10, #0x70]\n" + "bge 139b\n" + "140:" // Height 6: Multiply loop: Single iteration only + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "add x26, x26, #0x10\n" + "add x25, x25, #0x10\n" + "fmla v24.4s, v6.4s, v2.s[0]\n" + "fmla v26.4s, v6.4s, v3.s[0]\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "fmla v28.4s, v6.4s, v4.s[0]\n" + "fmla v30.4s, v6.4s, v5.s[0]\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "fmla v23.4s, v7.4s, v1.s[0]\n" + "sub x27, x27, #0x4\n" + "prfm pldl1keep, [x26, #0x80]\n" + "fmla v25.4s, v7.4s, v2.s[0]\n" + "fmla v27.4s, v7.4s, v3.s[0]\n" + "prfm pldl1keep, [x25, #0x80]\n" + "prfm pldl1keep, [x24, #0x80]\n" + "fmla v29.4s, v7.4s, v4.s[0]\n" + "fmla v31.4s, v7.4s, v5.s[0]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "fmla v20.4s, v8.4s, v0.s[1]\n" + "fmla v22.4s, v8.4s, v1.s[1]\n" + "prfm pldl1keep, [x21, #0x80]\n" + "add x10, x10, #0x80\n" + "fmla v24.4s, v8.4s, v2.s[1]\n" + "fmla v26.4s, v8.4s, v3.s[1]\n" + "fmla v28.4s, v8.4s, v4.s[1]\n" + "fmla v30.4s, v8.4s, v5.s[1]\n" + "fmla v21.4s, v9.4s, v0.s[1]\n" + "fmla v23.4s, v9.4s, v1.s[1]\n" + "fmla v25.4s, v9.4s, v2.s[1]\n" + "fmla v27.4s, v9.4s, v3.s[1]\n" + "fmla v29.4s, v9.4s, v4.s[1]\n" + "fmla v31.4s, v9.4s, v5.s[1]\n" + "fmla v20.4s, v10.4s, v0.s[2]\n" + "fmla v22.4s, v10.4s, v1.s[2]\n" + "fmla v24.4s, v10.4s, v2.s[2]\n" + "fmla v26.4s, v10.4s, v3.s[2]\n" + "fmla v28.4s, v10.4s, v4.s[2]\n" + "fmla v30.4s, v10.4s, v5.s[2]\n" + "fmla v21.4s, v11.4s, v0.s[2]\n" + "fmla v23.4s, v11.4s, v1.s[2]\n" + "fmla v25.4s, v11.4s, v2.s[2]\n" + "fmla v27.4s, v11.4s, v3.s[2]\n" + "fmla v29.4s, v11.4s, v4.s[2]\n" + "fmla v31.4s, v11.4s, v5.s[2]\n" + "fmla v20.4s, v12.4s, v0.s[3]\n" + "fmla v22.4s, v12.4s, v1.s[3]\n" + "fmla v24.4s, v12.4s, v2.s[3]\n" + "fmla v26.4s, v12.4s, v3.s[3]\n" + "fmla v28.4s, v12.4s, v4.s[3]\n" + "fmla v30.4s, v12.4s, v5.s[3]\n" + "fmla v21.4s, v13.4s, v0.s[3]\n" + "fmla v23.4s, v13.4s, v1.s[3]\n" + "fmla v25.4s, v13.4s, v2.s[3]\n" + "fmla v27.4s, v13.4s, v3.s[3]\n" + "fmla v29.4s, v13.4s, v4.s[3]\n" + "fmla v31.4s, v13.4s, v5.s[3]\n" + "141:" // Height 6: Multiply loop: Main loop skip + "cbz x27, 143f\n" + "142:" // Height 6: Multiply loop: Odd block loop + "ldr s3, [x26], #0x4\n" + "ldr s2, [x25], #0x4\n" + "sub x27, x27, #0x1\n" + "ldr s1, [x24], #0x4\n" + "ldr s0, [x23], #0x4\n" + "ldr s19, [x22], #0x4\n" + "ldr s18, [x21], #0x4\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "fmla v20.4s, v17.4s, v3.s[0]\n" + "fmla v22.4s, v17.4s, v2.s[0]\n" + "fmla v24.4s, v17.4s, v1.s[0]\n" + "fmla v26.4s, v17.4s, v0.s[0]\n" + "fmla v28.4s, v17.4s, v19.s[0]\n" + "fmla v30.4s, v17.4s, v18.s[0]\n" + "fmla v21.4s, v16.4s, v3.s[0]\n" + "fmla v23.4s, v16.4s, v2.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v27.4s, v16.4s, v0.s[0]\n" + "fmla v29.4s, v16.4s, v19.s[0]\n" + "fmla v31.4s, v16.4s, v18.s[0]\n" + "cbnz x27, 142b\n" + "143:" // Height 6: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 136b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "prfm pstl1keep, [x9, #0x0]\n" + "add x26, x9, x20, LSL #2\n" + "prfm pstl1keep, [x26, #0x0]\n" + "add x25, x26, x20, LSL #2\n" + "prfm pstl1keep, [x25, #0x0]\n" + "add x24, x25, x20, LSL #2\n" + "prfm pstl1keep, [x24, #0x0]\n" + "add x23, x24, x20, LSL #2\n" + "add x22, x23, x20, LSL #2\n" + "prfm pstl1keep, [x23, #0x0]\n" + "prfm pstl1keep, [x22, #0x0]\n" + "tbz %x[flags], #1, 144f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v17.4s }, [x21]\n" + "ld1r { v16.4s }, [x20]\n" + "fmin v20.4s, v20.4s, v17.4s\n" + "fmin v21.4s, v21.4s, v17.4s\n" + "fmin v22.4s, v22.4s, v17.4s\n" + "fmin v23.4s, v23.4s, v17.4s\n" + "fmin v24.4s, v24.4s, v17.4s\n" + "fmin v25.4s, v25.4s, v17.4s\n" + "fmin v26.4s, v26.4s, v17.4s\n" + "fmin v27.4s, v27.4s, v17.4s\n" + "fmin v28.4s, v28.4s, v17.4s\n" + "fmin v29.4s, v29.4s, v17.4s\n" + "fmin v30.4s, v30.4s, v17.4s\n" + "fmin v31.4s, v31.4s, v17.4s\n" + "fmax v20.4s, v20.4s, v16.4s\n" + "fmax v21.4s, v21.4s, v16.4s\n" + "fmax v22.4s, v22.4s, v16.4s\n" + "fmax v23.4s, v23.4s, v16.4s\n" + "fmax v24.4s, v24.4s, v16.4s\n" + "fmax v25.4s, v25.4s, v16.4s\n" + "fmax v26.4s, v26.4s, v16.4s\n" + "fmax v27.4s, v27.4s, v16.4s\n" + "fmax v28.4s, v28.4s, v16.4s\n" + "fmax v29.4s, v29.4s, v16.4s\n" + "fmax v30.4s, v30.4s, v16.4s\n" + "fmax v31.4s, v31.4s, v16.4s\n" + "144:" // Height 6: No activation + "cmp x11, #0x8\n" + "bge 149f\n" + "tbz x11, #2, 146f\n" + "st1 { v20.4s }, [x9], #0x10\n" + "st1 { v22.4s }, [x26], #0x10\n" + "st1 { v24.4s }, [x25], #0x10\n" + "st1 { v26.4s }, [x24], #0x10\n" + "st1 { v28.4s }, [x23], #0x10\n" + "st1 { v30.4s }, [x22], #0x10\n" + "tbz x11, #1, 145f\n" + "str d21, [x9], #0x8\n" + "str d23, [x26], #0x8\n" + "str d25, [x25], #0x8\n" + "str d27, [x24], #0x8\n" + "str d29, [x23], #0x8\n" + "str d31, [x22], #0x8\n" + "tbz x11, #0, 148f\n" + "st1 { v21.s }[2], [x9]\n" + "st1 { v23.s }[2], [x26]\n" + "st1 { v25.s }[2], [x25]\n" + "st1 { v27.s }[2], [x24]\n" + "st1 { v29.s }[2], [x23]\n" + "st1 { v31.s }[2], [x22]\n" + "b 148f\n" + "145:" // Height 6: Partial direct writeback: partial_1_4 + "tbz x11, #0, 148f\n" + "str s21, [x9, #0x0]\n" + "str s23, [x26, #0x0]\n" + "str s25, [x25, #0x0]\n" + "str s27, [x24, #0x0]\n" + "str s29, [x23, #0x0]\n" + "str s31, [x22, #0x0]\n" + "b 148f\n" + "146:" // Height 6: Partial direct writeback: partial_2_0 + "tbz x11, #1, 147f\n" + "str d20, [x9], #0x8\n" + "str d22, [x26], #0x8\n" + "str d24, [x25], #0x8\n" + "str d26, [x24], #0x8\n" + "str d28, [x23], #0x8\n" + "str d30, [x22], #0x8\n" + "tbz x11, #0, 148f\n" + "st1 { v20.s }[2], [x9]\n" + "st1 { v22.s }[2], [x26]\n" + "st1 { v24.s }[2], [x25]\n" + "st1 { v26.s }[2], [x24]\n" + "st1 { v28.s }[2], [x23]\n" + "st1 { v30.s }[2], [x22]\n" + "b 148f\n" + "147:" // Height 6: Partial direct writeback: partial_1_0 + "str s20, [x9, #0x0]\n" + "str s22, [x26, #0x0]\n" + "str s24, [x25, #0x0]\n" + "str s26, [x24, #0x0]\n" + "str s28, [x23, #0x0]\n" + "str s30, [x22, #0x0]\n" + "148:" // Height 6: Partial direct writeback: Done + "b 150f\n" + "149:" // Height 6: Full writeback + "str q20, [x9, #0x0]\n" + "str q21, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + "str q22, [x26, #0x0]\n" + "str q23, [x26, #0x10]\n" + "str q24, [x25, #0x0]\n" + "str q25, [x25, #0x10]\n" + "str q26, [x24, #0x0]\n" + "str q27, [x24, #0x10]\n" + "str q28, [x23, #0x0]\n" + "str q29, [x23, #0x10]\n" + "str q30, [x22, #0x0]\n" + "str q31, [x22, #0x10]\n" + "150:" // Height 6: Writeback done + "subs x11, x11, #0x8\n" + "bgt 127b\n" + "subs %x[m], %x[m], #0x6\n" + "beq 152f\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 151f\n" + "add x21, x21, #0x6\n" + "str x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "151:" // Update direct input + "mov x20, #0x18\n" + "madd %x[input_ptr], x20, x21, %x[input_ptr]\n" + "b 1b\n" + "152:" // Exit + : [input_ptr] "+&r"(input_ptr), [m] "+&r"(m) + : [args_ptr] "r"(&ka), [flags] "r"(flags), [offset_max] "I"(offsetof(KernelArgs, maxval)), + [offset_min] "I"(offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I"(offsetof(KernelArgs, B_ptr)), + [offsetof_N] "I"(offsetof(KernelArgs, N)), + [offsetof_input_initial_col] "I"(offsetof(KernelArgs, input_initial_col)), + [offsetof_input_offset] "I"(offsetof(KernelArgs, input_offset)), + [offsetof_num_strings] "I"(offsetof(KernelArgs, num_strings)), + [offsetof_output_offset] "I"(offsetof(KernelArgs, output_offset)), + [offsetof_output_ptr] "I"(offsetof(KernelArgs, output_ptr)), + [offsetof_string_lengths] "I"(offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", + "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h new file mode 100644 index 00000000..274d4453 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h @@ -0,0 +1,124 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void); + +/// Gets the offset in bytes to the data element in the LHS matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m_idx, size_t stride); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] n_idx Column index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m_idx, size_t n_idx, size_t stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operand. +/// @param[in] lhs LHS matrix buffer. +/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. +/// @param[in] rhs_packed Packed RHS buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla( + size_t m, size_t n, size_t k, // + const void* lhs, size_t lhs_stride, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c new file mode 100644 index 00000000..378acb9d --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c @@ -0,0 +1,253 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_nr = 8; +static const size_t kai_kr = 1; + +size_t kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(void) { + return kai_nr; +} + +size_t kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx * sizeof(uint32_t); +} + +size_t kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx) { + return n_idx * sizeof(uint32_t); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx * (sizeof(uint32_t) + k * sizeof(uint32_t)); +} + +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n, size_t k) { + return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(kai_roundup(n, kai_nr), k); +} + +void kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME(nr == kai_nr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == 1); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(scale == NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params == NULL); + + printf( + "n:%zu, k:%zu, nr:%zu, kr:%zu, sr:%zu, rhs_stride:%zu, extra_bytes:%zu \n", n, k, nr, kr, sr, rhs_stride, + extra_bytes); + + // float * rhs_f = (float*)rhs; + // float *rhs_p = (float*) rhs_packed; + + // for (size_t i = 0; i < n; i++) + // { + // for (size_t j = 0; j < k; j++) + // { + // printf("%.1f ",rhs_f[i*k + j]); + // } + // printf("\n"); + // } + + size_t height = k; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_stride; + size_t out_stride = kai_nr * height * sizeof(uint32_t) + kai_nr * sizeof(uint32_t); + + __asm__ __volatile__( + "mov x22, %x[width]\n" + "mov x21, %x[out]\n" + "cmp x22, #0x8\n" + "blt 2f\n" + "1:" // Bias: Full loop + "ldr q17, [%x[bias], #0x0]\n" + "ldr q16, [%x[bias], #0x10]\n" + "sub x22, x22, #0x8\n" + "add %x[bias], %x[bias], #0x20\n" + "cmp x22, #0x8\n" + "str q17, [x21, #0x0]\n" + "str q16, [x21, #0x10]\n" + "add x21, x21, %x[out_stride]\n" + "bge 1b\n" + "cbz x22, 3f\n" + "2:" // Bias: Tail loop + "ldr w20, [%x[bias], #0x0]\n" + "sub x22, x22, #0x1\n" + "add %x[bias], %x[bias], #0x4\n" + "cmp x22, #0x0\n" + "str x20, [x21]\n" + "add x21, x21, #0x4\n" + "bgt 2b\n" + "3:" // Bias: Done + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0x20\n" + "blt 12f\n" + "4:" // Main row loop: Head + "mov x25, %x[in]\n" + "mov x24, %x[width]\n" + "mov x23, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "add x22, x25, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "cmp x24, #0x8\n" + "add %x[in], x20, %x[in_stride]\n" + "blt 6f\n" + "5:" // Main row loop: Column loop + "ldr q23, [x25], #0x10\n" + "ldr q22, [x22], #0x10\n" + "sub x24, x24, #0x8\n" + "ldr q21, [x21], #0x10\n" + "ldr q20, [x20], #0x10\n" + "cmp x24, #0x8\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x22], #0x10\n" + "ldr q17, [x21], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q23, [x23, #0x0]\n" + "str q19, [x23, #0x10]\n" + "str q22, [x23, #0x20]\n" + "str q18, [x23, #0x30]\n" + "str q21, [x23, #0x40]\n" + "str q17, [x23, #0x50]\n" + "str q20, [x23, #0x60]\n" + "str q16, [x23, #0x70]\n" + "add x23, x23, %x[out_stride]\n" + "bge 5b\n" + "6:" // Main row loop: Column loop skip + "cbz x24, 11f\n" + "cmp x24, #0x4\n" + "movi v16.4s, #0x0\n" + "str q16, [x23, #0x0]\n" + "str q16, [x23, #0x10]\n" + "str q16, [x23, #0x20]\n" + "str q16, [x23, #0x30]\n" + "str q16, [x23, #0x40]\n" + "str q16, [x23, #0x50]\n" + "str q16, [x23, #0x60]\n" + "str q16, [x23, #0x70]\n" + "blt 8f\n" + "7:" // Main row loop: width 4 loop: loop + "ldr q19, [x25], #0x10\n" + "ldr q18, [x22], #0x10\n" + "sub x24, x24, #0x4\n" + "ldr q17, [x21], #0x10\n" + "ldr q16, [x20], #0x10\n" + "cmp x24, #0x4\n" + "str q19, [x23, #0x0]\n" + "str q18, [x23, #0x20]\n" + "str q17, [x23, #0x40]\n" + "str q16, [x23, #0x60]\n" + "add x23, x23, #0x10\n" + "bge 7b\n" + "8:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 10f\n" + "9:" // Main row loop: width 1 loop: loop + "ldr s19, [x25], #0x4\n" + "ldr s18, [x22], #0x4\n" + "sub x24, x24, #0x1\n" + "ldr s17, [x21], #0x4\n" + "ldr s16, [x20], #0x4\n" + "cmp x24, #0x1\n" + "str s19, [x23, #0x0]\n" + "str s18, [x23, #0x20]\n" + "str s17, [x23, #0x40]\n" + "str s16, [x23, #0x60]\n" + "add x23, x23, #0x4\n" + "bge 9b\n" + "10:" // Main row loop: width 1 loop: skip + "11:" // Main row loop: odd col skip + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0x80\n" + "bge 4b\n" + "cbz %x[height], 21f\n" + "12:" // Main loop skip + "13:" // Tail row loop: Head + "mov x20, %x[width]\n" + "mov x25, %x[in]\n" + "mov x23, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "cmp x20, #0x8\n" + "add %x[in], x25, %x[in_stride]\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr q17, [x25], #0x10\n" + "sub x20, x20, #0x8\n" + "ldr q16, [x25], #0x10\n" + "cmp x20, #0x8\n" + "str q17, [x23, #0x0]\n" + "str q16, [x23, #0x10]\n" + "add x23, x23, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cbz x20, 20f\n" + "cmp x20, #0x4\n" + "movi v16.4s, #0x0\n" + "str q16, [x23, #0x0]\n" + "str q16, [x23, #0x10]\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr q16, [x25], #0x10\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "str q16, [x23, #0x0]\n" + "add x23, x23, #0x10\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr s16, [x25], #0x4\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "str s16, [x23, #0x0]\n" + "add x23, x23, #0x4\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "20:" // Tail row loop: odd col skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x20\n" + "bge 13b\n" + "21:" // Done + : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) + : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [width] "r"(width) + : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24", + "x25"); + + // printf("Post pack \n"); + + // for (size_t i = 0; i < 6; i++) + // { + // for (size_t j = 0; j < kai_nr + n ; j++) + // { + // printf("%.1f ",rhs_p[i*(n + kai_nr) + j]); + // } + // printf("\n"); + // } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h new file mode 100644 index 00000000..3749276b --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h @@ -0,0 +1,80 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting row index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx, size_t k); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of rows. +/// @param[in] k Number of columns. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n, size_t k); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon. +/// * Bias: @ref kai_get_packed_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon. +/// * Output: @ref kai_get_dst_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon. +/// +/// @param[in] num_groups Number of groups. It must be 1. +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] nr Block size in N dimension. It must be 8. +/// @param[in] kr Block size in K dimension. It must be 1. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. It must be NULL. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. +/// @param[in] params Extra packing parameters. It must be NULL. +void kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index ac534645..de4a2ee0 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -41,6 +41,9 @@ #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +// matmul_clamp_f32_f32_f32p +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h" namespace kai::test { // NOLINTBEGIN(misc-non-private-member-variables-in-classes) @@ -221,8 +224,8 @@ struct MatMulMethod { /// @return The size in bytes of the destination matrix buffer. std::function fn_get_dst_size; - /// Performs F16 matrix multiplication with RHS packing followed by - /// clamp operation. + /// Performs F16 or F32 matrix multiplication with RHS packing + /// followed by clamp operation. /// /// @param[in] m Size of the matrix in M dimension. /// @param[in] n Size of the matrix in N dimension. @@ -242,6 +245,14 @@ struct MatMulMethod { Float16 clamp_min, Float16 clamp_max)> fn_matmul_f16_f16_f16p; + std::function + fn_matmul_f32_f32_f32p; + /// Performs F32 matrix multiplication with LHS & RHS packing /// followed by clamp operation. /// @@ -295,7 +306,8 @@ struct MatMulMethod { } [[nodiscard]] bool has_main_kernel() const { - return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr; + return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr || + fn_matmul_f32_f32_f32p != nullptr; } void main_kernel( @@ -303,10 +315,13 @@ struct MatMulMethod { size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const { KAI_UNUSED(bias); KAI_UNUSED(rhs_stride); - if (fn_matmul_f16_f16_f16p) { fn_matmul_f16_f16_f16p( - m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), static_cast(clamp_min), + m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), clamp_min, + static_cast(clamp_max)); + } else if (fn_matmul_f32_f32_f32p) { + fn_matmul_f32_f32_f32p( + m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, static_cast(clamp_max)); } else if (fn_matmul_f32_f32p_f32p) { fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); @@ -365,6 +380,56 @@ static const std::array matmul_methods = { .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, .fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_matmul_f32_f32_f32p = nullptr, + .fn_matmul_f32_f32p_f32p = nullptr, + }, + + MatMulMethod{ + .name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla", + + .m0 = 6, + .n0 = 8, + + .lhs_transposed = false, + .rhs_transposed = false, + + .is_sme2 = false, + + .dst_format = DataFormat(DataType::FP32), + .lhs_format = DataFormat(DataType::FP32), + .packed_lhs_format = DataFormat(DataType::UNKNOWN), + .rhs_format = DataFormat(DataType::FP32), + .packed_rhs_format = DataFormat( + DataType::FP32, 8, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 8, 1), + .bias_format = DataFormat(DataType::FP32), + + .fn_get_mr = nullptr, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + + .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_get_packed_lhs_size = nullptr, + .fn_get_packed_lhs_offset = nullptr, + .fn_pack_lhs = nullptr, + + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, + .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, + .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, + + .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, + + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + + .fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_matmul_f16_f16_f16p = nullptr, .fn_matmul_f32_f32p_f32p = nullptr, }, @@ -416,6 +481,7 @@ static const std::array matmul_methods = { .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, .fn_matmul_f16_f16_f16p = nullptr, + .fn_matmul_f32_f32_f32p = nullptr, .fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, }, }; -- GitLab From ebd9d72ae14c5357c535e2a3f54f313127829ad2 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Mon, 15 Jul 2024 11:19:33 +0100 Subject: [PATCH 02/12] Remove debug prints Signed-off-by: Felix Thomasmathibalan --- .../kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c | 5 ----- 1 file changed, 5 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c index fe88014e..e6f48471 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c @@ -11,7 +11,6 @@ #include #include #include -#include #include "kai/kai_common.h" @@ -110,10 +109,6 @@ void kai_run_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla( ka.maxval = clamp_max; ka.minval = clamp_min; - printf( - "*******F32************* N: %zu , ka.output_offset: %zu, input_offset: %zu m:%zu, n:%zu \n", ka.N, - ka.output_offset, ka.input_offset, m, n); - __asm__ __volatile__( "1:" // Row loop "cmp %x[m], #0x6\n" -- GitLab From e2a899dc04841997d6035fd89bc36d2ae27019f5 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Mon, 15 Jul 2024 11:47:39 +0100 Subject: [PATCH 03/12] Fix pre-commit job fail Signed-off-by: Felix Thomasmathibalan --- test/tests/matmul_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index de4a2ee0..c20f7e98 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -428,8 +428,8 @@ static const std::array matmul_methods = { .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, - .fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, .fn_matmul_f16_f16_f16p = nullptr, + .fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, .fn_matmul_f32_f32p_f32p = nullptr, }, -- GitLab From a5335f9b4d705182381f98ac37621a345d5cd25a Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Mon, 15 Jul 2024 17:14:08 +0100 Subject: [PATCH 04/12] Update Bazel build files to include FP32 Signed-off-by: Felix Thomasmathibalan --- kai/ukernels/matmul/BUILD.bazel | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 013a96c3..e6f552aa 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -33,6 +33,13 @@ kai_c_library( ], ) +kai_c_library( + name = "clamp_f32_f32_f32p", + srcs = ["matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c"], + hdrs = ["matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h"], + cpu_uarch = kai_cpu_neon(), +) + kai_c_library( name = "clamp_f32_f32p_f32p", srcs = ["matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c"], @@ -126,6 +133,13 @@ kai_c_library( cpu_uarch = kai_cpu_neon(), ) +kai_c_library( + name = "rhs_pack_kxn_f32pbiasf32_f32_f32_neon", + srcs = ["pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c"], + hdrs = ["pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h"], + cpu_uarch = kai_cpu_neon(), +) + kai_c_library( name = "rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", srcs = ["pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c"], @@ -151,6 +165,7 @@ kai_c_library( name = "matmul", deps = [ ":clamp_f16_f16_f16p", + ":clamp_f32_f32_f32p", ":clamp_f32_f32p_f32p", ":clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod", ":clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod", @@ -161,6 +176,7 @@ kai_c_library( ":lhs_pack_f32p2vlx1_f32_sme", ":lhs_quant_pack_qai8dxp_f32", ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", + ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", ":rhs_pack_kxn_qsi4cxp_qsu4cxs1s0", ":rhs_pack_nxk_qsi4cxp_qsu4cxs1s0", -- GitLab From 5d6f72619d0bb4449c2e9dbdfba167c0f233be8e Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Tue, 16 Jul 2024 08:18:20 +0100 Subject: [PATCH 05/12] Fix Bazel pre-commit issue Signed-off-by: Felix Thomasmathibalan --- kai/ukernels/matmul/BUILD.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index e6f552aa..768bd29b 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -176,8 +176,8 @@ kai_c_library( ":lhs_pack_f32p2vlx1_f32_sme", ":lhs_quant_pack_qai8dxp_f32", ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", - ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", + ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", ":rhs_pack_kxn_qsi4cxp_qsu4cxs1s0", ":rhs_pack_nxk_qsi4cxp_qsu4cxs1s0", ], -- GitLab From 5875809e7174f5a44ab4fc820471b4b039421dbf Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Tue, 16 Jul 2024 08:38:16 +0100 Subject: [PATCH 06/12] Regenerate pack functions Signed-off-by: Felix Thomasmathibalan --- ...rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c | 27 ------------------- 1 file changed, 27 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c index 378acb9d..0c6b0074 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c @@ -54,22 +54,6 @@ void kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon( KAI_ASSUME(extra_bytes == 0); KAI_ASSUME(params == NULL); - printf( - "n:%zu, k:%zu, nr:%zu, kr:%zu, sr:%zu, rhs_stride:%zu, extra_bytes:%zu \n", n, k, nr, kr, sr, rhs_stride, - extra_bytes); - - // float * rhs_f = (float*)rhs; - // float *rhs_p = (float*) rhs_packed; - - // for (size_t i = 0; i < n; i++) - // { - // for (size_t j = 0; j < k; j++) - // { - // printf("%.1f ",rhs_f[i*k + j]); - // } - // printf("\n"); - // } - size_t height = k; const size_t width = n; const void* in = rhs; @@ -237,17 +221,6 @@ void kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon( : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [width] "r"(width) : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24", "x25"); - - // printf("Post pack \n"); - - // for (size_t i = 0; i < 6; i++) - // { - // for (size_t j = 0; j < kai_nr + n ; j++) - // { - // printf("%.1f ",rhs_p[i*(n + kai_nr) + j]); - // } - // printf("\n"); - // } } #endif // Architectural features check. -- GitLab From 35d5b409a6157e95e7b271a9b8f1c6a99aee32c3 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Fri, 19 Jul 2024 11:11:40 +0100 Subject: [PATCH 07/12] Address review comment Signed-off-by: Felix Thomasmathibalan --- .../pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c | 6 +++++- .../pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c index 0c6b0074..b56ee639 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c @@ -30,10 +30,14 @@ size_t kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx return n_idx * sizeof(uint32_t); } +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t k) { + return (sizeof(uint32_t) + k * sizeof(uint32_t)); +} + size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_nr == 0); - return n_idx * (sizeof(uint32_t) + k * sizeof(uint32_t)); + return n_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(k); } size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n, size_t k) { diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h index 3749276b..cbbfb0d0 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h @@ -41,6 +41,13 @@ size_t kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx /// @return The offset in bytes to the data element. size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx, size_t k); +/// Gets row stride in bytes to the packed RHS buffer. +/// +/// @param[in] k Number of rows of unpacked RHS. +/// +/// @return The row stride in bytes. +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t k); + /// Gets the size in bytes of the packed RHS buffer. /// /// @param[in] n Number of rows. -- GitLab From a6996624178d511ab83feefdabe0fb59d4274a3f Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Fri, 19 Jul 2024 14:44:54 +0100 Subject: [PATCH 08/12] Address review comment Signed-off-by: Felix Thomasmathibalan --- CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d63efae2..4d875266 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,7 +85,7 @@ set(KLEIDIAI_FILES_NEON_FP16 kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c ) -set(KLEIDIAI_FILES_NEON_FP32 +set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c ) @@ -121,7 +121,7 @@ target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SCALAR}) # https://learn.microsoft.com/en-us/cpp/assembler/inline/inline-assembler?view=msvc-170 # if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") AND NOT MSVC) - target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_FP32}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_FP16}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM}) @@ -129,7 +129,7 @@ if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2}) set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a) - set_source_files_properties(${KLEIDIAI_FILES_NEON_FP32} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a) + set_source_files_properties(${KLEIDIAI_FILES_NEON} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a) set_source_files_properties(${KLEIDIAI_FILES_NEON_FP16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+fp16) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm) -- GitLab From 7d64c8ac0d5661aca916c8fd40b86ebe0c321dc8 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Thu, 1 Aug 2024 16:29:38 +0100 Subject: [PATCH 09/12] Revert "Address review comment" This reverts commit fd9db1ca2474b4e3d254058b9fa2b9597d71c9bd. Signed-off-by: Felix Thomasmathibalan --- .../pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c | 6 +----- .../pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h | 7 ------- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c index b56ee639..0c6b0074 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c @@ -30,14 +30,10 @@ size_t kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx return n_idx * sizeof(uint32_t); } -size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t k) { - return (sizeof(uint32_t) + k * sizeof(uint32_t)); -} - size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_nr == 0); - return n_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(k); + return n_idx * (sizeof(uint32_t) + k * sizeof(uint32_t)); } size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n, size_t k) { diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h index cbbfb0d0..3749276b 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h @@ -41,13 +41,6 @@ size_t kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx /// @return The offset in bytes to the data element. size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx, size_t k); -/// Gets row stride in bytes to the packed RHS buffer. -/// -/// @param[in] k Number of rows of unpacked RHS. -/// -/// @return The row stride in bytes. -size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t k); - /// Gets the size in bytes of the packed RHS buffer. /// /// @param[in] n Number of rows. -- GitLab From 1b11b46f9688ae72ae6ae06eb4ecdd8e001633c4 Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Tue, 6 Aug 2024 10:40:39 +0000 Subject: [PATCH 10/12] Add micro-kernel to compute FP16 GEMV MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Compute the general matrix-vector (GEMV) multiplication between an FP16 LHS and RHS and accumulate into FP16 output. The RHS packs FP16 weights and biases together. * Optimized for Arm® Neon™ using MLA instructions. * Add accompanying tests. Signed-off-by: Jakub Sujak Approved-by: Viet-Hoa Do --- CMakeLists.txt | 1 + kai/ukernels/matmul/BUILD.bazel | 15 +- ..._f16_f16_f16p16x1biasf16_1x16x8_neon_mla.c | 402 ++++++++++++++++++ ..._f16_f16_f16p16x1biasf16_1x16x8_neon_mla.h | 126 ++++++ test/tests/matmul_test.cpp | 51 +++ 5 files changed, 593 insertions(+), 2 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4d875266..8c272f8b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,6 +82,7 @@ set(KLEIDIAI_FILES_SCALAR set(KLEIDIAI_FILES_NEON_FP16 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c + kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c ) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 768bd29b..3487ae5c 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -24,7 +24,17 @@ kai_c_library( ) kai_c_library( - name = "clamp_f16_f16_f16p", + name = "clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla", + srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.c"], + hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.h"], + cpu_uarch = kai_cpu_fp16(), + deps = [ + ":clamp_f16_f16_f16p_interface", + ], +) + +kai_c_library( + name = "clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla", srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c"], hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h"], cpu_uarch = kai_cpu_fp16(), @@ -164,7 +174,8 @@ kai_c_library( kai_c_library( name = "matmul", deps = [ - ":clamp_f16_f16_f16p", + ":clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla", + ":clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla", ":clamp_f32_f32_f32p", ":clamp_f32_f32p_f32p", ":clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod", diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.c b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.c new file mode 100644 index 00000000..33bdaa66 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.c @@ -0,0 +1,402 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || \ + !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_FP16. +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 1; +static const size_t kai_nr = 16; +static const size_t kai_kr = 1; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(void) { + return kai_mr; +} + +size_t kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(void) { + return kai_nr; +} + +size_t kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(void) { + return kai_sr; +} + +size_t kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(size_t m_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + + return m_idx * stride; +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx / kai_nr * (kai_nr * sizeof(__fp16) + kai_nr * k * sizeof(__fp16)); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla( + size_t m_idx, size_t n_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + KAI_ASSUME(n_idx % kai_nr == 0); + + return m_idx * stride + n_idx * sizeof(__fp16); +} + +size_t kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(size_t m, size_t n) { + return m * n * sizeof(__fp16); +} + +void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla( + size_t m, size_t n, size_t k, // + const void* lhs, size_t lhs_stride, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + __fp16 clamp_min, __fp16 clamp_max) { + KAI_ASSERT(dst_stride_col == sizeof(__fp16)); + + typedef struct { + __fp16 maxval; + __fp16 minval; + unsigned int num_strings; + const unsigned int* string_lengths; + size_t N; + const void* B_ptr; + size_t output_offset; + size_t input_initial_col; + size_t input_offset; + void* output_ptr; + const void* bias; + } KernelArgs; + + KernelArgs ka; + + unsigned long flags = 0; + + unsigned int string_length = k; + ka.num_strings = 1; + ka.string_lengths = &string_length; + ka.N = n; + ka.B_ptr = rhs_packed; + ka.bias = NULL; + + // Direct input. + const void* input_ptr = lhs; + ka.input_offset = lhs_stride / sizeof(__fp16); + ka.input_initial_col = 0; + + // Direct output. + ka.output_ptr = dst; + ka.output_offset = dst_stride_row / sizeof(__fp16); + + // Clamping output. + flags |= 0x2; + ka.maxval = clamp_max; + ka.minval = clamp_min; + + __asm__ __volatile__( + "1:" // Row loop + "ldr x21, [%x[args_ptr], %[offsetof_output_offset]]\n" + "ldr x26, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "mov x20, #0x2\n" + "ldr x25, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x24, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "madd x20, x21, x20, x26\n" + "str x20, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "2:" // Height 1: Column loop + "cbz x24, 3f\n" + "ldr q30, [x24, #0x0]\n" + "ldr q31, [x24, #0x10]\n" + "add x24, x24, #0x20\n" + "b 14f\n" + "3:" // Height 1: no bias + "tbz %x[flags], #0, 13f\n" + "cmp x25, #0x10\n" + "bge 12f\n" + "tbz x25, #3, 7f\n" + "ld1 { v30.8h }, [x26], #0x10\n" + "tbz x25, #2, 5f\n" + "ldr d31, [x26], #0x8\n" + "tbz x25, #1, 4f\n" + "ld1 { v31.s }[2], [x26], #0x4\n" + "mov x20, #0x1c\n" + "tbz x25, #0, 11f\n" + "ld1 { v31.h }[6], [x26]\n" + "b 11f\n" + "4:" // Height 1: Partial accumulate: partial_1_12 + "mov x20, #0x18\n" + "tbz x25, #0, 11f\n" + "ld1 { v31.h }[4], [x26]\n" + "b 11f\n" + "5:" // Height 1: Partial accumulate: partial_2_8 + "tbz x25, #1, 6f\n" + "ldr s31, [x26], #0x4\n" + "mov x20, #0x14\n" + "tbz x25, #0, 11f\n" + "ld1 { v31.h }[2], [x26]\n" + "b 11f\n" + "6:" // Height 1: Partial accumulate: partial_1_8 + "mov x20, #0x10\n" + "tbz x25, #0, 11f\n" + "ldr h31, [x26, #0x0]\n" + "b 11f\n" + "7:" // Height 1: Partial accumulate: partial_4_0 + "tbz x25, #2, 9f\n" + "ldr d30, [x26], #0x8\n" + "tbz x25, #1, 8f\n" + "ld1 { v30.s }[2], [x26], #0x4\n" + "mov x20, #0xc\n" + "tbz x25, #0, 11f\n" + "ld1 { v30.h }[6], [x26]\n" + "b 11f\n" + "8:" // Height 1: Partial accumulate: partial_1_4 + "mov x20, #0x8\n" + "tbz x25, #0, 11f\n" + "ld1 { v30.h }[4], [x26]\n" + "b 11f\n" + "9:" // Height 1: Partial accumulate: partial_2_0 + "tbz x25, #1, 10f\n" + "ldr s30, [x26], #0x4\n" + "mov x20, #0x4\n" + "tbz x25, #0, 11f\n" + "ld1 { v30.h }[2], [x26]\n" + "b 11f\n" + "10:" // Height 1: Partial accumulate: partial_1_0 + "ldr h30, [x26, #0x0]\n" + "mov x20, #0x0\n" + "11:" // Height 1: Partial accumulate: Done + "sub x26, x26, x20\n" + "b 14f\n" + "12:" // Height 1: full accumulate + "ldr q30, [x26, #0x0]\n" + "ldr q31, [x26, #0x10]\n" + "b 14f\n" + "13:" // Height 1: no accumulate + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "14:" // Height 1: setup done + "mov x23, #0x0\n" + "15:" // Height 1: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w22, [x20, x23, LSL #0x2]\n" + "tbz %x[flags], #3, 16f\n" + "ldr x20, [%x[input_ptr], x23, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x21, [x20, #0x0]\n" + "cbnz x23, 17f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x21, x21, x20, LSL #1\n" + "b 17f\n" + "16:" // Height 1: setup direct input + "mov x21, %x[input_ptr]\n" + "17:" // Height 1: input setup done + "cmp x22, #0x8\n" + "blt 20f\n" + "ldr q0, [x21, #0x0]\n" + "ldr q1, [x24, #0x0]\n" + "cmp x22, #0x10\n" + "ldr q2, [x24, #0x10]\n" + "ldr q3, [x24, #0x20]\n" + "ldr q4, [x24, #0x30]\n" + "ldr q5, [x24, #0x40]\n" + "ldr q6, [x24, #0x50]\n" + "ldr q7, [x24, #0x60]\n" + "ldr q8, [x24, #0x70]\n" + "ldr q9, [x24, #0x80]\n" + "ldr q10, [x24, #0x90]\n" + "ldr q11, [x24, #0xa0]\n" + "ldr q12, [x24, #0xb0]\n" + "ldr q13, [x24, #0xc0]\n" + "ldr q14, [x24, #0xd0]\n" + "ldr q15, [x24, #0xe0]\n" + "ldr q16, [x24, #0xf0]\n" + "blt 19f\n" + "18:" // Height 1: Multiply loop: Main loop head + "fmla v30.8h, v1.8h, v0.h[0]\n" + "fmla v31.8h, v2.8h, v0.h[0]\n" + "sub x22, x22, #0x8\n" + "add x21, x21, #0x10\n" + "cmp x22, #0x10\n" + "add x24, x24, #0x100\n" + "prfm pldl1keep, [x21, #0x80]\n" + "ldr q1, [x24, #0x0]\n" + "ldr q2, [x24, #0x10]\n" + "fmla v30.8h, v3.8h, v0.h[1]\n" + "ldr q3, [x24, #0x20]\n" + "fmla v31.8h, v4.8h, v0.h[1]\n" + "ldr q4, [x24, #0x30]\n" + "fmla v30.8h, v5.8h, v0.h[2]\n" + "ldr q5, [x24, #0x40]\n" + "fmla v31.8h, v6.8h, v0.h[2]\n" + "ldr q6, [x24, #0x50]\n" + "fmla v30.8h, v7.8h, v0.h[3]\n" + "ldr q7, [x24, #0x60]\n" + "fmla v31.8h, v8.8h, v0.h[3]\n" + "ldr q8, [x24, #0x70]\n" + "fmla v30.8h, v9.8h, v0.h[4]\n" + "ldr q9, [x24, #0x80]\n" + "fmla v31.8h, v10.8h, v0.h[4]\n" + "ldr q10, [x24, #0x90]\n" + "fmla v30.8h, v11.8h, v0.h[5]\n" + "ldr q11, [x24, #0xa0]\n" + "fmla v31.8h, v12.8h, v0.h[5]\n" + "ldr q12, [x24, #0xb0]\n" + "fmla v30.8h, v13.8h, v0.h[6]\n" + "ldr q13, [x24, #0xc0]\n" + "fmla v31.8h, v14.8h, v0.h[6]\n" + "ldr q14, [x24, #0xd0]\n" + "fmla v30.8h, v15.8h, v0.h[7]\n" + "ldr q15, [x24, #0xe0]\n" + "fmla v31.8h, v16.8h, v0.h[7]\n" + "ldr q0, [x21, #0x0]\n" + "ldr q16, [x24, #0xf0]\n" + "bge 18b\n" + "19:" // Height 1: Multiply loop: Single iteration only + "fmla v30.8h, v1.8h, v0.h[0]\n" + "fmla v31.8h, v2.8h, v0.h[0]\n" + "add x21, x21, #0x10\n" + "sub x22, x22, #0x8\n" + "add x24, x24, #0x100\n" + "prfm pldl1keep, [x21, #0x80]\n" + "fmla v30.8h, v3.8h, v0.h[1]\n" + "fmla v31.8h, v4.8h, v0.h[1]\n" + "fmla v30.8h, v5.8h, v0.h[2]\n" + "fmla v31.8h, v6.8h, v0.h[2]\n" + "fmla v30.8h, v7.8h, v0.h[3]\n" + "fmla v31.8h, v8.8h, v0.h[3]\n" + "fmla v30.8h, v9.8h, v0.h[4]\n" + "fmla v31.8h, v10.8h, v0.h[4]\n" + "fmla v30.8h, v11.8h, v0.h[5]\n" + "fmla v31.8h, v12.8h, v0.h[5]\n" + "fmla v30.8h, v13.8h, v0.h[6]\n" + "fmla v31.8h, v14.8h, v0.h[6]\n" + "fmla v30.8h, v15.8h, v0.h[7]\n" + "fmla v31.8h, v16.8h, v0.h[7]\n" + "20:" // Height 1: Multiply loop: Main loop skip + "cbz x22, 22f\n" + "21:" // Height 1: Multiply loop: Odd block loop + "ldr h0, [x21], #0x2\n" + "ldr q17, [x24, #0x0]\n" + "sub x22, x22, #0x1\n" + "ldr q18, [x24, #0x10]\n" + "add x24, x24, #0x20\n" + "fmla v30.8h, v17.8h, v0.h[0]\n" + "fmla v31.8h, v18.8h, v0.h[0]\n" + "cbnz x22, 21b\n" + "22:" // Height 1: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x23, x23, #0x1\n" + "cmp x23, x20\n" + "bne 15b\n" + "prfm pstl1keep, [x26, #0x0]\n" + "tbz %x[flags], #1, 23f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v17.8h }, [x21]\n" + "ld1r { v16.8h }, [x20]\n" + "fmin v30.8h, v30.8h, v17.8h\n" + "fmin v31.8h, v31.8h, v17.8h\n" + "fmax v30.8h, v30.8h, v16.8h\n" + "fmax v31.8h, v31.8h, v16.8h\n" + "23:" // Height 1: No activation + "cmp x25, #0x10\n" + "bge 32f\n" + "tbz x25, #3, 27f\n" + "st1 { v30.8h }, [x26], #0x10\n" + "tbz x25, #2, 25f\n" + "str d31, [x26], #0x8\n" + "tbz x25, #1, 24f\n" + "st1 { v31.s }[2], [x26], #0x4\n" + "tbz x25, #0, 31f\n" + "st1 { v31.h }[6], [x26]\n" + "b 31f\n" + "24:" // Height 1: Partial direct writeback: partial_1_12 + "tbz x25, #0, 31f\n" + "st1 { v31.h }[4], [x26]\n" + "b 31f\n" + "25:" // Height 1: Partial direct writeback: partial_2_8 + "tbz x25, #1, 26f\n" + "str s31, [x26], #0x4\n" + "tbz x25, #0, 31f\n" + "st1 { v31.h }[2], [x26]\n" + "b 31f\n" + "26:" // Height 1: Partial direct writeback: partial_1_8 + "tbz x25, #0, 31f\n" + "str h31, [x26, #0x0]\n" + "b 31f\n" + "27:" // Height 1: Partial direct writeback: partial_4_0 + "tbz x25, #2, 29f\n" + "str d30, [x26], #0x8\n" + "tbz x25, #1, 28f\n" + "st1 { v30.s }[2], [x26], #0x4\n" + "tbz x25, #0, 31f\n" + "st1 { v30.h }[6], [x26]\n" + "b 31f\n" + "28:" // Height 1: Partial direct writeback: partial_1_4 + "tbz x25, #0, 31f\n" + "st1 { v30.h }[4], [x26]\n" + "b 31f\n" + "29:" // Height 1: Partial direct writeback: partial_2_0 + "tbz x25, #1, 30f\n" + "str s30, [x26], #0x4\n" + "tbz x25, #0, 31f\n" + "st1 { v30.h }[2], [x26]\n" + "b 31f\n" + "30:" // Height 1: Partial direct writeback: partial_1_0 + "str h30, [x26, #0x0]\n" + "31:" // Height 1: Partial direct writeback: Done + "b 33f\n" + "32:" // Height 1: Full writeback + "str q30, [x26, #0x0]\n" + "str q31, [x26, #0x10]\n" + "add x26, x26, #0x20\n" + "33:" // Height 1: Writeback done + "subs x25, x25, #0x10\n" + "bgt 2b\n" + "subs %x[m], %x[m], #0x1\n" + "beq 35f\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 34f\n" + "add x21, x21, #0x1\n" + "str x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "34:" // Update direct input + "mov x20, #0x2\n" + "madd %x[input_ptr], x20, x21, %x[input_ptr]\n" + "b 1b\n" + "35:" // Exit + : [input_ptr] "+&r"(input_ptr), [m] "+&r"(m) + : [args_ptr] "r"(&ka), [flags] "r"(flags), [offset_max] "I"(offsetof(KernelArgs, maxval)), + [offset_min] "I"(offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I"(offsetof(KernelArgs, B_ptr)), + [offsetof_N] "I"(offsetof(KernelArgs, N)), + [offsetof_input_initial_col] "I"(offsetof(KernelArgs, input_initial_col)), + [offsetof_input_offset] "I"(offsetof(KernelArgs, input_offset)), + [offsetof_num_strings] "I"(offsetof(KernelArgs, num_strings)), + [offsetof_output_offset] "I"(offsetof(KernelArgs, output_offset)), + [offsetof_output_ptr] "I"(offsetof(KernelArgs, output_ptr)), + [offsetof_string_lengths] "I"(offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.h b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.h new file mode 100644 index 00000000..178b70a1 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.h @@ -0,0 +1,126 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || \ + !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_FP16. +#else // Architectural features check. + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(void); + +/// Gets the offset in bytes to the data element in the LHS matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(size_t m_idx, size_t stride); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] n_idx Column index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla( + size_t m_idx, size_t n_idx, size_t stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operand. +/// @param[in] lhs LHS matrix buffer. +/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. +/// @param[in] rhs_packed Packed RHS buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(__fp16) +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla( + size_t m, size_t n, size_t k, // + const void* lhs, size_t lhs_stride, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + __fp16 clamp_min, __fp16 clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // Architectural features check. diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index c20f7e98..023ba7f6 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -32,6 +32,9 @@ #include "test/reference/fill.hpp" #include "test/reference/pack.hpp" +// matmul_nt_nt_fp16_fp16_fp16_1x16_neon_mla +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla.h" + // matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" @@ -335,6 +338,54 @@ struct MatMulMethod { /// List of supported matrix multiplication methods. static const std::array matmul_methods = { + MatMulMethod{ + .name = "matmul_nt_nt_fp16_fp16_fp16_1x16_neon_mla", + + .m0 = 1, + .n0 = 16, + + .lhs_transposed = false, + .rhs_transposed = false, + + .is_sme2 = false, + + .dst_format = DataFormat(DataType::FP16), + .lhs_format = DataFormat(DataType::FP16), + .packed_lhs_format = DataFormat(DataType::UNKNOWN), + .rhs_format = DataFormat(DataType::FP16), + .packed_rhs_format = DataFormat( + DataType::FP16, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 16, 1), + .bias_format = DataFormat(DataType::FP16), + + .fn_get_mr = nullptr, + .fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla, + .fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla, + .fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla, + + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla, + + .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla, + .fn_get_packed_lhs_size = nullptr, + .fn_get_packed_lhs_offset = nullptr, + .fn_pack_lhs = nullptr, + + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, + .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, + .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla, + .fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, + + .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, + + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla, + + .fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_1x16x8_neon_mla, + .fn_matmul_f32_f32p_f32p = nullptr, + }, + MatMulMethod{ .name = "matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla", -- GitLab From 393f557d4153d89e88219f656a3ea345d774e750 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Mon, 15 Jul 2024 17:14:08 +0100 Subject: [PATCH 11/12] Update Bazel build files to include FP32 Signed-off-by: Felix Thomasmathibalan --- kai/ukernels/matmul/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 3487ae5c..9c99c07b 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -187,6 +187,7 @@ kai_c_library( ":lhs_pack_f32p2vlx1_f32_sme", ":lhs_quant_pack_qai8dxp_f32", ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", + ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", ":rhs_pack_kxn_qsi4cxp_qsu4cxs1s0", -- GitLab From 3181086fd50d265ca8214ef49e2d151bd3016cc0 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Wed, 7 Aug 2024 10:55:11 +0100 Subject: [PATCH 12/12] Renamed to match FP16 Signed-off-by: Felix Thomasmathibalan --- CMakeLists.txt | 2 +- kai/ukernels/matmul/BUILD.bazel | 4 +-- ...p_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c} | 23 +++++++++------- ...p_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h} | 26 +++++++++---------- test/tests/matmul_test.cpp | 22 ++++++++-------- 5 files changed, 40 insertions(+), 37 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/{kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c => kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c} (98%) rename kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/{kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h => kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h} (77%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8c272f8b..d08600f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,7 +88,7 @@ set(KLEIDIAI_FILES_NEON_FP16 set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c - kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c + kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c ) set(KLEIDIAI_FILES_NEON_DOTPROD diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 9c99c07b..3b2a5288 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -45,8 +45,8 @@ kai_c_library( kai_c_library( name = "clamp_f32_f32_f32p", - srcs = ["matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c"], - hdrs = ["matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h"], + srcs = ["matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c"], + hdrs = ["matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h"], cpu_uarch = kai_cpu_neon(), ) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c similarity index 98% rename from kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c rename to kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c index e6f48471..62723018 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64. #else // Architectural features check. +#include "kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" + #include #include #include @@ -19,50 +21,51 @@ static const size_t kai_nr = 8; static const size_t kai_kr = 1; static const size_t kai_sr = 1; -size_t kai_get_m_step_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void) { +size_t kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(void) { return kai_mr; } -size_t kai_get_n_step_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void) { +size_t kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(void) { return kai_nr; } -size_t kai_get_nr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void) { +size_t kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void) { +size_t kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void) { +size_t kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(void) { return kai_sr; } -size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m_idx, size_t stride) { +size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(size_t m_idx, size_t stride) { KAI_ASSUME(m_idx % kai_mr == 0); return m_idx * stride; } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_nr == 0); return n_idx / kai_nr * (kai_nr * sizeof(float) + kai_nr * k * sizeof(float)); } -size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m_idx, size_t n_idx, size_t stride) { +size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla( + size_t m_idx, size_t n_idx, size_t stride) { KAI_ASSUME(m_idx % kai_mr == 0); KAI_ASSUME(n_idx % kai_nr == 0); return m_idx * stride + n_idx * sizeof(float); } -size_t kai_get_dst_size_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla( +void kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla( size_t m, size_t n, size_t k, // const void* lhs, size_t lhs_stride, // const void* rhs_packed, // diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h similarity index 77% rename from kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h rename to kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h index 274d4453..0a733f1e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h @@ -27,35 +27,35 @@ extern "C" { /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. -size_t kai_get_m_step_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void); +size_t kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(void); /// Gets n step value. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void); +size_t kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(void); /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The nr value. -size_t kai_get_nr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void); +size_t kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(void); /// Gets kr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The kr value. -size_t kai_get_kr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void); +size_t kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(void); /// Gets sr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The sr value. -size_t kai_get_sr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void); +size_t kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(void); /// Gets the offset in bytes to the data element in the LHS matrix buffer. /// @@ -63,7 +63,7 @@ size_t kai_get_sr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(void); /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m_idx, size_t stride); +size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(size_t m_idx, size_t stride); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// @@ -71,7 +71,7 @@ size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// @@ -80,7 +80,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(s /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m_idx, size_t n_idx, size_t stride); +size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(size_t m_idx, size_t n_idx, size_t stride); /// Gets the size in bytes of the destination matrix buffer. /// @@ -88,16 +88,16 @@ size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. -size_t kai_get_dst_size_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla(size_t m, size_t n); /// Runs the matrix multiplication microkernel followed by a clamp operation. /// /// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset /// calculated using the following functions: /// -/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla. -/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla. -/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla. +/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla. /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. @@ -110,7 +110,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla(size_t m, /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. -void kai_run_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla( +void kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla( size_t m, size_t n, size_t k, // const void* lhs, size_t lhs_stride, // const void* rhs_packed, // diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index 023ba7f6..6810910b 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -45,7 +45,7 @@ #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" // matmul_clamp_f32_f32_f32p -#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h" namespace kai::test { @@ -455,15 +455,15 @@ static const std::array matmul_methods = { .bias_format = DataFormat(DataType::FP32), .fn_get_mr = nullptr, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, .fn_get_packed_lhs_size = nullptr, .fn_get_packed_lhs_offset = nullptr, .fn_pack_lhs = nullptr, @@ -471,16 +471,16 @@ static const std::array matmul_methods = { .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, .fn_matmul_f16_f16_f16p = nullptr, - .fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32pbiasf32_6x8_neon_mla, + .fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, .fn_matmul_f32_f32p_f32p = nullptr, }, -- GitLab