diff --git a/CHANGELOG.md b/CHANGELOG.md index ebbd0acc1ccee9f6de3f28bda07756e31e2589e3..9dcaa5f13c613e94c1e617046358faa0c38aaa87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,12 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - New Advanced SIMD micro-kernels: - Matrix multiplication (MxN) Micro-kernels of QAI8DX LHS and QSI8CX RHS with F16 output, optimized for FEAT_I8MM and FEAT_DotProd. - Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI8CX RHS with F16 output, optimized for FEAT_DotProd. +- New SME micro-kernels: + - Indirect matrix multiplication (MxN) of FP16 input and output. + - Packing kernels for LHS and RHS +- New SME2 micro-kernels: + - Indirect matrix multiplication (MxN) of FP16 input and output. + - Matrix multiplication of packed indirect LHS and packed RHS ## v1.7.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index 500e3e531a7b713a2f4cd0e95421177bc387406b..f2221ecf17a783e673ca65d8758098a639c74789 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -222,11 +222,13 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME + kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c + kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -236,6 +238,7 @@ set(KLEIDIAI_FILES_SME ) set(KLEIDIAI_FILES_SME2 + kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c @@ -334,6 +337,7 @@ if(KLEIDIAI_BUILD_TESTS) test/common/int4.cpp test/common/matmul_test_common.cpp test/common/matrix_portion.cpp + test/common/memory.cpp test/common/printer.cpp test/common/rect.cpp test/common/round.cpp @@ -383,6 +387,7 @@ if(KLEIDIAI_BUILD_TESTS) add_executable(kleidiai_test test/tests/bfloat16_test.cpp test/tests/float16_test.cpp + test/tests/imatmul_test.cpp test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_f32_f32p_test.cpp @@ -399,6 +404,7 @@ if(KLEIDIAI_BUILD_TESTS) ) set_source_files_properties( + test/tests/imatmul_test.cpp test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index f3765e554b5ef268825901e9d93a2e72414dca20..7408dfff16ef6f8fa1898942e7d61a011d333b11 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -138,12 +138,14 @@ I8MM_KERNELS_ASM = [ # buildifier: keep sorted SME_KERNELS = [ + "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme", "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_lhs_pack_f32p2vlx1_f32_sme", "pack/kai_lhs_pack_x16p2vlx2_x16_sme", "pack/kai_lhs_pack_x8p2vlx4_x8_sme", "pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", + "pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", @@ -155,6 +157,7 @@ SME_KERNELS = [ # buildifier: keep sorted SME2_KERNELS = [ + "imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c new file mode 100644 index 0000000000000000000000000000000000000000..7e77125a4a58f190bc36cb868b52e6e348b539f5 --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c @@ -0,0 +1,252 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_nr = 2; +static const size_t kai_kr = 2; + +size_t kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); + const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + return m_idx * indirect_k * sizeof(uint16_t); +} + +static size_t kai_get_rhs_packed_stride_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t k_chunk_count, size_t k_chunk_length) { + const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + return kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() * + (sizeof(uint16_t) + indirect_k * sizeof(uint16_t)); +} + +size_t kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); + const size_t block_idx = n_idx / kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(); + return block_idx * + kai_get_rhs_packed_stride_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + k_chunk_count, k_chunk_length); +} + +size_t kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_row_stride) { + KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); + + return m_idx * dst_row_stride + n_idx * sizeof(uint16_t); +} + +size_t kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(size_t m, size_t n) { + return m * n * sizeof(uint16_t); +} + +void kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, float clamp_min, float clamp_max) { + typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + float16_t min; + float16_t max; + void* accumulator_buffer; + uint64_t flags; + } KernelArgs; + + KernelArgs args; + + args.A = lhs_packed; + args.B = rhs_packed; + + size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + + args.C = dst; + args.ldcb = dst_row_stride; + args.M = m; + args.N = n; + args.K = indirect_k; + args.min = (float16_t)clamp_min; + args.max = (float16_t)clamp_max; + + args.accumulator_buffer = NULL; + args.flags = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ldr w13, [%x[args], %[offsetof_M]]\n" + "mov x11, #0x0\n" + "mov x10, #0x0\n" + "ptrue p1.b\n" + ".inst 0x25207810 // ptrue pn8.b\n" + "ldr w9, [%x[args], %[offsetof_N]]\n" + "ldr x28, [%x[args], %[offsetof_A]]\n" + "1:" // M loop + "ldr x27, [%x[args], %[offsetof_B]]\n" + "2:" // N loop + "fmov z24.h, #0.0\n" + "ld1h { z5.h }, p1/Z, [x27]\n" + "fmov z27.h, #1.0\n" + "mov x26, x28\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "inch x27, ALL, MUL #2\n" + "zip1 z30.h, z5.h, z24.h\n" + "zip2 z20.h, z5.h, z24.h\n" + ".inst 0x81be2760 // fmopa za0.s, p1/M, p1/M, z27.h, z30.h\n" + ".inst 0x81b42761 // fmopa za1.s, p1/M, p1/M, z27.h, z20.h\n" + ".inst 0x81be2762 // fmopa za2.s, p1/M, p1/M, z27.h, z30.h\n" + ".inst 0x81b42763 // fmopa za3.s, p1/M, p1/M, z27.h, z20.h\n" + "ldr x20, [%x[args], %[offsetof_K]]\n" + "add x20, x20, #0x1\n" + "lsr x20, x20, #0x1\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 6f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" + ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" + ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" + ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" + ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" + "addvl x26, x26, #8\n" + ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + "ble 5f\n" + "4:" // K loop + ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" + "subs x21, x21, #0x1\n" + ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" + ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" + ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" + ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" + ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" + ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" + ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" + ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" + ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" + ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" + ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" + ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" + ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" + ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" + ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" + ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" + ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" + ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" + ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" + ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" + "addvl x26, x26, #8\n" + ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + "bgt 4b\n" + "5:" // K loop tail + ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" + ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" + ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" + ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" + ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" + ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" + ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" + ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" + ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" + ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" + ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" + ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" + ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" + ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" + ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" + ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" + "6:" // K oddments + "cbz x20, 8f\n" + "7:" // K oddments: Loop + ".inst 0xa1402345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26]\n" + "subs x20, x20, #0x1\n" + "addvl x26, x26, #2\n" + ".inst 0xa040236e // ld1h { z14.h-z15.h }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0x81ae24a0 // fmopa za0.s, p1/M, p1/M, z5.h, z14.h\n" + ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" + ".inst 0x81ae25a2 // fmopa za2.s, p1/M, p1/M, z13.h, z14.h\n" + ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" + "bgt 7b\n" + "8:" // K oddments: End + "ldr x25, [%x[args], %[offsetof_C]]\n" + "sub x24, x13, x11\n" + "cntw x23, ALL, MUL #2\n" + "ld1rh { z17.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "ldr x22, [%x[args], %[offsetof_ldcb]]\n" + "whilelt p0.h, x10, x9\n" + "cmp x24, x23\n" + "ld1rh { z16.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "mov x12, #0x0\n" + "mov x21, #0x0\n" + "add x25, x25, x10, LSL #1\n" // C += n + "mov x20, #0x2\n" + "madd x25, x11, x22, x25\n" // C += m * ldc + "csel x24, x24, x23, LT\n" + "10:" // Store to output array: Accumulator loop + ".inst 0xc006000e // mova { z14.b-z15.b }, za0h.b[x12, 0:1]\n" + "add x12, x12, #0x4\n" + "cmp x12, x23, LSL #1\n" + "add x21, x21, #0x1\n" + ".inst 0xc120e1cc // fcvt z12.h, { z14.s-z15.s }\n" + "csel x12, x12, x20, LT\n" + "cmp x21, x24\n" + ".inst 0x6470262c // fclamp z12.h, z17.h, z16.h\n" + "st1h { z12.h }, p0, [x25]\n" + "add x25, x25, x22\n" + "blt 10b\n" + "incw x10, ALL, MUL #2\n" + "cmp x10, x9\n" + "blt 2b\n" + "incw x11, ALL, MUL #2\n" + "mov x10, #0x0\n" + "cmp x11, x13\n" + "mov x28, x26\n" + "blt 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), + [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), + [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), + [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), [offsetof_M] "I"(offsetof(KernelArgs, M)), + [offsetof_N] "I"(offsetof(KernelArgs, N)), [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", + "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", + "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h new file mode 100644 index 0000000000000000000000000000000000000000..79c52a4205dd9b02c07bfc70a730ca96df8b77ee --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h @@ -0,0 +1,97 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme to pack the LHS matrix. +/// -# kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme to pack the RHS matrix. + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void); + +/// Gets the offset in bytes to the data element in the packed LHS matrix buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of `n_step`. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. Must be a multiple of `m_step`. +/// @param[in] n_idx Column index. Must be a multiple of `n_step`. +/// @param[in] dst_row_stride. Distance between start of two rows in the output buffer. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_row_stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * Packed LHS: @ref kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa. +/// * Output: @ref kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split +/// @param[in] lhs_packed Packed LHS matrix buffer. +/// @param[in] rhs_packed Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_row_stride Row stride in bytes of the output matrix. +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..bbc2b318bd11a02d8b576a6af9aae3a27fb70b43 --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h @@ -0,0 +1,45 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: imatmul_clamp_f16_f16p_f16p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_m_step_func_t)(void); +typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_n_step_func_t)(void); +typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_lhs_packed_offset_func_t)( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); +typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_rhs_packed_offset_func_t)( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); +typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_imatmul_clamp_f16_f16p_f16p_run_imatmul_func_t)( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + +/// Micro-kernel interface +struct kai_imatmul_clamp_f16_f16p_f16p_ukernel { + kai_imatmul_clamp_f16_f16p_f16p_get_m_step_func_t get_m_step; + kai_imatmul_clamp_f16_f16p_f16p_get_n_step_func_t get_n_step; + kai_imatmul_clamp_f16_f16p_f16p_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_imatmul_clamp_f16_f16p_f16p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_imatmul_clamp_f16_f16p_f16p_get_dst_offset_func_t get_dst_offset; + kai_imatmul_clamp_f16_f16p_f16p_get_dst_size_func_t get_dst_size; + kai_imatmul_clamp_f16_f16p_f16p_run_imatmul_func_t run_imatmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..f996bd4889613749a1f5a1ca25193c62b1bdd48b --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c @@ -0,0 +1,341 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" + +#include +#include + +#include "kai/kai_common.h" + +#define MR 2 +#define KR 2 +#define MAX_M_STEP (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR) + +static size_t kai_get_mr_lhs_imatmul_pack_x16p2vlx2_x16p_sme(void) { + return MR * kai_get_sme_vector_length_u16() / KR; +} + +size_t kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme(void) { + return kai_get_mr_lhs_imatmul_pack_x16p2vlx2_x16p_sme(); +} + +size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(m_idx % kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme() == 0); + + return m_idx * k_chunk_count * kai_roundup(k_chunk_length, KR) * sizeof(uint16_t); +} + +size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length) { + const size_t m_end = kai_roundup(m, kai_get_mr_lhs_imatmul_pack_x16p2vlx2_x16p_sme()); + return kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme(m_end, k_chunk_count, k_chunk_length); +} + +void kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, + const void* pad_ptr, void* lhs_packed) { + KAI_ASSUME(lhs_ptrs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + const size_t m_step = kai_get_mr_lhs_imatmul_pack_x16p2vlx2_x16p_sme(); + const size_t row_offset = 0; + const size_t width = k_chunk_length; + + KAI_ASSERT(m_step <= MAX_M_STEP); + const uint8_t* in[MAX_M_STEP]; + + uint8_t* out_base = lhs_packed; + for (size_t i_m = 0; i_m < m; i_m += m_step) { + for (size_t i_k_chunk = 0; i_k_chunk < k_chunk_count; i_k_chunk += 1) { + const size_t height = KAI_MIN(m - i_m, m_step); + void* out = out_base; + for (size_t y = 0; y < height; y += 1) { + KAI_ASSERT(i_k_chunk + (i_m + y) * k_chunk_count < m * k_chunk_count); + in[y] = *(lhs_ptrs + i_m * k_chunk_count + i_k_chunk * m_step + y); + if (in[y] != pad_ptr) { + in[y] += lhs_ptr_offset; + } + } + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x22, %x[width]\n" + "mov x21, %x[width]\n" + "cnth x20\n" + "inch x22\n" + "sub x7, x20, #0x1\n" + "sub x22, x22, #0x1\n" + "ands x7, x21, x7\n" + "cntw x8\n" + "udiv x22, x22, x20\n" // n_passes = ceildiv(width, VL) + "csel x7, x7, x20, NE\n" + "sub x13, x22, #0x1\n" + "add x7, x7, #0x1\n" + "sub x17, x8, #0x2\n" + "lsl x21, %x[height], #0x1\n" // height * 2 + "lsl x20, x8, #0x1\n" + "mov x16, #0x0\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + "cntw x9, ALL, MUL #2\n" + "cntw x28, ALL, MUL #3\n" + "ldr x27, [x11, #0x0]\n" + "lsr x13, x13, #0x1\n" // n_loops = (n_passes - 1) / 2 + "and x26, x22, #0x1\n" // odd_tail = bool(n_passes & 0x1) + "ldr x25, [x10, #0x0]\n" + "lsr x7, x7, #0x1\n" + "ptrue p12.s\n" + "ldr x24, [x11, #0x8]\n" + "whilelt p11.h, XZR, x21\n" + "whilelt p10.h, x20, x21\n" + "ldr x21, [x10, #0x8]\n" + "mov x23, %x[row_offset]\n" + "mov x22, %x[out]\n" + "whilelt p9.h, x16, %x[width]\n" + "whilelt p8.h, x16, %x[width]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "mov x12, #0x0\n" + "cbz x17, 2f\n" + "1:" // K loop: Charge: Loop + ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" + ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" + ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" + ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" + ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" + "add x12, x12, #0x4\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "cmp x12, x17, LSL #1\n" + "blt 1b\n" + "2:" // K loop: Charge: End + ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" + ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" + ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" + ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" + "ldr x27, [x11, #0x0]\n" + "inch x16\n" + ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "inch x23\n" + "cbz x13, 8f\n" + "mov x20, x13\n" + "3:" // K loop: Main loop + "whilelt p8.h, x16, %x[width]\n" + "mov x15, #0x0\n" + "mov x14, #0x0\n" + "cbz x17, 5f\n" + "4:" // K loop: Main loop: First: Loop + ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" + ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" + ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" + ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" + ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + "add x10, x10, #0x10\n" + "add x15, x15, #0x4\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x14, x14, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x14, x17\n" + "blt 4b\n" + "5:" // K loop: Main loop: First: Tail + ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" + ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" + ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" + ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" + "ldr x27, [x11, #0x0]\n" + "mov x13, #0x0\n" + ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" + "ldr x25, [x10, #0x0]\n" + "mov x12, #0x0\n" + ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + "whilelt p9.h, x16, %x[width]\n" + "inch x16\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "inch x23\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "whilelt p8.h, x16, %x[width]\n" + "cbz x17, 7f\n" + "6:" // K loop: Main loop: Second: Loop + ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" + ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" + ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" + ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" + ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "add x10, x10, #0x10\n" + "add x13, x13, #0x4\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x12, x12, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x12, x17\n" + "blt 6b\n" + "7:" // K loop: Main loop: Second: Tail + ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" + ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" + ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" + ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "whilelt p9.h, x16, %x[width]\n" + "subs x20, x20, #0x1\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "inch x16\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "inch x23\n" + "bgt 3b\n" + "8:" // K loop: Tails + "cbnz x26, 11f\n" + "mov x11, %x[in]\n" + "whilelt p8.h, x16, %x[width]\n" + "mov x13, #0x0\n" + "mov x12, #0x0\n" + "9:" // K loop: Tails: Even: First + ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25396161 // psel p1.h, p8.h/Z, p11.h[w13, #1]\n" + ".inst 0x25396140 // psel p0.h, p8.h/Z, p10.h[w13, #1]\n" + ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "ldr x21, [x11, #0x0]\n" + "cmp x12, x8\n" + "ldr x20, [x11, x8, LSL #0x3]\n" + "add x11, x11, #0x8\n" + ".inst 0xe05726a1 // ld1h { za0h.h[x13, #1] }, p1/Z, [x21, x23, LSL #1]\n" + ".inst 0xe0572289 // ld1h { za1h.h[x13, #1] }, p0/Z, [x20, x23, LSL #1]\n" + "add x13, x13, #0x2\n" + "blt 9b\n" + "whilelt p9.h, x16, %x[width]\n" + "whilelt p8.h, x16, %x[width]\n" + "mov x20, #0x0\n" + "mov x12, #0x0\n" + "10:" // K loop: Tails: Even: Second + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + "add x20, x20, #0x2\n" + ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 10b\n" + "whilelt p8.h, x16, %x[width]\n" + "b 13f\n" + "11:" // K loop: Tails: Odd + "mov x12, #0x0\n" + "12:" // K loop: Tails: Odd: Loop + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 12b\n" + "13:" // K loop: End + "mov %x[out], x22\n" + ".inst 0xd503467f // SMSTOP\n" + : [out] "+&r"(out) + : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", + "p7", "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", + "x24", "x25", "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", + "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", + "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + out_base += m_step * kai_roundup(k_chunk_length, KR) * sizeof(uint16_t); + } + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..a0343938bd99c5aa466595ee45344d9763e0721d --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h @@ -0,0 +1,61 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return Step size for row index +size_t kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme(void); + +/// Gets the offset in bytes to the data element in the packed LHS buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the size in bytes of the packed LHS buffer. +/// +/// @param[in] m Number of rows in the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The size in bytes of the packed LHS buffer. +size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length); + +/// Pack the LHS matrix for use with indirect matrix multiplication +/// +/// @param[in] m Number of rows of the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// @param[in] lhs_ptrs Pointer to an array of input pointers consisting of +/// `m * k_chunk_count` pointers. +/// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs +/// array, excluding zero pointers. +/// @param[in] pad_ptr Pointer to chunk used for padding. @ref lhs_ptr_offset is +/// not applied to this pointer when used in @ref lhs_ptrs. This can +/// be NULL if there is no padding used @ref lhs_ptrs +/// @param[out] lhs_packed Packed LHS matrix. +void kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, + const void* pad_ptr, void* lhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..a9c0bb73ab711e1793a31611172e551ee368eef2 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c @@ -0,0 +1,204 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. +#include "kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +#define NR 2 +#define KR 2 +static const size_t kai_num_bytes_input = sizeof(uint16_t); +static const size_t kai_num_bytes_output = sizeof(uint16_t); +static const size_t kai_num_bytes_bias = sizeof(uint16_t); + +#define MAX_N_STEP (NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR)) + +size_t kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(void) { + return NR * kai_get_sme_vector_length_u16() / KR; +} + +size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx) { + KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() == 0); + + return n_idx * kai_num_bytes_input; +} + +size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx) { + return n_idx * kai_num_bytes_bias; +} + +static size_t kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t k_chunk_count, size_t k_chunk_length) { + return kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() * + (kai_num_bytes_bias + k_chunk_count * kai_roundup(k_chunk_length, KR) * kai_num_bytes_output); +} + +size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() == 0); + + const size_t block_idx = n_idx / kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(); + return block_idx * + kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(k_chunk_count, k_chunk_length); +} + +size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length) { + const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme()); + return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + n_rounded_up, k_chunk_count, k_chunk_length); +} + +void kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + void* rhs_packed) { + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(rhs_packed != NULL); + + size_t height = k_chunk_length; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_row_stride; + + KAI_ASSERT(kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() <= MAX_N_STEP); + uint16_t pad_row[MAX_N_STEP]; + if (height % KR) { + memset(pad_row, 0, MAX_N_STEP * sizeof(uint16_t)); + } + + size_t out_stride = + kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(k_chunk_count, k_chunk_length); + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x21, %x[out]\n" + "mov x20, %x[width]\n" + "ptrue p1.b\n" + "1:" // Bias: Full loop + "whilelt p0.h, XZR, x20\n" + "dech x20\n" + "cmp x20, #0x0\n" + "ld1h { z16.h }, p0/Z, [%x[bias]]\n" + "incb %x[bias]\n" + "st1h { z16.h }, p1, [x21]\n" + "add x21, x21, %x[out_stride]\n" + "bgt 1b\n" + "incb %x[out]\n" + "mov x11, %x[k_chunk_count]\n" + "2:" // Chunk Loop + "mov x10, %x[height]\n" + "cmp x10, #0x8\n" + "blt 6f\n" + "3:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[out]\n" + "add x27, x9, %x[in_stride]\n" + "sub x10, x10, #0x8\n" + "add x26, x27, %x[in_stride]\n" + "mov x25, %x[width]\n" + "add x24, x26, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "4:" // Main row loop: Column loop + "whilelt p0.h, XZR, x25\n" + "decw x25, ALL, MUL #2\n" + "ld1h { z20.h }, p0/Z, [x9]\n" + "cmp x25, #0x0\n" + "addvl x9, x9, #1\n" + "ld1h { z17.h }, p0/Z, [x27]\n" + "addvl x27, x27, #1\n" + "ld1h { z19.h }, p0/Z, [x26]\n" + "addvl x26, x26, #1\n" + "ld1h { z16.h }, p0/Z, [x24]\n" + "addvl x24, x24, #1\n" + "ld1h { z18.h }, p0/Z, [x23]\n" + "addvl x23, x23, #1\n" + "zip1 z24.h, z20.h, z17.h\n" + "zip2 z23.h, z20.h, z17.h\n" + "ld1h { z17.h }, p0/Z, [x22]\n" + "addvl x22, x22, #1\n" + "ld1h { z22.h }, p0/Z, [x21]\n" + "addvl x21, x21, #1\n" + "zip1 z21.h, z19.h, z16.h\n" + "zip2 z20.h, z19.h, z16.h\n" + "ld1h { z16.h }, p0/Z, [x20]\n" + "addvl x20, x20, #1\n" + "zip1 z19.h, z18.h, z17.h\n" + "zip2 z18.h, z18.h, z17.h\n" + "st1h { z24.h }, p1, [x28]\n" + "st1h { z23.h }, p1, [x28, #1, MUL VL]\n" + "zip1 z17.h, z22.h, z16.h\n" + "zip2 z16.h, z22.h, z16.h\n" + "st1h { z21.h }, p1, [x28, #2, MUL VL]\n" + "st1h { z20.h }, p1, [x28, #3, MUL VL]\n" + "st1h { z19.h }, p1, [x28, #4, MUL VL]\n" + "st1h { z18.h }, p1, [x28, #5, MUL VL]\n" + "st1h { z17.h }, p1, [x28, #6, MUL VL]\n" + "st1h { z16.h }, p1, [x28, #7, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 4b\n" + "cmp x10, #0x8\n" + "addvl %x[out], %x[out], #8\n" + "bge 3b\n" + "cbz x10, 10f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x9, %x[in]\n" + "cntw x22, ALL, MUL #4\n" + "add x27, x9, %x[in_stride]\n" + "cmp x10, #0x1\n" + "add %x[in], x27, %x[in_stride]\n" + "mov x28, %x[out]\n" + "csel %x[in], %x[in], x27, GT\n" + "csel x27, x27, %x[pad_row], GT\n" + "csel x21, x22, XZR, GT\n" + "sub x10, x10, #0x2\n" + "mov x20, %x[width]\n" + "8:" // Tail row loop: Column loop + "whilelt p0.h, XZR, x20\n" + "decw x20, ALL, MUL #2\n" + "ld1h { z18.h }, p0/Z, [x9]\n" + "cmp x20, #0x0\n" + "add x9, x9, x22\n" + "ld1h { z16.h }, p0/Z, [x27]\n" + "add x27, x27, x21\n" + "zip1 z17.h, z18.h, z16.h\n" + "zip2 z16.h, z18.h, z16.h\n" + "st1h { z17.h }, p1, [x28]\n" + "st1h { z16.h }, p1, [x28, #1, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 8b\n" + "cmp x10, #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 7b\n" + "10:" // Done + "sub x11, x11, #0x1\n" + "cbnz x11, 2b\n" + ".inst 0xd503467f // SMSTOP\n" + : [bias] "+&r"(bias), [in] "+&r"(in), [out] "+&r"(out) + : [height] "r"(height), [in_stride] "r"(in_stride), [k_chunk_count] "r"(k_chunk_count), + [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", + "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", + "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..e26bc3f5773de3954f8b73a1a930d103f860211c --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h @@ -0,0 +1,79 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return Step size for column index. +size_t kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. Must be divisible by `n_step` +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Column index. Must be divisible by `n_step` +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of columns. +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme. +/// * Bias: @ref kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme. +/// * Scale: @ref kai_get_scale_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme. +/// +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// @param[in] rhs_row_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[out] rhs_packed Packed RHS matrix. +void kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + void* rhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/common/data_format.cpp b/test/common/data_format.cpp index ccf9a46acd69887b50a63fd0731fb48d519b1dc0..e188710ad8bfb34d721212c49f83d424a11f9daf 100644 --- a/test/common/data_format.cpp +++ b/test/common/data_format.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,7 @@ #include #include +#include #include "kai/kai_common.h" #include "test/common/data_type.hpp" @@ -175,4 +176,18 @@ size_t DataFormat::default_size_in_bytes(size_t height, size_t width) const { return num_rows * block_stride; } +size_t DataFormat::Hash::operator()(const DataFormat& format) const { + using DT = std::underlying_type_t; + using PF = std::underlying_type_t; + return // + (std::hash
{}(static_cast
(format._data_type)) << 0) ^ // + (std::hash{}(static_cast(format._zero_point_dt)) << 1) ^ // + (std::hash
{}(static_cast
(format._scale_dt) << 2)) ^ // + (std::hash
{}(static_cast
(format._zero_point_dt)) << 3) ^ // + (std::hash{}(format._block_height) << 4) ^ // + (std::hash{}(format._block_width) << 5) ^ // + (std::hash{}(format._subblock_height) << 6) ^ // + (std::hash{}(format._subblock_width) << 7); // +} + } // namespace kai::test diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp index 730dd86e59259e7cae48620144aceb01ce8a45c7..2d7f0b84d3577eb0ce52a71c2cd6b7df49459832 100644 --- a/test/common/data_format.hpp +++ b/test/common/data_format.hpp @@ -141,6 +141,11 @@ public: /// @return The size in bytes of the matrix. [[nodiscard]] size_t default_size_in_bytes(size_t height, size_t width) const; + /// Hash functor + struct Hash { + size_t operator()(const DataFormat& format) const; + }; + private: DataType _data_type; PackFormat _pack_format; diff --git a/test/common/matmul_test_common.cpp b/test/common/matmul_test_common.cpp index 73d41c09e28cab56db5c95df6e78e1e0c757c319..67b05d0b5588f7004c8d07959232b49ec4fb007e 100644 --- a/test/common/matmul_test_common.cpp +++ b/test/common/matmul_test_common.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -12,13 +12,20 @@ namespace kai::test { void PrintTo(const MatMulTestParams& param, std::ostream* os) { const auto& [method, shape, portion] = param; - // NOLINTBEGIN(cppcoreguidelines-pro-bounds-constant-array-index) - *os << "Method_" << method.name // - << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // - << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + *os << "Method_" << method.name << "__"; + PrintTo(shape, os); + *os << "__"; + PrintTo(portion, os); +} + +void PrintTo(const MatMulShape& shape, std::ostream* os) { + *os << "M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k; +} + +void PrintTo(const MatrixPortion& portion, std::ostream* os) { + *os << "PortionStartRow_" << static_cast(portion.start_row() * 1000) // << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // << "__PortionHeight_" << static_cast(portion.height() * 1000) // << "__PortionWidth_" << static_cast(portion.width() * 1000); - // NOLINTEND(cppcoreguidelines-pro-bounds-constant-array-index) } } // namespace kai::test diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index 5b3e2424bae08e93799cd33cb09f620693da1725..3db8b8113c19bb65a4eb9c1601e61baeee92a244 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -22,6 +22,34 @@ struct MatMulShape { size_t m; ///< LHS height. size_t n; ///< RHS width. size_t k; ///< LHS width and RHS height. + + struct Hash { + size_t operator()(const MatMulShape& shape) const { + return // + (std::hash{}(shape.m) << 0) ^ // + (std::hash{}(shape.n) << 1) ^ // + (std::hash{}(shape.k) << 2); // + } + }; + +private: + friend bool operator==(const MatMulShape& lhs, const MatMulShape& rhs) { + return // + lhs.m == rhs.m && // + lhs.n == rhs.n && // + lhs.k == rhs.k; + } +}; + +/// Value range +template +struct Range { + T min; + T max; + + [[nodiscard]] T range() const { + return max - min; + } }; // NOLINTBEGIN(misc-non-private-member-variables-in-classes) @@ -459,4 +487,6 @@ using MatMulTestParams = std::tuple; /// Prints the test information. void PrintTo(const MatMulTestParams& param, std::ostream* os); +void PrintTo(const MatMulShape& shape, std::ostream* os); +void PrintTo(const MatrixPortion& portion, std::ostream* os); } // namespace kai::test diff --git a/test/common/memory.cpp b/test/common/memory.cpp new file mode 100644 index 0000000000000000000000000000000000000000..499f213113e813937e8725c9e1a9a7117533b9f5 --- /dev/null +++ b/test/common/memory.cpp @@ -0,0 +1,75 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/memory.hpp" + +#include + +#include "test/common/bfloat16.hpp" +#include "test/common/float16.hpp" + +namespace kai::test { + +double read_array(DataType type, const void* array, size_t index) { + switch (type) { + case DataType::FP32: + return read_array(array, index); + case DataType::FP16: + return static_cast(read_array(array, index)); + case DataType::BF16: + return static_cast(read_array(array, index)); + case DataType::I32: + return read_array(array, index); + case DataType::QAI8: + return read_array(array, index); + case DataType::QSU4: + return read_array(array, index); + case DataType::QSI4: + return read_array(array, index); + case DataType::UNKNOWN: + default: + KAI_ERROR("Trying to read unknown data type"); + } + return std::numeric_limits::signaling_NaN(); +} + +void write_array(DataType type, void* array, size_t index, double value) { + switch (type) { + case DataType::FP32: { + write_array(array, index, value); + return; + } + case DataType::FP16: { + write_array(array, index, static_cast(value)); + return; + } + case DataType::BF16: { + write_array(array, index, static_cast(value)); + return; + } + case DataType::I32: { + write_array(array, index, value); + return; + } + case DataType::QAI8: { + write_array(array, index, value); + return; + } + case DataType::QSU4: { + write_array(array, index, static_cast(value)); + return; + } + case DataType::QSI4: { + write_array(array, index, static_cast(value)); + return; + } + case DataType::UNKNOWN: + default: + KAI_ERROR("Trying to write unknown data type"); + } +} + +} // namespace kai::test diff --git a/test/common/memory.hpp b/test/common/memory.hpp index c856218f6351eefa6c30fdf715ec6f875f756037..28a24ea1e018cab51759563eb0782b74e26eaf0c 100644 --- a/test/common/memory.hpp +++ b/test/common/memory.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -12,6 +12,7 @@ #include "kai/kai_common.h" #include "test/common/bfloat16.hpp" +#include "test/common/data_type.hpp" #include "test/common/int4.hpp" namespace kai::test { @@ -50,6 +51,15 @@ T read_array(const void* array, size_t index) { } } +/// Reads the array at the specified index +/// +/// @param[in] type Array element data type +/// @param[in] array Data buffer. +/// @param[in] index Array index. +/// +/// @return Value at specified index +double read_array(DataType type, const void* array, size_t index); + /// Writes the specified value to the array. /// /// @param[in] array Data buffer. @@ -80,4 +90,12 @@ void write_array(void* array, size_t index, T value) { } } +/// Writes the specified value to the array. +/// +/// @param[in] type Array element type. +/// @param[in] array Data buffer. +/// @param[in] index Array index. +/// @param[in] value Value to be stored. +void write_array(DataType type, void* array, size_t index, double value); + } // namespace kai::test diff --git a/test/reference/clamp.cpp b/test/reference/clamp.cpp index 6a8e7433c7069f14df8067c2eb104dde82608afa..ab2e77e9c9b9654b6845a90b0c99e27ea756e22b 100644 --- a/test/reference/clamp.cpp +++ b/test/reference/clamp.cpp @@ -49,6 +49,20 @@ std::tuple find_clamp_range(const void* src, size_t len, float ratio) { template std::tuple find_clamp_range(const void* src, size_t len, float ratio); template std::tuple find_clamp_range(const void* src, size_t len, float ratio); +std::tuple find_clamp_range(DataType type, const void* src, size_t len, float ratio) { + auto max = std::numeric_limits::min(); + auto min = std::numeric_limits::max(); + + for (size_t i = 0; i < len; i += 1) { + const double value = read_array(type, src, i); + max = std::max(value, max); + min = std::min(value, min); + } + + const float reduction = (max - min) * (1.0F - ratio) / 2.0F; + return {min + reduction, max - reduction}; +} + template std::vector clamp(const void* src, size_t len, T min_value, T max_value) { std::vector dst(round_up_division(len * size_in_bits, 8)); @@ -63,4 +77,14 @@ std::vector clamp(const void* src, size_t len, T min_value, T max_value template std::vector clamp(const void* src, size_t len, float min_value, float max_value); template std::vector clamp(const void* src, size_t len, Float16 min_value, Float16 max_value); +std::vector clamp(DataType type, const void* src, size_t len, float min_value, float max_value) { + std::vector dst(round_up_division(len * data_type_size_in_bits(type), 8)); + + for (size_t i = 0; i < len; ++i) { + write_array(type, dst.data(), i, std::clamp(read_array(type, src, i), min_value, max_value)); + } + + return dst; +} + } // namespace kai::test diff --git a/test/reference/clamp.hpp b/test/reference/clamp.hpp index b665917e7f6fbc9d7b92c6a49142415f2703d5a2..532e7d25c0d7d0587af228f9141c3f30d4ac869f 100644 --- a/test/reference/clamp.hpp +++ b/test/reference/clamp.hpp @@ -11,6 +11,8 @@ #include #include +#include "test/common/data_type.hpp" + namespace kai::test { /// Finds the clamping parameters to limit the dynamic range. @@ -23,6 +25,16 @@ namespace kai::test { template std::tuple find_clamp_range(const void* src, size_t len, float ratio); +/// Finds the clamping parameters to limit the dynamic range. +/// +/// @param[in] type Array element data type. +/// @param[in] src The data buffer. +/// @param[in] len The number of values. +/// @param[in] ratio The ratio between the output dynamic range and the input dynamic range. +/// +/// @return The minimum value and the maximum value. +std::tuple find_clamp_range(DataType type, const void* src, size_t len, float ratio); + /// Clamps the matrix. /// /// @param[in] src Data buffer of the source matrix. @@ -32,4 +44,12 @@ std::tuple find_clamp_range(const void* src, size_t len, float ratio); template std::vector clamp(const void* src, size_t len, T min_value, T max_value); +/// Clamps the matrix. +/// +/// @param[in] type Array element data type. +/// @param[in] src Data buffer of the source matrix. +/// @param[in] len Number of values in the source matrix. +/// @param[in] min_value Lower bound of clamp. +/// @param[in] max_value Upper bound of clamp. +std::vector clamp(DataType type, const void* src, size_t len, float min_value, float max_value); } // namespace kai::test diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 4ee2046abd1d13d67f377f7b902abd0c8f2ca045..b1378c759d687d6db45ab2e894afe36d65b48285 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -184,12 +184,42 @@ std::vector matmul( return tmp_dst; } +std::vector indirect_matmul( + const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales, + const void* lhs_zero_points, + DataType lhs_dt, // + const void* rhs, const void* rhs_scales, const void* rhs_zero_points, + DataType rhs_dt, // + const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, // + DataType dst_dt, // + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length) { + // This is inefficient, but allows code-reuse + const size_t chunk_bytes = k_chunk_length * round_up_division(data_type_size_in_bits(lhs_dt), 8); + const size_t n_chunks = m * k_chunk_count; + std::vector lhs(n_chunks * chunk_bytes); + + // Copy all chunks to the created matrix + for (size_t i = 0; i < n_chunks; i += 1) { + const uint8_t* src_pointer = static_cast(lhs_idata[i]); + if (src_pointer != lhs_padding_ptr) { + src_pointer += lhs_offset; + } + memcpy(lhs.data() + i * chunk_bytes, src_pointer, chunk_bytes); + } + + return matmul( + lhs.data(), lhs_scales, lhs_zero_points, lhs_dt, // + rhs, rhs_scales, rhs_zero_points, rhs_dt, // + bias, bias_scales, bias_zero_points, bias_dt, // + dst_dt, m, n, k_chunk_count * k_chunk_length, false, false); +} + template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> std::vector indirect_matmul_nt_t_quantized( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // - const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales, + const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, @@ -208,7 +238,7 @@ std::vector indirect_matmul_nt_t_quantized( // Calculate the K chunk pointer. Apply offset if this is not padding const size_t k_chunk_idx = i_m * k_chunk_count + i_k_chunk; const void* k_chunk_ptr = lhs_ptrs[k_chunk_idx]; - if (k_chunk_ptr != lhs_padding) { + if (k_chunk_ptr != lhs_padding_ptr) { k_chunk_ptr = reinterpret_cast(reinterpret_cast(k_chunk_ptr) + lhs_offset); } diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index 758d2fbbad9c2751eff831ac79eecdec87e3f8cb..343b8d343402b2b8a115876268047cbc32bdc984 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -65,6 +65,37 @@ std::vector matmul( size_t m, size_t n, size_t k, // bool lhs_transposed, bool rhs_transposed); +/// Indirect matrix multiplication. +/// +/// @param[in] lhs_idata The indirect LHS data matrix. +/// @param[in] lhs_scales (Optional) LHS operand quantization scales. +/// @param[in] lhs_offset The indirection LHS data matrix offset, applied to non-padding pointers +/// @param[in] lhs_padding_ptr The indirection LHS padding chunk pointer +/// @param[in] lhs_zero_points (Optional) LHS operand quantization zero point. +/// @param[in] lhs_dt LHS operand data type. +/// @param[in] rhs RHS operand data. +/// @param[in] rhs_scales (Optional) RHS operand quantization scales. +/// @param[in] rhs_zero_points (Optional) RHS operand quantization zero point. +/// @param[in] rhs_dt RHS operand data type. +/// @param[in] bias Bias operand data. +/// @param[in] bias_scales (Optional) Bias operand quantization scales. +/// @param[in] bias_zero_points (Optional) Bias operand quantization zero point. +/// @param[in] bias_dt Bias operand data type. +/// @param[in] dst_dt Output data type. +/// @param[in] m Output height. +/// @param[in] n Output width. +/// @param[in] k_chunk_count Number pointers per row in lhs_idata +/// @param[in] k_chunk_size Number of elements in each LHS K chunk +/// +/// @return The result data buffer. +std::vector indirect_matmul( + const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales, + const void* lhs_zero_points, DataType lhs_dt, // + const void* rhs, const void* rhs_scales, const void* rhs_zero_points, DataType rhs_dt, // + const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, // + DataType dst_dt, // + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length); + /// Matrix multiplication with quantized input and floating-point output. /// /// The LHS matrix is non-transposed and the RHS matrix is transposed. @@ -127,7 +158,7 @@ std::vector matmul_clamp_nt_t( /// @param[in] lhs_data The LHS data matrix. /// @param[in] lhs_ptrs The indirect LHS data matrix. /// @param[in] lhs_offset The indirection LHS data matrix offset, applied to non-padding pointers -/// @parma[in] lhs_padding The indirection LHS padding chunk pointer +/// @param[in] lhs_padding_ptr The indirection LHS padding chunk pointer /// @param[in] lhs_scales The LHS quantization scales matrix. /// @param[in] lhs_zero_points The LHS quantization zero points matrix. /// @param[in] lhs_quant_width The LHS quantization block width. @@ -166,7 +197,7 @@ template < typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> std::vector indirect_matmul_nt_t_quantized( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // - const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales, + const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, diff --git a/test/tests/imatmul_test.cpp b/test/tests/imatmul_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f787143bb8a1bd005e3c9cbbe4e5d9b923558595 --- /dev/null +++ b/test/tests/imatmul_test.cpp @@ -0,0 +1,503 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include + +#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h" +#include "test/common/compare.hpp" +#include "test/common/cpu_info.hpp" +#include "test/common/matmul_test_common.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/memory.hpp" +#include "test/common/round.hpp" +#include "test/common/sme.hpp" +#include "test/reference/clamp.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/reorder.hpp" + +namespace kai::test { + +// Ensure static linkage for all functionality local to this test file +namespace { + +/// Convenience wrapper for K-chunk handling +struct KChunk { + size_t count; + size_t length; +}; + +/// Interface for indirect matmul LHS packing kernel +struct LhsPackIndirectKernel { + std::function get_m_step; + std::function get_lhs_packed_offset; + std::function get_lhs_packed_size; + std::function + pack; +}; + +/// Interface for indirect matmul RHS packing kernel +struct RhsPackIndirectKernel { + std::function get_n_step; + std::function get_rhs_offset; + std::function get_bias_offset; + std::function get_rhs_packed_offset; + std::function get_rhs_packed_size; + std::function + pack; +}; + +/// Interface for indirect matmul kernel +struct MatMulIndirectKernel { + std::function get_m_step; + std::function get_n_step; + std::function get_mr; + std::function get_nr; + std::function get_kr; + std::function get_lhs_packed_offset; + std::function get_rhs_packed_offset; + std::function get_dst_offset; + std::function get_dst_size; + std::function + imatmul; +}; + +/// Description of a Indirect Matmul kernel set +struct IndirectMatMul { + std::string_view name; + std::function is_supported; + + MatMulShape pack_shape; + struct Format { + DataFormat lhs; + DataFormat rhs; + DataFormat bias; + DataFormat out; + + struct Hash { + size_t operator()(const Format& format) const { + return // + (DataFormat::Hash{}(format.lhs) << 0) ^ // + (DataFormat::Hash{}(format.rhs) << 1) ^ // + (DataFormat::Hash{}(format.bias) << 2) ^ // + (DataFormat::Hash{}(format.out) << 3); + } + }; + + private: + friend bool operator==(const Format& lhs, const Format& rhs) { + return // + lhs.lhs == rhs.lhs && // + lhs.rhs == rhs.rhs && // + lhs.bias == rhs.bias && // + lhs.out == rhs.out; + } + } format; + + LhsPackIndirectKernel lhs; + RhsPackIndirectKernel rhs; + MatMulIndirectKernel imatmul; +}; + +/// Simple byte buffer +using Buffer = std::vector; + +/// Convenience type for test list +using IndirectMatMulArray = std::array; + +/// Test parameter bundle type +using IndirectMatMulTestParams = std::tuple; + +/// Test type +using IndirectMatMulTest = testing::TestWithParam; + +/// Use interface for matmul kernel +const kai_imatmul_clamp_f16_f16p_f16p_ukernel& get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() { + static kai_imatmul_clamp_f16_f16p_f16p_ukernel ukernel; + ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + ukernel.run_imatmul = kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + return ukernel; +} + +/// Retreive the test list +const IndirectMatMulArray& get_indirect_matmul_methods() { + static IndirectMatMulArray indirect_matmul_methods{}; + + // F16 IMATMUL //////////////////////////////////////////////////////////// + indirect_matmul_methods[0].name = "indirect_matmul_f16_f16p_f16p_2vlx2vl_sme2_mopa"; + indirect_matmul_methods[0].is_supported = cpu_has_sme2; + indirect_matmul_methods[0].pack_shape.m = 2 * get_sme_vector_length(); + indirect_matmul_methods[0].pack_shape.n = 2 * get_sme_vector_length(); + indirect_matmul_methods[0].pack_shape.k = sizeof(int32_t); + indirect_matmul_methods[0].format.lhs = DataFormat(DataType::FP16); + indirect_matmul_methods[0].format.rhs = DataFormat(DataType::FP16); + indirect_matmul_methods[0].format.bias = DataFormat(DataType::FP16); + indirect_matmul_methods[0].format.out = DataFormat(DataType::FP16); + + // LHS + indirect_matmul_methods[0].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme; + indirect_matmul_methods[0].lhs.get_lhs_packed_offset = + kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme; + indirect_matmul_methods[0].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme; + indirect_matmul_methods[0].lhs.pack = kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme; + + // RHS + indirect_matmul_methods[0].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; + indirect_matmul_methods[0].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; + indirect_matmul_methods[0].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; + indirect_matmul_methods[0].rhs.get_rhs_packed_offset = + kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; + indirect_matmul_methods[0].rhs.get_rhs_packed_size = + kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; + indirect_matmul_methods[0].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; + + // IMATMUL + const kai_imatmul_clamp_f16_f16p_f16p_ukernel& ukernel_f16 = + get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(); + indirect_matmul_methods[0].imatmul.get_m_step = ukernel_f16.get_m_step; + indirect_matmul_methods[0].imatmul.get_n_step = ukernel_f16.get_n_step; + indirect_matmul_methods[0].imatmul.get_lhs_packed_offset = ukernel_f16.get_lhs_packed_offset; + indirect_matmul_methods[0].imatmul.get_rhs_packed_offset = ukernel_f16.get_rhs_packed_offset; + indirect_matmul_methods[0].imatmul.get_dst_offset = ukernel_f16.get_dst_offset; + indirect_matmul_methods[0].imatmul.get_dst_size = ukernel_f16.get_dst_size; + indirect_matmul_methods[0].imatmul.imatmul = ukernel_f16.run_imatmul; + + return indirect_matmul_methods; +} + +/// Test reference identification +struct TestDataId { + MatMulShape shape; + MatMulShape pack_shape; + IndirectMatMul::Format format; + size_t k_chunk_length; + float clamp_rate; + + struct Hash { + size_t operator()(const TestDataId& test_id) const { + return // + (MatMulShape::Hash{}(test_id.shape) << 0) ^ // + (MatMulShape::Hash{}(test_id.pack_shape) << 1) ^ // + (IndirectMatMul::Format::Hash{}(test_id.format) << 2) ^ // + (std::hash{}(test_id.k_chunk_length) << 3) ^ // + (std::hash{}(test_id.clamp_rate) << 4); // + } + }; + +private: + friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { + return // + lhs.shape == rhs.shape && // + lhs.pack_shape == rhs.pack_shape && // + lhs.format == rhs.format && // + lhs.k_chunk_length == rhs.k_chunk_length && // + lhs.clamp_rate == rhs.clamp_rate; + } +}; + +/// Test reference data +struct TestData { + Buffer lhs; ///< LHS input matrix + Buffer rhs; ///< RHS input matrix + Buffer bias; ///< Bias vector + Buffer out; ///< Reference imatmul result + Buffer indirection; ///< LHS indirection buffer + uintptr_t indirection_offset; ///< LHS indirection buffer offset + Buffer padding; ///< Padding buffer + Range clamp_range; ///< Clamp range +}; + +/// Reference data generator +/// +/// Uses test id to generate reference data, and caches it. +struct ReferenceGenerator { + /// Retrieve reference data for the provided test identification + static const TestData& get_test_reference(const TestDataId test_id) { + static std::unordered_map m_data; + if (const auto itr = m_data.find(test_id); itr != end(m_data)) { + return itr->second; + } + + return m_data[test_id] = generate_reference(test_id); + } + +private: + /// Return incremented seed value + static size_t get_seed() { + static size_t seed = 0; + return seed++; + } + + /// Generate reference data. Not intended to be called + /// directly, as this would bypass caching mechanism. + static TestData generate_reference(const TestDataId& test_id) { + const auto& [chunked_shape, pack_shape, format, k_chunk_length, clamp_rate] = test_id; + + // The LHS matrix will be split into several chunks in the K dimension + const size_t k_chunk_count = chunked_shape.k; + MatMulShape shape = {chunked_shape.m, chunked_shape.n, k_chunk_count * k_chunk_length}; + + // Generate random input data + Buffer lhs = fill_matrix_random(shape.m, shape.k, format.lhs, get_seed()); + Buffer rhs = fill_matrix_random(shape.k, shape.n, format.rhs, get_seed()); + Buffer bias = fill_matrix_random(1, shape.n, format.bias, get_seed()); + + // Data types used + const DataType lhs_dt = format.lhs.data_type(); + const DataType rhs_dt = format.rhs.data_type(); + const DataType out_dt = format.out.data_type(); + const DataType bias_dt = format.bias.data_type(); + + // Create a padding chunk + const size_t k_chunk_size = round_up_division(k_chunk_length * data_type_size_in_bits(lhs_dt), 8); + const size_t row_size = k_chunk_count * k_chunk_size; + Buffer lhs_padding(k_chunk_size); + for (size_t i = 0; i < k_chunk_length; i += 1) { + static constexpr double padding_value = 0; + write_array(lhs_dt, lhs_padding.data(), i, padding_value); + } + + // Set up indirection buffer + const uintptr_t indirection_offset = reinterpret_cast(lhs.data()); + std::vector indirection(chunked_shape.m * chunked_shape.k); + for (size_t i_m = 0; i_m < chunked_shape.m; i_m += 1) { + for (size_t i_k = 0; i_k < chunked_shape.k; i_k += 1) { + const size_t idx = i_m * chunked_shape.k + i_k; + // Test padding pointers using first LHS row for shapes where M > 1 + if (chunked_shape.m > 1 && i_m == 0) { + indirection.at(idx) = lhs_padding.data(); + } else { + uintptr_t offset = i_m * row_size + i_k * k_chunk_size; + indirection.at(idx) = reinterpret_cast(offset); + } + } + } + + // Pack indirection buffer + Buffer indirection_packed = reorder_block( + reinterpret_cast(indirection.data()), chunked_shape.m, chunked_shape.k, pack_shape.m, + 1); + + Buffer out = indirect_matmul( // + indirection.data(), indirection_offset, lhs_padding.data(), nullptr, nullptr, lhs_dt, // LHS + rhs.data(), nullptr, nullptr, rhs_dt, // RHS + bias.data(), nullptr, nullptr, bias_dt, // Bias + out_dt, // Out + chunked_shape.m, chunked_shape.n, chunked_shape.k, k_chunk_length); + + // Calculate clamping range based on full range of values, and then clamp values + const auto [min, max] = find_clamp_range(out_dt, out.data(), shape.m * shape.n, 1.0F - clamp_rate); + Buffer out_clamped = clamp(out_dt, out.data(), shape.m * shape.n, min, max); + + // Populate reference data + TestData test_reference; + test_reference.lhs = std::move(lhs); + test_reference.rhs = std::move(rhs); + test_reference.bias = std::move(bias); + test_reference.padding = std::move(lhs_padding); + test_reference.out = std::move(out_clamped); + test_reference.indirection_offset = indirection_offset; + test_reference.indirection = std::move(indirection_packed); + test_reference.clamp_range = {min, max}; + + return test_reference; + }; +}; + +/// Perform LHS packing for indirect matmul +Buffer pack_lhs( + const LhsPackIndirectKernel& kernel, const Rect& portion, const TestData& reference, size_t m, + const KChunk& k_chunk) { + const void* const* indirection_pointer = reinterpret_cast(reference.indirection.data()); + + // Calculate size, and allocate buffer + const size_t dst_size = kernel.get_lhs_packed_size(m, k_chunk.count, k_chunk.length); + Buffer dst(dst_size); + + // Calculate portion offsets + const size_t input_offset = portion.start_row() * k_chunk.count; + const size_t dst_offset = kernel.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length); + + // Perform packing + kernel.pack( + portion.height(), k_chunk.count, k_chunk.length, // Dimensions + indirection_pointer + input_offset, // Indirection input + reference.indirection_offset, // Chunk offset + reference.padding.data(), // Padding pointer + dst.data() + dst_offset); + return dst; +} + +/// Perform RHS packign for indirect matmul +Buffer pack_rhs( + const RhsPackIndirectKernel& kernel, const Rect& portion, const TestData& reference, size_t n, + const KChunk& k_chunk, DataType type) { + // Calculate size, and allocate buffer + const size_t row_stride = round_up_division(n * data_type_size_in_bits(type), 8); + const size_t dst_size = kernel.get_rhs_packed_size(n, k_chunk.count, k_chunk.length); + Buffer dst(dst_size); + + // Calculate offsets + const size_t rhs_offset = kernel.get_rhs_offset(portion.start_col()); + const size_t bias_offset = kernel.get_bias_offset(portion.start_col()); + const size_t dst_offset = kernel.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length); + + // Perform actual packing + kernel.pack( + portion.width(), k_chunk.count, k_chunk.length, row_stride, // Dimensions + reference.rhs.data() + rhs_offset, // RHS input + reference.bias.data() + bias_offset, // Bias + dst.data() + dst_offset); // Output + return dst; +} + +/// Perform imatmul +/// +/// Note, this should not be aware of reference result, as to make it clear that +/// any produced result is strictly from the code under test +Buffer imatmul( + const MatMulIndirectKernel& kernel, const Rect& portion, const MatMulShape& shape, const KChunk& k_chunk, + const Buffer& lhs_packed, const Buffer& rhs_packed, Range clamp_range, DataType type) { + // Calculate size, and allocate buffer + const size_t dst_size = kernel.get_dst_size(shape.m, shape.n); + const size_t row_stride = round_up_division(shape.n * data_type_size_in_bits(type), 8); + Buffer dst(dst_size); + + // Calculate portion offsets + const size_t lhs_offset = kernel.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length); + const size_t rhs_offset = kernel.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length); + const size_t dst_offset = kernel.get_dst_offset(portion.start_row(), portion.start_col(), row_stride); + + // Call matmul kernel + kernel.imatmul( + portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions + lhs_packed.data() + lhs_offset, // LHS + rhs_packed.data() + rhs_offset, // RHS + dst.data() + dst_offset, // DST + row_stride, clamp_range.min, clamp_range.max); + + return dst; +} + +} // namespace + +/// End-to-end test for indirection matmul kernels +TEST_P(IndirectMatMulTest, Output) { + const auto& [method, shape, k_chunk_length, output_portion, clamp_rate] = GetParam(); + if (not method.is_supported()) { + GTEST_SKIP() << "CPU features are not supported by current CPU"; + } + + const KChunk k_chunk{shape.k, k_chunk_length}; + + // Retrieve reference data + const TestData& test_data = + ReferenceGenerator::get_test_reference({shape, method.pack_shape, method.format, k_chunk_length, clamp_rate}); + const Rect portion = output_portion.compute_portion(shape.m, shape.n, method.pack_shape.m, method.pack_shape.n); + + // Call packing kernels, and then imatmul kernel + Buffer lhs_packed = pack_lhs(method.lhs, portion, test_data, shape.m, k_chunk); + Buffer rhs_packed = pack_rhs(method.rhs, portion, test_data, shape.n, k_chunk, method.format.rhs.data_type()); + Buffer out = imatmul( + method.imatmul, portion, shape, k_chunk, lhs_packed, rhs_packed, test_data.clamp_range, + method.format.out.data_type()); + + // Compare the actual result with the reference result + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + const auto success = + compare(out.data(), test_data.out.data(), method.format.out.data_type(), shape.m, shape.n, portion, handler); + ASSERT_TRUE(success); +} + +/// Name generator for test case +[[maybe_unused]] static void PrintTo(const IndirectMatMulTestParams& param, std::ostream* os) { + const auto& [method, shape, k_chunk_length, portion, clamp_rate] = param; + *os << "Method_" << method.name << "__"; + PrintTo(shape, os); + *os << "__K_chunk_length_" << k_chunk_length; + *os << "__clamp_rate_" << static_cast(clamp_rate * 100) << "__"; + PrintTo(portion, os); +} + +/// Test parameter listing +INSTANTIATE_TEST_SUITE_P( + IndirectMatMul, IndirectMatMulTest, + testing::Combine( + testing::ValuesIn(get_indirect_matmul_methods()), // + testing::ValuesIn({ + // clang-format off + MatMulShape{ 1, 1, 1}, // + MatMulShape{ 1, 17, 4}, // + MatMulShape{ 1, 19, 24}, // + MatMulShape{ 1, 32, 4}, // + MatMulShape{ 1, 32, 32}, // + MatMulShape{ 1, 33, 200}, // + MatMulShape{ 1, 49, 21}, // + MatMulShape{ 1, 64, 4}, // + MatMulShape{ 1, 65, 4}, // + MatMulShape{ 3, 6, 6}, // + MatMulShape{ 3, 28, 25}, // + MatMulShape{ 4, 16, 4}, // + MatMulShape{ 4, 16, 27}, // + MatMulShape{ 6, 18, 31}, // + MatMulShape{ 6, 28, 1}, // + MatMulShape{ 6, 29, 24}, // + MatMulShape{ 8, 16, 16}, // + MatMulShape{ 16, 16, 4}, // + MatMulShape{ 16, 16, 16}, // + MatMulShape{ 20, 30, 40}, // + MatMulShape{ 23, 1, 43}, // + MatMulShape{ 32, 14, 1}, // + MatMulShape{ 32, 16, 27}, // + MatMulShape{ 32, 32, 3}, // + MatMulShape{ 32, 32, 4}, // + MatMulShape{ 33, 29, 24}, // + MatMulShape{ 64, 64, 3}, // + MatMulShape{ 64, 64, 4}, // + MatMulShape{ 96, 96, 3}, // + MatMulShape{ 96, 97, 3}, // + MatMulShape{ 97, 96, 3}, // + MatMulShape{123, 85, 45}, // + MatMulShape{128, 128, 3}, // + MatMulShape{130, 130, 6}, // + // clang-format on + }), + testing::ValuesIn(std::initializer_list{1, 2, 3, 4, 8, 11, 16, 32, 33, 64, 65}), // + testing::ValuesIn({ + // clang-format off + // (Start row , start col , height , width) + MatrixPortion( 0 , 0 , 1 , 1 ), // Full matrix. + MatrixPortion( 0 , 0 , 1 , 0.5 ), // Left half + MatrixPortion( 0 , 0 , 0.5 , 1 ), // Upper half + MatrixPortion( 0 , 0.5 , 1 , 0.5 ), // Right half + MatrixPortion( 0.5 , 0 , 0.5 , 1 ), // Bottom half + MatrixPortion( 0.4 , 0.4 , 0.3 , 0.3 ), // Center ninth + // clang-format on + }), + testing::ValuesIn(std::initializer_list{0.0F, 0.1F, 0.5F})), // + testing::PrintToStringParamName()); + +} // namespace kai::test