From b2f57088f7776e1319d988e6ab8ca7fb02e2f906 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Fri, 25 Oct 2024 16:39:55 +0100 Subject: [PATCH 1/8] FP16 SME2 GEMM Micro Kernels Add FP16 SME2 GEMM kernels along with unit tests. Signed-off-by: Felix Thomasmathibalan --- CMakeLists.txt | 3 + ...l_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c | 254 +++++++++++++ ...l_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.h | 120 ++++++ .../pack/kai_lhs_pack_f16p2vlx2_f16_sme.c | 351 ++++++++++++++++++ .../pack/kai_lhs_pack_f16p2vlx2_f16_sme.h | 77 ++++ .../kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c | 190 ++++++++++ .../kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h | 80 ++++ 7 files changed, 1075 insertions(+) create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.h create mode 100644 kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.c create mode 100644 kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.h create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 649e1585..860fd417 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -124,11 +124,14 @@ set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_16vlx1_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c ) set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pb_1x16vl_sme2_mla.c + kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c ) add_library(kleidiai) diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c new file mode 100644 index 00000000..6ac227bc --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c @@ -0,0 +1,254 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_nr = 2; +static const size_t kai_kr = 1; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u16(); +} + +size_t kai_get_n_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u16(); +} + +size_t kai_get_mr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u16(); +} + +size_t kai_get_nr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u16(); +} + +size_t kai_get_kr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa() == 0); + return m_idx * k * sizeof(__fp16); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa() == 0); + return n_idx * (k * sizeof(__fp16) + sizeof(__fp16)); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa() == 0); + + return m_idx * dst_stride + n_idx * sizeof(__fp16); +} + +size_t kai_get_dst_size_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m, size_t n) { + return m * n * sizeof(__fp16); +} + +void kai_run_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, __fp16 clamp_min, __fp16 clamp_max) { + KAI_ASSUME(dst_stride_col == sizeof(__fp16)); + + typedef struct { + const void* A; + const void* B; + + void* C; + long ldcb; + long M, N, K; + __fp16 min; + __fp16 max; + + void* accumulator_buffer; + uint64_t flags; + } KernelArgs; + + KernelArgs args; + + args.A = lhs_packed; + args.B = rhs_packed; + + args.C = dst; + args.ldcb = dst_stride_row; + args.M = m; + args.N = n; + args.K = k; + args.min = clamp_min; + args.max = clamp_max; + + args.accumulator_buffer = NULL; + args.flags = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ldr w13, [%x[args], %[offsetof_M]]\n" + "mov x11, #0x0\n" + "mov x10, #0x0\n" + "ptrue p1.b\n" + ".inst 0x25207810 // ptrue pn8.b\n" + "ldr w9, [%x[args], %[offsetof_N]]\n" + "ldr x28, [%x[args], %[offsetof_A]]\n" + "1:" // M loop + "ldr x27, [%x[args], %[offsetof_B]]\n" + "2:" // N loop + "fmov z24.h, #0.0\n" + "ld1h { z5.h }, p1/Z, [x27]\n" + "fmov z27.h, #1.0\n" + "mov x26, x28\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "inch x27, ALL, MUL #2\n" + "zip1 z30.h, z5.h, z24.h\n" + "zip2 z20.h, z5.h, z24.h\n" + ".inst 0x81be2760 // fmopa za0.s, p1/M, p1/M, z27.h, z30.h\n" + ".inst 0x81b42761 // fmopa za1.s, p1/M, p1/M, z27.h, z20.h\n" + ".inst 0x81be2762 // fmopa za2.s, p1/M, p1/M, z27.h, z30.h\n" + ".inst 0x81b42763 // fmopa za3.s, p1/M, p1/M, z27.h, z20.h\n" + "3:" // Prepare accumulators: End + "ldr x20, [%x[args], %[offsetof_K]]\n" + "add x20, x20, #0x1\n" + "lsr x20, x20, #0x1\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 6f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" + ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" + ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" + ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" + ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" + "addvl x26, x26, #8\n" + ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + "ble 5f\n" + "4:" // K loop + ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" + "subs x21, x21, #0x1\n" + ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" + ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" + ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" + ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" + ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" + ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" + ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" + ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" + ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" + ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" + ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" + ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" + ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" + ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" + ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" + ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" + ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" + ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" + ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" + ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" + "addvl x26, x26, #8\n" + ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + "bgt 4b\n" + "5:" // K loop tail + ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" + ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" + ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" + ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" + ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" + ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" + ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" + ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" + ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" + ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" + ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" + ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" + ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" + ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" + ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" + ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" + "6:" // K oddments + "cbz x20, 8f\n" + "7:" // K oddments: Loop + ".inst 0xa1402345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26]\n" + "subs x20, x20, #0x1\n" + "addvl x26, x26, #2\n" + ".inst 0xa040236e // ld1h { z14.h-z15.h }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0x81ae24a0 // fmopa za0.s, p1/M, p1/M, z5.h, z14.h\n" + ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" + ".inst 0x81ae25a2 // fmopa za2.s, p1/M, p1/M, z13.h, z14.h\n" + ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" + "bgt 7b\n" + "8:" // K oddments: End + "9:" // Store to output array + "ldr x25, [%x[args], %[offsetof_C]]\n" + "sub x24, x13, x11\n" + "cntw x23, ALL, MUL #2\n" + "ld1rh { z17.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "ldr x22, [%x[args], %[offsetof_ldcb]]\n" + "whilelt p0.h, x10, x9\n" + "cmp x24, x23\n" + "ld1rh { z16.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "mov x12, #0x0\n" + "mov x21, #0x0\n" + "add x25, x25, x10, LSL #1\n" // C += n + "mov x20, #0x2\n" + "madd x25, x11, x22, x25\n" // C += m * ldc + "csel x24, x24, x23, LT\n" + "10:" // Store to output array: Accumulator loop + ".inst 0xc006000e // mova { z14.b-z15.b }, za0h.b[x12, 0:1]\n" + "add x12, x12, #0x4\n" + "cmp x12, x23, LSL #1\n" + "add x21, x21, #0x1\n" + ".inst 0xc120e1cc // fcvt z12.h, { z14.s-z15.s }\n" + "csel x12, x12, x20, LT\n" + "cmp x21, x24\n" + ".inst 0x6470262c // fclamp z12.h, z17.h, z16.h\n" + "st1h { z12.h }, p0, [x25]\n" + "add x25, x25, x22\n" + "blt 10b\n" + "11:" // Store to output array: End + "12:" // End block + "incw x10, ALL, MUL #2\n" + "cmp x10, x9\n" + "blt 2b\n" + "incw x11, ALL, MUL #2\n" + "mov x10, #0x0\n" + "cmp x11, x13\n" + "mov x28, x26\n" + "blt 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), + [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), + [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), + [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), [offsetof_M] "I"(offsetof(KernelArgs, M)), + [offsetof_N] "I"(offsetof(KernelArgs, N)), [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", + "p15", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", + "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", + "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.h new file mode 100644 index 00000000..fddd713c --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.h @@ -0,0 +1,120 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_pack_f16p2vlx1_f16_sme to pack the LHS matrix. +/// -# kai_rhs_pack_kxn_f16p2vlx1biasf16_f16_f16_sme to pack the RHS matrix. + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); + +/// Gets mr value. +/// +/// This is the packing parameter which must be used to pack the LHS matrix. +/// +/// @return The mr value. +size_t kai_get_mr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); + +/// Gets the offset in bytes to the data element in the packed LHS matrix buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m_idx, size_t k); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. +/// @param[in] k Number of rows in the unpacked RHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] n_idx Column index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operands. +/// @param[in] packed_lhs Packed LHS matrix buffer. +/// @param[in] packed_rhs Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. +/// @param[in] dst_stride_col Column stride in bytes of the output matrix. +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, __fp16 clamp_min, __fp16 clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.c new file mode 100644 index 00000000..9f78ccd3 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.c @@ -0,0 +1,351 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_kr = 2; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_lhs_pack_f16p2vlx2_f16_sme(size_t mr) { + KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u16()); + KAI_UNUSED(mr); + + return kai_mr * kai_get_sme_vector_length_u16(); +} + +size_t kai_get_lhs_offset_lhs_pack_f16p2vlx2_f16_sme(size_t m_idx, size_t lhs_stride) { + KAI_ASSUME(m_idx % (kai_mr * kai_get_sme_vector_length_u16()) == 0); + + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_pack_f16p2vlx2_f16_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + const size_t scaled_mr = kai_mr * kai_get_sme_vector_length_u16(); + KAI_ASSUME(m_idx % scaled_mr == 0); + KAI_ASSUME(mr == scaled_mr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + + KAI_UNUSED(mr); + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + return m_idx * k * sizeof(__fp16); +} + +size_t kai_get_lhs_packed_size_lhs_pack_f16p2vlx2_f16_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u16()); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + + KAI_UNUSED(mr); + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + return kai_roundup(m, kai_mr * kai_get_sme_vector_length_u16()) * kai_roundup(k, kai_kr) * sizeof(__fp16); +} + +void kai_run_lhs_pack_f16p2vlx2_f16_sme( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed) { + KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u16()); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(lhs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + KAI_ASSUME(m_idx_start == 0); + + const size_t block_height = kai_mr * kai_get_sme_vector_length_u16() / kai_kr; + const size_t width = k; + const size_t row_offset = 0; + + const void* in[block_height]; + + for (size_t block_y = 0; block_y < m; block_y += block_height) { + const size_t height = KAI_MIN(m - block_y, block_height); + void* out = lhs_packed + block_y * kai_roundup(k, kai_kr) * sizeof(__fp16); + + for (size_t y = 0; y < height; y++) { + in[y] = lhs + (block_y + y) * lhs_stride; + } + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x22, %x[width]\n" + "mov x21, %x[width]\n" + "cnth x20\n" + "inch x22\n" + "sub x7, x20, #0x1\n" + "sub x22, x22, #0x1\n" + "ands x7, x21, x7\n" + "cntw x8\n" + "udiv x22, x22, x20\n" // n_passes = ceildiv(width, VL) + "csel x7, x7, x20, NE\n" + "sub x13, x22, #0x1\n" + "add x7, x7, #0x1\n" + "sub x17, x8, #0x2\n" + "lsl x21, %x[height], #0x1\n" // height * 2 + "lsl x20, x8, #0x1\n" + "mov x16, #0x0\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + "cntw x9, ALL, MUL #2\n" + "cntw x28, ALL, MUL #3\n" + "ldr x27, [x11, #0x0]\n" + "lsr x13, x13, #0x1\n" // n_loops = (n_passes - 1) / 2 + "and x26, x22, #0x1\n" // odd_tail = bool(n_passes & 0x1) + "ldr x25, [x10, #0x0]\n" + "lsr x7, x7, #0x1\n" + "ptrue p12.s\n" + "ldr x24, [x11, #0x8]\n" + "whilelt p11.h, XZR, x21\n" + "whilelt p10.h, x20, x21\n" + "ldr x21, [x10, #0x8]\n" + "mov x23, %x[row_offset]\n" + "mov x22, %x[out]\n" + "whilelt p9.h, x16, %x[width]\n" + "whilelt p8.h, x16, %x[width]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "mov x12, #0x0\n" + "cbz x17, 2f\n" + "1:" // K loop: Charge: Loop + ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" + ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" + ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" + ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" + ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" + "add x12, x12, #0x4\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "cmp x12, x17, LSL #1\n" + "blt 1b\n" + "2:" // K loop: Charge: End + ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" + ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" + ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" + ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" + "ldr x27, [x11, #0x0]\n" + "inch x16\n" + ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "inch x23\n" + "cbz x13, 8f\n" + "mov x20, x13\n" + "3:" // K loop: Main loop + "whilelt p8.h, x16, %x[width]\n" + "mov x15, #0x0\n" + "mov x14, #0x0\n" + "cbz x17, 5f\n" + "4:" // K loop: Main loop: First: Loop + ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" + ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" + ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" + ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" + ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + "add x10, x10, #0x10\n" + "add x15, x15, #0x4\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x14, x14, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x14, x17\n" + "blt 4b\n" + "5:" // K loop: Main loop: First: Tail + ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" + ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" + ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" + ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" + "ldr x27, [x11, #0x0]\n" + "mov x13, #0x0\n" + ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" + "ldr x25, [x10, #0x0]\n" + "mov x12, #0x0\n" + ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + "whilelt p9.h, x16, %x[width]\n" + "inch x16\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "inch x23\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "whilelt p8.h, x16, %x[width]\n" + "cbz x17, 7f\n" + "6:" // K loop: Main loop: Second: Loop + ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" + ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" + ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" + ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" + ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "add x10, x10, #0x10\n" + "add x13, x13, #0x4\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x12, x12, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x12, x17\n" + "blt 6b\n" + "7:" // K loop: Main loop: Second: Tail + ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" + ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" + ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" + ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "whilelt p9.h, x16, %x[width]\n" + "subs x20, x20, #0x1\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "inch x16\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "inch x23\n" + "bgt 3b\n" + "8:" // K loop: Tails + "cbnz x26, 11f\n" + "mov x11, %x[in]\n" + "whilelt p8.h, x16, %x[width]\n" + "mov x13, #0x0\n" + "mov x12, #0x0\n" + "9:" // K loop: Tails: Even: First + ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25396161 // psel p1.h, p8.h/Z, p11.h[w13, #1]\n" + ".inst 0x25396140 // psel p0.h, p8.h/Z, p10.h[w13, #1]\n" + ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "ldr x21, [x11, #0x0]\n" + "cmp x12, x8\n" + "ldr x20, [x11, x8, LSL #0x3]\n" + "add x11, x11, #0x8\n" + ".inst 0xe05726a1 // ld1h { za0h.h[x13, #1] }, p1/Z, [x21, x23, LSL #1]\n" + ".inst 0xe0572289 // ld1h { za1h.h[x13, #1] }, p0/Z, [x20, x23, LSL #1]\n" + "add x13, x13, #0x2\n" + "blt 9b\n" + "whilelt p9.h, x16, %x[width]\n" + "whilelt p8.h, x16, %x[width]\n" + "mov x20, #0x0\n" + "mov x12, #0x0\n" + "10:" // K loop: Tails: Even: Second + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + "add x20, x20, #0x2\n" + ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 10b\n" + "whilelt p8.h, x16, %x[width]\n" + "b 13f\n" + "11:" // K loop: Tails: Odd + "mov x12, #0x0\n" + "12:" // K loop: Tails: Odd: Loop + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 12b\n" + "13:" // K loop: End + "mov %x[out], x22\n" + ".inst 0xd503467f // SMSTOP\n" + : [out] "+&r"(out) + : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", + "p14", "p15", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", + "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", + "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.h b/kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.h new file mode 100644 index 00000000..5fd52cac --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.h @@ -0,0 +1,77 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @param[in] mr Number of rows to be interleaved. +/// +/// @return The m step value. +size_t kai_get_m_step_lhs_pack_f16p2vlx2_f16_sme(size_t mr); + +/// Gets the offset in bytes to the data element in the LHS buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] lhs_stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_lhs_pack_f16p2vlx2_f16_sme(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes to the data element in the packed LHS buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] mr Number of rows to be interleaved. +/// @param[in] kr Unused. Must be 1. +/// @param[in] sr Unused. Must be 1. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_lhs_pack_f16p2vlx2_f16_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes of the packed LHS buffer. +/// +/// @param[in] m Number of rows in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] mr Number of rows to be interleaved. +/// @param[in] kr Unused. Must be 1. +/// @param[in] sr Unused. Must be 1. +/// +/// @return The size in bytes of the packed LHS buffer. +size_t kai_get_lhs_packed_size_lhs_pack_f16p2vlx2_f16_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +/// Runs the LHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (LHS and packed LHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_lhs_pack_f16p2vlx2_f16_sme. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_lhs_pack_f16p2vlx2_f16_sme. +/// +/// @param[in] m Number of rows of the unpacked LHS matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] mr Block size in M dimension. It must be 2 * kai_get_sme_vector_length_u16(). +/// @param[in] kr Block size in K dimension. It must be 2. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] m_idx_start Unused. Must be 0. +/// @param[in] lhs LHS matrix data buffer. +/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. +/// @param[out] lhs_packed Packed RHS matrix. +void kai_run_lhs_pack_f16p2vlx2_f16_sme( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c new file mode 100644 index 00000000..abf3cc29 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c @@ -0,0 +1,190 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_nr = 2; +static const size_t kai_kr = 2; +static const size_t kai_num_bytes_input = 2; +static const size_t kai_num_bytes_output = 2; +static const size_t kai_num_bytes_bias = 2; + +size_t kai_get_n_step_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(void) { + return kai_nr * kai_get_sme_vector_length_u16(); +} + +size_t kai_get_rhs_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n_idx) { + KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u16()) == 0); + + return n_idx * kai_num_bytes_input; +} + +size_t kai_get_bias_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n_idx) { + return n_idx * kai_num_bytes_bias; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t k) { + return kai_nr * kai_get_sme_vector_length_u16() / kai_kr * + (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u16() / kai_kr) == 0); + + return n_idx * (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output); +} + +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n, size_t k) { + return kai_get_rhs_packed_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme( + kai_roundup(n, kai_nr * kai_get_sme_vector_length_u16() / kai_kr), k); +} + +void kai_run_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME(nr == kai_nr * kai_get_sme_vector_length_u16()); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == 1); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(scale == NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params == NULL); + + size_t height = k; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_stride; + // Padding with potentially non-zero values + uint16_t* pad_row = rhs; + + size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(height); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x21, %x[out]\n" + "mov x20, %x[width]\n" + "ptrue p1.b\n" + "1:" // Bias: Full loop + "whilelt p0.h, XZR, x20\n" + "dech x20\n" + "cmp x20, #0x0\n" + "ld1h { z16.h }, p0/Z, [%x[bias]]\n" + "incb %x[bias]\n" + "st1h { z16.h }, p1, [x21]\n" + "add x21, x21, %x[out_stride]\n" + "bgt 1b\n" + "cmp %x[height], #0x8\n" + "incb %x[out]\n" + "blt 5f\n" + "2:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[out]\n" + "add x27, x9, %x[in_stride]\n" + "sub %x[height], %x[height], #0x8\n" + "add x26, x27, %x[in_stride]\n" + "mov x25, %x[width]\n" + "add x24, x26, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "3:" // Main row loop: Column loop + "whilelt p0.h, XZR, x25\n" + "decw x25, ALL, MUL #2\n" + "ld1h { z20.h }, p0/Z, [x9]\n" + "cmp x25, #0x0\n" + "addvl x9, x9, #1\n" + "ld1h { z17.h }, p0/Z, [x27]\n" + "addvl x27, x27, #1\n" + "ld1h { z19.h }, p0/Z, [x26]\n" + "addvl x26, x26, #1\n" + "ld1h { z16.h }, p0/Z, [x24]\n" + "addvl x24, x24, #1\n" + "ld1h { z18.h }, p0/Z, [x23]\n" + "addvl x23, x23, #1\n" + "zip1 z24.h, z20.h, z17.h\n" + "zip2 z23.h, z20.h, z17.h\n" + "ld1h { z17.h }, p0/Z, [x22]\n" + "addvl x22, x22, #1\n" + "ld1h { z22.h }, p0/Z, [x21]\n" + "addvl x21, x21, #1\n" + "zip1 z21.h, z19.h, z16.h\n" + "zip2 z20.h, z19.h, z16.h\n" + "ld1h { z16.h }, p0/Z, [x20]\n" + "addvl x20, x20, #1\n" + "zip1 z19.h, z18.h, z17.h\n" + "zip2 z18.h, z18.h, z17.h\n" + "st1h { z24.h }, p1, [x28]\n" + "st1h { z23.h }, p1, [x28, #1, MUL VL]\n" + "zip1 z17.h, z22.h, z16.h\n" + "zip2 z16.h, z22.h, z16.h\n" + "st1h { z21.h }, p1, [x28, #2, MUL VL]\n" + "st1h { z20.h }, p1, [x28, #3, MUL VL]\n" + "st1h { z19.h }, p1, [x28, #4, MUL VL]\n" + "st1h { z18.h }, p1, [x28, #5, MUL VL]\n" + "st1h { z17.h }, p1, [x28, #6, MUL VL]\n" + "st1h { z16.h }, p1, [x28, #7, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 3b\n" + "4:" // Main row loop: Column loop skip + "cmp %x[height], #0x8\n" + "addvl %x[out], %x[out], #8\n" + "bge 2b\n" + "cbz %x[height], 9f\n" + "5:" // Main loop skip + "6:" // Tail row loop: Head + "mov x9, %x[in]\n" + "cmp %x[height], #0x1\n" + "add x27, x9, %x[in_stride]\n" + "mov x28, %x[out]\n" + "add %x[in], x27, %x[in_stride]\n" + "csel x27, x27, %x[pad_row], GT\n" + "sub %x[height], %x[height], #0x2\n" + "mov x20, %x[width]\n" + "7:" // Tail row loop: Column loop + "whilelt p0.h, XZR, x20\n" + "decw x20, ALL, MUL #2\n" + "ld1h { z18.h }, p0/Z, [x9]\n" + "cmp x20, #0x0\n" + "addvl x9, x9, #1\n" + "ld1h { z16.h }, p0/Z, [x27]\n" + "addvl x27, x27, #1\n" + "zip1 z17.h, z18.h, z16.h\n" + "zip2 z16.h, z18.h, z16.h\n" + "st1h { z17.h }, p1, [x28]\n" + "st1h { z16.h }, p1, [x28, #1, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 7b\n" + "8:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 6b\n" + "9:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) + : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", + "p15", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", + "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", + "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h new file mode 100644 index 00000000..a24ef6e4 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h @@ -0,0 +1,80 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting row index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n_idx, size_t k); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of rows. +/// @param[in] k Number of columns. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n, size_t k); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme. +/// * Bias: @ref kai_get_bias_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme. +/// +/// @param[in] num_groups Number of groups. It must be 1. +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u16(). +/// @param[in] kr Block size in K dimension. It must be 2. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. It must be NULL. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. +/// @param[in] params Extra packing parameters. It must be NULL. +void kai_run_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus -- GitLab From 796e35cea6ae4d245551a3e356f1991ee0ed8ef0 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Tue, 29 Oct 2024 07:26:05 +0000 Subject: [PATCH 2/8] Set kai_kr in gemm Signed-off-by: Felix Thomasmathibalan --- .../kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c index 6ac227bc..de7e3cf1 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c @@ -15,9 +15,8 @@ static const size_t kai_mr = 2; static const size_t kai_nr = 2; -static const size_t kai_kr = 1; +static const size_t kai_kr = 2; static const size_t kai_sr = 1; - size_t kai_get_m_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { return kai_mr * kai_get_sme_vector_length_u16(); } -- GitLab From 5077c190e5b13df71e07185eccdee593d3a3934c Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Tue, 29 Oct 2024 07:54:24 +0000 Subject: [PATCH 3/8] Add files to Bazel Other pipeline fixes disable tests Signed-off-by: Felix Thomasmathibalan --- kai/ukernels/matmul/BUILD.bazel | 16 ++++++++++++++++ ...tmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c | 4 ++-- .../kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c | 2 +- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 66a3c386..6067ddc2 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -63,6 +63,21 @@ kai_c_library( cpu_uarch = kai_cpu_sme(), ) +kai_c_library( + name = "matmul_f16_f16p_f16p", + srcs = [ + "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c", + "pack/kai_lhs_pack_f16p2vlx2_f16_sme.c", + "pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c", + ], + hdrs = [ + "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.h", + "pack/kai_lhs_pack_f16p2vlx2_f16_sme.c", + "pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h", + ], + cpu_uarch = kai_cpu_sme(), +) + kai_c_library( name = "clamp_f32_f32_f32pb_1x16vl_sme2_mla", srcs = ["matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pb_1x16vl_sme2_mla.c"], @@ -362,6 +377,7 @@ kai_c_library( ":lhs_quant_pack_bf16p_f32_neon", ":lhs_quant_pack_qai8dxp_f32", ":lhs_quant_pack_qsi8d32p_f32", + ":matmul_f16_f16p_f16p", ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", ":rhs_pack_kxn_f32pb_f32_f32_16vlx1_sme", diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c index de7e3cf1..0f9cf568 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c @@ -72,8 +72,8 @@ void kai_run_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa( const void* B; void* C; - long ldcb; - long M, N, K; + uint64_t ldcb; + uint64_t M, N, K; __fp16 min; __fp16 max; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c index abf3cc29..b718cde6 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c @@ -72,7 +72,7 @@ void kai_run_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme( void* out = rhs_packed; const size_t in_stride = rhs_stride; // Padding with potentially non-zero values - uint16_t* pad_row = rhs; + const uint16_t* pad_row = rhs; size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(height); -- GitLab From 2d020d86f005a29229455923c802c97fcf308a8e Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Wed, 30 Oct 2024 10:04:31 +0000 Subject: [PATCH 4/8] Update as per naming convention Signed-off-by: Felix Thomasmathibalan --- ...kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c} | 0 ...kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/{kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c => kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c} (100%) rename kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/{kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.h => kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h} (100%) diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c rename to kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.h rename to kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h -- GitLab From b1e4f46e24f3a4ae81d8e40c23041a4575137e61 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Wed, 30 Oct 2024 10:14:16 +0000 Subject: [PATCH 5/8] Update as per naming convention Signed-off-by: Felix Thomasmathibalan --- CMakeLists.txt | 2 +- kai/ukernels/matmul/BUILD.bazel | 4 +-- ...6_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c | 31 ++++++++++--------- ...6_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h | 29 ++++++++--------- 4 files changed, 34 insertions(+), 32 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 860fd417..63253862 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -131,7 +131,7 @@ set(KLEIDIAI_FILES_SME set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pb_1x16vl_sme2_mla.c - kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c ) add_library(kleidiai) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 6067ddc2..33a0bb1a 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -66,12 +66,12 @@ kai_c_library( kai_c_library( name = "matmul_f16_f16p_f16p", srcs = [ - "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.c", + "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c", "pack/kai_lhs_pack_f16p2vlx2_f16_sme.c", "pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c", ], hdrs = [ - "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa.h", + "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h", "pack/kai_lhs_pack_f16p2vlx2_f16_sme.c", "pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h", ], diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c index 0f9cf568..440f34c2 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c @@ -17,52 +17,53 @@ static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 2; static const size_t kai_sr = 1; -size_t kai_get_m_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { +size_t kai_get_m_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { return kai_mr * kai_get_sme_vector_length_u16(); } -size_t kai_get_n_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { +size_t kai_get_n_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { return kai_nr * kai_get_sme_vector_length_u16(); } -size_t kai_get_mr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { +size_t kai_get_mr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { return kai_mr * kai_get_sme_vector_length_u16(); } -size_t kai_get_nr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { +size_t kai_get_nr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { return kai_nr * kai_get_sme_vector_length_u16(); } -size_t kai_get_kr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { +size_t kai_get_kr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void) { +size_t kai_get_sr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m_idx, size_t k) { - KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa() == 0); +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa() == 0); return m_idx * k * sizeof(__fp16); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t n_idx, size_t k) { - KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa() == 0); +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa() == 0); return n_idx * (k * sizeof(__fp16) + sizeof(__fp16)); } -size_t kai_get_dst_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa() == 0); - KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa() == 0); +size_t kai_get_dst_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa() == 0); return m_idx * dst_stride + n_idx * sizeof(__fp16); } -size_t kai_get_dst_size_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(size_t m, size_t n) { return m * n * sizeof(__fp16); } -void kai_run_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa( +void kai_run_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, __fp16 clamp_min, __fp16 clamp_max) { KAI_ASSUME(dst_stride_col == sizeof(__fp16)); diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h index fddd713c..227f8d4d 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h @@ -22,42 +22,42 @@ extern "C" { /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. -size_t kai_get_m_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); +size_t kai_get_m_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void); /// Gets n step value. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); +size_t kai_get_n_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void); /// Gets mr value. /// /// This is the packing parameter which must be used to pack the LHS matrix. /// /// @return The mr value. -size_t kai_get_mr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); +size_t kai_get_mr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void); /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The nr value. -size_t kai_get_nr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); +size_t kai_get_nr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void); /// Gets kr value. /// /// This is the packing parameter which must be used to pack the LHS and RHS matrix. /// /// @return The kr value. -size_t kai_get_kr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); +size_t kai_get_kr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void); /// Gets sr value. /// /// This is the packing parameter which must be used to pack the LHS and RHS matrix. /// /// @return The sr value. -size_t kai_get_sr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); +size_t kai_get_sr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void); /// Gets the offset in bytes to the data element in the packed LHS matrix buffer. /// @@ -65,7 +65,7 @@ size_t kai_get_sr_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(void); /// @param[in] k Number of columns in the unpacked LHS matrix. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(size_t m_idx, size_t k); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// @@ -73,7 +73,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(s /// @param[in] k Number of rows in the unpacked RHS matrix. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// @@ -82,7 +82,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(s /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_dst_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m_idx, size_t n_idx, size_t dst_stride); +size_t kai_get_dst_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride); /// Gets the size in bytes of the destination matrix buffer. /// @@ -90,16 +91,16 @@ size_t kai_get_dst_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. -size_t kai_get_dst_size_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(size_t m, size_t n); /// Runs the matrix multiplication microkernel followed by a clamp operation. /// /// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset /// calculated using the following functions: /// -/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa. -/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa. -/// * Output: @ref kai_get_dst_offset_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa. /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. @@ -111,7 +112,7 @@ size_t kai_get_dst_size_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa(size_t m, /// @param[in] dst_stride_col Column stride in bytes of the output matrix. /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. -void kai_run_matmul_clamp_f16_f16p_f16pb_2vlx2vl_sme2_mopa( +void kai_run_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, __fp16 clamp_min, __fp16 clamp_max); -- GitLab From e0120a317f64dbf7dec5ca32709859f38b722a0c Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Wed, 30 Oct 2024 17:23:58 +0000 Subject: [PATCH 6/8] Review comments update Signed-off-by: Felix Thomasmathibalan --- kai/ukernels/matmul/BUILD.bazel | 18 ++++++++++++++---- ...16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c | 1 + 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 33a0bb1a..9d1ccd12 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -64,15 +64,24 @@ kai_c_library( ) kai_c_library( - name = "matmul_f16_f16p_f16p", + name = "kai_files_sme2", srcs = [ "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c", - "pack/kai_lhs_pack_f16p2vlx2_f16_sme.c", - "pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c", ], hdrs = [ "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h", + ], + cpu_uarch = kai_cpu_sme(), +) + +kai_c_library( + name = "kai_files_sme", + srcs = [ "pack/kai_lhs_pack_f16p2vlx2_f16_sme.c", + "pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c", + ], + hdrs = [ + "pack/kai_lhs_pack_f16p2vlx2_f16_sme.h", "pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h", ], cpu_uarch = kai_cpu_sme(), @@ -373,11 +382,12 @@ kai_c_library( ":clamp_f32_qsi8d32p_qsi4c32p_dotprod", ":clamp_f32_qsi8d32p_qsi4c32p_i8mm", ":clamp_f32_qsi8d32p_qsi4c32p_interface", + ":kai_files_sme", + ":kai_files_sme2", ":lhs_pack_f32p2vlx1_f32_sme", ":lhs_quant_pack_bf16p_f32_neon", ":lhs_quant_pack_qai8dxp_f32", ":lhs_quant_pack_qsi8d32p_f32", - ":matmul_f16_f16p_f16p", ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", ":rhs_pack_kxn_f32pb_f32_f32_16vlx1_sme", diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c index 440f34c2..3c99aa82 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c @@ -17,6 +17,7 @@ static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 2; static const size_t kai_sr = 1; + size_t kai_get_m_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { return kai_mr * kai_get_sme_vector_length_u16(); } -- GitLab From aff4caa13e02d3e12a99375916f661d56d1b1787 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Wed, 30 Oct 2024 17:54:01 +0000 Subject: [PATCH 7/8] Review comments update Signed-off-by: Felix Thomasmathibalan --- ...tmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c index 3c99aa82..30f3e9a4 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. +#include "kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h" + #include #include @@ -17,7 +19,6 @@ static const size_t kai_mr = 2; static const size_t kai_nr = 2; static const size_t kai_kr = 2; static const size_t kai_sr = 1; - size_t kai_get_m_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { return kai_mr * kai_get_sme_vector_length_u16(); } @@ -74,8 +75,8 @@ void kai_run_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa( const void* B; void* C; - uint64_t ldcb; - uint64_t M, N, K; + long ldcb; + long M, N, K; __fp16 min; __fp16 max; -- GitLab From 2301be6337f795200b9d57842eae11588b689c78 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Wed, 30 Oct 2024 18:01:44 +0000 Subject: [PATCH 8/8] Fix Pipeline Signed-off-by: Felix Thomasmathibalan --- ..._matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c index 30f3e9a4..42f47f25 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c @@ -75,8 +75,8 @@ void kai_run_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa( const void* B; void* C; - long ldcb; - long M, N, K; + uint64_t ldcb; + uint64_t M, N, K; __fp16 min; __fp16 max; -- GitLab