From 6d2b8fad99ae046eae8e4577b86594bb274d2420 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 19 Jun 2025 15:10:48 +0100 Subject: [PATCH 1/3] Move GEMM F32 SME1 assembly kernel to separate file Signed-off-by: Viet-Hoa Do --- CMakeLists.txt | 1 + ...32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.c | 496 +---------------- ...32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h | 5 +- ...32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S | 522 ++++++++++++++++++ 4 files changed, 547 insertions(+), 477 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index 3aac20dd..01c87fc9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -231,6 +231,7 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME_ASM + kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.c index 8c24469d..32e459ab 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.c @@ -4,9 +4,6 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. @@ -17,11 +14,29 @@ #include "kai/kai_common.h" +typedef struct { + size_t M; + size_t N; + size_t K; + size_t flags; + void* accumulator_buffer; + float min; + float max; + void* C; + size_t ldcb; + const void* B; + size_t kstride_bytes; + const void* A; +} matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_args_t; + static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 1; static const size_t kai_sr = 1; +void kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( + const matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_args_t* args); + // Returns a constant value specific to this kernel that's relative to vector length static size_t kai_get_kernel_vec_length_constant(void) { const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u32(); @@ -83,28 +98,13 @@ size_t kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(s void kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) { - KAI_ASSUME(dst_stride_col == sizeof(float)); - - typedef struct { - const void* A; - const void* B; - uint64_t kstride_bytes; - void* C; - uint64_t ldcb; - uint64_t M, N, K, n_loops, n_tail_iters; - float min; - float max; - float* accumulator_buffer; - uint64_t flags; - } KernelArgs; + KAI_UNUSED(dst_stride_col); - KernelArgs args; + matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_args_t args; args.M = m; args.N = n; args.K = k; - args.n_loops = (k / kai_kr - 1) / 2; - args.n_tail_iters = (k / kai_kr - 1) % 2; args.A = lhs_packed; args.B = rhs_packed; @@ -119,461 +119,7 @@ void kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( args.flags = 0; - __asm__ __volatile__( - "ldr w8, [%x[args], %[offsetof_M]]\n" - ".inst 0xd503477f // SMSTART ZA\n" - "mov x17, XZR\n" - "mov x16, XZR\n" - "ldr x21, [%x[args], %[offsetof_A]]\n" - "cntw x20, ALL, MUL #2\n" - "ptrue p6.b\n" - "ldr x15, [%x[args], %[offsetof_flags]]\n" - "cntw x14\n" - "ld1rw { z1.s }, p6/Z, [%x[args], %[offset_min]]\n" - "sub x13, x8, x17\n" - "ldr w11, [%x[args], %[offsetof_N]]\n" - "ld1rw { z0.s }, p6/Z, [%x[args], %[offset_max]]\n" - "cmp x13, x20\n" - "mov x10, x21\n" - "ldr x9, [%x[args], %[offsetof_accumulator_buffer]]\n" - "csel x13, x13, x20, LT\n" // height = min(M - m, acc_height) - "mov x28, x10\n" - "ldr x27, [%x[args], %[offsetof_accumulator_buffer]]\n" - "whilelt p5.s, XZR, x13\n" - "whilelt p4.s, x14, x13\n" - "tbz x15, #0, 2f\n" - "ptrue p11.s\n" - "ptrue p10.s\n" - "cntw x21, ALL, MUL #2\n" - "cntw x20, ALL, MUL #3\n" - "mov x12, XZR\n" - "1:" // Initialise accumulators (first block): Loop - ".inst 0x25306960 // psel p0.s, p10.s/Z, p11.s[w12]\n" - ".inst 0x25306962 // psel p2.s, p10.s/Z, p11.s[w12]\n" - ".inst 0x25306961 // psel p1.s, p10.s/Z, p11.s[w12]\n" - ".inst 0xe09f0120 // ld1w { za0h.s[x12] }, p0/Z, [x9, XZR, LSL #2]\n" - ".inst 0x25306960 // psel p0.s, p10.s/Z, p11.s[w12]\n" - ".inst 0xe08e0924 // ld1w { za1h.s[x12] }, p2/Z, [x9, x14, LSL #2]\n" - ".inst 0xe0950528 // ld1w { za2h.s[x12] }, p1/Z, [x9, x21, LSL #2]\n" - ".inst 0xe094012c // ld1w { za3h.s[x12] }, p0/Z, [x9, x20, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x9, x9, #4\n" - "cmp x12, x14\n" - "blt 1b\n" - "2:" // Initialise accumulators (first block): End - "3:" // Outer loop - "cntw x20, ALL, MUL #2\n" - "sub x26, x11, x16\n" - "ldr x23, [%x[args], %[offsetof_C]]\n" - "ldr x25, [%x[args], %[offsetof_ldcb]]\n" - "cmp x26, x20\n" - "mov x22, XZR\n" - "ldr x21, [%x[args], %[offsetof_B]]\n" - "csel x26, x26, x20, LT\n" // width = min(N - n, acc_width) - "ldr x20, [%x[args], %[offsetof_kstride_bytes]]\n" - "whilelt p3.s, x22, x26\n" - "incw x22\n" - "madd x24, x17, x25, x23\n" // cptr = C + m * ldcb - "whilelt p2.s, x22, x26\n" - "add x24, x24, x16, LSL #2\n" // cptr += n * sizeof(T) - "madd x21, x16, x20, x21\n" // bptr = B + n * k_stride_bytes - "tbnz x15, #0, 4f\n" - ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" - "cbz x21, 4f\n" - "mov p1.b, p3.b\n" - "mov p0.b, p2.b\n" - "fmov z18.s, #1.0\n" - "ld1w { z17.s }, p1/Z, [x21]\n" - "ld1w { z16.s }, p0/Z, [x21, #1, MUL VL]\n" - "addvl x21, x21, #2\n" - ".inst 0x80917640 // fmopa za0.s, p5/M, p3/M, z18.s, z17.s\n" - ".inst 0x80905641 // fmopa za1.s, p5/M, p2/M, z18.s, z16.s\n" - ".inst 0x80917242 // fmopa za2.s, p4/M, p3/M, z18.s, z17.s\n" - ".inst 0x80905243 // fmopa za3.s, p4/M, p2/M, z18.s, z16.s\n" - "4:" // Initialise accumulators: End - "ldr x20, [%x[args], %[offsetof_K]]\n" - "cmp x20, #0x4\n" - "ble 8f\n" - "cmp x20, #0x8\n" - "ld1w { z31.s }, p5/Z, [x28]\n" - "ld1w { z30.s }, p4/Z, [x28, #1, MUL VL]\n" - "ldnt1w { z29.s }, p3/Z, [x21]\n" - "ldnt1w { z28.s }, p2/Z, [x21, #1, MUL VL]\n" - "ld1w { z27.s }, p5/Z, [x28, #2, MUL VL]\n" - "ld1w { z26.s }, p4/Z, [x28, #3, MUL VL]\n" - "ldnt1w { z25.s }, p3/Z, [x21, #2, MUL VL]\n" - "ldnt1w { z24.s }, p2/Z, [x21, #3, MUL VL]\n" - "ld1w { z23.s }, p5/Z, [x28, #4, MUL VL]\n" - "ld1w { z22.s }, p4/Z, [x28, #5, MUL VL]\n" - "ldnt1w { z21.s }, p3/Z, [x21, #4, MUL VL]\n" - "ldnt1w { z20.s }, p2/Z, [x21, #5, MUL VL]\n" - "ld1w { z19.s }, p5/Z, [x28, #6, MUL VL]\n" - "ld1w { z18.s }, p4/Z, [x28, #7, MUL VL]\n" - "addvl x28, x28, #8\n" - "ldnt1w { z17.s }, p3/Z, [x21, #6, MUL VL]\n" - "ldnt1w { z16.s }, p2/Z, [x21, #7, MUL VL]\n" - "addvl x21, x21, #8\n" - "blt 7f\n" - "6:" // K loop: Main: Loop - ".inst 0x809d77e0 // fmopa za0.s, p5/M, p3/M, z31.s, z29.s\n" - "sub x20, x20, #0x4\n" - ".inst 0x809c57e1 // fmopa za1.s, p5/M, p2/M, z31.s, z28.s\n" - "cmp x20, #0x8\n" - "ld1w { z31.s }, p5/Z, [x28]\n" - ".inst 0x809d73c2 // fmopa za2.s, p4/M, p3/M, z30.s, z29.s\n" - "ldnt1w { z29.s }, p3/Z, [x21]\n" - ".inst 0x809c53c3 // fmopa za3.s, p4/M, p2/M, z30.s, z28.s\n" - "ldnt1w { z28.s }, p2/Z, [x21, #1, MUL VL]\n" - ".inst 0x80997760 // fmopa za0.s, p5/M, p3/M, z27.s, z25.s\n" - "ld1w { z30.s }, p4/Z, [x28, #1, MUL VL]\n" - ".inst 0x80985761 // fmopa za1.s, p5/M, p2/M, z27.s, z24.s\n" - "ld1w { z27.s }, p5/Z, [x28, #2, MUL VL]\n" - ".inst 0x80997342 // fmopa za2.s, p4/M, p3/M, z26.s, z25.s\n" - "ldnt1w { z25.s }, p3/Z, [x21, #2, MUL VL]\n" - ".inst 0x80985343 // fmopa za3.s, p4/M, p2/M, z26.s, z24.s\n" - "ldnt1w { z24.s }, p2/Z, [x21, #3, MUL VL]\n" - ".inst 0x809576e0 // fmopa za0.s, p5/M, p3/M, z23.s, z21.s\n" - "ld1w { z26.s }, p4/Z, [x28, #3, MUL VL]\n" - ".inst 0x809456e1 // fmopa za1.s, p5/M, p2/M, z23.s, z20.s\n" - "ld1w { z23.s }, p5/Z, [x28, #4, MUL VL]\n" - ".inst 0x809572c2 // fmopa za2.s, p4/M, p3/M, z22.s, z21.s\n" - "ldnt1w { z21.s }, p3/Z, [x21, #4, MUL VL]\n" - ".inst 0x809452c3 // fmopa za3.s, p4/M, p2/M, z22.s, z20.s\n" - "ldnt1w { z20.s }, p2/Z, [x21, #5, MUL VL]\n" - "ld1w { z22.s }, p4/Z, [x28, #5, MUL VL]\n" - ".inst 0x80917660 // fmopa za0.s, p5/M, p3/M, z19.s, z17.s\n" - ".inst 0x80905661 // fmopa za1.s, p5/M, p2/M, z19.s, z16.s\n" - "ld1w { z19.s }, p5/Z, [x28, #6, MUL VL]\n" - ".inst 0x80917242 // fmopa za2.s, p4/M, p3/M, z18.s, z17.s\n" - "ldnt1w { z17.s }, p3/Z, [x21, #6, MUL VL]\n" - ".inst 0x80905243 // fmopa za3.s, p4/M, p2/M, z18.s, z16.s\n" - "ldnt1w { z16.s }, p2/Z, [x21, #7, MUL VL]\n" - "addvl x21, x21, #8\n" - "ld1w { z18.s }, p4/Z, [x28, #7, MUL VL]\n" - "addvl x28, x28, #8\n" - "bge 6b\n" - "7:" // K loop: Main: Detached iter - ".inst 0x809d77e0 // fmopa za0.s, p5/M, p3/M, z31.s, z29.s\n" - "sub x20, x20, #0x4\n" - ".inst 0x809c57e1 // fmopa za1.s, p5/M, p2/M, z31.s, z28.s\n" - ".inst 0x809d73c2 // fmopa za2.s, p4/M, p3/M, z30.s, z29.s\n" - ".inst 0x809c53c3 // fmopa za3.s, p4/M, p2/M, z30.s, z28.s\n" - ".inst 0x80997760 // fmopa za0.s, p5/M, p3/M, z27.s, z25.s\n" - ".inst 0x80985761 // fmopa za1.s, p5/M, p2/M, z27.s, z24.s\n" - ".inst 0x80997342 // fmopa za2.s, p4/M, p3/M, z26.s, z25.s\n" - ".inst 0x80985343 // fmopa za3.s, p4/M, p2/M, z26.s, z24.s\n" - ".inst 0x809576e0 // fmopa za0.s, p5/M, p3/M, z23.s, z21.s\n" - ".inst 0x809456e1 // fmopa za1.s, p5/M, p2/M, z23.s, z20.s\n" - ".inst 0x809572c2 // fmopa za2.s, p4/M, p3/M, z22.s, z21.s\n" - ".inst 0x809452c3 // fmopa za3.s, p4/M, p2/M, z22.s, z20.s\n" - ".inst 0x80917660 // fmopa za0.s, p5/M, p3/M, z19.s, z17.s\n" - ".inst 0x80905661 // fmopa za1.s, p5/M, p2/M, z19.s, z16.s\n" - ".inst 0x80917242 // fmopa za2.s, p4/M, p3/M, z18.s, z17.s\n" - ".inst 0x80905243 // fmopa za3.s, p4/M, p2/M, z18.s, z16.s\n" - "8:" // K loop: Tail - "cbz x20, 10f\n" - "9:" // K loop: Tail: Loop - "ld1w { z19.s }, p5/Z, [x28]\n" - "sub x20, x20, #0x1\n" - "ld1w { z18.s }, p4/Z, [x28, #1, MUL VL]\n" - "cmp x20, XZR\n" - "addvl x28, x28, #2\n" - "ldnt1w { z17.s }, p3/Z, [x21]\n" - "ldnt1w { z16.s }, p2/Z, [x21, #1, MUL VL]\n" - "addvl x21, x21, #2\n" - ".inst 0x80917660 // fmopa za0.s, p5/M, p3/M, z19.s, z17.s\n" - ".inst 0x80905661 // fmopa za1.s, p5/M, p2/M, z19.s, z16.s\n" - ".inst 0x80917242 // fmopa za2.s, p4/M, p3/M, z18.s, z17.s\n" - ".inst 0x80905243 // fmopa za3.s, p4/M, p2/M, z18.s, z16.s\n" - "bgt 9b\n" - "10:" // K loop: Tail: End - "incw x16, ALL, MUL #2\n" - "add x21, x17, x14, LSL #1\n" - "cmp x16, x11\n" - "cntw x20, ALL, MUL #2\n" - "csel x17, x17, x21, LT\n" // m := (n + block_width < N) ? m : m + height - "csel x16, x16, XZR, LT\n" // n := (n + block_width < N) ? n + block_width : 0 - "sub x23, x8, x17\n" - "csel x10, x10, x28, LT\n" // aptr0 := (n + block_width < N) ? aptr0 : aptr - "whilelt p5.s, XZR, x23\n" - "whilelt p4.s, x14, x23\n" - "cmp x23, x20\n" - "mov x28, x10\n" - "csel x23, x23, x20, LT\n" - "tbnz x15, #2, 24f\n" - "tbnz x15, #1, 26f\n" - "tbz x15, #3, 14f\n" - "mov x22, XZR\n" - "mov p11.b, p3.b\n" - "subs x21, x13, x22\n" - "mov p10.b, p2.b\n" - "ptrue p9.s\n" - "ptrue p8.s\n" - "cntw x20\n" - "ble 13f\n" - "cmp x21, x14\n" - "incw x22\n" - "csel x21, x21, x14, LT\n" - "mov x12, XZR\n" - "11:" // Store accumulators: Drain to output array: Skip activation: Accumulator row 0: Loop - ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25306900 // psel p0.s, p10.s/Z, p8.s[w12]\n" - ".inst 0xe0bf0700 // st1w { za0h.s[x12] }, p1/Z, [x24, XZR, LSL #2]\n" - ".inst 0xe0b40304 // st1w { za1h.s[x12] }, p0/Z, [x24, x20, LSL #2]\n" - "add x12, x12, #0x1\n" - "add x24, x24, x25\n" - "cmp x12, x21\n" - "blt 11b\n" - "subs x21, x13, x22\n" - "ble 13f\n" - "cmp x21, x14\n" - "mov x12, XZR\n" - "csel x21, x21, x14, LT\n" - "12:" // Store accumulators: Drain to output array: Skip activation: Accumulator row 1: Loop - ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" - ".inst 0x25306900 // psel p0.s, p10.s/Z, p8.s[w12]\n" - ".inst 0xe0bf0708 // st1w { za2h.s[x12] }, p1/Z, [x24, XZR, LSL #2]\n" - ".inst 0xe0b4030c // st1w { za3h.s[x12] }, p0/Z, [x24, x20, LSL #2]\n" - "add x12, x12, #0x1\n" - "add x24, x24, x25\n" - "cmp x12, x21\n" - "blt 12b\n" - "13:" // Store accumulators: Drain to output array: Skip activation: End - "tbnz x15, #0, 28f\n" - "b 30f\n" - "14:" // Store accumulators: Drain to output array: Activate - "mov x22, XZR\n" - "subs x21, x13, x22\n" - "ble 23f\n" - "cmp x21, x14\n" - "incw x22\n" - "csel x21, x21, x14, LT\n" - "mov x12, XZR\n" - "ands x20, x21, #0xfffffffffffffffe\n" - "beq 17f\n" - ".inst 0xc0820c17 // mova z23.s, p3/M, za0h.s[x12]\n" - ".inst 0xc0820896 // mova z22.s, p2/M, za1h.s[x12]\n" - ".inst 0xc0820c35 // mova z21.s, p3/M, za0h.s[x12, #1]\n" - ".inst 0xc08208b4 // mova z20.s, p2/M, za1h.s[x12, #1]\n" - "add x12, x12, #0x2\n" - "cmp x12, x20\n" - "beq 16f\n" - "15:" // Store accumulators: Drain to output array: Accumulator row 0: Loop - "movprfx z19, z23\n fmin z19.s, p6/M, z19.s, z0.s\n" - "movprfx z18, z22\n fmin z18.s, p6/M, z18.s, z0.s\n" - ".inst 0xc0820c17 // mova z23.s, p3/M, za0h.s[x12]\n" - "movprfx z17, z21\n fmin z17.s, p6/M, z17.s, z0.s\n" - "movprfx z16, z20\n fmin z16.s, p6/M, z16.s, z0.s\n" - ".inst 0xc0820896 // mova z22.s, p2/M, za1h.s[x12]\n" - ".inst 0xc0820c35 // mova z21.s, p3/M, za0h.s[x12, #1]\n" - ".inst 0xc08208b4 // mova z20.s, p2/M, za1h.s[x12, #1]\n" - "add x12, x12, #0x2\n" - "fmax z19.s, p6/M, z19.s, z1.s\n" - "fmax z18.s, p6/M, z18.s, z1.s\n" - "cmp x12, x20\n" - "fmax z17.s, p6/M, z17.s, z1.s\n" - "fmax z16.s, p6/M, z16.s, z1.s\n" - "stnt1w { z19.s }, p3, [x24]\n" - "stnt1w { z18.s }, p2, [x24, #1, MUL VL]\n" - "add x24, x24, x25\n" - "stnt1w { z17.s }, p3, [x24]\n" - "stnt1w { z16.s }, p2, [x24, #1, MUL VL]\n" - "add x24, x24, x25\n" - "blt 15b\n" - "16:" // Store accumulators: Drain to output array: Accumulator row 0: Tail - "movprfx z19, z23\n fmin z19.s, p6/M, z19.s, z0.s\n" - "movprfx z18, z22\n fmin z18.s, p6/M, z18.s, z0.s\n" - "cmp x12, x21\n" - "movprfx z17, z21\n fmin z17.s, p6/M, z17.s, z0.s\n" - "movprfx z16, z20\n fmin z16.s, p6/M, z16.s, z0.s\n" - "fmax z19.s, p6/M, z19.s, z1.s\n" - "fmax z18.s, p6/M, z18.s, z1.s\n" - "fmax z17.s, p6/M, z17.s, z1.s\n" - "fmax z16.s, p6/M, z16.s, z1.s\n" - "stnt1w { z19.s }, p3, [x24]\n" - "stnt1w { z18.s }, p2, [x24, #1, MUL VL]\n" - "add x24, x24, x25\n" - "stnt1w { z17.s }, p3, [x24]\n" - "stnt1w { z16.s }, p2, [x24, #1, MUL VL]\n" - "add x24, x24, x25\n" - "beq 18f\n" - "17:" // Store accumulators: Drain to output array: Accumulator row 0: Tail loop - ".inst 0xc0820c11 // mova z17.s, p3/M, za0h.s[x12]\n" - ".inst 0xc0820890 // mova z16.s, p2/M, za1h.s[x12]\n" - "fmin z17.s, p6/M, z17.s, z0.s\n" - "add x12, x12, #0x1\n" - "fmin z16.s, p6/M, z16.s, z0.s\n" - "cmp x12, x21\n" - "fmax z17.s, p6/M, z17.s, z1.s\n" - "fmax z16.s, p6/M, z16.s, z1.s\n" - "stnt1w { z17.s }, p3, [x24]\n" - "stnt1w { z16.s }, p2, [x24, #1, MUL VL]\n" - "add x24, x24, x25\n" - "blt 17b\n" - "18:" // Store accumulators: Drain to output array: Accumulator row 0: End - "subs x21, x13, x22\n" - "ble 23f\n" - "cmp x21, x14\n" - "mov x12, XZR\n" - "csel x21, x21, x14, LT\n" - "ands x20, x21, #0xfffffffffffffffe\n" - "beq 21f\n" - ".inst 0xc0820d17 // mova z23.s, p3/M, za2h.s[x12]\n" - ".inst 0xc0820996 // mova z22.s, p2/M, za3h.s[x12]\n" - ".inst 0xc0820d35 // mova z21.s, p3/M, za2h.s[x12, #1]\n" - ".inst 0xc08209b4 // mova z20.s, p2/M, za3h.s[x12, #1]\n" - "add x12, x12, #0x2\n" - "cmp x12, x20\n" - "beq 20f\n" - "19:" // Store accumulators: Drain to output array: Accumulator row 1: Loop - "movprfx z19, z23\n fmin z19.s, p6/M, z19.s, z0.s\n" - "movprfx z18, z22\n fmin z18.s, p6/M, z18.s, z0.s\n" - ".inst 0xc0820d17 // mova z23.s, p3/M, za2h.s[x12]\n" - "movprfx z17, z21\n fmin z17.s, p6/M, z17.s, z0.s\n" - "movprfx z16, z20\n fmin z16.s, p6/M, z16.s, z0.s\n" - ".inst 0xc0820996 // mova z22.s, p2/M, za3h.s[x12]\n" - ".inst 0xc0820d35 // mova z21.s, p3/M, za2h.s[x12, #1]\n" - ".inst 0xc08209b4 // mova z20.s, p2/M, za3h.s[x12, #1]\n" - "add x12, x12, #0x2\n" - "fmax z19.s, p6/M, z19.s, z1.s\n" - "fmax z18.s, p6/M, z18.s, z1.s\n" - "cmp x12, x20\n" - "fmax z17.s, p6/M, z17.s, z1.s\n" - "fmax z16.s, p6/M, z16.s, z1.s\n" - "stnt1w { z19.s }, p3, [x24]\n" - "stnt1w { z18.s }, p2, [x24, #1, MUL VL]\n" - "add x24, x24, x25\n" - "stnt1w { z17.s }, p3, [x24]\n" - "stnt1w { z16.s }, p2, [x24, #1, MUL VL]\n" - "add x24, x24, x25\n" - "blt 19b\n" - "20:" // Store accumulators: Drain to output array: Accumulator row 1: Tail - "movprfx z19, z23\n fmin z19.s, p6/M, z19.s, z0.s\n" - "movprfx z18, z22\n fmin z18.s, p6/M, z18.s, z0.s\n" - "cmp x12, x21\n" - "movprfx z17, z21\n fmin z17.s, p6/M, z17.s, z0.s\n" - "movprfx z16, z20\n fmin z16.s, p6/M, z16.s, z0.s\n" - "fmax z19.s, p6/M, z19.s, z1.s\n" - "fmax z18.s, p6/M, z18.s, z1.s\n" - "fmax z17.s, p6/M, z17.s, z1.s\n" - "fmax z16.s, p6/M, z16.s, z1.s\n" - "stnt1w { z19.s }, p3, [x24]\n" - "stnt1w { z18.s }, p2, [x24, #1, MUL VL]\n" - "add x24, x24, x25\n" - "stnt1w { z17.s }, p3, [x24]\n" - "stnt1w { z16.s }, p2, [x24, #1, MUL VL]\n" - "add x24, x24, x25\n" - "beq 22f\n" - "21:" // Store accumulators: Drain to output array: Accumulator row 1: Tail loop - ".inst 0xc0820d11 // mova z17.s, p3/M, za2h.s[x12]\n" - ".inst 0xc0820990 // mova z16.s, p2/M, za3h.s[x12]\n" - "fmin z17.s, p6/M, z17.s, z0.s\n" - "add x12, x12, #0x1\n" - "fmin z16.s, p6/M, z16.s, z0.s\n" - "cmp x12, x21\n" - "fmax z17.s, p6/M, z17.s, z1.s\n" - "fmax z16.s, p6/M, z16.s, z1.s\n" - "stnt1w { z17.s }, p3, [x24]\n" - "stnt1w { z16.s }, p2, [x24, #1, MUL VL]\n" - "add x24, x24, x25\n" - "blt 21b\n" - "22:" // Store accumulators: Drain to output array: Accumulator row 1: End - "23:" // Store accumulators: Drain to output array: End - "tbnz x15, #0, 28f\n" - "b 30f\n" - "24:" // Store accumulators: Drain to, and fill from buffer - "cmp x17, x8\n" - "bge 26f\n" // If there's no next block to load, then just drain. - "ptrue p11.s\n" - "ptrue p10.s\n" - "ptrue p9.s\n" - "ptrue p8.s\n" - "cntw x21, ALL, MUL #2\n" - "cntw x20, ALL, MUL #3\n" - "mov x12, XZR\n" - "25:" // Store accumulators: Drain to, and fill from buffer: Loop - ".inst 0x25306121 // psel p1.s, p8.s/Z, p9.s[w12]\n" - ".inst 0x25306960 // psel p0.s, p10.s/Z, p11.s[w12]\n" - ".inst 0x25306122 // psel p2.s, p8.s/Z, p9.s[w12]\n" - ".inst 0xe0bf0760 // st1w { za0h.s[x12] }, p1/Z, [x27, XZR, LSL #2]\n" - ".inst 0x25306961 // psel p1.s, p10.s/Z, p11.s[w12]\n" - ".inst 0xe09f0120 // ld1w { za0h.s[x12] }, p0/Z, [x9, XZR, LSL #2]\n" - ".inst 0x25306120 // psel p0.s, p8.s/Z, p9.s[w12]\n" - ".inst 0xe0ae0b64 // st1w { za1h.s[x12] }, p2/Z, [x27, x14, LSL #2]\n" - ".inst 0x25306962 // psel p2.s, p10.s/Z, p11.s[w12]\n" - ".inst 0xe08e0524 // ld1w { za1h.s[x12] }, p1/Z, [x9, x14, LSL #2]\n" - ".inst 0x25306121 // psel p1.s, p8.s/Z, p9.s[w12]\n" - ".inst 0xe0b50368 // st1w { za2h.s[x12] }, p0/Z, [x27, x21, LSL #2]\n" - ".inst 0x25306960 // psel p0.s, p10.s/Z, p11.s[w12]\n" - ".inst 0xe0950928 // ld1w { za2h.s[x12] }, p2/Z, [x9, x21, LSL #2]\n" - ".inst 0xe0b4076c // st1w { za3h.s[x12] }, p1/Z, [x27, x20, LSL #2]\n" - ".inst 0xe094012c // ld1w { za3h.s[x12] }, p0/Z, [x9, x20, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x27, x27, #4\n" - "cmp x12, x14\n" - "addvl x9, x9, #4\n" - "blt 25b\n" - "b 30f\n" - "26:" // Store accumulators: Drain to buffer - "ptrue p9.s\n" - "ptrue p8.s\n" - "cntw x21, ALL, MUL #2\n" - "cntw x20, ALL, MUL #3\n" - "mov x12, XZR\n" - "27:" // Store accumulators: Drain to buffer: Loop - ".inst 0x25306120 // psel p0.s, p8.s/Z, p9.s[w12]\n" - ".inst 0x25306122 // psel p2.s, p8.s/Z, p9.s[w12]\n" - ".inst 0x25306121 // psel p1.s, p8.s/Z, p9.s[w12]\n" - ".inst 0xe0bf0360 // st1w { za0h.s[x12] }, p0/Z, [x27, XZR, LSL #2]\n" - ".inst 0x25306120 // psel p0.s, p8.s/Z, p9.s[w12]\n" - ".inst 0xe0ae0b64 // st1w { za1h.s[x12] }, p2/Z, [x27, x14, LSL #2]\n" - ".inst 0xe0b50768 // st1w { za2h.s[x12] }, p1/Z, [x27, x21, LSL #2]\n" - ".inst 0xe0b4036c // st1w { za3h.s[x12] }, p0/Z, [x27, x20, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x27, x27, #4\n" - "cmp x12, x14\n" - "blt 27b\n" - "b 30f\n" - "28:" // Store accumulators: Fill from buffer - "cmp x17, x8\n" - "bge 30f\n" - "ptrue p11.s\n" - "ptrue p10.s\n" - "cntw x21, ALL, MUL #2\n" - "cntw x20, ALL, MUL #3\n" - "mov x12, XZR\n" - "29:" // Store accumulators: Fill from buffer: Loop - ".inst 0x25306960 // psel p0.s, p10.s/Z, p11.s[w12]\n" - ".inst 0x25306962 // psel p2.s, p10.s/Z, p11.s[w12]\n" - ".inst 0x25306961 // psel p1.s, p10.s/Z, p11.s[w12]\n" - ".inst 0xe09f0120 // ld1w { za0h.s[x12] }, p0/Z, [x9, XZR, LSL #2]\n" - ".inst 0x25306960 // psel p0.s, p10.s/Z, p11.s[w12]\n" - ".inst 0xe08e0924 // ld1w { za1h.s[x12] }, p2/Z, [x9, x14, LSL #2]\n" - ".inst 0xe0950528 // ld1w { za2h.s[x12] }, p1/Z, [x9, x21, LSL #2]\n" - ".inst 0xe094012c // ld1w { za3h.s[x12] }, p0/Z, [x9, x20, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x9, x9, #4\n" - "cmp x12, x14\n" - "blt 29b\n" - "30:" // Store accumulators: End - "cmp x17, x8\n" - "mov x13, x23\n" - "blt 3b\n" - ".inst 0xd503467f // SMSTOP\n" - : - : [args] "r"(&args), [offset_max] "I"(offsetof(KernelArgs, max)), [offset_min] "I"(offsetof(KernelArgs, min)), - [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), - [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), - [offsetof_M] "I"(offsetof(KernelArgs, M)), [offsetof_N] "I"(offsetof(KernelArgs, N)), - [offsetof_accumulator_buffer] "I"(offsetof(KernelArgs, accumulator_buffer)), - [offsetof_flags] "I"(offsetof(KernelArgs, flags)), - [offsetof_kstride_bytes] "I"(offsetof(KernelArgs, kstride_bytes)), - [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25", - "x26", "x27", "x28", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", - "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", - "z5", "z6", "z7", "z8", "z9"); + kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h index 7130c4c3..57287c0a 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h @@ -14,8 +14,9 @@ extern "C" { /// Micro-kernel dependencies /// -/// -# kai_lhs_pack_f32p2vlx1_f32_sme to pack the LHS matrix. -/// -# kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme to pack the RHS matrix. +/// TODO: Fix names auto generated below: +/// -# kai_lhs_pack_f32pvlx1_f32_sme to pack the LHS matrix. +/// -# kai_rhs_pack_kxn_f32pvlx1b_f32_f32_sme to pack the RHS matrix. /// Gets m step value. /// diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S new file mode 100644 index 00000000..f3f9774a --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S @@ -0,0 +1,522 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x8, XZR + ldr w17, [x0, #0x0] + cntw x21, ALL, MUL #2 + ptrue p6.b + ldr x20, [x0, #0x50] + cntw x16 + ld1rw { z1.s }, p6/Z, [x0, #40] + mov x15, XZR + ldr x14, [x0, #0x18] + ld1rw { z0.s }, p6/Z, [x0, #44] + sub x13, x17, x8 + ldr w11, [x0, #0x8] + cmp x13, x21 + mov x10, x20 + ldr x9, [x0, #0x20] + csel x13, x13, x21, LT // height = min(M - m, acc_height) + mov x28, x10 + ldr x27, [x0, #0x20] + whilelt p5.s, XZR, x13 + whilelt p4.s, x16, x13 + tbz x14, #0, label_2 + ptrue p11.s + ptrue p10.s + cntw x21, ALL, MUL #2 + cntw x20, ALL, MUL #3 + mov x12, XZR +KAI_ASM_LABEL(label_1) // Initialise accumulators (first block): Loop + KAI_ASM_INST(0x25306960) // psel p0.s, p10.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306962) // psel p2.s, p10.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306961) // psel p1.s, p10.s/Z, p11.s[w12] + KAI_ASM_INST(0xe09f0120) // ld1w { za0h.s[x12] }, p0/Z, [x9, XZR, LSL #2] + KAI_ASM_INST(0x25306960) // psel p0.s, p10.s/Z, p11.s[w12] + KAI_ASM_INST(0xe0900924) // ld1w { za1h.s[x12] }, p2/Z, [x9, x16, LSL #2] + KAI_ASM_INST(0xe0950528) // ld1w { za2h.s[x12] }, p1/Z, [x9, x21, LSL #2] + KAI_ASM_INST(0xe094012c) // ld1w { za3h.s[x12] }, p0/Z, [x9, x20, LSL #2] + add x12, x12, #0x1 + addvl x9, x9, #4 + cmp x12, x16 + blt label_1 +KAI_ASM_LABEL(label_2) // Initialise accumulators (first block): End +KAI_ASM_LABEL(label_3) // Outer loop + cntw x20, ALL, MUL #2 + sub x26, x11, x15 + ldr x23, [x0, #0x30] + ldr x25, [x0, #0x38] + cmp x26, x20 + mov x22, XZR + ldr x21, [x0, #0x40] + csel x26, x26, x20, LT // width = min(N - n, acc_width) + ldr x20, [x0, #0x48] + whilelt p3.s, x22, x26 + incw x22 + madd x24, x8, x25, x23 // cptr = C + m * ldcb + whilelt p2.s, x22, x26 + add x24, x24, x15, LSL #2 // cptr += n * sizeof(T) + madd x21, x15, x20, x21 // bptr = B + n * k_stride_bytes + tbnz x14, #0, label_4 + KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } + cbz x21, label_4 + mov p1.b, p3.b + mov p0.b, p2.b + fmov z18.s, #1.0 + ld1w { z17.s }, p1/Z, [x21] + ld1w { z16.s }, p0/Z, [x21, #1, MUL VL] + addvl x21, x21, #2 + KAI_ASM_INST(0x80917640) // fmopa za0.s, p5/M, p3/M, z18.s, z17.s + KAI_ASM_INST(0x80905641) // fmopa za1.s, p5/M, p2/M, z18.s, z16.s + KAI_ASM_INST(0x80917242) // fmopa za2.s, p4/M, p3/M, z18.s, z17.s + KAI_ASM_INST(0x80905243) // fmopa za3.s, p4/M, p2/M, z18.s, z16.s +KAI_ASM_LABEL(label_4) // Initialise accumulators: End + ldr x20, [x0, #0x10] + cmp x20, #0x4 + ble label_8 + cmp x20, #0x8 + ld1w { z31.s }, p5/Z, [x28] + ld1w { z30.s }, p4/Z, [x28, #1, MUL VL] + ldnt1w { z29.s }, p3/Z, [x21] + ldnt1w { z28.s }, p2/Z, [x21, #1, MUL VL] + ld1w { z27.s }, p5/Z, [x28, #2, MUL VL] + ld1w { z26.s }, p4/Z, [x28, #3, MUL VL] + ldnt1w { z25.s }, p3/Z, [x21, #2, MUL VL] + ldnt1w { z24.s }, p2/Z, [x21, #3, MUL VL] + ld1w { z23.s }, p5/Z, [x28, #4, MUL VL] + ld1w { z22.s }, p4/Z, [x28, #5, MUL VL] + ldnt1w { z21.s }, p3/Z, [x21, #4, MUL VL] + ldnt1w { z20.s }, p2/Z, [x21, #5, MUL VL] + ld1w { z19.s }, p5/Z, [x28, #6, MUL VL] + ld1w { z18.s }, p4/Z, [x28, #7, MUL VL] + addvl x28, x28, #8 + ldnt1w { z17.s }, p3/Z, [x21, #6, MUL VL] + ldnt1w { z16.s }, p2/Z, [x21, #7, MUL VL] + addvl x21, x21, #8 + blt label_7 +KAI_ASM_LABEL(label_6) // K loop: Main: Loop + KAI_ASM_INST(0x809d77e0) // fmopa za0.s, p5/M, p3/M, z31.s, z29.s + sub x20, x20, #0x4 + KAI_ASM_INST(0x809c57e1) // fmopa za1.s, p5/M, p2/M, z31.s, z28.s + cmp x20, #0x8 + ld1w { z31.s }, p5/Z, [x28] + KAI_ASM_INST(0x809d73c2) // fmopa za2.s, p4/M, p3/M, z30.s, z29.s + ldnt1w { z29.s }, p3/Z, [x21] + KAI_ASM_INST(0x809c53c3) // fmopa za3.s, p4/M, p2/M, z30.s, z28.s + ldnt1w { z28.s }, p2/Z, [x21, #1, MUL VL] + KAI_ASM_INST(0x80997760) // fmopa za0.s, p5/M, p3/M, z27.s, z25.s + ld1w { z30.s }, p4/Z, [x28, #1, MUL VL] + KAI_ASM_INST(0x80985761) // fmopa za1.s, p5/M, p2/M, z27.s, z24.s + ld1w { z27.s }, p5/Z, [x28, #2, MUL VL] + KAI_ASM_INST(0x80997342) // fmopa za2.s, p4/M, p3/M, z26.s, z25.s + ldnt1w { z25.s }, p3/Z, [x21, #2, MUL VL] + KAI_ASM_INST(0x80985343) // fmopa za3.s, p4/M, p2/M, z26.s, z24.s + ldnt1w { z24.s }, p2/Z, [x21, #3, MUL VL] + KAI_ASM_INST(0x809576e0) // fmopa za0.s, p5/M, p3/M, z23.s, z21.s + ld1w { z26.s }, p4/Z, [x28, #3, MUL VL] + KAI_ASM_INST(0x809456e1) // fmopa za1.s, p5/M, p2/M, z23.s, z20.s + ld1w { z23.s }, p5/Z, [x28, #4, MUL VL] + KAI_ASM_INST(0x809572c2) // fmopa za2.s, p4/M, p3/M, z22.s, z21.s + ldnt1w { z21.s }, p3/Z, [x21, #4, MUL VL] + KAI_ASM_INST(0x809452c3) // fmopa za3.s, p4/M, p2/M, z22.s, z20.s + ldnt1w { z20.s }, p2/Z, [x21, #5, MUL VL] + ld1w { z22.s }, p4/Z, [x28, #5, MUL VL] + KAI_ASM_INST(0x80917660) // fmopa za0.s, p5/M, p3/M, z19.s, z17.s + KAI_ASM_INST(0x80905661) // fmopa za1.s, p5/M, p2/M, z19.s, z16.s + ld1w { z19.s }, p5/Z, [x28, #6, MUL VL] + KAI_ASM_INST(0x80917242) // fmopa za2.s, p4/M, p3/M, z18.s, z17.s + ldnt1w { z17.s }, p3/Z, [x21, #6, MUL VL] + KAI_ASM_INST(0x80905243) // fmopa za3.s, p4/M, p2/M, z18.s, z16.s + ldnt1w { z16.s }, p2/Z, [x21, #7, MUL VL] + addvl x21, x21, #8 + ld1w { z18.s }, p4/Z, [x28, #7, MUL VL] + addvl x28, x28, #8 + bge label_6 +KAI_ASM_LABEL(label_7) // K loop: Main: Detached iter + KAI_ASM_INST(0x809d77e0) // fmopa za0.s, p5/M, p3/M, z31.s, z29.s + sub x20, x20, #0x4 + KAI_ASM_INST(0x809c57e1) // fmopa za1.s, p5/M, p2/M, z31.s, z28.s + KAI_ASM_INST(0x809d73c2) // fmopa za2.s, p4/M, p3/M, z30.s, z29.s + KAI_ASM_INST(0x809c53c3) // fmopa za3.s, p4/M, p2/M, z30.s, z28.s + KAI_ASM_INST(0x80997760) // fmopa za0.s, p5/M, p3/M, z27.s, z25.s + KAI_ASM_INST(0x80985761) // fmopa za1.s, p5/M, p2/M, z27.s, z24.s + KAI_ASM_INST(0x80997342) // fmopa za2.s, p4/M, p3/M, z26.s, z25.s + KAI_ASM_INST(0x80985343) // fmopa za3.s, p4/M, p2/M, z26.s, z24.s + KAI_ASM_INST(0x809576e0) // fmopa za0.s, p5/M, p3/M, z23.s, z21.s + KAI_ASM_INST(0x809456e1) // fmopa za1.s, p5/M, p2/M, z23.s, z20.s + KAI_ASM_INST(0x809572c2) // fmopa za2.s, p4/M, p3/M, z22.s, z21.s + KAI_ASM_INST(0x809452c3) // fmopa za3.s, p4/M, p2/M, z22.s, z20.s + KAI_ASM_INST(0x80917660) // fmopa za0.s, p5/M, p3/M, z19.s, z17.s + KAI_ASM_INST(0x80905661) // fmopa za1.s, p5/M, p2/M, z19.s, z16.s + KAI_ASM_INST(0x80917242) // fmopa za2.s, p4/M, p3/M, z18.s, z17.s + KAI_ASM_INST(0x80905243) // fmopa za3.s, p4/M, p2/M, z18.s, z16.s +KAI_ASM_LABEL(label_8) // K loop: Tail + cbz x20, label_10 +KAI_ASM_LABEL(label_9) // K loop: Tail: Loop + ld1w { z19.s }, p5/Z, [x28] + sub x20, x20, #0x1 + ld1w { z18.s }, p4/Z, [x28, #1, MUL VL] + cmp x20, XZR + addvl x28, x28, #2 + ldnt1w { z17.s }, p3/Z, [x21] + ldnt1w { z16.s }, p2/Z, [x21, #1, MUL VL] + addvl x21, x21, #2 + KAI_ASM_INST(0x80917660) // fmopa za0.s, p5/M, p3/M, z19.s, z17.s + KAI_ASM_INST(0x80905661) // fmopa za1.s, p5/M, p2/M, z19.s, z16.s + KAI_ASM_INST(0x80917242) // fmopa za2.s, p4/M, p3/M, z18.s, z17.s + KAI_ASM_INST(0x80905243) // fmopa za3.s, p4/M, p2/M, z18.s, z16.s + bgt label_9 +KAI_ASM_LABEL(label_10) // K loop: Tail: End + incw x15, ALL, MUL #2 + add x21, x8, x16, LSL #1 + cmp x15, x11 + cntw x20, ALL, MUL #2 + csel x8, x8, x21, LT // m := (n + block_width < N) ? m : m + height + csel x15, x15, XZR, LT // n := (n + block_width < N) ? n + block_width : 0 + sub x23, x17, x8 + csel x10, x10, x28, LT // aptr0 := (n + block_width < N) ? aptr0 : aptr + whilelt p5.s, XZR, x23 + whilelt p4.s, x16, x23 + cmp x23, x20 + mov x28, x10 + csel x23, x23, x20, LT + tbnz x14, #2, label_24 + tbnz x14, #1, label_26 + tbz x14, #3, label_14 + mov x22, XZR + mov p11.b, p3.b + subs x21, x13, x22 + mov p10.b, p2.b + ptrue p9.s + ptrue p8.s + cntw x20 + ble label_13 + cmp x21, x16 + incw x22 + csel x21, x21, x16, LT + mov x12, XZR +KAI_ASM_LABEL(label_11) // Store accumulators: Drain to output array: Skip activation: Accumulator row 0: Loop + KAI_ASM_INST(0x25306d21) // psel p1.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306900) // psel p0.s, p10.s/Z, p8.s[w12] + KAI_ASM_INST(0xe0bf0700) // st1w { za0h.s[x12] }, p1/Z, [x24, XZR, LSL #2] + KAI_ASM_INST(0xe0b40304) // st1w { za1h.s[x12] }, p0/Z, [x24, x20, LSL #2] + add x12, x12, #0x1 + add x24, x24, x25 + cmp x12, x21 + blt label_11 + subs x21, x13, x22 + ble label_13 + cmp x21, x16 + mov x12, XZR + csel x21, x21, x16, LT +KAI_ASM_LABEL(label_12) // Store accumulators: Drain to output array: Skip activation: Accumulator row 1: Loop + KAI_ASM_INST(0x25306d21) // psel p1.s, p11.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306900) // psel p0.s, p10.s/Z, p8.s[w12] + KAI_ASM_INST(0xe0bf0708) // st1w { za2h.s[x12] }, p1/Z, [x24, XZR, LSL #2] + KAI_ASM_INST(0xe0b4030c) // st1w { za3h.s[x12] }, p0/Z, [x24, x20, LSL #2] + add x12, x12, #0x1 + add x24, x24, x25 + cmp x12, x21 + blt label_12 +KAI_ASM_LABEL(label_13) // Store accumulators: Drain to output array: Skip activation: End + tbnz x14, #0, label_28 + b label_30 +KAI_ASM_LABEL(label_14) // Store accumulators: Drain to output array: Activate + mov x22, XZR + subs x21, x13, x22 + ble label_23 + cmp x21, x16 + incw x22 + csel x21, x21, x16, LT + mov x12, XZR + ands x20, x21, #0xfffffffffffffffe + beq label_17 + KAI_ASM_INST(0xc0820c17) // mova z23.s, p3/M, za0h.s[x12] + KAI_ASM_INST(0xc0820896) // mova z22.s, p2/M, za1h.s[x12] + KAI_ASM_INST(0xc0820c35) // mova z21.s, p3/M, za0h.s[x12, #1] + KAI_ASM_INST(0xc08208b4) // mova z20.s, p2/M, za1h.s[x12, #1] + add x12, x12, #0x2 + cmp x12, x20 + beq label_16 +KAI_ASM_LABEL(label_15) // Store accumulators: Drain to output array: Accumulator row 0: Loop + movprfx z19, z23 + fmin z19.s, p6/M, z19.s, z0.s + movprfx z18, z22 + fmin z18.s, p6/M, z18.s, z0.s + KAI_ASM_INST(0xc0820c17) // mova z23.s, p3/M, za0h.s[x12] + movprfx z17, z21 + fmin z17.s, p6/M, z17.s, z0.s + movprfx z16, z20 + fmin z16.s, p6/M, z16.s, z0.s + KAI_ASM_INST(0xc0820896) // mova z22.s, p2/M, za1h.s[x12] + KAI_ASM_INST(0xc0820c35) // mova z21.s, p3/M, za0h.s[x12, #1] + KAI_ASM_INST(0xc08208b4) // mova z20.s, p2/M, za1h.s[x12, #1] + add x12, x12, #0x2 + fmax z19.s, p6/M, z19.s, z1.s + fmax z18.s, p6/M, z18.s, z1.s + cmp x12, x20 + fmax z17.s, p6/M, z17.s, z1.s + fmax z16.s, p6/M, z16.s, z1.s + stnt1w { z19.s }, p3, [x24] + stnt1w { z18.s }, p2, [x24, #1, MUL VL] + add x24, x24, x25 + stnt1w { z17.s }, p3, [x24] + stnt1w { z16.s }, p2, [x24, #1, MUL VL] + add x24, x24, x25 + blt label_15 +KAI_ASM_LABEL(label_16) // Store accumulators: Drain to output array: Accumulator row 0: Tail + movprfx z19, z23 + fmin z19.s, p6/M, z19.s, z0.s + movprfx z18, z22 + fmin z18.s, p6/M, z18.s, z0.s + cmp x12, x21 + movprfx z17, z21 + fmin z17.s, p6/M, z17.s, z0.s + movprfx z16, z20 + fmin z16.s, p6/M, z16.s, z0.s + fmax z19.s, p6/M, z19.s, z1.s + fmax z18.s, p6/M, z18.s, z1.s + fmax z17.s, p6/M, z17.s, z1.s + fmax z16.s, p6/M, z16.s, z1.s + stnt1w { z19.s }, p3, [x24] + stnt1w { z18.s }, p2, [x24, #1, MUL VL] + add x24, x24, x25 + stnt1w { z17.s }, p3, [x24] + stnt1w { z16.s }, p2, [x24, #1, MUL VL] + add x24, x24, x25 + beq label_18 +KAI_ASM_LABEL(label_17) // Store accumulators: Drain to output array: Accumulator row 0: Tail loop + KAI_ASM_INST(0xc0820c11) // mova z17.s, p3/M, za0h.s[x12] + KAI_ASM_INST(0xc0820890) // mova z16.s, p2/M, za1h.s[x12] + fmin z17.s, p6/M, z17.s, z0.s + add x12, x12, #0x1 + fmin z16.s, p6/M, z16.s, z0.s + cmp x12, x21 + fmax z17.s, p6/M, z17.s, z1.s + fmax z16.s, p6/M, z16.s, z1.s + stnt1w { z17.s }, p3, [x24] + stnt1w { z16.s }, p2, [x24, #1, MUL VL] + add x24, x24, x25 + blt label_17 +KAI_ASM_LABEL(label_18) // Store accumulators: Drain to output array: Accumulator row 0: End + subs x21, x13, x22 + ble label_23 + cmp x21, x16 + mov x12, XZR + csel x21, x21, x16, LT + ands x20, x21, #0xfffffffffffffffe + beq label_21 + KAI_ASM_INST(0xc0820d17) // mova z23.s, p3/M, za2h.s[x12] + KAI_ASM_INST(0xc0820996) // mova z22.s, p2/M, za3h.s[x12] + KAI_ASM_INST(0xc0820d35) // mova z21.s, p3/M, za2h.s[x12, #1] + KAI_ASM_INST(0xc08209b4) // mova z20.s, p2/M, za3h.s[x12, #1] + add x12, x12, #0x2 + cmp x12, x20 + beq label_20 +KAI_ASM_LABEL(label_19) // Store accumulators: Drain to output array: Accumulator row 1: Loop + movprfx z19, z23 + fmin z19.s, p6/M, z19.s, z0.s + movprfx z18, z22 + fmin z18.s, p6/M, z18.s, z0.s + KAI_ASM_INST(0xc0820d17) // mova z23.s, p3/M, za2h.s[x12] + movprfx z17, z21 + fmin z17.s, p6/M, z17.s, z0.s + movprfx z16, z20 + fmin z16.s, p6/M, z16.s, z0.s + KAI_ASM_INST(0xc0820996) // mova z22.s, p2/M, za3h.s[x12] + KAI_ASM_INST(0xc0820d35) // mova z21.s, p3/M, za2h.s[x12, #1] + KAI_ASM_INST(0xc08209b4) // mova z20.s, p2/M, za3h.s[x12, #1] + add x12, x12, #0x2 + fmax z19.s, p6/M, z19.s, z1.s + fmax z18.s, p6/M, z18.s, z1.s + cmp x12, x20 + fmax z17.s, p6/M, z17.s, z1.s + fmax z16.s, p6/M, z16.s, z1.s + stnt1w { z19.s }, p3, [x24] + stnt1w { z18.s }, p2, [x24, #1, MUL VL] + add x24, x24, x25 + stnt1w { z17.s }, p3, [x24] + stnt1w { z16.s }, p2, [x24, #1, MUL VL] + add x24, x24, x25 + blt label_19 +KAI_ASM_LABEL(label_20) // Store accumulators: Drain to output array: Accumulator row 1: Tail + movprfx z19, z23 + fmin z19.s, p6/M, z19.s, z0.s + movprfx z18, z22 + fmin z18.s, p6/M, z18.s, z0.s + cmp x12, x21 + movprfx z17, z21 + fmin z17.s, p6/M, z17.s, z0.s + movprfx z16, z20 + fmin z16.s, p6/M, z16.s, z0.s + fmax z19.s, p6/M, z19.s, z1.s + fmax z18.s, p6/M, z18.s, z1.s + fmax z17.s, p6/M, z17.s, z1.s + fmax z16.s, p6/M, z16.s, z1.s + stnt1w { z19.s }, p3, [x24] + stnt1w { z18.s }, p2, [x24, #1, MUL VL] + add x24, x24, x25 + stnt1w { z17.s }, p3, [x24] + stnt1w { z16.s }, p2, [x24, #1, MUL VL] + add x24, x24, x25 + beq label_22 +KAI_ASM_LABEL(label_21) // Store accumulators: Drain to output array: Accumulator row 1: Tail loop + KAI_ASM_INST(0xc0820d11) // mova z17.s, p3/M, za2h.s[x12] + KAI_ASM_INST(0xc0820990) // mova z16.s, p2/M, za3h.s[x12] + fmin z17.s, p6/M, z17.s, z0.s + add x12, x12, #0x1 + fmin z16.s, p6/M, z16.s, z0.s + cmp x12, x21 + fmax z17.s, p6/M, z17.s, z1.s + fmax z16.s, p6/M, z16.s, z1.s + stnt1w { z17.s }, p3, [x24] + stnt1w { z16.s }, p2, [x24, #1, MUL VL] + add x24, x24, x25 + blt label_21 +KAI_ASM_LABEL(label_22) // Store accumulators: Drain to output array: Accumulator row 1: End +KAI_ASM_LABEL(label_23) // Store accumulators: Drain to output array: End + tbnz x14, #0, label_28 + b label_30 +KAI_ASM_LABEL(label_24) // Store accumulators: Drain to, and fill from buffer + cmp x8, x17 + bge label_26 // If there's no next block to load, then just drain. + ptrue p11.s + ptrue p10.s + ptrue p9.s + ptrue p8.s + cntw x21, ALL, MUL #2 + cntw x20, ALL, MUL #3 + mov x12, XZR +KAI_ASM_LABEL(label_25) // Store accumulators: Drain to, and fill from buffer: Loop + KAI_ASM_INST(0x25306121) // psel p1.s, p8.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306960) // psel p0.s, p10.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306122) // psel p2.s, p8.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf0760) // st1w { za0h.s[x12] }, p1/Z, [x27, XZR, LSL #2] + KAI_ASM_INST(0x25306961) // psel p1.s, p10.s/Z, p11.s[w12] + KAI_ASM_INST(0xe09f0120) // ld1w { za0h.s[x12] }, p0/Z, [x9, XZR, LSL #2] + KAI_ASM_INST(0x25306120) // psel p0.s, p8.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0b00b64) // st1w { za1h.s[x12] }, p2/Z, [x27, x16, LSL #2] + KAI_ASM_INST(0x25306962) // psel p2.s, p10.s/Z, p11.s[w12] + KAI_ASM_INST(0xe0900524) // ld1w { za1h.s[x12] }, p1/Z, [x9, x16, LSL #2] + KAI_ASM_INST(0x25306121) // psel p1.s, p8.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0b50368) // st1w { za2h.s[x12] }, p0/Z, [x27, x21, LSL #2] + KAI_ASM_INST(0x25306960) // psel p0.s, p10.s/Z, p11.s[w12] + KAI_ASM_INST(0xe0950928) // ld1w { za2h.s[x12] }, p2/Z, [x9, x21, LSL #2] + KAI_ASM_INST(0xe0b4076c) // st1w { za3h.s[x12] }, p1/Z, [x27, x20, LSL #2] + KAI_ASM_INST(0xe094012c) // ld1w { za3h.s[x12] }, p0/Z, [x9, x20, LSL #2] + add x12, x12, #0x1 + addvl x27, x27, #4 + cmp x12, x16 + addvl x9, x9, #4 + blt label_25 + b label_30 +KAI_ASM_LABEL(label_26) // Store accumulators: Drain to buffer + ptrue p9.s + ptrue p8.s + cntw x21, ALL, MUL #2 + cntw x20, ALL, MUL #3 + mov x12, XZR +KAI_ASM_LABEL(label_27) // Store accumulators: Drain to buffer: Loop + KAI_ASM_INST(0x25306120) // psel p0.s, p8.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306122) // psel p2.s, p8.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306121) // psel p1.s, p8.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf0360) // st1w { za0h.s[x12] }, p0/Z, [x27, XZR, LSL #2] + KAI_ASM_INST(0x25306120) // psel p0.s, p8.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0b00b64) // st1w { za1h.s[x12] }, p2/Z, [x27, x16, LSL #2] + KAI_ASM_INST(0xe0b50768) // st1w { za2h.s[x12] }, p1/Z, [x27, x21, LSL #2] + KAI_ASM_INST(0xe0b4036c) // st1w { za3h.s[x12] }, p0/Z, [x27, x20, LSL #2] + add x12, x12, #0x1 + addvl x27, x27, #4 + cmp x12, x16 + blt label_27 + b label_30 +KAI_ASM_LABEL(label_28) // Store accumulators: Fill from buffer + cmp x8, x17 + bge label_30 + ptrue p11.s + ptrue p10.s + cntw x21, ALL, MUL #2 + cntw x20, ALL, MUL #3 + mov x12, XZR +KAI_ASM_LABEL(label_29) // Store accumulators: Fill from buffer: Loop + KAI_ASM_INST(0x25306960) // psel p0.s, p10.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306962) // psel p2.s, p10.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306961) // psel p1.s, p10.s/Z, p11.s[w12] + KAI_ASM_INST(0xe09f0120) // ld1w { za0h.s[x12] }, p0/Z, [x9, XZR, LSL #2] + KAI_ASM_INST(0x25306960) // psel p0.s, p10.s/Z, p11.s[w12] + KAI_ASM_INST(0xe0900924) // ld1w { za1h.s[x12] }, p2/Z, [x9, x16, LSL #2] + KAI_ASM_INST(0xe0950528) // ld1w { za2h.s[x12] }, p1/Z, [x9, x21, LSL #2] + KAI_ASM_INST(0xe094012c) // ld1w { za3h.s[x12] }, p0/Z, [x9, x20, LSL #2] + add x12, x12, #0x1 + addvl x9, x9, #4 + cmp x12, x16 + blt label_29 +KAI_ASM_LABEL(label_30) // Store accumulators: End + cmp x8, x17 + mov x13, x23 + blt label_3 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa) + + KAI_ASM_END -- GitLab From 88cb32bc7d98606af2e51ba309cd0bb59d70a94d Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 19 Jun 2025 15:17:28 +0100 Subject: [PATCH 2/3] Revert all changes in the header file Signed-off-by: Viet-Hoa Do --- ..._matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h index 57287c0a..7130c4c3 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h @@ -14,9 +14,8 @@ extern "C" { /// Micro-kernel dependencies /// -/// TODO: Fix names auto generated below: -/// -# kai_lhs_pack_f32pvlx1_f32_sme to pack the LHS matrix. -/// -# kai_rhs_pack_kxn_f32pvlx1b_f32_f32_sme to pack the RHS matrix. +/// -# kai_lhs_pack_f32p2vlx1_f32_sme to pack the LHS matrix. +/// -# kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme to pack the RHS matrix. /// Gets m step value. /// -- GitLab From 62dff53263aee76ca962df901c30326eb8b23496 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 19 Jun 2025 15:23:25 +0100 Subject: [PATCH 3/3] Fix bazel build issue Signed-off-by: Viet-Hoa Do --- CMakeLists.txt | 2 +- kai/ukernels/matmul/BUILD.bazel | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 01c87fc9..d826540d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -231,6 +231,7 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME_ASM + kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme_asm.S @@ -264,7 +265,6 @@ set(KLEIDIAI_FILES_SME ${KLEIDIAI_FILES_SME_ASM} kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c - kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.c ) set(KLEIDIAI_FILES_SME2_ASM diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index a4f79cba..4e63ebab 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -150,6 +150,7 @@ SME_KERNELS = [ # buildifier: keep sorted SME_KERNELS_ASM = [ + "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa", "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", "pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme", "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme", -- GitLab