From e53f8c91e64c2a7444b5ab22163556caaa97b41c Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Fri, 11 Apr 2025 07:18:50 +0000 Subject: [PATCH 1/7] Add FP16 IGEMM LHS packing kernel This change adds the LHS FP16 packing kernel to be used with IGEMM. It doesn't actually change the elements, so this can be used with any 16-bit types. Signed-off-by: Emil Ohlsson Approved-by: Felix Johnny Thomasmathibalan --- CMakeLists.txt | 1 + kai/ukernels/matmul/BUILD.bazel | 1 + .../kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c | 341 ++++++++++++++++++ .../kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h | 61 ++++ 4 files changed, 404 insertions(+) create mode 100644 kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c create mode 100644 kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 8e0c7130..48e2dc37 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -205,6 +205,7 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME + kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 79bfac8b..8be1488d 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -129,6 +129,7 @@ I8MM_KERNELS_ASM = [ # buildifier: keep sorted SME_KERNELS = [ + "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme", "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_lhs_pack_f32p2vlx1_f32_sme", diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c new file mode 100644 index 00000000..f996bd48 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c @@ -0,0 +1,341 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" + +#include +#include + +#include "kai/kai_common.h" + +#define MR 2 +#define KR 2 +#define MAX_M_STEP (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR) + +static size_t kai_get_mr_lhs_imatmul_pack_x16p2vlx2_x16p_sme(void) { + return MR * kai_get_sme_vector_length_u16() / KR; +} + +size_t kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme(void) { + return kai_get_mr_lhs_imatmul_pack_x16p2vlx2_x16p_sme(); +} + +size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(m_idx % kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme() == 0); + + return m_idx * k_chunk_count * kai_roundup(k_chunk_length, KR) * sizeof(uint16_t); +} + +size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length) { + const size_t m_end = kai_roundup(m, kai_get_mr_lhs_imatmul_pack_x16p2vlx2_x16p_sme()); + return kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme(m_end, k_chunk_count, k_chunk_length); +} + +void kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, + const void* pad_ptr, void* lhs_packed) { + KAI_ASSUME(lhs_ptrs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + const size_t m_step = kai_get_mr_lhs_imatmul_pack_x16p2vlx2_x16p_sme(); + const size_t row_offset = 0; + const size_t width = k_chunk_length; + + KAI_ASSERT(m_step <= MAX_M_STEP); + const uint8_t* in[MAX_M_STEP]; + + uint8_t* out_base = lhs_packed; + for (size_t i_m = 0; i_m < m; i_m += m_step) { + for (size_t i_k_chunk = 0; i_k_chunk < k_chunk_count; i_k_chunk += 1) { + const size_t height = KAI_MIN(m - i_m, m_step); + void* out = out_base; + for (size_t y = 0; y < height; y += 1) { + KAI_ASSERT(i_k_chunk + (i_m + y) * k_chunk_count < m * k_chunk_count); + in[y] = *(lhs_ptrs + i_m * k_chunk_count + i_k_chunk * m_step + y); + if (in[y] != pad_ptr) { + in[y] += lhs_ptr_offset; + } + } + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x22, %x[width]\n" + "mov x21, %x[width]\n" + "cnth x20\n" + "inch x22\n" + "sub x7, x20, #0x1\n" + "sub x22, x22, #0x1\n" + "ands x7, x21, x7\n" + "cntw x8\n" + "udiv x22, x22, x20\n" // n_passes = ceildiv(width, VL) + "csel x7, x7, x20, NE\n" + "sub x13, x22, #0x1\n" + "add x7, x7, #0x1\n" + "sub x17, x8, #0x2\n" + "lsl x21, %x[height], #0x1\n" // height * 2 + "lsl x20, x8, #0x1\n" + "mov x16, #0x0\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + "cntw x9, ALL, MUL #2\n" + "cntw x28, ALL, MUL #3\n" + "ldr x27, [x11, #0x0]\n" + "lsr x13, x13, #0x1\n" // n_loops = (n_passes - 1) / 2 + "and x26, x22, #0x1\n" // odd_tail = bool(n_passes & 0x1) + "ldr x25, [x10, #0x0]\n" + "lsr x7, x7, #0x1\n" + "ptrue p12.s\n" + "ldr x24, [x11, #0x8]\n" + "whilelt p11.h, XZR, x21\n" + "whilelt p10.h, x20, x21\n" + "ldr x21, [x10, #0x8]\n" + "mov x23, %x[row_offset]\n" + "mov x22, %x[out]\n" + "whilelt p9.h, x16, %x[width]\n" + "whilelt p8.h, x16, %x[width]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "mov x12, #0x0\n" + "cbz x17, 2f\n" + "1:" // K loop: Charge: Loop + ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" + ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" + ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" + ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" + ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" + "add x12, x12, #0x4\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "cmp x12, x17, LSL #1\n" + "blt 1b\n" + "2:" // K loop: Charge: End + ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" + ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" + ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" + ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" + "ldr x27, [x11, #0x0]\n" + "inch x16\n" + ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "inch x23\n" + "cbz x13, 8f\n" + "mov x20, x13\n" + "3:" // K loop: Main loop + "whilelt p8.h, x16, %x[width]\n" + "mov x15, #0x0\n" + "mov x14, #0x0\n" + "cbz x17, 5f\n" + "4:" // K loop: Main loop: First: Loop + ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" + ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" + ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" + ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" + ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + "add x10, x10, #0x10\n" + "add x15, x15, #0x4\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x14, x14, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x14, x17\n" + "blt 4b\n" + "5:" // K loop: Main loop: First: Tail + ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" + ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" + ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" + ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" + "ldr x27, [x11, #0x0]\n" + "mov x13, #0x0\n" + ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" + "ldr x25, [x10, #0x0]\n" + "mov x12, #0x0\n" + ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + "whilelt p9.h, x16, %x[width]\n" + "inch x16\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "inch x23\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "whilelt p8.h, x16, %x[width]\n" + "cbz x17, 7f\n" + "6:" // K loop: Main loop: Second: Loop + ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" + ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" + ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" + ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" + ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "add x10, x10, #0x10\n" + "add x13, x13, #0x4\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x12, x12, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x12, x17\n" + "blt 6b\n" + "7:" // K loop: Main loop: Second: Tail + ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" + ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" + ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" + ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" + ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" + ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" + ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "whilelt p9.h, x16, %x[width]\n" + "subs x20, x20, #0x1\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "inch x16\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "inch x23\n" + "bgt 3b\n" + "8:" // K loop: Tails + "cbnz x26, 11f\n" + "mov x11, %x[in]\n" + "whilelt p8.h, x16, %x[width]\n" + "mov x13, #0x0\n" + "mov x12, #0x0\n" + "9:" // K loop: Tails: Even: First + ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25396161 // psel p1.h, p8.h/Z, p11.h[w13, #1]\n" + ".inst 0x25396140 // psel p0.h, p8.h/Z, p10.h[w13, #1]\n" + ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "ldr x21, [x11, #0x0]\n" + "cmp x12, x8\n" + "ldr x20, [x11, x8, LSL #0x3]\n" + "add x11, x11, #0x8\n" + ".inst 0xe05726a1 // ld1h { za0h.h[x13, #1] }, p1/Z, [x21, x23, LSL #1]\n" + ".inst 0xe0572289 // ld1h { za1h.h[x13, #1] }, p0/Z, [x20, x23, LSL #1]\n" + "add x13, x13, #0x2\n" + "blt 9b\n" + "whilelt p9.h, x16, %x[width]\n" + "whilelt p8.h, x16, %x[width]\n" + "mov x20, #0x0\n" + "mov x12, #0x0\n" + "10:" // K loop: Tails: Even: Second + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + "add x20, x20, #0x2\n" + ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 10b\n" + "whilelt p8.h, x16, %x[width]\n" + "b 13f\n" + "11:" // K loop: Tails: Odd + "mov x12, #0x0\n" + "12:" // K loop: Tails: Odd: Loop + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 12b\n" + "13:" // K loop: End + "mov %x[out], x22\n" + ".inst 0xd503467f // SMSTOP\n" + : [out] "+&r"(out) + : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", + "p7", "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", + "x24", "x25", "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", + "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", + "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + out_base += m_step * kai_roundup(k_chunk_length, KR) * sizeof(uint16_t); + } + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h new file mode 100644 index 00000000..a0343938 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h @@ -0,0 +1,61 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return Step size for row index +size_t kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme(void); + +/// Gets the offset in bytes to the data element in the packed LHS buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the size in bytes of the packed LHS buffer. +/// +/// @param[in] m Number of rows in the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The size in bytes of the packed LHS buffer. +size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length); + +/// Pack the LHS matrix for use with indirect matrix multiplication +/// +/// @param[in] m Number of rows of the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// @param[in] lhs_ptrs Pointer to an array of input pointers consisting of +/// `m * k_chunk_count` pointers. +/// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs +/// array, excluding zero pointers. +/// @param[in] pad_ptr Pointer to chunk used for padding. @ref lhs_ptr_offset is +/// not applied to this pointer when used in @ref lhs_ptrs. This can +/// be NULL if there is no padding used @ref lhs_ptrs +/// @param[out] lhs_packed Packed LHS matrix. +void kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, + const void* pad_ptr, void* lhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus -- GitLab From eeaa1a69c29d9aedebd8a29ccb8b17ebba1cf9de Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Fri, 11 Apr 2025 07:29:30 +0000 Subject: [PATCH 2/7] Add FP16 IMATMUL kernel This change adds the FP16 IMATMUL kernel, which consumes packed data from * `kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme` * `kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme` Signed-off-by: Emil Ohlsson Reviewed-by: Emil Ohlsson Approved-by: Felix Johnny Thomasmathibalan --- CMakeLists.txt | 1 + kai/ukernels/matmul/BUILD.bazel | 1 + ...16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c | 252 ++++++++++++++++++ ...16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h | 98 +++++++ ...ai_imatmul_clamp_f16_f16p_f16p_interface.h | 45 ++++ 5 files changed, 397 insertions(+) create mode 100644 kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c create mode 100644 kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h create mode 100644 kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 48e2dc37..33bbb7c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -220,6 +220,7 @@ set(KLEIDIAI_FILES_SME ) set(KLEIDIAI_FILES_SME2 + kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 8be1488d..9f655159 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -147,6 +147,7 @@ SME_KERNELS = [ # buildifier: keep sorted SME2_KERNELS = [ + "imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c new file mode 100644 index 00000000..7e77125a --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c @@ -0,0 +1,252 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_nr = 2; +static const size_t kai_kr = 2; + +size_t kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); + const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + return m_idx * indirect_k * sizeof(uint16_t); +} + +static size_t kai_get_rhs_packed_stride_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t k_chunk_count, size_t k_chunk_length) { + const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + return kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() * + (sizeof(uint16_t) + indirect_k * sizeof(uint16_t)); +} + +size_t kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); + const size_t block_idx = n_idx / kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(); + return block_idx * + kai_get_rhs_packed_stride_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + k_chunk_count, k_chunk_length); +} + +size_t kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_row_stride) { + KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() == 0); + + return m_idx * dst_row_stride + n_idx * sizeof(uint16_t); +} + +size_t kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(size_t m, size_t n) { + return m * n * sizeof(uint16_t); +} + +void kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, float clamp_min, float clamp_max) { + typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + float16_t min; + float16_t max; + void* accumulator_buffer; + uint64_t flags; + } KernelArgs; + + KernelArgs args; + + args.A = lhs_packed; + args.B = rhs_packed; + + size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + + args.C = dst; + args.ldcb = dst_row_stride; + args.M = m; + args.N = n; + args.K = indirect_k; + args.min = (float16_t)clamp_min; + args.max = (float16_t)clamp_max; + + args.accumulator_buffer = NULL; + args.flags = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ldr w13, [%x[args], %[offsetof_M]]\n" + "mov x11, #0x0\n" + "mov x10, #0x0\n" + "ptrue p1.b\n" + ".inst 0x25207810 // ptrue pn8.b\n" + "ldr w9, [%x[args], %[offsetof_N]]\n" + "ldr x28, [%x[args], %[offsetof_A]]\n" + "1:" // M loop + "ldr x27, [%x[args], %[offsetof_B]]\n" + "2:" // N loop + "fmov z24.h, #0.0\n" + "ld1h { z5.h }, p1/Z, [x27]\n" + "fmov z27.h, #1.0\n" + "mov x26, x28\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "inch x27, ALL, MUL #2\n" + "zip1 z30.h, z5.h, z24.h\n" + "zip2 z20.h, z5.h, z24.h\n" + ".inst 0x81be2760 // fmopa za0.s, p1/M, p1/M, z27.h, z30.h\n" + ".inst 0x81b42761 // fmopa za1.s, p1/M, p1/M, z27.h, z20.h\n" + ".inst 0x81be2762 // fmopa za2.s, p1/M, p1/M, z27.h, z30.h\n" + ".inst 0x81b42763 // fmopa za3.s, p1/M, p1/M, z27.h, z20.h\n" + "ldr x20, [%x[args], %[offsetof_K]]\n" + "add x20, x20, #0x1\n" + "lsr x20, x20, #0x1\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 6f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" + ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" + ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" + ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" + ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" + "addvl x26, x26, #8\n" + ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + "ble 5f\n" + "4:" // K loop + ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" + "subs x21, x21, #0x1\n" + ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" + ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" + ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" + ".inst 0xa0402352 // ld1h { z18.h-z19.h }, pn8.b/Z, [x26]\n" + ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" + ".inst 0xa0402370 // ld1h { z16.h-z17.h }, pn8.b/Z, [x27]\n" + ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" + ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" + ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" + ".inst 0xa1412342 // ld1h { z2.h, z10.h }, pn8.b/Z, [x26, #0x2, MUL VL]\n" + ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" + ".inst 0xa041237e // ld1h { z30.h-z31.h }, pn8.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" + ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" + ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" + ".inst 0xa042235c // ld1h { z28.h-z29.h }, pn8.b/Z, [x26, #0x4, MUL VL]\n" + ".inst 0xa1422366 // ld1h { z6.h, z14.h }, pn8.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" + ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" + ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" + ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" + ".inst 0xa1432345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26, #0x6, MUL VL]\n" + "addvl x26, x26, #8\n" + ".inst 0xa1432367 // ld1h { z7.h, z15.h }, pn8.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + "bgt 4b\n" + "5:" // K loop tail + ".inst 0x81b02640 // fmopa za0.s, p1/M, p1/M, z18.h, z16.h\n" + ".inst 0x81b12641 // fmopa za1.s, p1/M, p1/M, z18.h, z17.h\n" + ".inst 0x81b02662 // fmopa za2.s, p1/M, p1/M, z19.h, z16.h\n" + ".inst 0x81b12663 // fmopa za3.s, p1/M, p1/M, z19.h, z17.h\n" + ".inst 0x81be2440 // fmopa za0.s, p1/M, p1/M, z2.h, z30.h\n" + ".inst 0x81bf2441 // fmopa za1.s, p1/M, p1/M, z2.h, z31.h\n" + ".inst 0x81be2542 // fmopa za2.s, p1/M, p1/M, z10.h, z30.h\n" + ".inst 0x81bf2543 // fmopa za3.s, p1/M, p1/M, z10.h, z31.h\n" + ".inst 0x81a62780 // fmopa za0.s, p1/M, p1/M, z28.h, z6.h\n" + ".inst 0x81ae2781 // fmopa za1.s, p1/M, p1/M, z28.h, z14.h\n" + ".inst 0x81a627a2 // fmopa za2.s, p1/M, p1/M, z29.h, z6.h\n" + ".inst 0x81ae27a3 // fmopa za3.s, p1/M, p1/M, z29.h, z14.h\n" + ".inst 0x81a724a0 // fmopa za0.s, p1/M, p1/M, z5.h, z7.h\n" + ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" + ".inst 0x81a725a2 // fmopa za2.s, p1/M, p1/M, z13.h, z7.h\n" + ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" + "6:" // K oddments + "cbz x20, 8f\n" + "7:" // K oddments: Loop + ".inst 0xa1402345 // ld1h { z5.h, z13.h }, pn8.b/Z, [x26]\n" + "subs x20, x20, #0x1\n" + "addvl x26, x26, #2\n" + ".inst 0xa040236e // ld1h { z14.h-z15.h }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0x81ae24a0 // fmopa za0.s, p1/M, p1/M, z5.h, z14.h\n" + ".inst 0x81af24a1 // fmopa za1.s, p1/M, p1/M, z5.h, z15.h\n" + ".inst 0x81ae25a2 // fmopa za2.s, p1/M, p1/M, z13.h, z14.h\n" + ".inst 0x81af25a3 // fmopa za3.s, p1/M, p1/M, z13.h, z15.h\n" + "bgt 7b\n" + "8:" // K oddments: End + "ldr x25, [%x[args], %[offsetof_C]]\n" + "sub x24, x13, x11\n" + "cntw x23, ALL, MUL #2\n" + "ld1rh { z17.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "ldr x22, [%x[args], %[offsetof_ldcb]]\n" + "whilelt p0.h, x10, x9\n" + "cmp x24, x23\n" + "ld1rh { z16.h }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "mov x12, #0x0\n" + "mov x21, #0x0\n" + "add x25, x25, x10, LSL #1\n" // C += n + "mov x20, #0x2\n" + "madd x25, x11, x22, x25\n" // C += m * ldc + "csel x24, x24, x23, LT\n" + "10:" // Store to output array: Accumulator loop + ".inst 0xc006000e // mova { z14.b-z15.b }, za0h.b[x12, 0:1]\n" + "add x12, x12, #0x4\n" + "cmp x12, x23, LSL #1\n" + "add x21, x21, #0x1\n" + ".inst 0xc120e1cc // fcvt z12.h, { z14.s-z15.s }\n" + "csel x12, x12, x20, LT\n" + "cmp x21, x24\n" + ".inst 0x6470262c // fclamp z12.h, z17.h, z16.h\n" + "st1h { z12.h }, p0, [x25]\n" + "add x25, x25, x22\n" + "blt 10b\n" + "incw x10, ALL, MUL #2\n" + "cmp x10, x9\n" + "blt 2b\n" + "incw x11, ALL, MUL #2\n" + "mov x10, #0x0\n" + "cmp x11, x13\n" + "mov x28, x26\n" + "blt 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), + [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), + [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), + [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), [offsetof_M] "I"(offsetof(KernelArgs, M)), + [offsetof_N] "I"(offsetof(KernelArgs, N)), [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", + "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", + "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h new file mode 100644 index 00000000..a54d7c2a --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h @@ -0,0 +1,98 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme to pack the LHS matrix. +/// -# kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme to pack the RHS matrix. + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(void); + +/// Gets the offset in bytes to the data element in the packed LHS matrix buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of `n_step`. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. Must be a multiple of `m_step`. +/// @param[in] n_idx Column index. Must be a multiple of `n_step`. +/// @param[in] dst_row_stride. Distance between start of two rows in the output buffer. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_row_stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * Packed LHS: @ref kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa. +/// * Output: @ref kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split +/// @param[in] packed_lhs Packed LHS matrix buffer. +/// @param[in] packed_rhs Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_row_stride Row stride in bytes of the output matrix. +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h new file mode 100644 index 00000000..bbc2b318 --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h @@ -0,0 +1,45 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: imatmul_clamp_f16_f16p_f16p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_m_step_func_t)(void); +typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_n_step_func_t)(void); +typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_lhs_packed_offset_func_t)( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); +typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_rhs_packed_offset_func_t)( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); +typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_imatmul_clamp_f16_f16p_f16p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_imatmul_clamp_f16_f16p_f16p_run_imatmul_func_t)( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + +/// Micro-kernel interface +struct kai_imatmul_clamp_f16_f16p_f16p_ukernel { + kai_imatmul_clamp_f16_f16p_f16p_get_m_step_func_t get_m_step; + kai_imatmul_clamp_f16_f16p_f16p_get_n_step_func_t get_n_step; + kai_imatmul_clamp_f16_f16p_f16p_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_imatmul_clamp_f16_f16p_f16p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_imatmul_clamp_f16_f16p_f16p_get_dst_offset_func_t get_dst_offset; + kai_imatmul_clamp_f16_f16p_f16p_get_dst_size_func_t get_dst_size; + kai_imatmul_clamp_f16_f16p_f16p_run_imatmul_func_t run_imatmul; +}; + +#ifdef __cplusplus +} +#endif -- GitLab From 04fe4a3396982c49810ba7088a0f91435207c067 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Fri, 11 Apr 2025 07:35:34 +0000 Subject: [PATCH 3/7] Add FP16 IGEMM RHS packing kernel This change adds the RHS FP16 packing kernel to be used with IGEMM. It doesn't actually change the elements, so this can be used with any 16-bit types. Signed-off-by: Emil Ohlsson Approved-by: Felix Johnny Thomasmathibalan --- CMakeLists.txt | 1 + kai/ukernels/matmul/BUILD.bazel | 1 + ..._imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c | 204 ++++++++++++++++++ ..._imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h | 80 +++++++ 4 files changed, 286 insertions(+) create mode 100644 kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c create mode 100644 kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 33bbb7c7..4a720d7a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -211,6 +211,7 @@ set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c + kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 9f655159..8f56394f 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -136,6 +136,7 @@ SME_KERNELS = [ "pack/kai_lhs_pack_x16p2vlx2_x16_sme", "pack/kai_lhs_pack_x8p2vlx4_x8_sme", "pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", + "pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c new file mode 100644 index 00000000..a9c0bb73 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c @@ -0,0 +1,204 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. +#include "kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +#define NR 2 +#define KR 2 +static const size_t kai_num_bytes_input = sizeof(uint16_t); +static const size_t kai_num_bytes_output = sizeof(uint16_t); +static const size_t kai_num_bytes_bias = sizeof(uint16_t); + +#define MAX_N_STEP (NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR)) + +size_t kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(void) { + return NR * kai_get_sme_vector_length_u16() / KR; +} + +size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx) { + KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() == 0); + + return n_idx * kai_num_bytes_input; +} + +size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx) { + return n_idx * kai_num_bytes_bias; +} + +static size_t kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t k_chunk_count, size_t k_chunk_length) { + return kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() * + (kai_num_bytes_bias + k_chunk_count * kai_roundup(k_chunk_length, KR) * kai_num_bytes_output); +} + +size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() == 0); + + const size_t block_idx = n_idx / kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(); + return block_idx * + kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(k_chunk_count, k_chunk_length); +} + +size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length) { + const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme()); + return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + n_rounded_up, k_chunk_count, k_chunk_length); +} + +void kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + void* rhs_packed) { + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(rhs_packed != NULL); + + size_t height = k_chunk_length; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_row_stride; + + KAI_ASSERT(kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() <= MAX_N_STEP); + uint16_t pad_row[MAX_N_STEP]; + if (height % KR) { + memset(pad_row, 0, MAX_N_STEP * sizeof(uint16_t)); + } + + size_t out_stride = + kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(k_chunk_count, k_chunk_length); + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x21, %x[out]\n" + "mov x20, %x[width]\n" + "ptrue p1.b\n" + "1:" // Bias: Full loop + "whilelt p0.h, XZR, x20\n" + "dech x20\n" + "cmp x20, #0x0\n" + "ld1h { z16.h }, p0/Z, [%x[bias]]\n" + "incb %x[bias]\n" + "st1h { z16.h }, p1, [x21]\n" + "add x21, x21, %x[out_stride]\n" + "bgt 1b\n" + "incb %x[out]\n" + "mov x11, %x[k_chunk_count]\n" + "2:" // Chunk Loop + "mov x10, %x[height]\n" + "cmp x10, #0x8\n" + "blt 6f\n" + "3:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[out]\n" + "add x27, x9, %x[in_stride]\n" + "sub x10, x10, #0x8\n" + "add x26, x27, %x[in_stride]\n" + "mov x25, %x[width]\n" + "add x24, x26, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "4:" // Main row loop: Column loop + "whilelt p0.h, XZR, x25\n" + "decw x25, ALL, MUL #2\n" + "ld1h { z20.h }, p0/Z, [x9]\n" + "cmp x25, #0x0\n" + "addvl x9, x9, #1\n" + "ld1h { z17.h }, p0/Z, [x27]\n" + "addvl x27, x27, #1\n" + "ld1h { z19.h }, p0/Z, [x26]\n" + "addvl x26, x26, #1\n" + "ld1h { z16.h }, p0/Z, [x24]\n" + "addvl x24, x24, #1\n" + "ld1h { z18.h }, p0/Z, [x23]\n" + "addvl x23, x23, #1\n" + "zip1 z24.h, z20.h, z17.h\n" + "zip2 z23.h, z20.h, z17.h\n" + "ld1h { z17.h }, p0/Z, [x22]\n" + "addvl x22, x22, #1\n" + "ld1h { z22.h }, p0/Z, [x21]\n" + "addvl x21, x21, #1\n" + "zip1 z21.h, z19.h, z16.h\n" + "zip2 z20.h, z19.h, z16.h\n" + "ld1h { z16.h }, p0/Z, [x20]\n" + "addvl x20, x20, #1\n" + "zip1 z19.h, z18.h, z17.h\n" + "zip2 z18.h, z18.h, z17.h\n" + "st1h { z24.h }, p1, [x28]\n" + "st1h { z23.h }, p1, [x28, #1, MUL VL]\n" + "zip1 z17.h, z22.h, z16.h\n" + "zip2 z16.h, z22.h, z16.h\n" + "st1h { z21.h }, p1, [x28, #2, MUL VL]\n" + "st1h { z20.h }, p1, [x28, #3, MUL VL]\n" + "st1h { z19.h }, p1, [x28, #4, MUL VL]\n" + "st1h { z18.h }, p1, [x28, #5, MUL VL]\n" + "st1h { z17.h }, p1, [x28, #6, MUL VL]\n" + "st1h { z16.h }, p1, [x28, #7, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 4b\n" + "cmp x10, #0x8\n" + "addvl %x[out], %x[out], #8\n" + "bge 3b\n" + "cbz x10, 10f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x9, %x[in]\n" + "cntw x22, ALL, MUL #4\n" + "add x27, x9, %x[in_stride]\n" + "cmp x10, #0x1\n" + "add %x[in], x27, %x[in_stride]\n" + "mov x28, %x[out]\n" + "csel %x[in], %x[in], x27, GT\n" + "csel x27, x27, %x[pad_row], GT\n" + "csel x21, x22, XZR, GT\n" + "sub x10, x10, #0x2\n" + "mov x20, %x[width]\n" + "8:" // Tail row loop: Column loop + "whilelt p0.h, XZR, x20\n" + "decw x20, ALL, MUL #2\n" + "ld1h { z18.h }, p0/Z, [x9]\n" + "cmp x20, #0x0\n" + "add x9, x9, x22\n" + "ld1h { z16.h }, p0/Z, [x27]\n" + "add x27, x27, x21\n" + "zip1 z17.h, z18.h, z16.h\n" + "zip2 z16.h, z18.h, z16.h\n" + "st1h { z17.h }, p1, [x28]\n" + "st1h { z16.h }, p1, [x28, #1, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 8b\n" + "cmp x10, #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 7b\n" + "10:" // Done + "sub x11, x11, #0x1\n" + "cbnz x11, 2b\n" + ".inst 0xd503467f // SMSTOP\n" + : [bias] "+&r"(bias), [in] "+&r"(in), [out] "+&r"(out) + : [height] "r"(height), [in_stride] "r"(in_stride), [k_chunk_count] "r"(k_chunk_count), + [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", + "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", + "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h new file mode 100644 index 00000000..10578871 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h @@ -0,0 +1,80 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return Step size for column index. +size_t kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. Must be divisible by `n_step` +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Column index. Must be divisible by `n_step` +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of columns. +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme. +/// * Bias: @ref kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme. +/// * Scale: @ref kai_get_scale_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme. +/// +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// @param[in] rhs_row_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. +/// @param[out] rhs_packed Packed RHS matrix. +void kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + void* rhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus -- GitLab From 547c2011340a3205bcb6dd454036bf4ddcaeeb16 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Wed, 16 Apr 2025 09:05:00 +0000 Subject: [PATCH 4/7] Add unit testing for FP16 IGEMM kernels This change adds unit testing for the FP16 IGEMM kernels. The code is written in a type agnostic manner, as to easily allow testing for other data types with very low effort. This required the addition of non-templated `read`/`write` functions, as to allow runtime-generic access. The big change is the `imatmul_test.cpp`, which is the main driver of the FP16 IGEMM testing. The generated test data is cached, which introduced the need for hashing functionality in several places throughout the testing code. Clamp testing is parameterized, which introduced changes needed in the clamping code. Finally, test description generated has been updated to allow re-use of existing type printing Signed-off-by: Emil Ohlsson Reviewed-by: Jakub Sujak Approved-by: Jakub Sujak --- CMakeLists.txt | 5 +- test/common/data_format.cpp | 17 +- test/common/data_format.hpp | 5 + test/common/matmul_test_common.cpp | 19 +- test/common/matmul_test_common.hpp | 30 ++ test/common/memory.cpp | 75 +++++ test/common/memory.hpp | 20 +- test/reference/clamp.cpp | 24 ++ test/reference/clamp.hpp | 20 ++ test/reference/matmul.cpp | 30 ++ test/reference/matmul.hpp | 34 ++ test/tests/imatmul_test.cpp | 500 +++++++++++++++++++++++++++++ 12 files changed, 770 insertions(+), 9 deletions(-) create mode 100644 test/common/memory.cpp create mode 100644 test/tests/imatmul_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a720d7a..f94941cd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -313,6 +313,7 @@ if(KLEIDIAI_BUILD_TESTS) test/common/int4.cpp test/common/matmul_test_common.cpp test/common/matrix_portion.cpp + test/common/memory.cpp test/common/printer.cpp test/common/rect.cpp test/common/round.cpp @@ -362,6 +363,7 @@ if(KLEIDIAI_BUILD_TESTS) add_executable(kleidiai_test test/tests/bfloat16_test.cpp test/tests/float16_test.cpp + test/tests/imatmul_test.cpp test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_f32_f32p_test.cpp @@ -376,8 +378,9 @@ if(KLEIDIAI_BUILD_TESTS) ) set_source_files_properties( - test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp + test/tests/imatmul_test.cpp test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp + test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_test.cpp PROPERTIES COMPILE_FLAGS "-Wpedantic") diff --git a/test/common/data_format.cpp b/test/common/data_format.cpp index ccf9a46a..e188710a 100644 --- a/test/common/data_format.cpp +++ b/test/common/data_format.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,7 @@ #include #include +#include #include "kai/kai_common.h" #include "test/common/data_type.hpp" @@ -175,4 +176,18 @@ size_t DataFormat::default_size_in_bytes(size_t height, size_t width) const { return num_rows * block_stride; } +size_t DataFormat::Hash::operator()(const DataFormat& format) const { + using DT = std::underlying_type_t; + using PF = std::underlying_type_t; + return // + (std::hash
{}(static_cast
(format._data_type)) << 0) ^ // + (std::hash{}(static_cast(format._zero_point_dt)) << 1) ^ // + (std::hash
{}(static_cast
(format._scale_dt) << 2)) ^ // + (std::hash
{}(static_cast
(format._zero_point_dt)) << 3) ^ // + (std::hash{}(format._block_height) << 4) ^ // + (std::hash{}(format._block_width) << 5) ^ // + (std::hash{}(format._subblock_height) << 6) ^ // + (std::hash{}(format._subblock_width) << 7); // +} + } // namespace kai::test diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp index 730dd86e..2d7f0b84 100644 --- a/test/common/data_format.hpp +++ b/test/common/data_format.hpp @@ -141,6 +141,11 @@ public: /// @return The size in bytes of the matrix. [[nodiscard]] size_t default_size_in_bytes(size_t height, size_t width) const; + /// Hash functor + struct Hash { + size_t operator()(const DataFormat& format) const; + }; + private: DataType _data_type; PackFormat _pack_format; diff --git a/test/common/matmul_test_common.cpp b/test/common/matmul_test_common.cpp index 73d41c09..67b05d0b 100644 --- a/test/common/matmul_test_common.cpp +++ b/test/common/matmul_test_common.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -12,13 +12,20 @@ namespace kai::test { void PrintTo(const MatMulTestParams& param, std::ostream* os) { const auto& [method, shape, portion] = param; - // NOLINTBEGIN(cppcoreguidelines-pro-bounds-constant-array-index) - *os << "Method_" << method.name // - << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // - << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + *os << "Method_" << method.name << "__"; + PrintTo(shape, os); + *os << "__"; + PrintTo(portion, os); +} + +void PrintTo(const MatMulShape& shape, std::ostream* os) { + *os << "M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k; +} + +void PrintTo(const MatrixPortion& portion, std::ostream* os) { + *os << "PortionStartRow_" << static_cast(portion.start_row() * 1000) // << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // << "__PortionHeight_" << static_cast(portion.height() * 1000) // << "__PortionWidth_" << static_cast(portion.width() * 1000); - // NOLINTEND(cppcoreguidelines-pro-bounds-constant-array-index) } } // namespace kai::test diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index 5b3e2424..3db8b811 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -22,6 +22,34 @@ struct MatMulShape { size_t m; ///< LHS height. size_t n; ///< RHS width. size_t k; ///< LHS width and RHS height. + + struct Hash { + size_t operator()(const MatMulShape& shape) const { + return // + (std::hash{}(shape.m) << 0) ^ // + (std::hash{}(shape.n) << 1) ^ // + (std::hash{}(shape.k) << 2); // + } + }; + +private: + friend bool operator==(const MatMulShape& lhs, const MatMulShape& rhs) { + return // + lhs.m == rhs.m && // + lhs.n == rhs.n && // + lhs.k == rhs.k; + } +}; + +/// Value range +template +struct Range { + T min; + T max; + + [[nodiscard]] T range() const { + return max - min; + } }; // NOLINTBEGIN(misc-non-private-member-variables-in-classes) @@ -459,4 +487,6 @@ using MatMulTestParams = std::tuple; /// Prints the test information. void PrintTo(const MatMulTestParams& param, std::ostream* os); +void PrintTo(const MatMulShape& shape, std::ostream* os); +void PrintTo(const MatrixPortion& portion, std::ostream* os); } // namespace kai::test diff --git a/test/common/memory.cpp b/test/common/memory.cpp new file mode 100644 index 00000000..499f2131 --- /dev/null +++ b/test/common/memory.cpp @@ -0,0 +1,75 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/memory.hpp" + +#include + +#include "test/common/bfloat16.hpp" +#include "test/common/float16.hpp" + +namespace kai::test { + +double read_array(DataType type, const void* array, size_t index) { + switch (type) { + case DataType::FP32: + return read_array(array, index); + case DataType::FP16: + return static_cast(read_array(array, index)); + case DataType::BF16: + return static_cast(read_array(array, index)); + case DataType::I32: + return read_array(array, index); + case DataType::QAI8: + return read_array(array, index); + case DataType::QSU4: + return read_array(array, index); + case DataType::QSI4: + return read_array(array, index); + case DataType::UNKNOWN: + default: + KAI_ERROR("Trying to read unknown data type"); + } + return std::numeric_limits::signaling_NaN(); +} + +void write_array(DataType type, void* array, size_t index, double value) { + switch (type) { + case DataType::FP32: { + write_array(array, index, value); + return; + } + case DataType::FP16: { + write_array(array, index, static_cast(value)); + return; + } + case DataType::BF16: { + write_array(array, index, static_cast(value)); + return; + } + case DataType::I32: { + write_array(array, index, value); + return; + } + case DataType::QAI8: { + write_array(array, index, value); + return; + } + case DataType::QSU4: { + write_array(array, index, static_cast(value)); + return; + } + case DataType::QSI4: { + write_array(array, index, static_cast(value)); + return; + } + case DataType::UNKNOWN: + default: + KAI_ERROR("Trying to write unknown data type"); + } +} + +} // namespace kai::test diff --git a/test/common/memory.hpp b/test/common/memory.hpp index c856218f..28a24ea1 100644 --- a/test/common/memory.hpp +++ b/test/common/memory.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -12,6 +12,7 @@ #include "kai/kai_common.h" #include "test/common/bfloat16.hpp" +#include "test/common/data_type.hpp" #include "test/common/int4.hpp" namespace kai::test { @@ -50,6 +51,15 @@ T read_array(const void* array, size_t index) { } } +/// Reads the array at the specified index +/// +/// @param[in] type Array element data type +/// @param[in] array Data buffer. +/// @param[in] index Array index. +/// +/// @return Value at specified index +double read_array(DataType type, const void* array, size_t index); + /// Writes the specified value to the array. /// /// @param[in] array Data buffer. @@ -80,4 +90,12 @@ void write_array(void* array, size_t index, T value) { } } +/// Writes the specified value to the array. +/// +/// @param[in] type Array element type. +/// @param[in] array Data buffer. +/// @param[in] index Array index. +/// @param[in] value Value to be stored. +void write_array(DataType type, void* array, size_t index, double value); + } // namespace kai::test diff --git a/test/reference/clamp.cpp b/test/reference/clamp.cpp index 6a8e7433..ab2e77e9 100644 --- a/test/reference/clamp.cpp +++ b/test/reference/clamp.cpp @@ -49,6 +49,20 @@ std::tuple find_clamp_range(const void* src, size_t len, float ratio) { template std::tuple find_clamp_range(const void* src, size_t len, float ratio); template std::tuple find_clamp_range(const void* src, size_t len, float ratio); +std::tuple find_clamp_range(DataType type, const void* src, size_t len, float ratio) { + auto max = std::numeric_limits::min(); + auto min = std::numeric_limits::max(); + + for (size_t i = 0; i < len; i += 1) { + const double value = read_array(type, src, i); + max = std::max(value, max); + min = std::min(value, min); + } + + const float reduction = (max - min) * (1.0F - ratio) / 2.0F; + return {min + reduction, max - reduction}; +} + template std::vector clamp(const void* src, size_t len, T min_value, T max_value) { std::vector dst(round_up_division(len * size_in_bits, 8)); @@ -63,4 +77,14 @@ std::vector clamp(const void* src, size_t len, T min_value, T max_value template std::vector clamp(const void* src, size_t len, float min_value, float max_value); template std::vector clamp(const void* src, size_t len, Float16 min_value, Float16 max_value); +std::vector clamp(DataType type, const void* src, size_t len, float min_value, float max_value) { + std::vector dst(round_up_division(len * data_type_size_in_bits(type), 8)); + + for (size_t i = 0; i < len; ++i) { + write_array(type, dst.data(), i, std::clamp(read_array(type, src, i), min_value, max_value)); + } + + return dst; +} + } // namespace kai::test diff --git a/test/reference/clamp.hpp b/test/reference/clamp.hpp index b665917e..1570e129 100644 --- a/test/reference/clamp.hpp +++ b/test/reference/clamp.hpp @@ -11,6 +11,8 @@ #include #include +#include "test/common/data_type.hpp" + namespace kai::test { /// Finds the clamping parameters to limit the dynamic range. @@ -23,6 +25,16 @@ namespace kai::test { template std::tuple find_clamp_range(const void* src, size_t len, float ratio); +/// Finds the clamping parameters to limit the dynamic range. +/// +/// @param[in] type Array element data type. +/// @param[in] src The data buffer. +/// @param[in] len The number of values. +/// @param[in] ratio The ratio between the output dynamic range and the input dynamic range. +/// +/// @return The minimum value and the maximum value. +std::tuple find_clamp_range(DataType type, const void* src, size_t len, float ratio); + /// Clamps the matrix. /// /// @param[in] src Data buffer of the source matrix. @@ -32,4 +44,12 @@ std::tuple find_clamp_range(const void* src, size_t len, float ratio); template std::vector clamp(const void* src, size_t len, T min_value, T max_value); +/// Clamps the matrix. +/// +/// @param[in] type Array element data type. +/// @param[in] src Data buffer of the source matrix. +/// @param[in] len Number of values in the source matrix. +/// @param[in] min_value Lower bound of clamp. +/// @param[in] width Upper bound of clamp. +std::vector clamp(DataType type, const void* src, size_t len, float min_value, float max_value); } // namespace kai::test diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 9bfefbd2..7ce3b01a 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -184,6 +184,36 @@ std::vector matmul( return tmp_dst; } +std::vector indirect_matmul( + const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding_ptr, + [[maybe_unused]] const void* lhs_scales, [[maybe_unused]] const void* lhs_zero_points, + DataType lhs_dt, // + const void* rhs, [[maybe_unused]] const void* rhs_scales, [[maybe_unused]] const void* rhs_zero_points, + DataType rhs_dt, // + const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, // + DataType dst_dt, // + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length) { + // This is inefficient, but allows code-reuse + const size_t chunk_bytes = k_chunk_length * round_up_division(data_type_size_in_bits(lhs_dt), 8); + const size_t n_chunks = m * k_chunk_count; + std::vector lhs(n_chunks * chunk_bytes); + + // Copy all chunks to the created matrix + for (size_t i = 0; i < n_chunks; i += 1) { + const uint8_t* src_pointer = static_cast(lhs_idata[i]); + if (src_pointer != lhs_padding_ptr) { + src_pointer += lhs_offset; + } + memcpy(lhs.data() + i * chunk_bytes, src_pointer, chunk_bytes); + } + + return matmul( + lhs.data(), lhs_scales, lhs_zero_points, lhs_dt, // + rhs, rhs_scales, rhs_zero_points, rhs_dt, // + bias, bias_scales, bias_zero_points, bias_dt, // + dst_dt, m, n, k_chunk_count * k_chunk_length, false, false); +} + template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index 8d83e98c..d5118896 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -65,6 +65,40 @@ std::vector matmul( size_t m, size_t n, size_t k, // bool lhs_transposed, bool rhs_transposed); +/// Indirect matrix multiplication. +/// +/// @param[in] lhs_idata The indirect LHS data matrix. +/// @param[in] lhs_scales (Optional) LHS operand quantization scales. +/// @param[in] lhs_offset The indirection LHS data matrix offset, applied to non-padding pointers +/// @parma[in] lhs_padding The indirection LHS padding chunk pointer +/// @param[in] lhs_zero_points (Optional) LHS operand quantization zero point. +/// @param[in] lhs_dt LHS operand data type. +/// @param[in] rhs RHS operand data. +/// @param[in] rhs_scales (Optional) RHS operand quantization scales. +/// @param[in] rhs_zero_points (Optional) RHS operand quantization zero point. +/// @param[in] rhs_dt RHS operand data type. +/// @param[in] bias Bias operand data. +/// @param[in] bias_scales (Optional) Bias operand quantization scales. +/// @param[in] bias_zero_points (Optional) Bias operand quantization zero point. +/// @param[in] bias_dt Bias operand data type. +/// @param[in] dst Output data. +/// @param[in] dst_scales (Optional) Output quantization scales. +/// @param[in] dst_zero_points (Optional) Output quantization zero point. +/// @param[in] dst_dt Output data type. +/// @param[in] m Output height. +/// @param[in] n Output width. +/// @param[in] k_chunk_count Number pointers per row in lhs_idata +/// @param[in] k_chunk_size Number of elements in each LHS K chunk +/// +/// @return The result data buffer. +std::vector indirect_matmul( + const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales, + const void* lhs_zero_points, DataType lhs_dt, // + const void* rhs, const void* rhs_scales, const void* rhs_zero_points, DataType rhs_dt, // + const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, // + DataType dst_dt, // + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length); + /// Matrix multiplication with quantized input and floating-point output. /// /// The LHS matrix is non-transposed and the RHS matrix is transposed. diff --git a/test/tests/imatmul_test.cpp b/test/tests/imatmul_test.cpp new file mode 100644 index 00000000..bb1e54f2 --- /dev/null +++ b/test/tests/imatmul_test.cpp @@ -0,0 +1,500 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include + +#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h" +#include "test/common/compare.hpp" +#include "test/common/cpu_info.hpp" +#include "test/common/matmul_test_common.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/memory.hpp" +#include "test/common/round.hpp" +#include "test/common/sme.hpp" +#include "test/reference/clamp.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/reorder.hpp" + +namespace kai::test { + +// Ensure static linkage for all functionality local to this test file +namespace { + +/// Convenience wrapper for K-chunk handling +struct KChunk { + size_t count; + size_t length; +}; + +/// Interface for indirect matmul LHS packing kernel +struct LhsPackIndirectKernel { + std::function get_m_step; + std::function get_packed_lhs_offset; + std::function get_packed_lhs_size; + std::function + pack; +}; + +/// Interface for indirect matmul RHS packing kernel +struct RhsPackIndirectKernel { + std::function get_n_step; + std::function get_rhs_offset; + std::function get_bias_offset; + std::function get_packed_rhs_offset; + std::function get_packed_rhs_size; + std::function + pack; +}; + +/// Interface for indirect matmul kernel +struct MatMulIndirectKernel { + std::function get_m_step; + std::function get_n_step; + std::function get_mr; + std::function get_nr; + std::function get_kr; + std::function get_packed_lhs_offset; + std::function get_packed_rhs_offset; + std::function get_dst_offset; + std::function get_dst_size; + std::function + imatmul; +}; + +/// Description of a Indirect Matmul kernel set +struct IndirectMatMul { + std::string_view name; + std::function is_supported; + + MatMulShape pack_shape; + struct Format { + DataFormat lhs; + DataFormat rhs; + DataFormat bias; + DataFormat out; + + struct Hash { + size_t operator()(const Format& format) const { + return // + (DataFormat::Hash{}(format.lhs) << 0) ^ // + (DataFormat::Hash{}(format.rhs) << 1) ^ // + (DataFormat::Hash{}(format.bias) << 2) ^ // + (DataFormat::Hash{}(format.out) << 3); + } + }; + + private: + friend bool operator==(const Format& lhs, const Format& rhs) { + return // + lhs.lhs == rhs.lhs && // + lhs.rhs == rhs.rhs && // + lhs.bias == rhs.bias && // + lhs.out == rhs.out; + } + } format; + + LhsPackIndirectKernel lhs; + RhsPackIndirectKernel rhs; + MatMulIndirectKernel imatmul; +}; + +/// Simple byte buffer +using Buffer = std::vector; + +/// Convenience type for test list +using IndirectMatMulArray = std::array; + +/// Test parameter bundle type +using IndirectMatMulTestParams = std::tuple; + +/// Test type +using IndirectMatMulTest = testing::TestWithParam; + +/// Use interface for matmul kernel +const kai_imatmul_clamp_f16_f16p_f16p_ukernel& get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() { + static kai_imatmul_clamp_f16_f16p_f16p_ukernel ukernel; + ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + ukernel.run_imatmul = kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + return ukernel; +} + +/// Retreive the test list +const IndirectMatMulArray& get_indirect_matmul_methods() { + static IndirectMatMulArray indirect_matmul_methods{}; + + // F16 IMATMUL //////////////////////////////////////////////////////////// + indirect_matmul_methods[0].name = "indirect_matmul_f16_f16p_f16p"; + indirect_matmul_methods[0].is_supported = cpu_has_sme2; + indirect_matmul_methods[0].pack_shape.m = 2 * get_sme_vector_length(); + indirect_matmul_methods[0].pack_shape.n = 2 * get_sme_vector_length(); + indirect_matmul_methods[0].pack_shape.k = sizeof(int32_t) / sizeof(int8_t); + indirect_matmul_methods[0].format.lhs = DataFormat(DataType::FP16); + indirect_matmul_methods[0].format.rhs = DataFormat(DataType::FP16); + indirect_matmul_methods[0].format.bias = DataFormat(DataType::FP16); + indirect_matmul_methods[0].format.out = DataFormat(DataType::FP16); + + // LHS + indirect_matmul_methods[0].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme; + indirect_matmul_methods[0].lhs.get_packed_lhs_offset = + kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme; + indirect_matmul_methods[0].lhs.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme; + indirect_matmul_methods[0].lhs.pack = kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme; + + // RHS + indirect_matmul_methods[0].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; + indirect_matmul_methods[0].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; + indirect_matmul_methods[0].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; + indirect_matmul_methods[0].rhs.get_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; + indirect_matmul_methods[0].rhs.get_packed_rhs_size = + kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; + indirect_matmul_methods[0].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; + + // IMATMUL + const kai_imatmul_clamp_f16_f16p_f16p_ukernel& ukernel_f16 = + get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(); + indirect_matmul_methods[0].imatmul.get_m_step = ukernel_f16.get_m_step; + indirect_matmul_methods[0].imatmul.get_n_step = ukernel_f16.get_n_step; + indirect_matmul_methods[0].imatmul.get_packed_lhs_offset = ukernel_f16.get_lhs_packed_offset; + indirect_matmul_methods[0].imatmul.get_packed_rhs_offset = ukernel_f16.get_rhs_packed_offset; + indirect_matmul_methods[0].imatmul.get_dst_offset = ukernel_f16.get_dst_offset; + indirect_matmul_methods[0].imatmul.get_dst_size = ukernel_f16.get_dst_size; + indirect_matmul_methods[0].imatmul.imatmul = ukernel_f16.run_imatmul; + + return indirect_matmul_methods; +} + +/// Test reference identification +struct TestDataId { + MatMulShape shape; + MatMulShape pack_shape; + IndirectMatMul::Format format; + size_t k_chunk_length; + float clamp_rate; + + struct Hash { + size_t operator()(const TestDataId& test_id) const { + return // + (MatMulShape::Hash{}(test_id.shape) << 0) ^ // + (MatMulShape::Hash{}(test_id.pack_shape) << 1) ^ // + (IndirectMatMul::Format::Hash{}(test_id.format) << 2) ^ // + (std::hash{}(test_id.k_chunk_length) << 3) ^ // + (std::hash{}(test_id.clamp_rate) << 4); // + } + }; + +private: + friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { + return // + lhs.shape == rhs.shape && // + lhs.pack_shape == rhs.pack_shape && // + lhs.format == rhs.format && // + lhs.k_chunk_length == rhs.k_chunk_length && // + lhs.clamp_rate == rhs.clamp_rate; + } +}; + +/// Test reference data +struct TestData { + Buffer lhs; ///< LHS input matrix + Buffer rhs; ///< RHS input matrix + Buffer bias; ///< Bias vector + Buffer out; ///< Reference imatmul result + Buffer indirection; ///< LHS indirection buffer + uintptr_t indirection_offset; ///< LHS indirection buffer offset + Buffer padding; ///< Padding buffer + Range clamp_range; ///< Clamp range +}; + +/// Reference data generator +/// +/// Uses test id to generate reference data, and caches it. +struct ReferenceGenerator { + /// Retrieve reference data for the provided test identification + static const TestData& get_test_reference(const TestDataId test_id) { + static std::unordered_map m_data; + if (const auto itr = m_data.find(test_id); itr != end(m_data)) { + return itr->second; + } + + return m_data[test_id] = generate_reference(test_id); + } + +private: + /// Return incremented seed value + static size_t get_seed() { + static size_t seed = 0; + return seed++; + } + + /// Generate reference data. Not intended to be called + /// directly, as this would bypass caching mechanism. + static TestData generate_reference(const TestDataId& test_id) { + const auto& [chunked_shape, pack_shape, format, k_chunk_length, clamp_rate] = test_id; + + // The LHS matrix will be split into several chunks in the K dimension + const size_t k_chunk_count = chunked_shape.k; + MatMulShape shape = {chunked_shape.m, chunked_shape.n, k_chunk_count * k_chunk_length}; + + // Generate random input data + Buffer lhs = fill_matrix_random(shape.m, shape.k, format.lhs, get_seed()); + Buffer rhs = fill_matrix_random(shape.k, shape.n, format.rhs, get_seed()); + Buffer bias = fill_matrix_random(1, shape.n, format.bias, get_seed()); + + // Create a padding chunk + const DataType lhs_dt = format.lhs.data_type(); + const size_t k_chunk_size = + round_up_division(k_chunk_length * data_type_size_in_bits(format.lhs.data_type()), 8); + const size_t row_size = k_chunk_count * k_chunk_size; + Buffer lhs_padding(k_chunk_size); + for (size_t i = 0; i < k_chunk_length; i += 1) { + static constexpr double padding_value = 0; + write_array(lhs_dt, lhs_padding.data(), i, padding_value); + } + + // Set up indirection buffer + const uintptr_t indirection_offset = reinterpret_cast(lhs.data()); + std::vector indirection(chunked_shape.m * chunked_shape.k); + for (size_t i_m = 0; i_m < chunked_shape.m; i_m += 1) { + for (size_t i_k = 0; i_k < chunked_shape.k; i_k += 1) { + const size_t idx = i_m * chunked_shape.k + i_k; + // Test padding pointers using first LHS row for shapes where M > 1 + if (chunked_shape.m > 1 && i_m == 0) { + indirection.at(idx) = lhs_padding.data(); + } else { + uintptr_t offset = i_m * row_size + i_k * k_chunk_size; + indirection.at(idx) = reinterpret_cast(offset); + } + } + } + + // Pack indirection buffer + Buffer indirection_packed = reorder_block( + reinterpret_cast(indirection.data()), chunked_shape.m, chunked_shape.k, pack_shape.m, + 1); + + Buffer out = indirect_matmul( // + indirection.data(), indirection_offset, lhs_padding.data(), // + nullptr, nullptr, format.lhs.data_type(), // + rhs.data(), nullptr, nullptr, format.rhs.data_type(), // + bias.data(), nullptr, nullptr, format.bias.data_type(), // + format.out.data_type(), // + chunked_shape.m, chunked_shape.n, chunked_shape.k, k_chunk_length); + + // Calculate clamping range based on full range of values, and then clamp values + const auto [min, max] = + find_clamp_range(format.out.data_type(), out.data(), shape.m * shape.n, 1.0F - clamp_rate); + Buffer out_clamped = clamp(format.out.data_type(), out.data(), shape.m * shape.n, min, max); + + // Populate reference data + TestData test_reference; + test_reference.lhs = std::move(lhs); + test_reference.rhs = std::move(rhs); + test_reference.bias = std::move(bias); + test_reference.padding = std::move(lhs_padding); + test_reference.out = std::move(out_clamped); + test_reference.indirection_offset = indirection_offset; + test_reference.indirection = std::move(indirection_packed); + test_reference.clamp_range = {min, max}; + + return test_reference; + }; +}; + +/// Perform LHS packing for indirect matmul +Buffer pack_lhs( + const LhsPackIndirectKernel& kernel, const Rect& portion, const TestData& reference, size_t m, + const KChunk& k_chunk) { + const void* const* indirection_pointer = reinterpret_cast(reference.indirection.data()); + + // Calculate size, and allocate buffer + const size_t dst_size = kernel.get_packed_lhs_size(m, k_chunk.count, k_chunk.length); + Buffer dst(dst_size); + + // Calculate portion offsets + const size_t input_offset = portion.start_row() * k_chunk.count; + const size_t dst_offset = kernel.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); + + // Perform packing + kernel.pack( + portion.height(), k_chunk.count, k_chunk.length, // Dimensions + indirection_pointer + input_offset, // Indirection input + reference.indirection_offset, // Chunk offset + reference.padding.data(), // Padding pointer + dst.data() + dst_offset); + return dst; +} + +/// Perform RHS packign for indirect matmul +Buffer pack_rhs( + const RhsPackIndirectKernel& kernel, const Rect& portion, const TestData& reference, size_t n, + const KChunk& k_chunk, DataType type) { + // Calculate size, and allocate buffer + const size_t row_stride = round_up_division(n * data_type_size_in_bits(type), 8); + const size_t dst_size = kernel.get_packed_rhs_size(n, k_chunk.count, k_chunk.length); + Buffer dst(dst_size); + + // Calculate offsets + const size_t rhs_offset = kernel.get_rhs_offset(portion.start_col()); + const size_t bias_offset = kernel.get_bias_offset(portion.start_col()); + const size_t dst_offset = kernel.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length); + + // Perform actual packing + kernel.pack( + portion.width(), k_chunk.count, k_chunk.length, row_stride, // Dimensions + reference.rhs.data() + rhs_offset, // RHS input + reference.bias.data() + bias_offset, // Bias + dst.data() + dst_offset); // Output + return dst; +} + +/// Perform imatmul +/// +/// Note, this should not be aware of reference result, as to make it clear that +/// any produced result is strictly from the code under test +Buffer imatmul( + const MatMulIndirectKernel& kernel, const Rect& portion, const MatMulShape& shape, const KChunk& k_chunk, + const Buffer& packed_lhs, const Buffer& packed_rhs, Range clamp_range, DataType type) { + // Calculate size, and allocate buffer + const size_t dst_size = kernel.get_dst_size(shape.m, shape.n); + const size_t row_stride = round_up_division(shape.n * data_type_size_in_bits(type), 8); + Buffer dst(dst_size); + + // Calculate portion offsets + const size_t lhs_offset = kernel.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); + const size_t rhs_offset = kernel.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length); + const size_t dst_offset = kernel.get_dst_offset(portion.start_row(), portion.start_col(), row_stride); + + // Call matmul kernel + kernel.imatmul( + portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions + packed_lhs.data() + lhs_offset, // LHS + packed_rhs.data() + rhs_offset, // RHS + dst.data() + dst_offset, // DST + row_stride, clamp_range.min, clamp_range.max); + + return dst; +} + +} // namespace + +/// End-to-end test for indirection matmul kernels +TEST_P(IndirectMatMulTest, Output) { + const auto& [method, shape, k_chunk_length, clamp_rate, output_portion] = GetParam(); + if (not method.is_supported()) { + GTEST_SKIP() << "CPU features are not supported by current CPU"; + } + + const KChunk k_chunk{shape.k, k_chunk_length}; + + // Retrieve reference data + const TestData& test_data = + ReferenceGenerator::get_test_reference({shape, method.pack_shape, method.format, k_chunk_length, clamp_rate}); + const Rect portion = output_portion.compute_portion(shape.m, shape.n, method.pack_shape.m, method.pack_shape.n); + + // Call packing kernels, and then imatmul kernel + Buffer packed_lhs = pack_lhs(method.lhs, portion, test_data, shape.m, k_chunk); + Buffer packed_rhs = pack_rhs(method.rhs, portion, test_data, shape.n, k_chunk, method.format.rhs.data_type()); + Buffer out = imatmul( + method.imatmul, portion, shape, k_chunk, packed_lhs, packed_rhs, test_data.clamp_range, + method.format.out.data_type()); + + // Compare the actual result with the reference result + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + const auto success = + compare(out.data(), test_data.out.data(), method.format.out.data_type(), shape.m, shape.n, portion, handler); + ASSERT_TRUE(success); +} + +/// Name generator for test case +[[maybe_unused]] static void PrintTo(const IndirectMatMulTestParams& param, std::ostream* os) { + const auto& [method, shape, k_chunk_length, clamp_rate, portion] = param; + *os << "Method_" << method.name << "__"; + PrintTo(shape, os); + *os << "__K_chunk_length_" << k_chunk_length; + *os << "__clamp_rate_" << static_cast(clamp_rate * 100) << "__"; + PrintTo(portion, os); +} + +/// Test parameter listing +INSTANTIATE_TEST_SUITE_P( + IndirectMatMul, IndirectMatMulTest, + testing::Combine( + testing::ValuesIn(get_indirect_matmul_methods()), // + testing::ValuesIn({ + // clang-format off + MatMulShape{ 1, 1, 1}, // + MatMulShape{ 1, 17, 4}, // + MatMulShape{ 1, 19, 24}, // + MatMulShape{ 1, 32, 4}, // + MatMulShape{ 1, 32, 32}, // + MatMulShape{ 1, 33, 200}, // + MatMulShape{ 1, 49, 21}, // + MatMulShape{ 1, 64, 4}, // + MatMulShape{ 1, 65, 4}, // + MatMulShape{ 3, 6, 6}, // + MatMulShape{ 3, 28, 25}, // + MatMulShape{ 4, 16, 4}, // + MatMulShape{ 4, 16, 27}, // + MatMulShape{ 6, 18, 31}, // + MatMulShape{ 6, 28, 1}, // + MatMulShape{ 6, 29, 24}, // + MatMulShape{ 8, 16, 16}, // + MatMulShape{ 16, 16, 4}, // + MatMulShape{ 16, 16, 16}, // + MatMulShape{ 20, 30, 40}, // + MatMulShape{ 23, 1, 43}, // + MatMulShape{ 32, 14, 1}, // + MatMulShape{ 32, 16, 27}, // + MatMulShape{ 32, 32, 3}, // + MatMulShape{ 32, 32, 4}, // + MatMulShape{ 33, 29, 24}, // + MatMulShape{ 64, 64, 3}, // + MatMulShape{ 64, 64, 4}, // + MatMulShape{ 96, 96, 3}, // + MatMulShape{ 96, 97, 3}, // + MatMulShape{ 97, 96, 3}, // + MatMulShape{123, 85, 45}, // + MatMulShape{128, 128, 3}, // + MatMulShape{130, 130, 6}, // + // clang-format on + }), + testing::ValuesIn(std::initializer_list{1, 2, 3, 4, 8, 11, 16, 32, 33, 64, 65}), // + testing::ValuesIn(std::initializer_list{0.0F, 0.1F, 0.5F}), // + testing::ValuesIn({ + // clang-format off + // (Start row , start col , height , width) + MatrixPortion( 0 , 0 , 1 , 1 ), // Full matrix. + MatrixPortion( 0 , 0 , 1 , 0.5 ), // Left half + MatrixPortion( 0 , 0 , 0.5 , 1 ), // Upper half + MatrixPortion( 0 , 0.5 , 1 , 0.5 ), // Right half + MatrixPortion( 0.5 , 0 , 0.5 , 1 ), // Bottom half + MatrixPortion( 0.4 , 0.4 , 0.3 , 0.3 ), // Center ninth + // clang-format on + })), // + testing::PrintToStringParamName()); + +} // namespace kai::test -- GitLab From 1ec794931fcf26042258342f9d5548f509973f08 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Wed, 16 Apr 2025 12:56:34 +0000 Subject: [PATCH 5/7] Add unoredered map inclusion Signed-off-by: Emil Ohlsson Approved-by: Jakub Sujak --- test/tests/imatmul_test.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/test/tests/imatmul_test.cpp b/test/tests/imatmul_test.cpp index bb1e54f2..b9fa5637 100644 --- a/test/tests/imatmul_test.cpp +++ b/test/tests/imatmul_test.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h" -- GitLab From d7b2752f6f0362d500d6bc451d8a1936e60eeac6 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Wed, 16 Apr 2025 13:38:08 +0000 Subject: [PATCH 6/7] Add FP16 IGEMM to CHANGELOG.md Signed-off-by: Emil Ohlsson Approved-by: Jakub Sujak --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ebbd0acc..9dcaa5f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,12 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - New Advanced SIMD micro-kernels: - Matrix multiplication (MxN) Micro-kernels of QAI8DX LHS and QSI8CX RHS with F16 output, optimized for FEAT_I8MM and FEAT_DotProd. - Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI8CX RHS with F16 output, optimized for FEAT_DotProd. +- New SME micro-kernels: + - Indirect matrix multiplication (MxN) of FP16 input and output. + - Packing kernels for LHS and RHS +- New SME2 micro-kernels: + - Indirect matrix multiplication (MxN) of FP16 input and output. + - Matrix multiplication of packed indirect LHS and packed RHS ## v1.7.0 -- GitLab From 264c4d4b5b8cf8b7cda61055f79cf9f9d0a1c5c1 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Wed, 23 Apr 2025 09:08:15 +0200 Subject: [PATCH 7/7] Address review comments Signed-off-by: Emil Ohlsson --- ...16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h | 5 +- ..._imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h | 1 - test/reference/clamp.hpp | 2 +- test/reference/matmul.cpp | 10 +- test/reference/matmul.hpp | 9 +- test/tests/imatmul_test.cpp | 92 ++++++++++--------- 6 files changed, 58 insertions(+), 61 deletions(-) diff --git a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h index a54d7c2a..79c52a42 100644 --- a/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h @@ -36,7 +36,6 @@ size_t kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(vo /// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. /// @param[in] k_chunk_count Number of LHS column splits. /// @param[in] k_chunk_length Length of a LHS column split. -/// @param[in] k Number of columns in the unpacked LHS matrix. /// /// @return The offset in bytes to the data element. size_t kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( @@ -83,8 +82,8 @@ size_t kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( /// @param[in] n Number of output columns to be computed. /// @param[in] k_chunk_count Number of LHS column splits. /// @param[in] k_chunk_length Length of a LHS column split -/// @param[in] packed_lhs Packed LHS matrix buffer. -/// @param[in] packed_rhs Packed RHS matrix buffer. +/// @param[in] lhs_packed Packed LHS matrix buffer. +/// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. /// @param[in] dst_row_stride Row stride in bytes of the output matrix. /// @param[in] clamp_min Minimum value to clamp the final result. diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h index 10578871..e26bc3f5 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h @@ -69,7 +69,6 @@ size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( /// @param[in] rhs_row_stride Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. -/// @param[in] scale Scale data buffer. /// @param[out] rhs_packed Packed RHS matrix. void kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, diff --git a/test/reference/clamp.hpp b/test/reference/clamp.hpp index 1570e129..532e7d25 100644 --- a/test/reference/clamp.hpp +++ b/test/reference/clamp.hpp @@ -50,6 +50,6 @@ std::vector clamp(const void* src, size_t len, T min_value, T max_value /// @param[in] src Data buffer of the source matrix. /// @param[in] len Number of values in the source matrix. /// @param[in] min_value Lower bound of clamp. -/// @param[in] width Upper bound of clamp. +/// @param[in] max_value Upper bound of clamp. std::vector clamp(DataType type, const void* src, size_t len, float min_value, float max_value); } // namespace kai::test diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 5e743245..b1378c75 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -185,10 +185,10 @@ std::vector matmul( } std::vector indirect_matmul( - const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding_ptr, - [[maybe_unused]] const void* lhs_scales, [[maybe_unused]] const void* lhs_zero_points, + const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales, + const void* lhs_zero_points, DataType lhs_dt, // - const void* rhs, [[maybe_unused]] const void* rhs_scales, [[maybe_unused]] const void* rhs_zero_points, + const void* rhs, const void* rhs_scales, const void* rhs_zero_points, DataType rhs_dt, // const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, // DataType dst_dt, // @@ -219,7 +219,7 @@ template < typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> std::vector indirect_matmul_nt_t_quantized( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // - const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales, + const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, @@ -238,7 +238,7 @@ std::vector indirect_matmul_nt_t_quantized( // Calculate the K chunk pointer. Apply offset if this is not padding const size_t k_chunk_idx = i_m * k_chunk_count + i_k_chunk; const void* k_chunk_ptr = lhs_ptrs[k_chunk_idx]; - if (k_chunk_ptr != lhs_padding) { + if (k_chunk_ptr != lhs_padding_ptr) { k_chunk_ptr = reinterpret_cast(reinterpret_cast(k_chunk_ptr) + lhs_offset); } diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index d5118896..4054cee3 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -70,7 +70,7 @@ std::vector matmul( /// @param[in] lhs_idata The indirect LHS data matrix. /// @param[in] lhs_scales (Optional) LHS operand quantization scales. /// @param[in] lhs_offset The indirection LHS data matrix offset, applied to non-padding pointers -/// @parma[in] lhs_padding The indirection LHS padding chunk pointer +/// @param[in] lhs_padding_ptr The indirection LHS padding chunk pointer /// @param[in] lhs_zero_points (Optional) LHS operand quantization zero point. /// @param[in] lhs_dt LHS operand data type. /// @param[in] rhs RHS operand data. @@ -81,9 +81,6 @@ std::vector matmul( /// @param[in] bias_scales (Optional) Bias operand quantization scales. /// @param[in] bias_zero_points (Optional) Bias operand quantization zero point. /// @param[in] bias_dt Bias operand data type. -/// @param[in] dst Output data. -/// @param[in] dst_scales (Optional) Output quantization scales. -/// @param[in] dst_zero_points (Optional) Output quantization zero point. /// @param[in] dst_dt Output data type. /// @param[in] m Output height. /// @param[in] n Output width. @@ -161,7 +158,7 @@ std::vector matmul_clamp_nt_t( /// @param[in] lhs_data The LHS data matrix. /// @param[in] lhs_idata The indirect LHS data matrix. /// @param[in] lhs_offset The indirection LHS data matrix offset, applied to non-padding pointers -/// @parma[in] lhs_padding The indirection LHS padding chunk pointer +/// @param[in] lhs_padding_ptr The indirection LHS padding chunk pointer /// @param[in] lhs_scales The LHS quantization scales matrix. /// @param[in] lhs_zero_points The LHS quantization zero points matrix. /// @param[in] lhs_quant_width The LHS quantization block width. @@ -200,7 +197,7 @@ template < typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> std::vector indirect_matmul_nt_t_quantized( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // - const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales, + const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, diff --git a/test/tests/imatmul_test.cpp b/test/tests/imatmul_test.cpp index b9fa5637..f787143b 100644 --- a/test/tests/imatmul_test.cpp +++ b/test/tests/imatmul_test.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -42,11 +42,11 @@ struct KChunk { /// Interface for indirect matmul LHS packing kernel struct LhsPackIndirectKernel { std::function get_m_step; - std::function get_packed_lhs_offset; - std::function get_packed_lhs_size; + std::function get_lhs_packed_offset; + std::function get_lhs_packed_size; std::function + const void* zero, void* lhs_packed)> pack; }; @@ -55,8 +55,8 @@ struct RhsPackIndirectKernel { std::function get_n_step; std::function get_rhs_offset; std::function get_bias_offset; - std::function get_packed_rhs_offset; - std::function get_packed_rhs_size; + std::function get_rhs_packed_offset; + std::function get_rhs_packed_size; std::function @@ -70,8 +70,8 @@ struct MatMulIndirectKernel { std::function get_mr; std::function get_nr; std::function get_kr; - std::function get_packed_lhs_offset; - std::function get_packed_rhs_offset; + std::function get_lhs_packed_offset; + std::function get_rhs_packed_offset; std::function get_dst_offset; std::function get_dst_size; std::function; using IndirectMatMulArray = std::array; /// Test parameter bundle type -using IndirectMatMulTestParams = std::tuple; +using IndirectMatMulTestParams = std::tuple; /// Test type using IndirectMatMulTest = testing::TestWithParam; @@ -147,11 +147,11 @@ const IndirectMatMulArray& get_indirect_matmul_methods() { static IndirectMatMulArray indirect_matmul_methods{}; // F16 IMATMUL //////////////////////////////////////////////////////////// - indirect_matmul_methods[0].name = "indirect_matmul_f16_f16p_f16p"; + indirect_matmul_methods[0].name = "indirect_matmul_f16_f16p_f16p_2vlx2vl_sme2_mopa"; indirect_matmul_methods[0].is_supported = cpu_has_sme2; indirect_matmul_methods[0].pack_shape.m = 2 * get_sme_vector_length(); indirect_matmul_methods[0].pack_shape.n = 2 * get_sme_vector_length(); - indirect_matmul_methods[0].pack_shape.k = sizeof(int32_t) / sizeof(int8_t); + indirect_matmul_methods[0].pack_shape.k = sizeof(int32_t); indirect_matmul_methods[0].format.lhs = DataFormat(DataType::FP16); indirect_matmul_methods[0].format.rhs = DataFormat(DataType::FP16); indirect_matmul_methods[0].format.bias = DataFormat(DataType::FP16); @@ -159,18 +159,18 @@ const IndirectMatMulArray& get_indirect_matmul_methods() { // LHS indirect_matmul_methods[0].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme; - indirect_matmul_methods[0].lhs.get_packed_lhs_offset = + indirect_matmul_methods[0].lhs.get_lhs_packed_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme; - indirect_matmul_methods[0].lhs.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme; + indirect_matmul_methods[0].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme; indirect_matmul_methods[0].lhs.pack = kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme; // RHS indirect_matmul_methods[0].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; indirect_matmul_methods[0].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; indirect_matmul_methods[0].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; - indirect_matmul_methods[0].rhs.get_packed_rhs_offset = + indirect_matmul_methods[0].rhs.get_rhs_packed_offset = kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; - indirect_matmul_methods[0].rhs.get_packed_rhs_size = + indirect_matmul_methods[0].rhs.get_rhs_packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; indirect_matmul_methods[0].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; @@ -179,8 +179,8 @@ const IndirectMatMulArray& get_indirect_matmul_methods() { get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(); indirect_matmul_methods[0].imatmul.get_m_step = ukernel_f16.get_m_step; indirect_matmul_methods[0].imatmul.get_n_step = ukernel_f16.get_n_step; - indirect_matmul_methods[0].imatmul.get_packed_lhs_offset = ukernel_f16.get_lhs_packed_offset; - indirect_matmul_methods[0].imatmul.get_packed_rhs_offset = ukernel_f16.get_rhs_packed_offset; + indirect_matmul_methods[0].imatmul.get_lhs_packed_offset = ukernel_f16.get_lhs_packed_offset; + indirect_matmul_methods[0].imatmul.get_rhs_packed_offset = ukernel_f16.get_rhs_packed_offset; indirect_matmul_methods[0].imatmul.get_dst_offset = ukernel_f16.get_dst_offset; indirect_matmul_methods[0].imatmul.get_dst_size = ukernel_f16.get_dst_size; indirect_matmul_methods[0].imatmul.imatmul = ukernel_f16.run_imatmul; @@ -265,10 +265,14 @@ private: Buffer rhs = fill_matrix_random(shape.k, shape.n, format.rhs, get_seed()); Buffer bias = fill_matrix_random(1, shape.n, format.bias, get_seed()); - // Create a padding chunk + // Data types used const DataType lhs_dt = format.lhs.data_type(); - const size_t k_chunk_size = - round_up_division(k_chunk_length * data_type_size_in_bits(format.lhs.data_type()), 8); + const DataType rhs_dt = format.rhs.data_type(); + const DataType out_dt = format.out.data_type(); + const DataType bias_dt = format.bias.data_type(); + + // Create a padding chunk + const size_t k_chunk_size = round_up_division(k_chunk_length * data_type_size_in_bits(lhs_dt), 8); const size_t row_size = k_chunk_count * k_chunk_size; Buffer lhs_padding(k_chunk_size); for (size_t i = 0; i < k_chunk_length; i += 1) { @@ -297,18 +301,16 @@ private: reinterpret_cast(indirection.data()), chunked_shape.m, chunked_shape.k, pack_shape.m, 1); - Buffer out = indirect_matmul( // - indirection.data(), indirection_offset, lhs_padding.data(), // - nullptr, nullptr, format.lhs.data_type(), // - rhs.data(), nullptr, nullptr, format.rhs.data_type(), // - bias.data(), nullptr, nullptr, format.bias.data_type(), // - format.out.data_type(), // + Buffer out = indirect_matmul( // + indirection.data(), indirection_offset, lhs_padding.data(), nullptr, nullptr, lhs_dt, // LHS + rhs.data(), nullptr, nullptr, rhs_dt, // RHS + bias.data(), nullptr, nullptr, bias_dt, // Bias + out_dt, // Out chunked_shape.m, chunked_shape.n, chunked_shape.k, k_chunk_length); // Calculate clamping range based on full range of values, and then clamp values - const auto [min, max] = - find_clamp_range(format.out.data_type(), out.data(), shape.m * shape.n, 1.0F - clamp_rate); - Buffer out_clamped = clamp(format.out.data_type(), out.data(), shape.m * shape.n, min, max); + const auto [min, max] = find_clamp_range(out_dt, out.data(), shape.m * shape.n, 1.0F - clamp_rate); + Buffer out_clamped = clamp(out_dt, out.data(), shape.m * shape.n, min, max); // Populate reference data TestData test_reference; @@ -332,12 +334,12 @@ Buffer pack_lhs( const void* const* indirection_pointer = reinterpret_cast(reference.indirection.data()); // Calculate size, and allocate buffer - const size_t dst_size = kernel.get_packed_lhs_size(m, k_chunk.count, k_chunk.length); + const size_t dst_size = kernel.get_lhs_packed_size(m, k_chunk.count, k_chunk.length); Buffer dst(dst_size); // Calculate portion offsets const size_t input_offset = portion.start_row() * k_chunk.count; - const size_t dst_offset = kernel.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); + const size_t dst_offset = kernel.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length); // Perform packing kernel.pack( @@ -355,13 +357,13 @@ Buffer pack_rhs( const KChunk& k_chunk, DataType type) { // Calculate size, and allocate buffer const size_t row_stride = round_up_division(n * data_type_size_in_bits(type), 8); - const size_t dst_size = kernel.get_packed_rhs_size(n, k_chunk.count, k_chunk.length); + const size_t dst_size = kernel.get_rhs_packed_size(n, k_chunk.count, k_chunk.length); Buffer dst(dst_size); // Calculate offsets const size_t rhs_offset = kernel.get_rhs_offset(portion.start_col()); const size_t bias_offset = kernel.get_bias_offset(portion.start_col()); - const size_t dst_offset = kernel.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length); + const size_t dst_offset = kernel.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length); // Perform actual packing kernel.pack( @@ -378,22 +380,22 @@ Buffer pack_rhs( /// any produced result is strictly from the code under test Buffer imatmul( const MatMulIndirectKernel& kernel, const Rect& portion, const MatMulShape& shape, const KChunk& k_chunk, - const Buffer& packed_lhs, const Buffer& packed_rhs, Range clamp_range, DataType type) { + const Buffer& lhs_packed, const Buffer& rhs_packed, Range clamp_range, DataType type) { // Calculate size, and allocate buffer const size_t dst_size = kernel.get_dst_size(shape.m, shape.n); const size_t row_stride = round_up_division(shape.n * data_type_size_in_bits(type), 8); Buffer dst(dst_size); // Calculate portion offsets - const size_t lhs_offset = kernel.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); - const size_t rhs_offset = kernel.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length); + const size_t lhs_offset = kernel.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length); + const size_t rhs_offset = kernel.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length); const size_t dst_offset = kernel.get_dst_offset(portion.start_row(), portion.start_col(), row_stride); // Call matmul kernel kernel.imatmul( portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions - packed_lhs.data() + lhs_offset, // LHS - packed_rhs.data() + rhs_offset, // RHS + lhs_packed.data() + lhs_offset, // LHS + rhs_packed.data() + rhs_offset, // RHS dst.data() + dst_offset, // DST row_stride, clamp_range.min, clamp_range.max); @@ -404,7 +406,7 @@ Buffer imatmul( /// End-to-end test for indirection matmul kernels TEST_P(IndirectMatMulTest, Output) { - const auto& [method, shape, k_chunk_length, clamp_rate, output_portion] = GetParam(); + const auto& [method, shape, k_chunk_length, output_portion, clamp_rate] = GetParam(); if (not method.is_supported()) { GTEST_SKIP() << "CPU features are not supported by current CPU"; } @@ -417,10 +419,10 @@ TEST_P(IndirectMatMulTest, Output) { const Rect portion = output_portion.compute_portion(shape.m, shape.n, method.pack_shape.m, method.pack_shape.n); // Call packing kernels, and then imatmul kernel - Buffer packed_lhs = pack_lhs(method.lhs, portion, test_data, shape.m, k_chunk); - Buffer packed_rhs = pack_rhs(method.rhs, portion, test_data, shape.n, k_chunk, method.format.rhs.data_type()); + Buffer lhs_packed = pack_lhs(method.lhs, portion, test_data, shape.m, k_chunk); + Buffer rhs_packed = pack_rhs(method.rhs, portion, test_data, shape.n, k_chunk, method.format.rhs.data_type()); Buffer out = imatmul( - method.imatmul, portion, shape, k_chunk, packed_lhs, packed_rhs, test_data.clamp_range, + method.imatmul, portion, shape, k_chunk, lhs_packed, rhs_packed, test_data.clamp_range, method.format.out.data_type()); // Compare the actual result with the reference result @@ -432,7 +434,7 @@ TEST_P(IndirectMatMulTest, Output) { /// Name generator for test case [[maybe_unused]] static void PrintTo(const IndirectMatMulTestParams& param, std::ostream* os) { - const auto& [method, shape, k_chunk_length, clamp_rate, portion] = param; + const auto& [method, shape, k_chunk_length, portion, clamp_rate] = param; *os << "Method_" << method.name << "__"; PrintTo(shape, os); *os << "__K_chunk_length_" << k_chunk_length; @@ -484,7 +486,6 @@ INSTANTIATE_TEST_SUITE_P( // clang-format on }), testing::ValuesIn(std::initializer_list{1, 2, 3, 4, 8, 11, 16, 32, 33, 64, 65}), // - testing::ValuesIn(std::initializer_list{0.0F, 0.1F, 0.5F}), // testing::ValuesIn({ // clang-format off // (Start row , start col , height , width) @@ -495,7 +496,8 @@ INSTANTIATE_TEST_SUITE_P( MatrixPortion( 0.5 , 0 , 0.5 , 1 ), // Bottom half MatrixPortion( 0.4 , 0.4 , 0.3 , 0.3 ), // Center ninth // clang-format on - })), // + }), + testing::ValuesIn(std::initializer_list{0.0F, 0.1F, 0.5F})), // testing::PrintToStringParamName()); } // namespace kai::test -- GitLab