diff --git a/CMakeLists.txt b/CMakeLists.txt index 649e158504fa3b9ba96bc4a6be6f7cdee0944aac..6325386258e4decd157eed5ec4a23f167973527c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -124,11 +124,14 @@ set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_16vlx1_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c ) set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pb_1x16vl_sme2_mla.c + kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c ) add_library(kleidiai) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 66a3c3868a97e70828efb60ef3f84d724c0e8ef2..9d1ccd129cc8d33c765a3e63cb8112f06067c2dc 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -63,6 +63,30 @@ kai_c_library( cpu_uarch = kai_cpu_sme(), ) +kai_c_library( + name = "kai_files_sme2", + srcs = [ + "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c", + ], + hdrs = [ + "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h", + ], + cpu_uarch = kai_cpu_sme(), +) + +kai_c_library( + name = "kai_files_sme", + srcs = [ + "pack/kai_lhs_pack_f16p2vlx2_f16_sme.c", + "pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c", + ], + hdrs = [ + "pack/kai_lhs_pack_f16p2vlx2_f16_sme.h", + "pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h", + ], + cpu_uarch = kai_cpu_sme(), +) + kai_c_library( name = "clamp_f32_f32_f32pb_1x16vl_sme2_mla", srcs = ["matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32pb_1x16vl_sme2_mla.c"], @@ -358,6 +382,8 @@ kai_c_library( ":clamp_f32_qsi8d32p_qsi4c32p_dotprod", ":clamp_f32_qsi8d32p_qsi4c32p_i8mm", ":clamp_f32_qsi8d32p_qsi4c32p_interface", + ":kai_files_sme", + ":kai_files_sme2", ":lhs_pack_f32p2vlx1_f32_sme", ":lhs_quant_pack_bf16p_f32_neon", ":lhs_quant_pack_qai8dxp_f32", diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c new file mode 100644 index 0000000000000000000000000000000000000000..42f47f25e4016dfeeb8baac2d022bd63538994ea --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.c @@ -0,0 +1,256 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_nr = 2; +static const size_t kai_kr = 2; +static const size_t kai_sr = 1; +size_t kai_get_m_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u16(); +} + +size_t kai_get_n_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u16(); +} + +size_t kai_get_mr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u16(); +} + +size_t kai_get_nr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u16(); +} + +size_t kai_get_kr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa() == 0); + return m_idx * k * sizeof(__fp16); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa() == 0); + return n_idx * (k * sizeof(__fp16) + sizeof(__fp16)); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa() == 0); + + return m_idx * dst_stride + n_idx * sizeof(__fp16); +} + +size_t kai_get_dst_size_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(size_t m, size_t n) { + return m * n * sizeof(__fp16); +} + +void kai_run_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, __fp16 clamp_min, __fp16 clamp_max) { + KAI_ASSUME(dst_stride_col == sizeof(__fp16)); + + typedef struct { + const void* A; + const void* B; + + void* C; + uint64_t ldcb; + uint64_t M, N, K; + __fp16 min; + __fp16 max; + + void* accumulator_buffer; + uint64_t flags; + } KernelArgs; + + KernelArgs args; + + args.A = lhs_packed; + args.B = rhs_packed; + + args.C = dst; + args.ldcb = dst_stride_row; + args.M = m; + args.N = n; + args.K = k; + args.min = clamp_min; + args.max = clamp_max; + + args.accumulator_buffer = NULL; + args.flags = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ldr w13, [%x[args], %[offsetof_M]]\n" + "mov x11, #0x0\n" + "mov x10, #0x0\n" + "ptrue p1.b\n" + ".inst 0x25207810 // ptrue pn8.b\n" + "ldr w9, [%x[args], %[offsetof_N]]\n" + "ldr x28, [%x[args], %[offsetof_A]]\n" + "1:" // M loop + "ldr x27, [%x[args], %[offsetof_B]]\n" + "2:" // N loop + "fmov z24.h, #0.0\n" + "ld1h { z5.h }, p1/Z, [x27]\n" + "fmov z27.h, #1.0\n" + "mov x26, x28\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "inch x27, ALL, MUL #2\n" + "zip1 z30.h, z5.h, z24.h\n" + "zip2 z20.h, z5.h, z24.h\n" + ".inst 0x81be2760 // fmopa za0.s, p1/M, p1/M, z27.h, z30.h\n" + ".inst 0x81b42761 // fmopa za1.s, p1/M, p1/M, z27.h, z20.h\n" + ".inst 0x81be2762 // fmopa za2.s, p1/M, p1/M, z27.h, z30.h\n" + ".inst 0x81b42763 // fmopa za3.s, p1/M, p1/M, z27.h, z20.h\n" + "3:" // Prepare accumulators: End + "ldr x20, [%x[args], %[offsetof_K]]\n" + "add x20, x20, #0x1\n" + "lsr x20, x20, #0x1\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 6f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" + ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" + ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" + ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" + ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" + "addvl x26, x26, #8\n" + ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + "ble 5f\n" + "4:" // K loop + ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" + "subs x21, x21, #0x1\n" + ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" + ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" + ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" + ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" + ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" + ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" + ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" + ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" + ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" + ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" + ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" + ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" + ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" + ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" + ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" + ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" + ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" + ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" + ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" + ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" + "addvl x26, x26, #8\n" + ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + "bgt 4b\n" + "5:" // K loop tail + ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" + ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" + ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" + ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" + ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" + ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" + ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" + ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" + ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" + ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" + ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" + ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" + ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" + ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" + ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" + ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" + "6:" // K oddments + "cbz x20, 8f\n" + "7:" // K oddments: Loop + ".inst 0xa1402345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26]\n" + "subs x20, x20, #0x1\n" + "addvl x26, x26, #2\n" + ".inst 0xa040236e // ld1h { z14.h-z15.h }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0x81ae24a0 // fmopa za0.s, p1/M, p1/M, z5.h, z14.h\n" + ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" + ".inst 0x81ae25a2 // fmopa za2.s, p1/M, p1/M, z13.h, z14.h\n" + ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" + "bgt 7b\n" + "8:" // K oddments: End + "9:" // Store to output array + "ldr x25, [%x[args], %[offsetof_C]]\n" + "sub x24, x13, x11\n" + "cntw x23, ALL, MUL #2\n" + "ld1rh { z17.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "ldr x22, [%x[args], %[offsetof_ldcb]]\n" + "whilelt p0.h, x10, x9\n" + "cmp x24, x23\n" + "ld1rh { z16.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "mov x12, #0x0\n" + "mov x21, #0x0\n" + "add x25, x25, x10, LSL #1\n" // C += n + "mov x20, #0x2\n" + "madd x25, x11, x22, x25\n" // C += m * ldc + "csel x24, x24, x23, LT\n" + "10:" // Store to output array: Accumulator loop + ".inst 0xc006000e // mova { z14.b-z15.b }, za0h.b[x12, 0:1]\n" + "add x12, x12, #0x4\n" + "cmp x12, x23, LSL #1\n" + "add x21, x21, #0x1\n" + ".inst 0xc120e1cc // fcvt z12.h, { z14.s-z15.s }\n" + "csel x12, x12, x20, LT\n" + "cmp x21, x24\n" + ".inst 0x6470262c // fclamp z12.h, z17.h, z16.h\n" + "st1h { z12.h }, p0, [x25]\n" + "add x25, x25, x22\n" + "blt 10b\n" + "11:" // Store to output array: End + "12:" // End block + "incw x10, ALL, MUL #2\n" + "cmp x10, x9\n" + "blt 2b\n" + "incw x11, ALL, MUL #2\n" + "mov x10, #0x0\n" + "cmp x11, x13\n" + "mov x28, x26\n" + "blt 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), + [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), + [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), + [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), [offsetof_M] "I"(offsetof(KernelArgs, M)), + [offsetof_N] "I"(offsetof(KernelArgs, N)), [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", + "p15", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", + "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", + "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h new file mode 100644 index 0000000000000000000000000000000000000000..227f8d4d9a24c80c254070737c84e00fc4704fd2 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa.h @@ -0,0 +1,121 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_pack_f16p2vlx1_f16_sme to pack the LHS matrix. +/// -# kai_rhs_pack_kxn_f16p2vlx1biasf16_f16_f16_sme to pack the RHS matrix. + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void); + +/// Gets mr value. +/// +/// This is the packing parameter which must be used to pack the LHS matrix. +/// +/// @return The mr value. +size_t kai_get_mr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(void); + +/// Gets the offset in bytes to the data element in the packed LHS matrix buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(size_t m_idx, size_t k); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. +/// @param[in] k Number of rows in the unpacked RHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] n_idx Column index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operands. +/// @param[in] packed_lhs Packed LHS matrix buffer. +/// @param[in] packed_rhs Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. +/// @param[in] dst_stride_col Column stride in bytes of the output matrix. +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_matmul_clamp_f16_f16p2vlx1_f16p2vlx1b_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, __fp16 clamp_min, __fp16 clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..9f78ccd3e8478dbc87491f43e3c574b9f173d0bf --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.c @@ -0,0 +1,351 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_kr = 2; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_lhs_pack_f16p2vlx2_f16_sme(size_t mr) { + KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u16()); + KAI_UNUSED(mr); + + return kai_mr * kai_get_sme_vector_length_u16(); +} + +size_t kai_get_lhs_offset_lhs_pack_f16p2vlx2_f16_sme(size_t m_idx, size_t lhs_stride) { + KAI_ASSUME(m_idx % (kai_mr * kai_get_sme_vector_length_u16()) == 0); + + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_pack_f16p2vlx2_f16_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + const size_t scaled_mr = kai_mr * kai_get_sme_vector_length_u16(); + KAI_ASSUME(m_idx % scaled_mr == 0); + KAI_ASSUME(mr == scaled_mr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + + KAI_UNUSED(mr); + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + return m_idx * k * sizeof(__fp16); +} + +size_t kai_get_lhs_packed_size_lhs_pack_f16p2vlx2_f16_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u16()); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + + KAI_UNUSED(mr); + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + return kai_roundup(m, kai_mr * kai_get_sme_vector_length_u16()) * kai_roundup(k, kai_kr) * sizeof(__fp16); +} + +void kai_run_lhs_pack_f16p2vlx2_f16_sme( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed) { + KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u16()); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(lhs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + KAI_ASSUME(m_idx_start == 0); + + const size_t block_height = kai_mr * kai_get_sme_vector_length_u16() / kai_kr; + const size_t width = k; + const size_t row_offset = 0; + + const void* in[block_height]; + + for (size_t block_y = 0; block_y < m; block_y += block_height) { + const size_t height = KAI_MIN(m - block_y, block_height); + void* out = lhs_packed + block_y * kai_roundup(k, kai_kr) * sizeof(__fp16); + + for (size_t y = 0; y < height; y++) { + in[y] = lhs + (block_y + y) * lhs_stride; + } + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x22, %x[width]\n" + "mov x21, %x[width]\n" + "cnth x20\n" + "inch x22\n" + "sub x7, x20, #0x1\n" + "sub x22, x22, #0x1\n" + "ands x7, x21, x7\n" + "cntw x8\n" + "udiv x22, x22, x20\n" // n_passes = ceildiv(width, VL) + "csel x7, x7, x20, NE\n" + "sub x13, x22, #0x1\n" + "add x7, x7, #0x1\n" + "sub x17, x8, #0x2\n" + "lsl x21, %x[height], #0x1\n" // height * 2 + "lsl x20, x8, #0x1\n" + "mov x16, #0x0\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + "cntw x9, ALL, MUL #2\n" + "cntw x28, ALL, MUL #3\n" + "ldr x27, [x11, #0x0]\n" + "lsr x13, x13, #0x1\n" // n_loops = (n_passes - 1) / 2 + "and x26, x22, #0x1\n" // odd_tail = bool(n_passes & 0x1) + "ldr x25, [x10, #0x0]\n" + "lsr x7, x7, #0x1\n" + "ptrue p12.s\n" + "ldr x24, [x11, #0x8]\n" + "whilelt p11.h, XZR, x21\n" + "whilelt p10.h, x20, x21\n" + "ldr x21, [x10, #0x8]\n" + "mov x23, %x[row_offset]\n" + "mov x22, %x[out]\n" + "whilelt p9.h, x16, %x[width]\n" + "whilelt p8.h, x16, %x[width]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "mov x12, #0x0\n" + "cbz x17, 2f\n" + "1:" // K loop: Charge: Loop + ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" + ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" + ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" + ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" + ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" + "add x12, x12, #0x4\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "cmp x12, x17, LSL #1\n" + "blt 1b\n" + "2:" // K loop: Charge: End + ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" + ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" + ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" + ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" + "ldr x27, [x11, #0x0]\n" + "inch x16\n" + ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "inch x23\n" + "cbz x13, 8f\n" + "mov x20, x13\n" + "3:" // K loop: Main loop + "whilelt p8.h, x16, %x[width]\n" + "mov x15, #0x0\n" + "mov x14, #0x0\n" + "cbz x17, 5f\n" + "4:" // K loop: Main loop: First: Loop + ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" + ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" + ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" + ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" + ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + "add x10, x10, #0x10\n" + "add x15, x15, #0x4\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x14, x14, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x14, x17\n" + "blt 4b\n" + "5:" // K loop: Main loop: First: Tail + ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" + ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" + ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" + ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" + "ldr x27, [x11, #0x0]\n" + "mov x13, #0x0\n" + ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" + "ldr x25, [x10, #0x0]\n" + "mov x12, #0x0\n" + ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + "whilelt p9.h, x16, %x[width]\n" + "inch x16\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "inch x23\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "whilelt p8.h, x16, %x[width]\n" + "cbz x17, 7f\n" + "6:" // K loop: Main loop: Second: Loop + ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" + ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" + ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" + ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" + ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "add x10, x10, #0x10\n" + "add x13, x13, #0x4\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x12, x12, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x12, x17\n" + "blt 6b\n" + "7:" // K loop: Main loop: Second: Tail + ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" + ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" + ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" + ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "whilelt p9.h, x16, %x[width]\n" + "subs x20, x20, #0x1\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "inch x16\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "inch x23\n" + "bgt 3b\n" + "8:" // K loop: Tails + "cbnz x26, 11f\n" + "mov x11, %x[in]\n" + "whilelt p8.h, x16, %x[width]\n" + "mov x13, #0x0\n" + "mov x12, #0x0\n" + "9:" // K loop: Tails: Even: First + ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25396161 // psel p1.h, p8.h/Z, p11.h[w13, #1]\n" + ".inst 0x25396140 // psel p0.h, p8.h/Z, p10.h[w13, #1]\n" + ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "ldr x21, [x11, #0x0]\n" + "cmp x12, x8\n" + "ldr x20, [x11, x8, LSL #0x3]\n" + "add x11, x11, #0x8\n" + ".inst 0xe05726a1 // ld1h { za0h.h[x13, #1] }, p1/Z, [x21, x23, LSL #1]\n" + ".inst 0xe0572289 // ld1h { za1h.h[x13, #1] }, p0/Z, [x20, x23, LSL #1]\n" + "add x13, x13, #0x2\n" + "blt 9b\n" + "whilelt p9.h, x16, %x[width]\n" + "whilelt p8.h, x16, %x[width]\n" + "mov x20, #0x0\n" + "mov x12, #0x0\n" + "10:" // K loop: Tails: Even: Second + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + "add x20, x20, #0x2\n" + ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 10b\n" + "whilelt p8.h, x16, %x[width]\n" + "b 13f\n" + "11:" // K loop: Tails: Odd + "mov x12, #0x0\n" + "12:" // K loop: Tails: Odd: Loop + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 12b\n" + "13:" // K loop: End + "mov %x[out], x22\n" + ".inst 0xd503467f // SMSTOP\n" + : [out] "+&r"(out) + : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", + "p14", "p15", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", + "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", + "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.h b/kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..5fd52cacde7e4df275b396246f8a78bbb3ec171c --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f16p2vlx2_f16_sme.h @@ -0,0 +1,77 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @param[in] mr Number of rows to be interleaved. +/// +/// @return The m step value. +size_t kai_get_m_step_lhs_pack_f16p2vlx2_f16_sme(size_t mr); + +/// Gets the offset in bytes to the data element in the LHS buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] lhs_stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_lhs_pack_f16p2vlx2_f16_sme(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes to the data element in the packed LHS buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] mr Number of rows to be interleaved. +/// @param[in] kr Unused. Must be 1. +/// @param[in] sr Unused. Must be 1. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_lhs_pack_f16p2vlx2_f16_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes of the packed LHS buffer. +/// +/// @param[in] m Number of rows in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] mr Number of rows to be interleaved. +/// @param[in] kr Unused. Must be 1. +/// @param[in] sr Unused. Must be 1. +/// +/// @return The size in bytes of the packed LHS buffer. +size_t kai_get_lhs_packed_size_lhs_pack_f16p2vlx2_f16_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +/// Runs the LHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (LHS and packed LHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_lhs_pack_f16p2vlx2_f16_sme. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_lhs_pack_f16p2vlx2_f16_sme. +/// +/// @param[in] m Number of rows of the unpacked LHS matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] mr Block size in M dimension. It must be 2 * kai_get_sme_vector_length_u16(). +/// @param[in] kr Block size in K dimension. It must be 2. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] m_idx_start Unused. Must be 0. +/// @param[in] lhs LHS matrix data buffer. +/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. +/// @param[out] lhs_packed Packed RHS matrix. +void kai_run_lhs_pack_f16p2vlx2_f16_sme( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..b718cde6b7c34bd49b8f681198f5770a8846e003 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.c @@ -0,0 +1,190 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_nr = 2; +static const size_t kai_kr = 2; +static const size_t kai_num_bytes_input = 2; +static const size_t kai_num_bytes_output = 2; +static const size_t kai_num_bytes_bias = 2; + +size_t kai_get_n_step_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(void) { + return kai_nr * kai_get_sme_vector_length_u16(); +} + +size_t kai_get_rhs_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n_idx) { + KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u16()) == 0); + + return n_idx * kai_num_bytes_input; +} + +size_t kai_get_bias_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n_idx) { + return n_idx * kai_num_bytes_bias; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t k) { + return kai_nr * kai_get_sme_vector_length_u16() / kai_kr * + (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u16() / kai_kr) == 0); + + return n_idx * (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output); +} + +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n, size_t k) { + return kai_get_rhs_packed_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme( + kai_roundup(n, kai_nr * kai_get_sme_vector_length_u16() / kai_kr), k); +} + +void kai_run_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME(nr == kai_nr * kai_get_sme_vector_length_u16()); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == 1); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(scale == NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params == NULL); + + size_t height = k; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_stride; + // Padding with potentially non-zero values + const uint16_t* pad_row = rhs; + + size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(height); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x21, %x[out]\n" + "mov x20, %x[width]\n" + "ptrue p1.b\n" + "1:" // Bias: Full loop + "whilelt p0.h, XZR, x20\n" + "dech x20\n" + "cmp x20, #0x0\n" + "ld1h { z16.h }, p0/Z, [%x[bias]]\n" + "incb %x[bias]\n" + "st1h { z16.h }, p1, [x21]\n" + "add x21, x21, %x[out_stride]\n" + "bgt 1b\n" + "cmp %x[height], #0x8\n" + "incb %x[out]\n" + "blt 5f\n" + "2:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[out]\n" + "add x27, x9, %x[in_stride]\n" + "sub %x[height], %x[height], #0x8\n" + "add x26, x27, %x[in_stride]\n" + "mov x25, %x[width]\n" + "add x24, x26, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "3:" // Main row loop: Column loop + "whilelt p0.h, XZR, x25\n" + "decw x25, ALL, MUL #2\n" + "ld1h { z20.h }, p0/Z, [x9]\n" + "cmp x25, #0x0\n" + "addvl x9, x9, #1\n" + "ld1h { z17.h }, p0/Z, [x27]\n" + "addvl x27, x27, #1\n" + "ld1h { z19.h }, p0/Z, [x26]\n" + "addvl x26, x26, #1\n" + "ld1h { z16.h }, p0/Z, [x24]\n" + "addvl x24, x24, #1\n" + "ld1h { z18.h }, p0/Z, [x23]\n" + "addvl x23, x23, #1\n" + "zip1 z24.h, z20.h, z17.h\n" + "zip2 z23.h, z20.h, z17.h\n" + "ld1h { z17.h }, p0/Z, [x22]\n" + "addvl x22, x22, #1\n" + "ld1h { z22.h }, p0/Z, [x21]\n" + "addvl x21, x21, #1\n" + "zip1 z21.h, z19.h, z16.h\n" + "zip2 z20.h, z19.h, z16.h\n" + "ld1h { z16.h }, p0/Z, [x20]\n" + "addvl x20, x20, #1\n" + "zip1 z19.h, z18.h, z17.h\n" + "zip2 z18.h, z18.h, z17.h\n" + "st1h { z24.h }, p1, [x28]\n" + "st1h { z23.h }, p1, [x28, #1, MUL VL]\n" + "zip1 z17.h, z22.h, z16.h\n" + "zip2 z16.h, z22.h, z16.h\n" + "st1h { z21.h }, p1, [x28, #2, MUL VL]\n" + "st1h { z20.h }, p1, [x28, #3, MUL VL]\n" + "st1h { z19.h }, p1, [x28, #4, MUL VL]\n" + "st1h { z18.h }, p1, [x28, #5, MUL VL]\n" + "st1h { z17.h }, p1, [x28, #6, MUL VL]\n" + "st1h { z16.h }, p1, [x28, #7, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 3b\n" + "4:" // Main row loop: Column loop skip + "cmp %x[height], #0x8\n" + "addvl %x[out], %x[out], #8\n" + "bge 2b\n" + "cbz %x[height], 9f\n" + "5:" // Main loop skip + "6:" // Tail row loop: Head + "mov x9, %x[in]\n" + "cmp %x[height], #0x1\n" + "add x27, x9, %x[in_stride]\n" + "mov x28, %x[out]\n" + "add %x[in], x27, %x[in_stride]\n" + "csel x27, x27, %x[pad_row], GT\n" + "sub %x[height], %x[height], #0x2\n" + "mov x20, %x[width]\n" + "7:" // Tail row loop: Column loop + "whilelt p0.h, XZR, x20\n" + "decw x20, ALL, MUL #2\n" + "ld1h { z18.h }, p0/Z, [x9]\n" + "cmp x20, #0x0\n" + "addvl x9, x9, #1\n" + "ld1h { z16.h }, p0/Z, [x27]\n" + "addvl x27, x27, #1\n" + "zip1 z17.h, z18.h, z16.h\n" + "zip2 z16.h, z18.h, z16.h\n" + "st1h { z17.h }, p1, [x28]\n" + "st1h { z16.h }, p1, [x28, #1, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 7b\n" + "8:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 6b\n" + "9:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) + : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", + "p15", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", + "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", + "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..a24ef6e40bb3d8c27e4eec3fa301c003feb1eadf --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme.h @@ -0,0 +1,80 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting row index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n_idx, size_t k); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of rows. +/// @param[in] k Number of columns. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme(size_t n, size_t k); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme. +/// * Bias: @ref kai_get_bias_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme. +/// +/// @param[in] num_groups Number of groups. It must be 1. +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u16(). +/// @param[in] kr Block size in K dimension. It must be 2. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. It must be NULL. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. +/// @param[in] params Extra packing parameters. It must be NULL. +void kai_run_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus