From f5d0f09502c3ebae0f740cd5ca811df5a0b683bf Mon Sep 17 00:00:00 2001 From: Michael Kozlov Date: Thu, 23 Jan 2025 16:33:13 +0000 Subject: [PATCH 01/15] Extract inline assembly kernels into external files: 1x4 Signed-off-by: Michael Kozlov --- CMakeLists.txt | 1 + ...8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S | 155 ++++++++++++++++++ ..._qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c | 128 ++++----------- 3 files changed, 186 insertions(+), 98 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index 074bc1d4..515bf8de 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -142,6 +142,7 @@ set(KLEIDIAI_FILES_NEON_DOTPROD kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S ) set(KLEIDIAI_FILES_NEON_I8MM diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S new file mode 100644 index 00000000..a7e62551 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S @@ -0,0 +1,155 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + mov x15, #0x20 + movi v28.16b, #0xf0 + mov x21, #0x8 + ldr x14, [x0, #0x40] + ldr x13, [x0, #0x38] + ldr x20, [x0, #0x28] + ldr x12, [x0, #0x8] + ldr x11, [x0, #0x10] + ldr x10, [x0, #0x30] + mul x15, x14, x15 + ldr x9, [x0, #0x0] + ldr x28, [x0, #0x20] + ldr x27, [x0, #0x18] + mov x26, x20 + madd x15, x13, x15, x21 +KAI_ASM_LABEL(label_1) // Row loop + mov x25, x11 + mov x24, x10 + add x23, x9, x28 +KAI_ASM_LABEL(label_2) // Column loop + mov x22, x12 + movi v27.16b, #0x0 + mov x21, x13 +KAI_ASM_LABEL(label_3) // Block loop + movi v26.4s, #0x0 + mov x20, x14 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q25, [x25, #0x0] + ldr q24, [x22, #0x0] + subs x20, x20, #0x1 + ldr q23, [x25, #0x10] + ldr q22, [x25, #0x20] + ldr q21, [x25, #0x30] + ldr q20, [x22, #0x10] + add x25, x25, #0x40 + add x22, x22, #0x20 + shl v19.16b, v25.16b, #0x4 + and v25.16b, v25.16b, v28.16b + shl v18.16b, v23.16b, #0x4 + shl v17.16b, v22.16b, #0x4 + shl v16.16b, v21.16b, #0x4 + and v23.16b, v23.16b, v28.16b + KAI_ASM_INST(0x4f98e27a) // sdot v26.4s, v19.16b, v24.4b[0] + and v22.16b, v22.16b, v28.16b + and v21.16b, v21.16b, v28.16b + KAI_ASM_INST(0x4fb8e25a) // sdot v26.4s, v18.16b, v24.4b[1] + KAI_ASM_INST(0x4f98ea3a) // sdot v26.4s, v17.16b, v24.4b[2] + KAI_ASM_INST(0x4fb8ea1a) // sdot v26.4s, v16.16b, v24.4b[3] + KAI_ASM_INST(0x4f94e33a) // sdot v26.4s, v25.16b, v20.4b[0] + KAI_ASM_INST(0x4fb4e2fa) // sdot v26.4s, v23.16b, v20.4b[1] + KAI_ASM_INST(0x4f94eada) // sdot v26.4s, v22.16b, v20.4b[2] + KAI_ASM_INST(0x4fb4eaba) // sdot v26.4s, v21.16b, v20.4b[3] + bgt label_4 + ldr d16, [x25, #0x0] + scvtf v26.4s, v26.4s, #0x4 + sub x21, x21, #0x1 + add x25, x25, #0x8 + shll v16.4s, v16.4h, #0x10 + fmla v27.4s, v26.4s, v16.4s + cbnz x21, label_3 + ld1r { v21.4s }, [x22] + ldr q20, [x25, #0x0] + add x22, x22, #0x4 + add x20, x27, #0x4 + ld1r { v19.4s }, [x22] + ldr q18, [x25, #0x10] + cmp x24, #0x4 + add x25, x25, #0x20 + ld1r { v17.4s }, [x27] + ld1r { v16.4s }, [x20] + scvtf v21.4s, v21.4s + fmla v27.4s, v20.4s, v21.s[0] + fmul v27.4s, v27.4s, v19.4s + fadd v27.4s, v27.4s, v18.4s + fmax v27.4s, v27.4s, v17.4s + fmin v27.4s, v27.4s, v16.4s + blt label_5 + str q27, [x9, #0x0] + b label_8 +KAI_ASM_LABEL(label_5) // Partial output + mov x20, x9 + tbz x24, #1, label_6 + st1 { v27.d }[0], [x20], #0x8 + tbz x24, #0, label_7 + st1 { v27.s }[2], [x20] + b label_7 +KAI_ASM_LABEL(label_6) // Output block 0: partial_1_0 + st1 { v27.s }[0], [x20] +KAI_ASM_LABEL(label_7) // Output block 0: Done +KAI_ASM_LABEL(label_8) // Stores done + subs x24, x24, #0x4 + add x9, x9, #0x10 + bgt label_2 + subs x26, x26, #0x1 + add x12, x12, x15 + mov x9, x23 + bgt label_1 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c index 6ad79c6a..c517f5a2 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c @@ -14,6 +14,20 @@ #include "kai/kai_common.h" +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod(KernelArgs* args_ptr); + // Compute args static const size_t kai_m_step = 1; static const size_t kai_n_step = 4; @@ -57,7 +71,7 @@ inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { inline static size_t kai_get_lhs_packed_stride(size_t k) { const size_t k_internal = kai_get_k_roundedup(k); size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); - // Since the LHS matrix is asymmetric with per-row quantization, we must include + // Since the LHS matrix is asymmetric with per-row quantization, we must include the // the number of bytes to hold the zero point value lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; @@ -143,111 +157,29 @@ void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod( float scalar_min, // float scalar_max) { KAI_ASSUME(dst_stride_col == sizeof(float)); - KAI_ASSUME((bl % kai_bl) == 0); KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); if (m == 0) { return; } - - const size_t num_subblocks = bl / kai_bl; + const size_t num_subblocks = bl / 32; const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); const float clamp_vals[2] = {scalar_min, scalar_max}; - __asm__ __volatile__( - "mov x27, #0x20\n" - "mov x20, #0x8\n" - "movi v28.16b, #0xf0\n" - "mov x26, %x[m]\n" - "mul x27, %x[num_subblocks], x27\n" - "madd x27, %x[num_blocks], x27, x20\n" - "1:" // Row loop - "mov x25, %x[rhs_packed]\n" - "mov x24, %x[n]\n" - "add x23, %x[dst], %x[dst_stride_row]\n" - "2:" // Column loop - "mov x22, %x[lhs_packed]\n" - "movi v27.16b, #0x0\n" - "mov x21, %x[num_blocks]\n" - "3:" // Block loop - "movi v26.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "4:" // Sub block loop - "ldr q25, [x25, #0x0]\n" - "ldr q24, [x22, #0x0]\n" - "subs x20, x20, #0x1\n" - "ldr q23, [x25, #0x10]\n" - "ldr q22, [x25, #0x20]\n" - "ldr q21, [x25, #0x30]\n" - "ldr q20, [x22, #0x10]\n" - "add x25, x25, #0x40\n" - "add x22, x22, #0x20\n" - "shl v19.16b, v25.16b, #0x4\n" - "and v25.16b, v25.16b, v28.16b\n" - "shl v18.16b, v23.16b, #0x4\n" - "shl v17.16b, v22.16b, #0x4\n" - "shl v16.16b, v21.16b, #0x4\n" - "and v23.16b, v23.16b, v28.16b\n" - ".inst 0x4f98e27a // sdot v26.4s, v19.16b, v24.4b[0]\n" - "and v22.16b, v22.16b, v28.16b\n" - "and v21.16b, v21.16b, v28.16b\n" - ".inst 0x4fb8e25a // sdot v26.4s, v18.16b, v24.4b[1]\n" - ".inst 0x4f98ea3a // sdot v26.4s, v17.16b, v24.4b[2]\n" - ".inst 0x4fb8ea1a // sdot v26.4s, v16.16b, v24.4b[3]\n" - ".inst 0x4f94e33a // sdot v26.4s, v25.16b, v20.4b[0]\n" - ".inst 0x4fb4e2fa // sdot v26.4s, v23.16b, v20.4b[1]\n" - ".inst 0x4f94eada // sdot v26.4s, v22.16b, v20.4b[2]\n" - ".inst 0x4fb4eaba // sdot v26.4s, v21.16b, v20.4b[3]\n" - "bgt 4b\n" - "ldr d16, [x25, #0x0]\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "sub x21, x21, #0x1\n" - "add x25, x25, #0x8\n" - "shll v16.4s, v16.4h, #0x10\n" - "fmla v27.4s, v26.4s, v16.4s\n" - "cbnz x21, 3b\n" - "ld1r { v21.4s }, [x22]\n" - "ldr q20, [x25, #0x0]\n" - "add x22, x22, #0x4\n" - "add x20, %x[clamp_vals], #0x4\n" - "ld1r { v19.4s }, [x22]\n" - "ldr q18, [x25, #0x10]\n" - "cmp x24, #0x4\n" - "add x25, x25, #0x20\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v16.4s }, [x20]\n" - "scvtf v21.4s, v21.4s\n" - "fmla v27.4s, v20.4s, v21.s[0]\n" - "fmul v27.4s, v27.4s, v19.4s\n" - "fadd v27.4s, v27.4s, v18.4s\n" - "fmax v27.4s, v27.4s, v17.4s\n" - "fmin v27.4s, v27.4s, v16.4s\n" - "blt 5f\n" - "str q27, [%x[dst], #0x0]\n" - "b 8f\n" - "5:" // Partial output - "mov x20, %x[dst]\n" - "tbz x24, #1, 6f\n" - "st1 { v27.d }[0], [x20], #0x8\n" - "tbz x24, #0, 7f\n" - "st1 { v27.s }[2], [x20]\n" - "b 7f\n" - "6:" // Output block 0: partial_1_0 - "st1 { v27.s }[0], [x20]\n" - "7:" // Output block 0: Done - "8:" // Stores done - "subs x24, x24, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 2b\n" - "subs x26, x26, #0x1\n" - "add %x[lhs_packed], %x[lhs_packed], x27\n" - "mov %x[dst], x23\n" - "bgt 1b\n" - : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) - : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), - [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27"); + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod(&args); } #endif // Architectural features check. -- GitLab From ad6307de8e9c62c5514d0f67505d8e9b231338f7 Mon Sep 17 00:00:00 2001 From: Michael Kozlov Date: Thu, 23 Jan 2025 16:33:48 +0000 Subject: [PATCH 02/15] Extract inline assembly kernels into external files: 1x4x32 Signed-off-by: Michael Kozlov --- CMakeLists.txt | 1 + ...p1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S | 158 +++++++++++ ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c | 258 ++++++++---------- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 57 ++-- 4 files changed, 296 insertions(+), 178 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index 515bf8de..ecbb56d2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -143,6 +143,7 @@ set(KLEIDIAI_FILES_NEON_DOTPROD kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S ) set(KLEIDIAI_FILES_NEON_I8MM diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S new file mode 100644 index 00000000..3bb24142 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S @@ -0,0 +1,158 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + mov x15, #0x20 + movi v31.16b, #0xf0 + mov x21, #0x8 + ldr x14, [x0, #0x40] + ldr x13, [x0, #0x38] + ldr x20, [x0, #0x28] + ldr x12, [x0, #0x8] + ldr x11, [x0, #0x10] + ldr x10, [x0, #0x30] + mul x15, x14, x15 + ldr x9, [x0, #0x0] + ldr x28, [x0, #0x20] + ldr x27, [x0, #0x18] + mov x26, x20 + madd x15, x13, x15, x21 +KAI_ASM_LABEL(label_1) // Row loop + mov x25, x11 + mov x24, x10 + add x23, x9, x28 +KAI_ASM_LABEL(label_2) // Column loop + mov x22, x12 + movi v30.16b, #0x0 + mov x21, x13 +KAI_ASM_LABEL(label_3) // Block loop + movi v29.4s, #0x0 + movi v28.4s, #0x0 + mov x20, x14 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q27, [x25, #0x0] + ldr q26, [x25, #0x10] + subs x20, x20, #0x1 + ld1r { v25.2d }, [x22], #0x8 + ldr q24, [x25, #0x20] + ldr q23, [x25, #0x30] + add x25, x25, #0x40 + ld1r { v22.2d }, [x22], #0x8 + ld1r { v21.2d }, [x22], #0x8 + shl v20.16b, v27.16b, #0x4 + shl v19.16b, v26.16b, #0x4 + ld1r { v18.2d }, [x22], #0x8 + shl v17.16b, v24.16b, #0x4 + and v27.16b, v27.16b, v31.16b + shl v16.16b, v23.16b, #0x4 + and v26.16b, v26.16b, v31.16b + KAI_ASM_INST(0x4e99969d) // sdot v29.4s, v20.16b, v25.16b + KAI_ASM_INST(0x4e99967c) // sdot v28.4s, v19.16b, v25.16b + and v24.16b, v24.16b, v31.16b + and v23.16b, v23.16b, v31.16b + KAI_ASM_INST(0x4e96963d) // sdot v29.4s, v17.16b, v22.16b + KAI_ASM_INST(0x4e96961c) // sdot v28.4s, v16.16b, v22.16b + KAI_ASM_INST(0x4e95977d) // sdot v29.4s, v27.16b, v21.16b + KAI_ASM_INST(0x4e95975c) // sdot v28.4s, v26.16b, v21.16b + KAI_ASM_INST(0x4e92971d) // sdot v29.4s, v24.16b, v18.16b + KAI_ASM_INST(0x4e9296fc) // sdot v28.4s, v23.16b, v18.16b + bgt label_4 + ldr d16, [x25, #0x0] + addp v29.4s, v29.4s, v28.4s + sub x21, x21, #0x1 + add x25, x25, #0x8 + shll v16.4s, v16.4h, #0x10 + scvtf v29.4s, v29.4s, #0x4 + fmla v30.4s, v29.4s, v16.4s + cbnz x21, label_3 + ld1r { v21.4s }, [x22] + ldr q20, [x25, #0x0] + add x22, x22, #0x4 + add x20, x27, #0x4 + ld1r { v19.4s }, [x22] + ldr q18, [x25, #0x10] + cmp x24, #0x4 + add x25, x25, #0x20 + ld1r { v17.4s }, [x27] + ld1r { v16.4s }, [x20] + scvtf v21.4s, v21.4s + fmla v30.4s, v20.4s, v21.s[0] + fmul v30.4s, v30.4s, v19.4s + fadd v30.4s, v30.4s, v18.4s + fmax v30.4s, v30.4s, v17.4s + fmin v30.4s, v30.4s, v16.4s + blt label_5 + str q30, [x9, #0x0] + b label_8 +KAI_ASM_LABEL(label_5) // Partial output + mov x20, x9 + tbz x24, #1, label_6 + st1 { v30.d }[0], [x20], #0x8 + tbz x24, #0, label_7 + st1 { v30.s }[2], [x20] + b label_7 +KAI_ASM_LABEL(label_6) // Output block 0: partial_1_0 + st1 { v30.s }[0], [x20] +KAI_ASM_LABEL(label_7) // Output block 0: Done +KAI_ASM_LABEL(label_8) // Stores done + subs x24, x24, #0x4 + add x9, x9, #0x10 + bgt label_2 + subs x26, x26, #0x1 + add x12, x12, x15 + mov x9, x23 + bgt label_1 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c index 34fd180a..b87e155e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c @@ -3,58 +3,95 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__ARM_FEATURE_DOTPROD) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) #error "Dotprod extension required to compile this micro-kernel" -#else +#else // Architectural features check. + #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" -#include #include #include #include "kai/kai_common.h" +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(KernelArgs* args_ptr); + +// Compute args static const size_t kai_m_step = 1; static const size_t kai_n_step = 4; +// Packing args static const size_t kai_mr = 1; static const size_t kai_nr = 4; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_bl_multiple_of = 32; -static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); -static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); -static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); -static const size_t kai_num_bytes_sum_rhs = sizeof(float); -static const size_t kai_num_bytes_bias = sizeof(float); - -inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - return kai_roundup(k, bl) / bl; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 2; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); } -inline static size_t kai_k_roundedup(size_t k) { - // Since we pack a float and int32 value at the end of the row, - // we must make sure that k is a multiple of 4 for alignment - size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); - return kai_roundup(k, kr_sr_roundedup4); +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs; + return num_bytes_per_block_rhs; } -inline static size_t kai_lhs_packed_stride(size_t k) { - const size_t k_internal = kai_k_roundedup(k); +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); - KAI_ASSERT((k_internal % 2) == 0); + return kai_roundup(k, bl) / bl; +} - return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; } -inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { - KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; - return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); + return rhs_packed_stride; } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { @@ -82,148 +119,67 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(vo } size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k) { - KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSUME((m_idx % kai_m_step) == 0); - return (m_idx / kai_mr) * kai_lhs_packed_stride(k); + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( size_t n_idx, size_t k, size_t bl) { - KAI_ASSERT((k % bl) == 0); - KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); - return (n_idx / kai_nr) * kai_rhs_packed_stride(k, bl); + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSERT((m_idx % kai_m_step) == 0); - KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); - return (n_idx * sizeof(float)) + m_idx * dst_stride; + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(size_t m, size_t n) { - return m * n * sizeof(float); + return m * n * kai_num_bytes_dst_value; } void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( - size_t m, size_t n, size_t k, size_t bl, const void* restrict lhs_packed, const void* restrict rhs_packed, - float* restrict dst, // NOLINT(readability-non-const-parameter) - size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { - KAI_ASSERT((k % bl) == 0); - KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(dst_stride_col == sizeof(float)); + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); if (m == 0) { return; } + const size_t num_subblocks = bl / 32; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(&args); +} - size_t num_subblocks = bl / kai_bl_multiple_of; - size_t num_blocks = kai_num_blocks_per_row(k, bl); - - float clamp_vals[2] = {scalar_min, scalar_max}; - - __asm__ __volatile__( - "mov x27, #0x20\n" - "mov x21, #0x3d800000\n" - "movi v0.16b, #0xf0\n" - "mov x20, #0x8\n" - "mov x26, %x[m]\n" - "mul x27, %x[num_subblocks], x27\n" - "dup v31.4s, w21\n" - "madd x27, %x[num_blocks], x27, x20\n" - "1:" // Row loop - "mov x25, %x[rhs_packed]\n" - "mov x24, %x[n]\n" - "add x23, %x[dst], %x[dst_stride_row]\n" - "2:" // Column loop - "mov x22, %x[lhs_packed]\n" - "movi v30.16b, #0x0\n" - "mov x21, %x[num_blocks]\n" - "3:" // Block loop - "movi v29.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "4:" // Sub block loop - "ldr q27, [x25, #0x0]\n" - "ldr q26, [x25, #0x10]\n" - "subs x20, x20, #0x1\n" - "ld1r { v25.2d }, [x22], #0x8\n" - "ldr q24, [x25, #0x20]\n" - "ldr q23, [x25, #0x30]\n" - "add x25, x25, #0x40\n" - "ld1r { v22.2d }, [x22], #0x8\n" - "ld1r { v21.2d }, [x22], #0x8\n" - "shl v20.16b, v27.16b, #0x4\n" - "shl v19.16b, v26.16b, #0x4\n" - "ld1r { v18.2d }, [x22], #0x8\n" - "shl v17.16b, v24.16b, #0x4\n" - "and v27.16b, v27.16b, v0.16b\n" - "shl v16.16b, v23.16b, #0x4\n" - "and v26.16b, v26.16b, v0.16b\n" - ".inst 0x4e99969d // sdot v29.4s, v20.16b, v25.16b\n" - ".inst 0x4e99967c // sdot v28.4s, v19.16b, v25.16b\n" - "and v24.16b, v24.16b, v0.16b\n" - "and v23.16b, v23.16b, v0.16b\n" - ".inst 0x4e96963d // sdot v29.4s, v17.16b, v22.16b\n" - ".inst 0x4e96961c // sdot v28.4s, v16.16b, v22.16b\n" - ".inst 0x4e95977d // sdot v29.4s, v27.16b, v21.16b\n" - ".inst 0x4e95975c // sdot v28.4s, v26.16b, v21.16b\n" - ".inst 0x4e92971d // sdot v29.4s, v24.16b, v18.16b\n" - ".inst 0x4e9296fc // sdot v28.4s, v23.16b, v18.16b\n" - "bgt 4b\n" - "ldr d16, [x25, #0x0]\n" - "addp v29.4s, v29.4s, v28.4s\n" - "sub x21, x21, #0x1\n" - "add x25, x25, #0x8\n" - "shll v16.4s, v16.4h, #0x10\n" - "scvtf v29.4s, v29.4s\n" - "fmul v16.4s, v16.4s, v31.4s\n" - "fmla v30.4s, v29.4s, v16.4s\n" - "cbnz x21, 3b\n" - "ld1r { v21.4s }, [x22]\n" - "ldr q20, [x25, #0x0]\n" - "add x22, x22, #0x4\n" - "add x20, %x[clamp_vals], #0x4\n" - "ld1r { v19.4s }, [x22]\n" - "ldr q18, [x25, #0x10]\n" - "cmp x24, #0x4\n" - "add x25, x25, #0x20\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v16.4s }, [x20]\n" - "scvtf v21.4s, v21.4s\n" - "fmla v30.4s, v20.4s, v21.s[0]\n" - "fmul v30.4s, v30.4s, v19.4s\n" - "fadd v30.4s, v30.4s, v18.4s\n" - "fmax v30.4s, v30.4s, v17.4s\n" - "fmin v30.4s, v30.4s, v16.4s\n" - "blt 5f\n" - "str q30, [%x[dst], #0x0]\n" - "b 8f\n" - "5:" // Partial output - "mov x20, %x[dst]\n" - "tbz x24, #1, 6f\n" - "st1 { v30.d }[0], [x20], #0x8\n" - "tbz x24, #0, 7f\n" - "st1 { v30.s }[2], [x20]\n" - "b 7f\n" - "6:" // Output block 0: partial_1_0 - "st1 { v30.s }[0], [x20]\n" - "7:" // Output block 0: Done - "8:" // Stores done - "subs x24, x24, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 2b\n" - "subs x26, x26, #0x1\n" - "add %x[lhs_packed], %x[lhs_packed], x27\n" - "mov %x[dst], x23\n" - "bgt 1b\n" - : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) - : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), - [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", - "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27"); -} -#endif // Architectural feature check +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h index b15be244..1685c30b 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -1,21 +1,22 @@ - // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // + #pragma once #include #ifdef __cplusplus extern "C" { -#endif +#endif // __cplusplus /// Micro-kernel dependencies /// -/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix +/// -# @ref kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS KxN matrix. /// -------------------------------------------------- @@ -38,7 +39,7 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotpro /// @return the mr value size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); -/// Gets the nr value, which must be used to pack the RHS matrix +/// Gets the nr value, which must be used to pack the RHS matrix. /// /// @return the nr value size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); @@ -54,13 +55,14 @@ size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(vo size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); /// Gets the offset in bytes for the packed LHS matrix, -/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. /// /// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. /// -/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 1 +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. /// @param[in] k Total number of columns in the LHS matrix (not packed). /// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( @@ -68,9 +70,9 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_ size_t k); // /// Gets the offset in bytes for the packed RHS matrix, -/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// which contains the packed Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) values. /// -/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. /// @param[in] k The common dimension between the LHS and RHS matrix (K). /// It must be a multiple of the block length (bl). /// @param[in] bl Block length. It must be a multiple of 32. @@ -83,8 +85,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_ /// Gets the offset in bytes for the DST matrix /// -/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 1. -/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] m_idx Row index in the DST matrix. It must be 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. /// @param[in] dst_stride The number of bytes in in each row of the DST matrix /// /// @return the DST offset in bytes @@ -105,25 +107,26 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotp /// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. /// -/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed -/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. -/// Output tile: (rows x cols) = 1 x 4 -/// Accumulation performed in a single for loop: 32 -/// Extension used: dotprod +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod /// -/// @param[in] m The number of output rows written. +/// @param[in] m The number of output rows written. It must be 1. /// @param[in] n The number of output columns written. /// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. -/// @param[in] bl Block length. It must be a multiple of 32. -/// @param[in] lhs_packed The LHS packed matrix. -/// When the activation are dynamically quantized, you can obtain this matrix -/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs -/// both the dynamic quantization to 8-bit and activation packing in a single step. -/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. -/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. /// @param[in] scalar_min Min value used to clamp the final result. /// @param[in] scalar_max Max value used to clamp the final result. void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( @@ -141,4 +144,4 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( #ifdef __cplusplus } -#endif +#endif // __cplusplus -- GitLab From d0ddf791ec0937c73107e82b698cafc786336bff Mon Sep 17 00:00:00 2001 From: Michael Kozlov Date: Thu, 23 Jan 2025 16:34:30 +0000 Subject: [PATCH 03/15] Extract inline assembly kernels into external files: 1x8x32 Signed-off-by: Michael Kozlov --- CMakeLists.txt | 1 + ...p1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S | 205 ++++++++++++ ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c | 307 ++++++------------ ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 56 ++-- 4 files changed, 342 insertions(+), 227 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index ecbb56d2..27c7e211 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -144,6 +144,7 @@ set(KLEIDIAI_FILES_NEON_DOTPROD kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S ) set(KLEIDIAI_FILES_NEON_I8MM diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S new file mode 100644 index 00000000..baa49e99 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S @@ -0,0 +1,205 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + mov x15, #0x20 + movi v7.16b, #0xf0 + mov x21, #0x8 + ldr x14, [x0, #0x40] + ldr x13, [x0, #0x38] + ldr x20, [x0, #0x28] + ldr x12, [x0, #0x8] + ldr x11, [x0, #0x10] + ldr x10, [x0, #0x30] + mul x15, x14, x15 + ldr x9, [x0, #0x0] + ldr x28, [x0, #0x20] + ldr x27, [x0, #0x18] + mov x26, x20 + madd x15, x13, x15, x21 +KAI_ASM_LABEL(label_1) // Row loop + mov x25, x11 + mov x24, x10 + add x23, x9, x28 +KAI_ASM_LABEL(label_2) // Column loop + mov x22, x12 + movi v6.16b, #0x0 + movi v5.16b, #0x0 + mov x21, x13 +KAI_ASM_LABEL(label_3) // Block loop + movi v4.4s, #0x0 + movi v3.4s, #0x0 + mov x20, x14 + movi v2.4s, #0x0 + movi v1.4s, #0x0 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q0, [x25, #0x0] + ldr q31, [x25, #0x10] + subs x20, x20, #0x1 + ldr q30, [x25, #0x20] + ldr q29, [x25, #0x30] + ld1r { v28.2d }, [x22], #0x8 + ldr q27, [x25, #0x40] + ldr q26, [x25, #0x50] + ldr q25, [x25, #0x60] + shl v24.16b, v0.16b, #0x4 + shl v18.16b, v31.16b, #0x4 + ldr q23, [x25, #0x70] + shl v17.16b, v30.16b, #0x4 + shl v16.16b, v29.16b, #0x4 + add x25, x25, #0x80 + ld1r { v22.2d }, [x22], #0x8 + shl v21.16b, v27.16b, #0x4 + and v0.16b, v0.16b, v7.16b + ld1r { v20.2d }, [x22], #0x8 + ld1r { v19.2d }, [x22], #0x8 + KAI_ASM_INST(0x4e9c9704) // sdot v4.4s, v24.16b, v28.16b + KAI_ASM_INST(0x4e9c9643) // sdot v3.4s, v18.16b, v28.16b + shl v18.16b, v26.16b, #0x4 + KAI_ASM_INST(0x4e9c9622) // sdot v2.4s, v17.16b, v28.16b + KAI_ASM_INST(0x4e9c9601) // sdot v1.4s, v16.16b, v28.16b + shl v17.16b, v25.16b, #0x4 + shl v16.16b, v23.16b, #0x4 + and v31.16b, v31.16b, v7.16b + and v30.16b, v30.16b, v7.16b + and v29.16b, v29.16b, v7.16b + KAI_ASM_INST(0x4e9696a4) // sdot v4.4s, v21.16b, v22.16b + KAI_ASM_INST(0x4e969643) // sdot v3.4s, v18.16b, v22.16b + and v27.16b, v27.16b, v7.16b + KAI_ASM_INST(0x4e969622) // sdot v2.4s, v17.16b, v22.16b + KAI_ASM_INST(0x4e969601) // sdot v1.4s, v16.16b, v22.16b + and v26.16b, v26.16b, v7.16b + and v25.16b, v25.16b, v7.16b + and v23.16b, v23.16b, v7.16b + KAI_ASM_INST(0x4e949404) // sdot v4.4s, v0.16b, v20.16b + KAI_ASM_INST(0x4e9497e3) // sdot v3.4s, v31.16b, v20.16b + KAI_ASM_INST(0x4e9497c2) // sdot v2.4s, v30.16b, v20.16b + KAI_ASM_INST(0x4e9497a1) // sdot v1.4s, v29.16b, v20.16b + KAI_ASM_INST(0x4e939764) // sdot v4.4s, v27.16b, v19.16b + KAI_ASM_INST(0x4e939743) // sdot v3.4s, v26.16b, v19.16b + KAI_ASM_INST(0x4e939722) // sdot v2.4s, v25.16b, v19.16b + KAI_ASM_INST(0x4e9396e1) // sdot v1.4s, v23.16b, v19.16b + bgt label_4 + ldr q16, [x25, #0x0] + addp v4.4s, v4.4s, v3.4s + addp v2.4s, v2.4s, v1.4s + sub x21, x21, #0x1 + add x25, x25, #0x10 + shll v17.4s, v16.4h, #0x10 + shll2 v16.4s, v16.8h, #0x10 + scvtf v4.4s, v4.4s, #0x4 + scvtf v2.4s, v2.4s, #0x4 + fmla v6.4s, v4.4s, v17.4s + fmla v5.4s, v2.4s, v16.4s + cbnz x21, label_3 + ld1r { v23.4s }, [x22] + ldr q22, [x25, #0x0] + add x22, x22, #0x4 + add x20, x27, #0x4 + ldr q21, [x25, #0x10] + ld1r { v20.4s }, [x22] + cmp x24, #0x8 + ldr q19, [x25, #0x20] + ldr q18, [x25, #0x30] + add x25, x25, #0x40 + ld1r { v17.4s }, [x27] + ld1r { v16.4s }, [x20] + scvtf v23.4s, v23.4s + fmla v6.4s, v22.4s, v23.s[0] + fmla v5.4s, v21.4s, v23.s[0] + fmul v6.4s, v6.4s, v20.4s + fadd v6.4s, v6.4s, v19.4s + fmul v5.4s, v5.4s, v20.4s + fadd v5.4s, v5.4s, v18.4s + fmax v6.4s, v6.4s, v17.4s + fmax v5.4s, v5.4s, v17.4s + fmin v6.4s, v6.4s, v16.4s + fmin v5.4s, v5.4s, v16.4s + blt label_5 + str q6, [x9, #0x0] + str q5, [x9, #0x10] + b label_10 +KAI_ASM_LABEL(label_5) // Partial output + mov x20, x9 + tbz x24, #2, label_7 + st1 { v6.4s }, [x20], #0x10 + tbz x24, #1, label_6 + st1 { v5.d }[0], [x20], #0x8 + tbz x24, #0, label_9 + st1 { v5.s }[2], [x20] + b label_9 +KAI_ASM_LABEL(label_6) // Output block 0: partial_1_4 + tbz x24, #0, label_9 + st1 { v5.s }[0], [x20] + b label_9 +KAI_ASM_LABEL(label_7) // Output block 0: partial_2_0 + tbz x24, #1, label_8 + st1 { v6.d }[0], [x20], #0x8 + tbz x24, #0, label_9 + st1 { v6.s }[2], [x20] + b label_9 +KAI_ASM_LABEL(label_8) // Output block 0: partial_1_0 + st1 { v6.s }[0], [x20] +KAI_ASM_LABEL(label_9) // Output block 0: Done +KAI_ASM_LABEL(label_10) // Stores done + subs x24, x24, #0x8 + add x9, x9, #0x20 + bgt label_2 + subs x26, x26, #0x1 + add x12, x12, x15 + mov x9, x23 + bgt label_1 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c index 45b6fa60..59eec89d 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c @@ -3,58 +3,95 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__ARM_FEATURE_DOTPROD) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) #error "Dotprod extension required to compile this micro-kernel" -#else +#else // Architectural features check. + #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" -#include #include #include #include "kai/kai_common.h" +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(KernelArgs* args_ptr); + +// Compute args static const size_t kai_m_step = 1; static const size_t kai_n_step = 8; +// Packing args static const size_t kai_mr = 1; static const size_t kai_nr = 8; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_bl_multiple_of = 32; -static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); -static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); -static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); -static const size_t kai_num_bytes_sum_rhs = sizeof(float); -static const size_t kai_num_bytes_bias = sizeof(float); - -inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - return kai_roundup(k, bl) / bl; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 2; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); } -inline static size_t kai_k_roundedup(size_t k) { - // Since we pack a float and int32 value at the end of the row, - // we must make sure that k is a multiple of 4 for alignment - size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); - return kai_roundup(k, kr_sr_roundedup4); +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs; + return num_bytes_per_block_rhs; } -inline static size_t kai_lhs_packed_stride(size_t k) { - const size_t k_internal = kai_k_roundedup(k); +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} - KAI_ASSERT((k_internal % 2) == 0); +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; - return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + return lhs_packed_stride; } -inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { - KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; - return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); + return rhs_packed_stride; } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { @@ -82,197 +119,67 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(vo } size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k) { - KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSUME((m_idx % kai_m_step) == 0); - return (m_idx / kai_mr) * kai_lhs_packed_stride(k); + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( size_t n_idx, size_t k, size_t bl) { - KAI_ASSERT((k % bl) == 0); - KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); - return (n_idx / kai_nr) * kai_rhs_packed_stride(k, bl); + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSERT((m_idx % kai_m_step) == 0); - KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); - return (n_idx * sizeof(float)) + m_idx * dst_stride; + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(size_t m, size_t n) { - return m * n * sizeof(float); + return m * n * kai_num_bytes_dst_value; } void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( - size_t m, size_t n, size_t k, size_t bl, const void* restrict lhs_packed, const void* restrict rhs_packed, - float* restrict dst, // NOLINT(readability-non-const-parameter) - size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { - KAI_ASSERT((k % bl) == 0); - KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(dst_stride_col == sizeof(float)); + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); if (m == 0) { return; } - - size_t num_subblocks = bl / kai_bl_multiple_of; - size_t num_blocks = kai_num_blocks_per_row(k, bl); - - float clamp_vals[2] = {scalar_min, scalar_max}; - - __asm__ __volatile__( - "mov x27, #0x20\n" - "mov x21, #0x3d800000\n" - "movi v8.16b, #0xf0\n" - "mov x20, #0x8\n" - "mov x26, %x[m]\n" - "mul x27, %x[num_subblocks], x27\n" - "dup v7.4s, w21\n" - "madd x27, %x[num_blocks], x27, x20\n" - "1:" // Row loop - "mov x25, %x[rhs_packed]\n" - "mov x24, %x[n]\n" - "add x23, %x[dst], %x[dst_stride_row]\n" - "2:" // Column loop - "mov x22, %x[lhs_packed]\n" - "movi v6.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "mov x21, %x[num_blocks]\n" - "3:" // Block loop - "movi v4.4s, #0x0\n" - "movi v3.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "movi v2.4s, #0x0\n" - "movi v1.4s, #0x0\n" - "4:" // Sub block loop - "ldr q0, [x25, #0x0]\n" - "ldr q31, [x25, #0x10]\n" - "subs x20, x20, #0x1\n" - "ldr q30, [x25, #0x20]\n" - "ldr q29, [x25, #0x30]\n" - "ld1r { v28.2d }, [x22], #0x8\n" - "ldr q27, [x25, #0x40]\n" - "ldr q26, [x25, #0x50]\n" - "ldr q25, [x25, #0x60]\n" - "shl v24.16b, v0.16b, #0x4\n" - "shl v18.16b, v31.16b, #0x4\n" - "ldr q23, [x25, #0x70]\n" - "shl v17.16b, v30.16b, #0x4\n" - "shl v16.16b, v29.16b, #0x4\n" - "add x25, x25, #0x80\n" - "ld1r { v22.2d }, [x22], #0x8\n" - "shl v21.16b, v27.16b, #0x4\n" - "and v0.16b, v0.16b, v8.16b\n" - "ld1r { v20.2d }, [x22], #0x8\n" - "ld1r { v19.2d }, [x22], #0x8\n" - ".inst 0x4e9c9704 // sdot v4.4s, v24.16b, v28.16b\n" - ".inst 0x4e9c9643 // sdot v3.4s, v18.16b, v28.16b\n" - "shl v18.16b, v26.16b, #0x4\n" - ".inst 0x4e9c9622 // sdot v2.4s, v17.16b, v28.16b\n" - ".inst 0x4e9c9601 // sdot v1.4s, v16.16b, v28.16b\n" - "shl v17.16b, v25.16b, #0x4\n" - "shl v16.16b, v23.16b, #0x4\n" - "and v31.16b, v31.16b, v8.16b\n" - "and v30.16b, v30.16b, v8.16b\n" - "and v29.16b, v29.16b, v8.16b\n" - ".inst 0x4e9696a4 // sdot v4.4s, v21.16b, v22.16b\n" - ".inst 0x4e969643 // sdot v3.4s, v18.16b, v22.16b\n" - "and v27.16b, v27.16b, v8.16b\n" - ".inst 0x4e969622 // sdot v2.4s, v17.16b, v22.16b\n" - ".inst 0x4e969601 // sdot v1.4s, v16.16b, v22.16b\n" - "and v26.16b, v26.16b, v8.16b\n" - "and v25.16b, v25.16b, v8.16b\n" - "and v23.16b, v23.16b, v8.16b\n" - ".inst 0x4e949404 // sdot v4.4s, v0.16b, v20.16b\n" - ".inst 0x4e9497e3 // sdot v3.4s, v31.16b, v20.16b\n" - ".inst 0x4e9497c2 // sdot v2.4s, v30.16b, v20.16b\n" - ".inst 0x4e9497a1 // sdot v1.4s, v29.16b, v20.16b\n" - ".inst 0x4e939764 // sdot v4.4s, v27.16b, v19.16b\n" - ".inst 0x4e939743 // sdot v3.4s, v26.16b, v19.16b\n" - ".inst 0x4e939722 // sdot v2.4s, v25.16b, v19.16b\n" - ".inst 0x4e9396e1 // sdot v1.4s, v23.16b, v19.16b\n" - "bgt 4b\n" - "ldr q16, [x25, #0x0]\n" - "addp v4.4s, v4.4s, v3.4s\n" - "addp v2.4s, v2.4s, v1.4s\n" - "sub x21, x21, #0x1\n" - "add x25, x25, #0x10\n" - "shll v17.4s, v16.4h, #0x10\n" - "shll2 v16.4s, v16.8h, #0x10\n" - "scvtf v4.4s, v4.4s\n" - "scvtf v2.4s, v2.4s\n" - "fmul v17.4s, v17.4s, v7.4s\n" - "fmul v16.4s, v16.4s, v7.4s\n" - "fmla v6.4s, v4.4s, v17.4s\n" - "fmla v5.4s, v2.4s, v16.4s\n" - "cbnz x21, 3b\n" - "ld1r { v23.4s }, [x22]\n" - "ldr q22, [x25, #0x0]\n" - "add x22, x22, #0x4\n" - "add x20, %x[clamp_vals], #0x4\n" - "ldr q21, [x25, #0x10]\n" - "ld1r { v20.4s }, [x22]\n" - "cmp x24, #0x8\n" - "ldr q19, [x25, #0x20]\n" - "ldr q18, [x25, #0x30]\n" - "add x25, x25, #0x40\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v16.4s }, [x20]\n" - "scvtf v23.4s, v23.4s\n" - "fmla v6.4s, v22.4s, v23.s[0]\n" - "fmla v5.4s, v21.4s, v23.s[0]\n" - "fmul v6.4s, v6.4s, v20.4s\n" - "fadd v6.4s, v6.4s, v19.4s\n" - "fmul v5.4s, v5.4s, v20.4s\n" - "fadd v5.4s, v5.4s, v18.4s\n" - "fmax v6.4s, v6.4s, v17.4s\n" - "fmax v5.4s, v5.4s, v17.4s\n" - "fmin v6.4s, v6.4s, v16.4s\n" - "fmin v5.4s, v5.4s, v16.4s\n" - "blt 5f\n" - "str q6, [%x[dst], #0x0]\n" - "str q5, [%x[dst], #0x10]\n" - "b 10f\n" - "5:" // Partial output - "mov x20, %x[dst]\n" - "tbz x24, #2, 7f\n" - "st1 { v6.4s }, [x20], #0x10\n" - "tbz x24, #1, 6f\n" - "st1 { v5.d }[0], [x20], #0x8\n" - "tbz x24, #0, 9f\n" - "st1 { v5.s }[2], [x20]\n" - "b 9f\n" - "6:" // Output block 0: partial_1_4 - "tbz x24, #0, 9f\n" - "st1 { v5.s }[0], [x20]\n" - "b 9f\n" - "7:" // Output block 0: partial_2_0 - "tbz x24, #1, 8f\n" - "st1 { v6.d }[0], [x20], #0x8\n" - "tbz x24, #0, 9f\n" - "st1 { v6.s }[2], [x20]\n" - "b 9f\n" - "8:" // Output block 0: partial_1_0 - "st1 { v6.s }[0], [x20]\n" - "9:" // Output block 0: Done - "10:" // Stores done - "subs x24, x24, #0x8\n" - "add %x[dst], %x[dst], #0x20\n" - "bgt 2b\n" - "subs x26, x26, #0x1\n" - "add %x[lhs_packed], %x[lhs_packed], x27\n" - "mov %x[dst], x23\n" - "bgt 1b\n" - : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) - : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), - [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v16", "v17", "v18", "v19", "v20", - "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", - "x24", "x25", "x26", "x27"); + const size_t num_subblocks = bl / 32; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(&args); } -#endif // Architectural feature check + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h index b4ff155f..225add8a 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -1,21 +1,22 @@ - // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // + #pragma once #include #ifdef __cplusplus extern "C" { -#endif +#endif // __cplusplus /// Micro-kernel dependencies /// -/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix +/// -# @ref kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS KxN matrix. /// -------------------------------------------------- @@ -38,7 +39,7 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotpro /// @return the mr value size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); -/// Gets the nr value, which must be used to pack the RHS matrix +/// Gets the nr value, which must be used to pack the RHS matrix. /// /// @return the nr value size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); @@ -54,13 +55,14 @@ size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(vo size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); /// Gets the offset in bytes for the packed LHS matrix, -/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. /// /// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. /// -/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 1 +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. /// @param[in] k Total number of columns in the LHS matrix (not packed). /// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( @@ -68,9 +70,9 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_ size_t k); // /// Gets the offset in bytes for the packed RHS matrix, -/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// which contains the packed Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) values. /// -/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 8. +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. /// @param[in] k The common dimension between the LHS and RHS matrix (K). /// It must be a multiple of the block length (bl). /// @param[in] bl Block length. It must be a multiple of 32. @@ -83,8 +85,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_ /// Gets the offset in bytes for the DST matrix /// -/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 1. -/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 8. +/// @param[in] m_idx Row index in the DST matrix. It must be 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. /// @param[in] dst_stride The number of bytes in in each row of the DST matrix /// /// @return the DST offset in bytes @@ -105,26 +107,26 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotp /// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. /// -/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed -/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. -/// Output tile: (rows x cols) = 1 x 8 -/// Accumulation performed in a single for loop: 32 -/// Extension used: dotprod +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod /// -/// @param[in] m The number of output rows written. +/// @param[in] m The number of output rows written. It must be 1. /// @param[in] n The number of output columns written. /// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. /// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. It must be a multiple of 32. -/// @param[in] lhs_packed The LHS packed matrix. -/// When the activation are dynamically quantized, you can obtain this matrix -/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs -/// both the dynamic quantization to 8-bit and activation packing in a single step. -/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. -/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. /// @param[in] scalar_min Min value used to clamp the final result. /// @param[in] scalar_max Max value used to clamp the final result. void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( @@ -142,4 +144,4 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( #ifdef __cplusplus } -#endif +#endif // __cplusplus -- GitLab From d73bfe32a8a6284fc4356e77057049b553f88fde Mon Sep 17 00:00:00 2001 From: Michael Kozlov Date: Thu, 23 Jan 2025 16:35:00 +0000 Subject: [PATCH 04/15] Extract inline assembly kernels into external files: 16x4 Signed-off-by: Michael Kozlov --- CMakeLists.txt | 1 + ...dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S | 804 ++++++++++++++++++ ...qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c | 768 +---------------- 3 files changed, 835 insertions(+), 738 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index 27c7e211..21dcf7e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -145,6 +145,7 @@ set(KLEIDIAI_FILES_NEON_DOTPROD kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S ) set(KLEIDIAI_FILES_NEON_I8MM diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S new file mode 100644 index 00000000..dbb7a11d --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S @@ -0,0 +1,804 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x5, #0x80 + mov x21, #0x20 + sub SP, SP, #0x100 + ldr x20, [x0, #0x28] + ldr x6, [x0, #0x40] + ldr x7, [x0, #0x38] + ldr x8, [x0, #0x8] + ldr x17, [x0, #0x10] + ldr x16, [x0, #0x30] + mov x15, x20 + mul x5, x6, x5 + ldr x14, [x0, #0x0] + ldr x13, [x0, #0x20] + ldr x12, [x0, #0x18] + cmp x15, #0x10 + madd x5, x7, x5, x21 + blt label_15 +KAI_ASM_LABEL(label_1) // Row loop + mov x11, x17 + mov x10, x16 + add x9, x14, x13, LSL #4 +KAI_ASM_LABEL(label_2) // Column loop + mov x27, x8 + movi v6.4s, #0x0 + mov x24, x7 + str q6, [SP, #0x0] + str q6, [SP, #0x10] + str q6, [SP, #0x20] + add x23, x27, x5 + add x22, x23, x5 + str q6, [SP, #0x30] + add x21, x22, x5 + str q6, [SP, #0x40] + str q6, [SP, #0x50] + str q6, [SP, #0x60] + str q6, [SP, #0x70] + str q6, [SP, #0x80] + str q6, [SP, #0x90] + str q6, [SP, #0xa0] + str q6, [SP, #0xb0] + str q6, [SP, #0xc0] + str q6, [SP, #0xd0] + str q6, [SP, #0xe0] + str q6, [SP, #0xf0] +KAI_ASM_LABEL(label_3) // Block loop + movi v2.4s, #0x0 + movi v17.4s, #0x0 + mov x20, x6 + movi v12.4s, #0x0 + movi v9.4s, #0x0 + movi v14.4s, #0x0 + movi v11.4s, #0x0 + movi v13.4s, #0x0 + movi v15.4s, #0x0 + movi v23.4s, #0x0 + movi v29.4s, #0x0 + movi v0.4s, #0x0 + movi v4.4s, #0x0 + movi v16.4s, #0x0 + movi v21.4s, #0x0 + movi v10.4s, #0x0 + movi v3.4s, #0x0 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q6, [x11, #0x0] + ldr q1, [x27, #0x0] + movi v25.16b, #0xf0 + subs x20, x20, #0x1 + ldr q5, [x23, #0x0] + ldr q30, [x22, #0x0] + ldr q24, [x21, #0x0] + ldr q18, [x11, #0x10] + ldr q27, [x27, #0x10] + ldr q20, [x23, #0x10] + shl v31.16b, v6.16b, #0x4 + and v6.16b, v6.16b, v25.16b + ldr q19, [x22, #0x10] + ldr q26, [x21, #0x10] + ldr q7, [x11, #0x20] + ldr q8, [x27, #0x20] + shl v22.16b, v18.16b, #0x4 + and v18.16b, v18.16b, v25.16b + ldr q28, [x23, #0x20] + KAI_ASM_INST(0x4f81e3e2) // sdot v2.4s, v31.16b, v1.4b[0] + KAI_ASM_INST(0x4fa1e3f1) // sdot v17.4s, v31.16b, v1.4b[1] + KAI_ASM_INST(0x4f81ebec) // sdot v12.4s, v31.16b, v1.4b[2] + KAI_ASM_INST(0x4fa1ebe9) // sdot v9.4s, v31.16b, v1.4b[3] + ldr q1, [x22, #0x20] + KAI_ASM_INST(0x4f85e3ee) // sdot v14.4s, v31.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e3eb) // sdot v11.4s, v31.16b, v5.4b[1] + KAI_ASM_INST(0x4f85ebed) // sdot v13.4s, v31.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5ebef) // sdot v15.4s, v31.16b, v5.4b[3] + ldr q5, [x21, #0x20] + KAI_ASM_INST(0x4f9ee3f7) // sdot v23.4s, v31.16b, v30.4b[0] + KAI_ASM_INST(0x4fbee3fd) // sdot v29.4s, v31.16b, v30.4b[1] + KAI_ASM_INST(0x4f9eebe0) // sdot v0.4s, v31.16b, v30.4b[2] + KAI_ASM_INST(0x4fbeebe4) // sdot v4.4s, v31.16b, v30.4b[3] + ldr q30, [x11, #0x30] + add x11, x11, #0x40 + KAI_ASM_INST(0x4f98e3f0) // sdot v16.4s, v31.16b, v24.4b[0] + KAI_ASM_INST(0x4fb8e3f5) // sdot v21.4s, v31.16b, v24.4b[1] + KAI_ASM_INST(0x4f98ebea) // sdot v10.4s, v31.16b, v24.4b[2] + KAI_ASM_INST(0x4fb8ebe3) // sdot v3.4s, v31.16b, v24.4b[3] + ldr q24, [x27, #0x30] + ldr q31, [x23, #0x30] + KAI_ASM_INST(0x4f9be2c2) // sdot v2.4s, v22.16b, v27.4b[0] + KAI_ASM_INST(0x4fbbe2d1) // sdot v17.4s, v22.16b, v27.4b[1] + KAI_ASM_INST(0x4f9beacc) // sdot v12.4s, v22.16b, v27.4b[2] + KAI_ASM_INST(0x4fbbeac9) // sdot v9.4s, v22.16b, v27.4b[3] + ldr q27, [x22, #0x30] + KAI_ASM_INST(0x4f94e2ce) // sdot v14.4s, v22.16b, v20.4b[0] + KAI_ASM_INST(0x4fb4e2cb) // sdot v11.4s, v22.16b, v20.4b[1] + KAI_ASM_INST(0x4f94eacd) // sdot v13.4s, v22.16b, v20.4b[2] + KAI_ASM_INST(0x4fb4eacf) // sdot v15.4s, v22.16b, v20.4b[3] + ldr q20, [x21, #0x30] + KAI_ASM_INST(0x4f93e2d7) // sdot v23.4s, v22.16b, v19.4b[0] + KAI_ASM_INST(0x4fb3e2dd) // sdot v29.4s, v22.16b, v19.4b[1] + KAI_ASM_INST(0x4f93eac0) // sdot v0.4s, v22.16b, v19.4b[2] + KAI_ASM_INST(0x4fb3eac4) // sdot v4.4s, v22.16b, v19.4b[3] + ldr q19, [x27, #0x40] + KAI_ASM_INST(0x4f9ae2d0) // sdot v16.4s, v22.16b, v26.4b[0] + KAI_ASM_INST(0x4fbae2d5) // sdot v21.4s, v22.16b, v26.4b[1] + KAI_ASM_INST(0x4f9aeaca) // sdot v10.4s, v22.16b, v26.4b[2] + KAI_ASM_INST(0x4fbaeac3) // sdot v3.4s, v22.16b, v26.4b[3] + ldr q22, [x23, #0x40] + shl v26.16b, v7.16b, #0x4 + and v7.16b, v7.16b, v25.16b + KAI_ASM_INST(0x4f88e342) // sdot v2.4s, v26.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e351) // sdot v17.4s, v26.16b, v8.4b[1] + KAI_ASM_INST(0x4f88eb4c) // sdot v12.4s, v26.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8eb49) // sdot v9.4s, v26.16b, v8.4b[3] + ldr q8, [x22, #0x40] + KAI_ASM_INST(0x4f9ce34e) // sdot v14.4s, v26.16b, v28.4b[0] + KAI_ASM_INST(0x4fbce34b) // sdot v11.4s, v26.16b, v28.4b[1] + KAI_ASM_INST(0x4f9ceb4d) // sdot v13.4s, v26.16b, v28.4b[2] + KAI_ASM_INST(0x4fbceb4f) // sdot v15.4s, v26.16b, v28.4b[3] + ldr q28, [x21, #0x40] + KAI_ASM_INST(0x4f81e357) // sdot v23.4s, v26.16b, v1.4b[0] + KAI_ASM_INST(0x4fa1e35d) // sdot v29.4s, v26.16b, v1.4b[1] + KAI_ASM_INST(0x4f81eb40) // sdot v0.4s, v26.16b, v1.4b[2] + KAI_ASM_INST(0x4fa1eb44) // sdot v4.4s, v26.16b, v1.4b[3] + ldr q1, [x27, #0x50] + KAI_ASM_INST(0x4f85e350) // sdot v16.4s, v26.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e355) // sdot v21.4s, v26.16b, v5.4b[1] + KAI_ASM_INST(0x4f85eb4a) // sdot v10.4s, v26.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5eb43) // sdot v3.4s, v26.16b, v5.4b[3] + ldr q5, [x23, #0x50] + shl v26.16b, v30.16b, #0x4 + and v30.16b, v30.16b, v25.16b + ldr q25, [x22, #0x50] + KAI_ASM_INST(0x4f98e342) // sdot v2.4s, v26.16b, v24.4b[0] + KAI_ASM_INST(0x4fb8e351) // sdot v17.4s, v26.16b, v24.4b[1] + KAI_ASM_INST(0x4f98eb4c) // sdot v12.4s, v26.16b, v24.4b[2] + KAI_ASM_INST(0x4fb8eb49) // sdot v9.4s, v26.16b, v24.4b[3] + ldr q24, [x21, #0x50] + KAI_ASM_INST(0x4f9fe34e) // sdot v14.4s, v26.16b, v31.4b[0] + KAI_ASM_INST(0x4fbfe34b) // sdot v11.4s, v26.16b, v31.4b[1] + KAI_ASM_INST(0x4f9feb4d) // sdot v13.4s, v26.16b, v31.4b[2] + KAI_ASM_INST(0x4fbfeb4f) // sdot v15.4s, v26.16b, v31.4b[3] + ldr q31, [x27, #0x60] + KAI_ASM_INST(0x4f9be357) // sdot v23.4s, v26.16b, v27.4b[0] + KAI_ASM_INST(0x4fbbe35d) // sdot v29.4s, v26.16b, v27.4b[1] + KAI_ASM_INST(0x4f9beb40) // sdot v0.4s, v26.16b, v27.4b[2] + KAI_ASM_INST(0x4fbbeb44) // sdot v4.4s, v26.16b, v27.4b[3] + ldr q27, [x23, #0x60] + KAI_ASM_INST(0x4f94e350) // sdot v16.4s, v26.16b, v20.4b[0] + KAI_ASM_INST(0x4fb4e355) // sdot v21.4s, v26.16b, v20.4b[1] + KAI_ASM_INST(0x4f94eb4a) // sdot v10.4s, v26.16b, v20.4b[2] + KAI_ASM_INST(0x4fb4eb43) // sdot v3.4s, v26.16b, v20.4b[3] + ldr q26, [x22, #0x60] + ldr q20, [x21, #0x60] + KAI_ASM_INST(0x4f93e0c2) // sdot v2.4s, v6.16b, v19.4b[0] + KAI_ASM_INST(0x4fb3e0d1) // sdot v17.4s, v6.16b, v19.4b[1] + KAI_ASM_INST(0x4f93e8cc) // sdot v12.4s, v6.16b, v19.4b[2] + KAI_ASM_INST(0x4fb3e8c9) // sdot v9.4s, v6.16b, v19.4b[3] + ldr q19, [x27, #0x70] + add x27, x27, #0x80 + KAI_ASM_INST(0x4f96e0ce) // sdot v14.4s, v6.16b, v22.4b[0] + KAI_ASM_INST(0x4fb6e0cb) // sdot v11.4s, v6.16b, v22.4b[1] + KAI_ASM_INST(0x4f96e8cd) // sdot v13.4s, v6.16b, v22.4b[2] + KAI_ASM_INST(0x4fb6e8cf) // sdot v15.4s, v6.16b, v22.4b[3] + ldr q22, [x23, #0x70] + add x23, x23, #0x80 + KAI_ASM_INST(0x4f88e0d7) // sdot v23.4s, v6.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e0dd) // sdot v29.4s, v6.16b, v8.4b[1] + KAI_ASM_INST(0x4f88e8c0) // sdot v0.4s, v6.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8e8c4) // sdot v4.4s, v6.16b, v8.4b[3] + ldr q8, [x22, #0x70] + add x22, x22, #0x80 + KAI_ASM_INST(0x4f9ce0d0) // sdot v16.4s, v6.16b, v28.4b[0] + KAI_ASM_INST(0x4fbce0d5) // sdot v21.4s, v6.16b, v28.4b[1] + KAI_ASM_INST(0x4f9ce8ca) // sdot v10.4s, v6.16b, v28.4b[2] + KAI_ASM_INST(0x4fbce8c3) // sdot v3.4s, v6.16b, v28.4b[3] + ldr q28, [x21, #0x70] + add x21, x21, #0x80 + KAI_ASM_INST(0x4f81e242) // sdot v2.4s, v18.16b, v1.4b[0] + KAI_ASM_INST(0x4fa1e251) // sdot v17.4s, v18.16b, v1.4b[1] + KAI_ASM_INST(0x4f81ea4c) // sdot v12.4s, v18.16b, v1.4b[2] + KAI_ASM_INST(0x4fa1ea49) // sdot v9.4s, v18.16b, v1.4b[3] + KAI_ASM_INST(0x4f85e24e) // sdot v14.4s, v18.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e24b) // sdot v11.4s, v18.16b, v5.4b[1] + KAI_ASM_INST(0x4f85ea4d) // sdot v13.4s, v18.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5ea4f) // sdot v15.4s, v18.16b, v5.4b[3] + KAI_ASM_INST(0x4f99e257) // sdot v23.4s, v18.16b, v25.4b[0] + KAI_ASM_INST(0x4fb9e25d) // sdot v29.4s, v18.16b, v25.4b[1] + KAI_ASM_INST(0x4f99ea40) // sdot v0.4s, v18.16b, v25.4b[2] + KAI_ASM_INST(0x4fb9ea44) // sdot v4.4s, v18.16b, v25.4b[3] + KAI_ASM_INST(0x4f98e250) // sdot v16.4s, v18.16b, v24.4b[0] + KAI_ASM_INST(0x4fb8e255) // sdot v21.4s, v18.16b, v24.4b[1] + KAI_ASM_INST(0x4f98ea4a) // sdot v10.4s, v18.16b, v24.4b[2] + KAI_ASM_INST(0x4fb8ea43) // sdot v3.4s, v18.16b, v24.4b[3] + KAI_ASM_INST(0x4f9fe0e2) // sdot v2.4s, v7.16b, v31.4b[0] + KAI_ASM_INST(0x4fbfe0f1) // sdot v17.4s, v7.16b, v31.4b[1] + KAI_ASM_INST(0x4f9fe8ec) // sdot v12.4s, v7.16b, v31.4b[2] + KAI_ASM_INST(0x4fbfe8e9) // sdot v9.4s, v7.16b, v31.4b[3] + KAI_ASM_INST(0x4f9be0ee) // sdot v14.4s, v7.16b, v27.4b[0] + KAI_ASM_INST(0x4fbbe0eb) // sdot v11.4s, v7.16b, v27.4b[1] + KAI_ASM_INST(0x4f9be8ed) // sdot v13.4s, v7.16b, v27.4b[2] + KAI_ASM_INST(0x4fbbe8ef) // sdot v15.4s, v7.16b, v27.4b[3] + KAI_ASM_INST(0x4f9ae0f7) // sdot v23.4s, v7.16b, v26.4b[0] + KAI_ASM_INST(0x4fbae0fd) // sdot v29.4s, v7.16b, v26.4b[1] + KAI_ASM_INST(0x4f9ae8e0) // sdot v0.4s, v7.16b, v26.4b[2] + KAI_ASM_INST(0x4fbae8e4) // sdot v4.4s, v7.16b, v26.4b[3] + KAI_ASM_INST(0x4f94e0f0) // sdot v16.4s, v7.16b, v20.4b[0] + KAI_ASM_INST(0x4fb4e0f5) // sdot v21.4s, v7.16b, v20.4b[1] + KAI_ASM_INST(0x4f94e8ea) // sdot v10.4s, v7.16b, v20.4b[2] + KAI_ASM_INST(0x4fb4e8e3) // sdot v3.4s, v7.16b, v20.4b[3] + KAI_ASM_INST(0x4f93e3c2) // sdot v2.4s, v30.16b, v19.4b[0] + KAI_ASM_INST(0x4fb3e3d1) // sdot v17.4s, v30.16b, v19.4b[1] + KAI_ASM_INST(0x4f93ebcc) // sdot v12.4s, v30.16b, v19.4b[2] + KAI_ASM_INST(0x4fb3ebc9) // sdot v9.4s, v30.16b, v19.4b[3] + KAI_ASM_INST(0x4f96e3ce) // sdot v14.4s, v30.16b, v22.4b[0] + KAI_ASM_INST(0x4fb6e3cb) // sdot v11.4s, v30.16b, v22.4b[1] + KAI_ASM_INST(0x4f96ebcd) // sdot v13.4s, v30.16b, v22.4b[2] + KAI_ASM_INST(0x4fb6ebcf) // sdot v15.4s, v30.16b, v22.4b[3] + KAI_ASM_INST(0x4f88e3d7) // sdot v23.4s, v30.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e3dd) // sdot v29.4s, v30.16b, v8.4b[1] + KAI_ASM_INST(0x4f88ebc0) // sdot v0.4s, v30.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8ebc4) // sdot v4.4s, v30.16b, v8.4b[3] + KAI_ASM_INST(0x4f9ce3d0) // sdot v16.4s, v30.16b, v28.4b[0] + KAI_ASM_INST(0x4fbce3d5) // sdot v21.4s, v30.16b, v28.4b[1] + KAI_ASM_INST(0x4f9cebca) // sdot v10.4s, v30.16b, v28.4b[2] + KAI_ASM_INST(0x4fbcebc3) // sdot v3.4s, v30.16b, v28.4b[3] + bgt label_4 + ldr d7, [x11, #0x0] + ldr q31, [SP, #0x0] + scvtf v2.4s, v2.4s, #0x4 + scvtf v17.4s, v17.4s, #0x4 + scvtf v12.4s, v12.4s, #0x4 + scvtf v9.4s, v9.4s, #0x4 + add x11, x11, #0x8 + shll v7.4s, v7.4h, #0x10 + fmla v31.4s, v2.4s, v7.4s + str q31, [SP, #0x0] + ldr q2, [SP, #0x10] + fmla v2.4s, v17.4s, v7.4s + str q2, [SP, #0x10] + ldr q2, [SP, #0x20] + fmla v2.4s, v12.4s, v7.4s + str q2, [SP, #0x20] + ldr q2, [SP, #0x30] + fmla v2.4s, v9.4s, v7.4s + str q2, [SP, #0x30] + ldr q28, [SP, #0x40] + scvtf v14.4s, v14.4s, #0x4 + scvtf v11.4s, v11.4s, #0x4 + scvtf v13.4s, v13.4s, #0x4 + scvtf v15.4s, v15.4s, #0x4 + fmla v28.4s, v14.4s, v7.4s + str q28, [SP, #0x40] + ldr q1, [SP, #0x50] + fmla v1.4s, v11.4s, v7.4s + str q1, [SP, #0x50] + ldr q11, [SP, #0x60] + fmla v11.4s, v13.4s, v7.4s + str q11, [SP, #0x60] + ldr q14, [SP, #0x70] + fmla v14.4s, v15.4s, v7.4s + str q14, [SP, #0x70] + ldr q19, [SP, #0x80] + scvtf v23.4s, v23.4s, #0x4 + scvtf v29.4s, v29.4s, #0x4 + scvtf v0.4s, v0.4s, #0x4 + scvtf v4.4s, v4.4s, #0x4 + fmla v19.4s, v23.4s, v7.4s + str q19, [SP, #0x80] + ldr q15, [SP, #0x90] + fmla v15.4s, v29.4s, v7.4s + str q15, [SP, #0x90] + ldr q25, [SP, #0xa0] + fmla v25.4s, v0.4s, v7.4s + str q25, [SP, #0xa0] + ldr q12, [SP, #0xb0] + fmla v12.4s, v4.4s, v7.4s + str q12, [SP, #0xb0] + ldr q2, [SP, #0xc0] + scvtf v16.4s, v16.4s, #0x4 + scvtf v21.4s, v21.4s, #0x4 + scvtf v10.4s, v10.4s, #0x4 + scvtf v3.4s, v3.4s, #0x4 + fmla v2.4s, v16.4s, v7.4s + str q2, [SP, #0xc0] + ldr q16, [SP, #0xd0] + fmla v16.4s, v21.4s, v7.4s + str q16, [SP, #0xd0] + ldr q16, [SP, #0xe0] + fmla v16.4s, v10.4s, v7.4s + str q16, [SP, #0xe0] + ldr q16, [SP, #0xf0] + fmla v16.4s, v3.4s, v7.4s + str q16, [SP, #0xf0] + subs x24, x24, #0x1 + bgt label_3 + ld1 { v11.4s }, [x27] + ld1 { v10.4s }, [x23] + add x27, x27, #0x10 + add x23, x23, #0x10 + ld1 { v9.4s }, [x22] + ld1 { v8.4s }, [x21] + add x22, x22, #0x10 + add x21, x21, #0x10 + ldr q31, [SP, #0x0] + ldr q30, [SP, #0x10] + add x20, x12, #0x4 + cmp x10, #0x4 + ldr q29, [SP, #0x20] + ldr q28, [SP, #0x30] + scvtf v11.4s, v11.4s + scvtf v10.4s, v10.4s + ldr q27, [SP, #0x40] + ldr q26, [SP, #0x50] + scvtf v9.4s, v9.4s + scvtf v8.4s, v8.4s + ldr q25, [SP, #0x60] + ldr q24, [SP, #0x70] + ldr q23, [SP, #0x80] + ldr q22, [SP, #0x90] + ldr q21, [SP, #0xa0] + ldr q20, [SP, #0xb0] + ldr q19, [SP, #0xc0] + ldr q18, [SP, #0xd0] + ldr q17, [SP, #0xe0] + ldr q16, [SP, #0xf0] + ldr q7, [x11, #0x0] + ldr q6, [x27, #0x0] + ldr q5, [x23, #0x0] + ldr q4, [x22, #0x0] + ldr q3, [x21, #0x0] + ldr q2, [x11, #0x10] + add x11, x11, #0x20 + ld1r { v1.4s }, [x12] + ld1r { v0.4s }, [x20] + fmla v31.4s, v7.4s, v11.s[0] + fmla v30.4s, v7.4s, v11.s[1] + fmla v29.4s, v7.4s, v11.s[2] + fmla v28.4s, v7.4s, v11.s[3] + fmla v27.4s, v7.4s, v10.s[0] + fmla v26.4s, v7.4s, v10.s[1] + fmla v25.4s, v7.4s, v10.s[2] + fmla v24.4s, v7.4s, v10.s[3] + fmla v23.4s, v7.4s, v9.s[0] + fmla v22.4s, v7.4s, v9.s[1] + fmul v31.4s, v31.4s, v6.s[0] + fmla v21.4s, v7.4s, v9.s[2] + fmla v20.4s, v7.4s, v9.s[3] + fmul v30.4s, v30.4s, v6.s[1] + fmla v19.4s, v7.4s, v8.s[0] + fmla v18.4s, v7.4s, v8.s[1] + fmul v29.4s, v29.4s, v6.s[2] + fmla v17.4s, v7.4s, v8.s[2] + fmla v16.4s, v7.4s, v8.s[3] + fmul v28.4s, v28.4s, v6.s[3] + fmul v27.4s, v27.4s, v5.s[0] + fmul v26.4s, v26.4s, v5.s[1] + fmul v25.4s, v25.4s, v5.s[2] + fmul v24.4s, v24.4s, v5.s[3] + fmul v23.4s, v23.4s, v4.s[0] + fmul v22.4s, v22.4s, v4.s[1] + fmul v21.4s, v21.4s, v4.s[2] + fmul v20.4s, v20.4s, v4.s[3] + fmul v19.4s, v19.4s, v3.s[0] + fmul v18.4s, v18.4s, v3.s[1] + fmul v17.4s, v17.4s, v3.s[2] + fmul v16.4s, v16.4s, v3.s[3] + fadd v31.4s, v31.4s, v2.4s + fadd v30.4s, v30.4s, v2.4s + fadd v29.4s, v29.4s, v2.4s + fadd v28.4s, v28.4s, v2.4s + fadd v27.4s, v27.4s, v2.4s + fadd v26.4s, v26.4s, v2.4s + fadd v25.4s, v25.4s, v2.4s + fadd v24.4s, v24.4s, v2.4s + fadd v23.4s, v23.4s, v2.4s + fadd v22.4s, v22.4s, v2.4s + fadd v21.4s, v21.4s, v2.4s + fadd v20.4s, v20.4s, v2.4s + fadd v19.4s, v19.4s, v2.4s + fadd v18.4s, v18.4s, v2.4s + fadd v17.4s, v17.4s, v2.4s + fadd v16.4s, v16.4s, v2.4s + fmax v31.4s, v31.4s, v1.4s + fmax v30.4s, v30.4s, v1.4s + fmax v29.4s, v29.4s, v1.4s + fmax v28.4s, v28.4s, v1.4s + fmax v27.4s, v27.4s, v1.4s + fmax v26.4s, v26.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + fmax v24.4s, v24.4s, v1.4s + fmax v23.4s, v23.4s, v1.4s + fmax v22.4s, v22.4s, v1.4s + fmax v21.4s, v21.4s, v1.4s + fmax v20.4s, v20.4s, v1.4s + fmax v19.4s, v19.4s, v1.4s + fmax v18.4s, v18.4s, v1.4s + fmax v17.4s, v17.4s, v1.4s + fmax v16.4s, v16.4s, v1.4s + fmin v31.4s, v31.4s, v0.4s + fmin v30.4s, v30.4s, v0.4s + fmin v29.4s, v29.4s, v0.4s + fmin v28.4s, v28.4s, v0.4s + fmin v27.4s, v27.4s, v0.4s + fmin v26.4s, v26.4s, v0.4s + fmin v25.4s, v25.4s, v0.4s + fmin v24.4s, v24.4s, v0.4s + fmin v23.4s, v23.4s, v0.4s + fmin v22.4s, v22.4s, v0.4s + fmin v21.4s, v21.4s, v0.4s + fmin v20.4s, v20.4s, v0.4s + fmin v19.4s, v19.4s, v0.4s + fmin v18.4s, v18.4s, v0.4s + fmin v17.4s, v17.4s, v0.4s + fmin v16.4s, v16.4s, v0.4s + blt label_9 + mov x20, x14 + str q31, [x20, #0x0] + add x20, x20, x13 + str q30, [x20, #0x0] + add x20, x20, x13 + str q29, [x20, #0x0] + add x20, x20, x13 + str q28, [x20, #0x0] + add x20, x20, x13 + str q27, [x20, #0x0] + add x20, x20, x13 + str q26, [x20, #0x0] + add x20, x20, x13 + str q25, [x20, #0x0] + add x20, x20, x13 + str q24, [x20, #0x0] + add x20, x20, x13 + str q23, [x20, #0x0] + add x20, x20, x13 + str q22, [x20, #0x0] + add x20, x20, x13 + str q21, [x20, #0x0] + add x20, x20, x13 + str q20, [x20, #0x0] + add x20, x20, x13 + str q19, [x20, #0x0] + add x20, x20, x13 + str q18, [x20, #0x0] + add x20, x20, x13 + str q17, [x20, #0x0] + add x20, x20, x13 + str q16, [x20, #0x0] + b label_14 +KAI_ASM_LABEL(label_9) // Partial output + mov x28, x14 + add x26, x28, x13, LSL #2 + add x25, x26, x13, LSL #1 + add x24, x26, x13 + add x23, x25, x13 + add x22, x28, x13, LSL #1 + add x21, x28, x13 + add x20, x22, x13 + add x27, x23, x13 + tbz x10, #1, label_10 + st1 { v24.d }[0], [x23], #0x8 + st1 { v25.d }[0], [x25], #0x8 + st1 { v26.d }[0], [x24], #0x8 + st1 { v27.d }[0], [x26], #0x8 + st1 { v28.d }[0], [x20], #0x8 + st1 { v29.d }[0], [x22], #0x8 + st1 { v30.d }[0], [x21], #0x8 + st1 { v31.d }[0], [x28], #0x8 + tbz x10, #0, label_11 + st1 { v24.s }[2], [x23] + st1 { v25.s }[2], [x25] + st1 { v26.s }[2], [x24] + st1 { v27.s }[2], [x26] + st1 { v28.s }[2], [x20] + st1 { v29.s }[2], [x22] + st1 { v30.s }[2], [x21] + st1 { v31.s }[2], [x28] + b label_11 +KAI_ASM_LABEL(label_10) // Output block 0: partial_1_0 + st1 { v24.s }[0], [x23] + st1 { v25.s }[0], [x25] + st1 { v26.s }[0], [x24] + st1 { v27.s }[0], [x26] + st1 { v28.s }[0], [x20] + st1 { v29.s }[0], [x22] + st1 { v30.s }[0], [x21] + st1 { v31.s }[0], [x28] +KAI_ASM_LABEL(label_11) // Output block 0: Done + add x26, x27, x13, LSL #2 + add x25, x27, x13, LSL #1 + add x24, x26, x13, LSL #1 + add x23, x27, x13 + add x22, x25, x13 + add x21, x26, x13 + add x20, x24, x13 + tbz x10, #1, label_12 + st1 { v16.d }[0], [x20], #0x8 + st1 { v17.d }[0], [x24], #0x8 + st1 { v18.d }[0], [x21], #0x8 + st1 { v19.d }[0], [x26], #0x8 + st1 { v20.d }[0], [x22], #0x8 + st1 { v21.d }[0], [x25], #0x8 + st1 { v22.d }[0], [x23], #0x8 + st1 { v23.d }[0], [x27], #0x8 + tbz x10, #0, label_13 + st1 { v16.s }[2], [x20] + st1 { v17.s }[2], [x24] + st1 { v18.s }[2], [x21] + st1 { v19.s }[2], [x26] + st1 { v20.s }[2], [x22] + st1 { v21.s }[2], [x25] + st1 { v22.s }[2], [x23] + st1 { v23.s }[2], [x27] + b label_13 +KAI_ASM_LABEL(label_12) // Output block 1: partial_1_0 + st1 { v16.s }[0], [x20] + st1 { v17.s }[0], [x24] + st1 { v18.s }[0], [x21] + st1 { v19.s }[0], [x26] + st1 { v20.s }[0], [x22] + st1 { v21.s }[0], [x25] + st1 { v22.s }[0], [x23] + st1 { v23.s }[0], [x27] +KAI_ASM_LABEL(label_13) // Output block 1: Done +KAI_ASM_LABEL(label_14) // Output stage exit + subs x10, x10, #0x4 + add x14, x14, #0x10 + bgt label_2 + mov x20, #0x4 + sub x15, x15, #0x10 + cmp x15, #0x10 + mov x14, x9 + madd x8, x20, x5, x8 + bge label_1 +KAI_ASM_LABEL(label_15) // Row loop skip + cbz x15, label_25 +KAI_ASM_LABEL(label_16) // Row tail: Row loop + mov x26, x17 + mov x25, x16 + add x24, x14, x13, LSL #2 +KAI_ASM_LABEL(label_17) // Row tail: Column loop + movi v16.4s, #0x0 + mov x27, x8 + mov x21, x7 + str q16, [SP, #0x0] + str q16, [SP, #0x10] + str q16, [SP, #0x20] + str q16, [SP, #0x30] +KAI_ASM_LABEL(label_18) // Row tail: Block loop + movi v2.4s, #0x0 + movi v17.4s, #0x0 + mov x20, x6 + movi v12.4s, #0x0 + movi v9.4s, #0x0 +KAI_ASM_LABEL(label_19) // Row tail: Sub block loop + ldr q0, [x26, #0x0] + ldr q31, [x27, #0x0] + movi v30.16b, #0xf0 + subs x20, x20, #0x1 + ldr q29, [x26, #0x10] + ldr q28, [x27, #0x10] + ldr q27, [x26, #0x20] + ldr q26, [x27, #0x20] + ldr q25, [x26, #0x30] + ldr q24, [x27, #0x30] + shl v23.16b, v0.16b, #0x4 + and v0.16b, v0.16b, v30.16b + ldr q22, [x27, #0x40] + ldr q21, [x27, #0x50] + shl v20.16b, v29.16b, #0x4 + and v29.16b, v29.16b, v30.16b + ldr q7, [x27, #0x60] + ldr q18, [x27, #0x70] + shl v19.16b, v27.16b, #0x4 + and v27.16b, v27.16b, v30.16b + KAI_ASM_INST(0x4f9fe2e2) // sdot v2.4s, v23.16b, v31.4b[0] + KAI_ASM_INST(0x4fbfe2f1) // sdot v17.4s, v23.16b, v31.4b[1] + shl v16.16b, v25.16b, #0x4 + add x26, x26, #0x40 + KAI_ASM_INST(0x4f9feaec) // sdot v12.4s, v23.16b, v31.4b[2] + KAI_ASM_INST(0x4fbfeae9) // sdot v9.4s, v23.16b, v31.4b[3] + and v25.16b, v25.16b, v30.16b + add x27, x27, #0x80 + KAI_ASM_INST(0x4f9ce282) // sdot v2.4s, v20.16b, v28.4b[0] + KAI_ASM_INST(0x4fbce291) // sdot v17.4s, v20.16b, v28.4b[1] + KAI_ASM_INST(0x4f9cea8c) // sdot v12.4s, v20.16b, v28.4b[2] + KAI_ASM_INST(0x4fbcea89) // sdot v9.4s, v20.16b, v28.4b[3] + KAI_ASM_INST(0x4f9ae262) // sdot v2.4s, v19.16b, v26.4b[0] + KAI_ASM_INST(0x4fbae271) // sdot v17.4s, v19.16b, v26.4b[1] + KAI_ASM_INST(0x4f9aea6c) // sdot v12.4s, v19.16b, v26.4b[2] + KAI_ASM_INST(0x4fbaea69) // sdot v9.4s, v19.16b, v26.4b[3] + KAI_ASM_INST(0x4f98e202) // sdot v2.4s, v16.16b, v24.4b[0] + KAI_ASM_INST(0x4fb8e211) // sdot v17.4s, v16.16b, v24.4b[1] + KAI_ASM_INST(0x4f98ea0c) // sdot v12.4s, v16.16b, v24.4b[2] + KAI_ASM_INST(0x4fb8ea09) // sdot v9.4s, v16.16b, v24.4b[3] + KAI_ASM_INST(0x4f96e002) // sdot v2.4s, v0.16b, v22.4b[0] + KAI_ASM_INST(0x4fb6e011) // sdot v17.4s, v0.16b, v22.4b[1] + KAI_ASM_INST(0x4f96e80c) // sdot v12.4s, v0.16b, v22.4b[2] + KAI_ASM_INST(0x4fb6e809) // sdot v9.4s, v0.16b, v22.4b[3] + KAI_ASM_INST(0x4f95e3a2) // sdot v2.4s, v29.16b, v21.4b[0] + KAI_ASM_INST(0x4fb5e3b1) // sdot v17.4s, v29.16b, v21.4b[1] + KAI_ASM_INST(0x4f95ebac) // sdot v12.4s, v29.16b, v21.4b[2] + KAI_ASM_INST(0x4fb5eba9) // sdot v9.4s, v29.16b, v21.4b[3] + KAI_ASM_INST(0x4f87e362) // sdot v2.4s, v27.16b, v7.4b[0] + KAI_ASM_INST(0x4fa7e371) // sdot v17.4s, v27.16b, v7.4b[1] + KAI_ASM_INST(0x4f87eb6c) // sdot v12.4s, v27.16b, v7.4b[2] + KAI_ASM_INST(0x4fa7eb69) // sdot v9.4s, v27.16b, v7.4b[3] + KAI_ASM_INST(0x4f92e322) // sdot v2.4s, v25.16b, v18.4b[0] + KAI_ASM_INST(0x4fb2e331) // sdot v17.4s, v25.16b, v18.4b[1] + KAI_ASM_INST(0x4f92eb2c) // sdot v12.4s, v25.16b, v18.4b[2] + KAI_ASM_INST(0x4fb2eb29) // sdot v9.4s, v25.16b, v18.4b[3] + bgt label_19 + ldr d7, [x26, #0x0] + ldr q16, [SP, #0x0] + scvtf v2.4s, v2.4s, #0x4 + scvtf v17.4s, v17.4s, #0x4 + scvtf v12.4s, v12.4s, #0x4 + scvtf v9.4s, v9.4s, #0x4 + add x26, x26, #0x8 + shll v7.4s, v7.4h, #0x10 + fmla v16.4s, v2.4s, v7.4s + str q16, [SP, #0x0] + ldr q16, [SP, #0x10] + fmla v16.4s, v17.4s, v7.4s + str q16, [SP, #0x10] + ldr q16, [SP, #0x20] + fmla v16.4s, v12.4s, v7.4s + str q16, [SP, #0x20] + ldr q16, [SP, #0x30] + fmla v16.4s, v9.4s, v7.4s + str q16, [SP, #0x30] + subs x21, x21, #0x1 + bgt label_18 + ld1 { v21.4s }, [x27] + ldr q31, [SP, #0x0] + add x27, x27, #0x10 + add x20, x12, #0x4 + ldr q30, [SP, #0x10] + ldr q29, [SP, #0x20] + cmp x25, #0x4 + ldr q28, [SP, #0x30] + ldr q20, [x26, #0x0] + ldr q19, [x27, #0x0] + ldr q18, [x26, #0x10] + scvtf v21.4s, v21.4s + add x26, x26, #0x20 + ld1r { v17.4s }, [x12] + ld1r { v16.4s }, [x20] + fmla v31.4s, v20.4s, v21.s[0] + fmla v30.4s, v20.4s, v21.s[1] + fmla v29.4s, v20.4s, v21.s[2] + fmla v28.4s, v20.4s, v21.s[3] + fmul v31.4s, v31.4s, v19.s[0] + fmul v30.4s, v30.4s, v19.s[1] + fadd v31.4s, v31.4s, v18.4s + fmul v29.4s, v29.4s, v19.s[2] + fmul v28.4s, v28.4s, v19.s[3] + fadd v30.4s, v30.4s, v18.4s + fmax v31.4s, v31.4s, v17.4s + fadd v29.4s, v29.4s, v18.4s + fadd v28.4s, v28.4s, v18.4s + fmax v30.4s, v30.4s, v17.4s + fmin v31.4s, v31.4s, v16.4s + fmax v29.4s, v29.4s, v17.4s + fmax v28.4s, v28.4s, v17.4s + fmin v30.4s, v30.4s, v16.4s + fmin v29.4s, v29.4s, v16.4s + fmin v28.4s, v28.4s, v16.4s + blt label_21 + mov x20, x14 + cmp x15, #0x1 + str q31, [x20, #0x0] + add x20, x20, x13 + ble label_24 + cmp x15, #0x2 + str q30, [x20, #0x0] + add x20, x20, x13 + ble label_24 + cmp x15, #0x3 + str q29, [x20, #0x0] + add x20, x20, x13 + ble label_24 + str q28, [x20, #0x0] + b label_24 +KAI_ASM_LABEL(label_21) // Row tail: Partial output + mov x23, x14 + cmp x15, #0x1 + add x22, x23, x13 + csel x22, x22, x23, GT + cmp x15, #0x2 + add x21, x23, x13, LSL #1 + csel x21, x21, x22, GT + cmp x15, #0x3 + add x20, x21, x13 + csel x20, x20, x21, GT + tbz x25, #1, label_22 + st1 { v28.d }[0], [x20], #0x8 + st1 { v29.d }[0], [x21], #0x8 + st1 { v30.d }[0], [x22], #0x8 + st1 { v31.d }[0], [x23], #0x8 + tbz x25, #0, label_23 + st1 { v28.s }[2], [x20] + st1 { v29.s }[2], [x21] + st1 { v30.s }[2], [x22] + st1 { v31.s }[2], [x23] + b label_23 +KAI_ASM_LABEL(label_22) // Row tail: Output block 0: partial_1_0 + st1 { v28.s }[0], [x20] + st1 { v29.s }[0], [x21] + st1 { v30.s }[0], [x22] + st1 { v31.s }[0], [x23] +KAI_ASM_LABEL(label_23) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_24) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x14, x14, #0x10 + bgt label_17 + subs x15, x15, #0x4 + add x8, x8, x5 + mov x14, x24 + bgt label_16 +KAI_ASM_LABEL(label_25) // Row tail: Row loop skip + add SP, SP, #0x100 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c index 5165137e..b4faf0ec 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c @@ -18,6 +18,20 @@ #include "kai/kai_common.h" +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod(KernelArgs* args_ptr); + // Compute args static const size_t kai_m_step = 16; static const size_t kai_n_step = 4; @@ -116,6 +130,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_ne size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod( size_t n_idx, size_t k, size_t bl) { + KAI_ASSUME((k % bl) == 0); KAI_ASSUME((n_idx % kai_n_step) == 0); return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); @@ -146,752 +161,29 @@ void kai_run_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod( float scalar_min, // float scalar_max) { KAI_ASSUME(dst_stride_col == sizeof(float)); + KAI_ASSUME((k % bl) == 0); KAI_ASSUME((bl % kai_bl) == 0); if (m == 0) { return; } - - const size_t num_subblocks = bl / kai_bl; + const size_t num_subblocks = bl / 32; const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); const float clamp_vals[2] = {scalar_min, scalar_max}; - __asm__ __volatile__( - "mov x13, #0x80\n" - "mov x12, %x[m]\n" - "mov x20, #0x20\n" - "sub SP, SP, #0x100\n" - "mul x13, %x[num_subblocks], x13\n" - "cmp x12, #0x10\n" - "madd x13, %x[num_blocks], x13, x20\n" - "blt 15f\n" - "1:" // Row loop - "mov x11, %x[rhs_packed]\n" - "mov x10, %x[n]\n" - "add x9, %x[dst], %x[dst_stride_row], LSL #4\n" - "2:" // Column loop - "mov x27, %x[lhs_packed]\n" - "movi v6.4s, #0x0\n" - "mov x24, %x[num_blocks]\n" - "str q6, [SP, #0x0]\n" - "str q6, [SP, #0x10]\n" - "str q6, [SP, #0x20]\n" - "add x23, x27, x13\n" - "add x22, x23, x13\n" - "str q6, [SP, #0x30]\n" - "add x21, x22, x13\n" - "str q6, [SP, #0x40]\n" - "str q6, [SP, #0x50]\n" - "str q6, [SP, #0x60]\n" - "str q6, [SP, #0x70]\n" - "str q6, [SP, #0x80]\n" - "str q6, [SP, #0x90]\n" - "str q6, [SP, #0xa0]\n" - "str q6, [SP, #0xb0]\n" - "str q6, [SP, #0xc0]\n" - "str q6, [SP, #0xd0]\n" - "str q6, [SP, #0xe0]\n" - "str q6, [SP, #0xf0]\n" - "3:" // Block loop - "movi v2.4s, #0x0\n" - "movi v17.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "movi v12.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "movi v14.4s, #0x0\n" - "movi v11.4s, #0x0\n" - "movi v13.4s, #0x0\n" - "movi v15.4s, #0x0\n" - "movi v23.4s, #0x0\n" - "movi v29.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "movi v4.4s, #0x0\n" - "movi v16.4s, #0x0\n" - "movi v21.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "movi v3.4s, #0x0\n" - "4:" // Sub block loop - "ldr q6, [x11, #0x0]\n" - "ldr q1, [x27, #0x0]\n" - "movi v25.16b, #0xf0\n" - "subs x20, x20, #0x1\n" - "ldr q5, [x23, #0x0]\n" - "ldr q30, [x22, #0x0]\n" - "ldr q24, [x21, #0x0]\n" - "ldr q18, [x11, #0x10]\n" - "ldr q27, [x27, #0x10]\n" - "ldr q20, [x23, #0x10]\n" - "shl v31.16b, v6.16b, #0x4\n" - "and v6.16b, v6.16b, v25.16b\n" - "ldr q19, [x22, #0x10]\n" - "ldr q26, [x21, #0x10]\n" - "ldr q7, [x11, #0x20]\n" - "ldr q8, [x27, #0x20]\n" - "shl v22.16b, v18.16b, #0x4\n" - "and v18.16b, v18.16b, v25.16b\n" - "ldr q28, [x23, #0x20]\n" - ".inst 0x4f81e3e2 // sdot v2.4s, v31.16b, v1.4b[0]\n" - ".inst 0x4fa1e3f1 // sdot v17.4s, v31.16b, v1.4b[1]\n" - ".inst 0x4f81ebec // sdot v12.4s, v31.16b, v1.4b[2]\n" - ".inst 0x4fa1ebe9 // sdot v9.4s, v31.16b, v1.4b[3]\n" - "ldr q1, [x22, #0x20]\n" - ".inst 0x4f85e3ee // sdot v14.4s, v31.16b, v5.4b[0]\n" - ".inst 0x4fa5e3eb // sdot v11.4s, v31.16b, v5.4b[1]\n" - ".inst 0x4f85ebed // sdot v13.4s, v31.16b, v5.4b[2]\n" - ".inst 0x4fa5ebef // sdot v15.4s, v31.16b, v5.4b[3]\n" - "ldr q5, [x21, #0x20]\n" - ".inst 0x4f9ee3f7 // sdot v23.4s, v31.16b, v30.4b[0]\n" - ".inst 0x4fbee3fd // sdot v29.4s, v31.16b, v30.4b[1]\n" - ".inst 0x4f9eebe0 // sdot v0.4s, v31.16b, v30.4b[2]\n" - ".inst 0x4fbeebe4 // sdot v4.4s, v31.16b, v30.4b[3]\n" - "ldr q30, [x11, #0x30]\n" - "add x11, x11, #0x40\n" - ".inst 0x4f98e3f0 // sdot v16.4s, v31.16b, v24.4b[0]\n" - ".inst 0x4fb8e3f5 // sdot v21.4s, v31.16b, v24.4b[1]\n" - ".inst 0x4f98ebea // sdot v10.4s, v31.16b, v24.4b[2]\n" - ".inst 0x4fb8ebe3 // sdot v3.4s, v31.16b, v24.4b[3]\n" - "ldr q24, [x27, #0x30]\n" - "ldr q31, [x23, #0x30]\n" - ".inst 0x4f9be2c2 // sdot v2.4s, v22.16b, v27.4b[0]\n" - ".inst 0x4fbbe2d1 // sdot v17.4s, v22.16b, v27.4b[1]\n" - ".inst 0x4f9beacc // sdot v12.4s, v22.16b, v27.4b[2]\n" - ".inst 0x4fbbeac9 // sdot v9.4s, v22.16b, v27.4b[3]\n" - "ldr q27, [x22, #0x30]\n" - ".inst 0x4f94e2ce // sdot v14.4s, v22.16b, v20.4b[0]\n" - ".inst 0x4fb4e2cb // sdot v11.4s, v22.16b, v20.4b[1]\n" - ".inst 0x4f94eacd // sdot v13.4s, v22.16b, v20.4b[2]\n" - ".inst 0x4fb4eacf // sdot v15.4s, v22.16b, v20.4b[3]\n" - "ldr q20, [x21, #0x30]\n" - ".inst 0x4f93e2d7 // sdot v23.4s, v22.16b, v19.4b[0]\n" - ".inst 0x4fb3e2dd // sdot v29.4s, v22.16b, v19.4b[1]\n" - ".inst 0x4f93eac0 // sdot v0.4s, v22.16b, v19.4b[2]\n" - ".inst 0x4fb3eac4 // sdot v4.4s, v22.16b, v19.4b[3]\n" - "ldr q19, [x27, #0x40]\n" - ".inst 0x4f9ae2d0 // sdot v16.4s, v22.16b, v26.4b[0]\n" - ".inst 0x4fbae2d5 // sdot v21.4s, v22.16b, v26.4b[1]\n" - ".inst 0x4f9aeaca // sdot v10.4s, v22.16b, v26.4b[2]\n" - ".inst 0x4fbaeac3 // sdot v3.4s, v22.16b, v26.4b[3]\n" - "ldr q22, [x23, #0x40]\n" - "shl v26.16b, v7.16b, #0x4\n" - "and v7.16b, v7.16b, v25.16b\n" - ".inst 0x4f88e342 // sdot v2.4s, v26.16b, v8.4b[0]\n" - ".inst 0x4fa8e351 // sdot v17.4s, v26.16b, v8.4b[1]\n" - ".inst 0x4f88eb4c // sdot v12.4s, v26.16b, v8.4b[2]\n" - ".inst 0x4fa8eb49 // sdot v9.4s, v26.16b, v8.4b[3]\n" - "ldr q8, [x22, #0x40]\n" - ".inst 0x4f9ce34e // sdot v14.4s, v26.16b, v28.4b[0]\n" - ".inst 0x4fbce34b // sdot v11.4s, v26.16b, v28.4b[1]\n" - ".inst 0x4f9ceb4d // sdot v13.4s, v26.16b, v28.4b[2]\n" - ".inst 0x4fbceb4f // sdot v15.4s, v26.16b, v28.4b[3]\n" - "ldr q28, [x21, #0x40]\n" - ".inst 0x4f81e357 // sdot v23.4s, v26.16b, v1.4b[0]\n" - ".inst 0x4fa1e35d // sdot v29.4s, v26.16b, v1.4b[1]\n" - ".inst 0x4f81eb40 // sdot v0.4s, v26.16b, v1.4b[2]\n" - ".inst 0x4fa1eb44 // sdot v4.4s, v26.16b, v1.4b[3]\n" - "ldr q1, [x27, #0x50]\n" - ".inst 0x4f85e350 // sdot v16.4s, v26.16b, v5.4b[0]\n" - ".inst 0x4fa5e355 // sdot v21.4s, v26.16b, v5.4b[1]\n" - ".inst 0x4f85eb4a // sdot v10.4s, v26.16b, v5.4b[2]\n" - ".inst 0x4fa5eb43 // sdot v3.4s, v26.16b, v5.4b[3]\n" - "ldr q5, [x23, #0x50]\n" - "shl v26.16b, v30.16b, #0x4\n" - "and v30.16b, v30.16b, v25.16b\n" - "ldr q25, [x22, #0x50]\n" - ".inst 0x4f98e342 // sdot v2.4s, v26.16b, v24.4b[0]\n" - ".inst 0x4fb8e351 // sdot v17.4s, v26.16b, v24.4b[1]\n" - ".inst 0x4f98eb4c // sdot v12.4s, v26.16b, v24.4b[2]\n" - ".inst 0x4fb8eb49 // sdot v9.4s, v26.16b, v24.4b[3]\n" - "ldr q24, [x21, #0x50]\n" - ".inst 0x4f9fe34e // sdot v14.4s, v26.16b, v31.4b[0]\n" - ".inst 0x4fbfe34b // sdot v11.4s, v26.16b, v31.4b[1]\n" - ".inst 0x4f9feb4d // sdot v13.4s, v26.16b, v31.4b[2]\n" - ".inst 0x4fbfeb4f // sdot v15.4s, v26.16b, v31.4b[3]\n" - "ldr q31, [x27, #0x60]\n" - ".inst 0x4f9be357 // sdot v23.4s, v26.16b, v27.4b[0]\n" - ".inst 0x4fbbe35d // sdot v29.4s, v26.16b, v27.4b[1]\n" - ".inst 0x4f9beb40 // sdot v0.4s, v26.16b, v27.4b[2]\n" - ".inst 0x4fbbeb44 // sdot v4.4s, v26.16b, v27.4b[3]\n" - "ldr q27, [x23, #0x60]\n" - ".inst 0x4f94e350 // sdot v16.4s, v26.16b, v20.4b[0]\n" - ".inst 0x4fb4e355 // sdot v21.4s, v26.16b, v20.4b[1]\n" - ".inst 0x4f94eb4a // sdot v10.4s, v26.16b, v20.4b[2]\n" - ".inst 0x4fb4eb43 // sdot v3.4s, v26.16b, v20.4b[3]\n" - "ldr q26, [x22, #0x60]\n" - "ldr q20, [x21, #0x60]\n" - ".inst 0x4f93e0c2 // sdot v2.4s, v6.16b, v19.4b[0]\n" - ".inst 0x4fb3e0d1 // sdot v17.4s, v6.16b, v19.4b[1]\n" - ".inst 0x4f93e8cc // sdot v12.4s, v6.16b, v19.4b[2]\n" - ".inst 0x4fb3e8c9 // sdot v9.4s, v6.16b, v19.4b[3]\n" - "ldr q19, [x27, #0x70]\n" - "add x27, x27, #0x80\n" - ".inst 0x4f96e0ce // sdot v14.4s, v6.16b, v22.4b[0]\n" - ".inst 0x4fb6e0cb // sdot v11.4s, v6.16b, v22.4b[1]\n" - ".inst 0x4f96e8cd // sdot v13.4s, v6.16b, v22.4b[2]\n" - ".inst 0x4fb6e8cf // sdot v15.4s, v6.16b, v22.4b[3]\n" - "ldr q22, [x23, #0x70]\n" - "add x23, x23, #0x80\n" - ".inst 0x4f88e0d7 // sdot v23.4s, v6.16b, v8.4b[0]\n" - ".inst 0x4fa8e0dd // sdot v29.4s, v6.16b, v8.4b[1]\n" - ".inst 0x4f88e8c0 // sdot v0.4s, v6.16b, v8.4b[2]\n" - ".inst 0x4fa8e8c4 // sdot v4.4s, v6.16b, v8.4b[3]\n" - "ldr q8, [x22, #0x70]\n" - "add x22, x22, #0x80\n" - ".inst 0x4f9ce0d0 // sdot v16.4s, v6.16b, v28.4b[0]\n" - ".inst 0x4fbce0d5 // sdot v21.4s, v6.16b, v28.4b[1]\n" - ".inst 0x4f9ce8ca // sdot v10.4s, v6.16b, v28.4b[2]\n" - ".inst 0x4fbce8c3 // sdot v3.4s, v6.16b, v28.4b[3]\n" - "ldr q28, [x21, #0x70]\n" - "add x21, x21, #0x80\n" - ".inst 0x4f81e242 // sdot v2.4s, v18.16b, v1.4b[0]\n" - ".inst 0x4fa1e251 // sdot v17.4s, v18.16b, v1.4b[1]\n" - ".inst 0x4f81ea4c // sdot v12.4s, v18.16b, v1.4b[2]\n" - ".inst 0x4fa1ea49 // sdot v9.4s, v18.16b, v1.4b[3]\n" - ".inst 0x4f85e24e // sdot v14.4s, v18.16b, v5.4b[0]\n" - ".inst 0x4fa5e24b // sdot v11.4s, v18.16b, v5.4b[1]\n" - ".inst 0x4f85ea4d // sdot v13.4s, v18.16b, v5.4b[2]\n" - ".inst 0x4fa5ea4f // sdot v15.4s, v18.16b, v5.4b[3]\n" - ".inst 0x4f99e257 // sdot v23.4s, v18.16b, v25.4b[0]\n" - ".inst 0x4fb9e25d // sdot v29.4s, v18.16b, v25.4b[1]\n" - ".inst 0x4f99ea40 // sdot v0.4s, v18.16b, v25.4b[2]\n" - ".inst 0x4fb9ea44 // sdot v4.4s, v18.16b, v25.4b[3]\n" - ".inst 0x4f98e250 // sdot v16.4s, v18.16b, v24.4b[0]\n" - ".inst 0x4fb8e255 // sdot v21.4s, v18.16b, v24.4b[1]\n" - ".inst 0x4f98ea4a // sdot v10.4s, v18.16b, v24.4b[2]\n" - ".inst 0x4fb8ea43 // sdot v3.4s, v18.16b, v24.4b[3]\n" - ".inst 0x4f9fe0e2 // sdot v2.4s, v7.16b, v31.4b[0]\n" - ".inst 0x4fbfe0f1 // sdot v17.4s, v7.16b, v31.4b[1]\n" - ".inst 0x4f9fe8ec // sdot v12.4s, v7.16b, v31.4b[2]\n" - ".inst 0x4fbfe8e9 // sdot v9.4s, v7.16b, v31.4b[3]\n" - ".inst 0x4f9be0ee // sdot v14.4s, v7.16b, v27.4b[0]\n" - ".inst 0x4fbbe0eb // sdot v11.4s, v7.16b, v27.4b[1]\n" - ".inst 0x4f9be8ed // sdot v13.4s, v7.16b, v27.4b[2]\n" - ".inst 0x4fbbe8ef // sdot v15.4s, v7.16b, v27.4b[3]\n" - ".inst 0x4f9ae0f7 // sdot v23.4s, v7.16b, v26.4b[0]\n" - ".inst 0x4fbae0fd // sdot v29.4s, v7.16b, v26.4b[1]\n" - ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" - ".inst 0x4fbae8e4 // sdot v4.4s, v7.16b, v26.4b[3]\n" - ".inst 0x4f94e0f0 // sdot v16.4s, v7.16b, v20.4b[0]\n" - ".inst 0x4fb4e0f5 // sdot v21.4s, v7.16b, v20.4b[1]\n" - ".inst 0x4f94e8ea // sdot v10.4s, v7.16b, v20.4b[2]\n" - ".inst 0x4fb4e8e3 // sdot v3.4s, v7.16b, v20.4b[3]\n" - ".inst 0x4f93e3c2 // sdot v2.4s, v30.16b, v19.4b[0]\n" - ".inst 0x4fb3e3d1 // sdot v17.4s, v30.16b, v19.4b[1]\n" - ".inst 0x4f93ebcc // sdot v12.4s, v30.16b, v19.4b[2]\n" - ".inst 0x4fb3ebc9 // sdot v9.4s, v30.16b, v19.4b[3]\n" - ".inst 0x4f96e3ce // sdot v14.4s, v30.16b, v22.4b[0]\n" - ".inst 0x4fb6e3cb // sdot v11.4s, v30.16b, v22.4b[1]\n" - ".inst 0x4f96ebcd // sdot v13.4s, v30.16b, v22.4b[2]\n" - ".inst 0x4fb6ebcf // sdot v15.4s, v30.16b, v22.4b[3]\n" - ".inst 0x4f88e3d7 // sdot v23.4s, v30.16b, v8.4b[0]\n" - ".inst 0x4fa8e3dd // sdot v29.4s, v30.16b, v8.4b[1]\n" - ".inst 0x4f88ebc0 // sdot v0.4s, v30.16b, v8.4b[2]\n" - ".inst 0x4fa8ebc4 // sdot v4.4s, v30.16b, v8.4b[3]\n" - ".inst 0x4f9ce3d0 // sdot v16.4s, v30.16b, v28.4b[0]\n" - ".inst 0x4fbce3d5 // sdot v21.4s, v30.16b, v28.4b[1]\n" - ".inst 0x4f9cebca // sdot v10.4s, v30.16b, v28.4b[2]\n" - ".inst 0x4fbcebc3 // sdot v3.4s, v30.16b, v28.4b[3]\n" - "bgt 4b\n" - "ldr d7, [x11, #0x0]\n" - "ldr q31, [SP, #0x0]\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v12.4s, v12.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "add x11, x11, #0x8\n" - "shll v7.4s, v7.4h, #0x10\n" - "fmla v31.4s, v2.4s, v7.4s\n" - "str q31, [SP, #0x0]\n" - "ldr q2, [SP, #0x10]\n" - "fmla v2.4s, v17.4s, v7.4s\n" - "str q2, [SP, #0x10]\n" - "ldr q2, [SP, #0x20]\n" - "fmla v2.4s, v12.4s, v7.4s\n" - "str q2, [SP, #0x20]\n" - "ldr q2, [SP, #0x30]\n" - "fmla v2.4s, v9.4s, v7.4s\n" - "str q2, [SP, #0x30]\n" - "ldr q28, [SP, #0x40]\n" - "scvtf v14.4s, v14.4s, #0x4\n" - "scvtf v11.4s, v11.4s, #0x4\n" - "scvtf v13.4s, v13.4s, #0x4\n" - "scvtf v15.4s, v15.4s, #0x4\n" - "fmla v28.4s, v14.4s, v7.4s\n" - "str q28, [SP, #0x40]\n" - "ldr q1, [SP, #0x50]\n" - "fmla v1.4s, v11.4s, v7.4s\n" - "str q1, [SP, #0x50]\n" - "ldr q11, [SP, #0x60]\n" - "fmla v11.4s, v13.4s, v7.4s\n" - "str q11, [SP, #0x60]\n" - "ldr q14, [SP, #0x70]\n" - "fmla v14.4s, v15.4s, v7.4s\n" - "str q14, [SP, #0x70]\n" - "ldr q19, [SP, #0x80]\n" - "scvtf v23.4s, v23.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "scvtf v4.4s, v4.4s, #0x4\n" - "fmla v19.4s, v23.4s, v7.4s\n" - "str q19, [SP, #0x80]\n" - "ldr q15, [SP, #0x90]\n" - "fmla v15.4s, v29.4s, v7.4s\n" - "str q15, [SP, #0x90]\n" - "ldr q25, [SP, #0xa0]\n" - "fmla v25.4s, v0.4s, v7.4s\n" - "str q25, [SP, #0xa0]\n" - "ldr q12, [SP, #0xb0]\n" - "fmla v12.4s, v4.4s, v7.4s\n" - "str q12, [SP, #0xb0]\n" - "ldr q2, [SP, #0xc0]\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "scvtf v21.4s, v21.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "scvtf v3.4s, v3.4s, #0x4\n" - "fmla v2.4s, v16.4s, v7.4s\n" - "str q2, [SP, #0xc0]\n" - "ldr q16, [SP, #0xd0]\n" - "fmla v16.4s, v21.4s, v7.4s\n" - "str q16, [SP, #0xd0]\n" - "ldr q16, [SP, #0xe0]\n" - "fmla v16.4s, v10.4s, v7.4s\n" - "str q16, [SP, #0xe0]\n" - "ldr q16, [SP, #0xf0]\n" - "fmla v16.4s, v3.4s, v7.4s\n" - "str q16, [SP, #0xf0]\n" - "subs x24, x24, #0x1\n" - "bgt 3b\n" - "ld1 { v11.4s }, [x27]\n" - "ld1 { v10.4s }, [x23]\n" - "add x27, x27, #0x10\n" - "add x23, x23, #0x10\n" - "ld1 { v9.4s }, [x22]\n" - "ld1 { v8.4s }, [x21]\n" - "add x22, x22, #0x10\n" - "add x21, x21, #0x10\n" - "ldr q31, [SP, #0x0]\n" - "ldr q30, [SP, #0x10]\n" - "add x20, %x[clamp_vals], #0x4\n" - "cmp x10, #0x4\n" - "ldr q29, [SP, #0x20]\n" - "ldr q28, [SP, #0x30]\n" - "scvtf v11.4s, v11.4s\n" - "scvtf v10.4s, v10.4s\n" - "ldr q27, [SP, #0x40]\n" - "ldr q26, [SP, #0x50]\n" - "scvtf v9.4s, v9.4s\n" - "scvtf v8.4s, v8.4s\n" - "ldr q25, [SP, #0x60]\n" - "ldr q24, [SP, #0x70]\n" - "ldr q23, [SP, #0x80]\n" - "ldr q22, [SP, #0x90]\n" - "ldr q21, [SP, #0xa0]\n" - "ldr q20, [SP, #0xb0]\n" - "ldr q19, [SP, #0xc0]\n" - "ldr q18, [SP, #0xd0]\n" - "ldr q17, [SP, #0xe0]\n" - "ldr q16, [SP, #0xf0]\n" - "ldr q7, [x11, #0x0]\n" - "ldr q6, [x27, #0x0]\n" - "ldr q5, [x23, #0x0]\n" - "ldr q4, [x22, #0x0]\n" - "ldr q3, [x21, #0x0]\n" - "ldr q2, [x11, #0x10]\n" - "add x11, x11, #0x20\n" - "ld1r { v1.4s }, [%x[clamp_vals]]\n" - "ld1r { v0.4s }, [x20]\n" - "fmla v31.4s, v7.4s, v11.s[0]\n" - "fmla v30.4s, v7.4s, v11.s[1]\n" - "fmla v29.4s, v7.4s, v11.s[2]\n" - "fmla v28.4s, v7.4s, v11.s[3]\n" - "fmla v27.4s, v7.4s, v10.s[0]\n" - "fmla v26.4s, v7.4s, v10.s[1]\n" - "fmla v25.4s, v7.4s, v10.s[2]\n" - "fmla v24.4s, v7.4s, v10.s[3]\n" - "fmla v23.4s, v7.4s, v9.s[0]\n" - "fmla v22.4s, v7.4s, v9.s[1]\n" - "fmul v31.4s, v31.4s, v6.s[0]\n" - "fmla v21.4s, v7.4s, v9.s[2]\n" - "fmla v20.4s, v7.4s, v9.s[3]\n" - "fmul v30.4s, v30.4s, v6.s[1]\n" - "fmla v19.4s, v7.4s, v8.s[0]\n" - "fmla v18.4s, v7.4s, v8.s[1]\n" - "fmul v29.4s, v29.4s, v6.s[2]\n" - "fmla v17.4s, v7.4s, v8.s[2]\n" - "fmla v16.4s, v7.4s, v8.s[3]\n" - "fmul v28.4s, v28.4s, v6.s[3]\n" - "fmul v27.4s, v27.4s, v5.s[0]\n" - "fmul v26.4s, v26.4s, v5.s[1]\n" - "fmul v25.4s, v25.4s, v5.s[2]\n" - "fmul v24.4s, v24.4s, v5.s[3]\n" - "fmul v23.4s, v23.4s, v4.s[0]\n" - "fmul v22.4s, v22.4s, v4.s[1]\n" - "fmul v21.4s, v21.4s, v4.s[2]\n" - "fmul v20.4s, v20.4s, v4.s[3]\n" - "fmul v19.4s, v19.4s, v3.s[0]\n" - "fmul v18.4s, v18.4s, v3.s[1]\n" - "fmul v17.4s, v17.4s, v3.s[2]\n" - "fmul v16.4s, v16.4s, v3.s[3]\n" - "fadd v31.4s, v31.4s, v2.4s\n" - "fadd v30.4s, v30.4s, v2.4s\n" - "fadd v29.4s, v29.4s, v2.4s\n" - "fadd v28.4s, v28.4s, v2.4s\n" - "fadd v27.4s, v27.4s, v2.4s\n" - "fadd v26.4s, v26.4s, v2.4s\n" - "fadd v25.4s, v25.4s, v2.4s\n" - "fadd v24.4s, v24.4s, v2.4s\n" - "fadd v23.4s, v23.4s, v2.4s\n" - "fadd v22.4s, v22.4s, v2.4s\n" - "fadd v21.4s, v21.4s, v2.4s\n" - "fadd v20.4s, v20.4s, v2.4s\n" - "fadd v19.4s, v19.4s, v2.4s\n" - "fadd v18.4s, v18.4s, v2.4s\n" - "fadd v17.4s, v17.4s, v2.4s\n" - "fadd v16.4s, v16.4s, v2.4s\n" - "fmax v31.4s, v31.4s, v1.4s\n" - "fmax v30.4s, v30.4s, v1.4s\n" - "fmax v29.4s, v29.4s, v1.4s\n" - "fmax v28.4s, v28.4s, v1.4s\n" - "fmax v27.4s, v27.4s, v1.4s\n" - "fmax v26.4s, v26.4s, v1.4s\n" - "fmax v25.4s, v25.4s, v1.4s\n" - "fmax v24.4s, v24.4s, v1.4s\n" - "fmax v23.4s, v23.4s, v1.4s\n" - "fmax v22.4s, v22.4s, v1.4s\n" - "fmax v21.4s, v21.4s, v1.4s\n" - "fmax v20.4s, v20.4s, v1.4s\n" - "fmax v19.4s, v19.4s, v1.4s\n" - "fmax v18.4s, v18.4s, v1.4s\n" - "fmax v17.4s, v17.4s, v1.4s\n" - "fmax v16.4s, v16.4s, v1.4s\n" - "fmin v31.4s, v31.4s, v0.4s\n" - "fmin v30.4s, v30.4s, v0.4s\n" - "fmin v29.4s, v29.4s, v0.4s\n" - "fmin v28.4s, v28.4s, v0.4s\n" - "fmin v27.4s, v27.4s, v0.4s\n" - "fmin v26.4s, v26.4s, v0.4s\n" - "fmin v25.4s, v25.4s, v0.4s\n" - "fmin v24.4s, v24.4s, v0.4s\n" - "fmin v23.4s, v23.4s, v0.4s\n" - "fmin v22.4s, v22.4s, v0.4s\n" - "fmin v21.4s, v21.4s, v0.4s\n" - "fmin v20.4s, v20.4s, v0.4s\n" - "fmin v19.4s, v19.4s, v0.4s\n" - "fmin v18.4s, v18.4s, v0.4s\n" - "fmin v17.4s, v17.4s, v0.4s\n" - "fmin v16.4s, v16.4s, v0.4s\n" - "blt 9f\n" - "mov x20, %x[dst]\n" - "str q31, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q29, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q28, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q27, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q26, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q24, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q21, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q20, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q17, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q16, [x20, #0x0]\n" - "b 14f\n" - "9:" // Partial output - "mov x28, %x[dst]\n" - "add x26, x28, %x[dst_stride_row], LSL #2\n" - "add x25, x26, %x[dst_stride_row], LSL #1\n" - "add x24, x26, %x[dst_stride_row]\n" - "add x23, x25, %x[dst_stride_row]\n" - "add x22, x28, %x[dst_stride_row], LSL #1\n" - "add x21, x28, %x[dst_stride_row]\n" - "add x20, x22, %x[dst_stride_row]\n" - "add x27, x23, %x[dst_stride_row]\n" - "tbz x10, #1, 10f\n" - "st1 { v24.d }[0], [x23], #0x8\n" - "st1 { v25.d }[0], [x25], #0x8\n" - "st1 { v26.d }[0], [x24], #0x8\n" - "st1 { v27.d }[0], [x26], #0x8\n" - "st1 { v28.d }[0], [x20], #0x8\n" - "st1 { v29.d }[0], [x22], #0x8\n" - "st1 { v30.d }[0], [x21], #0x8\n" - "st1 { v31.d }[0], [x28], #0x8\n" - "tbz x10, #0, 11f\n" - "st1 { v24.s }[2], [x23]\n" - "st1 { v25.s }[2], [x25]\n" - "st1 { v26.s }[2], [x24]\n" - "st1 { v27.s }[2], [x26]\n" - "st1 { v28.s }[2], [x20]\n" - "st1 { v29.s }[2], [x22]\n" - "st1 { v30.s }[2], [x21]\n" - "st1 { v31.s }[2], [x28]\n" - "b 11f\n" - "10:" // Output block 0: partial_1_0 - "st1 { v24.s }[0], [x23]\n" - "st1 { v25.s }[0], [x25]\n" - "st1 { v26.s }[0], [x24]\n" - "st1 { v27.s }[0], [x26]\n" - "st1 { v28.s }[0], [x20]\n" - "st1 { v29.s }[0], [x22]\n" - "st1 { v30.s }[0], [x21]\n" - "st1 { v31.s }[0], [x28]\n" - "11:" // Output block 0: Done - "add x26, x27, %x[dst_stride_row], LSL #2\n" - "add x25, x27, %x[dst_stride_row], LSL #1\n" - "add x24, x26, %x[dst_stride_row], LSL #1\n" - "add x23, x27, %x[dst_stride_row]\n" - "add x22, x25, %x[dst_stride_row]\n" - "add x21, x26, %x[dst_stride_row]\n" - "add x20, x24, %x[dst_stride_row]\n" - "tbz x10, #1, 12f\n" - "st1 { v16.d }[0], [x20], #0x8\n" - "st1 { v17.d }[0], [x24], #0x8\n" - "st1 { v18.d }[0], [x21], #0x8\n" - "st1 { v19.d }[0], [x26], #0x8\n" - "st1 { v20.d }[0], [x22], #0x8\n" - "st1 { v21.d }[0], [x25], #0x8\n" - "st1 { v22.d }[0], [x23], #0x8\n" - "st1 { v23.d }[0], [x27], #0x8\n" - "tbz x10, #0, 13f\n" - "st1 { v16.s }[2], [x20]\n" - "st1 { v17.s }[2], [x24]\n" - "st1 { v18.s }[2], [x21]\n" - "st1 { v19.s }[2], [x26]\n" - "st1 { v20.s }[2], [x22]\n" - "st1 { v21.s }[2], [x25]\n" - "st1 { v22.s }[2], [x23]\n" - "st1 { v23.s }[2], [x27]\n" - "b 13f\n" - "12:" // Output block 1: partial_1_0 - "st1 { v16.s }[0], [x20]\n" - "st1 { v17.s }[0], [x24]\n" - "st1 { v18.s }[0], [x21]\n" - "st1 { v19.s }[0], [x26]\n" - "st1 { v20.s }[0], [x22]\n" - "st1 { v21.s }[0], [x25]\n" - "st1 { v22.s }[0], [x23]\n" - "st1 { v23.s }[0], [x27]\n" - "13:" // Output block 1: Done - "14:" // Output stage exit - "subs x10, x10, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 2b\n" - "mov x20, #0x4\n" - "sub x12, x12, #0x10\n" - "cmp x12, #0x10\n" - "mov %x[dst], x9\n" - "madd %x[lhs_packed], x20, x13, %x[lhs_packed]\n" - "bge 1b\n" - "15:" // Row loop skip - "cbz x12, 25f\n" - "16:" // Row tail: Row loop - "mov x26, %x[rhs_packed]\n" - "mov x25, %x[n]\n" - "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" - "17:" // Row tail: Column loop - "movi v16.4s, #0x0\n" - "mov x27, %x[lhs_packed]\n" - "mov x21, %x[num_blocks]\n" - "str q16, [SP, #0x0]\n" - "str q16, [SP, #0x10]\n" - "str q16, [SP, #0x20]\n" - "str q16, [SP, #0x30]\n" - "18:" // Row tail: Block loop - "movi v2.4s, #0x0\n" - "movi v17.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "movi v12.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "19:" // Row tail: Sub block loop - "ldr q0, [x26, #0x0]\n" - "ldr q31, [x27, #0x0]\n" - "movi v30.16b, #0xf0\n" - "subs x20, x20, #0x1\n" - "ldr q29, [x26, #0x10]\n" - "ldr q28, [x27, #0x10]\n" - "ldr q27, [x26, #0x20]\n" - "ldr q26, [x27, #0x20]\n" - "ldr q25, [x26, #0x30]\n" - "ldr q24, [x27, #0x30]\n" - "shl v23.16b, v0.16b, #0x4\n" - "and v0.16b, v0.16b, v30.16b\n" - "ldr q22, [x27, #0x40]\n" - "ldr q21, [x27, #0x50]\n" - "shl v20.16b, v29.16b, #0x4\n" - "and v29.16b, v29.16b, v30.16b\n" - "ldr q7, [x27, #0x60]\n" - "ldr q18, [x27, #0x70]\n" - "shl v19.16b, v27.16b, #0x4\n" - "and v27.16b, v27.16b, v30.16b\n" - ".inst 0x4f9fe2e2 // sdot v2.4s, v23.16b, v31.4b[0]\n" - ".inst 0x4fbfe2f1 // sdot v17.4s, v23.16b, v31.4b[1]\n" - "shl v16.16b, v25.16b, #0x4\n" - "add x26, x26, #0x40\n" - ".inst 0x4f9feaec // sdot v12.4s, v23.16b, v31.4b[2]\n" - ".inst 0x4fbfeae9 // sdot v9.4s, v23.16b, v31.4b[3]\n" - "and v25.16b, v25.16b, v30.16b\n" - "add x27, x27, #0x80\n" - ".inst 0x4f9ce282 // sdot v2.4s, v20.16b, v28.4b[0]\n" - ".inst 0x4fbce291 // sdot v17.4s, v20.16b, v28.4b[1]\n" - ".inst 0x4f9cea8c // sdot v12.4s, v20.16b, v28.4b[2]\n" - ".inst 0x4fbcea89 // sdot v9.4s, v20.16b, v28.4b[3]\n" - ".inst 0x4f9ae262 // sdot v2.4s, v19.16b, v26.4b[0]\n" - ".inst 0x4fbae271 // sdot v17.4s, v19.16b, v26.4b[1]\n" - ".inst 0x4f9aea6c // sdot v12.4s, v19.16b, v26.4b[2]\n" - ".inst 0x4fbaea69 // sdot v9.4s, v19.16b, v26.4b[3]\n" - ".inst 0x4f98e202 // sdot v2.4s, v16.16b, v24.4b[0]\n" - ".inst 0x4fb8e211 // sdot v17.4s, v16.16b, v24.4b[1]\n" - ".inst 0x4f98ea0c // sdot v12.4s, v16.16b, v24.4b[2]\n" - ".inst 0x4fb8ea09 // sdot v9.4s, v16.16b, v24.4b[3]\n" - ".inst 0x4f96e002 // sdot v2.4s, v0.16b, v22.4b[0]\n" - ".inst 0x4fb6e011 // sdot v17.4s, v0.16b, v22.4b[1]\n" - ".inst 0x4f96e80c // sdot v12.4s, v0.16b, v22.4b[2]\n" - ".inst 0x4fb6e809 // sdot v9.4s, v0.16b, v22.4b[3]\n" - ".inst 0x4f95e3a2 // sdot v2.4s, v29.16b, v21.4b[0]\n" - ".inst 0x4fb5e3b1 // sdot v17.4s, v29.16b, v21.4b[1]\n" - ".inst 0x4f95ebac // sdot v12.4s, v29.16b, v21.4b[2]\n" - ".inst 0x4fb5eba9 // sdot v9.4s, v29.16b, v21.4b[3]\n" - ".inst 0x4f87e362 // sdot v2.4s, v27.16b, v7.4b[0]\n" - ".inst 0x4fa7e371 // sdot v17.4s, v27.16b, v7.4b[1]\n" - ".inst 0x4f87eb6c // sdot v12.4s, v27.16b, v7.4b[2]\n" - ".inst 0x4fa7eb69 // sdot v9.4s, v27.16b, v7.4b[3]\n" - ".inst 0x4f92e322 // sdot v2.4s, v25.16b, v18.4b[0]\n" - ".inst 0x4fb2e331 // sdot v17.4s, v25.16b, v18.4b[1]\n" - ".inst 0x4f92eb2c // sdot v12.4s, v25.16b, v18.4b[2]\n" - ".inst 0x4fb2eb29 // sdot v9.4s, v25.16b, v18.4b[3]\n" - "bgt 19b\n" - "ldr d7, [x26, #0x0]\n" - "ldr q16, [SP, #0x0]\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v12.4s, v12.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "add x26, x26, #0x8\n" - "shll v7.4s, v7.4h, #0x10\n" - "fmla v16.4s, v2.4s, v7.4s\n" - "str q16, [SP, #0x0]\n" - "ldr q16, [SP, #0x10]\n" - "fmla v16.4s, v17.4s, v7.4s\n" - "str q16, [SP, #0x10]\n" - "ldr q16, [SP, #0x20]\n" - "fmla v16.4s, v12.4s, v7.4s\n" - "str q16, [SP, #0x20]\n" - "ldr q16, [SP, #0x30]\n" - "fmla v16.4s, v9.4s, v7.4s\n" - "str q16, [SP, #0x30]\n" - "subs x21, x21, #0x1\n" - "bgt 18b\n" - "ld1 { v21.4s }, [x27]\n" - "ldr q31, [SP, #0x0]\n" - "add x27, x27, #0x10\n" - "add x20, %x[clamp_vals], #0x4\n" - "ldr q30, [SP, #0x10]\n" - "ldr q29, [SP, #0x20]\n" - "cmp x25, #0x4\n" - "ldr q28, [SP, #0x30]\n" - "ldr q20, [x26, #0x0]\n" - "ldr q19, [x27, #0x0]\n" - "ldr q18, [x26, #0x10]\n" - "scvtf v21.4s, v21.4s\n" - "add x26, x26, #0x20\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v16.4s }, [x20]\n" - "fmla v31.4s, v20.4s, v21.s[0]\n" - "fmla v30.4s, v20.4s, v21.s[1]\n" - "fmla v29.4s, v20.4s, v21.s[2]\n" - "fmla v28.4s, v20.4s, v21.s[3]\n" - "fmul v31.4s, v31.4s, v19.s[0]\n" - "fmul v30.4s, v30.4s, v19.s[1]\n" - "fadd v31.4s, v31.4s, v18.4s\n" - "fmul v29.4s, v29.4s, v19.s[2]\n" - "fmul v28.4s, v28.4s, v19.s[3]\n" - "fadd v30.4s, v30.4s, v18.4s\n" - "fmax v31.4s, v31.4s, v17.4s\n" - "fadd v29.4s, v29.4s, v18.4s\n" - "fadd v28.4s, v28.4s, v18.4s\n" - "fmax v30.4s, v30.4s, v17.4s\n" - "fmin v31.4s, v31.4s, v16.4s\n" - "fmax v29.4s, v29.4s, v17.4s\n" - "fmax v28.4s, v28.4s, v17.4s\n" - "fmin v30.4s, v30.4s, v16.4s\n" - "fmin v29.4s, v29.4s, v16.4s\n" - "fmin v28.4s, v28.4s, v16.4s\n" - "blt 21f\n" - "mov x20, %x[dst]\n" - "cmp x12, #0x1\n" - "str q31, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 24f\n" - "cmp x12, #0x2\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 24f\n" - "cmp x12, #0x3\n" - "str q29, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 24f\n" - "str q28, [x20, #0x0]\n" - "b 24f\n" - "21:" // Row tail: Partial output - "mov x23, %x[dst]\n" - "cmp x12, #0x1\n" - "add x22, x23, %x[dst_stride_row]\n" - "csel x22, x22, x23, GT\n" - "cmp x12, #0x2\n" - "add x21, x23, %x[dst_stride_row], LSL #1\n" - "csel x21, x21, x22, GT\n" - "cmp x12, #0x3\n" - "add x20, x21, %x[dst_stride_row]\n" - "csel x20, x20, x21, GT\n" - "tbz x25, #1, 22f\n" - "st1 { v28.d }[0], [x20], #0x8\n" - "st1 { v29.d }[0], [x21], #0x8\n" - "st1 { v30.d }[0], [x22], #0x8\n" - "st1 { v31.d }[0], [x23], #0x8\n" - "tbz x25, #0, 23f\n" - "st1 { v28.s }[2], [x20]\n" - "st1 { v29.s }[2], [x21]\n" - "st1 { v30.s }[2], [x22]\n" - "st1 { v31.s }[2], [x23]\n" - "b 23f\n" - "22:" // Row tail: Output block 0: partial_1_0 - "st1 { v28.s }[0], [x20]\n" - "st1 { v29.s }[0], [x21]\n" - "st1 { v30.s }[0], [x22]\n" - "st1 { v31.s }[0], [x23]\n" - "23:" // Row tail: Output block 0: Done - "24:" // Row tail: Output stage exit - "subs x25, x25, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 17b\n" - "subs x12, x12, #0x4\n" - "add %x[lhs_packed], %x[lhs_packed], x13\n" - "mov %x[dst], x24\n" - "bgt 16b\n" - "25:" // Row tail: Row loop skip - "add SP, SP, #0x100\n" - : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) - : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), - [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v2", "v20", - "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v31", "v4", "v5", "v6", "v7", - "v8", "v9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9"); + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod(&args); } #endif // Architectural features check. -- GitLab From 1e11beb2ba4430788db0343e97678a61aa3d0307 Mon Sep 17 00:00:00 2001 From: Michael Kozlov Date: Thu, 23 Jan 2025 16:35:20 +0000 Subject: [PATCH 05/15] Extract inline assembly kernels into external files: 8x4x32 Signed-off-by: Michael Kozlov --- CMakeLists.txt | 1 + ...8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S | 486 +++++++++++++++ ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 584 ++++-------------- ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h | 54 +- 4 files changed, 622 insertions(+), 503 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index 21dcf7e9..de1f6758 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -159,6 +159,7 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S ) set(KLEIDIAI_FILES_SME diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S new file mode 100644 index 00000000..5132ba54 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S @@ -0,0 +1,486 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x6, #0x80 + movi v15.16b, #0xf0 + mov x21, #0x20 + ldr x20, [x0, #0x28] + ldr x7, [x0, #0x40] + ldr x8, [x0, #0x38] + ldr x17, [x0, #0x8] + ldr x16, [x0, #0x10] + ldr x15, [x0, #0x30] + mov x14, x20 + mul x6, x7, x6 + ldr x13, [x0, #0x0] + ldr x12, [x0, #0x20] + ldr x11, [x0, #0x18] + cmp x14, #0x8 + madd x6, x8, x6, x21 + blt label_11 +KAI_ASM_LABEL(label_1) // Row loop + mov x10, x16 + mov x9, x15 + add x28, x13, x12, LSL #3 +KAI_ASM_LABEL(label_2) // Column loop + mov x23, x17 + movi v23.16b, #0x0 + movi v16.16b, #0x0 + mov x22, x8 + movi v13.16b, #0x0 + movi v11.16b, #0x0 + movi v14.16b, #0x0 + movi v5.16b, #0x0 + movi v10.16b, #0x0 + movi v26.16b, #0x0 + add x21, x23, x6 +KAI_ASM_LABEL(label_3) // Block loop + movi v8.4s, #0x0 + movi v19.4s, #0x0 + mov x20, x7 + movi v6.4s, #0x0 + movi v7.4s, #0x0 + movi v2.4s, #0x0 + movi v4.4s, #0x0 + movi v3.4s, #0x0 + movi v0.4s, #0x0 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q1, [x10, #0x0] + ldr q21, [x10, #0x10] + subs x20, x20, #0x1 + ldr q9, [x23, #0x0] + ldr q28, [x23, #0x10] + ldr q20, [x21, #0x0] + ldr q31, [x21, #0x10] + ldr q30, [x10, #0x20] + ldr q24, [x10, #0x30] + shl v27.16b, v1.16b, #0x4 + shl v17.16b, v21.16b, #0x4 + ldr q22, [x23, #0x20] + ldr q12, [x23, #0x30] + and v1.16b, v1.16b, v15.16b + and v21.16b, v21.16b, v15.16b + ldr q18, [x21, #0x20] + ldr q25, [x21, #0x30] + add x10, x10, #0x40 + ldr q29, [x23, #0x40] + KAI_ASM_INST(0x4e9ba528) // smmla v8.4s, v9.16b, v27.16b + KAI_ASM_INST(0x4e91a533) // smmla v19.4s, v9.16b, v17.16b + ldr q9, [x23, #0x50] + KAI_ASM_INST(0x4e9ba786) // smmla v6.4s, v28.16b, v27.16b + KAI_ASM_INST(0x4e91a787) // smmla v7.4s, v28.16b, v17.16b + ldr q28, [x21, #0x40] + KAI_ASM_INST(0x4e9ba682) // smmla v2.4s, v20.16b, v27.16b + KAI_ASM_INST(0x4e91a684) // smmla v4.4s, v20.16b, v17.16b + ldr q20, [x21, #0x50] + KAI_ASM_INST(0x4e9ba7e3) // smmla v3.4s, v31.16b, v27.16b + ldr q27, [x23, #0x60] + KAI_ASM_INST(0x4e91a7e0) // smmla v0.4s, v31.16b, v17.16b + ldr q17, [x23, #0x70] + shl v31.16b, v30.16b, #0x4 + and v30.16b, v30.16b, v15.16b + add x23, x23, #0x80 + KAI_ASM_INST(0x4e9fa6c8) // smmla v8.4s, v22.16b, v31.16b + KAI_ASM_INST(0x4e9fa586) // smmla v6.4s, v12.16b, v31.16b + KAI_ASM_INST(0x4e9fa642) // smmla v2.4s, v18.16b, v31.16b + KAI_ASM_INST(0x4e9fa723) // smmla v3.4s, v25.16b, v31.16b + ldr q31, [x21, #0x60] + KAI_ASM_INST(0x4e81a7a8) // smmla v8.4s, v29.16b, v1.16b + KAI_ASM_INST(0x4e81a526) // smmla v6.4s, v9.16b, v1.16b + KAI_ASM_INST(0x4e81a782) // smmla v2.4s, v28.16b, v1.16b + KAI_ASM_INST(0x4e81a683) // smmla v3.4s, v20.16b, v1.16b + ldr q1, [x21, #0x70] + add x21, x21, #0x80 + KAI_ASM_INST(0x4e9ea768) // smmla v8.4s, v27.16b, v30.16b + KAI_ASM_INST(0x4e9ea626) // smmla v6.4s, v17.16b, v30.16b + KAI_ASM_INST(0x4e9ea7e2) // smmla v2.4s, v31.16b, v30.16b + KAI_ASM_INST(0x4e9ea423) // smmla v3.4s, v1.16b, v30.16b + shl v30.16b, v24.16b, #0x4 + and v24.16b, v24.16b, v15.16b + KAI_ASM_INST(0x4e9ea6d3) // smmla v19.4s, v22.16b, v30.16b + KAI_ASM_INST(0x4e9ea587) // smmla v7.4s, v12.16b, v30.16b + KAI_ASM_INST(0x4e9ea644) // smmla v4.4s, v18.16b, v30.16b + KAI_ASM_INST(0x4e9ea720) // smmla v0.4s, v25.16b, v30.16b + KAI_ASM_INST(0x4e95a7b3) // smmla v19.4s, v29.16b, v21.16b + KAI_ASM_INST(0x4e95a527) // smmla v7.4s, v9.16b, v21.16b + KAI_ASM_INST(0x4e95a784) // smmla v4.4s, v28.16b, v21.16b + KAI_ASM_INST(0x4e95a680) // smmla v0.4s, v20.16b, v21.16b + KAI_ASM_INST(0x4e98a773) // smmla v19.4s, v27.16b, v24.16b + KAI_ASM_INST(0x4e98a627) // smmla v7.4s, v17.16b, v24.16b + KAI_ASM_INST(0x4e98a7e4) // smmla v4.4s, v31.16b, v24.16b + KAI_ASM_INST(0x4e98a420) // smmla v0.4s, v1.16b, v24.16b + bgt label_4 + ldr d20, [x10, #0x0] + uzp1 v18.2d, v8.2d, v19.2d + uzp2 v9.2d, v8.2d, v19.2d + add x10, x10, #0x8 + uzp1 v17.2d, v6.2d, v7.2d + uzp2 v12.2d, v6.2d, v7.2d + shll v20.4s, v20.4h, #0x10 + scvtf v18.4s, v18.4s, #0x4 + scvtf v9.4s, v9.4s, #0x4 + scvtf v17.4s, v17.4s, #0x4 + scvtf v12.4s, v12.4s, #0x4 + fmla v23.4s, v18.4s, v20.4s + fmla v16.4s, v9.4s, v20.4s + fmla v13.4s, v17.4s, v20.4s + fmla v11.4s, v12.4s, v20.4s + uzp1 v19.2d, v2.2d, v4.2d + uzp2 v18.2d, v2.2d, v4.2d + uzp1 v17.2d, v3.2d, v0.2d + uzp2 v25.2d, v3.2d, v0.2d + scvtf v19.4s, v19.4s, #0x4 + scvtf v18.4s, v18.4s, #0x4 + scvtf v17.4s, v17.4s, #0x4 + scvtf v25.4s, v25.4s, #0x4 + fmla v14.4s, v19.4s, v20.4s + fmla v5.4s, v18.4s, v20.4s + fmla v10.4s, v17.4s, v20.4s + fmla v26.4s, v25.4s, v20.4s + subs x22, x22, #0x1 + bgt label_3 + ld1 { v24.4s }, [x23] + ld1 { v22.4s }, [x21] + add x23, x23, #0x10 + add x21, x21, #0x10 + ldr q21, [x10, #0x0] + ldr q20, [x23, #0x0] + add x20, x11, #0x4 + cmp x9, #0x4 + ldr q19, [x21, #0x0] + ldr q18, [x10, #0x10] + add x10, x10, #0x20 + ld1r { v17.4s }, [x11] + ld1r { v9.4s }, [x20] + scvtf v24.4s, v24.4s + scvtf v22.4s, v22.4s + fmla v23.4s, v21.4s, v24.s[0] + fmla v16.4s, v21.4s, v24.s[1] + fmla v13.4s, v21.4s, v24.s[2] + fmla v11.4s, v21.4s, v24.s[3] + fmla v14.4s, v21.4s, v22.s[0] + fmla v5.4s, v21.4s, v22.s[1] + fmla v10.4s, v21.4s, v22.s[2] + fmla v26.4s, v21.4s, v22.s[3] + fmul v23.4s, v23.4s, v20.s[0] + fmul v16.4s, v16.4s, v20.s[1] + fmul v13.4s, v13.4s, v20.s[2] + fmul v11.4s, v11.4s, v20.s[3] + fmul v14.4s, v14.4s, v19.s[0] + fmul v5.4s, v5.4s, v19.s[1] + fadd v23.4s, v23.4s, v18.4s + fmul v10.4s, v10.4s, v19.s[2] + fmul v26.4s, v26.4s, v19.s[3] + fadd v16.4s, v16.4s, v18.4s + fadd v13.4s, v13.4s, v18.4s + fadd v11.4s, v11.4s, v18.4s + fadd v14.4s, v14.4s, v18.4s + fadd v5.4s, v5.4s, v18.4s + fadd v10.4s, v10.4s, v18.4s + fadd v26.4s, v26.4s, v18.4s + fmax v23.4s, v23.4s, v17.4s + fmax v16.4s, v16.4s, v17.4s + fmax v13.4s, v13.4s, v17.4s + fmax v11.4s, v11.4s, v17.4s + fmax v14.4s, v14.4s, v17.4s + fmax v5.4s, v5.4s, v17.4s + fmax v10.4s, v10.4s, v17.4s + fmax v26.4s, v26.4s, v17.4s + fmin v23.4s, v23.4s, v9.4s + fmin v16.4s, v16.4s, v9.4s + fmin v13.4s, v13.4s, v9.4s + fmin v11.4s, v11.4s, v9.4s + fmin v14.4s, v14.4s, v9.4s + fmin v5.4s, v5.4s, v9.4s + fmin v10.4s, v10.4s, v9.4s + fmin v26.4s, v26.4s, v9.4s + blt label_7 + mov x20, x13 + str q23, [x20, #0x0] + add x20, x20, x12 + str q16, [x20, #0x0] + add x20, x20, x12 + str q13, [x20, #0x0] + add x20, x20, x12 + str q11, [x20, #0x0] + add x20, x20, x12 + str q14, [x20, #0x0] + add x20, x20, x12 + str q5, [x20, #0x0] + add x20, x20, x12 + str q10, [x20, #0x0] + add x20, x20, x12 + str q26, [x20, #0x0] + b label_10 +KAI_ASM_LABEL(label_7) // Partial output + mov x27, x13 + add x26, x27, x12, LSL #2 + add x25, x26, x12, LSL #1 + add x24, x26, x12 + add x23, x25, x12 + add x22, x27, x12, LSL #1 + add x21, x27, x12 + add x20, x22, x12 + tbz x9, #1, label_8 + st1 { v26.d }[0], [x23], #0x8 + st1 { v10.d }[0], [x25], #0x8 + st1 { v5.d }[0], [x24], #0x8 + st1 { v14.d }[0], [x26], #0x8 + st1 { v11.d }[0], [x20], #0x8 + st1 { v13.d }[0], [x22], #0x8 + st1 { v16.d }[0], [x21], #0x8 + st1 { v23.d }[0], [x27], #0x8 + tbz x9, #0, label_9 + st1 { v26.s }[2], [x23] + st1 { v10.s }[2], [x25] + st1 { v5.s }[2], [x24] + st1 { v14.s }[2], [x26] + st1 { v11.s }[2], [x20] + st1 { v13.s }[2], [x22] + st1 { v16.s }[2], [x21] + st1 { v23.s }[2], [x27] + b label_9 +KAI_ASM_LABEL(label_8) // Output block 0: partial_1_0 + st1 { v26.s }[0], [x23] + st1 { v10.s }[0], [x25] + st1 { v5.s }[0], [x24] + st1 { v14.s }[0], [x26] + st1 { v11.s }[0], [x20] + st1 { v13.s }[0], [x22] + st1 { v16.s }[0], [x21] + st1 { v23.s }[0], [x27] +KAI_ASM_LABEL(label_9) // Output block 0: Done +KAI_ASM_LABEL(label_10) // Output stage exit + subs x9, x9, #0x4 + add x13, x13, #0x10 + bgt label_2 + mov x20, #0x2 + sub x14, x14, #0x8 + cmp x14, #0x8 + mov x13, x28 + madd x17, x20, x6, x17 + bge label_1 +KAI_ASM_LABEL(label_11) // Row loop skip + cbz x14, label_21 +KAI_ASM_LABEL(label_12) // Row tail: Row loop + mov x26, x16 + mov x25, x15 + add x24, x13, x12, LSL #2 +KAI_ASM_LABEL(label_13) // Row tail: Column loop + movi v23.16b, #0x0 + movi v16.16b, #0x0 + mov x23, x17 + mov x21, x8 + movi v13.16b, #0x0 + movi v11.16b, #0x0 +KAI_ASM_LABEL(label_14) // Row tail: Block loop + movi v8.4s, #0x0 + movi v19.4s, #0x0 + mov x20, x7 + movi v6.4s, #0x0 + movi v7.4s, #0x0 +KAI_ASM_LABEL(label_15) // Row tail: Sub block loop + ldr q0, [x26, #0x0] + ldr q31, [x26, #0x10] + subs x20, x20, #0x1 + ldr q30, [x23, #0x0] + ldr q29, [x23, #0x10] + ldr q28, [x26, #0x20] + ldr q27, [x26, #0x30] + add x26, x26, #0x40 + ldr q26, [x23, #0x20] + ldr q25, [x23, #0x30] + shl v24.16b, v0.16b, #0x4 + shl v22.16b, v31.16b, #0x4 + ldr q21, [x23, #0x40] + ldr q20, [x23, #0x50] + and v0.16b, v0.16b, v15.16b + and v31.16b, v31.16b, v15.16b + ldr q3, [x23, #0x60] + ldr q18, [x23, #0x70] + shl v17.16b, v28.16b, #0x4 + shl v12.16b, v27.16b, #0x4 + KAI_ASM_INST(0x4e98a7c8) // smmla v8.4s, v30.16b, v24.16b + KAI_ASM_INST(0x4e96a7d3) // smmla v19.4s, v30.16b, v22.16b + and v28.16b, v28.16b, v15.16b + add x23, x23, #0x80 + KAI_ASM_INST(0x4e98a7a6) // smmla v6.4s, v29.16b, v24.16b + KAI_ASM_INST(0x4e96a7a7) // smmla v7.4s, v29.16b, v22.16b + and v27.16b, v27.16b, v15.16b + KAI_ASM_INST(0x4e91a748) // smmla v8.4s, v26.16b, v17.16b + KAI_ASM_INST(0x4e8ca753) // smmla v19.4s, v26.16b, v12.16b + KAI_ASM_INST(0x4e91a726) // smmla v6.4s, v25.16b, v17.16b + KAI_ASM_INST(0x4e8ca727) // smmla v7.4s, v25.16b, v12.16b + KAI_ASM_INST(0x4e80a6a8) // smmla v8.4s, v21.16b, v0.16b + KAI_ASM_INST(0x4e9fa6b3) // smmla v19.4s, v21.16b, v31.16b + KAI_ASM_INST(0x4e80a686) // smmla v6.4s, v20.16b, v0.16b + KAI_ASM_INST(0x4e9fa687) // smmla v7.4s, v20.16b, v31.16b + KAI_ASM_INST(0x4e9ca468) // smmla v8.4s, v3.16b, v28.16b + KAI_ASM_INST(0x4e9ba473) // smmla v19.4s, v3.16b, v27.16b + KAI_ASM_INST(0x4e9ca646) // smmla v6.4s, v18.16b, v28.16b + KAI_ASM_INST(0x4e9ba647) // smmla v7.4s, v18.16b, v27.16b + bgt label_15 + ldr d12, [x26, #0x0] + uzp1 v20.2d, v8.2d, v19.2d + uzp2 v19.2d, v8.2d, v19.2d + add x26, x26, #0x8 + uzp1 v18.2d, v6.2d, v7.2d + uzp2 v17.2d, v6.2d, v7.2d + shll v12.4s, v12.4h, #0x10 + scvtf v20.4s, v20.4s, #0x4 + scvtf v19.4s, v19.4s, #0x4 + scvtf v18.4s, v18.4s, #0x4 + scvtf v17.4s, v17.4s, #0x4 + fmla v23.4s, v20.4s, v12.4s + fmla v16.4s, v19.4s, v12.4s + fmla v13.4s, v18.4s, v12.4s + fmla v11.4s, v17.4s, v12.4s + subs x21, x21, #0x1 + bgt label_14 + ld1 { v21.4s }, [x23] + ldr q20, [x26, #0x0] + add x23, x23, #0x10 + add x20, x11, #0x4 + ldr q19, [x23, #0x0] + ldr q18, [x26, #0x10] + cmp x25, #0x4 + add x26, x26, #0x20 + ld1r { v17.4s }, [x11] + ld1r { v29.4s }, [x20] + scvtf v21.4s, v21.4s + fmla v23.4s, v20.4s, v21.s[0] + fmla v16.4s, v20.4s, v21.s[1] + fmla v13.4s, v20.4s, v21.s[2] + fmla v11.4s, v20.4s, v21.s[3] + fmul v23.4s, v23.4s, v19.s[0] + fmul v16.4s, v16.4s, v19.s[1] + fmul v13.4s, v13.4s, v19.s[2] + fadd v23.4s, v23.4s, v18.4s + fmul v11.4s, v11.4s, v19.s[3] + fadd v16.4s, v16.4s, v18.4s + fadd v13.4s, v13.4s, v18.4s + fadd v11.4s, v11.4s, v18.4s + fmax v23.4s, v23.4s, v17.4s + fmax v16.4s, v16.4s, v17.4s + fmax v13.4s, v13.4s, v17.4s + fmax v11.4s, v11.4s, v17.4s + fmin v23.4s, v23.4s, v29.4s + fmin v16.4s, v16.4s, v29.4s + fmin v13.4s, v13.4s, v29.4s + fmin v11.4s, v11.4s, v29.4s + blt label_17 + mov x20, x13 + cmp x14, #0x1 + str q23, [x20, #0x0] + add x20, x20, x12 + ble label_20 + cmp x14, #0x2 + str q16, [x20, #0x0] + add x20, x20, x12 + ble label_20 + cmp x14, #0x3 + str q13, [x20, #0x0] + add x20, x20, x12 + ble label_20 + str q11, [x20, #0x0] + b label_20 +KAI_ASM_LABEL(label_17) // Row tail: Partial output + mov x23, x13 + cmp x14, #0x1 + add x22, x23, x12 + csel x22, x22, x23, GT + cmp x14, #0x2 + add x21, x23, x12, LSL #1 + csel x21, x21, x22, GT + cmp x14, #0x3 + add x20, x21, x12 + csel x20, x20, x21, GT + tbz x25, #1, label_18 + st1 { v11.d }[0], [x20], #0x8 + st1 { v13.d }[0], [x21], #0x8 + st1 { v16.d }[0], [x22], #0x8 + st1 { v23.d }[0], [x23], #0x8 + tbz x25, #0, label_19 + st1 { v11.s }[2], [x20] + st1 { v13.s }[2], [x21] + st1 { v16.s }[2], [x22] + st1 { v23.s }[2], [x23] + b label_19 +KAI_ASM_LABEL(label_18) // Row tail: Output block 0: partial_1_0 + st1 { v11.s }[0], [x20] + st1 { v13.s }[0], [x21] + st1 { v16.s }[0], [x22] + st1 { v23.s }[0], [x23] +KAI_ASM_LABEL(label_19) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_20) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x13, x13, #0x10 + bgt label_13 + subs x14, x14, #0x4 + add x17, x17, x6 + mov x13, x24 + bgt label_12 +KAI_ASM_LABEL(label_21) // Row tail: Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c index 1805f20e..b1561b99 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -3,62 +3,95 @@ // // SPDX-License-Identifier: Apache-2.0 // - -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__ARM_FEATURE_MATMUL_INT8) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) #error "I8mm extension required to compile this micro-kernel" -#else +#else // Architectural features check. + #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" -#include #include #include #include "kai/kai_common.h" +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(KernelArgs* args_ptr); + +// Compute args static const size_t kai_m_step = 8; static const size_t kai_n_step = 4; +// Packing args static const size_t kai_mr = 4; static const size_t kai_nr = 4; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_bl_multiple_of = 32; -static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); -static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); -static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); -static const size_t kai_num_bytes_sum_rhs = sizeof(float); -static const size_t kai_num_bytes_bias = sizeof(float); - -inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - return kai_roundup(k, bl) / bl; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 2; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); } -inline static size_t kai_k_roundedup(size_t k) { - // Since we pack a float and int32 value at the end of the row, - // we must make sure that k is a multiple of 4 for alignment - size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); - return kai_roundup(k, kr_sr_roundedup4); +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs; + return num_bytes_per_block_rhs; } -inline static size_t kai_lhs_packed_stride(size_t k) { - const size_t k_internal = kai_k_roundedup(k); +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} - KAI_ASSERT((k_internal % 2) == 0); +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; - return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + return lhs_packed_stride; } -inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { - KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; - return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); + return rhs_packed_stride; } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { @@ -86,470 +119,67 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) } size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k) { - KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSUME((m_idx % kai_m_step) == 0); - return (m_idx / kai_mr) * kai_lhs_packed_stride(k); + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( size_t n_idx, size_t k, size_t bl) { - KAI_ASSERT((k % bl) == 0); - KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); - return (n_idx / kai_nr) * kai_rhs_packed_stride(k, bl); + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSERT((m_idx % kai_m_step) == 0); - KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); - return (n_idx * sizeof(float)) + m_idx * dst_stride; + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(size_t m, size_t n) { - return m * n * sizeof(float); + return m * n * kai_num_bytes_dst_value; } void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( - size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, - float* dst, // NOLINT(readability-non-const-parameter) - size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { - KAI_ASSERT((k % bl) == 0); - KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(dst_stride_col == sizeof(float)); + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); if (m == 0) { return; } - - size_t num_subblocks = bl / kai_bl_multiple_of; - size_t num_blocks = kai_num_blocks_per_row(k, bl); - - float clamp_vals[2] = {scalar_min, scalar_max}; - - __asm__ __volatile__( - "mov x12, #0x80\n" - "mov x11, %x[m]\n" - "movi v15.16b, #0xf0\n" - "mov x21, #0x3d800000\n" - "mov x20, #0x20\n" - "mul x12, %x[num_subblocks], x12\n" - "cmp x11, #0x8\n" - "dup v24.4s, w21\n" - "madd x12, %x[num_blocks], x12, x20\n" - "blt 11f\n" - "1:" // Row loop - "mov x10, %x[rhs_packed]\n" - "mov x9, %x[n]\n" - "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" - "2:" // Column loop - "mov x23, %x[lhs_packed]\n" - "movi v12.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "mov x22, %x[num_blocks]\n" - "movi v22.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v0.16b, #0x0\n" - "movi v30.16b, #0x0\n" - "movi v8.16b, #0x0\n" - "add x21, x23, x12\n" - "3:" // Block loop - "movi v6.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "movi v4.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "movi v31.4s, #0x0\n" - "movi v3.4s, #0x0\n" - "movi v7.4s, #0x0\n" - "movi v23.4s, #0x0\n" - "4:" // Sub block loop - "ldr q2, [x10, #0x0]\n" - "ldr q20, [x10, #0x10]\n" - "subs x20, x20, #0x1\n" - "ldr q25, [x23, #0x0]\n" - "ldr q11, [x23, #0x10]\n" - "ldr q9, [x21, #0x0]\n" - "ldr q19, [x21, #0x10]\n" - "ldr q1, [x10, #0x20]\n" - "ldr q29, [x10, #0x30]\n" - "shl v27.16b, v2.16b, #0x4\n" - "shl v21.16b, v20.16b, #0x4\n" - "ldr q17, [x23, #0x20]\n" - "ldr q26, [x23, #0x30]\n" - "and v2.16b, v2.16b, v15.16b\n" - "and v20.16b, v20.16b, v15.16b\n" - "ldr q28, [x21, #0x20]\n" - "ldr q16, [x21, #0x30]\n" - "add x10, x10, #0x40\n" - ".inst 0x4e9ba726 // smmla v6.4s, v25.16b, v27.16b\n" - ".inst 0x4e95a72a // smmla v10.4s, v25.16b, v21.16b\n" - "ldr q25, [x23, #0x40]\n" - ".inst 0x4e9ba564 // smmla v4.4s, v11.16b, v27.16b\n" - ".inst 0x4e95a572 // smmla v18.4s, v11.16b, v21.16b\n" - "ldr q11, [x23, #0x50]\n" - ".inst 0x4e9ba53f // smmla v31.4s, v9.16b, v27.16b\n" - ".inst 0x4e95a523 // smmla v3.4s, v9.16b, v21.16b\n" - "ldr q9, [x21, #0x40]\n" - ".inst 0x4e9ba667 // smmla v7.4s, v19.16b, v27.16b\n" - "ldr q27, [x21, #0x50]\n" - ".inst 0x4e95a677 // smmla v23.4s, v19.16b, v21.16b\n" - "ldr q21, [x23, #0x60]\n" - "shl v19.16b, v1.16b, #0x4\n" - "and v1.16b, v1.16b, v15.16b\n" - ".inst 0x4e93a626 // smmla v6.4s, v17.16b, v19.16b\n" - ".inst 0x4e93a744 // smmla v4.4s, v26.16b, v19.16b\n" - ".inst 0x4e93a79f // smmla v31.4s, v28.16b, v19.16b\n" - ".inst 0x4e93a607 // smmla v7.4s, v16.16b, v19.16b\n" - "ldr q19, [x23, #0x70]\n" - "add x23, x23, #0x80\n" - ".inst 0x4e82a726 // smmla v6.4s, v25.16b, v2.16b\n" - ".inst 0x4e82a564 // smmla v4.4s, v11.16b, v2.16b\n" - ".inst 0x4e82a53f // smmla v31.4s, v9.16b, v2.16b\n" - ".inst 0x4e82a767 // smmla v7.4s, v27.16b, v2.16b\n" - "shl v2.16b, v29.16b, #0x4\n" - "and v29.16b, v29.16b, v15.16b\n" - ".inst 0x4e82a62a // smmla v10.4s, v17.16b, v2.16b\n" - "ldr q17, [x21, #0x60]\n" - ".inst 0x4e82a752 // smmla v18.4s, v26.16b, v2.16b\n" - "ldr q26, [x21, #0x70]\n" - "add x21, x21, #0x80\n" - ".inst 0x4e82a783 // smmla v3.4s, v28.16b, v2.16b\n" - ".inst 0x4e82a617 // smmla v23.4s, v16.16b, v2.16b\n" - ".inst 0x4e81a6a6 // smmla v6.4s, v21.16b, v1.16b\n" - ".inst 0x4e81a664 // smmla v4.4s, v19.16b, v1.16b\n" - ".inst 0x4e81a63f // smmla v31.4s, v17.16b, v1.16b\n" - ".inst 0x4e94a72a // smmla v10.4s, v25.16b, v20.16b\n" - ".inst 0x4e94a572 // smmla v18.4s, v11.16b, v20.16b\n" - ".inst 0x4e81a747 // smmla v7.4s, v26.16b, v1.16b\n" - ".inst 0x4e94a523 // smmla v3.4s, v9.16b, v20.16b\n" - ".inst 0x4e94a777 // smmla v23.4s, v27.16b, v20.16b\n" - ".inst 0x4e9da6aa // smmla v10.4s, v21.16b, v29.16b\n" - ".inst 0x4e9da672 // smmla v18.4s, v19.16b, v29.16b\n" - ".inst 0x4e9da623 // smmla v3.4s, v17.16b, v29.16b\n" - ".inst 0x4e9da757 // smmla v23.4s, v26.16b, v29.16b\n" - "bgt 4b\n" - "ldr d20, [x10, #0x0]\n" - "uzp1 v21.2d, v6.2d, v10.2d\n" - "uzp2 v19.2d, v6.2d, v10.2d\n" - "add x10, x10, #0x8\n" - "uzp1 v17.2d, v4.2d, v18.2d\n" - "uzp2 v16.2d, v4.2d, v18.2d\n" - "shll v20.4s, v20.4h, #0x10\n" - "scvtf v21.4s, v21.4s\n" - "scvtf v19.4s, v19.4s\n" - "scvtf v17.4s, v17.4s\n" - "scvtf v16.4s, v16.4s\n" - "fmul v20.4s, v20.4s, v24.4s\n" - "fmla v12.4s, v21.4s, v20.4s\n" - "fmla v13.4s, v19.4s, v20.4s\n" - "fmla v22.4s, v17.4s, v20.4s\n" - "fmla v14.4s, v16.4s, v20.4s\n" - "uzp1 v19.2d, v31.2d, v3.2d\n" - "uzp2 v18.2d, v31.2d, v3.2d\n" - "uzp1 v17.2d, v7.2d, v23.2d\n" - "uzp2 v16.2d, v7.2d, v23.2d\n" - "scvtf v19.4s, v19.4s\n" - "scvtf v18.4s, v18.4s\n" - "scvtf v17.4s, v17.4s\n" - "scvtf v16.4s, v16.4s\n" - "fmla v5.4s, v19.4s, v20.4s\n" - "fmla v0.4s, v18.4s, v20.4s\n" - "fmla v30.4s, v17.4s, v20.4s\n" - "fmla v8.4s, v16.4s, v20.4s\n" - "subs x22, x22, #0x1\n" - "bgt 3b\n" - "ld1 { v23.4s }, [x23]\n" - "ld1 { v1.4s }, [x21]\n" - "add x23, x23, #0x10\n" - "add x21, x21, #0x10\n" - "ldr q21, [x10, #0x0]\n" - "ldr q20, [x23, #0x0]\n" - "add x20, %x[clamp_vals], #0x4\n" - "cmp x9, #0x4\n" - "ldr q19, [x21, #0x0]\n" - "ldr q18, [x10, #0x10]\n" - "add x10, x10, #0x20\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v16.4s }, [x20]\n" - "scvtf v23.4s, v23.4s\n" - "scvtf v1.4s, v1.4s\n" - "fmla v12.4s, v21.4s, v23.s[0]\n" - "fmla v13.4s, v21.4s, v23.s[1]\n" - "fmla v22.4s, v21.4s, v23.s[2]\n" - "fmla v14.4s, v21.4s, v23.s[3]\n" - "fmla v5.4s, v21.4s, v1.s[0]\n" - "fmla v0.4s, v21.4s, v1.s[1]\n" - "fmla v30.4s, v21.4s, v1.s[2]\n" - "fmla v8.4s, v21.4s, v1.s[3]\n" - "fmul v12.4s, v12.4s, v20.s[0]\n" - "fmul v13.4s, v13.4s, v20.s[1]\n" - "fmul v22.4s, v22.4s, v20.s[2]\n" - "fmul v14.4s, v14.4s, v20.s[3]\n" - "fmul v5.4s, v5.4s, v19.s[0]\n" - "fmul v0.4s, v0.4s, v19.s[1]\n" - "fadd v12.4s, v12.4s, v18.4s\n" - "fmul v30.4s, v30.4s, v19.s[2]\n" - "fmul v8.4s, v8.4s, v19.s[3]\n" - "fadd v13.4s, v13.4s, v18.4s\n" - "fadd v22.4s, v22.4s, v18.4s\n" - "fadd v14.4s, v14.4s, v18.4s\n" - "fadd v5.4s, v5.4s, v18.4s\n" - "fadd v0.4s, v0.4s, v18.4s\n" - "fadd v30.4s, v30.4s, v18.4s\n" - "fadd v8.4s, v8.4s, v18.4s\n" - "fmax v12.4s, v12.4s, v17.4s\n" - "fmax v13.4s, v13.4s, v17.4s\n" - "fmax v22.4s, v22.4s, v17.4s\n" - "fmax v14.4s, v14.4s, v17.4s\n" - "fmax v5.4s, v5.4s, v17.4s\n" - "fmax v0.4s, v0.4s, v17.4s\n" - "fmax v30.4s, v30.4s, v17.4s\n" - "fmax v8.4s, v8.4s, v17.4s\n" - "fmin v12.4s, v12.4s, v16.4s\n" - "fmin v13.4s, v13.4s, v16.4s\n" - "fmin v22.4s, v22.4s, v16.4s\n" - "fmin v14.4s, v14.4s, v16.4s\n" - "fmin v5.4s, v5.4s, v16.4s\n" - "fmin v0.4s, v0.4s, v16.4s\n" - "fmin v30.4s, v30.4s, v16.4s\n" - "fmin v8.4s, v8.4s, v16.4s\n" - "blt 7f\n" - "mov x20, %x[dst]\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q13, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q14, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q5, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q0, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q8, [x20, #0x0]\n" - "b 10f\n" - "7:" // Partial output - "mov x27, %x[dst]\n" - "add x26, x27, %x[dst_stride_row], LSL #2\n" - "add x25, x26, %x[dst_stride_row], LSL #1\n" - "add x24, x26, %x[dst_stride_row]\n" - "add x23, x25, %x[dst_stride_row]\n" - "add x22, x27, %x[dst_stride_row], LSL #1\n" - "add x21, x27, %x[dst_stride_row]\n" - "add x20, x22, %x[dst_stride_row]\n" - "tbz x9, #1, 8f\n" - "st1 { v8.d }[0], [x23], #0x8\n" - "st1 { v30.d }[0], [x25], #0x8\n" - "st1 { v0.d }[0], [x24], #0x8\n" - "st1 { v5.d }[0], [x26], #0x8\n" - "st1 { v14.d }[0], [x20], #0x8\n" - "st1 { v22.d }[0], [x22], #0x8\n" - "st1 { v13.d }[0], [x21], #0x8\n" - "st1 { v12.d }[0], [x27], #0x8\n" - "tbz x9, #0, 9f\n" - "st1 { v8.s }[2], [x23]\n" - "st1 { v30.s }[2], [x25]\n" - "st1 { v0.s }[2], [x24]\n" - "st1 { v5.s }[2], [x26]\n" - "st1 { v14.s }[2], [x20]\n" - "st1 { v22.s }[2], [x22]\n" - "st1 { v13.s }[2], [x21]\n" - "st1 { v12.s }[2], [x27]\n" - "b 9f\n" - "8:" // Output block 0: partial_1_0 - "st1 { v8.s }[0], [x23]\n" - "st1 { v30.s }[0], [x25]\n" - "st1 { v0.s }[0], [x24]\n" - "st1 { v5.s }[0], [x26]\n" - "st1 { v14.s }[0], [x20]\n" - "st1 { v22.s }[0], [x22]\n" - "st1 { v13.s }[0], [x21]\n" - "st1 { v12.s }[0], [x27]\n" - "9:" // Output block 0: Done - "10:" // Output stage exit - "subs x9, x9, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 2b\n" - "mov x20, #0x2\n" - "sub x11, x11, #0x8\n" - "cmp x11, #0x8\n" - "mov %x[dst], x28\n" - "madd %x[lhs_packed], x20, x12, %x[lhs_packed]\n" - "bge 1b\n" - "11:" // Row loop skip - "cbz x11, 21f\n" - "12:" // Row tail: Row loop - "mov x26, %x[rhs_packed]\n" - "mov x25, %x[n]\n" - "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" - "13:" // Row tail: Column loop - "movi v12.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "mov x23, %x[lhs_packed]\n" - "mov x21, %x[num_blocks]\n" - "movi v22.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "14:" // Row tail: Block loop - "movi v6.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "movi v4.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "15:" // Row tail: Sub block loop - "ldr q0, [x26, #0x0]\n" - "ldr q31, [x26, #0x10]\n" - "subs x20, x20, #0x1\n" - "ldr q11, [x23, #0x0]\n" - "ldr q30, [x23, #0x10]\n" - "ldr q29, [x26, #0x20]\n" - "ldr q28, [x26, #0x30]\n" - "add x26, x26, #0x40\n" - "ldr q27, [x23, #0x20]\n" - "ldr q26, [x23, #0x30]\n" - "shl v25.16b, v0.16b, #0x4\n" - "shl v23.16b, v31.16b, #0x4\n" - "ldr q1, [x23, #0x40]\n" - "ldr q21, [x23, #0x50]\n" - "and v0.16b, v0.16b, v15.16b\n" - "and v31.16b, v31.16b, v15.16b\n" - "ldr q20, [x23, #0x60]\n" - "ldr q19, [x23, #0x70]\n" - "shl v17.16b, v29.16b, #0x4\n" - "shl v16.16b, v28.16b, #0x4\n" - ".inst 0x4e99a566 // smmla v6.4s, v11.16b, v25.16b\n" - ".inst 0x4e97a56a // smmla v10.4s, v11.16b, v23.16b\n" - "and v29.16b, v29.16b, v15.16b\n" - "add x23, x23, #0x80\n" - ".inst 0x4e99a7c4 // smmla v4.4s, v30.16b, v25.16b\n" - ".inst 0x4e97a7d2 // smmla v18.4s, v30.16b, v23.16b\n" - "and v28.16b, v28.16b, v15.16b\n" - ".inst 0x4e91a766 // smmla v6.4s, v27.16b, v17.16b\n" - ".inst 0x4e90a76a // smmla v10.4s, v27.16b, v16.16b\n" - ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" - ".inst 0x4e90a752 // smmla v18.4s, v26.16b, v16.16b\n" - ".inst 0x4e80a426 // smmla v6.4s, v1.16b, v0.16b\n" - ".inst 0x4e9fa42a // smmla v10.4s, v1.16b, v31.16b\n" - ".inst 0x4e80a6a4 // smmla v4.4s, v21.16b, v0.16b\n" - ".inst 0x4e9fa6b2 // smmla v18.4s, v21.16b, v31.16b\n" - ".inst 0x4e9da686 // smmla v6.4s, v20.16b, v29.16b\n" - ".inst 0x4e9ca68a // smmla v10.4s, v20.16b, v28.16b\n" - ".inst 0x4e9da664 // smmla v4.4s, v19.16b, v29.16b\n" - ".inst 0x4e9ca672 // smmla v18.4s, v19.16b, v28.16b\n" - "bgt 15b\n" - "ldr d16, [x26, #0x0]\n" - "uzp1 v21.2d, v6.2d, v10.2d\n" - "uzp2 v20.2d, v6.2d, v10.2d\n" - "add x26, x26, #0x8\n" - "uzp1 v19.2d, v4.2d, v18.2d\n" - "uzp2 v17.2d, v4.2d, v18.2d\n" - "shll v16.4s, v16.4h, #0x10\n" - "scvtf v21.4s, v21.4s\n" - "scvtf v20.4s, v20.4s\n" - "scvtf v19.4s, v19.4s\n" - "scvtf v17.4s, v17.4s\n" - "fmul v16.4s, v16.4s, v24.4s\n" - "fmla v12.4s, v21.4s, v16.4s\n" - "fmla v13.4s, v20.4s, v16.4s\n" - "fmla v22.4s, v19.4s, v16.4s\n" - "fmla v14.4s, v17.4s, v16.4s\n" - "subs x21, x21, #0x1\n" - "bgt 14b\n" - "ld1 { v21.4s }, [x23]\n" - "ldr q20, [x26, #0x0]\n" - "add x23, x23, #0x10\n" - "add x20, %x[clamp_vals], #0x4\n" - "ldr q19, [x23, #0x0]\n" - "ldr q18, [x26, #0x10]\n" - "cmp x25, #0x4\n" - "add x26, x26, #0x20\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v16.4s }, [x20]\n" - "scvtf v21.4s, v21.4s\n" - "fmla v12.4s, v20.4s, v21.s[0]\n" - "fmla v13.4s, v20.4s, v21.s[1]\n" - "fmla v22.4s, v20.4s, v21.s[2]\n" - "fmla v14.4s, v20.4s, v21.s[3]\n" - "fmul v12.4s, v12.4s, v19.s[0]\n" - "fmul v13.4s, v13.4s, v19.s[1]\n" - "fmul v22.4s, v22.4s, v19.s[2]\n" - "fadd v12.4s, v12.4s, v18.4s\n" - "fmul v14.4s, v14.4s, v19.s[3]\n" - "fadd v13.4s, v13.4s, v18.4s\n" - "fadd v22.4s, v22.4s, v18.4s\n" - "fadd v14.4s, v14.4s, v18.4s\n" - "fmax v12.4s, v12.4s, v17.4s\n" - "fmax v13.4s, v13.4s, v17.4s\n" - "fmax v22.4s, v22.4s, v17.4s\n" - "fmax v14.4s, v14.4s, v17.4s\n" - "fmin v12.4s, v12.4s, v16.4s\n" - "fmin v13.4s, v13.4s, v16.4s\n" - "fmin v22.4s, v22.4s, v16.4s\n" - "fmin v14.4s, v14.4s, v16.4s\n" - "blt 17f\n" - "mov x20, %x[dst]\n" - "cmp x11, #0x1\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 20f\n" - "cmp x11, #0x2\n" - "str q13, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 20f\n" - "cmp x11, #0x3\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 20f\n" - "str q14, [x20, #0x0]\n" - "b 20f\n" - "17:" // Row tail: Partial output - "mov x23, %x[dst]\n" - "cmp x11, #0x1\n" - "add x22, x23, %x[dst_stride_row]\n" - "csel x22, x22, x23, GT\n" - "cmp x11, #0x2\n" - "add x21, x23, %x[dst_stride_row], LSL #1\n" - "csel x21, x21, x22, GT\n" - "cmp x11, #0x3\n" - "add x20, x21, %x[dst_stride_row]\n" - "csel x20, x20, x21, GT\n" - "tbz x25, #1, 18f\n" - "st1 { v14.d }[0], [x20], #0x8\n" - "st1 { v22.d }[0], [x21], #0x8\n" - "st1 { v13.d }[0], [x22], #0x8\n" - "st1 { v12.d }[0], [x23], #0x8\n" - "tbz x25, #0, 19f\n" - "st1 { v14.s }[2], [x20]\n" - "st1 { v22.s }[2], [x21]\n" - "st1 { v13.s }[2], [x22]\n" - "st1 { v12.s }[2], [x23]\n" - "b 19f\n" - "18:" // Row tail: Output block 0: partial_1_0 - "st1 { v14.s }[0], [x20]\n" - "st1 { v22.s }[0], [x21]\n" - "st1 { v13.s }[0], [x22]\n" - "st1 { v12.s }[0], [x23]\n" - "19:" // Row tail: Output block 0: Done - "20:" // Row tail: Output stage exit - "subs x25, x25, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 13b\n" - "subs x11, x11, #0x4\n" - "add %x[lhs_packed], %x[lhs_packed], x12\n" - "mov %x[dst], x24\n" - "bgt 12b\n" - "21:" // Row tail: Row loop skip - : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) - : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), - [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); + const size_t num_subblocks = bl / 32; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(&args); } -#endif // Architectural feature check + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h index 5f9059d0..81d48e04 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h @@ -1,21 +1,22 @@ - // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // + #pragma once #include #ifdef __cplusplus extern "C" { -#endif +#endif // __cplusplus /// Micro-kernel dependencies /// -/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix +/// -# @ref kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS KxN matrix. /// -------------------------------------------------- @@ -38,7 +39,7 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(v /// @return the mr value size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); -/// Gets the nr value, which must be used to pack the RHS matrix +/// Gets the nr value, which must be used to pack the RHS matrix. /// /// @return the nr value size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); @@ -54,13 +55,14 @@ size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); /// Gets the offset in bytes for the packed LHS matrix, -/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. /// /// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. /// -/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 8 +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. /// @param[in] k Total number of columns in the LHS matrix (not packed). /// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( @@ -68,9 +70,9 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_ size_t k); // /// Gets the offset in bytes for the packed RHS matrix, -/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// which contains the packed Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) values. /// -/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. /// @param[in] k The common dimension between the LHS and RHS matrix (K). /// It must be a multiple of the block length (bl). /// @param[in] bl Block length. It must be a multiple of 32. @@ -83,8 +85,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_ /// Gets the offset in bytes for the DST matrix /// -/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 8. -/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. /// @param[in] dst_stride The number of bytes in in each row of the DST matrix /// /// @return the DST offset in bytes @@ -105,26 +107,26 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm /// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. /// -/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed -/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. -/// Output tile: (rows x cols) = 8 x 4 -/// Accumulation performed in a single for loop: 32 -/// Extension used: i8mm +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: i8mm /// /// @param[in] m The number of output rows written. /// @param[in] n The number of output columns written. /// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. /// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. It must be a multiple of 32. -/// @param[in] lhs_packed The LHS packed matrix. -/// When the activation are dynamically quantized, you can obtain this matrix -/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs -/// both the dynamic quantization to 8-bit and activation packing in a single step. -/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. -/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. /// @param[in] scalar_min Min value used to clamp the final result. /// @param[in] scalar_max Max value used to clamp the final result. void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( @@ -142,4 +144,4 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( #ifdef __cplusplus } -#endif +#endif // __cplusplus -- GitLab From fccdc6f0b79d72e5dcec9ab0f419750582b44d1c Mon Sep 17 00:00:00 2001 From: Michael Kozlov Date: Thu, 23 Jan 2025 16:35:50 +0000 Subject: [PATCH 06/15] Extract inline assembly kernels into external files: 4x8x32 Signed-off-by: Michael Kozlov --- CMakeLists.txt | 1 + ...8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S | 340 ++++++++++++++ ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c | 438 +++++------------- ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h | 54 +-- 4 files changed, 476 insertions(+), 357 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index de1f6758..fb893901 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -160,6 +160,7 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S ) set(KLEIDIAI_FILES_SME diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S new file mode 100644 index 00000000..fd8afe71 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S @@ -0,0 +1,340 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x16, #0x80 + movi v9.16b, #0xf0 + mov x21, #0x20 + ldr x15, [x0, #0x40] + ldr x20, [x0, #0x28] + ldr x14, [x0, #0x38] + ldr x13, [x0, #0x8] + ldr x12, [x0, #0x10] + ldr x11, [x0, #0x30] + mul x16, x15, x16 + mov x10, x20 + ldr x9, [x0, #0x0] + ldr x28, [x0, #0x20] + ldr x27, [x0, #0x18] + madd x16, x14, x16, x21 + cbz x10, label_12 +KAI_ASM_LABEL(label_1) // Row loop + mov x26, x12 + mov x25, x11 + add x24, x9, x28, LSL #2 +KAI_ASM_LABEL(label_2) // Column loop + movi v12.16b, #0x0 + movi v5.16b, #0x0 + mov x22, x13 + mov x21, x14 + movi v11.16b, #0x0 + movi v13.16b, #0x0 + movi v21.16b, #0x0 + movi v27.16b, #0x0 + movi v7.16b, #0x0 + movi v4.16b, #0x0 +KAI_ASM_LABEL(label_3) // Block loop + movi v10.4s, #0x0 + movi v26.4s, #0x0 + mov x20, x15 + movi v18.4s, #0x0 + movi v2.4s, #0x0 + movi v3.4s, #0x0 + movi v16.4s, #0x0 + movi v8.4s, #0x0 + movi v19.4s, #0x0 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q14, [x26, #0x0] + ldr q28, [x26, #0x10] + subs x20, x20, #0x1 + ldr q23, [x26, #0x20] + ldr q1, [x26, #0x30] + ldr q22, [x22, #0x0] + ldr q15, [x22, #0x10] + ldr q25, [x26, #0x40] + ldr q17, [x26, #0x50] + shl v31.16b, v14.16b, #0x4 + shl v0.16b, v28.16b, #0x4 + ldr q6, [x26, #0x60] + ldr q24, [x26, #0x70] + shl v20.16b, v23.16b, #0x4 + shl v29.16b, v1.16b, #0x4 + ldr q30, [x22, #0x20] + and v14.16b, v14.16b, v9.16b + and v28.16b, v28.16b, v9.16b + add x26, x26, #0x80 + KAI_ASM_INST(0x4e9fa6ca) // smmla v10.4s, v22.16b, v31.16b + KAI_ASM_INST(0x4e80a6d2) // smmla v18.4s, v22.16b, v0.16b + and v23.16b, v23.16b, v9.16b + KAI_ASM_INST(0x4e94a6da) // smmla v26.4s, v22.16b, v20.16b + KAI_ASM_INST(0x4e9da6c2) // smmla v2.4s, v22.16b, v29.16b + ldr q22, [x22, #0x30] + and v1.16b, v1.16b, v9.16b + KAI_ASM_INST(0x4e9fa5e3) // smmla v3.4s, v15.16b, v31.16b + ldr q31, [x22, #0x40] + KAI_ASM_INST(0x4e80a5e8) // smmla v8.4s, v15.16b, v0.16b + ldr q0, [x22, #0x50] + KAI_ASM_INST(0x4e94a5f0) // smmla v16.4s, v15.16b, v20.16b + ldr q20, [x22, #0x60] + KAI_ASM_INST(0x4e9da5f3) // smmla v19.4s, v15.16b, v29.16b + ldr q29, [x22, #0x70] + shl v15.16b, v25.16b, #0x4 + and v25.16b, v25.16b, v9.16b + add x22, x22, #0x80 + KAI_ASM_INST(0x4e8fa7ca) // smmla v10.4s, v30.16b, v15.16b + KAI_ASM_INST(0x4e8fa6c3) // smmla v3.4s, v22.16b, v15.16b + shl v15.16b, v17.16b, #0x4 + and v17.16b, v17.16b, v9.16b + KAI_ASM_INST(0x4e8fa7d2) // smmla v18.4s, v30.16b, v15.16b + KAI_ASM_INST(0x4e8fa6c8) // smmla v8.4s, v22.16b, v15.16b + shl v15.16b, v6.16b, #0x4 + and v6.16b, v6.16b, v9.16b + KAI_ASM_INST(0x4e8ea7ea) // smmla v10.4s, v31.16b, v14.16b + KAI_ASM_INST(0x4e8ea403) // smmla v3.4s, v0.16b, v14.16b + shl v14.16b, v24.16b, #0x4 + and v24.16b, v24.16b, v9.16b + KAI_ASM_INST(0x4e8fa7da) // smmla v26.4s, v30.16b, v15.16b + KAI_ASM_INST(0x4e8fa6d0) // smmla v16.4s, v22.16b, v15.16b + KAI_ASM_INST(0x4e9ca7f2) // smmla v18.4s, v31.16b, v28.16b + KAI_ASM_INST(0x4e9ca408) // smmla v8.4s, v0.16b, v28.16b + KAI_ASM_INST(0x4e8ea7c2) // smmla v2.4s, v30.16b, v14.16b + KAI_ASM_INST(0x4e8ea6d3) // smmla v19.4s, v22.16b, v14.16b + KAI_ASM_INST(0x4e99a68a) // smmla v10.4s, v20.16b, v25.16b + KAI_ASM_INST(0x4e99a7a3) // smmla v3.4s, v29.16b, v25.16b + KAI_ASM_INST(0x4e97a7fa) // smmla v26.4s, v31.16b, v23.16b + KAI_ASM_INST(0x4e97a410) // smmla v16.4s, v0.16b, v23.16b + KAI_ASM_INST(0x4e91a692) // smmla v18.4s, v20.16b, v17.16b + KAI_ASM_INST(0x4e91a7a8) // smmla v8.4s, v29.16b, v17.16b + KAI_ASM_INST(0x4e81a7e2) // smmla v2.4s, v31.16b, v1.16b + KAI_ASM_INST(0x4e81a413) // smmla v19.4s, v0.16b, v1.16b + KAI_ASM_INST(0x4e86a69a) // smmla v26.4s, v20.16b, v6.16b + KAI_ASM_INST(0x4e86a7b0) // smmla v16.4s, v29.16b, v6.16b + KAI_ASM_INST(0x4e98a682) // smmla v2.4s, v20.16b, v24.16b + KAI_ASM_INST(0x4e98a7b3) // smmla v19.4s, v29.16b, v24.16b + bgt label_4 + ldr q0, [x26, #0x0] + uzp1 v14.2d, v10.2d, v18.2d + uzp2 v30.2d, v10.2d, v18.2d + add x26, x26, #0x10 + uzp1 v25.2d, v26.2d, v2.2d + uzp2 v22.2d, v26.2d, v2.2d + uzp1 v24.2d, v3.2d, v8.2d + uzp2 v20.2d, v3.2d, v8.2d + uzp1 v28.2d, v16.2d, v19.2d + uzp2 v18.2d, v16.2d, v19.2d + shll v17.4s, v0.4h, #0x10 + shll2 v16.4s, v0.8h, #0x10 + scvtf v14.4s, v14.4s, #0x4 + scvtf v25.4s, v25.4s, #0x4 + scvtf v30.4s, v30.4s, #0x4 + scvtf v22.4s, v22.4s, #0x4 + scvtf v24.4s, v24.4s, #0x4 + scvtf v28.4s, v28.4s, #0x4 + scvtf v20.4s, v20.4s, #0x4 + scvtf v18.4s, v18.4s, #0x4 + fmla v12.4s, v14.4s, v17.4s + fmla v5.4s, v25.4s, v16.4s + fmla v11.4s, v30.4s, v17.4s + fmla v13.4s, v22.4s, v16.4s + fmla v21.4s, v24.4s, v17.4s + fmla v27.4s, v28.4s, v16.4s + fmla v7.4s, v20.4s, v17.4s + fmla v4.4s, v18.4s, v16.4s + subs x21, x21, #0x1 + bgt label_3 + ld1 { v23.4s }, [x22] + ldr q22, [x26, #0x0] + add x22, x22, #0x10 + add x20, x27, #0x4 + ldr q6, [x26, #0x10] + ldr q20, [x22, #0x0] + cmp x25, #0x8 + ldr q19, [x26, #0x20] + ldr q18, [x26, #0x30] + add x26, x26, #0x40 + ld1r { v17.4s }, [x27] + ld1r { v16.4s }, [x20] + scvtf v23.4s, v23.4s + fmla v12.4s, v22.4s, v23.s[0] + fmla v5.4s, v6.4s, v23.s[0] + fmla v11.4s, v22.4s, v23.s[1] + fmla v13.4s, v6.4s, v23.s[1] + fmla v21.4s, v22.4s, v23.s[2] + fmla v27.4s, v6.4s, v23.s[2] + fmla v7.4s, v22.4s, v23.s[3] + fmla v4.4s, v6.4s, v23.s[3] + fmul v12.4s, v12.4s, v20.s[0] + fmul v5.4s, v5.4s, v20.s[0] + fmul v11.4s, v11.4s, v20.s[1] + fmul v13.4s, v13.4s, v20.s[1] + fmul v21.4s, v21.4s, v20.s[2] + fmul v27.4s, v27.4s, v20.s[2] + fmul v7.4s, v7.4s, v20.s[3] + fmul v4.4s, v4.4s, v20.s[3] + fadd v12.4s, v12.4s, v19.4s + fadd v5.4s, v5.4s, v18.4s + fadd v11.4s, v11.4s, v19.4s + fadd v13.4s, v13.4s, v18.4s + fadd v21.4s, v21.4s, v19.4s + fadd v27.4s, v27.4s, v18.4s + fadd v7.4s, v7.4s, v19.4s + fadd v4.4s, v4.4s, v18.4s + fmax v12.4s, v12.4s, v17.4s + fmax v5.4s, v5.4s, v17.4s + fmax v11.4s, v11.4s, v17.4s + fmax v13.4s, v13.4s, v17.4s + fmax v21.4s, v21.4s, v17.4s + fmax v27.4s, v27.4s, v17.4s + fmax v7.4s, v7.4s, v17.4s + fmax v4.4s, v4.4s, v17.4s + fmin v12.4s, v12.4s, v16.4s + fmin v5.4s, v5.4s, v16.4s + fmin v11.4s, v11.4s, v16.4s + fmin v13.4s, v13.4s, v16.4s + fmin v21.4s, v21.4s, v16.4s + fmin v27.4s, v27.4s, v16.4s + fmin v7.4s, v7.4s, v16.4s + fmin v4.4s, v4.4s, v16.4s + blt label_6 + mov x20, x9 + cmp x10, #0x1 + str q12, [x20, #0x0] + str q5, [x20, #0x10] + add x20, x20, x28 + ble label_11 + cmp x10, #0x2 + str q11, [x20, #0x0] + str q13, [x20, #0x10] + add x20, x20, x28 + ble label_11 + cmp x10, #0x3 + str q21, [x20, #0x0] + str q27, [x20, #0x10] + add x20, x20, x28 + ble label_11 + str q7, [x20, #0x0] + str q4, [x20, #0x10] + b label_11 +KAI_ASM_LABEL(label_6) // Partial output + mov x23, x9 + cmp x10, #0x1 + add x22, x23, x28 + csel x22, x22, x23, GT + cmp x10, #0x2 + add x21, x23, x28, LSL #1 + csel x21, x21, x22, GT + cmp x10, #0x3 + add x20, x21, x28 + csel x20, x20, x21, GT + tbz x25, #2, label_8 + st1 { v7.4s }, [x20], #0x10 + st1 { v21.4s }, [x21], #0x10 + st1 { v11.4s }, [x22], #0x10 + st1 { v12.4s }, [x23], #0x10 + tbz x25, #1, label_7 + st1 { v4.d }[0], [x20], #0x8 + st1 { v27.d }[0], [x21], #0x8 + st1 { v13.d }[0], [x22], #0x8 + st1 { v5.d }[0], [x23], #0x8 + tbz x25, #0, label_10 + st1 { v4.s }[2], [x20] + st1 { v27.s }[2], [x21] + st1 { v13.s }[2], [x22] + st1 { v5.s }[2], [x23] + b label_10 +KAI_ASM_LABEL(label_7) // Output block 0: partial_1_4 + tbz x25, #0, label_10 + st1 { v4.s }[0], [x20] + st1 { v27.s }[0], [x21] + st1 { v13.s }[0], [x22] + st1 { v5.s }[0], [x23] + b label_10 +KAI_ASM_LABEL(label_8) // Output block 0: partial_2_0 + tbz x25, #1, label_9 + st1 { v7.d }[0], [x20], #0x8 + st1 { v21.d }[0], [x21], #0x8 + st1 { v11.d }[0], [x22], #0x8 + st1 { v12.d }[0], [x23], #0x8 + tbz x25, #0, label_10 + st1 { v7.s }[2], [x20] + st1 { v21.s }[2], [x21] + st1 { v11.s }[2], [x22] + st1 { v12.s }[2], [x23] + b label_10 +KAI_ASM_LABEL(label_9) // Output block 0: partial_1_0 + st1 { v7.s }[0], [x20] + st1 { v21.s }[0], [x21] + st1 { v11.s }[0], [x22] + st1 { v12.s }[0], [x23] +KAI_ASM_LABEL(label_10) // Output block 0: Done +KAI_ASM_LABEL(label_11) // Output stage exit + subs x25, x25, #0x8 + add x9, x9, #0x20 + bgt label_2 + subs x10, x10, #0x4 + add x13, x13, x16 + mov x9, x24 + bgt label_1 +KAI_ASM_LABEL(label_12) // Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c index 41795b88..777fc994 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c @@ -3,62 +3,95 @@ // // SPDX-License-Identifier: Apache-2.0 // - -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__ARM_FEATURE_MATMUL_INT8) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) #error "I8mm extension required to compile this micro-kernel" -#else +#else // Architectural features check. + #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" -#include #include #include #include "kai/kai_common.h" +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(KernelArgs* args_ptr); + +// Compute args static const size_t kai_m_step = 4; static const size_t kai_n_step = 8; +// Packing args static const size_t kai_mr = 4; static const size_t kai_nr = 8; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_bl_multiple_of = 32; -static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); -static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); -static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); -static const size_t kai_num_bytes_sum_rhs = sizeof(float); -static const size_t kai_num_bytes_bias = sizeof(float); - -inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - return kai_roundup(k, bl) / bl; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 2; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); } -inline static size_t kai_k_roundedup(size_t k) { - // Since we pack a float and int32 value at the end of the row, - // we must make sure that k is a multiple of 4 for alignment - size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); - return kai_roundup(k, kr_sr_roundedup4); +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs; + return num_bytes_per_block_rhs; } -inline static size_t kai_lhs_packed_stride(size_t k) { - const size_t k_internal = kai_k_roundedup(k); +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} - KAI_ASSERT((k_internal % 2) == 0); +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; - return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + return lhs_packed_stride; } -inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { - KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; - return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); + return rhs_packed_stride; } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { @@ -86,324 +119,67 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) } size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k) { - KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSUME((m_idx % kai_m_step) == 0); - return (m_idx / kai_mr) * kai_lhs_packed_stride(k); + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( size_t n_idx, size_t k, size_t bl) { - KAI_ASSERT((k % bl) == 0); - KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); - return (n_idx / kai_nr) * kai_rhs_packed_stride(k, bl); + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSERT((m_idx % kai_m_step) == 0); - KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); - return (n_idx * sizeof(float)) + m_idx * dst_stride; + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(size_t m, size_t n) { - return m * n * sizeof(float); + return m * n * kai_num_bytes_dst_value; } void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( - size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, - float* dst, // NOLINT(readability-non-const-parameter) - size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { - KAI_ASSERT((k % bl) == 0); - KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(dst_stride_col == sizeof(float)); + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); if (m == 0) { return; } - - size_t num_subblocks = bl / kai_bl_multiple_of; - size_t num_blocks = kai_num_blocks_per_row(k, bl); - - float clamp_vals[2] = {scalar_min, scalar_max}; - - __asm__ __volatile__( - "mov x28, #0x80\n" - "mov x21, #0x3d800000\n" - "movi v17.16b, #0xf0\n" - "mov x20, #0x20\n" - "mov x27, %x[m]\n" - "mul x28, %x[num_subblocks], x28\n" - "dup v14.4s, w21\n" - "madd x28, %x[num_blocks], x28, x20\n" - "cbz x27, 12f\n" - "1:" // Row loop - "mov x26, %x[rhs_packed]\n" - "mov x25, %x[n]\n" - "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" - "2:" // Column loop - "movi v1.16b, #0x0\n" - "movi v12.16b, #0x0\n" - "mov x22, %x[lhs_packed]\n" - "mov x21, %x[num_blocks]\n" - "movi v11.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "movi v18.16b, #0x0\n" - "movi v27.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "movi v4.16b, #0x0\n" - "3:" // Block loop - "movi v21.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "movi v24.4s, #0x0\n" - "movi v23.4s, #0x0\n" - "movi v7.4s, #0x0\n" - "movi v3.4s, #0x0\n" - "movi v2.4s, #0x0\n" - "movi v8.4s, #0x0\n" - "4:" // Sub block loop - "ldr q6, [x26, #0x0]\n" - "ldr q0, [x26, #0x10]\n" - "subs x20, x20, #0x1\n" - "ldr q10, [x26, #0x20]\n" - "ldr q26, [x26, #0x30]\n" - "ldr q22, [x22, #0x0]\n" - "ldr q20, [x22, #0x10]\n" - "ldr q31, [x26, #0x40]\n" - "ldr q15, [x26, #0x50]\n" - "shl v29.16b, v6.16b, #0x4\n" - "shl v9.16b, v0.16b, #0x4\n" - "ldr q25, [x26, #0x60]\n" - "ldr q16, [x26, #0x70]\n" - "shl v5.16b, v10.16b, #0x4\n" - "shl v19.16b, v26.16b, #0x4\n" - "and v6.16b, v6.16b, v17.16b\n" - "and v0.16b, v0.16b, v17.16b\n" - "add x26, x26, #0x80\n" - ".inst 0x4e9da6d5 // smmla v21.4s, v22.16b, v29.16b\n" - ".inst 0x4e89a6d8 // smmla v24.4s, v22.16b, v9.16b\n" - ".inst 0x4e9da687 // smmla v7.4s, v20.16b, v29.16b\n" - "ldr q29, [x22, #0x20]\n" - "and v10.16b, v10.16b, v17.16b\n" - ".inst 0x4e85a6de // smmla v30.4s, v22.16b, v5.16b\n" - ".inst 0x4e93a6d7 // smmla v23.4s, v22.16b, v19.16b\n" - "ldr q22, [x22, #0x30]\n" - "and v26.16b, v26.16b, v17.16b\n" - ".inst 0x4e89a682 // smmla v2.4s, v20.16b, v9.16b\n" - "ldr q9, [x22, #0x40]\n" - ".inst 0x4e85a683 // smmla v3.4s, v20.16b, v5.16b\n" - "ldr q5, [x22, #0x50]\n" - ".inst 0x4e93a688 // smmla v8.4s, v20.16b, v19.16b\n" - "ldr q19, [x22, #0x60]\n" - "shl v20.16b, v31.16b, #0x4\n" - "and v31.16b, v31.16b, v17.16b\n" - ".inst 0x4e94a7b5 // smmla v21.4s, v29.16b, v20.16b\n" - ".inst 0x4e94a6c7 // smmla v7.4s, v22.16b, v20.16b\n" - "ldr q20, [x22, #0x70]\n" - "add x22, x22, #0x80\n" - ".inst 0x4e86a535 // smmla v21.4s, v9.16b, v6.16b\n" - ".inst 0x4e86a4a7 // smmla v7.4s, v5.16b, v6.16b\n" - "shl v6.16b, v15.16b, #0x4\n" - "and v15.16b, v15.16b, v17.16b\n" - ".inst 0x4e86a7b8 // smmla v24.4s, v29.16b, v6.16b\n" - ".inst 0x4e86a6c2 // smmla v2.4s, v22.16b, v6.16b\n" - "shl v6.16b, v25.16b, #0x4\n" - "and v25.16b, v25.16b, v17.16b\n" - ".inst 0x4e9fa675 // smmla v21.4s, v19.16b, v31.16b\n" - ".inst 0x4e9fa687 // smmla v7.4s, v20.16b, v31.16b\n" - "shl v31.16b, v16.16b, #0x4\n" - "and v16.16b, v16.16b, v17.16b\n" - ".inst 0x4e86a7be // smmla v30.4s, v29.16b, v6.16b\n" - ".inst 0x4e86a6c3 // smmla v3.4s, v22.16b, v6.16b\n" - ".inst 0x4e80a538 // smmla v24.4s, v9.16b, v0.16b\n" - ".inst 0x4e80a4a2 // smmla v2.4s, v5.16b, v0.16b\n" - ".inst 0x4e9fa7b7 // smmla v23.4s, v29.16b, v31.16b\n" - ".inst 0x4e9fa6c8 // smmla v8.4s, v22.16b, v31.16b\n" - ".inst 0x4e8aa53e // smmla v30.4s, v9.16b, v10.16b\n" - ".inst 0x4e8aa4a3 // smmla v3.4s, v5.16b, v10.16b\n" - ".inst 0x4e8fa678 // smmla v24.4s, v19.16b, v15.16b\n" - ".inst 0x4e8fa682 // smmla v2.4s, v20.16b, v15.16b\n" - ".inst 0x4e9aa537 // smmla v23.4s, v9.16b, v26.16b\n" - ".inst 0x4e9aa4a8 // smmla v8.4s, v5.16b, v26.16b\n" - ".inst 0x4e99a67e // smmla v30.4s, v19.16b, v25.16b\n" - ".inst 0x4e99a683 // smmla v3.4s, v20.16b, v25.16b\n" - ".inst 0x4e90a677 // smmla v23.4s, v19.16b, v16.16b\n" - ".inst 0x4e90a688 // smmla v8.4s, v20.16b, v16.16b\n" - "bgt 4b\n" - "ldr q29, [x26, #0x0]\n" - "uzp1 v26.2d, v21.2d, v24.2d\n" - "uzp2 v25.2d, v21.2d, v24.2d\n" - "add x26, x26, #0x10\n" - "uzp1 v24.2d, v30.2d, v23.2d\n" - "uzp2 v23.2d, v30.2d, v23.2d\n" - "uzp1 v22.2d, v7.2d, v2.2d\n" - "uzp2 v21.2d, v7.2d, v2.2d\n" - "shll v20.4s, v29.4h, #0x10\n" - "shll2 v19.4s, v29.8h, #0x10\n" - "uzp1 v0.2d, v3.2d, v8.2d\n" - "uzp2 v8.2d, v3.2d, v8.2d\n" - "scvtf v26.4s, v26.4s\n" - "scvtf v24.4s, v24.4s\n" - "fmul v20.4s, v20.4s, v14.4s\n" - "fmul v19.4s, v19.4s, v14.4s\n" - "scvtf v25.4s, v25.4s\n" - "scvtf v23.4s, v23.4s\n" - "scvtf v22.4s, v22.4s\n" - "scvtf v0.4s, v0.4s\n" - "scvtf v21.4s, v21.4s\n" - "scvtf v8.4s, v8.4s\n" - "fmla v1.4s, v26.4s, v20.4s\n" - "fmla v12.4s, v24.4s, v19.4s\n" - "fmla v11.4s, v25.4s, v20.4s\n" - "fmla v13.4s, v23.4s, v19.4s\n" - "fmla v18.4s, v22.4s, v20.4s\n" - "fmla v27.4s, v0.4s, v19.4s\n" - "fmla v28.4s, v21.4s, v20.4s\n" - "fmla v4.4s, v8.4s, v19.4s\n" - "subs x21, x21, #0x1\n" - "bgt 3b\n" - "ld1 { v23.4s }, [x22]\n" - "ldr q22, [x26, #0x0]\n" - "add x22, x22, #0x10\n" - "add x20, %x[clamp_vals], #0x4\n" - "ldr q9, [x26, #0x10]\n" - "ldr q20, [x22, #0x0]\n" - "cmp x25, #0x8\n" - "ldr q19, [x26, #0x20]\n" - "ldr q21, [x26, #0x30]\n" - "add x26, x26, #0x40\n" - "ld1r { v10.4s }, [%x[clamp_vals]]\n" - "ld1r { v30.4s }, [x20]\n" - "scvtf v23.4s, v23.4s\n" - "fmla v1.4s, v22.4s, v23.s[0]\n" - "fmla v12.4s, v9.4s, v23.s[0]\n" - "fmla v11.4s, v22.4s, v23.s[1]\n" - "fmla v13.4s, v9.4s, v23.s[1]\n" - "fmla v18.4s, v22.4s, v23.s[2]\n" - "fmla v27.4s, v9.4s, v23.s[2]\n" - "fmla v28.4s, v22.4s, v23.s[3]\n" - "fmla v4.4s, v9.4s, v23.s[3]\n" - "fmul v1.4s, v1.4s, v20.s[0]\n" - "fmul v12.4s, v12.4s, v20.s[0]\n" - "fmul v11.4s, v11.4s, v20.s[1]\n" - "fmul v13.4s, v13.4s, v20.s[1]\n" - "fmul v18.4s, v18.4s, v20.s[2]\n" - "fmul v27.4s, v27.4s, v20.s[2]\n" - "fmul v28.4s, v28.4s, v20.s[3]\n" - "fmul v4.4s, v4.4s, v20.s[3]\n" - "fadd v1.4s, v1.4s, v19.4s\n" - "fadd v12.4s, v12.4s, v21.4s\n" - "fadd v11.4s, v11.4s, v19.4s\n" - "fadd v13.4s, v13.4s, v21.4s\n" - "fadd v18.4s, v18.4s, v19.4s\n" - "fadd v27.4s, v27.4s, v21.4s\n" - "fadd v28.4s, v28.4s, v19.4s\n" - "fadd v4.4s, v4.4s, v21.4s\n" - "fmax v1.4s, v1.4s, v10.4s\n" - "fmax v12.4s, v12.4s, v10.4s\n" - "fmax v11.4s, v11.4s, v10.4s\n" - "fmax v13.4s, v13.4s, v10.4s\n" - "fmax v18.4s, v18.4s, v10.4s\n" - "fmax v27.4s, v27.4s, v10.4s\n" - "fmax v28.4s, v28.4s, v10.4s\n" - "fmax v4.4s, v4.4s, v10.4s\n" - "fmin v1.4s, v1.4s, v30.4s\n" - "fmin v12.4s, v12.4s, v30.4s\n" - "fmin v11.4s, v11.4s, v30.4s\n" - "fmin v13.4s, v13.4s, v30.4s\n" - "fmin v18.4s, v18.4s, v30.4s\n" - "fmin v27.4s, v27.4s, v30.4s\n" - "fmin v28.4s, v28.4s, v30.4s\n" - "fmin v4.4s, v4.4s, v30.4s\n" - "blt 6f\n" - "mov x20, %x[dst]\n" - "cmp x27, #0x1\n" - "str q1, [x20, #0x0]\n" - "str q12, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 11f\n" - "cmp x27, #0x2\n" - "str q11, [x20, #0x0]\n" - "str q13, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 11f\n" - "cmp x27, #0x3\n" - "str q18, [x20, #0x0]\n" - "str q27, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 11f\n" - "str q28, [x20, #0x0]\n" - "str q4, [x20, #0x10]\n" - "b 11f\n" - "6:" // Partial output - "mov x23, %x[dst]\n" - "cmp x27, #0x1\n" - "add x22, x23, %x[dst_stride_row]\n" - "csel x22, x22, x23, GT\n" - "cmp x27, #0x2\n" - "add x21, x23, %x[dst_stride_row], LSL #1\n" - "csel x21, x21, x22, GT\n" - "cmp x27, #0x3\n" - "add x20, x21, %x[dst_stride_row]\n" - "csel x20, x20, x21, GT\n" - "tbz x25, #2, 8f\n" - "st1 { v28.4s }, [x20], #0x10\n" - "st1 { v18.4s }, [x21], #0x10\n" - "st1 { v11.4s }, [x22], #0x10\n" - "st1 { v1.4s }, [x23], #0x10\n" - "tbz x25, #1, 7f\n" - "st1 { v4.d }[0], [x20], #0x8\n" - "st1 { v27.d }[0], [x21], #0x8\n" - "st1 { v13.d }[0], [x22], #0x8\n" - "st1 { v12.d }[0], [x23], #0x8\n" - "tbz x25, #0, 10f\n" - "st1 { v4.s }[2], [x20]\n" - "st1 { v27.s }[2], [x21]\n" - "st1 { v13.s }[2], [x22]\n" - "st1 { v12.s }[2], [x23]\n" - "b 10f\n" - "7:" // Output block 0: partial_1_4 - "tbz x25, #0, 10f\n" - "st1 { v4.s }[0], [x20]\n" - "st1 { v27.s }[0], [x21]\n" - "st1 { v13.s }[0], [x22]\n" - "st1 { v12.s }[0], [x23]\n" - "b 10f\n" - "8:" // Output block 0: partial_2_0 - "tbz x25, #1, 9f\n" - "st1 { v28.d }[0], [x20], #0x8\n" - "st1 { v18.d }[0], [x21], #0x8\n" - "st1 { v11.d }[0], [x22], #0x8\n" - "st1 { v1.d }[0], [x23], #0x8\n" - "tbz x25, #0, 10f\n" - "st1 { v28.s }[2], [x20]\n" - "st1 { v18.s }[2], [x21]\n" - "st1 { v11.s }[2], [x22]\n" - "st1 { v1.s }[2], [x23]\n" - "b 10f\n" - "9:" // Output block 0: partial_1_0 - "st1 { v28.s }[0], [x20]\n" - "st1 { v18.s }[0], [x21]\n" - "st1 { v11.s }[0], [x22]\n" - "st1 { v1.s }[0], [x23]\n" - "10:" // Output block 0: Done - "11:" // Output stage exit - "subs x25, x25, #0x8\n" - "add %x[dst], %x[dst], #0x20\n" - "bgt 2b\n" - "subs x27, x27, #0x4\n" - "add %x[lhs_packed], %x[lhs_packed], x28\n" - "mov %x[dst], x24\n" - "bgt 1b\n" - "12:" // Row loop skip - : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) - : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), - [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); + const size_t num_subblocks = bl / 32; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(&args); } -#endif // Architectural feature check + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h index 3811b253..acd7d784 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h @@ -1,21 +1,22 @@ - // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // + #pragma once #include #ifdef __cplusplus extern "C" { -#endif +#endif // __cplusplus /// Micro-kernel dependencies /// -/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix +/// -# @ref kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS KxN matrix. /// -------------------------------------------------- @@ -38,7 +39,7 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(v /// @return the mr value size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); -/// Gets the nr value, which must be used to pack the RHS matrix +/// Gets the nr value, which must be used to pack the RHS matrix. /// /// @return the nr value size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); @@ -54,13 +55,14 @@ size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); /// Gets the offset in bytes for the packed LHS matrix, -/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. /// /// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. /// -/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 4 +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. /// @param[in] k Total number of columns in the LHS matrix (not packed). /// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( @@ -68,9 +70,9 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_ size_t k); // /// Gets the offset in bytes for the packed RHS matrix, -/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// which contains the packed Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) values. /// -/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 8. +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. /// @param[in] k The common dimension between the LHS and RHS matrix (K). /// It must be a multiple of the block length (bl). /// @param[in] bl Block length. It must be a multiple of 32. @@ -83,8 +85,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_ /// Gets the offset in bytes for the DST matrix /// -/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 4. -/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 8. +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. /// @param[in] dst_stride The number of bytes in in each row of the DST matrix /// /// @return the DST offset in bytes @@ -105,26 +107,26 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm /// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. /// -/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed -/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. -/// Output tile: (rows x cols) = 4 x 8 -/// Accumulation performed in a single for loop: 32 -/// Extension used: i8mm +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: i8mm /// /// @param[in] m The number of output rows written. /// @param[in] n The number of output columns written. /// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. /// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. It must be a multiple of 32. -/// @param[in] lhs_packed The LHS packed matrix. -/// When the activation are dynamically quantized, you can obtain this matrix -/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs -/// both the dynamic quantization to 8-bit and activation packing in a single step. -/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. -/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. /// @param[in] scalar_min Min value used to clamp the final result. /// @param[in] scalar_max Max value used to clamp the final result. void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( @@ -142,4 +144,4 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( #ifdef __cplusplus } -#endif +#endif // __cplusplus -- GitLab From dbd0a5deecba4a20d841ade06a3f598b824e67d2 Mon Sep 17 00:00:00 2001 From: Michael Kozlov Date: Thu, 23 Jan 2025 16:36:18 +0000 Subject: [PATCH 07/15] Extract inline assembly kernels into external files: 16x4x32 Signed-off-by: Michael Kozlov --- CMakeLists.txt | 2 + ...dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S | 744 +++++++++ ..._qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S | 670 ++++++++ ...qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c | 1445 ++--------------- ...qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h | 54 +- ...p4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c | 185 +++ ...p4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.h | 147 ++ 7 files changed, 1883 insertions(+), 1364 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.h diff --git a/CMakeLists.txt b/CMakeLists.txt index fb893901..94059adc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -161,6 +161,8 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S ) set(KLEIDIAI_FILES_SME diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S new file mode 100644 index 00000000..9fca8668 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S @@ -0,0 +1,744 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x5, #0x80 + mov x21, #0x20 + sub SP, SP, #0x100 + ldr x20, [x0, #0x28] + ldr x6, [x0, #0x40] + ldr x7, [x0, #0x38] + ldr x8, [x0, #0x8] + ldr x17, [x0, #0x10] + ldr x16, [x0, #0x30] + mov x15, x20 + mul x5, x6, x5 + ldr x14, [x0, #0x0] + ldr x13, [x0, #0x20] + ldr x12, [x0, #0x18] + cmp x15, #0x10 + madd x5, x7, x5, x21 + blt label_15 +KAI_ASM_LABEL(label_1) // Row loop + mov x11, x17 + mov x10, x16 + add x9, x14, x13, LSL #4 +KAI_ASM_LABEL(label_2) // Column loop + mov x27, x8 + movi v29.4s, #0x0 + mov x24, x7 + str q29, [SP, #0x0] + str q29, [SP, #0x10] + str q29, [SP, #0x20] + add x23, x27, x5 + add x22, x23, x5 + str q29, [SP, #0x30] + add x21, x22, x5 + str q29, [SP, #0x40] + str q29, [SP, #0x50] + str q29, [SP, #0x60] + str q29, [SP, #0x70] + str q29, [SP, #0x80] + str q29, [SP, #0x90] + str q29, [SP, #0xa0] + str q29, [SP, #0xb0] + str q29, [SP, #0xc0] + str q29, [SP, #0xd0] + str q29, [SP, #0xe0] + str q29, [SP, #0xf0] +KAI_ASM_LABEL(label_3) // Block loop + movi v7.4s, #0x0 + movi v13.4s, #0x0 + mov x20, x6 + movi v29.4s, #0x0 + movi v12.4s, #0x0 + movi v28.4s, #0x0 + movi v15.4s, #0x0 + movi v2.4s, #0x0 + movi v22.4s, #0x0 + movi v30.4s, #0x0 + movi v26.4s, #0x0 + movi v6.4s, #0x0 + movi v10.4s, #0x0 + movi v9.4s, #0x0 + movi v18.4s, #0x0 + movi v0.4s, #0x0 + movi v14.4s, #0x0 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q4, [x11, #0x0] + ldr q3, [x11, #0x10] + movi v31.16b, #0xf0 + subs x20, x20, #0x1 + ldr q27, [x27, #0x0] + ldr q1, [x27, #0x10] + ldr q19, [x23, #0x0] + ldr q17, [x23, #0x10] + ldr q21, [x22, #0x0] + ldr q23, [x22, #0x10] + shl v25.16b, v4.16b, #0x4 + shl v20.16b, v3.16b, #0x4 + ldr q5, [x21, #0x0] + ldr q16, [x21, #0x10] + and v4.16b, v4.16b, v31.16b + and v3.16b, v3.16b, v31.16b + ldr q8, [x11, #0x20] + ldr q11, [x11, #0x30] + add x11, x11, #0x40 + ldr q24, [x27, #0x20] + KAI_ASM_INST(0x4e99a767) // smmla v7.4s, v27.16b, v25.16b + KAI_ASM_INST(0x4e94a76d) // smmla v13.4s, v27.16b, v20.16b + ldr q27, [x27, #0x30] + KAI_ASM_INST(0x4e99a43d) // smmla v29.4s, v1.16b, v25.16b + KAI_ASM_INST(0x4e94a42c) // smmla v12.4s, v1.16b, v20.16b + ldr q1, [x23, #0x20] + KAI_ASM_INST(0x4e99a67c) // smmla v28.4s, v19.16b, v25.16b + KAI_ASM_INST(0x4e94a66f) // smmla v15.4s, v19.16b, v20.16b + ldr q19, [x23, #0x30] + KAI_ASM_INST(0x4e99a622) // smmla v2.4s, v17.16b, v25.16b + KAI_ASM_INST(0x4e94a636) // smmla v22.4s, v17.16b, v20.16b + ldr q17, [x22, #0x20] + KAI_ASM_INST(0x4e99a6be) // smmla v30.4s, v21.16b, v25.16b + KAI_ASM_INST(0x4e94a6ba) // smmla v26.4s, v21.16b, v20.16b + ldr q21, [x22, #0x30] + KAI_ASM_INST(0x4e99a6e6) // smmla v6.4s, v23.16b, v25.16b + KAI_ASM_INST(0x4e94a6ea) // smmla v10.4s, v23.16b, v20.16b + ldr q23, [x21, #0x20] + KAI_ASM_INST(0x4e99a4a9) // smmla v9.4s, v5.16b, v25.16b + KAI_ASM_INST(0x4e94a4b2) // smmla v18.4s, v5.16b, v20.16b + ldr q5, [x21, #0x30] + KAI_ASM_INST(0x4e99a600) // smmla v0.4s, v16.16b, v25.16b + ldr q25, [x27, #0x40] + KAI_ASM_INST(0x4e94a60e) // smmla v14.4s, v16.16b, v20.16b + ldr q16, [x27, #0x50] + shl v20.16b, v8.16b, #0x4 + and v8.16b, v8.16b, v31.16b + KAI_ASM_INST(0x4e94a707) // smmla v7.4s, v24.16b, v20.16b + KAI_ASM_INST(0x4e94a77d) // smmla v29.4s, v27.16b, v20.16b + KAI_ASM_INST(0x4e94a43c) // smmla v28.4s, v1.16b, v20.16b + KAI_ASM_INST(0x4e94a662) // smmla v2.4s, v19.16b, v20.16b + KAI_ASM_INST(0x4e94a63e) // smmla v30.4s, v17.16b, v20.16b + KAI_ASM_INST(0x4e94a6a6) // smmla v6.4s, v21.16b, v20.16b + KAI_ASM_INST(0x4e94a6e9) // smmla v9.4s, v23.16b, v20.16b + KAI_ASM_INST(0x4e94a4a0) // smmla v0.4s, v5.16b, v20.16b + shl v20.16b, v11.16b, #0x4 + KAI_ASM_INST(0x4e84a727) // smmla v7.4s, v25.16b, v4.16b + KAI_ASM_INST(0x4e84a61d) // smmla v29.4s, v16.16b, v4.16b + and v11.16b, v11.16b, v31.16b + ldr q31, [x23, #0x40] + KAI_ASM_INST(0x4e94a70d) // smmla v13.4s, v24.16b, v20.16b + ldr q24, [x23, #0x50] + KAI_ASM_INST(0x4e94a76c) // smmla v12.4s, v27.16b, v20.16b + ldr q27, [x22, #0x40] + KAI_ASM_INST(0x4e94a42f) // smmla v15.4s, v1.16b, v20.16b + ldr q1, [x22, #0x50] + KAI_ASM_INST(0x4e94a676) // smmla v22.4s, v19.16b, v20.16b + ldr q19, [x21, #0x40] + KAI_ASM_INST(0x4e94a63a) // smmla v26.4s, v17.16b, v20.16b + ldr q17, [x21, #0x50] + KAI_ASM_INST(0x4e94a6aa) // smmla v10.4s, v21.16b, v20.16b + ldr q21, [x27, #0x60] + KAI_ASM_INST(0x4e94a6f2) // smmla v18.4s, v23.16b, v20.16b + ldr q23, [x27, #0x70] + KAI_ASM_INST(0x4e94a4ae) // smmla v14.4s, v5.16b, v20.16b + ldr q20, [x23, #0x60] + KAI_ASM_INST(0x4e83a72d) // smmla v13.4s, v25.16b, v3.16b + ldr q5, [x23, #0x70] + ldr q25, [x22, #0x60] + KAI_ASM_INST(0x4e83a60c) // smmla v12.4s, v16.16b, v3.16b + KAI_ASM_INST(0x4e84a7fc) // smmla v28.4s, v31.16b, v4.16b + ldr q16, [x22, #0x70] + KAI_ASM_INST(0x4e83a7ef) // smmla v15.4s, v31.16b, v3.16b + ldr q31, [x21, #0x60] + KAI_ASM_INST(0x4e84a702) // smmla v2.4s, v24.16b, v4.16b + KAI_ASM_INST(0x4e83a716) // smmla v22.4s, v24.16b, v3.16b + ldr q24, [x21, #0x70] + KAI_ASM_INST(0x4e84a77e) // smmla v30.4s, v27.16b, v4.16b + add x27, x27, #0x80 + KAI_ASM_INST(0x4e83a77a) // smmla v26.4s, v27.16b, v3.16b + KAI_ASM_INST(0x4e84a426) // smmla v6.4s, v1.16b, v4.16b + add x23, x23, #0x80 + add x22, x22, #0x80 + KAI_ASM_INST(0x4e83a42a) // smmla v10.4s, v1.16b, v3.16b + KAI_ASM_INST(0x4e84a669) // smmla v9.4s, v19.16b, v4.16b + add x21, x21, #0x80 + KAI_ASM_INST(0x4e83a672) // smmla v18.4s, v19.16b, v3.16b + KAI_ASM_INST(0x4e84a620) // smmla v0.4s, v17.16b, v4.16b + KAI_ASM_INST(0x4e83a62e) // smmla v14.4s, v17.16b, v3.16b + KAI_ASM_INST(0x4e88a6a7) // smmla v7.4s, v21.16b, v8.16b + KAI_ASM_INST(0x4e8ba6ad) // smmla v13.4s, v21.16b, v11.16b + KAI_ASM_INST(0x4e88a6fd) // smmla v29.4s, v23.16b, v8.16b + KAI_ASM_INST(0x4e8ba6ec) // smmla v12.4s, v23.16b, v11.16b + KAI_ASM_INST(0x4e88a69c) // smmla v28.4s, v20.16b, v8.16b + KAI_ASM_INST(0x4e8ba68f) // smmla v15.4s, v20.16b, v11.16b + KAI_ASM_INST(0x4e88a4a2) // smmla v2.4s, v5.16b, v8.16b + KAI_ASM_INST(0x4e8ba4b6) // smmla v22.4s, v5.16b, v11.16b + KAI_ASM_INST(0x4e88a73e) // smmla v30.4s, v25.16b, v8.16b + KAI_ASM_INST(0x4e8ba73a) // smmla v26.4s, v25.16b, v11.16b + KAI_ASM_INST(0x4e88a606) // smmla v6.4s, v16.16b, v8.16b + KAI_ASM_INST(0x4e8ba60a) // smmla v10.4s, v16.16b, v11.16b + KAI_ASM_INST(0x4e88a7e9) // smmla v9.4s, v31.16b, v8.16b + KAI_ASM_INST(0x4e8ba7f2) // smmla v18.4s, v31.16b, v11.16b + KAI_ASM_INST(0x4e88a700) // smmla v0.4s, v24.16b, v8.16b + KAI_ASM_INST(0x4e8ba70e) // smmla v14.4s, v24.16b, v11.16b + bgt label_4 + ldr d4, [x11, #0x0] + ldr q23, [SP, #0x0] + uzp1 v16.2d, v7.2d, v13.2d + uzp2 v19.2d, v7.2d, v13.2d + uzp1 v20.2d, v29.2d, v12.2d + uzp2 v17.2d, v29.2d, v12.2d + add x11, x11, #0x8 + shll v24.4s, v4.4h, #0x10 + scvtf v16.4s, v16.4s, #0x4 + scvtf v19.4s, v19.4s, #0x4 + scvtf v20.4s, v20.4s, #0x4 + scvtf v17.4s, v17.4s, #0x4 + fmla v23.4s, v16.4s, v24.4s + str q23, [SP, #0x0] + ldr q16, [SP, #0x10] + fmla v16.4s, v19.4s, v24.4s + str q16, [SP, #0x10] + ldr q16, [SP, #0x20] + fmla v16.4s, v20.4s, v24.4s + str q16, [SP, #0x20] + ldr q16, [SP, #0x30] + fmla v16.4s, v17.4s, v24.4s + str q16, [SP, #0x30] + ldr q1, [SP, #0x40] + uzp1 v16.2d, v28.2d, v15.2d + uzp2 v19.2d, v28.2d, v15.2d + uzp1 v5.2d, v2.2d, v22.2d + uzp2 v17.2d, v2.2d, v22.2d + scvtf v16.4s, v16.4s, #0x4 + scvtf v19.4s, v19.4s, #0x4 + scvtf v5.4s, v5.4s, #0x4 + scvtf v17.4s, v17.4s, #0x4 + fmla v1.4s, v16.4s, v24.4s + str q1, [SP, #0x40] + ldr q16, [SP, #0x50] + fmla v16.4s, v19.4s, v24.4s + str q16, [SP, #0x50] + ldr q16, [SP, #0x60] + fmla v16.4s, v5.4s, v24.4s + str q16, [SP, #0x60] + ldr q16, [SP, #0x70] + fmla v16.4s, v17.4s, v24.4s + str q16, [SP, #0x70] + ldr q1, [SP, #0x80] + uzp1 v16.2d, v30.2d, v26.2d + uzp2 v19.2d, v30.2d, v26.2d + uzp1 v30.2d, v6.2d, v10.2d + uzp2 v17.2d, v6.2d, v10.2d + scvtf v16.4s, v16.4s, #0x4 + scvtf v19.4s, v19.4s, #0x4 + scvtf v30.4s, v30.4s, #0x4 + scvtf v17.4s, v17.4s, #0x4 + fmla v1.4s, v16.4s, v24.4s + str q1, [SP, #0x80] + ldr q16, [SP, #0x90] + fmla v16.4s, v19.4s, v24.4s + str q16, [SP, #0x90] + ldr q16, [SP, #0xa0] + fmla v16.4s, v30.4s, v24.4s + str q16, [SP, #0xa0] + ldr q16, [SP, #0xb0] + fmla v16.4s, v17.4s, v24.4s + str q16, [SP, #0xb0] + ldr q31, [SP, #0xc0] + uzp1 v16.2d, v9.2d, v18.2d + uzp2 v19.2d, v9.2d, v18.2d + uzp1 v21.2d, v0.2d, v14.2d + uzp2 v17.2d, v0.2d, v14.2d + scvtf v16.4s, v16.4s, #0x4 + scvtf v19.4s, v19.4s, #0x4 + scvtf v21.4s, v21.4s, #0x4 + scvtf v17.4s, v17.4s, #0x4 + fmla v31.4s, v16.4s, v24.4s + str q31, [SP, #0xc0] + ldr q16, [SP, #0xd0] + fmla v16.4s, v19.4s, v24.4s + str q16, [SP, #0xd0] + ldr q16, [SP, #0xe0] + fmla v16.4s, v21.4s, v24.4s + str q16, [SP, #0xe0] + ldr q16, [SP, #0xf0] + fmla v16.4s, v17.4s, v24.4s + str q16, [SP, #0xf0] + subs x24, x24, #0x1 + bgt label_3 + ld1 { v11.4s }, [x27] + ld1 { v10.4s }, [x23] + add x27, x27, #0x10 + add x23, x23, #0x10 + ld1 { v9.4s }, [x22] + ld1 { v8.4s }, [x21] + add x22, x22, #0x10 + add x21, x21, #0x10 + ldr q31, [SP, #0x0] + ldr q30, [SP, #0x10] + add x20, x12, #0x4 + cmp x10, #0x4 + ldr q29, [SP, #0x20] + ldr q28, [SP, #0x30] + scvtf v11.4s, v11.4s + scvtf v10.4s, v10.4s + ldr q27, [SP, #0x40] + ldr q26, [SP, #0x50] + scvtf v9.4s, v9.4s + scvtf v8.4s, v8.4s + ldr q25, [SP, #0x60] + ldr q24, [SP, #0x70] + ldr q23, [SP, #0x80] + ldr q22, [SP, #0x90] + ldr q21, [SP, #0xa0] + ldr q20, [SP, #0xb0] + ldr q19, [SP, #0xc0] + ldr q18, [SP, #0xd0] + ldr q17, [SP, #0xe0] + ldr q16, [SP, #0xf0] + ldr q7, [x11, #0x0] + ldr q6, [x27, #0x0] + ldr q5, [x23, #0x0] + ldr q4, [x22, #0x0] + ldr q3, [x21, #0x0] + ldr q2, [x11, #0x10] + add x11, x11, #0x20 + ld1r { v1.4s }, [x12] + ld1r { v0.4s }, [x20] + fmla v31.4s, v7.4s, v11.s[0] + fmla v30.4s, v7.4s, v11.s[1] + fmla v29.4s, v7.4s, v11.s[2] + fmla v28.4s, v7.4s, v11.s[3] + fmla v27.4s, v7.4s, v10.s[0] + fmla v26.4s, v7.4s, v10.s[1] + fmla v25.4s, v7.4s, v10.s[2] + fmla v24.4s, v7.4s, v10.s[3] + fmla v23.4s, v7.4s, v9.s[0] + fmla v22.4s, v7.4s, v9.s[1] + fmul v31.4s, v31.4s, v6.s[0] + fmla v21.4s, v7.4s, v9.s[2] + fmla v20.4s, v7.4s, v9.s[3] + fmul v30.4s, v30.4s, v6.s[1] + fmla v19.4s, v7.4s, v8.s[0] + fmla v18.4s, v7.4s, v8.s[1] + fmul v29.4s, v29.4s, v6.s[2] + fmla v17.4s, v7.4s, v8.s[2] + fmla v16.4s, v7.4s, v8.s[3] + fmul v28.4s, v28.4s, v6.s[3] + fmul v27.4s, v27.4s, v5.s[0] + fmul v26.4s, v26.4s, v5.s[1] + fmul v25.4s, v25.4s, v5.s[2] + fmul v24.4s, v24.4s, v5.s[3] + fmul v23.4s, v23.4s, v4.s[0] + fmul v22.4s, v22.4s, v4.s[1] + fmul v21.4s, v21.4s, v4.s[2] + fmul v20.4s, v20.4s, v4.s[3] + fmul v19.4s, v19.4s, v3.s[0] + fmul v18.4s, v18.4s, v3.s[1] + fmul v17.4s, v17.4s, v3.s[2] + fmul v16.4s, v16.4s, v3.s[3] + fadd v31.4s, v31.4s, v2.4s + fadd v30.4s, v30.4s, v2.4s + fadd v29.4s, v29.4s, v2.4s + fadd v28.4s, v28.4s, v2.4s + fadd v27.4s, v27.4s, v2.4s + fadd v26.4s, v26.4s, v2.4s + fadd v25.4s, v25.4s, v2.4s + fadd v24.4s, v24.4s, v2.4s + fadd v23.4s, v23.4s, v2.4s + fadd v22.4s, v22.4s, v2.4s + fadd v21.4s, v21.4s, v2.4s + fadd v20.4s, v20.4s, v2.4s + fadd v19.4s, v19.4s, v2.4s + fadd v18.4s, v18.4s, v2.4s + fadd v17.4s, v17.4s, v2.4s + fadd v16.4s, v16.4s, v2.4s + fmax v31.4s, v31.4s, v1.4s + fmax v30.4s, v30.4s, v1.4s + fmax v29.4s, v29.4s, v1.4s + fmax v28.4s, v28.4s, v1.4s + fmax v27.4s, v27.4s, v1.4s + fmax v26.4s, v26.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + fmax v24.4s, v24.4s, v1.4s + fmax v23.4s, v23.4s, v1.4s + fmax v22.4s, v22.4s, v1.4s + fmax v21.4s, v21.4s, v1.4s + fmax v20.4s, v20.4s, v1.4s + fmax v19.4s, v19.4s, v1.4s + fmax v18.4s, v18.4s, v1.4s + fmax v17.4s, v17.4s, v1.4s + fmax v16.4s, v16.4s, v1.4s + fmin v31.4s, v31.4s, v0.4s + fmin v30.4s, v30.4s, v0.4s + fmin v29.4s, v29.4s, v0.4s + fmin v28.4s, v28.4s, v0.4s + fmin v27.4s, v27.4s, v0.4s + fmin v26.4s, v26.4s, v0.4s + fmin v25.4s, v25.4s, v0.4s + fmin v24.4s, v24.4s, v0.4s + fmin v23.4s, v23.4s, v0.4s + fmin v22.4s, v22.4s, v0.4s + fmin v21.4s, v21.4s, v0.4s + fmin v20.4s, v20.4s, v0.4s + fmin v19.4s, v19.4s, v0.4s + fmin v18.4s, v18.4s, v0.4s + fmin v17.4s, v17.4s, v0.4s + fmin v16.4s, v16.4s, v0.4s + blt label_9 + mov x20, x14 + str q31, [x20, #0x0] + add x20, x20, x13 + str q30, [x20, #0x0] + add x20, x20, x13 + str q29, [x20, #0x0] + add x20, x20, x13 + str q28, [x20, #0x0] + add x20, x20, x13 + str q27, [x20, #0x0] + add x20, x20, x13 + str q26, [x20, #0x0] + add x20, x20, x13 + str q25, [x20, #0x0] + add x20, x20, x13 + str q24, [x20, #0x0] + add x20, x20, x13 + str q23, [x20, #0x0] + add x20, x20, x13 + str q22, [x20, #0x0] + add x20, x20, x13 + str q21, [x20, #0x0] + add x20, x20, x13 + str q20, [x20, #0x0] + add x20, x20, x13 + str q19, [x20, #0x0] + add x20, x20, x13 + str q18, [x20, #0x0] + add x20, x20, x13 + str q17, [x20, #0x0] + add x20, x20, x13 + str q16, [x20, #0x0] + b label_14 +KAI_ASM_LABEL(label_9) // Partial output + mov x28, x14 + add x26, x28, x13, LSL #2 + add x25, x26, x13, LSL #1 + add x24, x26, x13 + add x23, x25, x13 + add x22, x28, x13, LSL #1 + add x21, x28, x13 + add x20, x22, x13 + add x27, x23, x13 + tbz x10, #1, label_10 + st1 { v24.d }[0], [x23], #0x8 + st1 { v25.d }[0], [x25], #0x8 + st1 { v26.d }[0], [x24], #0x8 + st1 { v27.d }[0], [x26], #0x8 + st1 { v28.d }[0], [x20], #0x8 + st1 { v29.d }[0], [x22], #0x8 + st1 { v30.d }[0], [x21], #0x8 + st1 { v31.d }[0], [x28], #0x8 + tbz x10, #0, label_11 + st1 { v24.s }[2], [x23] + st1 { v25.s }[2], [x25] + st1 { v26.s }[2], [x24] + st1 { v27.s }[2], [x26] + st1 { v28.s }[2], [x20] + st1 { v29.s }[2], [x22] + st1 { v30.s }[2], [x21] + st1 { v31.s }[2], [x28] + b label_11 +KAI_ASM_LABEL(label_10) // Output block 0: partial_1_0 + st1 { v24.s }[0], [x23] + st1 { v25.s }[0], [x25] + st1 { v26.s }[0], [x24] + st1 { v27.s }[0], [x26] + st1 { v28.s }[0], [x20] + st1 { v29.s }[0], [x22] + st1 { v30.s }[0], [x21] + st1 { v31.s }[0], [x28] +KAI_ASM_LABEL(label_11) // Output block 0: Done + add x26, x27, x13, LSL #2 + add x25, x27, x13, LSL #1 + add x24, x26, x13, LSL #1 + add x23, x27, x13 + add x22, x25, x13 + add x21, x26, x13 + add x20, x24, x13 + tbz x10, #1, label_12 + st1 { v16.d }[0], [x20], #0x8 + st1 { v17.d }[0], [x24], #0x8 + st1 { v18.d }[0], [x21], #0x8 + st1 { v19.d }[0], [x26], #0x8 + st1 { v20.d }[0], [x22], #0x8 + st1 { v21.d }[0], [x25], #0x8 + st1 { v22.d }[0], [x23], #0x8 + st1 { v23.d }[0], [x27], #0x8 + tbz x10, #0, label_13 + st1 { v16.s }[2], [x20] + st1 { v17.s }[2], [x24] + st1 { v18.s }[2], [x21] + st1 { v19.s }[2], [x26] + st1 { v20.s }[2], [x22] + st1 { v21.s }[2], [x25] + st1 { v22.s }[2], [x23] + st1 { v23.s }[2], [x27] + b label_13 +KAI_ASM_LABEL(label_12) // Output block 1: partial_1_0 + st1 { v16.s }[0], [x20] + st1 { v17.s }[0], [x24] + st1 { v18.s }[0], [x21] + st1 { v19.s }[0], [x26] + st1 { v20.s }[0], [x22] + st1 { v21.s }[0], [x25] + st1 { v22.s }[0], [x23] + st1 { v23.s }[0], [x27] +KAI_ASM_LABEL(label_13) // Output block 1: Done +KAI_ASM_LABEL(label_14) // Output stage exit + subs x10, x10, #0x4 + add x14, x14, #0x10 + bgt label_2 + mov x20, #0x4 + sub x15, x15, #0x10 + cmp x15, #0x10 + mov x14, x9 + madd x8, x20, x5, x8 + bge label_1 +KAI_ASM_LABEL(label_15) // Row loop skip + cbz x15, label_25 +KAI_ASM_LABEL(label_16) // Row tail: Row loop + mov x26, x17 + mov x25, x16 + add x24, x14, x13, LSL #2 +KAI_ASM_LABEL(label_17) // Row tail: Column loop + movi v16.4s, #0x0 + mov x27, x8 + mov x21, x7 + str q16, [SP, #0x0] + str q16, [SP, #0x10] + str q16, [SP, #0x20] + str q16, [SP, #0x30] +KAI_ASM_LABEL(label_18) // Row tail: Block loop + movi v7.4s, #0x0 + movi v13.4s, #0x0 + mov x20, x6 + movi v29.4s, #0x0 + movi v12.4s, #0x0 +KAI_ASM_LABEL(label_19) // Row tail: Sub block loop + ldr q0, [x26, #0x0] + ldr q31, [x26, #0x10] + movi v30.16b, #0xf0 + subs x20, x20, #0x1 + ldr q18, [x27, #0x0] + ldr q28, [x27, #0x10] + ldr q27, [x26, #0x20] + ldr q26, [x26, #0x30] + add x26, x26, #0x40 + ldr q25, [x27, #0x20] + ldr q24, [x27, #0x30] + shl v23.16b, v0.16b, #0x4 + shl v22.16b, v31.16b, #0x4 + ldr q21, [x27, #0x40] + ldr q20, [x27, #0x50] + and v0.16b, v0.16b, v30.16b + and v31.16b, v31.16b, v30.16b + ldr q19, [x27, #0x60] + ldr q14, [x27, #0x70] + shl v17.16b, v27.16b, #0x4 + shl v16.16b, v26.16b, #0x4 + KAI_ASM_INST(0x4e97a647) // smmla v7.4s, v18.16b, v23.16b + KAI_ASM_INST(0x4e96a64d) // smmla v13.4s, v18.16b, v22.16b + and v27.16b, v27.16b, v30.16b + add x27, x27, #0x80 + KAI_ASM_INST(0x4e97a79d) // smmla v29.4s, v28.16b, v23.16b + KAI_ASM_INST(0x4e96a78c) // smmla v12.4s, v28.16b, v22.16b + and v26.16b, v26.16b, v30.16b + KAI_ASM_INST(0x4e91a727) // smmla v7.4s, v25.16b, v17.16b + KAI_ASM_INST(0x4e90a72d) // smmla v13.4s, v25.16b, v16.16b + KAI_ASM_INST(0x4e91a71d) // smmla v29.4s, v24.16b, v17.16b + KAI_ASM_INST(0x4e90a70c) // smmla v12.4s, v24.16b, v16.16b + KAI_ASM_INST(0x4e80a6a7) // smmla v7.4s, v21.16b, v0.16b + KAI_ASM_INST(0x4e9fa6ad) // smmla v13.4s, v21.16b, v31.16b + KAI_ASM_INST(0x4e80a69d) // smmla v29.4s, v20.16b, v0.16b + KAI_ASM_INST(0x4e9fa68c) // smmla v12.4s, v20.16b, v31.16b + KAI_ASM_INST(0x4e9ba667) // smmla v7.4s, v19.16b, v27.16b + KAI_ASM_INST(0x4e9aa66d) // smmla v13.4s, v19.16b, v26.16b + KAI_ASM_INST(0x4e9ba5dd) // smmla v29.4s, v14.16b, v27.16b + KAI_ASM_INST(0x4e9aa5cc) // smmla v12.4s, v14.16b, v26.16b + bgt label_19 + ldr d17, [x26, #0x0] + ldr q21, [SP, #0x0] + uzp1 v16.2d, v7.2d, v13.2d + uzp2 v20.2d, v7.2d, v13.2d + uzp1 v19.2d, v29.2d, v12.2d + uzp2 v18.2d, v29.2d, v12.2d + add x26, x26, #0x8 + shll v17.4s, v17.4h, #0x10 + scvtf v16.4s, v16.4s, #0x4 + scvtf v20.4s, v20.4s, #0x4 + scvtf v19.4s, v19.4s, #0x4 + scvtf v18.4s, v18.4s, #0x4 + fmla v21.4s, v16.4s, v17.4s + str q21, [SP, #0x0] + ldr q16, [SP, #0x10] + fmla v16.4s, v20.4s, v17.4s + str q16, [SP, #0x10] + ldr q16, [SP, #0x20] + fmla v16.4s, v19.4s, v17.4s + str q16, [SP, #0x20] + ldr q16, [SP, #0x30] + fmla v16.4s, v18.4s, v17.4s + str q16, [SP, #0x30] + subs x21, x21, #0x1 + bgt label_18 + ld1 { v21.4s }, [x27] + ldr q31, [SP, #0x0] + add x27, x27, #0x10 + add x20, x12, #0x4 + ldr q30, [SP, #0x10] + ldr q29, [SP, #0x20] + cmp x25, #0x4 + ldr q28, [SP, #0x30] + ldr q20, [x26, #0x0] + ldr q19, [x27, #0x0] + ldr q18, [x26, #0x10] + scvtf v21.4s, v21.4s + add x26, x26, #0x20 + ld1r { v17.4s }, [x12] + ld1r { v16.4s }, [x20] + fmla v31.4s, v20.4s, v21.s[0] + fmla v30.4s, v20.4s, v21.s[1] + fmla v29.4s, v20.4s, v21.s[2] + fmla v28.4s, v20.4s, v21.s[3] + fmul v31.4s, v31.4s, v19.s[0] + fmul v30.4s, v30.4s, v19.s[1] + fadd v31.4s, v31.4s, v18.4s + fmul v29.4s, v29.4s, v19.s[2] + fmul v28.4s, v28.4s, v19.s[3] + fadd v30.4s, v30.4s, v18.4s + fmax v31.4s, v31.4s, v17.4s + fadd v29.4s, v29.4s, v18.4s + fadd v28.4s, v28.4s, v18.4s + fmax v30.4s, v30.4s, v17.4s + fmin v31.4s, v31.4s, v16.4s + fmax v29.4s, v29.4s, v17.4s + fmax v28.4s, v28.4s, v17.4s + fmin v30.4s, v30.4s, v16.4s + fmin v29.4s, v29.4s, v16.4s + fmin v28.4s, v28.4s, v16.4s + blt label_21 + mov x20, x14 + cmp x15, #0x1 + str q31, [x20, #0x0] + add x20, x20, x13 + ble label_24 + cmp x15, #0x2 + str q30, [x20, #0x0] + add x20, x20, x13 + ble label_24 + cmp x15, #0x3 + str q29, [x20, #0x0] + add x20, x20, x13 + ble label_24 + str q28, [x20, #0x0] + b label_24 +KAI_ASM_LABEL(label_21) // Row tail: Partial output + mov x23, x14 + cmp x15, #0x1 + add x22, x23, x13 + csel x22, x22, x23, GT + cmp x15, #0x2 + add x21, x23, x13, LSL #1 + csel x21, x21, x22, GT + cmp x15, #0x3 + add x20, x21, x13 + csel x20, x20, x21, GT + tbz x25, #1, label_22 + st1 { v28.d }[0], [x20], #0x8 + st1 { v29.d }[0], [x21], #0x8 + st1 { v30.d }[0], [x22], #0x8 + st1 { v31.d }[0], [x23], #0x8 + tbz x25, #0, label_23 + st1 { v28.s }[2], [x20] + st1 { v29.s }[2], [x21] + st1 { v30.s }[2], [x22] + st1 { v31.s }[2], [x23] + b label_23 +KAI_ASM_LABEL(label_22) // Row tail: Output block 0: partial_1_0 + st1 { v28.s }[0], [x20] + st1 { v29.s }[0], [x21] + st1 { v30.s }[0], [x22] + st1 { v31.s }[0], [x23] +KAI_ASM_LABEL(label_23) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_24) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x14, x14, #0x10 + bgt label_17 + subs x15, x15, #0x4 + add x8, x8, x5 + mov x14, x24 + bgt label_16 +KAI_ASM_LABEL(label_25) // Row tail: Row loop skip + add SP, SP, #0x100 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S new file mode 100644 index 00000000..b2194615 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S @@ -0,0 +1,670 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x6, #0x80 + mov x21, #0x20 + ldr x20, [x0, #0x28] + ldr x7, [x0, #0x38] + ldr x8, [x0, #0x8] + ldr x17, [x0, #0x10] + ldr x16, [x0, #0x30] + ldr x15, [x0, #0x0] + mov x14, x20 + ldr x13, [x0, #0x20] + madd x6, x7, x6, x21 + ldr x12, [x0, #0x18] + cmp x14, #0x10 + blt label_14 +KAI_ASM_LABEL(label_1) // Row loop + mov x11, x17 + mov x10, x16 + add x9, x15, x13, LSL #4 +KAI_ASM_LABEL(label_2) // Column loop + mov x27, x8 + movi v31.16b, #0x0 + movi v30.16b, #0x0 + mov x20, x7 + movi v29.16b, #0x0 + movi v28.16b, #0x0 + movi v27.16b, #0x0 + movi v26.16b, #0x0 + add x23, x27, x6 + add x22, x23, x6 + movi v25.16b, #0x0 + movi v24.16b, #0x0 + add x21, x22, x6 + movi v23.16b, #0x0 + movi v22.16b, #0x0 + movi v21.16b, #0x0 + movi v20.16b, #0x0 + movi v19.16b, #0x0 + movi v18.16b, #0x0 + movi v17.16b, #0x0 + movi v16.16b, #0x0 +KAI_ASM_LABEL(label_3) // Block loop + ldr q11, [x11, #0x0] + ldr q4, [x11, #0x10] + movi v2.4s, #0x0 + movi v9.4s, #0x0 + ldr q12, [x27, #0x0] + ldr q0, [x27, #0x10] + movi v7.4s, #0x0 + movi v5.4s, #0x0 + ldr q15, [x11, #0x20] + ldr q13, [x11, #0x30] + movi v10.16b, #0xf0 + add x11, x11, #0x40 + ldr q8, [x27, #0x20] + ldr q6, [x27, #0x30] + shl v14.16b, v11.16b, #0x4 + shl v3.16b, v4.16b, #0x4 + ldr q1, [x27, #0x40] + and v11.16b, v11.16b, v10.16b + and v4.16b, v4.16b, v10.16b + KAI_ASM_INST(0x4e8ea582) // smmla v2.4s, v12.16b, v14.16b + KAI_ASM_INST(0x4e83a589) // smmla v9.4s, v12.16b, v3.16b + shl v12.16b, v15.16b, #0x4 + KAI_ASM_INST(0x4e8ea407) // smmla v7.4s, v0.16b, v14.16b + KAI_ASM_INST(0x4e83a405) // smmla v5.4s, v0.16b, v3.16b + shl v0.16b, v13.16b, #0x4 + and v15.16b, v15.16b, v10.16b + and v13.16b, v13.16b, v10.16b + ldr q10, [x27, #0x50] + KAI_ASM_INST(0x4e8ca502) // smmla v2.4s, v8.16b, v12.16b + KAI_ASM_INST(0x4e80a509) // smmla v9.4s, v8.16b, v0.16b + ldr q8, [x27, #0x60] + KAI_ASM_INST(0x4e8ca4c7) // smmla v7.4s, v6.16b, v12.16b + KAI_ASM_INST(0x4e80a4c5) // smmla v5.4s, v6.16b, v0.16b + ldr q6, [x27, #0x70] + add x27, x27, #0x80 + KAI_ASM_INST(0x4e8ba422) // smmla v2.4s, v1.16b, v11.16b + KAI_ASM_INST(0x4e84a429) // smmla v9.4s, v1.16b, v4.16b + ldr d1, [x11, #0x0] + add x11, x11, #0x8 + KAI_ASM_INST(0x4e8ba547) // smmla v7.4s, v10.16b, v11.16b + KAI_ASM_INST(0x4e84a545) // smmla v5.4s, v10.16b, v4.16b + KAI_ASM_INST(0x4e8fa502) // smmla v2.4s, v8.16b, v15.16b + shll v1.4s, v1.4h, #0x10 + KAI_ASM_INST(0x4e8da509) // smmla v9.4s, v8.16b, v13.16b + KAI_ASM_INST(0x4e8fa4c7) // smmla v7.4s, v6.16b, v15.16b + KAI_ASM_INST(0x4e8da4c5) // smmla v5.4s, v6.16b, v13.16b + uzp1 v6.2d, v2.2d, v9.2d + uzp2 v8.2d, v2.2d, v9.2d + scvtf v6.4s, v6.4s, #0x4 + uzp1 v9.2d, v7.2d, v5.2d + uzp2 v2.2d, v7.2d, v5.2d + scvtf v8.4s, v8.4s, #0x4 + fmla v31.4s, v6.4s, v1.4s + scvtf v9.4s, v9.4s, #0x4 + scvtf v2.4s, v2.4s, #0x4 + fmla v30.4s, v8.4s, v1.4s + fmla v29.4s, v9.4s, v1.4s + fmla v28.4s, v2.4s, v1.4s + ldr q9, [x23, #0x0] + ldr q7, [x23, #0x10] + movi v8.4s, #0x0 + movi v2.4s, #0x0 + ldr q5, [x23, #0x20] + ldr q10, [x23, #0x30] + movi v6.4s, #0x0 + KAI_ASM_INST(0x4e8ea528) // smmla v8.4s, v9.16b, v14.16b + KAI_ASM_INST(0x4e83a522) // smmla v2.4s, v9.16b, v3.16b + ldr q9, [x23, #0x40] + KAI_ASM_INST(0x4e8ea4e6) // smmla v6.4s, v7.16b, v14.16b + KAI_ASM_INST(0x4e8ca4a8) // smmla v8.4s, v5.16b, v12.16b + KAI_ASM_INST(0x4e80a4a2) // smmla v2.4s, v5.16b, v0.16b + ldr q5, [x23, #0x50] + KAI_ASM_INST(0x4e8ca546) // smmla v6.4s, v10.16b, v12.16b + KAI_ASM_INST(0x4e8ba528) // smmla v8.4s, v9.16b, v11.16b + KAI_ASM_INST(0x4e84a522) // smmla v2.4s, v9.16b, v4.16b + ldr q9, [x23, #0x60] + KAI_ASM_INST(0x4e8ba4a6) // smmla v6.4s, v5.16b, v11.16b + KAI_ASM_INST(0x4e8fa528) // smmla v8.4s, v9.16b, v15.16b + KAI_ASM_INST(0x4e8da522) // smmla v2.4s, v9.16b, v13.16b + movi v9.4s, #0x0 + KAI_ASM_INST(0x4e83a4e9) // smmla v9.4s, v7.16b, v3.16b + ldr q7, [x23, #0x70] + add x23, x23, #0x80 + KAI_ASM_INST(0x4e8fa4e6) // smmla v6.4s, v7.16b, v15.16b + KAI_ASM_INST(0x4e80a549) // smmla v9.4s, v10.16b, v0.16b + uzp1 v10.2d, v8.2d, v2.2d + uzp2 v2.2d, v8.2d, v2.2d + scvtf v10.4s, v10.4s, #0x4 + KAI_ASM_INST(0x4e84a4a9) // smmla v9.4s, v5.16b, v4.16b + scvtf v2.4s, v2.4s, #0x4 + fmla v27.4s, v10.4s, v1.4s + KAI_ASM_INST(0x4e8da4e9) // smmla v9.4s, v7.16b, v13.16b + fmla v26.4s, v2.4s, v1.4s + uzp1 v2.2d, v6.2d, v9.2d + uzp2 v10.2d, v6.2d, v9.2d + scvtf v2.4s, v2.4s, #0x4 + scvtf v10.4s, v10.4s, #0x4 + fmla v25.4s, v2.4s, v1.4s + fmla v24.4s, v10.4s, v1.4s + ldr q8, [x22, #0x0] + ldr q7, [x22, #0x10] + movi v9.4s, #0x0 + movi v6.4s, #0x0 + ldr q2, [x22, #0x20] + ldr q5, [x22, #0x30] + movi v10.4s, #0x0 + KAI_ASM_INST(0x4e8ea509) // smmla v9.4s, v8.16b, v14.16b + KAI_ASM_INST(0x4e83a506) // smmla v6.4s, v8.16b, v3.16b + ldr q8, [x22, #0x40] + KAI_ASM_INST(0x4e8ea4ea) // smmla v10.4s, v7.16b, v14.16b + KAI_ASM_INST(0x4e8ca449) // smmla v9.4s, v2.16b, v12.16b + KAI_ASM_INST(0x4e80a446) // smmla v6.4s, v2.16b, v0.16b + ldr q2, [x22, #0x50] + KAI_ASM_INST(0x4e8ca4aa) // smmla v10.4s, v5.16b, v12.16b + KAI_ASM_INST(0x4e8ba509) // smmla v9.4s, v8.16b, v11.16b + KAI_ASM_INST(0x4e84a506) // smmla v6.4s, v8.16b, v4.16b + ldr q8, [x22, #0x60] + KAI_ASM_INST(0x4e8ba44a) // smmla v10.4s, v2.16b, v11.16b + KAI_ASM_INST(0x4e8fa509) // smmla v9.4s, v8.16b, v15.16b + KAI_ASM_INST(0x4e8da506) // smmla v6.4s, v8.16b, v13.16b + movi v8.4s, #0x0 + KAI_ASM_INST(0x4e83a4e8) // smmla v8.4s, v7.16b, v3.16b + ldr q7, [x22, #0x70] + add x22, x22, #0x80 + KAI_ASM_INST(0x4e8fa4ea) // smmla v10.4s, v7.16b, v15.16b + KAI_ASM_INST(0x4e80a4a8) // smmla v8.4s, v5.16b, v0.16b + uzp1 v5.2d, v9.2d, v6.2d + uzp2 v9.2d, v9.2d, v6.2d + scvtf v5.4s, v5.4s, #0x4 + KAI_ASM_INST(0x4e84a448) // smmla v8.4s, v2.16b, v4.16b + scvtf v9.4s, v9.4s, #0x4 + fmla v23.4s, v5.4s, v1.4s + KAI_ASM_INST(0x4e8da4e8) // smmla v8.4s, v7.16b, v13.16b + fmla v22.4s, v9.4s, v1.4s + uzp1 v2.2d, v10.2d, v8.2d + uzp2 v10.2d, v10.2d, v8.2d + scvtf v2.4s, v2.4s, #0x4 + scvtf v10.4s, v10.4s, #0x4 + fmla v21.4s, v2.4s, v1.4s + fmla v20.4s, v10.4s, v1.4s + ldr q2, [x21, #0x0] + ldr q10, [x21, #0x10] + movi v6.4s, #0x0 + movi v9.4s, #0x0 + ldr q5, [x21, #0x20] + ldr q8, [x21, #0x30] + movi v7.4s, #0x0 + KAI_ASM_INST(0x4e8ea446) // smmla v6.4s, v2.16b, v14.16b + KAI_ASM_INST(0x4e83a449) // smmla v9.4s, v2.16b, v3.16b + ldr q2, [x21, #0x40] + KAI_ASM_INST(0x4e8ea547) // smmla v7.4s, v10.16b, v14.16b + ldr q14, [x21, #0x50] + KAI_ASM_INST(0x4e8ca4a6) // smmla v6.4s, v5.16b, v12.16b + KAI_ASM_INST(0x4e80a4a9) // smmla v9.4s, v5.16b, v0.16b + ldr q5, [x21, #0x60] + KAI_ASM_INST(0x4e8ca507) // smmla v7.4s, v8.16b, v12.16b + ldr q12, [x21, #0x70] + add x21, x21, #0x80 + KAI_ASM_INST(0x4e8ba446) // smmla v6.4s, v2.16b, v11.16b + KAI_ASM_INST(0x4e84a449) // smmla v9.4s, v2.16b, v4.16b + movi v2.4s, #0x0 + KAI_ASM_INST(0x4e83a542) // smmla v2.4s, v10.16b, v3.16b + KAI_ASM_INST(0x4e8ba5c7) // smmla v7.4s, v14.16b, v11.16b + KAI_ASM_INST(0x4e8fa4a6) // smmla v6.4s, v5.16b, v15.16b + KAI_ASM_INST(0x4e80a502) // smmla v2.4s, v8.16b, v0.16b + KAI_ASM_INST(0x4e8da4a9) // smmla v9.4s, v5.16b, v13.16b + KAI_ASM_INST(0x4e8fa587) // smmla v7.4s, v12.16b, v15.16b + KAI_ASM_INST(0x4e84a5c2) // smmla v2.4s, v14.16b, v4.16b + uzp1 v11.2d, v6.2d, v9.2d + uzp2 v14.2d, v6.2d, v9.2d + scvtf v11.4s, v11.4s, #0x4 + KAI_ASM_INST(0x4e8da582) // smmla v2.4s, v12.16b, v13.16b + scvtf v14.4s, v14.4s, #0x4 + fmla v19.4s, v11.4s, v1.4s + uzp1 v9.2d, v7.2d, v2.2d + uzp2 v0.2d, v7.2d, v2.2d + fmla v18.4s, v14.4s, v1.4s + scvtf v9.4s, v9.4s, #0x4 + scvtf v0.4s, v0.4s, #0x4 + fmla v17.4s, v9.4s, v1.4s + fmla v16.4s, v0.4s, v1.4s + subs x20, x20, #0x1 + bgt label_3 + ld1 { v11.4s }, [x27] + ld1 { v10.4s }, [x23] + add x27, x27, #0x10 + add x23, x23, #0x10 + ld1 { v9.4s }, [x22] + ld1 { v8.4s }, [x21] + add x22, x22, #0x10 + add x21, x21, #0x10 + ldr q7, [x11, #0x0] + ldr q6, [x27, #0x0] + add x20, x12, #0x4 + cmp x10, #0x4 + ldr q5, [x23, #0x0] + ldr q4, [x22, #0x0] + scvtf v11.4s, v11.4s + scvtf v10.4s, v10.4s + ldr q3, [x21, #0x0] + ldr q2, [x11, #0x10] + scvtf v9.4s, v9.4s + scvtf v8.4s, v8.4s + ld1r { v1.4s }, [x12] + ld1r { v0.4s }, [x20] + add x11, x11, #0x20 + fmla v31.4s, v7.4s, v11.s[0] + fmla v30.4s, v7.4s, v11.s[1] + fmla v29.4s, v7.4s, v11.s[2] + fmla v28.4s, v7.4s, v11.s[3] + fmla v27.4s, v7.4s, v10.s[0] + fmla v26.4s, v7.4s, v10.s[1] + fmla v25.4s, v7.4s, v10.s[2] + fmla v24.4s, v7.4s, v10.s[3] + fmla v23.4s, v7.4s, v9.s[0] + fmul v31.4s, v31.4s, v6.s[0] + fmla v22.4s, v7.4s, v9.s[1] + fmla v21.4s, v7.4s, v9.s[2] + fmul v30.4s, v30.4s, v6.s[1] + fmla v20.4s, v7.4s, v9.s[3] + fmla v19.4s, v7.4s, v8.s[0] + fmul v29.4s, v29.4s, v6.s[2] + fmla v18.4s, v7.4s, v8.s[1] + fmla v17.4s, v7.4s, v8.s[2] + fmul v28.4s, v28.4s, v6.s[3] + fmla v16.4s, v7.4s, v8.s[3] + fmul v27.4s, v27.4s, v5.s[0] + fmul v26.4s, v26.4s, v5.s[1] + fmul v25.4s, v25.4s, v5.s[2] + fmul v24.4s, v24.4s, v5.s[3] + fmul v23.4s, v23.4s, v4.s[0] + fmul v22.4s, v22.4s, v4.s[1] + fmul v21.4s, v21.4s, v4.s[2] + fmul v20.4s, v20.4s, v4.s[3] + fmul v19.4s, v19.4s, v3.s[0] + fmul v18.4s, v18.4s, v3.s[1] + fmul v17.4s, v17.4s, v3.s[2] + fmul v16.4s, v16.4s, v3.s[3] + fadd v31.4s, v31.4s, v2.4s + fadd v30.4s, v30.4s, v2.4s + fadd v29.4s, v29.4s, v2.4s + fadd v28.4s, v28.4s, v2.4s + fadd v27.4s, v27.4s, v2.4s + fadd v26.4s, v26.4s, v2.4s + fadd v25.4s, v25.4s, v2.4s + fadd v24.4s, v24.4s, v2.4s + fadd v23.4s, v23.4s, v2.4s + fadd v22.4s, v22.4s, v2.4s + fadd v21.4s, v21.4s, v2.4s + fadd v20.4s, v20.4s, v2.4s + fadd v19.4s, v19.4s, v2.4s + fadd v18.4s, v18.4s, v2.4s + fadd v17.4s, v17.4s, v2.4s + fadd v16.4s, v16.4s, v2.4s + fmax v31.4s, v31.4s, v1.4s + fmax v30.4s, v30.4s, v1.4s + fmax v29.4s, v29.4s, v1.4s + fmax v28.4s, v28.4s, v1.4s + fmax v27.4s, v27.4s, v1.4s + fmax v26.4s, v26.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + fmax v24.4s, v24.4s, v1.4s + fmax v23.4s, v23.4s, v1.4s + fmax v22.4s, v22.4s, v1.4s + fmax v21.4s, v21.4s, v1.4s + fmax v20.4s, v20.4s, v1.4s + fmax v19.4s, v19.4s, v1.4s + fmax v18.4s, v18.4s, v1.4s + fmax v17.4s, v17.4s, v1.4s + fmax v16.4s, v16.4s, v1.4s + fmin v31.4s, v31.4s, v0.4s + fmin v30.4s, v30.4s, v0.4s + fmin v29.4s, v29.4s, v0.4s + fmin v28.4s, v28.4s, v0.4s + fmin v27.4s, v27.4s, v0.4s + fmin v26.4s, v26.4s, v0.4s + fmin v25.4s, v25.4s, v0.4s + fmin v24.4s, v24.4s, v0.4s + fmin v23.4s, v23.4s, v0.4s + fmin v22.4s, v22.4s, v0.4s + fmin v21.4s, v21.4s, v0.4s + fmin v20.4s, v20.4s, v0.4s + fmin v19.4s, v19.4s, v0.4s + fmin v18.4s, v18.4s, v0.4s + fmin v17.4s, v17.4s, v0.4s + fmin v16.4s, v16.4s, v0.4s + blt label_8 + mov x20, x15 + str q31, [x20, #0x0] + add x20, x20, x13 + str q30, [x20, #0x0] + add x20, x20, x13 + str q29, [x20, #0x0] + add x20, x20, x13 + str q28, [x20, #0x0] + add x20, x20, x13 + str q27, [x20, #0x0] + add x20, x20, x13 + str q26, [x20, #0x0] + add x20, x20, x13 + str q25, [x20, #0x0] + add x20, x20, x13 + str q24, [x20, #0x0] + add x20, x20, x13 + str q23, [x20, #0x0] + add x20, x20, x13 + str q22, [x20, #0x0] + add x20, x20, x13 + str q21, [x20, #0x0] + add x20, x20, x13 + str q20, [x20, #0x0] + add x20, x20, x13 + str q19, [x20, #0x0] + add x20, x20, x13 + str q18, [x20, #0x0] + add x20, x20, x13 + str q17, [x20, #0x0] + add x20, x20, x13 + str q16, [x20, #0x0] + b label_13 +KAI_ASM_LABEL(label_8) // Partial output + mov x28, x15 + add x26, x28, x13, LSL #2 + add x25, x26, x13, LSL #1 + add x24, x26, x13 + add x23, x25, x13 + add x22, x28, x13, LSL #1 + add x21, x28, x13 + add x20, x22, x13 + add x27, x23, x13 + tbz x10, #1, label_9 + st1 { v24.d }[0], [x23], #0x8 + st1 { v25.d }[0], [x25], #0x8 + st1 { v26.d }[0], [x24], #0x8 + st1 { v27.d }[0], [x26], #0x8 + st1 { v28.d }[0], [x20], #0x8 + st1 { v29.d }[0], [x22], #0x8 + st1 { v30.d }[0], [x21], #0x8 + st1 { v31.d }[0], [x28], #0x8 + tbz x10, #0, label_10 + st1 { v24.s }[2], [x23] + st1 { v25.s }[2], [x25] + st1 { v26.s }[2], [x24] + st1 { v27.s }[2], [x26] + st1 { v28.s }[2], [x20] + st1 { v29.s }[2], [x22] + st1 { v30.s }[2], [x21] + st1 { v31.s }[2], [x28] + b label_10 +KAI_ASM_LABEL(label_9) // Output block 0: partial_1_0 + st1 { v24.s }[0], [x23] + st1 { v25.s }[0], [x25] + st1 { v26.s }[0], [x24] + st1 { v27.s }[0], [x26] + st1 { v28.s }[0], [x20] + st1 { v29.s }[0], [x22] + st1 { v30.s }[0], [x21] + st1 { v31.s }[0], [x28] +KAI_ASM_LABEL(label_10) // Output block 0: Done + add x26, x27, x13, LSL #2 + add x25, x27, x13, LSL #1 + add x24, x26, x13, LSL #1 + add x23, x27, x13 + add x22, x25, x13 + add x21, x26, x13 + add x20, x24, x13 + tbz x10, #1, label_11 + st1 { v16.d }[0], [x20], #0x8 + st1 { v17.d }[0], [x24], #0x8 + st1 { v18.d }[0], [x21], #0x8 + st1 { v19.d }[0], [x26], #0x8 + st1 { v20.d }[0], [x22], #0x8 + st1 { v21.d }[0], [x25], #0x8 + st1 { v22.d }[0], [x23], #0x8 + st1 { v23.d }[0], [x27], #0x8 + tbz x10, #0, label_12 + st1 { v16.s }[2], [x20] + st1 { v17.s }[2], [x24] + st1 { v18.s }[2], [x21] + st1 { v19.s }[2], [x26] + st1 { v20.s }[2], [x22] + st1 { v21.s }[2], [x25] + st1 { v22.s }[2], [x23] + st1 { v23.s }[2], [x27] + b label_12 +KAI_ASM_LABEL(label_11) // Output block 1: partial_1_0 + st1 { v16.s }[0], [x20] + st1 { v17.s }[0], [x24] + st1 { v18.s }[0], [x21] + st1 { v19.s }[0], [x26] + st1 { v20.s }[0], [x22] + st1 { v21.s }[0], [x25] + st1 { v22.s }[0], [x23] + st1 { v23.s }[0], [x27] +KAI_ASM_LABEL(label_12) // Output block 1: Done +KAI_ASM_LABEL(label_13) // Output stage exit + subs x10, x10, #0x4 + add x15, x15, #0x10 + bgt label_2 + mov x20, #0x4 + sub x14, x14, #0x10 + cmp x14, #0x10 + mov x15, x9 + madd x8, x20, x6, x8 + bge label_1 +KAI_ASM_LABEL(label_14) // Row loop skip + cbz x14, label_23 +KAI_ASM_LABEL(label_15) // Row tail: Row loop + mov x26, x17 + mov x25, x16 + add x24, x15, x13, LSL #2 +KAI_ASM_LABEL(label_16) // Row tail: Column loop + movi v31.16b, #0x0 + movi v30.16b, #0x0 + mov x27, x8 + mov x20, x7 + movi v29.16b, #0x0 + movi v28.16b, #0x0 +KAI_ASM_LABEL(label_17) // Row tail: Block loop + ldr q9, [x26, #0x0] + ldr q8, [x26, #0x10] + movi v7.4s, #0x0 + movi v6.4s, #0x0 + ldr q5, [x27, #0x0] + ldr q4, [x27, #0x10] + movi v3.4s, #0x0 + movi v2.4s, #0x0 + ldr q1, [x26, #0x20] + ldr q0, [x26, #0x30] + movi v27.16b, #0xf0 + add x26, x26, #0x40 + ldr q26, [x27, #0x20] + ldr q25, [x27, #0x30] + shl v24.16b, v9.16b, #0x4 + shl v20.16b, v8.16b, #0x4 + ldr q23, [x27, #0x40] + ldr q22, [x27, #0x50] + and v9.16b, v9.16b, v27.16b + and v8.16b, v8.16b, v27.16b + ldr q21, [x27, #0x60] + ldr q19, [x27, #0x70] + shl v18.16b, v1.16b, #0x4 + shl v17.16b, v0.16b, #0x4 + ldr d16, [x26, #0x0] + KAI_ASM_INST(0x4e98a4a7) // smmla v7.4s, v5.16b, v24.16b + KAI_ASM_INST(0x4e94a4a6) // smmla v6.4s, v5.16b, v20.16b + and v1.16b, v1.16b, v27.16b + KAI_ASM_INST(0x4e98a483) // smmla v3.4s, v4.16b, v24.16b + KAI_ASM_INST(0x4e94a482) // smmla v2.4s, v4.16b, v20.16b + and v0.16b, v0.16b, v27.16b + add x26, x26, #0x8 + add x27, x27, #0x80 + shll v20.4s, v16.4h, #0x10 + KAI_ASM_INST(0x4e92a747) // smmla v7.4s, v26.16b, v18.16b + KAI_ASM_INST(0x4e91a746) // smmla v6.4s, v26.16b, v17.16b + KAI_ASM_INST(0x4e92a723) // smmla v3.4s, v25.16b, v18.16b + KAI_ASM_INST(0x4e91a722) // smmla v2.4s, v25.16b, v17.16b + KAI_ASM_INST(0x4e89a6e7) // smmla v7.4s, v23.16b, v9.16b + KAI_ASM_INST(0x4e88a6e6) // smmla v6.4s, v23.16b, v8.16b + KAI_ASM_INST(0x4e89a6c3) // smmla v3.4s, v22.16b, v9.16b + KAI_ASM_INST(0x4e88a6c2) // smmla v2.4s, v22.16b, v8.16b + KAI_ASM_INST(0x4e81a6a7) // smmla v7.4s, v21.16b, v1.16b + KAI_ASM_INST(0x4e80a6a6) // smmla v6.4s, v21.16b, v0.16b + KAI_ASM_INST(0x4e81a663) // smmla v3.4s, v19.16b, v1.16b + KAI_ASM_INST(0x4e80a662) // smmla v2.4s, v19.16b, v0.16b + uzp1 v19.2d, v7.2d, v6.2d + uzp2 v18.2d, v7.2d, v6.2d + scvtf v19.4s, v19.4s, #0x4 + uzp1 v17.2d, v3.2d, v2.2d + uzp2 v16.2d, v3.2d, v2.2d + scvtf v18.4s, v18.4s, #0x4 + fmla v31.4s, v19.4s, v20.4s + scvtf v17.4s, v17.4s, #0x4 + scvtf v16.4s, v16.4s, #0x4 + fmla v30.4s, v18.4s, v20.4s + fmla v29.4s, v17.4s, v20.4s + fmla v28.4s, v16.4s, v20.4s + subs x20, x20, #0x1 + bgt label_17 + ld1 { v21.4s }, [x27] + ldr q20, [x26, #0x0] + add x27, x27, #0x10 + add x20, x12, #0x4 + ldr q19, [x27, #0x0] + ldr q18, [x26, #0x10] + cmp x25, #0x4 + add x26, x26, #0x20 + ld1r { v17.4s }, [x12] + ld1r { v16.4s }, [x20] + scvtf v21.4s, v21.4s + fmla v31.4s, v20.4s, v21.s[0] + fmla v30.4s, v20.4s, v21.s[1] + fmla v29.4s, v20.4s, v21.s[2] + fmla v28.4s, v20.4s, v21.s[3] + fmul v31.4s, v31.4s, v19.s[0] + fmul v30.4s, v30.4s, v19.s[1] + fmul v29.4s, v29.4s, v19.s[2] + fadd v31.4s, v31.4s, v18.4s + fmul v28.4s, v28.4s, v19.s[3] + fadd v30.4s, v30.4s, v18.4s + fadd v29.4s, v29.4s, v18.4s + fadd v28.4s, v28.4s, v18.4s + fmax v31.4s, v31.4s, v17.4s + fmax v30.4s, v30.4s, v17.4s + fmax v29.4s, v29.4s, v17.4s + fmax v28.4s, v28.4s, v17.4s + fmin v31.4s, v31.4s, v16.4s + fmin v30.4s, v30.4s, v16.4s + fmin v29.4s, v29.4s, v16.4s + fmin v28.4s, v28.4s, v16.4s + blt label_19 + mov x20, x15 + cmp x14, #0x1 + str q31, [x20, #0x0] + add x20, x20, x13 + ble label_22 + cmp x14, #0x2 + str q30, [x20, #0x0] + add x20, x20, x13 + ble label_22 + cmp x14, #0x3 + str q29, [x20, #0x0] + add x20, x20, x13 + ble label_22 + str q28, [x20, #0x0] + b label_22 +KAI_ASM_LABEL(label_19) // Row tail: Partial output + mov x23, x15 + cmp x14, #0x1 + add x22, x23, x13 + csel x22, x22, x23, GT + cmp x14, #0x2 + add x21, x23, x13, LSL #1 + csel x21, x21, x22, GT + cmp x14, #0x3 + add x20, x21, x13 + csel x20, x20, x21, GT + tbz x25, #1, label_20 + st1 { v28.d }[0], [x20], #0x8 + st1 { v29.d }[0], [x21], #0x8 + st1 { v30.d }[0], [x22], #0x8 + st1 { v31.d }[0], [x23], #0x8 + tbz x25, #0, label_21 + st1 { v28.s }[2], [x20] + st1 { v29.s }[2], [x21] + st1 { v30.s }[2], [x22] + st1 { v31.s }[2], [x23] + b label_21 +KAI_ASM_LABEL(label_20) // Row tail: Output block 0: partial_1_0 + st1 { v28.s }[0], [x20] + st1 { v29.s }[0], [x21] + st1 { v30.s }[0], [x22] + st1 { v31.s }[0], [x23] +KAI_ASM_LABEL(label_21) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_22) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x15, x15, #0x10 + bgt label_16 + subs x14, x14, #0x4 + add x8, x8, x6 + mov x15, x24 + bgt label_15 +KAI_ASM_LABEL(label_23) // Row tail: Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c index 0342801e..b3b153f8 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c @@ -3,62 +3,95 @@ // // SPDX-License-Identifier: Apache-2.0 // - -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__ARM_FEATURE_MATMUL_INT8) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) #error "I8mm extension required to compile this micro-kernel" -#else +#else // Architectural features check. + #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" -#include #include #include #include "kai/kai_common.h" +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(KernelArgs* args_ptr); + +// Compute args static const size_t kai_m_step = 16; static const size_t kai_n_step = 4; +// Packing args static const size_t kai_mr = 4; static const size_t kai_nr = 4; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_bl_multiple_of = 32; -static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); -static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); -static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); -static const size_t kai_num_bytes_sum_rhs = sizeof(float); -static const size_t kai_num_bytes_bias = sizeof(float); - -inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - return kai_roundup(k, bl) / bl; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 2; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); } -inline static size_t kai_k_roundedup(size_t k) { - // Since we pack a float and int32 value at the end of the row, - // we must make sure that k is a multiple of 4 for alignment - size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); - return kai_roundup(k, kr_sr_roundedup4); +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs; + return num_bytes_per_block_rhs; } -inline static size_t kai_lhs_packed_stride(size_t k) { - const size_t k_internal = kai_k_roundedup(k); +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} - KAI_ASSERT((k_internal % 2) == 0); +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; - return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + return lhs_packed_stride; } -inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { - KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; - return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); + return rhs_packed_stride; } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void) { @@ -86,1331 +119,67 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void } size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(size_t m_idx, size_t k) { - KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSUME((m_idx % kai_m_step) == 0); - return (m_idx / kai_mr) * kai_lhs_packed_stride(k); + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( size_t n_idx, size_t k, size_t bl) { - KAI_ASSERT((n_idx % kai_nr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); - return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSERT((m_idx % kai_m_step) == 0); - KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); - return (n_idx * sizeof(float)) + m_idx * dst_stride; + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(size_t m, size_t n) { - return m * n * sizeof(float); + return m * n * kai_num_bytes_dst_value; } void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( - size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, - float* dst, // NOLINT(readability-non-const-parameter) - size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { - KAI_ASSERT((k % bl) == 0); - KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(dst_stride_col == sizeof(float)); + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); if (m == 0) { return; } - - size_t num_subblocks = bl / kai_bl_multiple_of; - size_t num_blocks = kai_num_blocks_per_row(k, bl); - - float clamp_vals[2] = {scalar_min, scalar_max}; - - if (bl == 32) { - __asm__ __volatile__( - "mov x13, %x[m]\n" - "mov x12, #0x80\n" - "mov x20, #0x20\n" - "cmp x13, #0x10\n" - "madd x12, %x[num_blocks], x12, x20\n" - "blt 14f\n" - "1:" // Row loop - "mov x11, %x[rhs_packed]\n" - "mov x10, %x[n]\n" - "add x9, %x[dst], %x[dst_stride_row], LSL #4\n" - "2:" // Column loop - "mov x27, %x[lhs_packed]\n" - "movi v31.16b, #0x0\n" - "movi v30.16b, #0x0\n" - "mov x20, %x[num_blocks]\n" - "movi v29.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "movi v27.16b, #0x0\n" - "movi v26.16b, #0x0\n" - "add x23, x27, x12\n" - "add x22, x23, x12\n" - "movi v25.16b, #0x0\n" - "movi v24.16b, #0x0\n" - "add x21, x22, x12\n" - "movi v23.16b, #0x0\n" - "movi v22.16b, #0x0\n" - "movi v21.16b, #0x0\n" - "movi v20.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "movi v18.16b, #0x0\n" - "movi v17.16b, #0x0\n" - "movi v16.16b, #0x0\n" - "3:" // Block loop - "ldr q11, [x11, #0x0]\n" - "ldr q4, [x11, #0x10]\n" - "movi v2.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "ldr q12, [x27, #0x0]\n" - "ldr q0, [x27, #0x10]\n" - "movi v7.4s, #0x0\n" - "movi v5.4s, #0x0\n" - "ldr q15, [x11, #0x20]\n" - "ldr q13, [x11, #0x30]\n" - "movi v10.16b, #0xf0\n" - "add x11, x11, #0x40\n" - "ldr q8, [x27, #0x20]\n" - "ldr q6, [x27, #0x30]\n" - "shl v14.16b, v11.16b, #0x4\n" - "shl v3.16b, v4.16b, #0x4\n" - "ldr q1, [x27, #0x40]\n" - "and v11.16b, v11.16b, v10.16b\n" - "and v4.16b, v4.16b, v10.16b\n" - ".inst 0x4e8ea582 // smmla v2.4s, v12.16b, v14.16b\n" - ".inst 0x4e83a589 // smmla v9.4s, v12.16b, v3.16b\n" - "shl v12.16b, v15.16b, #0x4\n" - ".inst 0x4e8ea407 // smmla v7.4s, v0.16b, v14.16b\n" - ".inst 0x4e83a405 // smmla v5.4s, v0.16b, v3.16b\n" - "shl v0.16b, v13.16b, #0x4\n" - "and v15.16b, v15.16b, v10.16b\n" - "and v13.16b, v13.16b, v10.16b\n" - "ldr q10, [x27, #0x50]\n" - ".inst 0x4e8ca502 // smmla v2.4s, v8.16b, v12.16b\n" - ".inst 0x4e80a509 // smmla v9.4s, v8.16b, v0.16b\n" - "ldr q8, [x27, #0x60]\n" - ".inst 0x4e8ca4c7 // smmla v7.4s, v6.16b, v12.16b\n" - ".inst 0x4e80a4c5 // smmla v5.4s, v6.16b, v0.16b\n" - "ldr q6, [x27, #0x70]\n" - "add x27, x27, #0x80\n" - ".inst 0x4e8ba422 // smmla v2.4s, v1.16b, v11.16b\n" - ".inst 0x4e84a429 // smmla v9.4s, v1.16b, v4.16b\n" - "ldr d1, [x11, #0x0]\n" - "add x11, x11, #0x8\n" - ".inst 0x4e8ba547 // smmla v7.4s, v10.16b, v11.16b\n" - ".inst 0x4e84a545 // smmla v5.4s, v10.16b, v4.16b\n" - ".inst 0x4e8fa502 // smmla v2.4s, v8.16b, v15.16b\n" - "shll v1.4s, v1.4h, #0x10\n" - ".inst 0x4e8da509 // smmla v9.4s, v8.16b, v13.16b\n" - ".inst 0x4e8fa4c7 // smmla v7.4s, v6.16b, v15.16b\n" - ".inst 0x4e8da4c5 // smmla v5.4s, v6.16b, v13.16b\n" - "uzp1 v6.2d, v2.2d, v9.2d\n" - "uzp2 v8.2d, v2.2d, v9.2d\n" - "scvtf v6.4s, v6.4s, #0x4\n" - "uzp1 v9.2d, v7.2d, v5.2d\n" - "uzp2 v2.2d, v7.2d, v5.2d\n" - "scvtf v8.4s, v8.4s, #0x4\n" - "fmla v31.4s, v6.4s, v1.4s\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v30.4s, v8.4s, v1.4s\n" - "fmla v29.4s, v9.4s, v1.4s\n" - "fmla v28.4s, v2.4s, v1.4s\n" - "ldr q9, [x23, #0x0]\n" - "ldr q7, [x23, #0x10]\n" - "movi v8.4s, #0x0\n" - "movi v2.4s, #0x0\n" - "ldr q5, [x23, #0x20]\n" - "ldr q10, [x23, #0x30]\n" - "movi v6.4s, #0x0\n" - ".inst 0x4e8ea528 // smmla v8.4s, v9.16b, v14.16b\n" - ".inst 0x4e83a522 // smmla v2.4s, v9.16b, v3.16b\n" - "ldr q9, [x23, #0x40]\n" - ".inst 0x4e8ea4e6 // smmla v6.4s, v7.16b, v14.16b\n" - ".inst 0x4e8ca4a8 // smmla v8.4s, v5.16b, v12.16b\n" - ".inst 0x4e80a4a2 // smmla v2.4s, v5.16b, v0.16b\n" - "ldr q5, [x23, #0x50]\n" - ".inst 0x4e8ca546 // smmla v6.4s, v10.16b, v12.16b\n" - ".inst 0x4e8ba528 // smmla v8.4s, v9.16b, v11.16b\n" - ".inst 0x4e84a522 // smmla v2.4s, v9.16b, v4.16b\n" - "ldr q9, [x23, #0x60]\n" - ".inst 0x4e8ba4a6 // smmla v6.4s, v5.16b, v11.16b\n" - ".inst 0x4e8fa528 // smmla v8.4s, v9.16b, v15.16b\n" - ".inst 0x4e8da522 // smmla v2.4s, v9.16b, v13.16b\n" - "movi v9.4s, #0x0\n" - ".inst 0x4e83a4e9 // smmla v9.4s, v7.16b, v3.16b\n" - "ldr q7, [x23, #0x70]\n" - "add x23, x23, #0x80\n" - ".inst 0x4e8fa4e6 // smmla v6.4s, v7.16b, v15.16b\n" - ".inst 0x4e80a549 // smmla v9.4s, v10.16b, v0.16b\n" - "uzp1 v10.2d, v8.2d, v2.2d\n" - "uzp2 v2.2d, v8.2d, v2.2d\n" - "scvtf v10.4s, v10.4s, #0x4\n" - ".inst 0x4e84a4a9 // smmla v9.4s, v5.16b, v4.16b\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v27.4s, v10.4s, v1.4s\n" - ".inst 0x4e8da4e9 // smmla v9.4s, v7.16b, v13.16b\n" - "fmla v26.4s, v2.4s, v1.4s\n" - "uzp1 v2.2d, v6.2d, v9.2d\n" - "uzp2 v10.2d, v6.2d, v9.2d\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "fmla v25.4s, v2.4s, v1.4s\n" - "fmla v24.4s, v10.4s, v1.4s\n" - "ldr q8, [x22, #0x0]\n" - "ldr q7, [x22, #0x10]\n" - "movi v9.4s, #0x0\n" - "movi v6.4s, #0x0\n" - "ldr q2, [x22, #0x20]\n" - "ldr q5, [x22, #0x30]\n" - "movi v10.4s, #0x0\n" - ".inst 0x4e8ea509 // smmla v9.4s, v8.16b, v14.16b\n" - ".inst 0x4e83a506 // smmla v6.4s, v8.16b, v3.16b\n" - "ldr q8, [x22, #0x40]\n" - ".inst 0x4e8ea4ea // smmla v10.4s, v7.16b, v14.16b\n" - ".inst 0x4e8ca449 // smmla v9.4s, v2.16b, v12.16b\n" - ".inst 0x4e80a446 // smmla v6.4s, v2.16b, v0.16b\n" - "ldr q2, [x22, #0x50]\n" - ".inst 0x4e8ca4aa // smmla v10.4s, v5.16b, v12.16b\n" - ".inst 0x4e8ba509 // smmla v9.4s, v8.16b, v11.16b\n" - ".inst 0x4e84a506 // smmla v6.4s, v8.16b, v4.16b\n" - "ldr q8, [x22, #0x60]\n" - ".inst 0x4e8ba44a // smmla v10.4s, v2.16b, v11.16b\n" - ".inst 0x4e8fa509 // smmla v9.4s, v8.16b, v15.16b\n" - ".inst 0x4e8da506 // smmla v6.4s, v8.16b, v13.16b\n" - "movi v8.4s, #0x0\n" - ".inst 0x4e83a4e8 // smmla v8.4s, v7.16b, v3.16b\n" - "ldr q7, [x22, #0x70]\n" - "add x22, x22, #0x80\n" - ".inst 0x4e8fa4ea // smmla v10.4s, v7.16b, v15.16b\n" - ".inst 0x4e80a4a8 // smmla v8.4s, v5.16b, v0.16b\n" - "uzp1 v5.2d, v9.2d, v6.2d\n" - "uzp2 v9.2d, v9.2d, v6.2d\n" - "scvtf v5.4s, v5.4s, #0x4\n" - ".inst 0x4e84a448 // smmla v8.4s, v2.16b, v4.16b\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "fmla v23.4s, v5.4s, v1.4s\n" - ".inst 0x4e8da4e8 // smmla v8.4s, v7.16b, v13.16b\n" - "fmla v22.4s, v9.4s, v1.4s\n" - "uzp1 v2.2d, v10.2d, v8.2d\n" - "uzp2 v10.2d, v10.2d, v8.2d\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "fmla v21.4s, v2.4s, v1.4s\n" - "fmla v20.4s, v10.4s, v1.4s\n" - "ldr q2, [x21, #0x0]\n" - "ldr q10, [x21, #0x10]\n" - "movi v6.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "ldr q5, [x21, #0x20]\n" - "ldr q8, [x21, #0x30]\n" - "movi v7.4s, #0x0\n" - ".inst 0x4e8ea446 // smmla v6.4s, v2.16b, v14.16b\n" - ".inst 0x4e83a449 // smmla v9.4s, v2.16b, v3.16b\n" - "ldr q2, [x21, #0x40]\n" - ".inst 0x4e8ea547 // smmla v7.4s, v10.16b, v14.16b\n" - "ldr q14, [x21, #0x50]\n" - ".inst 0x4e8ca4a6 // smmla v6.4s, v5.16b, v12.16b\n" - ".inst 0x4e80a4a9 // smmla v9.4s, v5.16b, v0.16b\n" - "ldr q5, [x21, #0x60]\n" - ".inst 0x4e8ca507 // smmla v7.4s, v8.16b, v12.16b\n" - "ldr q12, [x21, #0x70]\n" - "add x21, x21, #0x80\n" - ".inst 0x4e8ba446 // smmla v6.4s, v2.16b, v11.16b\n" - ".inst 0x4e84a449 // smmla v9.4s, v2.16b, v4.16b\n" - "movi v2.4s, #0x0\n" - ".inst 0x4e83a542 // smmla v2.4s, v10.16b, v3.16b\n" - ".inst 0x4e8ba5c7 // smmla v7.4s, v14.16b, v11.16b\n" - ".inst 0x4e8fa4a6 // smmla v6.4s, v5.16b, v15.16b\n" - ".inst 0x4e80a502 // smmla v2.4s, v8.16b, v0.16b\n" - ".inst 0x4e8da4a9 // smmla v9.4s, v5.16b, v13.16b\n" - ".inst 0x4e8fa587 // smmla v7.4s, v12.16b, v15.16b\n" - ".inst 0x4e84a5c2 // smmla v2.4s, v14.16b, v4.16b\n" - "uzp1 v11.2d, v6.2d, v9.2d\n" - "uzp2 v14.2d, v6.2d, v9.2d\n" - "scvtf v11.4s, v11.4s, #0x4\n" - ".inst 0x4e8da582 // smmla v2.4s, v12.16b, v13.16b\n" - "scvtf v14.4s, v14.4s, #0x4\n" - "fmla v19.4s, v11.4s, v1.4s\n" - "uzp1 v9.2d, v7.2d, v2.2d\n" - "uzp2 v0.2d, v7.2d, v2.2d\n" - "fmla v18.4s, v14.4s, v1.4s\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "fmla v17.4s, v9.4s, v1.4s\n" - "fmla v16.4s, v0.4s, v1.4s\n" - "subs x20, x20, #0x1\n" - "bgt 3b\n" - "ld1 { v11.4s }, [x27]\n" - "ld1 { v10.4s }, [x23]\n" - "add x27, x27, #0x10\n" - "add x23, x23, #0x10\n" - "ld1 { v9.4s }, [x22]\n" - "ld1 { v8.4s }, [x21]\n" - "add x22, x22, #0x10\n" - "add x21, x21, #0x10\n" - "ldr q7, [x11, #0x0]\n" - "ldr q6, [x27, #0x0]\n" - "add x20, %x[clamp_vals], #0x4\n" - "cmp x10, #0x4\n" - "ldr q5, [x23, #0x0]\n" - "ldr q4, [x22, #0x0]\n" - "scvtf v11.4s, v11.4s\n" - "scvtf v10.4s, v10.4s\n" - "ldr q3, [x21, #0x0]\n" - "ldr q2, [x11, #0x10]\n" - "scvtf v9.4s, v9.4s\n" - "scvtf v8.4s, v8.4s\n" - "ld1r { v1.4s }, [%x[clamp_vals]]\n" - "ld1r { v0.4s }, [x20]\n" - "add x11, x11, #0x20\n" - "fmla v31.4s, v7.4s, v11.s[0]\n" - "fmla v30.4s, v7.4s, v11.s[1]\n" - "fmla v29.4s, v7.4s, v11.s[2]\n" - "fmla v28.4s, v7.4s, v11.s[3]\n" - "fmla v27.4s, v7.4s, v10.s[0]\n" - "fmla v26.4s, v7.4s, v10.s[1]\n" - "fmla v25.4s, v7.4s, v10.s[2]\n" - "fmla v24.4s, v7.4s, v10.s[3]\n" - "fmla v23.4s, v7.4s, v9.s[0]\n" - "fmul v31.4s, v31.4s, v6.s[0]\n" - "fmla v22.4s, v7.4s, v9.s[1]\n" - "fmla v21.4s, v7.4s, v9.s[2]\n" - "fmul v30.4s, v30.4s, v6.s[1]\n" - "fmla v20.4s, v7.4s, v9.s[3]\n" - "fmla v19.4s, v7.4s, v8.s[0]\n" - "fmul v29.4s, v29.4s, v6.s[2]\n" - "fmla v18.4s, v7.4s, v8.s[1]\n" - "fmla v17.4s, v7.4s, v8.s[2]\n" - "fmul v28.4s, v28.4s, v6.s[3]\n" - "fmla v16.4s, v7.4s, v8.s[3]\n" - "fmul v27.4s, v27.4s, v5.s[0]\n" - "fmul v26.4s, v26.4s, v5.s[1]\n" - "fmul v25.4s, v25.4s, v5.s[2]\n" - "fmul v24.4s, v24.4s, v5.s[3]\n" - "fmul v23.4s, v23.4s, v4.s[0]\n" - "fmul v22.4s, v22.4s, v4.s[1]\n" - "fmul v21.4s, v21.4s, v4.s[2]\n" - "fmul v20.4s, v20.4s, v4.s[3]\n" - "fmul v19.4s, v19.4s, v3.s[0]\n" - "fmul v18.4s, v18.4s, v3.s[1]\n" - "fmul v17.4s, v17.4s, v3.s[2]\n" - "fmul v16.4s, v16.4s, v3.s[3]\n" - "fadd v31.4s, v31.4s, v2.4s\n" - "fadd v30.4s, v30.4s, v2.4s\n" - "fadd v29.4s, v29.4s, v2.4s\n" - "fadd v28.4s, v28.4s, v2.4s\n" - "fadd v27.4s, v27.4s, v2.4s\n" - "fadd v26.4s, v26.4s, v2.4s\n" - "fadd v25.4s, v25.4s, v2.4s\n" - "fadd v24.4s, v24.4s, v2.4s\n" - "fadd v23.4s, v23.4s, v2.4s\n" - "fadd v22.4s, v22.4s, v2.4s\n" - "fadd v21.4s, v21.4s, v2.4s\n" - "fadd v20.4s, v20.4s, v2.4s\n" - "fadd v19.4s, v19.4s, v2.4s\n" - "fadd v18.4s, v18.4s, v2.4s\n" - "fadd v17.4s, v17.4s, v2.4s\n" - "fadd v16.4s, v16.4s, v2.4s\n" - "fmax v31.4s, v31.4s, v1.4s\n" - "fmax v30.4s, v30.4s, v1.4s\n" - "fmax v29.4s, v29.4s, v1.4s\n" - "fmax v28.4s, v28.4s, v1.4s\n" - "fmax v27.4s, v27.4s, v1.4s\n" - "fmax v26.4s, v26.4s, v1.4s\n" - "fmax v25.4s, v25.4s, v1.4s\n" - "fmax v24.4s, v24.4s, v1.4s\n" - "fmax v23.4s, v23.4s, v1.4s\n" - "fmax v22.4s, v22.4s, v1.4s\n" - "fmax v21.4s, v21.4s, v1.4s\n" - "fmax v20.4s, v20.4s, v1.4s\n" - "fmax v19.4s, v19.4s, v1.4s\n" - "fmax v18.4s, v18.4s, v1.4s\n" - "fmax v17.4s, v17.4s, v1.4s\n" - "fmax v16.4s, v16.4s, v1.4s\n" - "fmin v31.4s, v31.4s, v0.4s\n" - "fmin v30.4s, v30.4s, v0.4s\n" - "fmin v29.4s, v29.4s, v0.4s\n" - "fmin v28.4s, v28.4s, v0.4s\n" - "fmin v27.4s, v27.4s, v0.4s\n" - "fmin v26.4s, v26.4s, v0.4s\n" - "fmin v25.4s, v25.4s, v0.4s\n" - "fmin v24.4s, v24.4s, v0.4s\n" - "fmin v23.4s, v23.4s, v0.4s\n" - "fmin v22.4s, v22.4s, v0.4s\n" - "fmin v21.4s, v21.4s, v0.4s\n" - "fmin v20.4s, v20.4s, v0.4s\n" - "fmin v19.4s, v19.4s, v0.4s\n" - "fmin v18.4s, v18.4s, v0.4s\n" - "fmin v17.4s, v17.4s, v0.4s\n" - "fmin v16.4s, v16.4s, v0.4s\n" - "blt 8f\n" - "mov x20, %x[dst]\n" - "str q31, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q29, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q28, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q27, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q26, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q24, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q21, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q20, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q17, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q16, [x20, #0x0]\n" - "b 13f\n" - "8:" // Partial output - "mov x28, %x[dst]\n" - "add x26, x28, %x[dst_stride_row], LSL #2\n" - "add x25, x26, %x[dst_stride_row], LSL #1\n" - "add x24, x26, %x[dst_stride_row]\n" - "add x23, x25, %x[dst_stride_row]\n" - "add x22, x28, %x[dst_stride_row], LSL #1\n" - "add x21, x28, %x[dst_stride_row]\n" - "add x20, x22, %x[dst_stride_row]\n" - "add x27, x23, %x[dst_stride_row]\n" - "tbz x10, #1, 9f\n" - "st1 { v24.d }[0], [x23], #0x8\n" - "st1 { v25.d }[0], [x25], #0x8\n" - "st1 { v26.d }[0], [x24], #0x8\n" - "st1 { v27.d }[0], [x26], #0x8\n" - "st1 { v28.d }[0], [x20], #0x8\n" - "st1 { v29.d }[0], [x22], #0x8\n" - "st1 { v30.d }[0], [x21], #0x8\n" - "st1 { v31.d }[0], [x28], #0x8\n" - "tbz x10, #0, 10f\n" - "st1 { v24.s }[2], [x23]\n" - "st1 { v25.s }[2], [x25]\n" - "st1 { v26.s }[2], [x24]\n" - "st1 { v27.s }[2], [x26]\n" - "st1 { v28.s }[2], [x20]\n" - "st1 { v29.s }[2], [x22]\n" - "st1 { v30.s }[2], [x21]\n" - "st1 { v31.s }[2], [x28]\n" - "b 10f\n" - "9:" // Output block 0: partial_1_0 - "st1 { v24.s }[0], [x23]\n" - "st1 { v25.s }[0], [x25]\n" - "st1 { v26.s }[0], [x24]\n" - "st1 { v27.s }[0], [x26]\n" - "st1 { v28.s }[0], [x20]\n" - "st1 { v29.s }[0], [x22]\n" - "st1 { v30.s }[0], [x21]\n" - "st1 { v31.s }[0], [x28]\n" - "10:" // Output block 0: Done - "add x26, x27, %x[dst_stride_row], LSL #2\n" - "add x25, x27, %x[dst_stride_row], LSL #1\n" - "add x24, x26, %x[dst_stride_row], LSL #1\n" - "add x23, x27, %x[dst_stride_row]\n" - "add x22, x25, %x[dst_stride_row]\n" - "add x21, x26, %x[dst_stride_row]\n" - "add x20, x24, %x[dst_stride_row]\n" - "tbz x10, #1, 11f\n" - "st1 { v16.d }[0], [x20], #0x8\n" - "st1 { v17.d }[0], [x24], #0x8\n" - "st1 { v18.d }[0], [x21], #0x8\n" - "st1 { v19.d }[0], [x26], #0x8\n" - "st1 { v20.d }[0], [x22], #0x8\n" - "st1 { v21.d }[0], [x25], #0x8\n" - "st1 { v22.d }[0], [x23], #0x8\n" - "st1 { v23.d }[0], [x27], #0x8\n" - "tbz x10, #0, 12f\n" - "st1 { v16.s }[2], [x20]\n" - "st1 { v17.s }[2], [x24]\n" - "st1 { v18.s }[2], [x21]\n" - "st1 { v19.s }[2], [x26]\n" - "st1 { v20.s }[2], [x22]\n" - "st1 { v21.s }[2], [x25]\n" - "st1 { v22.s }[2], [x23]\n" - "st1 { v23.s }[2], [x27]\n" - "b 12f\n" - "11:" // Output block 1: partial_1_0 - "st1 { v16.s }[0], [x20]\n" - "st1 { v17.s }[0], [x24]\n" - "st1 { v18.s }[0], [x21]\n" - "st1 { v19.s }[0], [x26]\n" - "st1 { v20.s }[0], [x22]\n" - "st1 { v21.s }[0], [x25]\n" - "st1 { v22.s }[0], [x23]\n" - "st1 { v23.s }[0], [x27]\n" - "12:" // Output block 1: Done - "13:" // Output stage exit - "subs x10, x10, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 2b\n" - "mov x20, #0x4\n" - "sub x13, x13, #0x10\n" - "cmp x13, #0x10\n" - "mov %x[dst], x9\n" - "madd %x[lhs_packed], x20, x12, %x[lhs_packed]\n" - "bge 1b\n" - "14:" // Row loop skip - "cbz x13, 23f\n" - "15:" // Row tail: Row loop - "mov x26, %x[rhs_packed]\n" - "mov x25, %x[n]\n" - "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" - "16:" // Row tail: Column loop - "movi v31.16b, #0x0\n" - "movi v30.16b, #0x0\n" - "mov x27, %x[lhs_packed]\n" - "mov x20, %x[num_blocks]\n" - "movi v29.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "17:" // Row tail: Block loop - "ldr q9, [x26, #0x0]\n" - "ldr q8, [x26, #0x10]\n" - "movi v7.4s, #0x0\n" - "movi v6.4s, #0x0\n" - "ldr q5, [x27, #0x0]\n" - "ldr q4, [x27, #0x10]\n" - "movi v3.4s, #0x0\n" - "movi v2.4s, #0x0\n" - "ldr q1, [x26, #0x20]\n" - "ldr q0, [x26, #0x30]\n" - "movi v27.16b, #0xf0\n" - "add x26, x26, #0x40\n" - "ldr q26, [x27, #0x20]\n" - "ldr q25, [x27, #0x30]\n" - "shl v24.16b, v9.16b, #0x4\n" - "shl v20.16b, v8.16b, #0x4\n" - "ldr q23, [x27, #0x40]\n" - "ldr q22, [x27, #0x50]\n" - "and v9.16b, v9.16b, v27.16b\n" - "and v8.16b, v8.16b, v27.16b\n" - "ldr q21, [x27, #0x60]\n" - "ldr q19, [x27, #0x70]\n" - "shl v18.16b, v1.16b, #0x4\n" - "shl v17.16b, v0.16b, #0x4\n" - "ldr d16, [x26, #0x0]\n" - ".inst 0x4e98a4a7 // smmla v7.4s, v5.16b, v24.16b\n" - ".inst 0x4e94a4a6 // smmla v6.4s, v5.16b, v20.16b\n" - "and v1.16b, v1.16b, v27.16b\n" - ".inst 0x4e98a483 // smmla v3.4s, v4.16b, v24.16b\n" - ".inst 0x4e94a482 // smmla v2.4s, v4.16b, v20.16b\n" - "and v0.16b, v0.16b, v27.16b\n" - "add x26, x26, #0x8\n" - "add x27, x27, #0x80\n" - "shll v20.4s, v16.4h, #0x10\n" - ".inst 0x4e92a747 // smmla v7.4s, v26.16b, v18.16b\n" - ".inst 0x4e91a746 // smmla v6.4s, v26.16b, v17.16b\n" - ".inst 0x4e92a723 // smmla v3.4s, v25.16b, v18.16b\n" - ".inst 0x4e91a722 // smmla v2.4s, v25.16b, v17.16b\n" - ".inst 0x4e89a6e7 // smmla v7.4s, v23.16b, v9.16b\n" - ".inst 0x4e88a6e6 // smmla v6.4s, v23.16b, v8.16b\n" - ".inst 0x4e89a6c3 // smmla v3.4s, v22.16b, v9.16b\n" - ".inst 0x4e88a6c2 // smmla v2.4s, v22.16b, v8.16b\n" - ".inst 0x4e81a6a7 // smmla v7.4s, v21.16b, v1.16b\n" - ".inst 0x4e80a6a6 // smmla v6.4s, v21.16b, v0.16b\n" - ".inst 0x4e81a663 // smmla v3.4s, v19.16b, v1.16b\n" - ".inst 0x4e80a662 // smmla v2.4s, v19.16b, v0.16b\n" - "uzp1 v19.2d, v7.2d, v6.2d\n" - "uzp2 v18.2d, v7.2d, v6.2d\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "uzp1 v17.2d, v3.2d, v2.2d\n" - "uzp2 v16.2d, v3.2d, v2.2d\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v31.4s, v19.4s, v20.4s\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "fmla v30.4s, v18.4s, v20.4s\n" - "fmla v29.4s, v17.4s, v20.4s\n" - "fmla v28.4s, v16.4s, v20.4s\n" - "subs x20, x20, #0x1\n" - "bgt 17b\n" - "ld1 { v21.4s }, [x27]\n" - "ldr q20, [x26, #0x0]\n" - "add x27, x27, #0x10\n" - "add x20, %x[clamp_vals], #0x4\n" - "ldr q19, [x27, #0x0]\n" - "ldr q18, [x26, #0x10]\n" - "cmp x25, #0x4\n" - "add x26, x26, #0x20\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v16.4s }, [x20]\n" - "scvtf v21.4s, v21.4s\n" - "fmla v31.4s, v20.4s, v21.s[0]\n" - "fmla v30.4s, v20.4s, v21.s[1]\n" - "fmla v29.4s, v20.4s, v21.s[2]\n" - "fmla v28.4s, v20.4s, v21.s[3]\n" - "fmul v31.4s, v31.4s, v19.s[0]\n" - "fmul v30.4s, v30.4s, v19.s[1]\n" - "fmul v29.4s, v29.4s, v19.s[2]\n" - "fadd v31.4s, v31.4s, v18.4s\n" - "fmul v28.4s, v28.4s, v19.s[3]\n" - "fadd v30.4s, v30.4s, v18.4s\n" - "fadd v29.4s, v29.4s, v18.4s\n" - "fadd v28.4s, v28.4s, v18.4s\n" - "fmax v31.4s, v31.4s, v17.4s\n" - "fmax v30.4s, v30.4s, v17.4s\n" - "fmax v29.4s, v29.4s, v17.4s\n" - "fmax v28.4s, v28.4s, v17.4s\n" - "fmin v31.4s, v31.4s, v16.4s\n" - "fmin v30.4s, v30.4s, v16.4s\n" - "fmin v29.4s, v29.4s, v16.4s\n" - "fmin v28.4s, v28.4s, v16.4s\n" - "blt 19f\n" - "mov x20, %x[dst]\n" - "cmp x13, #0x1\n" - "str q31, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 22f\n" - "cmp x13, #0x2\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 22f\n" - "cmp x13, #0x3\n" - "str q29, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 22f\n" - "str q28, [x20, #0x0]\n" - "b 22f\n" - "19:" // Row tail: Partial output - "mov x23, %x[dst]\n" - "cmp x13, #0x1\n" - "add x22, x23, %x[dst_stride_row]\n" - "csel x22, x22, x23, GT\n" - "cmp x13, #0x2\n" - "add x21, x23, %x[dst_stride_row], LSL #1\n" - "csel x21, x21, x22, GT\n" - "cmp x13, #0x3\n" - "add x20, x21, %x[dst_stride_row]\n" - "csel x20, x20, x21, GT\n" - "tbz x25, #1, 20f\n" - "st1 { v28.d }[0], [x20], #0x8\n" - "st1 { v29.d }[0], [x21], #0x8\n" - "st1 { v30.d }[0], [x22], #0x8\n" - "st1 { v31.d }[0], [x23], #0x8\n" - "tbz x25, #0, 21f\n" - "st1 { v28.s }[2], [x20]\n" - "st1 { v29.s }[2], [x21]\n" - "st1 { v30.s }[2], [x22]\n" - "st1 { v31.s }[2], [x23]\n" - "b 21f\n" - "20:" // Row tail: Output block 0: partial_1_0 - "st1 { v28.s }[0], [x20]\n" - "st1 { v29.s }[0], [x21]\n" - "st1 { v30.s }[0], [x22]\n" - "st1 { v31.s }[0], [x23]\n" - "21:" // Row tail: Output block 0: Done - "22:" // Row tail: Output stage exit - "subs x25, x25, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 16b\n" - "subs x13, x13, #0x4\n" - "add %x[lhs_packed], %x[lhs_packed], x12\n" - "mov %x[dst], x24\n" - "bgt 15b\n" - "23:" // Row tail: Row loop skip - : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) - : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), - [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v2", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v31", "v4", "v5", - "v6", "v7", "v8", "v9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", - "x27", "x28", "x9"); - } else { - __asm__ __volatile__( - "mov x13, #0x80\n" - "mov x12, %x[m]\n" - "mov x20, #0x20\n" - "sub SP, SP, #0x100\n" - "mul x13, %x[num_subblocks], x13\n" - "cmp x12, #0x10\n" - "madd x13, %x[num_blocks], x13, x20\n" - "blt 15f\n" - "1:" // Row loop - "mov x11, %x[rhs_packed]\n" - "mov x10, %x[n]\n" - "add x9, %x[dst], %x[dst_stride_row], LSL #4\n" - "2:" // Column loop - "mov x27, %x[lhs_packed]\n" - "movi v29.4s, #0x0\n" - "mov x24, %x[num_blocks]\n" - "str q29, [SP, #0x0]\n" - "str q29, [SP, #0x10]\n" - "str q29, [SP, #0x20]\n" - "add x23, x27, x13\n" - "add x22, x23, x13\n" - "str q29, [SP, #0x30]\n" - "add x21, x22, x13\n" - "str q29, [SP, #0x40]\n" - "str q29, [SP, #0x50]\n" - "str q29, [SP, #0x60]\n" - "str q29, [SP, #0x70]\n" - "str q29, [SP, #0x80]\n" - "str q29, [SP, #0x90]\n" - "str q29, [SP, #0xa0]\n" - "str q29, [SP, #0xb0]\n" - "str q29, [SP, #0xc0]\n" - "str q29, [SP, #0xd0]\n" - "str q29, [SP, #0xe0]\n" - "str q29, [SP, #0xf0]\n" - "3:" // Block loop - "movi v7.4s, #0x0\n" - "movi v13.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "movi v29.4s, #0x0\n" - "movi v12.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "movi v15.4s, #0x0\n" - "movi v2.4s, #0x0\n" - "movi v22.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "movi v26.4s, #0x0\n" - "movi v6.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "movi v14.4s, #0x0\n" - "4:" // Sub block loop - "ldr q4, [x11, #0x0]\n" - "ldr q3, [x11, #0x10]\n" - "movi v31.16b, #0xf0\n" - "subs x20, x20, #0x1\n" - "ldr q27, [x27, #0x0]\n" - "ldr q1, [x27, #0x10]\n" - "ldr q19, [x23, #0x0]\n" - "ldr q17, [x23, #0x10]\n" - "ldr q21, [x22, #0x0]\n" - "ldr q23, [x22, #0x10]\n" - "shl v25.16b, v4.16b, #0x4\n" - "shl v20.16b, v3.16b, #0x4\n" - "ldr q5, [x21, #0x0]\n" - "ldr q16, [x21, #0x10]\n" - "and v4.16b, v4.16b, v31.16b\n" - "and v3.16b, v3.16b, v31.16b\n" - "ldr q8, [x11, #0x20]\n" - "ldr q11, [x11, #0x30]\n" - "add x11, x11, #0x40\n" - "ldr q24, [x27, #0x20]\n" - ".inst 0x4e99a767 // smmla v7.4s, v27.16b, v25.16b\n" - ".inst 0x4e94a76d // smmla v13.4s, v27.16b, v20.16b\n" - "ldr q27, [x27, #0x30]\n" - ".inst 0x4e99a43d // smmla v29.4s, v1.16b, v25.16b\n" - ".inst 0x4e94a42c // smmla v12.4s, v1.16b, v20.16b\n" - "ldr q1, [x23, #0x20]\n" - ".inst 0x4e99a67c // smmla v28.4s, v19.16b, v25.16b\n" - ".inst 0x4e94a66f // smmla v15.4s, v19.16b, v20.16b\n" - "ldr q19, [x23, #0x30]\n" - ".inst 0x4e99a622 // smmla v2.4s, v17.16b, v25.16b\n" - ".inst 0x4e94a636 // smmla v22.4s, v17.16b, v20.16b\n" - "ldr q17, [x22, #0x20]\n" - ".inst 0x4e99a6be // smmla v30.4s, v21.16b, v25.16b\n" - ".inst 0x4e94a6ba // smmla v26.4s, v21.16b, v20.16b\n" - "ldr q21, [x22, #0x30]\n" - ".inst 0x4e99a6e6 // smmla v6.4s, v23.16b, v25.16b\n" - ".inst 0x4e94a6ea // smmla v10.4s, v23.16b, v20.16b\n" - "ldr q23, [x21, #0x20]\n" - ".inst 0x4e99a4a9 // smmla v9.4s, v5.16b, v25.16b\n" - ".inst 0x4e94a4b2 // smmla v18.4s, v5.16b, v20.16b\n" - "ldr q5, [x21, #0x30]\n" - ".inst 0x4e99a600 // smmla v0.4s, v16.16b, v25.16b\n" - "ldr q25, [x27, #0x40]\n" - ".inst 0x4e94a60e // smmla v14.4s, v16.16b, v20.16b\n" - "ldr q16, [x27, #0x50]\n" - "shl v20.16b, v8.16b, #0x4\n" - "and v8.16b, v8.16b, v31.16b\n" - ".inst 0x4e94a707 // smmla v7.4s, v24.16b, v20.16b\n" - ".inst 0x4e94a77d // smmla v29.4s, v27.16b, v20.16b\n" - ".inst 0x4e94a43c // smmla v28.4s, v1.16b, v20.16b\n" - ".inst 0x4e94a662 // smmla v2.4s, v19.16b, v20.16b\n" - ".inst 0x4e94a63e // smmla v30.4s, v17.16b, v20.16b\n" - ".inst 0x4e94a6a6 // smmla v6.4s, v21.16b, v20.16b\n" - ".inst 0x4e94a6e9 // smmla v9.4s, v23.16b, v20.16b\n" - ".inst 0x4e94a4a0 // smmla v0.4s, v5.16b, v20.16b\n" - "shl v20.16b, v11.16b, #0x4\n" - ".inst 0x4e84a727 // smmla v7.4s, v25.16b, v4.16b\n" - ".inst 0x4e84a61d // smmla v29.4s, v16.16b, v4.16b\n" - "and v11.16b, v11.16b, v31.16b\n" - "ldr q31, [x23, #0x40]\n" - ".inst 0x4e94a70d // smmla v13.4s, v24.16b, v20.16b\n" - "ldr q24, [x23, #0x50]\n" - ".inst 0x4e94a76c // smmla v12.4s, v27.16b, v20.16b\n" - "ldr q27, [x22, #0x40]\n" - ".inst 0x4e94a42f // smmla v15.4s, v1.16b, v20.16b\n" - "ldr q1, [x22, #0x50]\n" - ".inst 0x4e94a676 // smmla v22.4s, v19.16b, v20.16b\n" - "ldr q19, [x21, #0x40]\n" - ".inst 0x4e94a63a // smmla v26.4s, v17.16b, v20.16b\n" - "ldr q17, [x21, #0x50]\n" - ".inst 0x4e94a6aa // smmla v10.4s, v21.16b, v20.16b\n" - "ldr q21, [x27, #0x60]\n" - ".inst 0x4e94a6f2 // smmla v18.4s, v23.16b, v20.16b\n" - "ldr q23, [x27, #0x70]\n" - ".inst 0x4e94a4ae // smmla v14.4s, v5.16b, v20.16b\n" - "ldr q20, [x23, #0x60]\n" - ".inst 0x4e83a72d // smmla v13.4s, v25.16b, v3.16b\n" - "ldr q5, [x23, #0x70]\n" - "ldr q25, [x22, #0x60]\n" - ".inst 0x4e83a60c // smmla v12.4s, v16.16b, v3.16b\n" - ".inst 0x4e84a7fc // smmla v28.4s, v31.16b, v4.16b\n" - "ldr q16, [x22, #0x70]\n" - ".inst 0x4e83a7ef // smmla v15.4s, v31.16b, v3.16b\n" - "ldr q31, [x21, #0x60]\n" - ".inst 0x4e84a702 // smmla v2.4s, v24.16b, v4.16b\n" - ".inst 0x4e83a716 // smmla v22.4s, v24.16b, v3.16b\n" - "ldr q24, [x21, #0x70]\n" - ".inst 0x4e84a77e // smmla v30.4s, v27.16b, v4.16b\n" - "add x27, x27, #0x80\n" - ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" - ".inst 0x4e84a426 // smmla v6.4s, v1.16b, v4.16b\n" - "add x23, x23, #0x80\n" - "add x22, x22, #0x80\n" - ".inst 0x4e83a42a // smmla v10.4s, v1.16b, v3.16b\n" - ".inst 0x4e84a669 // smmla v9.4s, v19.16b, v4.16b\n" - "add x21, x21, #0x80\n" - ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" - ".inst 0x4e84a620 // smmla v0.4s, v17.16b, v4.16b\n" - ".inst 0x4e83a62e // smmla v14.4s, v17.16b, v3.16b\n" - ".inst 0x4e88a6a7 // smmla v7.4s, v21.16b, v8.16b\n" - ".inst 0x4e8ba6ad // smmla v13.4s, v21.16b, v11.16b\n" - ".inst 0x4e88a6fd // smmla v29.4s, v23.16b, v8.16b\n" - ".inst 0x4e8ba6ec // smmla v12.4s, v23.16b, v11.16b\n" - ".inst 0x4e88a69c // smmla v28.4s, v20.16b, v8.16b\n" - ".inst 0x4e8ba68f // smmla v15.4s, v20.16b, v11.16b\n" - ".inst 0x4e88a4a2 // smmla v2.4s, v5.16b, v8.16b\n" - ".inst 0x4e8ba4b6 // smmla v22.4s, v5.16b, v11.16b\n" - ".inst 0x4e88a73e // smmla v30.4s, v25.16b, v8.16b\n" - ".inst 0x4e8ba73a // smmla v26.4s, v25.16b, v11.16b\n" - ".inst 0x4e88a606 // smmla v6.4s, v16.16b, v8.16b\n" - ".inst 0x4e8ba60a // smmla v10.4s, v16.16b, v11.16b\n" - ".inst 0x4e88a7e9 // smmla v9.4s, v31.16b, v8.16b\n" - ".inst 0x4e8ba7f2 // smmla v18.4s, v31.16b, v11.16b\n" - ".inst 0x4e88a700 // smmla v0.4s, v24.16b, v8.16b\n" - ".inst 0x4e8ba70e // smmla v14.4s, v24.16b, v11.16b\n" - "bgt 4b\n" - "ldr d4, [x11, #0x0]\n" - "ldr q23, [SP, #0x0]\n" - "uzp1 v16.2d, v7.2d, v13.2d\n" - "uzp2 v19.2d, v7.2d, v13.2d\n" - "uzp1 v20.2d, v29.2d, v12.2d\n" - "uzp2 v17.2d, v29.2d, v12.2d\n" - "add x11, x11, #0x8\n" - "shll v24.4s, v4.4h, #0x10\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "fmla v23.4s, v16.4s, v24.4s\n" - "str q23, [SP, #0x0]\n" - "ldr q16, [SP, #0x10]\n" - "fmla v16.4s, v19.4s, v24.4s\n" - "str q16, [SP, #0x10]\n" - "ldr q16, [SP, #0x20]\n" - "fmla v16.4s, v20.4s, v24.4s\n" - "str q16, [SP, #0x20]\n" - "ldr q16, [SP, #0x30]\n" - "fmla v16.4s, v17.4s, v24.4s\n" - "str q16, [SP, #0x30]\n" - "ldr q1, [SP, #0x40]\n" - "uzp1 v16.2d, v28.2d, v15.2d\n" - "uzp2 v19.2d, v28.2d, v15.2d\n" - "uzp1 v5.2d, v2.2d, v22.2d\n" - "uzp2 v17.2d, v2.2d, v22.2d\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "scvtf v5.4s, v5.4s, #0x4\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "fmla v1.4s, v16.4s, v24.4s\n" - "str q1, [SP, #0x40]\n" - "ldr q16, [SP, #0x50]\n" - "fmla v16.4s, v19.4s, v24.4s\n" - "str q16, [SP, #0x50]\n" - "ldr q16, [SP, #0x60]\n" - "fmla v16.4s, v5.4s, v24.4s\n" - "str q16, [SP, #0x60]\n" - "ldr q16, [SP, #0x70]\n" - "fmla v16.4s, v17.4s, v24.4s\n" - "str q16, [SP, #0x70]\n" - "ldr q1, [SP, #0x80]\n" - "uzp1 v16.2d, v30.2d, v26.2d\n" - "uzp2 v19.2d, v30.2d, v26.2d\n" - "uzp1 v30.2d, v6.2d, v10.2d\n" - "uzp2 v17.2d, v6.2d, v10.2d\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "scvtf v30.4s, v30.4s, #0x4\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "fmla v1.4s, v16.4s, v24.4s\n" - "str q1, [SP, #0x80]\n" - "ldr q16, [SP, #0x90]\n" - "fmla v16.4s, v19.4s, v24.4s\n" - "str q16, [SP, #0x90]\n" - "ldr q16, [SP, #0xa0]\n" - "fmla v16.4s, v30.4s, v24.4s\n" - "str q16, [SP, #0xa0]\n" - "ldr q16, [SP, #0xb0]\n" - "fmla v16.4s, v17.4s, v24.4s\n" - "str q16, [SP, #0xb0]\n" - "ldr q31, [SP, #0xc0]\n" - "uzp1 v16.2d, v9.2d, v18.2d\n" - "uzp2 v19.2d, v9.2d, v18.2d\n" - "uzp1 v21.2d, v0.2d, v14.2d\n" - "uzp2 v17.2d, v0.2d, v14.2d\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "scvtf v21.4s, v21.4s, #0x4\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "fmla v31.4s, v16.4s, v24.4s\n" - "str q31, [SP, #0xc0]\n" - "ldr q16, [SP, #0xd0]\n" - "fmla v16.4s, v19.4s, v24.4s\n" - "str q16, [SP, #0xd0]\n" - "ldr q16, [SP, #0xe0]\n" - "fmla v16.4s, v21.4s, v24.4s\n" - "str q16, [SP, #0xe0]\n" - "ldr q16, [SP, #0xf0]\n" - "fmla v16.4s, v17.4s, v24.4s\n" - "str q16, [SP, #0xf0]\n" - "subs x24, x24, #0x1\n" - "bgt 3b\n" - "ld1 { v11.4s }, [x27]\n" - "ld1 { v10.4s }, [x23]\n" - "add x27, x27, #0x10\n" - "add x23, x23, #0x10\n" - "ld1 { v9.4s }, [x22]\n" - "ld1 { v8.4s }, [x21]\n" - "add x22, x22, #0x10\n" - "add x21, x21, #0x10\n" - "ldr q31, [SP, #0x0]\n" - "ldr q30, [SP, #0x10]\n" - "add x20, %x[clamp_vals], #0x4\n" - "cmp x10, #0x4\n" - "ldr q29, [SP, #0x20]\n" - "ldr q28, [SP, #0x30]\n" - "scvtf v11.4s, v11.4s\n" - "scvtf v10.4s, v10.4s\n" - "ldr q27, [SP, #0x40]\n" - "ldr q26, [SP, #0x50]\n" - "scvtf v9.4s, v9.4s\n" - "scvtf v8.4s, v8.4s\n" - "ldr q25, [SP, #0x60]\n" - "ldr q24, [SP, #0x70]\n" - "ldr q23, [SP, #0x80]\n" - "ldr q22, [SP, #0x90]\n" - "ldr q21, [SP, #0xa0]\n" - "ldr q20, [SP, #0xb0]\n" - "ldr q19, [SP, #0xc0]\n" - "ldr q18, [SP, #0xd0]\n" - "ldr q17, [SP, #0xe0]\n" - "ldr q16, [SP, #0xf0]\n" - "ldr q7, [x11, #0x0]\n" - "ldr q6, [x27, #0x0]\n" - "ldr q5, [x23, #0x0]\n" - "ldr q4, [x22, #0x0]\n" - "ldr q3, [x21, #0x0]\n" - "ldr q2, [x11, #0x10]\n" - "add x11, x11, #0x20\n" - "ld1r { v1.4s }, [%x[clamp_vals]]\n" - "ld1r { v0.4s }, [x20]\n" - "fmla v31.4s, v7.4s, v11.s[0]\n" - "fmla v30.4s, v7.4s, v11.s[1]\n" - "fmla v29.4s, v7.4s, v11.s[2]\n" - "fmla v28.4s, v7.4s, v11.s[3]\n" - "fmla v27.4s, v7.4s, v10.s[0]\n" - "fmla v26.4s, v7.4s, v10.s[1]\n" - "fmla v25.4s, v7.4s, v10.s[2]\n" - "fmla v24.4s, v7.4s, v10.s[3]\n" - "fmla v23.4s, v7.4s, v9.s[0]\n" - "fmla v22.4s, v7.4s, v9.s[1]\n" - "fmul v31.4s, v31.4s, v6.s[0]\n" - "fmla v21.4s, v7.4s, v9.s[2]\n" - "fmla v20.4s, v7.4s, v9.s[3]\n" - "fmul v30.4s, v30.4s, v6.s[1]\n" - "fmla v19.4s, v7.4s, v8.s[0]\n" - "fmla v18.4s, v7.4s, v8.s[1]\n" - "fmul v29.4s, v29.4s, v6.s[2]\n" - "fmla v17.4s, v7.4s, v8.s[2]\n" - "fmla v16.4s, v7.4s, v8.s[3]\n" - "fmul v28.4s, v28.4s, v6.s[3]\n" - "fmul v27.4s, v27.4s, v5.s[0]\n" - "fmul v26.4s, v26.4s, v5.s[1]\n" - "fmul v25.4s, v25.4s, v5.s[2]\n" - "fmul v24.4s, v24.4s, v5.s[3]\n" - "fmul v23.4s, v23.4s, v4.s[0]\n" - "fmul v22.4s, v22.4s, v4.s[1]\n" - "fmul v21.4s, v21.4s, v4.s[2]\n" - "fmul v20.4s, v20.4s, v4.s[3]\n" - "fmul v19.4s, v19.4s, v3.s[0]\n" - "fmul v18.4s, v18.4s, v3.s[1]\n" - "fmul v17.4s, v17.4s, v3.s[2]\n" - "fmul v16.4s, v16.4s, v3.s[3]\n" - "fadd v31.4s, v31.4s, v2.4s\n" - "fadd v30.4s, v30.4s, v2.4s\n" - "fadd v29.4s, v29.4s, v2.4s\n" - "fadd v28.4s, v28.4s, v2.4s\n" - "fadd v27.4s, v27.4s, v2.4s\n" - "fadd v26.4s, v26.4s, v2.4s\n" - "fadd v25.4s, v25.4s, v2.4s\n" - "fadd v24.4s, v24.4s, v2.4s\n" - "fadd v23.4s, v23.4s, v2.4s\n" - "fadd v22.4s, v22.4s, v2.4s\n" - "fadd v21.4s, v21.4s, v2.4s\n" - "fadd v20.4s, v20.4s, v2.4s\n" - "fadd v19.4s, v19.4s, v2.4s\n" - "fadd v18.4s, v18.4s, v2.4s\n" - "fadd v17.4s, v17.4s, v2.4s\n" - "fadd v16.4s, v16.4s, v2.4s\n" - "fmax v31.4s, v31.4s, v1.4s\n" - "fmax v30.4s, v30.4s, v1.4s\n" - "fmax v29.4s, v29.4s, v1.4s\n" - "fmax v28.4s, v28.4s, v1.4s\n" - "fmax v27.4s, v27.4s, v1.4s\n" - "fmax v26.4s, v26.4s, v1.4s\n" - "fmax v25.4s, v25.4s, v1.4s\n" - "fmax v24.4s, v24.4s, v1.4s\n" - "fmax v23.4s, v23.4s, v1.4s\n" - "fmax v22.4s, v22.4s, v1.4s\n" - "fmax v21.4s, v21.4s, v1.4s\n" - "fmax v20.4s, v20.4s, v1.4s\n" - "fmax v19.4s, v19.4s, v1.4s\n" - "fmax v18.4s, v18.4s, v1.4s\n" - "fmax v17.4s, v17.4s, v1.4s\n" - "fmax v16.4s, v16.4s, v1.4s\n" - "fmin v31.4s, v31.4s, v0.4s\n" - "fmin v30.4s, v30.4s, v0.4s\n" - "fmin v29.4s, v29.4s, v0.4s\n" - "fmin v28.4s, v28.4s, v0.4s\n" - "fmin v27.4s, v27.4s, v0.4s\n" - "fmin v26.4s, v26.4s, v0.4s\n" - "fmin v25.4s, v25.4s, v0.4s\n" - "fmin v24.4s, v24.4s, v0.4s\n" - "fmin v23.4s, v23.4s, v0.4s\n" - "fmin v22.4s, v22.4s, v0.4s\n" - "fmin v21.4s, v21.4s, v0.4s\n" - "fmin v20.4s, v20.4s, v0.4s\n" - "fmin v19.4s, v19.4s, v0.4s\n" - "fmin v18.4s, v18.4s, v0.4s\n" - "fmin v17.4s, v17.4s, v0.4s\n" - "fmin v16.4s, v16.4s, v0.4s\n" - "blt 9f\n" - "mov x20, %x[dst]\n" - "str q31, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q29, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q28, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q27, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q26, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q24, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q21, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q20, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q17, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q16, [x20, #0x0]\n" - "b 14f\n" - "9:" // Partial output - "mov x28, %x[dst]\n" - "add x26, x28, %x[dst_stride_row], LSL #2\n" - "add x25, x26, %x[dst_stride_row], LSL #1\n" - "add x24, x26, %x[dst_stride_row]\n" - "add x23, x25, %x[dst_stride_row]\n" - "add x22, x28, %x[dst_stride_row], LSL #1\n" - "add x21, x28, %x[dst_stride_row]\n" - "add x20, x22, %x[dst_stride_row]\n" - "add x27, x23, %x[dst_stride_row]\n" - "tbz x10, #1, 10f\n" - "st1 { v24.d }[0], [x23], #0x8\n" - "st1 { v25.d }[0], [x25], #0x8\n" - "st1 { v26.d }[0], [x24], #0x8\n" - "st1 { v27.d }[0], [x26], #0x8\n" - "st1 { v28.d }[0], [x20], #0x8\n" - "st1 { v29.d }[0], [x22], #0x8\n" - "st1 { v30.d }[0], [x21], #0x8\n" - "st1 { v31.d }[0], [x28], #0x8\n" - "tbz x10, #0, 11f\n" - "st1 { v24.s }[2], [x23]\n" - "st1 { v25.s }[2], [x25]\n" - "st1 { v26.s }[2], [x24]\n" - "st1 { v27.s }[2], [x26]\n" - "st1 { v28.s }[2], [x20]\n" - "st1 { v29.s }[2], [x22]\n" - "st1 { v30.s }[2], [x21]\n" - "st1 { v31.s }[2], [x28]\n" - "b 11f\n" - "10:" // Output block 0: partial_1_0 - "st1 { v24.s }[0], [x23]\n" - "st1 { v25.s }[0], [x25]\n" - "st1 { v26.s }[0], [x24]\n" - "st1 { v27.s }[0], [x26]\n" - "st1 { v28.s }[0], [x20]\n" - "st1 { v29.s }[0], [x22]\n" - "st1 { v30.s }[0], [x21]\n" - "st1 { v31.s }[0], [x28]\n" - "11:" // Output block 0: Done - "add x26, x27, %x[dst_stride_row], LSL #2\n" - "add x25, x27, %x[dst_stride_row], LSL #1\n" - "add x24, x26, %x[dst_stride_row], LSL #1\n" - "add x23, x27, %x[dst_stride_row]\n" - "add x22, x25, %x[dst_stride_row]\n" - "add x21, x26, %x[dst_stride_row]\n" - "add x20, x24, %x[dst_stride_row]\n" - "tbz x10, #1, 12f\n" - "st1 { v16.d }[0], [x20], #0x8\n" - "st1 { v17.d }[0], [x24], #0x8\n" - "st1 { v18.d }[0], [x21], #0x8\n" - "st1 { v19.d }[0], [x26], #0x8\n" - "st1 { v20.d }[0], [x22], #0x8\n" - "st1 { v21.d }[0], [x25], #0x8\n" - "st1 { v22.d }[0], [x23], #0x8\n" - "st1 { v23.d }[0], [x27], #0x8\n" - "tbz x10, #0, 13f\n" - "st1 { v16.s }[2], [x20]\n" - "st1 { v17.s }[2], [x24]\n" - "st1 { v18.s }[2], [x21]\n" - "st1 { v19.s }[2], [x26]\n" - "st1 { v20.s }[2], [x22]\n" - "st1 { v21.s }[2], [x25]\n" - "st1 { v22.s }[2], [x23]\n" - "st1 { v23.s }[2], [x27]\n" - "b 13f\n" - "12:" // Output block 1: partial_1_0 - "st1 { v16.s }[0], [x20]\n" - "st1 { v17.s }[0], [x24]\n" - "st1 { v18.s }[0], [x21]\n" - "st1 { v19.s }[0], [x26]\n" - "st1 { v20.s }[0], [x22]\n" - "st1 { v21.s }[0], [x25]\n" - "st1 { v22.s }[0], [x23]\n" - "st1 { v23.s }[0], [x27]\n" - "13:" // Output block 1: Done - "14:" // Output stage exit - "subs x10, x10, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 2b\n" - "mov x20, #0x4\n" - "sub x12, x12, #0x10\n" - "cmp x12, #0x10\n" - "mov %x[dst], x9\n" - "madd %x[lhs_packed], x20, x13, %x[lhs_packed]\n" - "bge 1b\n" - "15:" // Row loop skip - "cbz x12, 25f\n" - "16:" // Row tail: Row loop - "mov x26, %x[rhs_packed]\n" - "mov x25, %x[n]\n" - "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" - "17:" // Row tail: Column loop - "movi v16.4s, #0x0\n" - "mov x27, %x[lhs_packed]\n" - "mov x21, %x[num_blocks]\n" - "str q16, [SP, #0x0]\n" - "str q16, [SP, #0x10]\n" - "str q16, [SP, #0x20]\n" - "str q16, [SP, #0x30]\n" - "18:" // Row tail: Block loop - "movi v7.4s, #0x0\n" - "movi v13.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "movi v29.4s, #0x0\n" - "movi v12.4s, #0x0\n" - "19:" // Row tail: Sub block loop - "ldr q0, [x26, #0x0]\n" - "ldr q31, [x26, #0x10]\n" - "movi v30.16b, #0xf0\n" - "subs x20, x20, #0x1\n" - "ldr q18, [x27, #0x0]\n" - "ldr q28, [x27, #0x10]\n" - "ldr q27, [x26, #0x20]\n" - "ldr q26, [x26, #0x30]\n" - "add x26, x26, #0x40\n" - "ldr q25, [x27, #0x20]\n" - "ldr q24, [x27, #0x30]\n" - "shl v23.16b, v0.16b, #0x4\n" - "shl v22.16b, v31.16b, #0x4\n" - "ldr q21, [x27, #0x40]\n" - "ldr q20, [x27, #0x50]\n" - "and v0.16b, v0.16b, v30.16b\n" - "and v31.16b, v31.16b, v30.16b\n" - "ldr q19, [x27, #0x60]\n" - "ldr q14, [x27, #0x70]\n" - "shl v17.16b, v27.16b, #0x4\n" - "shl v16.16b, v26.16b, #0x4\n" - ".inst 0x4e97a647 // smmla v7.4s, v18.16b, v23.16b\n" - ".inst 0x4e96a64d // smmla v13.4s, v18.16b, v22.16b\n" - "and v27.16b, v27.16b, v30.16b\n" - "add x27, x27, #0x80\n" - ".inst 0x4e97a79d // smmla v29.4s, v28.16b, v23.16b\n" - ".inst 0x4e96a78c // smmla v12.4s, v28.16b, v22.16b\n" - "and v26.16b, v26.16b, v30.16b\n" - ".inst 0x4e91a727 // smmla v7.4s, v25.16b, v17.16b\n" - ".inst 0x4e90a72d // smmla v13.4s, v25.16b, v16.16b\n" - ".inst 0x4e91a71d // smmla v29.4s, v24.16b, v17.16b\n" - ".inst 0x4e90a70c // smmla v12.4s, v24.16b, v16.16b\n" - ".inst 0x4e80a6a7 // smmla v7.4s, v21.16b, v0.16b\n" - ".inst 0x4e9fa6ad // smmla v13.4s, v21.16b, v31.16b\n" - ".inst 0x4e80a69d // smmla v29.4s, v20.16b, v0.16b\n" - ".inst 0x4e9fa68c // smmla v12.4s, v20.16b, v31.16b\n" - ".inst 0x4e9ba667 // smmla v7.4s, v19.16b, v27.16b\n" - ".inst 0x4e9aa66d // smmla v13.4s, v19.16b, v26.16b\n" - ".inst 0x4e9ba5dd // smmla v29.4s, v14.16b, v27.16b\n" - ".inst 0x4e9aa5cc // smmla v12.4s, v14.16b, v26.16b\n" - "bgt 19b\n" - "ldr d17, [x26, #0x0]\n" - "ldr q21, [SP, #0x0]\n" - "uzp1 v16.2d, v7.2d, v13.2d\n" - "uzp2 v20.2d, v7.2d, v13.2d\n" - "uzp1 v19.2d, v29.2d, v12.2d\n" - "uzp2 v18.2d, v29.2d, v12.2d\n" - "add x26, x26, #0x8\n" - "shll v17.4s, v17.4h, #0x10\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v21.4s, v16.4s, v17.4s\n" - "str q21, [SP, #0x0]\n" - "ldr q16, [SP, #0x10]\n" - "fmla v16.4s, v20.4s, v17.4s\n" - "str q16, [SP, #0x10]\n" - "ldr q16, [SP, #0x20]\n" - "fmla v16.4s, v19.4s, v17.4s\n" - "str q16, [SP, #0x20]\n" - "ldr q16, [SP, #0x30]\n" - "fmla v16.4s, v18.4s, v17.4s\n" - "str q16, [SP, #0x30]\n" - "subs x21, x21, #0x1\n" - "bgt 18b\n" - "ld1 { v21.4s }, [x27]\n" - "ldr q31, [SP, #0x0]\n" - "add x27, x27, #0x10\n" - "add x20, %x[clamp_vals], #0x4\n" - "ldr q30, [SP, #0x10]\n" - "ldr q29, [SP, #0x20]\n" - "cmp x25, #0x4\n" - "ldr q28, [SP, #0x30]\n" - "ldr q20, [x26, #0x0]\n" - "ldr q19, [x27, #0x0]\n" - "ldr q18, [x26, #0x10]\n" - "scvtf v21.4s, v21.4s\n" - "add x26, x26, #0x20\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v16.4s }, [x20]\n" - "fmla v31.4s, v20.4s, v21.s[0]\n" - "fmla v30.4s, v20.4s, v21.s[1]\n" - "fmla v29.4s, v20.4s, v21.s[2]\n" - "fmla v28.4s, v20.4s, v21.s[3]\n" - "fmul v31.4s, v31.4s, v19.s[0]\n" - "fmul v30.4s, v30.4s, v19.s[1]\n" - "fadd v31.4s, v31.4s, v18.4s\n" - "fmul v29.4s, v29.4s, v19.s[2]\n" - "fmul v28.4s, v28.4s, v19.s[3]\n" - "fadd v30.4s, v30.4s, v18.4s\n" - "fmax v31.4s, v31.4s, v17.4s\n" - "fadd v29.4s, v29.4s, v18.4s\n" - "fadd v28.4s, v28.4s, v18.4s\n" - "fmax v30.4s, v30.4s, v17.4s\n" - "fmin v31.4s, v31.4s, v16.4s\n" - "fmax v29.4s, v29.4s, v17.4s\n" - "fmax v28.4s, v28.4s, v17.4s\n" - "fmin v30.4s, v30.4s, v16.4s\n" - "fmin v29.4s, v29.4s, v16.4s\n" - "fmin v28.4s, v28.4s, v16.4s\n" - "blt 21f\n" - "mov x20, %x[dst]\n" - "cmp x12, #0x1\n" - "str q31, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 24f\n" - "cmp x12, #0x2\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 24f\n" - "cmp x12, #0x3\n" - "str q29, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 24f\n" - "str q28, [x20, #0x0]\n" - "b 24f\n" - "21:" // Row tail: Partial output - "mov x23, %x[dst]\n" - "cmp x12, #0x1\n" - "add x22, x23, %x[dst_stride_row]\n" - "csel x22, x22, x23, GT\n" - "cmp x12, #0x2\n" - "add x21, x23, %x[dst_stride_row], LSL #1\n" - "csel x21, x21, x22, GT\n" - "cmp x12, #0x3\n" - "add x20, x21, %x[dst_stride_row]\n" - "csel x20, x20, x21, GT\n" - "tbz x25, #1, 22f\n" - "st1 { v28.d }[0], [x20], #0x8\n" - "st1 { v29.d }[0], [x21], #0x8\n" - "st1 { v30.d }[0], [x22], #0x8\n" - "st1 { v31.d }[0], [x23], #0x8\n" - "tbz x25, #0, 23f\n" - "st1 { v28.s }[2], [x20]\n" - "st1 { v29.s }[2], [x21]\n" - "st1 { v30.s }[2], [x22]\n" - "st1 { v31.s }[2], [x23]\n" - "b 23f\n" - "22:" // Row tail: Output block 0: partial_1_0 - "st1 { v28.s }[0], [x20]\n" - "st1 { v29.s }[0], [x21]\n" - "st1 { v30.s }[0], [x22]\n" - "st1 { v31.s }[0], [x23]\n" - "23:" // Row tail: Output block 0: Done - "24:" // Row tail: Output stage exit - "subs x25, x25, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 17b\n" - "subs x12, x12, #0x4\n" - "add %x[lhs_packed], %x[lhs_packed], x13\n" - "mov %x[dst], x24\n" - "bgt 16b\n" - "25:" // Row tail: Row loop skip - "add SP, SP, #0x100\n" - : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) - : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), - [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v2", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v31", "v4", "v5", - "v6", "v7", "v8", "v9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", - "x27", "x28", "x9"); - } + const size_t num_subblocks = bl / 32; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(&args); } -#endif // Architectural feature check + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h index f5c9cd29..37186f89 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h @@ -1,21 +1,22 @@ - // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // + #pragma once #include #ifdef __cplusplus extern "C" { -#endif +#endif // __cplusplus /// Micro-kernel dependencies /// -/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix +/// -# @ref kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS KxN matrix. /// -------------------------------------------------- @@ -38,7 +39,7 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( /// @return the mr value size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void); -/// Gets the nr value, which must be used to pack the RHS matrix +/// Gets the nr value, which must be used to pack the RHS matrix. /// /// @return the nr value size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void); @@ -54,13 +55,14 @@ size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void); /// Gets the offset in bytes for the packed LHS matrix, -/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. /// /// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. /// -/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 16 +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. /// @param[in] k Total number of columns in the LHS matrix (not packed). /// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( @@ -68,9 +70,9 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32 size_t k); // /// Gets the offset in bytes for the packed RHS matrix, -/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// which contains the packed Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) values. /// -/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. /// @param[in] k The common dimension between the LHS and RHS matrix (K). /// It must be a multiple of the block length (bl). /// @param[in] bl Block length. It must be a multiple of 32. @@ -83,8 +85,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32 /// Gets the offset in bytes for the DST matrix /// -/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 16. -/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. /// @param[in] dst_stride The number of bytes in in each row of the DST matrix /// /// @return the DST offset in bytes @@ -105,26 +107,26 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8m /// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. /// -/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed -/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. -/// Output tile: (rows x cols) = 8 x 4 -/// Accumulation performed in a single for loop: 32 -/// Extension used: i8mm +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: i8mm /// /// @param[in] m The number of output rows written. /// @param[in] n The number of output columns written. /// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. /// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. It must be a multiple of 32. -/// @param[in] lhs_packed The LHS packed matrix. -/// When the activation are dynamically quantized, you can obtain this matrix -/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs -/// both the dynamic quantization to 8-bit and activation packing in a single step. -/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. -/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. /// @param[in] scalar_min Min value used to clamp the final result. /// @param[in] scalar_max Max value used to clamp the final result. void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( @@ -142,4 +144,4 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( #ifdef __cplusplus } -#endif +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c new file mode 100644 index 00000000..fa7add7e --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c @@ -0,0 +1,185 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) +#error "I8mm extension required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 16; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 2; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs; + return num_bytes_per_block_rhs; +} + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); + + size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( + size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); + + if (m == 0) { + return; + } + const size_t num_subblocks = bl / 32; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.h new file mode 100644 index 00000000..4a11aa0b --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.h @@ -0,0 +1,147 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus -- GitLab From 6eda4ced7401704d76c39ba4529b3bb61dd1904f Mon Sep 17 00:00:00 2001 From: Michael Kozlov Date: Fri, 24 Jan 2025 11:18:11 +0000 Subject: [PATCH 08/15] Update build scripts for new assembly file kernels - Add new assembly kernels to Bazel build - Move assembly kernels out of subfolder and update CMake build Signed-off-by: Michael Kozlov --- CMakeLists.txt | 16 +++++------ kai/ukernels/matmul/BUILD.bazel | 28 +++++++++++++++++++ ...8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S | 0 ...p1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S | 0 ...p1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S | 0 ...dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S | 0 ...dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S | 0 ..._qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S | 0 ...8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S | 0 ...8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S | 0 10 files changed, 36 insertions(+), 8 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/{asm => }/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S (100%) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/{asm => }/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S (100%) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/{asm => }/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S (100%) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/{asm => }/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S (100%) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/{asm => }/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S (100%) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/{asm => }/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S (100%) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/{asm => }/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S (100%) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/{asm => }/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 94059adc..6a09a912 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -142,10 +142,10 @@ set(KLEIDIAI_FILES_NEON_DOTPROD kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S ) set(KLEIDIAI_FILES_NEON_I8MM @@ -159,10 +159,10 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S ) set(KLEIDIAI_FILES_SME diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 42a3f441..294b52ea 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -86,6 +86,13 @@ DOTPROD_KERNELS = [ "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod", ] +DOTPROD_KERNELS_ASM = [ + "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod", + "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod", + "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod", + "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod", +] + # buildifier: keep sorted I8MM_KERNELS = [ "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", @@ -100,6 +107,13 @@ I8MM_KERNELS = [ "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm", ] +I8MM_KERNELS_ASM = [ + "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", + "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", + "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm", + "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", +] + # buildifier: keep sorted SME_KERNELS = [ "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", @@ -183,6 +197,12 @@ kai_c_library( textual_hdrs = [ukernel + ".h" for ukernel in DOTPROD_KERNELS], ) +kai_c_library( + name = "dotprod_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in DOTPROD_KERNELS_ASM], + cpu_uarch = kai_cpu_dotprod(), +) + kai_c_library( name = "i8mm_impl", srcs = [ukernel + ".c" for ukernel in I8MM_KERNELS], @@ -190,6 +210,12 @@ kai_c_library( textual_hdrs = [ukernel + ".h" for ukernel in I8MM_KERNELS], ) +kai_c_library( + name = "i8mm_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in I8MM_KERNELS_ASM], + cpu_uarch = kai_cpu_i8mm(), +) + kai_c_library( name = "sme_impl", srcs = [ukernel + ".c" for ukernel in SME_KERNELS], @@ -210,9 +236,11 @@ kai_c_library( deps = [ ":bf16_impl", ":dotprod_impl", + ":dotprod_impl_asm", ":fp16_bf16_impl", ":fp16_impl", ":i8mm_impl", + ":i8mm_impl_asm", ":interface", ":neon_impl", ":neon_impl_asm", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/asm/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S -- GitLab From 678f5b5a6d17d091aae6f420d07569f97fa447f7 Mon Sep 17 00:00:00 2001 From: Michael Kozlov Date: Fri, 24 Jan 2025 11:45:02 +0000 Subject: [PATCH 09/15] Remove subblocks arg Signed-off-by: Michael Kozlov --- ...l_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c | 2 -- 1 file changed, 2 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c index fa7add7e..503e207c 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c @@ -163,7 +163,6 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( if (m == 0) { return; } - const size_t num_subblocks = bl / 32; const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); const float clamp_vals[2] = {scalar_min, scalar_max}; @@ -177,7 +176,6 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( args.m = m; args.n = n; args.num_blocks = num_blocks; - args.num_subblocks = num_subblocks; kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(&args); } -- GitLab From d16442e2da067fa909fe6a40c329c1419bf258d9 Mon Sep 17 00:00:00 2001 From: Michael Kozlov Date: Fri, 24 Jan 2025 12:16:44 +0000 Subject: [PATCH 10/15] Update examples build script Signed-off-by: Michael Kozlov --- examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt index 76fdc7f8..28bbd67c 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt @@ -6,6 +6,8 @@ cmake_minimum_required(VERSION 3.16) +enable_language(ASM) + set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(KLEIDIAI_PATH ../../) @@ -26,11 +28,17 @@ add_executable(matmul_clamp_f32_qai8dxp_qsi4c32p ${MATMUL_PACK_PATH}/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_qai8dxp_f32.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S ) target_compile_options(matmul_clamp_f32_qai8dxp_qsi4c32p -- GitLab From 9015fb051b9b11711ec221d9c327b7010a3664d6 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 27 Jan 2025 14:30:51 +0000 Subject: [PATCH 11/15] Update Changelog Signed-off-by: Anitha Raj --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bd9fbaf..f50c025a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release - Update FP16 example to use NHWC input +- Add MSVC support for kai_matmul_clamp_f32_qai8dxp_qsi4c32p micro-kernels. - Fixes: - Fix compilation warnings detected by `-Wcast-qual -Wmissing-prototypes -Wstrict-prototypes -Woverlength-strings` compiler options. - Support compiling the project with the above compilation options enabled. -- GitLab From 392e0d5c090ed7875f90ff0e0a336809d2d8410c Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 27 Jan 2025 17:42:59 +0000 Subject: [PATCH 12/15] Update 16x4 kernel with optimizations for bl=32 Remove the additional opt32 micro-kernel file. Add the generic and the bl=32 optimized micro-kernel assemblies to the same file for 16x4 kernels. Update CMakeLists Signed-off-by: Anitha Raj --- CMakeLists.txt | 1 - ..._qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h | 4 +- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 2 - ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 2 - ...qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c | 11 +- ...qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h | 6 +- ...dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S | 696 ++++++++++++++++++ ...qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c | 7 +- ...qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h | 2 - ...dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S | 636 ++++++++++++++++ ...p4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c | 183 ----- ...p4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.h | 147 ---- ..._qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S | 670 ----------------- ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h | 2 - ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h | 2 - 15 files changed, 1349 insertions(+), 1022 deletions(-) delete mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c delete mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.h delete mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index 6a09a912..19675fbf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -162,7 +162,6 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S ) set(KLEIDIAI_FILES_SME diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h index 0d7ca485..540f54fc 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -61,8 +61,6 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod(void) /// /// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. /// @param[in] k Total number of columns in the LHS matrix (not packed). -/// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. It must be a multiple of 32. /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod( diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h index 1685c30b..a865c16d 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -61,8 +61,6 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(vo /// /// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. /// @param[in] k Total number of columns in the LHS matrix (not packed). -/// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. It must be a multiple of 32. /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h index 225add8a..f23e05cf 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -61,8 +61,6 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(vo /// /// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. /// @param[in] k Total number of columns in the LHS matrix (not packed). -/// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. It must be a multiple of 32. /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c index b4faf0ec..18930d7f 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c @@ -3,10 +3,6 @@ // // SPDX-License-Identifier: Apache-2.0 // - -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - #if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) #error "Dotprod extension required to compile this micro-kernel" #else // Architectural features check. @@ -31,6 +27,7 @@ typedef struct { } KernelArgs; void kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod(KernelArgs* args_ptr); +void kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_opt32_neon_dotprod(KernelArgs* args_ptr); // Compute args static const size_t kai_m_step = 16; @@ -183,7 +180,11 @@ void kai_run_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod( args.num_blocks = num_blocks; args.num_subblocks = num_subblocks; - kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod(&args); + if (bl == 32) { + kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_opt32_neon_dotprod(&args); + } else { + kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod(&args); + } } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h index 0898f30d..8c257101 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -72,6 +72,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_ne /// /// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. /// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// It must be a multiple of the block length (bl). /// @param[in] bl Block length. It must be a multiple of 32. /// /// @return the offset in bytes to the packed RHS matrix @@ -108,13 +109,14 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotpro /// RHS matrix: Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) and packed. /// Output tile: (rows x cols) = m_step x n_step. /// -/// Note: Please, refer to the get functions for m_step and n_step for the exact values. +/// Note: Please refer to the get functions for m_step and n_step for the exact values. /// /// Features used: dotprod /// /// @param[in] m The number of output rows written. /// @param[in] n The number of output columns written. /// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// It must be a multiple of the block length (bl). /// @param[in] bl Block length. Block length. It must be a multiple of 32. /// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the /// top of this file. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S index dbb7a11d..1a448c5f 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S @@ -802,3 +802,699 @@ KAI_ASM_LABEL(label_25) // Row tail: Row loop skip KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod) KAI_ASM_END + +// Optimized kernel for bl = 32 + + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_opt32_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_opt32_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_opt32_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_opt32_neon_dotprod) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x6, #0x80 + mov x21, #0x20 + ldr x20, [x0, #0x28] + ldr x7, [x0, #0x38] + ldr x8, [x0, #0x8] + ldr x17, [x0, #0x10] + ldr x16, [x0, #0x30] + ldr x15, [x0, #0x0] + mov x14, x20 + ldr x13, [x0, #0x20] + madd x6, x7, x6, x21 + ldr x12, [x0, #0x18] + cmp x14, #0x10 + blt label_opt_14 +KAI_ASM_LABEL(label_opt_1) // Row loop + mov x11, x17 + mov x10, x16 + add x9, x15, x13, LSL #4 +KAI_ASM_LABEL(label_opt_2) // Column loop + mov x27, x8 + movi v31.16b, #0x0 + movi v30.16b, #0x0 + mov x20, x7 + movi v29.16b, #0x0 + movi v28.16b, #0x0 + movi v27.16b, #0x0 + movi v26.16b, #0x0 + add x23, x27, x6 + add x22, x23, x6 + movi v25.16b, #0x0 + movi v24.16b, #0x0 + add x21, x22, x6 + movi v23.16b, #0x0 + movi v22.16b, #0x0 + movi v21.16b, #0x0 + movi v20.16b, #0x0 + movi v19.16b, #0x0 + movi v18.16b, #0x0 + movi v17.16b, #0x0 + movi v16.16b, #0x0 +KAI_ASM_LABEL(label_opt_3) // Block loop + ldr q3, [x11, #0x0] + ldr q2, [x27, #0x0] + movi v5.4s, #0x0 + movi v4.4s, #0x0 + ldr q0, [x11, #0x10] + ldr q1, [x27, #0x10] + movi v6.4s, #0x0 + movi v11.4s, #0x0 + ldr q15, [x11, #0x20] + ldr q14, [x27, #0x20] + movi v7.16b, #0xf0 + ldr q13, [x11, #0x30] + ldr q8, [x27, #0x30] + shl v12.16b, v3.16b, #0x4 + add x11, x11, #0x40 + ldr q9, [x27, #0x40] + ldr q10, [x27, #0x50] + and v3.16b, v3.16b, v7.16b + KAI_ASM_INST(0x4f82e185) // sdot v5.4s, v12.16b, v2.4b[0] + KAI_ASM_INST(0x4fa2e184) // sdot v4.4s, v12.16b, v2.4b[1] + KAI_ASM_INST(0x4f82e986) // sdot v6.4s, v12.16b, v2.4b[2] + KAI_ASM_INST(0x4fa2e98b) // sdot v11.4s, v12.16b, v2.4b[3] + shl v2.16b, v0.16b, #0x4 + and v0.16b, v0.16b, v7.16b + KAI_ASM_INST(0x4f81e045) // sdot v5.4s, v2.16b, v1.4b[0] + KAI_ASM_INST(0x4fa1e044) // sdot v4.4s, v2.16b, v1.4b[1] + KAI_ASM_INST(0x4f81e846) // sdot v6.4s, v2.16b, v1.4b[2] + KAI_ASM_INST(0x4fa1e84b) // sdot v11.4s, v2.16b, v1.4b[3] + shl v1.16b, v15.16b, #0x4 + and v15.16b, v15.16b, v7.16b + KAI_ASM_INST(0x4f8ee025) // sdot v5.4s, v1.16b, v14.4b[0] + KAI_ASM_INST(0x4faee024) // sdot v4.4s, v1.16b, v14.4b[1] + KAI_ASM_INST(0x4f8ee826) // sdot v6.4s, v1.16b, v14.4b[2] + KAI_ASM_INST(0x4faee82b) // sdot v11.4s, v1.16b, v14.4b[3] + shl v14.16b, v13.16b, #0x4 + and v13.16b, v13.16b, v7.16b + ldr q7, [x27, #0x60] + KAI_ASM_INST(0x4f88e1c5) // sdot v5.4s, v14.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e1c4) // sdot v4.4s, v14.16b, v8.4b[1] + KAI_ASM_INST(0x4f88e9c6) // sdot v6.4s, v14.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8e9cb) // sdot v11.4s, v14.16b, v8.4b[3] + ldr q8, [x27, #0x70] + add x27, x27, #0x80 + KAI_ASM_INST(0x4f89e065) // sdot v5.4s, v3.16b, v9.4b[0] + KAI_ASM_INST(0x4fa9e064) // sdot v4.4s, v3.16b, v9.4b[1] + KAI_ASM_INST(0x4f89e866) // sdot v6.4s, v3.16b, v9.4b[2] + KAI_ASM_INST(0x4fa9e86b) // sdot v11.4s, v3.16b, v9.4b[3] + ldr d9, [x11, #0x0] + add x11, x11, #0x8 + KAI_ASM_INST(0x4f8ae005) // sdot v5.4s, v0.16b, v10.4b[0] + KAI_ASM_INST(0x4faae004) // sdot v4.4s, v0.16b, v10.4b[1] + shll v9.4s, v9.4h, #0x10 + KAI_ASM_INST(0x4f8ae806) // sdot v6.4s, v0.16b, v10.4b[2] + KAI_ASM_INST(0x4faae80b) // sdot v11.4s, v0.16b, v10.4b[3] + KAI_ASM_INST(0x4f87e1e5) // sdot v5.4s, v15.16b, v7.4b[0] + KAI_ASM_INST(0x4fa7e1e4) // sdot v4.4s, v15.16b, v7.4b[1] + KAI_ASM_INST(0x4f87e9e6) // sdot v6.4s, v15.16b, v7.4b[2] + KAI_ASM_INST(0x4fa7e9eb) // sdot v11.4s, v15.16b, v7.4b[3] + KAI_ASM_INST(0x4f88e1a5) // sdot v5.4s, v13.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e1a4) // sdot v4.4s, v13.16b, v8.4b[1] + KAI_ASM_INST(0x4f88e9a6) // sdot v6.4s, v13.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8e9ab) // sdot v11.4s, v13.16b, v8.4b[3] + scvtf v5.4s, v5.4s, #0x4 + scvtf v4.4s, v4.4s, #0x4 + fmla v31.4s, v5.4s, v9.4s + scvtf v6.4s, v6.4s, #0x4 + scvtf v11.4s, v11.4s, #0x4 + fmla v30.4s, v4.4s, v9.4s + fmla v29.4s, v6.4s, v9.4s + fmla v28.4s, v11.4s, v9.4s + ldr q8, [x23, #0x0] + ldr q5, [x23, #0x10] + movi v11.4s, #0x0 + movi v10.4s, #0x0 + ldr q7, [x23, #0x20] + movi v4.4s, #0x0 + movi v6.4s, #0x0 + KAI_ASM_INST(0x4f88e18b) // sdot v11.4s, v12.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e18a) // sdot v10.4s, v12.16b, v8.4b[1] + KAI_ASM_INST(0x4f88e984) // sdot v4.4s, v12.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8e986) // sdot v6.4s, v12.16b, v8.4b[3] + ldr q8, [x23, #0x30] + KAI_ASM_INST(0x4f85e04b) // sdot v11.4s, v2.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e04a) // sdot v10.4s, v2.16b, v5.4b[1] + KAI_ASM_INST(0x4f85e844) // sdot v4.4s, v2.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5e846) // sdot v6.4s, v2.16b, v5.4b[3] + ldr q5, [x23, #0x40] + KAI_ASM_INST(0x4f87e02b) // sdot v11.4s, v1.16b, v7.4b[0] + KAI_ASM_INST(0x4fa7e02a) // sdot v10.4s, v1.16b, v7.4b[1] + KAI_ASM_INST(0x4f87e824) // sdot v4.4s, v1.16b, v7.4b[2] + KAI_ASM_INST(0x4fa7e826) // sdot v6.4s, v1.16b, v7.4b[3] + ldr q7, [x23, #0x50] + KAI_ASM_INST(0x4f88e1cb) // sdot v11.4s, v14.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e1ca) // sdot v10.4s, v14.16b, v8.4b[1] + KAI_ASM_INST(0x4f88e9c4) // sdot v4.4s, v14.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8e9c6) // sdot v6.4s, v14.16b, v8.4b[3] + ldr q8, [x23, #0x60] + KAI_ASM_INST(0x4f85e06b) // sdot v11.4s, v3.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e06a) // sdot v10.4s, v3.16b, v5.4b[1] + KAI_ASM_INST(0x4f85e864) // sdot v4.4s, v3.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5e866) // sdot v6.4s, v3.16b, v5.4b[3] + ldr q5, [x23, #0x70] + add x23, x23, #0x80 + KAI_ASM_INST(0x4f87e00b) // sdot v11.4s, v0.16b, v7.4b[0] + KAI_ASM_INST(0x4fa7e00a) // sdot v10.4s, v0.16b, v7.4b[1] + KAI_ASM_INST(0x4f87e804) // sdot v4.4s, v0.16b, v7.4b[2] + KAI_ASM_INST(0x4fa7e806) // sdot v6.4s, v0.16b, v7.4b[3] + KAI_ASM_INST(0x4f88e1eb) // sdot v11.4s, v15.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e1ea) // sdot v10.4s, v15.16b, v8.4b[1] + KAI_ASM_INST(0x4f88e9e4) // sdot v4.4s, v15.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8e9e6) // sdot v6.4s, v15.16b, v8.4b[3] + KAI_ASM_INST(0x4f85e1ab) // sdot v11.4s, v13.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e1aa) // sdot v10.4s, v13.16b, v5.4b[1] + KAI_ASM_INST(0x4f85e9a4) // sdot v4.4s, v13.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5e9a6) // sdot v6.4s, v13.16b, v5.4b[3] + scvtf v11.4s, v11.4s, #0x4 + scvtf v10.4s, v10.4s, #0x4 + scvtf v4.4s, v4.4s, #0x4 + fmla v27.4s, v11.4s, v9.4s + scvtf v6.4s, v6.4s, #0x4 + fmla v26.4s, v10.4s, v9.4s + fmla v25.4s, v4.4s, v9.4s + fmla v24.4s, v6.4s, v9.4s + ldr q5, [x22, #0x0] + ldr q4, [x22, #0x10] + movi v11.4s, #0x0 + movi v10.4s, #0x0 + ldr q6, [x22, #0x20] + movi v8.4s, #0x0 + movi v7.4s, #0x0 + KAI_ASM_INST(0x4f85e18b) // sdot v11.4s, v12.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e18a) // sdot v10.4s, v12.16b, v5.4b[1] + KAI_ASM_INST(0x4f85e988) // sdot v8.4s, v12.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5e987) // sdot v7.4s, v12.16b, v5.4b[3] + ldr q5, [x22, #0x30] + KAI_ASM_INST(0x4f84e04b) // sdot v11.4s, v2.16b, v4.4b[0] + KAI_ASM_INST(0x4fa4e04a) // sdot v10.4s, v2.16b, v4.4b[1] + KAI_ASM_INST(0x4f84e848) // sdot v8.4s, v2.16b, v4.4b[2] + KAI_ASM_INST(0x4fa4e847) // sdot v7.4s, v2.16b, v4.4b[3] + ldr q4, [x22, #0x40] + KAI_ASM_INST(0x4f86e02b) // sdot v11.4s, v1.16b, v6.4b[0] + KAI_ASM_INST(0x4fa6e02a) // sdot v10.4s, v1.16b, v6.4b[1] + KAI_ASM_INST(0x4f86e828) // sdot v8.4s, v1.16b, v6.4b[2] + KAI_ASM_INST(0x4fa6e827) // sdot v7.4s, v1.16b, v6.4b[3] + ldr q6, [x22, #0x50] + KAI_ASM_INST(0x4f85e1cb) // sdot v11.4s, v14.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e1ca) // sdot v10.4s, v14.16b, v5.4b[1] + KAI_ASM_INST(0x4f85e9c8) // sdot v8.4s, v14.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5e9c7) // sdot v7.4s, v14.16b, v5.4b[3] + ldr q5, [x22, #0x60] + KAI_ASM_INST(0x4f84e06b) // sdot v11.4s, v3.16b, v4.4b[0] + KAI_ASM_INST(0x4fa4e06a) // sdot v10.4s, v3.16b, v4.4b[1] + KAI_ASM_INST(0x4f84e868) // sdot v8.4s, v3.16b, v4.4b[2] + KAI_ASM_INST(0x4fa4e867) // sdot v7.4s, v3.16b, v4.4b[3] + ldr q4, [x22, #0x70] + add x22, x22, #0x80 + KAI_ASM_INST(0x4f86e00b) // sdot v11.4s, v0.16b, v6.4b[0] + KAI_ASM_INST(0x4fa6e00a) // sdot v10.4s, v0.16b, v6.4b[1] + KAI_ASM_INST(0x4f86e808) // sdot v8.4s, v0.16b, v6.4b[2] + KAI_ASM_INST(0x4fa6e807) // sdot v7.4s, v0.16b, v6.4b[3] + KAI_ASM_INST(0x4f85e1eb) // sdot v11.4s, v15.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e1ea) // sdot v10.4s, v15.16b, v5.4b[1] + KAI_ASM_INST(0x4f85e9e8) // sdot v8.4s, v15.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5e9e7) // sdot v7.4s, v15.16b, v5.4b[3] + KAI_ASM_INST(0x4f84e1ab) // sdot v11.4s, v13.16b, v4.4b[0] + KAI_ASM_INST(0x4fa4e1aa) // sdot v10.4s, v13.16b, v4.4b[1] + KAI_ASM_INST(0x4f84e9a8) // sdot v8.4s, v13.16b, v4.4b[2] + KAI_ASM_INST(0x4fa4e9a7) // sdot v7.4s, v13.16b, v4.4b[3] + scvtf v11.4s, v11.4s, #0x4 + scvtf v10.4s, v10.4s, #0x4 + scvtf v8.4s, v8.4s, #0x4 + fmla v23.4s, v11.4s, v9.4s + scvtf v7.4s, v7.4s, #0x4 + fmla v22.4s, v10.4s, v9.4s + fmla v21.4s, v8.4s, v9.4s + fmla v20.4s, v7.4s, v9.4s + ldr q5, [x21, #0x0] + ldr q4, [x21, #0x10] + movi v8.4s, #0x0 + movi v11.4s, #0x0 + ldr q7, [x21, #0x20] + movi v10.4s, #0x0 + movi v6.4s, #0x0 + KAI_ASM_INST(0x4f85e188) // sdot v8.4s, v12.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e18b) // sdot v11.4s, v12.16b, v5.4b[1] + KAI_ASM_INST(0x4f85e98a) // sdot v10.4s, v12.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5e986) // sdot v6.4s, v12.16b, v5.4b[3] + ldr q5, [x21, #0x30] + ldr q12, [x21, #0x40] + KAI_ASM_INST(0x4f84e048) // sdot v8.4s, v2.16b, v4.4b[0] + KAI_ASM_INST(0x4fa4e04b) // sdot v11.4s, v2.16b, v4.4b[1] + KAI_ASM_INST(0x4f84e84a) // sdot v10.4s, v2.16b, v4.4b[2] + KAI_ASM_INST(0x4fa4e846) // sdot v6.4s, v2.16b, v4.4b[3] + ldr q4, [x21, #0x50] + ldr q2, [x21, #0x60] + KAI_ASM_INST(0x4f87e028) // sdot v8.4s, v1.16b, v7.4b[0] + KAI_ASM_INST(0x4fa7e02b) // sdot v11.4s, v1.16b, v7.4b[1] + KAI_ASM_INST(0x4f87e82a) // sdot v10.4s, v1.16b, v7.4b[2] + KAI_ASM_INST(0x4fa7e826) // sdot v6.4s, v1.16b, v7.4b[3] + ldr q1, [x21, #0x70] + add x21, x21, #0x80 + KAI_ASM_INST(0x4f85e1c8) // sdot v8.4s, v14.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e1cb) // sdot v11.4s, v14.16b, v5.4b[1] + KAI_ASM_INST(0x4f85e9ca) // sdot v10.4s, v14.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5e9c6) // sdot v6.4s, v14.16b, v5.4b[3] + KAI_ASM_INST(0x4f8ce068) // sdot v8.4s, v3.16b, v12.4b[0] + KAI_ASM_INST(0x4face06b) // sdot v11.4s, v3.16b, v12.4b[1] + KAI_ASM_INST(0x4f8ce86a) // sdot v10.4s, v3.16b, v12.4b[2] + KAI_ASM_INST(0x4face866) // sdot v6.4s, v3.16b, v12.4b[3] + KAI_ASM_INST(0x4f84e008) // sdot v8.4s, v0.16b, v4.4b[0] + KAI_ASM_INST(0x4fa4e00b) // sdot v11.4s, v0.16b, v4.4b[1] + KAI_ASM_INST(0x4f84e80a) // sdot v10.4s, v0.16b, v4.4b[2] + KAI_ASM_INST(0x4fa4e806) // sdot v6.4s, v0.16b, v4.4b[3] + KAI_ASM_INST(0x4f82e1e8) // sdot v8.4s, v15.16b, v2.4b[0] + KAI_ASM_INST(0x4fa2e1eb) // sdot v11.4s, v15.16b, v2.4b[1] + KAI_ASM_INST(0x4f82e9ea) // sdot v10.4s, v15.16b, v2.4b[2] + KAI_ASM_INST(0x4fa2e9e6) // sdot v6.4s, v15.16b, v2.4b[3] + KAI_ASM_INST(0x4f81e1a8) // sdot v8.4s, v13.16b, v1.4b[0] + KAI_ASM_INST(0x4fa1e1ab) // sdot v11.4s, v13.16b, v1.4b[1] + KAI_ASM_INST(0x4f81e9aa) // sdot v10.4s, v13.16b, v1.4b[2] + KAI_ASM_INST(0x4fa1e9a6) // sdot v6.4s, v13.16b, v1.4b[3] + scvtf v8.4s, v8.4s, #0x4 + scvtf v11.4s, v11.4s, #0x4 + scvtf v10.4s, v10.4s, #0x4 + fmla v19.4s, v8.4s, v9.4s + scvtf v6.4s, v6.4s, #0x4 + fmla v18.4s, v11.4s, v9.4s + fmla v17.4s, v10.4s, v9.4s + fmla v16.4s, v6.4s, v9.4s + subs x20, x20, #0x1 + bgt label_opt_3 + ld1 { v11.4s }, [x27] + ld1 { v10.4s }, [x23] + add x27, x27, #0x10 + add x23, x23, #0x10 + ld1 { v9.4s }, [x22] + ld1 { v8.4s }, [x21] + add x22, x22, #0x10 + add x21, x21, #0x10 + ldr q7, [x11, #0x0] + ldr q6, [x27, #0x0] + add x20, x12, #0x4 + cmp x10, #0x4 + ldr q5, [x23, #0x0] + ldr q4, [x22, #0x0] + scvtf v11.4s, v11.4s + scvtf v10.4s, v10.4s + ldr q3, [x21, #0x0] + ldr q2, [x11, #0x10] + scvtf v9.4s, v9.4s + scvtf v8.4s, v8.4s + ld1r { v1.4s }, [x12] + ld1r { v0.4s }, [x20] + add x11, x11, #0x20 + fmla v31.4s, v7.4s, v11.s[0] + fmla v30.4s, v7.4s, v11.s[1] + fmla v29.4s, v7.4s, v11.s[2] + fmla v28.4s, v7.4s, v11.s[3] + fmla v27.4s, v7.4s, v10.s[0] + fmla v26.4s, v7.4s, v10.s[1] + fmla v25.4s, v7.4s, v10.s[2] + fmla v24.4s, v7.4s, v10.s[3] + fmla v23.4s, v7.4s, v9.s[0] + fmul v31.4s, v31.4s, v6.s[0] + fmla v22.4s, v7.4s, v9.s[1] + fmla v21.4s, v7.4s, v9.s[2] + fmul v30.4s, v30.4s, v6.s[1] + fmla v20.4s, v7.4s, v9.s[3] + fmla v19.4s, v7.4s, v8.s[0] + fmul v29.4s, v29.4s, v6.s[2] + fmla v18.4s, v7.4s, v8.s[1] + fmla v17.4s, v7.4s, v8.s[2] + fmul v28.4s, v28.4s, v6.s[3] + fmla v16.4s, v7.4s, v8.s[3] + fmul v27.4s, v27.4s, v5.s[0] + fmul v26.4s, v26.4s, v5.s[1] + fmul v25.4s, v25.4s, v5.s[2] + fmul v24.4s, v24.4s, v5.s[3] + fmul v23.4s, v23.4s, v4.s[0] + fmul v22.4s, v22.4s, v4.s[1] + fmul v21.4s, v21.4s, v4.s[2] + fmul v20.4s, v20.4s, v4.s[3] + fmul v19.4s, v19.4s, v3.s[0] + fmul v18.4s, v18.4s, v3.s[1] + fmul v17.4s, v17.4s, v3.s[2] + fmul v16.4s, v16.4s, v3.s[3] + fadd v31.4s, v31.4s, v2.4s + fadd v30.4s, v30.4s, v2.4s + fadd v29.4s, v29.4s, v2.4s + fadd v28.4s, v28.4s, v2.4s + fadd v27.4s, v27.4s, v2.4s + fadd v26.4s, v26.4s, v2.4s + fadd v25.4s, v25.4s, v2.4s + fadd v24.4s, v24.4s, v2.4s + fadd v23.4s, v23.4s, v2.4s + fadd v22.4s, v22.4s, v2.4s + fadd v21.4s, v21.4s, v2.4s + fadd v20.4s, v20.4s, v2.4s + fadd v19.4s, v19.4s, v2.4s + fadd v18.4s, v18.4s, v2.4s + fadd v17.4s, v17.4s, v2.4s + fadd v16.4s, v16.4s, v2.4s + fmax v31.4s, v31.4s, v1.4s + fmax v30.4s, v30.4s, v1.4s + fmax v29.4s, v29.4s, v1.4s + fmax v28.4s, v28.4s, v1.4s + fmax v27.4s, v27.4s, v1.4s + fmax v26.4s, v26.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + fmax v24.4s, v24.4s, v1.4s + fmax v23.4s, v23.4s, v1.4s + fmax v22.4s, v22.4s, v1.4s + fmax v21.4s, v21.4s, v1.4s + fmax v20.4s, v20.4s, v1.4s + fmax v19.4s, v19.4s, v1.4s + fmax v18.4s, v18.4s, v1.4s + fmax v17.4s, v17.4s, v1.4s + fmax v16.4s, v16.4s, v1.4s + fmin v31.4s, v31.4s, v0.4s + fmin v30.4s, v30.4s, v0.4s + fmin v29.4s, v29.4s, v0.4s + fmin v28.4s, v28.4s, v0.4s + fmin v27.4s, v27.4s, v0.4s + fmin v26.4s, v26.4s, v0.4s + fmin v25.4s, v25.4s, v0.4s + fmin v24.4s, v24.4s, v0.4s + fmin v23.4s, v23.4s, v0.4s + fmin v22.4s, v22.4s, v0.4s + fmin v21.4s, v21.4s, v0.4s + fmin v20.4s, v20.4s, v0.4s + fmin v19.4s, v19.4s, v0.4s + fmin v18.4s, v18.4s, v0.4s + fmin v17.4s, v17.4s, v0.4s + fmin v16.4s, v16.4s, v0.4s + blt label_opt_8 + mov x20, x15 + str q31, [x20, #0x0] + add x20, x20, x13 + str q30, [x20, #0x0] + add x20, x20, x13 + str q29, [x20, #0x0] + add x20, x20, x13 + str q28, [x20, #0x0] + add x20, x20, x13 + str q27, [x20, #0x0] + add x20, x20, x13 + str q26, [x20, #0x0] + add x20, x20, x13 + str q25, [x20, #0x0] + add x20, x20, x13 + str q24, [x20, #0x0] + add x20, x20, x13 + str q23, [x20, #0x0] + add x20, x20, x13 + str q22, [x20, #0x0] + add x20, x20, x13 + str q21, [x20, #0x0] + add x20, x20, x13 + str q20, [x20, #0x0] + add x20, x20, x13 + str q19, [x20, #0x0] + add x20, x20, x13 + str q18, [x20, #0x0] + add x20, x20, x13 + str q17, [x20, #0x0] + add x20, x20, x13 + str q16, [x20, #0x0] + b label_opt_13 +KAI_ASM_LABEL(label_opt_8) // Partial output + mov x28, x15 + add x26, x28, x13, LSL #2 + add x25, x26, x13, LSL #1 + add x24, x26, x13 + add x23, x25, x13 + add x22, x28, x13, LSL #1 + add x21, x28, x13 + add x20, x22, x13 + add x27, x23, x13 + tbz x10, #1, label_opt_9 + st1 { v24.d }[0], [x23], #0x8 + st1 { v25.d }[0], [x25], #0x8 + st1 { v26.d }[0], [x24], #0x8 + st1 { v27.d }[0], [x26], #0x8 + st1 { v28.d }[0], [x20], #0x8 + st1 { v29.d }[0], [x22], #0x8 + st1 { v30.d }[0], [x21], #0x8 + st1 { v31.d }[0], [x28], #0x8 + tbz x10, #0, label_opt_10 + st1 { v24.s }[2], [x23] + st1 { v25.s }[2], [x25] + st1 { v26.s }[2], [x24] + st1 { v27.s }[2], [x26] + st1 { v28.s }[2], [x20] + st1 { v29.s }[2], [x22] + st1 { v30.s }[2], [x21] + st1 { v31.s }[2], [x28] + b label_opt_10 +KAI_ASM_LABEL(label_opt_9) // Output block 0: partial_1_0 + st1 { v24.s }[0], [x23] + st1 { v25.s }[0], [x25] + st1 { v26.s }[0], [x24] + st1 { v27.s }[0], [x26] + st1 { v28.s }[0], [x20] + st1 { v29.s }[0], [x22] + st1 { v30.s }[0], [x21] + st1 { v31.s }[0], [x28] +KAI_ASM_LABEL(label_opt_10) // Output block 0: Done + add x26, x27, x13, LSL #2 + add x25, x27, x13, LSL #1 + add x24, x26, x13, LSL #1 + add x23, x27, x13 + add x22, x25, x13 + add x21, x26, x13 + add x20, x24, x13 + tbz x10, #1, label_opt_11 + st1 { v16.d }[0], [x20], #0x8 + st1 { v17.d }[0], [x24], #0x8 + st1 { v18.d }[0], [x21], #0x8 + st1 { v19.d }[0], [x26], #0x8 + st1 { v20.d }[0], [x22], #0x8 + st1 { v21.d }[0], [x25], #0x8 + st1 { v22.d }[0], [x23], #0x8 + st1 { v23.d }[0], [x27], #0x8 + tbz x10, #0, label_opt_12 + st1 { v16.s }[2], [x20] + st1 { v17.s }[2], [x24] + st1 { v18.s }[2], [x21] + st1 { v19.s }[2], [x26] + st1 { v20.s }[2], [x22] + st1 { v21.s }[2], [x25] + st1 { v22.s }[2], [x23] + st1 { v23.s }[2], [x27] + b label_opt_12 +KAI_ASM_LABEL(label_opt_11) // Output block 1: partial_1_0 + st1 { v16.s }[0], [x20] + st1 { v17.s }[0], [x24] + st1 { v18.s }[0], [x21] + st1 { v19.s }[0], [x26] + st1 { v20.s }[0], [x22] + st1 { v21.s }[0], [x25] + st1 { v22.s }[0], [x23] + st1 { v23.s }[0], [x27] +KAI_ASM_LABEL(label_opt_12) // Output block 1: Done +KAI_ASM_LABEL(label_opt_13) // Output stage exit + subs x10, x10, #0x4 + add x15, x15, #0x10 + bgt label_opt_2 + mov x20, #0x4 + sub x14, x14, #0x10 + cmp x14, #0x10 + mov x15, x9 + madd x8, x20, x6, x8 + bge label_opt_1 +KAI_ASM_LABEL(label_opt_14) // Row loop skip + cbz x14, label_opt_23 +KAI_ASM_LABEL(label_opt_15) // Row tail: Row loop + mov x26, x17 + mov x25, x16 + add x24, x15, x13, LSL #2 +KAI_ASM_LABEL(label_opt_16) // Row tail: Column loop + movi v31.16b, #0x0 + movi v30.16b, #0x0 + mov x27, x8 + mov x20, x7 + movi v29.16b, #0x0 + movi v28.16b, #0x0 +KAI_ASM_LABEL(label_opt_17) // Row tail: Block loop + ldr q9, [x26, #0x0] + ldr q8, [x27, #0x0] + movi v7.4s, #0x0 + movi v6.4s, #0x0 + ldr q5, [x26, #0x10] + ldr q4, [x27, #0x10] + movi v3.4s, #0x0 + movi v2.4s, #0x0 + ldr q1, [x26, #0x20] + ldr q0, [x27, #0x20] + movi v27.16b, #0xf0 + ldr q26, [x26, #0x30] + ldr q25, [x27, #0x30] + shl v24.16b, v9.16b, #0x4 + add x26, x26, #0x40 + ldr q23, [x27, #0x40] + ldr q22, [x27, #0x50] + shl v21.16b, v5.16b, #0x4 + and v9.16b, v9.16b, v27.16b + ldr q20, [x27, #0x60] + ldr q19, [x27, #0x70] + shl v18.16b, v1.16b, #0x4 + and v5.16b, v5.16b, v27.16b + ldr d16, [x26, #0x0] + KAI_ASM_INST(0x4f88e307) // sdot v7.4s, v24.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e306) // sdot v6.4s, v24.16b, v8.4b[1] + shl v17.16b, v26.16b, #0x4 + KAI_ASM_INST(0x4f88eb03) // sdot v3.4s, v24.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8eb02) // sdot v2.4s, v24.16b, v8.4b[3] + and v1.16b, v1.16b, v27.16b + add x26, x26, #0x8 + and v26.16b, v26.16b, v27.16b + add x27, x27, #0x80 + shll v16.4s, v16.4h, #0x10 + KAI_ASM_INST(0x4f84e2a7) // sdot v7.4s, v21.16b, v4.4b[0] + KAI_ASM_INST(0x4fa4e2a6) // sdot v6.4s, v21.16b, v4.4b[1] + KAI_ASM_INST(0x4f84eaa3) // sdot v3.4s, v21.16b, v4.4b[2] + KAI_ASM_INST(0x4fa4eaa2) // sdot v2.4s, v21.16b, v4.4b[3] + KAI_ASM_INST(0x4f80e247) // sdot v7.4s, v18.16b, v0.4b[0] + KAI_ASM_INST(0x4fa0e246) // sdot v6.4s, v18.16b, v0.4b[1] + KAI_ASM_INST(0x4f80ea43) // sdot v3.4s, v18.16b, v0.4b[2] + KAI_ASM_INST(0x4fa0ea42) // sdot v2.4s, v18.16b, v0.4b[3] + KAI_ASM_INST(0x4f99e227) // sdot v7.4s, v17.16b, v25.4b[0] + KAI_ASM_INST(0x4fb9e226) // sdot v6.4s, v17.16b, v25.4b[1] + KAI_ASM_INST(0x4f99ea23) // sdot v3.4s, v17.16b, v25.4b[2] + KAI_ASM_INST(0x4fb9ea22) // sdot v2.4s, v17.16b, v25.4b[3] + KAI_ASM_INST(0x4f97e127) // sdot v7.4s, v9.16b, v23.4b[0] + KAI_ASM_INST(0x4fb7e126) // sdot v6.4s, v9.16b, v23.4b[1] + KAI_ASM_INST(0x4f97e923) // sdot v3.4s, v9.16b, v23.4b[2] + KAI_ASM_INST(0x4fb7e922) // sdot v2.4s, v9.16b, v23.4b[3] + KAI_ASM_INST(0x4f96e0a7) // sdot v7.4s, v5.16b, v22.4b[0] + KAI_ASM_INST(0x4fb6e0a6) // sdot v6.4s, v5.16b, v22.4b[1] + KAI_ASM_INST(0x4f96e8a3) // sdot v3.4s, v5.16b, v22.4b[2] + KAI_ASM_INST(0x4fb6e8a2) // sdot v2.4s, v5.16b, v22.4b[3] + KAI_ASM_INST(0x4f94e027) // sdot v7.4s, v1.16b, v20.4b[0] + KAI_ASM_INST(0x4fb4e026) // sdot v6.4s, v1.16b, v20.4b[1] + KAI_ASM_INST(0x4f94e823) // sdot v3.4s, v1.16b, v20.4b[2] + KAI_ASM_INST(0x4fb4e822) // sdot v2.4s, v1.16b, v20.4b[3] + KAI_ASM_INST(0x4f93e347) // sdot v7.4s, v26.16b, v19.4b[0] + KAI_ASM_INST(0x4fb3e346) // sdot v6.4s, v26.16b, v19.4b[1] + KAI_ASM_INST(0x4f93eb43) // sdot v3.4s, v26.16b, v19.4b[2] + KAI_ASM_INST(0x4fb3eb42) // sdot v2.4s, v26.16b, v19.4b[3] + scvtf v7.4s, v7.4s, #0x4 + scvtf v6.4s, v6.4s, #0x4 + scvtf v3.4s, v3.4s, #0x4 + fmla v31.4s, v7.4s, v16.4s + scvtf v2.4s, v2.4s, #0x4 + fmla v30.4s, v6.4s, v16.4s + fmla v29.4s, v3.4s, v16.4s + fmla v28.4s, v2.4s, v16.4s + subs x20, x20, #0x1 + bgt label_opt_17 + ld1 { v21.4s }, [x27] + ldr q20, [x26, #0x0] + add x27, x27, #0x10 + add x20, x12, #0x4 + ldr q19, [x27, #0x0] + ldr q18, [x26, #0x10] + cmp x25, #0x4 + add x26, x26, #0x20 + ld1r { v17.4s }, [x12] + ld1r { v16.4s }, [x20] + scvtf v21.4s, v21.4s + fmla v31.4s, v20.4s, v21.s[0] + fmla v30.4s, v20.4s, v21.s[1] + fmla v29.4s, v20.4s, v21.s[2] + fmla v28.4s, v20.4s, v21.s[3] + fmul v31.4s, v31.4s, v19.s[0] + fmul v30.4s, v30.4s, v19.s[1] + fmul v29.4s, v29.4s, v19.s[2] + fadd v31.4s, v31.4s, v18.4s + fmul v28.4s, v28.4s, v19.s[3] + fadd v30.4s, v30.4s, v18.4s + fadd v29.4s, v29.4s, v18.4s + fadd v28.4s, v28.4s, v18.4s + fmax v31.4s, v31.4s, v17.4s + fmax v30.4s, v30.4s, v17.4s + fmax v29.4s, v29.4s, v17.4s + fmax v28.4s, v28.4s, v17.4s + fmin v31.4s, v31.4s, v16.4s + fmin v30.4s, v30.4s, v16.4s + fmin v29.4s, v29.4s, v16.4s + fmin v28.4s, v28.4s, v16.4s + blt label_opt_19 + mov x20, x15 + cmp x14, #0x1 + str q31, [x20, #0x0] + add x20, x20, x13 + ble label_opt_22 + cmp x14, #0x2 + str q30, [x20, #0x0] + add x20, x20, x13 + ble label_opt_22 + cmp x14, #0x3 + str q29, [x20, #0x0] + add x20, x20, x13 + ble label_opt_22 + str q28, [x20, #0x0] + b label_opt_22 +KAI_ASM_LABEL(label_opt_19) // Row tail: Partial output + mov x23, x15 + cmp x14, #0x1 + add x22, x23, x13 + csel x22, x22, x23, GT + cmp x14, #0x2 + add x21, x23, x13, LSL #1 + csel x21, x21, x22, GT + cmp x14, #0x3 + add x20, x21, x13 + csel x20, x20, x21, GT + tbz x25, #1, label_opt_20 + st1 { v28.d }[0], [x20], #0x8 + st1 { v29.d }[0], [x21], #0x8 + st1 { v30.d }[0], [x22], #0x8 + st1 { v31.d }[0], [x23], #0x8 + tbz x25, #0, label_opt_21 + st1 { v28.s }[2], [x20] + st1 { v29.s }[2], [x21] + st1 { v30.s }[2], [x22] + st1 { v31.s }[2], [x23] + b label_opt_21 +KAI_ASM_LABEL(label_opt_20) // Row tail: Output block 0: partial_1_0 + st1 { v28.s }[0], [x20] + st1 { v29.s }[0], [x21] + st1 { v30.s }[0], [x22] + st1 { v31.s }[0], [x23] +KAI_ASM_LABEL(label_opt_21) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_opt_22) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x15, x15, #0x10 + bgt label_opt_16 + subs x14, x14, #0x4 + add x8, x8, x6 + mov x15, x24 + bgt label_opt_15 +KAI_ASM_LABEL(label_opt_23) // Row tail: Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_opt32_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c index b3b153f8..3a98c339 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c @@ -27,6 +27,7 @@ typedef struct { } KernelArgs; void kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(KernelArgs* args_ptr); +void kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(KernelArgs* args_ptr); // Compute args static const size_t kai_m_step = 16; @@ -179,7 +180,11 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( args.num_blocks = num_blocks; args.num_subblocks = num_subblocks; - kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(&args); + if (bl == 32) { + kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(&args); + } else { + kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(&args); + } } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h index 37186f89..0263a79d 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h @@ -61,8 +61,6 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void /// /// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. /// @param[in] k Total number of columns in the LHS matrix (not packed). -/// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. It must be a multiple of 32. /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S index 9fca8668..cc3c75bc 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S @@ -742,3 +742,639 @@ KAI_ASM_LABEL(label_25) // Row tail: Row loop skip KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm) KAI_ASM_END + +// Optimized kernel for bl = 32 + + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x6, #0x80 + mov x21, #0x20 + ldr x20, [x0, #0x28] + ldr x7, [x0, #0x38] + ldr x8, [x0, #0x8] + ldr x17, [x0, #0x10] + ldr x16, [x0, #0x30] + ldr x15, [x0, #0x0] + mov x14, x20 + ldr x13, [x0, #0x20] + madd x6, x7, x6, x21 + ldr x12, [x0, #0x18] + cmp x14, #0x10 + blt label_opt_14 +KAI_ASM_LABEL(label_opt_1) // Row loop + mov x11, x17 + mov x10, x16 + add x9, x15, x13, LSL #4 +KAI_ASM_LABEL(label_opt_2) // Column loop + mov x27, x8 + movi v31.16b, #0x0 + movi v30.16b, #0x0 + mov x20, x7 + movi v29.16b, #0x0 + movi v28.16b, #0x0 + movi v27.16b, #0x0 + movi v26.16b, #0x0 + add x23, x27, x6 + add x22, x23, x6 + movi v25.16b, #0x0 + movi v24.16b, #0x0 + add x21, x22, x6 + movi v23.16b, #0x0 + movi v22.16b, #0x0 + movi v21.16b, #0x0 + movi v20.16b, #0x0 + movi v19.16b, #0x0 + movi v18.16b, #0x0 + movi v17.16b, #0x0 + movi v16.16b, #0x0 +KAI_ASM_LABEL(label_opt_3) // Block loop + ldr q11, [x11, #0x0] + ldr q4, [x11, #0x10] + movi v2.4s, #0x0 + movi v9.4s, #0x0 + ldr q12, [x27, #0x0] + ldr q0, [x27, #0x10] + movi v7.4s, #0x0 + movi v5.4s, #0x0 + ldr q15, [x11, #0x20] + ldr q13, [x11, #0x30] + movi v10.16b, #0xf0 + add x11, x11, #0x40 + ldr q8, [x27, #0x20] + ldr q6, [x27, #0x30] + shl v14.16b, v11.16b, #0x4 + shl v3.16b, v4.16b, #0x4 + ldr q1, [x27, #0x40] + and v11.16b, v11.16b, v10.16b + and v4.16b, v4.16b, v10.16b + KAI_ASM_INST(0x4e8ea582) // smmla v2.4s, v12.16b, v14.16b + KAI_ASM_INST(0x4e83a589) // smmla v9.4s, v12.16b, v3.16b + shl v12.16b, v15.16b, #0x4 + KAI_ASM_INST(0x4e8ea407) // smmla v7.4s, v0.16b, v14.16b + KAI_ASM_INST(0x4e83a405) // smmla v5.4s, v0.16b, v3.16b + shl v0.16b, v13.16b, #0x4 + and v15.16b, v15.16b, v10.16b + and v13.16b, v13.16b, v10.16b + ldr q10, [x27, #0x50] + KAI_ASM_INST(0x4e8ca502) // smmla v2.4s, v8.16b, v12.16b + KAI_ASM_INST(0x4e80a509) // smmla v9.4s, v8.16b, v0.16b + ldr q8, [x27, #0x60] + KAI_ASM_INST(0x4e8ca4c7) // smmla v7.4s, v6.16b, v12.16b + KAI_ASM_INST(0x4e80a4c5) // smmla v5.4s, v6.16b, v0.16b + ldr q6, [x27, #0x70] + add x27, x27, #0x80 + KAI_ASM_INST(0x4e8ba422) // smmla v2.4s, v1.16b, v11.16b + KAI_ASM_INST(0x4e84a429) // smmla v9.4s, v1.16b, v4.16b + ldr d1, [x11, #0x0] + add x11, x11, #0x8 + KAI_ASM_INST(0x4e8ba547) // smmla v7.4s, v10.16b, v11.16b + KAI_ASM_INST(0x4e84a545) // smmla v5.4s, v10.16b, v4.16b + KAI_ASM_INST(0x4e8fa502) // smmla v2.4s, v8.16b, v15.16b + shll v1.4s, v1.4h, #0x10 + KAI_ASM_INST(0x4e8da509) // smmla v9.4s, v8.16b, v13.16b + KAI_ASM_INST(0x4e8fa4c7) // smmla v7.4s, v6.16b, v15.16b + KAI_ASM_INST(0x4e8da4c5) // smmla v5.4s, v6.16b, v13.16b + uzp1 v6.2d, v2.2d, v9.2d + uzp2 v8.2d, v2.2d, v9.2d + scvtf v6.4s, v6.4s, #0x4 + uzp1 v9.2d, v7.2d, v5.2d + uzp2 v2.2d, v7.2d, v5.2d + scvtf v8.4s, v8.4s, #0x4 + fmla v31.4s, v6.4s, v1.4s + scvtf v9.4s, v9.4s, #0x4 + scvtf v2.4s, v2.4s, #0x4 + fmla v30.4s, v8.4s, v1.4s + fmla v29.4s, v9.4s, v1.4s + fmla v28.4s, v2.4s, v1.4s + ldr q9, [x23, #0x0] + ldr q7, [x23, #0x10] + movi v8.4s, #0x0 + movi v2.4s, #0x0 + ldr q5, [x23, #0x20] + ldr q10, [x23, #0x30] + movi v6.4s, #0x0 + KAI_ASM_INST(0x4e8ea528) // smmla v8.4s, v9.16b, v14.16b + KAI_ASM_INST(0x4e83a522) // smmla v2.4s, v9.16b, v3.16b + ldr q9, [x23, #0x40] + KAI_ASM_INST(0x4e8ea4e6) // smmla v6.4s, v7.16b, v14.16b + KAI_ASM_INST(0x4e8ca4a8) // smmla v8.4s, v5.16b, v12.16b + KAI_ASM_INST(0x4e80a4a2) // smmla v2.4s, v5.16b, v0.16b + ldr q5, [x23, #0x50] + KAI_ASM_INST(0x4e8ca546) // smmla v6.4s, v10.16b, v12.16b + KAI_ASM_INST(0x4e8ba528) // smmla v8.4s, v9.16b, v11.16b + KAI_ASM_INST(0x4e84a522) // smmla v2.4s, v9.16b, v4.16b + ldr q9, [x23, #0x60] + KAI_ASM_INST(0x4e8ba4a6) // smmla v6.4s, v5.16b, v11.16b + KAI_ASM_INST(0x4e8fa528) // smmla v8.4s, v9.16b, v15.16b + KAI_ASM_INST(0x4e8da522) // smmla v2.4s, v9.16b, v13.16b + movi v9.4s, #0x0 + KAI_ASM_INST(0x4e83a4e9) // smmla v9.4s, v7.16b, v3.16b + ldr q7, [x23, #0x70] + add x23, x23, #0x80 + KAI_ASM_INST(0x4e8fa4e6) // smmla v6.4s, v7.16b, v15.16b + KAI_ASM_INST(0x4e80a549) // smmla v9.4s, v10.16b, v0.16b + uzp1 v10.2d, v8.2d, v2.2d + uzp2 v2.2d, v8.2d, v2.2d + scvtf v10.4s, v10.4s, #0x4 + KAI_ASM_INST(0x4e84a4a9) // smmla v9.4s, v5.16b, v4.16b + scvtf v2.4s, v2.4s, #0x4 + fmla v27.4s, v10.4s, v1.4s + KAI_ASM_INST(0x4e8da4e9) // smmla v9.4s, v7.16b, v13.16b + fmla v26.4s, v2.4s, v1.4s + uzp1 v2.2d, v6.2d, v9.2d + uzp2 v10.2d, v6.2d, v9.2d + scvtf v2.4s, v2.4s, #0x4 + scvtf v10.4s, v10.4s, #0x4 + fmla v25.4s, v2.4s, v1.4s + fmla v24.4s, v10.4s, v1.4s + ldr q8, [x22, #0x0] + ldr q7, [x22, #0x10] + movi v9.4s, #0x0 + movi v6.4s, #0x0 + ldr q2, [x22, #0x20] + ldr q5, [x22, #0x30] + movi v10.4s, #0x0 + KAI_ASM_INST(0x4e8ea509) // smmla v9.4s, v8.16b, v14.16b + KAI_ASM_INST(0x4e83a506) // smmla v6.4s, v8.16b, v3.16b + ldr q8, [x22, #0x40] + KAI_ASM_INST(0x4e8ea4ea) // smmla v10.4s, v7.16b, v14.16b + KAI_ASM_INST(0x4e8ca449) // smmla v9.4s, v2.16b, v12.16b + KAI_ASM_INST(0x4e80a446) // smmla v6.4s, v2.16b, v0.16b + ldr q2, [x22, #0x50] + KAI_ASM_INST(0x4e8ca4aa) // smmla v10.4s, v5.16b, v12.16b + KAI_ASM_INST(0x4e8ba509) // smmla v9.4s, v8.16b, v11.16b + KAI_ASM_INST(0x4e84a506) // smmla v6.4s, v8.16b, v4.16b + ldr q8, [x22, #0x60] + KAI_ASM_INST(0x4e8ba44a) // smmla v10.4s, v2.16b, v11.16b + KAI_ASM_INST(0x4e8fa509) // smmla v9.4s, v8.16b, v15.16b + KAI_ASM_INST(0x4e8da506) // smmla v6.4s, v8.16b, v13.16b + movi v8.4s, #0x0 + KAI_ASM_INST(0x4e83a4e8) // smmla v8.4s, v7.16b, v3.16b + ldr q7, [x22, #0x70] + add x22, x22, #0x80 + KAI_ASM_INST(0x4e8fa4ea) // smmla v10.4s, v7.16b, v15.16b + KAI_ASM_INST(0x4e80a4a8) // smmla v8.4s, v5.16b, v0.16b + uzp1 v5.2d, v9.2d, v6.2d + uzp2 v9.2d, v9.2d, v6.2d + scvtf v5.4s, v5.4s, #0x4 + KAI_ASM_INST(0x4e84a448) // smmla v8.4s, v2.16b, v4.16b + scvtf v9.4s, v9.4s, #0x4 + fmla v23.4s, v5.4s, v1.4s + KAI_ASM_INST(0x4e8da4e8) // smmla v8.4s, v7.16b, v13.16b + fmla v22.4s, v9.4s, v1.4s + uzp1 v2.2d, v10.2d, v8.2d + uzp2 v10.2d, v10.2d, v8.2d + scvtf v2.4s, v2.4s, #0x4 + scvtf v10.4s, v10.4s, #0x4 + fmla v21.4s, v2.4s, v1.4s + fmla v20.4s, v10.4s, v1.4s + ldr q2, [x21, #0x0] + ldr q10, [x21, #0x10] + movi v6.4s, #0x0 + movi v9.4s, #0x0 + ldr q5, [x21, #0x20] + ldr q8, [x21, #0x30] + movi v7.4s, #0x0 + KAI_ASM_INST(0x4e8ea446) // smmla v6.4s, v2.16b, v14.16b + KAI_ASM_INST(0x4e83a449) // smmla v9.4s, v2.16b, v3.16b + ldr q2, [x21, #0x40] + KAI_ASM_INST(0x4e8ea547) // smmla v7.4s, v10.16b, v14.16b + ldr q14, [x21, #0x50] + KAI_ASM_INST(0x4e8ca4a6) // smmla v6.4s, v5.16b, v12.16b + KAI_ASM_INST(0x4e80a4a9) // smmla v9.4s, v5.16b, v0.16b + ldr q5, [x21, #0x60] + KAI_ASM_INST(0x4e8ca507) // smmla v7.4s, v8.16b, v12.16b + ldr q12, [x21, #0x70] + add x21, x21, #0x80 + KAI_ASM_INST(0x4e8ba446) // smmla v6.4s, v2.16b, v11.16b + KAI_ASM_INST(0x4e84a449) // smmla v9.4s, v2.16b, v4.16b + movi v2.4s, #0x0 + KAI_ASM_INST(0x4e83a542) // smmla v2.4s, v10.16b, v3.16b + KAI_ASM_INST(0x4e8ba5c7) // smmla v7.4s, v14.16b, v11.16b + KAI_ASM_INST(0x4e8fa4a6) // smmla v6.4s, v5.16b, v15.16b + KAI_ASM_INST(0x4e80a502) // smmla v2.4s, v8.16b, v0.16b + KAI_ASM_INST(0x4e8da4a9) // smmla v9.4s, v5.16b, v13.16b + KAI_ASM_INST(0x4e8fa587) // smmla v7.4s, v12.16b, v15.16b + KAI_ASM_INST(0x4e84a5c2) // smmla v2.4s, v14.16b, v4.16b + uzp1 v11.2d, v6.2d, v9.2d + uzp2 v14.2d, v6.2d, v9.2d + scvtf v11.4s, v11.4s, #0x4 + KAI_ASM_INST(0x4e8da582) // smmla v2.4s, v12.16b, v13.16b + scvtf v14.4s, v14.4s, #0x4 + fmla v19.4s, v11.4s, v1.4s + uzp1 v9.2d, v7.2d, v2.2d + uzp2 v0.2d, v7.2d, v2.2d + fmla v18.4s, v14.4s, v1.4s + scvtf v9.4s, v9.4s, #0x4 + scvtf v0.4s, v0.4s, #0x4 + fmla v17.4s, v9.4s, v1.4s + fmla v16.4s, v0.4s, v1.4s + subs x20, x20, #0x1 + bgt label_opt_3 + ld1 { v11.4s }, [x27] + ld1 { v10.4s }, [x23] + add x27, x27, #0x10 + add x23, x23, #0x10 + ld1 { v9.4s }, [x22] + ld1 { v8.4s }, [x21] + add x22, x22, #0x10 + add x21, x21, #0x10 + ldr q7, [x11, #0x0] + ldr q6, [x27, #0x0] + add x20, x12, #0x4 + cmp x10, #0x4 + ldr q5, [x23, #0x0] + ldr q4, [x22, #0x0] + scvtf v11.4s, v11.4s + scvtf v10.4s, v10.4s + ldr q3, [x21, #0x0] + ldr q2, [x11, #0x10] + scvtf v9.4s, v9.4s + scvtf v8.4s, v8.4s + ld1r { v1.4s }, [x12] + ld1r { v0.4s }, [x20] + add x11, x11, #0x20 + fmla v31.4s, v7.4s, v11.s[0] + fmla v30.4s, v7.4s, v11.s[1] + fmla v29.4s, v7.4s, v11.s[2] + fmla v28.4s, v7.4s, v11.s[3] + fmla v27.4s, v7.4s, v10.s[0] + fmla v26.4s, v7.4s, v10.s[1] + fmla v25.4s, v7.4s, v10.s[2] + fmla v24.4s, v7.4s, v10.s[3] + fmla v23.4s, v7.4s, v9.s[0] + fmul v31.4s, v31.4s, v6.s[0] + fmla v22.4s, v7.4s, v9.s[1] + fmla v21.4s, v7.4s, v9.s[2] + fmul v30.4s, v30.4s, v6.s[1] + fmla v20.4s, v7.4s, v9.s[3] + fmla v19.4s, v7.4s, v8.s[0] + fmul v29.4s, v29.4s, v6.s[2] + fmla v18.4s, v7.4s, v8.s[1] + fmla v17.4s, v7.4s, v8.s[2] + fmul v28.4s, v28.4s, v6.s[3] + fmla v16.4s, v7.4s, v8.s[3] + fmul v27.4s, v27.4s, v5.s[0] + fmul v26.4s, v26.4s, v5.s[1] + fmul v25.4s, v25.4s, v5.s[2] + fmul v24.4s, v24.4s, v5.s[3] + fmul v23.4s, v23.4s, v4.s[0] + fmul v22.4s, v22.4s, v4.s[1] + fmul v21.4s, v21.4s, v4.s[2] + fmul v20.4s, v20.4s, v4.s[3] + fmul v19.4s, v19.4s, v3.s[0] + fmul v18.4s, v18.4s, v3.s[1] + fmul v17.4s, v17.4s, v3.s[2] + fmul v16.4s, v16.4s, v3.s[3] + fadd v31.4s, v31.4s, v2.4s + fadd v30.4s, v30.4s, v2.4s + fadd v29.4s, v29.4s, v2.4s + fadd v28.4s, v28.4s, v2.4s + fadd v27.4s, v27.4s, v2.4s + fadd v26.4s, v26.4s, v2.4s + fadd v25.4s, v25.4s, v2.4s + fadd v24.4s, v24.4s, v2.4s + fadd v23.4s, v23.4s, v2.4s + fadd v22.4s, v22.4s, v2.4s + fadd v21.4s, v21.4s, v2.4s + fadd v20.4s, v20.4s, v2.4s + fadd v19.4s, v19.4s, v2.4s + fadd v18.4s, v18.4s, v2.4s + fadd v17.4s, v17.4s, v2.4s + fadd v16.4s, v16.4s, v2.4s + fmax v31.4s, v31.4s, v1.4s + fmax v30.4s, v30.4s, v1.4s + fmax v29.4s, v29.4s, v1.4s + fmax v28.4s, v28.4s, v1.4s + fmax v27.4s, v27.4s, v1.4s + fmax v26.4s, v26.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + fmax v24.4s, v24.4s, v1.4s + fmax v23.4s, v23.4s, v1.4s + fmax v22.4s, v22.4s, v1.4s + fmax v21.4s, v21.4s, v1.4s + fmax v20.4s, v20.4s, v1.4s + fmax v19.4s, v19.4s, v1.4s + fmax v18.4s, v18.4s, v1.4s + fmax v17.4s, v17.4s, v1.4s + fmax v16.4s, v16.4s, v1.4s + fmin v31.4s, v31.4s, v0.4s + fmin v30.4s, v30.4s, v0.4s + fmin v29.4s, v29.4s, v0.4s + fmin v28.4s, v28.4s, v0.4s + fmin v27.4s, v27.4s, v0.4s + fmin v26.4s, v26.4s, v0.4s + fmin v25.4s, v25.4s, v0.4s + fmin v24.4s, v24.4s, v0.4s + fmin v23.4s, v23.4s, v0.4s + fmin v22.4s, v22.4s, v0.4s + fmin v21.4s, v21.4s, v0.4s + fmin v20.4s, v20.4s, v0.4s + fmin v19.4s, v19.4s, v0.4s + fmin v18.4s, v18.4s, v0.4s + fmin v17.4s, v17.4s, v0.4s + fmin v16.4s, v16.4s, v0.4s + blt label_opt_8 + mov x20, x15 + str q31, [x20, #0x0] + add x20, x20, x13 + str q30, [x20, #0x0] + add x20, x20, x13 + str q29, [x20, #0x0] + add x20, x20, x13 + str q28, [x20, #0x0] + add x20, x20, x13 + str q27, [x20, #0x0] + add x20, x20, x13 + str q26, [x20, #0x0] + add x20, x20, x13 + str q25, [x20, #0x0] + add x20, x20, x13 + str q24, [x20, #0x0] + add x20, x20, x13 + str q23, [x20, #0x0] + add x20, x20, x13 + str q22, [x20, #0x0] + add x20, x20, x13 + str q21, [x20, #0x0] + add x20, x20, x13 + str q20, [x20, #0x0] + add x20, x20, x13 + str q19, [x20, #0x0] + add x20, x20, x13 + str q18, [x20, #0x0] + add x20, x20, x13 + str q17, [x20, #0x0] + add x20, x20, x13 + str q16, [x20, #0x0] + b label_opt_13 +KAI_ASM_LABEL(label_opt_8) // Partial output + mov x28, x15 + add x26, x28, x13, LSL #2 + add x25, x26, x13, LSL #1 + add x24, x26, x13 + add x23, x25, x13 + add x22, x28, x13, LSL #1 + add x21, x28, x13 + add x20, x22, x13 + add x27, x23, x13 + tbz x10, #1, label_opt_9 + st1 { v24.d }[0], [x23], #0x8 + st1 { v25.d }[0], [x25], #0x8 + st1 { v26.d }[0], [x24], #0x8 + st1 { v27.d }[0], [x26], #0x8 + st1 { v28.d }[0], [x20], #0x8 + st1 { v29.d }[0], [x22], #0x8 + st1 { v30.d }[0], [x21], #0x8 + st1 { v31.d }[0], [x28], #0x8 + tbz x10, #0, label_opt_10 + st1 { v24.s }[2], [x23] + st1 { v25.s }[2], [x25] + st1 { v26.s }[2], [x24] + st1 { v27.s }[2], [x26] + st1 { v28.s }[2], [x20] + st1 { v29.s }[2], [x22] + st1 { v30.s }[2], [x21] + st1 { v31.s }[2], [x28] + b label_opt_10 +KAI_ASM_LABEL(label_opt_9) // Output block 0: partial_1_0 + st1 { v24.s }[0], [x23] + st1 { v25.s }[0], [x25] + st1 { v26.s }[0], [x24] + st1 { v27.s }[0], [x26] + st1 { v28.s }[0], [x20] + st1 { v29.s }[0], [x22] + st1 { v30.s }[0], [x21] + st1 { v31.s }[0], [x28] +KAI_ASM_LABEL(label_opt_10) // Output block 0: Done + add x26, x27, x13, LSL #2 + add x25, x27, x13, LSL #1 + add x24, x26, x13, LSL #1 + add x23, x27, x13 + add x22, x25, x13 + add x21, x26, x13 + add x20, x24, x13 + tbz x10, #1, label_opt_11 + st1 { v16.d }[0], [x20], #0x8 + st1 { v17.d }[0], [x24], #0x8 + st1 { v18.d }[0], [x21], #0x8 + st1 { v19.d }[0], [x26], #0x8 + st1 { v20.d }[0], [x22], #0x8 + st1 { v21.d }[0], [x25], #0x8 + st1 { v22.d }[0], [x23], #0x8 + st1 { v23.d }[0], [x27], #0x8 + tbz x10, #0, label_opt_12 + st1 { v16.s }[2], [x20] + st1 { v17.s }[2], [x24] + st1 { v18.s }[2], [x21] + st1 { v19.s }[2], [x26] + st1 { v20.s }[2], [x22] + st1 { v21.s }[2], [x25] + st1 { v22.s }[2], [x23] + st1 { v23.s }[2], [x27] + b label_opt_12 +KAI_ASM_LABEL(label_opt_11) // Output block 1: partial_1_0 + st1 { v16.s }[0], [x20] + st1 { v17.s }[0], [x24] + st1 { v18.s }[0], [x21] + st1 { v19.s }[0], [x26] + st1 { v20.s }[0], [x22] + st1 { v21.s }[0], [x25] + st1 { v22.s }[0], [x23] + st1 { v23.s }[0], [x27] +KAI_ASM_LABEL(label_opt_12) // Output block 1: Done +KAI_ASM_LABEL(label_opt_13) // Output stage exit + subs x10, x10, #0x4 + add x15, x15, #0x10 + bgt label_opt_2 + mov x20, #0x4 + sub x14, x14, #0x10 + cmp x14, #0x10 + mov x15, x9 + madd x8, x20, x6, x8 + bge label_opt_1 +KAI_ASM_LABEL(label_opt_14) // Row loop skip + cbz x14, label_opt_23 +KAI_ASM_LABEL(label_opt_15) // Row tail: Row loop + mov x26, x17 + mov x25, x16 + add x24, x15, x13, LSL #2 +KAI_ASM_LABEL(label_opt_16) // Row tail: Column loop + movi v31.16b, #0x0 + movi v30.16b, #0x0 + mov x27, x8 + mov x20, x7 + movi v29.16b, #0x0 + movi v28.16b, #0x0 +KAI_ASM_LABEL(label_opt_17) // Row tail: Block loop + ldr q9, [x26, #0x0] + ldr q8, [x26, #0x10] + movi v7.4s, #0x0 + movi v6.4s, #0x0 + ldr q5, [x27, #0x0] + ldr q4, [x27, #0x10] + movi v3.4s, #0x0 + movi v2.4s, #0x0 + ldr q1, [x26, #0x20] + ldr q0, [x26, #0x30] + movi v27.16b, #0xf0 + add x26, x26, #0x40 + ldr q26, [x27, #0x20] + ldr q25, [x27, #0x30] + shl v24.16b, v9.16b, #0x4 + shl v20.16b, v8.16b, #0x4 + ldr q23, [x27, #0x40] + ldr q22, [x27, #0x50] + and v9.16b, v9.16b, v27.16b + and v8.16b, v8.16b, v27.16b + ldr q21, [x27, #0x60] + ldr q19, [x27, #0x70] + shl v18.16b, v1.16b, #0x4 + shl v17.16b, v0.16b, #0x4 + ldr d16, [x26, #0x0] + KAI_ASM_INST(0x4e98a4a7) // smmla v7.4s, v5.16b, v24.16b + KAI_ASM_INST(0x4e94a4a6) // smmla v6.4s, v5.16b, v20.16b + and v1.16b, v1.16b, v27.16b + KAI_ASM_INST(0x4e98a483) // smmla v3.4s, v4.16b, v24.16b + KAI_ASM_INST(0x4e94a482) // smmla v2.4s, v4.16b, v20.16b + and v0.16b, v0.16b, v27.16b + add x26, x26, #0x8 + add x27, x27, #0x80 + shll v20.4s, v16.4h, #0x10 + KAI_ASM_INST(0x4e92a747) // smmla v7.4s, v26.16b, v18.16b + KAI_ASM_INST(0x4e91a746) // smmla v6.4s, v26.16b, v17.16b + KAI_ASM_INST(0x4e92a723) // smmla v3.4s, v25.16b, v18.16b + KAI_ASM_INST(0x4e91a722) // smmla v2.4s, v25.16b, v17.16b + KAI_ASM_INST(0x4e89a6e7) // smmla v7.4s, v23.16b, v9.16b + KAI_ASM_INST(0x4e88a6e6) // smmla v6.4s, v23.16b, v8.16b + KAI_ASM_INST(0x4e89a6c3) // smmla v3.4s, v22.16b, v9.16b + KAI_ASM_INST(0x4e88a6c2) // smmla v2.4s, v22.16b, v8.16b + KAI_ASM_INST(0x4e81a6a7) // smmla v7.4s, v21.16b, v1.16b + KAI_ASM_INST(0x4e80a6a6) // smmla v6.4s, v21.16b, v0.16b + KAI_ASM_INST(0x4e81a663) // smmla v3.4s, v19.16b, v1.16b + KAI_ASM_INST(0x4e80a662) // smmla v2.4s, v19.16b, v0.16b + uzp1 v19.2d, v7.2d, v6.2d + uzp2 v18.2d, v7.2d, v6.2d + scvtf v19.4s, v19.4s, #0x4 + uzp1 v17.2d, v3.2d, v2.2d + uzp2 v16.2d, v3.2d, v2.2d + scvtf v18.4s, v18.4s, #0x4 + fmla v31.4s, v19.4s, v20.4s + scvtf v17.4s, v17.4s, #0x4 + scvtf v16.4s, v16.4s, #0x4 + fmla v30.4s, v18.4s, v20.4s + fmla v29.4s, v17.4s, v20.4s + fmla v28.4s, v16.4s, v20.4s + subs x20, x20, #0x1 + bgt label_opt_17 + ld1 { v21.4s }, [x27] + ldr q20, [x26, #0x0] + add x27, x27, #0x10 + add x20, x12, #0x4 + ldr q19, [x27, #0x0] + ldr q18, [x26, #0x10] + cmp x25, #0x4 + add x26, x26, #0x20 + ld1r { v17.4s }, [x12] + ld1r { v16.4s }, [x20] + scvtf v21.4s, v21.4s + fmla v31.4s, v20.4s, v21.s[0] + fmla v30.4s, v20.4s, v21.s[1] + fmla v29.4s, v20.4s, v21.s[2] + fmla v28.4s, v20.4s, v21.s[3] + fmul v31.4s, v31.4s, v19.s[0] + fmul v30.4s, v30.4s, v19.s[1] + fmul v29.4s, v29.4s, v19.s[2] + fadd v31.4s, v31.4s, v18.4s + fmul v28.4s, v28.4s, v19.s[3] + fadd v30.4s, v30.4s, v18.4s + fadd v29.4s, v29.4s, v18.4s + fadd v28.4s, v28.4s, v18.4s + fmax v31.4s, v31.4s, v17.4s + fmax v30.4s, v30.4s, v17.4s + fmax v29.4s, v29.4s, v17.4s + fmax v28.4s, v28.4s, v17.4s + fmin v31.4s, v31.4s, v16.4s + fmin v30.4s, v30.4s, v16.4s + fmin v29.4s, v29.4s, v16.4s + fmin v28.4s, v28.4s, v16.4s + blt label_opt_19 + mov x20, x15 + cmp x14, #0x1 + str q31, [x20, #0x0] + add x20, x20, x13 + ble label_opt_22 + cmp x14, #0x2 + str q30, [x20, #0x0] + add x20, x20, x13 + ble label_opt_22 + cmp x14, #0x3 + str q29, [x20, #0x0] + add x20, x20, x13 + ble label_opt_22 + str q28, [x20, #0x0] + b label_opt_22 +KAI_ASM_LABEL(label_opt_19) // Row tail: Partial output + mov x23, x15 + cmp x14, #0x1 + add x22, x23, x13 + csel x22, x22, x23, GT + cmp x14, #0x2 + add x21, x23, x13, LSL #1 + csel x21, x21, x22, GT + cmp x14, #0x3 + add x20, x21, x13 + csel x20, x20, x21, GT + tbz x25, #1, label_opt_20 + st1 { v28.d }[0], [x20], #0x8 + st1 { v29.d }[0], [x21], #0x8 + st1 { v30.d }[0], [x22], #0x8 + st1 { v31.d }[0], [x23], #0x8 + tbz x25, #0, label_opt_21 + st1 { v28.s }[2], [x20] + st1 { v29.s }[2], [x21] + st1 { v30.s }[2], [x22] + st1 { v31.s }[2], [x23] + b label_opt_21 +KAI_ASM_LABEL(label_opt_20) // Row tail: Output block 0: partial_1_0 + st1 { v28.s }[0], [x20] + st1 { v29.s }[0], [x21] + st1 { v30.s }[0], [x22] + st1 { v31.s }[0], [x23] +KAI_ASM_LABEL(label_opt_21) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_opt_22) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x15, x15, #0x10 + bgt label_opt_16 + subs x14, x14, #0x4 + add x8, x8, x6 + mov x15, x24 + bgt label_opt_15 +KAI_ASM_LABEL(label_opt_23) // Row tail: Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c deleted file mode 100644 index 503e207c..00000000 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.c +++ /dev/null @@ -1,183 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) -#error "I8mm extension required to compile this micro-kernel" -#else // Architectural features check. - -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.h" - -#include -#include - -#include "kai/kai_common.h" - -typedef struct { - float* dst; - const void* lhs_packed; - const void* rhs_packed; - const float* clamp_vals; - size_t dst_stride_row; - size_t m; - size_t n; - size_t num_blocks; -} KernelArgs; - -void kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(KernelArgs* args_ptr); - -// Compute args -static const size_t kai_m_step = 16; -static const size_t kai_n_step = 4; -// Packing args -static const size_t kai_mr = 4; -static const size_t kai_nr = 4; -static const size_t kai_kr = 16; -static const size_t kai_sr = 2; -// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) -static const size_t kai_num_bytes_qvalue_lhs = 1; -static const size_t kai_num_bytes_multiplier_lhs = 4; -static const size_t kai_num_bytes_zp_lhs = 4; -// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is -// asymmetric)) -static const size_t kai_num_bytes_recip_qvalue_rhs = 2; -static const size_t kai_num_bytes_multiplier_rhs = 2; -static const size_t kai_num_bytes_rsum_rhs = 4; -// DST format args -static const size_t kai_num_bytes_dst_value = 4; -// Extra args -static const size_t kai_num_bytes_bias = 4; -static const size_t kai_k_multiple_of = 32; -static const size_t kai_bl = 32; - -inline static size_t kai_get_k_roundedup(size_t k) { - return kai_roundup(k, kai_k_multiple_of); -} - -inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { - KAI_ASSUME((bl % kai_bl) == 0); - size_t num_bytes_per_block_rhs = (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs; - return num_bytes_per_block_rhs; -} - -inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSUME((bl % kai_bl) == 0); - - return kai_roundup(k, bl) / bl; -} - -inline static size_t kai_get_lhs_packed_stride(size_t k) { - const size_t k_internal = kai_get_k_roundedup(k); - size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); - // Since the LHS matrix is asymmetric with per-row quantization, we must include the - // the number of bytes to hold the zero point value - lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; - - return lhs_packed_stride; -} - -inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { - KAI_ASSUME((bl % kai_bl) == 0); - - const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); - - size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); - // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include - // the number of bytes for the reduction sum - rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; - // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias - rhs_packed_stride += kai_nr * kai_num_bytes_bias; - - return rhs_packed_stride; -} - -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void) { - return kai_m_step; -} - -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void) { - return kai_n_step; -} - -size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void) { - return kai_mr; -} - -size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void) { - return kai_nr; -} - -size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void) { - return kai_kr; -} - -size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void) { - return kai_sr; -} - -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( - size_t m_idx, size_t k) { - KAI_ASSUME((m_idx % kai_m_step) == 0); - - return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); -} - -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( - size_t n_idx, size_t k, size_t bl) { - KAI_ASSUME((k % bl) == 0); - KAI_ASSUME((n_idx % kai_n_step) == 0); - - return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); -} - -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( - size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSUME((m_idx % kai_m_step) == 0); - KAI_ASSUME((n_idx % kai_n_step) == 0); - - return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; -} - -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(size_t m, size_t n) { - return m * n * kai_num_bytes_dst_value; -} - -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( - size_t m, // - size_t n, // - size_t k, // - size_t bl, // - const void* restrict lhs_packed, // - const void* restrict rhs_packed, // - float* restrict dst, // NOLINT(readability-non-const-parameter) - size_t dst_stride_row, // - size_t dst_stride_col, // - float scalar_min, // - float scalar_max) { - KAI_ASSUME(dst_stride_col == sizeof(float)); - KAI_ASSUME((k % bl) == 0); - KAI_ASSUME((bl % kai_bl) == 0); - - if (m == 0) { - return; - } - const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); - const float clamp_vals[2] = {scalar_min, scalar_max}; - - KernelArgs args; - - args.dst = dst; - args.lhs_packed = lhs_packed; - args.rhs_packed = rhs_packed; - args.clamp_vals = clamp_vals; - args.dst_stride_row = dst_stride_row; - args.m = m; - args.n = n; - args.num_blocks = num_blocks; - - kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(&args); -} - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.h deleted file mode 100644 index 4a11aa0b..00000000 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm.h +++ /dev/null @@ -1,147 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -/// Micro-kernel dependencies -/// -/// -# @ref kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. -/// -# @ref kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS NxK matrix. -/// -# @ref kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS KxN matrix. - -/// -------------------------------------------------- - -/// Gets the m step value. -/// The micro-kernel can process any M values. However, the starting M index to -/// be processed must be a multiple of m step. -/// -/// @return the m step value -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void); - -/// Gets the n step value. -/// The micro-kernel can process any N values. However, the starting N index to -/// be processed must be a multiple of n step. -/// -/// @return the n step -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void); - -/// Gets the mr value, which must be used to pack the LHS matrix -/// -/// @return the mr value -size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void); - -/// Gets the nr value, which must be used to pack the RHS matrix. -/// -/// @return the nr value -size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void); - -/// Gets the kr value, which must be used to pack the LHS and RHS matrices -/// -/// @return the kr value -size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void); - -/// Gets the sr value, which must be used to pack the LHS and RHS matrices -/// -/// @return the sr value -size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm(void); - -/// Gets the offset in bytes for the packed LHS matrix, -/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. -/// -/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. -/// -/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. -/// @param[in] k Total number of columns in the LHS matrix (not packed). -/// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. It must be a multiple of 32. -/// -/// @return the offset in bytes to the packed LHS matrix -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( - size_t m_idx, // - size_t k); // - -/// Gets the offset in bytes for the packed RHS matrix, -/// which contains the packed Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) values. -/// -/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. -/// @param[in] k The common dimension between the LHS and RHS matrix (K). -/// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. It must be a multiple of 32. -/// -/// @return the offset in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( - size_t n_idx, // - size_t k, // - size_t bl); // - -/// Gets the offset in bytes for the DST matrix -/// -/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. -/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. -/// @param[in] dst_stride The number of bytes in in each row of the DST matrix -/// -/// @return the DST offset in bytes -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( - size_t m_idx, // - size_t n_idx, // - size_t dst_stride); // - -/// Gets the size in bytes for the destination (DST) matrix. -/// -/// @param[in] m Number of rows in the destination (DST) matrix. -/// @param[in] n Number of columns in the destination (DST) matrix. -/// -/// @return the destination (DST) matrix size in bytes -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( - size_t m, // - size_t n); // - -/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. -/// -/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. -/// RHS matrix: Quantized Symmetric Signed 4-bit with per-block (32) quantization (qsi4c32) and packed. -/// Output tile: (rows x cols) = m_step x n_step. -/// -/// Note: Please refer to the get functions for m_step and n_step for the exact values. -/// -/// Features used: i8mm -/// -/// @param[in] m The number of output rows written. -/// @param[in] n The number of output columns written. -/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. -/// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. Block length. It must be a multiple of 32. -/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the -/// top of this file. -/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the -/// top of this file. -/// @param[out] dst The DST matrix. -/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. -/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. -/// @param[in] scalar_min Min value used to clamp the final result. -/// @param[in] scalar_max Max value used to clamp the final result. -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm( - size_t m, // - size_t n, // - size_t k, // - size_t bl, // - const void* lhs_packed, // - const void* rhs_packed, // - float* dst, // - size_t dst_stride_row, // - size_t dst_stride_col, // - float scalar_min, // - float scalar_max); // - -#ifdef __cplusplus -} -#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S deleted file mode 100644 index b2194615..00000000 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm_asm.S +++ /dev/null @@ -1,670 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#if defined(_MSC_VER) - #define KAI_ASM_GLOBAL(name) GLOBAL name - #define KAI_ASM_FUNCTION_TYPE(name) - #define KAI_ASM_FUNCTION_LABEL(name) name PROC - #define KAI_ASM_FUNCTION_END(name) ENDP - - #define KAI_ASM_CODE(name) AREA name, CODE, READONLY - #define KAI_ASM_ALIGN - #define KAI_ASM_LABEL(name) name - #define KAI_ASM_INST(hex) DCD hex - #define KAI_ASM_END END -#else - #if defined(__APPLE__) - #define KAI_ASM_GLOBAL(name) .globl _##name - #define KAI_ASM_FUNCTION_TYPE(name) - #define KAI_ASM_FUNCTION_LABEL(name) _##name: - #define KAI_ASM_FUNCTION_END(name) - #else - #define KAI_ASM_GLOBAL(name) .global name - #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function - #define KAI_ASM_FUNCTION_LABEL(name) name: - #define KAI_ASM_FUNCTION_END(name) .size name, .-name - #endif - - #define KAI_ASM_CODE(name) .text - #define KAI_ASM_ALIGN .p2align 4,,11 - #define KAI_ASM_LABEL(name) name: - #define KAI_ASM_INST(hex) .inst hex - #define KAI_ASM_END -#endif - - KAI_ASM_CODE(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) - KAI_ASM_ALIGN - - KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) - -KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) -KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) - stp x20, x21, [sp, -144]! - stp x22, x23, [sp, 16] - stp x24, x25, [sp, 32] - stp x26, x27, [sp, 48] - str x28, [sp, 64] - stp d10, d11, [sp, 72] - stp d12, d13, [sp, 88] - stp d14, d15, [sp, 104] - stp d8, d9, [sp, 120] - mov x6, #0x80 - mov x21, #0x20 - ldr x20, [x0, #0x28] - ldr x7, [x0, #0x38] - ldr x8, [x0, #0x8] - ldr x17, [x0, #0x10] - ldr x16, [x0, #0x30] - ldr x15, [x0, #0x0] - mov x14, x20 - ldr x13, [x0, #0x20] - madd x6, x7, x6, x21 - ldr x12, [x0, #0x18] - cmp x14, #0x10 - blt label_14 -KAI_ASM_LABEL(label_1) // Row loop - mov x11, x17 - mov x10, x16 - add x9, x15, x13, LSL #4 -KAI_ASM_LABEL(label_2) // Column loop - mov x27, x8 - movi v31.16b, #0x0 - movi v30.16b, #0x0 - mov x20, x7 - movi v29.16b, #0x0 - movi v28.16b, #0x0 - movi v27.16b, #0x0 - movi v26.16b, #0x0 - add x23, x27, x6 - add x22, x23, x6 - movi v25.16b, #0x0 - movi v24.16b, #0x0 - add x21, x22, x6 - movi v23.16b, #0x0 - movi v22.16b, #0x0 - movi v21.16b, #0x0 - movi v20.16b, #0x0 - movi v19.16b, #0x0 - movi v18.16b, #0x0 - movi v17.16b, #0x0 - movi v16.16b, #0x0 -KAI_ASM_LABEL(label_3) // Block loop - ldr q11, [x11, #0x0] - ldr q4, [x11, #0x10] - movi v2.4s, #0x0 - movi v9.4s, #0x0 - ldr q12, [x27, #0x0] - ldr q0, [x27, #0x10] - movi v7.4s, #0x0 - movi v5.4s, #0x0 - ldr q15, [x11, #0x20] - ldr q13, [x11, #0x30] - movi v10.16b, #0xf0 - add x11, x11, #0x40 - ldr q8, [x27, #0x20] - ldr q6, [x27, #0x30] - shl v14.16b, v11.16b, #0x4 - shl v3.16b, v4.16b, #0x4 - ldr q1, [x27, #0x40] - and v11.16b, v11.16b, v10.16b - and v4.16b, v4.16b, v10.16b - KAI_ASM_INST(0x4e8ea582) // smmla v2.4s, v12.16b, v14.16b - KAI_ASM_INST(0x4e83a589) // smmla v9.4s, v12.16b, v3.16b - shl v12.16b, v15.16b, #0x4 - KAI_ASM_INST(0x4e8ea407) // smmla v7.4s, v0.16b, v14.16b - KAI_ASM_INST(0x4e83a405) // smmla v5.4s, v0.16b, v3.16b - shl v0.16b, v13.16b, #0x4 - and v15.16b, v15.16b, v10.16b - and v13.16b, v13.16b, v10.16b - ldr q10, [x27, #0x50] - KAI_ASM_INST(0x4e8ca502) // smmla v2.4s, v8.16b, v12.16b - KAI_ASM_INST(0x4e80a509) // smmla v9.4s, v8.16b, v0.16b - ldr q8, [x27, #0x60] - KAI_ASM_INST(0x4e8ca4c7) // smmla v7.4s, v6.16b, v12.16b - KAI_ASM_INST(0x4e80a4c5) // smmla v5.4s, v6.16b, v0.16b - ldr q6, [x27, #0x70] - add x27, x27, #0x80 - KAI_ASM_INST(0x4e8ba422) // smmla v2.4s, v1.16b, v11.16b - KAI_ASM_INST(0x4e84a429) // smmla v9.4s, v1.16b, v4.16b - ldr d1, [x11, #0x0] - add x11, x11, #0x8 - KAI_ASM_INST(0x4e8ba547) // smmla v7.4s, v10.16b, v11.16b - KAI_ASM_INST(0x4e84a545) // smmla v5.4s, v10.16b, v4.16b - KAI_ASM_INST(0x4e8fa502) // smmla v2.4s, v8.16b, v15.16b - shll v1.4s, v1.4h, #0x10 - KAI_ASM_INST(0x4e8da509) // smmla v9.4s, v8.16b, v13.16b - KAI_ASM_INST(0x4e8fa4c7) // smmla v7.4s, v6.16b, v15.16b - KAI_ASM_INST(0x4e8da4c5) // smmla v5.4s, v6.16b, v13.16b - uzp1 v6.2d, v2.2d, v9.2d - uzp2 v8.2d, v2.2d, v9.2d - scvtf v6.4s, v6.4s, #0x4 - uzp1 v9.2d, v7.2d, v5.2d - uzp2 v2.2d, v7.2d, v5.2d - scvtf v8.4s, v8.4s, #0x4 - fmla v31.4s, v6.4s, v1.4s - scvtf v9.4s, v9.4s, #0x4 - scvtf v2.4s, v2.4s, #0x4 - fmla v30.4s, v8.4s, v1.4s - fmla v29.4s, v9.4s, v1.4s - fmla v28.4s, v2.4s, v1.4s - ldr q9, [x23, #0x0] - ldr q7, [x23, #0x10] - movi v8.4s, #0x0 - movi v2.4s, #0x0 - ldr q5, [x23, #0x20] - ldr q10, [x23, #0x30] - movi v6.4s, #0x0 - KAI_ASM_INST(0x4e8ea528) // smmla v8.4s, v9.16b, v14.16b - KAI_ASM_INST(0x4e83a522) // smmla v2.4s, v9.16b, v3.16b - ldr q9, [x23, #0x40] - KAI_ASM_INST(0x4e8ea4e6) // smmla v6.4s, v7.16b, v14.16b - KAI_ASM_INST(0x4e8ca4a8) // smmla v8.4s, v5.16b, v12.16b - KAI_ASM_INST(0x4e80a4a2) // smmla v2.4s, v5.16b, v0.16b - ldr q5, [x23, #0x50] - KAI_ASM_INST(0x4e8ca546) // smmla v6.4s, v10.16b, v12.16b - KAI_ASM_INST(0x4e8ba528) // smmla v8.4s, v9.16b, v11.16b - KAI_ASM_INST(0x4e84a522) // smmla v2.4s, v9.16b, v4.16b - ldr q9, [x23, #0x60] - KAI_ASM_INST(0x4e8ba4a6) // smmla v6.4s, v5.16b, v11.16b - KAI_ASM_INST(0x4e8fa528) // smmla v8.4s, v9.16b, v15.16b - KAI_ASM_INST(0x4e8da522) // smmla v2.4s, v9.16b, v13.16b - movi v9.4s, #0x0 - KAI_ASM_INST(0x4e83a4e9) // smmla v9.4s, v7.16b, v3.16b - ldr q7, [x23, #0x70] - add x23, x23, #0x80 - KAI_ASM_INST(0x4e8fa4e6) // smmla v6.4s, v7.16b, v15.16b - KAI_ASM_INST(0x4e80a549) // smmla v9.4s, v10.16b, v0.16b - uzp1 v10.2d, v8.2d, v2.2d - uzp2 v2.2d, v8.2d, v2.2d - scvtf v10.4s, v10.4s, #0x4 - KAI_ASM_INST(0x4e84a4a9) // smmla v9.4s, v5.16b, v4.16b - scvtf v2.4s, v2.4s, #0x4 - fmla v27.4s, v10.4s, v1.4s - KAI_ASM_INST(0x4e8da4e9) // smmla v9.4s, v7.16b, v13.16b - fmla v26.4s, v2.4s, v1.4s - uzp1 v2.2d, v6.2d, v9.2d - uzp2 v10.2d, v6.2d, v9.2d - scvtf v2.4s, v2.4s, #0x4 - scvtf v10.4s, v10.4s, #0x4 - fmla v25.4s, v2.4s, v1.4s - fmla v24.4s, v10.4s, v1.4s - ldr q8, [x22, #0x0] - ldr q7, [x22, #0x10] - movi v9.4s, #0x0 - movi v6.4s, #0x0 - ldr q2, [x22, #0x20] - ldr q5, [x22, #0x30] - movi v10.4s, #0x0 - KAI_ASM_INST(0x4e8ea509) // smmla v9.4s, v8.16b, v14.16b - KAI_ASM_INST(0x4e83a506) // smmla v6.4s, v8.16b, v3.16b - ldr q8, [x22, #0x40] - KAI_ASM_INST(0x4e8ea4ea) // smmla v10.4s, v7.16b, v14.16b - KAI_ASM_INST(0x4e8ca449) // smmla v9.4s, v2.16b, v12.16b - KAI_ASM_INST(0x4e80a446) // smmla v6.4s, v2.16b, v0.16b - ldr q2, [x22, #0x50] - KAI_ASM_INST(0x4e8ca4aa) // smmla v10.4s, v5.16b, v12.16b - KAI_ASM_INST(0x4e8ba509) // smmla v9.4s, v8.16b, v11.16b - KAI_ASM_INST(0x4e84a506) // smmla v6.4s, v8.16b, v4.16b - ldr q8, [x22, #0x60] - KAI_ASM_INST(0x4e8ba44a) // smmla v10.4s, v2.16b, v11.16b - KAI_ASM_INST(0x4e8fa509) // smmla v9.4s, v8.16b, v15.16b - KAI_ASM_INST(0x4e8da506) // smmla v6.4s, v8.16b, v13.16b - movi v8.4s, #0x0 - KAI_ASM_INST(0x4e83a4e8) // smmla v8.4s, v7.16b, v3.16b - ldr q7, [x22, #0x70] - add x22, x22, #0x80 - KAI_ASM_INST(0x4e8fa4ea) // smmla v10.4s, v7.16b, v15.16b - KAI_ASM_INST(0x4e80a4a8) // smmla v8.4s, v5.16b, v0.16b - uzp1 v5.2d, v9.2d, v6.2d - uzp2 v9.2d, v9.2d, v6.2d - scvtf v5.4s, v5.4s, #0x4 - KAI_ASM_INST(0x4e84a448) // smmla v8.4s, v2.16b, v4.16b - scvtf v9.4s, v9.4s, #0x4 - fmla v23.4s, v5.4s, v1.4s - KAI_ASM_INST(0x4e8da4e8) // smmla v8.4s, v7.16b, v13.16b - fmla v22.4s, v9.4s, v1.4s - uzp1 v2.2d, v10.2d, v8.2d - uzp2 v10.2d, v10.2d, v8.2d - scvtf v2.4s, v2.4s, #0x4 - scvtf v10.4s, v10.4s, #0x4 - fmla v21.4s, v2.4s, v1.4s - fmla v20.4s, v10.4s, v1.4s - ldr q2, [x21, #0x0] - ldr q10, [x21, #0x10] - movi v6.4s, #0x0 - movi v9.4s, #0x0 - ldr q5, [x21, #0x20] - ldr q8, [x21, #0x30] - movi v7.4s, #0x0 - KAI_ASM_INST(0x4e8ea446) // smmla v6.4s, v2.16b, v14.16b - KAI_ASM_INST(0x4e83a449) // smmla v9.4s, v2.16b, v3.16b - ldr q2, [x21, #0x40] - KAI_ASM_INST(0x4e8ea547) // smmla v7.4s, v10.16b, v14.16b - ldr q14, [x21, #0x50] - KAI_ASM_INST(0x4e8ca4a6) // smmla v6.4s, v5.16b, v12.16b - KAI_ASM_INST(0x4e80a4a9) // smmla v9.4s, v5.16b, v0.16b - ldr q5, [x21, #0x60] - KAI_ASM_INST(0x4e8ca507) // smmla v7.4s, v8.16b, v12.16b - ldr q12, [x21, #0x70] - add x21, x21, #0x80 - KAI_ASM_INST(0x4e8ba446) // smmla v6.4s, v2.16b, v11.16b - KAI_ASM_INST(0x4e84a449) // smmla v9.4s, v2.16b, v4.16b - movi v2.4s, #0x0 - KAI_ASM_INST(0x4e83a542) // smmla v2.4s, v10.16b, v3.16b - KAI_ASM_INST(0x4e8ba5c7) // smmla v7.4s, v14.16b, v11.16b - KAI_ASM_INST(0x4e8fa4a6) // smmla v6.4s, v5.16b, v15.16b - KAI_ASM_INST(0x4e80a502) // smmla v2.4s, v8.16b, v0.16b - KAI_ASM_INST(0x4e8da4a9) // smmla v9.4s, v5.16b, v13.16b - KAI_ASM_INST(0x4e8fa587) // smmla v7.4s, v12.16b, v15.16b - KAI_ASM_INST(0x4e84a5c2) // smmla v2.4s, v14.16b, v4.16b - uzp1 v11.2d, v6.2d, v9.2d - uzp2 v14.2d, v6.2d, v9.2d - scvtf v11.4s, v11.4s, #0x4 - KAI_ASM_INST(0x4e8da582) // smmla v2.4s, v12.16b, v13.16b - scvtf v14.4s, v14.4s, #0x4 - fmla v19.4s, v11.4s, v1.4s - uzp1 v9.2d, v7.2d, v2.2d - uzp2 v0.2d, v7.2d, v2.2d - fmla v18.4s, v14.4s, v1.4s - scvtf v9.4s, v9.4s, #0x4 - scvtf v0.4s, v0.4s, #0x4 - fmla v17.4s, v9.4s, v1.4s - fmla v16.4s, v0.4s, v1.4s - subs x20, x20, #0x1 - bgt label_3 - ld1 { v11.4s }, [x27] - ld1 { v10.4s }, [x23] - add x27, x27, #0x10 - add x23, x23, #0x10 - ld1 { v9.4s }, [x22] - ld1 { v8.4s }, [x21] - add x22, x22, #0x10 - add x21, x21, #0x10 - ldr q7, [x11, #0x0] - ldr q6, [x27, #0x0] - add x20, x12, #0x4 - cmp x10, #0x4 - ldr q5, [x23, #0x0] - ldr q4, [x22, #0x0] - scvtf v11.4s, v11.4s - scvtf v10.4s, v10.4s - ldr q3, [x21, #0x0] - ldr q2, [x11, #0x10] - scvtf v9.4s, v9.4s - scvtf v8.4s, v8.4s - ld1r { v1.4s }, [x12] - ld1r { v0.4s }, [x20] - add x11, x11, #0x20 - fmla v31.4s, v7.4s, v11.s[0] - fmla v30.4s, v7.4s, v11.s[1] - fmla v29.4s, v7.4s, v11.s[2] - fmla v28.4s, v7.4s, v11.s[3] - fmla v27.4s, v7.4s, v10.s[0] - fmla v26.4s, v7.4s, v10.s[1] - fmla v25.4s, v7.4s, v10.s[2] - fmla v24.4s, v7.4s, v10.s[3] - fmla v23.4s, v7.4s, v9.s[0] - fmul v31.4s, v31.4s, v6.s[0] - fmla v22.4s, v7.4s, v9.s[1] - fmla v21.4s, v7.4s, v9.s[2] - fmul v30.4s, v30.4s, v6.s[1] - fmla v20.4s, v7.4s, v9.s[3] - fmla v19.4s, v7.4s, v8.s[0] - fmul v29.4s, v29.4s, v6.s[2] - fmla v18.4s, v7.4s, v8.s[1] - fmla v17.4s, v7.4s, v8.s[2] - fmul v28.4s, v28.4s, v6.s[3] - fmla v16.4s, v7.4s, v8.s[3] - fmul v27.4s, v27.4s, v5.s[0] - fmul v26.4s, v26.4s, v5.s[1] - fmul v25.4s, v25.4s, v5.s[2] - fmul v24.4s, v24.4s, v5.s[3] - fmul v23.4s, v23.4s, v4.s[0] - fmul v22.4s, v22.4s, v4.s[1] - fmul v21.4s, v21.4s, v4.s[2] - fmul v20.4s, v20.4s, v4.s[3] - fmul v19.4s, v19.4s, v3.s[0] - fmul v18.4s, v18.4s, v3.s[1] - fmul v17.4s, v17.4s, v3.s[2] - fmul v16.4s, v16.4s, v3.s[3] - fadd v31.4s, v31.4s, v2.4s - fadd v30.4s, v30.4s, v2.4s - fadd v29.4s, v29.4s, v2.4s - fadd v28.4s, v28.4s, v2.4s - fadd v27.4s, v27.4s, v2.4s - fadd v26.4s, v26.4s, v2.4s - fadd v25.4s, v25.4s, v2.4s - fadd v24.4s, v24.4s, v2.4s - fadd v23.4s, v23.4s, v2.4s - fadd v22.4s, v22.4s, v2.4s - fadd v21.4s, v21.4s, v2.4s - fadd v20.4s, v20.4s, v2.4s - fadd v19.4s, v19.4s, v2.4s - fadd v18.4s, v18.4s, v2.4s - fadd v17.4s, v17.4s, v2.4s - fadd v16.4s, v16.4s, v2.4s - fmax v31.4s, v31.4s, v1.4s - fmax v30.4s, v30.4s, v1.4s - fmax v29.4s, v29.4s, v1.4s - fmax v28.4s, v28.4s, v1.4s - fmax v27.4s, v27.4s, v1.4s - fmax v26.4s, v26.4s, v1.4s - fmax v25.4s, v25.4s, v1.4s - fmax v24.4s, v24.4s, v1.4s - fmax v23.4s, v23.4s, v1.4s - fmax v22.4s, v22.4s, v1.4s - fmax v21.4s, v21.4s, v1.4s - fmax v20.4s, v20.4s, v1.4s - fmax v19.4s, v19.4s, v1.4s - fmax v18.4s, v18.4s, v1.4s - fmax v17.4s, v17.4s, v1.4s - fmax v16.4s, v16.4s, v1.4s - fmin v31.4s, v31.4s, v0.4s - fmin v30.4s, v30.4s, v0.4s - fmin v29.4s, v29.4s, v0.4s - fmin v28.4s, v28.4s, v0.4s - fmin v27.4s, v27.4s, v0.4s - fmin v26.4s, v26.4s, v0.4s - fmin v25.4s, v25.4s, v0.4s - fmin v24.4s, v24.4s, v0.4s - fmin v23.4s, v23.4s, v0.4s - fmin v22.4s, v22.4s, v0.4s - fmin v21.4s, v21.4s, v0.4s - fmin v20.4s, v20.4s, v0.4s - fmin v19.4s, v19.4s, v0.4s - fmin v18.4s, v18.4s, v0.4s - fmin v17.4s, v17.4s, v0.4s - fmin v16.4s, v16.4s, v0.4s - blt label_8 - mov x20, x15 - str q31, [x20, #0x0] - add x20, x20, x13 - str q30, [x20, #0x0] - add x20, x20, x13 - str q29, [x20, #0x0] - add x20, x20, x13 - str q28, [x20, #0x0] - add x20, x20, x13 - str q27, [x20, #0x0] - add x20, x20, x13 - str q26, [x20, #0x0] - add x20, x20, x13 - str q25, [x20, #0x0] - add x20, x20, x13 - str q24, [x20, #0x0] - add x20, x20, x13 - str q23, [x20, #0x0] - add x20, x20, x13 - str q22, [x20, #0x0] - add x20, x20, x13 - str q21, [x20, #0x0] - add x20, x20, x13 - str q20, [x20, #0x0] - add x20, x20, x13 - str q19, [x20, #0x0] - add x20, x20, x13 - str q18, [x20, #0x0] - add x20, x20, x13 - str q17, [x20, #0x0] - add x20, x20, x13 - str q16, [x20, #0x0] - b label_13 -KAI_ASM_LABEL(label_8) // Partial output - mov x28, x15 - add x26, x28, x13, LSL #2 - add x25, x26, x13, LSL #1 - add x24, x26, x13 - add x23, x25, x13 - add x22, x28, x13, LSL #1 - add x21, x28, x13 - add x20, x22, x13 - add x27, x23, x13 - tbz x10, #1, label_9 - st1 { v24.d }[0], [x23], #0x8 - st1 { v25.d }[0], [x25], #0x8 - st1 { v26.d }[0], [x24], #0x8 - st1 { v27.d }[0], [x26], #0x8 - st1 { v28.d }[0], [x20], #0x8 - st1 { v29.d }[0], [x22], #0x8 - st1 { v30.d }[0], [x21], #0x8 - st1 { v31.d }[0], [x28], #0x8 - tbz x10, #0, label_10 - st1 { v24.s }[2], [x23] - st1 { v25.s }[2], [x25] - st1 { v26.s }[2], [x24] - st1 { v27.s }[2], [x26] - st1 { v28.s }[2], [x20] - st1 { v29.s }[2], [x22] - st1 { v30.s }[2], [x21] - st1 { v31.s }[2], [x28] - b label_10 -KAI_ASM_LABEL(label_9) // Output block 0: partial_1_0 - st1 { v24.s }[0], [x23] - st1 { v25.s }[0], [x25] - st1 { v26.s }[0], [x24] - st1 { v27.s }[0], [x26] - st1 { v28.s }[0], [x20] - st1 { v29.s }[0], [x22] - st1 { v30.s }[0], [x21] - st1 { v31.s }[0], [x28] -KAI_ASM_LABEL(label_10) // Output block 0: Done - add x26, x27, x13, LSL #2 - add x25, x27, x13, LSL #1 - add x24, x26, x13, LSL #1 - add x23, x27, x13 - add x22, x25, x13 - add x21, x26, x13 - add x20, x24, x13 - tbz x10, #1, label_11 - st1 { v16.d }[0], [x20], #0x8 - st1 { v17.d }[0], [x24], #0x8 - st1 { v18.d }[0], [x21], #0x8 - st1 { v19.d }[0], [x26], #0x8 - st1 { v20.d }[0], [x22], #0x8 - st1 { v21.d }[0], [x25], #0x8 - st1 { v22.d }[0], [x23], #0x8 - st1 { v23.d }[0], [x27], #0x8 - tbz x10, #0, label_12 - st1 { v16.s }[2], [x20] - st1 { v17.s }[2], [x24] - st1 { v18.s }[2], [x21] - st1 { v19.s }[2], [x26] - st1 { v20.s }[2], [x22] - st1 { v21.s }[2], [x25] - st1 { v22.s }[2], [x23] - st1 { v23.s }[2], [x27] - b label_12 -KAI_ASM_LABEL(label_11) // Output block 1: partial_1_0 - st1 { v16.s }[0], [x20] - st1 { v17.s }[0], [x24] - st1 { v18.s }[0], [x21] - st1 { v19.s }[0], [x26] - st1 { v20.s }[0], [x22] - st1 { v21.s }[0], [x25] - st1 { v22.s }[0], [x23] - st1 { v23.s }[0], [x27] -KAI_ASM_LABEL(label_12) // Output block 1: Done -KAI_ASM_LABEL(label_13) // Output stage exit - subs x10, x10, #0x4 - add x15, x15, #0x10 - bgt label_2 - mov x20, #0x4 - sub x14, x14, #0x10 - cmp x14, #0x10 - mov x15, x9 - madd x8, x20, x6, x8 - bge label_1 -KAI_ASM_LABEL(label_14) // Row loop skip - cbz x14, label_23 -KAI_ASM_LABEL(label_15) // Row tail: Row loop - mov x26, x17 - mov x25, x16 - add x24, x15, x13, LSL #2 -KAI_ASM_LABEL(label_16) // Row tail: Column loop - movi v31.16b, #0x0 - movi v30.16b, #0x0 - mov x27, x8 - mov x20, x7 - movi v29.16b, #0x0 - movi v28.16b, #0x0 -KAI_ASM_LABEL(label_17) // Row tail: Block loop - ldr q9, [x26, #0x0] - ldr q8, [x26, #0x10] - movi v7.4s, #0x0 - movi v6.4s, #0x0 - ldr q5, [x27, #0x0] - ldr q4, [x27, #0x10] - movi v3.4s, #0x0 - movi v2.4s, #0x0 - ldr q1, [x26, #0x20] - ldr q0, [x26, #0x30] - movi v27.16b, #0xf0 - add x26, x26, #0x40 - ldr q26, [x27, #0x20] - ldr q25, [x27, #0x30] - shl v24.16b, v9.16b, #0x4 - shl v20.16b, v8.16b, #0x4 - ldr q23, [x27, #0x40] - ldr q22, [x27, #0x50] - and v9.16b, v9.16b, v27.16b - and v8.16b, v8.16b, v27.16b - ldr q21, [x27, #0x60] - ldr q19, [x27, #0x70] - shl v18.16b, v1.16b, #0x4 - shl v17.16b, v0.16b, #0x4 - ldr d16, [x26, #0x0] - KAI_ASM_INST(0x4e98a4a7) // smmla v7.4s, v5.16b, v24.16b - KAI_ASM_INST(0x4e94a4a6) // smmla v6.4s, v5.16b, v20.16b - and v1.16b, v1.16b, v27.16b - KAI_ASM_INST(0x4e98a483) // smmla v3.4s, v4.16b, v24.16b - KAI_ASM_INST(0x4e94a482) // smmla v2.4s, v4.16b, v20.16b - and v0.16b, v0.16b, v27.16b - add x26, x26, #0x8 - add x27, x27, #0x80 - shll v20.4s, v16.4h, #0x10 - KAI_ASM_INST(0x4e92a747) // smmla v7.4s, v26.16b, v18.16b - KAI_ASM_INST(0x4e91a746) // smmla v6.4s, v26.16b, v17.16b - KAI_ASM_INST(0x4e92a723) // smmla v3.4s, v25.16b, v18.16b - KAI_ASM_INST(0x4e91a722) // smmla v2.4s, v25.16b, v17.16b - KAI_ASM_INST(0x4e89a6e7) // smmla v7.4s, v23.16b, v9.16b - KAI_ASM_INST(0x4e88a6e6) // smmla v6.4s, v23.16b, v8.16b - KAI_ASM_INST(0x4e89a6c3) // smmla v3.4s, v22.16b, v9.16b - KAI_ASM_INST(0x4e88a6c2) // smmla v2.4s, v22.16b, v8.16b - KAI_ASM_INST(0x4e81a6a7) // smmla v7.4s, v21.16b, v1.16b - KAI_ASM_INST(0x4e80a6a6) // smmla v6.4s, v21.16b, v0.16b - KAI_ASM_INST(0x4e81a663) // smmla v3.4s, v19.16b, v1.16b - KAI_ASM_INST(0x4e80a662) // smmla v2.4s, v19.16b, v0.16b - uzp1 v19.2d, v7.2d, v6.2d - uzp2 v18.2d, v7.2d, v6.2d - scvtf v19.4s, v19.4s, #0x4 - uzp1 v17.2d, v3.2d, v2.2d - uzp2 v16.2d, v3.2d, v2.2d - scvtf v18.4s, v18.4s, #0x4 - fmla v31.4s, v19.4s, v20.4s - scvtf v17.4s, v17.4s, #0x4 - scvtf v16.4s, v16.4s, #0x4 - fmla v30.4s, v18.4s, v20.4s - fmla v29.4s, v17.4s, v20.4s - fmla v28.4s, v16.4s, v20.4s - subs x20, x20, #0x1 - bgt label_17 - ld1 { v21.4s }, [x27] - ldr q20, [x26, #0x0] - add x27, x27, #0x10 - add x20, x12, #0x4 - ldr q19, [x27, #0x0] - ldr q18, [x26, #0x10] - cmp x25, #0x4 - add x26, x26, #0x20 - ld1r { v17.4s }, [x12] - ld1r { v16.4s }, [x20] - scvtf v21.4s, v21.4s - fmla v31.4s, v20.4s, v21.s[0] - fmla v30.4s, v20.4s, v21.s[1] - fmla v29.4s, v20.4s, v21.s[2] - fmla v28.4s, v20.4s, v21.s[3] - fmul v31.4s, v31.4s, v19.s[0] - fmul v30.4s, v30.4s, v19.s[1] - fmul v29.4s, v29.4s, v19.s[2] - fadd v31.4s, v31.4s, v18.4s - fmul v28.4s, v28.4s, v19.s[3] - fadd v30.4s, v30.4s, v18.4s - fadd v29.4s, v29.4s, v18.4s - fadd v28.4s, v28.4s, v18.4s - fmax v31.4s, v31.4s, v17.4s - fmax v30.4s, v30.4s, v17.4s - fmax v29.4s, v29.4s, v17.4s - fmax v28.4s, v28.4s, v17.4s - fmin v31.4s, v31.4s, v16.4s - fmin v30.4s, v30.4s, v16.4s - fmin v29.4s, v29.4s, v16.4s - fmin v28.4s, v28.4s, v16.4s - blt label_19 - mov x20, x15 - cmp x14, #0x1 - str q31, [x20, #0x0] - add x20, x20, x13 - ble label_22 - cmp x14, #0x2 - str q30, [x20, #0x0] - add x20, x20, x13 - ble label_22 - cmp x14, #0x3 - str q29, [x20, #0x0] - add x20, x20, x13 - ble label_22 - str q28, [x20, #0x0] - b label_22 -KAI_ASM_LABEL(label_19) // Row tail: Partial output - mov x23, x15 - cmp x14, #0x1 - add x22, x23, x13 - csel x22, x22, x23, GT - cmp x14, #0x2 - add x21, x23, x13, LSL #1 - csel x21, x21, x22, GT - cmp x14, #0x3 - add x20, x21, x13 - csel x20, x20, x21, GT - tbz x25, #1, label_20 - st1 { v28.d }[0], [x20], #0x8 - st1 { v29.d }[0], [x21], #0x8 - st1 { v30.d }[0], [x22], #0x8 - st1 { v31.d }[0], [x23], #0x8 - tbz x25, #0, label_21 - st1 { v28.s }[2], [x20] - st1 { v29.s }[2], [x21] - st1 { v30.s }[2], [x22] - st1 { v31.s }[2], [x23] - b label_21 -KAI_ASM_LABEL(label_20) // Row tail: Output block 0: partial_1_0 - st1 { v28.s }[0], [x20] - st1 { v29.s }[0], [x21] - st1 { v30.s }[0], [x22] - st1 { v31.s }[0], [x23] -KAI_ASM_LABEL(label_21) // Row tail: Output block 0: Done -KAI_ASM_LABEL(label_22) // Row tail: Output stage exit - subs x25, x25, #0x4 - add x15, x15, #0x10 - bgt label_16 - subs x14, x14, #0x4 - add x8, x8, x6 - mov x15, x24 - bgt label_15 -KAI_ASM_LABEL(label_23) // Row tail: Row loop skip - ldp x22, x23, [sp, 16] - ldp x24, x25, [sp, 32] - ldp x26, x27, [sp, 48] - ldr x28, [sp, 64] - ldp d10, d11, [sp, 72] - ldp d12, d13, [sp, 88] - ldp d14, d15, [sp, 104] - ldp d8, d9, [sp, 120] - ldp x20, x21, [sp], 144 - ret - KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm) - - KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h index 81d48e04..f29ed2de 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h @@ -61,8 +61,6 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) /// /// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. /// @param[in] k Total number of columns in the LHS matrix (not packed). -/// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. It must be a multiple of 32. /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h index acd7d784..15be9bd4 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h @@ -61,8 +61,6 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) /// /// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. /// @param[in] k Total number of columns in the LHS matrix (not packed). -/// It must be a multiple of the block length (bl). -/// @param[in] bl Block length. It must be a multiple of 32. /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( -- GitLab From 77b9826a6e9b39ca20ec75293427c6fbdc04fc1e Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 27 Jan 2025 19:01:22 +0000 Subject: [PATCH 13/15] Update Bazel Build Signed-off-by: Anitha Raj --- kai/ukernels/matmul/BUILD.bazel | 1 - 1 file changed, 1 deletion(-) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 294b52ea..97f12218 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -110,7 +110,6 @@ I8MM_KERNELS = [ I8MM_KERNELS_ASM = [ "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", - "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_opt32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", ] -- GitLab From fc87e1efded6bf7042fd1fd9fcf879cee1279c65 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 28 Jan 2025 09:50:56 +0000 Subject: [PATCH 14/15] Use kai_bl instead of number 32 Signed-off-by: Anitha Raj --- ...i_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c | 2 +- ...atmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c | 2 +- ...atmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c | 2 +- ..._matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c | 2 +- ..._matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c | 2 +- ...i_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 2 +- ...i_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c index c517f5a2..60b2132e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c @@ -163,7 +163,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod( if (m == 0) { return; } - const size_t num_subblocks = bl / 32; + const size_t num_subblocks = bl / kai_bl; const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); const float clamp_vals[2] = {scalar_min, scalar_max}; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c index b87e155e..07b559e4 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c @@ -163,7 +163,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( if (m == 0) { return; } - const size_t num_subblocks = bl / 32; + const size_t num_subblocks = bl / kai_bl; const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); const float clamp_vals[2] = {scalar_min, scalar_max}; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c index 59eec89d..b1a36df2 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c @@ -163,7 +163,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( if (m == 0) { return; } - const size_t num_subblocks = bl / 32; + const size_t num_subblocks = bl / kai_bl; const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); const float clamp_vals[2] = {scalar_min, scalar_max}; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c index 18930d7f..2336485c 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c @@ -164,7 +164,7 @@ void kai_run_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod( if (m == 0) { return; } - const size_t num_subblocks = bl / 32; + const size_t num_subblocks = bl / kai_bl; const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); const float clamp_vals[2] = {scalar_min, scalar_max}; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c index 3a98c339..5fd96361 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c @@ -164,7 +164,7 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( if (m == 0) { return; } - const size_t num_subblocks = bl / 32; + const size_t num_subblocks = bl / kai_bl; const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); const float clamp_vals[2] = {scalar_min, scalar_max}; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c index b1561b99..72bb1910 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -163,7 +163,7 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( if (m == 0) { return; } - const size_t num_subblocks = bl / 32; + const size_t num_subblocks = bl / kai_bl; const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); const float clamp_vals[2] = {scalar_min, scalar_max}; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c index 777fc994..6f97a11a 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c @@ -163,7 +163,7 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( if (m == 0) { return; } - const size_t num_subblocks = bl / 32; + const size_t num_subblocks = bl / kai_bl; const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); const float clamp_vals[2] = {scalar_min, scalar_max}; -- GitLab From bff56fe98b568e01bc372d3aab69ace6ae65a05c Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 28 Jan 2025 11:01:24 +0000 Subject: [PATCH 15/15] Update architectural feature guards Signed-off-by: Anitha Raj --- ...i_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c | 2 +- ...atmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c | 2 +- ...atmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c | 2 +- ..._matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c | 2 +- ..._matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c | 2 +- ...i_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 2 +- ...i_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c index 60b2132e..c7fac421 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(_M_ARM64) #error "Dotprod extension required to compile this micro-kernel" #else // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c index 07b559e4..3fc3cfc6 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(_M_ARM64) #error "Dotprod extension required to compile this micro-kernel" #else // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c index b1a36df2..4738d9aa 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(_M_ARM64) #error "Dotprod extension required to compile this micro-kernel" #else // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c index 2336485c..b442cda1 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(_M_ARM64) #error "Dotprod extension required to compile this micro-kernel" #else // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c index 5fd96361..0fc8e35c 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) && !defined(_M_ARM64) #error "I8mm extension required to compile this micro-kernel" #else // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c index 72bb1910..2dcca508 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) && !defined(_M_ARM64) #error "I8mm extension required to compile this micro-kernel" #else // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c index 6f97a11a..c2140be0 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) && !defined(_M_ARM64) #error "I8mm extension required to compile this micro-kernel" #else // Architectural features check. -- GitLab